diff --git a/minibmg/ad/reverse.h b/minibmg/ad/reverse.h index 1536460bac..a7d03234f5 100644 --- a/minibmg/ad/reverse.h +++ b/minibmg/ad/reverse.h @@ -49,12 +49,14 @@ requires Number class ReverseBody { public: Underlying primal; - std::list> inputs; + std::vector> inputs; Underlying adjoint = 0; /* implicit */ ReverseBody(double primal); /* implicit */ ReverseBody(Underlying primal); - ReverseBody(Underlying primal, const std::list>& inputs); + ReverseBody( + Underlying primal, + const std::vector>& inputs); ReverseBody(const ReverseBody& other); ReverseBody& operator=(const ReverseBody& other); virtual ~ReverseBody() {} @@ -95,17 +97,17 @@ requires Number Reverse template class PrecomputedGradients : public ReverseBody { public: - const std::list computed_partial_derivatives; + const std::vector computed_partial_derivatives; PrecomputedGradients( Underlying primal, - const std::list>& inputs, - const std::list& computed_partial_derivatives) + const std::vector>& inputs, + const std::vector& computed_partial_derivatives) : ReverseBody{primal, inputs}, computed_partial_derivatives{computed_partial_derivatives} {} void backprop() override { - auto& /*std::list>*/ t = this->inputs; - typename std::list>::iterator it1 = t.begin(); - typename std::list::const_iterator it2 = + auto& /*std::vector>*/ t = this->inputs; + typename std::vector>::iterator it1 = t.begin(); + typename std::vector::const_iterator it2 = computed_partial_derivatives.begin(); for (; it1 != t.end() && it2 != computed_partial_derivatives.end(); ++it1, ++it2) { @@ -132,7 +134,7 @@ requires Number void Reverse::reverse(double initial_adjoint) { // topologically sort the set of ReverseBody pointers reachable - these are // the nodes to which we need to backprop. - std::list roots = {ptr}; + std::vector roots = {ptr}; std::function(const Bodyp&)> predecessors = [&](const Bodyp& ptr) -> std::vector { std::vector result; @@ -172,7 +174,7 @@ requires Number ReverseBody::ReverseBody( template requires Number ReverseBody::ReverseBody( Underlying primal, - const std::list>& inputs) + const std::vector>& inputs) : primal{primal}, inputs{inputs} {} template @@ -185,8 +187,8 @@ operator+(const Reverse& left, const Reverse& right) { Underlying new_primal = left.ptr->primal + right.ptr->primal; return Reverse{std::make_shared>( new_primal, - std::list>{left, right}, - std::list{1, 1})}; + std::vector>{left, right}, + std::vector{1, 1})}; } template @@ -194,8 +196,8 @@ requires Number Reverse operator-(const Reverse& left, const Reverse& right) { return Reverse{std::make_shared>( left.ptr->primal - right.ptr->primal, - std::list>{left, right}, - std::list{1, -1})}; + std::vector>{left, right}, + std::vector{1, -1})}; } template @@ -203,8 +205,8 @@ requires Number Reverse operator-(const Reverse& x) { return Reverse{std::make_shared>( -x.ptr->primal, - std::list>{x}, - std::list{-1})}; + std::vector>{x}, + std::vector{-1})}; } template @@ -212,8 +214,8 @@ requires Number Reverse operator*(const Reverse& left, const Reverse& right) { return Reverse{std::make_shared>( left.ptr->primal * right.ptr->primal, - std::list>{left, right}, - std::list{right.ptr->primal, left.ptr->primal})}; + std::vector>{left, right}, + std::vector{right.ptr->primal, left.ptr->primal})}; } template @@ -224,8 +226,8 @@ operator/(const Reverse& left, const Reverse& right) { return Reverse{std::make_shared>( new_primal, - std::list>{left, right}, - std::list{ + std::vector>{left, right}, + std::vector{ 1 / right.ptr->primal, -new_primal / right.ptr->primal})}; } @@ -255,8 +257,8 @@ requires Number Reverse pow( .derivative1; return Reverse{std::make_shared>( new_primal, - std::list>{base, exponent}, - std::list{grad1, grad2})}; + std::vector>{base, exponent}, + std::vector{grad1, grad2})}; } template @@ -265,8 +267,8 @@ requires Number Reverse exp( Underlying new_primal = exp(x.ptr->primal); return Reverse{std::make_shared>( new_primal, - std::list>{x}, - std::list{new_primal})}; + std::vector>{x}, + std::vector{new_primal})}; } template @@ -275,8 +277,8 @@ requires Number Reverse log( Underlying new_primal = log(x.ptr->primal); return Reverse{std::make_shared>( new_primal, - std::list>{x}, - std::list{1 / x.ptr->primal})}; + std::vector>{x}, + std::vector{1 / x.ptr->primal})}; } template @@ -286,8 +288,8 @@ requires Number Reverse atan( Underlying new_derivative1 = 1 / (x.ptr->primal * x.ptr->primal + 1.0f); return Reverse{std::make_shared>( new_primal, - std::list>{x}, - std::list{new_derivative1})}; + std::vector>{x}, + std::vector{new_derivative1})}; } template @@ -297,8 +299,8 @@ requires Number Reverse lgamma( Underlying new_derivative1 = polygamma(0, x.ptr->primal); return Reverse{std::make_shared>( new_primal, - std::list>{x}, - std::list{new_derivative1})}; + std::vector>{x}, + std::vector{new_derivative1})}; } template @@ -309,8 +311,8 @@ requires Number Reverse polygamma( Underlying new_derivative1 = polygamma(n + 1, x.ptr->primal); return Reverse{std::make_shared>( new_primal, - std::list>{x}, - std::list{new_derivative1})}; + std::vector>{x}, + std::vector{new_derivative1})}; } template @@ -321,8 +323,8 @@ requires Number Reverse log1p( Underlying new_derivative1 = 1.0 / (x_value + 1); return Reverse{std::make_shared>( new_primal, - std::list>{x}, - std::list{new_derivative1})}; + std::vector>{x}, + std::vector{new_derivative1})}; } template @@ -340,8 +342,8 @@ requires Number Reverse if_equal( when_not_equal.ptr->primal); return Reverse{std::make_shared>( new_primal, - std::list>{when_equal, when_not_equal}, - std::list{ + std::vector>{when_equal, when_not_equal}, + std::vector{ if_equal(value.ptr->primal, comparand.ptr->primal, 1, 0), if_equal(value.ptr->primal, comparand.ptr->primal, 0, 1)})}; } @@ -361,8 +363,8 @@ requires Number Reverse if_less( when_not_less.ptr->primal); return Reverse{std::make_shared>( new_primal, - std::list>{when_less, when_not_less}, - std::list{ + std::vector>{when_less, when_not_less}, + std::vector{ if_less(value.ptr->primal, comparand.ptr->primal, 1, 0), if_less(value.ptr->primal, comparand.ptr->primal, 0, 1)})}; } diff --git a/minibmg/dedup.h b/minibmg/dedup.h index 307afdff8d..5334c49123 100644 --- a/minibmg/dedup.h +++ b/minibmg/dedup.h @@ -8,7 +8,6 @@ #pragma once #include -#include #include #include #include diff --git a/minibmg/fluid_factory.h b/minibmg/fluid_factory.h index 62baeeebc2..a2c71331ef 100644 --- a/minibmg/fluid_factory.h +++ b/minibmg/fluid_factory.h @@ -39,7 +39,7 @@ class Graph::FluidFactory { private: std::vector queries; - std::list> observations; + std::vector> observations; }; } // namespace beanmachine::minibmg diff --git a/minibmg/graph.cpp b/minibmg/graph.cpp index a74e2b6493..3b4bb1824b 100644 --- a/minibmg/graph.cpp +++ b/minibmg/graph.cpp @@ -6,7 +6,6 @@ */ #include "beanmachine/minibmg/graph.h" -#include #include #include #include @@ -19,8 +18,8 @@ using namespace beanmachine::minibmg; const std::vector roots( const std::vector& queries, - const std::list>& observations) { - std::list roots; + const std::vector>& observations) { + std::vector roots; for (auto& n : queries) { roots.push_back(n); } @@ -28,7 +27,7 @@ const std::vector roots( if (!std::dynamic_pointer_cast(p.first)) { throw std::invalid_argument(fmt::format("can only observe a sample")); } - roots.push_front(p.first); + roots.push_back(p.first); } std::vector all_nodes; if (!topological_sort(roots, &in_nodes, all_nodes)) { @@ -40,7 +39,7 @@ const std::vector roots( struct QueriesAndObservations { std::vector queries; - std::list> observations; + std::vector> observations; ~QueriesAndObservations() {} }; @@ -65,7 +64,7 @@ class NodeRewriteAdapter { const QueriesAndObservations& qo, const std::unordered_map& map) const { NodeRewriteAdapter> h1{}; - NodeRewriteAdapter>> h2{}; + NodeRewriteAdapter>> h2{}; return QueriesAndObservations{ h1.rewrite(qo.queries, map), h2.rewrite(qo.observations, map)}; } @@ -75,7 +74,7 @@ using dynamic = folly::dynamic; Graph Graph::create( const std::vector& queries, - const std::list>& observations, + const std::vector>& observations, std::unordered_map* built_map) { for (auto& p : observations) { if (!std::dynamic_pointer_cast(p.first)) { @@ -95,7 +94,7 @@ Graph::~Graph() {} Graph::Graph( const std::vector& nodes, const std::vector& queries, - const std::list>& observations) + const std::vector>& observations) : nodes{nodes}, queries{queries}, observations{observations} {} } // namespace beanmachine::minibmg diff --git a/minibmg/graph.h b/minibmg/graph.h index 2d01ef4017..e871bf7620 100644 --- a/minibmg/graph.h +++ b/minibmg/graph.h @@ -8,7 +8,7 @@ #pragma once #include -#include +#include #include "beanmachine/minibmg/dedup.h" #include "beanmachine/minibmg/graph_properties/container.h" #include "beanmachine/minibmg/node.h" @@ -27,7 +27,7 @@ class Graph : public Container { static Graph create( const std::vector& queries, - const std::list>& observations, + const std::vector>& observations, std::unordered_map* built_map = nullptr); ~Graph(); @@ -55,7 +55,7 @@ class Graph : public Container { // Observations of the model. These are SAMPLE nodes in the model whose // values are known. - const std::list> observations; + const std::vector> observations; private: // A private constructor that forms a graph without validation. @@ -63,7 +63,7 @@ class Graph : public Container { Graph( const std::vector& nodes, const std::vector& queries, - const std::list>& observations); + const std::vector>& observations); public: // A factory for making graphs, like the bmg API used by Beanstalk @@ -99,7 +99,7 @@ class NodeRewriteAdapter { Graph rewrite(const Graph& qo, const std::unordered_map& map) const { NodeRewriteAdapter> h1{}; - NodeRewriteAdapter>> h2{}; + NodeRewriteAdapter>> h2{}; return Graph::create( h1.rewrite(qo.queries, map), h2.rewrite(qo.observations, map)); } diff --git a/minibmg/graph_factory.cpp b/minibmg/graph_factory.cpp index a03d9d4ec4..0cc8a7c057 100644 --- a/minibmg/graph_factory.cpp +++ b/minibmg/graph_factory.cpp @@ -10,7 +10,6 @@ #include #include #include "beanmachine/minibmg/node.h" -#include "node.h" namespace beanmachine::minibmg { diff --git a/minibmg/graph_factory.h b/minibmg/graph_factory.h index ce5960c3d3..ed3ec32461 100644 --- a/minibmg/graph_factory.h +++ b/minibmg/graph_factory.h @@ -8,7 +8,6 @@ #pragma once #include -#include #include #include #include @@ -123,7 +122,7 @@ class Graph::Factory { std::unordered_map identifer_to_node; std::unordered_map node_to_identifier; std::vector queries; - std::list> observations; + std::vector> observations; unsigned long next_identifier = 0; ScalarNodeId add_node(ScalarNodep node); diff --git a/minibmg/graph_properties/observations_by_node.h b/minibmg/graph_properties/observations_by_node.h index 2e19b8c0c1..39ddf0b177 100644 --- a/minibmg/graph_properties/observations_by_node.h +++ b/minibmg/graph_properties/observations_by_node.h @@ -7,7 +7,6 @@ #pragma once -#include #include #include "beanmachine/minibmg/graph.h" #include "beanmachine/minibmg/node.h" diff --git a/minibmg/graph_properties/out_nodes.cpp b/minibmg/graph_properties/out_nodes.cpp index b741a977c8..b5da96bd33 100644 --- a/minibmg/graph_properties/out_nodes.cpp +++ b/minibmg/graph_properties/out_nodes.cpp @@ -7,7 +7,6 @@ #include "beanmachine/minibmg/graph_properties/out_nodes.h" #include -#include #include namespace { @@ -17,13 +16,13 @@ using namespace beanmachine::minibmg; class OutNodesProperty : public Property< OutNodesProperty, Graph, - std::map>> { + std::map>> { public: - std::map>* create(const Graph& g) const override { - std::map>* data = - new std::map>{}; + std::map>* create(const Graph& g) const override { + std::map>* data = + new std::map>{}; for (auto node : g) { - (*data)[node] = std::list{}; + (*data)[node] = std::vector{}; for (auto in_node : in_nodes(node)) { auto& predecessor_out_set = (*data)[in_node]; predecessor_out_set.push_back(node); @@ -38,8 +37,8 @@ class OutNodesProperty : public Property< namespace beanmachine::minibmg { -const std::list& out_nodes(const Graph& graph, Nodep node) { - std::map>& node_map = *OutNodesProperty::get(graph); +const std::vector& out_nodes(const Graph& graph, Nodep node) { + std::map>& node_map = *OutNodesProperty::get(graph); auto found = node_map.find(node); if (found == node_map.end()) { throw std::invalid_argument("node not in graph"); diff --git a/minibmg/graph_properties/out_nodes.h b/minibmg/graph_properties/out_nodes.h index 7d693929e2..137463315e 100644 --- a/minibmg/graph_properties/out_nodes.h +++ b/minibmg/graph_properties/out_nodes.h @@ -7,8 +7,8 @@ #pragma once -#include #include +#include #include "beanmachine/minibmg/graph.h" #include "beanmachine/minibmg/node.h" @@ -16,6 +16,6 @@ namespace beanmachine::minibmg { // return the set of nodes that have the given node as an input in the given // graph. -const std::list& out_nodes(const Graph& graph, Nodep node); +const std::vector& out_nodes(const Graph& graph, Nodep node); } // namespace beanmachine::minibmg diff --git a/minibmg/json.cpp b/minibmg/json.cpp index 0a687dcf44..fbd9a40f15 100644 --- a/minibmg/json.cpp +++ b/minibmg/json.cpp @@ -431,7 +431,7 @@ Graph json_to_graph2(folly::dynamic d) { } } - std::list> observations; + std::vector> observations; auto observation_nodes = d["observations"]; if (observation_nodes.isArray()) { for (auto& obs : observation_nodes) { diff --git a/minibmg/rewrite_adapter.h b/minibmg/rewrite_adapter.h index 6df3ebfc13..95fdfe733c 100644 --- a/minibmg/rewrite_adapter.h +++ b/minibmg/rewrite_adapter.h @@ -7,7 +7,6 @@ #pragma once -#include #include #include #include "beanmachine/minibmg/ad/real.h" @@ -108,33 +107,6 @@ class NodeRewriteAdapter> { }; static_assert(Rewritable>); -// A list can be deduplicated -template -requires Rewritable -class NodeRewriteAdapter> { - NodeRewriteAdapter t_helper{}; - - public: - std::vector find_roots(const std::list& roots) const { - std::vector result; - for (const auto& root : roots) { - auto more_roots = t_helper.find_roots(root); - result.push_back(more_roots.begin(), more_roots.end()); - } - return result; - } - std::list rewrite( - const std::list& roots, - const std::unordered_map& map) const { - std::list result; - for (const auto& root : roots) { - result.push_back(t_helper.rewrite(root, map)); - } - return result; - } -}; -static_assert(Rewritable>); - // A pair can be deduplicated template requires Rewritable && Rewritable diff --git a/minibmg/tests/ad/num3_test.cpp b/minibmg/tests/ad/num3_test.cpp index 5d60daa4f6..60f543103c 100644 --- a/minibmg/tests/ad/num3_test.cpp +++ b/minibmg/tests/ad/num3_test.cpp @@ -9,7 +9,6 @@ #include #include #include -#include #include #include "beanmachine/minibmg/ad/num2.h" diff --git a/minibmg/tests/graph_properties/out_nodes_test.cpp b/minibmg/tests/graph_properties/out_nodes_test.cpp index c58deb4335..7243a95f49 100644 --- a/minibmg/tests/graph_properties/out_nodes_test.cpp +++ b/minibmg/tests/graph_properties/out_nodes_test.cpp @@ -6,7 +6,7 @@ */ #include -#include +#include #include "beanmachine/minibmg/graph.h" #include "beanmachine/minibmg/graph_factory.h" #include "beanmachine/minibmg/graph_properties/out_nodes.h" @@ -33,15 +33,15 @@ TEST(out_nodes_test, simple) { Nodep betan = gf[beta]; Nodep samplen = gf[sample]; - ASSERT_EQ(out_nodes(g, k12n), std::list{plusn}); - ASSERT_EQ(out_nodes(g, k34n), (std::list{plusn})); - ASSERT_EQ(out_nodes(g, plusn), std::list{betan}); - ASSERT_EQ(out_nodes(g, k56n), std::list{betan}); - ASSERT_EQ(out_nodes(g, betan), (std::list{samplen})); - ASSERT_EQ(out_nodes(g, samplen), std::list{}); + ASSERT_EQ(out_nodes(g, k12n), std::vector{plusn}); + ASSERT_EQ(out_nodes(g, k34n), (std::vector{plusn})); + ASSERT_EQ(out_nodes(g, plusn), std::vector{betan}); + ASSERT_EQ(out_nodes(g, k56n), std::vector{betan}); + ASSERT_EQ(out_nodes(g, betan), (std::vector{samplen})); + ASSERT_EQ(out_nodes(g, samplen), std::vector{}); ASSERT_EQ(g.queries, std::vector{samplen}); - std::list> expected_observations; + std::vector> expected_observations; expected_observations.push_back(std::pair{samplen, 7.8}); ASSERT_EQ(g.observations, expected_observations); } diff --git a/minibmg/tests/localopt_test.cpp b/minibmg/tests/localopt_test.cpp index 2a981c8cfc..358d744bb7 100644 --- a/minibmg/tests/localopt_test.cpp +++ b/minibmg/tests/localopt_test.cpp @@ -83,16 +83,16 @@ TEST(localopt_test, symbolic_derivatives) { printed << "d2 = "; printed << print_result.code[d2] << std::endl; - auto expected = R"(auto temp_1 = -pow(rvid.constrained, -2); + auto expected = R"(auto temp_1 = log(rvid.constrained); auto temp_2 = 1 - rvid.constrained; -auto temp_3 = -pow(temp_2, -2); +auto temp_3 = log(temp_2); auto temp_4 = 1 / rvid.constrained; auto temp_5 = -1 / temp_2; -auto temp_6 = log(rvid.constrained); -auto temp_7 = log(temp_2); -log_prob = temp_6 + temp_7 + 1.791759469228055 + temp_6 + temp_6 + temp_7 +auto temp_6 = -pow(rvid.constrained, -2); +auto temp_7 = -pow(temp_2, -2); +log_prob = temp_1 + temp_3 + 1.791759469228055 + temp_1 + temp_1 + temp_3 d1 = temp_4 + temp_5 + temp_4 + temp_4 + temp_5 -d2 = temp_1 + temp_3 + temp_1 + temp_1 + temp_3 +d2 = temp_6 + temp_7 + temp_6 + temp_6 + temp_7 )"; ASSERT_EQ(expected, printed.str()); } diff --git a/minibmg/tests/minibmg_test.cpp b/minibmg/tests/minibmg_test.cpp index d0d4540501..7979f6570f 100644 --- a/minibmg/tests/minibmg_test.cpp +++ b/minibmg/tests/minibmg_test.cpp @@ -77,10 +77,10 @@ TEST(test_minibmg, dedupable_concept) { ASSERT_FALSE(Rewritable>>); ASSERT_TRUE((Rewritable>)); ASSERT_FALSE((Rewritable>)); - ASSERT_TRUE(Rewritable>); - ASSERT_TRUE(Rewritable>>); - ASSERT_FALSE(Rewritable>); - ASSERT_FALSE(Rewritable>>); + ASSERT_TRUE(Rewritable>); + ASSERT_TRUE(Rewritable>>); + ASSERT_FALSE(Rewritable>); + ASSERT_FALSE(Rewritable>>); } TEST(test_minibmg, graph_factory_nodeid_equality) { diff --git a/minibmg/tests/topological_test.cpp b/minibmg/tests/topological_test.cpp index 23bfe6ea8c..2b89090122 100644 --- a/minibmg/tests/topological_test.cpp +++ b/minibmg/tests/topological_test.cpp @@ -80,7 +80,7 @@ TEST(topological_test, ensure_sorted) { // topologically sort them. std::vector result; auto sorted = topological_sort( - std::list{nodes.begin(), nodes.end()}, + nodes, [](Node* node) { return std::vector{ node->successors.begin(), node->successors.end()}; @@ -105,7 +105,7 @@ TEST(topological_test, ensure_sorted) { // topologically sort them. if there was a cycle, this should return false. result.clear(); sorted = topological_sort( - std::list{nodes.begin(), nodes.end()}, + nodes, [](Node* const& node) { return std::vector{ node->successors.begin(), node->successors.end()}; diff --git a/minibmg/topological.h b/minibmg/topological.h index 4722e0dcfd..52692369a0 100644 --- a/minibmg/topological.h +++ b/minibmg/topological.h @@ -8,9 +8,9 @@ #pragma once #include -#include #include #include +#include namespace { @@ -18,15 +18,24 @@ namespace { // given. template std::map count_predecessors_internal( - const std::list& root_nodes, + const std::vector& root_nodes, std::function(const T&)> successors, - std::list& nodes) { + std::vector& nodes, + bool include_roots = false) { std::map predecessor_counts; - std::list to_count; + std::vector to_count; std::set counted; for (const auto& node : root_nodes) { to_count.push_back(node); + if (include_roots) { + if (!predecessor_counts.contains(node)) { + predecessor_counts[node] = 1; + } else { + predecessor_counts[node] = predecessor_counts[node] + 1; + } + } } + std::reverse(to_count.begin(), to_count.end()); while (!to_count.empty()) { auto node = to_count.back(); @@ -64,10 +73,10 @@ template bool topological_sort_internal( std::map& predecessor_counts, std::function(const T&)> successors, - std::list& nodes, + std::vector& nodes, std::vector& result) { // initialize the ready set with those nodes that have no predecessors - std::list ready; + std::vector ready; for (auto node : nodes) { if (predecessor_counts[node] == 0) { ready.push_back(node); @@ -104,9 +113,9 @@ namespace beanmachine::minibmg { // given. template std::map count_predecessors( - const std::list& root_nodes, + const std::vector& root_nodes, std::function(const T&)> successors) { - std::list ready; + std::vector ready; return count_predecessors_internal(root_nodes, successors, ready); } @@ -116,10 +125,10 @@ std::map count_predecessors( // sorted result in the `result` parameter. template bool topological_sort( - const std::list& root_nodes, + const std::vector& root_nodes, std::function(const T&)> successors, std::vector& result) { - std::list ready; + std::vector ready; // count the predecessors of each node. std::map predecessor_counts = count_predecessors_internal(root_nodes, successors, ready);