diff --git a/coremltools/models/neural_network/quantization_utils.py b/coremltools/models/neural_network/quantization_utils.py index 8af017153..0178b7b5f 100644 --- a/coremltools/models/neural_network/quantization_utils.py +++ b/coremltools/models/neural_network/quantization_utils.py @@ -1688,7 +1688,7 @@ def quantize_weights( raise Exception("updatable models cannot get quantized to FP16.") qspec = _quantize_spec_weights(spec, nbits, qmode, **kwargs) - quantized_model = _get_model(qspec, compute_units=full_precision_model.compute_unit) + quantized_model = _get_model(qspec, compute_units=full_precision_model.compute_unit, **kwargs) if _macos_version() >= (10, 14) and sample_data: compare_models(full_precision_model, quantized_model, sample_data) diff --git a/coremltools/models/utils.py b/coremltools/models/utils.py index 97408611e..c84ef740e 100644 --- a/coremltools/models/utils.py +++ b/coremltools/models/utils.py @@ -356,7 +356,7 @@ def _convert_neural_network_weights_to_fp16(full_precision_model): return _get_model(_convert_neural_network_spec_weights_to_fp16(spec)) -def _get_model(spec, compute_units=_ComputeUnit.ALL): +def _get_model(spec, compute_units=_ComputeUnit.ALL, **kwargs): """ Utility to get the model and the data. """ @@ -365,7 +365,7 @@ def _get_model(spec, compute_units=_ComputeUnit.ALL): if isinstance(spec, MLModel): return spec else: - return MLModel(spec, compute_units=compute_units) + return MLModel(spec, compute_units=compute_units, **kwargs) def evaluate_regressor(model, data, target="target", verbose=False): """