Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[red-knot] Support assert_type #15194

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# `assert_type`
InSyncWithFoo marked this conversation as resolved.
Show resolved Hide resolved

## Basic

```py
from typing_extensions import assert_type

def _(x: int):
assert_type(x, int) # fine
assert_type(x, str) # error: [type-assertion-failure]
```

## Narrowing

The asserted type is checked against the inferred type, not the declared type.

```toml
[environment]
python-version = "3.10"
```

```py
from typing_extensions import assert_type

def _(x: int | str):
if isinstance(x, int):
reveal_type(x) # revealed: int
assert_type(x, int) # fine
```

## Equivalence

The actual type must match the asserted type precisely.
InSyncWithFoo marked this conversation as resolved.
Show resolved Hide resolved

```py
from typing import Any, Type, Union
from typing_extensions import assert_type

# Subtype does not count
def _(x: bool):
assert_type(x, int) # error: [type-assertion-failure]

def _(a: type[int], b: type[Any]):
assert_type(a, type[Any]) # error: [type-assertion-failure]
assert_type(b, type[int]) # error: [type-assertion-failure]

# The expression constructing the type is not taken into account
def _(a: type[int]):
# TODO: Infer the second argument as a type expression
assert_type(a, Type[int]) # error: [type-assertion-failure]
```

## Gradual types

```py
from typing import Any
from typing_extensions import Literal, assert_type

from knot_extensions import Unknown

# Any and Unknown are considered equivalent
def _(a: Unknown, b: Any):
reveal_type(a) # revealed: Unknown
assert_type(a, Any) # fine

reveal_type(b) # revealed: Any
assert_type(b, Unknown) # fine

def _(a: type[Unknown], b: type[Any]):
# TODO: Should be `type[Unknown]`
reveal_type(a) # revealed: @Todo(unsupported type[X] special form)
reveal_type(b) # revealed: type[Any]

# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(a, type[Unknown]) # error: [type-assertion-failure]
# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(b, type[Any]) # error: [type-assertion-failure]
```

## Tuples

Tuple types with the same elements are the same.

```py
from typing_extensions import assert_type

def _(a: tuple[int, str, bytes]):
# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(a, tuple[int, str, bytes]) # error: [type-assertion-failure]

assert_type(a, tuple[int, str]) # error: [type-assertion-failure]
assert_type(a, tuple[int, str, bytes, None]) # error: [type-assertion-failure]
assert_type(a, tuple[int, bytes, str]) # error: [type-assertion-failure]

def _(a: tuple[Any, ...]):
# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(a, tuple[Any, ...]) # error: [type-assertion-failure]
```

## Unions

Unions with the same elements are the same, regardless of order.

```toml
[environment]
python-version = "3.10"
```

```py
from typing_extensions import assert_type

def _(a: str | int):
# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(a, int | str) # error: [type-assertion-failure]
```

## Intersections

Intersections are the same when their positive and negative parts are respectively the same,
regardless of order.

```py
from typing_extensions import assert_type

from knot_extensions import Intersection, Not

class A: ...
class B: ...
class C: ...
class D: ...

def _(a: A):
if isinstance(a, B) and not isinstance(a, C) and not isinstance(a, D):
reveal_type(a) # revealed: A & B & ~C & ~D

# TODO: Infer the second argument as a type expression
# Should be fine
assert_type(a, Intersection[B, A, Not[D], Not[C]]) # error: [type-assertion-failure]
```
190 changes: 177 additions & 13 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::hash::Hash;
use rustc_hash::FxHasher;
use std::collections::HashSet;
use std::hash::{BuildHasherDefault, Hash};
use std::iter;

use context::InferContext;
use diagnostic::{report_not_iterable, report_not_iterable_possibly_unbound};
Expand Down Expand Up @@ -555,6 +558,8 @@ pub enum Type<'db> {
// TODO protocols, callable types, overloads, generics, type vars
}

type OrderedTypeSet<'a, 'db> = HashSet<&'a Type<'db>, BuildHasherDefault<FxHasher>>;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What makes this set ordered? This is just a regular hash set, isn't it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. I just thought its name should reflect the fact that the order of iteration is depended upon.

Copy link
Member

@MichaReiser MichaReiser Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that the order isn't guaranteed. The element ordering may depend on insertion order. The IntersectionType positive and negative sets are ordered (although I don't remember the details of how they're ordered)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IntersectionType uses OrderedSet, whose iteration order is the same as insertion order. The comparison in question needs a self-ordered set type.

