Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[red-knot] Narrowing For Truthiness Checks (if x or if not x) #14687

Merged
merged 6 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# Narrowing For Truthiness Checks (`if x` or `if not x`)
carljm marked this conversation as resolved.
Show resolved Hide resolved

## Value Literals

```py
def foo() -> Literal[0, -1, True, False, "", "foo", b"", b"bar", None] | tuple[()]:
return 0

x = foo()

if x:
reveal_type(x) # revealed: Literal[-1] | Literal[True] | Literal["foo"] | Literal[b"bar"]
else:
reveal_type(x) # revealed: Literal[0] | Literal[False] | Literal[""] | Literal[b""] | None | tuple[()]

if not x:
reveal_type(x) # revealed: Literal[0] | Literal[False] | Literal[""] | Literal[b""] | None | tuple[()]
else:
reveal_type(x) # revealed: Literal[-1] | Literal[True] | Literal["foo"] | Literal[b"bar"]

if x and not x:
reveal_type(x) # revealed: Never
else:
reveal_type(x) # revealed: Literal[-1, 0] | bool | Literal["", "foo"] | Literal[b"", b"bar"] | None | tuple[()]

if not (x and not x):
reveal_type(x) # revealed: Literal[-1, 0] | bool | Literal["", "foo"] | Literal[b"", b"bar"] | None | tuple[()]
else:
reveal_type(x) # revealed: Never

if x or not x:
reveal_type(x) # revealed: Literal[-1, 0] | bool | Literal["foo", ""] | Literal[b"bar", b""] | None | tuple[()]
else:
reveal_type(x) # revealed: Never

if not (x or not x):
reveal_type(x) # revealed: Never
else:
reveal_type(x) # revealed: Literal[-1, 0] | bool | Literal["foo", ""] | Literal[b"bar", b""] | None | tuple[()]

if (isinstance(x, int) or isinstance(x, str)) and x:
reveal_type(x) # revealed: Literal[-1] | Literal[True] | Literal["foo"]
else:
reveal_type(x) # revealed: Literal[b"", b"bar"] | None | tuple[()] | Literal[0] | Literal[False] | Literal[""]
```

## Function Literals

Basically functions are always truthy.

```py
def flag() -> bool:
return True

def foo(hello: int) -> bytes:
return b""

def bar(world: str, *args, **kwargs) -> float:
return 0.0

x = foo if flag() else bar

if x:
reveal_type(x) # revealed: Literal[foo, bar]
else:
reveal_type(x) # revealed: Never
```

## Mutable Truthiness

### Truthiness of Instances

The boolean value of an instance is not always consistent. For example, `__bool__` can be customized
to return random values, or in the case of a `list()`, the result depends on the number of elements
in the list. Therefore, these types should not be narrowed by `if x` or `if not x`.

```py
class A: ...
class B: ...

def f(x: A | B):
if x:
reveal_type(x) # revealed: A & ~AlwaysFalsy | B & ~AlwaysFalsy
else:
reveal_type(x) # revealed: A & ~AlwaysTruthy | B & ~AlwaysTruthy

if x and not x:
reveal_type(x) # revealed: A & ~AlwaysFalsy & ~AlwaysTruthy | B & ~AlwaysFalsy & ~AlwaysTruthy
else:
reveal_type(x) # revealed: A & ~AlwaysTruthy | B & ~AlwaysTruthy | A & ~AlwaysFalsy | B & ~AlwaysFalsy

if x or not x:
reveal_type(x) # revealed: A & ~AlwaysFalsy | B & ~AlwaysFalsy | A & ~AlwaysTruthy | B & ~AlwaysTruthy
else:
reveal_type(x) # revealed: A & ~AlwaysTruthy & ~AlwaysFalsy | B & ~AlwaysTruthy & ~AlwaysFalsy
```

### Truthiness of Types

Also, types may not be Truthy. This is because `__bool__` can be customized via a metaclass.
Although this is a very rare case, we may consider metaclass checks in the future to handle this
more accurately.
carljm marked this conversation as resolved.
Show resolved Hide resolved

