diff --git a/pyrefly/lib/alt/class/class_metadata.rs b/pyrefly/lib/alt/class/class_metadata.rs index fa41a8b26..755cdca90 100644 --- a/pyrefly/lib/alt/class/class_metadata.rs +++ b/pyrefly/lib/alt/class/class_metadata.rs @@ -28,6 +28,7 @@ use crate::alt::types::class_metadata::DataclassMetadata; use crate::alt::types::class_metadata::EnumMetadata; use crate::alt::types::class_metadata::NamedTupleMetadata; use crate::alt::types::class_metadata::ProtocolMetadata; +use crate::alt::types::class_metadata::TotalOrderingMetadata; use crate::alt::types::class_metadata::TypedDictMetadata; use crate::binding::binding::Key; use crate::binding::binding::KeyLegacyTypeParam; @@ -49,7 +50,7 @@ use crate::types::types::Type; /// Private helper type used to share part of the logic needed for the /// binding-level work of finding legacy type parameters versus the type-level -/// work of computing inherticance information and the MRO. +/// work of computing inheritance information and the MRO. #[derive(Debug, Clone)] pub enum BaseClass { TypedDict, @@ -153,7 +154,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { cls: &Class, bases: &[Expr], keywords: &[(Name, Expr)], - decorators: &[Idx], + decorators: &[(Idx, TextRange)], is_new_type: bool, special_base: &Option>, errors: &ErrorCollector, @@ -405,8 +406,9 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } } let mut is_final = false; - for decorator in decorators { - let decorator = self.get_idx(*decorator); + let mut total_ordering_metadata = None; + for (decorator_key, decorator_range) in decorators { + let decorator = self.get_idx(*decorator_key); match decorator.ty().callee_kind() { Some(CalleeKind::Function(FunctionKind::Dataclass(kws))) => { let dataclass_fields = self.get_dataclass_fields(cls, &bases_with_metadata); @@ -431,6 +433,11 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ); } } + Some(CalleeKind::Function(FunctionKind::TotalOrdering)) => { + total_ordering_metadata = Some(TotalOrderingMetadata { + location: *decorator_range, + }); + } _ => {} } } @@ -477,6 +484,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { is_new_type, is_final, has_unknown_tparams, + total_ordering_metadata, errors, ) } diff --git a/pyrefly/lib/alt/class/mod.rs b/pyrefly/lib/alt/class/mod.rs index 9903d9cb1..5d947b164 100644 --- a/pyrefly/lib/alt/class/mod.rs +++ b/pyrefly/lib/alt/class/mod.rs @@ -13,5 +13,6 @@ pub mod enums; pub mod named_tuple; pub mod new_type; pub mod targs; +pub mod total_ordering; pub mod typed_dict; pub mod variance_inference; diff --git a/pyrefly/lib/alt/class/total_ordering.rs b/pyrefly/lib/alt/class/total_ordering.rs new file mode 100644 index 000000000..9685562b0 --- /dev/null +++ b/pyrefly/lib/alt/class/total_ordering.rs @@ -0,0 +1,85 @@ +use ruff_python_ast::name::Name; +use starlark_map::small_map::SmallMap; + +use crate::alt::answers::AnswersSolver; +use crate::alt::answers::LookupAnswer; +use crate::alt::types::class_metadata::ClassSynthesizedField; +use crate::alt::types::class_metadata::ClassSynthesizedFields; +use crate::binding::binding::KeyClassField; +use crate::dunder; +use crate::error::collector::ErrorCollector; +use crate::error::kind::ErrorKind; +use crate::types::class::Class; + +// https://github.com/python/cpython/blob/a8ec511900d0d84cffbb4ee6419c9a790d131129/Lib/functools.py#L173 +// conversion order of rich comparison methods: +const LT_CONVERSION_ORDER: &[Name; 3] = &[dunder::GT, dunder::LE, dunder::GE]; +const GT_CONVERSION_ORDER: &[Name; 3] = &[dunder::LT, dunder::GE, dunder::LE]; +const LE_CONVERSION_ORDER: &[Name; 3] = &[dunder::GE, dunder::LT, dunder::GT]; +const GE_CONVERSION_ORDER: &[Name; 3] = &[dunder::LE, dunder::GT, dunder::LT]; + +impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { + fn synthesize_rich_cmp(&self, cls: &Class, cmp: &Name) -> ClassSynthesizedField { + let conversion_order = if cmp == &dunder::LT { + LT_CONVERSION_ORDER + } else if cmp == &dunder::GT { + GT_CONVERSION_ORDER + } else if cmp == &dunder::LE { + LE_CONVERSION_ORDER + } else if cmp == &dunder::GE { + GE_CONVERSION_ORDER + } else { + unreachable!("Unexpected rich comparison method: {}", cmp); + }; + // The first field in the conversion order is the one that we will use to synthesize the method. + for other_cmp in conversion_order { + let other_cmp_field = cls.fields().find(|f| **f == *other_cmp); + if other_cmp_field.is_some() { + let other_cmp_field = + self.get_from_class(cls, &KeyClassField(cls.index(), other_cmp.clone())); + let ty = other_cmp_field.as_named_tuple_type(); + return ClassSynthesizedField::new(ty); + } + } + unreachable!("No rich comparison method found for {}", cmp); + } + + pub fn get_total_ordering_synthesized_fields( + &self, + errors: &ErrorCollector, + cls: &Class, + ) -> Option { + let metadata = self.get_metadata_for_class(cls); + if !metadata.is_total_ordering() { + return None; + } + // The class must have one of the rich comparison dunder methods defined + if !cls + .fields() + .any(|f| *f == dunder::LT || *f == dunder::LE || *f == dunder::GT || *f == dunder::GE) + { + let total_ordering_metadata = metadata.total_ordering_metadata().unwrap(); + self.error( + errors, + total_ordering_metadata.location, + ErrorKind::MissingAttribute, + None, + format!( + "Class `{}` must define at least one of the rich comparison methods.", + cls.name() + ), + ); + return None; + } + let rich_cmps_to_synthesize: Vec<_> = dunder::RICH_CMPS_TOTAL_ORDERING + .iter() + .filter(|cmp| !cls.contains(cmp)) + .collect(); + let mut fields = SmallMap::new(); + for cmp in rich_cmps_to_synthesize { + let synthesized_field = self.synthesize_rich_cmp(cls, cmp); + fields.insert(cmp.clone(), synthesized_field); + } + Some(ClassSynthesizedFields::new(fields)) + } +} diff --git a/pyrefly/lib/alt/solve.rs b/pyrefly/lib/alt/solve.rs index b53f8e40b..120ade78a 100644 --- a/pyrefly/lib/alt/solve.rs +++ b/pyrefly/lib/alt/solve.rs @@ -1326,16 +1326,30 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { pub fn solve_class_synthesized_fields( &self, + errors: &ErrorCollector, fields: &BindingClassSynthesizedFields, ) -> Arc { let fields = match &self.get_idx(fields.0).0 { None => ClassSynthesizedFields::default(), - Some(cls) => self - .get_typed_dict_synthesized_fields(cls) - .or_else(|| self.get_dataclass_synthesized_fields(cls)) - .or_else(|| self.get_named_tuple_synthesized_fields(cls)) - .or_else(|| self.get_new_type_synthesized_fields(cls)) - .unwrap_or_default(), + Some(cls) => { + let mut fields = ClassSynthesizedFields::default(); + if let Some(new_fields) = self.get_typed_dict_synthesized_fields(cls) { + fields = fields.combine(new_fields); + } + if let Some(new_fields) = self.get_dataclass_synthesized_fields(cls) { + fields = fields.combine(new_fields); + } + if let Some(new_fields) = self.get_named_tuple_synthesized_fields(cls) { + fields = fields.combine(new_fields); + } + if let Some(new_fields) = self.get_new_type_synthesized_fields(cls) { + fields = fields.combine(new_fields); + } + if let Some(new_fields) = self.get_total_ordering_synthesized_fields(errors, cls) { + fields = fields.combine(new_fields); + } + fields + } }; Arc::new(fields) } diff --git a/pyrefly/lib/alt/traits.rs b/pyrefly/lib/alt/traits.rs index 21b366b9d..0ee02bed7 100644 --- a/pyrefly/lib/alt/traits.rs +++ b/pyrefly/lib/alt/traits.rs @@ -262,9 +262,9 @@ impl Solve for KeyClassSynthesizedFields { fn solve( answers: &AnswersSolver, binding: &BindingClassSynthesizedFields, - _errors: &ErrorCollector, + errors: &ErrorCollector, ) -> Arc { - answers.solve_class_synthesized_fields(binding) + answers.solve_class_synthesized_fields(errors, binding) } fn create_recursive(_: &AnswersSolver, _: &Self::Value) -> Self::Recursive {} diff --git a/pyrefly/lib/alt/types/class_metadata.rs b/pyrefly/lib/alt/types/class_metadata.rs index d7bd0823b..13457d186 100644 --- a/pyrefly/lib/alt/types/class_metadata.rs +++ b/pyrefly/lib/alt/types/class_metadata.rs @@ -16,6 +16,7 @@ use pyrefly_derive::VisitMut; use pyrefly_util::display::commas_iter; use pyrefly_util::visit::VisitMut; use ruff_python_ast::name::Name; +use ruff_text_size::TextRange; use starlark_map::small_map::SmallMap; use starlark_map::small_set::SmallSet; use vec1::Vec1; @@ -49,6 +50,7 @@ pub struct ClassMetadata { /// Is it possible for this class to have type parameters that we don't know about? /// This can happen if, e.g., a class inherits from Any. has_unknown_tparams: bool, + total_ordering_metadata: Option, } impl VisitMut for ClassMetadata { @@ -80,6 +82,7 @@ impl ClassMetadata { is_new_type: bool, is_final: bool, has_unknown_tparams: bool, + total_ordering_metadata: Option, errors: &ErrorCollector, ) -> ClassMetadata { let mro = Mro::new(cls, &bases_with_metadata, errors); @@ -103,6 +106,7 @@ impl ClassMetadata { is_new_type, is_final, has_unknown_tparams, + total_ordering_metadata, } } @@ -166,6 +170,7 @@ impl ClassMetadata { is_new_type: false, is_final: false, has_unknown_tparams: false, + total_ordering_metadata: None, } } @@ -228,6 +233,14 @@ impl ClassMetadata { self.enum_metadata.is_some() } + pub fn is_total_ordering(&self) -> bool { + self.total_ordering_metadata.is_some() + } + + pub fn total_ordering_metadata(&self) -> Option<&TotalOrderingMetadata> { + self.total_ordering_metadata.as_ref() + } + pub fn protocol_metadata(&self) -> Option<&ProtocolMetadata> { self.protocol_metadata.as_ref() } @@ -296,6 +309,15 @@ impl ClassSynthesizedFields { pub fn get(&self, name: &Name) -> Option<&ClassSynthesizedField> { self.0.get(name) } + + /// Combines two sets of synthesized fields, with the second set + /// overwriting any fields in the first set that have the same name. + pub fn combine(mut self, other: Self) -> Self { + for (name, field) in other.0 { + self.0.insert(name, field); + } + self + } } impl Display for ClassSynthesizedFields { @@ -391,6 +413,12 @@ pub struct ProtocolMetadata { pub is_runtime_checkable: bool, } +#[derive(Clone, Debug, TypeEq, PartialEq, Eq)] +pub struct TotalOrderingMetadata { + /// Location of the decorator for `@total_ordering`. + pub location: TextRange, +} + /// A struct representing a class's ancestors, in method resolution order (MRO) /// and after dropping cycles and nonlinearizable inheritance. /// diff --git a/pyrefly/lib/binding/binding.rs b/pyrefly/lib/binding/binding.rs index 40ce8ac6e..1c56db4e9 100644 --- a/pyrefly/lib/binding/binding.rs +++ b/pyrefly/lib/binding/binding.rs @@ -1539,7 +1539,7 @@ pub struct BindingClassMetadata { pub class_idx: Idx, pub bases: Box<[Expr]>, pub keywords: Box<[(Name, Expr)]>, - pub decorators: Box<[Idx]>, + pub decorators: Box<[(Idx, TextRange)]>, pub is_new_type: bool, pub special_base: Option>, } diff --git a/pyrefly/lib/binding/class.rs b/pyrefly/lib/binding/class.rs index 6177561c5..291db6dbb 100644 --- a/pyrefly/lib/binding/class.rs +++ b/pyrefly/lib/binding/class.rs @@ -114,8 +114,10 @@ impl<'a> BindingsBuilder<'a> { let mut key_class_fields: SmallSet> = SmallSet::new(); let body = mem::take(&mut x.body); - let decorators = - self.ensure_and_bind_decorators(mem::take(&mut x.decorator_list), class_object.usage()); + let decorators_with_ranges = self.ensure_and_bind_decorators_with_ranges( + mem::take(&mut x.decorator_list), + class_object.usage(), + ); self.scopes.push(Scope::annotation(x.range)); @@ -180,7 +182,7 @@ impl<'a> BindingsBuilder<'a> { class_idx: class_indices.class_idx, bases: bases.clone().into_boxed_slice(), keywords: keywords.into_boxed_slice(), - decorators: decorators.clone().into_boxed_slice(), + decorators: decorators_with_ranges.clone().into_boxed_slice(), is_new_type: false, special_base: None, }, @@ -290,11 +292,14 @@ impl<'a> BindingsBuilder<'a> { } let legacy_tparams = legacy_tparam_builder.lookup_keys(); + let decorator_keys = decorators_with_ranges + .map(|(idx, _)| *idx) + .into_boxed_slice(); self.bind_definition_user( &x.name, class_object, - Binding::ClassDef(class_indices.class_idx, decorators.into_boxed_slice()), + Binding::ClassDef(class_indices.class_idx, decorator_keys), FlowStyle::Other, ); fields_possibly_defined_by_this_class.reserve(0); // Attempt to shrink to capacity diff --git a/pyrefly/lib/binding/expr.rs b/pyrefly/lib/binding/expr.rs index 9ac8795b3..c536efe4d 100644 --- a/pyrefly/lib/binding/expr.rs +++ b/pyrefly/lib/binding/expr.rs @@ -737,4 +737,19 @@ impl<'a> BindingsBuilder<'a> { } decorator_keys } + + pub fn ensure_and_bind_decorators_with_ranges( + &mut self, + decorators: Vec, + usage: &mut Usage, + ) -> Vec<(Idx, TextRange)> { + let mut decorator_keys_with_ranges = Vec::with_capacity(decorators.len()); + for mut x in decorators { + self.ensure_expr(&mut x.expression, usage); + let range = x.range(); + let k = self.insert_binding(Key::Anon(x.range), Binding::Decorator(x.expression)); + decorator_keys_with_ranges.push((k, range)); + } + decorator_keys_with_ranges + } } diff --git a/pyrefly/lib/dunder.rs b/pyrefly/lib/dunder.rs index 80632c0c7..2d137acb3 100644 --- a/pyrefly/lib/dunder.rs +++ b/pyrefly/lib/dunder.rs @@ -48,6 +48,8 @@ pub const SETITEM: Name = Name::new_static("__setitem__"); pub const BOOL: Name = Name::new_static("__bool__"); pub const RICH_CMPS: &[Name] = &[LT, LE, EQ, NE, GT, GE]; +/// Rich comparison methods supplied by the `functools.total_ordering` decorator +pub const RICH_CMPS_TOTAL_ORDERING: &[Name] = &[LT, LE, GT, GE]; /// Returns the associated dunder if `op` corresponds to a "rich comparison method": /// https://docs.python.org/3/reference/datamodel.html#object.__lt__. diff --git a/pyrefly/lib/module/module_name.rs b/pyrefly/lib/module/module_name.rs index af7a5fded..9c655f909 100644 --- a/pyrefly/lib/module/module_name.rs +++ b/pyrefly/lib/module/module_name.rs @@ -134,6 +134,10 @@ impl ModuleName { Self::from_str("dataclasses") } + pub fn functools() -> Self { + Self::from_str("functools") + } + pub fn type_checker_internals() -> Self { Self::from_str("_typeshed._type_checker_internals") } diff --git a/pyrefly/lib/test/decorators.rs b/pyrefly/lib/test/decorators.rs index 6b597aac9..a0b83da48 100644 --- a/pyrefly/lib/test/decorators.rs +++ b/pyrefly/lib/test/decorators.rs @@ -323,3 +323,98 @@ def f0(arg: Callable[..., int]) -> Callable[..., int]: ... def f0(arg: Callable[..., int]) -> Callable[..., int]: ... "#, ); + +// Reported in https://github.com/facebook/pyrefly/issues/491 +testcase!( + test_total_ordering, + r#" +from functools import total_ordering +from typing import reveal_type + +@total_ordering +class A: + def __init__(self, x: int) -> None: + self.x = x + def __eq__(self, other: "A") -> bool: + return self.x == other.x + def __lt__(self, other: "A") -> bool: + return self.x < other.x + +a = A(x=1) +b = A(x=2) + +# This should give the correct type for the method `__lt__` +reveal_type(A.__lt__) # E: revealed type: (self: Self@A, other: A) -> bool +# This should give be synthesized via `functools.total_ordering` +reveal_type(A.__gt__) # E: revealed type: (self: Self@A, other: A) -> bool +a <= b +"#, +); + +testcase!( + test_total_ordering_no_rich_cmp, + r#" +from functools import total_ordering + +@total_ordering # E: Class `A` must define at least one of the rich comparison methods. +class A: + def __init__(self, x: int) -> None: + self.x = x +"#, +); + +testcase!( + test_total_ordering_dataclass, + r#" +from dataclasses import dataclass +from functools import total_ordering +from typing import reveal_type + +@dataclass +@total_ordering +class A: + x: int + def __lt__(self, other: "A") -> bool: + return self.x < other.x + +a = A(x=1) +b = A(x=2) + +# This should give the correct type for the method `__lt__` +reveal_type(A.__lt__) # E: revealed type: (self: Self@A, other: A) -> bool +# This should give be synthesized via `functools.total_ordering` +reveal_type(A.__gt__) # E: revealed type: (self: Self@A, other: A) -> bool +a <= b +"#, +); + +testcase!( + test_total_ordering_precedence, + r#" +from functools import total_ordering +from typing import reveal_type + +@total_ordering +class A: + def __init__(self, x: int) -> None: + self.x = x + def __eq__(self, other: "A") -> bool: + return self.x == other.x + def __lt__(self, other: "A") -> bool: + return self.x < other.x + def __le__(self, other: object) -> bool: + if not isinstance(other, A): + return NotImplemented + return self.x <= other.x + +# This should give the correct type for the method `__lt__` +reveal_type(A.__lt__) # E: revealed type: (self: Self@A, other: A) -> bool +# This should give be synthesized via `functools.total_ordering` via `__lt__` +reveal_type(A.__gt__) # E: revealed type: (self: Self@A, other: A) -> bool + +# This should give the correct type for the method `__le__` +reveal_type(A.__le__) # E: revealed type: (self: Self@A, other: object) -> bool +# This should give be synthesized via `functools.total_ordering` via `__le__` +reveal_type(A.__ge__) # E: revealed type: (self: Self@A, other: object) -> bool +"#, +); diff --git a/pyrefly/lib/types/callable.rs b/pyrefly/lib/types/callable.rs index d67a990d6..fe97d41ee 100644 --- a/pyrefly/lib/types/callable.rs +++ b/pyrefly/lib/types/callable.rs @@ -253,6 +253,7 @@ pub enum FunctionKind { AbstractMethod, /// Instance of a protocol with a `__call__` method. The function has the `__call__` signature. CallbackProtocol(Box), + TotalOrdering, } /// A map from keywords to boolean values. Useful for storing sets of keyword arguments for various @@ -530,6 +531,7 @@ impl FunctionKind { ("typing", None, "runtime_checkable") => Self::RuntimeCheckable, ("typing_extensions", None, "runtime_checkable") => Self::RuntimeCheckable, ("abc", None, "abstractmethod") => Self::AbstractMethod, + ("functools", None, "total_ordering") => Self::TotalOrdering, _ => Self::Def(Box::new(FuncId { module, cls: cls.cloned(), @@ -611,6 +613,11 @@ impl FunctionKind { func: Name::new_static("abstractmethod"), }, Self::PropertySetter(func_id) | Self::Def(func_id) => (**func_id).clone(), + Self::TotalOrdering => FuncId { + module: ModuleName::functools(), + cls: None, + func: Name::new_static("total_ordering"), + }, } } }