Skip to content

Implement a basic model for functools.total_ordering #537

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions pyrefly/lib/alt/class/class_metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use crate::alt::types::class_metadata::DataclassMetadata;
use crate::alt::types::class_metadata::EnumMetadata;
use crate::alt::types::class_metadata::NamedTupleMetadata;
use crate::alt::types::class_metadata::ProtocolMetadata;
use crate::alt::types::class_metadata::TotalOrderingMetadata;
use crate::alt::types::class_metadata::TypedDictMetadata;
use crate::binding::binding::Key;
use crate::binding::binding::KeyLegacyTypeParam;
Expand All @@ -49,7 +50,7 @@ use crate::types::types::Type;

/// Private helper type used to share part of the logic needed for the
/// binding-level work of finding legacy type parameters versus the type-level
/// work of computing inherticance information and the MRO.
/// work of computing inheritance information and the MRO.
#[derive(Debug, Clone)]
pub enum BaseClass {
TypedDict,
Expand Down Expand Up @@ -153,7 +154,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
cls: &Class,
bases: &[Expr],
keywords: &[(Name, Expr)],
decorators: &[Idx<Key>],
decorators: &[(Idx<Key>, TextRange)],
is_new_type: bool,
special_base: &Option<Box<BaseClass>>,
errors: &ErrorCollector,
Expand Down Expand Up @@ -405,8 +406,9 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
}
}
let mut is_final = false;
for decorator in decorators {
let decorator = self.get_idx(*decorator);
let mut total_ordering_metadata = None;
for (decorator_key, decorator_range) in decorators {
let decorator = self.get_idx(*decorator_key);
match decorator.ty().callee_kind() {
Some(CalleeKind::Function(FunctionKind::Dataclass(kws))) => {
let dataclass_fields = self.get_dataclass_fields(cls, &bases_with_metadata);
Expand All @@ -431,6 +433,11 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
);
}
}
Some(CalleeKind::Function(FunctionKind::TotalOrdering)) => {
total_ordering_metadata = Some(TotalOrderingMetadata {
location: *decorator_range,
});
}
_ => {}
}
}
Expand Down Expand Up @@ -477,6 +484,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
is_new_type,
is_final,
has_unknown_tparams,
total_ordering_metadata,
errors,
)
}
Expand Down
1 change: 1 addition & 0 deletions pyrefly/lib/alt/class/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ pub mod enums;
pub mod named_tuple;
pub mod new_type;
pub mod targs;
pub mod total_ordering;
pub mod typed_dict;
pub mod variance_inference;
85 changes: 85 additions & 0 deletions pyrefly/lib/alt/class/total_ordering.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
use ruff_python_ast::name::Name;
use starlark_map::small_map::SmallMap;

use crate::alt::answers::AnswersSolver;
use crate::alt::answers::LookupAnswer;
use crate::alt::types::class_metadata::ClassSynthesizedField;
use crate::alt::types::class_metadata::ClassSynthesizedFields;
use crate::binding::binding::KeyClassField;
use crate::dunder;
use crate::error::collector::ErrorCollector;
use crate::error::kind::ErrorKind;
use crate::types::class::Class;

// https://github.com/python/cpython/blob/a8ec511900d0d84cffbb4ee6419c9a790d131129/Lib/functools.py#L173
// conversion order of rich comparison methods:
const LT_CONVERSION_ORDER: &[Name; 3] = &[dunder::GT, dunder::LE, dunder::GE];
const GT_CONVERSION_ORDER: &[Name; 3] = &[dunder::LT, dunder::GE, dunder::LE];
const LE_CONVERSION_ORDER: &[Name; 3] = &[dunder::GE, dunder::LT, dunder::GT];
const GE_CONVERSION_ORDER: &[Name; 3] = &[dunder::LE, dunder::GT, dunder::LT];

impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
fn synthesize_rich_cmp(&self, cls: &Class, cmp: &Name) -> ClassSynthesizedField {
let conversion_order = if cmp == &dunder::LT {
LT_CONVERSION_ORDER
} else if cmp == &dunder::GT {
GT_CONVERSION_ORDER
} else if cmp == &dunder::LE {
LE_CONVERSION_ORDER
} else if cmp == &dunder::GE {
GE_CONVERSION_ORDER
} else {
unreachable!("Unexpected rich comparison method: {}", cmp);
};
// The first field in the conversion order is the one that we will use to synthesize the method.
for other_cmp in conversion_order {
let other_cmp_field = cls.fields().find(|f| **f == *other_cmp);
if other_cmp_field.is_some() {
let other_cmp_field =
self.get_from_class(cls, &KeyClassField(cls.index(), other_cmp.clone()));
let ty = other_cmp_field.as_named_tuple_type();
return ClassSynthesizedField::new(ty);
}
}
unreachable!("No rich comparison method found for {}", cmp);
}

