Skip to content

Commit

Permalink
Merge pull request #2009 from borglab/feature/search_wrapper
Browse files Browse the repository at this point in the history
Wrapper for DiscreteSearch
  • Loading branch information
dellaert authored Jan 29, 2025
2 parents 3c80a80 + 5e5a67d commit bb0c70b
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 58 deletions.
2 changes: 2 additions & 0 deletions gtsam/discrete/DiscreteSearch.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,6 @@ class GTSAM_EXPORT DiscreteSearch {
double lowerBound_; ///< Lower bound on the cost-to-go for the entire search.
std::vector<Slot> slots_; ///< The slots to fill in the search.
};

using DiscreteSearchSolution = DiscreteSearch::Solution; // for wrapping
} // namespace gtsam
25 changes: 25 additions & 0 deletions gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -464,4 +464,29 @@ class DiscreteJunctionTree {
const gtsam::DiscreteCluster& operator[](size_t i) const;
};

#include <gtsam/discrete/DiscreteSearch.h>
class DiscreteSearchSolution {
double error;
gtsam::DiscreteValues assignment;
DiscreteSearchSolution(double error, const gtsam::DiscreteValues& assignment);
};

class DiscreteSearch {
static DiscreteSearch FromFactorGraph(const gtsam::DiscreteFactorGraph& factorGraph,
const gtsam::Ordering& ordering,
bool buildJunctionTree = false);

DiscreteSearch(const gtsam::DiscreteEliminationTree& etree);
DiscreteSearch(const gtsam::DiscreteJunctionTree& junctionTree);
DiscreteSearch(const gtsam::DiscreteBayesNet& bayesNet);
DiscreteSearch(const gtsam::DiscreteBayesTree& bayesTree);

void print(string name = "DiscreteSearch: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;

double lowerBound() const;

std::vector<gtsam::DiscreteSearchSolution> run(size_t K = 1) const;
};

} // namespace gtsam
35 changes: 35 additions & 0 deletions python/gtsam/tests/dfg_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import numpy as np
from gtsam import Symbol


def make_key(character, index, cardinality):
"""
Helper function to mimic the behavior of gtbook.Variables discrete_series function.
"""
symbol = Symbol(character, index)
key = symbol.key()
return (key, cardinality)


def generate_transition_cpt(num_states, transitions=None):
"""
Generate a row-wise CPT for a transition matrix.
"""
if transitions is None:
# Default to identity matrix with slight regularization
transitions = np.eye(num_states) + 0.1 / num_states

# Ensure transitions sum to 1 if not already normalized
transitions /= np.sum(transitions, axis=1, keepdims=True)
return " ".join(["/".join(map(str, row)) for row in transitions])


def generate_observation_cpt(num_states, num_obs, desired_state):
"""
Generate a row-wise CPT for observations with contrived probabilities.
"""
obs = np.zeros((num_states, num_obs + 1))
obs[:, -1] = 1 # All states default to measurement num_obs
obs[desired_state, 0:-1] = 1
obs[desired_state, -1] = 0
return " ".join(["/".join(map(str, row)) for row in obs])
92 changes: 34 additions & 58 deletions python/gtsam/tests/test_DiscreteFactorGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,16 @@

import numpy as np
from gtsam.utils.test_case import GtsamTestCase
from dfg_utils import make_key, generate_transition_cpt, generate_observation_cpt

from gtsam import (DecisionTreeFactor, DiscreteConditional,
DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering,
Symbol)
from gtsam import (
DecisionTreeFactor,
DiscreteConditional,
DiscreteFactorGraph,
DiscreteKeys,
DiscreteValues,
Ordering,
)

OrderingType = Ordering.OrderingType

Expand Down Expand Up @@ -50,7 +56,7 @@ def test_evaluation(self):
assignment[1] = 1

# Check if graph evaluation works ( 0.3*0.6*4 )
self.assertAlmostEqual(.72, graph(assignment))
self.assertAlmostEqual(0.72, graph(assignment))

