Skip to content

Commit 34d749d

Browse files
committed
Improve runtime complexity of roi_align
1 parent a63221d commit 34d749d

File tree

2 files changed

+81
-98
lines changed

2 files changed

+81
-98
lines changed

torchvision/csrc/ops/mps/mps_kernels.h

Lines changed: 75 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -225,105 +225,96 @@ kernel void nms<DTYPE ## 4, DTYPE>( \
225225
uint2 tgid [[threadgroup_position_in_grid]], \
226226
uint2 tid2 [[thread_position_in_threadgroup]]);
227227
228-
template<typename T, typename integer_t>
228+
template <typename T, typename integer_t>
229229
kernel void roi_align(
230230
constant T * input [[buffer(0)]],
231231
constant T * rois [[buffer(1)]],
232232
device T * output [[buffer(2)]],
233-
constant int64_t & output_size [[buffer(3)]],
233+
constant float & spatial_scale [[buffer(3)]],
234234
constant int64_t & channels [[buffer(4)]],
235235
constant int64_t & height [[buffer(5)]],
236236
constant int64_t & width [[buffer(6)]],
237237
constant int64_t & pooled_height [[buffer(7)]],
238238
constant int64_t & pooled_width [[buffer(8)]],
239239
constant int64_t & sampling_ratio [[buffer(9)]],
240240
constant bool & aligned [[buffer(10)]],
241-
constant float & spatial_scale [[buffer(11)]],
242-
uint2 tgid [[threadgroup_position_in_grid]],
243-
uint2 tptg [[threads_per_threadgroup]],
244-
uint2 tid2 [[thread_position_in_threadgroup]]){
245-
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
246-
// (n, c, ph, pw) is an element in the pooled output
247-
integer_t pw = index % pooled_width;
248-
integer_t ph = (index / pooled_width) % pooled_height;
249-
integer_t c = (index / pooled_width / pooled_height) % channels;
250-
integer_t n = index / pooled_width / pooled_height / channels;
251-
252-
constant T* offset_rois = rois + n * 5;
253-
integer_t roi_batch_ind = offset_rois[0];
254-
255-
// Do not using rounding; this implementation detail is critical
256-
T offset = aligned ? (T)0.5 : (T)0.0;
257-
T roi_start_w = offset_rois[1] * spatial_scale - offset;
258-
T roi_start_h = offset_rois[2] * spatial_scale - offset;
259-
T roi_end_w = offset_rois[3] * spatial_scale - offset;
260-
T roi_end_h = offset_rois[4] * spatial_scale - offset;
261-
262-
T roi_width = roi_end_w - roi_start_w;
263-
T roi_height = roi_end_h - roi_start_h;
264-
if (!aligned) {
265-
// Force malformed ROIs to be 1x1
266-
roi_width = max(roi_width, (T)1.);
267-
roi_height = max(roi_height, (T)1.);
268-
}
269-
270-
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
271-
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
272-
273-
constant T* offset_input =
274-
input + (roi_batch_ind * channels + c) * height * width;
275-
276-
// We use roi_bin_grid to sample the grid and mimic integral
277-
integer_t roi_bin_grid_h = (sampling_ratio > 0)
278-
? sampling_ratio
279-
: ceil(roi_height / pooled_height); // e.g., = 2
280-
integer_t roi_bin_grid_w =
281-
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
282-
283-
// We do average (integral) pooling inside a bin
284-
// When the grid is empty, output zeros.
285-
const T count = max(roi_bin_grid_h * roi_bin_grid_w, static_cast<integer_t>(1)); // e.g. = 4
286-
287-
T output_val = 0.;
288-
for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
289-
{
290-
const T y = roi_start_h + ph * bin_size_h +
291-
static_cast<T>(iy + .5f) * bin_size_h /
292-
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
293-
for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) {
294-
const T x = roi_start_w + pw * bin_size_w +
295-
static_cast<T>(ix + .5f) * bin_size_w /
296-
static_cast<T>(roi_bin_grid_w);
241+
uint index [[thread_position_in_grid]])
242+
{
243+
// Decode linear index into (n, c, ph, pw)
244+
integer_t pw = index % pooled_width;
245+
integer_t ph = (index / pooled_width) % pooled_height;
246+
integer_t c = (index / pooled_width / pooled_height) % channels;
247+
integer_t n = index / (pooled_width * pooled_height * channels);
248+
249+
constant T* offset_rois = rois + n * 5;
250+
integer_t roi_batch_ind = static_cast<integer_t>(offset_rois[0]);
251+
252+
// Do not using rounding; this implementation detail is critical
253+
T offset = aligned ? static_cast<T>(0.5) : static_cast<T>(0.0);
254+
T roi_start_w = offset_rois[1] * spatial_scale - offset;
255+
T roi_start_h = offset_rois[2] * spatial_scale - offset;
256+
T roi_end_w = offset_rois[3] * spatial_scale - offset;
257+
T roi_end_h = offset_rois[4] * spatial_scale - offset;
258+
259+
T roi_width = roi_end_w - roi_start_w;
260+
T roi_height = roi_end_h - roi_start_h;
261+
262+
if (!aligned) {
263+
// Force malformed ROIs to be 1x1
264+
roi_width = max(roi_width, static_cast<T>(1.0));
265+
roi_height = max(roi_height, static_cast<T>(1.0));
266+
}
297267
298-
T val = bilinear_interpolate(offset_input, height, width, y, x, index);
299-
output_val += val;
300-
}
268+
T bin_size_h = roi_height / static_cast<T>(pooled_height);
269+
T bin_size_w = roi_width / static_cast<T>(pooled_width);
270+
271+
constant T* offset_input = input + (roi_batch_ind * channels + c) * height * width;
272+
273+
// We use roi_bin_grid to sample the grid and mimic integral
274+
integer_t roi_bin_grid_h = sampling_ratio > 0
275+
? sampling_ratio
276+
: static_cast<integer_t>(ceil(roi_height / static_cast<T>(pooled_height)));
277+
integer_t roi_bin_grid_w = sampling_ratio > 0
278+
? sampling_ratio
279+
: static_cast<integer_t>(ceil(roi_width / static_cast<T>(pooled_width)));
280+
281+
// We do average (integral) pooling inside a bin
282+
// When the grid is empty, output zeros.
283+
const T count = max(roi_bin_grid_h * roi_bin_grid_w, static_cast<integer_t>(1));
284+
T output_val = static_cast<T>(0.0);
285+
286+
for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) {
287+
T y = roi_start_h + static_cast<T>(ph) * bin_size_h +
288+
(static_cast<T>(iy) + static_cast<T>(0.5)) * bin_size_h / static_cast<T>(roi_bin_grid_h);
289+
for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) {
290+
T x = roi_start_w + static_cast<T>(pw) * bin_size_w +
291+
(static_cast<T>(ix) + static_cast<T>(0.5)) * bin_size_w / static_cast<T>(roi_bin_grid_w);
292+
293+
T val = bilinear_interpolate(offset_input, height, width, y, x, index);
294+
output_val += val;
301295
}
302-
output_val /= count;
303-
304-
output[index] = output_val;
305296
}
297+
298+
output_val /= count;
299+
output[index] = output_val;
306300
}
307301
308-
#define REGISTER_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \
309-
template \
310-
[[host_name("roi_align_" #DTYPE)]] \
311-
kernel void roi_align<DTYPE, INT_DTYPE>( \
312-
constant DTYPE * input [[buffer(0)]], \
313-
constant DTYPE * rois [[buffer(1)]], \
314-
device DTYPE * output [[buffer(2)]], \
315-
constant int64_t & output_size [[buffer(3)]], \
316-
constant int64_t & channels [[buffer(4)]], \
317-
constant int64_t & height [[buffer(5)]], \
318-
constant int64_t & width [[buffer(6)]], \
319-
constant int64_t & pooled_height [[buffer(7)]], \
320-
constant int64_t & pooled_width [[buffer(8)]], \
321-
constant int64_t & sampling_ratio [[buffer(9)]], \
322-
constant bool & aligned [[buffer(10)]], \
323-
constant float & spatial_scale [[buffer(11)]], \
324-
uint2 tgid [[threadgroup_position_in_grid]], \
325-
uint2 tptg [[threads_per_threadgroup]], \
326-
uint2 tid2 [[thread_position_in_threadgroup]]);
302+
#define REGISTER_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \
303+
template \
304+
[[host_name("roi_align_" #DTYPE)]] \
305+
kernel void roi_align<DTYPE, INT_DTYPE>( \
306+
constant DTYPE * input [[buffer(0)]], \
307+
constant DTYPE * rois [[buffer(1)]], \
308+
device DTYPE * output [[buffer(2)]], \
309+
constant float & spatial_scale [[buffer(3)]], \
310+
constant int64_t & channels [[buffer(4)]], \
311+
constant int64_t & height [[buffer(5)]], \
312+
constant int64_t & width [[buffer(6)]], \
313+
constant int64_t & pooled_height [[buffer(7)]], \
314+
constant int64_t & pooled_width [[buffer(8)]], \
315+
constant int64_t & sampling_ratio [[buffer(9)]], \
316+
constant bool & aligned [[buffer(10)]], \
317+
uint index [[thread_position_in_grid]]);
327318
328319
template<typename T, typename integer_t>
329320
kernel void roi_align_backward(
@@ -1005,7 +996,7 @@ kernel void ps_roi_pool_backward<DTYPE, INT_DTYPE>( \
1005996
constant int64_t & width [[buffer(7)]], \
1006997
constant int64_t & pooled_height [[buffer(8)]], \
1007998
constant int64_t & pooled_width [[buffer(9)]], \
1008-
constant int64_t & channels_out [[buffer(10)]], \
999+
constant int64_t & channels_out [[buffer(10)]], \
10091000
constant float & spatial_scale [[buffer(11)]], \
10101001
uint2 tgid [[threadgroup_position_in_grid]], \
10111002
uint2 tptg [[threads_per_threadgroup]], \

torchvision/csrc/ops/mps/roi_align_kernel.mm

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,14 @@
5151
dispatch_sync(mpsStream->queue(), ^() {
5252
@autoreleasepool {
5353
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
54-
MTLSize threadgroupsPerGrid = MTLSizeMake(
55-
std::min(ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
56-
1,
57-
1);
5854

5955
const std::string kernel = "roi_align_" + scalarToMetalTypeString(input.scalar_type());
6056
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
6157

58+
auto threadsPerGrid = MTLSizeMake(output_size, 1, 1);
59+
auto threadsPerThreadgroup =
60+
MTLSizeMake(std::min(static_cast<int64_t>(visionPSO.maxTotalThreadsPerThreadgroup), output_size), 1, 1);
61+
6262
// this function call is a no-op if MPS Profiler is not enabled
6363
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_});
6464

@@ -68,24 +68,16 @@
6868
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
6969
[computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2];
7070

71-
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3];
71+
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:3];
7272
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4];
7373
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5];
7474
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6];
7575
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7];
7676
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8];
7777
[computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9];
7878
[computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10];
79-
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11];
80-
81-
// A threadGroup is equivalent to a cuda's block.
82-
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
83-
if (tgSize > threadsPerBlock) {
84-
tgSize = threadsPerBlock;
85-
}
8679

87-
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
88-
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
80+
[computeEncoder dispatchThreads:threadsPerGrid threadsPerThreadgroup:threadsPerThreadgroup];
8981

9082
getMPSProfiler().endProfileKernel(visionPSO);
9183
}

0 commit comments

Comments
 (0)