Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions hls4ml/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def convert_from_keras_model(
output_data_tb=None,
backend='Vivado',
hls_config=None,
bit_exact=None,
**kwargs,
):
"""Convert Keras model to hls4ml model based on the provided configuration.
Expand Down Expand Up @@ -194,6 +195,10 @@ def convert_from_keras_model(
'io_parallel' or 'io_stream'. Defaults to 'io_parallel'.
hls_config (dict, optional): The HLS config.
kwargs** (dict, optional): Additional parameters that will be used to create the config of the specified backend
bit_exact (bool, optional): If True, enable model-wise precision propagation
with **only fixed-point data types**. If None, enable if there is at least one
FixedPointQuantizer layer in the model (only resulting from converting HGQ1/2
models for now). By default, None.

Raises:
Exception: If precision and reuse factor are not present in 'hls_config'.
Expand All @@ -214,6 +219,7 @@ def convert_from_keras_model(

model_config = hls_config.get('Model', None)
config['HLSConfig']['Model'] = _check_model_config(model_config)
config['HLSConfig']['Model']['BitExact'] = bit_exact

_check_hls_config(config, hls_config)
if 'KerasModel' in config:
Expand Down Expand Up @@ -306,6 +312,7 @@ def convert_from_onnx_model(
output_data_tb=None,
backend='Vivado',
hls_config=None,
bit_exact=None,
**kwargs,
):
"""Convert Keras model to hls4ml model based on the provided configuration.
Expand Down Expand Up @@ -335,6 +342,10 @@ def convert_from_onnx_model(
'io_parallel' or 'io_stream'. Defaults to 'io_parallel'.
hls_config (dict, optional): The HLS config.
kwargs** (dict, optional): Additional parameters that will be used to create the config of the specified backend
bit_exact (bool, optional): If True, enable model-wise precision propagation
with **only fixed-point data types**. If None, enable if there is at least one
FixedPointQuantizer layer in the model (only resulting from converting HGQ1/2
models for now). By default, None.

Raises:
Exception: If precision and reuse factor are not present in 'hls_config'.
Expand Down
1 change: 1 addition & 0 deletions hls4ml/converters/keras/qkeras.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def get_activation_quantizer(keras_layer, input_names, activation_name='activati
layer[activation_name] = activation_config['class_name'].replace('quantized_', '')

layer[f'{activation_name}_quantizer'] = activation_config
layer['trusted'] = True

return layer

Expand Down
44 changes: 43 additions & 1 deletion hls4ml/model/optimizer/passes/bit_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,30 @@ def _(layer: Reshape):
@_request_kif.register
def _(layer: Activation):
fn_name = layer.attributes.get('activation')

if layer.attributes.get('trusted', False):
result_t = layer.get_output_variable().type.precision
if fn_name in ('linear', 'relu'):
output_shape = get_output_shape(layer)
k, w, f = result_t.signed, result_t.width, result_t.fractional
i = w - k - f
k = np.full(output_shape, k, dtype=np.int8)
i = np.full(output_shape, i, dtype=np.int8)
f = np.full(output_shape, f, dtype=np.int8)
if result_t.rounding_mode == RoundingMode.RND:
f += 1
elif result_t.rounding_mode != RoundingMode.TRN:
f = np.full(output_shape, 126, dtype=np.int8)
if result_t.saturation_mode != SaturationMode.WRAP:
k = np.ones(output_shape, dtype=np.int8)
i = np.full(output_shape, 126, dtype=np.int8)
if fn_name == 'linear':
return ((k, i, f),)
else:
k = np.ones(output_shape, dtype=np.int8)
i = np.full(output_shape, 126, dtype=np.int8)
return ((k, i, f),)

if fn_name == 'linear':
return (requested_kif(layer),)
if fn_name == 'relu':
Expand Down Expand Up @@ -531,6 +555,16 @@ def _(layer: Concatenate):
@_produce_kif.register
def _(layer: Activation):
fn_name = layer.attributes['activation'].lower()
if layer.attributes.get('trusted', False):
output_shape = get_output_shape(layer)
result_t = layer.get_output_variable().type.precision
k, w, f = result_t.signed, result_t.width, result_t.fractional
i = w - k - f
k = np.full(output_shape, k, dtype=np.int8)
i = np.full(output_shape, i, dtype=np.int8)
f = np.full(output_shape, f, dtype=np.int8)
return k, i, f

k, i, f = get_input_kifs(layer)[0]

match fn_name:
Expand Down Expand Up @@ -603,6 +637,10 @@ def requested_by_non_saturating_quantizer(layer: Layer) -> bool:


def default_register_precision(layer: Layer):
if layer.attributes.get('trusted', False):
# Trusted layers have their precision already set
return

_pk, _pi, _pf = produce_kif(layer) # Maximum possible k,i,f output from this layer
_rk, _ri, _rf = requested_kif(layer) # Maximum possible k,i,f may be utilized by the next layer
_oi, _of = np.minimum(_pi, _ri), np.minimum(_pf, _rf)
Expand Down Expand Up @@ -791,7 +829,11 @@ def has_fixed_quantizer(self, model: 'ModelGraph'):
return True

def _match(self, model: 'ModelGraph'):
return self.has_fixed_quantizer(model)
enabled = model.config.config['HLSConfig']['Model'].get('BitExact', None)
if enabled is None:
# Enable by default if any FixedPointQuantizer is present
enabled = self.has_fixed_quantizer(model)
return enabled

def transform(self, model: 'ModelGraph'):
if not self._match(model):
Expand Down
29 changes: 17 additions & 12 deletions hls4ml/model/optimizer/passes/hgq_proxy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

from hls4ml.model.attributes import Attribute, TypeAttribute, WeightAttribute
from hls4ml.model.layers import Layer, Reshape, register_layer
from hls4ml.model.layers import Activation, Layer, Reshape, register_layer
from hls4ml.model.optimizer import OptimizerPass, register_pass
from hls4ml.model.types import FixedPrecisionType, UnspecifiedPrecisionType

Expand Down Expand Up @@ -77,11 +77,13 @@ def userconf_ifdef(key: str, layer_name: str, model):

class FuseFixedPointQuantizer(OptimizerPass):
def match(self, node: Layer):
if not isinstance(node, FixedPointQuantizer):
return False
if any(np.unique(x).size > 1 for x in node.mask_kbi):
return False
return True
if isinstance(node, FixedPointQuantizer):
return all(np.unique(x).size == 1 for x in node.mask_kbi)

if isinstance(node, Activation):
return node.get_attr('activation') == 'linear' and node.get_attr('trusted', False)

return False

def propagate(self, node: Layer, precision: FixedPrecisionType):
from hls4ml.model.optimizer.passes.bit_exact import get_input_layers, get_output_layers
Expand Down Expand Up @@ -113,13 +115,16 @@ def propagate(self, node: Layer, precision: FixedPrecisionType):
def transform(self, model: 'ModelGraph', node: FixedPointQuantizer):
from hls4ml.model.optimizer.passes.bit_exact import get_input_layers, get_output_layers

# Rounding and saturation for FixedPointQuantizer are applied in generated code, thus not reflected in result_t.
if node.RND == 'TRN' and node.SAT == 'WRAP':
precision: FixedPrecisionType = copy(node.get_output_variable().type.precision)
if isinstance(node, FixedPointQuantizer):
# Rounding and saturation for FixedPointQuantizer are applied in generated code, thus not reflected in result_t.
if node.RND == 'TRN' and node.SAT == 'WRAP':
precision: FixedPrecisionType = copy(node.get_output_variable().type.precision)
else:
k, b, i = node.mask_kbi
k, b, i = bool(k.ravel()[0]), max(int(b.ravel()[0]), 1), int(i.ravel()[0])
precision = FixedPrecisionType(b, i, k, node.RND, node.SAT)
else:
k, b, i = node.mask_kbi
k, b, i = bool(k.ravel()[0]), max(int(b.ravel()[0]), 1), int(i.ravel()[0])
precision = FixedPrecisionType(b, i, k, node.RND, node.SAT)
precision = copy(node.get_output_variable().type.precision)

inp_layer = get_input_layers(node)[0]
can_fuse = len(get_output_layers(inp_layer)) == 1
Expand Down
44 changes: 25 additions & 19 deletions test/pytest/test_qkeras.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

import numpy as np
import pytest
from keras.layers import BatchNormalization, Input
from keras.models import Model, Sequential, model_from_json
from keras.utils import to_categorical
from qkeras import QGRU, QLSTM, QSimpleRNN
from qkeras.qconv2d_batchnorm import QConv2DBatchnorm
from qkeras.qconvolutional import QDepthwiseConv2D, QSeparableConv1D, QSeparableConv2D
Expand All @@ -20,9 +23,6 @@
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from tensorflow.keras.layers import BatchNormalization, Input
from tensorflow.keras.models import Model, Sequential, model_from_json
from tensorflow.keras.utils import to_categorical

import hls4ml

Expand Down Expand Up @@ -142,33 +142,39 @@ def test_single_dense_activation_exact(randX_100_16, bits, alpha, backend, io_ty
bit exactness with number of bits parameter
'''
X = randX_100_16
model = Sequential()
model.add(
QDense(
16,
input_shape=(16,),
name='fc1',
kernel_quantizer=quantized_bits(bits, 0, alpha=alpha),
bias_quantizer=quantized_bits(bits, 0, alpha=1),
kernel_initializer='lecun_uniform',
)
model = Sequential(
[
QActivation(activation=quantized_bits(bits, 0, alpha=1), input_shape=(16,), name='inp_quant'),
QDense(
16,
name='fc1',
kernel_quantizer=quantized_bits(bits, 0, alpha=alpha),
bias_quantizer=quantized_bits(bits, 0, alpha=1),
kernel_initializer='lecun_uniform',
),
QActivation(activation=quantized_relu(bits, 0), name='relu1'),
]
)
model.add(QActivation(activation=quantized_relu(bits, 0), name='relu1'))
model.compile()

config = hls4ml.utils.config_from_keras_model(model, granularity='name', backend=backend)
output_dir = str(test_root_path / f'hls4mlprj_qkeras_single_dense_activation_exact_{bits}_{alpha}_{backend}_{io_type}')

bit_exact = alpha == 1
# alpha!=po2 case uses non-fixed-point data types, unsupported by the precision propagation flow
hls_model = hls4ml.converters.convert_from_keras_model(
model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type
model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type, bit_exact=bit_exact
)
hls_model.compile()

y_qkeras = model.predict(X)
y_hls4ml = hls_model.predict(X)
# Goal is to get it passing with all equal
# np.testing.assert_array_equal(y_qkeras, y_hls4ml)
# For now allow matching within 1 bit
np.testing.assert_allclose(y_qkeras.ravel(), y_hls4ml.ravel(), atol=2**-bits, rtol=1.0)

# alpha!=1 case for weights can be supported if weight conversion is done before writing
if bit_exact:
np.testing.assert_array_equal(y_qkeras, y_hls4ml)
else:
np.testing.assert_allclose(y_qkeras.ravel(), y_hls4ml.ravel(), atol=2**-bits, rtol=1.0)


@pytest.fixture
Expand Down
Loading