Skip to content

[MPS] Improve runtime complexity of roi_align #9100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import pytest
import time
import torch
import torch.fx
import torch.nn.functional as F
Expand Down Expand Up @@ -615,6 +616,31 @@ def test_jit_boxes_list(self):
model = PoolWrapper(ops.RoIAlign(output_size=[3, 3], spatial_scale=1.0, sampling_ratio=-1))
self._helper_jit_boxes_list(model)

@needs_mps
def test_performance_mps(self):
# Regression test for https://github.com/pytorch/pytorch/issues/124850
execution_time_ms_threshold = 1000 # ms = 1 second

num_imgs, n_channels, img_size, img_size = 1, 256, 200, 200
spatial_scale = 0.25
output_size = 7
sampling_ratio = 2
aligned = False
dtype = torch.float32
device = "mps"

x = torch.randint(50, 100, size=(num_imgs, n_channels, img_size, img_size), dtype=dtype).to(device)
rois = self._make_rois(img_size, num_imgs, dtype).to(device)

start = time.time()
_ = ops.roi_align(x, rois, output_size, spatial_scale, sampling_ratio, aligned)
torch.mps.synchronize()
end_execution = time.time()
execution_time_ms = 1000 * (end_execution - start)

assert (
execution_time_ms < execution_time_ms_threshold
), f"Expected execution to take < {execution_time_ms_threshold} ms, actually took {execution_time_ms} ms"

class TestPSRoIAlign(RoIOpTester):
mps_backward_atol = 5e-2
Expand Down
163 changes: 77 additions & 86 deletions torchvision/csrc/ops/mps/mps_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,105 +225,96 @@ kernel void nms<DTYPE ## 4, DTYPE>( \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tid2 [[thread_position_in_threadgroup]]);

template<typename T, typename integer_t>
template <typename T>
kernel void roi_align(
constant T * input [[buffer(0)]],
constant T * rois [[buffer(1)]],
device T * output [[buffer(2)]],
constant int64_t & output_size [[buffer(3)]],
constant float & spatial_scale [[buffer(3)]],
constant int64_t & channels [[buffer(4)]],
constant int64_t & height [[buffer(5)]],
constant int64_t & width [[buffer(6)]],
constant int64_t & pooled_height [[buffer(7)]],
constant int64_t & pooled_width [[buffer(8)]],
constant int64_t & sampling_ratio [[buffer(9)]],
constant bool & aligned [[buffer(10)]],
constant float & spatial_scale [[buffer(11)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, c, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t c = (index / pooled_width / pooled_height) % channels;
integer_t n = index / pooled_width / pooled_height / channels;

constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];

// Do not using rounding; this implementation detail is critical
T offset = aligned ? (T)0.5 : (T)0.0;
T roi_start_w = offset_rois[1] * spatial_scale - offset;
T roi_start_h = offset_rois[2] * spatial_scale - offset;
T roi_end_w = offset_rois[3] * spatial_scale - offset;
T roi_end_h = offset_rois[4] * spatial_scale - offset;

T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
if (!aligned) {
// Force malformed ROIs to be 1x1
roi_width = max(roi_width, (T)1.);
roi_height = max(roi_height, (T)1.);
}

T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

constant T* offset_input =
input + (roi_batch_ind * channels + c) * height * width;

// We use roi_bin_grid to sample the grid and mimic integral
integer_t roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
integer_t roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);

// We do average (integral) pooling inside a bin
// When the grid is empty, output zeros.
const T count = max(roi_bin_grid_h * roi_bin_grid_w, static_cast<integer_t>(1)); // e.g. = 4

T output_val = 0.;
for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
{
const T y = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
uint index [[thread_position_in_grid]])
{
// Decode linear index into (n, c, ph, pw)
int64_t pw = index % pooled_width;
int64_t ph = (index / pooled_width) % pooled_height;
int64_t c = (index / pooled_width / pooled_height) % channels;
int64_t n = index / (pooled_width * pooled_height * channels);

constant T* offset_rois = rois + n * 5;
int64_t roi_batch_ind = static_cast<int64_t>(offset_rois[0]);

// Do not using rounding; this implementation detail is critical
T offset = aligned ? static_cast<T>(0.5) : static_cast<T>(0.0);
T roi_start_w = offset_rois[1] * spatial_scale - offset;
T roi_start_h = offset_rois[2] * spatial_scale - offset;
T roi_end_w = offset_rois[3] * spatial_scale - offset;
T roi_end_h = offset_rois[4] * spatial_scale - offset;

T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;

if (!aligned) {
// Force malformed ROIs to be 1x1
roi_width = max(roi_width, static_cast<T>(1.0));
roi_height = max(roi_height, static_cast<T>(1.0));
}

T val = bilinear_interpolate(offset_input, height, width, y, x, index);
output_val += val;
}
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);

