diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b065ed7..0bb29c7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -50,7 +50,7 @@ jobs: allow-prereleases: true - name: Install requirementes - run: python -m pip install -r requirements-dev.txt + run: python -m pip install -r requirements/dev.txt - name: Install package run: python -m pip install . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 586cb6f..ad3a4ce 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -52,11 +52,10 @@ repos: - id: debug-statements - id: end-of-file-fixer - id: mixed-line-ending - - id: requirements-txt-fixer - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.8 + rev: v0.4.9 hooks: - id: ruff args: [--fix, --show-fixes] @@ -69,10 +68,11 @@ repos: files: src|tests args: [--no-install-types] additional_dependencies: - - dace==0.15.1 - - jax[cpu]==0.4.28 - - numpy==1.26.4 - - pytest==8.2.1 + - dace==0.16 + - jax[cpu]==0.4.29 + - numpy==2.0.0 + - pytest==8.2.2 + - typing-extensions==4.12.2 - repo: https://github.com/codespell-project/codespell rev: v2.3.0 hooks: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 19d5adb..555d9d5 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -9,6 +9,7 @@ The fastest way to start with development is to use nox. If you don't have nox, To use, run `nox`. This will lint and test using every installed version of Python on your system, skipping ones that are not installed. You can also run specific jobs: ```console +$ nox -s venv-3.10 # (or venv-3.11, or venv-3.12) Setup a fully working development environment $ nox -s lint # Lint only $ nox -s tests # Python tests $ nox -s docs -- --serve # Build and serve the docs @@ -25,16 +26,16 @@ You can set up a development environment by running: python3 -m venv .venv source ./.venv/bin/activate pip install --upgrade pip setuptools wheel -pip install -r requirements-dev.txt +pip install -r requirements/dev.txt pip install -v -e . ``` -If you have the [Python Launcher for Unix](https://github.com/brettcannon/python-launcher), you can instead do: +Or, if you have the [Python Launcher for Unix](https://github.com/brettcannon/python-launcher), you could do: ```bash py -m venv .venv py -m pip install --upgrade pip setuptools wheel -py -m pip install -r requirements-dev.txt +py -m pip install -r requirements/dev.txt py -m pip install -v -e . ``` @@ -43,7 +44,7 @@ py -m pip install -v -e . You should prepare pre-commit, which will help you by checking that commits pass required checks: ```bash -pip install pre-commit # or brew install pre-commit on macOS +pipx install pre-commit # or brew install pre-commit on macOS pre-commit install # Will install a pre-commit hook into the git repo ``` diff --git a/noxfile.py b/noxfile.py index 6c53e26..b6aec1b 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,22 +1,45 @@ +"""Nox session definitions.""" + from __future__ import annotations import argparse +import pathlib +import re import shutil -from pathlib import Path import nox -DIR = Path(__file__).parent.resolve() - nox.needs_version = ">=2024.3.2" -nox.options.sessions = ["lint", "pylint", "tests"] +nox.options.sessions = ["lint", "tests"] nox.options.default_venv_backend = "uv|virtualenv" -@nox.session +ROOT_DIR = pathlib.Path(__file__).parent.resolve() +DEFAULT_DEV_VENV_PATH = ROOT_DIR / ".venv" + + +def load_from_frozen_requirements(filename: str) -> dict[str, str]: + requirements = {} + with pathlib.Path(filename).open(encoding="locale") as f: + for raw_line in f: + if (end := raw_line.find("#")) != -1: + raw_line = raw_line[:end] # noqa: PLW2901 [redefined-loop-name] + line = raw_line.strip() + if line and not line.startswith("-"): + m = re.match(r"^([^=]*)\s*([^;]*)\s*;?\s*(.*)$", line) + if m: + requirements[m[1]] = m[2] + + return requirements + + +REQUIREMENTS = load_from_frozen_requirements(ROOT_DIR / "requirements" / "dev.txt") + + +@nox.session(python="3.10") def lint(session: nox.Session) -> None: - """Run the linter.""" + """Run the linter (pre-commit).""" session.install("pre-commit") session.run("pre-commit", "run", "--all-files", "--show-diff-on-failure", *session.posargs) @@ -24,13 +47,109 @@ def lint(session: nox.Session) -> None: @nox.session def tests(session: nox.Session) -> None: """Run the unit and regular tests.""" - session.install(".[test]") + session.install("-e", ".", "-r", "requirements/dev.txt") session.run("pytest", *session.posargs) +@nox.session(python=["3.10", "3.11", "3.12"]) +def venv(session: nox.Session) -> None: + """ + Sets up a Python development environment. Use as: `nox -s venv-3.xx -- [req_preset] [dest_path] + + req_preset: The requirements file to use as 'requirements/{req_preset}.txt'. + Default: 'dev' + dest_path (optional): The path to the virtualenv to create. + Default: '.venv-{3.xx}-{req_preset}' + + This session will: + - Create a python virtualenv for the session + - Install the `virtualenv` cli tool into this environment + - Use `virtualenv` to create a project virtual environment + - Invoke the python interpreter from the created project environment + to install the project and all it's development dependencies. + """ # noqa: W505 [doc-line-too-long] + req_preset = "dev" + venv_path = None + virtualenv_args = [] + if session.posargs: + req_preset, *more_pos_args = session.posargs + if more_pos_args: + venv_path, *_ = more_pos_args + if not venv_path: + venv_path = f"{DEFAULT_DEV_VENV_PATH}-{session.python}-{req_preset}" + venv_path = pathlib.Path(venv_path).resolve() + + if not venv_path.exists(): + print(f"Creating virtualenv at '{venv_path}' (options: {virtualenv_args})...") + session.install("virtualenv") + session.run("virtualenv", venv_path, silent=True) + elif venv_path.exists(): + assert ( + venv_path.is_dir() and (venv_path / "bin" / f"python{session.python}").exists + ), f"'{venv_path}' path already exists but is not a virtualenv with python{session.python}." + print(f"'{venv_path}' path already exists. Skipping virtualenv creation...") + + python_path = venv_path / "bin" / "python" + requirements_file = f"requirements/{req_preset}.txt" + + # Use the venv's interpreter to install the project along with + # all it's dev dependencies, this ensures it's installed in the right way + print(f"Setting up development environment from '{requirements_file}'...") + session.run( + python_path, + "-m", + "pip", + "install", + "-r", + requirements_file, + "-e.", + external=True, + ) + + +@nox.session(reuse_venv=True) +def requirements(session: nox.Session) -> None: + """Freeze requirements files from project specification and synchronize versions across tools.""" # noqa: W505 [doc-line-too-long] + requirements_path = ROOT_DIR / "requirements" + req_sync_tool = requirements_path / "sync_tool.py" + + dependencies = ["pre-commit"] + nox.project.load_toml(req_sync_tool)["dependencies"] + session.install(*dependencies) + session.install("pip-compile-multi") + + session.run("python", req_sync_tool, "pull") + session.run("pip-compile-multi", "-g", "--skip-constraints") + session.run("python", req_sync_tool, "push") + + session.run("pre-commit", "run", "--files", ".pre-commit-config.yaml", success_codes=[0, 1]) + + @nox.session(reuse_venv=True) def docs(session: nox.Session) -> None: - """Build the docs. Pass "--serve" to serve. Pass "-b linkcheck" to check links.""" + """Regenerate and build all API and user docs.""" + session.notify("api_docs") + session.notify("user_docs", posargs=session.posargs) + + +@nox.session(reuse_venv=True) +def api_docs(session: nox.Session) -> None: + """Build (regenerate) API docs.""" + session.install(f"sphinx=={REQUIREMENTS['sphinx']}") + session.chdir("docs") + session.run( + "sphinx-apidoc", + "-o", + "api/", + "--module-first", + "--no-toc", + "--force", + "../src/jace", + ) + + +@nox.session(reuse_venv=True) +def user_docs(session: nox.Session) -> None: + """Build the user docs. Pass "--serve" to serve. Pass "-b linkcheck" to check links.""" # noqa: W505 [doc-line-too-long] parser = argparse.ArgumentParser() parser.add_argument("--serve", action="store_true", help="Serve after building") parser.add_argument("-b", dest="builder", default="html", help="Build target (default: html)") @@ -40,8 +159,7 @@ def docs(session: nox.Session) -> None: session.error("Must not specify non-HTML builder with --serve") extra_installs = ["sphinx-autobuild"] if args.serve else [] - - session.install("-e.[docs]", *extra_installs) + session.install("-e", ".", "-r", "requirements/dev.txt", *extra_installs) session.chdir("docs") if args.builder == "linkcheck": @@ -63,28 +181,12 @@ def docs(session: nox.Session) -> None: session.run("sphinx-build", "--keep-going", *shared_args) -@nox.session -def build_api_docs(session: nox.Session) -> None: - """Build (regenerate) API docs.""" - session.install("sphinx") - session.chdir("docs") - session.run( - "sphinx-apidoc", - "-o", - "api/", - "--module-first", - "--no-toc", - "--force", - "../src/jace", - ) - - @nox.session def build(session: nox.Session) -> None: """Build an SDist and wheel.""" - build_path = DIR.joinpath("build") + build_path = ROOT_DIR / "build" if build_path.exists(): shutil.rmtree(build_path) - session.install("build") + session.install(f"build=={REQUIREMENTS['build']}") session.run("python", "-m", "build") diff --git a/pyproject.toml b/pyproject.toml index 37c0d3d..0add471 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ "Typing :: Typed", ] dependencies = [ - "dace>=0.15", + "dace>=0.16", "jax[cpu]>=0.4.24", "numpy>=1.26.0", ] @@ -211,7 +211,7 @@ tests = [ max-complexity = 12 [tool.ruff.lint.per-file-ignores] -"!tests/**.py" = ["PT"] # Ignore flake8-pytest-style outside 'tests/' +"!tests/**" = ["PT"] # Ignore flake8-pytest-style outside 'tests/' "docs/**" = [ "D", # pydocstyle "T10", # flake8-debugger @@ -219,6 +219,12 @@ max-complexity = 12 ] "noxfile.py" = [ "D", # pydocstyle + "T10", # flake8-debugger + "T20", # flake8-print +] +"requirements/**" = [ + "D", # pydocstyle + "T10", # flake8-debugger "T20", # flake8-print ] "tests/**" = [ diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index a7a822e..0000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,11 +0,0 @@ -furo>=2023.08.17 -mypy >= 1.9.0 -myst_parser>=0.13 -pytest >=6 -pytest-cov >=3 -ruff >= 0.3.5 -sphinx>=7.0 -sphinx_autodoc_typehints -sphinx_copybutton -types-all -typing-extensions>=4.10.0 diff --git a/requirements/base.in b/requirements/base.in new file mode 100644 index 0000000..9fee484 --- /dev/null +++ b/requirements/base.in @@ -0,0 +1,3 @@ +dace>=0.16 +jax[cpu]>=0.4.24 +numpy>=1.26.0 diff --git a/requirements/base.txt b/requirements/base.txt new file mode 100644 index 0000000..fb784ca --- /dev/null +++ b/requirements/base.txt @@ -0,0 +1,69 @@ +# SHA1:50585cb1d4e4cc2297a939939d360c886c4ee3e4 +# +# This file is autogenerated by pip-compile-multi +# To update, run: +# +# pip-compile-multi +# +aenum==3.1.15 + # via dace +astunparse==1.6.3 + # via dace +dace==0.16 + # via -r requirements/base.in +dill==0.3.8 + # via dace +fparser==0.1.4 + # via dace +jax[cpu]==0.4.29 + # via -r requirements/base.in +jaxlib==0.4.29 + # via jax +jinja2==3.1.4 + # via dace +markupsafe==2.1.5 + # via jinja2 +ml-dtypes==0.4.0 + # via + # jax + # jaxlib +mpmath==1.3.0 + # via sympy +networkx==3.3 + # via dace +numpy==2.0.0 + # via + # -r requirements/base.in + # dace + # jax + # jaxlib + # ml-dtypes + # opt-einsum + # scipy +opt-einsum==3.3.0 + # via jax +packaging==24.1 + # via setuptools-scm +ply==3.11 + # via dace +pyyaml==6.0.1 + # via dace +scipy==1.13.1 + # via + # jax + # jaxlib +setuptools-scm==8.1.0 + # via fparser +six==1.16.0 + # via astunparse +sympy==1.12.1 + # via dace +tomli==2.0.1 + # via setuptools-scm +websockets==12.0 + # via dace +wheel==0.43.0 + # via astunparse + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/requirements/cuda12.in b/requirements/cuda12.in new file mode 100644 index 0000000..d603a3b --- /dev/null +++ b/requirements/cuda12.in @@ -0,0 +1,4 @@ +-r base.in +cupy-cuda12x>=12.1.0 +jax[cuda12]>=0.4.24 +optuna>=3.4.0 diff --git a/requirements/cuda12.txt b/requirements/cuda12.txt new file mode 100644 index 0000000..ebeb3aa --- /dev/null +++ b/requirements/cuda12.txt @@ -0,0 +1,72 @@ +# SHA1:035352ab483a9ee349c593a1ff7f359a88012cc9 +# +# This file is autogenerated by pip-compile-multi +# To update, run: +# +# pip-compile-multi +# +-r base.txt +alembic==1.13.1 + # via optuna +colorlog==6.8.2 + # via optuna +cupy-cuda12x==13.2.0 + # via -r requirements/cuda12.in +fastrlock==0.8.2 + # via cupy-cuda12x +greenlet==3.0.3 + # via sqlalchemy +jax[cpu,cuda12]==0.4.29 + # via + # -r requirements/base.in + # -r requirements/cuda12.in +jax-cuda12-pjrt==0.4.29 + # via jax-cuda12-plugin +jax-cuda12-plugin==0.4.29 + # via jax +mako==1.3.5 + # via alembic +nvidia-cublas-cu12==12.5.2.13 + # via + # jax + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.5.39 + # via jax +nvidia-cuda-nvcc-cu12==12.5.40 + # via jax +nvidia-cuda-runtime-cu12==12.5.39 + # via jax +nvidia-cudnn-cu12==9.1.1.17 + # via jax +nvidia-cufft-cu12==11.2.3.18 + # via jax +nvidia-cusolver-cu12==11.6.2.40 + # via jax +nvidia-cusparse-cu12==12.4.1.24 + # via + # jax + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.21.5 + # via jax +nvidia-nvjitlink-cu12==12.5.40 + # via + # jax + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 +optuna==3.6.1 + # via -r requirements/cuda12.in +sqlalchemy==2.0.30 + # via + # alembic + # optuna +tqdm==4.66.4 + # via optuna +typing-extensions==4.12.2 + # via + # alembic + # sqlalchemy + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/requirements/dev-cuda12.in b/requirements/dev-cuda12.in new file mode 100644 index 0000000..496e623 --- /dev/null +++ b/requirements/dev-cuda12.in @@ -0,0 +1,2 @@ +-r cuda12.in +-r dev.in diff --git a/requirements/dev-cuda12.txt b/requirements/dev-cuda12.txt new file mode 100644 index 0000000..0dca1e7 --- /dev/null +++ b/requirements/dev-cuda12.txt @@ -0,0 +1,12 @@ +# SHA1:bdbfa7e1d9b9ca837d092c4efc6792c2b58238be +# +# This file is autogenerated by pip-compile-multi +# To update, run: +# +# pip-compile-multi +# +-r cuda12.txt +-r dev.txt + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/requirements/dev.in b/requirements/dev.in new file mode 100644 index 0000000..4421d27 --- /dev/null +++ b/requirements/dev.in @@ -0,0 +1,14 @@ +-r base.in +build>=1.2 +furo>=2023.08.17 +mypy>=1.9.0 +myst_parser>=0.13 +pytest>=6 +pytest-cov>=3 +ruff>=0.3.5 +sphinx>=7.0 +sphinx-autobuild>=2021.3.14 +sphinx_autodoc_typehints>=2.1 +sphinx_copybutton>=0.5 +tomlkit>=0.12.4 +typing-extensions>=4.10.0 diff --git a/requirements/dev.txt b/requirements/dev.txt new file mode 100644 index 0000000..b73bf83 --- /dev/null +++ b/requirements/dev.txt @@ -0,0 +1,136 @@ +# SHA1:60e060370596513d7e06534a0655974dcc750dcd +# +# This file is autogenerated by pip-compile-multi +# To update, run: +# +# pip-compile-multi +# +-r base.txt +alabaster==0.7.16 + # via sphinx +anyio==4.4.0 + # via + # starlette + # watchfiles +babel==2.15.0 + # via sphinx +beautifulsoup4==4.12.3 + # via furo +build==1.2.1 + # via -r requirements/dev.in +certifi==2024.6.2 + # via requests +charset-normalizer==3.3.2 + # via requests +click==8.1.7 + # via uvicorn +colorama==0.4.6 + # via sphinx-autobuild +coverage[toml]==7.5.3 + # via pytest-cov +docutils==0.21.2 + # via + # myst-parser + # sphinx +exceptiongroup==1.2.1 + # via + # anyio + # pytest +furo==2024.5.6 + # via -r requirements/dev.in +h11==0.14.0 + # via uvicorn +idna==3.7 + # via + # anyio + # requests +imagesize==1.4.1 + # via sphinx +iniconfig==2.0.0 + # via pytest +markdown-it-py==3.0.0 + # via + # mdit-py-plugins + # myst-parser +mdit-py-plugins==0.4.1 + # via myst-parser +mdurl==0.1.2 + # via markdown-it-py +mypy==1.10.0 + # via -r requirements/dev.in +mypy-extensions==1.0.0 + # via mypy +myst-parser==3.0.1 + # via -r requirements/dev.in +pluggy==1.5.0 + # via pytest +pygments==2.18.0 + # via + # furo + # sphinx +pyproject-hooks==1.1.0 + # via build +pytest==8.2.2 + # via + # -r requirements/dev.in + # pytest-cov +pytest-cov==5.0.0 + # via -r requirements/dev.in +requests==2.32.3 + # via sphinx +ruff==0.4.9 + # via -r requirements/dev.in +sniffio==1.3.1 + # via anyio +snowballstemmer==2.2.0 + # via sphinx +soupsieve==2.5 + # via beautifulsoup4 +sphinx==7.3.7 + # via + # -r requirements/dev.in + # furo + # myst-parser + # sphinx-autobuild + # sphinx-autodoc-typehints + # sphinx-basic-ng + # sphinx-copybutton +sphinx-autobuild==2024.4.16 + # via -r requirements/dev.in +sphinx-autodoc-typehints==2.1.1 + # via -r requirements/dev.in +sphinx-basic-ng==1.0.0b2 + # via furo +sphinx-copybutton==0.5.2 + # via -r requirements/dev.in +sphinxcontrib-applehelp==1.0.8 + # via sphinx +sphinxcontrib-devhelp==1.0.6 + # via sphinx +sphinxcontrib-htmlhelp==2.0.5 + # via sphinx +sphinxcontrib-jsmath==1.0.1 + # via sphinx +sphinxcontrib-qthelp==1.0.7 + # via sphinx +sphinxcontrib-serializinghtml==1.1.10 + # via sphinx +starlette==0.37.2 + # via sphinx-autobuild +tomlkit==0.12.5 + # via -r requirements/dev.in +typing-extensions==4.12.2 + # via + # -r requirements/dev.in + # anyio + # mypy + # uvicorn +urllib3==2.2.1 + # via requests +uvicorn==0.30.1 + # via sphinx-autobuild +watchfiles==0.22.0 + # via sphinx-autobuild + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/requirements/sync_tool.py b/requirements/sync_tool.py new file mode 100644 index 0000000..846cb0e --- /dev/null +++ b/requirements/sync_tool.py @@ -0,0 +1,214 @@ +#! /usr/bin/env python3 + +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "packaging>=24.0", +# "tomlkit>=0.12.4", +# "typer-slim>=0.12.3", +# "yamlpath>=3.8.2" +# ] +# /// + +"""Script to synchronize requirements across tools.""" + +from __future__ import annotations + +import pathlib +import re +import types +from collections.abc import Iterable, Mapping +from typing import NamedTuple, TypeAlias + +import tomlkit +import typer +import yamlpath +from packaging import ( + markers as pkg_markers, + requirements as pkg_requirements, + specifiers as pkg_specifiers, +) + + +# -- Classes -- +class RequirementSpec(NamedTuple): + """A parsed requirement specification.""" + + package: pkg_requirements.Requirement + specifiers: pkg_specifiers.SpecifierSet | None = None + marker: pkg_markers.Marker | None = None + + @classmethod + def from_text(cls, req_text: str) -> RequirementSpec: + req_text = req_text.strip() + assert req_text, "Requirement string cannot be empty" + + m = re.match(r"^([^><=~]*)\s*([^;]*)\s*;?\s*(.*)$", req_text) + return RequirementSpec( + pkg_requirements.Requirement(m[1]), + pkg_specifiers.Specifier(m[2]) if m[2] else None, + pkg_markers.Marker(m[3]) if m[3] else None, + ) + + def as_text(self) -> str: + return f"{self.package!s}{(self.specifiers or '')!s}{(self.marker or '')!s}".strip() + + +class Requirement(NamedTuple): + """An item in a list of requirements and its parsed specification.""" + + text: str + spec: RequirementSpec + + @classmethod + def from_text(cls, req_text: str) -> Requirement: + return Requirement(req_text, RequirementSpec.from_text(req_text)) + + @classmethod + def from_spec(cls, req: RequirementSpec) -> Requirement: + return Requirement(req.as_text(), req) + + def as_text(self, *, template: str | None = None) -> str: + template = template or "{req.text}" + return template.format(req=self) + + +class RequirementDumpSpec(NamedTuple): + value: Requirement | Iterable[Requirement] + template: str | None = None + + +DumpSpec: TypeAlias = ( + RequirementDumpSpec | tuple[Requirement | Iterable[Requirement], str | None] | str +) + + +# -- Functions -- +def make_requirements_map(requirements: Iterable[Requirement]) -> dict[str, Requirement]: + return {req.spec.package.name: req for req in requirements} + + +def load_from_requirements(filename: str) -> list[Requirement]: + requirements = [] + with pathlib.Path(filename).open(encoding="locale") as f: + for raw_line in f: + if (end := raw_line.find("#")) != -1: + raw_line = raw_line[:end] # noqa: PLW2901 [redefined-loop-name] + line = raw_line.strip() + if line and not line.startswith("-"): + requirements.append(Requirement.from_text(line)) + + return requirements + + +def load_from_toml(filename: str, key: str) -> list[Requirement]: + with pathlib.Path(filename).open(encoding="locale") as f: + toml_data = tomlkit.loads(f.read()) + + section = toml_data + for part in key.split("."): + section = section[part] + + return [Requirement.from_text(req) for req in section] + + +def dump(requirements: Iterable[Requirement], *, template: str | None = None) -> None: + return [req.as_text(template=template) for req in requirements] + + +def dump_to_requirements( + requirements: Iterable[Requirement], + filename: str, + *, + template: str | None = None, + header: str | None = None, + footer: str | None = None, +) -> None: + with pathlib.Path(filename).open("w", encoding="locale") as f: + if header: + f.write(f"{header}\n") + f.write("\n".join(dump(requirements, template=template))) + if footer: + f.write(f"{footer}\n") + f.write("\n") + + +def dump_to_yaml(requirements_map: Mapping[str, DumpSpec], filename: str) -> None: + file_path = pathlib.Path(filename) + logging_args = types.SimpleNamespace(quiet=False, verbose=False, debug=False) + console_log = yamlpath.wrappers.ConsolePrinter(logging_args) + yaml = yamlpath.common.Parsers.get_yaml_editor() + (yaml_data, doc_loaded) = yamlpath.common.Parsers.get_yaml_data(yaml, console_log, file_path) + assert doc_loaded + processor = yamlpath.Processor(console_log, yaml_data) + + for key_path, (value, template) in requirements_map.items(): + match value: + case str(): + processor.set_value(yamlpath.YAMLPath(key_path), value) + case Requirement(): + processor.set_value(yamlpath.YAMLPath(key_path), value.as_text(template=template)) + case Iterable(): + for _ in processor.delete_nodes(yamlpath.YAMLPath(key_path)): + pass + for i, req in enumerate(dump(value, template=template)): + item_path = yamlpath.YAMLPath(f"{key_path}[{i}]") + processor.set_value(item_path, req) + + with file_path.open("w") as f: + yaml.dump(yaml_data, f) + + +# -- CLI -- +app = typer.Typer() + + +@app.command() +def pull(): + base = load_from_toml("pyproject.toml", "project.dependencies") + dump_to_requirements(base, "requirements/base.in") + cuda12 = load_from_toml("pyproject.toml", "project.optional-dependencies.cuda12") + dump_to_requirements(cuda12, "requirements/cuda12.in", header="-r base.in") + + +@app.command() +def push(): + base_names = {r.spec.package for r in load_from_toml("pyproject.toml", "project.dependencies")} + base_versions = [ + r for r in load_from_requirements("requirements/base.txt") if r.spec.package in base_names + ] + dev_versions_map = make_requirements_map(load_from_requirements("requirements/dev.txt")) + mypy_req_versions = sorted( + base_versions + [dev_versions_map[r] for r in ("pytest", "typing-extensions")], + key=lambda r: str(r.spec.package), + ) + dump_to_yaml( + { + # ruff + "repos[.repo%https://github.com/astral-sh/ruff-pre-commit].rev": ( + dev_versions_map["ruff"], + "v{req.spec.specifiers.version}", + ), + # mypy + "repos[.repo%https://github.com/pre-commit/mirrors-mypy].rev": ( + dev_versions_map["mypy"], + "v{req.spec.specifiers.version}", + ), + "repos[.repo%https://github.com/pre-commit/mirrors-mypy].hooks[.id%mypy].additional_dependencies": ( + mypy_req_versions, + None, + ), + }, + ".pre-commit-config.yaml", + ) + + +if __name__ == "__main__": + app()