Skip to content

Commit 7b41cb4

Browse files
committed
Fix and clean up multihost preprocessing.
Multihost JAX programs require the same JAX program to run across all hosts. To enforce this, we need to synchronize all input statistics, and repeat preprocessing if the stats have changed. Also simplified the synchronization program for better efficiency by collecting all input statistics into a single JAX array.
1 parent 0ef7408 commit 7b41cb4

File tree

2 files changed

+105
-191
lines changed

2 files changed

+105
-191
lines changed

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 37 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
table_stacking as jte_table_stacking,
1616
)
1717
from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
18-
from keras.src import backend
1918

2019
from keras_rs.src import types
2120
from keras_rs.src.layers.embedding import base_distributed_embedding
@@ -247,23 +246,6 @@ def _create_sparsecore_distribution(
247246
)
248247
return sparsecore_distribution, sparsecore_layout
249248

250-
def _create_cpu_distribution(
251-
self, cpu_axis_name: str = "cpu"
252-
) -> tuple[
253-
keras.distribution.ModelParallel, keras.distribution.TensorLayout
254-
]:
255-
"""Share a variable across all CPU processes."""
256-
cpu_devices = jax.devices("cpu")
257-
device_mesh = keras.distribution.DeviceMesh(
258-
(len(cpu_devices),), [cpu_axis_name], cpu_devices
259-
)
260-
replicated_layout = keras.distribution.TensorLayout([], device_mesh)
261-
layout_map = keras.distribution.LayoutMap(device_mesh=device_mesh)
262-
cpu_distribution = keras.distribution.ModelParallel(
263-
layout_map=layout_map
264-
)
265-
return cpu_distribution, replicated_layout
266-
267249
def _add_sparsecore_weight(
268250
self,
269251
name: str,
@@ -405,11 +387,6 @@ def sparsecore_build(
405387
self._sparsecore_layout = sparsecore_layout
406388
self._sparsecore_distribution = sparsecore_distribution
407389

408-
# Distribution for CPU operations.
409-
cpu_distribution, cpu_layout = self._create_cpu_distribution()
410-
self._cpu_distribution = cpu_distribution
411-
self._cpu_layout = cpu_layout
412-
413390
mesh = sparsecore_distribution.device_mesh.backend_mesh
414391
global_device_count = mesh.devices.size
415392
num_sc_per_device = jte_utils.num_sparsecores_per_device(
@@ -466,10 +443,6 @@ def sparsecore_build(
466443
# Collect all stacked tables.
467444
table_specs = embedding_utils.get_table_specs(feature_specs)
468445
table_stacks = embedding_utils.get_table_stacks(table_specs)
469-
stacked_table_specs = {
470-
stack_name: stack[0].stacked_table_spec
471-
for stack_name, stack in table_stacks.items()
472-
}
473446

474447
# Create variables for all stacked tables and slot variables.
475448
with sparsecore_distribution.scope():
@@ -502,50 +475,6 @@ def sparsecore_build(
502475
)
503476
self._iterations.overwrite_with_gradient = True
504477

505-
with cpu_distribution.scope():
506-
# Create variables to track static buffer size and max IDs for each
507-
# table during preprocessing. These variables are shared across all
508-
# processes on CPU. We don't add these via `add_weight` because we
509-
# can't have them passed to the training function.
510-
replicated_zeros_initializer = ShardedInitializer(
511-
"zeros", cpu_layout
512-
)
513-
514-
with backend.name_scope(self.name, caller=self):
515-
self._preprocessing_buffer_size = {
516-
table_name: backend.Variable(
517-
initializer=replicated_zeros_initializer,
518-
shape=(),
519-
dtype=backend.standardize_dtype("int32"),
520-
trainable=False,
521-
name=table_name + ":preprocessing:buffer_size",
522-
)
523-
for table_name in stacked_table_specs.keys()
524-
}
525-
self._preprocessing_max_unique_ids_per_partition = {
526-
table_name: backend.Variable(
527-
shape=(),
528-
name=table_name
529-
+ ":preprocessing:max_unique_ids_per_partition",
530-
initializer=replicated_zeros_initializer,
531-
dtype=backend.standardize_dtype("int32"),
532-
trainable=False,
533-
)
534-
for table_name in stacked_table_specs.keys()
535-
}
536-
537-
self._preprocessing_max_ids_per_partition = {
538-
table_name: backend.Variable(
539-
shape=(),
540-
name=table_name
541-
+ ":preprocessing:max_ids_per_partition",
542-
initializer=replicated_zeros_initializer,
543-
dtype=backend.standardize_dtype("int32"),
544-
trainable=False,
545-
)
546-
for table_name in stacked_table_specs.keys()
547-
}
548-
549478
self._config = jte_embedding_lookup.EmbeddingLookupConfiguration(
550479
feature_specs,
551480
mesh=mesh,
@@ -660,125 +589,64 @@ def _sparsecore_preprocess(
660589
mesh.devices.item(0)
661590
)
662591

663-
# Get current buffer size/max_ids.
664-
previous_max_ids_per_partition = keras.tree.map_structure(
665-
lambda max_ids_per_partition: max_ids_per_partition.value.item(),
666-
self._preprocessing_max_ids_per_partition,
667-
)
668-
previous_max_unique_ids_per_partition = keras.tree.map_structure(
669-
lambda max_unique_ids_per_partition: (
670-
max_unique_ids_per_partition.value.item()
671-
),
672-
self._preprocessing_max_unique_ids_per_partition,
673-
)
674-
previous_buffer_size = keras.tree.map_structure(
675-
lambda buffer_size: buffer_size.value.item(),
676-
self._preprocessing_buffer_size,
677-
)
678-
679592
preprocessed, stats = embedding_utils.stack_and_shard_samples(
680593
self._config.feature_specs,
681594
samples,
682595
local_device_count,
683596
global_device_count,
684597
num_sc_per_device,
685-
static_buffer_size=previous_buffer_size,
686598
)
687599

688-
# Extract max unique IDs and buffer sizes.
689-
# We need to replicate this value across all local CPU devices.
690600
if training:
601+
# Synchronize input statistics across all devices and update the
602+
# underlying stacked tables specs in the feature specs.
603+
prev_stats = embedding_utils.get_stacked_table_stats(
604+
self._config.feature_specs
605+
)
606+
607+
# Take the maximum with existing stats.
608+
stats = keras.tree.map_structure(max, prev_stats, stats)
609+
610+
# Flatten the stats so we can more efficiently transfer them
611+
# between hosts. We use jax.tree because we will later need to
612+
# unflatten.
613+
flat_stats, stats_treedef = jax.tree.flatten(stats)
614+
615+
# In the case of multiple local CPU devices per host, we need to
616+
# replicate the stats to placate JAX collectives.
691617
num_local_cpu_devices = jax.local_device_count("cpu")
692-
local_max_ids_per_partition = {
693-
table_name: np.repeat(
694-
# Maximum across all partitions and previous max.
695-
np.maximum(
696-
np.max(elems),
697-
previous_max_ids_per_partition[table_name],
698-
),
699-
num_local_cpu_devices,
700-
)
701-
for table_name, elems in stats.max_ids_per_partition.items()
702-
}
703-
local_max_unique_ids_per_partition = {
704-
name: np.repeat(
705-
# Maximum across all partitions and previous max.
706-
np.maximum(
707-
np.max(elems),
708-
previous_max_unique_ids_per_partition[name],
709-
),
710-
num_local_cpu_devices,
711-
)
712-
for name, elems in stats.max_unique_ids_per_partition.items()
713-
}
714-
local_buffer_size = {
715-
table_name: np.repeat(
716-
np.maximum(
717-
np.max(
718-
# Round values up to the next multiple of 8.
719-
# Currently using this as a proxy for the actual
720-
# required buffer size.
721-
((elems + 7) // 8) * 8
722-
)
723-
* global_device_count
724-
* num_sc_per_device
725-
* local_device_count
726-
* num_sc_per_device,
727-
previous_buffer_size[table_name],
728-
),
729-
num_local_cpu_devices,
730-
)
731-
for table_name, elems in stats.max_ids_per_partition.items()
732-
}
618+
tiled_stats = np.tile(
619+
np.array(flat_stats, dtype=np.int32), (num_local_cpu_devices, 1)
620+
)
733621

734622
# Aggregate variables across all processes/devices.
735623
max_across_cpus = jax.pmap(
736624
lambda x: jax.lax.pmax( # type: ignore[no-untyped-call]
737625
x, "all_cpus"
738626
),
739627
axis_name="all_cpus",
740-
devices=self._cpu_layout.device_mesh.backend_mesh.devices,
741-
)
742-
new_max_ids_per_partition = max_across_cpus(
743-
local_max_ids_per_partition
744-
)
745-
new_max_unique_ids_per_partition = max_across_cpus(
746-
local_max_unique_ids_per_partition
628+
backend="cpu",
747629
)
748-
new_buffer_size = max_across_cpus(local_buffer_size)
749-
750-
# Assign new preprocessing parameters.
751-
with self._cpu_distribution.scope():
752-
# For each process, all max ids/buffer sizes are replicated
753-
# across all local devices. Take the value from the first
754-
# device.
755-
keras.tree.map_structure(
756-
lambda var, values: var.assign(values[0]),
757-
self._preprocessing_max_ids_per_partition,
758-
new_max_ids_per_partition,
759-
)
760-
keras.tree.map_structure(
761-
lambda var, values: var.assign(values[0]),
762-
self._preprocessing_max_unique_ids_per_partition,
763-
new_max_unique_ids_per_partition,
764-
)
765-
keras.tree.map_structure(
766-
lambda var, values: var.assign(values[0]),
767-
self._preprocessing_buffer_size,
768-
new_buffer_size,
769-
)
770-
# Update parameters in the underlying feature specs.
771-
int_max_ids_per_partition = keras.tree.map_structure(
772-
lambda varray: varray.item(), new_max_ids_per_partition
630+
flat_stats = max_across_cpus(tiled_stats)[0].tolist()
631+
stats = jax.tree.unflatten(stats_treedef, stats)
632+
633+
# Update configuration and repeat preprocessing if stats changed.
634+
if stats.values() != prev_stats.values():
635+
embedding_utils.update_stacked_table_stats(
636+
self._config.feature_specs, stats
773637
)
774-
int_max_unique_ids_per_partition = keras.tree.map_structure(
775-
lambda varray: varray.item(),
776-
new_max_unique_ids_per_partition,
638+
639+
prev_stats = embedding_utils.get_stacked_table_stats(
640+
self._config.feature_specs
777641
)
778-
embedding_utils.update_stacked_table_specs(
642+
643+
# Re-execute preprocessing with consistent input statistics.
644+
preprocessed, _ = embedding_utils.stack_and_shard_samples(
779645
self._config.feature_specs,
780-
int_max_ids_per_partition,
781-
int_max_unique_ids_per_partition,
646+
samples,
647+
local_device_count,
648+
global_device_count,
649+
num_sc_per_device,
782650
)
783651

784652
return {"inputs": preprocessed}

0 commit comments

Comments
 (0)