# Create a new test with third node and adding unary and ternary factor
graph.add(P3, "0.9 0.2 0.5")
Expand Down Expand Up @@ -100,8 +106,7 @@ def test_optimize(self):
expectedValues[1] = 0
expectedValues[2] = 0
actualValues = graph.optimize()
self.assertEqual(list(actualValues.items()),
list(expectedValues.items()))
self.assertEqual(list(actualValues.items()), list(expectedValues.items()))

def test_MPE(self):
"""Test maximum probable explanation (MPE): same as optimize."""
Expand All @@ -123,13 +128,11 @@ def test_MPE(self):
# Use maxProduct
dag = graph.maxProduct(OrderingType.COLAMD)
actualMPE = dag.argmax()
self.assertEqual(list(actualMPE.items()),
list(mpe.items()))
self.assertEqual(list(actualMPE.items()), list(mpe.items()))

# All in one
actualMPE2 = graph.optimize()
self.assertEqual(list(actualMPE2.items()),
list(mpe.items()))
self.assertEqual(list(actualMPE2.items()), list(mpe.items()))

def test_sumProduct(self):
"""Test sumProduct."""
Expand All @@ -154,11 +157,17 @@ def test_sumProduct(self):
self.assertAlmostEqual(mpeProbability, 0.36) # regression

# Use sumProduct
for ordering_type in [OrderingType.COLAMD, OrderingType.METIS, OrderingType.NATURAL,
OrderingType.CUSTOM]:
for ordering_type in [
OrderingType.COLAMD,
OrderingType.METIS,
OrderingType.NATURAL,
OrderingType.CUSTOM,
]:
bayesNet = graph.sumProduct(ordering_type)
self.assertEqual(bayesNet(mpe), mpeProbability)


class TestChains(GtsamTestCase):
def test_MPE_chain(self):
"""
Test for numerical underflow in EliminateMPE on long chains.
Expand All @@ -170,54 +179,30 @@ def test_MPE_chain(self):
desired_state = 1
states = list(range(num_states))

# Helper function to mimic the behavior of gtbook.Variables discrete_series function
def make_key(character, index, cardinality):
symbol = Symbol(character, index)
key = symbol.key()
return (key, cardinality)

X = {index: make_key("X", index, len(states)) for index in range(num_obs)}
Z = {index: make_key("Z", index, num_obs + 1) for index in range(num_obs)}
graph = DiscreteFactorGraph()

# Mostly identity transition matrix
transitions = np.eye(num_states)

# Needed otherwise mpe is always state 0?
transitions += 0.1/(num_states)

transition_cpt = []
for i in range(0, num_states):
transition_row = "/".join([str(x) for x in transitions[i]])
transition_cpt.append(transition_row)
transition_cpt = " ".join(transition_cpt)

transition_cpt = generate_transition_cpt(num_states)
for i in reversed(range(1, num_obs)):
transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt)
transition_conditional = DiscreteConditional(
X[i], [X[i - 1]], transition_cpt
)
graph.push_back(transition_conditional)

# Contrived example such that the desired state gives measurements [0, num_obs) with equal probability
# but all other states always give measurement num_obs
obs = np.zeros((num_states, num_obs+1))
obs[:,-1] = 1
obs[desired_state,0: -1] = 1
obs[desired_state,-1] = 0
obs_cpt_list = []
for i in range(0, num_states):
obs_row = "/".join([str(z) for z in obs[i]])
obs_cpt_list.append(obs_row)
obs_cpt = " ".join(obs_cpt_list)

obs_cpt = generate_observation_cpt(num_states, num_obs, desired_state)
# Contrived example where each measurement is its own index
for i in range(0, num_obs):
for i in range(num_obs):
obs_conditional = DiscreteConditional(Z[i], [X[i]], obs_cpt)
factor = obs_conditional.likelihood(i)
graph.push_back(factor)

mpe = graph.optimize()
vals = [mpe[X[i][0]] for i in range(num_obs)]

self.assertEqual(vals, [desired_state]*num_obs)
self.assertEqual(vals, [desired_state] * num_obs)

def test_sumProduct_chain(self):
"""
Expand All @@ -227,15 +212,8 @@ def test_sumProduct_chain(self):
"""
num_states = 3
chain_length = 400
desired_state = 1
states = list(range(num_states))

# Helper function to mimic the behavior of gtbook.Variables discrete_series function
def make_key(character, index, cardinality):
symbol = Symbol(character, index)
key = symbol.key()
return (key, cardinality)

X = {index: make_key("X", index, len(states)) for index in range(chain_length)}
graph = DiscreteFactorGraph()

Expand All @@ -253,18 +231,15 @@ def make_key(character, index, cardinality):

# Ensure that the stationary distribution is positive and normalized
stationary_dist /= np.sum(stationary_dist)
expected = DecisionTreeFactor(X[chain_length-1], stationary_dist.flatten())
expected = DecisionTreeFactor(X[chain_length - 1], stationary_dist.ravel())

# The transition matrix parsed by DiscreteConditional is a row-wise CPT
transitions = transitions.T
transition_cpt = []
for i in range(0, num_states):
transition_row = "/".join([str(x) for x in transitions[i]])
transition_cpt.append(transition_row)
transition_cpt = " ".join(transition_cpt)
transition_cpt = generate_transition_cpt(num_states, transitions.T)

for i in reversed(range(1, chain_length)):
transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt)
transition_conditional = DiscreteConditional(
X[i], [X[i - 1]], transition_cpt
)
graph.push_back(transition_conditional)

# Run sum product using natural ordering so the resulting Bayes net has the form:
Expand All @@ -277,5 +252,6 @@ def make_key(character, index, cardinality):
# Ensure marginal probabilities are close to the stationary distribution
self.gtsamAssertEquals(expected, last_marginal)


if __name__ == "__main__":
unittest.main()
84 changes: 84 additions & 0 deletions python/gtsam/tests/test_DiscreteSearch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
Atlanta, Georgia 30332-0415
All Rights Reserved
See LICENSE for the license information
Unit tests for Discrete Search.
Author: Frank Dellaert
"""

