@@ -31,7 +31,7 @@ def __next__(self):
31
31
self .count = self .count + 1
32
32
return ('[%d]' % count , elt )
33
33
34
- class _device_iterator (Iterator ):
34
+ class _cuda_iterator (Iterator ):
35
35
def __init__ (self , start , size ):
36
36
self .exec = exec
37
37
self .item = start
@@ -88,14 +88,14 @@ def __init__(self, val):
88
88
self .pointer = val ['m_storage' ]['m_begin' ]['m_iterator' ]
89
89
self .size = int (val ['m_size' ])
90
90
self .capacity = int (val ['m_storage' ]['m_size' ])
91
- self .is_device = False
92
- if str ( self .pointer . type ). startswith ( "thrust::device_ptr" ) :
91
+ self .is_device_vector = str ( self . pointer . type ). startswith ( "thrust::device_ptr" )
92
+ if self .is_device_vector :
93
93
self .pointer = self .pointer ['m_iterator' ]
94
- self .is_device = True
94
+ self .is_cuda_vector = "cuda" in str ( val [ 'm_storage' ][ 'm_allocator' ])
95
95
96
96
def children (self ):
97
- if self .is_device :
98
- return self ._device_iterator (self .pointer , self .size )
97
+ if self .is_cuda_vector :
98
+ return self ._cuda_iterator (self .pointer , self .size )
99
99
else :
100
100
return self ._host_accessible_iterator (self .pointer , self .size )
101
101
@@ -107,8 +107,8 @@ def display_hint(self):
107
107
return 'array'
108
108
109
109
110
- class ThrustReferencePrinter (gdb .printing .PrettyPrinter ):
111
- "Print a thrust::device_reference"
110
+ class ThrustCUDAReferencePrinter (gdb .printing .PrettyPrinter ):
111
+ "Print a thrust::device_reference that resides in CUDA memory space "
112
112
113
113
def __init__ (self , val ):
114
114
self .val = val
@@ -138,6 +138,22 @@ def to_string(self):
138
138
def display_hint (self ):
139
139
return None
140
140
141
+ class ThrustHostAccessibleReferencePrinter (gdb .printing .PrettyPrinter ):
142
+ def __init__ (self , val ):
143
+ self .val = val
144
+ self .pointer = val ['ptr' ]['m_iterator' ]
145
+
146
+ def children (self ):
147
+ return []
148
+
149
+ def to_string (self ):
150
+ typename = str (self .val .type )
151
+ return ('(%s) @%s: %s' % (typename , self .pointer , self .pointer .dereference ()))
152
+
153
+ def display_hint (self ):
154
+ return None
155
+
156
+
141
157
142
158
def lookup_thrust_type (val ):
143
159
if not str (val .type .unqualified ()).startswith ('thrust::' ):
@@ -146,7 +162,10 @@ def lookup_thrust_type(val):
146
162
if suffix .startswith ('host_vector' ) or suffix .startswith ('device_vector' ):
147
163
return ThrustVectorPrinter (val )
148
164
elif int (gdb .VERSION .split ("." )[0 ]) >= 10 and suffix .startswith ('device_reference' ):
149
- return ThrustReferencePrinter (val )
165
+ # look for tag in type name
166
+ if "cuda" in "" .join (str (field .type ) for field in val ["ptr" ].type .fields ()):
167
+ return ThrustCUDAReferencePrinter (val )
168
+ return ThrustHostAccessibleReferencePrinter (val )
150
169
return None
151
170
152
171
0 commit comments