Skip to content

Commit cfb4cfb

Browse files
authored
Merge pull request #105 from mehdiataei/functional_examples
Converted examples to functional. Made compute_backend name consistent.
2 parents 0cd49de + a3f1b17 commit cfb4cfb

30 files changed

+640
-585
lines changed

examples/cfd/flow_past_sphere_3d.py

+125-134
Original file line numberDiff line numberDiff line change
@@ -16,95 +16,54 @@
1616
import jax.numpy as jnp
1717
import time
1818

19+
# -------------------------- Simulation Setup --------------------------
20+
21+
omega = 1.6
22+
grid_shape = (512 // 2, 128 // 2, 128 // 2)
23+
compute_backend = ComputeBackend.WARP
24+
precision_policy = PrecisionPolicy.FP32FP32
25+
velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, compute_backend=compute_backend)
26+
u_max = 0.04
27+
num_steps = 10000
28+
post_process_interval = 1000
29+
30+
# Initialize XLB
31+
xlb.init(
32+
velocity_set=velocity_set,
33+
default_backend=compute_backend,
34+
default_precision_policy=precision_policy,
35+
)
1936

20-
class FlowOverSphere:
21-
def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy):
22-
# initialize backend
23-
xlb.init(
24-
velocity_set=velocity_set,
25-
default_backend=backend,
26-
default_precision_policy=precision_policy,
27-
)
28-
29-
self.grid_shape = grid_shape
30-
self.velocity_set = velocity_set
31-
self.backend = backend
32-
self.precision_policy = precision_policy
33-
self.omega = omega
34-
35-
self.boundary_conditions = []
36-
self.u_max = 0.04
37-
38-
# Create grid using factory
39-
self.grid = grid_factory(grid_shape, compute_backend=backend)
40-
41-
# Setup the simulation BC and stepper
42-
self._setup()
43-
44-
def _setup(self):
45-
self.setup_boundary_conditions()
46-
self.setup_stepper()
47-
48-
def define_boundary_indices(self):
49-
box = self.grid.bounding_box_indices()
50-
box_no_edge = self.grid.bounding_box_indices(remove_edges=True)
51-
inlet = box_no_edge["left"]
52-
outlet = box_no_edge["right"]
53-
walls = [box["bottom"][i] + box["top"][i] + box["front"][i] + box["back"][i] for i in range(self.velocity_set.d)]
54-
walls = np.unique(np.array(walls), axis=-1).tolist()
55-
56-
sphere_radius = self.grid_shape[1] // 12
57-
x = np.arange(self.grid_shape[0])
58-
y = np.arange(self.grid_shape[1])
59-
z = np.arange(self.grid_shape[2])
60-
X, Y, Z = np.meshgrid(x, y, z, indexing="ij")
61-
indices = np.where(
62-
(X - self.grid_shape[0] // 6) ** 2 + (Y - self.grid_shape[1] // 2) ** 2 + (Z - self.grid_shape[2] // 2) ** 2 < sphere_radius**2
63-
)
64-
sphere = [tuple(indices[i]) for i in range(self.velocity_set.d)]
65-
66-
return inlet, outlet, walls, sphere
67-
68-
def setup_boundary_conditions(self):
69-
inlet, outlet, walls, sphere = self.define_boundary_indices()
70-
bc_left = RegularizedBC("velocity", profile=self.bc_profile(), indices=inlet)
71-
# bc_left = RegularizedBC("velocity", prescribed_value=(self.u_max, 0.0, 0.0), indices=inlet)
72-
bc_walls = FullwayBounceBackBC(indices=walls)
73-
bc_outlet = ExtrapolationOutflowBC(indices=outlet)
74-
bc_sphere = HalfwayBounceBackBC(indices=sphere)
75-
self.boundary_conditions = [bc_walls, bc_left, bc_outlet, bc_sphere]
76-
77-
def setup_stepper(self):
78-
self.stepper = IncompressibleNavierStokesStepper(
79-
grid=self.grid,
80-
boundary_conditions=self.boundary_conditions,
81-
collision_type="BGK",
82-
)
83-
self.f_0, self.f_1, self.bc_mask, self.missing_mask = self.stepper.prepare_fields()
84-
85-
def bc_profile(self):
86-
u_max = self.u_max # u_max = 0.04
87-
# Get the grid dimensions for the y and z directions
88-
H_y = float(self.grid_shape[1] - 1) # Height in y direction
89-
H_z = float(self.grid_shape[2] - 1) # Height in z direction
37+
# Create Grid
38+
grid = grid_factory(grid_shape, compute_backend=compute_backend)
9039

91-
@wp.func
92-
def bc_profile_warp(index: wp.vec3i):
93-
# Poiseuille flow profile: parabolic velocity distribution
94-
y = wp.float32(index[1])
95-
z = wp.float32(index[2])
40+
# Define Boundary Indices
41+
box = grid.bounding_box_indices()
42+
box_no_edge = grid.bounding_box_indices(remove_edges=True)
43+
inlet = box_no_edge["left"]
44+
outlet = box_no_edge["right"]
45+
walls = [box["bottom"][i] + box["top"][i] + box["front"][i] + box["back"][i] for i in range(velocity_set.d)]
46+
walls = np.unique(np.array(walls), axis=-1).tolist()
9647

97-
# Calculate normalized distance from center
98-
y_center = y - (H_y / 2.0)
99-
z_center = z - (H_z / 2.0)
100-
r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0
48+
sphere_radius = grid_shape[1] // 12
49+
x = np.arange(grid_shape[0])
50+
y = np.arange(grid_shape[1])
51+
z = np.arange(grid_shape[2])
52+
X, Y, Z = np.meshgrid(x, y, z, indexing="ij")
53+
indices = np.where((X - grid_shape[0] // 6) ** 2 + (Y - grid_shape[1] // 2) ** 2 + (Z - grid_shape[2] // 2) ** 2 < sphere_radius**2)
54+
sphere = [tuple(indices[i]) for i in range(velocity_set.d)]
10155

102-
# Parabolic profile: u = u_max * (1 - r²)
103-
return wp.vec(u_max * wp.max(0.0, 1.0 - r_squared), length=1)
56+
57+
# Define Boundary Conditions
58+
def bc_profile():
59+
H_y = float(grid_shape[1] - 1) # Height in y direction
60+
H_z = float(grid_shape[2] - 1) # Height in z direction
61+
62+
if compute_backend == ComputeBackend.JAX:
10463

10564
def bc_profile_jax():
106-
y = jnp.arange(self.grid_shape[1])
107-
z = jnp.arange(self.grid_shape[2])
65+
y = jnp.arange(grid_shape[1])
66+
z = jnp.arange(grid_shape[2])
10867
Y, Z = jnp.meshgrid(y, z, indexing="ij")
10968

11069
# Calculate normalized distance from center
@@ -119,56 +78,88 @@ def bc_profile_jax():
11978

12079
return jnp.stack([u_x, u_y, u_z])
12180

122-
if self.backend == ComputeBackend.JAX:
123-
return bc_profile_jax
124-
elif self.backend == ComputeBackend.WARP:
125-
return bc_profile_warp
81+
return bc_profile_jax
82+
83+
elif compute_backend == ComputeBackend.WARP:
84+
85+
@wp.func
86+
def bc_profile_warp(index: wp.vec3i):
87+
# Poiseuille flow profile: parabolic velocity distribution
88+
y = wp.float32(index[1])
89+
z = wp.float32(index[2])
90+
91+
# Calculate normalized distance from center
92+
y_center = y - (H_y / 2.0)
93+
z_center = z - (H_z / 2.0)
94+
r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0
95+
96+
# Parabolic profile: u = u_max * (1 - r²)
97+
return wp.vec(u_max * wp.max(0.0, 1.0 - r_squared), length=1)
98+
99+
return bc_profile_warp
100+
101+
102+
# Initialize Boundary Conditions
103+
bc_left = RegularizedBC("velocity", profile=bc_profile(), indices=inlet)
104+
# Alternatively, use a prescribed velocity profile
105+
# bc_left = RegularizedBC("velocity", prescribed_value=(u_max, 0.0, 0.0), indices=inlet)
106+
bc_walls = FullwayBounceBackBC(indices=walls)
107+
bc_outlet = ExtrapolationOutflowBC(indices=outlet)
108+
bc_sphere = HalfwayBounceBackBC(indices=sphere)
109+
boundary_conditions = [bc_walls, bc_left, bc_outlet, bc_sphere]
110+
111+
# Setup Stepper
112+
stepper = IncompressibleNavierStokesStepper(
113+
grid=grid,
114+
boundary_conditions=boundary_conditions,
115+
collision_type="BGK",
116+
)
117+
f_0, f_1, bc_mask, missing_mask = stepper.prepare_fields()
118+
119+
# Define Macroscopic Calculation
120+
macro = Macroscopic(
121+
compute_backend=ComputeBackend.JAX,
122+
precision_policy=precision_policy,
123+
velocity_set=xlb.velocity_set.D3Q19(precision_policy=precision_policy, compute_backend=ComputeBackend.JAX),
124+
)
125+
126+
127+
# Post-Processing Function
128+
def post_process(step, f_current):
129+
# Convert to JAX array if necessary
130+
if not isinstance(f_current, jnp.ndarray):
131+
f_current = wp.to_jax(f_current)
132+
133+
rho, u = macro(f_current)
134+
135+
# Remove boundary cells
136+
u = u[:, 1:-1, 1:-1, 1:-1]
137+
rho = rho[:, 1:-1, 1:-1, 1:-1]
138+
u_magnitude = jnp.sqrt(u[0] ** 2 + u[1] ** 2 + u[2] ** 2)
139+
140+
fields = {
141+
"u_magnitude": u_magnitude,
142+
"u_x": u[0],
143+
"u_y": u[1],
144+
"u_z": u[2],
145+
"rho": rho[0],
146+
}
147+
148+
# Save the u_magnitude slice at the mid y-plane
149+
save_image(fields["u_magnitude"][:, grid_shape[1] // 2, :], timestep=step)
150+
print(f"Post-processed step {step}: Saved u_magnitude slice at y={grid_shape[1] // 2}")
151+
152+
153+
# -------------------------- Simulation Loop --------------------------
154+
155+
start_time = time.time()
156+
for step in range(num_steps):
157+
f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, omega, step)
158+
f_0, f_1 = f_1, f_0 # Swap the buffers
126159

127-
def run(self, num_steps, post_process_interval=100):
160+
if step % post_process_interval == 0 or step == num_steps - 1:
161+
post_process(step, f_0)
162+
end_time = time.time()
163+
elapsed = end_time - start_time
164+
print(f"Completed step {step}. Time elapsed for {post_process_interval} steps: {elapsed:.6f} seconds.")
128165
start_time = time.time()
129-
for i in range(num_steps):
130-
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, self.omega, i)
131-
self.f_0, self.f_1 = self.f_1, self.f_0
132-
133-
if i % post_process_interval == 0 or i == num_steps - 1:
134-
self.post_process(i)
135-
end_time = time.time()
136-
print(f"Completing {i} iterations. Time elapsed for 1000 LBM steps in {end_time - start_time:.6f} seconds.")
137-
start_time = time.time()
138-
139-
def post_process(self, i):
140-
# Write the results. We'll use JAX backend for the post-processing
141-
if not isinstance(self.f_0, jnp.ndarray):
142-
f_0 = wp.to_jax(self.f_0)
143-
else:
144-
f_0 = self.f_0
145-
146-
macro = Macroscopic(
147-
compute_backend=ComputeBackend.JAX,
148-
precision_policy=self.precision_policy,
149-
velocity_set=xlb.velocity_set.D3Q19(precision_policy=self.precision_policy, backend=ComputeBackend.JAX),
150-
)
151-
rho, u = macro(f_0)
152-
153-
# remove boundary cells
154-
u = u[:, 1:-1, 1:-1, 1:-1]
155-
rho = rho[:, 1:-1, 1:-1, 1:-1]
156-
u_magnitude = (u[0] ** 2 + u[1] ** 2 + u[2] ** 2) ** 0.5
157-
158-
fields = {"u_magnitude": u_magnitude, "u_x": u[0], "u_y": u[1], "u_z": u[2], "rho": rho[0]}
159-
160-
# save_fields_vtk(fields, timestep=i)
161-
save_image(fields["u_magnitude"][:, self.grid_shape[1] // 2, :], timestep=i)
162-
163-
164-
if __name__ == "__main__":
165-
# Running the simulation
166-
grid_shape = (512 // 2, 128 // 2, 128 // 2)
167-
backend = ComputeBackend.WARP
168-
precision_policy = PrecisionPolicy.FP32FP32
169-
170-
velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend)
171-
omega = 1.6
172-
173-
simulation = FlowOverSphere(omega, grid_shape, velocity_set, backend, precision_policy)
174-
simulation.run(num_steps=10000, post_process_interval=1000)

examples/cfd/lid_driven_cavity_2d.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,24 @@
1313

1414

1515
class LidDrivenCavity2D:
16-
def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy):
17-
# initialize backend
16+
def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, compute_backend, precision_policy):
17+
# initialize compute_backend
1818
xlb.init(
1919
velocity_set=velocity_set,
20-
default_backend=backend,
20+
default_backend=compute_backend,
2121
default_precision_policy=precision_policy,
2222
)
2323

2424
self.grid_shape = grid_shape
2525
self.velocity_set = velocity_set
26-
self.backend = backend
26+
self.compute_backend = compute_backend
2727
self.precision_policy = precision_policy
2828
self.omega = omega
2929
self.boundary_conditions = []
3030
self.prescribed_vel = prescribed_vel
3131

3232
# Create grid using factory
33-
self.grid = grid_factory(grid_shape, compute_backend=backend)
33+
self.grid = grid_factory(grid_shape, compute_backend=compute_backend)
3434

3535
# Setup the simulation BC and stepper
3636
self._setup()
@@ -71,17 +71,17 @@ def run(self, num_steps, post_process_interval=100):
7171
self.post_process(i)
7272

7373
def post_process(self, i):
74-
# Write the results. We'll use JAX backend for the post-processing
74+
# Write the results. We'll use JAX compute_backend for the post-processing
7575
if not isinstance(self.f_0, jnp.ndarray):
76-
# If the backend is warp, we need to drop the last dimension added by warp for 2D simulations
76+
# If the compute_backend is warp, we need to drop the last dimension added by warp for 2D simulations
7777
f_0 = wp.to_jax(self.f_0)[..., 0]
7878
else:
7979
f_0 = self.f_0
8080

8181
macro = Macroscopic(
8282
compute_backend=ComputeBackend.JAX,
8383
precision_policy=self.precision_policy,
84-
velocity_set=xlb.velocity_set.D2Q9(precision_policy=self.precision_policy, backend=ComputeBackend.JAX),
84+
velocity_set=xlb.velocity_set.D2Q9(precision_policy=self.precision_policy, compute_backend=ComputeBackend.JAX),
8585
)
8686

8787
rho, u = macro(f_0)
@@ -101,10 +101,10 @@ def post_process(self, i):
101101
# Running the simulation
102102
grid_size = 500
103103
grid_shape = (grid_size, grid_size)
104-
backend = ComputeBackend.WARP
104+
compute_backend = ComputeBackend.WARP
105105
precision_policy = PrecisionPolicy.FP32FP32
106106

107-
velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend)
107+
velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, compute_backend=compute_backend)
108108

109109
# Setting fluid viscosity and relaxation parameter.
110110
Re = 200.0
@@ -113,5 +113,5 @@ def post_process(self, i):
113113
visc = prescribed_vel * clength / Re
114114
omega = 1.0 / (3.0 * visc + 0.5)
115115

116-
simulation = LidDrivenCavity2D(omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy)
116+
simulation = LidDrivenCavity2D(omega, prescribed_vel, grid_shape, velocity_set, compute_backend, precision_policy)
117117
simulation.run(num_steps=50000, post_process_interval=1000)

examples/cfd/lid_driven_cavity_2d_distributed.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88

99
class LidDrivenCavity2D_distributed(LidDrivenCavity2D):
10-
def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy):
11-
super().__init__(omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy)
10+
def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, compute_backend, precision_policy):
11+
super().__init__(omega, prescribed_vel, grid_shape, velocity_set, compute_backend, precision_policy)
1212

