Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional arguments #223

Open
wants to merge 15 commits into
base: mlscript
Choose a base branch
from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ project/Dependencies.scala
project/metals.sbt
**.worksheet.sc
.DS_Store
.idea/
16 changes: 16 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"type": "java",
"request": "attach",
"name": "gqd debugger",
// "projectName": "mlscript",
"hostName": "localhost",
"port": "8000"
}
]
}
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@
"strings": "off"
}
},
"files.autoSave": "off"
"files.autoSave": "afterDelay"
}
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class ClassLifter(logDebugMsg: Boolean = false) {
selPath2Term(l.map(x => genParName(x.name)).updated(0, "this").reverse, v)
})
}
private def toFldsEle(trm: Term): (Option[Var], Fld) = (None, Fld(FldFlags(false, false, false), trm))
private def toFldsEle(trm: Term): (Option[Var], Fld) = (None, Fld(FldFlags(false, false, false, false), trm))

def getSupClsInfoByTerm(parentTerm: Term): (List[TypeName], List[(Var, Fld)]) = parentTerm match{
case Var(nm) => List(TypeName(nm)) -> Nil
Expand Down Expand Up @@ -498,7 +498,7 @@ class ClassLifter(logDebugMsg: Boolean = false) {
private def liftTypeField(target: Field)(using ctx: LocalContext, cache: ClassCache, globFuncs: Map[Var, (Var, LocalContext)], outer: Option[ClassInfoCache]): (Field, LocalContext) = {
val (inT, iCtx) = target.in.map(liftType).unzip
val (outT, oCtx) = liftType(target.out)
Field(inT, outT) -> (iCtx.getOrElse(emptyCtx) ++ oCtx)
Field(inT, outT, false) -> (iCtx.getOrElse(emptyCtx) ++ oCtx)
}

private def liftType(target: Type)(using ctx: LocalContext, cache: ClassCache, globFuncs: Map[Var, (Var, LocalContext)], outer: Option[ClassInfoCache]): (Type, LocalContext) = target match{
Expand Down
47 changes: 36 additions & 11 deletions shared/src/main/scala/mlscript/ConstraintSolver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class ConstraintSolver extends NormalForms { self: Typer =>
Nil)
val ty = d.typeSignature
S(
if (d.fd.isMut) FieldType(S(ty), ty)(d.prov)
if (d.fd.isMut) FieldType(S(ty), ty, false)(d.prov)
else ty.toUpper(d.prov)
)
case S(p: NuParam) =>
Expand Down Expand Up @@ -512,6 +512,7 @@ class ConstraintSolver extends NormalForms { self: Typer =>
case (LhsRefined(S(Without(b, _)), _, _, _), RhsBot) => rec(b, BotType, true)
case (LhsTop, _) | (LhsRefined(N, empty(), RecordType(Nil), empty()), _) =>
// TODO ^ actually get rid of LhsTop and RhsBot...? (might make constraint solving slower)
println(s"REPORT ERROR 515")
reportError()
case (LhsRefined(_, ts, _, trs), RhsBases(pts, _, _)) if ts.exists(pts.contains) => ()

Expand All @@ -530,6 +531,7 @@ class ConstraintSolver extends NormalForms { self: Typer =>
case (_, RhsBases(_, _, trs)) if trs.nonEmpty => die

case (_, RhsBot) | (_, RhsBases(Nil, N, _)) =>
println("REPORT ERRROR LINE 534")
reportError()

case (LhsRefined(S(f0@FunctionType(l0, r0)), ts, r, _)
Expand All @@ -554,7 +556,10 @@ class ConstraintSolver extends NormalForms { self: Typer =>
lit match {
case _: IntLit | _: DecLit => rec(fldTy.lb.getOrElse(TopType), DecType, false)
case _: StrLit => rec(fldTy.lb.getOrElse(TopType), StrType, false)
case _: UnitLit => reportError()
case _: UnitLit => {
println(s"REPORT ERROR 560")
reportError()
}
}

// * This deals with the implicit Eql type member for user-defined classes.
Expand Down Expand Up @@ -604,7 +609,10 @@ class ConstraintSolver extends NormalForms { self: Typer =>
case S(Without(b, ns)) =>
if (ns(n)) rec(b, RhsBases(ots, N, trs).toType(), true)
else rec(b, done_rs.toType(), true)
case _ => reportError()
case _ => {
println("REPORT ERROR 613")
reportError()
}
}
}
case (LhsRefined(N, ts, r, trs), RhsBases(pts, N, trs2)) =>
Expand All @@ -614,9 +622,15 @@ class ConstraintSolver extends NormalForms { self: Typer =>
case _ => Nil
}.contains(p.id)))
println(s"OK $ts <: $pts")
else reportError()
else {
println(s"REPORT ERROR 624")
reportError()
}
case (LhsRefined(N, ts, r, _), RhsBases(pts, S(L(_: FunctionType | _: ArrayBase)), _)) =>
reportError()
{
println(s"REPORT ERROR 627")
reportError()
}
case (LhsRefined(S(b: TupleType), ts, r, _), RhsBases(pts, S(L(ty: TupleType)), _))
if b.fields.size === ty.fields.size =>
(b.fields.unzip._2 lazyZip ty.fields.unzip._2).foreach { (l, r) =>
Expand Down Expand Up @@ -952,7 +966,7 @@ class ConstraintSolver extends NormalForms { self: Typer =>
}


case (TupleType(fs0), TupleType(fs1)) if fs0.size === fs1.size => // TODO generalize (coerce compatible tuples)
case (t0 @ TupleType(fs0), t1 @ TupleType(fs1)) if t0.isLengthCompatibleWith(t1) => {
fs0.lazyZip(fs1).foreach { case ((ln, l), (rn, r)) =>
ln.foreach { ln => rn.foreach { rn =>
if (ln =/= rn) err(
Expand All @@ -962,6 +976,7 @@ class ConstraintSolver extends NormalForms { self: Typer =>
recLb(r, l)
rec(l.ub, r.ub, false)
}
}
case (t: ArrayBase, a: ArrayType) =>
recLb(a.inner, t.inner)
rec(t.inner.ub, a.inner.ub, false)
Expand All @@ -984,7 +999,10 @@ class ConstraintSolver extends NormalForms { self: Typer =>
case (RecordType(fs0), RecordType(fs1)) =>
fs1.foreach { case (n1, t1) =>
fs0.find(_._1 === n1).fold {
reportError()
{
println(s"REPORTERROR 933")
reportError()
}
} { case (n0, t0) =>
recLb(t1, t0)
rec(t0.ub, t1.ub, false)
Expand Down Expand Up @@ -1040,7 +1058,10 @@ class ConstraintSolver extends NormalForms { self: Typer =>
if (tr1.mayHaveTransitiveSelfType) rec(tr1.expand, tr2.expand, true)
else (tr1.mkClsTag, tr2.mkClsTag) match {
case (S(tag1), S(tag2)) if !(tag1 <:< tag2) =>
reportError()
{
println(s"REPORTERROR 992")
reportError()
}
case _ =>
rec(tr1.expand, tr2.expand, true)
}
Expand Down Expand Up @@ -1139,14 +1160,17 @@ class ConstraintSolver extends NormalForms { self: Typer =>
goToWork(lhs, rhs)
case (_: ClassTag | _: TraitTag, _: TraitTag) =>
goToWork(lhs, rhs)
case _ => reportError()
case _ => {
println(s"REPORTERROR 1093")
reportError()
}
}}
}}()

def reportError(failureOpt: Opt[Message] = N)(implicit cctx: ConCtx, ctx: Ctx): Unit = {
val lhs = cctx._1.head
val rhs = cctx._2.head

println(s"CONSTRAINT FAILURE: $lhs <: $rhs")
// println(s"CTX: ${cctx.map(_.map(lr => s"${lr._1} <: ${lr._2} [${lr._1.prov}] [${lr._2.prov}]"))}")

Expand Down Expand Up @@ -1309,7 +1333,6 @@ class ConstraintSolver extends NormalForms { self: Typer =>
case _ => doesntMatch(rhs)
})


val mismatchMessage =
msg"Type mismatch in ${prov.desc}:" -> show(prov.loco) :: (
msg"${printProv(lhsProv)} `${lhs.expPos}` $failure"
Expand All @@ -1321,6 +1344,8 @@ class ConstraintSolver extends NormalForms { self: Typer =>
msg"but it flows into ${l.prov.desc}$expTyMsg" -> show(l.prov.loco) :: Nil
}.toList.flatten

println(s"dbg [ConstraintSolver]: ${flowHint}")

val constraintProvenanceHints = mk_constraintProvenanceHints

var first = true
Expand Down
6 changes: 3 additions & 3 deletions shared/src/main/scala/mlscript/JSBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1202,10 +1202,10 @@ abstract class JSBackend {

def prepare(nme: Str, fs: Ls[Opt[Var] -> Fld], pars: Ls[Term], unit: TypingUnit) = {
val params = fs.map {
case (S(nme), Fld(FldFlags(mut, spec, _), trm)) =>
case (S(nme), Fld(FldFlags(mut, spec, opt, _), trm)) =>
val ty = tt(trm)
nme -> Field(if (mut) S(ty) else N, ty)
case (N, Fld(FldFlags(mut, spec, _), nme: Var)) => nme -> Field(if (mut) S(Bot) else N, Top)
nme -> Field(if (mut) S(ty) else N, ty, opt)
case (N, Fld(FldFlags(mut, spec, opt, _), nme: Var)) => nme -> Field(if (mut) S(Bot) else N, Top, opt)
case _ => die
}
val publicCtors = fs.filter {
Expand Down
26 changes: 13 additions & 13 deletions shared/src/main/scala/mlscript/MLParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class MLParser(origin: Origin, indent: Int = 0, recordLocations: Bool = true) {
}
def toParamsTy(t: Type): Tuple = t match {
case t: Tuple => t
case _ => Tuple((N, Field(None, t)) :: Nil)
case _ => Tuple((N, Field(None, t, false)) :: Nil)
}

def letter[p: P] = P( lowercase | uppercase )
Expand Down Expand Up @@ -69,14 +69,14 @@ class MLParser(origin: Origin, indent: Int = 0, recordLocations: Bool = true) {

def parens[p: P]: P[Term] = locate(P( "(" ~/ parenCell.rep(0, ",") ~ ",".!.? ~ ")" ).map {
case (Seq(Right(t -> false)), N) => Bra(false, t)
case (Seq(Right(t -> true)), N) => Tup(N -> Fld(FldFlags(true, false, false), t) :: Nil) // ? single tuple with mutable
case (Seq(Right(t -> true)), N) => Tup(N -> Fld(FldFlags(true, false, false, false), t) :: Nil) // ? single tuple with mutable
case (ts, _) =>
if (ts.forall(_.isRight)) Tup(ts.iterator.map {
case R(f) => N -> Fld(FldFlags(f._2, false, false), f._1)
case R(f) => N -> Fld(FldFlags(f._2, false, false, false), f._1)
case _ => die // left unreachable
}.toList)
else Splc(ts.map {
case R((v, m)) => R(Fld(FldFlags(m, false, false), v))
case R((v, m)) => R(Fld(FldFlags(m, false, false, false), v))
case L(spl) => L(spl)
}.toList)
})
Expand Down Expand Up @@ -106,8 +106,8 @@ class MLParser(origin: Origin, indent: Int = 0, recordLocations: Bool = true) {
"{" ~/ (kw("mut").!.? ~ fieldName ~ "=" ~ term map L.apply).|(kw("mut").!.? ~
variable map R.apply).rep(sep = ";" | ",") ~ "}"
).map { fs => Rcd(fs.map{
case L((mut, v, t)) => v -> Fld(FldFlags(mut.isDefined, false, false), t)
case R(mut -> id) => id -> Fld(FldFlags(mut.isDefined, false, false), id) }.toList)})
case L((mut, v, t)) => v -> Fld(FldFlags(mut.isDefined, false, false, false), t)
case R(mut -> id) => id -> Fld(FldFlags(mut.isDefined, false, false, false), id) }.toList)})

def fun[p: P]: P[Term] = locate(P( kw("fun") ~/ term ~ "->" ~ term ).map(nb => Lam(toParams(nb._1), nb._2)))

Expand Down Expand Up @@ -274,8 +274,8 @@ class MLParser(origin: Origin, indent: Int = 0, recordLocations: Bool = true) {
def rcd[p: P]: P[Record] =
locate(P( "{" ~/ ( kw("mut").!.? ~ fieldName ~ ":" ~ ty).rep(sep = ";") ~ "}" )
.map(_.toList.map {
case (None, v, t) => v -> Field(None, t)
case (Some(_), v, t) => v -> Field(Some(t), t)
case (None, v, t) => v -> Field(None, t, false)
case (Some(_), v, t) => v -> Field(Some(t), t, false)
} pipe Record))

def parTyCell[p: P]: P[Either[Type, (Type, Boolean)]] = (("..." | kw("mut")).!.? ~ ty). map {
Expand All @@ -289,14 +289,14 @@ class MLParser(origin: Origin, indent: Int = 0, recordLocations: Bool = true) {
case (fs, _) =>
if (fs.forall(_._2.isRight))
Tuple(fs.map {
case (l, Right(t -> false)) => l -> Field(None, t)
case (l, Right(t -> true)) => l -> Field(Some(t), t)
case (l, Right(t -> false)) => l -> Field(None, t, false)
case (l, Right(t -> true)) => l -> Field(Some(t), t, false)
case _ => ??? // unreachable
})
else Splice(fs.map{ _._2 match {
case L(l) => L(l)
case R(r -> true) => R(Field(Some(r), r))
case R(r -> false) => R(Field(None, r))
case R(r -> true) => R(Field(Some(r), r, false))
case R(r -> false) => R(Field(None, r, false))
} })
})
def litTy[p: P]: P[Type] = P( lit.map(l => Literal(l).withLocOf(l)) )
Expand Down Expand Up @@ -327,7 +327,7 @@ class MLParser(origin: Origin, indent: Int = 0, recordLocations: Bool = true) {
}
def tup = parTy.map {
case t: Tuple => t
case t => Tuple(N -> Field(N, t) :: Nil)
case t => Tuple(N -> Field(N, t, false) :: Nil)
}
P((ctorName ~ tup.?).map {
case (id, S(body: Tuple)) =>
Expand Down
10 changes: 7 additions & 3 deletions shared/src/main/scala/mlscript/NewLexer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@ class NewLexer(origin: Origin, raise: Diagnostic => Unit, dbg: Bool) {
";",
// ",",
"#",
"`"
"`",
// ".",
// "<",
// ">",
"?:"
)

private val isAlphaOp = Set(
Expand Down Expand Up @@ -396,9 +397,12 @@ class NewLexer(origin: Origin, raise: Diagnostic => Unit, dbg: Bool) {
lex(k, ind, next(k, SELECT(name)))
}
else lex(j, ind, next(j, if (isSymKeyword.contains(n)) KEYWORD(n) else IDENT(n, true)))
}
else {
// println(s"dbg [NewLexer.scala]: n: $n, j: $j, isSymKeyword: ${isSymKeyword.contains(n)}")
// else go(j, i f (isSymKeyword.contains(n)) KEYWORD(n) else IDENT(n, true))
lex(j, ind, next(j, if (isSymKeyword.contains(n)) KEYWORD(n) else IDENT(n, true)))
}
// else go(j, if (isSymKeyword.contains(n)) KEYWORD(n) else IDENT(n, true))
else lex(j, ind, next(j, if (isSymKeyword.contains(n)) KEYWORD(n) else IDENT(n, true)))
case _ if isDigit(c) =>
val (lit, j) = num(i)
// go(j, LITVAL(IntLit(BigInt(str))))
Expand Down
Loading
Loading