diff --git a/gtsam/discrete/DiscreteSearch.cpp b/gtsam/discrete/DiscreteSearch.cpp index 43f321d4ac..c046f508f9 100644 --- a/gtsam/discrete/DiscreteSearch.cpp +++ b/gtsam/discrete/DiscreteSearch.cpp @@ -9,7 +9,7 @@ * -------------------------------------------------------------------------- */ -/* +/** * DiscreteSearch.cpp * * @date January, 2025 @@ -25,22 +25,19 @@ namespace gtsam { using Slot = DiscreteSearch::Slot; using Solution = DiscreteSearch::Solution; -/** - * @brief Represents a node in the search tree for discrete search algorithms. - * - * @details Each SearchNode contains a partial assignment of discrete variables, - * the current error, a bound on the final error, and the index of the next - * conditional to be assigned. +/* + * A SearchNode represents a node in the search tree for the search algorithm. + * Each SearchNode contains a partial assignment of discrete variables, the + * current error, a bound on the final error, and the index of the next + * slot to be assigned. */ struct SearchNode { - DiscreteValues assignment; ///< Partial assignment of discrete variables. - double error; ///< Current error for the partial assignment. - double bound; ///< Lower bound on the final error - std::optional next; ///< Index of the next factor to be assigned. - - /** - * @brief Construct the root node for the search. - */ + DiscreteValues assignment; // Partial assignment of discrete variables. + double error; // Current error for the partial assignment. + double bound; // Lower bound on the final error + std::optional next; // Index of the next slot to be assigned. + + // Construct the root node for the search. static SearchNode Root(size_t numSlots, double bound) { return {DiscreteValues(), 0.0, bound, 0}; } @@ -51,10 +48,10 @@ struct SearchNode { } }; - /// Checks if the node represents a complete assignment. + // Checks if the node represents a complete assignment. inline bool isComplete() const { return !next; } - /// Expands the node by assigning the next variable(s). + // Expands the node by assigning the next variable(s). SearchNode expand(const DiscreteValues& fa, const Slot& slot, std::optional nextSlot) const { // Combine the new frontal assignment with the current partial assignment @@ -66,7 +63,7 @@ struct SearchNode { return {newAssignment, errorSoFar, errorSoFar + slot.heuristic, nextSlot}; } - /// Prints the SearchNode to an output stream. + // Prints the SearchNode to an output stream. friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) { os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")"; return os; @@ -79,17 +76,20 @@ struct CompareSolution { } }; -// Define the Solutions class +/* + * A Solutions object maintains a priority queue of the best solutions found + * during the search. The priority queue is limited to a maximum size, and + * solutions are only added if they are better than the worst solution. + */ class Solutions { - private: - size_t maxSize_; + size_t maxSize_; // Maximum number of solutions to keep std::priority_queue, CompareSolution> pq_; public: Solutions(size_t maxSize) : maxSize_(maxSize) {} - /// Add a solution to the priority queue, possibly evicting the worst one. - /// Return true if we added the solution. + // Add a solution to the priority queue, possibly evicting the worst one. + // Return true if we added the solution. bool maybeAdd(double error, const DiscreteValues& assignment) { const bool full = pq_.size() == maxSize_; if (full && error >= pq_.top().error) return false; @@ -98,7 +98,7 @@ class Solutions { return true; } - /// Check if we have any solutions + // Check if we have any solutions bool empty() const { return pq_.empty(); } // Method to print all solutions @@ -112,9 +112,9 @@ class Solutions { return os; } - /// Check if (partial) solution with given bound can be pruned. If we have - /// room, we never prune. Otherwise, prune if lower bound on error is worse - /// than our current worst error. + // Check if (partial) solution with given bound can be pruned. If we have + // room, we never prune. Otherwise, prune if lower bound on error is worse + // than our current worst error. bool prune(double bound) const { if (pq_.size() < maxSize_) return false; return bound >= pq_.top().error; @@ -134,9 +134,9 @@ class Solutions { } }; -/// @brief Get the factor associated with a node, possibly product of factors. +// Get the factor associated with a node, possibly product of factors. template -static auto getFactor(const NodeType& node) { +static DiscreteFactor::shared_ptr getFactor(const NodeType& node) { const auto& factors = node->factors; return factors.size() == 1 ? factors.back() : DiscreteFactorGraph(factors).product(); @@ -145,7 +145,7 @@ static auto getFactor(const NodeType& node) { DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) { using NodePtr = std::shared_ptr; auto visitor = [this](const NodePtr& node, int data) { - const auto factor = getFactor(node); + const DiscreteFactor::shared_ptr factor = getFactor(node); const size_t cardinality = factor->cardinality(node->key); std::vector> pairs{{node->key, cardinality}}; const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0}; @@ -266,13 +266,14 @@ std::vector DiscreteSearch::run(size_t K) const { // Extract solutions from bestSolutions in ascending order of error return solutions.extractSolutions(); } - -// We have a number of factors, each with a max value, and we want to compute -// a lower-bound on the cost-to-go for each slot, *not* including this factor. -// For the last slot, this is 0.0, as the cost after that is zero. -// For the second-to-last slot, it is -log(max(factor[0])), because after we -// assign slot[1] we still need to assign slot[0], which will cost *at least* -// h0. We return the estimated lower bound of the cost for *all* slots. +/* + * We have a number of factors, each with a max value, and we want to compute + * a lower-bound on the cost-to-go for each slot, *not* including this factor. + * For the last slot[n-1], this is 0.0, as the cost after that is zero. + * For the second-to-last slot, it is h = -log(max(factor[n-1])), because after + * we assign slot[n-2] we still need to assign slot[n-1], which will cost *at + * least* h. We return the estimated lower bound of the cost for *all* slots. + */ double DiscreteSearch::computeHeuristic() { double error = 0.0; for (auto it = slots_.rbegin(); it != slots_.rend(); ++it) { diff --git a/gtsam/discrete/DiscreteSearch.h b/gtsam/discrete/DiscreteSearch.h index 700e41392f..b610955b29 100644 --- a/gtsam/discrete/DiscreteSearch.h +++ b/gtsam/discrete/DiscreteSearch.h @@ -9,7 +9,7 @@ * -------------------------------------------------------------------------- */ -/* +/** * @file DiscreteSearch.h * @brief Defines the DiscreteSearch class for discrete search algorithms. * @@ -28,24 +28,40 @@ namespace gtsam { /** - * DiscreteSearch: Search for the K best solutions. + * @brief DiscreteSearch: Search for the K best solutions. + * + * This class is used to search for the K best solutions in a DiscreteBayesNet. + * This is implemented with a modified A* search algorithm that uses a priority + * queue to manage the search nodes. That machinery is defined in the .cpp file. + * The heuristic we use is the sum of the log-probabilities of the + * maximum-probability assignments for each slot, for all slots to the right of + * the current slot. + * + * TODO: The heuristic could be refined by using the partial assignment in + * search node to refine the max-probability assignment for the remaining slots. + * This would incur more computation but will lead to fewer expansions. */ class GTSAM_EXPORT DiscreteSearch { public: - /// We structure the search as a set of slots, each with a factor and - /// a set of variable assignments that need to be chosen. In addition, each - /// slot has a heuristic associated with it. + /** + * We structure the search as a set of slots, each with a factor and + * a set of variable assignments that need to be chosen. In addition, each + * slot has a heuristic associated with it. + * + * Example: + * The factors in the search problem (always parents before descendents!): + * [P(A), P(B|A), P(C|A,B)] + * The assignments for each factor. + * [[A0,A1], [B0,B1], [C0,C1,C2]] + * A lower bound on the cost-to-go after each slot, e.g., + * [-log(max_B P(B|A)) -log(max_C P(C|A,B)), -log(max_C P(C|A,B)), 0.0] + * Note that these decrease as we move from right to left. + * We keep the global lower bound as lowerBound_. In the example, it is: + * -log(max_B P(B|A)) -log(max_C P(C|A,B)) -log(max_C P(C|A,B)) + */ struct Slot { - /// The factors in the search problem, - /// e.g., [P(B|A),P(A)] DiscreteFactor::shared_ptr factor; - - /// The assignments for each factor, - /// e.g., [[B0,B1] [A0,A1]] std::vector assignments; - - /// A lower bound on the cost-to-go for each slot, e.g., - /// [-log(max_B P(B|A)), -log(max_A P(A))] double heuristic; friend std::ostream& operator<<(std::ostream& os, const Slot& slot) { @@ -56,8 +72,10 @@ class GTSAM_EXPORT DiscreteSearch { } }; - /// A solution is then a set of assignments, covering all the slots. - /// as well as an associated error = -log(probability) + /** + * A solution is a set of assignments, covering all the slots. + * as well as an associated error = -log(probability) + */ struct Solution { double error; DiscreteValues assignment; @@ -89,28 +107,16 @@ class GTSAM_EXPORT DiscreteSearch { const Ordering& ordering, bool buildJunctionTree = false); - /** - * @brief Constructor from a DiscreteEliminationTree. - * - * @param etree The DiscreteEliminationTree to initialize from. - */ + /// Construct from a DiscreteEliminationTree. DiscreteSearch(const DiscreteEliminationTree& etree); - /** - * @brief Constructor from a DiscreteJunctionTree. - * - * @param junctionTree The DiscreteJunctionTree to initialize from. - */ + /// Construct from a DiscreteJunctionTree. DiscreteSearch(const DiscreteJunctionTree& junctionTree); - /** - * Construct from a DiscreteBayesNet. - */ + //// Construct from a DiscreteBayesNet. DiscreteSearch(const DiscreteBayesNet& bayesNet); - /** - * Construct from a DiscreteBayesTree. - */ + /// Construct from a DiscreteBayesTree. DiscreteSearch(const DiscreteBayesTree& bayesTree); /// @} @@ -146,8 +152,10 @@ class GTSAM_EXPORT DiscreteSearch { /// @} private: - /// Compute the cumulative lower-bound cost-to-go after each slot is filled. - /// @return the estimated lower bound of the cost for *all* slots. + /** + * Compute the cumulative lower-bound cost-to-go after each slot is filled. + * @return the estimated lower bound of the cost for *all* slots. + */ double computeHeuristic(); double lowerBound_; ///< Lower bound on the cost-to-go for the entire search.