pub fn get_total_ordering_synthesized_fields(
&self,
errors: &ErrorCollector,
cls: &Class,
) -> Option<ClassSynthesizedFields> {
let metadata = self.get_metadata_for_class(cls);
if !metadata.is_total_ordering() {
return None;
}
// The class must have one of the rich comparison dunder methods defined
if !cls
.fields()
.any(|f| *f == dunder::LT || *f == dunder::LE || *f == dunder::GT || *f == dunder::GE)
{
let total_ordering_metadata = metadata.total_ordering_metadata().unwrap();
self.error(
errors,
total_ordering_metadata.location,
ErrorKind::MissingAttribute,
None,
format!(
"Class `{}` must define at least one of the rich comparison methods.",
cls.name()
),
);
return None;
}
let rich_cmps_to_synthesize: Vec<_> = dunder::RICH_CMPS_TOTAL_ORDERING
.iter()
.filter(|cmp| !cls.contains(cmp))
.collect();
let mut fields = SmallMap::new();
for cmp in rich_cmps_to_synthesize {
let synthesized_field = self.synthesize_rich_cmp(cls, cmp);
fields.insert(cmp.clone(), synthesized_field);
}
Some(ClassSynthesizedFields::new(fields))
}
}
26 changes: 20 additions & 6 deletions pyrefly/lib/alt/solve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1326,16 +1326,30 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {

pub fn solve_class_synthesized_fields(
&self,
errors: &ErrorCollector,
fields: &BindingClassSynthesizedFields,
) -> Arc<ClassSynthesizedFields> {
let fields = match &self.get_idx(fields.0).0 {
None => ClassSynthesizedFields::default(),
Some(cls) => self
.get_typed_dict_synthesized_fields(cls)
.or_else(|| self.get_dataclass_synthesized_fields(cls))
.or_else(|| self.get_named_tuple_synthesized_fields(cls))
.or_else(|| self.get_new_type_synthesized_fields(cls))
.unwrap_or_default(),
Some(cls) => {
let mut fields = ClassSynthesizedFields::default();
if let Some(new_fields) = self.get_typed_dict_synthesized_fields(cls) {
fields = fields.combine(new_fields);
}
if let Some(new_fields) = self.get_dataclass_synthesized_fields(cls) {
fields = fields.combine(new_fields);
}
if let Some(new_fields) = self.get_named_tuple_synthesized_fields(cls) {
fields = fields.combine(new_fields);
}
if let Some(new_fields) = self.get_new_type_synthesized_fields(cls) {
fields = fields.combine(new_fields);
}
if let Some(new_fields) = self.get_total_ordering_synthesized_fields(errors, cls) {
fields = fields.combine(new_fields);
}
fields
}
};
Arc::new(fields)
}
Expand Down
4 changes: 2 additions & 2 deletions pyrefly/lib/alt/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,9 @@ impl<Ans: LookupAnswer> Solve<Ans> for KeyClassSynthesizedFields {
fn solve(
answers: &AnswersSolver<Ans>,
binding: &BindingClassSynthesizedFields,
_errors: &ErrorCollector,
errors: &ErrorCollector,
) -> Arc<ClassSynthesizedFields> {
answers.solve_class_synthesized_fields(binding)
answers.solve_class_synthesized_fields(errors, binding)
}

fn create_recursive(_: &AnswersSolver<Ans>, _: &Self::Value) -> Self::Recursive {}
Expand Down
28 changes: 28 additions & 0 deletions pyrefly/lib/alt/types/class_metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use pyrefly_derive::VisitMut;
use pyrefly_util::display::commas_iter;
use pyrefly_util::visit::VisitMut;
use ruff_python_ast::name::Name;
use ruff_text_size::TextRange;
use starlark_map::small_map::SmallMap;
use starlark_map::small_set::SmallSet;
use vec1::Vec1;
Expand Down Expand Up @@ -49,6 +50,7 @@ pub struct ClassMetadata {
/// Is it possible for this class to have type parameters that we don't know about?
/// This can happen if, e.g., a class inherits from Any.
has_unknown_tparams: bool,
total_ordering_metadata: Option<TotalOrderingMetadata>,
}

impl VisitMut<Type> for ClassMetadata {
Expand Down Expand Up @@ -80,6 +82,7 @@ impl ClassMetadata {
is_new_type: bool,
is_final: bool,
has_unknown_tparams: bool,
total_ordering_metadata: Option<TotalOrderingMetadata>,
errors: &ErrorCollector,
) -> ClassMetadata {
let mro = Mro::new(cls, &bases_with_metadata, errors);
Expand All @@ -103,6 +106,7 @@ impl ClassMetadata {
is_new_type,
is_final,
has_unknown_tparams,
total_ordering_metadata,
}
}

Expand Down Expand Up @@ -166,6 +170,7 @@ impl ClassMetadata {
is_new_type: false,
is_final: false,
has_unknown_tparams: false,
total_ordering_metadata: None,
}
}

Expand Down Expand Up @@ -228,6 +233,14 @@ impl ClassMetadata {
self.enum_metadata.is_some()
}

pub fn is_total_ordering(&self) -> bool {
self.total_ordering_metadata.is_some()
}

pub fn total_ordering_metadata(&self) -> Option<&TotalOrderingMetadata> {
self.total_ordering_metadata.as_ref()
}

pub fn protocol_metadata(&self) -> Option<&ProtocolMetadata> {
self.protocol_metadata.as_ref()
}
Expand Down Expand Up @@ -296,6 +309,15 @@ impl ClassSynthesizedFields {
pub fn get(&self, name: &Name) -> Option<&ClassSynthesizedField> {
self.0.get(name)
}

/// Combines two sets of synthesized fields, with the second set
/// overwriting any fields in the first set that have the same name.
pub fn combine(mut self, other: Self) -> Self {
for (name, field) in other.0 {
self.0.insert(name, field);
}
self
}
}

impl Display for ClassSynthesizedFields {
Expand Down Expand Up @@ -391,6 +413,12 @@ pub struct ProtocolMetadata {
pub is_runtime_checkable: bool,
}

#[derive(Clone, Debug, TypeEq, PartialEq, Eq)]
pub struct TotalOrderingMetadata {
/// Location of the decorator for `@total_ordering`.
pub location: TextRange,
}

/// A struct representing a class's ancestors, in method resolution order (MRO)
/// and after dropping cycles and nonlinearizable inheritance.
///
Expand Down
2 changes: 1 addition & 1 deletion pyrefly/lib/binding/binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1539,7 +1539,7 @@ pub struct BindingClassMetadata {
pub class_idx: Idx<KeyClass>,
pub bases: Box<[Expr]>,
pub keywords: Box<[(Name, Expr)]>,
pub decorators: Box<[Idx<Key>]>,
pub decorators: Box<[(Idx<Key>, TextRange)]>,
pub is_new_type: bool,
pub special_base: Option<Box<BaseClass>>,
}
Expand Down
13 changes: 9 additions & 4 deletions pyrefly/lib/binding/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,10 @@ impl<'a> BindingsBuilder<'a> {
let mut key_class_fields: SmallSet<Idx<KeyClassField>> = SmallSet::new();

let body = mem::take(&mut x.body);
let decorators =
self.ensure_and_bind_decorators(mem::take(&mut x.decorator_list), class_object.usage());
let decorators_with_ranges = self.ensure_and_bind_decorators_with_ranges(
mem::take(&mut x.decorator_list),
class_object.usage(),
);

self.scopes.push(Scope::annotation(x.range));

Expand Down Expand Up @@ -180,7 +182,7 @@ impl<'a> BindingsBuilder<'a> {
class_idx: class_indices.class_idx,
bases: bases.clone().into_boxed_slice(),
keywords: keywords.into_boxed_slice(),
decorators: decorators.clone().into_boxed_slice(),
decorators: decorators_with_ranges.clone().into_boxed_slice(),
is_new_type: false,
special_base: None,
},
Expand Down Expand Up @@ -290,11 +292,14 @@ impl<'a> BindingsBuilder<'a> {
}

let legacy_tparams = legacy_tparam_builder.lookup_keys();
let decorator_keys = decorators_with_ranges
.map(|(idx, _)| *idx)
.into_boxed_slice();

self.bind_definition_user(
&x.name,
class_object,
Binding::ClassDef(class_indices.class_idx, decorators.into_boxed_slice()),
Binding::ClassDef(class_indices.class_idx, decorator_keys),
FlowStyle::Other,
);
fields_possibly_defined_by_this_class.reserve(0); // Attempt to shrink to capacity
Expand Down
15 changes: 15 additions & 0 deletions pyrefly/lib/binding/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -737,4 +737,19 @@ impl<'a> BindingsBuilder<'a> {
}
decorator_keys
}

pub fn ensure_and_bind_decorators_with_ranges(
&mut self,
decorators: Vec<Decorator>,
usage: &mut Usage,
) -> Vec<(Idx<Key>, TextRange)> {
let mut decorator_keys_with_ranges = Vec::with_capacity(decorators.len());
for mut x in decorators {
self.ensure_expr(&mut x.expression, usage);
let range = x.range();
let k = self.insert_binding(Key::Anon(x.range), Binding::Decorator(x.expression));
decorator_keys_with_ranges.push((k, range));
}
decorator_keys_with_ranges
}
}
2 changes: 2 additions & 0 deletions pyrefly/lib/dunder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ pub const SETITEM: Name = Name::new_static("__setitem__");
pub const BOOL: Name = Name::new_static("__bool__");

pub const RICH_CMPS: &[Name] = &[LT, LE, EQ, NE, GT, GE];
/// Rich comparison methods supplied by the `functools.total_ordering` decorator
pub const RICH_CMPS_TOTAL_ORDERING: &[Name] = &[LT, LE, GT, GE];

/// Returns the associated dunder if `op` corresponds to a "rich comparison method":
/// https://docs.python.org/3/reference/datamodel.html#object.__lt__.
Expand Down
4 changes: 4 additions & 0 deletions pyrefly/lib/module/module_name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ impl ModuleName {
Self::from_str("dataclasses")
}

pub fn functools() -> Self {
Self::from_str("functools")
}

pub fn type_checker_internals() -> Self {
Self::from_str("_typeshed._type_checker_internals")
}
Expand Down
Loading