Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve types of BaseDatabaseOperations #2238

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions django-stubs/contrib/gis/db/backends/base/operations.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any

from django.db.backends.base.operations import _Converter
from django.db.models.expressions import Expression
from django.utils.functional import cached_property

class BaseSpatialOperations:
Expand All @@ -24,13 +26,13 @@ class BaseSpatialOperations:
def geo_db_type(self, f: Any) -> Any: ...
def get_distance(self, f: Any, value: Any, lookup_type: Any) -> Any: ...
def get_geom_placeholder(self, f: Any, value: Any, compiler: Any) -> Any: ...
def check_expression_support(self, expression: Any) -> None: ...
def check_expression_support(self, expression: Expression) -> None: ...
def spatial_aggregate_name(self, agg_name: Any) -> Any: ...
def spatial_function_name(self, func_name: Any) -> Any: ...
def geometry_columns(self) -> Any: ...
def spatial_ref_sys(self) -> Any: ...
distance_expr_for_lookup: Any
def get_db_converters(self, expression: Any) -> Any: ...
def get_db_converters(self, expression: Expression) -> list[_Converter]: ...
def get_geometry_converter(self, expression: Any) -> Any: ...
def get_area_att_for_field(self, field: Any) -> Any: ...
def get_distance_att_for_field(self, field: Any) -> Any: ...
4 changes: 3 additions & 1 deletion django-stubs/contrib/gis/db/models/aggregates.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any

from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models import Aggregate
from django.db.models import Aggregate, Expression
from django.db.models.sql.compiler import SQLCompiler, _AsSqlType

class GeoAggregate(Aggregate):
Expand All @@ -15,10 +15,12 @@ class Collect(GeoAggregate):
class Extent(GeoAggregate):
name: str
def __init__(self, expression: Any, **extra: Any) -> None: ...
def convert_value(self, value: Any, expression: Expression, connection: BaseDatabaseWrapper) -> Any: ...

class Extent3D(GeoAggregate):
name: str
def __init__(self, expression: Any, **extra: Any) -> None: ...
def convert_value(self, value: Any, expression: Expression, connection: BaseDatabaseWrapper) -> Any: ...

class MakeLine(GeoAggregate):
name: str
Expand Down
31 changes: 17 additions & 14 deletions django-stubs/db/backends/base/operations.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from collections.abc import Iterable, Sequence
from collections.abc import Callable, Sequence
from datetime import date, time, timedelta
from datetime import datetime as real_datetime
from decimal import Decimal
Expand All @@ -13,13 +13,16 @@ from django.db.models.constants import OnConflict
from django.db.models.expressions import Case, Col, Expression
from django.db.models.fields import Field
from django.db.models.sql.compiler import SQLCompiler
from typing_extensions import TypeAlias

_Converter: TypeAlias = Callable[[Any, Expression, BaseDatabaseWrapper], Any]

class BaseDatabaseOperations:
compiler_module: str
integer_field_ranges: dict[str, tuple[int, int]]
set_operators: dict[str, str]
cast_data_types: dict[Any, Any]
cast_char_field_without_max_length: Any
cast_char_field_without_max_length: str | None
PRECEDING: str
FOLLOWING: str
UNBOUNDED_PRECEDING: str
Expand Down Expand Up @@ -57,7 +60,7 @@ class BaseDatabaseOperations:
def pk_default_value(self) -> str: ...
def prepare_sql_script(self, sql: Any) -> list[str]: ...
def process_clob(self, value: str) -> str: ...
def return_insert_columns(self, fields: Any) -> Any: ...
def return_insert_columns(self, fields: list[Field[Any, Any]]) -> tuple[str, list[Any]]: ...
def compiler(self, compiler_name: str) -> type[SQLCompiler]: ...
def quote_name(self, name: str) -> str: ...
def regex_lookup(self, lookup_type: str) -> str: ...
Expand All @@ -66,16 +69,16 @@ class BaseDatabaseOperations:
def savepoint_rollback_sql(self, sid: str) -> str: ...
def set_time_zone_sql(self) -> str: ...
def sql_flush(
self, style: Any, tables: Sequence[str], *, reset_sequences: bool = ..., allow_cascade: bool = ...
self, style: Style, tables: Sequence[str], *, reset_sequences: bool = ..., allow_cascade: bool = ...
) -> list[str]: ...
def execute_sql_flush(self, sql_list: Iterable[str]) -> None: ...
def sequence_reset_by_name_sql(self, style: Style | None, sequences: list[Any]) -> list[Any]: ...
def sequence_reset_sql(self, style: Style, model_list: Sequence[type[Model]]) -> list[Any]: ...
def execute_sql_flush(self, sql_list: list[str]) -> None: ...
def sequence_reset_by_name_sql(self, style: Style, sequences: list[dict[str, str | None]]) -> list[str]: ...
def sequence_reset_sql(self, style: Style, model_list: list[type[Model]]) -> list[str]: ...
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since sequence_reset_sql only iterates over the model_list, I typing this as list seems too strict. Sequence[] or Iterable[] would be better.