I have so far been unsuccessful in finding such a type. Any suggestions?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could collect the types into a vec and then order them. I'm just not sure why the order is important. I suspect that we instead have to implement an O(n^2) algorithm that, for every type in a, searches a type in b to which it is equivalent, disregarding ordering entirely. But I'm not sure if doing so is too naive.

Copy link
Contributor Author

@InSyncWithFoo InSyncWithFoo Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, this was done to avoid $$O(n^2)$$ time at the cost of a few $$O(n)$$ iterations and allocations. However, even for a $$O(n^2)$$ algorithm, I suspect $$O(n)$$ space is still needed to store paired indices.

impl<'db> Type<'db> {
pub const fn is_unknown(&self) -> bool {
matches!(self, Type::Unknown)
Expand Down Expand Up @@ -1431,6 +1436,88 @@ impl<'db> Type<'db> {
}
}

/// Returns true if this type and `other` are gradual equivalent.
///
/// > Two gradual types `A` and `B` are equivalent
/// > (that is, the same gradual type, not merely consistent with one another)
/// > if and only if all materializations of `A` are also materializations of `B`,
/// > and all materializations of `B` are also materializations of `A`.
/// >
/// > &mdash; [Summary of type relations]
///
/// Note: `Todo != Todo`.
///
/// This powers the `assert_type()` directive.
///
/// [Summary of type relations]: https://typing.readthedocs.io/en/latest/spec/concepts.html#summary-of-type-relations
pub(crate) fn is_gradual_equivalent_to(self, db: &'db dyn Db, other: Type<'db>) -> bool {
let equivalent =
|(first, second): (&Type<'db>, &Type<'db>)| first.is_gradual_equivalent_to(db, *second);

match (self, other) {
(Type::Todo(_), Type::Todo(_)) => false,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Todo is supposed to behave like Any or Unknown in the typing system, so I think our default approach should be that all three are gradually-equivalent to each other. Is there a concrete reason you felt it necessary to make Todo != Todo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just thought that it would be rather nonsensical to do something like this:

def _(a: tuple[int, ...]):
	assert_type(a, tuple[str, ...])  # Pass because both are Todos

No strong feelings though.

(_, _) if self == other => true,

(Type::Any | Type::Unknown, Type::Any | Type::Unknown) => true,

(Type::SubclassOf(first), Type::SubclassOf(second)) => {
match (first.subclass_of(), second.subclass_of()) {
(ClassBase::Todo(_), ClassBase::Todo(_)) => false,
(ClassBase::Any | ClassBase::Unknown, ClassBase::Any | ClassBase::Unknown) => {
true
}
(ClassBase::Class(first), ClassBase::Class(second)) => first == second,
_ => false,
}
}

(Type::Tuple(first), Type::Tuple(second)) => {
first.len(db) == second.len(db)
&& iter::zip(first.elements(db), second.elements(db)).all(equivalent)
}

(Type::Union(first), Type::Union(second)) => {
let first_elements = first.elements(db);
let second_elements = second.elements(db);

if first_elements.len() != second_elements.len() {
return false;
}

let first_elements = first_elements.iter().collect::<OrderedTypeSet>();
let second_elements = second_elements.iter().collect::<OrderedTypeSet>();

iter::zip(first_elements, second_elements).all(equivalent)
}

(Type::Intersection(first), Type::Intersection(second)) => {
let first_positive = first.positive(db);
let first_negative = first.negative(db);

let second_positive = second.positive(db);
let second_negative = second.negative(db);

if first_positive.len() != second_positive.len()
|| first_negative.len() != second_negative.len()
{
return false;
}

let first_positive = first_positive.iter().collect::<OrderedTypeSet>();
let first_negative = first_negative.iter().collect::<OrderedTypeSet>();

let second_positive = second_positive.iter().collect::<OrderedTypeSet>();
let second_negative = second_negative.iter().collect::<OrderedTypeSet>();

iter::zip(first_positive, second_positive).all(equivalent)
&& iter::zip(first_negative, second_negative).all(equivalent)
}

_ => false,
}
}

/// Return true if there is just a single inhabitant for this type.
///
/// Note: This function aims to have no false positives, but might return `false`
Expand Down Expand Up @@ -1923,6 +2010,19 @@ impl<'db> Type<'db> {
CallOutcome::callable(binding)
}

Some(KnownFunction::AssertType) => {
let [_actual_type, asserted] = binding.parameter_tys() else {
return CallOutcome::callable(binding);
};

// TODO: Infer this as a type expression directly
let Ok(asserted_type) = asserted.in_type_expression(db) else {
return CallOutcome::callable(binding);
};

CallOutcome::asserted(binding, asserted_type)
}

_ => CallOutcome::callable(binding),
}
}
Expand Down Expand Up @@ -3239,6 +3339,9 @@ pub enum KnownFunction {
/// [`typing(_extensions).no_type_check`](https://typing.readthedocs.io/en/latest/spec/directives.html#no-type-check)
NoTypeCheck,

/// `typing(_extensions).assert_type`
AssertType,

/// `knot_extensions.static_assert`
StaticAssert,
/// `knot_extensions.is_equivalent_to`
Expand All @@ -3261,18 +3364,7 @@ impl KnownFunction {
pub fn constraint_function(self) -> Option<KnownConstraintFunction> {
match self {
Self::ConstraintFunction(f) => Some(f),
Self::RevealType
| Self::Len
| Self::Final
| Self::NoTypeCheck
| Self::StaticAssert
| Self::IsEquivalentTo
| Self::IsSubtypeOf
| Self::IsAssignableTo
| Self::IsDisjointFrom
| Self::IsFullyStatic
| Self::IsSingleton
| Self::IsSingleValued => None,
_ => None,
}
}

Expand All @@ -3294,6 +3386,7 @@ impl KnownFunction {
"no_type_check" if definition.is_typing_definition(db) => {
Some(KnownFunction::NoTypeCheck)
}
"assert_type" if definition.is_typing_definition(db) => Some(KnownFunction::AssertType),
"static_assert" if definition.is_knot_extensions_definition(db) => {
Some(KnownFunction::StaticAssert)
}
Expand Down Expand Up @@ -4648,6 +4741,77 @@ pub(crate) mod tests {
assert!(!from.into_type(&db).is_fully_static(&db));
}

#[test_case(Ty::Any, Ty::Any)]
#[test_case(Ty::Unknown, Ty::Unknown)]
#[test_case(Ty::Any, Ty::Unknown)]
#[test_case(Ty::Never, Ty::Never)]
#[test_case(Ty::AlwaysTruthy, Ty::AlwaysTruthy)]
#[test_case(Ty::AlwaysFalsy, Ty::AlwaysFalsy)]
#[test_case(Ty::LiteralString, Ty::LiteralString)]
#[test_case(Ty::BooleanLiteral(true), Ty::BooleanLiteral(true))]
#[test_case(Ty::BooleanLiteral(false), Ty::BooleanLiteral(false))]
#[test_case(Ty::SliceLiteral(0, 1, 2), Ty::SliceLiteral(0, 1, 2))]
#[test_case(Ty::BuiltinClassLiteral("str"), Ty::BuiltinClassLiteral("str"))]
#[test_case(Ty::SubclassOfAny, Ty::SubclassOfUnknown)]
#[test_case(
Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]),
Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")])
)]
#[test_case(
Ty::Intersection {
pos: vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")],
neg: vec![Ty::BuiltinInstance("bytes"), Ty::None]
},
Ty::Intersection {
pos: vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")],
neg: vec![Ty::None, Ty::BuiltinInstance("bytes")]
}
)]
#[test_case(
Ty::Intersection {
pos: vec![Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")])],
neg: vec![Ty::SubclassOfAny]
},
Ty::Intersection {
pos: vec![Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")])],
neg: vec![Ty::SubclassOfUnknown]
}
)]
fn is_gradual_equivalent_to(a: Ty, b: Ty) {
let db = setup_db();
let a = a.into_type(&db);
let b = b.into_type(&db);

assert!(a.is_gradual_equivalent_to(&db, b));
assert!(b.is_gradual_equivalent_to(&db, a));
}

#[test_case(Ty::Todo, Ty::Todo)]
#[test_case(
Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]),
Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"), Ty::BuiltinInstance("bytes")])
)]
#[test_case(
Ty::Union(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int"), Ty::BuiltinInstance("bytes")]),
Ty::Union(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"), Ty::BuiltinInstance("dict")])
)]
#[test_case(
Ty::Tuple(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]),
Ty::Tuple(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int"), Ty::BuiltinInstance("bytes")])
)]
#[test_case(
Ty::Tuple(vec![Ty::BuiltinInstance("str"), Ty::BuiltinInstance("int")]),
Ty::Tuple(vec![Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str")])
)]
fn is_not_gradual_equivalent_to(a: Ty, b: Ty) {
let db = setup_db();
let a = a.into_type(&db);
let b = b.into_type(&db);

assert!(!a.is_gradual_equivalent_to(&db, b));
assert!(!b.is_gradual_equivalent_to(&db, a));
}

#[test_case(Ty::IntLiteral(1); "is_int_literal_truthy")]
#[test_case(Ty::IntLiteral(-1))]
#[test_case(Ty::StringLiteral("foo"))]
Expand Down
Loading
Loading