Skip to content

Adjust clamping for rotated bboxes #9112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,9 +469,9 @@ def sample_position(values, max_value):
raise ValueError(f"Format {format} is not supported")
out_boxes = torch.stack(parts, dim=-1).to(dtype=dtype, device=device)
if tv_tensors.is_rotated_bounding_format(format):
# The rotated bounding boxes are not guaranteed to be within the canvas by design,
# so we apply clamping. We also add a 2 buffer to the canvas size to avoid
# numerical issues during the testing
# Rotated bounding boxes are not inherently confined within the canvas, so clamping is applied.
# Transform tests allow a 2-pixel tolerance relative to the canvas size.
# To prevent discrepancies when clamping with different canvas sizes, we add a 2-pixel buffer.
buffer = 4
out_boxes = clamp_bounding_boxes(
out_boxes, format=format, canvas_size=(canvas_size[0] - buffer, canvas_size[1] - buffer)
Expand Down
5 changes: 3 additions & 2 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4413,17 +4413,18 @@ def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, h
[0, 0, 1],
],
)
affine_matrix = (resize_affine_matrix @ crop_affine_matrix)[:2, :]

helper = (
reference_affine_rotated_bounding_boxes_helper
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
else reference_affine_bounding_boxes_helper
)

bounding_boxes = helper(bounding_boxes, affine_matrix=crop_affine_matrix, new_canvas_size=(height, width))

return helper(
bounding_boxes,
affine_matrix=affine_matrix,
affine_matrix=resize_affine_matrix,
new_canvas_size=size,
)

Expand Down
15 changes: 8 additions & 7 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,8 +1104,9 @@ def _affine_bounding_boxes_with_expand(

original_shape = bounding_boxes.shape
dtype = bounding_boxes.dtype
need_cast = not bounding_boxes.is_floating_point()
bounding_boxes = bounding_boxes.float() if need_cast else bounding_boxes.clone()
acceptable_dtypes = [torch.float64] # Ensure consistency between CPU and GPU.
need_cast = dtype not in acceptable_dtypes
bounding_boxes = bounding_boxes.to(torch.float64) if need_cast else bounding_boxes.clone()
device = bounding_boxes.device
is_rotated = tv_tensors.is_rotated_bounding_format(format)
intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY
Expand Down Expand Up @@ -2397,19 +2398,19 @@ def elastic_bounding_boxes(

original_shape = bounding_boxes.shape
# TODO: first cast to float if bbox is int64 before convert_bounding_box_format
intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY
intermediate_format = tv_tensors.BoundingBoxFormat.CXCYWHR if is_rotated else tv_tensors.BoundingBoxFormat.XYXY

bounding_boxes = (
convert_bounding_box_format(bounding_boxes.clone(), old_format=format, new_format=intermediate_format)
).reshape(-1, 8 if is_rotated else 4)
).reshape(-1, 5 if is_rotated else 4)

id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype)
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement
# This is not an exact inverse of the grid
inv_grid = id_grid.sub_(displacement)

# Get points from bboxes
points = bounding_boxes if is_rotated else bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]]
points = bounding_boxes[:, :2] if is_rotated else bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]]
points = points.reshape(-1, 2)
if points.is_floating_point():
points = points.ceil_()
Expand All @@ -2421,8 +2422,8 @@ def elastic_bounding_boxes(
transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)

