diff --git a/irksome/dirk_stepper.py b/irksome/dirk_stepper.py index 7d7c480..fb87d66 100644 --- a/irksome/dirk_stepper.py +++ b/irksome/dirk_stepper.py @@ -14,7 +14,7 @@ def __init__(self): pass def __call__(self, u): - return u.dat.data_ro + return u class BCCompOfNotMixedThingy: @@ -22,7 +22,7 @@ def __init__(self, comp): self.comp = comp def __call__(self, u): - return u.dat.data_ro[:, self.comp] + return u[self.comp] class BCMixedBitThingy: @@ -30,7 +30,7 @@ def __init__(self, sub): self.sub = sub def __call__(self, u): - return u.dat[self.sub].data_ro + return u.sub(self.sub) class BCCompOfMixedBitThingy: @@ -39,7 +39,7 @@ def __init__(self, sub, comp): self.comp = comp def __call__(self, u): - return u.dat[self.sub].data_ro[:, self.comp] + return u.sub(self.sub)[self.comp] def getThingy(V, bc): @@ -119,7 +119,8 @@ def getFormDIRK(F, butch, t, dt, u0, bcs=None): bcnew.append(new_bc) dat4bc = getThingy(V, bc) - gblah.append((gdat, bcarg_stage, gmethod, dat4bc)) + gdat2 = Function(gdat.function_space()) + gblah.append((gdat, gdat2, bcarg_stage, gmethod, dat4bc)) return stage_F, (k, g, a, c), bcnew, gblah @@ -182,6 +183,7 @@ def advance(self): AA = bt.A CC = bt.c BB = bt.b + gsplit = g.split() for i in range(self.num_stages): # update a, c constants tucked into the variational problem # for the current stage @@ -191,24 +193,25 @@ def advance(self): # variational form g.assign(u0) for j in range(i): - for (gd, kd) in zip(g.dat, ks[j].dat): - gd.data[:] += dtc * AA[i, j] * kd.data_ro[:] + ksplit = ks[j].split() + for (gbit, kbit) in zip(gsplit, ksplit): + gbit += dtc * float(AA[i, j]) * kbit # update BC's for the variational problem - for (bc, (gdat, gcur, gmethod, dat4bc)) in zip(self.bcnew, self.gblah): + for (bc, (gdat, gdat2, gcur, gmethod, dat4bc)) in zip(self.bcnew, self.gblah): # Evaluate the Dirichlet BC at the current stage time gmethod(gdat, gcur) - # Now modify gdat based on the evolving solution - # subtract u0 from gdat - gdat.dat.data[:] -= dat4bc(u0)[:] + gmethod(gdat2, dat4bc(u0)) + gdat -= gdat2 # Subtract previous stage values for j in range(i): - gdat.dat.data[:] -= dtc * AA[i, j] * dat4bc(ks[j])[:] + gmethod(gdat2, dat4bc(ks[j])) + gdat -= dtc * float(AA[i, j]) * gdat2 # Rescale gdat - gdat.dat.data[:] /= dtc * AA[i, i] + gdat /= dtc * float(AA[i, i]) # solve new variational problem, stash the computed # stage value. @@ -223,5 +226,5 @@ def advance(self): # update the solution with now-computed stage values. for i in range(self.num_stages): - for (u0d, kd) in zip(u0.dat, ks[i].dat): - u0d.data[:] += dtc * BB[i] * kd.data_ro[:] + for (u0bit, kbit) in zip(u0.split(), ks[i].split()): + u0bit += dtc * float(BB[i]) * kbit diff --git a/irksome/imex.py b/irksome/imex.py index b27e2d1..8d03195 100644 --- a/irksome/imex.py +++ b/irksome/imex.py @@ -241,10 +241,11 @@ def __init__(self, F, Fexp, butcher_tableau, pop_parent(self.u0.function_space().dm, self.UU.function_space().dm) num_fields = len(self.u0.function_space()) - for i, u0dat in enumerate(u0.dat): + u0split = u0.split() + for i, u0bit in enumerate(u0split): for s in range(self.num_stages): ii = s * num_fields + i - self.UU_old_split[ii].dat.data[:] = u0dat.data_ro[:] + self.UU_old_split[ii].assign(u0bit) for _ in range(num_its_initial): self.iterate() @@ -256,8 +257,7 @@ def iterate(self): push_parent(self.u0.function_space().dm, self.UU.function_space().dm) self.it_solver.solve() pop_parent(self.u0.function_space().dm, self.UU.function_space().dm) - for uod, uud in zip(self.UU_old.dat, self.UU.dat): - uod.data[:] = uud.data_ro[:] + self.UU_old.assign(self.UU) def propagate(self): """Moves the solution forward in time, to be followed by 0 or @@ -265,8 +265,9 @@ def propagate(self): ns = self.num_stages nf = self.num_fields - for i, u0dat in enumerate(self.u0.dat): - u0dat.data[:] = self.UU_old_split[(ns-1)*nf + i].dat.data_ro[:] + u0split = self.u0.split() + for i, u0bit in enumerate(u0split): + u0bit.assign(self.UU_old_split[(ns-1)*nf + i]) for gdat, gcur, gmethod in self.bcdat: gmethod(gdat, gcur) diff --git a/irksome/stage.py b/irksome/stage.py index d538fba..36fe024 100644 --- a/irksome/stage.py +++ b/irksome/stage.py @@ -358,16 +358,19 @@ def _update_stiff_acc(self): nf = self.num_fields ns = self.num_stages - for i, u0d in enumerate(u0.dat): - u0d.data[:] = UUs[nf*(ns-1)+i].dat.data_ro[:] + u0bits = u0.split() + for i, u0bit in enumerate(u0bits): + u0bit.assign(UUs[nf*(ns-1)+i]) def _update_general(self): (unew, Fupdate, update_bcs, update_bcs_gblah) = self.update_stuff for gdat, gcur, gmethod in update_bcs_gblah: gmethod(gdat, gcur) self.update_solver.solve() - for u0d, und in zip(self.u0.dat, unew.dat): - u0d.data[:] = und.data_ro[:] + u0bits = self.u0.split() + unewbits = unew.split() + for u0bit, unewbit in zip(u0bits, unewbits): + u0bit.assign(unewbit) def advance(self): for gdat, gcur, gmethod in self.bcdat: diff --git a/irksome/stepper.py b/irksome/stepper.py index 67521ab..4518eca 100644 --- a/irksome/stepper.py +++ b/irksome/stepper.py @@ -224,9 +224,10 @@ def _update_general(self): nf = self.num_fields ws = self.ws + u0bits = u0.split() for s in range(ns): - for i, u0d in enumerate(u0.dat): - u0d.data[:] += dtc * b[s] * ws[nf*s+i].dat.data_ro + for i, u0bit in enumerate(u0bits): + u0bit += dtc * float(b[s]) * ws[nf*s+i] def _update_A2Tmb(self): """Assuming the algebraic problem for the RK stages has been @@ -240,8 +241,9 @@ def _update_A2Tmb(self): nf = self.num_fields ws = self.ws - for i, u0d in enumerate(u0.dat): - u0d.data[:] += dtc * ws[nf*(ns-1)+i].dat.data_ro + u0bits = u0.split() + for i, u0bit in enumerate(u0bits): + u0bit += dtc * ws[nf*(ns-1)+i] def advance(self): """Advances the system from time `t` to time `t + dt`. diff --git a/tests/test_stokes.py b/tests/test_stokes.py index cb9edc0..dd8262a 100644 --- a/tests/test_stokes.py +++ b/tests/test_stokes.py @@ -144,7 +144,7 @@ def NSETest(butch, stage_type, splitting): @pytest.mark.parametrize('butch', (LobattoIIIC, RadauIIA)) def test_Stokes(N, butch, time_stages, stage_type, splitting): error = StokesTest(N, butch(time_stages), stage_type, splitting) - assert abs(error) < 2e-9 + assert abs(error) < 4e-9 @pytest.mark.parametrize('stage_type', ("deriv", "value"))