Skip to content

Commit

Permalink
Better and more consistent documentation.
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert committed Jan 28, 2025
1 parent 8c7e75b commit 1afb089
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 70 deletions.
75 changes: 38 additions & 37 deletions gtsam/discrete/DiscreteSearch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
* -------------------------------------------------------------------------- */

/*
/**
* DiscreteSearch.cpp
*
* @date January, 2025
Expand All @@ -25,22 +25,19 @@ namespace gtsam {
using Slot = DiscreteSearch::Slot;
using Solution = DiscreteSearch::Solution;

/**
* @brief Represents a node in the search tree for discrete search algorithms.
*
* @details Each SearchNode contains a partial assignment of discrete variables,
* the current error, a bound on the final error, and the index of the next
* conditional to be assigned.
/*
* A SearchNode represents a node in the search tree for the search algorithm.
* Each SearchNode contains a partial assignment of discrete variables, the
* current error, a bound on the final error, and the index of the next
* slot to be assigned.
*/
struct SearchNode {
DiscreteValues assignment; ///< Partial assignment of discrete variables.
double error; ///< Current error for the partial assignment.
double bound; ///< Lower bound on the final error
std::optional<size_t> next; ///< Index of the next factor to be assigned.

/**
* @brief Construct the root node for the search.
*/
DiscreteValues assignment; // Partial assignment of discrete variables.
double error; // Current error for the partial assignment.
double bound; // Lower bound on the final error
std::optional<size_t> next; // Index of the next slot to be assigned.

// Construct the root node for the search.
static SearchNode Root(size_t numSlots, double bound) {
return {DiscreteValues(), 0.0, bound, 0};
}
Expand All @@ -51,10 +48,10 @@ struct SearchNode {
}
};

/// Checks if the node represents a complete assignment.
// Checks if the node represents a complete assignment.
inline bool isComplete() const { return !next; }

/// Expands the node by assigning the next variable(s).
// Expands the node by assigning the next variable(s).
SearchNode expand(const DiscreteValues& fa, const Slot& slot,
std::optional<size_t> nextSlot) const {
// Combine the new frontal assignment with the current partial assignment
Expand All @@ -66,7 +63,7 @@ struct SearchNode {
return {newAssignment, errorSoFar, errorSoFar + slot.heuristic, nextSlot};
}

/// Prints the SearchNode to an output stream.
// Prints the SearchNode to an output stream.
friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) {
os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")";
return os;
Expand All @@ -79,17 +76,20 @@ struct CompareSolution {
}
};

