Skip to content

Commit

Permalink
Merge pull request #56 from firedrakeproject/rckirby/nodat
Browse files Browse the repository at this point in the history
Remove dats from updates
  • Loading branch information
rckirby authored Oct 19, 2022
2 parents 1f5d7d6 + 51d4bd0 commit 73d3860
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 30 deletions.
33 changes: 18 additions & 15 deletions irksome/dirk_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,23 @@ def __init__(self):
pass

def __call__(self, u):
return u.dat.data_ro
return u


class BCCompOfNotMixedThingy:
def __init__(self, comp):
self.comp = comp

def __call__(self, u):
return u.dat.data_ro[:, self.comp]
return u[self.comp]


class BCMixedBitThingy:
def __init__(self, sub):
self.sub = sub

def __call__(self, u):
return u.dat[self.sub].data_ro
return u.sub(self.sub)


class BCCompOfMixedBitThingy:
Expand All @@ -39,7 +39,7 @@ def __init__(self, sub, comp):
self.comp = comp

def __call__(self, u):
return u.dat[self.sub].data_ro[:, self.comp]
return u.sub(self.sub)[self.comp]


def getThingy(V, bc):
Expand Down Expand Up @@ -119,7 +119,8 @@ def getFormDIRK(F, butch, t, dt, u0, bcs=None):
bcnew.append(new_bc)

dat4bc = getThingy(V, bc)
gblah.append((gdat, bcarg_stage, gmethod, dat4bc))
gdat2 = Function(gdat.function_space())
gblah.append((gdat, gdat2, bcarg_stage, gmethod, dat4bc))

return stage_F, (k, g, a, c), bcnew, gblah

Expand Down Expand Up @@ -182,6 +183,7 @@ def advance(self):
AA = bt.A
CC = bt.c
BB = bt.b
gsplit = g.split()
for i in range(self.num_stages):
# update a, c constants tucked into the variational problem
# for the current stage
Expand All @@ -191,24 +193,25 @@ def advance(self):
# variational form
g.assign(u0)
for j in range(i):
for (gd, kd) in zip(g.dat, ks[j].dat):
gd.data[:] += dtc * AA[i, j] * kd.data_ro[:]
ksplit = ks[j].split()
for (gbit, kbit) in zip(gsplit, ksplit):
gbit += dtc * float(AA[i, j]) * kbit

# update BC's for the variational problem
for (bc, (gdat, gcur, gmethod, dat4bc)) in zip(self.bcnew, self.gblah):
for (bc, (gdat, gdat2, gcur, gmethod, dat4bc)) in zip(self.bcnew, self.gblah):
# Evaluate the Dirichlet BC at the current stage time
gmethod(gdat, gcur)

# Now modify gdat based on the evolving solution
# subtract u0 from gdat
gdat.dat.data[:] -= dat4bc(u0)[:]
gmethod(gdat2, dat4bc(u0))
gdat -= gdat2

# Subtract previous stage values
for j in range(i):
gdat.dat.data[:] -= dtc * AA[i, j] * dat4bc(ks[j])[:]
gmethod(gdat2, dat4bc(ks[j]))
gdat -= dtc * float(AA[i, j]) * gdat2

# Rescale gdat
gdat.dat.data[:] /= dtc * AA[i, i]
gdat /= dtc * float(AA[i, i])

# solve new variational problem, stash the computed
# stage value.
Expand All @@ -223,5 +226,5 @@ def advance(self):

