From 2834b8706f980ba9c06663f4860f4f1d64e5d6a8 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Thu, 12 Oct 2023 22:06:40 -0700 Subject: [PATCH 1/9] fix overflow with softmax --- hls4ml/model/optimizer/__init__.py | 1 + .../passes/fix_softmax_table_size.py | 61 +++++++++++++++++++ test/pytest/test_softmax.py | 56 ++++++++--------- 3 files changed, 90 insertions(+), 28 deletions(-) create mode 100644 hls4ml/model/optimizer/passes/fix_softmax_table_size.py diff --git a/hls4ml/model/optimizer/__init__.py b/hls4ml/model/optimizer/__init__.py index 2e9b197475..7c1518fdae 100644 --- a/hls4ml/model/optimizer/__init__.py +++ b/hls4ml/model/optimizer/__init__.py @@ -47,6 +47,7 @@ register_flow( 'optimize', [ + 'fix_softmax_table_size', 'eliminate_linear_activation', 'fuse_consecutive_batch_normalization', 'fuse_batch_normalization', diff --git a/hls4ml/model/optimizer/passes/fix_softmax_table_size.py b/hls4ml/model/optimizer/passes/fix_softmax_table_size.py new file mode 100644 index 0000000000..74735c1d3c --- /dev/null +++ b/hls4ml/model/optimizer/passes/fix_softmax_table_size.py @@ -0,0 +1,61 @@ +import warnings + +from hls4ml.model.layers import Layer, Softmax +from hls4ml.model.optimizer import OptimizerPass + + +class FixSoftmaxTableSize(OptimizerPass): + def match(self, node): + return isinstance(node, Softmax) + + def transform(self, model, node: Layer): + inp_layer = node.get_input_node() # type: ignore + if not isinstance(inp_layer, Layer): + raise RuntimeError(f'Softmax layer {node.name} does not have an input layer') + + input_bw: int = inp_layer.get_attr('result_t').precision.width # type: ignore + table_bw: int = node.get_attr('inv_table_t').precision.width # type: ignore + table_size = int(node.get_attr('table_size')) # type: ignore + + backend = model.config.config['Backend'] + + # Somehow, Intel want one extra bits for the table. + # I don't know why but if not simulation will crash with segmentation fault. + backend_limitation = -1 if backend == 'Quartus' else 0 + + if 2 ** (min(input_bw, table_bw) + backend_limitation) < table_size: + # If table size is too large w.r.t. input bitwidth and table bitwidth, + # reduce table size to avoid undefined behavior when cutting indices from, + # fixed point number. + node.set_attr('table_size', str(2 ** (min(input_bw, table_bw) + backend_limitation))) + if 2**input_bw < table_size: + # The warning message does not have to be looking like this, but you are asking + # 125 characters long line. + warnings.warn( + ( + f"Softmax layer {node.name} table size is too large for input" + "bitwidth {input_bw}. Setting table size to {2**input_bw}." + "To avoid this warning, please increase input bitwidth or" + "decrease table size." + ), + stacklevel=1, + ) + if 2**table_bw < table_size: + warnings.warn( + ( + f"Softmax layer {node.name} table size is too large for input" + "bitwidth {input_bw}. Setting table size to {2**input_bw}." + "To avoid this warning, please increase input bitwidth or" + "decrease table size." + ), + stacklevel=1, + ) + if backend == 'Quartus': + warnings.warn( + ( + "Quartus backend's table size is half of 2^min(input_bw-1,table_bw-1)" + " instead of 2^min(input_bw,table_bw)." + ), + stacklevel=1, + ) + return False diff --git a/test/pytest/test_softmax.py b/test/pytest/test_softmax.py index 15680638da..02a43903f8 100644 --- a/test/pytest/test_softmax.py +++ b/test/pytest/test_softmax.py @@ -10,50 +10,47 @@ test_root_path = Path(__file__).parent -def flat_distribution(shape): - return np.random.rand(*shape) - - -def high_accuracy_distribution(shape): - '''Start with a flat distribution, then pick a random member of each row to amplify''' - x = np.random.rand(*shape) - imax = np.random.randint(0, shape[1], size=shape[0]) - x[:, imax] *= 10 - return x +def normal_dist(shape): + return np.clip(np.random.normal(0, 8, shape), -32, 31) @pytest.fixture() -def generate_data(function, input_shape): - return function((1000, *input_shape)) +def generate_data(input_shape): + return normal_dist((1000, *input_shape)) @pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) -@pytest.mark.parametrize('strategy', ['stable', 'argmax']) +@pytest.mark.parametrize('strategy', ['stable', 'latency', 'argmax']) @pytest.mark.parametrize( - 'function,input_shape,io_type', + 'input_bits,input_shape,table_bits,io_type', [ - (flat_distribution, (8,), 'io_parallel'), - (high_accuracy_distribution, (8,), 'io_parallel'), - (flat_distribution, (8,), 'io_stream'), - (high_accuracy_distribution, (8,), 'io_stream'), - (flat_distribution, (8, 8, 3), 'io_stream'), - (high_accuracy_distribution, (8, 8, 3), 'io_stream'), + ('16,6', (8,), '18,8', 'io_parallel'), + ('16,6', (8,), '18,8', 'io_stream'), + ('16,6', (8,), '9,6', 'io_parallel'), + ('16,6', (8,), '9,6', 'io_stream'), + ('9,6', (8,), '18,8', 'io_parallel'), + ('9,6', (8,), '18,8', 'io_stream'), + ('16,6', (8, 8, 3), '18,8', 'io_stream'), ], ) -def test_softmax(backend, strategy, generate_data, input_shape, io_type, function): +def test_softmax(backend, strategy, generate_data, input_bits, input_shape, table_bits, io_type): X = generate_data model = tf.keras.models.Sequential() model.add(tf.keras.layers.Activation(input_shape=input_shape, activation='softmax', name='softmax')) model.compile() - f_type = 'ac_fixed<18,8,true,AC_RND,AC_SAT>' if backend == 'Quartus' else 'ap_fixed<18,8,AP_RND,AP_SAT>' + f_type = ( + f'ac_fixed<{table_bits},true,AC_RND,AC_SAT>' if backend == 'Quartus' else f'ap_fixed<{table_bits},AP_RND,AP_SAT>' + ) cfg = hls4ml.utils.config_from_keras_model(model, granularity='name') cfg['LayerName']['softmax']['Strategy'] = strategy cfg['LayerName']['softmax']['inv_table_t'] = f_type cfg['LayerName']['softmax']['exp_table_t'] = f_type + cfg['LayerName']['softmax_input']['Precision']['result'] = f'ap_fixed<{input_bits}>' - odir = str(test_root_path / 'hls4mlprj_softmax_{}_{}_{}_{}_{}').format( - backend, io_type, strategy, function.__name__, str(input_shape) + odir = str( + test_root_path + / f'hls4mlprj_softmax_{backend}_{io_type}_{strategy}_{input_shape}_input-bits={input_bits}_table-bits={table_bits}' ) hls_model = hls4ml.converters.convert_from_keras_model( model, hls_config=cfg, io_type=io_type, output_dir=odir, backend=backend @@ -92,7 +89,10 @@ def test_softmax_skipped(backend, io_type): assert len(hls_layers) == 2 # Verify hls4ml output is equal to Dense output - y_keras = model.predict(X) - y_hls4ml = hls_model.predict(X).reshape(y_keras.shape) - keras_trace = hls4ml.model.profiling.get_ymodel_keras(model, X) - np.testing.assert_allclose(y_hls4ml, keras_trace['dense'], rtol=0, atol=2e-2) + y_keras = model.layers[0](X).numpy() # type: ignore + y_hls4ml = hls_model.predict(X).reshape(y_keras.shape) # type: ignore + np.testing.assert_allclose(y_hls4ml, y_keras, rtol=0, atol=2e-2) + + +if __name__ == '__main__': + test_softmax('Quartus', 'stable', generate_data((8,)), '9,6', (8,), '9,6', 'io_parallel') From 04e2a9bed4d5505e841e167c6252c05f97ae4cf7 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Fri, 13 Oct 2023 15:27:09 -0700 Subject: [PATCH 2/9] fix warning message --- hls4ml/model/optimizer/passes/fix_softmax_table_size.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hls4ml/model/optimizer/passes/fix_softmax_table_size.py b/hls4ml/model/optimizer/passes/fix_softmax_table_size.py index 74735c1d3c..d683bc2e30 100644 --- a/hls4ml/model/optimizer/passes/fix_softmax_table_size.py +++ b/hls4ml/model/optimizer/passes/fix_softmax_table_size.py @@ -34,7 +34,7 @@ def transform(self, model, node: Layer): warnings.warn( ( f"Softmax layer {node.name} table size is too large for input" - "bitwidth {input_bw}. Setting table size to {2**input_bw}." + f"bitwidth {input_bw}. Setting table size to {2**input_bw}." "To avoid this warning, please increase input bitwidth or" "decrease table size." ), @@ -44,7 +44,7 @@ def transform(self, model, node: Layer): warnings.warn( ( f"Softmax layer {node.name} table size is too large for input" - "bitwidth {input_bw}. Setting table size to {2**input_bw}." + f"bitwidth {input_bw}. Setting table size to {2**input_bw}." "To avoid this warning, please increase input bitwidth or" "decrease table size." ), From 0a768518bc79640362d22b4d4a269e498f8c96b8 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Fri, 20 Oct 2023 08:23:27 -0700 Subject: [PATCH 3/9] move sources --- .../fpga}/passes/fix_softmax_table_size.py | 4 ++++ hls4ml/backends/quartus/quartus_backend.py | 7 ++++--- hls4ml/backends/vivado/vivado_backend.py | 1 + hls4ml/model/optimizer/__init__.py | 1 - test/pytest/test_softmax.py | 4 ---- 5 files changed, 9 insertions(+), 8 deletions(-) rename hls4ml/{model/optimizer => backends/fpga}/passes/fix_softmax_table_size.py (95%) diff --git a/hls4ml/model/optimizer/passes/fix_softmax_table_size.py b/hls4ml/backends/fpga/passes/fix_softmax_table_size.py similarity index 95% rename from hls4ml/model/optimizer/passes/fix_softmax_table_size.py rename to hls4ml/backends/fpga/passes/fix_softmax_table_size.py index d683bc2e30..4e04626d2e 100644 --- a/hls4ml/model/optimizer/passes/fix_softmax_table_size.py +++ b/hls4ml/backends/fpga/passes/fix_softmax_table_size.py @@ -59,3 +59,7 @@ def transform(self, model, node: Layer): stacklevel=1, ) return False + + +def register_softmax__table_size_fix(backend): + backend.register_pass('fix_softmax_table_size', FixSoftmaxTableSize) diff --git a/hls4ml/backends/quartus/quartus_backend.py b/hls4ml/backends/quartus/quartus_backend.py index 382cd40b7d..996d969110 100644 --- a/hls4ml/backends/quartus/quartus_backend.py +++ b/hls4ml/backends/quartus/quartus_backend.py @@ -72,6 +72,7 @@ def _register_flows(self): 'quartus:inplace_parallel_reshape', 'quartus:inplace_stream_flatten', 'quartus:skip_softmax', + 'quartus:fix_softmax_table_size', ] optimization_flow = register_flow('optimize', optimization_passes, requires=[init_flow], backend=self.name) @@ -332,7 +333,7 @@ def init_lstm(self, layer): name=f'weight_{weight_types[i]}', var_name=f'kernel_{weight_types[i]}_{{index}}', data=weights_data[ - 0 : layer.get_attr('n_in'), i * layer.get_attr('n_out') : (i + 1) * layer.get_attr('n_out') + 0: layer.get_attr('n_in'), i * layer.get_attr('n_out'): (i + 1) * layer.get_attr('n_out') ], quantizer=layer.get_attr('weight_quantizer'), compression=None, @@ -341,7 +342,7 @@ def init_lstm(self, layer): name=f'recurrent_weight_{weight_types[i]}', var_name=f'recurrent_kernel_{weight_types[i]}_{{index}}', data=rec_weights_data[ - 0 : layer.get_attr('n_out'), i * layer.get_attr('n_out') : (i + 1) * layer.get_attr('n_out') + 0: layer.get_attr('n_out'), i * layer.get_attr('n_out'): (i + 1) * layer.get_attr('n_out') ], quantizer=layer.get_attr('weight_quantizer'), compression=None, @@ -349,7 +350,7 @@ def init_lstm(self, layer): layer.add_weights_variable( name=f'bias_{weight_types[i]}', var_name=f'bias_{weight_types[i]}_{{index}}', - data=bias_data[i * layer.get_attr('n_out') : (i + 1) * (layer.get_attr('n_out'))], + data=bias_data[i * layer.get_attr('n_out'): (i + 1) * (layer.get_attr('n_out'))], quantizer=layer.get_attr('weight_quantizer'), compression=None, ) diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 0f5dc5bd4d..c48ff574e5 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -108,6 +108,7 @@ def _register_flows(self): 'vivado:inplace_parallel_reshape', 'vivado:inplace_stream_flatten', 'vivado:skip_softmax', + 'vivado:fix_softmax_table_size', ] optimization_flow = register_flow('optimize', optimization_passes, requires=[init_flow], backend=self.name) diff --git a/hls4ml/model/optimizer/__init__.py b/hls4ml/model/optimizer/__init__.py index 7c1518fdae..2e9b197475 100644 --- a/hls4ml/model/optimizer/__init__.py +++ b/hls4ml/model/optimizer/__init__.py @@ -47,7 +47,6 @@ register_flow( 'optimize', [ - 'fix_softmax_table_size', 'eliminate_linear_activation', 'fuse_consecutive_batch_normalization', 'fuse_batch_normalization', diff --git a/test/pytest/test_softmax.py b/test/pytest/test_softmax.py index 02a43903f8..5096bdb79d 100644 --- a/test/pytest/test_softmax.py +++ b/test/pytest/test_softmax.py @@ -92,7 +92,3 @@ def test_softmax_skipped(backend, io_type): y_keras = model.layers[0](X).numpy() # type: ignore y_hls4ml = hls_model.predict(X).reshape(y_keras.shape) # type: ignore np.testing.assert_allclose(y_hls4ml, y_keras, rtol=0, atol=2e-2) - - -if __name__ == '__main__': - test_softmax('Quartus', 'stable', generate_data((8,)), '9,6', (8,), '9,6', 'io_parallel') From fb64887c1a18ce4ac23ef19219f0b3ed28e9cbd5 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Fri, 20 Oct 2023 08:29:55 -0700 Subject: [PATCH 4/9] format not-my-code --- hls4ml/backends/quartus/quartus_backend.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hls4ml/backends/quartus/quartus_backend.py b/hls4ml/backends/quartus/quartus_backend.py index 996d969110..2c2c05916b 100644 --- a/hls4ml/backends/quartus/quartus_backend.py +++ b/hls4ml/backends/quartus/quartus_backend.py @@ -333,7 +333,7 @@ def init_lstm(self, layer): name=f'weight_{weight_types[i]}', var_name=f'kernel_{weight_types[i]}_{{index}}', data=weights_data[ - 0: layer.get_attr('n_in'), i * layer.get_attr('n_out'): (i + 1) * layer.get_attr('n_out') + 0 : layer.get_attr('n_in'), i * layer.get_attr('n_out') : (i + 1) * layer.get_attr('n_out') ], quantizer=layer.get_attr('weight_quantizer'), compression=None, @@ -342,7 +342,7 @@ def init_lstm(self, layer): name=f'recurrent_weight_{weight_types[i]}', var_name=f'recurrent_kernel_{weight_types[i]}_{{index}}', data=rec_weights_data[ - 0: layer.get_attr('n_out'), i * layer.get_attr('n_out'): (i + 1) * layer.get_attr('n_out') + 0 : layer.get_attr('n_out'), i * layer.get_attr('n_out') : (i + 1) * layer.get_attr('n_out') ], quantizer=layer.get_attr('weight_quantizer'), compression=None, @@ -350,7 +350,7 @@ def init_lstm(self, layer): layer.add_weights_variable( name=f'bias_{weight_types[i]}', var_name=f'bias_{weight_types[i]}_{{index}}', - data=bias_data[i * layer.get_attr('n_out'): (i + 1) * (layer.get_attr('n_out'))], + data=bias_data[i * layer.get_attr('n_out') : (i + 1) * (layer.get_attr('n_out'))], quantizer=layer.get_attr('weight_quantizer'), compression=None, ) From 01b8df6d100b47ff016f72ed8bdd88e1ed55a2a7 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 31 Oct 2023 10:20:41 -0700 Subject: [PATCH 5/9] test_softmax: cleanup, use long tain dist --- test/pytest/test_softmax.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/test/pytest/test_softmax.py b/test/pytest/test_softmax.py index 5096bdb79d..8c3f7370e7 100644 --- a/test/pytest/test_softmax.py +++ b/test/pytest/test_softmax.py @@ -8,15 +8,16 @@ import hls4ml test_root_path = Path(__file__).parent - - -def normal_dist(shape): - return np.clip(np.random.normal(0, 8, shape), -32, 31) +test_root_path = Path('/tmp/unit_test') @pytest.fixture() def generate_data(input_shape): - return normal_dist((1000, *input_shape)) + shape = (5000, *input_shape) + d = np.random.normal(0, 2, shape) + modify_entries = np.random.randint(0, 1, shape) < 0.05 + d[modify_entries] = d[modify_entries]*5+10 + return np.clip(d, -32, 31) @pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) @@ -46,7 +47,7 @@ def test_softmax(backend, strategy, generate_data, input_bits, input_shape, tabl cfg['LayerName']['softmax']['Strategy'] = strategy cfg['LayerName']['softmax']['inv_table_t'] = f_type cfg['LayerName']['softmax']['exp_table_t'] = f_type - cfg['LayerName']['softmax_input']['Precision']['result'] = f'ap_fixed<{input_bits}>' + cfg['LayerName']['softmax_input']['Precision']['result'] = f'fixed<{input_bits}>' odir = str( test_root_path @@ -70,9 +71,9 @@ def test_softmax(backend, strategy, generate_data, input_bits, input_shape, tabl @pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) def test_softmax_skipped(backend, io_type): X = np.random.rand(100, 10) - model = tf.keras.models.Sequential() - model.add(tf.keras.layers.Dense(14, input_shape=(10,), name='dense')) - model.add(tf.keras.layers.Activation(activation='softmax', name='softmax')) + dense = tf.keras.layers.Dense(14, input_shape=(10,), name='dense') + softmax = tf.keras.layers.Activation(activation='softmax', name='softmax') + model = tf.keras.models.Sequential([dense,softmax]) model.compile() cfg = hls4ml.utils.config_from_keras_model(model, granularity='name') @@ -89,6 +90,6 @@ def test_softmax_skipped(backend, io_type): assert len(hls_layers) == 2 # Verify hls4ml output is equal to Dense output - y_keras = model.layers[0](X).numpy() # type: ignore - y_hls4ml = hls_model.predict(X).reshape(y_keras.shape) # type: ignore - np.testing.assert_allclose(y_hls4ml, y_keras, rtol=0, atol=2e-2) + y_keras_dense = dense(X).numpy() # type: ignore + y_hls4ml = hls_model.predict(X).reshape(y_keras_dense.shape) # type: ignore + np.testing.assert_allclose(y_hls4ml, y_keras_dense, rtol=0, atol=2e-2) From 0186f3a2a3ad88751d9eb315aaef8ee7f593305e Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 31 Oct 2023 10:44:07 -0700 Subject: [PATCH 6/9] format --- test/pytest/test_softmax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/pytest/test_softmax.py b/test/pytest/test_softmax.py index 8c3f7370e7..bd1aa03df0 100644 --- a/test/pytest/test_softmax.py +++ b/test/pytest/test_softmax.py @@ -16,7 +16,7 @@ def generate_data(input_shape): shape = (5000, *input_shape) d = np.random.normal(0, 2, shape) modify_entries = np.random.randint(0, 1, shape) < 0.05 - d[modify_entries] = d[modify_entries]*5+10 + d[modify_entries] = d[modify_entries] * 5 + 10 return np.clip(d, -32, 31) @@ -73,7 +73,7 @@ def test_softmax_skipped(backend, io_type): X = np.random.rand(100, 10) dense = tf.keras.layers.Dense(14, input_shape=(10,), name='dense') softmax = tf.keras.layers.Activation(activation='softmax', name='softmax') - model = tf.keras.models.Sequential([dense,softmax]) + model = tf.keras.models.Sequential([dense, softmax]) model.compile() cfg = hls4ml.utils.config_from_keras_model(model, granularity='name') From a0906497e0f75e0f320212ca7bb7b7f3c375253f Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 31 Oct 2023 13:40:43 -0700 Subject: [PATCH 7/9] cleanup v2 --- test/pytest/test_softmax.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/pytest/test_softmax.py b/test/pytest/test_softmax.py index bd1aa03df0..c71b279e24 100644 --- a/test/pytest/test_softmax.py +++ b/test/pytest/test_softmax.py @@ -8,7 +8,6 @@ import hls4ml test_root_path = Path(__file__).parent -test_root_path = Path('/tmp/unit_test') @pytest.fixture() @@ -40,9 +39,7 @@ def test_softmax(backend, strategy, generate_data, input_bits, input_shape, tabl model.add(tf.keras.layers.Activation(input_shape=input_shape, activation='softmax', name='softmax')) model.compile() - f_type = ( - f'ac_fixed<{table_bits},true,AC_RND,AC_SAT>' if backend == 'Quartus' else f'ap_fixed<{table_bits},AP_RND,AP_SAT>' - ) + f_type = 'fixed<{table_bits},true,RND,SAT>' cfg = hls4ml.utils.config_from_keras_model(model, granularity='name') cfg['LayerName']['softmax']['Strategy'] = strategy cfg['LayerName']['softmax']['inv_table_t'] = f_type From c7aa067eb9a4bd3bf682023839577504d52ed13d Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Tue, 31 Oct 2023 13:51:58 -0700 Subject: [PATCH 8/9] revert rm ap_/ac_ prefix --- test/pytest/test_softmax.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/pytest/test_softmax.py b/test/pytest/test_softmax.py index c71b279e24..6f6e193381 100644 --- a/test/pytest/test_softmax.py +++ b/test/pytest/test_softmax.py @@ -39,7 +39,10 @@ def test_softmax(backend, strategy, generate_data, input_bits, input_shape, tabl model.add(tf.keras.layers.Activation(input_shape=input_shape, activation='softmax', name='softmax')) model.compile() - f_type = 'fixed<{table_bits},true,RND,SAT>' + f_type = ( + f'ac_fixed<{table_bits},true,AC_RND,AC_SAT>' if backend == 'Quartus' else f'ap_fixed<{table_bits},AP_RND,AP_SAT>' + ) + cfg = hls4ml.utils.config_from_keras_model(model, granularity='name') cfg['LayerName']['softmax']['Strategy'] = strategy cfg['LayerName']['softmax']['inv_table_t'] = f_type From 3db1e836bbe0b8ad115c6c046593d0a8db9e85a3 Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Mon, 6 Nov 2023 22:39:35 +0100 Subject: [PATCH 9/9] Simplify table type names in softmax tests --- test/pytest/test_softmax.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/test/pytest/test_softmax.py b/test/pytest/test_softmax.py index 6f6e193381..3cab00745c 100644 --- a/test/pytest/test_softmax.py +++ b/test/pytest/test_softmax.py @@ -39,14 +39,12 @@ def test_softmax(backend, strategy, generate_data, input_bits, input_shape, tabl model.add(tf.keras.layers.Activation(input_shape=input_shape, activation='softmax', name='softmax')) model.compile() - f_type = ( - f'ac_fixed<{table_bits},true,AC_RND,AC_SAT>' if backend == 'Quartus' else f'ap_fixed<{table_bits},AP_RND,AP_SAT>' - ) + table_type = f'fixed<{table_bits}, RND, SAT>' cfg = hls4ml.utils.config_from_keras_model(model, granularity='name') cfg['LayerName']['softmax']['Strategy'] = strategy - cfg['LayerName']['softmax']['inv_table_t'] = f_type - cfg['LayerName']['softmax']['exp_table_t'] = f_type + cfg['LayerName']['softmax']['inv_table_t'] = table_type + cfg['LayerName']['softmax']['exp_table_t'] = table_type cfg['LayerName']['softmax_input']['Precision']['result'] = f'fixed<{input_bits}>' odir = str(