Skip to content

Commit c5be7c7

Browse files
authored
Port a test from LFortran (#2782)
1 parent 542300f commit c5be7c7

File tree

4 files changed

+59
-7
lines changed

4 files changed

+59
-7
lines changed

integration_tests/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,7 @@ RUN(NAME lambda_01 LABELS cpython llvm llvm_jit)
835835
RUN(NAME c_mangling LABELS cpython llvm llvm_jit c)
836836
RUN(NAME class_01 LABELS cpython llvm llvm_jit)
837837
RUN(NAME class_02 LABELS cpython llvm llvm_jit)
838+
RUN(NAME class_03 LABELS cpython llvm llvm_jit)
838839

839840
# callback_04 is to test emulation. So just run with cpython
840841
RUN(NAME callback_04 IMPORT_PATH .. LABELS cpython)

integration_tests/class_02.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ def __init__(self:"Character", name:str, health:i32, attack_power:i32):
66
self.attack_power : i32 = attack_power
77
self.is_immortal : bool = False
88

9-
def attack(self:"Character", other:"Character") -> str:
9+
def attack(self:"Character", other:"Character")->str:
1010
other.health -= self.attack_power
1111
return self.name+" attacks "+ other.name+" for "+str(self.attack_power)+" damage."
1212

@@ -41,4 +41,3 @@ def main():
4141
assert hero.is_alive() == True
4242

4343
main()
44-

integration_tests/class_03.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from lpython import f64
2+
from math import pi
3+
4+
class Circle:
5+
def __init__(self:"Circle", radius:f64):
6+
self.radius :f64 = radius
7+
8+
def circle_area(self:"Circle")->f64:
9+
return pi * self.radius ** 2.0
10+
11+
def circle_print(self:"Circle"):
12+
area : f64 = self.circle_area()
13+
print("Circle: r = ",str(self.radius)," area = ",str(area))
14+
15+
def main():
16+
c : Circle = Circle(1.0)
17+
c.circle_print()
18+
assert abs(c.circle_area() - 3.141593) <= 1e-6
19+
c.radius = 1.5
20+
c.circle_print()
21+
assert abs(c.circle_area() - 7.068583) < 1e-6
22+
23+
if __name__ == "__main__":
24+
main()

src/lpython/semantics/python_ast_to_asr.cpp

+33-5
Original file line numberDiff line numberDiff line change
@@ -1931,7 +1931,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
19311931
return ASRUtils::TYPE(ASR::make_Union_t(al, attr_annotation->base.base.loc, import_struct_member));
19321932
} else if ( AST::is_a<AST::ConstantStr_t>(annotation) ) {
19331933
AST::ConstantStr_t *n = AST::down_cast<AST::ConstantStr_t>(&annotation);
1934-
ASR::symbol_t *sym = current_scope->parent->parent->resolve_symbol(n->m_value);
1934+
ASR::symbol_t *sym = current_scope->resolve_symbol(n->m_value);
19351935
if ( sym == nullptr || !ASR::is_a<ASR::Struct_t>(*sym) ) {
19361936
throw SemanticError("Only Struct implemented for constant"
19371937
" str annotation", loc);
@@ -3300,6 +3300,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
33003300
if ( AST::is_a<AST::FunctionDef_t>(*x.m_body[i]) ) {
33013301
AST::FunctionDef_t*
33023302
f = AST::down_cast<AST::FunctionDef_t>(x.m_body[i]);
3303+
init_self_type(*f, sym, x.base.base.loc);
33033304
if ( std::string(f->m_name) == std::string("__init__") ) {
33043305
this->visit_init_body(*f);
33053306
} else {
@@ -3348,6 +3349,30 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
33483349
}
33493350
}
33503351

3352+
void init_self_type (const AST::FunctionDef_t &x,
3353+
ASR::symbol_t* class_sym, Location loc) {
3354+
SymbolTable* parent_scope = current_scope;
3355+
ASR::symbol_t *t = current_scope->get_symbol(x.m_name);
3356+
if (t==nullptr) {
3357+
throw SemanticError("Function not found in current symbol table",
3358+
x.base.base.loc);
3359+
}
3360+
if ( !ASR::is_a<ASR::Function_t>(*t) ) {
3361+
throw SemanticError("Only functions implemented in classes",
3362+
x.base.base.loc);
3363+
}
3364+
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(t);
3365+
current_scope = f->m_symtab;
3366+
if ( f->n_args==0 ) {
3367+
return;
3368+
}
3369+
std::string self_name = x.m_args.m_args[0].m_arg;
3370+
ASR::symbol_t* sym = current_scope->get_symbol(self_name);
3371+
ASR::Variable_t* self_var = ASR::down_cast<ASR::Variable_t>(sym);
3372+
self_var->m_type = ASRUtils::TYPE(ASRUtils::make_StructType_t_util(al,loc, class_sym));
3373+
current_scope = parent_scope;
3374+
}
3375+
33513376
virtual void visit_init_body (const AST::FunctionDef_t &/*x*/) = 0;
33523377

33533378
void add_name(const Location &loc) {
@@ -5139,7 +5164,8 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
51395164
for (const auto &rt: rt_vec) { rts.push_back(al, rt); }
51405165
f->m_body = body.p;
51415166
f->n_body = body.size();
5142-
ASR::FunctionType_t* func_type = ASR::down_cast<ASR::FunctionType_t>(f->m_function_signature);
5167+
ASR::FunctionType_t* func_type = ASR::down_cast<ASR::FunctionType_t>(
5168+
f->m_function_signature);
51435169
func_type->m_restrictions = rts.p;
51445170
func_type->n_restrictions = rts.size();
51455171
f->m_dependencies = dependencies.p;
@@ -8033,10 +8059,12 @@ we will have to use something else.
80338059
st = get_struct_member(st, call_name, loc);
80348060
} else if ( ASR::is_a<ASR::Variable_t>(*st)) {
80358061
ASR::Variable_t* var = ASR::down_cast<ASR::Variable_t>(st);
8036-
if (ASR::is_a<ASR::StructType_t>(*var->m_type)) {
8062+
if (ASR::is_a<ASR::StructType_t>(*var->m_type) ||
8063+
ASR::is_a<ASR::Class_t>(*var->m_type) ) {
8064+
//TODO: Correct Class and ClassType
80378065
// call to struct member function
80388066
// modifying args to pass the object as self
8039-
ASR::StructType_t* var_struct = ASR::down_cast<ASR::StructType_t>(var->m_type);
8067+
ASR::symbol_t* der = ASR::down_cast<ASR::StructType_t>(var->m_type)->m_derived_type;
80408068
Vec<ASR::call_arg_t> new_args; new_args.reserve(al, args.n + 1);
80418069
ASR::call_arg_t self_arg;
80428070
self_arg.loc = args[0].loc;
@@ -8045,7 +8073,7 @@ we will have to use something else.
80458073
for (size_t i=0; i<args.n; i++) {
80468074
new_args.push_back(al, args[i]);
80478075
}
8048-
st = get_struct_member(var_struct->m_derived_type, call_name, loc);
8076+
st = get_struct_member(der, call_name, loc);
80498077
tmp = make_call_helper(al, st, current_scope, new_args, call_name, loc);
80508078
return;
80518079
} else {

0 commit comments

Comments
 (0)