constant T* offset_input = input + (roi_batch_ind * channels + c) * height * width;

// We use roi_bin_grid to sample the grid and mimic integral
int64_t roi_bin_grid_h = sampling_ratio > 0
? sampling_ratio
: static_cast<int64_t>(ceil(roi_height / static_cast<T>(pooled_height)));
int64_t roi_bin_grid_w = sampling_ratio > 0
? sampling_ratio
: static_cast<int64_t>(ceil(roi_width / static_cast<T>(pooled_width)));

// We do average (integral) pooling inside a bin
// When the grid is empty, output zeros.
const T count = max(roi_bin_grid_h * roi_bin_grid_w, static_cast<int64_t>(1));
T output_val = static_cast<T>(0.0);

for (int64_t iy = 0; iy < roi_bin_grid_h; iy++) {
T y = roi_start_h + static_cast<T>(ph) * bin_size_h +
(static_cast<T>(iy) + static_cast<T>(0.5)) * bin_size_h / static_cast<T>(roi_bin_grid_h);
for (int64_t ix = 0; ix < roi_bin_grid_w; ix++) {
T x = roi_start_w + static_cast<T>(pw) * bin_size_w +
(static_cast<T>(ix) + static_cast<T>(0.5)) * bin_size_w / static_cast<T>(roi_bin_grid_w);

T val = bilinear_interpolate(offset_input, height, width, y, x, index);
output_val += val;
}
output_val /= count;

output[index] = output_val;
}

output_val /= count;
output[index] = output_val;
}

#define REGISTER_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("roi_align_" #DTYPE)]] \
kernel void roi_align<DTYPE, INT_DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
device DTYPE * output [[buffer(2)]], \
constant int64_t & output_size [[buffer(3)]], \
constant int64_t & channels [[buffer(4)]], \
constant int64_t & height [[buffer(5)]], \
constant int64_t & width [[buffer(6)]], \
constant int64_t & pooled_height [[buffer(7)]], \
constant int64_t & pooled_width [[buffer(8)]], \
constant int64_t & sampling_ratio [[buffer(9)]], \
constant bool & aligned [[buffer(10)]], \
constant float & spatial_scale [[buffer(11)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
#define REGISTER_ROI_ALIGN_OP(DTYPE) \
template \
[[host_name("roi_align_" #DTYPE)]] \
kernel void roi_align<DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
device DTYPE * output [[buffer(2)]], \
constant float & spatial_scale [[buffer(3)]], \
constant int64_t & channels [[buffer(4)]], \
constant int64_t & height [[buffer(5)]], \
constant int64_t & width [[buffer(6)]], \
constant int64_t & pooled_height [[buffer(7)]], \
constant int64_t & pooled_width [[buffer(8)]], \
constant int64_t & sampling_ratio [[buffer(9)]], \
constant bool & aligned [[buffer(10)]], \
uint index [[thread_position_in_grid]]);

template<typename T, typename integer_t>
kernel void roi_align_backward(
Expand Down Expand Up @@ -1005,16 +996,16 @@ kernel void ps_roi_pool_backward<DTYPE, INT_DTYPE>( \
constant int64_t & width [[buffer(7)]], \
constant int64_t & pooled_height [[buffer(8)]], \
constant int64_t & pooled_width [[buffer(9)]], \
constant int64_t & channels_out [[buffer(10)]], \
constant int64_t & channels_out [[buffer(10)]], \
constant float & spatial_scale [[buffer(11)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);

REGISTER_NMS_OP(float);
REGISTER_NMS_OP(half);
REGISTER_ROI_ALIGN_OP(float, int64_t);
REGISTER_ROI_ALIGN_OP(half, int64_t);
REGISTER_ROI_ALIGN_OP(float);
REGISTER_ROI_ALIGN_OP(half);
REGISTER_ROI_ALIGN_BACKWARD_OP(float, int64_t);
REGISTER_ROI_ALIGN_BACKWARD_OP(half, int64_t);
REGISTER_ROI_POOL_OP(float, int64_t);
Expand Down
20 changes: 6 additions & 14 deletions torchvision/csrc/ops/mps/roi_align_kernel.mm
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);

const std::string kernel = "roi_align_" + scalarToMetalTypeString(input.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);

auto threadsPerGrid = MTLSizeMake(output_size, 1, 1);
auto threadsPerThreadgroup =
MTLSizeMake(std::min(static_cast<int64_t>(visionPSO.maxTotalThreadsPerThreadgroup), output_size), 1, 1);

// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_});

Expand All @@ -68,24 +68,16 @@
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2];

[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:3];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11];

// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}

MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
[computeEncoder dispatchThreads:threadsPerGrid threadsPerThreadgroup:threadsPerThreadgroup];

getMPSProfiler().endProfileKernel(visionPSO);
}
Expand Down