Skip to content

Commit 237ac22

Browse files
authored
Merge pull request #95 from mehdiataei/main
Get ready for 0.2.1 release. Fixed a bug running JAX on CPU.
2 parents f7bed81 + 2f8ffe8 commit 237ac22

File tree

2 files changed

+28
-12
lines changed

2 files changed

+28
-12
lines changed

CHANGELOG.md

+13-5
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
## [Unreleased]
99
- _No changes yet_ <!-- Placeholder for future changes -->
1010

11+
12+
## [0.2.1] - 2024-12-05
13+
1114
### Fixed
12-
- mkdocs is now configured correctly for the new project structure
13-
- JAX installation is now handled correctly for different configurations (CPU, CUDA, TPU)
15+
- mkdocs is now configured correctly for the new project structure
16+
- JAX installation is now handled correctly for different configurations (CPU, CUDA, TPU)
17+
- Fixed a couple of bugs in 2D regularied_bc and kbc (Warp) that emerged after merging 2d and 3d kernels
18+
19+
### Added
20+
21+
- Added abstraction layer for boundary condition efficient encoding/decoding of auxiliary data
22+
- Added the capability to add profiles to boundary conditions
23+
- Added prepare_fields method to the Stepper class to allow for more automatic preparation of fields
1424

1525
## [0.2.0] - 2024-10-22
1626

1727
### Added
1828
- XLB is now installable via pip
1929
- Complete rewrite of the codebase for better modularity and extensibility based on "Operators" design pattern
20-
- Added NVIDIA's Warp backend for state-of-the-art performance
21-
- Added abstraction layer for boundary condition efficient encoding/decoding of auxiliary data
22-
- Added the capability to add profiles to boundary conditions
30+
- Added NVIDIA's Warp backend for state-of-the-art performance

xlb/default_config.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def init(velocity_set, default_backend, default_precision_policy):
2121

2222
wp.init() # TODO: Must be removed in the future versions of WARP
2323
elif default_backend == ComputeBackend.JAX:
24-
check_multi_gpu_support()
24+
check_backend_support()
2525
else:
2626
raise ValueError(f"Unsupported compute backend: {default_backend}")
2727

@@ -30,11 +30,19 @@ def default_backend() -> ComputeBackend:
3030
return DefaultConfig.default_backend
3131

3232

33-
def check_multi_gpu_support():
34-
gpus = jax.devices("gpu")
35-
if len(gpus) > 1:
36-
print("Multi-GPU support is available: {} GPUs detected.".format(len(gpus)))
37-
elif len(gpus) == 1:
38-
print("Single-GPU support is available: 1 GPU detected.")
33+
def check_backend_support():
34+
if jax.devices()[0].device_kind == "gpu":
35+
gpus = jax.devices("gpu")
36+
if len(gpus) > 1:
37+
print("Multi-GPU support is available: {} GPUs detected.".format(len(gpus)))
38+
elif len(gpus) == 1:
39+
print("Single-GPU support is available: 1 GPU detected.")
40+
41+
if jax.devices()[0].device_kind == "tpu":
42+
tpus = jax.devices("tpu")
43+
if len(tpus) > 1:
44+
print("Multi-TPU support is available: {} TPUs detected.".format(len(tpus)))
45+
elif len(tpus) == 1:
46+
print("Single-TPU support is available: 1 TPU detected.")
3947
else:
4048
print("No GPU support is available; CPU fallback will be used.")

0 commit comments

Comments
 (0)