```py
def flag() -> bool:
return True
carljm marked this conversation as resolved.
Show resolved Hide resolved

x = int if flag() else str
reveal_type(x) # revealed: Literal[int, str]

if x:
reveal_type(x) # revealed: Literal[int] & ~AlwaysFalsy | Literal[str] & ~AlwaysFalsy
else:
reveal_type(x) # revealed: Literal[int] & ~AlwaysTruthy | Literal[str] & ~AlwaysTruthy
```

## Determined Truthiness

Some custom classes can have a boolean value that is consistently determined as either `True` or
`False`, regardless of the instance's state. This is achieved by defining a `__bool__` method that
always returns a fixed value.

These types can always be fully narrowed in boolean contexts, as shown below:

```py
class T:
def __bool__(self) -> Literal[True]:
return True

class F:
def __bool__(self) -> Literal[False]:
return False

t = T()

if t:
reveal_type(t) # revealed: T
else:
reveal_type(t) # revealed: Never

f = F()

if f:
reveal_type(f) # revealed: Never
else:
reveal_type(f) # revealed: F
```

## Narrowing Complex Intersection and Union

```py
class A: ...
class B: ...

def flag() -> bool:
return True

def instance() -> A | B:
return A()

def literals() -> Literal[0, 42, "", "hello"]:
return 42

x = instance()
y = literals()

if isinstance(x, str) and not isinstance(x, B):
reveal_type(x) # revealed: A & str & ~B
reveal_type(y) # revealed: Literal[0, 42] | Literal["", "hello"]

z = x if flag() else y

reveal_type(z) # revealed: A & str & ~B | Literal[0, 42] | Literal["", "hello"]

if z:
reveal_type(z) # revealed: A & str & ~B & ~AlwaysFalsy | Literal[42] | Literal["hello"]
else:
reveal_type(z) # revealed: A & str & ~B & ~AlwaysTruthy | Literal[0] | Literal[""]
```

## Narrowing Multiple Variables

```py
def f(x: Literal[0, 1], y: Literal["", "hello"]):
if x and y and not x and not y:
reveal_type(x) # revealed: Never
reveal_type(y) # revealed: Never
else:
# ~(x or not x) and ~(y or not y)
reveal_type(x) # revealed: Literal[0, 1]
reveal_type(y) # revealed: Literal["", "hello"]

if (x or not x) and (y and not y):
reveal_type(x) # revealed: Literal[0, 1]
reveal_type(y) # revealed: Never
else:
# ~(x or not x) or ~(y and not y)
reveal_type(x) # revealed: Literal[0, 1]
reveal_type(y) # revealed: Literal["", "hello"]
```

## ControlFlow Merging

After merging control flows, when we take the union of all constraints applied in each branch, we
should return to the original state.

