Skip to content

Commit 95b4cb3

Browse files
committed
Make convolve mode symbolic to avoid unnecessary large convolution in gradient
1 parent 5b1db5a commit 95b4cb3

File tree

5 files changed

+164
-131
lines changed

5 files changed

+164
-131
lines changed
Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
11
import jax
22

33
from pytensor.link.jax.dispatch import jax_funcify
4+
from pytensor.tensor.basic import get_underlying_scalar_constant_value
5+
from pytensor.tensor.exceptions import NotScalarConstantError
46
from pytensor.tensor.signal.conv import Convolve1d
57

68

79
@jax_funcify.register(Convolve1d)
810
def jax_funcify_Convolve1d(op, node, **kwargs):
9-
mode = op.mode
11+
_, _, full_mode = node.inputs
12+
try:
13+
full_mode = get_underlying_scalar_constant_value(full_mode)
14+
except NotScalarConstantError:
15+
raise NotImplementedError(
16+
"Cannot compile Convolve1D to jax without static mode"
17+
)
18+
static_mode = "full" if full_mode else "valid"
1019

11-
def conv1d(data, kernel):
12-
return jax.numpy.convolve(data, kernel, mode=mode)
20+
def conv1d(data, kernel, _):
21+
return jax.numpy.convolve(data, kernel, mode=static_mode)
1322

1423
return conv1d

pytensor/link/numba/dispatch/signal/conv.py

Lines changed: 54 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -9,62 +9,61 @@
99
@numba_funcify.register(Convolve1d)
1010
def numba_funcify_Convolve1d(op, node, **kwargs):
1111
# This specialized version is faster than the overloaded numba np.convolve
12-
mode = op.mode
1312
a_dtype, b_dtype = node.inputs[0].type.dtype, node.inputs[1].type.dtype
1413
out_dtype = node.outputs[0].type.dtype
1514
innerprod = _get_inner_prod(a_dtype, b_dtype)
1615

17-
if mode == "valid":
18-
19-
def valid_convolve1d(x, y):
20-
nx = len(x)
21-
ny = len(y)
22-
if nx < ny:
23-
x, y = y, x
24-
nx, ny = ny, nx
25-
y_flipped = y[::-1]
26-
27-
length = nx - ny + 1
28-
ret = np.empty(length, out_dtype)
29-
30-
for i in range(length):
31-
ret[i] = innerprod(x[i : i + ny], y_flipped)
32-
33-
return ret
34-
35-
return numba_njit(valid_convolve1d)
36-
37-
elif mode == "full":
38-
39-
def full_convolve1d(x, y):
40-
nx = len(x)
41-
ny = len(y)
42-
if nx < ny:
43-
x, y = y, x
44-
nx, ny = ny, nx
45-
y_flipped = y[::-1]
46-
47-
length = nx + ny - 1
48-
ret = np.empty(length, out_dtype)
49-
idx = 0
50-
51-
for i in range(ny - 1):
52-
k = i + 1
53-
ret[idx] = innerprod(x[:k], y_flipped[-k:])
54-
idx = idx + 1
55-
56-
for i in range(nx - ny + 1):
57-
ret[idx] = innerprod(x[i : i + ny], y_flipped)
58-
idx = idx + 1
59-
60-
for i in range(ny - 1):
61-
k = ny - i - 1
62-
ret[idx] = innerprod(x[-k:], y_flipped[:k])
63-
idx = idx + 1
64-
65-
return ret
66-
67-
return numba_njit(full_convolve1d)
68-
69-
else:
70-
raise ValueError(f"Unsupported mode: {mode}")
16+
@numba_njit
17+
def valid_convolve1d(x, y):
18+
nx = len(x)
19+
ny = len(y)
20+
if nx < ny:
21+
x, y = y, x
22+
nx, ny = ny, nx
23+
y_flipped = y[::-1]
24+
25+
length = nx - ny + 1
26+
ret = np.empty(length, out_dtype)
27+
28+
for i in range(length):
29+
ret[i] = innerprod(x[i : i + ny], y_flipped)
30+
31+
return ret
32+
33+
@numba_njit
34+
def full_convolve1d(x, y):
35+
nx = len(x)
36+
ny = len(y)
37+
if nx < ny:
38+
x, y = y, x
39+
nx, ny = ny, nx
40+
y_flipped = y[::-1]
41+
42+
length = nx + ny - 1
43+
ret = np.empty(length, out_dtype)
44+
idx = 0
45+
46+
for i in range(ny - 1):
47+
k = i + 1
48+
ret[idx] = innerprod(x[:k], y_flipped[-k:])
49+
idx = idx + 1
50+
51+
for i in range(nx - ny + 1):
52+
ret[idx] = innerprod(x[i : i + ny], y_flipped)
53+
idx = idx + 1
54+
55+
for i in range(ny - 1):
56+
k = ny - i - 1
57+
ret[idx] = innerprod(x[-k:], y_flipped[:k])
58+
idx = idx + 1
59+
60+
return ret
61+
62+
@numba_njit
63+
def convolve_1d(x, y, mode):
64+
if mode:
65+
return full_convolve1d(x, y)
66+
else:
67+
return valid_convolve1d(x, y)
68+
69+
return convolve_1d

