Skip to content

Commit

Permalink
Introduces a ParamList AST node (#43)
Browse files Browse the repository at this point in the history
This lets us report better error spans for unannotated lambdas
  • Loading branch information
kritzcreek authored Oct 31, 2024
1 parent 32bc746 commit 4d2d575
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 48 deletions.
6 changes: 3 additions & 3 deletions crates/cli/tests/snapshots/check@cant_infer_lambda.nemo.snap
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ exit_code: 1
[27] Error: Can't infer type of unannotated lambda
╭─[tests/check/cant_infer_lambda.nemo:1:13]
│
[38;5;246m2 │[0m [38;5;249m [0m[38;5;249m [0m[38;5;249ml[0m[38;5;249me[0m[38;5;249mt[0m[38;5;249m [0m[38;5;249mc[0m[38;5;249ml[0m[38;5;249mo[0m[38;5;249ms[0m[38;5;249m [0m[38;5;249m=[0m[38;5;249m [0m[38;5;249m\[0m[38;5;249m([0mz[38;5;249m)[0m[38;5;249m [0m[38;5;249m{[0m[38;5;249m [0m[38;5;249mz[0m[38;5;249m [0m[38;5;249m}[0m[38;5;249m;[0m
[38;5;240m │[0m
[38;5;240m │[0m ╰── Can't infer type of unannotated lambda
[38;5;246m2 │[0m [38;5;249m [0m[38;5;249m [0m[38;5;249ml[0m[38;5;249me[0m[38;5;249mt[0m[38;5;249m [0m[38;5;249mc[0m[38;5;249ml[0m[38;5;249mo[0m[38;5;249ms[0m[38;5;249m [0m[38;5;249m=[0m[38;5;249m [0m[38;5;249m\[0m(z)[38;5;249m [0m[38;5;249m{[0m[38;5;249m [0m[38;5;249mz[0m[38;5;249m [0m[38;5;249m}[0m[38;5;249m;[0m
[38;5;240m │[0m ─┬─
[38;5;240m │[0m ╰── Can't infer type of unannotated lambda
───╯


Expand Down
2 changes: 2 additions & 0 deletions crates/frontend/src/parser/grammar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ fn typ_param_list(p: &mut Parser) -> Progress {
}

fn param_list(p: &mut Parser, typ_annot_opt: TypAnnot) {
let c = p.checkpoint();
if !p.eat(SyntaxKind::L_PAREN) {
p.error("expected a parameter list");
return;
Expand All @@ -249,6 +250,7 @@ fn param_list(p: &mut Parser, typ_annot_opt: TypAnnot) {
p.error("expected a closing paren");
// TODO recover
}
p.finish_at(c, SyntaxKind::ParamList);
}

fn qualifier(p: &mut Parser) -> Progress {
Expand Down
1 change: 1 addition & 0 deletions crates/frontend/src/parser/lexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ pub enum SyntaxKind {
// Composite nodes
Param,
ParamTy,
ParamList,
BinOp,
TyArgList,
EArgList,
Expand Down
69 changes: 46 additions & 23 deletions crates/frontend/src/syntax/nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,14 +186,8 @@ impl TopFn {
pub fn param_tys(&self) -> AstChildren<ParamTy> {
support::children(&self.syntax)
}
pub fn l_paren_token(&self) -> Option<SyntaxToken> {
support::token(&self.syntax, T!['('])
}
pub fn params(&self) -> AstChildren<Param> {
support::children(&self.syntax)
}
pub fn r_paren_token(&self) -> Option<SyntaxToken> {
support::token(&self.syntax, T![')'])
pub fn param_list(&self) -> Option<ParamList> {
support::child(&self.syntax)
}
pub fn arrow_token(&self) -> Option<SyntaxToken> {
support::token(&self.syntax, T![->])
Expand Down Expand Up @@ -248,18 +242,12 @@ impl StructField {
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Param {
pub struct ParamList {
pub(crate) syntax: SyntaxNode,
}
impl Param {
pub fn ident_token(&self) -> Option<SyntaxToken> {
support::token(&self.syntax, T![ident])
}
pub fn colon_token(&self) -> Option<SyntaxToken> {
support::token(&self.syntax, T![:])
}
pub fn ty(&self) -> Option<Type> {
support::child(&self.syntax)
impl ParamList {
pub fn params(&self) -> AstChildren<Param> {
support::children(&self.syntax)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
Expand All @@ -278,6 +266,21 @@ impl EBlock {
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Param {
pub(crate) syntax: SyntaxNode,
}
impl Param {
pub fn ident_token(&self) -> Option<SyntaxToken> {
support::token(&self.syntax, T![ident])
}
pub fn colon_token(&self) -> Option<SyntaxToken> {
support::token(&self.syntax, T![:])
}
pub fn ty(&self) -> Option<Type> {
support::child(&self.syntax)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TyInt {
pub(crate) syntax: SyntaxNode,
}
Expand Down Expand Up @@ -632,8 +635,8 @@ pub struct ELambda {
pub(crate) syntax: SyntaxNode,
}
impl ELambda {
pub fn params(&self) -> AstChildren<Param> {
support::children(&self.syntax)
pub fn param_list(&self) -> Option<ParamList> {
support::child(&self.syntax)
}
pub fn return_ty(&self) -> Option<Type> {
support::child(&self.syntax)
Expand Down Expand Up @@ -1118,9 +1121,9 @@ impl AstNode for StructField {
&self.syntax
}
}
impl AstNode for Param {
impl AstNode for ParamList {
fn can_cast(kind: SyntaxKind) -> bool {
kind == Param
kind == ParamList
}
fn cast(syntax: SyntaxNode) -> Option<Self> {
if Self::can_cast(syntax.kind()) {
Expand Down Expand Up @@ -1148,6 +1151,21 @@ impl AstNode for EBlock {
&self.syntax
}
}
impl AstNode for Param {
fn can_cast(kind: SyntaxKind) -> bool {
kind == Param
}
fn cast(syntax: SyntaxNode) -> Option<Self> {
if Self::can_cast(syntax.kind()) {
Some(Self { syntax })
} else {
None
}
}
fn syntax(&self) -> &SyntaxNode {
&self.syntax
}
}
impl AstNode for TyInt {
fn can_cast(kind: SyntaxKind) -> bool {
kind == TyInt
Expand Down Expand Up @@ -2321,7 +2339,7 @@ impl std::fmt::Display for StructField {
std::fmt::Display::fmt(self.syntax(), f)
}
}
impl std::fmt::Display for Param {
impl std::fmt::Display for ParamList {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self.syntax(), f)
}
Expand All @@ -2331,6 +2349,11 @@ impl std::fmt::Display for EBlock {
std::fmt::Display::fmt(self.syntax(), f)
}
}
impl std::fmt::Display for Param {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self.syntax(), f)
}
}
impl std::fmt::Display for TyInt {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self.syntax(), f)
Expand Down
5 changes: 3 additions & 2 deletions crates/frontend/src/syntax/nodes.ungram
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ StructField = 'ident' ':' Type

TopVariant = 'variant' 'upper_ident' type_params:ParamTy* '{' TopStruct* '}'

TopFn = 'fn' 'ident' ParamTy* '(' Param* ')' ('->' Type)? body:EBlock
TopFn = 'fn' 'ident' ParamTy* ParamList ('->' Type)? body:EBlock
ParamList = Param*
Param = 'ident' ':' Type
ParamTy = 'ident'

Expand Down Expand Up @@ -107,7 +108,7 @@ EIf = 'if' condition:Expr then_branch:EBlock 'else' else_branch:EBlock
EMatch = 'match' scrutinee:Expr '{' EMatchBranch* '}'
EMatchBranch = Pattern '=>' body:EBlock

ELambda = Param* return_ty:Type body:Expr
ELambda = ParamList return_ty:Type body:Expr

EBlock = '{' Declaration* '}'
EReturn = 'return' Expr
Expand Down
50 changes: 30 additions & 20 deletions crates/frontend/src/types/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -577,8 +577,11 @@ impl Typechecker<'_> {
ty_args.push(name)
}

let Some(param_list) = top_fn.param_list() else {
continue;
};
let mut arguments = vec![];
for param in top_fn.params() {
for param in param_list.params() {
let ty = param
.ty()
.map(|t| self.check_ty(errors, &mut scope, &t))
Expand Down Expand Up @@ -785,7 +788,10 @@ impl Typechecker<'_> {
scope.add_type_var(v, *name);
}

for (param, ty) in top_fn.params().zip(func_ty.arguments.into_iter()) {
let Some(param_list) = top_fn.param_list() else {
continue;
};
for (param, ty) in param_list.params().zip(func_ty.arguments.into_iter()) {
let Some(ident_tkn) = param.ident_token() else {
funcs = None;
continue;
Expand Down Expand Up @@ -1052,23 +1058,24 @@ impl Typechecker<'_> {
scope.enter_block();
let prev_return_ty = scope.set_return_type(ty_func.result.clone());
let mut params = HashSet::new();
for param in lambda.params() {
let Some(name_tkn) = param.ident_token() else {
continue;
};
let (name, sym) = self.name_supply.local_idx(&name_tkn);
let ty = match param.ty() {
None => {
// TODO: report a single error spanning the entire param list
errors.report(&param, CantInferLambda);
Ty::Error
}
Some(t) => self.check_ty(errors, scope, &t),
};
builder.params(Some((name, ty.clone())));
ty_func.arguments.push(ty.clone());
params.insert(name);
scope.add_var(sym, ty, name);
if let Some(param_list) = lambda.param_list() {
for param in param_list.params() {
let Some(name_tkn) = param.ident_token() else {
continue;
};
let (name, sym) = self.name_supply.local_idx(&name_tkn);
let ty = match param.ty() {
None => {
errors.report(&param_list, CantInferLambda);
Ty::Error
}
Some(t) => self.check_ty(errors, scope, &t),
};
builder.params(Some((name, ty.clone())));
ty_func.arguments.push(ty.clone());
params.insert(name);
scope.add_var(sym, ty, name);
}
}

if let Some(body) = lambda.body() {
Expand Down Expand Up @@ -1983,7 +1990,10 @@ impl Typechecker<'_> {
scope.enter_block();
let prev_return_ty = scope.set_return_type(expected.result.clone());
// TODO: Check for duplicate parameter names
let params: Vec<Param> = expr.params().collect();
let params: Vec<Param> = expr
.param_list()
.map(|pl| pl.params().collect())
.unwrap_or_default();
if params.len() != expected.arguments.len() {
errors.report(
// TODO: Make a node for the arg list, so we can report it instead of the whole lambda
Expand Down

0 comments on commit 4d2d575

Please sign in to comment.