diff --git a/bitmat/bitlinear.py b/bitmat/bitlinear.py index fe1efc9..16f1f20 100644 --- a/bitmat/bitlinear.py +++ b/bitmat/bitlinear.py @@ -76,7 +76,7 @@ def forward(self, x): self.convert_weights_to_parameters() x_dtype = x.dtype x = self.norm(x.to(self.norm.weight.dtype)).to(x_dtype) - output = bitmat(self.weight.data, x, scale_w=self.scale_w) + output = bitmat(self.weight, x, scale_w=self.scale_w) if self.bias is not None: output += self.bias.unsqueeze(0).expand_as(output) return output diff --git a/bitmat/triton_kernels/rmsnorm_kernel.py b/bitmat/triton_kernels/rmsnorm_kernel.py index 7b3753f..66e57d2 100644 --- a/bitmat/triton_kernels/rmsnorm_kernel.py +++ b/bitmat/triton_kernels/rmsnorm_kernel.py @@ -166,11 +166,11 @@ def backward(ctx, dY): num_warps = ctx.num_warps, ) dX = dY.view(*shape) - return dX, None, None, None + return None, dX, None, None pass pass def fast_rms_layernorm(weight, X, eps, gemma = False): out = Fast_RMS_Layernorm.apply(X, weight, eps, gemma) return out -pass \ No newline at end of file +pass