@@ -1931,7 +1931,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
1931
1931
return ASRUtils::TYPE (ASR::make_Union_t (al, attr_annotation->base .base .loc , import_struct_member));
1932
1932
} else if ( AST::is_a<AST::ConstantStr_t>(annotation) ) {
1933
1933
AST::ConstantStr_t *n = AST::down_cast<AST::ConstantStr_t>(&annotation);
1934
- ASR::symbol_t *sym = current_scope->parent -> parent -> resolve_symbol (n->m_value );
1934
+ ASR::symbol_t *sym = current_scope->resolve_symbol (n->m_value );
1935
1935
if ( sym == nullptr || !ASR::is_a<ASR::Struct_t>(*sym) ) {
1936
1936
throw SemanticError (" Only Struct implemented for constant"
1937
1937
" str annotation" , loc);
@@ -3300,6 +3300,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
3300
3300
if ( AST::is_a<AST::FunctionDef_t>(*x.m_body [i]) ) {
3301
3301
AST::FunctionDef_t*
3302
3302
f = AST::down_cast<AST::FunctionDef_t>(x.m_body [i]);
3303
+ init_self_type (*f, sym, x.base .base .loc );
3303
3304
if ( std::string (f->m_name ) == std::string (" __init__" ) ) {
3304
3305
this ->visit_init_body (*f);
3305
3306
} else {
@@ -3348,6 +3349,30 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
3348
3349
}
3349
3350
}
3350
3351
3352
+ void init_self_type (const AST::FunctionDef_t &x,
3353
+ ASR::symbol_t * class_sym, Location loc) {
3354
+ SymbolTable* parent_scope = current_scope;
3355
+ ASR::symbol_t *t = current_scope->get_symbol (x.m_name );
3356
+ if (t==nullptr ) {
3357
+ throw SemanticError (" Function not found in current symbol table" ,
3358
+ x.base .base .loc );
3359
+ }
3360
+ if ( !ASR::is_a<ASR::Function_t>(*t) ) {
3361
+ throw SemanticError (" Only functions implemented in classes" ,
3362
+ x.base .base .loc );
3363
+ }
3364
+ ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(t);
3365
+ current_scope = f->m_symtab ;
3366
+ if ( f->n_args ==0 ) {
3367
+ return ;
3368
+ }
3369
+ std::string self_name = x.m_args .m_args [0 ].m_arg ;
3370
+ ASR::symbol_t * sym = current_scope->get_symbol (self_name);
3371
+ ASR::Variable_t* self_var = ASR::down_cast<ASR::Variable_t>(sym);
3372
+ self_var->m_type = ASRUtils::TYPE (ASRUtils::make_StructType_t_util (al,loc, class_sym));
3373
+ current_scope = parent_scope;
3374
+ }
3375
+
3351
3376
virtual void visit_init_body (const AST::FunctionDef_t &/* x*/ ) = 0;
3352
3377
3353
3378
void add_name (const Location &loc) {
@@ -5139,7 +5164,8 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
5139
5164
for (const auto &rt: rt_vec) { rts.push_back (al, rt); }
5140
5165
f->m_body = body.p ;
5141
5166
f->n_body = body.size ();
5142
- ASR::FunctionType_t* func_type = ASR::down_cast<ASR::FunctionType_t>(f->m_function_signature );
5167
+ ASR::FunctionType_t* func_type = ASR::down_cast<ASR::FunctionType_t>(
5168
+ f->m_function_signature );
5143
5169
func_type->m_restrictions = rts.p ;
5144
5170
func_type->n_restrictions = rts.size ();
5145
5171
f->m_dependencies = dependencies.p ;
@@ -8033,10 +8059,12 @@ we will have to use something else.
8033
8059
st = get_struct_member (st, call_name, loc);
8034
8060
} else if ( ASR::is_a<ASR::Variable_t>(*st)) {
8035
8061
ASR::Variable_t* var = ASR::down_cast<ASR::Variable_t>(st);
8036
- if (ASR::is_a<ASR::StructType_t>(*var->m_type )) {
8062
+ if (ASR::is_a<ASR::StructType_t>(*var->m_type ) ||
8063
+ ASR::is_a<ASR::Class_t>(*var->m_type ) ) {
8064
+ // TODO: Correct Class and ClassType
8037
8065
// call to struct member function
8038
8066
// modifying args to pass the object as self
8039
- ASR::StructType_t* var_struct = ASR::down_cast<ASR::StructType_t>(var->m_type );
8067
+ ASR::symbol_t * der = ASR::down_cast<ASR::StructType_t>(var->m_type )-> m_derived_type ;
8040
8068
Vec<ASR::call_arg_t > new_args; new_args.reserve (al, args.n + 1 );
8041
8069
ASR::call_arg_t self_arg;
8042
8070
self_arg.loc = args[0 ].loc ;
@@ -8045,7 +8073,7 @@ we will have to use something else.
8045
8073
for (size_t i=0 ; i<args.n ; i++) {
8046
8074
new_args.push_back (al, args[i]);
8047
8075
}
8048
- st = get_struct_member (var_struct-> m_derived_type , call_name, loc);
8076
+ st = get_struct_member (der , call_name, loc);
8049
8077
tmp = make_call_helper (al, st, current_scope, new_args, call_name, loc);
8050
8078
return ;
8051
8079
} else {
0 commit comments