-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathResultAssemblyModule.ini
107 lines (92 loc) · 3.4 KB
/
ResultAssemblyModule.ini
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
// ResultAssembly.cpp
#include "ResultAssembly.h"
GlobalSolution::GlobalSolution(const std::vector<size_t>& dims) : dimensions(dims) {
size_t total_size = 1;
for(auto dim : dimensions){
total_size *= dim;
}
data.resize(total_size, 0.0f); // 初始化为0
}
size_t GlobalSolution::GetGlobalIndex(const std::vector<size_t>& indices) const {
size_t index = 0;
size_t multiplier = 1;
for(int i = dimensions.size() -1; i >=0; --i){
index += indices[i] * multiplier;
multiplier *= dimensions[i];
}
return index;
}
void GlobalSolution::SetData(const std::vector<size_t>& indices, float value){
size_t idx = GetGlobalIndex(indices);
data[idx] = value;
}
float GlobalSolution::GetData(const std::vector<size_t>& indices) const {
size_t idx = GetGlobalIndex(indices);
return data[idx];
}
void MappingTable::AddMapping(int sub_id, const std::vector<size_t>& global_start){
sub_to_global_start[sub_id] = global_start;
}
std::vector<size_t> MappingTable::GetGlobalStart(int sub_id) const {
auto it = sub_to_global_start.find(sub_id);
if(it != sub_to_global_start.end()){
return it->second;
}
else{
return {};
}
}
ResultAssembly::ResultAssembly(const std::vector<size_t>& global_dims)
: global_solution_(global_dims) {}
void ResultAssembly::AddMapping(int sub_id, const std::vector<size_t>& global_start){
mapping_table_.AddMapping(sub_id, global_start);
}
void ResultAssembly::AggregateCIResults(const std::vector<Task>& ci_tasks){
std::lock_guard<std::mutex> lock(mutex_);
for(const auto& task : ci_tasks){
// 获取子张量的全局起始索引
std::vector<size_t> global_start = mapping_table_.GetGlobalStart(task.id);
if(global_start.empty()){
std::cerr << "任务 " << task.id << " 没有对应的全局起始索引。" << std::endl;
continue;
}
// 将子张量数据拷贝到全局解
for(size_t i = 0; i < task.sub_tensor.sizes[0]; ++i){
for(size_t j = 0; j < task.sub_tensor.sizes[1]; ++j){
std::vector<size_t> indices = {global_start[0] + i, global_start[1] + j};
size_t sub_idx = i * task.sub_tensor.sizes[1] + j;
global_solution_.SetData(indices, task.sub_tensor.h_data[sub_idx]);
}
}
}
}
void ResultAssembly::AggregatePRRLUResults(const std::vector<Task>& prrlu_tasks){
// 根据需要实现
}
void ResultAssembly::HandleOverlaps(){
// 根据具体的子张量布局和重叠方式实现
}
void ResultAssembly::AssembleGlobalSolution(const std::vector<Task>& ci_tasks, const std::vector<Task>& prrlu_tasks){
AggregateCIResults(ci_tasks);
AggregatePRRLUResults(prrlu_tasks);
HandleOverlaps();
}
void ResultAssembly::ValidateGlobalSolution(){
// 实现误差检查和验证
float total = 0.0f;
for(auto val : global_solution_.data){
total += val;
}
std::cout << "全局解方案总和: " << total << std::endl;
}
void ResultAssembly::PrintGlobalSolution(){
std::cout << "全局解方案:" << std::endl;
for(size_t i = 0; i < global_solution_.dimensions[0]; ++i){
for(size_t j = 0; j < global_solution_.dimensions[1]; ++j){
std::vector<size_t> indices = {i, j};
float val = global_solution_.GetData(indices);
std::cout << val << " ";
}
std::cout << std::endl;
}
}