diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index dbfe65503b..ff05b48e49 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1890,6 +1890,23 @@ def _get_vector_length_MakeVector(op, var): return len(var.owner.inputs) +@_vectorize_node.register +def vectorize_make_vector(op: MakeVector, node, *batch_inputs): + # We vectorize make_vector as a join along the last axis of the broadcasted inputs + from pytensor.tensor.extra_ops import broadcast_arrays + + # Check if we need to broadcast at all + bcast_pattern = batch_inputs[0].type.broadcastable + if not all( + batch_input.type.broadcastable == bcast_pattern for batch_input in batch_inputs + ): + batch_inputs = broadcast_arrays(*batch_inputs) + + # Join along the last axis + new_out = stack(batch_inputs, axis=-1) + return new_out.owner + + def transfer(var, target): """ Return a version of `var` transferred to `target`. @@ -2690,6 +2707,10 @@ def vectorize_join(op: Join, node, batch_axis, *batch_inputs): # We can vectorize join as a shifted axis on the batch inputs if: # 1. The batch axis is a constant and has not changed # 2. All inputs are batched with the same broadcastable pattern + + # TODO: We can relax the second condition by broadcasting the batch dimensions + # This can be done with `broadcast_arrays` if the tensors shape match at the axis or reduction + # Or otherwise by calling `broadcast_to` for each tensor that needs it if ( original_axis.type.ndim == 0 and isinstance(original_axis, Constant) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index ed8909944a..49c8e9c38c 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -4577,6 +4577,46 @@ def core_np(x): ) +@pytest.mark.parametrize( + "batch_shapes", + [ + ((3,),), # edge case of make_vector with a single input + ((), (), ()), # Useless + ((3,), (3,), (3,)), # No broadcasting needed + ((3,), (5, 3), ()), # Broadcasting needed + ], +) +def test_vectorize_make_vector(batch_shapes): + n_inputs = len(batch_shapes) + input_sig = ",".join(["()"] * n_inputs) + signature = f"{input_sig}->({n_inputs})" # Something like "(),(),()->(3)" + + def core_pt(*scalars): + out = stack(scalars) + out.dprint() + return out + + def core_np(*scalars): + return np.stack(scalars) + + tensors = [tensor(shape=shape) for shape in batch_shapes] + + vectorize_pt = function(tensors, vectorize(core_pt, signature=signature)(*tensors)) + assert not any( + isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes + ) + + test_values = [ + np.random.normal(size=tensor.type.shape).astype(tensor.type.dtype) + for tensor in tensors + ] + + np.testing.assert_allclose( + vectorize_pt(*test_values), + np.vectorize(core_np, signature=signature)(*test_values), + ) + + @pytest.mark.parametrize("axis", [constant(1), constant(-2), shared(1)]) @pytest.mark.parametrize("broadcasting_y", ["none", "implicit", "explicit"]) @config.change_flags(cxx="") # C code not needed