Skip to content

Commit

Permalink
GPU programs now launch for simple bodies. Need to fix type alignment…
Browse files Browse the repository at this point in the history
… and usage of np functionality.
  • Loading branch information
braxtoncuneo committed Apr 22, 2024
1 parent 8942e3d commit 79445f0
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 133 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ __pycache__
*.nbc
*.nbi

# PTX cache
__ptxcache__

# Editor
.spyproject
*.swp
Expand Down
50 changes: 44 additions & 6 deletions mcdc/adapt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from numba import njit, jit, objmode, literal_unroll, cuda
from numba import njit, jit, objmode, literal_unroll, cuda, types
from numba.extending import intrinsic
import numba
import mcdc.type_ as type_
import mcdc.kernel as kernel
Expand All @@ -25,6 +26,46 @@ def unknown_target(target):
print_error(f"ERROR: Unrecognized target '{target}'")



# =============================================================================
# uintp/voidptr casters
# =============================================================================


@intrinsic
def cast_uintp_to_voidptr(typingctx, src):
# check for accepted types
if isinstance(src, types.Integer):
# create the expected type signature
result_type = types.voidptr
sig = result_type(types.uintp)
# defines the custom code generation
def codegen(context, builder, signature, args):
# llvm IRBuilder code here
[src] = args
rtype = signature.return_type
llrtype = context.get_value_type(rtype)
return builder.inttoptr(src, llrtype)
return sig, codegen


@intrinsic
def cast_voidptr_to_uintp(typingctx, src):
# check for accepted types
if isinstance(src, types.RawPointer):
# create the expected type signature
result_type = types.uintp
sig = result_type(types.voidptr)
# defines the custom code generation
def codegen(context, builder, signature, args):
# llvm IRBuilder code here
[src] = args
rtype = signature.return_type
llrtype = context.get_value_type(rtype)
return builder.ptrtoint(src, llrtype)
return sig, codegen


# =============================================================================
# Decorators
# =============================================================================
Expand Down Expand Up @@ -89,7 +130,7 @@ def eval_toggle():
else:
global do_nothing_id
name = func.__name__
#print(f"do_nothing_{do_nothing_id} for {name}")
print(f"do_nothing_{do_nothing_id} for {name}")
arg_count = len(inspect.signature(func).parameters)
overwrite_func(func,numba.njit(generate_do_nothing(arg_count)))

Expand Down Expand Up @@ -287,10 +328,7 @@ def add_active(particle,prog):
@for_gpu()
def add_active(particle,prog):
P = kernel.recordlike_to_particle(particle)
if SIMPLE_ASYNC:
step_async(prog,P)
else:
find_cell_async(prog,P)
step_async(prog,P)


@for_cpu()
Expand Down
33 changes: 23 additions & 10 deletions mcdc/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def adapt_rng(object_mode=False):
wrapping_mul = wrapping_mul_python


@njit(numba.uint64(numba.uint64, numba.uint64))
@njit
def split_seed(key, seed):
"""murmur_hash64a"""
multiplier = numba.uint64(0xC6A4A7935BD1E995)
Expand Down Expand Up @@ -597,7 +597,7 @@ def bank_rebalance(mcdc):


@njit
def distribute_work(N, mcdc, precursor=False):
def distribute_work_precursor(N, mcdc, precursor):
size = mcdc["mpi_size"]
rank = mcdc["mpi_rank"]

Expand Down Expand Up @@ -629,6 +629,9 @@ def distribute_work(N, mcdc, precursor=False):
mcdc["mpi_work_size_precursor"] = work_size
mcdc["mpi_work_size_total_precursor"] = work_size_total

@njit
def distribute_work(N, mcdc):
distribute_work_precursor(N,mcdc,False)

# =============================================================================
# IC generator
Expand Down Expand Up @@ -1596,7 +1599,7 @@ def eigenvalue_tally(P, distance, mcdc):
ID_nuclide = material["nuclide_IDs"][i]
nuclide = mcdc["nuclides"][ID_nuclide]
for j in range(J):
nu_d = get_nu(NU_FISSION_DELAYED, nuclide, E, j)
nu_d = get_nu_group(NU_FISSION_DELAYED, nuclide, E, j)
decay = nuclide["ce_decay"][j]
total += nu_d / decay
C_density = flux * total * SigmaF / mcdc["k_eff"]
Expand Down Expand Up @@ -2573,7 +2576,7 @@ def fission_CE(P, nuclide, P_new):
nu_p = get_nu(NU_FISSION_PROMPT, nuclide, E)
nu_d = np.zeros(J)
for j in range(J):
nu_d[j] = get_nu(NU_FISSION_DELAYED, nuclide, E, j)
nu_d[j] = get_nu_group(NU_FISSION_DELAYED, nuclide, E, j)

# Delayed?
prompt = True
Expand Down Expand Up @@ -2636,7 +2639,9 @@ def fission_CE(P, nuclide, P_new):


@njit
def branchless_collision(P, mcdc):
def branchless_collision(P, prog):
mcdc = adapt.device(prog)

material = mcdc["materials"][P["material_ID"]]

# Adjust weight
Expand All @@ -2660,7 +2665,7 @@ def branchless_collision(P, mcdc):
idx_census = mcdc["idx_census"]
if P["t"] > mcdc["setting"]["census_time"][idx_census]:
P["alive"] = False
add_particle(split_particle(P), mcdc["bank_census"])
adapt.add_census(split_particle(P), prog)
elif P["t"] > mcdc["setting"]["time_boundary"]:
P["alive"] = False

Expand Down Expand Up @@ -4174,7 +4179,7 @@ def get_microXS(type_, nuclide, E):
@njit
def get_XS(data, E, E_grid, NE):
# Search XS energy bin index
idx = binary_search(E, E_grid, NE)
idx = binary_search_length(E, E_grid, NE)

# Extrapolate if E is outside the given data
if idx == -1:
Expand All @@ -4192,7 +4197,7 @@ def get_XS(data, E, E_grid, NE):


@njit
def get_nu(type_, nuclide, E, group=-1):
def get_nu_group(type_, nuclide, E, group):
if type_ == NU_FISSION:
nu = get_XS(nuclide["ce_nu_p"], E, nuclide["E_nu_p"], nuclide["NE_nu_p"])
for i in range(6):
Expand All @@ -4217,6 +4222,9 @@ def get_nu(type_, nuclide, E, group=-1):
nuclide["ce_nu_d"][group], E, nuclide["E_nu_d"], nuclide["NE_nu_d"]
)

@njit
def get_nu(type_, nuclide, E):
return get_nu_group(type_, nuclide, E, -1)

@njit
def sample_nuclide(material, P, type_, mcdc):
Expand All @@ -4239,7 +4247,7 @@ def sample_Eout(P_new, E_grid, NE, chi):
xi = rng(P_new)

# Determine bin index
idx = binary_search(xi, chi, NE)
idx = binary_search_length(xi, chi, NE)

# Linear interpolation
E1 = E_grid[idx]
Expand All @@ -4255,7 +4263,7 @@ def sample_Eout(P_new, E_grid, NE, chi):


@njit
def binary_search(val, grid, length=0):
def binary_search_length(val, grid, length):
"""
Binary search that returns the bin index of the value `val` given grid `grid`.
Expand All @@ -4281,6 +4289,11 @@ def binary_search(val, grid, length=0):
return int(right)


@njit
def binary_search(val, grid):
return binary_search_length(val,grid,0)


@njit
def lartg(f, g):
"""
Expand Down
Loading

0 comments on commit 79445f0

Please sign in to comment.