From 1f04780e5489b2f7b42de2f59af64bee1606eff4 Mon Sep 17 00:00:00 2001 From: Neal Gafter Date: Sun, 30 Oct 2022 14:50:13 -0700 Subject: [PATCH] minibmg: use NUTS from bmg. (#1794) Summary: Pull Request resolved: https://github.com/facebookresearch/beanmachine/pull/1794 Use NUTS from bmg as a library from minibmg. Currently uses (slow) reverse-mode AD for the gradients. Reviewed By: horizon-blue Differential Revision: D40356996 fbshipit-source-id: ebd1ebc054f736e57259f16b1c21d3e659950b97 --- .../graph_properties/unobserved_samples.cpp | 46 +++++ minibmg/graph_properties/unobserved_samples.h | 19 ++ minibmg/inference/global_state.cpp | 174 ++++++++++++++++++ minibmg/inference/global_state.h | 66 +++++++ minibmg/inference/hmc_world.cpp | 63 ++++--- minibmg/inference/hmc_world.h | 26 ++- minibmg/inference/mle_inference.cpp | 6 +- minibmg/tests/inference/nuts_test.cpp | 157 ++++++++++++++++ 8 files changed, 517 insertions(+), 40 deletions(-) create mode 100644 minibmg/graph_properties/unobserved_samples.cpp create mode 100644 minibmg/graph_properties/unobserved_samples.h create mode 100644 minibmg/inference/global_state.cpp create mode 100644 minibmg/inference/global_state.h create mode 100644 minibmg/tests/inference/nuts_test.cpp diff --git a/minibmg/graph_properties/unobserved_samples.cpp b/minibmg/graph_properties/unobserved_samples.cpp new file mode 100644 index 0000000000..a18926daec --- /dev/null +++ b/minibmg/graph_properties/unobserved_samples.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "beanmachine/minibmg/graph_properties/unobserved_samples.h" +#include +#include +#include +#include +#include + +namespace { + +using namespace beanmachine::minibmg; + +class unobserved_samples_property + : public Property> { + public: + std::vector* create(const Graph& g) const override { + auto result = new std::vector{}; + std::unordered_set observed_samples; + for (auto& p : g.observations) { + observed_samples.insert(p.first); + } + for (auto& node : g) { + if (std::dynamic_pointer_cast(node) && + !observed_samples.contains(node)) { + result->push_back(node); + } + } + return result; + } +}; + +} // namespace + +namespace beanmachine::minibmg { + +const std::vector& unobserved_samples(const Graph& graph) { + return *unobserved_samples_property::get(graph); +} + +} // namespace beanmachine::minibmg diff --git a/minibmg/graph_properties/unobserved_samples.h b/minibmg/graph_properties/unobserved_samples.h new file mode 100644 index 0000000000..2b7711c50a --- /dev/null +++ b/minibmg/graph_properties/unobserved_samples.h @@ -0,0 +1,19 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include "beanmachine/minibmg/graph.h" +#include "beanmachine/minibmg/node.h" + +namespace beanmachine::minibmg { + +const std::vector& unobserved_samples(const Graph& graph); + +} // namespace beanmachine::minibmg diff --git a/minibmg/inference/global_state.cpp b/minibmg/inference/global_state.cpp new file mode 100644 index 0000000000..5e75e98b88 --- /dev/null +++ b/minibmg/inference/global_state.cpp @@ -0,0 +1,174 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "beanmachine/minibmg/inference/global_state.h" +#include +#include +#include "beanmachine/graph/global/global_state.h" +#include "beanmachine/minibmg/eval.h" +#include "beanmachine/minibmg/graph_properties/unobserved_samples.h" + +namespace beanmachine::minibmg { + +MinibmgGlobalState::MinibmgGlobalState(beanmachine::minibmg::Graph& graph) + : graph{graph}, world{hmc_world_0(graph)} { + samples.clear(); + // Since we only support scalars, we count the unobserved samples by ones. + int num_unobserved_samples = -graph.observations.size(); + for (auto& node : graph) { + if (std::dynamic_pointer_cast(node)) { + num_unobserved_samples++; + } + } + flat_size = num_unobserved_samples; +} + +void MinibmgGlobalState::initialize_values( + beanmachine::graph::InitType init_type, + uint seed) { + std::mt19937 gen(31 * seed + 17); + std::vector& samples = unconstrained_values; + switch (init_type) { + case graph::InitType::PRIOR: { + // Evaluate the graph, saving samples. + auto read_variable = [](const std::string&, const unsigned) -> Real { + // there are no variables, so we don't have to read them. + throw std::logic_error("models do not contain variables"); + }; + auto my_sampler = [&samples]( + const Distribution& distribution, + std::mt19937& gen) -> SampledValue { + auto result = sample_from_distribution(distribution, gen); + // save the proposed value + samples.push_back(result.unconstrained.as_double()); + return result; + }; + auto eval_result = eval_graph( + graph, + gen, + /* read_variable= */ read_variable, + real_eval_data, + /* run_queries= */ false, + /* eval_log_prob= */ true, + /* sampler = */ my_sampler); + } break; + case graph::InitType::RANDOM: { + std::uniform_real_distribution<> uniform_real_distribution(-2, 2); + for (int i = 0; i < flat_size; i++) { + samples.push_back(uniform_real_distribution(gen)); + } + } break; + default: { + for (int i = 0; i < flat_size; i++) { + samples.push_back(0); + } + } break; + } + + // update and backup values, gradients, and log_prob + update_log_prob(); + update_backgrad(); + backup_unconstrained_values(); + backup_unconstrained_grads(); +} + +void MinibmgGlobalState::backup_unconstrained_values() { + saved_unconstrained_values = unconstrained_values; +} + +void MinibmgGlobalState::backup_unconstrained_grads() { + saved_unconstrained_grads = unconstrained_grads; +} + +void MinibmgGlobalState::revert_unconstrained_values() { + unconstrained_values = saved_unconstrained_values; +} + +void MinibmgGlobalState::revert_unconstrained_grads() { + unconstrained_grads = saved_unconstrained_grads; +} + +void MinibmgGlobalState::add_to_stochastic_unconstrained_nodes( + Eigen::VectorXd& increment) { + if (increment.size() != flat_size) { + throw std::invalid_argument( + "The size of increment is inconsistent with the values in the graph"); + } + for (int i = 0; i < flat_size; i++) { + unconstrained_values[i] += increment[i]; + } +} + +void MinibmgGlobalState::get_flattened_unconstrained_values( + Eigen::VectorXd& flattened_values) { + flattened_values.resize(flat_size); + for (int i = 0; i < flat_size; i++) { + flattened_values[i] = unconstrained_values[i]; + } +} + +void MinibmgGlobalState::set_flattened_unconstrained_values( + Eigen::VectorXd& flattened_values) { + if (flattened_values.size() != flat_size) { + throw std::invalid_argument( + "The size of flattened_values is inconsistent with the values in the graph"); + } + for (int i = 0; i < flat_size; i++) { + unconstrained_values[i] = flattened_values[i]; + } +} + +void MinibmgGlobalState::get_flattened_unconstrained_grads( + Eigen::VectorXd& flattened_grad) { + flattened_grad.resize(flat_size); + for (int i = 0; i < flat_size; i++) { + flattened_grad[i] = unconstrained_grads[i]; + } +} + +double MinibmgGlobalState::get_log_prob() { + return log_prob; +} + +void MinibmgGlobalState::update_log_prob() { + log_prob = world->log_prob(this->unconstrained_values); +} + +void MinibmgGlobalState::update_backgrad() { + unconstrained_grads = world->gradients(this->unconstrained_values); +} + +void MinibmgGlobalState::collect_sample() { + auto queries = world->queries(this->unconstrained_values); + std::vector compat_query; + for (auto v : queries) { + compat_query.emplace_back(v); + } + this->samples.emplace_back(compat_query); +} + +std::vector>& +MinibmgGlobalState::get_samples() { + return samples; +} + +void MinibmgGlobalState::set_default_transforms() { + // minibmg always uses the default transforms +} + +void MinibmgGlobalState::set_agg_type( + beanmachine::graph::AggregationType agg_type) { + if (agg_type != beanmachine::graph::AggregationType::NONE) { + throw std::logic_error("unimplemented AggregationType"); + } +} + +void MinibmgGlobalState::clear_samples() { + samples.clear(); +} + +} // namespace beanmachine::minibmg diff --git a/minibmg/inference/global_state.h b/minibmg/inference/global_state.h new file mode 100644 index 0000000000..ccc7178c39 --- /dev/null +++ b/minibmg/inference/global_state.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include "beanmachine/graph/global/global_state.h" +#include "beanmachine/graph/graph.h" +#include "beanmachine/minibmg/ad/real.h" +#include "beanmachine/minibmg/ad/reverse.h" +#include "beanmachine/minibmg/graph.h" +#include "hmc_world.h" + +namespace beanmachine::minibmg { + +// using namespace beanmachine::graph; + +// Global state, an implementation of beanmachine::graph::GlobalState which is +// needed to use the NUTS api from bmg. +class MinibmgGlobalState : public beanmachine::graph::GlobalState { + public: + explicit MinibmgGlobalState(beanmachine::minibmg::Graph& graph); + void initialize_values(beanmachine::graph::InitType init_type, uint seed) + override; + void backup_unconstrained_values() override; + void backup_unconstrained_grads() override; + void revert_unconstrained_values() override; + void revert_unconstrained_grads() override; + void add_to_stochastic_unconstrained_nodes( + Eigen::VectorXd& increment) override; + void get_flattened_unconstrained_values( + Eigen::VectorXd& flattened_values) override; + void set_flattened_unconstrained_values( + Eigen::VectorXd& flattened_values) override; + void get_flattened_unconstrained_grads( + Eigen::VectorXd& flattened_grad) override; + double get_log_prob() override; + void update_log_prob() override; + void update_backgrad() override; + void collect_sample() override; + std::vector>& get_samples() + override; + void set_default_transforms() override; + void set_agg_type(beanmachine::graph::AggregationType) override; + void clear_samples() override; + + private: + const beanmachine::minibmg::Graph& graph; + const std::unique_ptr world; + std::vector> samples; + int flat_size; + double log_prob; + std::vector unconstrained_values; + std::vector unconstrained_grads; + std::vector saved_unconstrained_values; + std::vector saved_unconstrained_grads; + + // scratchpads for evaluation + std::unordered_map> reverse_eval_data; + std::unordered_map real_eval_data; +}; + +} // namespace beanmachine::minibmg diff --git a/minibmg/inference/hmc_world.cpp b/minibmg/inference/hmc_world.cpp index 3dcca6abb9..f295ed41f5 100644 --- a/minibmg/inference/hmc_world.cpp +++ b/minibmg/inference/hmc_world.cpp @@ -15,6 +15,7 @@ #include "beanmachine/minibmg/eval.h" #include "beanmachine/minibmg/graph.h" #include "beanmachine/minibmg/graph_properties/observations_by_node.h" +#include "beanmachine/minibmg/graph_properties/unobserved_samples.h" #include "beanmachine/minibmg/node.h" namespace { @@ -24,36 +25,33 @@ using namespace beanmachine::minibmg; class HMCWorld0 : public HMCWorld { private: const Graph& graph; - std::unordered_set unobserved_samples; + std::unordered_set unobserved_sample_set; + std::unordered_map observations; public: explicit HMCWorld0(const Graph& graph); unsigned num_unobserved_samples() const override; - HMCWorldEvalResult evaluate( + double log_prob( + const std::vector& proposed_unconstrained_values) const override; + + std::vector gradients( const std::vector& proposed_unconstrained_values) const override; std::vector queries( const std::vector& proposed_unconstrained_values) const override; }; -HMCWorld0::HMCWorld0(const Graph& graph) : graph{graph} { - // we identify the set of unobserved samples by... - for (const auto& node : graph.nodes) { - // ...counting the samples... - if (dynamic_cast(node.get())) { - unobserved_samples.insert(node); - } - } - // and subtracting the observed ones. - for (const auto& p : graph.observations) { - unobserved_samples.erase(p.first); - } -} +HMCWorld0::HMCWorld0(const Graph& graph) + : graph{graph}, + unobserved_sample_set{ + unobserved_samples(graph).begin(), + unobserved_samples(graph).end()}, + observations{observations_by_node(graph)} {} unsigned HMCWorld0::num_unobserved_samples() const { - return unobserved_samples.size(); + return unobserved_sample_set.size(); } template @@ -66,8 +64,9 @@ requires Number EvalResult evaluate_internal( bool eval_log_prob) { unsigned next_sample = 0; - // Here is our function for producing an unobserved sample. We consume the - // data provided by the caller, transforming it if necessary. + // Here is our function for producing an unobserved sample by drawing from + // `proposed_unconstrained_values`. We consume that data provided by the + // caller, transforming it if necessary. std::function( const Distribution& distribution, std::mt19937& gen)> sample_from_distribution = [&](const Distribution& distribution, @@ -86,8 +85,9 @@ requires Number EvalResult evaluate_internal( // highest likelihood value. For example, with a beta(7, 5), the peak is // at X=0.6. However, when viewed in the transformed space with the // log_prob value also transformed, the peak occurs at a value - // corresponding to X=0.625. I need help understanding what to do here. - // For now we just avoid transforming the log_prob value. + // corresponding to X=0.625. I am probably misunderstanding the math. I + // need help understanding what to do here. For now we just avoid + // transforming the log_prob value. // // // logp = transform->transform_log_prob(constrained, logp); return {constrained, unconstrained, logp}; @@ -111,7 +111,25 @@ requires Number EvalResult evaluate_internal( sample_from_distribution); } -HMCWorldEvalResult HMCWorld0::evaluate( +double HMCWorld0::log_prob( + const std::vector& proposed_unconstrained_values) const { + using T = Real; + std::unordered_map data; + std::mt19937 gen; + + // evaluate the graph and its log_prob in normal mode + auto eval_result = evaluate_internal( + graph, + proposed_unconstrained_values, + data, + gen, + /* run_queries = */ false, + /* eval_log_prob = */ true); + + return eval_result.log_prob.as_double(); +} + +std::vector HMCWorld0::gradients( const std::vector& proposed_unconstrained_values) const { using T = Reverse; std::unordered_map data; @@ -144,8 +162,7 @@ HMCWorldEvalResult HMCWorld0::evaluate( } } - return HMCWorldEvalResult{ - eval_result.log_prob.as_double(), std::move(gradients)}; + return gradients; } std::vector HMCWorld0::queries( diff --git a/minibmg/inference/hmc_world.h b/minibmg/inference/hmc_world.h index 8615605741..6fc1c89fff 100644 --- a/minibmg/inference/hmc_world.h +++ b/minibmg/inference/hmc_world.h @@ -34,11 +34,19 @@ class HMCWorld { // model (transformed, if necessary, so that they are unconstrained - // supported over the real numbers), given in the same order as the // observation nodes appear in the graph, compute the log probability of the - // model with that assignment, as well as the the first derivative of the log - // probability with respect to each of the proposed values. The input vector - // is required to be of a size that is equal to the return value of + // model with that assignment. The input vector is required to be of a size + // that is equal to the return value of num_unobserved_samples. + virtual double log_prob( + const std::vector& proposed_unconstrained_values) const = 0; + + // Given proposed assigned values for all of the onobserved samples in the + // model (transformed, if necessary, so that they are unconstrained - + // supported over the real numbers), given in the same order as the + // observation nodes appear in the graph, compute the first derivative of the + // log probability with respect to each of the proposed values. The input + // vector is required to be of a size that is equal to the return value of // num_unobserved_samples. - virtual HMCWorldEvalResult evaluate( + virtual std::vector gradients( const std::vector& proposed_unconstrained_values) const = 0; // Given proposed assigned values for all of the onobserved samples in the @@ -52,16 +60,6 @@ class HMCWorld { virtual ~HMCWorld() {} }; -struct HMCWorldEvalResult { - // The computed log probability of a given assignment to the samples in a - // model. - double log_prob; - - // The derivative of the log probability with respect to each of the - // unobserved samples in a model. - std::vector gradients; -}; - // produce an abstraction of the graph for use by inference. This // implementation performs its work by evaluating the graph node by node on // demand. You can think of this as an interpreter for the graph. diff --git a/minibmg/inference/mle_inference.cpp b/minibmg/inference/mle_inference.cpp index e3be252cac..b4659b3f74 100644 --- a/minibmg/inference/mle_inference.cpp +++ b/minibmg/inference/mle_inference.cpp @@ -26,11 +26,11 @@ std::vector mle_inference_0( proposals.resize(num_samples); for (int round = 0; round < num_rounds; round++) { - auto result = abstraction->evaluate(proposals); + auto grads = abstraction->gradients(proposals); if (print_progress) { std::cout << fmt::format( "log_prob: {} inferred: {}\n", - result.log_prob, + abstraction->log_prob(proposals), abstraction->queries(proposals)[0]); } assert(!proposals.empty()); @@ -38,7 +38,7 @@ std::vector mle_inference_0( // We use + rather than - here because we want to maximize (not minimize) // the log_prob; we move in the direction of the gradient rather than // opposite to it as we would in gradient descent. - proposals[samp] += result.gradients[samp] * learning_rate; + proposals[samp] += grads[samp] * learning_rate; } } diff --git a/minibmg/tests/inference/nuts_test.cpp b/minibmg/tests/inference/nuts_test.cpp new file mode 100644 index 0000000000..f7a3c3644e --- /dev/null +++ b/minibmg/tests/inference/nuts_test.cpp @@ -0,0 +1,157 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include "beanmachine/graph/global/global_mh.h" +#include "beanmachine/graph/global/nuts.h" +#include "beanmachine/minibmg/fluid_factory.h" +#include "beanmachine/minibmg/graph.h" +#include "beanmachine/minibmg/inference/global_state.h" + +using namespace ::testing; +using namespace beanmachine::minibmg; +using beanmachine::graph::GlobalState; +using beanmachine::graph::NodeValue; +using beanmachine::graph::NUTS; + +template +requires Number T expit(const T& x) { + return 1 / (1 + exp(-x)); +} + +const int num_heads = 15; +const int num_tails = 1; +const int num_samples = 1000; +const int skip_samples = std::min(num_samples / 2, 500); +const int seed = 12345; + +// Take a familiar-looking model and runs NUTS. +TEST(nuts_test, coin_flipping) { + Graph::FluidFactory f; + + // We would like to use + // + // auto d = beta(1, 1); + // auto s = sample(d); + // + // but we don't have transformations working quite right yet, which we would + // need to use beta. So we use a distribution that doesn't require it. + + auto s = expit(sample(normal(0, 100))); + + auto bn = bernoulli(s); + for (int i = 0; i < num_heads; i++) { + f.observe(sample(bn), 1); + } + for (int i = 0; i < num_tails; i++) { + f.observe(sample(bn), 0); + } + f.query(s); + auto graph = f.build(); + + auto state = std::make_unique(graph); + auto nuts = NUTS(std::move(state)); + + auto start = std::chrono::high_resolution_clock::now(); + std::vector> infer_results = + nuts.infer(/* num_samples = */ num_samples, /* seed = */ seed); + auto finish = std::chrono::high_resolution_clock::now(); + auto time_in_microseconds = + std::chrono::duration_cast(finish - start) + .count(); + std::cout << fmt::format( + "minibmg NUTS: ran in {} s", time_in_microseconds / 1E6) + << std::endl; + + // check that the results are as expected + ASSERT_EQ(infer_results.size(), num_samples); + double sum = 0; + int count = 0; + for (int i = skip_samples; i < num_samples; i++) { + auto estimate = infer_results[i][0]._double; + sum += estimate; + count++; + } + + double average = sum / num_samples; + // the following is the actual computed value using bmg + double expected = + 0.46432190477921237; // num_heads / (num_heads + num_tails + 0.0); + ASSERT_NEAR(average, expected, 0.001); +} + +// Take a familiar-looking model and runs NUTS using bmg. +TEST(nuts_test, coin_flipping_bmg) { + using namespace beanmachine::graph; + using Graph = beanmachine::graph::Graph; + + Graph g; + + // auto s = expit(sample(normal(0, 100))); + auto k0 = g.add_constant(0.0); + auto k100 = g.add_constant_pos_real(100.0); + auto normal = g.add_distribution( + DistributionType::NORMAL, AtomicType::REAL, {k0, k100}); + auto sample_normal = g.add_operator(OperatorType::SAMPLE, {normal}); + // s = (expit =) 1 / (1 + exp(-sample_normal)) + auto neg_sample = g.add_operator(OperatorType::NEGATE, {sample_normal}); + auto exp = g.add_operator(OperatorType::EXP, {neg_sample}); + auto k1 = g.add_constant_pos_real(1.0); + auto denom = g.add_operator(OperatorType::ADD, {k1, exp}); + // At this point we would like to compute + // s = 1 / denom + // but BMG has no divide or reciprocal operation. So instead we compute + // s = exp(-log(denom)) + auto log_denom = g.add_operator(OperatorType::LOG, {denom}); + auto nld = g.add_operator(OperatorType::NEGATE, {log_denom}); + auto s0 = g.add_operator(OperatorType::EXP, {nld}); + auto s = g.add_operator(OperatorType::TO_PROBABILITY, {s0}); + + auto bn = + g.add_distribution(DistributionType::BERNOULLI, AtomicType::BOOLEAN, {s}); + + for (int i = 0; i < num_heads; i++) { + auto sample = g.add_operator(OperatorType::SAMPLE, {bn}); + g.observe(sample, true); + } + for (int i = 0; i < num_tails; i++) { + auto sample = g.add_operator(OperatorType::SAMPLE, {bn}); + g.observe(sample, false); + } + + g.query(s); + auto nuts = NUTS(g); + + auto start = std::chrono::high_resolution_clock::now(); + std::vector> infer_results = + nuts.infer(/* num_samples = */ num_samples, /* seed = */ seed); + auto finish = std::chrono::high_resolution_clock::now(); + auto time_in_microseconds = + std::chrono::duration_cast(finish - start) + .count(); + std::cout << fmt::format( + " bmg NUTS: ran in {} s", time_in_microseconds / 1E6) + << std::endl; + + // check that the results are as expected + ASSERT_EQ(infer_results.size(), num_samples); + double sum = 0; + int count = 0; + for (int i = skip_samples; i < num_samples; i++) { + auto estimate = infer_results[i][0]._double; + sum += estimate; + count++; + } + + double average = sum / num_samples; + // the following is the actual computed value using bmg + double expected = + 0.46432190477921237; // num_heads / (num_heads + num_tails + 0.0); + ASSERT_NEAR(average, expected, 0.001); +}