4
4
from grunnur import API , Buffer , Context , Queue , cuda_api_id , opencl_api_id
5
5
6
6
7
- def test_allocate_and_copy (mock_or_real_context ): # noqa: PLR0915
7
+ @pytest .mark .parametrize ("sync" , [False , True ], ids = ["async" , "sync" ])
8
+ def test_transfer (mock_or_real_context , sync ):
8
9
context , mocked = mock_or_real_context
9
10
10
11
length = 100
@@ -18,73 +19,141 @@ def test_allocate_and_copy(mock_or_real_context): # noqa: PLR0915
18
19
assert buf .offset == 0
19
20
20
21
# Just covering the existence of the attribute.
21
- # Hard to actually check it without running a kernel
22
+ # Hard to actually check its validity without running a kernel
22
23
assert buf .kernel_arg is not None
23
24
24
25
queue = Queue (context .device )
25
- buf .set (queue , arr )
26
+
27
+ buf .set (queue , arr , no_async = sync )
26
28
27
29
# Read the whole buffer
28
30
res = numpy .empty_like (arr )
29
- buf .get (queue , res )
30
- queue .synchronize ()
31
+ buf .get (queue , res , async_ = not sync )
32
+ if not sync :
33
+ queue .synchronize ()
34
+ assert (res == arr ).all ()
35
+
36
+ # Device-to-device copy
37
+ res = numpy .empty_like (arr )
38
+ buf2 = Buffer .allocate (context .device , size )
39
+ buf2 .set (queue , buf , no_async = sync )
40
+ buf2 .get (queue , res , async_ = not sync )
41
+ if not sync :
42
+ queue .synchronize ()
31
43
assert (res == arr ).all ()
32
44
45
+
46
+ @pytest .mark .parametrize ("sync" , [False , True ], ids = ["async" , "sync" ])
47
+ def test_subregion (mock_or_real_context , sync ):
48
+ context , mocked = mock_or_real_context
49
+
50
+ length = 200
51
+ dtype = numpy .dtype ("int32" )
52
+ size = length * dtype .itemsize
53
+
54
+ arr = numpy .arange (length ).astype (dtype )
55
+
56
+ buf = Buffer .allocate (context .device , size )
57
+
58
+ queue = Queue (context .device )
59
+
60
+ region_offset = 64
61
+ region_length = 50
62
+
63
+ buf_region = buf .get_sub_region (region_offset * dtype .itemsize , region_length * dtype .itemsize )
64
+ buf .set (queue , arr ) # just to make the test harder, fill the buffer after creating a subregion
65
+
33
66
# Read a subregion
34
- buf_region = buf . get_sub_region ( 25 * dtype . itemsize , 50 * dtype . itemsize )
35
- arr_region = arr [25 : 25 + 50 ]
67
+ res = numpy . empty_like ( arr )
68
+ arr_region = arr [region_offset : region_offset + region_length ]
36
69
res_region = numpy .empty_like (arr_region )
37
- buf_region .get (queue , res_region )
38
- queue .synchronize ()
70
+ buf_region .get (queue , res_region , async_ = not sync )
71
+ if not sync :
72
+ queue .synchronize ()
39
73
assert (res_region == arr_region ).all ()
40
74
41
- # Check that our mock can detect taking a sub-region of a sub-region (segfault in OpenCL)
42
- if mocked and context .api .id == opencl_api_id ():
43
- with pytest .raises (RuntimeError , match = "Cannot create a subregion of subregion" ):
44
- buf_region .get_sub_region (0 , 10 )
45
-
46
75
# Write a subregion
47
- arr_region = (numpy .ones (50 ) * 100 ).astype (dtype )
48
- arr [25 : 25 + 50 ] = arr_region
49
- buf_region .set (queue , arr_region )
50
- buf .get (queue , res )
51
- queue .synchronize ()
76
+ res = numpy .empty_like (arr )
77
+ arr_region = (numpy .ones (50 ) * region_length ).astype (dtype )
78
+ arr [region_offset : region_offset + region_length ] = arr_region
79
+ buf_region .set (queue , arr_region , no_async = sync )
80
+ buf .get (queue , res , async_ = not sync )
81
+ if not sync :
82
+ queue .synchronize ()
52
83
assert (res == arr ).all ()
53
84
54
- # Subregion of subregion
55
- if context .api .id == cuda_api_id ():
56
- # In OpenCL that leads to segfault, but with CUDA we just emulate that with pointers.
57
- arr_region2 = (numpy .ones (20 ) * 200 ).astype (dtype )
58
- arr [25 + 20 : 25 + 40 ] = arr_region2
59
- buf_region2 = buf_region .get_sub_region (20 * dtype .itemsize , 20 * dtype .itemsize )
60
- buf_region2 .set (queue , arr_region2 )
61
- buf .get (queue , res )
62
- queue .synchronize ()
63
- assert (res == arr ).all ()
85
+
86
+ @pytest .mark .parametrize ("sync" , [False , True ], ids = ["async" , "sync" ])
87
+ def test_subregion_copy (mock_or_real_context , sync ):
88
+ context , mocked = mock_or_real_context
89
+
90
+ length = 100
91
+ dtype = numpy .dtype ("int32" )
92
+ size = length * dtype .itemsize
93
+
94
+ arr = numpy .arange (length ).astype (dtype )
95
+
96
+ buf = Buffer .allocate (context .device , size )
97
+
98
+ queue = Queue (context .device )
99
+ buf .set (queue , arr , no_async = sync )
100
+
101
+ region_offset = 64
102
+ region_length = 100
64
103
65
104
# Device-to-device copy
66
105
buf2 = Buffer .allocate (context .device , size * 2 )
67
106
buf2 .set (queue , numpy .ones (length * 2 , dtype ))
68
- buf2_view = buf2 .get_sub_region (50 * dtype .itemsize , 100 * dtype .itemsize )
69
- buf2_view .set (queue , buf )
107
+ buf2_view = buf2 .get_sub_region (region_offset * dtype .itemsize , region_length * dtype .itemsize )
108
+ buf2_view .set (queue , buf , no_async = sync )
70
109
res2 = numpy .empty (length * 2 , dtype )
71
- buf2 .get (queue , res2 )
72
- queue .synchronize ()
73
- assert (res2 [50 :150 ] == arr ).all ()
74
- assert (res2 [:50 ] == 1 ).all ()
75
- assert (res2 [150 :] == 1 ).all ()
110
+ buf2 .get (queue , res2 , async_ = not sync )
111
+ if not sync :
112
+ queue .synchronize ()
113
+ assert (res2 [region_offset : region_offset + region_length ] == arr ).all ()
114
+ assert (res2 [:region_offset ] == 1 ).all ()
115
+ assert (res2 [region_offset + region_length :] == 1 ).all ()
76
116
77
- # Device-to-device copy (no_async)
78
- buf2 = Buffer .allocate (context .device , size * 2 )
79
- buf2 .set (queue , numpy .ones (length * 2 , dtype ))
80
- buf2_view = buf2 .get_sub_region (50 * dtype .itemsize , 100 * dtype .itemsize )
81
- buf2_view .set (queue , buf , no_async = True )
82
- res2 = numpy .empty (length * 2 , dtype )
83
- buf2 .get (queue , res2 )
117
+
118
+ def test_subregion_of_subregion (mock_or_real_context ):
119
+ context , mocked = mock_or_real_context
120
+
121
+ length = 200
122
+ dtype = numpy .dtype ("int32" )
123
+ size = length * dtype .itemsize
124
+
125
+ r1_offset = 64
126
+ r1_length = 50
127
+
128
+ r2_offset = 20
129
+ r2_length = 20
130
+
131
+ buf = Buffer .allocate (context .device , size )
132
+ buf_region = buf .get_sub_region (r1_offset * dtype .itemsize , r1_length * dtype .itemsize )
133
+
134
+ if context .api .id == opencl_api_id ():
135
+ # Check that our mock can detect taking a sub-region of a sub-region (segfault in OpenCL)
136
+ if mocked :
137
+ with pytest .raises (RuntimeError , match = "Cannot create a subregion of subregion" ):
138
+ buf_region .get_sub_region (0 , 10 )
139
+ return
140
+
141
+ pytest .skip ("Subregions of subregions are not supported in OpenCL" )
142
+
143
+ queue = Queue (context .device )
144
+
145
+ arr = numpy .arange (length ).astype (dtype )
146
+ res = numpy .empty_like (arr )
147
+
148
+ buf .set (queue , arr )
149
+
150
+ arr_region2 = (numpy .ones (r2_length ) * 200 ).astype (dtype )
151
+ arr [r1_offset + r2_offset : r1_offset + r2_offset + r2_length ] = arr_region2
152
+ buf_region2 = buf_region .get_sub_region (r2_offset * dtype .itemsize , r2_length * dtype .itemsize )
153
+ buf_region2 .set (queue , arr_region2 )
154
+ buf .get (queue , res )
84
155
queue .synchronize ()
85
- assert (res2 [50 :150 ] == arr ).all ()
86
- assert (res2 [:50 ] == 1 ).all ()
87
- assert (res2 [150 :] == 1 ).all ()
156
+ assert (res == arr ).all ()
88
157
89
158
90
159
def test_subregion_overflow (mock_context ):
0 commit comments