Skip to content

Commit

Permalink
[red-knot] function parameter types (#14802)
Browse files Browse the repository at this point in the history
## Summary

Inferred and declared types for function parameters, in the function
body scope.

Fixes #13693.

## Test Plan

Added mdtests.

---------

Co-authored-by: Micha Reiser <[email protected]>
Co-authored-by: Alex Waygood <[email protected]>
  • Loading branch information
3 people authored Dec 6, 2024
1 parent 2119dca commit 3017b3b
Show file tree
Hide file tree
Showing 7 changed files with 340 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ from typing_extensions import TypeVarTuple
Ts = TypeVarTuple("Ts")

def append_int(*args: *Ts) -> tuple[*Ts, int]:
# TODO: should show some representation of the variadic generic type
reveal_type(args) # revealed: @Todo(function parameter type)
# TODO: tuple[*Ts]
reveal_type(args) # revealed: tuple

return (*args, 1)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Function parameter types

Within a function scope, the declared type of each parameter is its annotated type (or Unknown if
not annotated). The initial inferred type is the union of the declared type with the type of the
default value expression (if any). If both are fully static types, this union should simplify to the
annotated type (since the default value type must be assignable to the annotated type, and for fully
static types this means subtype-of, which simplifies in unions). But if the annotated type is
Unknown or another non-fully-static type, the default value type may still be relevant as lower
bound.

The variadic parameter is a variadic tuple of its annotated type; the variadic-keywords parameter is
a dictionary from strings to its annotated type.

## Parameter kinds

```py
from typing import Literal

def f(a, b: int, c=1, d: int = 2, /, e=3, f: Literal[4] = 4, *args: object, g=5, h: Literal[6] = 6, **kwargs: str):
reveal_type(a) # revealed: Unknown
reveal_type(b) # revealed: int
reveal_type(c) # revealed: Unknown | Literal[1]
reveal_type(d) # revealed: int
reveal_type(e) # revealed: Unknown | Literal[3]
reveal_type(f) # revealed: Literal[4]
reveal_type(g) # revealed: Unknown | Literal[5]
reveal_type(h) # revealed: Literal[6]

# TODO: should be `tuple[object, ...]` (needs generics)
reveal_type(args) # revealed: tuple

# TODO: should be `dict[str, str]` (needs generics)
reveal_type(kwargs) # revealed: dict
```

## Unannotated variadic parameters

...are inferred as tuple of Unknown or dict from string to Unknown.

```py
def g(*args, **kwargs):
# TODO: should be `tuple[Unknown, ...]` (needs generics)
reveal_type(args) # revealed: tuple

# TODO: should be `dict[str, Unknown]` (needs generics)
reveal_type(kwargs) # revealed: dict
```

## Annotation is present but not a fully static type

The default value type should be a lower bound on the inferred type.

```py
from typing import Any

def f(x: Any = 1):
reveal_type(x) # revealed: Any | Literal[1]
```

## Default value type must be assignable to annotated type

The default value type must be assignable to the annotated type. If not, we emit a diagnostic, and
fall back to inferring the annotated type, ignoring the default value type.

```py
# error: [invalid-parameter-default]
def f(x: int = "foo"):
reveal_type(x) # revealed: int

# The check is assignable-to, not subtype-of, so this is fine:
from typing import Any

def g(x: Any = "foo"):
reveal_type(x) # revealed: Any | Literal["foo"]
```
70 changes: 46 additions & 24 deletions crates/red_knot_python_semantic/src/semantic_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -606,24 +606,11 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
let function_table = index.symbol_table(function_scope_id);
assert_eq!(
names(&function_table),
vec!["a", "b", "c", "args", "d", "kwargs"],
vec!["a", "b", "c", "d", "args", "kwargs"],
);

let use_def = index.use_def_map(function_scope_id);
for name in ["a", "b", "c", "d"] {
let binding = use_def
.first_public_binding(
function_table
.symbol_id_by_name(name)
.expect("symbol exists"),
)
.unwrap();
assert!(matches!(
binding.kind(&db),
DefinitionKind::ParameterWithDefault(_)
));
}
for name in ["args", "kwargs"] {
let binding = use_def
.first_public_binding(
function_table
Expand All @@ -633,6 +620,28 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
.unwrap();
assert!(matches!(binding.kind(&db), DefinitionKind::Parameter(_)));
}
let args_binding = use_def
.first_public_binding(
function_table
.symbol_id_by_name("args")
.expect("symbol exists"),
)
.unwrap();
assert!(matches!(
args_binding.kind(&db),
DefinitionKind::VariadicPositionalParameter(_)
));
let kwargs_binding = use_def
.first_public_binding(
function_table
.symbol_id_by_name("kwargs")
.expect("symbol exists"),
)
.unwrap();
assert!(matches!(
kwargs_binding.kind(&db),
DefinitionKind::VariadicKeywordParameter(_)
));
}

#[test]
Expand All @@ -654,25 +663,38 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
let lambda_table = index.symbol_table(lambda_scope_id);
assert_eq!(
names(&lambda_table),
vec!["a", "b", "c", "args", "d", "kwargs"],
vec!["a", "b", "c", "d", "args", "kwargs"],
);

let use_def = index.use_def_map(lambda_scope_id);
for name in ["a", "b", "c", "d"] {
let binding = use_def
.first_public_binding(lambda_table.symbol_id_by_name(name).expect("symbol exists"))
.unwrap();
assert!(matches!(
binding.kind(&db),
DefinitionKind::ParameterWithDefault(_)
));
}
for name in ["args", "kwargs"] {
let binding = use_def
.first_public_binding(lambda_table.symbol_id_by_name(name).expect("symbol exists"))
.unwrap();
assert!(matches!(binding.kind(&db), DefinitionKind::Parameter(_)));
}
let args_binding = use_def
.first_public_binding(
lambda_table
.symbol_id_by_name("args")
.expect("symbol exists"),
)
.unwrap();
assert!(matches!(
args_binding.kind(&db),
DefinitionKind::VariadicPositionalParameter(_)
));
let kwargs_binding = use_def
.first_public_binding(
lambda_table
.symbol_id_by_name("kwargs")
.expect("symbol exists"),
)
.unwrap();
assert!(matches!(
kwargs_binding.kind(&db),
DefinitionKind::VariadicKeywordParameter(_)
));
}

