forked from Preparation-Publication-BD2K/db_compress
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_learner_test.cpp
162 lines (147 loc) · 5.42 KB
/
model_learner_test.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#include "base.h"
#include "model.h"
#include "model_learner.h"
#include <cmath>
#include <vector>
#include <iostream>
namespace db_compress {
std::map<int, int> model_cost;
CompressionConfig config;
Schema schema;
inline int GetCost(const std::vector<size_t>& pred, int target) {
int index = 0;
for (size_t i = 0; i < pred.size(); ++i)
index = index * 10 + pred[i] + 1;
index = index * 10 + target + 1;
if (model_cost[index] == 0) return 1000000;
return model_cost[index];
}
class MockAttr : public AttrValue {
private:
int val_;
public:
MockAttr(int val) : val_(val) {}
int Value() const { return val_; }
};
class MockTree : public ProbTree {
private:
bool first_step_;
MockAttr attr_;
public:
MockTree() : first_step_(true), attr_(0) {}
bool HasNextBranch() const { return first_step_; }
void GenerateNextBranch() { first_step_ = false; }
int GetNextBranch(const AttrValue* attr) const { return 0; }
void ChooseNextBranch(int branch) {}
const AttrValue* GetResultAttr() { return &attr_; }
};
class MockModel : public Model {
private:
MockTree tree_;
int a_, b_, c_;
public:
MockModel(const std::vector<size_t>& pred, size_t target) : Model(pred, target) {}
ProbTree* GetProbTree(const Tuple& tuple) { return &tree_; }
void FeedTuple(const Tuple& tuple) {
a_ = static_cast<const MockAttr*>(tuple.attr[0])->Value();
b_ = static_cast<const MockAttr*>(tuple.attr[1])->Value();
c_ = static_cast<const MockAttr*>(tuple.attr[2])->Value();
}
int Check() const { return a_* 100 + b_ * 10 + c_; }
int GetModelDescriptionLength() const { return 0; }
void WriteModel(ByteWriter* byte_writer, size_t block_index) const {}
int GetModelCost() const {
return GetCost(predictor_list_, target_var_);
}
};
class MockModelCreator : public ModelCreator {
public:
Model* ReadModel(ByteReader* byte_reader, const Schema& schema, size_t index) { return NULL; }
Model* CreateModel(const Schema& schema, const std::vector<size_t>& pred,
size_t index, double err) {
int cost = GetCost(pred, index);
if (cost == 1000000)
return NULL;
return new MockModel(pred, index);
}
};
inline int Check(Model* model) {
return static_cast<MockModel*>(model)->Check();
}
void PrepareData() {
model_cost[1] = 9;
model_cost[2] = 8;
model_cost[3] = 7;
model_cost[21] = 4;
model_cost[231] = 2;
model_cost[12] = 12;
model_cost[23] = 5;
model_cost[32] = 5;
std::vector<int> attr(3);
attr[0] = attr[1] = attr[2] = 0;
schema = Schema(attr);
RegisterAttrModel(0, new MockModelCreator());
config.allowed_err.resize(3);
}
void TestWithPrimaryAttr() {
config.sort_by_attr = 0;
ModelLearner learner(schema, config);
MockAttr attr(1);
Tuple tuple(3);
tuple.attr[0] = tuple.attr[1] = tuple.attr[2] = &attr;
while (1) {
learner.FeedTuple(tuple);
learner.EndOfData();
if (!learner.RequireMoreIterations())
break;
}
std::vector<size_t> attr_vec = learner.GetOrderOfAttributes();
if (attr_vec[0] != 0 || attr_vec[1] != 2 || attr_vec[2] != 1)
std::cerr << attr_vec[1] << "Model Learner w/ Primary Attr Unit Test Failed!\n";
std::unique_ptr<Model> a(learner.GetModel(0));
std::unique_ptr<Model> b(learner.GetModel(1));
std::unique_ptr<Model> c(learner.GetModel(2));
if (a->GetTargetVar() != 0 || a->GetPredictorList().size() != 0 || Check(a.get()) != 111)
std::cerr << "Model Learner w/ Primary Attr Unit Test Failed!\n";
if (b->GetTargetVar() != 1 || b->GetPredictorList().size() != 1 ||
b->GetPredictorList()[0] != 2 || Check(b.get()) != 10)
std::cerr << "Model Learner w/ Primary Attr Unit Test Failed!\n";
if (c->GetTargetVar() != 2 || c->GetPredictorList().size() != 0 || Check(c.get()) != 111)
std::cerr << "Model Learner w/ Primary Attr Unit Test Failed!\n";
}
void TestWithoutPrimaryAttr() {
config.sort_by_attr = -1;
ModelLearner learner(schema, config);
MockAttr attr(1);
Tuple tuple(3);
tuple.attr[0] = tuple.attr[1] = tuple.attr[2] = &attr;
while (1) {
learner.FeedTuple(tuple);
learner.EndOfData();
if (!learner.RequireMoreIterations())
break;
}
std::vector<size_t> attr_vec = learner.GetOrderOfAttributes();
if (attr_vec[0] != 2 || attr_vec[1] != 1 || attr_vec[2] != 0)
std::cerr << "Model Learner w/o Primary Attr Unit Test Failed!\n";
std::unique_ptr<Model> a(learner.GetModel(0));
std::unique_ptr<Model> b(learner.GetModel(1));
std::unique_ptr<Model> c(learner.GetModel(2));
if (a->GetTargetVar() != 0 || a->GetPredictorList().size() != 2 ||
a->GetPredictorList()[0] != 1 || a->GetPredictorList()[1] != 2 || Check(a.get()) != 100)
std::cerr << "Model Learner w/o Primary Attr Unit Test Failed!\n";
if (b->GetTargetVar() != 1 || b->GetPredictorList().size() != 1 ||
b->GetPredictorList()[0] != 2 || Check(b.get()) != 110)
std::cerr << "Model Learner w/o Primary Attr Unit Test Failed!\n";
if (c->GetTargetVar() != 2 || c->GetPredictorList().size() != 0 || Check(c.get()) != 111)
std::cerr << "Model Learner w/o Primary Attr Unit Test Failed!\n";
}
void Test() {
PrepareData();
TestWithPrimaryAttr();
TestWithoutPrimaryAttr();
}
} // namespace db_compress
int main() {
db_compress::Test();
}