From bcbb18a42880fcf8b96deb313b9e12f746e74f19 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sat, 5 Jul 2025 08:53:49 +0200 Subject: [PATCH 1/4] Simplify python implementation of ScalarFromTensor --- pytensor/tensor/basic.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 5a3cf0036f..506df7db09 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -678,10 +678,9 @@ def make_node(self, t): self, [t], [ps.get_scalar_type(dtype=t.type.dtype).make_variable()] ) - def perform(self, node, inp, out_): - (s,) = inp - (out,) = out_ - out[0] = s.flatten()[0] + def perform(self, node, inputs, output_storage): + # not using .item() because that returns a Python scalar, not a numpy scalar + output_storage[0][0] = inputs[0][()] def infer_shape(self, fgraph, node, in_shapes): return [()] From 4c22301ada1c111bed8fabc9fa755e37ef401154 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sat, 5 Jul 2025 12:42:47 +0200 Subject: [PATCH 2/4] local_subtensor_make_vector: don't return make_vector when slice keeps only one item --- pytensor/tensor/rewriting/subtensor_lift.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py index eb31514463..5a367a302a 100644 --- a/pytensor/tensor/rewriting/subtensor_lift.py +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -613,10 +613,6 @@ def local_subtensor_make_vector(fgraph, node): something more general for constant ``*Subtensor*`` graphs (or perhaps include this kind of work in the constant folding). """ - - if not isinstance(node.op, Subtensor | AdvancedSubtensor1): - return False - x = node.inputs[0] if not (x.owner and isinstance(x.owner.op, MakeVector)): @@ -666,7 +662,11 @@ def local_subtensor_make_vector(fgraph, node): const_slice = get_constant_idx( node.op.idx_list, node.inputs, allow_partial=False )[0] - ret = make_vector_op(*x.owner.inputs[const_slice]) + sliced_inputs = x.owner.inputs[const_slice] + if len(sliced_inputs) == 1: + ret = expand_dims(sliced_inputs[0], axis=0) + else: + ret = make_vector_op(*sliced_inputs) copy_stack_trace(node.outputs, ret) return [ret] except NotScalarConstantError: From 07c998e71595df43294e23f3362c7b477c9f61b2 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sat, 5 Jul 2025 14:15:10 +0200 Subject: [PATCH 3/4] Allow Blockwise to create dummy core nodes with outer inputs, if these are unbatched --- pytensor/link/jax/dispatch/blockwise.py | 18 +--- pytensor/link/numba/dispatch/blockwise.py | 5 +- pytensor/tensor/blockwise.py | 125 +++++++++++++++++----- pytensor/tensor/rewriting/blockwise.py | 13 +-- 4 files changed, 105 insertions(+), 56 deletions(-) diff --git a/pytensor/link/jax/dispatch/blockwise.py b/pytensor/link/jax/dispatch/blockwise.py index 5e691c141b..7151394354 100644 --- a/pytensor/link/jax/dispatch/blockwise.py +++ b/pytensor/link/jax/dispatch/blockwise.py @@ -1,24 +1,16 @@ import jax.numpy as jnp -from pytensor.graph import FunctionGraph from pytensor.link.jax.dispatch import jax_funcify from pytensor.tensor.blockwise import Blockwise @jax_funcify.register(Blockwise) -def funcify_Blockwise(op: Blockwise, node, *args, **kwargs): +def jax_funcify_Blockwise(op: Blockwise, node, **kwargs): signature = op.signature - core_node = op._create_dummy_core_node(node.inputs) - core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs) - tuple_core_fn = jax_funcify(core_fgraph) - - if len(node.outputs) == 1: - - def core_fn(*inputs): - return tuple_core_fn(*inputs)[0] - - else: - core_fn = tuple_core_fn + core_node = op._create_dummy_core_node( + node.inputs, propagate_unbatched_core_inputs=True + ) + core_fn = jax_funcify(core_node.op, node=core_node, **kwargs) vect_fn = jnp.vectorize(core_fn, signature=signature) diff --git a/pytensor/link/numba/dispatch/blockwise.py b/pytensor/link/numba/dispatch/blockwise.py index b7481bd5a3..45df8341ea 100644 --- a/pytensor/link/numba/dispatch/blockwise.py +++ b/pytensor/link/numba/dispatch/blockwise.py @@ -16,7 +16,7 @@ from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape -@numba_funcify.register +@numba_funcify.register(BlockwiseWithCoreShape) def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs): [blockwise_node] = op.fgraph.apply_nodes blockwise_op: Blockwise = blockwise_node.op @@ -26,7 +26,8 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs): core_shapes_len = tuple(get_vector_length(sh) for sh in node.inputs[nin:]) core_node = blockwise_op._create_dummy_core_node( - cast(tuple[TensorVariable], blockwise_node.inputs) + cast(tuple[TensorVariable], node.inputs[:nin]), + propagate_unbatched_core_inputs=True, ) core_op_fn = numba_funcify( core_op, diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 4cc59fd0cf..4b2a246795 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Sequence -from typing import Any, cast +from typing import Any, Literal, cast, overload import numpy as np from numpy import broadcast_shapes, empty @@ -32,6 +32,17 @@ from pytensor.tensor.variable import TensorVariable +def _squeeze_left(x, stop_at_dim: int | None = None): + """Squeeze any leading dims of `x` until a real dim or `stop_at_dim` (if not None) is reached.""" + x_dims = x.type.broadcastable + squeeze_ndim = len(x_dims) if all(x_dims) else x_dims.index(False) + if stop_at_dim is not None: + squeeze_ndim = min(squeeze_ndim, stop_at_dim) + if squeeze_ndim == 0: + return x + return x.squeeze(axis=tuple(range(squeeze_ndim))) + + def _vectorize_node_perform( core_node: Apply, batch_bcast_patterns: Sequence[tuple[bool, ...]], @@ -143,8 +154,6 @@ def _check_runtime_broadcast_core(numerical_inputs, batch_bcast_patterns, batch_ class Blockwise(COp): """Generalizes a core `Op` to work with batched dimensions. - TODO: Dispatch JAX (should be easy with the vectorize macro) - TODO: Dispatch Numba TODO: C implementation? TODO: Fuse Blockwise? """ @@ -202,21 +211,52 @@ def __init__( super().__init__(**kwargs) - def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply: - core_input_types = [] + @overload + def _create_dummy_core_node( + self, + inputs: Sequence[TensorVariable], + *, + propagate_unbatched_core_inputs: bool = False, + return_dummy_inputs: Literal[False] = ..., + ) -> Apply: ... + + @overload + def _create_dummy_core_node( + self, + inputs: Sequence[TensorVariable], + *, + propagate_unbatched_core_inputs: bool = False, + return_dummy_inputs: Literal[True] = ..., + ) -> tuple[Apply, list[TensorVariable]]: ... + + def _create_dummy_core_node( + self, + inputs: Sequence[TensorVariable], + *, + propagate_unbatched_core_inputs: bool = False, + return_dummy_inputs: bool = False, + ) -> Apply | tuple[Apply, list[TensorVariable]]: + core_inputs = [] + core_dummy_inputs = [] for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig, strict=True)): if inp.type.ndim < len(sig): raise ValueError( f"Input {i} {inp} has insufficient core dimensions for signature {self.signature}" ) # ndim_supp = 0 case - if not sig: - core_shape = () + inp_ndim = inp.type.ndim + batch_ndim = inp_ndim - len(sig) + core_shape = inp.type.shape[batch_ndim:] + if propagate_unbatched_core_inputs and all( + inp.type.broadcastable[:batch_ndim] + ): + core_inputs.append(_squeeze_left(inp, batch_ndim)) else: - core_shape = inp.type.shape[-len(sig) :] - core_input_types.append(tensor(dtype=inp.type.dtype, shape=core_shape)) + dummy_inp = tensor(dtype=inp.type.dtype, shape=core_shape) + core_inputs.append(dummy_inp) + core_dummy_inputs.append(dummy_inp) - core_node = self.core_op.make_node(*core_input_types) + core_node = self.core_op.make_node(*core_inputs) if len(core_node.outputs) != len(self.outputs_sig): raise ValueError( @@ -230,6 +270,9 @@ def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply: f"Output {i} of {self.core_op} has wrong number of core dimensions for signature {self.signature}: {core_out.type.ndim}" ) + if return_dummy_inputs: + return core_node, core_dummy_inputs + return core_node def make_node(self, *inputs): @@ -298,11 +341,17 @@ def infer_shape( batch_shape = broadcast_shape(*batch_shapes, arrays_are_shapes=True) - # Try to extract the core shapes from the core_op - core_op_infer_shape = getattr(self.core_op, "infer_shape", None) - if core_op_infer_shape is not None: - dummy_core_node = self._create_dummy_core_node(node.inputs) - dummy_core_inputs = tuple(explicit_graph_inputs(dummy_core_node.inputs)) + def extract_core_shape_from_infer_shape(): + # Try to extract the core shapes from the core_op + core_op_infer_shape = getattr(self.core_op, "infer_shape", None) + if core_op_infer_shape is None: + return [[None] * out.ndim for out in node.outputs] + + dummy_core_node, dummy_core_inputs = self._create_dummy_core_node( + node.inputs, + return_dummy_inputs=True, + propagate_unbatched_core_inputs=True, + ) dummy_fgraph = FunctionGraph(outputs=dummy_core_node.outputs, clone=False) core_input_shapes = [ input_shape[batch_ndims:] for input_shape in input_shapes @@ -311,6 +360,25 @@ def infer_shape( dummy_fgraph, dummy_core_node, core_input_shapes ) + # Set to None those core_shapes that depend on dummy_core_inputs, + # meaning their value may not be constant across batch dims of the Blockwise + if not dummy_core_inputs: + # All inputs are unbatched, so the core_shape can be used as is + return core_output_shapes + else: + set_dummy_core_inputs = set(dummy_core_inputs) + safe_core_output_shapes = [list(shape) for shape in core_output_shapes] + for core_out_shape in safe_core_output_shapes: + for o, core_out_dim in enumerate(core_out_shape): + if set_dummy_core_inputs & set( + explicit_graph_inputs([core_out_dim]) + ): + core_out_shape[o] = None + + return safe_core_output_shapes + + safe_core_out_shape = None + out_shapes = [] for o, (output, sig) in enumerate( zip(node.outputs, self.outputs_sig, strict=True) @@ -321,19 +389,15 @@ def infer_shape( if dim_name in core_dims: core_out_shape.append(core_dims[dim_name]) else: - if core_op_infer_shape is not None: - # If the input values are needed to compute the dimension length, we can't use the infer_shape - # of the core_node as the value is not constant across batch dims of the Blockwise - core_out_dim = core_output_shapes[o][i] - if not ( - set(dummy_core_inputs) - & set(explicit_graph_inputs([core_out_dim])) - ): - core_out_shape.append(core_out_dim) - continue - - # Fallback shape requires evaluating the Blockwise Op - core_out_shape.append(Shape_i(batch_ndims + i)(output)) + if safe_core_out_shape is None: + # Extract the core shape from the core_op infer_shape on demand + # For many Ops we never need to do this, because all info is in their signature + safe_core_out_shape = extract_core_shape_from_infer_shape() + if (core_out_dim := safe_core_out_shape[o][i]) is not None: + core_out_shape.append(core_out_dim) + else: + # Fallback shape requires evaluating the Blockwise Op + core_out_shape.append(Shape_i(batch_ndims + i)(output)) out_shapes.append((*batch_shape, *core_out_shape)) return out_shapes @@ -448,7 +512,10 @@ def gufunc( ) return core_func(*inputs) else: - core_node = self._create_dummy_core_node(node.inputs) # type: ignore + core_node = self._create_dummy_core_node( + cast(list[TensorVariable], node.inputs), + propagate_unbatched_core_inputs=True, + ) gufunc = _vectorize_node_perform( core_node, batch_bcast_patterns=batch_bcast_patterns, diff --git a/pytensor/tensor/rewriting/blockwise.py b/pytensor/tensor/rewriting/blockwise.py index 88ad4c1522..4879f86a72 100644 --- a/pytensor/tensor/rewriting/blockwise.py +++ b/pytensor/tensor/rewriting/blockwise.py @@ -4,7 +4,7 @@ from pytensor.graph.replace import vectorize_node from pytensor.graph.rewriting.basic import copy_stack_trace, out2in from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft -from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.blockwise import Blockwise, _squeeze_left from pytensor.tensor.math import Dot from pytensor.tensor.rewriting.basic import ( register_canonicalize, @@ -90,17 +90,6 @@ def local_eager_useless_unbatched_blockwise(fgraph, node): return local_useless_unbatched_blockwise.fn(fgraph, node) -def _squeeze_left(x, stop_at_dim: int | None = None): - """Squeeze any leading dims of `x` until a real dim or `stop_at_dim` (if not None) is reached.""" - x_dims = x.type.broadcastable - squeeze_ndim = len(x_dims) if all(x_dims) else x_dims.index(False) - if stop_at_dim is not None: - squeeze_ndim = min(squeeze_ndim, stop_at_dim) - if squeeze_ndim == 0: - return x - return x.squeeze(axis=tuple(range(squeeze_ndim))) - - @register_specialize("shape_unsafe") @node_rewriter([Blockwise]) def local_blockwise_alloc(fgraph, node): From 9026dd8d40e603c46995821336b15ec157839ddd Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 4 Jul 2025 13:54:56 +0200 Subject: [PATCH 4/4] Make convolve mode symbolic to avoid unnecessary large convolution in gradient --- pytensor/link/jax/dispatch/signal/conv.py | 16 ++- pytensor/link/numba/dispatch/signal/conv.py | 109 ++++++++++--------- pytensor/tensor/blockwise.py | 4 +- pytensor/tensor/rewriting/__init__.py | 1 - pytensor/tensor/rewriting/conv.py | 78 -------------- pytensor/tensor/signal/conv.py | 110 ++++++++++---------- tests/link/numba/signal/test_conv.py | 17 ++- tests/tensor/signal/test_conv.py | 46 ++++++-- 8 files changed, 168 insertions(+), 213 deletions(-) delete mode 100644 pytensor/tensor/rewriting/conv.py diff --git a/pytensor/link/jax/dispatch/signal/conv.py b/pytensor/link/jax/dispatch/signal/conv.py index 92414ac59a..788d9cc073 100644 --- a/pytensor/link/jax/dispatch/signal/conv.py +++ b/pytensor/link/jax/dispatch/signal/conv.py @@ -1,14 +1,24 @@ import jax from pytensor.link.jax.dispatch import jax_funcify +from pytensor.tensor.basic import get_underlying_scalar_constant_value +from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.signal.conv import Convolve1d @jax_funcify.register(Convolve1d) def jax_funcify_Convolve1d(op, node, **kwargs): - mode = op.mode + _, _, full_mode = node.inputs + try: + full_mode = get_underlying_scalar_constant_value(full_mode) + except NotScalarConstantError: + raise NotImplementedError( + "Cannot compile Convolve1D to jax without static mode" + ) + static_mode = "full" if full_mode else "valid" - def conv1d(data, kernel): - return jax.numpy.convolve(data, kernel, mode=mode) + def conv1d(data, kernel, _runtime_full_mode): + # _runtime_full_mode is not used, as we only support static mode + return jax.numpy.convolve(data, kernel, mode=static_mode) return conv1d diff --git a/pytensor/link/numba/dispatch/signal/conv.py b/pytensor/link/numba/dispatch/signal/conv.py index cf163228ad..15d1bb29b1 100644 --- a/pytensor/link/numba/dispatch/signal/conv.py +++ b/pytensor/link/numba/dispatch/signal/conv.py @@ -9,62 +9,61 @@ @numba_funcify.register(Convolve1d) def numba_funcify_Convolve1d(op, node, **kwargs): # This specialized version is faster than the overloaded numba np.convolve - mode = op.mode a_dtype, b_dtype = node.inputs[0].type.dtype, node.inputs[1].type.dtype out_dtype = node.outputs[0].type.dtype innerprod = _get_inner_prod(a_dtype, b_dtype) - if mode == "valid": - - def valid_convolve1d(x, y): - nx = len(x) - ny = len(y) - if nx < ny: - x, y = y, x - nx, ny = ny, nx - y_flipped = y[::-1] - - length = nx - ny + 1 - ret = np.empty(length, out_dtype) - - for i in range(length): - ret[i] = innerprod(x[i : i + ny], y_flipped) - - return ret - - return numba_njit(valid_convolve1d) - - elif mode == "full": - - def full_convolve1d(x, y): - nx = len(x) - ny = len(y) - if nx < ny: - x, y = y, x - nx, ny = ny, nx - y_flipped = y[::-1] - - length = nx + ny - 1 - ret = np.empty(length, out_dtype) - idx = 0 - - for i in range(ny - 1): - k = i + 1 - ret[idx] = innerprod(x[:k], y_flipped[-k:]) - idx = idx + 1 - - for i in range(nx - ny + 1): - ret[idx] = innerprod(x[i : i + ny], y_flipped) - idx = idx + 1 - - for i in range(ny - 1): - k = ny - i - 1 - ret[idx] = innerprod(x[-k:], y_flipped[:k]) - idx = idx + 1 - - return ret - - return numba_njit(full_convolve1d) - - else: - raise ValueError(f"Unsupported mode: {mode}") + @numba_njit + def valid_convolve1d(x, y): + nx = len(x) + ny = len(y) + if nx < ny: + x, y = y, x + nx, ny = ny, nx + y_flipped = y[::-1] + + length = nx - ny + 1 + ret = np.empty(length, out_dtype) + + for i in range(length): + ret[i] = innerprod(x[i : i + ny], y_flipped) + + return ret + + @numba_njit + def full_convolve1d(x, y): + nx = len(x) + ny = len(y) + if nx < ny: + x, y = y, x + nx, ny = ny, nx + y_flipped = y[::-1] + + length = nx + ny - 1 + ret = np.empty(length, out_dtype) + idx = 0 + + for i in range(ny - 1): + k = i + 1 + ret[idx] = innerprod(x[:k], y_flipped[-k:]) + idx = idx + 1 + + for i in range(nx - ny + 1): + ret[idx] = innerprod(x[i : i + ny], y_flipped) + idx = idx + 1 + + for i in range(ny - 1): + k = ny - i - 1 + ret[idx] = innerprod(x[-k:], y_flipped[:k]) + idx = idx + 1 + + return ret + + @numba_njit + def convolve_1d(x, y, mode): + if mode: + return full_convolve1d(x, y) + else: + return valid_convolve1d(x, y) + + return convolve_1d diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 4b2a246795..14d9a53251 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -360,12 +360,12 @@ def extract_core_shape_from_infer_shape(): dummy_fgraph, dummy_core_node, core_input_shapes ) - # Set to None those core_shapes that depend on dummy_core_inputs, - # meaning their value may not be constant across batch dims of the Blockwise if not dummy_core_inputs: # All inputs are unbatched, so the core_shape can be used as is return core_output_shapes else: + # Set to None those core_shapes that depend on dummy_core_inputs, + # meaning their value may not be constant across batch dims of the Blockwise set_dummy_core_inputs = set(dummy_core_inputs) safe_core_output_shapes = [list(shape) for shape in core_output_shapes] for core_out_shape in safe_core_output_shapes: diff --git a/pytensor/tensor/rewriting/__init__.py b/pytensor/tensor/rewriting/__init__.py index 34e070bfcf..6d411d3827 100644 --- a/pytensor/tensor/rewriting/__init__.py +++ b/pytensor/tensor/rewriting/__init__.py @@ -3,7 +3,6 @@ import pytensor.tensor.rewriting.blas_c import pytensor.tensor.rewriting.blas_scipy import pytensor.tensor.rewriting.blockwise -import pytensor.tensor.rewriting.conv import pytensor.tensor.rewriting.einsum import pytensor.tensor.rewriting.elemwise import pytensor.tensor.rewriting.extra_ops diff --git a/pytensor/tensor/rewriting/conv.py b/pytensor/tensor/rewriting/conv.py deleted file mode 100644 index 37a3fdc00f..0000000000 --- a/pytensor/tensor/rewriting/conv.py +++ /dev/null @@ -1,78 +0,0 @@ -from pytensor.graph.basic import Constant -from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter -from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.rewriting.basic import register_specialize, register_stabilize -from pytensor.tensor.signal import convolve1d -from pytensor.tensor.signal.conv import Convolve1d -from pytensor.tensor.subtensor import Subtensor, indices_from_subtensor - - -@register_stabilize -@register_specialize -@node_rewriter([Subtensor]) -def local_sliced_full_conv_to_valid_conv(fgraph, node): - """Rewrite sliced full conv that are equivalent to valid. - - The gradient of a valid Conv1d always implements the worst case scenario - full convolution - - because it would need to know which input is larger to do something smarter. - If we find out (through rewrites or static shape) we provide the direct implementation - which can be orders of magnitude faster. - - # if x.shape[-1] > y.shape[-1] - # z = convolve1d(x, y, mode="full") - # z[..., y.shape[-1] - 1: z.shape[-1] - y.shape[-1] - 1] -> convolve1d(x, y, mode="valid") - """ - conv, *other_idx_vars = node.inputs - - if not ( - conv.owner is not None - and isinstance(conv.owner.op, Blockwise) - and isinstance(conv.owner.op.core_op, Convolve1d) - and conv.owner.op.core_op.mode == "full" - ): - return None - - # Check we have an (a:b) constant slice at the last axis of the input - idx_list = node.op.idx_list - if not (len(idx_list) == conv.type.ndim and isinstance(idx_list[-1], slice)): - return None - - last_slice = idx_list[-1] - if not ( - last_slice.start is not None - and last_slice.stop is not None - and last_slice.step is None - ): - return None - - *other_idx_vars, start, stop = other_idx_vars - if not (isinstance(start, Constant) and isinstance(stop, Constant)): - return None - - x, y = conv.owner.inputs - len_x = x.type.shape[-1] - len_y = y.type.shape[-1] - if len_x is None or len_y is None: - return None - - start, stop = start.data, stop.data - if len_x < len_y: - # Convolution is symmetric with input order - x, y = y, x - len_x, len_y = len_y, len_x - - if ( - start == len_y - 1 - # equivalent to stop = conv.shape[-1] - len_y - 1 - and stop == start + (len_x - len_y) + 1 - ): - new_conv = convolve1d(x, y, mode="valid") - copy_stack_trace(conv, new_conv) - - if other_idx_vars: - # If there were more than just empty slices besides the last one - new_indices = indices_from_subtensor(idx_list[:-1], other_idx_vars) - new_conv = new_conv[new_indices] - copy_stack_trace(node.out, new_conv) - - return [new_conv] diff --git a/pytensor/tensor/signal/conv.py b/pytensor/tensor/signal/conv.py index 26c210fda3..5d5d0c8f40 100644 --- a/pytensor/tensor/signal/conv.py +++ b/pytensor/tensor/signal/conv.py @@ -1,13 +1,16 @@ from typing import TYPE_CHECKING, Literal, cast +import numpy as np from numpy import convolve as numpy_convolve -from pytensor.graph import Apply +from pytensor.gradient import DisconnectedType +from pytensor.graph import Apply, Constant from pytensor.link.c.op import COp +from pytensor.scalar import as_scalar from pytensor.scalar.basic import upcast from pytensor.tensor.basic import as_tensor_variable, join, zeros from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.math import maximum, minimum +from pytensor.tensor.math import maximum, minimum, switch from pytensor.tensor.type import vector from pytensor.tensor.variable import TensorVariable @@ -17,92 +20,83 @@ class Convolve1d(COp): - __props__ = ("mode",) - gufunc_signature = "(n),(k)->(o)" + __props__ = () + gufunc_signature = "(n),(k),()->(o)" - def __init__(self, mode: Literal["full", "valid"] = "full"): - if mode not in ("full", "valid"): - raise ValueError(f"Invalid mode: {mode}") - self.mode = mode - - def make_node(self, in1, in2): + def make_node(self, in1, in2, full_mode): in1 = as_tensor_variable(in1) in2 = as_tensor_variable(in2) + full_mode = as_scalar(full_mode) - assert in1.ndim == 1 - assert in2.ndim == 1 + if not (in1.ndim == 1 and in2.ndim == 1): + raise ValueError("Convolution inputs must be vector (ndim=1)") + if not full_mode.dtype == "bool": + raise ValueError("Convolution mode must be a boolean type") dtype = upcast(in1.dtype, in2.dtype) - n = in1.type.shape[0] k = in2.type.shape[0] + match full_mode: + case Constant(): + static_mode = "full" if full_mode.data else "valid" + case _: + static_mode = None - if n is None or k is None: + if n is None or k is None or static_mode is None: out_shape = (None,) - elif self.mode == "full": + elif static_mode == "full": out_shape = (n + k - 1,) else: # mode == "valid": out_shape = (max(n, k) - min(n, k) + 1,) out = vector(dtype=dtype, shape=out_shape) - return Apply(self, [in1, in2], [out]) + return Apply(self, [in1, in2, full_mode], [out]) def perform(self, node, inputs, outputs): # We use numpy_convolve as that's what scipy would use if method="direct" was passed. # And mode != "same", which this Op doesn't cover anyway. - outputs[0][0] = numpy_convolve(*inputs, mode=self.mode) + in1, in2, full_mode = inputs + outputs[0][0] = numpy_convolve(in1, in2, mode="full" if full_mode else "valid") def infer_shape(self, fgraph, node, shapes): - in1_shape, in2_shape = shapes + _, _, full_mode = node.inputs + in1_shape, in2_shape, _ = shapes n = in1_shape[0] k = in2_shape[0] - if self.mode == "full": - shape = n + k - 1 - else: # mode == "valid": - shape = maximum(n, k) - minimum(n, k) + 1 + shape_valid = maximum(n, k) - minimum(n, k) + 1 + shape_full = n + k - 1 + shape = switch(full_mode, shape_full, shape_valid) return [[shape]] + def connection_pattern(self, node): + return [[True], [True], [False]] + def L_op(self, inputs, outputs, output_grads): - in1, in2 = inputs + in1, in2, full_mode = inputs [grad] = output_grads - if self.mode == "full": - valid_conv = type(self)(mode="valid") - in1_bar = valid_conv(grad, in2[::-1]) - in2_bar = valid_conv(grad, in1[::-1]) + n = in1.shape[0] + k = in2.shape[0] - else: # mode == "valid": - full_conv = type(self)(mode="full") - n = in1.shape[0] - k = in2.shape[0] - kmn = maximum(0, k - n) - nmk = maximum(0, n - k) - # We need mode="full" if k >= n else "valid" for `in1_bar` (opposite for `in2_bar`), but mode is not symbolic. - # Instead, we always use mode="full" and slice the result so it behaves like "valid" for the input that's shorter. - # There is a rewrite that optimizes this case when n, k are static - in1_bar = full_conv(grad, in2[::-1]) - in1_bar = in1_bar[kmn : in1_bar.shape[0] - kmn] - in2_bar = full_conv(grad, in1[::-1]) - in2_bar = in2_bar[nmk : in2_bar.shape[0] - nmk] - - return [in1_bar, in2_bar] + # If mode is "full", or mode is "valid" and k >= n, then in1_bar mode should use "valid" convolve + # The expression below is equivalent to ~(full_mode | (k >= n)) + full_mode_in1_bar = ~full_mode & (k < n) + # If mode is "full", or mode is "valid" and n >= k, then in2_bar mode should use "valid" convolve + # The expression below is equivalent to ~(full_mode | (n >= k)) + full_mode_in2_bar = ~full_mode & (n < k) + + return [ + self(grad, in2[::-1], full_mode_in1_bar), + self(grad, in1[::-1], full_mode_in2_bar), + DisconnectedType()(), + ] def c_code_cache_version(self): - return (1,) + return (2,) def c_code(self, node, name, inputs, outputs, sub): - # raise NotImplementedError() - in1, in2 = inputs + in1, in2, full_mode = inputs [out] = outputs - mode_str = self.mode - - if mode_str == "full": - np_mode_val = 2 # NPY_CONVOLVE_FULL - elif mode_str == "valid": - np_mode_val = 0 # NPY_CONVOLVE_VALID - else: - # This case should ideally be prevented by __init__ or make_node - raise ValueError(f"Unsupported mode {mode_str}") code = f""" {{ @@ -158,7 +152,7 @@ def c_code(self, node, name, inputs, outputs, sub): // TODO: Use lower level implementation that allows reusing the output buffer Py_XDECREF({out}); - {out} = (PyArrayObject*) PyArray_Correlate2((PyObject*){in1}, (PyObject*)in2_flipped_view, {np_mode_val}); + {out} = (PyArrayObject*) PyArray_Correlate2((PyObject*){in1}, (PyObject*)in2_flipped_view, {full_mode} ? 2 : 0); Py_XDECREF(in2_flipped_view); // Clean up the view if correlate fails if (!{out}) {{ // PyArray_Correlate already set an error @@ -169,6 +163,9 @@ def c_code(self, node, name, inputs, outputs, sub): return code +blockwise_convolve_1d = Blockwise(Convolve1d()) + + def convolve1d( in1: "TensorLike", in2: "TensorLike", @@ -212,4 +209,5 @@ def convolve1d( ) mode = "valid" - return cast(TensorVariable, Blockwise(Convolve1d(mode=mode))(in1, in2)) + full_mode = as_scalar(np.bool_(mode == "full")) + return cast(TensorVariable, blockwise_convolve_1d(in1, in2, full_mode)) diff --git a/tests/link/numba/signal/test_conv.py b/tests/link/numba/signal/test_conv.py index d1e90a6dae..20d80bd0ab 100644 --- a/tests/link/numba/signal/test_conv.py +++ b/tests/link/numba/signal/test_conv.py @@ -7,6 +7,7 @@ from pytensor.tensor import dmatrix, tensor from pytensor.tensor.signal import convolve1d from tests.link.numba.test_basic import compare_numba_and_py +from tests.tensor.signal.test_conv import convolve1d_grad_benchmarker pytestmark = pytest.mark.filterwarnings("error") @@ -31,15 +32,8 @@ def test_convolve1d(x_smaller, mode): @pytest.mark.parametrize("mode", ("full", "valid"), ids=lambda x: f"mode={x}") @pytest.mark.parametrize("batch", (False, True), ids=lambda x: f"batch={x}") -def test_convolve1d_benchmark(batch, mode, benchmark): - x = tensor( - shape=( - 7, - 183, - ) - if batch - else (183,) - ) +def test_convolve1d_benchmark_numba(batch, mode, benchmark): + x = tensor(shape=(7, 183) if batch else (183,)) y = tensor(shape=(7, 6) if batch else (6,)) out = convolve1d(x, y, mode=mode) fn = function([x, y], out, mode="NUMBA", trust_input=True) @@ -57,3 +51,8 @@ def test_convolve1d_benchmark(batch, mode, benchmark): np_convolve1d(x_test, y_test), ) benchmark(fn, x_test, y_test) + + +@pytest.mark.parametrize("convolve_mode", ["full", "valid"]) +def test_convolve1d_grad_benchmark_numba(convolve_mode, benchmark): + convolve1d_grad_benchmarker(convolve_mode, "NUMBA", benchmark) diff --git a/tests/tensor/signal/test_conv.py b/tests/tensor/signal/test_conv.py index d6b0d69d7c..a22a07d101 100644 --- a/tests/tensor/signal/test_conv.py +++ b/tests/tensor/signal/test_conv.py @@ -7,7 +7,7 @@ from pytensor import config, function, grad from pytensor.graph.basic import ancestors, io_toposort from pytensor.graph.rewriting import rewrite_graph -from pytensor.tensor import matrix, vector +from pytensor.tensor import matrix, tensor, vector from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.signal.conv import Convolve1d, convolve1d from tests import unittest_tools as utt @@ -86,11 +86,8 @@ def test_convolve1d_batch_graph(mode): @pytest.mark.parametrize("static_shape", [False, True]) -def test_convolve1d_valid_grad_rewrite(static_shape): - """Test that we don't do a useless full convolve1d when taking the gradient of a valid convolve wrt to the smallest input. - - This can only be achieved when the two inputs have static shapes, so we know which one is larger - """ +def test_convolve1d_valid_grad(static_shape): + """Test we don't do a full convolve in the gradient of the smaller input to a valid convolve.""" larger = vector("larger", shape=(128 if static_shape else None,)) smaller = vector("smaller", shape=(64 if static_shape else None,)) out = convolve1d(larger, smaller, mode="valid") @@ -103,9 +100,40 @@ def test_convolve1d_valid_grad_rewrite(static_shape): "local_useless_unbatched_blockwise", ), ) - [conv_op] = [ - node.op + [conv_node] = [ + node for node in io_toposort([larger, smaller], [grad_out]) if isinstance(node.op, Convolve1d) ] - assert conv_op.mode == ("valid" if static_shape else "full") + full_mode = conv_node.inputs[-1] + # If shape is static we get constant mode == "valid", otherwise it depends on the input shapes + # ignoring E712 because np.True_ and np.False_ need to be compared with `==` to produce a valid boolean + if static_shape: + assert full_mode.eval() == False # noqa: E712 + else: + dtype = larger.dtype + larger_test = np.zeros((128,), dtype=dtype) + smaller_test = np.zeros((64,), dtype=dtype) + assert full_mode.eval({larger: larger_test, smaller: smaller_test}) == False # noqa: E712 + assert full_mode.eval({larger: smaller_test, smaller: larger_test}) == True # noqa: E712 + + +def convolve1d_grad_benchmarker(convolve_mode, mode, benchmark): + # Use None core shape so PyTensor doesn't know which mode to use until runtime. + larger = tensor("larger", shape=(8, None)) + smaller = tensor("smaller", shape=(8, None)) + grad_wrt_smaller = grad( + convolve1d(larger, smaller, mode=convolve_mode).sum(), wrt=smaller + ) + + fn = function([larger, smaller], grad_wrt_smaller, trust_input=True, mode=mode) + + rng = np.random.default_rng([119, mode == "full"]) + test_larger = rng.normal(size=(8, 1024)).astype(larger.type.dtype) + test_smaller = rng.normal(size=(8, 16)).astype(smaller.type.dtype) + benchmark(fn, test_larger, test_smaller) + + +@pytest.mark.parametrize("convolve_mode", ["full", "valid"]) +def test_convolve1d_grad_benchmark_c(convolve_mode, benchmark): + convolve1d_grad_benchmarker(convolve_mode, "FAST_RUN", benchmark)