diff --git a/pyrefly/lib/alt/function.rs b/pyrefly/lib/alt/function.rs index 1c1ed8451..456d21a7d 100644 --- a/pyrefly/lib/alt/function.rs +++ b/pyrefly/lib/alt/function.rs @@ -65,27 +65,86 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { || class_metadata.is_some_and(|idx| self.get_idx(*idx).is_protocol()); let def = self.get_idx(idx); if def.metadata.flags.is_overload { + if !skip_implementation && def.stub_or_impl == FunctionStubOrImpl::Impl { + self.error( + errors, + def.id_range, + ErrorKind::InvalidOverload, + None, + "@overload decorator should not be used on function implementations.".to_owned(), + ); + } + // This function is decorated with @overload. We should warn if this function is actually called anywhere. let successor = self.bindings().get(idx).successor; let ty = def.ty.clone(); if successor.is_none() { // This is the last definition in the chain. We should produce an overload type. let mut acc = Vec1::new((def.id_range, ty)); - let mut first = def; - while let Some(def) = self.step_overload_pred(predecessor) { - acc.push((def.id_range, def.ty.clone())); - first = def; + let mut has_overload_after = false; + let mut has_implementation_before_overload = false; + let mut has_any_implementation = false; + let mut temp_pred = *predecessor; + while let Some(current_pred_idx) = temp_pred { + let mut current_binding = self.bindings().get(current_pred_idx); + while let Binding::Forward(forward_key) = current_binding { + current_binding = self.bindings().get(*forward_key); + } + if let Binding::Function(func_idx, next_predecessor, _) = current_binding { + let func_def = self.get_idx(*func_idx); + + if func_def.metadata.flags.is_overload { + has_overload_after = true; + } + if func_def.stub_or_impl == FunctionStubOrImpl::Impl { + has_any_implementation = true; + if !func_def.metadata.flags.is_overload { + has_implementation_before_overload = true; + } + } + if has_overload_after && has_any_implementation && has_implementation_before_overload { + break; + } + temp_pred = *next_predecessor; + } else { + break; + } + } + + let mut first = def.clone(); + while let Some(predecessor_def) = self.step_overload_pred(predecessor) { + acc.push((predecessor_def.id_range, predecessor_def.ty.clone())); + first = predecessor_def; } if !skip_implementation { - self.error( - errors, - first.id_range, - ErrorKind::InvalidOverload, - None, - "Overloaded function must have an implementation".to_owned(), - ); + if !has_implementation_before_overload && def.stub_or_impl == FunctionStubOrImpl::Impl { + self.error( + errors, + def.id_range, + ErrorKind::InvalidOverload, + None, + "@overload decorator should not be used on function implementations.".to_owned(), + ); + } else if has_any_implementation { + self.error( + errors, + def.id_range, + ErrorKind::InvalidOverload, + None, + "@overload declarations must come before function implementation. ".to_owned(), + ); + } + else { + self.error( + errors, + first.id_range, + ErrorKind::InvalidOverload, + None, + "Overloaded function must have an implementation".to_owned(), + ); + } } - if acc.len() == 1 { + if acc.len() == 1 && !has_overload_after && !has_any_implementation { self.error( errors, first.id_range, @@ -463,6 +522,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { id_range: def.name.range, ty, metadata, + stub_or_impl, }) } diff --git a/pyrefly/lib/alt/types/decorated_function.rs b/pyrefly/lib/alt/types/decorated_function.rs index bb376e701..a46a9f3a6 100644 --- a/pyrefly/lib/alt/types/decorated_function.rs +++ b/pyrefly/lib/alt/types/decorated_function.rs @@ -21,6 +21,7 @@ use crate::types::callable::FuncId; use crate::types::callable::FuncMetadata; use crate::types::callable::FunctionKind; use crate::types::types::Type; +use crate::binding::binding::FunctionStubOrImpl; /// The type of a function definition after decorators are applied. Metadata arising from the /// decorators can be stored here. Note that the type might not be a function at all, since @@ -30,6 +31,7 @@ pub struct DecoratedFunction { pub id_range: TextRange, pub ty: Type, pub metadata: FuncMetadata, + pub stub_or_impl: FunctionStubOrImpl, } impl Display for DecoratedFunction { @@ -51,6 +53,7 @@ impl DecoratedFunction { })), flags: FuncFlags::default(), }, + stub_or_impl: FunctionStubOrImpl::Stub, } } } diff --git a/pyrefly/lib/binding/binding.rs b/pyrefly/lib/binding/binding.rs index 40ce8ac6e..7285765bc 100644 --- a/pyrefly/lib/binding/binding.rs +++ b/pyrefly/lib/binding/binding.rs @@ -754,7 +754,7 @@ impl IsAsync { } /// Is the body of this function stubbed out (contains nothing but `...`)? -#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, TypeEq, VisitMut)] pub enum FunctionStubOrImpl { /// The function body is `...`. Stub, diff --git a/pyrefly/lib/test/overload.rs b/pyrefly/lib/test/overload.rs index 79653a4f5..3f10cb790 100644 --- a/pyrefly/lib/test/overload.rs +++ b/pyrefly/lib/test/overload.rs @@ -405,3 +405,37 @@ def exponential() -> Any: f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(X())))))))))))))))))))))))) # E: # E: # E: # E: # E: # E: # E: # E: # E: # E: # E: # E: "#, ); + +testcase!( + test_implementation_with_overload, + r#" +from typing import overload + +@overload +def f(x: int) -> int: ... + +@overload +def f(x: int | str) -> int | str: # E: @overload decorator should not be used on function implementations. + return x + +@overload +def f(x: str) -> str: ... # E: @overload declarations must come before function implementation. + "#, +); + + +testcase!( + test_implementation_before_overload, + r#" +from typing import overload + +def f(x: int | str) -> int | str: + return x + +@overload +def f(x: int) -> int: ... + +@overload +def f(x: str) -> str: ... # E: @overload declarations must come before function implementation. + "#, +);