-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathradixsort.cu
95 lines (71 loc) · 2.06 KB
/
radixsort.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#include "radixsort.h"
#include <cstdio>
extern "C" {
__global__
/* Result:
[pos] - target positions in blocks only
[zerosInBlocks] - for each block number of zeros in that block */
void computeLocalPositions(int* in, int n, int* pos, int k, int* zerosInBlocks) {
int thid = blockDim.x * blockIdx.x + threadIdx.x;
int id = threadIdx.x;
__shared__ int sh[THREADS_PER_BLOCK];
int realShSize =
blockIdx.x == gridDim.x - 1 ?
n - blockIdx.x * THREADS_PER_BLOCK : THREADS_PER_BLOCK;
if(id >= realShSize) return;
int bit = k < 32 ?
(in[3 * thid + 1] >> k) & 1 :
(in[3 * thid] >> (k - 32)) & 1;
sh[id] = bit;
for(int offset = 1; offset < THREADS_PER_BLOCK; offset *= 2) {
__syncthreads();
int tmp = 0;
if(id >= offset)
tmp = sh[id - offset];
__syncthreads();
sh[id] += tmp;
}
__syncthreads();
int zeros = realShSize - sh[realShSize - 1];
pos[thid] = bit ?
zeros + sh[id] - 1 : id - sh[id];
if(id == 0)
zerosInBlocks[blockIdx.x] = zeros;
}
}
extern "C" {
__global__
/*
Parameters:
[zerosPref] - prefix sum of array [zerosInBlock] (filled by function radixsort)
Result:
[pos] - global target positions */
void computeGlobalPositions(int* in, int n, int* pos, int k, int* zerosPref) {
int thid = blockDim.x * blockIdx.x + threadIdx.x;
if(thid >= n) return;
int bit = k < 32 ?
(in[3 * thid + 1] >> k) & 1 :
(in[3 * thid] >> (k - 32)) & 1;
int elementsBefore = blockIdx.x * THREADS_PER_BLOCK;
int zerosBefore =
blockIdx.x == 0 ?
0 : zerosPref[blockIdx.x - 1];
int zerosAfter = zerosPref[gridDim.x - 1] - zerosPref[blockIdx.x];
if(bit == 0) {
pos[thid] += zerosBefore;
} else {
pos[thid] += zerosAfter + elementsBefore;
}
}
}
extern "C" {
__global__
/* This function rewrites elements from array [in] to array [out] at their correct positions in order given in array [pos] */
void permute(int* in, int n, int* out, int* pos) {
int thid = blockDim.x * blockIdx.x + threadIdx.x;
if(thid >= n) return;
out[3 * pos[thid]] = in[3 * thid];
out[3 * pos[thid] + 1] = in[3 * thid + 1];
out[3 * pos[thid] + 2] = in[3 * thid + 2];
}
}