Skip to content

Commit

Permalink
docs: fix docs again
Browse files Browse the repository at this point in the history
  • Loading branch information
34j committed Nov 24, 2024
1 parent 1810177 commit f3c9eb4
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 30 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,5 @@ tmp/
*.egg
dist/
.DS_STORE
venv
.venv
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,13 @@ repos:
rev: 23.7.0
hooks:
- id: black

- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.0.0"
hooks:
- id: mypy
additional_dependencies: [typing_extensions>=4.4.0]
args:
- --ignore-missing-imports
- --config=pyproject.toml
files: ".*(_draft.*)$"
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,12 @@ build-backend = "setuptools.build_meta"

[tool.black]
line-length = 88

[tool.mypy]
python_version = "3.9"
mypy_path = "$MYPY_CONFIG_FILE_DIR/src/array_api_stubs/_draft/"
files = [
"src/array_api_stubs/_draft/**/*.py"
]
follow_imports = "silent"
disable_error_code = "empty-body,type-var"
4 changes: 4 additions & 0 deletions src/_array_api_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,11 @@
]
nitpick_ignore_regex = [
("py:class", ".*array"),
("py:class", ".*Array"),
("py:class", ".*device"),
("py:class", ".*Device"),
("py:class", ".*dtype"),
("py:class", ".*DType"),
("py:class", ".*NestedSequence"),
("py:class", ".*SupportsBufferProtocol"),
("py:class", ".*PyCapsule"),
Expand All @@ -84,6 +87,7 @@
"array": "array",
"Device": "device",
"Dtype": "dtype",
"DType": "dtype",
}

# Make autosummary show the signatures of functions in the tables using actual
Expand Down
22 changes: 12 additions & 10 deletions src/array_api_stubs/_draft/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
"Info",
]

from dataclasses import dataclass
from typing import (
Any,
List,
Expand All @@ -45,10 +44,13 @@
Protocol,
)
from enum import Enum
from .data_types import DType

array = TypeVar("array", bound="array_")
array = TypeVar("array", bound="Array")
device = TypeVar("device")
dtype = TypeVar("dtype")
dtype = TypeVar("dtype", bound=DType)
device_ = TypeVar("device_") # only used in this file
dtype_ = TypeVar("dtype_", bound=DType) # only used in this file
SupportsDLPack = TypeVar("SupportsDLPack")
SupportsBufferProtocol = TypeVar("SupportsBufferProtocol")
PyCapsule = TypeVar("PyCapsule")
Expand Down Expand Up @@ -88,7 +90,7 @@ def __len__(self, /) -> int:
...


class Info(Protocol):
class Info(Protocol[device]):
"""Namespace returned by `__array_namespace_info__`."""

def capabilities(self) -> Capabilities:
Expand Down Expand Up @@ -147,12 +149,12 @@ def dtypes(
)


class _array(Protocol[array, dtype, device]):
class Array(Protocol[array, dtype_, device_, PyCapsule]): # type: ignore
def __init__(self: array) -> None:
"""Initialize the attributes for the array object class."""

@property
def dtype(self: array) -> dtype:
def dtype(self: array) -> dtype_:
"""
Data type of the array elements.
Expand All @@ -163,7 +165,7 @@ def dtype(self: array) -> dtype:
"""

@property
def device(self: array) -> device:
def device(self: array) -> device_:
"""
Hardware device the array data resides on.
Expand Down Expand Up @@ -625,7 +627,7 @@ def __dlpack_device__(self: array, /) -> Tuple[Enum, int]:
ONE_API = 14
"""

def __eq__(self: array, other: Union[int, float, bool, array], /) -> array:
def __eq__(self: array, other: Union[int, float, bool, array], /) -> array: # type: ignore
r"""
Computes the truth value of ``self_i == other_i`` for each element of an array instance with the respective element of the array ``other``.
Expand Down Expand Up @@ -1072,7 +1074,7 @@ def __mul__(self: array, other: Union[int, float, array], /) -> array:
Added complex data type support.
"""

def __ne__(self: array, other: Union[int, float, bool, array], /) -> array:
def __ne__(self: array, other: Union[int, float, bool, array], /) -> array: # type: ignore
"""
Computes the truth value of ``self_i != other_i`` for each element of an array instance with the respective element of the array ``other``.
Expand Down Expand Up @@ -1342,7 +1344,7 @@ def __xor__(self: array, other: Union[int, bool, array], /) -> array:
"""

def to_device(
self: array, device: device, /, *, stream: Optional[Union[int, Any]] = None
self: array, device: device_, /, *, stream: Optional[Union[int, Any]] = None
) -> array:
"""
Copy the array from the device on which it currently resides to the specified ``device``.
Expand Down
6 changes: 6 additions & 0 deletions src/array_api_stubs/_draft/array_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ._types import Array

# for documentation
array = Array

__all__ = ["array"]
36 changes: 19 additions & 17 deletions src/array_api_stubs/_draft/data_types.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
__all__ = ["__eq__"]
from __future__ import annotations

__all__ = ["DType"]

from ._types import dtype

from typing import Protocol

def __eq__(self: dtype, other: dtype, /) -> bool:
"""
Computes the truth value of ``self == other`` in order to test for data type object equality.

Parameters
----------
self: dtype
data type instance. May be any supported data type.
other: dtype
other data type instance. May be any supported data type.
Returns
-------
out: bool
a boolean indicating whether the data type objects are equal.
"""
class DType(Protocol):
def __eq__(self, other: DType, /) -> bool:
"""
Computes the truth value of ``self == other`` in order to test for data type object equality.
Parameters
----------
self: dtype
data type instance. May be any supported data type.
other: dtype
other data type instance. May be any supported data type.
Returns
-------
out: bool
a boolean indicating whether the data type objects are equal.
"""
...
6 changes: 3 additions & 3 deletions src/array_api_stubs/_draft/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def matrix_norm(
/,
*,
keepdims: bool = False,
ord: Optional[Union[int, float, Literal[inf, -inf, "fro", "nuc"]]] = "fro",
ord: Optional[Union[int, float, Literal[inf, -inf, "fro", "nuc"]]] = "fro", # type: ignore
) -> array:
"""
Computes the matrix norm of a matrix (or a stack of matrices) ``x``.
Expand Down Expand Up @@ -781,7 +781,7 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[dtype] = None) -> arr
"""


def vecdot(x1: array, x2: array, /, *, axis: int = None) -> array:
def vecdot(x1: array, x2: array, /, *, axis: Optional[int] = None) -> array:
"""Alias for :func:`~array_api.vecdot`."""


Expand All @@ -791,7 +791,7 @@ def vector_norm(
*,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims: bool = False,
ord: Union[int, float, Literal[inf, -inf]] = 2,
ord: Union[int, float, Literal[inf, -inf]] = 2, # type: ignore
) -> array:
r"""
Computes the vector norm of a vector (or batch of vectors) ``x``.
Expand Down

0 comments on commit f3c9eb4

Please sign in to comment.