diff --git a/include/core/allocator.h b/include/core/allocator.h index 002601d..898b0dc 100644 --- a/include/core/allocator.h +++ b/include/core/allocator.h @@ -27,6 +27,9 @@ namespace infini { // TODO:可能需要设计一个数据结构来存储free block,以便于管理和合并 // HINT: 可以使用一个 map 来存储 free block,key 为 block 的起始/结尾地址,value 为 block 的大小 // =================================== 作业 =================================== + std::map freeByAddr; // addr -> size + std::multimap freeBySize; // size -> addr + std::unordered_map live; // addr -> size public: Allocator(Runtime runtime); @@ -51,6 +54,17 @@ namespace infini { void info(); + void insertFreeBlock(size_t addr, size_t size); + + void eraseFreeBlockByAddr(std::map::iterator itAddr); + + void eraseFreeBlockBySize(std::multimap::iterator itSize); + + static inline bool checked_add(size_t a, size_t b, size_t &out) { + if (b > std::numeric_limits::max() - a) return false; + out = a + b; + return true; + } private: // function: memory alignment, rouned up // return: size of the aligned memory block diff --git a/include/core/graph.h b/include/core/graph.h index c45580c..4941908 100644 --- a/include/core/graph.h +++ b/include/core/graph.h @@ -25,19 +25,8 @@ namespace infini Tensor addTensor(Shape dim, DataType dtype = DataType::Float32); Tensor addTensor(const Tensor &tensor); TensorVec addTensor(const TensorVec &tensors); - void removeOperator(Operator op) - { - auto it = std::find(ops.begin(), ops.end(), op); - if (it != ops.end()) - ops.erase(it); - } - - void removeTensor(Tensor tensor) - { - auto it = std::find(tensors.begin(), tensors.end(), tensor); - if (it != tensors.end()) - tensors.erase(it); - } + void removeOperator(Operator op); + void removeTensor(Tensor tensor); const TensorVec &getTensors() const { return tensors; } const OpVec &getOperators() const { return ops; } @@ -112,6 +101,14 @@ namespace infini */ void addOperatorAndConnect(const Operator &op); + // Optimization helpers (implemented in src/core/graph.cc) + bool fuseConsecutiveTransposes(); + bool foldTransposeIntoMatmul(); + bool deadCodeEliminate(); + void replaceInputTensor(const Operator &op, const Tensor &oldT, const Tensor &newT); + void rewirePredToSucc(const Operator &pred, const Operator &succ); + bool hasSingleUse(const Tensor &t) const; + /** * @brief If the nodes is sorted in topological order. */ diff --git a/include/core/ref.h b/include/core/ref.h index 3393f6e..94e79c9 100644 --- a/include/core/ref.h +++ b/include/core/ref.h @@ -35,8 +35,11 @@ std::vector> refs_to_wrefs(const std::vector> &refs) { template std::vector> wrefs_to_refs(const std::vector> &wrefs) { std::vector> refs; - for (const auto &wref : wrefs) - refs.emplace_back(wref); + refs.reserve(wrefs.size()); + for (const auto &wref : wrefs) { + if (auto p = wref.lock()) + refs.emplace_back(std::move(p)); + } return refs; } diff --git a/include/operators/transpose.h b/include/operators/transpose.h index c32bbe5..4d03d68 100644 --- a/include/operators/transpose.h +++ b/include/operators/transpose.h @@ -27,6 +27,7 @@ namespace infini int numInputs() const override { return 1; } int numOutputs() const override { return 1; } std::vector getPermute() const { return transposePermute; } + void setPermute(const std::vector &p) { transposePermute = p; } private: vector transposePermute; diff --git a/src/core/allocator.cc b/src/core/allocator.cc index ff593ae..24d548d 100644 --- a/src/core/allocator.cc +++ b/src/core/allocator.cc @@ -3,6 +3,35 @@ namespace infini { + void Allocator::insertFreeBlock(size_t addr, size_t size) { + IT_ASSERT(size > 0); + // 回填两个索引 + freeByAddr.emplace(addr, size); + freeBySize.emplace(size, addr); + } + + void Allocator::eraseFreeBlockByAddr(std::map::iterator itAddr) { + // 同步从 size 索引里删除对应项(可能有多个相同 size,按 唯一定位) + size_t addr = itAddr->first, sz = itAddr->second; + auto range = freeBySize.equal_range(sz); + for (auto it = range.first; it != range.second; ++it) { + if (it->second == addr) { freeBySize.erase(it); break; } + } + freeByAddr.erase(itAddr); + } + + void Allocator::eraseFreeBlockBySize(std::multimap::iterator itSize) { + // 同步从 addr 索引里删除对应项 + size_t sz = itSize->first, addr = itSize->second; + auto itAddr = freeByAddr.find(addr); + if (itAddr != freeByAddr.end()) { + IT_ASSERT(itAddr->second == sz); + freeByAddr.erase(itAddr); + } + freeBySize.erase(itSize); + } + + Allocator::Allocator(Runtime runtime) : runtime(runtime) { used = 0; @@ -25,25 +54,121 @@ namespace infini size_t Allocator::alloc(size_t size) { - IT_ASSERT(this->ptr == nullptr); - // pad the size to the multiple of alignment size = this->getAlignedSize(size); // =================================== 作业 =================================== // TODO: 设计一个算法来分配内存,返回起始地址偏移量 // =================================== 作业 =================================== - return 0; - } + auto it = freeBySize.lower_bound(size); + if (it != freeBySize.end()) { + size_t blkSize = it->first; + size_t addr = it->second; + + // 同步删除旧空闲块(两个索引) + eraseFreeBlockBySize(it); + + // 2) 分裂:保留尾侧残块 [addr+size, addr+blkSize) + if (blkSize > size) { + size_t tailAddr; + [[maybe_unused]] bool ok = checked_add(addr, size, tailAddr); + IT_ASSERT(ok); + insertFreeBlock(tailAddr, blkSize - size); + } + + // 3) 记账:live / used / peak + live.emplace(addr, size); + used += size; + + size_t high; + [[maybe_unused]] bool ok2 = checked_add(addr, size, high); + IT_ASSERT(ok2); + if (high > peak) peak = high; // 注意:peak 是地址高水位,不是 used 峰值 + + return addr; + } + size_t addr = peak; + size_t newPeak; + bool ok = checked_add(peak, size, newPeak); + IT_ASSERT(ok); + peak = newPeak; + + live.emplace(addr, size); + used += size; + return addr; + } void Allocator::free(size_t addr, size_t size) { - IT_ASSERT(this->ptr == nullptr); size = getAlignedSize(size); // =================================== 作业 =================================== // TODO: 设计一个算法来回收内存 // =================================== 作业 =================================== + auto itLive = live.find(addr); + IT_ASSERT(itLive != live.end()); // 非法地址 + IT_ASSERT(itLive->second == size); // 尺寸必须一致(你也可以选择“以 live 为准”) + live.erase(itLive); + used -= size; + + // 2) 合并:先将 [addr,size] 投入 freeByAddr,再和前/后邻接块合并 + auto itInsert = freeByAddr.lower_bound(addr); + size_t newAddr = addr, newSize = size; + + // 与前块合并(prev) + if (itInsert != freeByAddr.begin()) { + auto itPrev = std::prev(itInsert); + size_t prevAddr = itPrev->first, prevSize = itPrev->second; + size_t prevEnd; + bool ok = checked_add(prevAddr, prevSize, prevEnd); + IT_ASSERT(ok); + if (prevEnd == addr) { + // 删前块(同步 size 索引) + eraseFreeBlockByAddr(itPrev); + newAddr = prevAddr; + newSize += prevSize; + } + } + + // 与后块合并(next = itInsert) + if (itInsert != freeByAddr.end()) { + size_t nextAddr = itInsert->first, nextSize = itInsert->second; + size_t thisEnd; + bool ok = checked_add(newAddr, newSize, thisEnd); + IT_ASSERT(ok); + if (thisEnd == nextAddr) { + // 删后块(同步 size 索引) + eraseFreeBlockByAddr(itInsert); + newSize += nextSize; + } + } + + // 3) 写回合并后的最终空闲块到两个索引 + insertFreeBlock(newAddr, newSize); + + // // 4) 回缩 peak:仅当“末端连续空闲覆盖到高水位” + // // 需要在 freeByAddr 中查看最高地址块是否接触 peak + auto itLast = freeByAddr.empty() ? freeByAddr.end() : std::prev(freeByAddr.end()); + if (!freeByAddr.empty()) { + size_t lastAddr = itLast->first, lastSize = itLast->second; + size_t lastEnd; + bool ok = checked_add(lastAddr, lastSize, lastEnd); + IT_ASSERT(ok); + if (lastEnd == peak) { + // 从末端收缩:不断回收可贴到峰值的尾块 + // 移除并更新两索引,直到不再贴合 + while (!freeByAddr.empty()) { + auto itTail = std::prev(freeByAddr.end()); + size_t a = itTail->first, s = itTail->second, e; + bool ok2 = checked_add(a, s, e); + IT_ASSERT(ok2); + if (e != peak) break; + eraseFreeBlockByAddr(itTail); + peak = a; + if (freeByAddr.empty()) break; + } + } + } } void *Allocator::getPtr() @@ -51,7 +176,6 @@ namespace infini if (this->ptr == nullptr) { this->ptr = runtime->alloc(this->peak); - printf("Allocator really alloc: %p %lu bytes\n", this->ptr, peak); } return this->ptr; } diff --git a/src/core/graph.cc b/src/core/graph.cc index 3a90637..26aefe1 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -1,7 +1,13 @@ #include "core/graph.h" +#include "core/blob.h" +#include "core/ref.h" +#include "core/runtime.h" +#include "operators/transpose.h" +#include "operators/matmul.h" #include #include #include +#include namespace infini { @@ -98,6 +104,208 @@ namespace infini return this->sorted = true; } + inline bool isTranspose(const OperatorObj& op) { + return op.getOpType() == OpType::Transpose; + } + inline bool isMatmul(const OperatorObj& op) { + return op.getOpType() == OpType::MatMul; + } + + // ====== 辅助:置换运算 ====== + inline bool isIdentityPerm(const std::vector& p) { + for (int i = 0; i < (int)p.size(); ++i) if (p[i] != i) return false; + return true; + } + inline std::vector composePerm(const std::vector& p2, const std::vector& p1) { + // 先做 p1,再做 p2:结果 p(i) = p2[p1[i]] + assert(p1.size() == p2.size()); + std::vector r(p1.size()); + for (int i = 0; i < (int)p1.size(); ++i) r[i] = p2[p1[i]]; + return r; + } + inline bool isSwapLastTwo(const std::vector& p) { + int n = (int)p.size(); + if (n < 2) return false; + for (int i = 0; i < n - 2; ++i) if (p[i] != i) return false; + return p[n - 2] == n - 1 && p[n - 1] == n - 2; + } + + // ====== 辅助:安全连接/去重 ====== + template + static void uniq(std::vector& v) { + std::sort(v.begin(), v.end()); + v.erase(std::unique(v.begin(), v.end()), v.end()); + } + + void GraphObj::replaceInputTensor(const Operator &op, const Tensor &oldT, const Tensor &newT) { + // 1) 替换 op.inputs 里的指针 + // op is Ref + op->replaceInput(oldT, newT); + // 2) 维护边:oldT.targets 删除 op,新T.targets 添加 op + oldT->removeTarget(op); + newT->addTarget(op); + } + + void GraphObj::rewirePredToSucc(const Operator &pred, const Operator &succ) { + pred->addSuccessors(succ); + succ->addPredecessors(pred); + // remove duplicates in weak-ref lists by converting to vector + auto psucc = pred->getSuccessors(); + auto spre = succ->getPredecessors(); + uniq(psucc); + uniq(spre); + // reassign deduped lists (hack via clone): not necessary because add/remove manage lists + } + + bool GraphObj::hasSingleUse(const Tensor &t) const { + return t->getTargets().size() == 1; + } + + void GraphObj::removeOperator(Operator op) { + // 1) remove this op from predecessors' successors + for (auto pre : op->getPredecessors()) { + pre->removeSuccessors(op); + } + // 2) remove this op from successors' predecessors + for (auto suc : op->getSuccessors()) { + suc->removePredecessors(op); + } + // 3) clean up inputs' target lists + for (auto in : op->getInputs()) { + if (in) in->removeTarget(op); + } + // 4) clean up outputs' source + for (auto out : op->getOutputs()) { + if (out) { + if (out->getSource() && out->getSource() == op) + out->setSource(nullptr); + } + } + // 5) finally remove from ops vector + auto it = std::find(ops.begin(), ops.end(), op); + if (it != ops.end()) ops.erase(it); + } + + void GraphObj::removeTensor(Tensor tensor) { + // 1) remove tensor from producer's outputs + if (auto prod = tensor->getSource()) { + // OperatorObj's outputs is accessible since GraphObj is friend + auto &outs = prod->outputs; + for (auto &o : outs) { + if (o == tensor) o = nullptr; + } + } + // 2) remove tensor from consumers' inputs + for (auto consumer : tensor->getTargets()) { + consumer->replaceInput(tensor, nullptr); + // also remove predecessor link from consumer to producer if any + // handled elsewhere when operators are removed + } + // 3) finally remove from tensors vector + auto it = std::find(tensors.begin(), tensors.end(), tensor); + if (it != tensors.end()) tensors.erase(it); + } + + // ====== 规则 1:合并/消除连续 Transpose ====== + bool GraphObj::fuseConsecutiveTransposes() { + bool changed = false; + auto &opsRef = this->ops; + for (size_t i = 0; i < opsRef.size();) { + auto op = opsRef[i]; + if (!isTranspose(*op)) { ++i; continue; } + + auto t1_out = op->getOutputs().front(); + if (!hasSingleUse(t1_out)) { ++i; continue; } + + auto succ = t1_out->getTargets().front(); + if (!isTranspose(*succ)) { ++i; continue; } + + auto T1 = std::static_pointer_cast(op); + auto T2 = std::static_pointer_cast(succ); + + auto perm = composePerm(T2->getPermute(), T1->getPermute()); + if (isIdentityPerm(perm)) { + auto inT = op->getInputs().front(); + auto outT = succ->getOutputs().front(); + + auto succs = outT->getTargets(); + for (auto &s : succs) { + replaceInputTensor(s, outT, inT); + if (auto pred = inT->getSource()) rewirePredToSucc(pred, s); + } + removeOperator(succ); + removeOperator(op); + if (hasSingleUse(t1_out)) removeTensor(t1_out); + changed = true; + } + } + return changed; + } + + // ====== 辅助:删除无用张量(DCE) ====== + bool GraphObj::deadCodeEliminate() { + bool changed = false; + // remove tensors that have no source and no targets + for (auto it = tensors.begin(); it != tensors.end();) { + auto t = *it; + if (t->getSource() == nullptr && t->getTargets().empty()) { + it = tensors.erase(it); + changed = true; + } else { + ++it; + } + } + return changed; + } + + // ====== 规则 2:把 last-2 swap 的 Transpose 融入 Matmul ====== + bool GraphObj::foldTransposeIntoMatmul() { + bool changed = false; + for (size_t i = 0; i < ops.size(); ++i) { + auto op = ops[i]; + if (!isMatmul(*op)) continue; + + auto mm = std::static_pointer_cast(op); + // A 输入 + { + auto A = op->getInputs()[0]; + auto ts = A->getSource(); + if (ts && isTranspose(*ts)) { + auto T = std::static_pointer_cast(ts); + if (isSwapLastTwo(T->getPermute()) && hasSingleUse(A)) { + mm->setTransA(!mm->getTransA()); + auto Ain = ts->getInputs().front(); + replaceInputTensor(op, A, Ain); + if (auto pred = Ain->getSource()) rewirePredToSucc(pred, op); + op->removePredecessors(ts); + removeOperator(ts); + if (hasSingleUse(A)) removeTensor(A); + changed = true; + } + } + } + // B 输入 + { + auto B = op->getInputs()[1]; + auto ts = B->getSource(); + if (ts && isTranspose(*ts)) { + auto T = std::static_pointer_cast(ts); + if (isSwapLastTwo(T->getPermute()) && hasSingleUse(B)) { + mm->setTransB(!mm->getTransB()); + auto Bin = ts->getInputs().front(); + replaceInputTensor(op, B, Bin); + if (auto pred = Bin->getSource()) rewirePredToSucc(pred, op); + op->removePredecessors(ts); + removeOperator(ts); + if (hasSingleUse(B)) removeTensor(B); + changed = true; + } + } + } + } + return changed; + } + void GraphObj::optimize() { // =================================== 作业 =================================== @@ -106,7 +314,37 @@ namespace infini // 1. 去除冗余的算子(例如,两个相邻的算子都是 transpose 算子,且做的是相反的操作,可以将其全部删除) // 2. 合并算子(例如,矩阵乘算子中含有属性transA、transB,如果其输入存在transpose,且对最后两个维度做交换,就可以将transpose融入到矩阵乘算子的属性中去) // =================================== 作业 =================================== - } + bool changed; + do { + changed = false; + changed |= fuseConsecutiveTransposes(); + changed |= foldTransposeIntoMatmul(); + // 这时可以跑一次 DCE:删除无用张量/算子(无用户/无产生者) + changed |= deadCodeEliminate(); + } while (changed); + + // After optimization, compact tensor list to only tensors referenced by ops + { + std::vector newTensors; + newTensors.reserve(tensors.size()); + std::unordered_set seen; + for (const auto &op : ops) { + for (const auto &t : op->getInputs()) { + if (t && seen.insert(t.get()).second) newTensors.push_back(t); + } + for (const auto &t : op->getOutputs()) { + if (t && seen.insert(t.get()).second) newTensors.push_back(t); + } + } + // also preserve graph inputs (tensors with no source but used by nothing yet) + for (const auto &t : tensors) { + if (t && t->getSource() == nullptr && seen.insert(t.get()).second) newTensors.push_back(t); + } + tensors = std::move(newTensors); + } + + std::cout << "Optimize complete!\n\n"; + } Tensor GraphObj::getTensor(int fuid) const { @@ -152,6 +390,21 @@ namespace infini // TODO:利用 allocator 给计算图分配内存 // HINT: 获取分配好的内存指针后,可以调用 tensor 的 setDataBlob 函数给 tensor 绑定内存 // =================================== 作业 =================================== + std::unordered_map, size_t> tensorToOffset; + for(auto tensor : tensors) + { + tensorToOffset[tensor] = allocator.alloc(tensor->getBytes()); + } + for(auto tensor : tensors) + { + tensor->setDataBlob(make_ref + ( + tensor->runtime, + static_cast(allocator.getPtr()) + + tensorToOffset[tensor] + ) + ); + } allocator.info(); } diff --git a/src/core/tensor.cc b/src/core/tensor.cc index db54a2d..da72113 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -24,8 +24,10 @@ namespace infini { ", dtype " + dtype.toString() + ", " + runtime->toString() + ", " + ss.str() + "\n"; vector targetGuids; - for (const auto &op : targets) - targetGuids.emplace_back(op.lock()->getGuid()); + for (const auto &wop : targets) { + if (auto op = wop.lock()) + targetGuids.emplace_back(op->getGuid()); + } if (auto o = source.lock()) ret += ", source " + std::to_string(o->getGuid()); else diff --git a/src/operators/concat.cc b/src/operators/concat.cc index d196330..c0fc864 100644 --- a/src/operators/concat.cc +++ b/src/operators/concat.cc @@ -18,6 +18,16 @@ optional> ConcatObj::inferShape(const TensorVec &inputs) { // REF: https://onnx.ai/onnx/operators/onnx__Concat.html#concat-13 // =================================== 作业 =================================== + int concat_dim = 0; + for (size_t i = 0; i < rank; i++) { + if(inputs[0]->getDims()[i] != inputs[1]->getDims()[i]) { + for (size_t j = 0; j < inputs.size(); j++) { + concat_dim += inputs[j]->getDims()[i]; + } + dims[i] = concat_dim; + break; + } + } return {{dims}}; } diff --git a/src/operators/matmul.cc b/src/operators/matmul.cc index 7a16ca2..ca51d97 100644 --- a/src/operators/matmul.cc +++ b/src/operators/matmul.cc @@ -1,4 +1,5 @@ #include "operators/matmul.h" +#include namespace infini { @@ -27,7 +28,13 @@ namespace infini // TODO:返回经过 matmul 操作后的 shape // REF: https://github.com/onnx/onnx/blob/main/docs/Operators.md#gemm // =================================== 作业 =================================== - return std::nullopt; + Shape ret = inputs[0]->getDims(); + for (size_t i = 0; i < inputs[1]->getRank() - 2; ++i) { + ret[i] = std::max(ret[i], inputs[1]->getDims()[i]); + } + ret[ret.size() - 2] = transA ? inputs[0]->getDims()[ret.size() - 1] : inputs[0]->getDims()[ret.size() - 2]; + ret[ret.size() - 1] = transB ? inputs[1]->getDims()[ret.size() - 2] : inputs[1]->getDims()[ret.size() - 1]; + return {{ret}}; } } // namespace infini \ No newline at end of file diff --git a/src/operators/transpose.cc b/src/operators/transpose.cc index faab2b6..9ba2e71 100644 --- a/src/operators/transpose.cc +++ b/src/operators/transpose.cc @@ -7,9 +7,10 @@ namespace infini : OperatorObj(OpType::Transpose, {input}, {output}) { auto rank = input->getRank(); + transposePermute.resize(rank); if (permute.empty()) { - for (size_t i = 0; i < rank; ++i) + for (int i = 0; i < (int)rank; ++i) { transposePermute[i] = i; } @@ -24,17 +25,19 @@ namespace infini optional> TransposeObj::inferShape(const TensorVec &inputs) { - const auto A = inputs[0]; - auto input_dim = A->getDims(); - auto output_dim = input_dim; - int rank = A->getRank(); - // =================================== 作业 =================================== // TODO:修改 output_dim,返回正确的 transpose 后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Transpose.html#transpose-21 // =================================== 作业 =================================== - - return std::nullopt; + int rank = static_cast(inputs[0]->getRank()); + auto input_dim = inputs[0]->getDims(); + auto output_dim = input_dim; + for (int i = 0; i < rank; ++i) + { + output_dim[i] = input_dim[transposePermute[i]]; + } + return {{output_dim}}; + } std::string TransposeObj::toString() const diff --git a/src/operators/unary.cc b/src/operators/unary.cc index 3daad36..81e0555 100644 --- a/src/operators/unary.cc +++ b/src/operators/unary.cc @@ -1,4 +1,5 @@ #include "operators/unary.h" +#include "core/tensor.h" namespace infini { @@ -39,7 +40,7 @@ namespace infini // TODO:返回经过 clip 操作后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Clip.html#clip-13 // =================================== 作业 =================================== - return std::nullopt; + return {{inputs[0]->getDims()}}; } std::string ClipObj::toString() const @@ -66,7 +67,8 @@ namespace infini // REF_FILE: src/core/operator.cc // REF: https://onnx.ai/onnx/operators/onnx__Cast.html#cast-21 // =================================== 作业 =================================== - return {}; + auto outputDTYPE = getOutputDataType(); + return vector(numOutputs(), outputDTYPE); } optional> CastObj::inferShape(const TensorVec &inputs) @@ -75,7 +77,7 @@ namespace infini // TODO:返回经过 cast 操作后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Cast.html#cast-21 // =================================== 作业 =================================== - return std::nullopt; + return {{inputs[0]->getDims()}}; } std::string CastObj::toString() const diff --git a/src/utils/operator_utils.cc b/src/utils/operator_utils.cc index edbd2c8..efe31f4 100644 --- a/src/utils/operator_utils.cc +++ b/src/utils/operator_utils.cc @@ -9,8 +9,16 @@ Shape infer_broadcast(const Shape &A, const Shape &B) { // TODO:对 A 和 B 进行双向广播,返回广播后的形状。 // REF: https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md // =================================== 作业 =================================== - - return {}; + Shape result = A.size() > B.size() ? A : B; + auto result_begin = result.rbegin(), result_end = result.rend(); + auto A_begin = A.rbegin(), A_end = A.rend(); + auto B_begin = B.rbegin(), B_end = B.rend(); + for (; A_begin != A_end && B_begin != B_end && result_begin != result_end; + ++A_begin, ++B_begin, ++result_begin) { + *result_begin = std::max(*A_begin, *B_begin); + } + + return {result}; } int get_real_axis(const int &axis, const int &rank) {