diff --git a/Tools/dea_tools/classification.py b/Tools/dea_tools/classification.py index 692119f6..a9d37e98 100644 --- a/Tools/dea_tools/classification.py +++ b/Tools/dea_tools/classification.py @@ -344,21 +344,29 @@ def _predict_func(model, input_xr, persist, proba, max_proba, clean, return_inpu out_proba = xr.DataArray( out_proba, coords={"x": x, "y": y}, dims=["y", "x"] ) + output_xr["Probabilities"] = out_proba else: print(" returning class probability array.") out_proba = out_proba * 100.0 - # Loop through each DataArray in the Dataset - for band_name in out_proba.data_vars: - reshaped_band = out_proba[band_name].values.reshape(len(y), len(x)) - reshaped_band = xr.DataArray( + + class_names = model.classes_ # Get the unique class names from the fitted classifier + + probabilities_dataset = xr.Dataset() + + # Loop through each class (band) + for i, class_name in enumerate(class_names): + reshaped_band = out_proba[:, i].reshape(len(y), len(x)) + reshaped_da = xr.DataArray( reshaped_band, coords={"x": x, "y": y}, dims=["y", "x"] ) - output_xr[out_proba] = reshaped_band + probabilities_dataset[f"prob_{class_name}"] = reshaped_da + + # merge in the probabilities + output_xr = xr.merge([output_xr, probabilities_dataset]) if clean == True: out_proba = da.where(da.isfinite(out_proba), out_proba, 0) - output_xr["Probabilities"] = out_proba if return_input == True: print(" input features...")