Skip to content

Commit

Permalink
fix global pooling bw inference
Browse files Browse the repository at this point in the history
  • Loading branch information
calad0i committed Feb 6, 2025
1 parent a558b49 commit f9e22d5
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions hls4ml/model/optimizer/passes/bit_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,10 +407,7 @@ def _(layer: Conv1D | Conv2D):
@_produce_kif.register(GlobalPooling1D)
@_produce_kif.register(GlobalPooling2D)
def _(layer: Pooling1D | Pooling2D | GlobalPooling1D | GlobalPooling2D):
if isinstance(layer, (Pooling1D, GlobalPooling1D)):
px_shape = (layer.attributes['pool_width'],)
else:
px_shape = (layer.attributes['pool_height'], layer.attributes['pool_width'])
px_shape = _get_px_shape(layer)
ch_out = ch_in = layer.attributes['n_filt']

im2col_shape = *px_shape, ch_in, ch_out # conv kernel shape
Expand All @@ -432,6 +429,8 @@ def _(layer: Pooling1D | Pooling2D | GlobalPooling1D | GlobalPooling2D):
raise ValueError('Average pooling with non-power-of-2 pool size cannot be bit-exact')
f_out += int(f_add)

if isinstance(layer, (GlobalPooling1D, GlobalPooling2D)):
k_out, i_out, f_out = k_out[0], i_out[0], f_out[0]
return k_out, i_out, f_out


Expand Down Expand Up @@ -665,6 +664,22 @@ def _(node: UnaryLUT):
default_register_precision(node)


def _get_px_shape(node: Layer):
if isinstance(node, Pooling1D):
px_shape = (node.attributes['pool_width'],)
elif isinstance(node, GlobalPooling1D):
inp_shape = get_input_shapes(node)[0]
px_shape = (inp_shape[0],)
elif isinstance(node, Pooling2D):
px_shape = (node.attributes['pool_height'], node.attributes['pool_width'])
elif isinstance(node, GlobalPooling2D):
inp_shape = get_input_shapes(node)[0]
px_shape = (inp_shape[0], inp_shape[1])
else:
raise ValueError(f'Layer {node.class_name} is not supported for pooling precision derivation')
return px_shape


@register_precision.register(Pooling1D)
@register_precision.register(Pooling2D)
@register_precision.register(GlobalPooling1D)
Expand All @@ -674,10 +689,7 @@ def _(node: Pooling1D | Pooling2D | GlobalPooling1D | GlobalPooling2D):
pool_op = node.attributes['pool_op']
if pool_op != 'Average':
return
if isinstance(node, (Pooling1D, GlobalPooling1D)):
px_shape = (node.attributes['pool_width'],)
else:
px_shape = (node.attributes['pool_height'], node.attributes['pool_width'])
px_shape = _get_px_shape(node)
i_add = int(log2(prod(px_shape)))
node.attributes['accum_t'].precision.width += i_add
node.attributes['accum_t'].precision.integer += i_add
Expand Down

0 comments on commit f9e22d5

Please sign in to comment.