@@ -225,105 +225,96 @@ kernel void nms<DTYPE ## 4, DTYPE>( \
225
225
uint2 tgid [[threadgroup_position_in_grid]], \
226
226
uint2 tid2 [[thread_position_in_threadgroup]]);
227
227
228
- template<typename T, typename integer_t>
228
+ template <typename T, typename integer_t>
229
229
kernel void roi_align(
230
230
constant T * input [[buffer(0)]],
231
231
constant T * rois [[buffer(1)]],
232
232
device T * output [[buffer(2)]],
233
- constant int64_t & output_size [[buffer(3)]],
233
+ constant float & spatial_scale [[buffer(3)]],
234
234
constant int64_t & channels [[buffer(4)]],
235
235
constant int64_t & height [[buffer(5)]],
236
236
constant int64_t & width [[buffer(6)]],
237
237
constant int64_t & pooled_height [[buffer(7)]],
238
238
constant int64_t & pooled_width [[buffer(8)]],
239
239
constant int64_t & sampling_ratio [[buffer(9)]],
240
240
constant bool & aligned [[buffer(10)]],
241
- constant float & spatial_scale [[buffer(11)]],
242
- uint2 tgid [[threadgroup_position_in_grid]],
243
- uint2 tptg [[threads_per_threadgroup]],
244
- uint2 tid2 [[thread_position_in_threadgroup]]){
245
- MPS_1D_KERNEL_LOOP(index, output_size, 1) {
246
- // (n, c, ph, pw) is an element in the pooled output
247
- integer_t pw = index % pooled_width;
248
- integer_t ph = (index / pooled_width) % pooled_height;
249
- integer_t c = (index / pooled_width / pooled_height) % channels;
250
- integer_t n = index / pooled_width / pooled_height / channels;
251
-
252
- constant T* offset_rois = rois + n * 5;
253
- integer_t roi_batch_ind = offset_rois[0];
254
-
255
- // Do not using rounding; this implementation detail is critical
256
- T offset = aligned ? (T)0.5 : (T)0.0;
257
- T roi_start_w = offset_rois[1] * spatial_scale - offset;
258
- T roi_start_h = offset_rois[2] * spatial_scale - offset;
259
- T roi_end_w = offset_rois[3] * spatial_scale - offset;
260
- T roi_end_h = offset_rois[4] * spatial_scale - offset;
261
-
262
- T roi_width = roi_end_w - roi_start_w;
263
- T roi_height = roi_end_h - roi_start_h;
264
- if (!aligned) {
265
- // Force malformed ROIs to be 1x1
266
- roi_width = max(roi_width, (T)1.);
267
- roi_height = max(roi_height, (T)1.);
268
- }
269
-
270
- T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
271
- T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
272
-
273
- constant T* offset_input =
274
- input + (roi_batch_ind * channels + c) * height * width;
275
-
276
- // We use roi_bin_grid to sample the grid and mimic integral
277
- integer_t roi_bin_grid_h = (sampling_ratio > 0)
278
- ? sampling_ratio
279
- : ceil(roi_height / pooled_height); // e.g., = 2
280
- integer_t roi_bin_grid_w =
281
- (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
282
-
283
- // We do average (integral) pooling inside a bin
284
- // When the grid is empty, output zeros.
285
- const T count = max(roi_bin_grid_h * roi_bin_grid_w, static_cast<integer_t>(1)); // e.g. = 4
286
-
287
- T output_val = 0.;
288
- for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
289
- {
290
- const T y = roi_start_h + ph * bin_size_h +
291
- static_cast<T>(iy + .5f) * bin_size_h /
292
- static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
293
- for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) {
294
- const T x = roi_start_w + pw * bin_size_w +
295
- static_cast<T>(ix + .5f) * bin_size_w /
296
- static_cast<T>(roi_bin_grid_w);
241
+ uint index [[thread_position_in_grid]])
242
+ {
243
+ // Decode linear index into (n, c, ph, pw)
244
+ integer_t pw = index % pooled_width;
245
+ integer_t ph = (index / pooled_width) % pooled_height;
246
+ integer_t c = (index / pooled_width / pooled_height) % channels;
247
+ integer_t n = index / (pooled_width * pooled_height * channels);
248
+
249
+ constant T* offset_rois = rois + n * 5;
250
+ integer_t roi_batch_ind = static_cast<integer_t>(offset_rois[0]);
251
+
252
+ // Do not using rounding; this implementation detail is critical
253
+ T offset = aligned ? static_cast<T>(0.5) : static_cast<T>(0.0);
254
+ T roi_start_w = offset_rois[1] * spatial_scale - offset;
255
+ T roi_start_h = offset_rois[2] * spatial_scale - offset;
256
+ T roi_end_w = offset_rois[3] * spatial_scale - offset;
257
+ T roi_end_h = offset_rois[4] * spatial_scale - offset;
258
+
259
+ T roi_width = roi_end_w - roi_start_w;
260
+ T roi_height = roi_end_h - roi_start_h;
261
+
262
+ if (!aligned) {
263
+ // Force malformed ROIs to be 1x1
264
+ roi_width = max(roi_width, static_cast<T>(1.0));
265
+ roi_height = max(roi_height, static_cast<T>(1.0));
266
+ }
297
267
298
- T val = bilinear_interpolate(offset_input, height, width, y, x, index);
299
- output_val += val;
300
- }
268
+ T bin_size_h = roi_height / static_cast<T>(pooled_height);
269
+ T bin_size_w = roi_width / static_cast<T>(pooled_width);
270
+
271
+ constant T* offset_input = input + (roi_batch_ind * channels + c) * height * width;
272
+
273
+ // We use roi_bin_grid to sample the grid and mimic integral
274
+ integer_t roi_bin_grid_h = sampling_ratio > 0
275
+ ? sampling_ratio
276
+ : static_cast<integer_t>(ceil(roi_height / static_cast<T>(pooled_height)));
277
+ integer_t roi_bin_grid_w = sampling_ratio > 0
278
+ ? sampling_ratio
279
+ : static_cast<integer_t>(ceil(roi_width / static_cast<T>(pooled_width)));
280
+
281
+ // We do average (integral) pooling inside a bin
282
+ // When the grid is empty, output zeros.
283
+ const T count = max(roi_bin_grid_h * roi_bin_grid_w, static_cast<integer_t>(1));
284
+ T output_val = static_cast<T>(0.0);
285
+
286
+ for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) {
287
+ T y = roi_start_h + static_cast<T>(ph) * bin_size_h +
288
+ (static_cast<T>(iy) + static_cast<T>(0.5)) * bin_size_h / static_cast<T>(roi_bin_grid_h);
289
+ for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) {
290
+ T x = roi_start_w + static_cast<T>(pw) * bin_size_w +
291
+ (static_cast<T>(ix) + static_cast<T>(0.5)) * bin_size_w / static_cast<T>(roi_bin_grid_w);
292
+
293
+ T val = bilinear_interpolate(offset_input, height, width, y, x, index);
294
+ output_val += val;
301
295
}
302
- output_val /= count;
303
-
304
- output[index] = output_val;
305
296
}
297
+
298
+ output_val /= count;
299
+ output[index] = output_val;
306
300
}
307
301
308
- #define REGISTER_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \
309
- template \
310
- [[host_name("roi_align_" #DTYPE)]] \
311
- kernel void roi_align<DTYPE, INT_DTYPE>( \
312
- constant DTYPE * input [[buffer(0)]], \
313
- constant DTYPE * rois [[buffer(1)]], \
314
- device DTYPE * output [[buffer(2)]], \
315
- constant int64_t & output_size [[buffer(3)]], \
316
- constant int64_t & channels [[buffer(4)]], \
317
- constant int64_t & height [[buffer(5)]], \
318
- constant int64_t & width [[buffer(6)]], \
319
- constant int64_t & pooled_height [[buffer(7)]], \
320
- constant int64_t & pooled_width [[buffer(8)]], \
321
- constant int64_t & sampling_ratio [[buffer(9)]], \
322
- constant bool & aligned [[buffer(10)]], \
323
- constant float & spatial_scale [[buffer(11)]], \
324
- uint2 tgid [[threadgroup_position_in_grid]], \
325
- uint2 tptg [[threads_per_threadgroup]], \
326
- uint2 tid2 [[thread_position_in_threadgroup]]);
302
+ #define REGISTER_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \
303
+ template \
304
+ [[host_name("roi_align_" #DTYPE)]] \
305
+ kernel void roi_align<DTYPE, INT_DTYPE>( \
306
+ constant DTYPE * input [[buffer(0)]], \
307
+ constant DTYPE * rois [[buffer(1)]], \
308
+ device DTYPE * output [[buffer(2)]], \
309
+ constant float & spatial_scale [[buffer(3)]], \
310
+ constant int64_t & channels [[buffer(4)]], \
311
+ constant int64_t & height [[buffer(5)]], \
312
+ constant int64_t & width [[buffer(6)]], \
313
+ constant int64_t & pooled_height [[buffer(7)]], \
314
+ constant int64_t & pooled_width [[buffer(8)]], \
315
+ constant int64_t & sampling_ratio [[buffer(9)]], \
316
+ constant bool & aligned [[buffer(10)]], \
317
+ uint index [[thread_position_in_grid]]);
327
318
328
319
template<typename T, typename integer_t>
329
320
kernel void roi_align_backward(
@@ -1005,7 +996,7 @@ kernel void ps_roi_pool_backward<DTYPE, INT_DTYPE>( \
1005
996
constant int64_t & width [[buffer(7)]], \
1006
997
constant int64_t & pooled_height [[buffer(8)]], \
1007
998
constant int64_t & pooled_width [[buffer(9)]], \
1008
- constant int64_t & channels_out [[buffer(10)]], \
999
+ constant int64_t & channels_out [[buffer(10)]], \
1009
1000
constant float & spatial_scale [[buffer(11)]], \
1010
1001
uint2 tgid [[threadgroup_position_in_grid]], \
1011
1002
uint2 tptg [[threads_per_threadgroup]], \
0 commit comments