From 11ec4ed4bc326f591fe20c9f206cc9763d5abde5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20H=C3=BCbotter?= Date: Wed, 28 Aug 2024 15:46:25 +0200 Subject: [PATCH] black --- afsl/acquisition_functions/__init__.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/afsl/acquisition_functions/__init__.py b/afsl/acquisition_functions/__init__.py index 9e03455..f84df9f 100644 --- a/afsl/acquisition_functions/__init__.py +++ b/afsl/acquisition_functions/__init__.py @@ -328,7 +328,9 @@ def compute_fn( indexed_dataset = _IndexedDataset(dataset) selected_indices = range(len(dataset)) - while len(selected_indices) > batch_size: # gradually shrinks size of selected batch, until the correct size is reached + while ( + len(selected_indices) > batch_size + ): # gradually shrinks size of selected batch, until the correct size is reached data_loader = DataLoader( indexed_dataset, batch_size=self.mini_batch_size, @@ -339,7 +341,9 @@ def compute_fn( selected_indices = [] selected_values = [] for data, idx in data_loader: - sub_idx, sub_val = self.select_from_minibatch(batch_size, model, data, device) + sub_idx, sub_val = self.select_from_minibatch( + batch_size, model, data, device + ) selected_indices.extend(idx[sub_idx].cpu().tolist()) selected_values.extend(sub_val.cpu().tolist()) if self.subsample: