From 359a62314a458f890348387e9e91da07c7b9e24f Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Wed, 6 Sep 2023 12:01:58 -0500 Subject: [PATCH] find static fields from dataclass parents (#14) --- simple_pytree/pytree.py | 13 +++++++------ tests/test_pytree.py | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/simple_pytree/pytree.py b/simple_pytree/pytree.py index b63a8c8..060a1fb 100644 --- a/simple_pytree/pytree.py +++ b/simple_pytree/pytree.py @@ -239,10 +239,11 @@ def __setattr__(self: P, field: str, value: tp.Any): def _inherited_static_fields(cls: type) -> tp.Set[str]: static_fields = set() for parent_class in cls.mro(): - if ( - parent_class is not cls - and parent_class is not Pytree - and issubclass(parent_class, Pytree) - ): - static_fields.update(parent_class._pytree__static_fields) + if parent_class is not cls and parent_class is not Pytree: + if issubclass(parent_class, Pytree): + static_fields.update(parent_class._pytree__static_fields) + elif dataclasses.is_dataclass(parent_class): + for field in dataclasses.fields(parent_class): + if not field.metadata.get("pytree_node", True): + static_fields.add(field.name) return static_fields diff --git a/tests/test_pytree.py b/tests/test_pytree.py index 827fdff..be4e158 100644 --- a/tests/test_pytree.py +++ b/tests/test_pytree.py @@ -1,3 +1,4 @@ +import dataclasses from typing import Generic, TypeVar import jax @@ -270,3 +271,21 @@ class Foo(Pytree, mutable=True): # test mutation pytree.x = 4 assert pytree.x == 4 + + def test_dataclass_inheritance(self): + A = dataclasses.make_dataclass( + "A", + [("x", int), "y", ("z", int, static_field(default=5))], + ) + + @dataclass + class B(Pytree, A): + ... + + b = B(1, 2) + + assert b.x == 1 + assert b.y == 2 + assert b.z == 5 + + assert jax.tree_util.tree_leaves(b) == [1, 2]