diff --git a/src/beanmachine/graph/distribution/distribution.h b/src/beanmachine/graph/distribution/distribution.h index d1d0cedcc6..a3037d904a 100644 --- a/src/beanmachine/graph/distribution/distribution.h +++ b/src/beanmachine/graph/distribution/distribution.h @@ -112,11 +112,13 @@ class Distribution : public graph::Node { the log prob of the distribution w.r.t. the sampled value. :param value: value of the child Sample operator, a single draw from the distribution - :param back_grad: back_grad1 of the child Sample operator, to be incremented - :param adjunct: a multiplier that represents the gradient of the target - function w.r.t the log prob of this distribution. It uses the default value - 1.0 if the direct child is a StochasticOperator, but requires input if the - direct child is a mixture distribution. + :param back_grad: variable to which the gradient will be added. + :param adjunct: if we are interested in df(log_prob)/dvalue, then + adjunct must be df(log_prob)/dlog_prob. + If we are interested in dlog_prob/dvalue then the adjunct is 1, + which is the default. + For other cases (such as this distribution being a component of a + mixture distribution), the appropriate adjunct must be provided. */ virtual void backward_value( const graph::NodeValue& /* value */, @@ -134,14 +136,8 @@ class Distribution : public graph::Node { graph::DoubleMatrix& /* back_grad */, Eigen::MatrixXd& /* adjunct */) const {} /* - In backward gradient propagation, increments the back_grad1 of each parent - node w.r.t. the log prob of the distribution, evaluated at the given value. - :param value: value of the child Sample operator, a single draw from the - distribution - :param adjunct: a multiplier that represents the gradient of the - target function w.r.t the log prob of this distribution. It uses the default - value 1.0 if the direct child is a StochasticOperator, but requires input if - the direct child is a mixture distribution. + Analogous to backward_value, but computes the gradient + wrt to each parameter, and adds results to their back_grad1 field. */ virtual void backward_param( const graph::NodeValue& /* value */, diff --git a/src/beanmachine/graph/graph.cpp b/src/beanmachine/graph/graph.cpp index bfae4d11f3..052b321fbc 100644 --- a/src/beanmachine/graph/graph.cpp +++ b/src/beanmachine/graph/graph.cpp @@ -23,6 +23,9 @@ namespace beanmachine { namespace graph { +NATURAL_TYPE NATURAL_ZERO = 0ull; +NATURAL_TYPE NATURAL_ONE = 1ull; + std::string ValueType::to_string() const { std::string vtype; std::string atype; diff --git a/src/beanmachine/graph/graph.h b/src/beanmachine/graph/graph.h index 81cb0f66f9..21f2f59808 100644 --- a/src/beanmachine/graph/graph.h +++ b/src/beanmachine/graph/graph.h @@ -124,6 +124,9 @@ struct ValueType { typedef NATURAL_TYPE natural_t; +extern NATURAL_TYPE NATURAL_ZERO; +extern NATURAL_TYPE NATURAL_ONE; + class NodeValue { public: ValueType type; @@ -374,6 +377,9 @@ enum class DistributionType { LKJ_CHOLESKY }; +// TODO: do we really need DistributionType? Can't we know the type of a +// Distribution from its class alone? + enum class FactorType { UNKNOWN, EXP_PRODUCT, @@ -449,9 +455,17 @@ class Node { virtual bool needs_gradient() const { return true; } - // gradient_log_prob is also only valid for stochastic nodes - // TODO: shouldn't we then restrict them to those classes? See above. - // this function adds the gradients to the passed in gradients + // gradient_log_prob is also only valid for stochastic nodes. + // (TODO: shouldn't we then restrict them to those classes? See above.) + // It computes the first and second gradients of the log prob + // of this node with respect to a given target node and + // adds them to the passed-in gradient parameters. + // Note that for this computation to be correct, + // gradients (the grad1 and grad2 properties of nodes) + // must have been updated all the way from the + // target node to this node. + // This is because this method only performs a local computation + // and relies on the grad1 and grad2 attributes of nodes. virtual void gradient_log_prob( const graph::Node* target_node, double& /* grad1 */, diff --git a/src/beanmachine/graph/marginalization/marginalized_graph.cpp b/src/beanmachine/graph/marginalization/marginalized_graph.cpp index 993b73f417..85a3907996 100644 --- a/src/beanmachine/graph/marginalization/marginalized_graph.cpp +++ b/src/beanmachine/graph/marginalization/marginalized_graph.cpp @@ -63,13 +63,13 @@ all of the nodes required to compute the MarginalDistribution 4. the stochastic children nodes of the discrete sample 5. the parents (a node not in 1-4 that has a child in 1-4) -The original graph will contain +The original graph will be modified to contain 1. the MarginalDistribution node (to replace #1-3 from the subgraph above) -2. the children of the MarginalDistribution are the +2. the children of the MarginalDistribution, which are the stochastic children nodes of the discrete node (the same as #4 from the subgraph) -3. the parents of the MarginalDistribution are the parents +3. the parents of the MarginalDistribution, which are the parents of the subgraph (same as #5 from the subgraph) In order to keep the original graph and the subgraph completely @@ -87,7 +87,8 @@ same as the parent node in the graph. CHILDREN: The children of the MarginalDistribution are the stochastic children of the discrete sample node. -The stochastic children are needed to compute the MarginalDistribution, +The stochastic children are needed to compute the +log prob of the MarginalDistribution, so they are part of the subgraph. However, a "copy" of these children also needs to be added to the graph. This "copy" node is a SAMPLE node of MarginalDistribution whose value @@ -104,11 +105,13 @@ void marginalize_graph(Graph& graph, uint discrete_sample_node_id) { std::vector sto_node_ids; std::tie(det_node_ids, sto_node_ids) = compute_children(graph, discrete_sample->index); + // TODO: do we need to rename the above compute_affected_nodes, + // or even use Graph's methods for that instead of computing it ourselves? // create MarginalDistribution std::unique_ptr marginal_distribution_ptr = std::make_unique(std::move(subgraph_ptr)); - // TODO: support the correct sample type for multiple children + // TODO: support multiple children if (sto_node_ids.size() > 0) { // @lint-ignore marginal_distribution_ptr->sample_type = @@ -119,7 +122,6 @@ void marginalize_graph(Graph& graph, uint discrete_sample_node_id) { marginal_distribution_ptr.get(); SubGraph* subgraph = marginal_distribution->subgraph_ptr.get(); - // add nodes to subgraph add_nodes_to_subgraph( subgraph, discrete_distribution, @@ -127,9 +129,8 @@ void marginalize_graph(Graph& graph, uint discrete_sample_node_id) { det_node_ids, sto_node_ids); - // connect parents to MarginalDistribution in graph connect_parents_to_marginal_distribution(graph, marginal_distribution); - // add copy of parents to subgraph + add_copy_of_parent_nodes_to_subgraph(subgraph, marginal_distribution); // create and connect children to MarginalDistribution @@ -139,6 +140,7 @@ void marginalize_graph(Graph& graph, uint discrete_sample_node_id) { // list of all created nodes to add to `nodes` of current graph std::vector> created_graph_nodes; + // add MarginalDistribution to list of created nodes created_graph_nodes.push_back(std::move(marginal_distribution_ptr)); // add created nodes to list of created_graph_nodes @@ -154,6 +156,7 @@ void marginalize_graph(Graph& graph, uint discrete_sample_node_id) { // the created nodes should be inserted right after the largest parent index uint marginal_distribution_index = compute_largest_parent_index(marginal_distribution) + 1; + // insert created nodes into graph at "marginalized_node_index" for (uint i = 0; i < created_graph_nodes.size(); i++) { graph.nodes.insert( @@ -164,7 +167,7 @@ void marginalize_graph(Graph& graph, uint discrete_sample_node_id) { } /* -returns +returns 1. deterministic_node_ids are all of the deterministic nodes up until the 2. stochastic_node_ids children are reached */ diff --git a/src/beanmachine/graph/operator/stochasticop.cpp b/src/beanmachine/graph/operator/stochasticop.cpp index 9ae4a2a2e5..af87588374 100644 --- a/src/beanmachine/graph/operator/stochasticop.cpp +++ b/src/beanmachine/graph/operator/stochasticop.cpp @@ -42,7 +42,8 @@ void StochasticOperator::gradient_log_prob( // but it is not represented as an in-node. // // This makes the computation of this derivative less uniform - // and less directly corresponding to the typical use of the chain rule. + // and less directly corresponding to the typical use of the chain rule + // for the operations explicitly represented as nodes. // // Still, it should be possible to simply apply the chain rule // and find an expression involving the gradient of this stochastic node's diff --git a/src/beanmachine/graph/util.cpp b/src/beanmachine/graph/util.cpp index 0403c992e2..23b235abb4 100644 --- a/src/beanmachine/graph/util.cpp +++ b/src/beanmachine/graph/util.cpp @@ -76,13 +76,10 @@ double Phi_approx_inv(double z) { } double log_sum_exp(const std::vector& values) { - // find the max and subtract it out - double max = values[0]; - for (std::vector::size_type idx = 1; idx < values.size(); idx++) { - if (values[idx] > max) { - max = values[idx]; - } - } + // See "log-sum-exp trick for log-domain calculations" in + // https://en.wikipedia.org/wiki/LogSumExp + assert(values.size() != 0); + double max = *std::max_element(values.begin(), values.end()); double sum = 0; for (auto value : values) { sum += std::exp(value - max); @@ -96,6 +93,21 @@ double log_sum_exp(double a, double b) { return std::log(sum) + max_val; } +std::vector probs_given_log_potentials(std::vector log_pot) { + // p_i = pot_i/Z + // where Z is the normalization constant sum_i exp(log pot_i). + // = exp(log(pot_i/Z)) + // = exp(log pot_i - logZ) + // logZ is log(sum_i exp(log pot_i)) + auto logZ = log_sum_exp(log_pot); + std::vector probs; + probs.reserve(log_pot.size()); + for (size_t i = 0; i != log_pot.size(); i++) { + probs.push_back(std::exp(log_pot[i] - logZ)); + } + return probs; +} + double polygamma(int n, double x) { return boost::math::polygamma(n, x); } diff --git a/src/beanmachine/graph/util.h b/src/beanmachine/graph/util.h index bbf1683ecb..1e7c84dbc5 100644 --- a/src/beanmachine/graph/util.h +++ b/src/beanmachine/graph/util.h @@ -89,13 +89,23 @@ std::vector percentiles( } /* -Compute log of the sum of the exponentiation of all the values in the vector +Equivalent to log of sum of exponentiations of values, +but more numerically stable. :param values: vector of log values :returns: log sum exp of values */ double log_sum_exp(const std::vector& values); double log_sum_exp(double a, double b); +/* + Given log potentials log pot_i + where potentials pot_i are an unnormalized probability distribution, + return the normalized probability distribution p_i. + p_i = pot_i/Z + where Z is the normalization constant sum_i exp(log pot_i). +*/ +std::vector probs_given_log_potentials(std::vector log_pot); + struct BinaryLogSumExp { double operator()(double a, double b) const { return log_sum_exp(a, b);