-
Notifications
You must be signed in to change notification settings - Fork 184
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Switch to using gpu_struct decorator
- Loading branch information
Showing
6 changed files
with
130 additions
and
105 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
86 changes: 0 additions & 86 deletions
86
python/cuda_parallel/cuda/parallel/experimental/_structwrapper.py
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
99 changes: 99 additions & 0 deletions
99
python/cuda_parallel/cuda/parallel/experimental/gpu_struct.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from dataclasses import dataclass | ||
from dataclasses import fields as dataclass_fields | ||
|
||
import numba | ||
import numpy as np | ||
from numba.core import cgutils | ||
from numba.core.extending import ( | ||
make_attribute_wrapper, | ||
models, | ||
register_model, | ||
typeof_impl, | ||
) | ||
from numba.core.typing import signature as nb_signature | ||
from numba.core.typing.templates import AttributeTemplate, ConcreteTemplate | ||
from numba.cuda.cudadecl import registry as cuda_registry | ||
from numba.extending import lower_builtin | ||
|
||
from .typing import GpuStruct | ||
|
||
|
||
def gpu_struct(this: type) -> GpuStruct: | ||
anns = getattr(this, "__annotations__", {}) | ||
|
||
# set the .dtype attribute on the class for numpy compatibility: | ||
setattr(this, "dtype", np.dtype(list(anns.items()))) | ||
|
||
# define __post_init__ to create a ctypes object from the fields, | ||
# and keep a reference to it in the `._data` attribute. | ||
def __post_init__(self): | ||
ctypes_typ = np.ctypeslib.as_ctypes_type(this.dtype) | ||
self._data = ctypes_typ(*(getattr(self, name) for name in this.dtype.names)) | ||
|
||
setattr(this, "__post_init__", __post_init__) | ||
|
||
# create a dataclass: | ||
this = dataclass(this) | ||
fields = dataclass_fields(this) | ||
|
||
# define a numba type corresponding to the dataclass: | ||
class ThisType(numba.types.Type): | ||
def __init__(self): | ||
super().__init__(name=this.__name__) | ||
|
||
this_type = ThisType() | ||
|
||
@typeof_impl.register(this) | ||
def typeof_this(val, c): | ||
return ThisType() | ||
|
||
# Data model corresponding to ThisType: | ||
@register_model(ThisType) | ||
class ThisModel(models.StructModel): | ||
def __init__(self, dmm, fe_type): | ||
members = [(field.name, numba.from_dtype(field.type)) for field in fields] | ||
super().__init__(dmm, fe_type, members) | ||
|
||
# Typing for accessing attributes (fields) of the dataclass: | ||
class ThisAttrsTemplate(AttributeTemplate): | ||
pass | ||
|
||
for field in fields: | ||
typ = field.type | ||
name = field.name | ||
|
||
def resolver(self, this): | ||
return numba.from_dtype(typ) | ||
|
||
setattr(ThisAttrsTemplate, f"resolve_{name}", resolver) | ||
|
||
@cuda_registry.register_attr | ||
class ThisAttrs(ThisAttrsTemplate): | ||
key = this_type | ||
|
||
# Lowering for attribute access: | ||
for field in fields: | ||
make_attribute_wrapper(ThisType, field.name, field.name) | ||
|
||
# Register typing for constructor. | ||
@cuda_registry.register | ||
class TypeConstructor(ConcreteTemplate): | ||
key = this | ||
cases = [ | ||
nb_signature(this_type, *[numba.from_dtype(field.type) for field in fields]) | ||
] | ||
|
||
cuda_registry.register_global(this, numba.types.Function(TypeConstructor)) | ||
|
||
def type_constructor(context, builder, sig, args): | ||
ty = sig.return_type | ||
retval = cgutils.create_struct_proxy(ty)(context, builder) | ||
for field, val in zip(fields, args): | ||
setattr(retval, field.name, val) | ||
return retval._getvalue() | ||
|
||
lower_builtin(this, *[numba.from_dtype(field.type) for field in fields])( | ||
type_constructor | ||
) | ||
|
||
return this |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters