diff --git a/roma/euler.py b/roma/euler.py index b1b2005..8af25e0 100644 --- a/roma/euler.py +++ b/roma/euler.py @@ -159,7 +159,7 @@ def unitquat_to_euler(convention : str, quat, degrees=False, epsilon=1e-7): # Compute second angle... angles = [torch.empty(N, device=quat.device, dtype=quat.dtype) for _ in range(3)] - angles[1] = 2 * torch.atan2(torch.hypot(c, d), torch.hypot(a, b)) + angles[1] = 2 * torch.atan2(roma.internal.hypot(c, d), roma.internal.hypot(a, b)) # ... and check if equal to is 0 or pi, causing a singularity case1 = torch.abs(angles[1]) <= epsilon diff --git a/roma/internal.py b/roma/internal.py index b3dee1b..eaac4e2 100644 --- a/roma/internal.py +++ b/roma/internal.py @@ -104,4 +104,12 @@ def norm(x, dim=None, keepdim=False): except AttributeError: # torch.linalg.norm was introduced in PyTorch 1.7, and torch.norm is deprecated. def norm(x, dim=None, keepdim=False): - return torch.norm(x, dim=dim, keepdim=keepdim) \ No newline at end of file + return torch.norm(x, dim=dim, keepdim=keepdim) + +try: + torch.hypot + hypot = torch.hypot +except AttributeError: + # torch.hypot is not available in PyTorch 1.6. + def hypot(x, y): + return torch.sqrt(torch.square(x) + torch.square(y))