Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduces a ParamList AST node #43

Merged
merged 1 commit into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions crates/cli/tests/snapshots/check@cant_infer_lambda.nemo.snap
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ exit_code: 1
[27] Error: Can't infer type of unannotated lambda
╭─[tests/check/cant_infer_lambda.nemo:1:13]
│
2 │   let clos = \(z) { z };
 │
 │ ╰── Can't infer type of unannotated lambda
2 │   let clos = \(z) { z };
 │ ─┬─
 │ ╰── Can't infer type of unannotated lambda
───╯


Expand Down
2 changes: 2 additions & 0 deletions crates/frontend/src/parser/grammar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ fn typ_param_list(p: &mut Parser) -> Progress {
}

fn param_list(p: &mut Parser, typ_annot_opt: TypAnnot) {
let c = p.checkpoint();
if !p.eat(SyntaxKind::L_PAREN) {
p.error("expected a parameter list");
return;
Expand All @@ -249,6 +250,7 @@ fn param_list(p: &mut Parser, typ_annot_opt: TypAnnot) {
p.error("expected a closing paren");
// TODO recover
}
p.finish_at(c, SyntaxKind::ParamList);
}

fn qualifier(p: &mut Parser) -> Progress {
Expand Down
1 change: 1 addition & 0 deletions crates/frontend/src/parser/lexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ pub enum SyntaxKind {
// Composite nodes
Param,
ParamTy,
ParamList,
BinOp,
TyArgList,
EArgList,
Expand Down
69 changes: 46 additions & 23 deletions crates/frontend/src/syntax/nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,14 +186,8 @@ impl TopFn {
pub fn param_tys(&self) -> AstChildren<ParamTy> {
support::children(&self.syntax)
}
pub fn l_paren_token(&self) -> Option<SyntaxToken> {
support::token(&self.syntax, T!['('])
}
pub fn params(&self) -> AstChildren<Param> {
support::children(&self.syntax)
}
pub fn r_paren_token(&self) -> Option<SyntaxToken> {
support::token(&self.syntax, T![')'])
pub fn param_list(&self) -> Option<ParamList> {
support::child(&self.syntax)
}
pub fn arrow_token(&self) -> Option<SyntaxToken> {
support::token(&self.syntax, T![->])
Expand Down Expand Up @@ -248,18 +242,12 @@ impl StructField {
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Param {
pub struct ParamList {
pub(crate) syntax: SyntaxNode,
}
impl Param {
pub fn ident_token(&self) -> Option<SyntaxToken> {
support::token(&self.syntax, T![ident])
}
pub fn colon_token(&self) -> Option<SyntaxToken> {
support::token(&self.syntax, T![:])
}
pub fn ty(&self) -> Option<Type> {
support::child(&self.syntax)
impl ParamList {
pub fn params(&self) -> AstChildren<Param> {
support::children(&self.syntax)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
Expand All @@ -278,6 +266,21 @@ impl EBlock {
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Param {
pub(crate) syntax: SyntaxNode,
}
impl Param {
pub fn ident_token(&self) -> Option<SyntaxToken> {
support::token(&self.syntax, T![ident])
}
pub fn colon_token(&self) -> Option<SyntaxToken> {
support::token(&self.syntax, T![:])
}
pub fn ty(&self) -> Option<Type> {
support::child(&self.syntax)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TyInt {
pub(crate) syntax: SyntaxNode,
}
Expand Down Expand Up @@ -632,8 +635,8 @@ pub struct ELambda {
pub(crate) syntax: SyntaxNode,
}
impl ELambda {
pub fn params(&self) -> AstChildren<Param> {
support::children(&self.syntax)
pub fn param_list(&self) -> Option<ParamList> {
support::child(&self.syntax)
}
pub fn return_ty(&self) -> Option<Type> {
support::child(&self.syntax)
Expand Down Expand Up @@ -1118,9 +1121,9 @@ impl AstNode for StructField {
&self.syntax
}
}
impl AstNode for Param {
impl AstNode for ParamList {
fn can_cast(kind: SyntaxKind) -> bool {
kind == Param
kind == ParamList
}
fn cast(syntax: SyntaxNode) -> Option<Self> {
if Self::can_cast(syntax.kind()) {
Expand Down Expand Up @@ -1148,6 +1151,21 @@ impl AstNode for EBlock {
&self.syntax
}
}
impl AstNode for Param {
fn can_cast(kind: SyntaxKind) -> bool {
kind == Param
}
fn cast(syntax: SyntaxNode) -> Option<Self> {
if Self::can_cast(syntax.kind()) {
Some(Self { syntax })
} else {
None
}
}
fn syntax(&self) -> &SyntaxNode {
&self.syntax
}
}
impl AstNode for TyInt {
fn can_cast(kind: SyntaxKind) -> bool {
kind == TyInt
Expand Down Expand Up @@ -2321,7 +2339,7 @@ impl std::fmt::Display for StructField {
std::fmt::Display::fmt(self.syntax(), f)
}
}
impl std::fmt::Display for Param {
impl std::fmt::Display for ParamList {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self.syntax(), f)
}
Expand All @@ -2331,6 +2349,11 @@ impl std::fmt::Display for EBlock {
std::fmt::Display::fmt(self.syntax(), f)
}
}
impl std::fmt::Display for Param {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self.syntax(), f)
}
}
impl std::fmt::Display for TyInt {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self.syntax(), f)
Expand Down
5 changes: 3 additions & 2 deletions crates/frontend/src/syntax/nodes.ungram
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ StructField = 'ident' ':' Type

TopVariant = 'variant' 'upper_ident' type_params:ParamTy* '{' TopStruct* '}'

TopFn = 'fn' 'ident' ParamTy* '(' Param* ')' ('->' Type)? body:EBlock
TopFn = 'fn' 'ident' ParamTy* ParamList ('->' Type)? body:EBlock
ParamList = Param*
Param = 'ident' ':' Type
ParamTy = 'ident'

Expand Down Expand Up @@ -107,7 +108,7 @@ EIf = 'if' condition:Expr then_branch:EBlock 'else' else_branch:EBlock
EMatch = 'match' scrutinee:Expr '{' EMatchBranch* '}'
EMatchBranch = Pattern '=>' body:EBlock

ELambda = Param* return_ty:Type body:Expr
ELambda = ParamList return_ty:Type body:Expr

EBlock = '{' Declaration* '}'
EReturn = 'return' Expr
Expand Down
50 changes: 30 additions & 20 deletions crates/frontend/src/types/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -577,8 +577,11 @@ impl Typechecker<'_> {
ty_args.push(name)
}

let Some(param_list) = top_fn.param_list() else {
continue;
};
let mut arguments = vec![];
for param in top_fn.params() {
for param in param_list.params() {
let ty = param
.ty()
.map(|t| self.check_ty(errors, &mut scope, &t))
Expand Down Expand Up @@ -785,7 +788,10 @@ impl Typechecker<'_> {
scope.add_type_var(v, *name);
}

for (param, ty) in top_fn.params().zip(func_ty.arguments.into_iter()) {
let Some(param_list) = top_fn.param_list() else {
continue;
};
for (param, ty) in param_list.params().zip(func_ty.arguments.into_iter()) {
let Some(ident_tkn) = param.ident_token() else {
funcs = None;
continue;
Expand Down Expand Up @@ -1052,23 +1058,24 @@ impl Typechecker<'_> {
scope.enter_block();
let prev_return_ty = scope.set_return_type(ty_func.result.clone());
let mut params = HashSet::new();
for param in lambda.params() {
let Some(name_tkn) = param.ident_token() else {
continue;
};
let (name, sym) = self.name_supply.local_idx(&name_tkn);
let ty = match param.ty() {
None => {
// TODO: report a single error spanning the entire param list
errors.report(&param, CantInferLambda);
Ty::Error
}
Some(t) => self.check_ty(errors, scope, &t),
};
builder.params(Some((name, ty.clone())));
ty_func.arguments.push(ty.clone());
params.insert(name);
scope.add_var(sym, ty, name);
if let Some(param_list) = lambda.param_list() {
for param in param_list.params() {
let Some(name_tkn) = param.ident_token() else {
continue;
};
let (name, sym) = self.name_supply.local_idx(&name_tkn);
let ty = match param.ty() {
None => {
errors.report(&param_list, CantInferLambda);
Ty::Error
}
Some(t) => self.check_ty(errors, scope, &t),
};
builder.params(Some((name, ty.clone())));
ty_func.arguments.push(ty.clone());
params.insert(name);
scope.add_var(sym, ty, name);
}
}

if let Some(body) = lambda.body() {
Expand Down Expand Up @@ -1983,7 +1990,10 @@ impl Typechecker<'_> {
scope.enter_block();
let prev_return_ty = scope.set_return_type(expected.result.clone());
// TODO: Check for duplicate parameter names
let params: Vec<Param> = expr.params().collect();
let params: Vec<Param> = expr
.param_list()
.map(|pl| pl.params().collect())
.unwrap_or_default();
if params.len() != expected.arguments.len() {
errors.report(
// TODO: Make a node for the arg list, so we can report it instead of the whole lambda
Expand Down
Loading