diff --git a/test/test_ops.py b/test/test_ops.py index 3f0d8312c01..d8ea34e0e72 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -7,6 +7,7 @@ import numpy as np import pytest +import time import torch import torch.fx import torch.nn.functional as F @@ -615,6 +616,31 @@ def test_jit_boxes_list(self): model = PoolWrapper(ops.RoIAlign(output_size=[3, 3], spatial_scale=1.0, sampling_ratio=-1)) self._helper_jit_boxes_list(model) + @needs_mps + def test_performance_mps(self): + # Regression test for https://github.com/pytorch/pytorch/issues/124850 + execution_time_ms_threshold = 1000 # ms = 1 second + + num_imgs, n_channels, img_size, img_size = 1, 256, 200, 200 + spatial_scale = 0.25 + output_size = 7 + sampling_ratio = 2 + aligned = False + dtype = torch.float32 + device = "mps" + + x = torch.randint(50, 100, size=(num_imgs, n_channels, img_size, img_size), dtype=dtype).to(device) + rois = self._make_rois(img_size, num_imgs, dtype).to(device) + + start = time.time() + _ = ops.roi_align(x, rois, output_size, spatial_scale, sampling_ratio, aligned) + torch.mps.synchronize() + end_execution = time.time() + execution_time_ms = 1000 * (end_execution - start) + + assert ( + execution_time_ms < execution_time_ms_threshold + ), f"Expected execution to take < {execution_time_ms_threshold} ms, actually took {execution_time_ms} ms" class TestPSRoIAlign(RoIOpTester): mps_backward_atol = 5e-2 diff --git a/torchvision/csrc/ops/mps/mps_kernels.h b/torchvision/csrc/ops/mps/mps_kernels.h index f85546a6c41..a672aa35e32 100644 --- a/torchvision/csrc/ops/mps/mps_kernels.h +++ b/torchvision/csrc/ops/mps/mps_kernels.h @@ -225,12 +225,12 @@ kernel void nms( \ uint2 tgid [[threadgroup_position_in_grid]], \ uint2 tid2 [[thread_position_in_threadgroup]]); -template +template kernel void roi_align( constant T * input [[buffer(0)]], constant T * rois [[buffer(1)]], device T * output [[buffer(2)]], - constant int64_t & output_size [[buffer(3)]], + constant float & spatial_scale [[buffer(3)]], constant int64_t & channels [[buffer(4)]], constant int64_t & height [[buffer(5)]], constant int64_t & width [[buffer(6)]], @@ -238,92 +238,83 @@ kernel void roi_align( constant int64_t & pooled_width [[buffer(8)]], constant int64_t & sampling_ratio [[buffer(9)]], constant bool & aligned [[buffer(10)]], - constant float & spatial_scale [[buffer(11)]], - uint2 tgid [[threadgroup_position_in_grid]], - uint2 tptg [[threads_per_threadgroup]], - uint2 tid2 [[thread_position_in_threadgroup]]){ - MPS_1D_KERNEL_LOOP(index, output_size, 1) { - // (n, c, ph, pw) is an element in the pooled output - integer_t pw = index % pooled_width; - integer_t ph = (index / pooled_width) % pooled_height; - integer_t c = (index / pooled_width / pooled_height) % channels; - integer_t n = index / pooled_width / pooled_height / channels; - - constant T* offset_rois = rois + n * 5; - integer_t roi_batch_ind = offset_rois[0]; - - // Do not using rounding; this implementation detail is critical - T offset = aligned ? (T)0.5 : (T)0.0; - T roi_start_w = offset_rois[1] * spatial_scale - offset; - T roi_start_h = offset_rois[2] * spatial_scale - offset; - T roi_end_w = offset_rois[3] * spatial_scale - offset; - T roi_end_h = offset_rois[4] * spatial_scale - offset; - - T roi_width = roi_end_w - roi_start_w; - T roi_height = roi_end_h - roi_start_h; - if (!aligned) { - // Force malformed ROIs to be 1x1 - roi_width = max(roi_width, (T)1.); - roi_height = max(roi_height, (T)1.); - } - - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - constant T* offset_input = - input + (roi_batch_ind * channels + c) * height * width; - - // We use roi_bin_grid to sample the grid and mimic integral - integer_t roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : ceil(roi_height / pooled_height); // e.g., = 2 - integer_t roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); - - // We do average (integral) pooling inside a bin - // When the grid is empty, output zeros. - const T count = max(roi_bin_grid_h * roi_bin_grid_w, static_cast(1)); // e.g. = 4 - - T output_val = 0.; - for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 - { - const T y = roi_start_h + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 - for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = roi_start_w + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); + uint index [[thread_position_in_grid]]) +{ + // Decode linear index into (n, c, ph, pw) + int64_t pw = index % pooled_width; + int64_t ph = (index / pooled_width) % pooled_height; + int64_t c = (index / pooled_width / pooled_height) % channels; + int64_t n = index / (pooled_width * pooled_height * channels); + + constant T* offset_rois = rois + n * 5; + int64_t roi_batch_ind = static_cast(offset_rois[0]); + + // Do not using rounding; this implementation detail is critical + T offset = aligned ? static_cast(0.5) : static_cast(0.0); + T roi_start_w = offset_rois[1] * spatial_scale - offset; + T roi_start_h = offset_rois[2] * spatial_scale - offset; + T roi_end_w = offset_rois[3] * spatial_scale - offset; + T roi_end_h = offset_rois[4] * spatial_scale - offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = max(roi_width, static_cast(1.0)); + roi_height = max(roi_height, static_cast(1.0)); + } - T val = bilinear_interpolate(offset_input, height, width, y, x, index); - output_val += val; - } + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + constant T* offset_input = input + (roi_batch_ind * channels + c) * height * width; + + // We use roi_bin_grid to sample the grid and mimic integral + int64_t roi_bin_grid_h = sampling_ratio > 0 + ? sampling_ratio + : static_cast(ceil(roi_height / static_cast(pooled_height))); + int64_t roi_bin_grid_w = sampling_ratio > 0 + ? sampling_ratio + : static_cast(ceil(roi_width / static_cast(pooled_width))); + + // We do average (integral) pooling inside a bin + // When the grid is empty, output zeros. + const T count = max(roi_bin_grid_h * roi_bin_grid_w, static_cast(1)); + T output_val = static_cast(0.0); + + for (int64_t iy = 0; iy < roi_bin_grid_h; iy++) { + T y = roi_start_h + static_cast(ph) * bin_size_h + + (static_cast(iy) + static_cast(0.5)) * bin_size_h / static_cast(roi_bin_grid_h); + for (int64_t ix = 0; ix < roi_bin_grid_w; ix++) { + T x = roi_start_w + static_cast(pw) * bin_size_w + + (static_cast(ix) + static_cast(0.5)) * bin_size_w / static_cast(roi_bin_grid_w); + + T val = bilinear_interpolate(offset_input, height, width, y, x, index); + output_val += val; } - output_val /= count; - - output[index] = output_val; } + + output_val /= count; + output[index] = output_val; } -#define REGISTER_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \ -template \ -[[host_name("roi_align_" #DTYPE)]] \ -kernel void roi_align( \ - constant DTYPE * input [[buffer(0)]], \ - constant DTYPE * rois [[buffer(1)]], \ - device DTYPE * output [[buffer(2)]], \ - constant int64_t & output_size [[buffer(3)]], \ - constant int64_t & channels [[buffer(4)]], \ - constant int64_t & height [[buffer(5)]], \ - constant int64_t & width [[buffer(6)]], \ - constant int64_t & pooled_height [[buffer(7)]], \ - constant int64_t & pooled_width [[buffer(8)]], \ - constant int64_t & sampling_ratio [[buffer(9)]], \ - constant bool & aligned [[buffer(10)]], \ - constant float & spatial_scale [[buffer(11)]], \ - uint2 tgid [[threadgroup_position_in_grid]], \ - uint2 tptg [[threads_per_threadgroup]], \ - uint2 tid2 [[thread_position_in_threadgroup]]); +#define REGISTER_ROI_ALIGN_OP(DTYPE) \ +template \ +[[host_name("roi_align_" #DTYPE)]] \ +kernel void roi_align( \ + constant DTYPE * input [[buffer(0)]], \ + constant DTYPE * rois [[buffer(1)]], \ + device DTYPE * output [[buffer(2)]], \ + constant float & spatial_scale [[buffer(3)]], \ + constant int64_t & channels [[buffer(4)]], \ + constant int64_t & height [[buffer(5)]], \ + constant int64_t & width [[buffer(6)]], \ + constant int64_t & pooled_height [[buffer(7)]], \ + constant int64_t & pooled_width [[buffer(8)]], \ + constant int64_t & sampling_ratio [[buffer(9)]], \ + constant bool & aligned [[buffer(10)]], \ + uint index [[thread_position_in_grid]]); template kernel void roi_align_backward( @@ -1005,7 +996,7 @@ kernel void ps_roi_pool_backward( \ constant int64_t & width [[buffer(7)]], \ constant int64_t & pooled_height [[buffer(8)]], \ constant int64_t & pooled_width [[buffer(9)]], \ - constant int64_t & channels_out [[buffer(10)]], \ + constant int64_t & channels_out [[buffer(10)]], \ constant float & spatial_scale [[buffer(11)]], \ uint2 tgid [[threadgroup_position_in_grid]], \ uint2 tptg [[threads_per_threadgroup]], \ @@ -1013,8 +1004,8 @@ kernel void ps_roi_pool_backward( \ REGISTER_NMS_OP(float); REGISTER_NMS_OP(half); -REGISTER_ROI_ALIGN_OP(float, int64_t); -REGISTER_ROI_ALIGN_OP(half, int64_t); +REGISTER_ROI_ALIGN_OP(float); +REGISTER_ROI_ALIGN_OP(half); REGISTER_ROI_ALIGN_BACKWARD_OP(float, int64_t); REGISTER_ROI_ALIGN_BACKWARD_OP(half, int64_t); REGISTER_ROI_POOL_OP(float, int64_t); diff --git a/torchvision/csrc/ops/mps/roi_align_kernel.mm b/torchvision/csrc/ops/mps/roi_align_kernel.mm index d4ed8b43fd2..6cf07761060 100644 --- a/torchvision/csrc/ops/mps/roi_align_kernel.mm +++ b/torchvision/csrc/ops/mps/roi_align_kernel.mm @@ -51,14 +51,14 @@ dispatch_sync(mpsStream->queue(), ^() { @autoreleasepool { id computeEncoder = mpsStream->commandEncoder(); - MTLSize threadgroupsPerGrid = MTLSizeMake( - std::min(ceil_div(static_cast(output_size), static_cast(512)), static_cast(4096)), - 1, - 1); const std::string kernel = "roi_align_" + scalarToMetalTypeString(input.scalar_type()); id visionPSO = mps::visionPipelineState(device, kernel); + auto threadsPerGrid = MTLSizeMake(output_size, 1, 1); + auto threadsPerThreadgroup = + MTLSizeMake(std::min(static_cast(visionPSO.maxTotalThreadsPerThreadgroup), output_size), 1, 1); + // this function call is a no-op if MPS Profiler is not enabled getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_}); @@ -68,7 +68,7 @@ [computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1]; [computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2]; - [computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3]; + [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:3]; [computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4]; [computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5]; [computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6]; @@ -76,16 +76,8 @@ [computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8]; [computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9]; [computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10]; - [computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11]; - - // A threadGroup is equivalent to a cuda's block. - NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup; - if (tgSize > threadsPerBlock) { - tgSize = threadsPerBlock; - } - MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); - [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + [computeEncoder dispatchThreads:threadsPerGrid threadsPerThreadgroup:threadsPerThreadgroup]; getMPSProfiler().endProfileKernel(visionPSO); }