1313
def setup_stepper(self):
1414
# Create the base stepper
@@ -30,10 +30,12 @@ def setup_stepper(self):
3030
# Running the simulation
3131
grid_size = 512
3232
grid_shape = (grid_size, grid_size)
33-
backend = ComputeBackend.JAX # Must be JAX for distributed multi-GPU computations. Distributed computations on WARP are not supported yet!
33+
compute_backend = (
34+
ComputeBackend.JAX
35+
) # Must be JAX for distributed multi-GPU computations. Distributed computations on WARP are not supported yet!
3436
precision_policy = PrecisionPolicy.FP32FP32
3537

36-
velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend)
38+
velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, compute_backend=compute_backend)
3739

3840
# Setting fluid viscosity and relaxation parameter.
3941
Re = 200.0
@@ -42,5 +44,5 @@ def setup_stepper(self):
4244
visc = prescribed_vel * clength / Re
4345
omega = 1.0 / (3.0 * visc + 0.5)
4446

45-
simulation = LidDrivenCavity2D_distributed(omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy)
47+
simulation = LidDrivenCavity2D_distributed(omega, prescribed_vel, grid_shape, velocity_set, compute_backend, precision_policy)
4648
simulation.run(num_steps=50000, post_process_interval=1000)

0 commit comments

Comments
 (0)