```py
class A: ...

x = A()

if x and not x:
y = x
reveal_type(y) # revealed: A & ~AlwaysFalsy & ~AlwaysTruthy
else:
y = x
reveal_type(y) # revealed: A & ~AlwaysTruthy | A & ~AlwaysFalsy

# TODO: It should be A. We should improve UnionBuilder or IntersectionBuilder. (issue #15023)
reveal_type(y) # revealed: A & ~AlwaysTruthy | A & ~AlwaysFalsy
```
61 changes: 58 additions & 3 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,11 @@ pub enum Type<'db> {
Union(UnionType<'db>),
/// The set of objects in all of the types in the intersection
Intersection(IntersectionType<'db>),
/// Represents objects whose `__bool__` method is deterministic:
/// - `AlwaysTruthy`: `__bool__` always returns `True`
/// - `AlwaysFalsy`: `__bool__` always returns `False`
AlwaysTruthy,
AlwaysFalsy,
/// An integer literal
IntLiteral(i64),
/// A boolean literal, either `True` or `False`.
Expand Down Expand Up @@ -717,6 +722,15 @@ impl<'db> Type<'db> {
.all(|&neg_ty| self.is_disjoint_from(db, neg_ty))
}

// Note that the definition of `Type::AlwaysFalsy` depends on the return value of `__bool__`.
// If `__bool__` always returns True or False, it can be treated as a subtype of `AlwaysTruthy` or `AlwaysFalsy`, respectively.
(left, Type::AlwaysFalsy) => matches!(left.bool(db), Truthiness::AlwaysFalse),
(left, Type::AlwaysTruthy) => matches!(left.bool(db), Truthiness::AlwaysTrue),
// Currently, the only supertype of `AlwaysFalsy` and `AlwaysTruthy` is the universal set (object instance).
(Type::AlwaysFalsy | Type::AlwaysTruthy, _) => {
target.is_equivalent_to(db, KnownClass::Object.to_instance(db))
}

// All `StringLiteral` types are a subtype of `LiteralString`.
(Type::StringLiteral(_), Type::LiteralString) => true,

Expand Down Expand Up @@ -1105,6 +1119,16 @@ impl<'db> Type<'db> {
false
}

(Type::AlwaysTruthy, ty) | (ty, Type::AlwaysTruthy) => {
// `Truthiness::Ambiguous` may include `AlwaysTrue` as a subset, so it's not guaranteed to be disjoint.
// Thus, they are only disjoint if `ty.bool() == AlwaysFalse`.
matches!(ty.bool(db), Truthiness::AlwaysFalse)
}
(Type::AlwaysFalsy, ty) | (ty, Type::AlwaysFalsy) => {
// Similarly, they are only disjoint if `ty.bool() == AlwaysTrue`.
matches!(ty.bool(db), Truthiness::AlwaysTrue)
}

(Type::KnownInstance(left), right) => {
left.instance_fallback(db).is_disjoint_from(db, right)
}
Expand Down Expand Up @@ -1238,7 +1262,9 @@ impl<'db> Type<'db> {
| Type::LiteralString
| Type::BytesLiteral(_)
| Type::SliceLiteral(_)
| Type::KnownInstance(_) => true,
| Type::KnownInstance(_)
| Type::AlwaysFalsy
| Type::AlwaysTruthy => true,
Type::SubclassOf(SubclassOfType { base }) => matches!(base, ClassBase::Class(_)),
Type::ClassLiteral(_) | Type::Instance(_) => {
// TODO: Ideally, we would iterate over the MRO of the class, check if all
Expand Down Expand Up @@ -1340,6 +1366,7 @@ impl<'db> Type<'db> {
//
false
}
Type::AlwaysTruthy | Type::AlwaysFalsy => false,
}
}

Expand Down Expand Up @@ -1405,7 +1432,9 @@ impl<'db> Type<'db> {
| Type::Todo(_)
| Type::Union(..)
| Type::Intersection(..)
| Type::LiteralString => false,
| Type::LiteralString
| Type::AlwaysTruthy
| Type::AlwaysFalsy => false,
}
}

Expand Down Expand Up @@ -1573,6 +1602,10 @@ impl<'db> Type<'db> {
// TODO: implement tuple methods
todo_type!().into()
}
Type::AlwaysTruthy | Type::AlwaysFalsy => {
// TODO return `Callable[[], Literal[True/False]]` for `__bool__` access
KnownClass::Object.to_instance(db).member(db, name)
}
&todo @ Type::Todo(_) => todo.into(),
}
}
Expand All @@ -1595,6 +1628,8 @@ impl<'db> Type<'db> {
// TODO: see above
Truthiness::Ambiguous
}
Type::AlwaysTruthy => Truthiness::AlwaysTrue,
Type::AlwaysFalsy => Truthiness::AlwaysFalse,
instance_ty @ Type::Instance(InstanceType { class }) => {
if class.is_known(db, KnownClass::NoneType) {
Truthiness::AlwaysFalse
Expand Down Expand Up @@ -1907,7 +1942,9 @@ impl<'db> Type<'db> {
| Type::StringLiteral(_)
| Type::SliceLiteral(_)
| Type::Tuple(_)
| Type::LiteralString => Type::Unknown,
| Type::LiteralString
| Type::AlwaysTruthy
| Type::AlwaysFalsy => Type::Unknown,
}
}

Expand Down Expand Up @@ -2047,6 +2084,7 @@ impl<'db> Type<'db> {
ClassBase::try_from_ty(db, todo_type!("Intersection meta-type"))
.expect("Type::Todo should be a valid ClassBase"),
),
Type::AlwaysTruthy | Type::AlwaysFalsy => KnownClass::Type.to_instance(db),
Type::Todo(todo) => Type::subclass_of_base(ClassBase::Todo(*todo)),
}
}
Expand Down Expand Up @@ -3491,6 +3529,8 @@ pub(crate) mod tests {
SubclassOfAbcClass(&'static str),
StdlibModule(CoreStdlibModule),
SliceLiteral(i32, i32, i32),
AlwaysTruthy,
AlwaysFalsy,
}

impl Ty {
Expand Down Expand Up @@ -3558,6 +3598,8 @@ pub(crate) mod tests {
Some(stop),
Some(step),
)),
Ty::AlwaysTruthy => Type::AlwaysTruthy,
Ty::AlwaysFalsy => Type::AlwaysFalsy,
}
}
}
Expand Down Expand Up @@ -3696,6 +3738,12 @@ pub(crate) mod tests {
)]
#[test_case(Ty::SliceLiteral(1, 2, 3), Ty::BuiltinInstance("slice"))]
#[test_case(Ty::SubclassOfBuiltinClass("str"), Ty::Intersection{pos: vec![], neg: vec![Ty::None]})]
#[test_case(Ty::IntLiteral(1), Ty::AlwaysTruthy)]
#[test_case(Ty::IntLiteral(0), Ty::AlwaysFalsy)]
#[test_case(Ty::AlwaysTruthy, Ty::BuiltinInstance("object"))]
#[test_case(Ty::AlwaysFalsy, Ty::BuiltinInstance("object"))]
#[test_case(Ty::Never, Ty::AlwaysTruthy)]
#[test_case(Ty::Never, Ty::AlwaysFalsy)]
fn is_subtype_of(from: Ty, to: Ty) {
let db = setup_db();
assert!(from.into_type(&db).is_subtype_of(&db, to.into_type(&db)));
Expand Down Expand Up @@ -3730,6 +3778,10 @@ pub(crate) mod tests {
#[test_case(Ty::BuiltinClassLiteral("str"), Ty::SubclassOfAny)]
#[test_case(Ty::AbcInstance("ABCMeta"), Ty::SubclassOfBuiltinClass("type"))]
#[test_case(Ty::SubclassOfBuiltinClass("str"), Ty::BuiltinClassLiteral("str"))]
#[test_case(Ty::IntLiteral(1), Ty::AlwaysFalsy)]
#[test_case(Ty::IntLiteral(0), Ty::AlwaysTruthy)]
carljm marked this conversation as resolved.
Show resolved Hide resolved
#[test_case(Ty::BuiltinInstance("str"), Ty::AlwaysTruthy)]
#[test_case(Ty::BuiltinInstance("str"), Ty::AlwaysFalsy)]
fn is_not_subtype_of(from: Ty, to: Ty) {
let db = setup_db();
assert!(!from.into_type(&db).is_subtype_of(&db, to.into_type(&db)));
Expand Down Expand Up @@ -3864,6 +3916,7 @@ pub(crate) mod tests {
#[test_case(Ty::Tuple(vec![]), Ty::BuiltinClassLiteral("object"))]
#[test_case(Ty::SubclassOfBuiltinClass("object"), Ty::None)]
#[test_case(Ty::SubclassOfBuiltinClass("str"), Ty::LiteralString)]
#[test_case(Ty::AlwaysFalsy, Ty::AlwaysTruthy)]
carljm marked this conversation as resolved.
Show resolved Hide resolved
fn is_disjoint_from(a: Ty, b: Ty) {
let db = setup_db();
let a = a.into_type(&db);
Expand Down Expand Up @@ -3894,6 +3947,8 @@ pub(crate) mod tests {
#[test_case(Ty::BuiltinClassLiteral("str"), Ty::BuiltinInstance("type"))]
#[test_case(Ty::BuiltinClassLiteral("str"), Ty::SubclassOfAny)]
#[test_case(Ty::AbcClassLiteral("ABC"), Ty::AbcInstance("ABCMeta"))]
#[test_case(Ty::BuiltinInstance("str"), Ty::AlwaysTruthy)]
#[test_case(Ty::BuiltinInstance("str"), Ty::AlwaysFalsy)]
fn is_not_disjoint_from(a: Ty, b: Ty) {
let db = setup_db();
let a = a.into_type(&db);
Expand Down
Loading
Loading