Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sample-counting and resuming bugs #41

Merged
merged 8 commits into from
Nov 17, 2023
9 changes: 6 additions & 3 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ max-line-length = 120
max-complexity = 45
ignore =
E203
W503 # line break before binary operator; conflicts with black
E722 # bare except ok
E731 # lambda expressions ok
# line break before binary operator; conflicts with black
W503
# bare except ok
E722
# lambda expressions ok
E731
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't need to do these edits. But if they gave you problems locally, I don't see a reason to decline this

exclude =
.git
.tox
Expand Down
148 changes: 99 additions & 49 deletions PTMCMCSampler/PTMCMCSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def __init__(
resume=False,
seed=None,
):

# MPI initialization
self.comm = comm
self.MPIrank = self.comm.Get_rank()
Expand Down Expand Up @@ -204,11 +203,12 @@ def initialize(
self.neff = neff
self.tstart = 0

N = int(maxIter / thin)
N = int(maxIter / thin) + 1 # first sample + those we generate

self._lnprob = np.zeros(N)
self._lnlike = np.zeros(N)
self._chain = np.zeros((N, self.ndim))
self.ind_next_write = 0 # Next index in these arrays to write out
self.naccepted = 0
self.swapProposed = 0
self.nswap_accepted = 0
Expand Down Expand Up @@ -291,13 +291,27 @@ def initialize(
print("Resuming run from chain file {0}".format(self.fname))
try:
self.resumechain = np.loadtxt(self.fname)
self.resumeLength = self.resumechain.shape[0]
except ValueError:
print("WARNING: Cant read in file. Removing last line.")
os.system("sed -ie '$d' {0}".format(self.fname))
self.resumechain = np.loadtxt(self.fname)
self.resumeLength = self.resumechain.shape[0]
self.resumeLength = self.resumechain.shape[0] # Number of samples read from old chain
except ValueError as error:
print("Reading old chain files failed with error", error)
raise Exception("Couldn't read old chain to resume")
self._chainfile = open(self.fname, "a")
if (
self.isave != self.thin
and self.resumeLength % (self.isave / self.thin) != 1 # This special case is always OK
): # Initial sample plus blocks of isave/thin
raise Exception(
(
"Old chain has {0} rows, which is not the initial sample plus a multiple of isave/thin = {1}"
).format(self.resumeLength, self.isave // self.thin)
)
print(
"Resuming with",
self.resumeLength,
"samples from file representing",
(self.resumeLength - 1) * self.thin + 1,
"original samples",
)
else:
self._chainfile = open(self.fname, "w")
self._chainfile.close()
Expand All @@ -319,18 +333,40 @@ def updateChains(self, p0, lnlike0, lnprob0, iter):
self._lnprob[ind] = lnprob0

# write to file
if iter % self.isave == 0 and iter > 1 and iter > self.resumeLength:
if iter % self.isave == 0:
self.writeOutput(iter)

def writeOutput(self, iter):
"""
Write chains and covariance matrix. Called every isave on samples or at end.
"""
if iter // self.thin >= self.ind_next_write:
if self.writeHotChains or self.MPIrank == 0:
self._writeToFile(iter)

# write output covariance matrix
np.save(self.outDir + "/cov.npy", self.cov)
if self.MPIrank == 0 and self.verbose and iter > 1:
sys.stdout.write("\r")
sys.stdout.write(
"Finished %2.2f percent in %f s Acceptance rate = %g"
% (iter / self.Niter * 100, time.time() - self.tstart, self.naccepted / iter)
)
if iter > 0:
np.save(self.outDir + "/cov.npy", self.cov)

if self.MPIrank == 0 and self.verbose:
if iter > 0:
sys.stdout.write("\r")
percent = iter / self.Niter * 100 # Percent of total work finished
acceptance = self.naccepted / iter if iter > 0 else 0
elapsed = time.time() - self.tstart
if self.resume:
# Percentage of new work done
percentnew = (
(iter - self.resumeLength * self.thin) / (self.Niter - self.resumeLength * self.thin) * 100
)
sys.stdout.write(
"Finished %2.2f percent (%2.2f percent of new work) in %f s Acceptance rate = %g"
% (percent, percentnew, elapsed, acceptance)
)
else:
sys.stdout.write(
"Finished %2.2f percent in %f s Acceptance rate = %g" % (percent, elapsed, acceptance)
)
sys.stdout.flush()

def sample(
Expand Down Expand Up @@ -368,7 +404,7 @@ def sample(
@param Tmin: Minimum temperature in ladder (default=1)
@param Tmax: Maximum temperature in ladder (default=None)
@param Tskip: Number of steps between proposed temperature swaps (default=100)
@param isave: Number of iterations before writing to file (default=1000)
@param isave: Write to file every isave samples (default=1000)
@param covUpdate: Number of iterations between AM covariance updates (default=1000)
@param SCAMweight: Weight of SCAM jumps in overall jump cycle (default=20)
@param AMweight: Weight of AM jumps in overall jump cycle (default=20)
Expand All @@ -381,7 +417,7 @@ def sample(
@param burn: Burn in time (DE jumps added after this iteration) (default=10000)
@param maxIter: Maximum number of iterations for high temperature chains
(default=2*self.Niter)
@param self.thin: Save every self.thin MCMC samples
@param self.thin: MCMC Samples are recorded every self.thin samples
@param i0: Iteration to start MCMC (if i0 !=0, do not re-initialize)
@param neff: Number of effective samples to collect before terminating

Expand All @@ -393,6 +429,15 @@ def sample(
elif maxIter is None and self.MPIrank == 0:
maxIter = Niter

if isave % thin != 0:
raise ValueError("isave = %d is not a multiple of thin = %d" % (isave, thin))

if Niter % thin != 0:
print(
"Niter = %d is not a multiple of thin = %d. The last %d samples will be lost"
% (Niter, thin, Niter % thin)
)

# set up arrays to store lnprob, lnlike and chain
# if picking up from previous run, don't re-initialize
if i0 == 0:
Expand Down Expand Up @@ -426,28 +471,28 @@ def sample(
# if resuming, just start with first point in chain
if self.resume and self.resumeLength > 0:
p0, lnlike0, lnprob0 = self.resumechain[0, :-4], self.resumechain[0, -3], self.resumechain[0, -4]
self.ind_next_write = self.resumeLength
else:
# compute prior
lp = self.logp(p0)

if lp == float(-np.inf):

lnprob0 = -np.inf
lnlike0 = -np.inf

else:

lnlike0 = self.logl(p0)
lnprob0 = 1 / self.temp * lnlike0 + lp

# record first values
self.tstart = time.time()
self.updateChains(p0, lnlike0, lnprob0, i0)

self.comm.barrier()

# start iterations
iter = i0
self.tstart = time.time()

runComplete = False
Neff = 0
while runComplete is False:
Expand All @@ -456,7 +501,7 @@ def sample(
# call PTMCMCOneStep
p0, lnlike0, lnprob0 = self.PTMCMCOneStep(p0, lnlike0, lnprob0, iter)

# compute effective number of samples
# compute effective number of samples in cold chain
if iter % 1000 == 0 and iter > 2 * self.burn and self.MPIrank == 0:
try:
Neff = iter / max(
Expand All @@ -468,19 +513,21 @@ def sample(
Neff = 0
pass

# stop if reached maximum number of iterations
if self.MPIrank == 0 and iter >= self.Niter - 1:
if self.verbose:
print("\nRun Complete")
runComplete = True
# rank 0 decides whether to stop
if self.MPIrank == 0:
if iter >= self.Niter: # stop if reached maximum number of iterations
message = "\nRun Complete"
runComplete = True
elif int(Neff) > self.neff: # stop if reached maximum number of iterations
message = "\nRun Complete with {0} effective samples".format(int(Neff))
runComplete = True

# stop if reached effective number of samples
if self.MPIrank == 0 and int(Neff) > self.neff:
if self.verbose:
print("\nRun Complete with {0} effective samples".format(int(Neff)))
runComplete = True
runComplete = self.comm.bcast(runComplete, root=0) # rank 0 tells others whether to stop

runComplete = self.comm.bcast(runComplete, root=0)
if runComplete:
self.writeOutput(iter) # Possibly write partial block
if self.MPIrank == 0 and self.verbose:
print(message)

def PTMCMCOneStep(self, p0, lnlike0, lnprob0, iter):
"""
Expand Down Expand Up @@ -541,12 +588,17 @@ def PTMCMCOneStep(self, p0, lnlike0, lnprob0, iter):

# jump proposal ###

# if resuming, just use previous chain points
if self.resume and self.resumeLength > 0 and iter < self.resumeLength:
p0, lnlike0, lnprob0 = self.resumechain[iter, :-4], self.resumechain[iter, -3], self.resumechain[iter, -4]
# if resuming, just use previous chain points. Use each one thin times to compensate for
# thinning when they were written out
if self.resume and self.resumeLength > 0 and iter < self.resumeLength * self.thin:
p0, lnlike0, lnprob0 = (
self.resumechain[iter // self.thin, :-4],
self.resumechain[iter // self.thin, -3],
self.resumechain[iter // self.thin, -4],
)

# update acceptance counter
self.naccepted = iter * self.resumechain[iter, -2]
self.naccepted = iter * self.resumechain[iter // self.thin, -2]
else:
y, qxy, jump_name = self._jump(p0, iter)
self.jumpDict[jump_name][0] += 1
Expand All @@ -555,18 +607,15 @@ def PTMCMCOneStep(self, p0, lnlike0, lnprob0, iter):
lp = self.logp(y)

if lp == -np.inf:

newlnprob = -np.inf

else:

newlnlike = self.logl(y)
newlnprob = 1 / self.temp * newlnlike + lp

# hastings step
diff = newlnprob - lnprob0 + qxy
if diff > np.log(self.stream.random()):

# accept jump
p0, lnlike0, lnprob0 = y, newlnlike, newlnprob

Expand Down Expand Up @@ -664,32 +713,35 @@ def temperatureLadder(self, Tmin, Tmax=None, tstep=None):

def _writeToFile(self, iter):
"""
Function to write chain file. File has 3+ndim columns,
the first is log-posterior (unweighted), log-likelihood,
and acceptance probability, followed by parameter values.
Function to write chain file. File has ndim+4 columns,
appended to the parameter values are log-posterior (unnormalized),
log-likelihood, acceptance rate, and PT acceptance rate.
Rates are as of time of writing.

@param iter: Iteration of sampler

"""

self._chainfile = open(self.fname, "a+")
for jj in range((iter - self.isave), iter, self.thin):
ind = int(jj / self.thin)
# index 0 is the initial element. So after 10*thin iterations we need to write elements 1..10
write_end = iter // self.thin + 1 # First element not to write.
for ind in range(self.ind_next_write, write_end):
pt_acc = 1
if self.MPIrank < self.nchain - 1 and self.swapProposed != 0:
pt_acc = self.nswap_accepted / self.swapProposed

self._chainfile.write("\t".join(["%22.22f" % (self._chain[ind, kk]) for kk in range(self.ndim)]))
self._chainfile.write(
"\t%f\t%f\t%f\t%f\n" % (self._lnprob[ind], self._lnlike[ind], self.naccepted / iter, pt_acc)
"\t%f\t%f\t%f\t%f\n"
% (self._lnprob[ind], self._lnlike[ind], self.naccepted / iter if iter > 0 else 0, pt_acc)
)
self._chainfile.close()
self.ind_next_write = write_end # Ready for next write

# write jump statistics files ####

# only for T=1 chain
if self.MPIrank == 0:

# first write file contaning jump names and jump rates
fout = open(self.outDir + "/jumps.txt", "w")
njumps = len(self.propCycle)
Expand Down Expand Up @@ -726,7 +778,6 @@ def _updateRecursive(self, iter, mem):
diff = np.zeros(ndim)
it += 1
for jj in range(ndim):

diff[jj] = self._AMbuffer[ii, jj] - self.mu[jj]
self.mu[jj] += diff[jj] / it

Expand Down Expand Up @@ -917,7 +968,6 @@ def DEJump(self, x, iter, beta):
scale = self.stream.random() * 2.4 / np.sqrt(2 * ndim) * np.sqrt(1 / beta)

for ii in range(ndim):

# jump size
sigma = self._DEbuffer[mm, self.groups[jumpind][ii]] - self._DEbuffer[nn, self.groups[jumpind][ii]]

Expand Down
Loading