Same for execute_sql_flush

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went for list as these methods receive an actual list at runtime (e.g. sequence_reset_sql receives the result of sql_flush), so I assumed it is safe to allow BaseDatabaseOperations subclasses to rely on the available attributes/methods of list in these methods. On the other hand, keeping Sequence seems fine as well

def start_transaction_sql(self) -> str: ...
def end_transaction_sql(self, success: bool = ...) -> str: ...
def tablespace_sql(self, tablespace: str | None, inline: bool = ...) -> str: ...
def prep_for_like_query(self, x: str) -> str: ...
prep_for_iexact_query: Any
def tablespace_sql(self, tablespace: str, inline: bool = ...) -> str: ...
def prep_for_like_query(self, x: object) -> str: ...
def prep_for_iexact_query(self, x: object) -> str: ...
Comment on lines +80 to +81
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think Any has been preferred over object recently.

def validate_autopk_value(self, value: int) -> int: ...
def adapt_unknown_value(self, value: Any) -> Any: ...
def adapt_datefield_value(self, value: date | None) -> str | None: ...
Expand All @@ -89,14 +92,14 @@ class BaseDatabaseOperations:
def adapt_integerfield_value(self, value: Any, internal_type: Any) -> Any: ...
def year_lookup_bounds_for_date_field(self, value: int, iso_year: bool = ...) -> list[str]: ...
def year_lookup_bounds_for_datetime_field(self, value: int, iso_year: bool = ...) -> list[str]: ...
def get_db_converters(self, expression: Expression) -> list[Any]: ...
def get_db_converters(self, expression: Expression) -> list[_Converter]: ...
def convert_durationfield_value(
self, value: float | None, expression: Expression, connection: BaseDatabaseWrapper
) -> timedelta | None: ...
def check_expression_support(self, expression: Any) -> None: ...
def conditional_expression_supported_in_where_clause(self, expression: Any) -> bool: ...
def check_expression_support(self, expression: Expression) -> None: ...
def conditional_expression_supported_in_where_clause(self, expression: Expression) -> bool: ...
def combine_expression(self, connector: str, sub_expressions: list[str]) -> str: ...
def combine_duration_expression(self, connector: Any, sub_expressions: Any) -> str: ...
def combine_duration_expression(self, connector: str, sub_expressions: list[str]) -> str: ...
def binary_placeholder_sql(self, value: Case | None) -> str: ...
def modify_insert_params(self, placeholder: str, params: Any) -> Any: ...
def integer_field_range(self, internal_type: Any) -> tuple[int, int]: ...
Expand Down
8 changes: 0 additions & 8 deletions django-stubs/db/backends/mysql/operations.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,14 @@ class DatabaseOperations(BaseDatabaseOperations):
def force_no_ordering(self) -> Any: ...
def last_executed_query(self, cursor: Any, sql: Any, params: Any) -> Any: ...
def no_limit_value(self) -> Any: ...
def quote_name(self, name: str) -> Any: ...
def return_insert_columns(self, fields: Any) -> Any: ...
def sequence_reset_by_name_sql(self, style: Any, sequences: Any) -> Any: ...
def validate_autopk_value(self, value: Any) -> Any: ...
def adapt_datetimefield_value(self, value: Any) -> Any: ...
def adapt_timefield_value(self, value: Any) -> Any: ...
def max_name_length(self) -> Any: ...
def bulk_insert_sql(self, fields: Any, placeholder_rows: Any) -> Any: ...
def combine_expression(self, connector: Any, sub_expressions: Any) -> Any: ...
def get_db_converters(self, expression: Any) -> Any: ...
def convert_booleanfield_value(self, value: Any, expression: Any, connection: Any) -> Any: ...
def convert_datetimefield_value(self, value: Any, expression: Any, connection: Any) -> Any: ...
def convert_uuidfield_value(self, value: Any, expression: Any, connection: Any) -> Any: ...
def binary_placeholder_sql(self, value: Any) -> Any: ...
def subtract_temporals(self, internal_type: Any, lhs: Any, rhs: Any) -> Any: ...
def explain_query_prefix(self, format: Any | None = ..., **options: Any) -> Any: ...
def regex_lookup(self, lookup_type: str) -> Any: ...
def insert_statement(self, on_conflict: OnConflict | None = ...) -> str: ...
def lookup_cast(self, lookup_type: str, internal_type: Any | None = ...) -> Any: ...
10 changes: 0 additions & 10 deletions django-stubs/db/backends/oracle/operations.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ class DatabaseOperations(BaseDatabaseOperations):
def datetime_extract_sql(self, lookup_type: str, sql: str, params: Any, tzname: str | None) -> tuple[str, Any]: ...
def datetime_trunc_sql(self, lookup_type: str, sql: str, params: Any, tzname: str | None) -> str: ...
def time_trunc_sql(self, lookup_type: str, sql: str, params: Any, tzname: str | None = ...) -> str: ...
def get_db_converters(self, expression: Any) -> list[Any]: ...
def convert_textfield_value(self, value: Any, expression: Any, connection: Any) -> Any: ...
def convert_binaryfield_value(self, value: Any, expression: Any, connection: Any) -> Any: ...
def convert_booleanfield_value(self, value: Any, expression: Any, connection: Any) -> Any: ...
Expand All @@ -40,20 +39,11 @@ class DatabaseOperations(BaseDatabaseOperations):
def max_in_list_size(self) -> int: ...
def max_name_length(self) -> int: ...
def pk_default_value(self) -> str: ...
def prep_for_iexact_query(self, x: Any) -> str: ...
def process_clob(self, value: Any) -> Any: ...
def quote_name(self, name: str) -> str: ...
def regex_lookup(self, lookup_type: str) -> str: ...
def return_insert_columns(self, fields: Any) -> Any: ...
def sequence_reset_by_name_sql(self, style: Any, sequences: Any) -> list[str]: ...
def sequence_reset_sql(self, style: Any, model_list: Any) -> list[str]: ...
def start_transaction_sql(self) -> str: ...
def tablespace_sql(self, tablespace: Any, inline: bool = ...) -> str: ...
def adapt_datefield_value(self, value: Any) -> Any: ...
def adapt_datetimefield_value(self, value: Any) -> Any: ...
def adapt_timefield_value(self, value: Any) -> Any: ...
def combine_expression(self, connector: Any, sub_expressions: Any) -> Any: ...
def bulk_insert_sql(self, fields: Any, placeholder_rows: Any) -> str: ...
def subtract_temporals(self, internal_type: Any, lhs: Any, rhs: Any) -> Any: ...
def bulk_batch_size(self, fields: Any, objs: Any) -> int: ...
def conditional_expression_supported_in_where_clause(self, expression: Any) -> bool: ...
7 changes: 4 additions & 3 deletions django-stubs/db/models/expressions.pyi
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import datetime
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from collections.abc import Iterable, Iterator, Mapping, Sequence
from decimal import Decimal
from typing import Any, ClassVar, Generic, Literal, TypeVar

