From f565be3e67ea7ac6bdb660ff073aedcd67334b05 Mon Sep 17 00:00:00 2001 From: Yorick Peterse Date: Wed, 6 Sep 2023 01:19:57 +0200 Subject: [PATCH] WIP: better generics --- compiler/src/compiler.rs | 17 +- compiler/src/hir.rs | 8 +- compiler/src/llvm/layouts.rs | 15 +- compiler/src/mir/mod.rs | 110 +++++-- compiler/src/mir/passes.rs | 140 ++++---- compiler/src/mir/pattern_matching.rs | 4 +- compiler/src/mir/printer.rs | 15 +- compiler/src/mir/specialize.rs | 400 +++++++++++++++++++++++ compiler/src/symbol_names.rs | 45 ++- compiler/src/type_check/expressions.rs | 5 + rt/src/process.rs | 19 +- types/src/format.rs | 6 +- types/src/lib.rs | 360 +++++++++++++++++++-- types/src/specialize.rs | 422 +++++++++++++++++++++++++ 14 files changed, 1423 insertions(+), 143 deletions(-) create mode 100644 compiler/src/mir/specialize.rs create mode 100644 types/src/specialize.rs diff --git a/compiler/src/compiler.rs b/compiler/src/compiler.rs index 84edfad1a..6a2dd55d1 100644 --- a/compiler/src/compiler.rs +++ b/compiler/src/compiler.rs @@ -3,8 +3,10 @@ use crate::config::{Config, SOURCE, SOURCE_EXT, TESTS}; use crate::hir; use crate::linker::link; use crate::llvm; +use crate::mir::passes as mir; use crate::mir::printer::to_dot; -use crate::mir::{passes as mir, Mir}; +use crate::mir::specialize::Specialize; +use crate::mir::Mir; use crate::modules_parser::{ModulesParser, ParsedModule}; use crate::state::State; use crate::type_check::define_types::{ @@ -131,12 +133,12 @@ impl Compiler { } let mut mir = Mir::new(); + let state = &mut self.state; - mir::check_global_limits(&mut self.state) - .map_err(CompileError::Internal)?; + mir::check_global_limits(state).map_err(CompileError::Internal)?; - if mir::DefineConstants::run_all(&mut self.state, &mut mir, &modules) - && mir::LowerToMir::run_all(&mut self.state, &mut mir, modules) + if mir::DefineConstants::run_all(state, &mut mir, &modules) + && mir::LowerToMir::run_all(state, &mut mir, modules) { Ok(mir) } else { @@ -185,8 +187,9 @@ impl Compiler { } fn optimise_mir(&mut self, mir: &mut Mir) { - mir::ExpandDrop::run_all(&self.state.db, mir); - mir::ExpandReference::run_all(&self.state.db, mir); + Specialize::run_all(&mut self.state, mir); + mir::ExpandDrop::run_all(&self.state.db, mir); // TODO: remove in favour of specialization + mir::ExpandReference::run_all(&self.state.db, mir); // TODO: remove in favour of specialization mir::clean_up_basic_blocks(mir); } diff --git a/compiler/src/hir.rs b/compiler/src/hir.rs index e3da9d522..3ad67ccef 100644 --- a/compiler/src/hir.rs +++ b/compiler/src/hir.rs @@ -336,7 +336,7 @@ pub(crate) struct DefineVariant { } #[derive(Clone, Debug, PartialEq, Eq)] -pub(crate) struct AssignInstanceLiteralField { +pub(crate) struct AssignClassLiteralField { pub(crate) resolved_type: types::TypeRef, pub(crate) field_id: Option, pub(crate) field: Field, @@ -349,7 +349,7 @@ pub(crate) struct ClassLiteral { pub(crate) class_id: Option, pub(crate) resolved_type: types::TypeRef, pub(crate) class_name: Constant, - pub(crate) fields: Vec, + pub(crate) fields: Vec, pub(crate) location: SourceLocation, } @@ -2794,7 +2794,7 @@ impl<'a> LowerToHir<'a> { fields: node .fields .into_iter() - .map(|n| AssignInstanceLiteralField { + .map(|n| AssignClassLiteralField { resolved_type: types::TypeRef::Unknown, field_id: None, field: self.field(n.field), @@ -5923,7 +5923,7 @@ mod tests { name: "A".to_string(), location: cols(8, 8) }, - fields: vec![AssignInstanceLiteralField { + fields: vec![AssignClassLiteralField { resolved_type: types::TypeRef::Unknown, field_id: None, field: Field { diff --git a/compiler/src/llvm/layouts.rs b/compiler/src/llvm/layouts.rs index bf4c53187..c1fde6f6e 100644 --- a/compiler/src/llvm/layouts.rs +++ b/compiler/src/llvm/layouts.rs @@ -12,8 +12,8 @@ use inkwell::AddressSpace; use std::cmp::max; use std::collections::HashMap; use types::{ - ClassId, MethodId, MethodSource, BOOLEAN_ID, BYTE_ARRAY_ID, CALL_METHOD, - CHANNEL_ID, DROPPER_METHOD, FLOAT_ID, INT_ID, NIL_ID, + ClassId, MethodId, MethodSource, TraitId, BOOL_ID, BYTE_ARRAY_ID, + CALL_METHOD, CHANNEL_ID, DROPPER_METHOD, FLOAT_ID, INT_ID, NIL_ID, }; /// The size of an object header. @@ -180,12 +180,13 @@ impl<'ctx> Layouts<'ctx> { // // This information is defined first so we can update the `collision` // flag when generating this information for method implementations. - for mir_trait in mir.traits.values() { - for method in mir_trait - .id + for idx in 0..db.number_of_traits() { + let id = TraitId(idx as _); + + for method in id .required_methods(db) .into_iter() - .chain(mir_trait.id.default_methods(db)) + .chain(id.default_methods(db)) { let name = method.name(db); let hash = method_hasher.hash(name); @@ -249,7 +250,7 @@ impl<'ctx> Layouts<'ctx> { header, context.f64_type().into(), ), - BOOLEAN_ID | NIL_ID => { + BOOL_ID | NIL_ID => { let typ = context.opaque_struct(&name); typ.set_body(&[header.into()], false); diff --git a/compiler/src/mir/mod.rs b/compiler/src/mir/mod.rs index 5ee6a4264..faad3befd 100644 --- a/compiler/src/mir/mod.rs +++ b/compiler/src/mir/mod.rs @@ -2,13 +2,14 @@ //! //! MIR is used for various optimisations, analysing moves of values, compiling //! pattern matching into decision trees, and more. +use crate::symbol_names::class_name; use ast::source_location::SourceLocation; use std::collections::{HashMap, HashSet}; use std::fmt; use std::hash::{Hash, Hasher}; use std::rc::Rc; use types::collections::IndexMap; -use types::{BuiltinFunction, Database}; +use types::{BuiltinFunction, Database, MethodId, TypeArguments}; /// The number of reductions to perform after calling a method. const CALL_COST: u16 = 1; @@ -16,6 +17,7 @@ const CALL_COST: u16 = 1; pub(crate) mod passes; pub(crate) mod pattern_matching; pub(crate) mod printer; +pub(crate) mod specialize; fn join(values: &[RegisterId]) -> String { values.iter().map(|v| format!("r{}", v.0)).collect::>().join(", ") @@ -42,6 +44,10 @@ impl Registers { &self.values[register.0 as usize] } + pub(crate) fn get_mut(&mut self, register: RegisterId) -> &mut Register { + &mut self.values[register.0 as usize] + } + pub(crate) fn value_type(&self, register: RegisterId) -> types::TypeRef { self.get(register).value_type } @@ -49,6 +55,10 @@ impl Registers { pub(crate) fn len(&self) -> usize { self.values.len() } + + pub(crate) fn iter_mut(&mut self) -> impl Iterator { + self.values.iter_mut() + } } /// A directed control-flow graph. @@ -409,12 +419,14 @@ impl Block { register: RegisterId, method: types::MethodId, arguments: Vec, + type_arguments: Option, location: LocationId, ) { self.instructions.push(Instruction::CallStatic(Box::new(CallStatic { register, method, arguments, + type_arguments, location, }))); } @@ -425,10 +437,18 @@ impl Block { receiver: RegisterId, method: types::MethodId, arguments: Vec, + type_arguments: Option, location: LocationId, ) { self.instructions.push(Instruction::CallInstance(Box::new( - CallInstance { register, receiver, method, arguments, location }, + CallInstance { + register, + receiver, + method, + arguments, + type_arguments, + location, + }, ))); } @@ -453,10 +473,18 @@ impl Block { receiver: RegisterId, method: types::MethodId, arguments: Vec, + type_arguments: Option, location: LocationId, ) { self.instructions.push(Instruction::CallDynamic(Box::new( - CallDynamic { register, receiver, method, arguments, location }, + CallDynamic { + register, + receiver, + method, + arguments, + type_arguments, + location, + }, ))); } @@ -500,12 +528,14 @@ impl Block { receiver: RegisterId, method: types::MethodId, arguments: Vec, + type_arguments: Option, location: LocationId, ) { self.instructions.push(Instruction::Send(Box::new(Send { receiver, method, arguments, + type_arguments, location, }))); } @@ -740,6 +770,7 @@ pub(crate) struct Switch { pub(crate) location: LocationId, } +// TODO: remove after implementing specialization #[derive(Clone)] pub(crate) struct SwitchKind { pub(crate) register: RegisterId, @@ -893,6 +924,7 @@ pub(crate) struct CallStatic { pub(crate) register: RegisterId, pub(crate) method: types::MethodId, pub(crate) arguments: Vec, + pub(crate) type_arguments: Option, pub(crate) location: LocationId, } @@ -902,6 +934,7 @@ pub(crate) struct CallInstance { pub(crate) receiver: RegisterId, pub(crate) method: types::MethodId, pub(crate) arguments: Vec, + pub(crate) type_arguments: Option, pub(crate) location: LocationId, } @@ -919,6 +952,7 @@ pub(crate) struct CallDynamic { pub(crate) receiver: RegisterId, pub(crate) method: types::MethodId, pub(crate) arguments: Vec, + pub(crate) type_arguments: Option, pub(crate) location: LocationId, } @@ -943,6 +977,7 @@ pub(crate) struct Send { pub(crate) receiver: RegisterId, pub(crate) method: types::MethodId, pub(crate) arguments: Vec, + pub(crate) type_arguments: Option, pub(crate) location: LocationId, } @@ -1219,15 +1254,27 @@ impl Instruction { format!("return r{}", v.register.0) } Instruction::Allocate(ref v) => { - format!("r{} = allocate {}", v.register.0, v.class.name(db)) + format!( + "r{} = allocate {}", + v.register.0, + class_name(db, v.class), + ) } Instruction::Spawn(ref v) => { format!("r{} = spawn {}", v.register.0, v.class.name(db)) } Instruction::CallStatic(ref v) => { + let class = v + .method + .receiver(db) + .as_class(db) + .map(|v| v.name(db)) + .expect("static methods must have a valid receiver"); + format!( - "r{} = call_static {}({})", + "r{} = call_static {}.{}({})", v.register.0, + class, v.method.name(db), join(&v.arguments) ) @@ -1376,33 +1423,17 @@ impl Class { } } -pub(crate) struct Trait { - pub(crate) id: types::TraitId, - pub(crate) methods: Vec, -} - -impl Trait { - pub(crate) fn new(id: types::TraitId) -> Self { - Self { id, methods: Vec::new() } - } - - pub(crate) fn add_methods(&mut self, methods: &Vec) { - for method in methods { - self.methods.push(method.id); - } - } -} - #[derive(Clone)] pub(crate) struct Module { pub(crate) id: types::ModuleId, pub(crate) classes: Vec, pub(crate) constants: Vec, + pub(crate) location: LocationId, } impl Module { - pub(crate) fn new(id: types::ModuleId) -> Self { - Self { id, classes: Vec::new(), constants: Vec::new() } + pub(crate) fn new(id: types::ModuleId, location: LocationId) -> Self { + Self { id, classes: Vec::new(), constants: Vec::new(), location } } } @@ -1435,8 +1466,22 @@ pub(crate) struct Mir { pub(crate) constants: HashMap, pub(crate) modules: IndexMap, pub(crate) classes: HashMap, - pub(crate) traits: HashMap, pub(crate) methods: HashMap, + + /// The type arguments to expose to call instructions, used to specialize + /// types and method calls. + /// + /// This data is stored out of bounds and addressed through an index, as + /// it's only needed by the specialization pass, and this makes it easy to + /// remove the data once we no longer need it. + pub(crate) type_arguments: Vec, + + /// Methods called through traits/dynamic dispatch. + /// + /// If the method itself is generic, the specialized version is tracked in + /// this set, otherwise the original version is tracked. + pub(crate) trait_calls: HashSet, + locations: Vec, } @@ -1446,8 +1491,9 @@ impl Mir { constants: HashMap::new(), modules: IndexMap::new(), classes: HashMap::new(), - traits: HashMap::new(), methods: HashMap::new(), + type_arguments: Vec::new(), + trait_calls: HashSet::new(), locations: Vec::new(), } } @@ -1458,6 +1504,18 @@ impl Mir { } } + pub(crate) fn add_type_arguments( + &mut self, + arguments: TypeArguments, + ) -> Option { + if arguments.is_empty() { + None + } else { + self.type_arguments.push(arguments); + Some(self.type_arguments.len() - 1) + } + } + pub(crate) fn add_location( &mut self, location: SourceLocation, diff --git a/compiler/src/mir/passes.rs b/compiler/src/mir/passes.rs index 8bb6ca1e5..dcda58af0 100644 --- a/compiler/src/mir/passes.rs +++ b/compiler/src/mir/passes.rs @@ -4,7 +4,7 @@ use crate::hir; use crate::mir::pattern_matching as pmatch; use crate::mir::{ Block, BlockId, CastType, Class, CloneKind, Constant, Goto, Instruction, - LocationId, Method, Mir, Module, RegisterId, Trait, + LocationId, Method, Mir, Module, RegisterId, }; use crate::state::State; use ast::source_location::SourceLocation; @@ -15,9 +15,9 @@ use std::path::PathBuf; use std::rc::Rc; use types::format::format_type; use types::{ - self, Block as _, ClassId, ForeignType, MethodId, TypeBounds, TypeId, - TypeRef, EQ_METHOD, FIELDS_LIMIT, OPTION_NONE, OPTION_SOME, RESULT_CLASS, - RESULT_ERROR, RESULT_MODULE, RESULT_OK, + self, Block as _, ClassId, ForeignType, MethodId, ModuleId, TypeBounds, + TypeId, TypeRef, EQ_METHOD, FIELDS_LIMIT, OPTION_NONE, OPTION_SOME, + RESULT_CLASS, RESULT_ERROR, RESULT_MODULE, RESULT_OK, }; const SELF_NAME: &str = "self"; @@ -346,17 +346,17 @@ impl DecisionState { } } -struct GenerateDropper<'a> { - state: &'a mut State, - mir: &'a mut Mir, - module: &'a mut Module, - class: &'a mut Class, - location: SourceLocation, +pub(crate) struct GenerateDropper<'a> { + pub(crate) state: &'a mut State, + pub(crate) mir: &'a mut Mir, + pub(crate) module: ModuleId, + pub(crate) class: ClassId, + pub(crate) location: LocationId, } impl<'a> GenerateDropper<'a> { - fn run(mut self) { - match self.class.id.kind(&self.state.db) { + pub(crate) fn run(mut self) { + match self.class.kind(&self.state.db) { types::ClassKind::Async => self.async_class(), types::ClassKind::Enum => self.enum_class(), _ => self.regular_class(), @@ -384,7 +384,7 @@ impl<'a> GenerateDropper<'a> { /// the async dropper is the last message. When run, it cleans up the object /// like a regular class, and the process shuts down. fn async_class(&mut self) { - let loc = self.mir.add_location(self.location.clone()); + let loc = self.location; let async_dropper = self.generate_dropper( types::ASYNC_DROPPER_METHOD, types::MethodKind::AsyncMutable, @@ -413,6 +413,7 @@ impl<'a> GenerateDropper<'a> { self_reg, async_dropper, Vec::new(), + None, loc, ); lower.reduce_call(TypeRef::nil(), loc); @@ -430,9 +431,9 @@ impl<'a> GenerateDropper<'a> { /// tag, certain fields may be set to NULL. As such we branch based on the /// tag value, and only drop the fields relevant for that tag. fn enum_class(&mut self) { - let loc = self.mir.add_location(self.location.clone()); + let loc = self.location; let name = types::DROPPER_METHOD; - let class = self.class.id; + let class = self.class; let drop_method_opt = class.method(&self.state.db, types::DROP_METHOD); let method_type = self.method_type(name, types::MethodKind::Mutable); let mut method = Method::new(method_type, loc); @@ -452,6 +453,7 @@ impl<'a> GenerateDropper<'a> { self_reg, id, Vec::new(), + None, loc, ); lower.reduce_call(typ, loc); @@ -516,10 +518,10 @@ impl<'a> GenerateDropper<'a> { free_self: bool, terminate: bool, ) -> MethodId { - let class = self.class.id; + let class = self.class; let drop_method_opt = class.method(&self.state.db, types::DROP_METHOD); let method_type = self.method_type(name, kind); - let loc = self.mir.add_location(self.location.clone()); + let loc = self.location; let mut method = Method::new(method_type, loc); let mut lower = LowerMethod::new(self.state, self.mir, self.module, &mut method); @@ -537,6 +539,7 @@ impl<'a> GenerateDropper<'a> { self_reg, id, Vec::new(), + None, loc, ); lower.reduce_call(typ, loc); @@ -583,7 +586,7 @@ impl<'a> GenerateDropper<'a> { fn method_type(&mut self, name: &str, kind: types::MethodKind) -> MethodId { let id = types::Method::alloc( &mut self.state.db, - self.module.id, + self.module, name.to_string(), types::Visibility::TypePrivate, kind, @@ -592,7 +595,7 @@ impl<'a> GenerateDropper<'a> { let self_type = types::TypeId::ClassInstance(types::ClassInstance::rigid( &mut self.state.db, - self.class.id, + self.class, &types::TypeBounds::new(), )); let receiver = TypeRef::Mut(self_type); @@ -603,8 +606,10 @@ impl<'a> GenerateDropper<'a> { } fn add_method(&mut self, name: &str, id: MethodId, method: Method) { - self.class.id.add_method(&mut self.state.db, name.to_string(), id); - self.class.methods.push(id); + let cid = self.class; + + cid.add_method(&mut self.state.db, name.to_string(), id); + self.mir.classes.get_mut(&cid).unwrap().methods.push(id); self.mir.methods.insert(id, method); } } @@ -823,7 +828,7 @@ impl<'a> DefineConstants<'a> { pub(crate) struct LowerToMir<'a> { state: &'a mut State, mir: &'a mut Mir, - module: &'a mut Module, + module: ModuleId, } impl<'a> LowerToMir<'a> { @@ -848,23 +853,23 @@ impl<'a> LowerToMir<'a> { ) }); + let loc = mir.add_location(module.location); + let id = module.module_id; + mod_types.push(types); mod_nodes.push(rest); - modules.push(Module::new(module.module_id)); + modules.push(id); + mir.modules.insert(id, Module::new(id, loc)); } - for (module, nodes) in modules.iter_mut().zip(mod_types.into_iter()) { + for (&module, nodes) in modules.iter().zip(mod_types.into_iter()) { LowerToMir { state, mir, module }.lower_types(nodes); } - for (module, nodes) in modules.iter_mut().zip(mod_nodes.into_iter()) { + for (&module, nodes) in modules.iter().zip(mod_nodes.into_iter()) { LowerToMir { state, mir, module }.lower_rest(nodes); } - for module in modules { - mir.modules.insert(module.id, module); - } - !state.diagnostics.has_errors() } @@ -886,13 +891,20 @@ impl<'a> LowerToMir<'a> { } fn lower_rest(&mut self, nodes: Vec) { - let id = self.module.id; + let id = self.module; let mut mod_methods = Vec::new(); for expr in nodes { match expr { hir::TopLevelExpression::Constant(n) => { - self.module.constants.push(n.constant_id.unwrap()) + let mod_id = self.module; + + self.mir + .modules + .get_mut(&mod_id) + .unwrap() + .constants + .push(n.constant_id.unwrap()) } hir::TopLevelExpression::ModuleMethod(n) => { mod_methods.push(self.define_module_method(*n)); @@ -916,7 +928,6 @@ impl<'a> LowerToMir<'a> { } fn define_trait(&mut self, node: hir::DefineTrait) { - let id = node.trait_id.unwrap(); let mut methods = Vec::new(); for expr in node.body { @@ -925,11 +936,7 @@ impl<'a> LowerToMir<'a> { } } - let mut mir_trait = Trait::new(id); - - mir_trait.add_methods(&methods); self.mir.add_methods(methods); - self.mir.traits.insert(id, mir_trait); } fn implement_trait(&mut self, node: hir::ImplementTrait) { @@ -986,19 +993,20 @@ impl<'a> LowerToMir<'a> { } let mut class = Class::new(id); + let loc = self.mir.add_location(node.location); + + class.add_methods(&methods); + self.mir.add_methods(methods); + self.add_class(id, class); GenerateDropper { state: self.state, mir: self.mir, module: self.module, - class: &mut class, - location: node.location, + class: id, + location: loc, } .run(); - - class.add_methods(&methods); - self.mir.add_methods(methods); - self.add_class(id, class); } fn define_extern_class(&mut self, node: hir::DefineExternClass) { @@ -1139,8 +1147,10 @@ impl<'a> LowerToMir<'a> { } fn add_class(&mut self, id: types::ClassId, class: Class) { + let mod_id = self.module; + self.mir.classes.insert(id, class); - self.module.classes.push(id); + self.mir.modules.get_mut(&mod_id).unwrap().classes.push(id); } } @@ -1148,7 +1158,7 @@ impl<'a> LowerToMir<'a> { pub(crate) struct LowerMethod<'a> { state: &'a mut State, mir: &'a mut Mir, - module: &'a mut Module, + module: ModuleId, method: &'a mut Method, scope: Box, current_block: BlockId, @@ -1193,7 +1203,7 @@ impl<'a> LowerMethod<'a> { fn new( state: &'a mut State, mir: &'a mut Mir, - module: &'a mut Module, + module: ModuleId, method: &'a mut Method, ) -> Self { let current_block = method.body.add_start_block(); @@ -1868,8 +1878,11 @@ impl<'a> LowerMethod<'a> { return result; } types::Receiver::Class => { + let targs = self.mir.add_type_arguments(info.type_arguments); + self.current_block_mut() - .call_static(result, info.id, arguments, location); + .call_static(result, info.id, arguments, targs, location); + self.reduce_call(info.returns, location); return result; @@ -1880,6 +1893,8 @@ impl<'a> LowerMethod<'a> { rec = self.receiver_for_moving_method(rec, location); } + let targs = self.mir.add_type_arguments(info.type_arguments); + if info.id.is_async(self.db()) { let rec_typ = self.register_type(rec); let msg_rec = self.new_register(rec_typ); @@ -1889,15 +1904,17 @@ impl<'a> LowerMethod<'a> { // (e.g. if new references are created before it runs). self.current_block_mut().increment_atomic(msg_rec, rec, location); self.current_block_mut() - .send(msg_rec, info.id, arguments, location); + .send(msg_rec, info.id, arguments, targs, location); + self.mark_register_as_moved(msg_rec); self.current_block_mut().nil_literal(result, location); } else if info.dynamic { self.current_block_mut() - .call_dynamic(result, rec, info.id, arguments, location); + .call_dynamic(result, rec, info.id, arguments, targs, location); } else { - self.current_block_mut() - .call_instance(result, rec, info.id, arguments, location); + self.current_block_mut().call_instance( + result, rec, info.id, arguments, targs, location, + ); } self.reduce_call(info.returns, location); @@ -2954,6 +2971,7 @@ impl<'a> LowerMethod<'a> { test_reg, eq_method, vec![val_reg], + None, loc, ); @@ -3242,7 +3260,7 @@ impl<'a> LowerMethod<'a> { fn closure(&mut self, node: hir::Closure) -> RegisterId { self.check_inferred(node.resolved_type, &node.location); - let module = self.module.id; + let module = self.module; let closure_id = node.closure_id.unwrap(); let moving = closure_id.is_moving(self.db()); let class_id = types::Class::alloc( @@ -3411,19 +3429,24 @@ impl<'a> LowerMethod<'a> { lower.run(node.body, loc); } + let mod_id = self.module; + + mir_class.methods.push(method_id); + self.mir.methods.insert(method_id, mir_method); + self.mir.classes.insert(class_id, mir_class); + self.mir.modules.get_mut(&mod_id).unwrap().classes.push(class_id); + + let loc = self.mir.add_location(node.location); + GenerateDropper { state: self.state, mir: self.mir, module: self.module, - class: &mut mir_class, - location: node.location, + class: class_id, + location: loc, } .run(); - mir_class.methods.push(method_id); - self.mir.methods.insert(method_id, mir_method); - self.mir.classes.insert(class_id, mir_class); - self.module.classes.push(class_id); gen_class_reg } @@ -4303,7 +4326,7 @@ impl<'a> LowerMethod<'a> { } fn file(&self) -> PathBuf { - self.module.id.file(&self.state.db) + self.module.file(&self.state.db) } fn self_type(&self) -> types::TypeId { @@ -4774,6 +4797,7 @@ impl<'a> ExpandDrop<'a> { value, method, Vec::new(), + None, location, ); } else if !typ.is_any(self.db) { diff --git a/compiler/src/mir/pattern_matching.rs b/compiler/src/mir/pattern_matching.rs index 50c300798..15ad9713f 100644 --- a/compiler/src/mir/pattern_matching.rs +++ b/compiler/src/mir/pattern_matching.rs @@ -25,7 +25,7 @@ use std::collections::{HashMap, HashSet}; use types::resolve::TypeResolver; use types::{ ClassInstance, ClassKind, Database, FieldId, TypeArguments, TypeBounds, - TypeId, TypeRef, VariableId, VariantId, BOOLEAN_ID, INT_ID, STRING_ID, + TypeId, TypeRef, VariableId, VariantId, BOOL_ID, INT_ID, STRING_ID, }; fn add_missing_patterns( @@ -836,7 +836,7 @@ impl<'a> Compiler<'a> { match class_id.0 { INT_ID => Type::Int, STRING_ID => Type::String, - BOOLEAN_ID => Type::Finite(vec![ + BOOL_ID => Type::Finite(vec![ (Constructor::False, Vec::new(), Vec::new()), (Constructor::True, Vec::new(), Vec::new()), ]), diff --git a/compiler/src/mir/printer.rs b/compiler/src/mir/printer.rs index 04c7b2d73..35e55b9c7 100644 --- a/compiler/src/mir/printer.rs +++ b/compiler/src/mir/printer.rs @@ -6,6 +6,7 @@ #![allow(unused)] use crate::mir::{BlockId, Method, Mir}; +use crate::symbol_names::{class_name, method_name}; use std::fmt::Write; use types::{Database, TypeId}; @@ -24,17 +25,17 @@ pub(crate) fn to_dot(db: &Database, mir: &Mir, methods: &[&Method]) -> String { buffer.push_str("edge[fontname=\"monospace\", fontsize=10];\n"); let rec_name = match method.id.receiver_id(db) { - TypeId::Class(id) => id.name(db), - TypeId::Trait(id) => id.name(db), - TypeId::ClassInstance(ins) => ins.instance_of().name(db), - TypeId::TraitInstance(ins) => ins.instance_of().name(db), - _ => "", + TypeId::Class(id) => id.name(db).clone(), + TypeId::Trait(id) => id.name(db).clone(), + TypeId::ClassInstance(ins) => ins.instance_of().name(db).clone(), + TypeId::TraitInstance(ins) => ins.instance_of().name(db).clone(), + _ => String::new(), }; let name = if rec_name.is_empty() { - format!("{}()", method.id.name(db)) + format!("{}()", method_name(db, method.id)) } else { - format!("{}.{}()", rec_name, method.id.name(db)) + format!("{}.{}()", rec_name, method_name(db, method.id)) }; let _ = writeln!(buffer, "label=\"{}\";", name); diff --git a/compiler/src/mir/specialize.rs b/compiler/src/mir/specialize.rs new file mode 100644 index 000000000..1a8769799 --- /dev/null +++ b/compiler/src/mir/specialize.rs @@ -0,0 +1,400 @@ +use crate::mir::passes::GenerateDropper; +use crate::mir::{Class as MirClass, Instruction, Mir, RegisterId, Registers}; +use crate::state::State; +use std::collections::{HashMap, VecDeque}; +use types::specialize::TypeSpecializer; +use types::{ + Block, Database, MethodId, MethodKind, Shape, TypeArguments, + TypeParameterId, TypeRef, +}; + +enum Specialized { + Existing(MethodId), + New(MethodId), +} + +struct Job { + /// The ID of the method that's being specialized. + method: MethodId, + + /// The shapes of the method (including its receiver), in the same order as + /// the type parameters. + shapes: HashMap, +} + +/// The methods that are pending specialization. +struct WorkList { + jobs: VecDeque, +} + +impl WorkList { + fn new() -> WorkList { + WorkList { jobs: VecDeque::new() } + } + + fn push( + &mut self, + method: MethodId, + shapes: HashMap, + ) { + self.jobs.push_back(Job { method, shapes }); + } + + fn pop(&mut self) -> Option { + if let Some(job) = self.jobs.pop_front() { + // self.scheduled.remove(&job); + Some(job) + } else { + None + } + } +} + +/// A compiler pass that specializes generic types. +pub(crate) struct Specialize<'a, 'b> { + method: MethodId, + + state: &'a mut State, + work: &'b mut WorkList, + + /// The shapes of the type parameters the method to specialize has access + /// to. + shapes: HashMap, +} + +impl<'a, 'b> Specialize<'a, 'b> { + pub(crate) fn run_all(state: &'a mut State, mir: &'a mut Mir) { + let mut work = WorkList::new(); + let main_method = state.db.main_method().unwrap(); + + work.push(main_method, HashMap::new()); + + while let Some(job) = work.pop() { + // TODO: remove + if !job + .method + .module(&state.db) + .file(&state.db) + .ends_with("Downloads/test.inko") + { + continue; + } + + Specialize { + state, + method: job.method, + shapes: job.shapes, + work: &mut work, + } + .run(mir); + } + + // We don't need the type arguments after this point. + mir.type_arguments = Vec::new(); + + // TODO: should we leave classes behind for static methods? + // let remove = mir + // .classes + // .keys() + // .filter(|&&id| { + // // TODO: use something better + // id.module(&state.db) + // .file(&state.db) + // .ends_with("Downloads/test.inko") + // && id.is_generic(&state.db) + // && id.get_specialization_id(&state.db).is_none() + // }) + // .cloned() + // .collect::>(); + // + // for module in mir.modules.values_mut() { + // module.classes.retain(|id| { + // // TODO: use something that isn't O(n*2). + // !remove.contains(&id) + // }); + // } + // + // for id in remove { + // let class = mir.classes.remove(&id).unwrap(); + // + // for method in class.methods { + // mir.methods.remove(&method); + // } + // } + } + + fn run(&mut self, mir: &mut Mir) { + // TODO: remove + println!("specializing '{}'", self.method.name(&self.state.db)); + + let mut methods = Vec::new(); + let before_class_id = self.state.db.last_class_id(); + let method = mir.methods.get_mut(&self.method).unwrap(); + + // Rather than specializing the registers of instructions that may + // produce generic types, we just specialize all of them. The type + // specializer bails out if this isn't needed anyway, and this makes our + // code not prone to accidentally forgetting to specialize a register + // when adding or changing MIR instructions. + for reg in method.registers.iter_mut() { + reg.value_type = + TypeSpecializer::new(&mut self.state.db, &self.shapes) + .specialize(reg.value_type); + } + + for block in &mut method.body.blocks { + for ins in &mut block.instructions { + // When specializing a method, we _don't_ store them in any + // class types. Different specializations of the same method use + // the same name, so if they are stored on the same class they'd + // overwrite each other. Since we don't need to look up any + // methods by their names at and beyond this point, we just not + // store them in the class types to begin with. + match ins { + Instruction::CallStatic(ins) + if ins.type_arguments.is_some() => + { + let rec = ins.method.receiver(&self.state.db); + let class = rec.as_class(&self.state.db).unwrap(); + let targs = ins + .type_arguments + .and_then(|i| mir.type_arguments.get(i)) + .unwrap(); + + match self.specialize_call(rec, ins.method, targs) { + Specialized::Existing(id) => ins.method = id, + Specialized::New(id) => { + methods.push((class, ins.method, id)); + ins.method = id; + } + } + } + Instruction::CallInstance(ins) + if ins.type_arguments.is_some() => + { + let rec = method.registers.value_type(ins.receiver); + let class = rec.class_id(&self.state.db).unwrap(); + let targs = ins + .type_arguments + .and_then(|i| mir.type_arguments.get(i)) + .unwrap(); + + match self.specialize_call(rec, ins.method, targs) { + Specialized::Existing(id) => ins.method = id, + Specialized::New(id) => { + methods.push((class, ins.method, id)); + ins.method = id; + } + } + } + Instruction::Send(ins) if ins.type_arguments.is_some() => { + let rec = method.registers.value_type(ins.receiver); + let class = rec.class_id(&self.state.db).unwrap(); + let targs = ins + .type_arguments + .and_then(|i| mir.type_arguments.get(i)) + .unwrap(); + + match self.specialize_call(rec, ins.method, targs) { + Specialized::Existing(id) => ins.method = id, + Specialized::New(id) => { + methods.push((class, ins.method, id)); + ins.method = id; + } + } + } + Instruction::CallDynamic(ins) + if ins.type_arguments.is_some() => + { + let targs = ins + .type_arguments + .and_then(|i| mir.type_arguments.get(i)) + .unwrap(); + + match self.specialize_dynamic_call(ins.method, targs) { + Specialized::Existing(id) => ins.method = id, + Specialized::New(id) => { + // methods.push((class, ins.method, id)); + ins.method = id; + } + } + } + Instruction::Allocate(ins) => { + ins.class = method + .registers + .value_type(ins.register) + .class_id(&self.state.db) + .unwrap(); + } + Instruction::Spawn(ins) => { + ins.class = method + .registers + .value_type(ins.register) + .class_id(&self.state.db) + .unwrap(); + } + Instruction::SetField(ins) => { + ins.class = method + .registers + .value_type(ins.receiver) + .class_id(&self.state.db) + .unwrap(); + } + Instruction::GetField(ins) => { + ins.class = method + .registers + .value_type(ins.receiver) + .class_id(&self.state.db) + .unwrap(); + } + Instruction::FieldPointer(ins) => { + ins.class = method + .registers + .value_type(ins.receiver) + .class_id(&self.state.db) + .unwrap(); + } + _ => {} + } + } + } + + // Specialization may create one or more new classes, which we need to + // add to MIR so we can generate code for them. Instead of tracking + // these in a bunch of places, we just grab all newly created ones + // (since the start of this method) here. + for cid in self.state.db.classes_since(before_class_id) { + mir.classes.insert(cid, MirClass::new(cid)); + + let mod_id = cid.module(&self.state.db); + let module = mir.modules.get_mut(&mod_id).unwrap(); + let loc = module.location; + + module.classes.push(cid); + GenerateDropper { + state: self.state, + mir, + module: mod_id, + class: cid, + location: loc, + } + .run(); + } + + for (class, old, new) in methods { + let mut method = mir.methods[&old].clone(); + + method.id = new; + mir.classes.get_mut(&class).unwrap().methods.push(new); + mir.methods.insert(new, method); + } + } + + fn specialize_call( + &mut self, + receiver: TypeRef, + method: MethodId, + type_arguments: &TypeArguments, + ) -> Specialized { + let mut key = Vec::new(); + let shapes: HashMap<_, _> = type_arguments + .iter() + .map(|(&par, typ)| (par, typ.shape(&self.state.db, &self.shapes))) + .collect(); + + if let Some(class) = receiver.class_id(&self.state.db) { + for param in class.type_parameters(&self.state.db) { + key.push(shapes.get(¶m).unwrap().clone()); + } + } + + for param in method.type_parameters(&self.state.db) { + key.push(shapes.get(¶m).unwrap().clone()); + } + + if let Some(new) = method.get_specialization(&self.state.db, &key) { + return Specialized::Existing(new); + } + + let new_method = + method.clone_for_specialization(&mut self.state.db, key.clone()); + let old_ret = method.return_type(&self.state.db); + + for arg in method.arguments(&self.state.db) { + let arg_type = TypeSpecializer::new(&mut self.state.db, &shapes) + .specialize(arg.value_type); + let raw_var_type = arg.variable.value_type(&self.state.db); + let var_type = TypeSpecializer::new(&mut self.state.db, &shapes) + .specialize(raw_var_type); + + new_method.new_argument( + &mut self.state.db, + arg.name, + var_type, + arg_type, + ); + } + + let new_ret = TypeSpecializer::new(&mut self.state.db, &shapes) + .specialize(old_ret); + + new_method.set_return_type(&mut self.state.db, new_ret); + + // At this point we can only call methods on valid types, so it's safe to + // unwrap. + let rec_id = receiver.type_id(&self.state.db).unwrap(); + let new_rec = match new_method.kind(&self.state.db) { + // Async methods always access `self` through a reference even though + // processes are value types. This way we prevent immutable async + // methods from being able to mutate the process' internal state. + MethodKind::Async => TypeRef::Ref(rec_id), + MethodKind::AsyncMutable => TypeRef::Mut(rec_id), + + // For regular value types (e.g. Int), `self` is always an owned value. + _ if receiver.is_value_type(&self.state.db) => { + TypeRef::Owned(rec_id) + } + MethodKind::Instance => TypeRef::Ref(rec_id), + MethodKind::Mutable | MethodKind::Destructor => { + TypeRef::Mut(rec_id) + } + MethodKind::Static | MethodKind::Moving => TypeRef::Owned(rec_id), + MethodKind::Extern => receiver, + }; + + new_method.set_receiver(&mut self.state.db, new_rec); + method.add_specialization(&mut self.state.db, key, new_method); + self.work.push(new_method, shapes.clone()); + Specialized::New(new_method) + } + + fn specialize_dynamic_call( + &mut self, + method: MethodId, + type_arguments: &TypeArguments, + ) -> Specialized { + let mut key = Vec::new(); + let shapes: HashMap<_, _> = type_arguments + .iter() + .map(|(&par, typ)| (par, typ.shape(&self.state.db, &self.shapes))) + .collect(); + + // We deliberately don't include the receiver's shapes into the key. + // This is done so the LLVM codegen can look up specializations without + // needing to specify the receiver shapes. This in turn is needed for it + // to determine if a dynamic dispatch call site requires probing. + // + // Apart from that, we simply don't need the receiver types, as for + // scheduling the underlying implementation we'll use the shapes of said + // implementation. + for param in method.type_parameters(&self.state.db) { + key.push(shapes.get(¶m).unwrap().clone()); + } + + if let Some(new) = method.get_specialization(&self.state.db, &key) { + return Specialized::Existing(new); + } + + todo!() + } +} diff --git a/compiler/src/symbol_names.rs b/compiler/src/symbol_names.rs index 1d66bece9..cb7c48ebf 100644 --- a/compiler/src/symbol_names.rs +++ b/compiler/src/symbol_names.rs @@ -1,10 +1,33 @@ //! Mangled symbol names for native code. use crate::mir::Mir; use std::collections::HashMap; -use types::{ClassId, ConstantId, Database, MethodId, ModuleId}; +use types::{ClassId, ConstantId, Database, MethodId, ModuleId, Shape}; pub(crate) const SYMBOL_PREFIX: &str = "_I"; +pub(crate) fn name_with_shapes(base: &String, shapes: &[Shape]) -> String { + if shapes.is_empty() { + return base.clone(); + } + + let mut name = format!("{}[", base); + + for shape in shapes { + name.push_str(shape.identifier()); + } + + name.push(']'); + name +} + +pub(crate) fn class_name(db: &Database, id: ClassId) -> String { + name_with_shapes(id.name(db), id.get_shapes(db)) +} + +pub(crate) fn method_name(db: &Database, id: MethodId) -> String { + name_with_shapes(id.name(db), id.get_shapes(db)) +} + /// A cache of mangled symbol names. pub(crate) struct SymbolNames { pub(crate) classes: HashMap, @@ -29,25 +52,37 @@ impl SymbolNames { for &class in &module.classes { let is_mod = class.kind(db).is_module(); - let class_name = - format!("{}T_{}.{}", prefix, mod_name, class.name(db)); + let class_name = format!( + "{}T_{}.{}", + prefix, + mod_name, + class_name(db, class) + ); classes.insert(class, class_name); for &method in &mir.classes[&class].methods { + // Method names include their IDs to ensure specialized + // methods with the same name and type don't conflict with + // each other. let name = if is_mod { // This ensures that methods such as // `std::process.sleep` aren't formatted as // `std::process::std::process.sleep`. This in turn // makes stack traces easier to read. - format!("{}M_{}.{}", prefix, mod_name, method.name(db)) + format!( + "{}M_{}.{}", + prefix, + mod_name, + method_name(db, method) + ) } else { format!( "{}M_{}.{}.{}", prefix, mod_name, class.name(db), - method.name(db) + method_name(db, method) ) }; diff --git a/compiler/src/type_check/expressions.rs b/compiler/src/type_check/expressions.rs index 074dd0588..34f0995d9 100644 --- a/compiler/src/type_check/expressions.rs +++ b/compiler/src/type_check/expressions.rs @@ -2601,6 +2601,7 @@ impl<'a> CheckMethodBody<'a> { receiver: rec_kind, returns, dynamic: rec_id.use_dynamic_dispatch(), + type_arguments: call.type_arguments, }); node.resolved_type = returns; @@ -2721,6 +2722,7 @@ impl<'a> CheckMethodBody<'a> { receiver: rec_kind, returns, dynamic: rec_id.use_dynamic_dispatch(), + type_arguments: call.type_arguments, }); returns @@ -3301,6 +3303,7 @@ impl<'a> CheckMethodBody<'a> { receiver: rec_info, returns, dynamic: rec_id.use_dynamic_dispatch(), + type_arguments: call.type_arguments, }); returns @@ -3662,6 +3665,7 @@ impl<'a> CheckMethodBody<'a> { receiver: rec_info, returns, dynamic: rec_id.use_dynamic_dispatch(), + type_arguments: call.type_arguments, }); returns @@ -3759,6 +3763,7 @@ impl<'a> CheckMethodBody<'a> { receiver: rec_info, returns, dynamic: rec_id.use_dynamic_dispatch(), + type_arguments: call.type_arguments, }); returns diff --git a/rt/src/process.rs b/rt/src/process.rs index 376a484ff..48759a22d 100644 --- a/rt/src/process.rs +++ b/rt/src/process.rs @@ -647,17 +647,26 @@ impl Process { for frame in trace.frames() { backtrace::resolve(frame.ip(), |symbol| { let name = if let Some(sym_name) = symbol.name() { - let name = sym_name.as_str().unwrap_or(""); - // We only want to include frames for Inko source code, not // any additional frames introduced by the runtime library // and its dependencies. - if let Some(name) = - name.strip_prefix(INKO_SYMBOL_IDENTIFIER) + let base = if let Some(name) = sym_name + .as_str() + .unwrap_or("") + .strip_prefix(INKO_SYMBOL_IDENTIFIER) { - name.to_string() + name } else { return; + }; + + // Methods include the type shapes to prevent name conflicts + // as a result of specialization. We get rid of these to + // ensure the stacktraces are easier to understand. + if let Some(idx) = base.find('[') { + base[0..idx].to_string() + } else { + base.to_string() } } else { String::new() diff --git a/types/src/format.rs b/types/src/format.rs index 497ffe57d..23ca4bb8b 100644 --- a/types/src/format.rs +++ b/types/src/format.rs @@ -186,7 +186,9 @@ impl FormatType for TypeParameterId { if let Some(arg) = buffer.type_arguments.and_then(|a| a.get(*self)) { if let TypeRef::Placeholder(p) = arg { match p.value(buffer.db) { - Some(t) if t.as_type_parameter() == Some(*self) => { + Some(t) + if t.as_type_parameter(buffer.db) == Some(*self) => + { self.format_type_without_argument(buffer) } Some(t) => t.format_type(buffer), @@ -196,7 +198,7 @@ impl FormatType for TypeParameterId { return; } - if arg.as_type_parameter() == Some(*self) { + if arg.as_type_parameter(buffer.db) == Some(*self) { self.format_type_without_argument(buffer); return; } diff --git a/types/src/lib.rs b/types/src/lib.rs index 23b554ee2..ccf37c596 100644 --- a/types/src/lib.rs +++ b/types/src/lib.rs @@ -11,12 +11,14 @@ pub mod either; pub mod format; pub mod module_name; pub mod resolve; +pub mod specialize; use crate::collections::IndexMap; use crate::module_name::ModuleName; use crate::resolve::TypeResolver; use std::cell::Cell; use std::collections::{HashMap, HashSet}; +use std::iter::once; use std::path::PathBuf; // The IDs of these built-in types must match the order of the fields in the @@ -24,7 +26,7 @@ use std::path::PathBuf; pub const INT_ID: u32 = 0; pub const FLOAT_ID: u32 = 1; pub const STRING_ID: u32 = 2; -pub const BOOLEAN_ID: u32 = 3; +pub const BOOL_ID: u32 = 3; pub const NIL_ID: u32 = 4; pub const BYTE_ARRAY_ID: u32 = 5; pub const CHANNEL_ID: u32 = 6; @@ -51,7 +53,7 @@ const INT_NAME: &str = "Int"; const FLOAT_NAME: &str = "Float"; const STRING_NAME: &str = "String"; const ARRAY_NAME: &str = "Array"; -const BOOLEAN_NAME: &str = "Bool"; +const BOOL_NAME: &str = "Bool"; const NIL_NAME: &str = "Nil"; const BYTE_ARRAY_NAME: &str = "ByteArray"; const CHANNEL_NAME: &str = "Channel"; @@ -297,7 +299,7 @@ impl TypeParameterId { } /// Type parameters and the types assigned to them. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct TypeArguments { /// We use a HashMap as parameters can be assigned in any order, and some /// may not be assigned at all. @@ -360,6 +362,16 @@ impl TypeArguments { } } } + + pub fn is_empty(&self) -> bool { + self.mapping.is_empty() + } + + pub fn iter( + &self, + ) -> std::collections::hash_map::Iter { + self.mapping.iter() + } } /// An Inko trait. @@ -930,6 +942,11 @@ pub struct Class { methods: HashMap, implemented_traits: HashMap, variants: IndexMap, + specialized: HashMap, ClassId>, + + /// The shapes of the type parameters of this class, in the same order as + /// the parameters. + shapes: Vec, } impl Class { @@ -968,6 +985,8 @@ impl Class { implemented_traits: HashMap::new(), variants: IndexMap::new(), module, + specialized: HashMap::new(), + shapes: Vec::new(), } } @@ -1036,7 +1055,7 @@ impl ClassId { } pub fn boolean() -> ClassId { - ClassId(BOOLEAN_ID) + ClassId(BOOL_ID) } pub fn nil() -> ClassId { @@ -1274,6 +1293,14 @@ impl ClassId { self.get(db).module } + pub fn set_shapes(self, db: &mut Database, shapes: Vec) { + self.get_mut(db).shapes = shapes; + } + + pub fn get_shapes(self, db: &Database) -> &Vec { + &self.get(db).shapes + } + pub fn number_of_type_parameters(self, db: &Database) -> usize { self.get(db).type_parameters.len() } @@ -1313,6 +1340,18 @@ impl ClassId { matches!(self.0, INT_ID | FLOAT_ID) } + fn shape(self, db: &Database, default: Shape) -> Shape { + match self.0 { + INT_ID => Shape::Int, + FLOAT_ID => Shape::Float, + BOOL_ID => Shape::Boolean, + STRING_ID => Shape::String, + CHANNEL_ID => Shape::Atomic, + _ if self.kind(db).is_async() => Shape::Atomic, + _ => default, + } + } + fn get(self, db: &Database) -> &Class { &db.classes[self.0 as usize] } @@ -1838,12 +1877,23 @@ pub struct Method { main: bool, variadic: bool, - /// The type of the receiver of the method, aka the type of `self` (not - /// `Self`). + /// The type of the receiver of the method. receiver: TypeRef, /// The fields this method has access to, along with their types. field_types: HashMap, + + /// The specializations of this method, if the method itself is generic. + /// + /// Each key is the combination of the receiver and method shapes, in the + /// same order as their type parameters. + specialized: HashMap, MethodId>, + + /// The shapes of type parameters exposed to this method. + /// + /// The order matches the order of type parameters of the receiver (if any), + /// followed by the type parameters of the method (if any). + shapes: Vec, } impl Method { @@ -1869,6 +1919,8 @@ impl Method { field_types: HashMap::new(), main: false, variadic: false, + specialized: HashMap::new(), + shapes: Vec::new(), }; db.methods.push(method); @@ -2086,6 +2138,49 @@ impl MethodId { self.has_return_type(db) && !self.return_type(db).is_never(db) } + pub fn add_specialization( + self, + db: &mut Database, + shapes: Vec, + method: MethodId, + ) { + self.get_mut(db).specialized.insert(shapes, method); + } + + pub fn get_specialization( + self, + db: &Database, + shapes: &[Shape], + ) -> Option { + self.get(db).specialized.get(shapes).cloned() + } + + pub fn set_shapes(self, db: &mut Database, shapes: Vec) { + self.get_mut(db).shapes = shapes; + } + + pub fn get_shapes(self, db: &Database) -> &Vec { + &self.get(db).shapes + } + + pub fn clone_for_specialization( + self, + db: &mut Database, + shapes: Vec, + ) -> MethodId { + let (module, name, vis, kind, source) = { + let old = self.get(db); + + (old.module, old.name.clone(), old.visibility, old.kind, old.source) + }; + + let new = Method::alloc(db, module, name, vis, kind); + + new.set_source(db, source); + new.set_shapes(db, shapes); + new + } + fn get(self, db: &Database) -> &Method { &db.methods[self.0] } @@ -2184,6 +2279,7 @@ pub struct CallInfo { pub receiver: Receiver, pub returns: TypeRef, pub dynamic: bool, + pub type_arguments: TypeArguments, } #[derive(Clone, Debug, PartialEq, Eq)] @@ -2644,21 +2740,32 @@ impl ClosureId { db: &mut Database, value_type: TypeRef, ) { - let lambda = self.get_mut(db); - let name = lambda.arguments.len().to_string(); + let closure = self.get_mut(db); // Anonymous arguments can never be used, so the variable ID is never // used. As such we just set it to ID 0 so we don't need to wrap it in // an `Option` type. let var = VariableId(0); - lambda.arguments.new_argument(name, value_type, var); + closure.arguments.new_argument("_".to_string(), value_type, var); } pub fn is_moving(self, db: &Database) -> bool { self.get(db).moving } + pub fn contains_generic_types(self, db: &Database) -> bool { + let closure = self.get(db); + + closure + .arguments + .iter() + .map(|arg| arg.value_type) + .chain(once(closure.return_type)) + .chain(closure.captured.iter().map(|v| v.1)) + .any(|typ| typ.is_generic(db) || typ.is_type_parameter(db)) + } + pub fn set_captured_self_type( self, db: &mut Database, @@ -2734,6 +2841,55 @@ impl Block for ClosureId { } } +/// A type describing the "shape" of a type, which describes its size on the +/// stack, how to create aliases, etc. +#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] +pub enum Shape { + /// An owned value addressed through a pointer. + Owned, + + /// A mutable reference to a value. + Mut, + + /// An immutable reference to a value. + Ref, + + /// A 64-bits unboxed integer. + /// + /// These values are passed around using a simple copy. + Int, + + /// A 64-bits unboxed float. + /// + /// These values are passed around using a simple copy. In native code, + /// these values use the appropriate floating point registers. + Float, + + /// The value is a boolean. + Boolean, + + /// The value is a string. + String, + + /// The value is an owned value that uses atomic reference counting. + Atomic, +} + +impl Shape { + pub fn identifier(&self) -> &'static str { + match self { + Shape::Owned => "o", + Shape::Mut => "m", + Shape::Ref => "r", + Shape::Int => "i", + Shape::Float => "f", + Shape::Boolean => "b", + Shape::String => "s", + Shape::Atomic => "a", + } + } +} + /// A reference to a type. #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] pub enum TypeRef { @@ -2805,7 +2961,7 @@ impl TypeRef { pub fn boolean() -> TypeRef { TypeRef::Owned(TypeId::ClassInstance(ClassInstance::new(ClassId( - BOOLEAN_ID, + BOOL_ID, )))) } @@ -3197,6 +3353,7 @@ impl TypeRef { TypeRef::Owned(_) => true, TypeRef::Ref(TypeId::ClassInstance(ins)) => { ins.instance_of.is_value_type(db) + && !ins.instance_of().kind(db).is_async() } TypeRef::Placeholder(id) => { id.value(db).map_or(false, |v| v.allow_mutating(db)) @@ -3339,9 +3496,16 @@ impl TypeRef { pub fn as_mut(self, db: &Database) -> Self { match self { - TypeRef::Owned(TypeId::RigidTypeParameter(id)) => { - if id.is_mutable(db) { - TypeRef::Mut(TypeId::RigidTypeParameter(id)) + TypeRef::Owned( + id @ TypeId::RigidTypeParameter(pid) + | id @ TypeId::TypeParameter(pid), + ) + | TypeRef::Infer( + id @ TypeId::RigidTypeParameter(pid) + | id @ TypeId::TypeParameter(pid), + ) => { + if pid.is_mutable(db) { + TypeRef::Mut(id) } else { self } @@ -3426,7 +3590,7 @@ impl TypeRef { } } - pub fn as_type_parameter(self) -> Option { + pub fn as_type_parameter(self, db: &Database) -> Option { match self { TypeRef::Owned(TypeId::TypeParameter(id)) | TypeRef::Uni(TypeId::TypeParameter(id)) @@ -3440,6 +3604,9 @@ impl TypeRef { | TypeRef::UniRef(TypeId::RigidTypeParameter(id)) | TypeRef::UniMut(TypeId::RigidTypeParameter(id)) | TypeRef::Infer(TypeId::RigidTypeParameter(id)) => Some(id), + TypeRef::Placeholder(id) => { + id.value(db).and_then(|v| v.as_type_parameter(db)) + } _ => None, } } @@ -3493,7 +3660,7 @@ impl TypeRef { | TypeRef::UniMut(TypeId::ClassInstance(ins)) | TypeRef::UniRef(TypeId::ClassInstance(ins)) => { ins.instance_of.kind(db).is_extern() - || matches!(ins.instance_of.0, BOOLEAN_ID | NIL_ID) + || matches!(ins.instance_of.0, BOOL_ID | NIL_ID) } TypeRef::Owned(TypeId::Foreign(_)) => true, TypeRef::Owned(TypeId::Module(_)) => true, @@ -3626,6 +3793,61 @@ impl TypeRef { ))) } + pub fn shape( + self, + db: &Database, + shapes: &HashMap, + ) -> Shape { + match self { + TypeRef::Owned(TypeId::ClassInstance(ins)) + | TypeRef::Uni(TypeId::ClassInstance(ins)) => { + ins.instance_of.shape(db, Shape::Owned) + } + TypeRef::Mut(TypeId::ClassInstance(ins)) + | TypeRef::UniMut(TypeId::ClassInstance(ins)) => { + ins.instance_of.shape(db, Shape::Mut) + } + TypeRef::Ref(TypeId::ClassInstance(ins)) + | TypeRef::UniRef(TypeId::ClassInstance(ins)) => { + ins.instance_of.shape(db, Shape::Ref) + } + TypeRef::Infer( + TypeId::TypeParameter(id) | TypeId::RigidTypeParameter(id), + ) + | TypeRef::Owned( + TypeId::TypeParameter(id) | TypeId::RigidTypeParameter(id), + ) => shapes.get(&id).cloned().unwrap_or(Shape::Owned), + + // For `ref T` and `mut T`, where T is a type parameter assigned a + // value type, we return shape(T), otherwise we return the shape of + // the ownership. + TypeRef::Mut( + TypeId::TypeParameter(id) | TypeId::RigidTypeParameter(id), + ) + | TypeRef::UniMut( + TypeId::TypeParameter(id) | TypeId::RigidTypeParameter(id), + ) => match shapes.get(&id).cloned() { + Some(Shape::Owned) | None => Shape::Mut, + Some(shape) => shape, + }, + TypeRef::Ref( + TypeId::TypeParameter(id) | TypeId::RigidTypeParameter(id), + ) + | TypeRef::UniRef( + TypeId::TypeParameter(id) | TypeId::RigidTypeParameter(id), + ) => match shapes.get(&id).cloned() { + Some(Shape::Owned) | None => Shape::Ref, + Some(shape) => shape, + }, + TypeRef::Mut(_) | TypeRef::UniMut(_) => Shape::Mut, + TypeRef::Ref(_) | TypeRef::UniRef(_) => Shape::Ref, + TypeRef::Placeholder(id) => { + id.value(db).map_or(Shape::Int, |v| v.shape(db, shapes)) + } + _ => Shape::Owned, + } + } + fn is_instance_of(self, db: &Database, id: ClassId) -> bool { self.class_id(db) == Some(id) } @@ -3785,7 +4007,7 @@ impl Database { Class::value_type(INT_NAME.to_string()), Class::value_type(FLOAT_NAME.to_string()), Class::atomic(STRING_NAME.to_string()), - Class::value_type(BOOLEAN_NAME.to_string()), + Class::value_type(BOOL_NAME.to_string()), Class::value_type(NIL_NAME.to_string()), Class::regular(BYTE_ARRAY_NAME.to_string()), Class::atomic(CHANNEL_NAME.to_string()), @@ -3828,7 +4050,7 @@ impl Database { FLOAT_NAME => Some(ClassId::float()), STRING_NAME => Some(ClassId::string()), ARRAY_NAME => Some(ClassId::array()), - BOOLEAN_NAME => Some(ClassId::boolean()), + BOOL_NAME => Some(ClassId::boolean()), NIL_NAME => Some(ClassId::nil()), BYTE_ARRAY_NAME => Some(ClassId::byte_array()), CHANNEL_NAME => Some(ClassId::channel()), @@ -3883,6 +4105,10 @@ impl Database { self.trait_in_module(DROP_MODULE, DROP_TRAIT) } + pub fn number_of_traits(&self) -> usize { + self.traits.len() + } + pub fn number_of_modules(&self) -> usize { self.modules.len() } @@ -3891,6 +4117,16 @@ impl Database { self.classes.len() } + pub fn last_class_id(&self) -> ClassId { + ClassId((self.number_of_classes() - 1) as _) + } + + pub fn classes_since(&self, id: ClassId) -> Vec { + ((id.0 + 1)..(self.number_of_classes() as u32)) + .map(|idx| ClassId(idx)) + .collect() + } + pub fn number_of_methods(&self) -> usize { self.methods.len() } @@ -3924,8 +4160,8 @@ impl Database { mod tests { use super::*; use crate::test::{ - closure, immutable, instance, mutable, new_class, new_parameter, owned, - parameter, placeholder, rigid, uni, + closure, generic_instance_id, immutable, instance, mutable, new_class, + new_parameter, owned, parameter, placeholder, rigid, uni, }; use std::mem::size_of; @@ -4428,7 +4664,7 @@ mod tests { assert_eq!(&db.classes[FLOAT_ID as usize].name, FLOAT_NAME); assert_eq!(&db.classes[STRING_ID as usize].name, STRING_NAME); assert_eq!(&db.classes[ARRAY_ID as usize].name, ARRAY_NAME); - assert_eq!(&db.classes[BOOLEAN_ID as usize].name, BOOLEAN_NAME); + assert_eq!(&db.classes[BOOL_ID as usize].name, BOOL_NAME); assert_eq!(&db.classes[NIL_ID as usize].name, NIL_NAME); assert_eq!(&db.classes[BYTE_ARRAY_ID as usize].name, BYTE_ARRAY_NAME); assert_eq!(&db.classes[CHANNEL_ID as usize].name, CHANNEL_NAME); @@ -4557,6 +4793,10 @@ mod tests { ); assert_eq!(owned(rigid(param1)).as_mut(&db), owned(rigid(param1))); assert_eq!(owned(rigid(param2)).as_mut(&db), mutable(rigid(param2))); + assert_eq!( + owned(parameter(param2)).as_mut(&db), + mutable(parameter(param2)) + ); } #[test] @@ -4639,4 +4879,84 @@ mod tests { assert!(bla_mod.has_same_root_namespace(&db, test_mod)); assert!(!test_mod.has_same_root_namespace(&db, bla_mod)); } + + #[test] + fn test_type_ref_shape() { + let mut db = Database::new(); + let string = ClassId::string(); + let int = ClassId::int(); + let float = ClassId::float(); + let boolean = ClassId::boolean(); + let class = new_class(&mut db, "Thing"); + let var = TypePlaceholder::alloc(&mut db, None); + let param = new_parameter(&mut db, "T"); + let mut shapes = HashMap::new(); + + shapes.insert(param, Shape::Int); + var.assign(&db, TypeRef::int()); + + assert_eq!(TypeRef::int().shape(&db, &shapes), Shape::Int); + assert_eq!(TypeRef::float().shape(&db, &shapes), Shape::Float); + assert_eq!(TypeRef::boolean().shape(&db, &shapes), Shape::Boolean); + assert_eq!(TypeRef::nil().shape(&db, &shapes), Shape::Owned); + assert_eq!(TypeRef::string().shape(&db, &shapes), Shape::String); + assert_eq!(owned(instance(class)).shape(&db, &shapes), Shape::Owned); + assert_eq!(immutable(instance(class)).shape(&db, &shapes), Shape::Ref); + assert_eq!(mutable(instance(class)).shape(&db, &shapes), Shape::Mut); + assert_eq!(uni(instance(class)).shape(&db, &shapes), Shape::Owned); + assert_eq!(placeholder(var).shape(&db, &shapes), Shape::Int); + assert_eq!(owned(parameter(param)).shape(&db, &shapes), Shape::Int); + assert_eq!(immutable(parameter(param)).shape(&db, &shapes), Shape::Int); + assert_eq!(mutable(parameter(param)).shape(&db, &shapes), Shape::Int); + assert_eq!( + immutable(instance(string)).shape(&db, &shapes), + Shape::String + ); + assert_eq!(immutable(instance(int)).shape(&db, &shapes), Shape::Int); + assert_eq!( + immutable(instance(float)).shape(&db, &shapes), + Shape::Float + ); + assert_eq!( + immutable(instance(boolean)).shape(&db, &shapes), + Shape::Boolean + ); + assert_eq!( + mutable(instance(string)).shape(&db, &shapes), + Shape::String + ); + assert_eq!(mutable(instance(int)).shape(&db, &shapes), Shape::Int); + assert_eq!(mutable(instance(float)).shape(&db, &shapes), Shape::Float); + assert_eq!( + mutable(instance(boolean)).shape(&db, &shapes), + Shape::Boolean + ); + assert_eq!( + owned(generic_instance_id( + &mut db, + ClassId::channel(), + vec![TypeRef::int()] + )) + .shape(&db, &shapes), + Shape::Atomic + ); + } + + #[test] + fn test_database_last_class_id() { + let mut db = Database::new(); + let class = new_class(&mut db, "A"); + + assert_eq!(db.last_class_id(), class); + } + + #[test] + fn test_database_classes_since() { + let mut db = Database::new(); + let class1 = new_class(&mut db, "A"); + let class2 = new_class(&mut db, "A"); + let class3 = new_class(&mut db, "A"); + + assert_eq!(db.classes_since(class1), vec![class2, class3]); + } } diff --git a/types/src/specialize.rs b/types/src/specialize.rs new file mode 100644 index 000000000..439a629fc --- /dev/null +++ b/types/src/specialize.rs @@ -0,0 +1,422 @@ +use crate::{ + Block, Class, ClassId, ClassInstance, Closure, ClosureId, Database, Shape, + TypeId, TypeParameterId, TypeRef, +}; +use std::collections::HashMap; + +/// A type which takes a (potentially) generic type, and specializes it and its +/// fields (if it has any). +/// +/// This type handles only type signatures, closure _literals_ are not +/// specialized; instead the compiler does this itself in its specialization +/// pass. +pub struct TypeSpecializer<'a, 'b> { + db: &'a mut Database, + + /// A cache of existing shapes to use when encountering a type parameter. + /// + /// When specializing a class, it may have fields or variants that are or + /// contain its type parameter (e.g. `Array[T]` for a `Foo[T]`). When + /// encountering such types, we need to reuse the shape of the type + /// parameter as it was determined when creating the newly specialized + /// class. + shapes: &'b HashMap, +} + +impl<'a, 'b> TypeSpecializer<'a, 'b> { + pub fn new( + db: &'a mut Database, + shapes: &'b HashMap, + ) -> TypeSpecializer<'a, 'b> { + TypeSpecializer { db, shapes } + } + + pub fn specialize(&mut self, value: TypeRef) -> TypeRef { + match value { + // When specializing type parameters, we have to reuse existing + // shapes if there are any. This leads to a bit of duplication, but + // there's not really a way around that without making things more + // complicated than they already are. + TypeRef::Owned( + TypeId::TypeParameter(id) | TypeId::RigidTypeParameter(id), + ) + | TypeRef::Infer( + TypeId::TypeParameter(id) | TypeId::RigidTypeParameter(id), + ) + | TypeRef::Uni( + TypeId::TypeParameter(id) | TypeId::RigidTypeParameter(id), + ) => match self.shapes.get(&id) { + Some(Shape::Int) => TypeRef::int(), + Some(Shape::Float) => TypeRef::float(), + Some(Shape::Boolean) => TypeRef::boolean(), + Some(Shape::String) => TypeRef::string(), + Some(Shape::Ref) => value.as_ref(self.db), + Some(Shape::Mut) => value.as_mut(self.db), + _ => value, + }, + TypeRef::Ref( + TypeId::TypeParameter(id) | TypeId::RigidTypeParameter(id), + ) + | TypeRef::UniRef( + TypeId::TypeParameter(id) | TypeId::RigidTypeParameter(id), + ) => match self.shapes.get(&id) { + Some(Shape::Int) => TypeRef::int(), + Some(Shape::Float) => TypeRef::float(), + Some(Shape::Boolean) => TypeRef::boolean(), + Some(Shape::String) => TypeRef::string(), + _ => value.as_ref(self.db), + }, + TypeRef::Mut( + TypeId::TypeParameter(id) | TypeId::RigidTypeParameter(id), + ) + | TypeRef::UniMut( + TypeId::TypeParameter(id) | TypeId::RigidTypeParameter(id), + ) => match self.shapes.get(&id) { + Some(Shape::Int) => TypeRef::int(), + Some(Shape::Float) => TypeRef::float(), + Some(Shape::Boolean) => TypeRef::boolean(), + Some(Shape::String) => TypeRef::string(), + Some(Shape::Ref) => value.as_ref(self.db), + _ => value.as_mut(self.db), + }, + + TypeRef::Owned(id) | TypeRef::Infer(id) => { + TypeRef::Owned(self.specialize_type_id(id)) + } + TypeRef::Uni(id) => TypeRef::Uni(self.specialize_type_id(id)), + + // Value types should always be specialized as owned types, even + // when using e.g. `ref Int`. + TypeRef::Ref(TypeId::ClassInstance(ins)) + | TypeRef::Mut(TypeId::ClassInstance(ins)) + | TypeRef::UniRef(TypeId::ClassInstance(ins)) + | TypeRef::UniMut(TypeId::ClassInstance(ins)) + if ins.instance_of().is_value_type(self.db) => + { + TypeRef::Owned( + self.specialize_type_id(TypeId::ClassInstance(ins)), + ) + } + + TypeRef::Ref(id) => TypeRef::Ref(self.specialize_type_id(id)), + TypeRef::Mut(id) => TypeRef::Mut(self.specialize_type_id(id)), + TypeRef::UniRef(id) => TypeRef::UniRef(self.specialize_type_id(id)), + TypeRef::UniMut(id) => TypeRef::UniMut(self.specialize_type_id(id)), + TypeRef::Placeholder(id) => { + id.value(self.db).map_or(value, |v| self.specialize(v)) + } + _ => value, + } + } + + fn specialize_type_id(&mut self, id: TypeId) -> TypeId { + match id { + TypeId::ClassInstance(ins) + if ins.instance_of.is_generic(self.db) => + { + TypeId::ClassInstance(self.specialize_class_instance(ins)) + } + TypeId::Closure(id) if id.contains_generic_types(self.db) => { + TypeId::Closure(self.specialize_closure(id)) + } + _ => id, + } + } + + fn specialize_class_instance( + &mut self, + ins: ClassInstance, + ) -> ClassInstance { + let class = ins.instance_of; + let mut args = ins.type_arguments(self.db).clone(); + let mut shapes = Vec::new(); + + for param in class.type_parameters(self.db) { + let arg = self.specialize(args.get(param).unwrap()); + let shape = arg.shape(self.db, self.shapes); + + shapes.push(shape); + args.assign(param, arg); + } + + let new = class + .get(self.db) + .specialized + .get(&shapes) + .cloned() + .unwrap_or_else(|| self.specialize_class(class, shapes)); + + ClassInstance::generic(self.db, new, args) + } + + fn specialize_class(&mut self, id: ClassId, shapes: Vec) -> ClassId { + let (name, kind, vis, module) = { + let cls = id.get(self.db); + + (cls.name.clone(), cls.kind, cls.visibility, cls.module) + }; + + let new = Class::alloc(self.db, name, kind, vis, module); + + // We just copy over the type parameters as-is, as there's nothing + // stored in them that we can't share between the different class + // specializations. + for param in id.type_parameters(self.db) { + let name = param.name(self.db).clone(); + + new.get_mut(self.db).type_parameters.insert(name, param); + } + + new.set_shapes(self.db, shapes.clone()); + + // When specializing fields and variants, we want them to reuse the + // shapes we just created. + let mapping = id + .type_parameters(self.db) + .into_iter() + .zip(shapes.iter()) + .fold(HashMap::new(), |mut map, (param, &shape)| { + map.insert(param, shape); + map + }); + + if id.kind(self.db).is_enum() { + for old_var in id.get(self.db).variants.values().clone() { + let name = old_var.name(self.db).clone(); + let members = old_var + .members(self.db) + .into_iter() + .map(|v| { + let typ = TypeSpecializer::new(self.db, &mapping) + .specialize(v); + typ + }) + .collect(); + + new.new_variant(self.db, name, members); + } + } + + for (idx, old_field) in id.fields(self.db).into_iter().enumerate() { + let (name, orig_typ, vis, module) = { + let field = old_field.get(self.db); + + ( + field.name.clone(), + field.value_type, + field.visibility, + field.module, + ) + }; + + let typ = + TypeSpecializer::new(self.db, &mapping).specialize(orig_typ); + + new.new_field(self.db, name, idx as _, typ, vis, module); + } + + id.get_mut(self.db).specialized.insert(shapes, new); + new + } + + fn specialize_closure(&mut self, id: ClosureId) -> ClosureId { + let moving = id.is_moving(self.db); + let new = Closure::alloc(self.db, moving); + + for arg in id.arguments(self.db) { + let typ = self.specialize(arg.value_type); + + new.new_anonymous_argument(self.db, typ) + } + + let ret = self.specialize(id.return_type(self.db)); + + new.set_return_type(self.db, ret); + new + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::format::format_type; + use crate::test::{ + closure, generic_instance_id, immutable, infer, instance, mutable, + new_parameter, owned, parameter, rigid, + }; + use crate::{ClassId, ModuleId, TypeParameter, Visibility}; + + #[test] + fn test_specialize_array() { + let mut db = Database::new(); + let class = ClassId::array(); + let shapes = HashMap::new(); + + class.new_type_parameter(&mut db, "T".to_string()); + + let int = TypeRef::int(); + let raw1 = owned(generic_instance_id(&mut db, class, vec![int])); + let raw2 = owned(generic_instance_id(&mut db, class, vec![int])); + let spec1 = TypeSpecializer::new(&mut db, &shapes).specialize(raw1); + let spec2 = TypeSpecializer::new(&mut db, &shapes).specialize(raw2); + + assert_eq!(format_type(&db, spec1), "Array[Int]"); + assert_eq!(format_type(&db, spec2), "Array[Int]"); + assert_eq!(class.get(&db).specialized.len(), 1); + + let new_class = + *class.get(&db).specialized.get(&vec![Shape::Int]).unwrap(); + + assert_eq!(new_class.kind(&db), class.kind(&db)); + assert_eq!(new_class.get(&db).visibility, class.get(&db).visibility); + assert_eq!(new_class.module(&db), class.module(&db)); + + // This is to test if we reuse the cached results, instead of just + // creating a new specialized class every time. + assert!(matches!( + spec1, + TypeRef::Owned(TypeId::ClassInstance(ins)) if ins.instance_of == new_class + )); + assert!(matches!( + spec2, + TypeRef::Owned(TypeId::ClassInstance(ins)) if ins.instance_of == new_class + )); + } + + #[test] + fn test_specialize_array_with_ref_value_types() { + let mut db = Database::new(); + let class = ClassId::array(); + let shapes = HashMap::new(); + + class.new_type_parameter(&mut db, "T".to_string()); + + let raw = owned(generic_instance_id( + &mut db, + class, + vec![immutable(instance(ClassId::int()))], + )); + let spec = TypeSpecializer::new(&mut db, &shapes).specialize(raw); + + assert_eq!(format_type(&db, spec), "Array[Int]"); + } + + #[test] + fn test_specialize_class_with_fields() { + let mut db = Database::new(); + let tup = ClassId::tuple3(); + let param1 = tup.new_type_parameter(&mut db, "A".to_string()); + let param2 = tup.new_type_parameter(&mut db, "B".to_string()); + let param3 = tup.new_type_parameter(&mut db, "C".to_string()); + + param3.set_mutable(&mut db); + + let rigid1 = new_parameter(&mut db, "X"); + let rigid2 = new_parameter(&mut db, "Y"); + + rigid2.set_mutable(&mut db); + + tup.new_field( + &mut db, + "0".to_string(), + 0, + infer(parameter(param1)), + Visibility::Public, + ModuleId(0), + ); + + tup.new_field( + &mut db, + "1".to_string(), + 1, + infer(parameter(param2)), + Visibility::Public, + ModuleId(0), + ); + + tup.new_field( + &mut db, + "2".to_string(), + 2, + infer(parameter(param3)), + Visibility::Public, + ModuleId(0), + ); + + let mut shapes = HashMap::new(); + + shapes.insert(rigid1, Shape::Owned); + shapes.insert(rigid2, Shape::Owned); + + let raw = owned(generic_instance_id( + &mut db, + tup, + vec![ + TypeRef::int(), + immutable(rigid(rigid1)), + mutable(rigid(rigid2)), + ], + )); + + let spec = TypeSpecializer::new(&mut db, &shapes).specialize(raw); + + assert_eq!(format_type(&db, spec), "(Int, ref X, mut Y: mut)"); + + let ins = if let TypeRef::Owned(TypeId::ClassInstance(ins)) = spec { + ins + } else { + panic!("Expected an owned class instance"); + }; + + assert_ne!(ins.instance_of(), tup); + assert!(ins.instance_of().kind(&db).is_tuple()); + assert_eq!( + ins.instance_of().field_by_index(&db, 0).unwrap().value_type(&db), + TypeRef::int(), + ); + + assert_eq!( + ins.instance_of().field_by_index(&db, 1).unwrap().value_type(&db), + immutable(parameter(param2)), + ); + + assert_eq!( + ins.instance_of().field_by_index(&db, 2).unwrap().value_type(&db), + mutable(parameter(param3)), + ); + } + + #[test] + fn test_specialize_enum_class() { + // TODO: write tests + } + + #[test] + fn test_specialize_closure_type() { + let mut db = Database::new(); + let orig = Closure::alloc(&mut db, true); + let param = TypeParameter::alloc(&mut db, "A".to_string()); + + ClassId::array().new_type_parameter(&mut db, "T".to_string()); + + let array = owned(generic_instance_id( + &mut db, + ClassId::array(), + vec![owned(parameter(param))], + )); + + orig.new_anonymous_argument(&mut db, array); + orig.set_return_type(&mut db, array); + + let mut shapes = HashMap::new(); + + shapes.insert(param, Shape::Int); + + let new = TypeSpecializer::new(&mut db, &shapes) + .specialize(owned(closure(orig))); + + // Testing the type name is easier than writing a giant match in an + // assert, and is accurate enough. + assert_eq!(format_type(&db, new), "fn move (Array[Int]) -> Array[Int]"); + } +}