You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi there,
Noticed a bug in JitTraceEnum_ELBO. My code runs fine with a previous version of pytorch or with JitTrace_ELBO (I can use RelaxedOneHotCategorical instead of OneHotCategorical for what I was enumerating). I don't personally need this bug fixed at this time, and this bug is out of my depth to understand but figured I'd report it in case someone else notices the same problem:
The error seems to come from a torchscript issue in calculating the Enumerate ELBO in pyro.infer.SVI:
315 def step(self, *args, **kwargs):
316 # Compute loss and gradients
317 with poutine.trace(param_only=True) as param_capture:
--> 318 loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
320 loss_val = torch_item(loss)
321 self.losses.append(loss_val)
File /allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/utils/miniconda3/envs/pyro2/lib/python3.11/site-packages/pyro/infer/traceenum_elbo.py:564, in JitTraceEnum_ELBO.loss_and_grads(self, model, guide, *args, **kwargs)
563 def loss_and_grads(self, model, guide, *args, **kwargs):
--> 564 differentiable_loss = self.differentiable_loss(model, guide, *args, **kwargs)
565 differentiable_loss.backward() # this line triggers jit compilation
566 loss = differentiable_loss.item()
File /allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/utils/miniconda3/envs/pyro2/lib/python3.11/site-packages/pyro/infer/traceenum_elbo.py:561, in JitTraceEnum_ELBO.differentiable_loss(self, model, guide, *args, **kwargs)
557 return elbo * (-1.0 / self.num_particles)
559 self._differentiable_loss = differentiable_loss
--> 561 return self._differentiable_loss(*args, **kwargs)
File /allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/utils/miniconda3/envs/pyro2/lib/python3.11/site-packages/pyro/ops/jit.py:120, in CompiledFunction.__call__(self, *args, **kwargs)
118 with poutine.block(hide=self._param_names):
119 with poutine.trace(param_only=True) as param_capture:
--> 120 ret = self.compiled[key](*params_and_args)
122 for name in param_capture.trace.nodes.keys():
123 if name not in self._param_names:
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: default_program(23): error: extra text after expected end of number
aten_exp[(long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)] = expf(v - (tshift_1_1<-3.402823466385289e+38.f ? -3.402823466385289e+38.f : tshift_1_1));
^
default_program(23): error: extra text after expected end of number
aten_exp[(long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)] = expf(v - (tshift_1_1<-3.402823466385289e+38.f ? -3.402823466385289e+38.f : tshift_1_1));
^
2 errors detected in the compilation of "default_program".
nvrtc compilation failed:
#define NAN __int_as_float(0x7fffffff)
#define POS_INFINITY __int_as_float(0x7f800000)
#define NEG_INFINITY __int_as_float(0xff800000)
template<typename T>
__device__ T maximum(T a, T b) {
return isnan(a) ? a : (a > b ? a : b);
}
template<typename T>
__device__ T minimum(T a, T b) {
return isnan(a) ? a : (a < b ? a : b);
}
extern "C" __global__
void fused_clamp_sub_exp(float* tt_3, float* tshift_1, float* aten_exp) {
{
if ((long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)<45150ll ? 1 : 0) {
float tshift_1_1 = __ldg(tshift_1 + (long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x));
float v = __ldg(tt_3 + (long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x));
aten_exp[(long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)] = expf(v - (tshift_1_1<-3.402823466385289e+38.f ? -3.402823466385289e+38.f : tshift_1_1));
}}
}
Thanks for the bug report. My guess is that this is an upstream bug in pytorch code generation where they are writing two decimal points in a floating point constant. I'm not sure what we can do but wait for an upstream fix.
Hi there,
Noticed a bug in JitTraceEnum_ELBO. My code runs fine with a previous version of pytorch or with JitTrace_ELBO (I can use RelaxedOneHotCategorical instead of OneHotCategorical for what I was enumerating). I don't personally need this bug fixed at this time, and this bug is out of my depth to understand but figured I'd report it in case someone else notices the same problem:
The error seems to come from a torchscript issue in calculating the Enumerate ELBO in pyro.infer.SVI:
My environment is as follows:
Thanks for all the development work, pyro rules!
The text was updated successfully, but these errors were encountered: