Skip to content

Commit

Permalink
Sync to upstream/release/579 (#943)
Browse files Browse the repository at this point in the history
A pretty small changelist this week:

* When type inference fails to find any matching overload for a
function, we were declining to commit any changes to the type graph at
all. This was resulting in confusing type errors in certain cases. Now,
when a matching overload cannot be found, we always commit to the first
overload we tried.

JIT

* Fix missing variadic register invalidation in FALLBACK_GETVARARGS
* Add a missing null pointer check for the result of luaT_gettm

---------

Co-authored-by: Arseny Kapoulkine <[email protected]>
Co-authored-by: Vyacheslav Egorov <[email protected]>
  • Loading branch information
3 people authored Jun 2, 2023
1 parent 271c509 commit 63679f7
Show file tree
Hide file tree
Showing 28 changed files with 597 additions and 461 deletions.
11 changes: 9 additions & 2 deletions Analysis/include/Luau/TypeInfer.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,14 @@ struct ModuleResolver;

using Name = std::string;
using ScopePtr = std::shared_ptr<Scope>;
using OverloadErrorEntry = std::tuple<std::vector<TypeError>, std::vector<TypeId>, const FunctionType*>;

struct OverloadErrorEntry
{
TxnLog log;
ErrorVec errors;
std::vector<TypeId> arguments;
const FunctionType* fnTy;
};

bool doesCallError(const AstExprCall* call);
bool hasBreak(AstStat* node);
Expand Down Expand Up @@ -166,7 +173,7 @@ struct TypeChecker
const std::vector<OverloadErrorEntry>& errors);
void reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack,
const std::vector<Location>& argLocations, const std::vector<TypeId>& overloads, const std::vector<TypeId>& overloadsThatMatchArgCount,
const std::vector<OverloadErrorEntry>& errors);
std::vector<OverloadErrorEntry>& errors);

WithPredicate<TypePackId> checkExprList(const ScopePtr& scope, const Location& location, const AstArray<AstExpr*>& exprs,
bool substituteFreeForNil = false, const std::vector<bool>& lhsAnnotations = {},
Expand Down
47 changes: 24 additions & 23 deletions Analysis/src/ConstraintSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1300,6 +1300,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
Instantiation inst(TxnLog::empty(), arena, TypeLevel{}, constraint->scope);

std::vector<TypeId> arityMatchingOverloads;
std::optional<TxnLog> bestOverloadLog;

for (TypeId overload : overloads)
{
Expand Down Expand Up @@ -1330,29 +1331,24 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
if (const auto& e = hasUnificationTooComplex(u.errors))
reportError(*e);

if (const auto& e = hasCountMismatch(u.errors);
(!e || get<CountMismatch>(*e)->context != CountMismatch::Context::Arg) && get<FunctionType>(*instantiated))
{
const auto& e = hasCountMismatch(u.errors);
bool areArgumentsCompatible = (!e || get<CountMismatch>(*e)->context != CountMismatch::Context::Arg) && get<FunctionType>(*instantiated);
if (areArgumentsCompatible)
arityMatchingOverloads.push_back(*instantiated);
}

if (u.errors.empty())
{
if (c.callSite)
(*c.astOverloadResolvedTypes)[c.callSite] = *instantiated;

// We found a matching overload.
const auto [changedTypes, changedPacks] = u.log.getChanges();
u.log.commit();
unblock(changedTypes);
unblock(changedPacks);
unblock(c.result);

InstantiationQueuer queuer{constraint->scope, constraint->location, this};
queuer.traverse(fn);
queuer.traverse(inferredTy);

return true;
// This overload has no errors, so override the bestOverloadLog and use this one.
bestOverloadLog = std::move(u.log);
break;
}
else if (areArgumentsCompatible && !bestOverloadLog)
{
// This overload is erroneous. Replace its inferences with `any` iff there isn't already a TxnLog.
bestOverloadLog = std::move(u.log);
}
}

Expand All @@ -1365,15 +1361,20 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull<cons
(*c.astOverloadResolvedTypes)[c.callSite] = arityMatchingOverloads.at(0);
}

// We found no matching overloads.
Unifier u{normalizer, constraint->scope, Location{}, Covariant};
u.enableScopeTests();
// We didn't find any overload that were a viable candidate, so replace the inferences with `any`.
if (!bestOverloadLog)
{
Unifier u{normalizer, constraint->scope, Location{}, Covariant};
u.enableScopeTests();

u.tryUnify(inferredTy, builtinTypes->anyType);
u.tryUnify(fn, builtinTypes->anyType);
u.tryUnify(inferredTy, builtinTypes->anyType);
u.tryUnify(fn, builtinTypes->anyType);

const auto [changedTypes, changedPacks] = u.log.getChanges();
u.log.commit();
bestOverloadLog = std::move(u.log);
}

const auto [changedTypes, changedPacks] = bestOverloadLog->getChanges();
bestOverloadLog->commit();

