Skip to content

Adding immersed boundary conditions to boundaries.py #250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
190 changes: 170 additions & 20 deletions jax_cfd/base/boundaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Classes that specify how boundary conditions are applied to arrays."""

import dataclasses
from dataclasses import replace
import math
from typing import Optional, Sequence, Tuple, Union

Expand All @@ -29,11 +30,11 @@
GridVariableVector = grids.GridVariableVector
Array = Union[np.ndarray, jax.Array]


class BCType:
PERIODIC = 'periodic'
DIRICHLET = 'dirichlet'
NEUMANN = 'neumann'
IMMERSED = 'immersed'


class Padding:
Expand All @@ -49,8 +50,8 @@ class ConstantBoundaryConditions(BoundaryConditions):
grid = Grid((10, 10))
array = GridArray(np.zeros((10, 10)), offset=(0.5, 0.5), grid)
bc = ConstantBoundaryConditions(((BCType.PERIODIC, BCType.PERIODIC),
(BCType.DIRICHLET, BCType.DIRICHLET)),
((0.0, 10.0),(1.0, 0.0)))
(BCType.DIRICHLET, BCType.DIRICHLET)),
((0.0, 10.0), (1.0, 0.0)))
u = GridVariable(array, bc)

Attributes:
Expand All @@ -60,12 +61,21 @@ class ConstantBoundaryConditions(BoundaryConditions):
types: Tuple[Tuple[str, str], ...]
bc_values: Tuple[Tuple[Optional[float], Optional[float]], ...]

def __init__(self, types: Sequence[Tuple[str, str]],
values: Sequence[Tuple[Optional[float], Optional[float]]]):
mask: Optional[GridArray]
immersed_bc_value: float

def __init__(self,
types: Sequence[Tuple[str, str]],
values: Sequence[Tuple[Optional[float], Optional[float]]],
mask: Optional[GridArray] = None,
immersed_bc_value: float = 0.0):
types = tuple(types)
values = tuple(values)
object.__setattr__(self, 'types', types)
object.__setattr__(self, 'bc_values', values)
object.__setattr__(self, 'immersed_body_mask', mask)
object.__setattr__(self, 'immersed_bc_value', immersed_bc_value)
object.__setattr__(self, 'mask', mask)

@property
def constant_values(self) -> Tuple[Tuple[float, float], ...]:
Expand Down Expand Up @@ -197,6 +207,7 @@ def make_padding(width):
# self.values are ignored here
pad_kwargs = dict(mode='wrap')
data = jnp.pad(data, full_padding, **pad_kwargs)


elif bc_type == BCType.DIRICHLET:
if np.isclose(u.offset[axis] % 1, 0.5): # cell center
Expand Down Expand Up @@ -301,7 +312,12 @@ def make_padding(width):
data,
full_padding,
mode='constant',
constant_values=self.constant_values)))
constant_values=self.bc_values)))

elif bc_type == BCType.IMMERSED:
# Only have one padding immersed for now
data = jnp.pad(data, full_padding, mode='edge')

else:
raise ValueError('invalid boundary type')

Expand Down Expand Up @@ -509,7 +525,21 @@ def pad_and_impose_bc(
u = self._pad(u, 1, axis, mode=mode)
elif np.isclose(offset_to_pad_to[axis], 0.0):
u = self._pad(u, -1, axis, mode=mode)
return grids.GridVariable(u, self)

grid_var_with_bc = grids.GridVariable(u, self)

if self.immersed_body_mask is not None:
# In the case of an immersed body mask being present,
# we overwrite the 'solid' cells with immersed_bc_value.
masked_data = (self.immersed_body_mask.data * grid_var_with_bc.array.data
+ (1.0 - self.immersed_body_mask.data) * self.immersed_bc_value)
masked_arr = grids.GridArray(
masked_data, grid_var_with_bc.array.offset, grid_var_with_bc.array.grid
)
return grids.GridVariable(masked_arr, self)

else:
return grid_var_with_bc

def impose_bc(self, u: grids.GridArray) -> grids.GridVariable:
"""Returns GridVariable with correct boundary condition.
Expand All @@ -529,6 +559,40 @@ def impose_bc(self, u: grids.GridArray) -> grids.GridVariable:

trim = _trim

@staticmethod
def create_immersed_shape_mask(grid, shape, size_fraction):
domain_lengths = [grid.shape[i] * grid.step[i] for i in range(grid.ndim)]
center = tuple(0.5 * L for L in domain_lengths)
min_length = min(domain_lengths)