/// Test case to validate that the comprehension scope is correctly identified and that the target
Expand Down
78 changes: 48 additions & 30 deletions crates/red_knot_python_semantic/src/semantic_index/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use ruff_index::IndexVec;
use ruff_python_ast as ast;
use ruff_python_ast::name::Name;
use ruff_python_ast::visitor::{walk_expr, walk_pattern, walk_stmt, Visitor};
use ruff_python_ast::{AnyParameterRef, BoolOp, Expr};
use ruff_python_ast::{BoolOp, Expr};

use crate::ast_node_ref::AstNodeRef;
use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
Expand Down Expand Up @@ -479,21 +479,35 @@ impl<'db> SemanticIndexBuilder<'db> {
self.pop_scope();
}

fn declare_parameter(&mut self, parameter: AnyParameterRef<'db>) {
let symbol = self.add_symbol(parameter.name().id().clone());
fn declare_parameters(&mut self, parameters: &'db ast::Parameters) {
for parameter in parameters.iter_non_variadic_params() {
self.declare_parameter(parameter);
}
if let Some(vararg) = parameters.vararg.as_ref() {
let symbol = self.add_symbol(vararg.name.id().clone());
self.add_definition(
symbol,
DefinitionNodeRef::VariadicPositionalParameter(vararg),
);
}
if let Some(kwarg) = parameters.kwarg.as_ref() {
let symbol = self.add_symbol(kwarg.name.id().clone());
self.add_definition(symbol, DefinitionNodeRef::VariadicKeywordParameter(kwarg));
}
}

fn declare_parameter(&mut self, parameter: &'db ast::ParameterWithDefault) {
let symbol = self.add_symbol(parameter.parameter.name.id().clone());

let definition = self.add_definition(symbol, parameter);

if let AnyParameterRef::NonVariadic(with_default) = parameter {
// Insert a mapping from the parameter to the same definition.
// This ensures that calling `HasTy::ty` on the inner parameter returns
// a valid type (and doesn't panic)
let existing_definition = self.definitions_by_node.insert(
DefinitionNodeRef::from(AnyParameterRef::Variadic(&with_default.parameter)).key(),
definition,
);
debug_assert_eq!(existing_definition, None);
}
// Insert a mapping from the inner Parameter node to the same definition.
// This ensures that calling `HasTy::ty` on the inner parameter returns
// a valid type (and doesn't panic)
let existing_definition = self
.definitions_by_node
.insert((&parameter.parameter).into(), definition);
debug_assert_eq!(existing_definition, None);
}

pub(super) fn build(mut self) -> SemanticIndex<'db> {
Expand Down Expand Up @@ -556,34 +570,40 @@ where
fn visit_stmt(&mut self, stmt: &'ast ast::Stmt) {
match stmt {
ast::Stmt::FunctionDef(function_def) => {
for decorator in &function_def.decorator_list {
let ast::StmtFunctionDef {
decorator_list,
parameters,
type_params,
name,
returns,
body,
is_async: _,
range: _,
} = function_def;
for decorator in decorator_list {
self.visit_decorator(decorator);
}

self.with_type_params(
NodeWithScopeRef::FunctionTypeParameters(function_def),
function_def.type_params.as_deref(),
type_params.as_deref(),
|builder| {
builder.visit_parameters(&function_def.parameters);
if let Some(expr) = &function_def.returns {
builder.visit_annotation(expr);
builder.visit_parameters(parameters);
if let Some(returns) = returns {
builder.visit_annotation(returns);
}

builder.push_scope(NodeWithScopeRef::Function(function_def));

// Add symbols and definitions for the parameters to the function scope.
for parameter in &*function_def.parameters {
builder.declare_parameter(parameter);
}
builder.declare_parameters(parameters);

builder.visit_body(&function_def.body);
builder.visit_body(body);
builder.pop_scope()
},
);
// The default value of the parameters needs to be evaluated in the
// enclosing scope.
for default in function_def
.parameters
for default in parameters
.iter_non_variadic_params()
.filter_map(|param| param.default.as_deref())
{
Expand All @@ -592,7 +612,7 @@ where
// The symbol for the function name itself has to be evaluated
// at the end to match the runtime evaluation of parameter defaults
// and return-type annotations.
let symbol = self.add_symbol(function_def.name.id.clone());
let symbol = self.add_symbol(name.id.clone());
self.add_definition(symbol, function_def);
}
ast::Stmt::ClassDef(class) => {
Expand Down Expand Up @@ -1179,10 +1199,8 @@ where
self.push_scope(NodeWithScopeRef::Lambda(lambda));

// Add symbols and definitions for the parameters to the lambda scope.
if let Some(parameters) = &lambda.parameters {
for parameter in parameters {
self.declare_parameter(parameter);
}
if let Some(parameters) = lambda.parameters.as_ref() {
self.declare_parameters(parameters);
}

self.visit_expr(lambda.body.as_ref());
Expand Down
Loading

0 comments on commit 3017b3b

Please sign in to comment.