Skip to content

Add precommit #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ jobs:
python-version: ${{ matrix.python }}
- name: Install project dependencies
run: pip install -e ".[dev]"
- name: Test code's formatting (Black)
run: black --check docs tests src/xarray_dataclasses
- name: Test code's typing (Pyright)
run: pyright docs tests src/xarray_dataclasses
- name: Test code's execution (pytest)
run: pytest -v tests
- name: Test docs' building (Sphinx)
Expand Down
25 changes: 25 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-added-large-files
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.5
hooks:
- id: ruff-check
args: [--fix] # optional, to autofix lint errors
- id: ruff-format
- repo: https://github.com/fsouza/mirrors-pyright
rev: v1.1.403
hooks:
- id: pyright
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.24.1
hooks:
- id: validate-pyproject
name: "python · Validate pyproject.toml"
additional_dependencies: ["validate-pyproject-schema-store[all]"]
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Important Note
This repository was a fork from [here](https://github.com/astropenguin/xarray-dataclasses). We are grateful for the
work of the developer on this repo. That being said, sadly the state on main has been the case that the code was
deleted and there has been no development for a while. Therefore, we intially decided to fork the repository and
This repository was a fork from [here](https://github.com/astropenguin/xarray-dataclasses). We are grateful for the
work of the developer on this repo. That being said, sadly the state on main has been the case that the code was
deleted and there has been no development for a while. Therefore, we intially decided to fork the repository and
continue development here, where the community is better able to contribute to and maintain the project. We now changed
it into a standalone repository.

Expand Down Expand Up @@ -392,7 +392,7 @@ pixi run pyright
```

### Creating documentation
We also have a [documentation workflow] (Add link). However, if you want to locally create the documentation
We also have a [documentation workflow] (Add link). However, if you want to locally create the documentation
run the following:

```shell
Expand Down
2 changes: 1 addition & 1 deletion docs/_static/logo-dark.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/_static/logo-light.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1,826 changes: 934 additions & 892 deletions pixi.lock

Large diffs are not rendered by default.

18 changes: 8 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,14 @@ Documentation = "https://xarray-contrib.github.io/xarray-dataclasses/"

[project.optional-dependencies]
dev = [
"black>=24.4",
"build>=1.0.0",
"flake8>=7.0.0",
"ipython>=8.12",
"myst-parser>=3.0",
"pydata-sphinx-theme>=0.14",
"pyright>=1.1",
"pytest>=8.2",
"sphinx>=7.1",
"twine>=4.0.0",
"pre-commit>=4.2.0,<5"
]

[tool.hatch.build]
Expand All @@ -72,16 +70,16 @@ dev = { features = ["dev"], solve-group = "default" }

[tool.pixi.feature.dev.tasks]
tests = "pytest"
flake8 = "flake8 docs tests src/xarray_dataclasses"
black = "black docs tests src/xarray_dataclasses"
doc_build = { cmd = "sphinx-apidoc -efT -o docs/_apidoc src/xarray_dataclasses && sphinx-build -a docs docs/_build" }
pyright = "pyright docs tests src/xarray_dataclasses"

[tool.pyright]
reportImportCycles = "warning"
reportUnknownArgumentType = "warning"
reportUnknownMemberType = "warning"
reportUnknownVariableType = "warning"
reportMissingImports = "none"
reportImportCycles = "none"
reportUnknownArgumentType = "none"
reportUnknownMemberType = "none"
reportUnknownVariableType = "none"
reportUnknownParameterType = "none"
reportUntypedFunctionDecorator = "none"
typeCheckingMode = "strict"

[tool.pixi.pypi-dependencies]
Expand Down
17 changes: 12 additions & 5 deletions src/xarray_dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,15 @@
from . import dataoptions
from . import typing
from .__about__ import __version__
from .dataarray import *
from .dataset import *
from .datamodel import *
from .dataoptions import *
from .typing import *
from .dataarray import AsDataArray, asdataarray
from .dataset import AsDataset, asdataset
from .datamodel import DataModel
from .dataoptions import DataOptions
from .typing import (
Attr,
Coord,
Coordof,
Data,
Dataof,
Name,
)
27 changes: 14 additions & 13 deletions src/xarray_dataclasses/datamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import (
Any,
Dict,
get_type_hints,
Hashable,
List,
Literal,
Expand All @@ -22,7 +23,7 @@
# dependencies
import numpy as np
import xarray as xr
from typing_extensions import ParamSpec, get_type_hints
from typing_extensions import ParamSpec


# submodules
Expand Down Expand Up @@ -133,7 +134,7 @@ def __post_init__(self) -> None:
if model.names:
setattr(self, "name", model.names[0].value)

def __call__(self, reference: Optional[AnyXarray] = None) -> xr.DataArray:
def __call__(self, reference: Optional[AnyXarray] = None) -> xr.DataArray: # pyright: ignore[reportUnknownParameterType]
"""Create a DataArray object according to the entry."""
from .dataarray import asdataarray

Expand Down Expand Up @@ -187,12 +188,12 @@ def from_dataclass(cls, dataclass: AnyDataClass[PInit]) -> "DataModel":
model = cls()
eval_dataclass(dataclass)

for field in dataclass.__dataclass_fields__.values():
value = getattr(dataclass, field.name, MISSING)
entry = get_entry(field, value)
for field_value in dataclass.__dataclass_fields__.values():
value = getattr(dataclass, field_value.name, MISSING)
entry = get_entry(field_value, value)

if entry is not None:
model.entries[field.name] = entry
model.entries[field_value.name] = entry

return model

Expand All @@ -203,10 +204,10 @@ def eval_dataclass(dataclass: AnyDataClass[PInit]) -> None:
if not is_dataclass(dataclass):
raise TypeError("Not a dataclass or its object.")

fields = dataclass.__dataclass_fields__.values()
field_values = dataclass.__dataclass_fields__.values()

# do nothing if field types are already evaluated
if not any(isinstance(field.type, str) for field in fields):
if not any(isinstance(field_value.type, str) for field_value in field_values):
return

# otherwise, replace field types with evaluated types
Expand All @@ -215,8 +216,8 @@ def eval_dataclass(dataclass: AnyDataClass[PInit]) -> None:

types = get_type_hints(dataclass, include_extras=True)

for field in fields:
field.type = types[field.name]
for field_value in field_values:
field_value.type = types[field_value.name]


def get_entry(field: AnyField, value: Any) -> Optional[AnyEntry]:
Expand Down Expand Up @@ -250,11 +251,11 @@ def get_entry(field: AnyField, value: Any) -> Optional[AnyEntry]:
)


def get_typedarray(
def get_typedarray( # pyright: ignore[reportUnknownParameterType]
data: Any,
dims: Dims,
dtype: Optional[AnyDType],
reference: Optional[AnyXarray] = None,
dtype: Optional[AnyDType], # pyright: ignore[reportUnknownParameterType]
reference: Optional[AnyXarray] = None, # pyright: ignore[reportUnknownParameterType]
) -> xr.DataArray:
"""Create a DataArray object with given dims and dtype.

Expand Down
23 changes: 10 additions & 13 deletions src/xarray_dataclasses/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@
from enum import Enum
from itertools import chain
from typing import (
Annotated,
Any,
ClassVar,
Collection,
Dict,
Generic,
get_args,
get_origin,
get_type_hints,
Hashable,
Iterable,
Literal,
Expand All @@ -44,14 +48,7 @@
# dependencies
import numpy as np
import xarray as xr
from typing_extensions import (
Annotated,
ParamSpec,
TypeAlias,
get_args,
get_origin,
get_type_hints,
)
from typing_extensions import ParamSpec, TypeAlias


# type hints (private)
Expand All @@ -62,10 +59,10 @@
TDType = TypeVar("TDType", covariant=True)
THashable = TypeVar("THashable", bound=Hashable)

AnyArray: TypeAlias = "np.ndarray[Any, Any]"
AnyDType: TypeAlias = "np.dtype[Any]"
AnyField: TypeAlias = "Field[Any]"
AnyXarray: TypeAlias = "xr.DataArray | xr.Dataset"
AnyArray: TypeAlias = np.ndarray[Any, Any]
AnyDType: TypeAlias = np.dtype[Any]
AnyField: TypeAlias = Field[Any]
AnyXarray: TypeAlias = Union[xr.DataArray, xr.Dataset]
Dims = Tuple[str, ...]
Order = Literal["C", "F"]
Shape = Union[Sequence[int], int]
Expand Down Expand Up @@ -325,7 +322,7 @@ def get_dims(tp: Any) -> Dims:
return tuple(str(get_args(arg)[0]) for arg in args)


def get_dtype(tp: Any) -> Optional[AnyDType]:
def get_dtype(tp: Any) -> Optional[AnyDType]: # pyright: ignore[reportUnknownParameterType]
"""Extract a NumPy data type (dtype)."""
try:
dtype = get_args(get_args(get_annotated(tp))[1])[0]
Expand Down
8 changes: 2 additions & 6 deletions tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from dataclasses import dataclass
from typing import Literal, Tuple


# dependencies
import numpy as np
import xarray as xr
from typing_extensions import TypeAlias


# submodules
Expand All @@ -24,11 +24,7 @@
Y = Literal["y"]


# dataclasses
class Custom(xr.DataArray):
"""Custom DataArray."""

__slots__ = ()
Custom: TypeAlias = xr.DataArray


@dataclass
Expand Down
22 changes: 11 additions & 11 deletions tests/test_datamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_xaxis_attr() -> None:
assert units.tag == "attr"
assert units.type is str
assert units.value == "pixel"
assert units.cast == False
assert units.cast is False


def test_xaxis_data() -> None:
Expand All @@ -64,7 +64,7 @@ def test_xaxis_data() -> None:
assert data.dims == ("x",)
assert data.dtype == "int"
assert data.base is None
assert data.cast == True
assert data.cast is True


def test_yaxis_attr() -> None:
Expand All @@ -73,7 +73,7 @@ def test_yaxis_attr() -> None:
assert units.tag == "attr"
assert units.type is str
assert units.value == "pixel"
assert units.cast == False
assert units.cast is False


def test_yaxis_data() -> None:
Expand All @@ -83,7 +83,7 @@ def test_yaxis_data() -> None:
assert data.dims == ("y",)
assert data.dtype == "int"
assert data.base is None
assert data.cast == True
assert data.cast is True


def test_image_coord() -> None:
Expand All @@ -93,23 +93,23 @@ def test_image_coord() -> None:
assert mask.dims == ("x", "y")
assert mask.dtype == "bool"
assert mask.base is None
assert mask.cast == True
assert mask.cast is True

x = image_model.coords[1]
assert x.name == "x"
assert x.tag == "coord"
assert x.dims == ("x",)
assert x.dtype == "int"
assert x.base is XAxis
assert x.cast == True
assert x.cast is True

y = image_model.coords[2]
assert y.name == "y"
assert y.tag == "coord"
assert y.dims == ("y",)
assert y.dtype == "int"
assert y.base is YAxis
assert y.cast == True
assert y.cast is True


def test_image_data() -> None:
Expand All @@ -119,7 +119,7 @@ def test_image_data() -> None:
assert data.dims == ("x", "y")
assert data.dtype == "float"
assert data.base is None
assert data.cast == True
assert data.cast is True


def test_color_data() -> None:
Expand All @@ -129,20 +129,20 @@ def test_color_data() -> None:
assert red.dims == ("x", "y")
assert red.dtype == "float"
assert red.base is Image
assert red.cast == True
assert red.cast is True

green = color_model.data_vars[1]
assert green.name == "green"
assert green.tag == "data"
assert green.dims == ("x", "y")
assert green.dtype == "float"
assert green.base is Image
assert green.cast == True
assert green.cast is True

blue = color_model.data_vars[2]
assert blue.name == "blue"
assert blue.tag == "data"
assert blue.dims == ("x", "y")
assert blue.dtype == "float"
assert blue.base is Image
assert blue.cast == True
assert blue.cast is True
5 changes: 2 additions & 3 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# dependencies
import numpy as np
import xarray as xr
from typing_extensions import TypeAlias


# submodules
Expand All @@ -24,9 +25,7 @@
Y = Literal["y"]


# dataclasses
class Custom(xr.Dataset):
__slots__ = ()
Custom: TypeAlias = xr.Dataset


@dataclass
Expand Down