pytensor/tensor/signal/conv.py

Lines changed: 54 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
from typing import TYPE_CHECKING, Literal, cast
22

3+
import numpy as np
34
from numpy import convolve as numpy_convolve
45

5-
from pytensor.graph import Apply
6+
from pytensor.gradient import DisconnectedType
7+
from pytensor.graph import Apply, Constant
68
from pytensor.link.c.op import COp
9+
from pytensor.scalar import as_scalar
710
from pytensor.scalar.basic import upcast
811
from pytensor.tensor.basic import as_tensor_variable, join, zeros
912
from pytensor.tensor.blockwise import Blockwise
10-
from pytensor.tensor.math import maximum, minimum
13+
from pytensor.tensor.math import maximum, minimum, switch
1114
from pytensor.tensor.type import vector
1215
from pytensor.tensor.variable import TensorVariable
1316

@@ -17,92 +20,83 @@
1720

1821

1922
class Convolve1d(COp):
20-
__props__ = ("mode",)
21-
gufunc_signature = "(n),(k)->(o)"
23+
__props__ = ()
24+
gufunc_signature = "(n),(k),()->(o)"
2225

23-
def __init__(self, mode: Literal["full", "valid"] = "full"):
24-
if mode not in ("full", "valid"):
25-
raise ValueError(f"Invalid mode: {mode}")
26-
self.mode = mode
27-
28-
def make_node(self, in1, in2):
26+
def make_node(self, in1, in2, full_mode):
2927
in1 = as_tensor_variable(in1)
3028
in2 = as_tensor_variable(in2)
29+
full_mode = as_scalar(full_mode)
3130

32-
assert in1.ndim == 1
33-
assert in2.ndim == 1
31+
if not (in1.ndim == 1 and in2.ndim == 1):
32+
raise ValueError("Convolution inputs must be vector (ndim=1)")
33+
if not full_mode.dtype == "bool":
34+
raise ValueError("Convolution mode must be a boolean type")
3435

3536
dtype = upcast(in1.dtype, in2.dtype)
36-
3737
n = in1.type.shape[0]
3838
k = in2.type.shape[0]
39+
match full_mode:
40+
case Constant():
41+
static_mode = "full" if full_mode.data else "valid"
42+
case _:
43+
static_mode = None
3944

40-
if n is None or k is None:
45+
if n is None or k is None or static_mode is None:
4146
out_shape = (None,)
42-
elif self.mode == "full":
47+
elif static_mode == "full":
4348
out_shape = (n + k - 1,)
4449
else: # mode == "valid":
4550
out_shape = (max(n, k) - min(n, k) + 1,)
4651

4752
out = vector(dtype=dtype, shape=out_shape)
48-
return Apply(self, [in1, in2], [out])
53+
return Apply(self, [in1, in2, full_mode], [out])
4954

5055
def perform(self, node, inputs, outputs):
5156
# We use numpy_convolve as that's what scipy would use if method="direct" was passed.
5257
# And mode != "same", which this Op doesn't cover anyway.
53-
outputs[0][0] = numpy_convolve(*inputs, mode=self.mode)
58+
in1, in2, full_mode = inputs
59+
outputs[0][0] = numpy_convolve(in1, in2, mode="full" if full_mode else "valid")
5460

5561
def infer_shape(self, fgraph, node, shapes):
56-
in1_shape, in2_shape = shapes
62+
_, _, full_mode = node.inputs
63+
in1_shape, in2_shape, _ = shapes
5764
n = in1_shape[0]
5865
k = in2_shape[0]
59-
if self.mode == "full":
60-
shape = n + k - 1
61-
else: # mode == "valid":
62-
shape = maximum(n, k) - minimum(n, k) + 1
66+
shape_valid = maximum(n, k) - minimum(n, k) + 1
67+
shape_full = n + k - 1
68+
shape = switch(full_mode, shape_full, shape_valid)
6369
return [[shape]]
6470

71+
def connection_pattern(self, node):
72+
return [[True], [True], [False]]
73+
6574
def L_op(self, inputs, outputs, output_grads):
66-
in1, in2 = inputs
75+
in1, in2, full_mode = inputs
6776
[grad] = output_grads
6877

69-
if self.mode == "full":
70-
valid_conv = type(self)(mode="valid")
71-
in1_bar = valid_conv(grad, in2[::-1])
72-
in2_bar = valid_conv(grad, in1[::-1])
78+
n = in1.shape[0]
79+
k = in2.shape[0]
7380