unblock(changedTypes);
unblock(changedPacks);
Expand Down
139 changes: 61 additions & 78 deletions Analysis/src/TxnLog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,94 +111,77 @@ void TxnLog::concatAsIntersections(TxnLog rhs, NotNull<TypeArena> arena)

void TxnLog::concatAsUnion(TxnLog rhs, NotNull<TypeArena> arena)
{
if (FFlag::DebugLuauDeferredConstraintResolution)
/*
* Check for cycles.
*
* We must not combine a log entry that binds 'a to 'b with a log that
* binds 'b to 'a.
*
* Of the two, identify the one with the 'bigger' scope and eliminate the
* entry that rebinds it.
*/
for (const auto& [rightTy, rightRep] : rhs.typeVarChanges)
{
/*
* Check for cycles.
*
* We must not combine a log entry that binds 'a to 'b with a log that
* binds 'b to 'a.
*
* Of the two, identify the one with the 'bigger' scope and eliminate the
* entry that rebinds it.
*/
for (const auto& [rightTy, rightRep] : rhs.typeVarChanges)
{
if (rightRep->dead)
continue;

// We explicitly use get_if here because we do not wish to do anything
// if the uncommitted type is already bound to something else.
const FreeType* rf = get_if<FreeType>(&rightTy->ty);
if (!rf)
continue;

const BoundType* rb = Luau::get<BoundType>(&rightRep->pending);
if (!rb)
continue;

const TypeId leftTy = rb->boundTo;
const FreeType* lf = get_if<FreeType>(&leftTy->ty);
if (!lf)
continue;

auto leftRep = typeVarChanges.find(leftTy);
if (!leftRep)
continue;

if ((*leftRep)->dead)
continue;

const BoundType* lb = Luau::get<BoundType>(&(*leftRep)->pending);
if (!lb)
continue;

if (lb->boundTo == rightTy)
{
// leftTy has been bound to rightTy, but rightTy has also been bound
// to leftTy. We find the one that belongs to the more deeply nested
// scope and remove it from the log.
const bool discardLeft = useScopes ? subsumes(lf->scope, rf->scope) : lf->level.subsumes(rf->level);

if (discardLeft)
(*leftRep)->dead = true;
else
rightRep->dead = true;
}
}
if (rightRep->dead)
continue;

for (auto& [ty, rightRep] : rhs.typeVarChanges)
// We explicitly use get_if here because we do not wish to do anything
// if the uncommitted type is already bound to something else.
const FreeType* rf = get_if<FreeType>(&rightTy->ty);
if (!rf)
continue;

const BoundType* rb = Luau::get<BoundType>(&rightRep->pending);
if (!rb)
continue;

const TypeId leftTy = rb->boundTo;
const FreeType* lf = get_if<FreeType>(&leftTy->ty);
if (!lf)
continue;

auto leftRep = typeVarChanges.find(leftTy);
if (!leftRep)
continue;

if ((*leftRep)->dead)
continue;

const BoundType* lb = Luau::get<BoundType>(&(*leftRep)->pending);
if (!lb)
continue;

if (lb->boundTo == rightTy)
{
if (rightRep->dead)
continue;

if (auto leftRep = typeVarChanges.find(ty); leftRep && !(*leftRep)->dead)
{
TypeId leftTy = arena->addType((*leftRep)->pending);
TypeId rightTy = arena->addType(rightRep->pending);

if (follow(leftTy) == follow(rightTy))
typeVarChanges[ty] = std::move(rightRep);
else
typeVarChanges[ty]->pending.ty = UnionType{{leftTy, rightTy}};
}
// leftTy has been bound to rightTy, but rightTy has also been bound
// to leftTy. We find the one that belongs to the more deeply nested
// scope and remove it from the log.
const bool discardLeft = useScopes ? subsumes(lf->scope, rf->scope) : lf->level.subsumes(rf->level);

if (discardLeft)
(*leftRep)->dead = true;
else
typeVarChanges[ty] = std::move(rightRep);
rightRep->dead = true;
}
}
else

for (auto& [ty, rightRep] : rhs.typeVarChanges)
{
for (auto& [ty, rightRep] : rhs.typeVarChanges)
if (rightRep->dead)
continue;

if (auto leftRep = typeVarChanges.find(ty); leftRep && !(*leftRep)->dead)
{
if (auto leftRep = typeVarChanges.find(ty))
{
TypeId leftTy = arena->addType((*leftRep)->pending);
TypeId rightTy = arena->addType(rightRep->pending);
typeVarChanges[ty]->pending.ty = UnionType{{leftTy, rightTy}};
}
else
TypeId leftTy = arena->addType((*leftRep)->pending);
TypeId rightTy = arena->addType(rightRep->pending);

if (follow(leftTy) == follow(rightTy))
typeVarChanges[ty] = std::move(rightRep);
else
typeVarChanges[ty]->pending.ty = UnionType{{leftTy, rightTy}};
}
else
typeVarChanges[ty] = std::move(rightRep);
}

for (auto& [tp, rep] : rhs.typePackChanges)
Expand Down
10 changes: 7 additions & 3 deletions Analysis/src/TypeFamily.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,14 +347,18 @@ TypeFamilyReductionResult<TypeId> addFamilyFn(std::vector<TypeId> typeParams, st
const NormalizedType* normLhsTy = normalizer->normalize(lhsTy);
const NormalizedType* normRhsTy = normalizer->normalize(rhsTy);

if (normLhsTy && normRhsTy && normLhsTy->isNumber() && normRhsTy->isNumber())
if (!normLhsTy || !normRhsTy)
{
return {builtins->numberType, false, {}, {}};
return {std::nullopt, false, {}, {}};
}
else if (log->is<AnyType>(lhsTy) || log->is<AnyType>(rhsTy))
else if (log->is<AnyType>(normLhsTy->tops) || log->is<AnyType>(normRhsTy->tops))
{
return {builtins->anyType, false, {}, {}};
}
else if (normLhsTy->isNumber() && normRhsTy->isNumber())
{
return {builtins->numberType, false, {}, {}};
}
else if (log->is<ErrorType>(lhsTy) || log->is<ErrorType>(rhsTy))
{
return {builtins->errorRecoveryType(), false, {}, {}};
Expand Down
31 changes: 22 additions & 9 deletions Analysis/src/TypeInfer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ LUAU_FASTFLAG(LuauOccursIsntAlwaysFailure)
LUAU_FASTFLAGVARIABLE(LuauTypecheckTypeguards, false)
LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false)
LUAU_FASTFLAGVARIABLE(LuauTypecheckClassTypeIndexers, false)
LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false)

namespace Luau
{
Expand Down Expand Up @@ -4387,7 +4388,12 @@ std::unique_ptr<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const
else
overloadsThatDont.push_back(fn);

errors.emplace_back(std::move(state.errors), args->head, ftv);
errors.push_back(OverloadErrorEntry{
std::move(state.log),
std::move(state.errors),
args->head,
ftv,
});
}
else
{
Expand All @@ -4407,7 +4413,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal
{
// No overloads succeeded: Scan for one that would have worked had the user
// used a.b() rather than a:b() or vice versa.
for (const auto& [_, argVec, ftv] : errors)
for (const auto& e : errors)
{
// Did you write foo:bar() when you should have written foo.bar()?
if (expr.self)
Expand All @@ -4418,7 +4424,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal
TypePackId editedArgPack = addTypePack(TypePack{editedParamList});

Unifier editedState = mkUnifier(scope, expr.location);
checkArgumentList(scope, *expr.func, editedState, editedArgPack, ftv->argTypes, editedArgLocations);
checkArgumentList(scope, *expr.func, editedState, editedArgPack, e.fnTy->argTypes, editedArgLocations);

if (editedState.errors.empty())
{
Expand All @@ -4433,7 +4439,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal
return true;
}
}
else if (ftv->hasSelf)
else if (e.fnTy->hasSelf)
{
// Did you write foo.bar() when you should have written foo:bar()?
if (AstExprIndexName* indexName = expr.func->as<AstExprIndexName>())
Expand All @@ -4449,7 +4455,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal

Unifier editedState = mkUnifier(scope, expr.location);

checkArgumentList(scope, *expr.func, editedState, editedArgPack, ftv->argTypes, editedArgLocations);
checkArgumentList(scope, *expr.func, editedState, editedArgPack, e.fnTy->argTypes, editedArgLocations);

if (editedState.errors.empty())
{
Expand All @@ -4472,11 +4478,14 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal

void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack,
const std::vector<Location>& argLocations, const std::vector<TypeId>& overloads, const std::vector<TypeId>& overloadsThatMatchArgCount,
const std::vector<OverloadErrorEntry>& errors)
std::vector<OverloadErrorEntry>& errors)
{
if (overloads.size() == 1)
{
reportErrors(std::get<0>(errors.front()));
if (FFlag::LuauAlwaysCommitInferencesOfFunctionCalls)
errors.front().log.commit();

reportErrors(errors.front().errors);
return;
}

Expand All @@ -4498,11 +4507,15 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast
const FunctionType* ftv = get<FunctionType>(overload);

auto error = std::find_if(errors.begin(), errors.end(), [ftv](const OverloadErrorEntry& e) {
return ftv == std::get<2>(e);
return ftv == e.fnTy;
});

LUAU_ASSERT(error != errors.end());
reportErrors(std::get<0>(*error));

if (FFlag::LuauAlwaysCommitInferencesOfFunctionCalls)
error->log.commit();

reportErrors(error->errors);

// If only one overload matched, we don't need this error because we provided the previous errors.
if (overloadsThatMatchArgCount.size() == 1)
Expand Down
Loading

0 comments on commit 63679f7

Please sign in to comment.