# update the solution with now-computed stage values.
for i in range(self.num_stages):
for (u0d, kd) in zip(u0.dat, ks[i].dat):
u0d.data[:] += dtc * BB[i] * kd.data_ro[:]
for (u0bit, kbit) in zip(u0.split(), ks[i].split()):
u0bit += dtc * float(BB[i]) * kbit
13 changes: 7 additions & 6 deletions irksome/imex.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,11 @@ def __init__(self, F, Fexp, butcher_tableau,
pop_parent(self.u0.function_space().dm, self.UU.function_space().dm)

num_fields = len(self.u0.function_space())
for i, u0dat in enumerate(u0.dat):
u0split = u0.split()
for i, u0bit in enumerate(u0split):
for s in range(self.num_stages):
ii = s * num_fields + i
self.UU_old_split[ii].dat.data[:] = u0dat.data_ro[:]
self.UU_old_split[ii].assign(u0bit)

for _ in range(num_its_initial):
self.iterate()
Expand All @@ -256,17 +257,17 @@ def iterate(self):
push_parent(self.u0.function_space().dm, self.UU.function_space().dm)
self.it_solver.solve()
pop_parent(self.u0.function_space().dm, self.UU.function_space().dm)
for uod, uud in zip(self.UU_old.dat, self.UU.dat):
uod.data[:] = uud.data_ro[:]
self.UU_old.assign(self.UU)

def propagate(self):
"""Moves the solution forward in time, to be followed by 0 or
more calls to `iterate`."""

ns = self.num_stages
nf = self.num_fields
for i, u0dat in enumerate(self.u0.dat):
u0dat.data[:] = self.UU_old_split[(ns-1)*nf + i].dat.data_ro[:]
u0split = self.u0.split()
for i, u0bit in enumerate(u0split):
u0bit.assign(self.UU_old_split[(ns-1)*nf + i])

for gdat, gcur, gmethod in self.bcdat:
gmethod(gdat, gcur)
Expand Down
11 changes: 7 additions & 4 deletions irksome/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,16 +358,19 @@ def _update_stiff_acc(self):
nf = self.num_fields
ns = self.num_stages

for i, u0d in enumerate(u0.dat):
u0d.data[:] = UUs[nf*(ns-1)+i].dat.data_ro[:]
u0bits = u0.split()
for i, u0bit in enumerate(u0bits):
u0bit.assign(UUs[nf*(ns-1)+i])

def _update_general(self):
(unew, Fupdate, update_bcs, update_bcs_gblah) = self.update_stuff
for gdat, gcur, gmethod in update_bcs_gblah:
gmethod(gdat, gcur)
self.update_solver.solve()
for u0d, und in zip(self.u0.dat, unew.dat):
u0d.data[:] = und.data_ro[:]
u0bits = self.u0.split()
unewbits = unew.split()
for u0bit, unewbit in zip(u0bits, unewbits):
u0bit.assign(unewbit)

def advance(self):
for gdat, gcur, gmethod in self.bcdat:
Expand Down
10 changes: 6 additions & 4 deletions irksome/stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,10 @@ def _update_general(self):
nf = self.num_fields

ws = self.ws
u0bits = u0.split()
for s in range(ns):
for i, u0d in enumerate(u0.dat):
u0d.data[:] += dtc * b[s] * ws[nf*s+i].dat.data_ro
for i, u0bit in enumerate(u0bits):
u0bit += dtc * float(b[s]) * ws[nf*s+i]

def _update_A2Tmb(self):
"""Assuming the algebraic problem for the RK stages has been
Expand All @@ -240,8 +241,9 @@ def _update_A2Tmb(self):
nf = self.num_fields

ws = self.ws
for i, u0d in enumerate(u0.dat):
u0d.data[:] += dtc * ws[nf*(ns-1)+i].dat.data_ro
u0bits = u0.split()
for i, u0bit in enumerate(u0bits):
u0bit += dtc * ws[nf*(ns-1)+i]

def advance(self):
"""Advances the system from time `t` to time `t + dt`.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_stokes.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def NSETest(butch, stage_type, splitting):
@pytest.mark.parametrize('butch', (LobattoIIIC, RadauIIA))
def test_Stokes(N, butch, time_stages, stage_type, splitting):
error = StokesTest(N, butch(time_stages), stage_type, splitting)
assert abs(error) < 2e-9
assert abs(error) < 4e-9


@pytest.mark.parametrize('stage_type', ("deriv", "value"))
Expand Down

0 comments on commit 73d3860

Please sign in to comment.