forked from sorbet/sorbet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathInitializer.cc
108 lines (93 loc) · 3.53 KB
/
Initializer.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#include "rewriter/Initializer.h"
#include "ast/Helpers.h"
#include "ast/ast.h"
#include "core/core.h"
#include "rewriter/Util.h"
using namespace std;
namespace sorbet::rewriter {
namespace {
// We can't actually use a T.type_parameter type in the body of a method, so this prevents us from copying those.
//
// TODO: remove once https://github.com/sorbet/sorbet/issues/1715 is fixed
bool isCopyableType(const ast::ExpressionPtr &typeExpr) {
auto send = ast::cast_tree<ast::Send>(typeExpr);
if (send && send->fun == core::Names::typeParameter()) {
return false;
}
return true;
}
// if expr is of the form `@var = local`, and `local` is typed, then replace it with with `@var = T.let(local,
// type_of_local)`
void maybeAddLet(core::MutableContext ctx, ast::ExpressionPtr &expr,
const UnorderedMap<core::NameRef, const ast::ExpressionPtr *> &argTypeMap) {
auto assn = ast::cast_tree<ast::Assign>(expr);
if (assn == nullptr) {
return;
}
auto lhs = ast::cast_tree<ast::UnresolvedIdent>(assn->lhs);
if (lhs == nullptr || lhs->kind != ast::UnresolvedIdent::Kind::Instance) {
return;
}
auto rhs = ast::cast_tree<ast::UnresolvedIdent>(assn->rhs);
if (rhs == nullptr || rhs->kind != ast::UnresolvedIdent::Kind::Local) {
return;
}
auto typeExpr = argTypeMap.find(rhs->name);
if (typeExpr != argTypeMap.end() && isCopyableType(*typeExpr->second)) {
auto loc = rhs->loc;
auto newLet = ast::MK::Let(loc, move(assn->rhs), (*typeExpr->second).deepCopy());
assn->rhs = move(newLet);
}
}
// this walks through the chain of sends contained in the body of the `sig` block to find the `params` one, if it
// exists; and otherwise returns a null pointer
const ast::Send *findParams(const ast::Send *send) {
while (send && send->fun != core::Names::params()) {
send = ast::cast_tree<ast::Send>(send->recv);
}
return send;
}
} // namespace
void Initializer::run(core::MutableContext ctx, ast::MethodDef *methodDef, ast::ExpressionPtr *prevStat) {
// this should only run in an `initialize` that has a sig
if (methodDef->name != core::Names::initialize()) {
return;
}
if (prevStat == nullptr) {
return;
}
// make sure that the `sig` block looks like a valid sig block
auto *sig = ASTUtil::castSig(*prevStat);
if (sig == nullptr) {
return;
}
auto *block = ast::cast_tree<ast::Block>(sig->block);
if (block == nullptr) {
return;
}
// walk through, find the `params()` invocation, and get its hash
auto *params = findParams(ast::cast_tree<ast::Send>(block->body));
if (params == nullptr) {
return;
}
// build a lookup table that maps from names to the types they have
auto [kwStart, kwEnd] = params->kwArgsRange();
UnorderedMap<core::NameRef, const ast::ExpressionPtr *> argTypeMap;
for (int i = kwStart; i < kwEnd; i += 2) {
auto *argName = ast::cast_tree<ast::Literal>(params->args[i]);
auto *argVal = ¶ms->args[i + 1];
if (argName->isSymbol(ctx)) {
argTypeMap[argName->asSymbol(ctx)] = argVal;
}
}
// look through the rhs to find statements of the form `@var = local`
if (auto stmts = ast::cast_tree<ast::InsSeq>(methodDef->rhs)) {
for (auto &s : stmts->stats) {
maybeAddLet(ctx, s, argTypeMap);
}
maybeAddLet(ctx, stmts->expr, argTypeMap);
} else {
maybeAddLet(ctx, methodDef->rhs, argTypeMap);
}
}
} // namespace sorbet::rewriter