Skip to content

Commit

Permalink
[refactor] Move TypedConstants to taichi/ir/type (#2211)
Browse files Browse the repository at this point in the history
* [refactor] Move TypedConstants to taichi/ir/type

* [skip ci] enforce code format

* rm

* sort

Co-authored-by: Taichi Gardener <[email protected]>
  • Loading branch information
k-ye and taichi-gardener authored Mar 14, 2021
1 parent d28668b commit a33a71f
Show file tree
Hide file tree
Showing 12 changed files with 296 additions and 311 deletions.
76 changes: 38 additions & 38 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,42 +5,6 @@

TLANG_NAMESPACE_BEGIN

ASTBuilder &current_ast_builder() {
return context->builder();
}

Block *ASTBuilder::current_block() {
if (stack.empty())
return nullptr;
else
return stack.back();
}

Stmt *ASTBuilder::get_last_stmt() {
TI_ASSERT(!stack.empty());
return stack.back()->back();
}

void ASTBuilder::insert(std::unique_ptr<Stmt> &&stmt, int location) {
TI_ASSERT(!stack.empty());
stack.back()->insert(std::move(stmt), location);
}

void ASTBuilder::stop_gradient(SNode *snode) {
TI_ASSERT(!stack.empty());
stack.back()->stop_gradients.push_back(snode);
}

std::unique_ptr<ASTBuilder::ScopeGuard> ASTBuilder::create_scope(
std::unique_ptr<Block> &list) {
TI_ASSERT(list == nullptr);
list = std::make_unique<Block>();
if (!stack.empty()) {
list->parent_stmt = get_last_stmt();
}
return std::make_unique<ScopeGuard>(this, list.get());
}

FrontendSNodeOpStmt::FrontendSNodeOpStmt(SNodeOpType op_type,
SNode *snode,
const ExprGroup &indices,
Expand All @@ -63,8 +27,6 @@ IRNode *FrontendContext::root() {
return static_cast<IRNode *>(root_node.get());
}

std::unique_ptr<FrontendContext> context;

FrontendForStmt::FrontendForStmt(const ExprGroup &loop_var,
const Expr &global_var)
: global_var(global_var) {
Expand Down Expand Up @@ -401,4 +363,42 @@ void ExternalTensorShapeAlongAxisExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

Block *ASTBuilder::current_block() {
if (stack.empty())
return nullptr;
else
return stack.back();
}

Stmt *ASTBuilder::get_last_stmt() {
TI_ASSERT(!stack.empty());
return stack.back()->back();
}

void ASTBuilder::insert(std::unique_ptr<Stmt> &&stmt, int location) {
TI_ASSERT(!stack.empty());
stack.back()->insert(std::move(stmt), location);
}

void ASTBuilder::stop_gradient(SNode *snode) {
TI_ASSERT(!stack.empty());
stack.back()->stop_gradients.push_back(snode);
}

std::unique_ptr<ASTBuilder::ScopeGuard> ASTBuilder::create_scope(
std::unique_ptr<Block> &list) {
TI_ASSERT(list == nullptr);
list = std::make_unique<Block>();
if (!stack.empty()) {
list->parent_stmt = get_last_stmt();
}
return std::make_unique<ScopeGuard>(this, list.get());
}

ASTBuilder &current_ast_builder() {
return context->builder();
}

std::unique_ptr<FrontendContext> context;

TLANG_NAMESPACE_END
108 changes: 53 additions & 55 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,61 +10,6 @@

TLANG_NAMESPACE_BEGIN

class ASTBuilder;

ASTBuilder &current_ast_builder();

class FrontendContext {
private:
std::unique_ptr<ASTBuilder> current_builder;
std::unique_ptr<Block> root_node;

public:
FrontendContext();

ASTBuilder &builder() {
return *current_builder;
}

IRNode *root();

std::unique_ptr<Block> get_root() {
return std::move(root_node);
}
};

extern std::unique_ptr<FrontendContext> context;

class ASTBuilder {
private:
std::vector<Block *> stack;

public:
ASTBuilder(Block *initial) {
stack.push_back(initial);
}

void insert(std::unique_ptr<Stmt> &&stmt, int location = -1);

struct ScopeGuard {
ASTBuilder *builder;
Block *list;
ScopeGuard(ASTBuilder *builder, Block *list)
: builder(builder), list(list) {
builder->stack.push_back(list);
}

~ScopeGuard() {
builder->stack.pop_back();
}
};

std::unique_ptr<ScopeGuard> create_scope(std::unique_ptr<Block> &list);
Block *current_block();
Stmt *get_last_stmt();
void stop_gradient(SNode *);
};

// Frontend Statements

class FrontendAllocaStmt : public Stmt {
Expand Down Expand Up @@ -631,4 +576,57 @@ class ExternalTensorShapeAlongAxisExpression : public Expression {
void flatten(FlattenContext *ctx) override;
};

class ASTBuilder {
private:
std::vector<Block *> stack;

public:
ASTBuilder(Block *initial) {
stack.push_back(initial);
}

void insert(std::unique_ptr<Stmt> &&stmt, int location = -1);

struct ScopeGuard {
ASTBuilder *builder;
Block *list;
ScopeGuard(ASTBuilder *builder, Block *list)
: builder(builder), list(list) {
builder->stack.push_back(list);
}

~ScopeGuard() {
builder->stack.pop_back();
}
};

std::unique_ptr<ScopeGuard> create_scope(std::unique_ptr<Block> &list);
Block *current_block();
Stmt *get_last_stmt();
void stop_gradient(SNode *);
};

ASTBuilder &current_ast_builder();

class FrontendContext {
private:
std::unique_ptr<ASTBuilder> current_builder;
std::unique_ptr<Block> root_node;

public:
FrontendContext();

ASTBuilder &builder() {
return *current_builder;
}

IRNode *root();

std::unique_ptr<Block> get_root() {
return std::move(root_node);
}
};

extern std::unique_ptr<FrontendContext> context;

TLANG_NAMESPACE_END
6 changes: 1 addition & 5 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <thread>
#include <unordered_map>

#include "taichi/ir/analysis.h"
// #include "taichi/ir/analysis.h"
#include "taichi/ir/statements.h"
#include "taichi/ir/transforms.h"

Expand Down Expand Up @@ -514,8 +514,4 @@ LocalAddress::LocalAddress(Stmt *var, int offset) : var(var), offset(offset) {
TI_ASSERT(var->is<AllocaStmt>());
}

void Stmt::infer_type() {
irpass::type_check(this);
}

TLANG_NAMESPACE_END
3 changes: 0 additions & 3 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include "taichi/ir/ir_modified.h"
#include "taichi/ir/snode.h"
#include "taichi/ir/type_factory.h"
#include "taichi/llvm/llvm_fwd.h"
#include "taichi/util/short_name.h"

TLANG_NAMESPACE_BEGIN
Expand Down Expand Up @@ -586,8 +585,6 @@ class Stmt : public IRNode {
return make_typed<T>(std::forward<Args>(args)...);
}

void infer_type();

void set_tb(const std::string &tb) {
this->tb = tb;
}
Expand Down
21 changes: 4 additions & 17 deletions taichi/ir/statements.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
// TODO: gradually cppize statements.h
#include "taichi/ir/statements.h"
#include "taichi/program/program.h"
#include "taichi/util/bit.h"

TLANG_NAMESPACE_BEGIN
Expand Down Expand Up @@ -304,20 +303,8 @@ GetChStmt::GetChStmt(Stmt *input_ptr, int chid, bool is_bit_vectorized)
TI_STMT_REG_FIELDS;
}

OffloadedStmt::OffloadedStmt(OffloadedStmt::TaskType task_type)
: OffloadedStmt(task_type, nullptr) {
}

OffloadedStmt::OffloadedStmt(OffloadedStmt::TaskType task_type, SNode *snode)
: task_type(task_type), snode(snode) {
num_cpu_threads = 1;
const_begin = false;
const_end = false;
begin_value = 0;
end_value = 0;
step = 0;
reversed = false;
device = get_current_program().config.arch;
OffloadedStmt::OffloadedStmt(TaskType task_type, Arch arch)
: task_type(task_type), device(arch) {
if (has_body()) {
body = std::make_unique<Block>();
body->parent_stmt = this;
Expand Down Expand Up @@ -349,7 +336,8 @@ std::string OffloadedStmt::task_type_name(TaskType tt) {
}

std::unique_ptr<Stmt> OffloadedStmt::clone() const {
auto new_stmt = std::make_unique<OffloadedStmt>(task_type, snode);
auto new_stmt = std::make_unique<OffloadedStmt>(task_type, device);
new_stmt->snode = snode;
new_stmt->begin_offset = begin_offset;
new_stmt->end_offset = end_offset;
new_stmt->const_begin = const_begin;
Expand All @@ -361,7 +349,6 @@ std::unique_ptr<Stmt> OffloadedStmt::clone() const {
new_stmt->block_dim = block_dim;
new_stmt->reversed = reversed;
new_stmt->num_cpu_threads = num_cpu_threads;
new_stmt->device = device;
new_stmt->index_offsets = index_offsets;
if (tls_prologue) {
new_stmt->tls_prologue = tls_prologue->clone();
Expand Down
28 changes: 14 additions & 14 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -839,17 +839,19 @@ class OffloadedStmt : public Stmt {
using TaskType = OffloadedTaskType;

TaskType task_type;
SNode *snode;
std::size_t begin_offset;
std::size_t end_offset;
bool const_begin, const_end;
int32 begin_value, end_value;
int step;
Arch device;
SNode *snode{nullptr};
std::size_t begin_offset{0};
std::size_t end_offset{0};
bool const_begin{false};
bool const_end{false};
int32 begin_value{0};
int32 end_value{0};
int step{0};
int grid_dim{1};
int block_dim{1};
bool reversed;
int num_cpu_threads;
Arch device;
bool reversed{false};
int num_cpu_threads{1};

std::vector<int> index_offsets;

Expand All @@ -862,9 +864,7 @@ class OffloadedStmt : public Stmt {
std::size_t bls_size{0};
MemoryAccessOptions mem_access_opt;

OffloadedStmt(TaskType task_type);

OffloadedStmt(TaskType task_type, SNode *snode);
OffloadedStmt(TaskType task_type, Arch arch);

std::string task_name() const;

Expand All @@ -882,8 +882,9 @@ class OffloadedStmt : public Stmt {

void all_blocks_accept(IRVisitor *visitor);

TI_STMT_DEF_FIELDS(ret_type,
TI_STMT_DEF_FIELDS(ret_type /*inherited from Stmt*/,
task_type,
device,
snode,
begin_offset,
end_offset,
Expand All @@ -896,7 +897,6 @@ class OffloadedStmt : public Stmt {
block_dim,
reversed,
num_cpu_threads,
device,
index_offsets,
mem_access_opt);
TI_DEFINE_ACCEPT
Expand Down
Loading

0 comments on commit a33a71f

Please sign in to comment.