if shape.lower() == 'circle':
# For a circle, the radius is a fraction of the minimum domain length.
radius = size_fraction * min_length
return create_immersed_circle_mask(grid, center, radius)
elif shape.lower() == 'square':
# For a square, the half-width is a fraction of the minimum domain length.
half_width = size_fraction * min_length
return create_immersed_square_mask(grid, center, half_width)
else:
raise ValueError("Unsupported shape: choose 'circle' or 'square'")

def create_immersed_circle_mask(grid, center, radius):
coords = jnp.meshgrid(*[
(jnp.arange(grid.shape[i]) + 0.5) * grid.step[i]
for i in range(grid.ndim)], indexing='ij')
squared_dist = sum((coords[i] - center[i])**2 for i in range(grid.ndim))
mask_data = jnp.where(squared_dist <= radius**2, 0.0, 1.0)
return grids.GridArray(data=mask_data, offset=(0.5,) * grid.ndim, grid=grid)

def create_immersed_square_mask(grid, center, half_width):
coords = jnp.meshgrid(*[
(jnp.arange(grid.shape[i]) + 0.5) * grid.step[i]
for i in range(grid.ndim)], indexing='ij')
abs_dists = [jnp.abs(coords[i] - center[i]) for i in range(grid.ndim)]
max_dist = jnp.max(jnp.stack(abs_dists, axis=0), axis=0)
mask_data = jnp.where(max_dist <= half_width, 0.0, 1.0)
return grids.GridArray(data=mask_data, offset=(0.5,) * grid.ndim, grid=grid)


class HomogeneousBoundaryConditions(ConstantBoundaryConditions):
"""Boundary conditions for a PDE variable.
Expand Down Expand Up @@ -626,6 +690,47 @@ def channel_flow_boundary_conditions(
return HomogeneousBoundaryConditions(bc_type)
else:
return ConstantBoundaryConditions(bc_type, bc_vals)


def channel_flow_with_simple_immersed_body_boundary_conditions(
grid: grids.Grid,
bc_vals: Optional[Sequence[Tuple[float, float]]] = None,
shape: str = 'circle',
shape_size: float = 0.25,
bc_value: float = 0.0, # Set to 0.0 by default for "no-slip" homogenous condition
) -> ConstantBoundaryConditions:
"""Returns channel-flow BCs with a simple immersed boundary in the domain.

Boundary conditions are periodic in dimension 0 and Dirichlet in dimension 1.

An immersed solid body (e.g. circle or square) is placed in the center of the domain
enforcing constant values (`bc_value`) within the solid region.

Args:
grid: a Grid object defining the simulation domain and grid spacing.
bc_vals: optional tuple specifying lower and upper BC values per dimension.
If None, homogeneous Dirichlet (zero-value) BCs are used on walls,
and periodic dimensions use (None, None).
shape: type of immersed geometry, currently supporting 'circle' or 'square'.
shape_size: size of the immersed body as a fraction of the domain's smallest
dimension (e.g. radius for circle, half-width for square).
bc_value: scalar value enforced inside the immersed solid region.

Returns:
ConstantBoundaryConditions instance specifying the immersed body and
outer-domain channel flow boundary conditions.
"""
underlying_bc = channel_flow_boundary_conditions(grid.ndim, bc_vals)
immersed_mask = ConstantBoundaryConditions.create_immersed_shape_mask(
grid, shape, shape_size
)

return ConstantBoundaryConditions(
types=underlying_bc.types,
values=underlying_bc.bc_values,
mask=immersed_mask,
immersed_bc_value=bc_value,
)


def periodic_and_neumann_boundary_conditions(
Expand Down Expand Up @@ -763,14 +868,51 @@ def get_advection_flux_bc_from_velocity_and_scalar(
Returns:
BoundaryCondition instance for advection flux of c in flux_direction.
"""

# for the immersed body logic, If u.bc includes a 'mask' (immersed BC),
# we temporarily ignore the mask to compute "domain" flux boundaries.
# That means we create a dummy BC with mask=None and immersed_bc_value=0.0,
# then do the usual logic.
bc = u.bc
if getattr(bc, 'immersed_body_mask', None) is not None:


