4
4
import numpy as np
5
5
import tensorrt as trt
6
6
import torch
7
+ from torch ._subclasses .fake_tensor import unset_fake_temporarily
7
8
from torch .fx .node import Target
8
9
from torch_tensorrt .dynamo ._SourceIR import SourceIR
9
10
from torch_tensorrt .dynamo .conversion import impl
@@ -32,21 +33,22 @@ def batch_norm(
32
33
source_ir : Optional [SourceIR ],
33
34
name : str ,
34
35
input : trt .ITensor ,
35
- weight : Optional [Union [trt .ITensor , torch .Tensor , np .ndarray ]],
36
- bias : Optional [Union [trt .ITensor , torch .Tensor , np .ndarray ]],
37
- running_mean : Optional [Union [trt .ITensor , torch .Tensor , np .ndarray ]],
38
- running_var : Optional [Union [trt .ITensor , torch .Tensor , np .ndarray ]],
39
- training : bool ,
40
36
momentum : float ,
41
37
eps : float ,
42
- cudnn_enabled : bool ,
43
38
return_mean_rstd : bool ,
39
+ weight : Optional [Union [trt .ITensor , torch .Tensor , np .ndarray ]] = None ,
40
+ bias : Optional [Union [trt .ITensor , torch .Tensor , np .ndarray ]] = None ,
41
+ running_mean : Optional [Union [trt .ITensor , torch .Tensor , np .ndarray ]] = None ,
42
+ running_var : Optional [Union [trt .ITensor , torch .Tensor , np .ndarray ]] = None ,
43
+ training : bool = False ,
44
+ cudnn_enabled : bool = False ,
44
45
) -> Union [trt .ITensor , Tuple [trt .ITensor , torch .Tensor , torch .Tensor ]]:
45
46
if has_dynamic_shape (input .shape ):
46
47
assert input .shape [1 ] != - 1 , "Channel dim can't be dynamic for batch norm."
47
48
48
49
# Save the original output shape for later use
49
50
output_shape = input .shape
51
+ feature_num = output_shape [1 ]
50
52
# We perform constant folding for batch norm when the weight, bias, running_mean, and running_var are all tensors.
51
53
# Batch norm operation can be fused into a single layer, which is more efficient than the original implementation.
52
54
# In this way, the batch norm layer will be fused with the Convolution layer and get a performance boost.
@@ -59,26 +61,41 @@ def batch_norm(
59
61
]
60
62
):
61
63
# We name the weight here according to the state_dict name
62
- weight = (
63
- get_trt_tensor (ctx , 1.0 , f"{ name } _weight" , dtype = input .dtype )
64
- if weight is None
65
- else get_trt_tensor (ctx , weight , f"{ name } _weight" )
66
- )
67
- bias = (
68
- get_trt_tensor (ctx , 0.0 , f"{ name } _bias" , dtype = input .dtype )
69
- if bias is None
70
- else get_trt_tensor (ctx , bias , f"{ name } _bias" )
71
- )
72
- running_mean = (
73
- get_trt_tensor (ctx , 0.0 , f"{ name } _running_mean" , dtype = input .dtype )
74
- if running_mean is None
75
- else get_trt_tensor (ctx , running_mean , f"{ name } _running_mean" )
76
- )
77
- running_var = (
78
- get_trt_tensor (ctx , 1.0 , f"{ name } _running_var" , dtype = input .dtype )
79
- if running_var is None
80
- else get_trt_tensor (ctx , running_var , f"{ name } _running_var" )
81
- )
64
+ with unset_fake_temporarily ():
65
+ weight = (
66
+ get_trt_tensor (
67
+ ctx , torch .ones ((feature_num ,)), f"{ name } _weight" , dtype = input .dtype
68
+ )
69
+ if weight is None
70
+ else get_trt_tensor (ctx , weight , f"{ name } _weight" )
71
+ )
72
+ bias = (
73
+ get_trt_tensor (
74
+ ctx , torch .zeros ((feature_num ,)), f"{ name } _bias" , dtype = input .dtype
75
+ )
76
+ if bias is None
77
+ else get_trt_tensor (ctx , bias , f"{ name } _bias" )
78
+ )
79
+ running_mean = (
80
+ get_trt_tensor (
81
+ ctx ,
82
+ torch .zeros ((feature_num ,)),
83
+ f"{ name } _running_mean" ,
84
+ dtype = input .dtype ,
85
+ )
86
+ if running_mean is None
87
+ else get_trt_tensor (ctx , running_mean , f"{ name } _running_mean" )
88
+ )
89
+ running_var = (
90
+ get_trt_tensor (
91
+ ctx ,
92
+ torch .ones ((feature_num ,)),
93
+ f"{ name } _running_var" ,
94
+ dtype = input .dtype ,
95
+ )
96
+ if running_var is None
97
+ else get_trt_tensor (ctx , running_var , f"{ name } _running_var" )
98
+ )
82
99
83
100
# eps_tensor for numerical stability
84
101
eps_tensor = get_trt_tensor (ctx , eps , f"{ name } _eps" , dtype = input .dtype )
@@ -110,8 +127,7 @@ def batch_norm(
110
127
111
128
# Reshape scale and bias_adjusted to match input shape for broadcasting
112
129
expanded_shape = [1 ] * len (output_shape )
113
- expanded_shape [1 ] = output_shape [1 ] # Set channel dimension
114
-
130
+ expanded_shape [1 ] = feature_num # Set channel dimension
115
131
scale_reshape = impl .shuffle .reshape (
116
132
ctx ,
117
133
target ,
@@ -143,21 +159,24 @@ def batch_norm(
143
159
)
144
160
145
161
else :
146
- if weight is None :
147
- weight = 1.0
162
+ with unset_fake_temporarily ():
163
+ if weight is None :
164
+ weight = torch .ones ((feature_num ,))
148
165
149
- if bias is None :
150
- bias = 0.0
166
+ if bias is None :
167
+ bias = torch . zeros (( feature_num ,))
151
168
152
- if running_mean is None :
153
- running_mean = 0.0
169
+ if running_mean is None :
170
+ running_mean = torch . zeros (( feature_num ,))
154
171
155
- if running_var is None :
156
- running_var = 1.0
157
- adjusted_scale , adjusted_bias = batch_norm_constant_folding (
158
- weight , bias , running_mean , running_var , eps
159
- )
160
- power = torch .ones_like (adjusted_scale )
172
+ if running_var is None :
173
+ running_var = torch .ones ((feature_num ,))
174
+
175
+ power = torch .ones_like (weight )
176
+
177
+ adjusted_scale , adjusted_bias = batch_norm_constant_folding (
178
+ weight , bias , running_mean , running_var , eps
179
+ )
161
180
162
181
adjusted_scale = to_trt_weights (
163
182
ctx ,
@@ -188,9 +207,7 @@ def batch_norm(
188
207
source_ir = source_ir ,
189
208
)
190
209
191
- output_shape = input .shape
192
210
if len (input .shape ) < 4 :
193
-
194
211
new_shape = (
195
212
(input .shape [0 ], input .shape [1 ], 1 , 1 )
196
213
if len (input .shape ) == 2
0 commit comments