Skip to content

Commit

Permalink
Fix expansions on expanded types
Browse files Browse the repository at this point in the history
  • Loading branch information
stanhebben committed Nov 1, 2024
1 parent b8546f9 commit d84c9a8
Show file tree
Hide file tree
Showing 20 changed files with 96 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ public void addExpansion(ExpansionDefinition expansion) {
expansions.add(expansion);
}

@Override
public List<ExpansionSymbol> getAvailableExpansions() {
return expansions;
}

public Optional<IGlobal> findGlobal(String name) {
return Optional.ofNullable(globals.get(name));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import org.openzen.zenscript.codemodel.GenericName;
import org.openzen.zenscript.codemodel.compilation.impl.capture.LocalExpression;
import org.openzen.zenscript.codemodel.expression.LambdaClosure;
import org.openzen.zenscript.codemodel.identifiers.ExpansionSymbol;
import org.openzen.zenscript.codemodel.type.TypeID;

import java.util.List;
Expand Down Expand Up @@ -41,4 +42,6 @@ public interface ExpressionCompiler extends TypeResolver {
ExpressionCompiler withDollar(CompilingExpression value);

StatementCompiler forLambda(LambdaClosure closure, FunctionHeader header);

List<ExpansionSymbol> getAvailableExpansions();
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.openzen.zenscript.codemodel.expression.CallArguments;
import org.openzen.zenscript.codemodel.expression.Expression;
import org.openzen.zenscript.codemodel.generic.TypeParameter;
import org.openzen.zenscript.codemodel.identifiers.ExpansionSymbol;
import org.openzen.zenscript.codemodel.identifiers.instances.MethodInstance;
import org.openzen.zenscript.codemodel.type.BasicTypeID;
import org.openzen.zenscript.codemodel.type.TypeID;
Expand Down Expand Up @@ -143,7 +144,7 @@ private static <T extends AnyMethod> MatchedCallArguments<T> match(
}

// Type inference
Optional<TypeID[]> inferred = inferTypeArguments(expansionTypeArguments, method, result, typeArguments, arguments);
Optional<TypeID[]> inferred = inferTypeArguments(expansionTypeArguments, method, result, typeArguments, compiler.getAvailableExpansions(), arguments);
if (!inferred.isPresent()) {
return new MatchedCallArguments<>(
method,
Expand Down Expand Up @@ -320,6 +321,7 @@ private static <T extends AnyMethod> Optional<TypeID[]> inferTypeArguments(
T method,
TypeID result,
TypeID[] typeArguments,
List<ExpansionSymbol> expansions,
CompilingExpression... arguments
) {
int providedTypeArguments = typeArguments == null ? 0 : typeArguments.length;
Expand All @@ -335,7 +337,7 @@ private static <T extends AnyMethod> Optional<TypeID[]> inferTypeArguments(
// attempt to infer type arguments from the return type
final Map<TypeParameter, TypeID> typeArgumentMap = new HashMap<>();
if (result != null) {
typeArgumentMap.putAll(method.getHeader().getReturnType().inferTypeParameters(result));
typeArgumentMap.putAll(method.getHeader().getReturnType().inferTypeParameters(result, expansions));
}

// create a mapping with everything found so far
Expand All @@ -348,7 +350,7 @@ private static <T extends AnyMethod> Optional<TypeID[]> inferTypeArguments(
Expression evaluated = argument.eval();
if (evaluated.type != BasicTypeID.UNDETERMINED) {
TypeID parameterType = mapper.map(method.getHeader().parameters[i].type);
Map<TypeParameter, TypeID> mapping = parameterType.inferTypeParameters(evaluated.type);
Map<TypeParameter, TypeID> mapping = parameterType.inferTypeParameters(evaluated.type, expansions);
if (mapping != null)
typeArgumentMap.putAll(mapping);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
package org.openzen.zenscript.codemodel.compilation;

import org.openzen.zenscript.codemodel.identifiers.ExpansionSymbol;
import org.openzen.zenscript.codemodel.type.TypeID;

import java.util.List;

public interface TypeResolver {
List<ExpansionSymbol> getAvailableExpansions();

ResolvedType resolve(TypeID type);
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.openzen.zenscript.codemodel.definition.ZSPackage;
import org.openzen.zenscript.codemodel.expression.*;
import org.openzen.zenscript.codemodel.expression.modifiable.ModifiableExpression;
import org.openzen.zenscript.codemodel.identifiers.ExpansionSymbol;
import org.openzen.zenscript.codemodel.identifiers.instances.FieldInstance;
import org.openzen.zenscript.codemodel.identifiers.instances.MethodInstance;
import org.openzen.zenscript.codemodel.member.ref.ImplementationMemberInstance;
Expand Down Expand Up @@ -189,6 +190,11 @@ public StatementCompiler forLambda(LambdaClosure closure, FunctionHeader header)
return new StatementCompilerImpl(context, localType, types, header, newLocals, null);
}

@Override
public List<ExpansionSymbol> getAvailableExpansions() {
return context.getAvailableExpansions();
}

@Override
public ResolvedType resolve(TypeID type) {
return context.resolve(type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

import org.openzen.zenscript.codemodel.FunctionHeader;
import org.openzen.zenscript.codemodel.compilation.*;
import org.openzen.zenscript.codemodel.identifiers.ExpansionSymbol;
import org.openzen.zenscript.codemodel.type.TypeID;

import java.util.List;

public class MemberCompilerImpl implements MemberCompiler {
private final CompileContext context;
private final LocalType localType;
Expand Down Expand Up @@ -42,6 +45,11 @@ public DefinitionCompiler forInner() {
return definitionCompiler;
}

@Override
public List<ExpansionSymbol> getAvailableExpansions() {
return context.getAvailableExpansions();
}

@Override
public ResolvedType resolve(TypeID type) {
return definitionCompiler.resolve(type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,21 @@
import org.openzen.zenscript.codemodel.HighLevelDefinition;
import org.openzen.zenscript.codemodel.Modifiers;
import org.openzen.zenscript.codemodel.compilation.ResolvedType;
import org.openzen.zenscript.codemodel.compilation.ResolvingType;
import org.openzen.zenscript.codemodel.generic.TypeParameter;
import org.openzen.zenscript.codemodel.identifiers.ExpansionSymbol;
import org.openzen.zenscript.codemodel.identifiers.ModuleSymbol;
import org.openzen.zenscript.codemodel.member.IDefinitionMember;
import org.openzen.zenscript.codemodel.type.TypeID;
import org.openzen.zenscript.codemodel.type.TypeMatcher;
import org.openzen.zenscript.codemodel.type.member.InterfaceResolvingType;
import org.openzen.zenscript.codemodel.type.member.MemberSet;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class ExpansionDefinition extends HighLevelDefinition implements ExpansionSymbol {
Expand All @@ -40,11 +45,11 @@ public String getName() {
}

@Override
public Optional<ResolvedType> resolve(TypeID expandingType) {
public Optional<ResolvedType> resolve(TypeID expandingType, List<ExpansionSymbol> expansions) {
if (target == null)
throw new RuntimeException(position.toString() + ": Missing expansion target");

Map<TypeParameter, TypeID> mapping = TypeMatcher.match(expandingType, target);
Map<TypeParameter, TypeID> mapping = TypeMatcher.match(expandingType, target, expansions);
if (mapping == null)
return Optional.empty();

Expand All @@ -54,6 +59,15 @@ public Optional<ResolvedType> resolve(TypeID expandingType) {
for (IDefinitionMember member : members)
member.registerTo(expandingType, resolution, mapper);

return Optional.of(resolution.buildWithoutExpansions());
List<TypeID> interfaces = this.members.stream()
.map(IDefinitionMember::asImplementation)
.filter(Optional::isPresent)
.map(Optional::get)
.map(mapper::map)
.collect(Collectors.toList());

ResolvingType resolved = resolution.build();
ResolvingType withInterfaces = InterfaceResolvingType.of(resolved, interfaces);
return Optional.of(withInterfaces.withExpansions(Collections.emptyList()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import org.openzen.zenscript.codemodel.GenericMapper;
import org.openzen.zenscript.codemodel.compilation.ResolvingType;
import org.openzen.zenscript.codemodel.identifiers.ExpansionSymbol;
import org.openzen.zenscript.codemodel.type.TypeID;

import java.util.List;
import java.util.Optional;

public final class ParameterSuperBound implements TypeParameterBound {
Expand All @@ -24,8 +26,8 @@ public Optional<ResolvingType> resolveMembers() {
}

@Override
public boolean matches(TypeID type) {
return type.extendsOrImplements(type);
public boolean matches(TypeID type, List<ExpansionSymbol> expansions) {
return type.extendsOrImplements(type, expansions);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import org.openzen.zencode.shared.CodePosition;
import org.openzen.zenscript.codemodel.GenericMapper;
import org.openzen.zenscript.codemodel.compilation.ResolvingType;
import org.openzen.zenscript.codemodel.identifiers.ExpansionSymbol;
import org.openzen.zenscript.codemodel.type.TypeID;

import java.util.List;
import java.util.Optional;

public final class ParameterTypeBound implements TypeParameterBound {
Expand All @@ -27,8 +29,8 @@ public Optional<ResolvingType> resolveMembers() {
}

@Override
public boolean matches(TypeID type) {
return type.extendsOrImplements(this.type);
public boolean matches(TypeID type, List<ExpansionSymbol> expansions) {
return type.extendsOrImplements(this.type, expansions);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import org.openzen.zencode.shared.CodePosition;
import org.openzen.zencode.shared.Taggable;
import org.openzen.zenscript.codemodel.GenericMapper;
import org.openzen.zenscript.codemodel.identifiers.ExpansionSymbol;
import org.openzen.zenscript.codemodel.type.TypeID;

import java.util.ArrayList;
Expand Down Expand Up @@ -32,10 +33,10 @@ public boolean isObjectType() {
return false;
}

public boolean matches(TypeID type, GenericMapper mapper) {
public boolean matches(TypeID type, GenericMapper mapper, List<ExpansionSymbol> expansions) {
for (TypeParameterBound bound : bounds) {
TypeParameterBound instanced = bound.instance(mapper);
if (!instanced.matches(type))
if (!instanced.matches(type, expansions))
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import org.openzen.zenscript.codemodel.GenericMapper;
import org.openzen.zenscript.codemodel.compilation.ResolvingType;
import org.openzen.zenscript.codemodel.identifiers.ExpansionSymbol;
import org.openzen.zenscript.codemodel.type.TypeID;

import java.util.List;
import java.util.Optional;

public interface TypeParameterBound {
Expand All @@ -15,7 +17,7 @@ public interface TypeParameterBound {

Optional<ResolvingType> resolveMembers();

boolean matches(TypeID type);
boolean matches(TypeID type, List<ExpansionSymbol> expansions);

TypeParameterBound instance(GenericMapper mapper);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import org.openzen.zenscript.codemodel.compilation.ResolvedType;
import org.openzen.zenscript.codemodel.type.TypeID;

import java.util.List;
import java.util.Optional;

public interface ExpansionSymbol extends DefinitionSymbol {
Optional<ResolvedType> resolve(TypeID expandingType);
Optional<ResolvedType> resolve(TypeID expandingType, List<ExpansionSymbol> expansions);
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ public GenericTypeID(TypeParameter parameter) {
this.parameter = parameter;
}

public boolean matches(TypeID type) {
public boolean matches(TypeID type, List<ExpansionSymbol> expansions) {
GenericMapper mapper = GenericMapper.single(parameter, type);
return parameter.matches(type, mapper);
return parameter.matches(type, mapper, expansions);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.openzen.zenscript.codemodel.compilation.ResolvingType;
import org.openzen.zenscript.codemodel.expression.Expression;
import org.openzen.zenscript.codemodel.generic.TypeParameter;
import org.openzen.zenscript.codemodel.identifiers.ExpansionSymbol;

import java.util.*;

Expand Down Expand Up @@ -44,8 +45,8 @@ default Expression getDefaultValue() {
*
* @return inferred type parameters, or null if no match was found
*/
default Map<TypeParameter, TypeID> inferTypeParameters(TypeID targetType) {
return TypeMatcher.match(this, targetType);
default Map<TypeParameter, TypeID> inferTypeParameters(TypeID targetType, List<ExpansionSymbol> expansions) {
return TypeMatcher.match(this, targetType, expansions);
}

void extractTypeParameters(List<TypeParameter> typeParameters);
Expand Down Expand Up @@ -163,8 +164,8 @@ default ResolvedType resolveWithoutExpansions() {
return this.resolve().withExpansions(Collections.emptyList());
}

default boolean extendsOrImplements(TypeID type) {
return resolveWithoutExpansions().extendsOrImplements(type);
default boolean extendsOrImplements(TypeID type, List<ExpansionSymbol> expansions) {
return this.resolve().withExpansions(expansions).extendsOrImplements(type);
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package org.openzen.zenscript.codemodel.type;

import org.openzen.zenscript.codemodel.generic.TypeParameter;
import org.openzen.zenscript.codemodel.identifiers.ExpansionSymbol;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class TypeMatcher implements TypeVisitorWithContext<TypeMatcher.Matching, Boolean, RuntimeException> {
Expand All @@ -11,8 +13,8 @@ public class TypeMatcher implements TypeVisitorWithContext<TypeMatcher.Matching,
private TypeMatcher() {
}

public static Map<TypeParameter, TypeID> match(TypeID type, TypeID pattern) {
Matching matching = new Matching(type);
public static Map<TypeParameter, TypeID> match(TypeID type, TypeID pattern, List<ExpansionSymbol> expansions) {
Matching matching = new Matching(type, expansions);
if (pattern.accept(matching, INSTANCE))
return matching.mapping;

Expand Down Expand Up @@ -113,7 +115,7 @@ public Boolean visitGeneric(Matching context, GenericTypeID generic) {
if (context.mapping.containsKey(generic.parameter)) {
TypeID argument = context.mapping.get(generic.parameter);
return argument == context.type;
} else if (context.type == generic || generic.matches(context.type)) {
} else if (context.type == generic || generic.matches(context.type, context.expansions)) {
context.mapping.put(generic.parameter, context.type);
return true;
} else {
Expand Down Expand Up @@ -153,19 +155,22 @@ public Boolean visitGenericMap(Matching context, GenericMapTypeID map) {
public static final class Matching {
public final TypeID type;
public final Map<TypeParameter, TypeID> mapping;
private final List<ExpansionSymbol> expansions;

public Matching(TypeID type) {
public Matching(TypeID type, List<ExpansionSymbol> expansions) {
this.type = type;
mapping = new HashMap<>();
this.expansions = expansions;
}

private Matching(TypeID type, Map<TypeParameter, TypeID> mapping) {
private Matching(TypeID type, Map<TypeParameter, TypeID> mapping, List<ExpansionSymbol> expansions) {
this.type = type;
this.mapping = mapping;
this.expansions = expansions;
}

public Matching withType(TypeID type) {
return new Matching(type, mapping);
return new Matching(type, mapping, expansions);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public class ExpandedResolvedType implements ResolvedType {
public static ResolvedType resolve(ResolvedType base, List<ExpansionSymbol> expansions) {
List<ResolvedType> resolutions = new ArrayList<>();
for (ExpansionSymbol expansion : expansions) {
expansion.resolve(base.getType()).ifPresent(resolutions::add);
expansion.resolve(base.getType(), expansions).ifPresent(resolutions::add);
}
return ExpandedResolvedType.of(base, resolutions);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public ResolvedType withExpansions(List<ExpansionSymbol> expansions) {
List<ResolvedType> resolvedInterfaces = implementedInterfaces.stream().map(iface -> iface.resolve().withExpansions(expansions)).collect(Collectors.toList());

List<ResolvedType> interfaceExpansions = implementedInterfaces.stream()
.flatMap(iface -> expansions.stream().map(expansion -> expansion.resolve(iface)).filter(Optional::isPresent).map(Optional::get))
.flatMap(iface -> expansions.stream().map(expansion -> expansion.resolve(iface, expansions)).filter(Optional::isPresent).map(Optional::get))
.collect(Collectors.toList());

return SubtypeResolvedType.ofImplementation(
Expand Down
Loading

0 comments on commit d84c9a8

Please sign in to comment.