diff --git a/test/common_utils.py b/test/common_utils.py index 9da3cf52d1c..8ecfd81d3a0 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -469,9 +469,9 @@ def sample_position(values, max_value): raise ValueError(f"Format {format} is not supported") out_boxes = torch.stack(parts, dim=-1).to(dtype=dtype, device=device) if tv_tensors.is_rotated_bounding_format(format): - # The rotated bounding boxes are not guaranteed to be within the canvas by design, - # so we apply clamping. We also add a 2 buffer to the canvas size to avoid - # numerical issues during the testing + # Rotated bounding boxes are not inherently confined within the canvas, so clamping is applied. + # Transform tests allow a 2-pixel tolerance relative to the canvas size. + # To prevent discrepancies when clamping with different canvas sizes, we add a 2-pixel buffer. buffer = 4 out_boxes = clamp_bounding_boxes( out_boxes, format=format, canvas_size=(canvas_size[0] - buffer, canvas_size[1] - buffer) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 7e667586ac1..19b832a14bd 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -4413,7 +4413,6 @@ def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, h [0, 0, 1], ], ) - affine_matrix = (resize_affine_matrix @ crop_affine_matrix)[:2, :] helper = ( reference_affine_rotated_bounding_boxes_helper @@ -4421,9 +4420,11 @@ def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, h else reference_affine_bounding_boxes_helper ) + bounding_boxes = helper(bounding_boxes, affine_matrix=crop_affine_matrix, new_canvas_size=(height, width)) + return helper( bounding_boxes, - affine_matrix=affine_matrix, + affine_matrix=resize_affine_matrix, new_canvas_size=size, ) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 7e9766bdaf5..b28f2aced28 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1104,8 +1104,9 @@ def _affine_bounding_boxes_with_expand( original_shape = bounding_boxes.shape dtype = bounding_boxes.dtype - need_cast = not bounding_boxes.is_floating_point() - bounding_boxes = bounding_boxes.float() if need_cast else bounding_boxes.clone() + acceptable_dtypes = [torch.float64] # Ensure consistency between CPU and GPU. + need_cast = dtype not in acceptable_dtypes + bounding_boxes = bounding_boxes.to(torch.float64) if need_cast else bounding_boxes.clone() device = bounding_boxes.device is_rotated = tv_tensors.is_rotated_bounding_format(format) intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY @@ -2397,11 +2398,11 @@ def elastic_bounding_boxes( original_shape = bounding_boxes.shape # TODO: first cast to float if bbox is int64 before convert_bounding_box_format - intermediate_format = tv_tensors.BoundingBoxFormat.XYXYXYXY if is_rotated else tv_tensors.BoundingBoxFormat.XYXY + intermediate_format = tv_tensors.BoundingBoxFormat.CXCYWHR if is_rotated else tv_tensors.BoundingBoxFormat.XYXY bounding_boxes = ( convert_bounding_box_format(bounding_boxes.clone(), old_format=format, new_format=intermediate_format) - ).reshape(-1, 8 if is_rotated else 4) + ).reshape(-1, 5 if is_rotated else 4) id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype) # We construct an approximation of inverse grid as inv_grid = id_grid - displacement @@ -2409,7 +2410,7 @@ def elastic_bounding_boxes( inv_grid = id_grid.sub_(displacement) # Get points from bboxes - points = bounding_boxes if is_rotated else bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]] + points = bounding_boxes[:, :2] if is_rotated else bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]] points = points.reshape(-1, 2) if points.is_floating_point(): points = points.ceil_() @@ -2421,8 +2422,8 @@ def elastic_bounding_boxes( transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5) if is_rotated: - transformed_points = transformed_points.reshape(-1, 8) - out_bboxes = _parallelogram_to_bounding_boxes(transformed_points).to(bounding_boxes.dtype) + transformed_points = transformed_points.reshape(-1, 2) + out_bboxes = torch.cat([transformed_points, bounding_boxes[:, 2:]], dim=1).to(bounding_boxes.dtype) else: transformed_points = transformed_points.reshape(-1, 4, 2) out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 1729aa4bbaf..96ee69c46c0 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -409,23 +409,17 @@ def _order_bounding_boxes_points( if indices is None: output_xyxyxyxy = bounding_boxes.reshape(-1, 8) x, y = output_xyxyxyxy[..., 0::2], output_xyxyxyxy[..., 1::2] - y_max = torch.max(y, dim=1, keepdim=True)[0] - _, x1 = ((y_max - y) / y_max + (x + 1) * 100).min(dim=1) + y_max = torch.max(y.abs(), dim=1, keepdim=True)[0] + _, x1 = (y / y_max + (x + 1) * 100).min(dim=1) indices = torch.ones_like(output_xyxyxyxy) indices[..., 0] = x1.mul(2) indices.cumsum_(1).remainder_(8) return indices, bounding_boxes.gather(1, indices.to(torch.int64)) -def _area(box: torch.Tensor) -> torch.Tensor: - x1, y1, x2, y2, x3, y3, x4, y4 = box.reshape(-1, 8).unbind(-1) - w = torch.sqrt((y2 - y1) ** 2 + (x2 - x1) ** 2) - h = torch.sqrt((y3 - y2) ** 2 + (x3 - x2) ** 2) - return w * h - - def _clamp_along_y_axis( bounding_boxes: torch.Tensor, + canvas_size: tuple[int, int], ) -> torch.Tensor: """ Adjusts bounding boxes along the y-axis based on specific conditions. @@ -448,29 +442,33 @@ def _clamp_along_y_axis( b2 = y2 + x2 / a b3 = y3 - a * x3 b4 = y4 + x4 / a - b23 = (b2 - b3) / 2 * a / (1 + a**2) - z = torch.zeros_like(b1) - case_a = torch.cat([x.unsqueeze(1) for x in [z, b1, x2, y2, x3, y3, x3 - x2, y3 + b1 - y2]], dim=1) - case_b = torch.cat([x.unsqueeze(1) for x in [z, b4, x2 - x1, y2 - y1 + b4, x3, y3, x4, y4]], dim=1) - case_c = torch.cat( - [x.unsqueeze(1) for x in [z, (b2 + b3) / 2, b23, -b23 / a + b2, x3, y3, b23, b23 * a + b3]], dim=1 + c = a / (1 + a**2) + b1 = b2.clamp(0).clamp(b1, b3) + b4 = b3.clamp(max=canvas_size[0]).clamp(b2, b4) + case_a = torch.stack( + ( + (b4 - b1) * c, + (b4 - b1) * c * a + b1, + (b2 - b1) * c, + (b1 - b2) * c / a + b2, + x3, + y3, + (b4 - b3) * c, + (b3 - b4) * c / a + b4, + ), + dim=-1, ) - case_d = torch.zeros_like(case_c) - case_e = torch.cat([x.unsqueeze(1) for x in [x1.clamp(0), y1, x2.clamp(0), y2, x3, y3, x4, y4]], dim=1) - - cond_a = (x1 < 0).logical_and(x2 >= 0).logical_and(x3 >= 0).logical_and(x4 >= 0) - cond_a = cond_a.logical_and(_area(case_a) > _area(case_b)) - cond_a = cond_a.logical_or((x1 < 0).logical_and(x2 >= 0).logical_and(x3 >= 0).logical_and(x4 <= 0)) - cond_b = (x1 < 0).logical_and(x2 >= 0).logical_and(x3 >= 0).logical_and(x4 >= 0) - cond_b = cond_b.logical_and(_area(case_a) <= _area(case_b)) - cond_b = cond_b.logical_or((x1 < 0).logical_and(x2 <= 0).logical_and(x3 >= 0).logical_and(x4 >= 0)) - cond_c = (x1 < 0).logical_and(x2 <= 0).logical_and(x3 >= 0).logical_and(x4 <= 0) - cond_d = (x1 < 0).logical_and(x2 <= 0).logical_and(x3 <= 0).logical_and(x4 <= 0) - cond_e = x1.isclose(x2) - + case_b = bounding_boxes.clone() + case_b[..., 0].clamp_(0) + case_b[..., 6].clamp_(0) + case_c = torch.zeros_like(case_b) + + cond_a = x1 < 0 + cond_b = y1.isclose(y2, rtol=1e-05, atol=1e-05) + cond_c = (x1 <= 0).logical_and(x2 <= 0).logical_and(x3 <= 0).logical_and(x4 <= 0) for cond, case in zip( - [cond_a, cond_b, cond_c, cond_d, cond_e], - [case_a, case_b, case_c, case_d, case_e], + [cond_a, cond_b, cond_c], + [case_a, case_b, case_c], ): bounding_boxes = torch.where(cond.unsqueeze(1).repeat(1, 8), case.reshape(-1, 8), bounding_boxes) return bounding_boxes.to(original_dtype).reshape(original_shape) @@ -512,7 +510,7 @@ def _clamp_rotated_bounding_boxes( for _ in range(4): # Iterate over the 4 vertices. indices, out_boxes = _order_bounding_boxes_points(out_boxes) - out_boxes = _clamp_along_y_axis(out_boxes) + out_boxes = _clamp_along_y_axis(out_boxes, canvas_size) _, out_boxes = _order_bounding_boxes_points(out_boxes, indices) # rotate 90 degrees counter clock wise out_boxes[:, ::2], out_boxes[:, 1::2] = (