1
1
from typing import TYPE_CHECKING , Literal , cast
2
2
3
+ import numpy as np
3
4
from numpy import convolve as numpy_convolve
4
5
5
- from pytensor .graph import Apply
6
+ from pytensor .gradient import DisconnectedType
7
+ from pytensor .graph import Apply , Constant
6
8
from pytensor .link .c .op import COp
9
+ from pytensor .scalar import as_scalar
7
10
from pytensor .scalar .basic import upcast
8
11
from pytensor .tensor .basic import as_tensor_variable , join , zeros
9
12
from pytensor .tensor .blockwise import Blockwise
10
- from pytensor .tensor .math import maximum , minimum
13
+ from pytensor .tensor .math import maximum , minimum , switch
11
14
from pytensor .tensor .type import vector
12
15
from pytensor .tensor .variable import TensorVariable
13
16
17
20
18
21
19
22
class Convolve1d (COp ):
20
- __props__ = ("mode" , )
21
- gufunc_signature = "(n),(k)->(o)"
23
+ __props__ = ()
24
+ gufunc_signature = "(n),(k),() ->(o)"
22
25
23
- def __init__ (self , mode : Literal ["full" , "valid" ] = "full" ):
24
- if mode not in ("full" , "valid" ):
25
- raise ValueError (f"Invalid mode: { mode } " )
26
- self .mode = mode
27
-
28
- def make_node (self , in1 , in2 ):
26
+ def make_node (self , in1 , in2 , full_mode ):
29
27
in1 = as_tensor_variable (in1 )
30
28
in2 = as_tensor_variable (in2 )
29
+ full_mode = as_scalar (full_mode )
31
30
32
- assert in1 .ndim == 1
33
- assert in2 .ndim == 1
31
+ if not (in1 .ndim == 1 and in2 .ndim == 1 ):
32
+ raise ValueError ("Convolution inputs must be vector (ndim=1)" )
33
+ if not full_mode .dtype == "bool" :
34
+ raise ValueError ("Convolution mode must be a boolean type" )
34
35
35
36
dtype = upcast (in1 .dtype , in2 .dtype )
36
-
37
37
n = in1 .type .shape [0 ]
38
38
k = in2 .type .shape [0 ]
39
+ match full_mode :
40
+ case Constant ():
41
+ static_mode = "full" if full_mode .data else "valid"
42
+ case _:
43
+ static_mode = None
39
44
40
- if n is None or k is None :
45
+ if n is None or k is None or static_mode is None :
41
46
out_shape = (None ,)
42
- elif self . mode == "full" :
47
+ elif static_mode == "full" :
43
48
out_shape = (n + k - 1 ,)
44
49
else : # mode == "valid":
45
50
out_shape = (max (n , k ) - min (n , k ) + 1 ,)
46
51
47
52
out = vector (dtype = dtype , shape = out_shape )
48
- return Apply (self , [in1 , in2 ], [out ])
53
+ return Apply (self , [in1 , in2 , full_mode ], [out ])
49
54
50
55
def perform (self , node , inputs , outputs ):
51
56
# We use numpy_convolve as that's what scipy would use if method="direct" was passed.
52
57
# And mode != "same", which this Op doesn't cover anyway.
53
- outputs [0 ][0 ] = numpy_convolve (* inputs , mode = self .mode )
58
+ in1 , in2 , full_mode = inputs
59
+ outputs [0 ][0 ] = numpy_convolve (in1 , in2 , mode = "full" if full_mode else "valid" )
54
60
55
61
def infer_shape (self , fgraph , node , shapes ):
56
- in1_shape , in2_shape = shapes
62
+ _ , _ , full_mode = node .inputs
63
+ in1_shape , in2_shape , _ = shapes
57
64
n = in1_shape [0 ]
58
65
k = in2_shape [0 ]
59
- if self .mode == "full" :
60
- shape = n + k - 1
61
- else : # mode == "valid":
62
- shape = maximum (n , k ) - minimum (n , k ) + 1
66
+ shape_valid = maximum (n , k ) - minimum (n , k ) + 1
67
+ shape_full = n + k - 1
68
+ shape = switch (full_mode , shape_full , shape_valid )
63
69
return [[shape ]]
64
70
71
+ def connection_pattern (self , node ):
72
+ return [[True ], [True ], [False ]]
73
+
65
74
def L_op (self , inputs , outputs , output_grads ):
66
- in1 , in2 = inputs
75
+ in1 , in2 , full_mode = inputs
67
76
[grad ] = output_grads
68
77
69
- if self .mode == "full" :
70
- valid_conv = type (self )(mode = "valid" )
71
- in1_bar = valid_conv (grad , in2 [::- 1 ])
72
- in2_bar = valid_conv (grad , in1 [::- 1 ])
78
+ n = in1 .shape [0 ]
79
+ k = in2 .shape [0 ]
73
80
74
- else : # mode == "valid":
75
- full_conv = type (self )(mode = "full" )
76
- n = in1 .shape [0 ]
77
- k = in2 .shape [0 ]
78
- kmn = maximum (0 , k - n )
79
- nmk = maximum (0 , n - k )
80
- # We need mode="full" if k >= n else "valid" for `in1_bar` (opposite for `in2_bar`), but mode is not symbolic.
81
- # Instead, we always use mode="full" and slice the result so it behaves like "valid" for the input that's shorter.
82
- # There is a rewrite that optimizes this case when n, k are static
83
- in1_bar = full_conv (grad , in2 [::- 1 ])
84
- in1_bar = in1_bar [kmn : in1_bar .shape [0 ] - kmn ]
85
- in2_bar = full_conv (grad , in1 [::- 1 ])
86
- in2_bar = in2_bar [nmk : in2_bar .shape [0 ] - nmk ]
87
-
88
- return [in1_bar , in2_bar ]
81
+ # If mode is "full", or mode is "valid" and k >= n, then in1_bar mode should use "valid" convolve
82
+ # The expression below is equivalent to ~(full_mode | (k >= n))
83
+ full_mode_in1_bar = ~ full_mode & (k < n )
84
+ # If mode is "full", or mode is "valid" and n >= k, then in2_bar mode should use "valid" convolve
85
+ # The expression below is equivalent to ~(full_mode | (n >= k))
86
+ full_mode_in2_bar = ~ full_mode & (n < k )
87
+
88
+ return [
89
+ self (grad , in2 [::- 1 ], full_mode_in1_bar ),
90
+ self (grad , in1 [::- 1 ], full_mode_in2_bar ),
91
+ DisconnectedType ()(),
92
+ ]
89
93
90
94
def c_code_cache_version (self ):
91
- return ( 1 ,)
95
+ return None # (2 ,)
92
96
93
97
def c_code (self , node , name , inputs , outputs , sub ):
94
- # raise NotImplementedError()
95
- in1 , in2 = inputs
98
+ in1 , in2 , full_mode = inputs
96
99
[out ] = outputs
97
- mode_str = self .mode
98
-
99
- if mode_str == "full" :
100
- np_mode_val = 2 # NPY_CONVOLVE_FULL
101
- elif mode_str == "valid" :
102
- np_mode_val = 0 # NPY_CONVOLVE_VALID
103
- else :
104
- # This case should ideally be prevented by __init__ or make_node
105
- raise ValueError (f"Unsupported mode { mode_str } " )
106
100
107
101
code = f"""
108
102
{{
@@ -158,7 +152,7 @@ def c_code(self, node, name, inputs, outputs, sub):
158
152
159
153
// TODO: Use lower level implementation that allows reusing the output buffer
160
154
Py_XDECREF({ out } );
161
- { out } = (PyArrayObject*) PyArray_Correlate2((PyObject*){ in1 } , (PyObject*)in2_flipped_view, { np_mode_val } );
155
+ { out } = (PyArrayObject*) PyArray_Correlate2((PyObject*){ in1 } , (PyObject*)in2_flipped_view, { full_mode } ? 2 : 0 );
162
156
Py_XDECREF(in2_flipped_view); // Clean up the view if correlate fails
163
157
if (!{ out } ) {{
164
158
// PyArray_Correlate already set an error
@@ -169,6 +163,9 @@ def c_code(self, node, name, inputs, outputs, sub):
169
163
return code
170
164
171
165
166
+ blockwise_convolve_1d = Blockwise (Convolve1d ())
167
+
168
+
172
169
def convolve1d (
173
170
in1 : "TensorLike" ,
174
171
in2 : "TensorLike" ,
@@ -212,4 +209,5 @@ def convolve1d(
212
209
)
213
210
mode = "valid"
214
211
215
- return cast (TensorVariable , Blockwise (Convolve1d (mode = mode ))(in1 , in2 ))
212
+ full_mode = as_scalar (np .bool_ (mode == "full" ))
213
+ return cast (TensorVariable , blockwise_convolve_1d (in1 , in2 , full_mode ))
0 commit comments