// Define the Solutions class
/*
* A Solutions object maintains a priority queue of the best solutions found
* during the search. The priority queue is limited to a maximum size, and
* solutions are only added if they are better than the worst solution.
*/
class Solutions {
private:
size_t maxSize_;
size_t maxSize_; // Maximum number of solutions to keep
std::priority_queue<Solution, std::vector<Solution>, CompareSolution> pq_;

public:
Solutions(size_t maxSize) : maxSize_(maxSize) {}

/// Add a solution to the priority queue, possibly evicting the worst one.
/// Return true if we added the solution.
// Add a solution to the priority queue, possibly evicting the worst one.
// Return true if we added the solution.
bool maybeAdd(double error, const DiscreteValues& assignment) {
const bool full = pq_.size() == maxSize_;
if (full && error >= pq_.top().error) return false;
Expand All @@ -98,7 +98,7 @@ class Solutions {
return true;
}

/// Check if we have any solutions
// Check if we have any solutions
bool empty() const { return pq_.empty(); }

// Method to print all solutions
Expand All @@ -112,9 +112,9 @@ class Solutions {
return os;
}

/// Check if (partial) solution with given bound can be pruned. If we have
/// room, we never prune. Otherwise, prune if lower bound on error is worse
/// than our current worst error.
// Check if (partial) solution with given bound can be pruned. If we have
// room, we never prune. Otherwise, prune if lower bound on error is worse
// than our current worst error.
bool prune(double bound) const {
if (pq_.size() < maxSize_) return false;
return bound >= pq_.top().error;
Expand All @@ -134,9 +134,9 @@ class Solutions {
}
};

/// @brief Get the factor associated with a node, possibly product of factors.
// Get the factor associated with a node, possibly product of factors.
template <typename NodeType>
static auto getFactor(const NodeType& node) {
static DiscreteFactor::shared_ptr getFactor(const NodeType& node) {
const auto& factors = node->factors;
return factors.size() == 1 ? factors.back()
: DiscreteFactorGraph(factors).product();
Expand All @@ -145,7 +145,7 @@ static auto getFactor(const NodeType& node) {
DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) {
using NodePtr = std::shared_ptr<DiscreteEliminationTree::Node>;
auto visitor = [this](const NodePtr& node, int data) {
const auto factor = getFactor(node);
const DiscreteFactor::shared_ptr factor = getFactor(node);
const size_t cardinality = factor->cardinality(node->key);
std::vector<std::pair<Key, size_t>> pairs{{node->key, cardinality}};
const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0};
Expand Down Expand Up @@ -266,13 +266,14 @@ std::vector<Solution> DiscreteSearch::run(size_t K) const {
// Extract solutions from bestSolutions in ascending order of error
return solutions.extractSolutions();
}

// We have a number of factors, each with a max value, and we want to compute
// a lower-bound on the cost-to-go for each slot, *not* including this factor.
// For the last slot, this is 0.0, as the cost after that is zero.
// For the second-to-last slot, it is -log(max(factor[0])), because after we
// assign slot[1] we still need to assign slot[0], which will cost *at least*
// h0. We return the estimated lower bound of the cost for *all* slots.
/*
* We have a number of factors, each with a max value, and we want to compute
* a lower-bound on the cost-to-go for each slot, *not* including this factor.
* For the last slot[n-1], this is 0.0, as the cost after that is zero.
* For the second-to-last slot, it is h = -log(max(factor[n-1])), because after
* we assign slot[n-2] we still need to assign slot[n-1], which will cost *at
* least* h. We return the estimated lower bound of the cost for *all* slots.
*/
double DiscreteSearch::computeHeuristic() {
double error = 0.0;
for (auto it = slots_.rbegin(); it != slots_.rend(); ++it) {
Expand Down
74 changes: 41 additions & 33 deletions gtsam/discrete/DiscreteSearch.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
* -------------------------------------------------------------------------- */

/*
/**
* @file DiscreteSearch.h
* @brief Defines the DiscreteSearch class for discrete search algorithms.
*
Expand All @@ -28,24 +28,40 @@
namespace gtsam {

/**
* DiscreteSearch: Search for the K best solutions.
* @brief DiscreteSearch: Search for the K best solutions.
*
* This class is used to search for the K best solutions in a DiscreteBayesNet.
* This is implemented with a modified A* search algorithm that uses a priority
* queue to manage the search nodes. That machinery is defined in the .cpp file.
* The heuristic we use is the sum of the log-probabilities of the
* maximum-probability assignments for each slot, for all slots to the right of
* the current slot.
*
* TODO: The heuristic could be refined by using the partial assignment in
* search node to refine the max-probability assignment for the remaining slots.
* This would incur more computation but will lead to fewer expansions.
*/
class GTSAM_EXPORT DiscreteSearch {
public:
/// We structure the search as a set of slots, each with a factor and
/// a set of variable assignments that need to be chosen. In addition, each
/// slot has a heuristic associated with it.
/**
* We structure the search as a set of slots, each with a factor and
* a set of variable assignments that need to be chosen. In addition, each
* slot has a heuristic associated with it.
*
* Example:
* The factors in the search problem (always parents before descendents!):
* [P(A), P(B|A), P(C|A,B)]
* The assignments for each factor.
* [[A0,A1], [B0,B1], [C0,C1,C2]]
* A lower bound on the cost-to-go after each slot, e.g.,
* [-log(max_B P(B|A)) -log(max_C P(C|A,B)), -log(max_C P(C|A,B)), 0.0]
* Note that these decrease as we move from right to left.
* We keep the global lower bound as lowerBound_. In the example, it is:
* -log(max_B P(B|A)) -log(max_C P(C|A,B)) -log(max_C P(C|A,B))
*/
struct Slot {
/// The factors in the search problem,
/// e.g., [P(B|A),P(A)]
DiscreteFactor::shared_ptr factor;

/// The assignments for each factor,
/// e.g., [[B0,B1] [A0,A1]]
std::vector<DiscreteValues> assignments;

/// A lower bound on the cost-to-go for each slot, e.g.,
/// [-log(max_B P(B|A)), -log(max_A P(A))]
double heuristic;

friend std::ostream& operator<<(std::ostream& os, const Slot& slot) {
Expand All @@ -56,8 +72,10 @@ class GTSAM_EXPORT DiscreteSearch {
}
};

/// A solution is then a set of assignments, covering all the slots.
/// as well as an associated error = -log(probability)
/**
* A solution is a set of assignments, covering all the slots.
* as well as an associated error = -log(probability)
*/
struct Solution {
double error;
DiscreteValues assignment;
Expand Down Expand Up @@ -89,28 +107,16 @@ class GTSAM_EXPORT DiscreteSearch {
const Ordering& ordering,
bool buildJunctionTree = false);

/**
* @brief Constructor from a DiscreteEliminationTree.
*
* @param etree The DiscreteEliminationTree to initialize from.
*/
/// Construct from a DiscreteEliminationTree.
DiscreteSearch(const DiscreteEliminationTree& etree);

/**
* @brief Constructor from a DiscreteJunctionTree.
*
* @param junctionTree The DiscreteJunctionTree to initialize from.
*/
/// Construct from a DiscreteJunctionTree.
DiscreteSearch(const DiscreteJunctionTree& junctionTree);

/**
* Construct from a DiscreteBayesNet.
*/
//// Construct from a DiscreteBayesNet.
DiscreteSearch(const DiscreteBayesNet& bayesNet);

/**
* Construct from a DiscreteBayesTree.
*/
/// Construct from a DiscreteBayesTree.
DiscreteSearch(const DiscreteBayesTree& bayesTree);

/// @}
Expand Down Expand Up @@ -146,8 +152,10 @@ class GTSAM_EXPORT DiscreteSearch {
/// @}

private:
/// Compute the cumulative lower-bound cost-to-go after each slot is filled.
/// @return the estimated lower bound of the cost for *all* slots.
/**
* Compute the cumulative lower-bound cost-to-go after each slot is filled.
* @return the estimated lower bound of the cost for *all* slots.
*/
double computeHeuristic();

double lowerBound_; ///< Lower bound on the cost-to-go for the entire search.
Expand Down

0 comments on commit 1afb089

Please sign in to comment.