Skip to content

Commit 339e1a5

Browse files
authored
Fix writing index variables for parallel batches (#123)
* Fix writing index variables for parallel batches Issue when running batches in parallel using a multi-process scheduler: model state is not returned by execution of dask delayed functions, so the state was lost. Fix by calling store.write_index_vars for each model run. It is not a big deal to write those datasets many times (should be same values across all simulations). * clean up
1 parent 9d7cbc0 commit 339e1a5

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

xsimlab/drivers.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,7 @@ def run_model(self):
328328
self.store.write_input_xr_dataset()
329329

330330
if self.batch_dim is None:
331-
model = self.model
332-
self._run_one_model(self.dataset, model, parallel=self.parallel)
331+
self._run_one_model(self.dataset, self.model, parallel=self.parallel)
333332

334333
else:
335334
ds_gby_batch = self.dataset.groupby(self.batch_dim)
@@ -348,8 +347,6 @@ def run_model(self):
348347
if self.parallel:
349348
dask.compute(futures, scheduler=self.scheduler)
350349

351-
self.store.write_index_vars(model=model)
352-
353350
def _run_one_model(self, dataset, model, batch=-1, parallel=False):
354351
"""Run one simulation.
355352
@@ -406,3 +403,5 @@ def _run_one_model(self, dataset, model, batch=-1, parallel=False):
406403
self.store.write_output_vars(batch, -1, model=model)
407404

408405
model.execute("finalize", rt_context, **execute_kwargs)
406+
407+
self.store.write_index_vars(model=model)

xsimlab/tests/test_xr_accessor.py

+4
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,10 @@ def test_run_batch_dim(self, dims, data, clock, parallel, scheduler):
477477
class P:
478478
in_var = xs.variable(dims=[(), "x"])
479479
out_var = xs.variable(dims=[(), "x"], intent="out")
480+
idx_var = xs.index(dims="x")
481+
482+
def initialize(self):
483+
self.idx_var = [0, 1]
480484

481485
def run_step(self):
482486
self.out_var = self.in_var * 2

0 commit comments

Comments
 (0)