Skip to content

WIP: experiment with first class dim objects #1517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
412 changes: 412 additions & 0 deletions doc/internal/named-dims.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pytensor/xtensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import pytensor.xtensor.rewriting
from pytensor.xtensor import linalg, random
from pytensor.xtensor.basic import ones, xtensor_from_tensor, zeros
from pytensor.xtensor.math import dot
from pytensor.xtensor.shape import concat
from pytensor.xtensor.type import (
as_xtensor,
dim,
xtensor,
xtensor_constant,
)
Expand Down
81 changes: 58 additions & 23 deletions pytensor/xtensor/basic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from collections.abc import Sequence

from pytensor.compile.ops import TypeCastingOp
from pytensor.graph import Apply, Op
from pytensor.scalar.basic import uint64
from pytensor.tensor.basic import ones as tensor_ones
from pytensor.tensor.basic import zeros as tensor_zeros
from pytensor.tensor.shape import specify_shape
from pytensor.tensor.type import TensorType
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
from pytensor.xtensor.type import DimVariable, XTensorType, as_xtensor, xtensor


DIM_LENGTH_SCALAR = uint64


class XOp(Op):
Expand Down Expand Up @@ -32,6 +37,7 @@ def make_node(self, x):
return Apply(self, [x], [output])

def L_op(self, inputs, outs, g_outs):
# TODO fix
[x] = inputs
[g_out] = g_outs
return [xtensor_from_tensor(g_out, dims=x.type.dims)]
Expand All @@ -41,46 +47,49 @@ def L_op(self, inputs, outs, g_outs):


class XTensorFromTensor(XTypeCastOp):
__props__ = ("dims",)

def __init__(self, dims: Sequence[str]):
super().__init__()
self.dims = tuple(dims)
__props__ = ()

def make_node(self, x):
def make_node(self, x, *dims):
if not isinstance(x.type, TensorType):
raise TypeError(f"x must be an TensorType type, got {type(x.type)}")
output = xtensor(dtype=x.type.dtype, dims=self.dims, shape=x.type.shape)
return Apply(self, [x], [output])
output = xtensor(dtype=x.type.dtype, dims=dims)
return Apply(self, [x, *dims], [output])

def L_op(self, inputs, outs, g_outs):
# TODO fix
[g_out] = g_outs
return [tensor_from_xtensor(g_out)]


def xtensor_from_tensor(x, dims, name=None):
return XTensorFromTensor(dims=dims)(x, name=name)
def xtensor_from_tensor(x, dims, name=None, check: bool = True):
if check:
x = specify_shape(x, [dim.size for dim in dims])
return XTensorFromTensor()(x, *dims, name=name)


class Rename(XTypeCastOp):
__props__ = ("new_dims",)
class MapDims(XTypeCastOp):
__props__ = ("new_dim_indices",)

def __init__(self, new_dims: tuple[str, ...]):
super().__init__()
self.new_dims = new_dims
def __init__(self, new_dim_indices: tuple[int, ...]):
self.new_dims_indices = new_dim_indices

def make_node(self, x):
def make_node(self, x, *new_dims):
x = as_xtensor(x)
output = x.type.clone(dims=self.new_dims)()
new_dims = list(x.dims)
for i, idx in enumerate(self.new_dims_indices):
new_dims[idx] = new_dims[i]

output = x.type.clone(dims=new_dims)()
return Apply(self, [x], [output])

def L_op(self, inputs, outs, g_outs):
# TODO fix
[x] = inputs
[g_out] = g_outs
return [rename(g_out, dims=x.type.dims)]
return [map_dims(g_out, dims=x.type.dims)]


def rename(x, name_dict: dict[str, str] | None = None, **names: str):
def map_dims(x, name_dict: dict[DimVariable, DimVariable] | None = None, **names):
if name_dict is not None:
if names:
raise ValueError("Cannot use both positional and keyword names in rename")
Expand All @@ -97,4 +106,30 @@ def rename(x, name_dict: dict[str, str] | None = None, **names: str):
f"Cannot rename {old_name} to {new_name}: {old_name} not in {old_names}"
)

return Rename(tuple(new_names))(x)
return MapDims(tuple(new_names))(x)


def zeros(*dims, dtype=None, name=None):
"""Create a new XTensor filled with zeros."""
if not dims:
raise ValueError("At least one dimension must be specified")

return xtensor_from_tensor(
tensor_zeros(shape=[dim.size for dim in dims], dtype=dtype),
dims=dims,
name=name,
check=False,
)


def ones(*dims, dtype=None, name=None):
"""Create a new XTensor filled with zeros."""
if not dims:
raise ValueError("At least one dimension must be specified")

return xtensor_from_tensor(
tensor_ones(shape=[dim.size for dim in dims], dtype=dtype),
dims=dims,
name=name,
check=False,
)
175 changes: 175 additions & 0 deletions pytensor/xtensor/dims.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from __future__ import annotations

from uuid import uuid4

import numpy as np

from pytensor.graph.basic import Apply
from pytensor.graph.op import Op, Variable
from pytensor.xtensor.type import (
DIM_LENGTH_TYPE,
DIM_LENGTH_VARIABLE,
BasicDim,
CloneDim,
DimType,
DimVariable,
XTensorVariable,
)


class DimOp(Op):
def perform(self, node, inputs, outputs):
raise NotImplementedError(
f"xtensor operation {self} must be lowered to equivalent tensor operations"
)


# Not a dim op, because it doesn't return a DimVariable
class Length(Op):
__props__ = ()

def make_node(self, *inputs: Variable) -> Apply:
(x,) = inputs
if not isinstance(x, DimVariable):
raise TypeError(f"x must be a DimVariable, got {type(x.type)}")
return Apply(self, [x], [DIM_LENGTH_TYPE()])

def perform(self, node, inputs, outputs):
# outputs[0][0] = np.int64(inputs[0])
outputs[0][0] = np.array(inputs[0], dtype=DIM_LENGTH_TYPE.dtype)


def _dim_size(dim: DimVariable) -> DIM_LENGTH_VARIABLE:
return Length()(dim)


class FromLength(DimOp):
__props__ = ("dim_type",)

def __init__(self, dim_type: DimType):
super().__init__()
self.dim_type = dim_type

def make_node(self, *inputs: Variable) -> Apply:
(length,) = inputs
if not isinstance(length, DIM_LENGTH_VARIABLE):
raise TypeError(
f"length must be a DIM_LENGTH_VARIABLE, got {type(length.type)}"
)
if length.type != DIM_LENGTH_TYPE:
raise TypeError(
f"length must be of dtype 'DIM_LENGTH_SCALAR', got {length.type.dtype}"
)
return Apply(self, [length], [self.dim_type()])

def perform(self, node, inputs, outputs):
"""Convert the length to a list of lengths."""
outputs[0][0] = inputs[0]


def from_length(length: DIM_LENGTH_VARIABLE, name: str | None = None) -> DimVariable:
# TODO add check for dtype
if not isinstance(length, DIM_LENGTH_VARIABLE):
raise TypeError(
f"length must be a DIM_LENGTH_VARIABLE, got {type(length.type)}"
)
if length.type != DIM_LENGTH_TYPE:
raise TypeError(
f"length must be of dtype 'DIM_LENGTH_SCALAR', got {length.type.dtype}"
)

uuid = uuid4()
dim_type = BasicDim(uuid=uuid, name=name)
op = FromLength(dim_type)
return op(length, name=name)


class DimFromTensor(Op):
__props__ = ("dim_type",)

def __init__(self, dim_type: DimType):
super().__init__()
self.dim_type = dim_type

def make_node(self, *inputs: Variable) -> Apply:
(x,) = inputs
if not isinstance(x, XTensorVariable):
raise TypeError(f"x must be an XTensorVariable, got {type(x.type)}")
return Apply(self, [x], [self.dim_type()])

def perform(self, node, inputs, outputs):
"""Convert the tensor to a dimension variable."""
(x,) = inputs
(x_var,) = node.inputs
for i, dim in enumerate(x_var.type.dims):
if dim == self.dim_type:
# outputs[0][0] = np.int64(x.shape[i])
outputs[0][0] = np.array(x.shape[i], dtype=DIM_LENGTH_TYPE.dtype)
return
raise ValueError(f"Dimension {self.dim_type} not found in tensor {x.type.dims}")


def _dim_from_tensor(x: XTensorVariable, idx: int) -> DimVariable:
op = DimFromTensor(dim_type=x.type.dims[idx])
return op(x, name=x.type.dims[idx].name)


class Clone(Op):
__props__ = ("dim_type",)

def __init__(self, dim_type):
super().__init__()
self.dim_type = dim_type

