diff --git a/CHANGELOG.md b/CHANGELOG.md index a6082e9..098fcbb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,18 @@ and this project adheres to [Semantic Versioning][]. [keep a changelog]: https://keepachangelog.com/en/1.0.0/ [semantic versioning]: https://semver.org/spec/v2.0.0.html +## [Unrealeased] + +### Added + +### Changed + +- Encoding of `treedata` attributes in h5ad and zarr files. `label`, `allow_overlap`, `obst`, and `vart` are now separate fields in the file. (#31) + +### Fixed + +- `TreeData` objects with `.raw` specified can now be read (#31) + ## [0.0.4] - 2024-09-02 ### Added diff --git a/src/treedata/_core/read.py b/src/treedata/_core/read.py index 57ea6d5..6a804c8 100755 --- a/src/treedata/_core/read.py +++ b/src/treedata/_core/read.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from collections.abc import MutableMapping, Sequence +from collections.abc import MutableMapping from pathlib import Path from typing import ( Literal, @@ -9,37 +9,84 @@ import anndata as ad import h5py +import networkx as nx import zarr -from scipy import sparse -from treedata._core.aligned_mapping import AxisTrees from treedata._core.treedata import TreeData -from treedata._utils import dict_to_digraph -def _tdata_from_adata(tdata, treedata_attrs=None) -> TreeData: - """Create a TreeData object parsing attribute from AnnData uns field.""" - tdata.__class__ = TreeData +def _dict_to_digraph(graph_dict: dict) -> nx.DiGraph: + """Convert a dictionary to a networkx.DiGraph.""" + G = nx.DiGraph() + # Add nodes and their attributes + for node, attrs in graph_dict["nodes"].items(): + G.add_node(node, **attrs) + # Add edges and their attributes + for source, targets in graph_dict["edges"].items(): + for target, attrs in targets.items(): + G.add_edge(source, target, **attrs) + return G + + +def _parse_axis_trees(data: str) -> dict: + """Parse AxisTrees from a string.""" + return {k: _dict_to_digraph(v) for k, v in json.loads(data).items()} + + +def _parse_legacy(treedata_attrs: dict) -> dict: + """Parse tree attributes from AnnData uns field.""" if treedata_attrs is not None: - tdata._tree_label = treedata_attrs["label"] if "label" in treedata_attrs.keys() else None - tdata._allow_overlap = bool(treedata_attrs["allow_overlap"]) - tdata._obst = AxisTrees(tdata, 0, vals={k: dict_to_digraph(v) for k, v in treedata_attrs["obst"].items()}) - tdata._vart = AxisTrees(tdata, 1, vals={k: dict_to_digraph(v) for k, v in treedata_attrs["vart"].items()}) + for j in ["obst", "vart"]: + if j in treedata_attrs: + treedata_attrs[j] = {k: _dict_to_digraph(v) for k, v in treedata_attrs[j].items()} + treedata_attrs["allow_overlap"] = bool(treedata_attrs["allow_overlap"]) + treedata_attrs["label"] = treedata_attrs["label"] if "label" in treedata_attrs.keys() else None + return treedata_attrs + + +def _read_raw(f, backed): + """Read raw from file.""" + d = {} + for k in ["obs", "var"]: + if f"raw/{k}" in f: + d[k] = ad.experimental.read_elem(f[f"raw/{k}"]) + if not backed: + d["X"] = ad.experimental.read_elem(f["raw/X"]) + return d + + +def _read_tdata(f, filename, backed) -> dict: + """Read TreeData from file.""" + d = {} + if backed is None: + backed = False + elif backed is True: + backed = "r" + # Read X if not backed + if not backed: + d["X"] = ad.experimental.read_elem(f["X"]) else: - tdata._tree_label = None - tdata._allow_overlap = False - tdata._obst = AxisTrees(tdata, 0) - tdata._vart = AxisTrees(tdata, 1) - return tdata + d.update({"filename": filename, "filemode": backed}) + # Read standard elements + for k in ["obs", "var", "obsm", "varm", "obsp", "varp", "layers", "uns", "label", "allow_overlap"]: + if k in f: + d[k] = ad.experimental.read_elem(f[k]) + # Read raw + if "raw" in f: + d["raw"] = _read_raw(f, backed) + # Read axis tree elements + for k in ["obst", "vart"]: + if k in f: + d[k] = _parse_axis_trees(ad.experimental.read_elem(f[k])) + # Read legacy treedata format + if "raw.treedata" in f: + d.update(_parse_legacy(json.loads(ad.experimental.read_elem(f["raw.treedata"])))) + return d def read_h5ad( filename: str | Path = None, backed: Literal["r", "r+"] | bool | None = None, - *, - as_sparse: Sequence[str] = (), - as_sparse_fmt: type[sparse.spmatrix] = sparse.csr_matrix, - chunk_size: int = 6000, ) -> TreeData: """Read `.h5ad`-formatted hdf5 file. @@ -52,33 +99,10 @@ def read_h5ad( instead of fully loading it into memory (`memory` mode). If you want to modify backed attributes of the TreeData object, you need to choose `'r+'`. - as_sparse - If an array was saved as dense, passing its name here will read it as - a sparse_matrix, by chunk of size `chunk_size`. - as_sparse_fmt - Sparse format class to read elements from `as_sparse` in as. - chunk_size - Used only when loading sparse dataset that is stored as dense. - Loading iterates through chunks of the dataset of this row size - until it reads the whole dataset. - Higher size means higher memory consumption and higher (to a point) - loading speed. """ - adata = ad.read_h5ad( - filename, - backed=backed, - as_sparse=as_sparse, - as_sparse_fmt=as_sparse_fmt, - chunk_size=chunk_size, - ) with h5py.File(filename, "r") as f: - if "raw.treedata" in f: - treedata_attrs = json.loads(f["raw.treedata"][()]) - else: - treedata_attrs = None - tdata = _tdata_from_adata(adata, treedata_attrs) - - return tdata + d = _read_tdata(f, filename, backed) + return TreeData(**d) def read_zarr(store: str | Path | MutableMapping | zarr.Group) -> TreeData: @@ -89,13 +113,6 @@ def read_zarr(store: str | Path | MutableMapping | zarr.Group) -> TreeData: store The filename, a :class:`~typing.MutableMapping`, or a Zarr storage class. """ - adata = ad.read_zarr(store) - with zarr.open(store, mode="r") as f: - if "raw.treedata" in f: - treedata_attrs = json.loads(f["raw.treedata"][()]) - else: - treedata_attrs = None - tdata = _tdata_from_adata(adata, treedata_attrs) - - return tdata + d = _read_tdata(f, store, backed=False) + return TreeData(**d) diff --git a/src/treedata/_core/treedata.py b/src/treedata/_core/treedata.py index d9ff34d..ebae38b 100755 --- a/src/treedata/_core/treedata.py +++ b/src/treedata/_core/treedata.py @@ -1,6 +1,5 @@ from __future__ import annotations -import json from collections.abc import Iterable, Mapping, MutableMapping, Sequence from copy import deepcopy from pathlib import Path @@ -11,17 +10,12 @@ ) import anndata as ad -import h5py import networkx as nx import numpy as np import pandas as pd -import zarr from anndata._core.index import _subset -from anndata._io import write_h5ad, write_zarr from scipy import sparse -from treedata._utils import digraph_to_dict, make_serializable - from .aligned_mapping import ( AxisTrees, ) @@ -181,8 +175,14 @@ def _init_as_actual( # init from scratch else: - self._tree_label = label - self._allow_overlap = allow_overlap + if isinstance(label, str) or label is None: + self._tree_label = label + else: + raise ValueError("label has to be a string or None") + if isinstance(allow_overlap, bool) or isinstance(allow_overlap, np.bool_): + self._allow_overlap = bool(allow_overlap) + else: + raise ValueError("allow_overlap has to be a boolean") self._obst = AxisTrees(self, 0, vals=obst) self._vart = AxisTrees(self, 1, vals=vart) @@ -281,16 +281,6 @@ def to_adata(self) -> ad.AnnData: """Convert this TreeData object to an AnnData object.""" return ad.AnnData(self) - def _treedata_attrs(self) -> dict: - """Dictionary of TreeData attributes""" - attrs = { - "obst": {k: digraph_to_dict(v) for k, v in self.obst.items()}, - "vart": {k: digraph_to_dict(v) for k, v in self.vart.items()}, - "label": self.label, - "allow_overlap": self.allow_overlap, - } - return make_serializable(attrs) - def _mutated_copy(self, **kwargs): """Creating TreeData with attributes optionally specified via kwargs.""" if self.isbacked: @@ -366,7 +356,7 @@ def write_h5ad( filename: PathLike | None = None, compression: Literal["gzip", "lzf"] | None = None, compression_opts: int | Any = None, - as_dense: Sequence[str] = (), + **kwargs, ): """Write `.h5ad`-formatted hdf5 file. @@ -378,27 +368,18 @@ def write_h5ad( [`lzf`, `gzip`], see the h5py :ref:`dataset_compression`. compression_opts [`lzf`, `gzip`], see the h5py :ref:`dataset_compression`. - as_dense - Sparse arrays in TreeData object to write as dense. Currently only - supports `X` and `raw/X`. """ + from .write import write_h5ad + if filename is None and not self.isbacked: raise ValueError("Provide a filename!") if filename is None: filename = self.filename - write_h5ad( - Path(filename), - self, - compression=compression, - compression_opts=compression_opts, - as_dense=as_dense, - ) + write_h5ad(Path(filename), self, compression=compression, compression_opts=compression_opts) - with h5py.File(filename, "a") as f: - if "raw.treedata" in f: - del f["raw.treedata"] - f.create_dataset("raw.treedata", data=json.dumps(self._treedata_attrs())) + if self.isbacked: + self.file.filename = filename write = write_h5ad # a shortcut and backwards compat @@ -406,6 +387,7 @@ def write_zarr( self, store: MutableMapping | PathLike, chunks: bool | int | tuple[int, ...] | None = None, + **kwargs, ): """Write a hierarchical Zarr array store. @@ -416,12 +398,9 @@ def write_zarr( chunks Chunk shape. """ - write_zarr(store, self.to_adata(), chunks=chunks) + from .write import write_zarr - with zarr.open(store, mode="a") as f: - if "treedata" in f: - del f["raw.treedata"] - f.create_dataset("raw.treedata", data=json.dumps(self._treedata_attrs())) + write_zarr(Path(store), self, chunks=chunks) def to_memory(self, copy=False) -> TreeData: """Return a new AnnData object with all backed arrays loaded into memory. diff --git a/src/treedata/_core/write.py b/src/treedata/_core/write.py new file mode 100755 index 0000000..39548b2 --- /dev/null +++ b/src/treedata/_core/write.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import ( + Any, + Literal, +) + +import anndata as ad +import h5py +import networkx as nx +import numpy as np +import pandas as pd +import zarr + +from treedata._core.aligned_mapping import AxisTrees +from treedata._core.treedata import TreeData + + +def _make_serializable(data: dict) -> dict: + """Make a dictionary serializable.""" + if isinstance(data, dict): + return {k: _make_serializable(v) for k, v in data.items()} + elif isinstance(data, list | tuple | set): + return [_make_serializable(v) for v in data] + elif isinstance(data, np.ndarray): + return data.tolist() + elif isinstance(data, np.generic | np.number): + return data.item() + elif isinstance(data, pd.Series): + return data.tolist() + else: + return data + + +def _digraph_to_dict(G: nx.DiGraph) -> dict: + """Convert a networkx.DiGraph to a dictionary.""" + G = nx.DiGraph(G) + edge_dict = nx.to_dict_of_dicts(G) + # Get node data + node_dict = {node: G.nodes[node] for node in G.nodes()} + # Combine edge and node data in one dictionary + graph_dict = {"edges": edge_dict, "nodes": node_dict} + return graph_dict + + +def _serialize_axis_trees(trees: AxisTrees) -> dict: + """Serialize AxisTrees.""" + d = {k: _digraph_to_dict(v) for k, v in trees.items()} + return json.dumps(_make_serializable(d)) + + +def _write_tdata(f, tdata, filename, **kwargs) -> None: + """Write TreeData to file.""" + # Add encoding type and version + f = f["/"] + f.attrs.setdefault("encoding-type", "anndata") + f.attrs.setdefault("encoding-version", "0.1.0") + # Convert strings to categoricals + tdata.strings_to_categoricals() + # Write X if not backed + if not (tdata.isbacked and Path(tdata.filename) == Path(filename)): + ad.experimental.write_elem(f, "X", tdata.X, dataset_kwargs=kwargs) + # Write array elements + for key in ["obs", "var", "label", "allow_overlap"]: + ad.experimental.write_elem(f, key, getattr(tdata, key), dataset_kwargs=kwargs) + # Write group elements + for key in ["obsm", "varm", "obsp", "varp", "layers", "uns"]: + ad.experimental.write_elem(f, key, dict(getattr(tdata, key)), dataset_kwargs=kwargs) + # Write axis tree elements + for key in ["obst", "vart"]: + ad.experimental.write_elem(f, key, _serialize_axis_trees(getattr(tdata, key)), dataset_kwargs=kwargs) + # Write raw + if tdata.raw is not None: + tdata.strings_to_categoricals(tdata.raw.var) + ad.experimental.write_elem(f, "raw", tdata.raw, dataset_kwargs=kwargs) + # Close the file + tdata.file.close() + + +def write_h5ad( + filename: str | Path, + tdata: TreeData, + compression: Literal["gzip", "lzf"] | None = None, + compression_opts: int | Any = None, + **kwargs, +) -> None: + """Write `.h5ad`-formatted hdf5 file. + + Parameters + ---------- + filename + Filename of data file. Defaults to backing file. + tdata + TreeData object to write. + compression + [`lzf`, `gzip`], see the h5py :ref:`dataset_compression`. + compression_opts + [`lzf`, `gzip`], see the h5py :ref:`dataset_compression`. + """ + mode = "a" if tdata.isbacked else "w" + if tdata.isbacked: # close so that we can reopen below + tdata.file.close() + with h5py.File(filename, mode) as f: + _write_tdata(f, tdata, filename, compression=compression, compression_opts=compression_opts, **kwargs) + + +def write_zarr(filename: str | Path, tdata: TreeData, **kwargs) -> None: + """Write `.zarr`-formatted zarr file. + + Parameters + ---------- + filename + Filename of data file. Defaults to backing file. + tdata + TreeData object to write. + kwargs + Additional keyword arguments passed to :func:`zarr.save`. + """ + with zarr.open(filename, mode="w") as f: + _write_tdata(f, tdata, filename, **kwargs) diff --git a/src/treedata/_utils.py b/src/treedata/_utils.py index bd6797d..9f360fd 100755 --- a/src/treedata/_utils.py +++ b/src/treedata/_utils.py @@ -1,8 +1,6 @@ from collections import deque import networkx as nx -import numpy as np -import pandas as pd def subset_tree(tree: nx.DiGraph, leaves: list[str], asview: bool) -> nx.DiGraph: @@ -36,44 +34,3 @@ def combine_trees(subsets: list[nx.DiGraph]) -> nx.DiGraph: # The combined_tree now contains all nodes and edges from the subsets return combined_tree - - -def digraph_to_dict(G: nx.DiGraph) -> dict: - """Convert a networkx.DiGraph to a dictionary.""" - G = nx.DiGraph(G) - edge_dict = nx.to_dict_of_dicts(G) - # Get node data - node_dict = {node: G.nodes[node] for node in G.nodes()} - # Combine edge and node data in one dictionary - graph_dict = {"edges": edge_dict, "nodes": node_dict} - - return graph_dict - - -def dict_to_digraph(graph_dict: dict) -> nx.DiGraph: - """Convert a dictionary to a networkx.DiGraph.""" - G = nx.DiGraph() - # Add nodes and their attributes - for node, attrs in graph_dict["nodes"].items(): - G.add_node(node, **attrs) - # Add edges and their attributes - for source, targets in graph_dict["edges"].items(): - for target, attrs in targets.items(): - G.add_edge(source, target, **attrs) - return G - - -def make_serializable(data) -> dict: - """Make a graph dictionary serializable.""" - if isinstance(data, dict): - return {k: make_serializable(v) for k, v in data.items()} - elif isinstance(data, list | tuple | set): - return [make_serializable(v) for v in data] - elif isinstance(data, np.ndarray): - return data.tolist() - elif isinstance(data, np.generic | np.number): - return data.item() - elif isinstance(data, pd.Series): - return data.tolist() - else: - return data diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index 0923ee0..1b5bf2c 100755 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -29,15 +29,19 @@ def tdata(X, tree): def check_graph_equality(g1, g2): - assert nx.is_isomorphic(g1, g2, node_match=lambda n1, n2: n1 == n2, edge_match=lambda e1, e2: e1 == e2) + assert nx.is_isomorphic( + g1, g2, node_match=lambda n1, n2: set(n1) == set(n2), edge_match=lambda e1, e2: set(e1) == set(e2) + ) @pytest.mark.parametrize("backed", [None, "r"]) def test_h5ad_readwrite(tdata, tmp_path, backed): + tdata.raw = tdata file_path = tmp_path / "test.h5ad" tdata.write_h5ad(file_path) tdata2 = td.read_h5ad(file_path, backed=backed) assert np.array_equal(tdata2.X, tdata.X) + assert np.array_equal(tdata2.raw.X, tdata.raw.X) check_graph_equality(tdata2.obst["1"], tdata.obst["1"]) check_graph_equality(tdata2.vart["1"], tdata.vart["1"]) assert tdata2.label == "tree" @@ -93,7 +97,7 @@ def test_read_anndata(X, tmp_path): adata.write_h5ad(file_path) tdata = td.read_h5ad(file_path) assert np.array_equal(tdata.X, adata.X) - assert tdata.label is None + assert tdata.label == "tree" assert tdata.allow_overlap is False assert tdata.obst_keys() == []