74-
else: # mode == "valid":
75-
full_conv = type(self)(mode="full")
76-
n = in1.shape[0]
77-
k = in2.shape[0]
78-
kmn = maximum(0, k - n)
79-
nmk = maximum(0, n - k)
80-
# We need mode="full" if k >= n else "valid" for `in1_bar` (opposite for `in2_bar`), but mode is not symbolic.
81-
# Instead, we always use mode="full" and slice the result so it behaves like "valid" for the input that's shorter.
82-
# There is a rewrite that optimizes this case when n, k are static
83-
in1_bar = full_conv(grad, in2[::-1])
84-
in1_bar = in1_bar[kmn : in1_bar.shape[0] - kmn]
85-
in2_bar = full_conv(grad, in1[::-1])
86-
in2_bar = in2_bar[nmk : in2_bar.shape[0] - nmk]
87-
88-
return [in1_bar, in2_bar]
81+
# If mode is "full", or mode is "valid" and k >= n, then in1_bar mode should use "valid" convolve
82+
# The expression below is equivalent to ~(full_mode | (k >= n))
83+
full_mode_in1_bar = ~full_mode & (k < n)
84+
# If mode is "full", or mode is "valid" and n >= k, then in2_bar mode should use "valid" convolve
85+
# The expression below is equivalent to ~(full_mode | (n >= k))
86+
full_mode_in2_bar = ~full_mode & (n < k)
87+
88+
return [
89+
self(grad, in2[::-1], full_mode_in1_bar),
90+
self(grad, in1[::-1], full_mode_in2_bar),
91+
DisconnectedType()(),
92+
]
8993

9094
def c_code_cache_version(self):
91-
return (1,)
95+
return None # (2,)
9296

9397
def c_code(self, node, name, inputs, outputs, sub):
94-
# raise NotImplementedError()
95-
in1, in2 = inputs
98+
in1, in2, full_mode = inputs
9699
[out] = outputs
97-
mode_str = self.mode
98-
99-
if mode_str == "full":
100-
np_mode_val = 2 # NPY_CONVOLVE_FULL
101-
elif mode_str == "valid":
102-
np_mode_val = 0 # NPY_CONVOLVE_VALID
103-
else:
104-
# This case should ideally be prevented by __init__ or make_node
105-
raise ValueError(f"Unsupported mode {mode_str}")
106100

107101
code = f"""
108102
{{
@@ -158,7 +152,7 @@ def c_code(self, node, name, inputs, outputs, sub):
158152
159153
// TODO: Use lower level implementation that allows reusing the output buffer
160154
Py_XDECREF({out});
161-
{out} = (PyArrayObject*) PyArray_Correlate2((PyObject*){in1}, (PyObject*)in2_flipped_view, {np_mode_val});
155+
{out} = (PyArrayObject*) PyArray_Correlate2((PyObject*){in1}, (PyObject*)in2_flipped_view, {full_mode} ? 2 : 0);
162156
Py_XDECREF(in2_flipped_view); // Clean up the view if correlate fails
163157
if (!{out}) {{
164158
// PyArray_Correlate already set an error
@@ -169,6 +163,9 @@ def c_code(self, node, name, inputs, outputs, sub):
169163
return code
170164

171165

166+
blockwise_convolve_1d = Blockwise(Convolve1d())
167+
168+
172169
def convolve1d(
173170
in1: "TensorLike",
174171
in2: "TensorLike",
@@ -212,4 +209,5 @@ def convolve1d(
212209
)
213210
mode = "valid"
214211

215-
return cast(TensorVariable, Blockwise(Convolve1d(mode=mode))(in1, in2))
212+
full_mode = as_scalar(np.bool_(mode == "full"))
213+
return cast(TensorVariable, blockwise_convolve_1d(in1, in2, full_mode))

tests/link/numba/signal/test_conv.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pytensor.tensor import dmatrix, tensor
88
from pytensor.tensor.signal import convolve1d
99
from tests.link.numba.test_basic import compare_numba_and_py
10+
from tests.tensor.signal.test_conv import convolve1d_grad_benchmarker
1011

1112

1213
pytestmark = pytest.mark.filterwarnings("error")
@@ -32,14 +33,7 @@ def test_convolve1d(x_smaller, mode):
3233
@pytest.mark.parametrize("mode", ("full", "valid"), ids=lambda x: f"mode={x}")
3334
@pytest.mark.parametrize("batch", (False, True), ids=lambda x: f"batch={x}")
3435
def test_convolve1d_benchmark(batch, mode, benchmark):
35-
x = tensor(
36-
shape=(
37-
7,
38-
183,
39-
)
40-
if batch
41-
else (183,)
42-
)
36+
x = tensor(shape=(7, 183) if batch else (183,))
4337
y = tensor(shape=(7, 6) if batch else (6,))
4438
out = convolve1d(x, y, mode=mode)
4539
fn = function([x, y], out, mode="NUMBA", trust_input=True)
@@ -57,3 +51,8 @@ def test_convolve1d_benchmark(batch, mode, benchmark):
5751
np_convolve1d(x_test, y_test),
5852
)
5953
benchmark(fn, x_test, y_test)
54+
55+
56+
@pytest.mark.parametrize("convolve_mode", ["full", "valid"])
57+
def test_convolve1d_grad_benchmark_numba(convolve_mode, benchmark):
58+
convolve1d_grad_benchmarker(convolve_mode, "NUMBA", benchmark)

0 commit comments

Comments
 (0)