forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest_subgraph_utils.cpp
41 lines (32 loc) · 1.15 KB
/
test_subgraph_utils.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include "test/cpp/jit/test_base.h"
#include "test/cpp/jit/test_utils.h"
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
#include "torch/csrc/jit/passes/utils/subgraph_utils.h"
namespace torch {
namespace jit {
void testSubgraphUtils() {
auto graph = build_lstm();
EliminateCommonSubexpression(graph);
std::vector<Node*> originalNodes(
graph->nodes().begin(), graph->nodes().end());
// Merge everything into a single subgraph
bool first = true;
Node* subgraph;
for (auto it = graph->nodes().rbegin(); it != graph->nodes().rend();) {
if (first) {
subgraph = SubgraphUtils::createSingletonSubgraph(
*it, prim::DifferentiableGraph);
it = ++subgraph->reverseIterator();
first = false;
}
SubgraphUtils::mergeNodeIntoSubgraph(*it, subgraph);
it = ++subgraph->reverseIterator();
}
// Unmerge and compare with original node listing
SubgraphUtils::unmergeSubgraph(subgraph);
EliminateCommonSubexpression(graph);
std::vector<Node*> newNodes(graph->nodes().begin(), graph->nodes().end());
ASSERT_EQ(originalNodes.size(), newNodes.size());
}
} // namespace jit
} // namespace torch