def make_node(self, *inputs: Variable) -> Apply:
(x,) = inputs
if not isinstance(x, DimVariable):
raise TypeError(f"x must be a DimVariable, got {type(x.type)}")
return Apply(self, [x], [self.dim_type()])

def perform(self, node, inputs, outputs):
outputs[0][0] = inputs[0]


def _clone_dim(dim: DimVariable, *, name: str | None = None) -> DimVariable:
"""Rename a dimension variable.

Args:
name: The new name for the dimension.

Returns:
A new DimVariable with the updated name.
"""
dim_type = CloneDim(uuid=uuid4(), base=dim.type, name=name)
return Clone(dim_type)(dim, name=name)


class Product(Op):
__props__ = ()

def make_node(self, *dims: Variable) -> Apply:
if not all(isinstance(dim, DimVariable) for dim in dims):
raise TypeError("All inputs must be DimVariables.")
out = dim_type()
return Apply(self, list(dims), [out])

def perform(self, node, inputs, outputs):
outputs[0][0] = np.prod(inputs, dtype=DIM_LENGTH_TYPE.dtype).item()


def product_dim(*dims: DimVariable, name: str | None = None) -> DimVariable:
return Product()(*dims, name=name)


def rebase_dim(dim: DimVariable, *tensors: XTensorVariable) -> DimVariable:
if not isinstance(dim, DimVariable):
raise TypeError(f"dim must be a DimVariable, got {type(dim)}")

if not tensors:
raise ValueError("At least one tensor must be provided for rebasing.")

for tensor in tensors:
for i, tensor_dim in enumerate(tensor.type.dims):
if dim.type == tensor_dim:
return _dim_from_tensor(tensor, idx=i)
raise ValueError(f"Dimension {dim.type} not found in any of the provided tensors.")
32 changes: 19 additions & 13 deletions pytensor/xtensor/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,20 @@
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.math import neq, sqrt
from pytensor.xtensor.math import sqr as square
from pytensor.xtensor.type import as_xtensor, xtensor
from pytensor.xtensor.type import DimType, DimVariable, as_xtensor, xtensor


REDUCE_DIM = str | Sequence[str] | EllipsisType | None
REDUCE_DIM = DimVariable | Sequence[DimVariable] | EllipsisType | None


class XReduce(XOp):
__slots__ = ("binary_op", "dims")

def __init__(self, binary_op, dims: Sequence[str]):
def __init__(self, binary_op, dims: Sequence[DimVariable]):
super().__init__()
self.binary_op = binary_op
# Order of reduce dims doesn't change the behavior of the Op
self.dims = tuple(sorted(dims))
self.dims = tuple(dims)

def make_node(self, x):
x = as_xtensor(x)
Expand All @@ -43,17 +43,17 @@ def make_node(self, x):
if d not in reduce_dims_set
]
)
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
output = xtensor(dtype=x.type.dtype, dims=out_dims)
return Apply(self, [x], [output])


def _process_user_dims(x, dim: REDUCE_DIM) -> Sequence[str]:
if isinstance(dim, str):
return (dim,)
def _process_user_dims(x, dim: REDUCE_DIM) -> Sequence[DimType]:
if isinstance(dim, DimVariable):
return (dim.type,)
elif dim is None or dim is Ellipsis:
x = as_xtensor(x)
return typing.cast(tuple[str], x.type.dims)
return dim
return typing.cast(tuple[DimType], x.type.dims)
return tuple(dim.type for dim in dim)


def reduce(x, dim: REDUCE_DIM = None, *, binary_op):
Expand All @@ -80,8 +80,14 @@ def bool_reduce(x, dim: REDUCE_DIM = None, *, binary_op):

def _infer_reduced_size(original_var, reduced_var):
reduced_dims = reduced_var.dims
return variadic_mul(
*[size for dim, size in original_var.sizes if dim not in reduced_dims]
return as_xtensor(
variadic_mul(
*[
size
for dim, size in original_var.sizes.items()
if dim not in reduced_dims
]
)
)


Expand All @@ -96,7 +102,7 @@ def var(x, dim: REDUCE_DIM, *, ddof: int = 0):
x = as_xtensor(x)
x_mean = mean(x, dim)
n = _infer_reduced_size(x, x_mean)
return square(x - x_mean) / (n - ddof)
return square(x - x_mean).mean(dim) / (n - ddof)


def std(x, dim: REDUCE_DIM, *, ddof: int = 0):
Expand Down
Loading