diff --git a/Makefile b/Makefile index 35ef7ef..f032aaa 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,8 @@ .PHONY : build clean format install-python test-cpp test-onnx -TYPE ?= Release + + +TYPE ?= Debug TEST ?= ON CMAKE_OPT = -DCMAKE_BUILD_TYPE=$(TYPE) diff --git a/README.md b/README.md index 7bbac1c..f62cd19 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,9 @@ # TinyInfiniTensor + + 一个简化版的 ai compiler,用于初学者快速上手学习,保留了计算图和 kernel 层的概念,能够基于 c++ 搭建计算图进行推理计算,目前只支持 cpu 平台。 [环境部署文档](docs/项目部署.md) -[训练营作业介绍文档](docs/训练营作业介绍.md) \ No newline at end of file +[训练营作业介绍文档](docs/训练营作业介绍.md) diff --git a/include/core/allocator.h b/include/core/allocator.h index 002601d..0df7e2f 100644 --- a/include/core/allocator.h +++ b/include/core/allocator.h @@ -27,7 +27,8 @@ namespace infini { // TODO:可能需要设计一个数据结构来存储free block,以便于管理和合并 // HINT: 可以使用一个 map 来存储 free block,key 为 block 的起始/结尾地址,value 为 block 的大小 // =================================== 作业 =================================== - + std::map freeBlocks; + public: Allocator(Runtime runtime); diff --git a/report.xml b/report.xml new file mode 100644 index 0000000..c584263 --- /dev/null +++ b/report.xml @@ -0,0 +1,10 @@ + + + + + + + + diff --git a/src/core/allocator.cc b/src/core/allocator.cc index ff593ae..c06a10d 100644 --- a/src/core/allocator.cc +++ b/src/core/allocator.cc @@ -32,18 +32,67 @@ namespace infini // =================================== 作业 =================================== // TODO: 设计一个算法来分配内存,返回起始地址偏移量 // =================================== 作业 =================================== - + if (freeBlocks.empty()) + { + freeBlocks[0] = 4096; // Initially, all memory is free + } + for(auto it = freeBlocks.begin(); it != freeBlocks.end(); it ++){ + auto [addr, blockSize] = *it; + if(blockSize >= size){ //blockSize 是可用空间 + if(blockSize > size){ + // Split the block if it's larger than requested size + freeBlocks[addr + size] = blockSize - size; + } + freeBlocks.erase(it); + used += size; + peak = std::max(peak, used); + return it->first; + } + } + return 0; + + + + // if (this->freeBlocks.empty()) + // this->freeBlocks[0] = 1024; + // for (auto it = this->freeBlocks.begin(); it != this->freeBlocks.end(); ++it) + // { + // if (it->second >= size) + // { + // if (it->second > size) + // this->freeBlocks[it->first + size] = it->second - size; + // auto ans = it->first; + // this->freeBlocks.erase(it); + // this->used += size; + // this->peak = (this->peak >= this->used) ? this->peak : this->used; + // return ans; + // } + // } } void Allocator::free(size_t addr, size_t size) { IT_ASSERT(this->ptr == nullptr); size = getAlignedSize(size); - // =================================== 作业 =================================== // TODO: 设计一个算法来回收内存 - // =================================== 作业 =================================== + // =================================== 作业 =================================== + freeBlocks[addr] = size; + auto it = freeBlocks.find(addr); + auto nextIt = std::next(it); + if (nextIt != freeBlocks.end() && it->first + it->second == nextIt->first) + { + it->second += nextIt->second; + freeBlocks.erase(nextIt); + } + auto prevIt = std::prev(it); + if (it != freeBlocks.begin() && prevIt->first + prevIt->second == it->first) + { + prevIt->second += it->second; + freeBlocks.erase(it); + } + used = used - size; } void *Allocator::getPtr() diff --git a/src/core/graph.cc b/src/core/graph.cc index 3a90637..4219772 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -2,7 +2,8 @@ #include #include #include - +#include "operators/matmul.h" +#include "operators/transpose.h" namespace infini { @@ -106,6 +107,163 @@ namespace infini // 1. 去除冗余的算子(例如,两个相邻的算子都是 transpose 算子,且做的是相反的操作,可以将其全部删除) // 2. 合并算子(例如,矩阵乘算子中含有属性transA、transB,如果其输入存在transpose,且对最后两个维度做交换,就可以将transpose融入到矩阵乘算子的属性中去) // =================================== 作业 =================================== + + // rule1: 删除无用的transpose算子 + for (size_t i = 0; i < ops.size(); ++i) + { + Operator op = ops[i]; + if (op->getOpType() == OpType::Transpose) + { + Tensor tensor = op->getOutput(); + if (!tensor) + continue; + auto targets = tensor->getTargets(); + if (targets.empty()) + continue; + Operator op_next = targets[0]; + if (op_next->getOpType() == OpType::Transpose) + { + TransposeObj *op1 = as(op).get(); + TransposeObj *op2 = as(op_next).get(); + auto op1_permute = op1->getPermute(); + auto op2_permute = op2->getPermute(); + if (op1_permute.size() != op2_permute.size()) + continue; + bool flag = true; + for (int j = 0; j < (int)op1_permute.size(); j++) + { + if (op1_permute[op2_permute[j]] != j) + { + flag = false; + continue; + } + } + if (!flag) //flag为false说明 无法合并 + continue; + // 获取第一个转置算子的输入张量(原始输入数据) + Tensor originalInput = op->getInputs()[0]; + + // 获取第一个转置算子的输出张量(第一次转置结果) + Tensor firstTransposeOutput = op->getOutput(); + + // 获取第二个转置算子的输出张量(最终转置结果) + Tensor secondTransposeOutput = op_next->getOutput(); + + // 获取使用最终结果的消费者算子(如矩阵乘法) + Operator consumerOp = secondTransposeOutput->getTargets()[0]; + + // 保留消费者算子的其他输入(如矩阵乘法的右矩阵) + Tensor consumerOtherInput = consumerOp->getInputs()[1]; + + // 重定向消费者算子的输入:跳过两个转置,直接使用原始输入 + consumerOp->replaceInput(consumerOp->getInputs()[0], originalInput); + + // 更新原始输入的连接关系: + originalInput->removeTarget(op); // 移除对第一个转置的引用 + originalInput->addTarget(consumerOp); // 添加对消费者算子的引用 + originalInput->setSource(nullptr); // 清除可能存在的生产者标记 + + // 清理冗余资源 + removeOperator(op); // 删除第一个转置算子 + removeOperator(op_next); // 删除第二个转置算子 + removeTensor(firstTransposeOutput); // 删除中间结果张量 + removeTensor(secondTransposeOutput); // 删除最终结果张量 + + // 更新算子间的拓扑依赖关系 + consumerOp->removePredecessors(op_next); // 移除与第二个转置的依赖 + + // 如果原始输入有生产者,建立新的依赖关系 + if (originalInput->getSource()) { + consumerOp->addPredecessors(originalInput->getSource()); + originalInput->getSource()->addSuccessors(consumerOp); + } + } + } + } + + // 遍历图中的所有算子,寻找可优化的矩阵乘法算子 + for (size_t opIndex = 0; opIndex < ops.size(); ++opIndex) { + Operator currentOp = ops[opIndex]; + + // 只处理矩阵乘法算子 + if (currentOp->getOpType() == OpType::MatMul) { + // 获取矩阵乘法的输入张量列表(左矩阵和右矩阵) + TensorVec matmulInputs = currentOp->getInputs(); + int inputIndex = 0; // 用于标识当前是左输入(0)还是右输入(1) + + // 检查每个输入张量 + for (Tensor inputTensor : matmulInputs) { + inputIndex++; + + // 检查输入张量是否有生产者算子 + if (inputTensor->getSource()) { + Operator producerOp = inputTensor->getSource(); + + // 如果生产者是转置算子 + if (producerOp->getOpType() == OpType::Transpose) { + TransposeObj *transposeOp = as(producerOp).get(); + Shape transposePerm = transposeOp->getPermute(); + bool isLastTwoDimsSwap = true; + + /* 验证转置操作是否只交换最后两个维度: + * 1. 前n-2个维度必须保持原顺序(即perm[j] == j) + * 2. 最后两个维度必须交换(即perm[-2] == rank-1 且 perm[-1] == rank-2) + */ + for (int dim = 0; dim < (int)transposePerm.size() - 2; dim++) { + if (transposePerm[dim] != dim) { + isLastTwoDimsSwap = false; + break; + } + } + if (transposePerm[transposePerm.size() - 2] != (int)transposePerm.size() - 1 || + transposePerm[transposePerm.size() - 1] != (int)transposePerm.size() - 2) { + isLastTwoDimsSwap = false; + } + + // 如果不满足条件则跳过优化 + if (!isLastTwoDimsSwap) continue; + + // 获取矩阵乘法算子(用于修改转置属性) + MatmulObj *matmulOp = as(currentOp).get(); + Tensor transposedTensor; + + // 根据输入位置设置对应的转置标志 + if (inputIndex == 1) { // 左输入 + matmulOp->setTransA(true); // 启用左矩阵转置 + transposedTensor = matmulOp->getInputs(0); + } else { // 右输入 + matmulOp->setTransB(true); // 启用右矩阵转置 + transposedTensor = matmulOp->getInputs(1); + } + + // 获取转置算子的输入(原始未转置的张量) + Operator transposeOperator = transposedTensor->getSource(); + Tensor originalTensor = transposeOperator->getInputs()[0]; + + // 重定向矩阵乘法的输入:跳过转置算子,直接使用原始张量 + matmulOp->replaceInput(transposedTensor, originalTensor); + + // 更新张量连接关系 + originalTensor->removeTarget(transposeOperator); + originalTensor->addTarget(currentOp); + + // 清理资源:删除转置算子和中间张量 + removeOperator(transposeOperator); + removeTensor(transposedTensor); + + // 更新拓扑关系:移除转置算子作为前驱 + currentOp->removePredecessors(transposeOperator); + + // 如果原始张量有生产者,建立新的依赖关系 + if (originalTensor->getSource()) { + currentOp->addPredecessors(originalTensor->getSource()); + originalTensor->getSource()->addSuccessors(currentOp); + } + } + } + } + } +} } Tensor GraphObj::getTensor(int fuid) const @@ -152,7 +310,40 @@ namespace infini // TODO:利用 allocator 给计算图分配内存 // HINT: 获取分配好的内存指针后,可以调用 tensor 的 setDataBlob 函数给 tensor 绑定内存 // =================================== 作业 =================================== + // allocator.info(); + // void* allocatorPtr = allocator.getPtr(); + // for(auto it = tensors.begin(); it != tensors.end(); it++){ + // auto tensor = *it; + // size_t size = tensor->getBytes(); + // size_t addr = allocator.alloc(size); + // char * tmpPtr = reinterpret_cast(allocatorPtr) + addr; + // Blob blob = make_ref(runtime, (void *)tmpPtr); + // tensor->setDataBlob(blob); + // } + // topological sorting first + IT_ASSERT(topo_sort() == true); + // =================================== 作业 =================================== + // TODO:利用 allocator 给计算图分配内存 + // HINT: 获取分配好的内存指针后,可以调用 tensor 的 setDataBlob 函数给 tensor 绑定内存 + // =================================== 作业 =================================== + vector offsets; + for (auto tensor : tensors) + { + size_t size = tensor->getBytes(); + size_t offset = allocator.alloc(size); + offsets.push_back(offset); + } + auto it = offsets.begin(); + void *basePtr = allocator.getPtr(); + for (auto tensor : tensors) + { + char *charPtr = reinterpret_cast(basePtr) + *it; + void *ptr = charPtr; + Blob blob = make_ref(runtime, ptr); + tensor->setDataBlob(blob); + it++; + } allocator.info(); } diff --git a/src/operators/concat.cc b/src/operators/concat.cc index d196330..0e8f85b 100644 --- a/src/operators/concat.cc +++ b/src/operators/concat.cc @@ -10,15 +10,30 @@ ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int _dim) } optional> ConcatObj::inferShape(const TensorVec &inputs) { - Shape dims = inputs[0]->getDims(); + Shape dims = inputs[0]->getDims(); // 数组的 shape auto rank = inputs[0]->getRank(); - // =================================== 作业 =================================== // TODO:修改 dims,返回正确的 concat 后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Concat.html#concat-13 // =================================== 作业 =================================== - - return {{dims}}; + if(inputs.size() == 0) { + return std::nullopt; + } + for(auto input: inputs){ + if(input->getDims().size() != rank) + return std::nullopt; + } + vector res(rank, 0); + for(auto input: inputs){ + for(size_t i = 0; i < rank; i++){ + if(i == size_t(dim)){ + res[i] += input->getDims()[i]; + }else if (i != size_t(dim)){ + res[i] = input->getDims()[i]; + } + } + } + return {{res}}; } std::string ConcatObj::toString() const { diff --git a/src/operators/matmul.cc b/src/operators/matmul.cc index 7a16ca2..7f903a4 100644 --- a/src/operators/matmul.cc +++ b/src/operators/matmul.cc @@ -27,7 +27,71 @@ namespace infini // TODO:返回经过 matmul 操作后的 shape // REF: https://github.com/onnx/onnx/blob/main/docs/Operators.md#gemm // =================================== 作业 =================================== - return std::nullopt; + // 检查输入数量 + if (inputs.size() != 2) { + return std::nullopt; + } + + const auto &A = inputs[0]; + const auto &B = inputs[1]; + const auto &dimsA = A->getDims(); + const auto &dimsB = B->getDims(); + + // 检查维度数量至少为2 + if (dimsA.size() < 2 || dimsB.size() < 2) { + return std::nullopt; + } + + // 获取最后两个维度(矩阵维度) + size_t rankA = dimsA.size(); + size_t rankB = dimsB.size(); + size_t M = dimsA[rankA - 2]; + size_t K_A = dimsA[rankA - 1]; + size_t K_B = dimsB[rankB - 2]; + size_t N = dimsB[rankB - 1]; + + // 考虑转置 + if (transA) { + std::swap(M, K_A); + } + if (transB) { + std::swap(K_B, N); + } + + // 检查K维度是否匹配 + if (K_A != K_B) { + return std::nullopt; + } + + // 计算输出形状 + vector outputShapes; + + // 处理广播维度 + size_t broadcastRank = std::max(rankA, rankB); + Shape broadcastDims(broadcastRank - 2); + + for (size_t i = 0; i < broadcastRank - 2; ++i) { + size_t dimA = (i < rankA - 2) ? dimsA[i] : 1; + size_t dimB = (i < rankB - 2) ? dimsB[i] : 1; + + if (dimA == dimB) { + broadcastDims[i] = dimA; + } else if (dimA == 1) { + broadcastDims[i] = dimB; + } else if (dimB == 1) { + broadcastDims[i] = dimA; + } else { + // 不兼容的广播维度 + return std::nullopt; + } + } + // 构建最终输出形状 + Shape outputShape = broadcastDims; + outputShape.push_back(M); + outputShape.push_back(N); + + outputShapes.push_back(outputShape); + return outputShapes; } } // namespace infini \ No newline at end of file diff --git a/src/operators/transpose.cc b/src/operators/transpose.cc index faab2b6..83dff41 100644 --- a/src/operators/transpose.cc +++ b/src/operators/transpose.cc @@ -25,7 +25,7 @@ namespace infini optional> TransposeObj::inferShape(const TensorVec &inputs) { const auto A = inputs[0]; - auto input_dim = A->getDims(); + auto input_dim = A->getDims(); // vector auto output_dim = input_dim; int rank = A->getRank(); @@ -33,8 +33,12 @@ namespace infini // TODO:修改 output_dim,返回正确的 transpose 后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Transpose.html#transpose-21 // =================================== 作业 =================================== - - return std::nullopt; + for(int i = 0; i < rank; ++i) + { + output_dim[i] = input_dim[transposePermute[i]]; + } + + return vector{output_dim}; } std::string TransposeObj::toString() const diff --git a/src/operators/unary.cc b/src/operators/unary.cc index 3daad36..c60bb3f 100644 --- a/src/operators/unary.cc +++ b/src/operators/unary.cc @@ -39,7 +39,12 @@ namespace infini // TODO:返回经过 clip 操作后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Clip.html#clip-13 // =================================== 作业 =================================== - return std::nullopt; + if(inputs.empty()){ + return std::nullopt; + } + const auto A = inputs[0]; + auto input_dim = A->getDims(); // vector + return {{input_dim}}; } std::string ClipObj::toString() const @@ -66,7 +71,11 @@ namespace infini // REF_FILE: src/core/operator.cc // REF: https://onnx.ai/onnx/operators/onnx__Cast.html#cast-21 // =================================== 作业 =================================== - return {}; + if(inputs.empty()){ + return {}; + } + + return {getOutputDataType()}; } optional> CastObj::inferShape(const TensorVec &inputs) @@ -75,7 +84,10 @@ namespace infini // TODO:返回经过 cast 操作后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Cast.html#cast-21 // =================================== 作业 =================================== - return std::nullopt; + if(inputs.empty()){ + return std::nullopt; + } + return {{inputs[0]->getDims()}}; } std::string CastObj::toString() const diff --git a/src/utils/operator_utils.cc b/src/utils/operator_utils.cc index edbd2c8..f7c43b8 100644 --- a/src/utils/operator_utils.cc +++ b/src/utils/operator_utils.cc @@ -10,7 +10,32 @@ Shape infer_broadcast(const Shape &A, const Shape &B) { // REF: https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md // =================================== 作业 =================================== - return {}; + // 获取输入张量的维度 + size_t rankA = A.size(); + size_t rankB = B.size(); + size_t maxRank = std::max(rankA, rankB); + Shape result(maxRank); + + // 从右向左逐维度比较 + for (size_t i = 0; i < maxRank; ++i) { + // 获取当前维度(从右向左) + size_t dimA = (i < rankA) ? A[rankA - 1 - i] : 1; + size_t dimB = (i < rankB) ? B[rankB - 1 - i] : 1; + + // 检查兼容性 + if (dimA == dimB) { + result[maxRank - 1 - i] = dimA; + } else if (dimA == 1) { + result[maxRank - 1 - i] = dimB; + } else if (dimB == 1) { + result[maxRank - 1 - i] = dimA; + } else { + // 不兼容,返回空 Shape + return {}; + } + } + + return result; } int get_real_axis(const int &axis, const int &rank) { diff --git a/test/core/test_allocator.cc b/test/core/test_allocator.cc index 0515edc..455902f 100644 --- a/test/core/test_allocator.cc +++ b/test/core/test_allocator.cc @@ -9,8 +9,8 @@ namespace infini { TEST(Allocator, testAlloc) { - Shape shape = Shape{1, 2, 2, 3}; - Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Shape shape = Shape{1, 2, 2, 3}; // Shape <--> vector + Runtime runtime = NativeCpuRuntimeObj::getInstance(); Tensor a = make_ref(shape, DataType::Float32, runtime); Tensor b = make_ref(shape, DataType::Float32, runtime); Tensor c = make_ref(shape, DataType::Float32, runtime); @@ -35,8 +35,7 @@ namespace infini Tensor a = make_ref(shape, DataType::Float32, runtime); Tensor b = make_ref(shape, DataType::Float32, runtime); Tensor c = make_ref(shape, DataType::Float32, runtime); - Tensor d = - make_ref(Shape{2, 2, 2, 3}, DataType::Float32, runtime); + Tensor d = make_ref(Shape{2, 2, 2, 3}, DataType::Float32, runtime); Allocator allocator = Allocator(runtime); // allocate a->b->c allocator.alloc(a->getBytes()); diff --git a/test/core/test_graph.cc b/test/core/test_graph.cc index bf696dd..afed62d 100644 --- a/test/core/test_graph.cc +++ b/test/core/test_graph.cc @@ -19,7 +19,7 @@ namespace infini Tensor t3 = g->addTensor({2, 3, 5, 4}, DataType::UInt32); Tensor o = g->addTensor({2, 3, 4, 4}, DataType::UInt32); g->addOpWithOutputs(i1, t1, Shape{0, 1, 3, 2}); - g->addOpWithOutputs(t1, t2, Shape{0, 1, 3, 2}); + g->addOpWithOutputs(t1, t2, Shape{0, 1, 3, 2}); g->addOpWithOutputs(i2, t3, Shape{0, 1, 3, 2}); g->addOpWithOutputs(t2, t3, o); // 优化前