# Build a "dummy" BC with the same domain boundary settings, but no mask
bc_no_mask = ConstantBoundaryConditions(
types=bc.types,
values=bc.bc_values, # Note the attribute name vs parameter name difference
mask=None,
immersed_bc_value=0.0
)
# Wrap this in a dummy velocity var to pass below:
class _DummyGridVar:
def __init__(self, array, bc):
self.array = array
self.bc = bc
self.grid = array.grid

dummy_vel = _DummyGridVar(u.array, bc_no_mask)
# Now we do the normal logic on dummy_vel
u_for_flux = dummy_vel
else:
# If there's no mask, proceed as usual:
u_for_flux = u

# usual (non-immersed boundary condition) logic after this line
# ----------------------------------------------------------------------

# only no penetration and periodic boundaries are supported.
flux_bc_types = []
flux_bc_values = []
if not isinstance(u.bc, HomogeneousBoundaryConditions):
if not (isinstance(u_for_flux.bc, HomogeneousBoundaryConditions) or
(isinstance(u_for_flux.bc, ConstantBoundaryConditions) and
all(all(v == 0.0 for v in values if v is not None) for values in u_for_flux.bc.bc_values))):
raise NotImplementedError(
f'Flux boundary condition is not implemented for velocity with {u.bc}')
f'Flux boundary condition is not implemented for velocity with {u_for_flux.bc}'
)

for axis in range(c.grid.ndim):
if u.bc.types[axis][0] == 'periodic':
if u_for_flux.bc.types[axis][0] == 'periodic':
flux_bc_types.append((BCType.PERIODIC, BCType.PERIODIC))
flux_bc_values.append((None, None))
elif flux_direction != axis:
Expand All @@ -790,32 +932,40 @@ def get_advection_flux_bc_from_velocity_and_scalar(
else:
flux_bc_types_ax = []
flux_bc_values_ax = []
for i in range(2): # lower and upper boundary.
for i in range(2): # lower and upper boundary

# case 1: nonpourous boundary
if (u.bc.types[axis][i] == BCType.DIRICHLET and
u.bc.bc_values[axis][i] == 0.0):
if (
u_for_flux.bc.types[axis][i] == BCType.DIRICHLET
and u_for_flux.bc.bc_values[axis][i] == 0.0
):
flux_bc_types_ax.append(BCType.DIRICHLET)
flux_bc_values_ax.append(0.0)

# case 2: zero flux boundary
elif (u.bc.types[axis][i] == BCType.NEUMANN and
c.bc.types[axis][i] == BCType.NEUMANN):
elif (
u_for_flux.bc.types[axis][i] == BCType.NEUMANN
and c.bc.types[axis][i] == BCType.NEUMANN
):
if not isinstance(c.bc, ConstantBoundaryConditions):
raise NotImplementedError(
'Flux boundary condition is not implemented for scalar' +
f' with {c.bc}')
'Flux boundary condition is not implemented for scalar'
f' with {c.bc}'
)
if not np.isclose(c.bc.bc_values[axis][i], 0.0):
raise NotImplementedError(
'Flux boundary condition is not implemented for scalar' +
f' with {c.bc}')
'Flux boundary condition is not implemented for scalar'
f' with {c.bc}'
)
flux_bc_types_ax.append(BCType.NEUMANN)
flux_bc_values_ax.append(0.0)

# no other case is supported
else:
raise NotImplementedError(
f'Flux boundary condition is not implemented for {u.bc, c.bc}')
f'Flux boundary condition is not implemented for {(u_for_flux.bc, c.bc)}'
)
flux_bc_types.append(flux_bc_types_ax)
flux_bc_values.append(flux_bc_values_ax)

return ConstantBoundaryConditions(flux_bc_types, flux_bc_values)
6 changes: 6 additions & 0 deletions jax_cfd/data/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,19 @@ def trajectory_to_images(
compute_norm_fn: NormFn = quantile_normalize_fn,
cmap: mpl.colors.ListedColormap = sns.cm.icefire, # pytype: disable=module-attr
longest_side: Optional[int] = None,
rotation_angle: int = 0, # in degrees
) -> List[Image.Image]:
"""Converts scalar trajectory with leading time axis into a list of images."""
images = []

for i, image_data in enumerate(trajectory):
norm = compute_norm_fn(image_data, i)
mappable = cm.ScalarMappable(norm=norm, cmap=cmap)
img = Image.fromarray(mappable.to_rgba(image_data, bytes=True))

if rotation_angle != 0:
img = img.rotate(rotation_angle, expand=True)

if longest_side is not None:
img = resize_image(img, longest_side)
images.append(img)
Expand Down
Loading