16
16
import jax .numpy as jnp
17
17
import time
18
18
19
+ # -------------------------- Simulation Setup --------------------------
20
+
21
+ omega = 1.6
22
+ grid_shape = (512 // 2 , 128 // 2 , 128 // 2 )
23
+ compute_backend = ComputeBackend .WARP
24
+ precision_policy = PrecisionPolicy .FP32FP32
25
+ velocity_set = xlb .velocity_set .D3Q19 (precision_policy = precision_policy , compute_backend = compute_backend )
26
+ u_max = 0.04
27
+ num_steps = 10000
28
+ post_process_interval = 1000
29
+
30
+ # Initialize XLB
31
+ xlb .init (
32
+ velocity_set = velocity_set ,
33
+ default_backend = compute_backend ,
34
+ default_precision_policy = precision_policy ,
35
+ )
19
36
20
- class FlowOverSphere :
21
- def __init__ (self , omega , grid_shape , velocity_set , backend , precision_policy ):
22
- # initialize backend
23
- xlb .init (
24
- velocity_set = velocity_set ,
25
- default_backend = backend ,
26
- default_precision_policy = precision_policy ,
27
- )
28
-
29
- self .grid_shape = grid_shape
30
- self .velocity_set = velocity_set
31
- self .backend = backend
32
- self .precision_policy = precision_policy
33
- self .omega = omega
34
-
35
- self .boundary_conditions = []
36
- self .u_max = 0.04
37
-
38
- # Create grid using factory
39
- self .grid = grid_factory (grid_shape , compute_backend = backend )
40
-
41
- # Setup the simulation BC and stepper
42
- self ._setup ()
43
-
44
- def _setup (self ):
45
- self .setup_boundary_conditions ()
46
- self .setup_stepper ()
47
-
48
- def define_boundary_indices (self ):
49
- box = self .grid .bounding_box_indices ()
50
- box_no_edge = self .grid .bounding_box_indices (remove_edges = True )
51
- inlet = box_no_edge ["left" ]
52
- outlet = box_no_edge ["right" ]
53
- walls = [box ["bottom" ][i ] + box ["top" ][i ] + box ["front" ][i ] + box ["back" ][i ] for i in range (self .velocity_set .d )]
54
- walls = np .unique (np .array (walls ), axis = - 1 ).tolist ()
55
-
56
- sphere_radius = self .grid_shape [1 ] // 12
57
- x = np .arange (self .grid_shape [0 ])
58
- y = np .arange (self .grid_shape [1 ])
59
- z = np .arange (self .grid_shape [2 ])
60
- X , Y , Z = np .meshgrid (x , y , z , indexing = "ij" )
61
- indices = np .where (
62
- (X - self .grid_shape [0 ] // 6 ) ** 2 + (Y - self .grid_shape [1 ] // 2 ) ** 2 + (Z - self .grid_shape [2 ] // 2 ) ** 2 < sphere_radius ** 2
63
- )
64
- sphere = [tuple (indices [i ]) for i in range (self .velocity_set .d )]
65
-
66
- return inlet , outlet , walls , sphere
67
-
68
- def setup_boundary_conditions (self ):
69
- inlet , outlet , walls , sphere = self .define_boundary_indices ()
70
- bc_left = RegularizedBC ("velocity" , profile = self .bc_profile (), indices = inlet )
71
- # bc_left = RegularizedBC("velocity", prescribed_value=(self.u_max, 0.0, 0.0), indices=inlet)
72
- bc_walls = FullwayBounceBackBC (indices = walls )
73
- bc_outlet = ExtrapolationOutflowBC (indices = outlet )
74
- bc_sphere = HalfwayBounceBackBC (indices = sphere )
75
- self .boundary_conditions = [bc_walls , bc_left , bc_outlet , bc_sphere ]
76
-
77
- def setup_stepper (self ):
78
- self .stepper = IncompressibleNavierStokesStepper (
79
- grid = self .grid ,
80
- boundary_conditions = self .boundary_conditions ,
81
- collision_type = "BGK" ,
82
- )
83
- self .f_0 , self .f_1 , self .bc_mask , self .missing_mask = self .stepper .prepare_fields ()
84
-
85
- def bc_profile (self ):
86
- u_max = self .u_max # u_max = 0.04
87
- # Get the grid dimensions for the y and z directions
88
- H_y = float (self .grid_shape [1 ] - 1 ) # Height in y direction
89
- H_z = float (self .grid_shape [2 ] - 1 ) # Height in z direction
37
+ # Create Grid
38
+ grid = grid_factory (grid_shape , compute_backend = compute_backend )
90
39
91
- @wp .func
92
- def bc_profile_warp (index : wp .vec3i ):
93
- # Poiseuille flow profile: parabolic velocity distribution
94
- y = wp .float32 (index [1 ])
95
- z = wp .float32 (index [2 ])
40
+ # Define Boundary Indices
41
+ box = grid .bounding_box_indices ()
42
+ box_no_edge = grid .bounding_box_indices (remove_edges = True )
43
+ inlet = box_no_edge ["left" ]
44
+ outlet = box_no_edge ["right" ]
45
+ walls = [box ["bottom" ][i ] + box ["top" ][i ] + box ["front" ][i ] + box ["back" ][i ] for i in range (velocity_set .d )]
46
+ walls = np .unique (np .array (walls ), axis = - 1 ).tolist ()
96
47
97
- # Calculate normalized distance from center
98
- y_center = y - (H_y / 2.0 )
99
- z_center = z - (H_z / 2.0 )
100
- r_squared = (2.0 * y_center / H_y ) ** 2.0 + (2.0 * z_center / H_z ) ** 2.0
48
+ sphere_radius = grid_shape [1 ] // 12
49
+ x = np .arange (grid_shape [0 ])
50
+ y = np .arange (grid_shape [1 ])
51
+ z = np .arange (grid_shape [2 ])
52
+ X , Y , Z = np .meshgrid (x , y , z , indexing = "ij" )
53
+ indices = np .where ((X - grid_shape [0 ] // 6 ) ** 2 + (Y - grid_shape [1 ] // 2 ) ** 2 + (Z - grid_shape [2 ] // 2 ) ** 2 < sphere_radius ** 2 )
54
+ sphere = [tuple (indices [i ]) for i in range (velocity_set .d )]
101
55
102
- # Parabolic profile: u = u_max * (1 - r²)
103
- return wp .vec (u_max * wp .max (0.0 , 1.0 - r_squared ), length = 1 )
56
+
57
+ # Define Boundary Conditions
58
+ def bc_profile ():
59
+ H_y = float (grid_shape [1 ] - 1 ) # Height in y direction
60
+ H_z = float (grid_shape [2 ] - 1 ) # Height in z direction
61
+
62
+ if compute_backend == ComputeBackend .JAX :
104
63
105
64
def bc_profile_jax ():
106
- y = jnp .arange (self . grid_shape [1 ])
107
- z = jnp .arange (self . grid_shape [2 ])
65
+ y = jnp .arange (grid_shape [1 ])
66
+ z = jnp .arange (grid_shape [2 ])
108
67
Y , Z = jnp .meshgrid (y , z , indexing = "ij" )
109
68
110
69
# Calculate normalized distance from center
@@ -119,56 +78,88 @@ def bc_profile_jax():
119
78
120
79
return jnp .stack ([u_x , u_y , u_z ])
121
80
122
- if self .backend == ComputeBackend .JAX :
123
- return bc_profile_jax
124
- elif self .backend == ComputeBackend .WARP :
125
- return bc_profile_warp
81
+ return bc_profile_jax
82
+
83
+ elif compute_backend == ComputeBackend .WARP :
84
+
85
+ @wp .func
86
+ def bc_profile_warp (index : wp .vec3i ):
87
+ # Poiseuille flow profile: parabolic velocity distribution
88
+ y = wp .float32 (index [1 ])
89
+ z = wp .float32 (index [2 ])
90
+
91
+ # Calculate normalized distance from center
92
+ y_center = y - (H_y / 2.0 )
93
+ z_center = z - (H_z / 2.0 )
94
+ r_squared = (2.0 * y_center / H_y ) ** 2.0 + (2.0 * z_center / H_z ) ** 2.0
95
+
96
+ # Parabolic profile: u = u_max * (1 - r²)
97
+ return wp .vec (u_max * wp .max (0.0 , 1.0 - r_squared ), length = 1 )
98
+
99
+ return bc_profile_warp
100
+
101
+
102
+ # Initialize Boundary Conditions
103
+ bc_left = RegularizedBC ("velocity" , profile = bc_profile (), indices = inlet )
104
+ # Alternatively, use a prescribed velocity profile
105
+ # bc_left = RegularizedBC("velocity", prescribed_value=(u_max, 0.0, 0.0), indices=inlet)
106
+ bc_walls = FullwayBounceBackBC (indices = walls )
107
+ bc_outlet = ExtrapolationOutflowBC (indices = outlet )
108
+ bc_sphere = HalfwayBounceBackBC (indices = sphere )
109
+ boundary_conditions = [bc_walls , bc_left , bc_outlet , bc_sphere ]
110
+
111
+ # Setup Stepper
112
+ stepper = IncompressibleNavierStokesStepper (
113
+ grid = grid ,
114
+ boundary_conditions = boundary_conditions ,
115
+ collision_type = "BGK" ,
116
+ )
117
+ f_0 , f_1 , bc_mask , missing_mask = stepper .prepare_fields ()
118
+
119
+ # Define Macroscopic Calculation
120
+ macro = Macroscopic (
121
+ compute_backend = ComputeBackend .JAX ,
122
+ precision_policy = precision_policy ,
123
+ velocity_set = xlb .velocity_set .D3Q19 (precision_policy = precision_policy , compute_backend = ComputeBackend .JAX ),
124
+ )
125
+
126
+
127
+ # Post-Processing Function
128
+ def post_process (step , f_current ):
129
+ # Convert to JAX array if necessary
130
+ if not isinstance (f_current , jnp .ndarray ):
131
+ f_current = wp .to_jax (f_current )
132
+
133
+ rho , u = macro (f_current )
134
+
135
+ # Remove boundary cells
136
+ u = u [:, 1 :- 1 , 1 :- 1 , 1 :- 1 ]
137
+ rho = rho [:, 1 :- 1 , 1 :- 1 , 1 :- 1 ]
138
+ u_magnitude = jnp .sqrt (u [0 ] ** 2 + u [1 ] ** 2 + u [2 ] ** 2 )
139
+
140
+ fields = {
141
+ "u_magnitude" : u_magnitude ,
142
+ "u_x" : u [0 ],
143
+ "u_y" : u [1 ],
144
+ "u_z" : u [2 ],
145
+ "rho" : rho [0 ],
146
+ }
147
+
148
+ # Save the u_magnitude slice at the mid y-plane
149
+ save_image (fields ["u_magnitude" ][:, grid_shape [1 ] // 2 , :], timestep = step )
150
+ print (f"Post-processed step { step } : Saved u_magnitude slice at y={ grid_shape [1 ] // 2 } " )
151
+
152
+
153
+ # -------------------------- Simulation Loop --------------------------
154
+
155
+ start_time = time .time ()
156
+ for step in range (num_steps ):
157
+ f_0 , f_1 = stepper (f_0 , f_1 , bc_mask , missing_mask , omega , step )
158
+ f_0 , f_1 = f_1 , f_0 # Swap the buffers
126
159
127
- def run (self , num_steps , post_process_interval = 100 ):
160
+ if step % post_process_interval == 0 or step == num_steps - 1 :
161
+ post_process (step , f_0 )
162
+ end_time = time .time ()
163
+ elapsed = end_time - start_time
164
+ print (f"Completed step { step } . Time elapsed for { post_process_interval } steps: { elapsed :.6f} seconds." )
128
165
start_time = time .time ()
129
- for i in range (num_steps ):
130
- self .f_0 , self .f_1 = self .stepper (self .f_0 , self .f_1 , self .bc_mask , self .missing_mask , self .omega , i )
131
- self .f_0 , self .f_1 = self .f_1 , self .f_0
132
-
133
- if i % post_process_interval == 0 or i == num_steps - 1 :
134
- self .post_process (i )
135
- end_time = time .time ()
136
- print (f"Completing { i } iterations. Time elapsed for 1000 LBM steps in { end_time - start_time :.6f} seconds." )
137
- start_time = time .time ()
138
-
139
- def post_process (self , i ):
140
- # Write the results. We'll use JAX backend for the post-processing
141
- if not isinstance (self .f_0 , jnp .ndarray ):
142
- f_0 = wp .to_jax (self .f_0 )
143
- else :
144
- f_0 = self .f_0
145
-
146
- macro = Macroscopic (
147
- compute_backend = ComputeBackend .JAX ,
148
- precision_policy = self .precision_policy ,
149
- velocity_set = xlb .velocity_set .D3Q19 (precision_policy = self .precision_policy , backend = ComputeBackend .JAX ),
150
- )
151
- rho , u = macro (f_0 )
152
-
153
- # remove boundary cells
154
- u = u [:, 1 :- 1 , 1 :- 1 , 1 :- 1 ]
155
- rho = rho [:, 1 :- 1 , 1 :- 1 , 1 :- 1 ]
156
- u_magnitude = (u [0 ] ** 2 + u [1 ] ** 2 + u [2 ] ** 2 ) ** 0.5
157
-
158
- fields = {"u_magnitude" : u_magnitude , "u_x" : u [0 ], "u_y" : u [1 ], "u_z" : u [2 ], "rho" : rho [0 ]}
159
-
160
- # save_fields_vtk(fields, timestep=i)
161
- save_image (fields ["u_magnitude" ][:, self .grid_shape [1 ] // 2 , :], timestep = i )
162
-
163
-
164
- if __name__ == "__main__" :
165
- # Running the simulation
166
- grid_shape = (512 // 2 , 128 // 2 , 128 // 2 )
167
- backend = ComputeBackend .WARP
168
- precision_policy = PrecisionPolicy .FP32FP32
169
-
170
- velocity_set = xlb .velocity_set .D3Q19 (precision_policy = precision_policy , backend = backend )
171
- omega = 1.6
172
-
173
- simulation = FlowOverSphere (omega , grid_shape , velocity_set , backend , precision_policy )
174
- simulation .run (num_steps = 10000 , post_process_interval = 1000 )
0 commit comments