From 061002efe863b8694ba13608d977f439d79677eb Mon Sep 17 00:00:00 2001 From: Christoph Hegemann Date: Sun, 6 Oct 2024 14:09:58 +0200 Subject: [PATCH] Inference for lambda argument and result types (#40) Closes #36 --- .../cant_reassign_captured_variable.nemo | 2 +- crates/cli/tests/check/lambda.nemo | 15 ++- crates/cli/tests/check/type_mismatch.nemo | 2 +- crates/cli/tests/lib.rs | 7 +- crates/cli/tests/run/closures.nemo | 22 ++-- crates/cli/tests/run/expr_postfix.nemo | 2 +- crates/cli/tests/run/return.nemo | 2 +- .../tests/snapshots/check@lambda.nemo.snap | 36 +++++-- .../snapshots/check@type_mismatch.nemo.snap | 28 ++++- .../tests/snapshots/run@closures.nemo-2.snap | 1 + crates/frontend/src/parser/grammar.rs | 22 ++-- crates/frontend/src/parser/lexer.rs | 4 + crates/frontend/src/scip.rs | 11 +- crates/frontend/src/syntax/nodes.rs | 9 -- crates/frontend/src/syntax/nodes.ungram | 2 +- crates/frontend/src/types/check.rs | 100 ++++++++++++++++-- design/lambdas.md | 20 ++++ 17 files changed, 218 insertions(+), 67 deletions(-) create mode 100644 design/lambdas.md diff --git a/crates/cli/tests/check/cant_reassign_captured_variable.nemo b/crates/cli/tests/check/cant_reassign_captured_variable.nemo index aa6d5ee8..f9265963 100644 --- a/crates/cli/tests/check/cant_reassign_captured_variable.nemo +++ b/crates/cli/tests/check/cant_reassign_captured_variable.nemo @@ -1,6 +1,6 @@ fn main() { let local = 10; - let assign_capture = fn () { + let assign_capture = \() { set local = 20; }; {} diff --git a/crates/cli/tests/check/lambda.nemo b/crates/cli/tests/check/lambda.nemo index bedecedf..f3ccf11a 100644 --- a/crates/cli/tests/check/lambda.nemo +++ b/crates/cli/tests/check/lambda.nemo @@ -1,6 +1,15 @@ fn main() -> i32 { - let f = fn () -> f32 { - return 10; + let f = \() -> i32 { + return 10.0; }; - return 10.0; + let wrong_return : fn (i32) -> i32 = \(x) { + x + }; + let wrong_param : fn (i32) -> i32 = \(x : f32) { + return x + 1.0; + }; + let wrong_param_count : fn (i32) -> i32 = \(x, y) { + return x + y; + }; + return 42; } diff --git a/crates/cli/tests/check/type_mismatch.nemo b/crates/cli/tests/check/type_mismatch.nemo index 22dbb2fe..27fbd578 100644 --- a/crates/cli/tests/check/type_mismatch.nemo +++ b/crates/cli/tests/check/type_mismatch.nemo @@ -1,2 +1,2 @@ global x : i32 = 10.0 -global my_fn : fn (i32) -> i32 = fn (x : f32) -> f32 { x } +global my_fn : fn (i32) -> i32 = \(x : f32) -> f32 { x } diff --git a/crates/cli/tests/lib.rs b/crates/cli/tests/lib.rs index 1df4fa2c..43fe9ac9 100644 --- a/crates/cli/tests/lib.rs +++ b/crates/cli/tests/lib.rs @@ -14,12 +14,9 @@ fn render_slash_path(path: &Utf8Path) -> String { fn compile_args(paths: &[Utf8PathBuf], out_name: &str) -> (Vec, String) { let out_path = format!("tests/build/{}.wasm", out_name); - let mut args = vec![ - "compile".to_string(), - "--output".to_string(), - out_path.clone(), - ]; + let mut args = vec!["compile".to_string()]; args.extend(paths.iter().map(|p| render_slash_path(p))); + args.extend(["--output".to_string(), out_path.clone()]); (args, out_path) } diff --git a/crates/cli/tests/run/closures.nemo b/crates/cli/tests/run/closures.nemo index 74eb2395..b415b9a4 100644 --- a/crates/cli/tests/run/closures.nemo +++ b/crates/cli/tests/run/closures.nemo @@ -1,16 +1,20 @@ import log : fn (i32) -> unit from log +fn inferred() -> i32 { + let f : fn (i32, f32) -> i32 = \(x, y) { x + 2 }; + f(1, 2.0) +} + fn main() { - let x = (fn (x : i32) -> i32 { x + 1 })(10); + let x = (\(x : i32) -> i32 { x + 1 })(10); log(x); - let twice = { - fn (f : fn (i32) -> i32) -> fn(i32) -> i32 { - fn (x : i32) -> i32 { - f(f(x)) - } - } + let twice = \(f : fn (i32) -> i32) -> fn(i32) -> i32 { + \(x : i32) -> i32 { + f(f(x)) + } }; - let add1 = fn (x : i32) -> i32 { x + 1 }; - log(twice(add1)(3)) + let add1 = \(x : i32) -> i32 { x + 1 }; + log(twice(add1)(3)); + log(inferred()); } diff --git a/crates/cli/tests/run/expr_postfix.nemo b/crates/cli/tests/run/expr_postfix.nemo index 3c077d8d..0dc823a7 100644 --- a/crates/cli/tests/run/expr_postfix.nemo +++ b/crates/cli/tests/run/expr_postfix.nemo @@ -14,7 +14,7 @@ fn f2(x : i32, y : i32) { } fn ff(x : i32) -> fn(i32) -> unit { - fn (y : i32) { + \(y : i32) { log(y) } } diff --git a/crates/cli/tests/run/return.nemo b/crates/cli/tests/run/return.nemo index 88ea2dab..ba5a9b57 100644 --- a/crates/cli/tests/run/return.nemo +++ b/crates/cli/tests/run/return.nemo @@ -24,7 +24,7 @@ fn loop_return() -> f32 { } fn lambda_return() -> i32 { - let f = fn () -> bool { + let f = \() -> bool { return true; }; if f() { diff --git a/crates/cli/tests/snapshots/check@lambda.nemo.snap b/crates/cli/tests/snapshots/check@lambda.nemo.snap index 2058c3fc..56ad9164 100644 --- a/crates/cli/tests/snapshots/check@lambda.nemo.snap +++ b/crates/cli/tests/snapshots/check@lambda.nemo.snap @@ -5,27 +5,47 @@ info: args: - check - tests/check/lambda.nemo + env: + RUST_BACKTRACE: "0" input_file: crates/cli/tests/check/lambda.nemo --- success: false exit_code: 1 ----- stdout ----- -[15] Error: Type mismatch. Expected f32, but got i32 +[15] Error: Type mismatch. Expected i32, but got f32 ╭─[tests/check/lambda.nemo:1:13] │ - 3 │       return 10; -  │ ─┬ -  │ ╰── Type mismatch. Expected f32, but got i32 + 3 │       return 10.0; +  │ ──┬─ +  │ ╰─── Type mismatch. Expected i32, but got f32 ───╯ [15] Error: Type mismatch. Expected i32, but got f32 ╭─[tests/check/lambda.nemo:1:13] │ - 5 │     return 10.0; -  │ ──┬─ -  │ ╰─── Type mismatch. Expected i32, but got f32 + 8 │     let wrong_param : fn (i32) -> i32 = \(x : f32) { +  │ ─┬─ +  │ ╰─── Type mismatch. Expected i32, but got f32 ───╯ +[15] Error: Type mismatch. Expected i32, but got f32 + ╭─[tests/check/lambda.nemo:1:13] + │ + 9 │       return x + 1.0; +  │ ───┬─── +  │ ╰───── Type mismatch. Expected i32, but got f32 +───╯ + +[11] Error: Mismatched arg count. Expected 1 argument, but got 2 + ╭─[tests/check/lambda.nemo:1:13] + │ + 11 │ ╭─▶     let wrong_param_count : fn (i32) -> i32 = \(x, y) { +  ┆ ┆ + 13 │ ├─▶ }; +  │ │ +  │ ╰──────────── Mismatched arg count. Expected 1 argument, but got 2 +────╯ + ----- stderr ----- -Error: "Check failed with 2 errors" +Error: "Check failed with 4 errors" diff --git a/crates/cli/tests/snapshots/check@type_mismatch.nemo.snap b/crates/cli/tests/snapshots/check@type_mismatch.nemo.snap index bc966d69..af339a12 100644 --- a/crates/cli/tests/snapshots/check@type_mismatch.nemo.snap +++ b/crates/cli/tests/snapshots/check@type_mismatch.nemo.snap @@ -5,6 +5,8 @@ info: args: - check - tests/check/type_mismatch.nemo + env: + RUST_BACKTRACE: "0" input_file: crates/cli/tests/check/type_mismatch.nemo --- success: false @@ -18,14 +20,30 @@ exit_code: 1  │ ╰─── Type mismatch. Expected i32, but got f32 ───╯ -[15] Error: Type mismatch. Expected fn (i32) -> i32, but got fn (f32) -> f32 +[15] Error: Type mismatch. Expected i32, but got f32 + ╭─[tests/check/type_mismatch.nemo:1:13] + │ + 2 │ global my_fn : fn (i32) -> i32 = \(x : f32) -> f32 { x } +  │ ──┬─ +  │ ╰─── Type mismatch. Expected i32, but got f32 +───╯ + +[15] Error: Type mismatch. Expected i32, but got f32 + ╭─[tests/check/type_mismatch.nemo:1:13] + │ + 2 │ global my_fn : fn (i32) -> i32 = \(x : f32) -> f32 { x } +  │ ─┬─ +  │ ╰─── Type mismatch. Expected i32, but got f32 +───╯ + +[15] Error: Type mismatch. Expected i32, but got f32 ╭─[tests/check/type_mismatch.nemo:1:13] │ - 2 │ global my_fn : fn (i32) -> i32 = fn (x : f32) -> f32 { x } -  │ ────────────┬──────────── -  │ ╰────────────── Type mismatch. Expected fn (i32) -> i32, but got fn (f32) -> f32 + 2 │ global my_fn : fn (i32) -> i32 = \(x : f32) -> f32 { x } +  │ ─┬ +  │ ╰── Type mismatch. Expected i32, but got f32 ───╯ ----- stderr ----- -Error: "Check failed with 2 errors" +Error: "Check failed with 4 errors" diff --git a/crates/cli/tests/snapshots/run@closures.nemo-2.snap b/crates/cli/tests/snapshots/run@closures.nemo-2.snap index afb35050..fee0c839 100644 --- a/crates/cli/tests/snapshots/run@closures.nemo-2.snap +++ b/crates/cli/tests/snapshots/run@closures.nemo-2.snap @@ -16,6 +16,7 @@ exit_code: 0 ----- stdout ----- 11 5 +3 0 ----- stderr ----- diff --git a/crates/frontend/src/parser/grammar.rs b/crates/frontend/src/parser/grammar.rs index 2bf73bd8..ff913d16 100644 --- a/crates/frontend/src/parser/grammar.rs +++ b/crates/frontend/src/parser/grammar.rs @@ -15,6 +15,12 @@ impl Progress { } } +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +enum TypAnnot { + Optional, + Required, +} + pub fn prog(p: &mut Parser) { while !p.at(SyntaxKind::EOF) { module(p); @@ -197,7 +203,7 @@ fn top_func(p: &mut Parser) { p.error("expected a function name") } typ_param_list(p); - param_list(p); + param_list(p, TypAnnot::Required); if p.eat(T![->]) && !typ(p).made_progress() { p.error("expected a return type") } @@ -225,7 +231,7 @@ fn typ_param_list(p: &mut Parser) -> Progress { Progress::Made } -fn param_list(p: &mut Parser) { +fn param_list(p: &mut Parser, typ_annot_opt: TypAnnot) { if !p.eat(SyntaxKind::L_PAREN) { p.error("expected a parameter list"); return; @@ -233,7 +239,7 @@ fn param_list(p: &mut Parser) { while p.at(SyntaxKind::IDENT) { let c = p.checkpoint(); p.bump(SyntaxKind::IDENT); - if !typ_annot(p).made_progress() { + if !typ_annot(p).made_progress() && typ_annot_opt == TypAnnot::Required { p.error("expected a type annotation") } p.eat(SyntaxKind::COMMA); @@ -428,8 +434,8 @@ fn expr(p: &mut Parser) -> Progress { match_expr(p); return Progress::Made; } - if p.at(T![fn]) { - func_expr(p); + if p.at(T![lambda]) { + lambda_expr(p); return Progress::Made; } if p.at(T![return]) { @@ -537,10 +543,10 @@ fn return_expr(p: &mut Parser) { p.finish_at(c, SyntaxKind::EReturn) } -fn func_expr(p: &mut Parser) { +fn lambda_expr(p: &mut Parser) { let c = p.checkpoint(); - p.bump(T![fn]); - param_list(p); + p.bump(T![lambda]); + param_list(p, TypAnnot::Optional); if p.eat(T![->]) && !typ(p).made_progress() { p.error("expected a return type") } diff --git a/crates/frontend/src/parser/lexer.rs b/crates/frontend/src/parser/lexer.rs index 5a5af6e6..220054b8 100644 --- a/crates/frontend/src/parser/lexer.rs +++ b/crates/frontend/src/parser/lexer.rs @@ -229,6 +229,9 @@ pub enum SyntaxKind { #[token("/")] SLASH, + #[token("\\")] + BACKSLASH, + #[token("=")] EQUALS, @@ -449,6 +452,7 @@ macro_rules ! T { [exports] => { SyntaxKind::EXPORTS_KW }; [use] => { SyntaxKind::USE_KW }; [fn] => { SyntaxKind::FN_KW }; + [lambda] => { SyntaxKind::BACKSLASH }; [global] => { SyntaxKind::GLOBAL_KW }; [import] => { SyntaxKind::IMPORT_KW }; [from] => { SyntaxKind::FROM_KW }; diff --git a/crates/frontend/src/scip.rs b/crates/frontend/src/scip.rs index 3b7c62ea..3e2eec82 100644 --- a/crates/frontend/src/scip.rs +++ b/crates/frontend/src/scip.rs @@ -184,10 +184,13 @@ fn index_occurrence_map( local_id: local_supply.to_string().into(), }; let symbol = sym.to_string(); - (sym, SymbolInformation { - symbol, - ..Default::default() - }) + ( + sym, + SymbolInformation { + symbol, + ..Default::default() + }, + ) }); if !sym.is_local() { symbols.push(symbol_information.clone()); diff --git a/crates/frontend/src/syntax/nodes.rs b/crates/frontend/src/syntax/nodes.rs index 405e781c..71734f72 100644 --- a/crates/frontend/src/syntax/nodes.rs +++ b/crates/frontend/src/syntax/nodes.rs @@ -632,18 +632,9 @@ pub struct ELambda { pub(crate) syntax: SyntaxNode, } impl ELambda { - pub fn fn_token(&self) -> Option { - support::token(&self.syntax, T![fn]) - } - pub fn l_paren_token(&self) -> Option { - support::token(&self.syntax, T!['(']) - } pub fn params(&self) -> AstChildren { support::children(&self.syntax) } - pub fn r_paren_token(&self) -> Option { - support::token(&self.syntax, T![')']) - } pub fn return_ty(&self) -> Option { support::child(&self.syntax) } diff --git a/crates/frontend/src/syntax/nodes.ungram b/crates/frontend/src/syntax/nodes.ungram index 4322723a..34697198 100644 --- a/crates/frontend/src/syntax/nodes.ungram +++ b/crates/frontend/src/syntax/nodes.ungram @@ -107,7 +107,7 @@ EIf = 'if' condition:Expr then_branch:EBlock 'else' else_branch:EBlock EMatch = 'match' scrutinee:Expr '{' EMatchBranch* '}' EMatchBranch = Pattern '=>' body:EBlock -ELambda = 'fn' '(' Param* ')' return_ty:Type body:Expr +ELambda = Param* return_ty:Type body:Expr EBlock = '{' Declaration* '}' EReturn = 'return' Expr diff --git a/crates/frontend/src/types/check.rs b/crates/frontend/src/types/check.rs index 8ab82930..b7a01768 100644 --- a/crates/frontend/src/types/check.rs +++ b/crates/frontend/src/types/check.rs @@ -1058,6 +1058,7 @@ impl Typechecker<'_> { }; let (name, sym) = self.name_supply.local_idx(&name_tkn); let ty = match param.ty() { + // TODO: Produce a type error. We can't infer lambdas None => Ty::Error, Some(t) => self.check_ty(errors, scope, &t), }; @@ -1678,19 +1679,13 @@ impl Typechecker<'_> { let (_, ir) = self.check_call(errors, scope, expr, Some(expected))?; ir } + (Expr::ELambda(expr), Ty::Func(func_ty)) => { + self.check_lambda(errors, scope, expr, func_ty.as_ref()) + } _ => { let (ty, ir) = self.infer_expr(errors, scope, expr); - if *expected != Ty::Error - && !matches!(ty, Ty::Error | Ty::Diverge) - && ty.ne(expected) - { - errors.report( - expr, - TypeMismatch { - expected: expected.clone(), - actual: ty, - }, - ); + if let Some(ty_err) = expect_ty(expected, &ty) { + errors.report(expr, ty_err); } return ir; } @@ -1966,6 +1961,78 @@ impl Typechecker<'_> { at: pattern.syntax().text_range(), }) } + + fn check_lambda( + &self, + errors: &mut TyErrors, + scope: &mut Scope, + expr: &ELambda, + expected: &FuncTy, + ) -> Option { + if let Some(return_ty) = expr.return_ty() { + let t = self.check_ty(errors, scope, &return_ty); + if let Some(ty_err) = expect_ty(&expected.result, &t) { + errors.report(&return_ty, ty_err); + } + } + let mut builder = LambdaBuilder::default(); + builder.return_ty(Some(expected.result.clone())); + scope.enter_block(); + let prev_return_ty = scope.set_return_type(expected.result.clone()); + // TODO: Check for duplicate parameter names + let params: Vec = expr.params().collect(); + if params.len() != expected.arguments.len() { + errors.report( + // TODO: Make a node for the arg list, so we can report it instead of the whole lambda + expr, + ArgCountMismatch(expected.arguments.len(), params.len()), + ); + } + let mut param_names = HashSet::new(); + for (param, param_ty) in params.iter().zip( + expected + .arguments + .iter() + .chain(std::iter::repeat(&Ty::Error)), + ) { + let Some(name_tkn) = param.ident_token() else { + continue; + }; + let (name, sym) = self.name_supply.local_idx(&name_tkn); + let ty = if let Some(t) = param.ty() { + let ty = self.check_ty(errors, scope, &t); + if let Some(ty_err) = expect_ty(param_ty, &ty) { + errors.report(&t, ty_err); + } + ty + } else { + param_ty.clone() + }; + builder.params(Some((name, ty.clone()))); + param_names.insert(name); + scope.add_var(sym, ty.clone(), name); + } + + if let Some(body) = expr.body() { + let body_ir = self.check_expr(errors, scope, &body, &expected.result); + if let Some(body_ir) = body_ir { + for (n, fvi) in body_ir.free_vars() { + if !param_names.contains(&n) { + if let Some(assignment) = fvi.is_assigned { + errors.report(&assignment, CantReassignCapturedVariable(n)); + } else { + builder.captures(Some((n, fvi.ty.clone()))); + } + } + } + builder.body(Some(body_ir)); + } + } + scope.leave_block(); + // Restore the previous return type + scope.restore_return_type(prev_return_ty); + builder.build() + } } fn check_op(op: &SyntaxToken, ty_left: &Ty, ty_right: &Ty) -> Option<(ir::OpData, Ty)> { @@ -2025,6 +2092,17 @@ fn infer_struct_instantiation<'a>(def: &StructDef, expected: &'a Ty) -> Option<& None } +fn expect_ty(expected: &Ty, ty: &Ty) -> Option { + if *expected != Ty::Error && !matches!(ty, Ty::Error | Ty::Diverge) && expected != ty { + Some(TypeMismatch { + expected: expected.clone(), + actual: ty.clone(), + }) + } else { + None + } +} + // `match_ty` is non-commutative unification. Only variables on the left are allowed to be solved, and we limit // the set of variables that may be solved to `Name::Gen(_)`. This is used to implement inference for type // parameters to polymorphic functions or struct literals. diff --git a/design/lambdas.md b/design/lambdas.md new file mode 100644 index 00000000..c5150494 --- /dev/null +++ b/design/lambdas.md @@ -0,0 +1,20 @@ +# Anonymous functions (Lambdas) + +Lambdas are nameless functions that can appear in expression position. They +capture variables from their environment. Value types are copied when captured +and trying to `set` them results in a type error. + +They are defined using a `\` followed by a parameter list: + +``` +\(x : i32) -> i32 { x + 1 } +``` + +When a lambda appears in checking position its argument and result types +can be ommitted and will be inferred. + +``` +let x : i32 = (\(y) { y + 1 })(2); +``` + +Lambdas are always monomorphic.