from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.backends.base.operations import _Converter
from django.db.models import Q, fields
from django.db.models.fields import Field
from django.db.models.lookups import Lookup, Transform
Expand Down Expand Up @@ -63,7 +64,7 @@ class BaseExpression:
window_compatible: bool
allowed_default: bool
def __init__(self, output_field: Field | None = ...) -> None: ...
def get_db_converters(self, connection: BaseDatabaseWrapper) -> list[Callable]: ...
def get_db_converters(self, connection: BaseDatabaseWrapper) -> list[_Converter]: ...
def get_source_expressions(self) -> list[Any]: ...
def set_source_expressions(self, exprs: Sequence[Combinable | Expression]) -> None: ...
@cached_property
Expand All @@ -89,7 +90,7 @@ class BaseExpression:
@cached_property
def output_field(self) -> Field: ...
@cached_property
def convert_value(self) -> Callable: ...
def convert_value(self) -> _Converter: ...
def get_lookup(self, lookup: str) -> type[Lookup] | None: ...
def get_transform(self, name: str) -> type[Transform] | None: ...
def relabeled_clone(self, change_map: Mapping[str, str]) -> Self: ...
Expand Down
3 changes: 2 additions & 1 deletion django-stubs/db/models/functions/datetime.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ from typing import Any, ClassVar
from django.db import models
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models import Func, Transform
from django.db.models.expressions import Combinable
from django.db.models.expressions import Combinable, Expression
from django.db.models.fields import Field
from django.db.models.sql.compiler import SQLCompiler, _AsSqlType

Expand Down Expand Up @@ -44,6 +44,7 @@ class TruncBase(TimezoneMixin, Transform):
self, expression: Combinable | str, output_field: Field | None = ..., tzinfo: tzinfo | None = ..., **extra: Any
) -> None: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ... # type: ignore[override]
def convert_value(self, value: Any, expression: Expression, connection: BaseDatabaseWrapper) -> Any: ...

class Trunc(TruncBase):
def __init__(
Expand Down
Loading