if is_rotated:
transformed_points = transformed_points.reshape(-1, 8)
out_bboxes = _parallelogram_to_bounding_boxes(transformed_points).to(bounding_boxes.dtype)
transformed_points = transformed_points.reshape(-1, 2)
out_bboxes = torch.cat([transformed_points, bounding_boxes[:, 2:]], dim=1).to(bounding_boxes.dtype)
else:
transformed_points = transformed_points.reshape(-1, 4, 2)
out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
Expand Down
60 changes: 29 additions & 31 deletions torchvision/transforms/v2/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,23 +409,17 @@ def _order_bounding_boxes_points(
if indices is None:
output_xyxyxyxy = bounding_boxes.reshape(-1, 8)
x, y = output_xyxyxyxy[..., 0::2], output_xyxyxyxy[..., 1::2]
y_max = torch.max(y, dim=1, keepdim=True)[0]
_, x1 = ((y_max - y) / y_max + (x + 1) * 100).min(dim=1)
y_max = torch.max(y.abs(), dim=1, keepdim=True)[0]
_, x1 = (y / y_max + (x + 1) * 100).min(dim=1)
indices = torch.ones_like(output_xyxyxyxy)
indices[..., 0] = x1.mul(2)
indices.cumsum_(1).remainder_(8)
return indices, bounding_boxes.gather(1, indices.to(torch.int64))


def _area(box: torch.Tensor) -> torch.Tensor:
x1, y1, x2, y2, x3, y3, x4, y4 = box.reshape(-1, 8).unbind(-1)
w = torch.sqrt((y2 - y1) ** 2 + (x2 - x1) ** 2)
h = torch.sqrt((y3 - y2) ** 2 + (x3 - x2) ** 2)
return w * h


def _clamp_along_y_axis(
bounding_boxes: torch.Tensor,
canvas_size: tuple[int, int],
) -> torch.Tensor:
"""
Adjusts bounding boxes along the y-axis based on specific conditions.
Expand All @@ -448,29 +442,33 @@ def _clamp_along_y_axis(
b2 = y2 + x2 / a
b3 = y3 - a * x3
b4 = y4 + x4 / a
b23 = (b2 - b3) / 2 * a / (1 + a**2)
z = torch.zeros_like(b1)
case_a = torch.cat([x.unsqueeze(1) for x in [z, b1, x2, y2, x3, y3, x3 - x2, y3 + b1 - y2]], dim=1)
case_b = torch.cat([x.unsqueeze(1) for x in [z, b4, x2 - x1, y2 - y1 + b4, x3, y3, x4, y4]], dim=1)
case_c = torch.cat(
[x.unsqueeze(1) for x in [z, (b2 + b3) / 2, b23, -b23 / a + b2, x3, y3, b23, b23 * a + b3]], dim=1
c = a / (1 + a**2)
b1 = b2.clamp(0).clamp(b1, b3)
b4 = b3.clamp(max=canvas_size[0]).clamp(b2, b4)
case_a = torch.stack(
(
(b4 - b1) * c,
(b4 - b1) * c * a + b1,
(b2 - b1) * c,
(b1 - b2) * c / a + b2,
x3,
y3,
(b4 - b3) * c,
(b3 - b4) * c / a + b4,
),
dim=-1,
)
case_d = torch.zeros_like(case_c)
case_e = torch.cat([x.unsqueeze(1) for x in [x1.clamp(0), y1, x2.clamp(0), y2, x3, y3, x4, y4]], dim=1)

cond_a = (x1 < 0).logical_and(x2 >= 0).logical_and(x3 >= 0).logical_and(x4 >= 0)
cond_a = cond_a.logical_and(_area(case_a) > _area(case_b))
cond_a = cond_a.logical_or((x1 < 0).logical_and(x2 >= 0).logical_and(x3 >= 0).logical_and(x4 <= 0))
cond_b = (x1 < 0).logical_and(x2 >= 0).logical_and(x3 >= 0).logical_and(x4 >= 0)
cond_b = cond_b.logical_and(_area(case_a) <= _area(case_b))
cond_b = cond_b.logical_or((x1 < 0).logical_and(x2 <= 0).logical_and(x3 >= 0).logical_and(x4 >= 0))
cond_c = (x1 < 0).logical_and(x2 <= 0).logical_and(x3 >= 0).logical_and(x4 <= 0)
cond_d = (x1 < 0).logical_and(x2 <= 0).logical_and(x3 <= 0).logical_and(x4 <= 0)
cond_e = x1.isclose(x2)

case_b = bounding_boxes.clone()
case_b[..., 0].clamp_(0)
case_b[..., 6].clamp_(0)
case_c = torch.zeros_like(case_b)

cond_a = x1 < 0
cond_b = y1.isclose(y2, rtol=1e-05, atol=1e-05)
cond_c = (x1 <= 0).logical_and(x2 <= 0).logical_and(x3 <= 0).logical_and(x4 <= 0)
for cond, case in zip(
[cond_a, cond_b, cond_c, cond_d, cond_e],
[case_a, case_b, case_c, case_d, case_e],
[cond_a, cond_b, cond_c],
[case_a, case_b, case_c],
):
bounding_boxes = torch.where(cond.unsqueeze(1).repeat(1, 8), case.reshape(-1, 8), bounding_boxes)
return bounding_boxes.to(original_dtype).reshape(original_shape)
Expand Down Expand Up @@ -512,7 +510,7 @@ def _clamp_rotated_bounding_boxes(

for _ in range(4): # Iterate over the 4 vertices.
indices, out_boxes = _order_bounding_boxes_points(out_boxes)
out_boxes = _clamp_along_y_axis(out_boxes)
out_boxes = _clamp_along_y_axis(out_boxes, canvas_size)
_, out_boxes = _order_bounding_boxes_points(out_boxes, indices)
# rotate 90 degrees counter clock wise
out_boxes[:, ::2], out_boxes[:, 1::2] = (
Expand Down
Loading