From 67bfda07412b7570e99cfce381268ece58a82aad Mon Sep 17 00:00:00 2001 From: Christoph Hegemann Date: Thu, 31 Oct 2024 18:17:15 +0100 Subject: [PATCH] Introduces a ParamList AST node This lets us report better error spans for unannotated lambdas --- .../check@cant_infer_lambda.nemo.snap | 6 +- crates/frontend/src/parser/grammar.rs | 2 + crates/frontend/src/parser/lexer.rs | 1 + crates/frontend/src/syntax/nodes.rs | 69 ++++++++++++------- crates/frontend/src/syntax/nodes.ungram | 5 +- crates/frontend/src/types/check.rs | 50 ++++++++------ 6 files changed, 85 insertions(+), 48 deletions(-) diff --git a/crates/cli/tests/snapshots/check@cant_infer_lambda.nemo.snap b/crates/cli/tests/snapshots/check@cant_infer_lambda.nemo.snap index 5d28b50..5116dcd 100644 --- a/crates/cli/tests/snapshots/check@cant_infer_lambda.nemo.snap +++ b/crates/cli/tests/snapshots/check@cant_infer_lambda.nemo.snap @@ -15,9 +15,9 @@ exit_code: 1 [27] Error: Can't infer type of unannotated lambda â•­─[tests/check/cant_infer_lambda.nemo:1:13] │ - 2 │   let clos = \(z) { z }; -  │ ┬ -  │ ╰── Can't infer type of unannotated lambda + 2 │   let clos = \(z) { z }; +  │ ─┬─ +  │ ╰─── Can't infer type of unannotated lambda ───╯ diff --git a/crates/frontend/src/parser/grammar.rs b/crates/frontend/src/parser/grammar.rs index ff913d1..71d1e92 100644 --- a/crates/frontend/src/parser/grammar.rs +++ b/crates/frontend/src/parser/grammar.rs @@ -232,6 +232,7 @@ fn typ_param_list(p: &mut Parser) -> Progress { } fn param_list(p: &mut Parser, typ_annot_opt: TypAnnot) { + let c = p.checkpoint(); if !p.eat(SyntaxKind::L_PAREN) { p.error("expected a parameter list"); return; @@ -249,6 +250,7 @@ fn param_list(p: &mut Parser, typ_annot_opt: TypAnnot) { p.error("expected a closing paren"); // TODO recover } + p.finish_at(c, SyntaxKind::ParamList); } fn qualifier(p: &mut Parser) -> Progress { diff --git a/crates/frontend/src/parser/lexer.rs b/crates/frontend/src/parser/lexer.rs index 220054b..9847222 100644 --- a/crates/frontend/src/parser/lexer.rs +++ b/crates/frontend/src/parser/lexer.rs @@ -307,6 +307,7 @@ pub enum SyntaxKind { // Composite nodes Param, ParamTy, + ParamList, BinOp, TyArgList, EArgList, diff --git a/crates/frontend/src/syntax/nodes.rs b/crates/frontend/src/syntax/nodes.rs index 71734f7..f562d72 100644 --- a/crates/frontend/src/syntax/nodes.rs +++ b/crates/frontend/src/syntax/nodes.rs @@ -186,14 +186,8 @@ impl TopFn { pub fn param_tys(&self) -> AstChildren { support::children(&self.syntax) } - pub fn l_paren_token(&self) -> Option { - support::token(&self.syntax, T!['(']) - } - pub fn params(&self) -> AstChildren { - support::children(&self.syntax) - } - pub fn r_paren_token(&self) -> Option { - support::token(&self.syntax, T![')']) + pub fn param_list(&self) -> Option { + support::child(&self.syntax) } pub fn arrow_token(&self) -> Option { support::token(&self.syntax, T![->]) @@ -248,18 +242,12 @@ impl StructField { } } #[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct Param { +pub struct ParamList { pub(crate) syntax: SyntaxNode, } -impl Param { - pub fn ident_token(&self) -> Option { - support::token(&self.syntax, T![ident]) - } - pub fn colon_token(&self) -> Option { - support::token(&self.syntax, T![:]) - } - pub fn ty(&self) -> Option { - support::child(&self.syntax) +impl ParamList { + pub fn params(&self) -> AstChildren { + support::children(&self.syntax) } } #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -278,6 +266,21 @@ impl EBlock { } } #[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Param { + pub(crate) syntax: SyntaxNode, +} +impl Param { + pub fn ident_token(&self) -> Option { + support::token(&self.syntax, T![ident]) + } + pub fn colon_token(&self) -> Option { + support::token(&self.syntax, T![:]) + } + pub fn ty(&self) -> Option { + support::child(&self.syntax) + } +} +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct TyInt { pub(crate) syntax: SyntaxNode, } @@ -632,8 +635,8 @@ pub struct ELambda { pub(crate) syntax: SyntaxNode, } impl ELambda { - pub fn params(&self) -> AstChildren { - support::children(&self.syntax) + pub fn param_list(&self) -> Option { + support::child(&self.syntax) } pub fn return_ty(&self) -> Option { support::child(&self.syntax) @@ -1118,9 +1121,9 @@ impl AstNode for StructField { &self.syntax } } -impl AstNode for Param { +impl AstNode for ParamList { fn can_cast(kind: SyntaxKind) -> bool { - kind == Param + kind == ParamList } fn cast(syntax: SyntaxNode) -> Option { if Self::can_cast(syntax.kind()) { @@ -1148,6 +1151,21 @@ impl AstNode for EBlock { &self.syntax } } +impl AstNode for Param { + fn can_cast(kind: SyntaxKind) -> bool { + kind == Param + } + fn cast(syntax: SyntaxNode) -> Option { + if Self::can_cast(syntax.kind()) { + Some(Self { syntax }) + } else { + None + } + } + fn syntax(&self) -> &SyntaxNode { + &self.syntax + } +} impl AstNode for TyInt { fn can_cast(kind: SyntaxKind) -> bool { kind == TyInt @@ -2321,7 +2339,7 @@ impl std::fmt::Display for StructField { std::fmt::Display::fmt(self.syntax(), f) } } -impl std::fmt::Display for Param { +impl std::fmt::Display for ParamList { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(self.syntax(), f) } @@ -2331,6 +2349,11 @@ impl std::fmt::Display for EBlock { std::fmt::Display::fmt(self.syntax(), f) } } +impl std::fmt::Display for Param { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(self.syntax(), f) + } +} impl std::fmt::Display for TyInt { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(self.syntax(), f) diff --git a/crates/frontend/src/syntax/nodes.ungram b/crates/frontend/src/syntax/nodes.ungram index 3469719..d19269c 100644 --- a/crates/frontend/src/syntax/nodes.ungram +++ b/crates/frontend/src/syntax/nodes.ungram @@ -30,7 +30,8 @@ StructField = 'ident' ':' Type TopVariant = 'variant' 'upper_ident' type_params:ParamTy* '{' TopStruct* '}' -TopFn = 'fn' 'ident' ParamTy* '(' Param* ')' ('->' Type)? body:EBlock +TopFn = 'fn' 'ident' ParamTy* ParamList ('->' Type)? body:EBlock +ParamList = Param* Param = 'ident' ':' Type ParamTy = 'ident' @@ -107,7 +108,7 @@ EIf = 'if' condition:Expr then_branch:EBlock 'else' else_branch:EBlock EMatch = 'match' scrutinee:Expr '{' EMatchBranch* '}' EMatchBranch = Pattern '=>' body:EBlock -ELambda = Param* return_ty:Type body:Expr +ELambda = ParamList return_ty:Type body:Expr EBlock = '{' Declaration* '}' EReturn = 'return' Expr diff --git a/crates/frontend/src/types/check.rs b/crates/frontend/src/types/check.rs index 8359cd3..8d0b4df 100644 --- a/crates/frontend/src/types/check.rs +++ b/crates/frontend/src/types/check.rs @@ -577,8 +577,11 @@ impl Typechecker<'_> { ty_args.push(name) } + let Some(param_list) = top_fn.param_list() else { + continue; + }; let mut arguments = vec![]; - for param in top_fn.params() { + for param in param_list.params() { let ty = param .ty() .map(|t| self.check_ty(errors, &mut scope, &t)) @@ -785,7 +788,10 @@ impl Typechecker<'_> { scope.add_type_var(v, *name); } - for (param, ty) in top_fn.params().zip(func_ty.arguments.into_iter()) { + let Some(param_list) = top_fn.param_list() else { + continue; + }; + for (param, ty) in param_list.params().zip(func_ty.arguments.into_iter()) { let Some(ident_tkn) = param.ident_token() else { funcs = None; continue; @@ -1052,23 +1058,24 @@ impl Typechecker<'_> { scope.enter_block(); let prev_return_ty = scope.set_return_type(ty_func.result.clone()); let mut params = HashSet::new(); - for param in lambda.params() { - let Some(name_tkn) = param.ident_token() else { - continue; - }; - let (name, sym) = self.name_supply.local_idx(&name_tkn); - let ty = match param.ty() { - None => { - // TODO: report a single error spanning the entire param list - errors.report(¶m, CantInferLambda); - Ty::Error - } - Some(t) => self.check_ty(errors, scope, &t), - }; - builder.params(Some((name, ty.clone()))); - ty_func.arguments.push(ty.clone()); - params.insert(name); - scope.add_var(sym, ty, name); + if let Some(param_list) = lambda.param_list() { + for param in param_list.params() { + let Some(name_tkn) = param.ident_token() else { + continue; + }; + let (name, sym) = self.name_supply.local_idx(&name_tkn); + let ty = match param.ty() { + None => { + errors.report(¶m_list, CantInferLambda); + Ty::Error + } + Some(t) => self.check_ty(errors, scope, &t), + }; + builder.params(Some((name, ty.clone()))); + ty_func.arguments.push(ty.clone()); + params.insert(name); + scope.add_var(sym, ty, name); + } } if let Some(body) = lambda.body() { @@ -1983,7 +1990,10 @@ impl Typechecker<'_> { scope.enter_block(); let prev_return_ty = scope.set_return_type(expected.result.clone()); // TODO: Check for duplicate parameter names - let params: Vec = expr.params().collect(); + let params: Vec = expr + .param_list() + .map(|pl| pl.params().collect()) + .unwrap_or_default(); if params.len() != expected.arguments.len() { errors.report( // TODO: Make a node for the arg list, so we can report it instead of the whole lambda