# pylint: disable=no-name-in-module, invalid-name

import unittest

from dfg_utils import generate_observation_cpt, generate_transition_cpt, make_key
from gtsam.utils.test_case import GtsamTestCase

from gtsam import (
DiscreteConditional,
DiscreteFactorGraph,
DiscreteSearch,
Ordering,
DefaultKeyFormatter,
)

OrderingType = Ordering.OrderingType


class TestDiscreteSearch(GtsamTestCase):
"""Tests for Discrete Factor Graphs."""

def test_MPE_chain(self):
"""
Test for numerical underflow in EliminateMPE on long chains.
Adapted from the toy problem of @pcl15423
Ref: https://github.com/borglab/gtsam/issues/1448
"""
num_states = 3
num_obs = 200
desired_state = 1
states = list(range(num_states))

X = {index: make_key("X", index, len(states)) for index in range(num_obs)}
Z = {index: make_key("Z", index, num_obs + 1) for index in range(num_obs)}
graph = DiscreteFactorGraph()

transition_cpt = generate_transition_cpt(num_states)
for i in reversed(range(1, num_obs)):
transition_conditional = DiscreteConditional(
X[i], [X[i - 1]], transition_cpt
)
graph.push_back(transition_conditional)

# Contrived example such that the desired state gives measurements [0, num_obs) with equal
# probability but all other states always give measurement num_obs
obs_cpt = generate_observation_cpt(num_states, num_obs, desired_state)
# Contrived example where each measurement is its own index
for i in range(num_obs):
obs_conditional = DiscreteConditional(Z[i], [X[i]], obs_cpt)
factor = obs_conditional.likelihood(i)
graph.push_back(factor)

# Check MPE
mpe = graph.optimize()
vals = [mpe[X[i][0]] for i in range(num_obs)]
self.assertEqual(vals, [desired_state] * num_obs)

# Create an ordering:
ordering = Ordering()
for i in reversed(range(num_obs)):
ordering.push_back(X[i][0])

# Now do Search
search = DiscreteSearch.FromFactorGraph(graph, ordering)
solutions = search.run(K=1)
mpe2 = solutions[0].assignment
# print({DefaultKeyFormatter(key): value for key, value in mpe2.items()})
vals = [mpe2[X[i][0]] for i in range(num_obs)]
self.assertEqual(vals, [desired_state] * num_obs)


if __name__ == "__main__":
unittest.main()

0 comments on commit bb0c70b

Please sign in to comment.