Skip to content

Commit

Permalink
Inference for lambda argument and result types (#40)
Browse files Browse the repository at this point in the history
Closes #36
  • Loading branch information
kritzcreek authored Oct 6, 2024
1 parent 1cdce7b commit 061002e
Show file tree
Hide file tree
Showing 17 changed files with 218 additions and 67 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
fn main() {
let local = 10;
let assign_capture = fn () {
let assign_capture = \() {
set local = 20;
};
{}
Expand Down
15 changes: 12 additions & 3 deletions crates/cli/tests/check/lambda.nemo
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
fn main() -> i32 {
let f = fn () -> f32 {
return 10;
let f = \() -> i32 {
return 10.0;
};
return 10.0;
let wrong_return : fn (i32) -> i32 = \(x) {
x
};
let wrong_param : fn (i32) -> i32 = \(x : f32) {
return x + 1.0;
};
let wrong_param_count : fn (i32) -> i32 = \(x, y) {
return x + y;
};
return 42;
}
2 changes: 1 addition & 1 deletion crates/cli/tests/check/type_mismatch.nemo
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
global x : i32 = 10.0
global my_fn : fn (i32) -> i32 = fn (x : f32) -> f32 { x }
global my_fn : fn (i32) -> i32 = \(x : f32) -> f32 { x }
7 changes: 2 additions & 5 deletions crates/cli/tests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,9 @@ fn render_slash_path(path: &Utf8Path) -> String {

fn compile_args(paths: &[Utf8PathBuf], out_name: &str) -> (Vec<String>, String) {
let out_path = format!("tests/build/{}.wasm", out_name);
let mut args = vec![
"compile".to_string(),
"--output".to_string(),
out_path.clone(),
];
let mut args = vec!["compile".to_string()];
args.extend(paths.iter().map(|p| render_slash_path(p)));
args.extend(["--output".to_string(), out_path.clone()]);
(args, out_path)
}

Expand Down
22 changes: 13 additions & 9 deletions crates/cli/tests/run/closures.nemo
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import log : fn (i32) -> unit from log

fn inferred() -> i32 {
let f : fn (i32, f32) -> i32 = \(x, y) { x + 2 };
f(1, 2.0)
}

fn main() {
let x = (fn (x : i32) -> i32 { x + 1 })(10);
let x = (\(x : i32) -> i32 { x + 1 })(10);
log(x);

let twice = {
fn (f : fn (i32) -> i32) -> fn(i32) -> i32 {
fn (x : i32) -> i32 {
f(f(x))
}
}
let twice = \(f : fn (i32) -> i32) -> fn(i32) -> i32 {
\(x : i32) -> i32 {
f(f(x))
}
};
let add1 = fn (x : i32) -> i32 { x + 1 };
log(twice(add1)(3))
let add1 = \(x : i32) -> i32 { x + 1 };
log(twice(add1)(3));
log(inferred());
}
2 changes: 1 addition & 1 deletion crates/cli/tests/run/expr_postfix.nemo
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fn f2(x : i32, y : i32) {
}

fn ff(x : i32) -> fn(i32) -> unit {
fn (y : i32) {
\(y : i32) {
log(y)
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/cli/tests/run/return.nemo
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fn loop_return() -> f32 {
}

fn lambda_return() -> i32 {
let f = fn () -> bool {
let f = \() -> bool {
return true;
};
if f() {
Expand Down
36 changes: 28 additions & 8 deletions crates/cli/tests/snapshots/[email protected]
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,47 @@ info:
args:
- check
- tests/check/lambda.nemo
env:
RUST_BACKTRACE: "0"
input_file: crates/cli/tests/check/lambda.nemo
---
success: false
exit_code: 1
----- stdout -----
[31m[15] Error:[0m Type mismatch. Expected f32, but got i32
[31m[15] Error:[0m Type mismatch. Expected i32, but got f32
╭─[tests/check/lambda.nemo:1:13]
│
[38;5;246m3 │[0m [38;5;249m [0m[38;5;249m [0m[38;5;249m [0m[38;5;249m [0m[38;5;249m [0m[38;5;249m [0m[38;5;249mr[0m[38;5;249me[0m[38;5;249mt[0m[38;5;249mu[0m[38;5;249mr[0m[38;5;249mn[0m[38;5;249m [0m10[38;5;249m;[0m
[38;5;240m │[0m
[38;5;240m │[0m ╰── Type mismatch. Expected f32, but got i32
[38;5;246m3 │[0m [38;5;249m [0m[38;5;249m [0m[38;5;249m [0m[38;5;249m [0m[38;5;249m [0m[38;5;249m [0m[38;5;249mr[0m[38;5;249me[0m[38;5;249mt[0m[38;5;249mu[0m[38;5;249mr[0m[38;5;249mn[0m[38;5;249m [0m10.0[38;5;249m;[0m
[38;5;240m │[0m ──┬─
[38;5;240m │[0m ╰── Type mismatch. Expected i32, but got f32
───╯

[15] Error: Type mismatch. Expected i32, but got f32
╭─[tests/check/lambda.nemo:1:13]
│
[38;5;246m5 │[0m [38;5;249m [0m[38;5;249m [0m[38;5;249m [0m[38;5;249m [0m[38;5;249mr[0m[38;5;249me[0m[38;5;249mt[0m[38;5;249mu[0m[38;5;249mr[0m[38;5;249mn[0m[38;5;249m [0m10.0[38;5;249m;[0m
[38;5;240m │[0m ─┬─
[38;5;240m │[0m ╰─── Type mismatch. Expected i32, but got f32
[38;5;246m8 │[0m [38;5;249m [0m[38;5;249m [0m[38;5;249m [0m[38;5;249m [0m[38;5;249ml[0m[38;5;249me[0m[38;5;249mt[0m[38;5;249m [0m[38;5;249mw[0m[38;5;249mr[0m[38;5;249mo[0m[38;5;249mn[0m[38;5;249mg[0m[38;5;249m_[0m[38;5;249mp[0m[38;5;249ma[0m[38;5;249mr[0m[38;5;249ma[0m[38;5;249mm[0m[38;5;249m [0m[38;5;249m:[0m[38;5;249m [0m[38;5;249mf[0m[38;5;249mn[0m[38;5;249m [0m[38;5;249m([0m[38;5;249mi[0m[38;5;249m3[0m[38;5;249m2[0m[38;5;249m)[0m[38;5;249m [0m[38;5;249m-[0m[38;5;249m>[0m[38;5;249m [0m[38;5;249mi[0m[38;5;249m3[0m[38;5;249m2[0m[38;5;249m [0m[38;5;249m=[0m[38;5;249m [0m[38;5;249m\[0m[38;5;249m([0m[38;5;249mx[0m[38;5;249m [0m[38;5;249m:[0m[38;5;249m [0mf32[38;5;249m)[0m[38;5;249m [0m[38;5;249m{[0m
[38;5;240m │[0m ─┬─
[38;5;240m │[0m ╰─── Type mismatch. Expected i32, but got f32
───╯

[15] Error: Type mismatch. Expected i32, but got f32
╭─[tests/check/lambda.nemo:1:13]
│
9 │       return x + 1.0;
 │ ───┬───
 │ ╰───── Type mismatch. Expected i32, but got f32
───╯

[11] Error: Mismatched arg count. Expected 1 argument, but got 2
╭─[tests/check/lambda.nemo:1:13]
│
11 │ ╭─▶     let wrong_param_count : fn (i32) -> i32 = \(x, y) {
 ┆ ┆
13 │ ├─▶ };
 │ │
 │ ╰──────────── Mismatched arg count. Expected 1 argument, but got 2
────╯


----- stderr -----
Error: "Check failed with 2 errors"
Error: "Check failed with 4 errors"
28 changes: 23 additions & 5 deletions crates/cli/tests/snapshots/check@type_mismatch.nemo.snap
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ info:
args:
- check
- tests/check/type_mismatch.nemo
env:
RUST_BACKTRACE: "0"
input_file: crates/cli/tests/check/type_mismatch.nemo
---
success: false
Expand All @@ -18,14 +20,30 @@ exit_code: 1
 │ ╰─── Type mismatch. Expected i32, but got f32
───╯

[15] Error: Type mismatch. Expected fn (i32) -> i32, but got fn (f32) -> f32
[15] Error: Type mismatch. Expected i32, but got f32
╭─[tests/check/type_mismatch.nemo:1:13]
│
2 │ global my_fn : fn (i32) -> i32 = \(x : f32) -> f32 { x }
 │ ──┬─
 │ ╰─── Type mismatch. Expected i32, but got f32
───╯

[15] Error: Type mismatch. Expected i32, but got f32
╭─[tests/check/type_mismatch.nemo:1:13]
│
2 │ global my_fn : fn (i32) -> i32 = \(x : f32) -> f32 { x }
 │ ─┬─
 │ ╰─── Type mismatch. Expected i32, but got f32
───╯

[15] Error: Type mismatch. Expected i32, but got f32
╭─[tests/check/type_mismatch.nemo:1:13]
│
[38;5;246m2 │[0m [38;5;249mg[0m[38;5;249ml[0m[38;5;249mo[0m[38;5;249mb[0m[38;5;249ma[0m[38;5;249ml[0m[38;5;249m [0m[38;5;249mm[0m[38;5;249my[0m[38;5;249m_[0m[38;5;249mf[0m[38;5;249mn[0m[38;5;249m [0m[38;5;249m:[0m[38;5;249m [0m[38;5;249mf[0m[38;5;249mn[0m[38;5;249m [0m[38;5;249m([0m[38;5;249mi[0m[38;5;249m3[0m[38;5;249m2[0m[38;5;249m)[0m[38;5;249m [0m[38;5;249m-[0m[38;5;249m>[0m[38;5;249m [0m[38;5;249mi[0m[38;5;249m3[0m[38;5;249m2[0m[38;5;249m [0m[38;5;249m=[0m[38;5;249m [0mfn (x : f32) -> f32 { x }
[38;5;240m │[0m ────────────┬────────────
[38;5;240m │[0m ╰────────────── Type mismatch. Expected fn (i32) -> i32, but got fn (f32) -> f32
[38;5;246m2 │[0m [38;5;249mg[0m[38;5;249ml[0m[38;5;249mo[0m[38;5;249mb[0m[38;5;249ma[0m[38;5;249ml[0m[38;5;249m [0m[38;5;249mm[0m[38;5;249my[0m[38;5;249m_[0m[38;5;249mf[0m[38;5;249mn[0m[38;5;249m [0m[38;5;249m:[0m[38;5;249m [0m[38;5;249mf[0m[38;5;249mn[0m[38;5;249m [0m[38;5;249m([0m[38;5;249mi[0m[38;5;249m3[0m[38;5;249m2[0m[38;5;249m)[0m[38;5;249m [0m[38;5;249m-[0m[38;5;249m>[0m[38;5;249m [0m[38;5;249mi[0m[38;5;249m3[0m[38;5;249m2[0m[38;5;249m [0m[38;5;249m=[0m[38;5;249m [0m[38;5;249m\[0m[38;5;249m([0m[38;5;249mx[0m[38;5;249m [0m[38;5;249m:[0m[38;5;249m [0m[38;5;249mf[0m[38;5;249m3[0m[38;5;249m2[0m[38;5;249m)[0m[38;5;249m [0m[38;5;249m-[0m[38;5;249m>[0m[38;5;249m [0m[38;5;249mf[0m[38;5;249m3[0m[38;5;249m2[0m[38;5;249m [0m[38;5;249m{[0m[38;5;249m [0mx [38;5;249m}[0m
[38;5;240m │[0m ─┬
[38;5;240m │[0m ╰── Type mismatch. Expected i32, but got f32
───╯


----- stderr -----
Error: "Check failed with 2 errors"
Error: "Check failed with 4 errors"
1 change: 1 addition & 0 deletions crates/cli/tests/snapshots/[email protected]
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ exit_code: 0
----- stdout -----
11
5
3
0

----- stderr -----
22 changes: 14 additions & 8 deletions crates/frontend/src/parser/grammar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ impl Progress {
}
}

#[derive(Debug, PartialEq, Eq, Clone, Copy)]
enum TypAnnot {
Optional,
Required,
}

pub fn prog(p: &mut Parser) {
while !p.at(SyntaxKind::EOF) {
module(p);
Expand Down Expand Up @@ -197,7 +203,7 @@ fn top_func(p: &mut Parser) {
p.error("expected a function name")
}
typ_param_list(p);
param_list(p);
param_list(p, TypAnnot::Required);
if p.eat(T![->]) && !typ(p).made_progress() {
p.error("expected a return type")
}
Expand Down Expand Up @@ -225,15 +231,15 @@ fn typ_param_list(p: &mut Parser) -> Progress {
Progress::Made
}

fn param_list(p: &mut Parser) {
fn param_list(p: &mut Parser, typ_annot_opt: TypAnnot) {
if !p.eat(SyntaxKind::L_PAREN) {
p.error("expected a parameter list");
return;
}
while p.at(SyntaxKind::IDENT) {
let c = p.checkpoint();
p.bump(SyntaxKind::IDENT);
if !typ_annot(p).made_progress() {
if !typ_annot(p).made_progress() && typ_annot_opt == TypAnnot::Required {
p.error("expected a type annotation")
}
p.eat(SyntaxKind::COMMA);
Expand Down Expand Up @@ -428,8 +434,8 @@ fn expr(p: &mut Parser) -> Progress {
match_expr(p);
return Progress::Made;
}
if p.at(T![fn]) {
func_expr(p);
if p.at(T![lambda]) {
lambda_expr(p);
return Progress::Made;
}
if p.at(T![return]) {
Expand Down Expand Up @@ -537,10 +543,10 @@ fn return_expr(p: &mut Parser) {
p.finish_at(c, SyntaxKind::EReturn)
}

fn func_expr(p: &mut Parser) {
fn lambda_expr(p: &mut Parser) {
let c = p.checkpoint();
p.bump(T![fn]);
param_list(p);
p.bump(T![lambda]);
param_list(p, TypAnnot::Optional);
if p.eat(T![->]) && !typ(p).made_progress() {
p.error("expected a return type")
}
Expand Down
4 changes: 4 additions & 0 deletions crates/frontend/src/parser/lexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ pub enum SyntaxKind {
#[token("/")]
SLASH,

#[token("\\")]
BACKSLASH,

#[token("=")]
EQUALS,

Expand Down Expand Up @@ -449,6 +452,7 @@ macro_rules ! T {
[exports] => { SyntaxKind::EXPORTS_KW };
[use] => { SyntaxKind::USE_KW };
[fn] => { SyntaxKind::FN_KW };
[lambda] => { SyntaxKind::BACKSLASH };
[global] => { SyntaxKind::GLOBAL_KW };
[import] => { SyntaxKind::IMPORT_KW };
[from] => { SyntaxKind::FROM_KW };
Expand Down
11 changes: 7 additions & 4 deletions crates/frontend/src/scip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,13 @@ fn index_occurrence_map(
local_id: local_supply.to_string().into(),
};
let symbol = sym.to_string();
(sym, SymbolInformation {
symbol,
..Default::default()
})
(
sym,
SymbolInformation {
symbol,
..Default::default()
},
)
});
if !sym.is_local() {
symbols.push(symbol_information.clone());
Expand Down
9 changes: 0 additions & 9 deletions crates/frontend/src/syntax/nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -632,18 +632,9 @@ pub struct ELambda {
pub(crate) syntax: SyntaxNode,
}
impl ELambda {
pub fn fn_token(&self) -> Option<SyntaxToken> {
support::token(&self.syntax, T![fn])
}
pub fn l_paren_token(&self) -> Option<SyntaxToken> {
support::token(&self.syntax, T!['('])
}
pub fn params(&self) -> AstChildren<Param> {
support::children(&self.syntax)
}
pub fn r_paren_token(&self) -> Option<SyntaxToken> {
support::token(&self.syntax, T![')'])
}
pub fn return_ty(&self) -> Option<Type> {
support::child(&self.syntax)
}
Expand Down
2 changes: 1 addition & 1 deletion crates/frontend/src/syntax/nodes.ungram
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ EIf = 'if' condition:Expr then_branch:EBlock 'else' else_branch:EBlock
EMatch = 'match' scrutinee:Expr '{' EMatchBranch* '}'
EMatchBranch = Pattern '=>' body:EBlock

ELambda = 'fn' '(' Param* ')' return_ty:Type body:Expr
ELambda = Param* return_ty:Type body:Expr

EBlock = '{' Declaration* '}'
EReturn = 'return' Expr
Expand Down
Loading

0 comments on commit 061002e

Please sign in to comment.