|
15 | 15 | table_stacking as jte_table_stacking,
|
16 | 16 | )
|
17 | 17 | from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
|
18 |
| -from keras.src import backend |
19 | 18 |
|
20 | 19 | from keras_rs.src import types
|
21 | 20 | from keras_rs.src.layers.embedding import base_distributed_embedding
|
@@ -247,23 +246,6 @@ def _create_sparsecore_distribution(
|
247 | 246 | )
|
248 | 247 | return sparsecore_distribution, sparsecore_layout
|
249 | 248 |
|
250 |
| - def _create_cpu_distribution( |
251 |
| - self, cpu_axis_name: str = "cpu" |
252 |
| - ) -> tuple[ |
253 |
| - keras.distribution.ModelParallel, keras.distribution.TensorLayout |
254 |
| - ]: |
255 |
| - """Share a variable across all CPU processes.""" |
256 |
| - cpu_devices = jax.devices("cpu") |
257 |
| - device_mesh = keras.distribution.DeviceMesh( |
258 |
| - (len(cpu_devices),), [cpu_axis_name], cpu_devices |
259 |
| - ) |
260 |
| - replicated_layout = keras.distribution.TensorLayout([], device_mesh) |
261 |
| - layout_map = keras.distribution.LayoutMap(device_mesh=device_mesh) |
262 |
| - cpu_distribution = keras.distribution.ModelParallel( |
263 |
| - layout_map=layout_map |
264 |
| - ) |
265 |
| - return cpu_distribution, replicated_layout |
266 |
| - |
267 | 249 | def _add_sparsecore_weight(
|
268 | 250 | self,
|
269 | 251 | name: str,
|
@@ -405,11 +387,6 @@ def sparsecore_build(
|
405 | 387 | self._sparsecore_layout = sparsecore_layout
|
406 | 388 | self._sparsecore_distribution = sparsecore_distribution
|
407 | 389 |
|
408 |
| - # Distribution for CPU operations. |
409 |
| - cpu_distribution, cpu_layout = self._create_cpu_distribution() |
410 |
| - self._cpu_distribution = cpu_distribution |
411 |
| - self._cpu_layout = cpu_layout |
412 |
| - |
413 | 390 | mesh = sparsecore_distribution.device_mesh.backend_mesh
|
414 | 391 | global_device_count = mesh.devices.size
|
415 | 392 | num_sc_per_device = jte_utils.num_sparsecores_per_device(
|
@@ -466,10 +443,6 @@ def sparsecore_build(
|
466 | 443 | # Collect all stacked tables.
|
467 | 444 | table_specs = embedding_utils.get_table_specs(feature_specs)
|
468 | 445 | table_stacks = embedding_utils.get_table_stacks(table_specs)
|
469 |
| - stacked_table_specs = { |
470 |
| - stack_name: stack[0].stacked_table_spec |
471 |
| - for stack_name, stack in table_stacks.items() |
472 |
| - } |
473 | 446 |
|
474 | 447 | # Create variables for all stacked tables and slot variables.
|
475 | 448 | with sparsecore_distribution.scope():
|
@@ -502,50 +475,6 @@ def sparsecore_build(
|
502 | 475 | )
|
503 | 476 | self._iterations.overwrite_with_gradient = True
|
504 | 477 |
|
505 |
| - with cpu_distribution.scope(): |
506 |
| - # Create variables to track static buffer size and max IDs for each |
507 |
| - # table during preprocessing. These variables are shared across all |
508 |
| - # processes on CPU. We don't add these via `add_weight` because we |
509 |
| - # can't have them passed to the training function. |
510 |
| - replicated_zeros_initializer = ShardedInitializer( |
511 |
| - "zeros", cpu_layout |
512 |
| - ) |
513 |
| - |
514 |
| - with backend.name_scope(self.name, caller=self): |
515 |
| - self._preprocessing_buffer_size = { |
516 |
| - table_name: backend.Variable( |
517 |
| - initializer=replicated_zeros_initializer, |
518 |
| - shape=(), |
519 |
| - dtype=backend.standardize_dtype("int32"), |
520 |
| - trainable=False, |
521 |
| - name=table_name + ":preprocessing:buffer_size", |
522 |
| - ) |
523 |
| - for table_name in stacked_table_specs.keys() |
524 |
| - } |
525 |
| - self._preprocessing_max_unique_ids_per_partition = { |
526 |
| - table_name: backend.Variable( |
527 |
| - shape=(), |
528 |
| - name=table_name |
529 |
| - + ":preprocessing:max_unique_ids_per_partition", |
530 |
| - initializer=replicated_zeros_initializer, |
531 |
| - dtype=backend.standardize_dtype("int32"), |
532 |
| - trainable=False, |
533 |
| - ) |
534 |
| - for table_name in stacked_table_specs.keys() |
535 |
| - } |
536 |
| - |
537 |
| - self._preprocessing_max_ids_per_partition = { |
538 |
| - table_name: backend.Variable( |
539 |
| - shape=(), |
540 |
| - name=table_name |
541 |
| - + ":preprocessing:max_ids_per_partition", |
542 |
| - initializer=replicated_zeros_initializer, |
543 |
| - dtype=backend.standardize_dtype("int32"), |
544 |
| - trainable=False, |
545 |
| - ) |
546 |
| - for table_name in stacked_table_specs.keys() |
547 |
| - } |
548 |
| - |
549 | 478 | self._config = jte_embedding_lookup.EmbeddingLookupConfiguration(
|
550 | 479 | feature_specs,
|
551 | 480 | mesh=mesh,
|
@@ -660,125 +589,64 @@ def _sparsecore_preprocess(
|
660 | 589 | mesh.devices.item(0)
|
661 | 590 | )
|
662 | 591 |
|
663 |
| - # Get current buffer size/max_ids. |
664 |
| - previous_max_ids_per_partition = keras.tree.map_structure( |
665 |
| - lambda max_ids_per_partition: max_ids_per_partition.value.item(), |
666 |
| - self._preprocessing_max_ids_per_partition, |
667 |
| - ) |
668 |
| - previous_max_unique_ids_per_partition = keras.tree.map_structure( |
669 |
| - lambda max_unique_ids_per_partition: ( |
670 |
| - max_unique_ids_per_partition.value.item() |
671 |
| - ), |
672 |
| - self._preprocessing_max_unique_ids_per_partition, |
673 |
| - ) |
674 |
| - previous_buffer_size = keras.tree.map_structure( |
675 |
| - lambda buffer_size: buffer_size.value.item(), |
676 |
| - self._preprocessing_buffer_size, |
677 |
| - ) |
678 |
| - |
679 | 592 | preprocessed, stats = embedding_utils.stack_and_shard_samples(
|
680 | 593 | self._config.feature_specs,
|
681 | 594 | samples,
|
682 | 595 | local_device_count,
|
683 | 596 | global_device_count,
|
684 | 597 | num_sc_per_device,
|
685 |
| - static_buffer_size=previous_buffer_size, |
686 | 598 | )
|
687 | 599 |
|
688 |
| - # Extract max unique IDs and buffer sizes. |
689 |
| - # We need to replicate this value across all local CPU devices. |
690 | 600 | if training:
|
| 601 | + # Synchronize input statistics across all devices and update the |
| 602 | + # underlying stacked tables specs in the feature specs. |
| 603 | + prev_stats = embedding_utils.get_stacked_table_stats( |
| 604 | + self._config.feature_specs |
| 605 | + ) |
| 606 | + |
| 607 | + # Take the maximum with existing stats. |
| 608 | + stats = keras.tree.map_structure(max, prev_stats, stats) |
| 609 | + |
| 610 | + # Flatten the stats so we can more efficiently transfer them |
| 611 | + # between hosts. We use jax.tree because we will later need to |
| 612 | + # unflatten. |
| 613 | + flat_stats, stats_treedef = jax.tree.flatten(stats) |
| 614 | + |
| 615 | + # In the case of multiple local CPU devices per host, we need to |
| 616 | + # replicate the stats to placate JAX collectives. |
691 | 617 | num_local_cpu_devices = jax.local_device_count("cpu")
|
692 |
| - local_max_ids_per_partition = { |
693 |
| - table_name: np.repeat( |
694 |
| - # Maximum across all partitions and previous max. |
695 |
| - np.maximum( |
696 |
| - np.max(elems), |
697 |
| - previous_max_ids_per_partition[table_name], |
698 |
| - ), |
699 |
| - num_local_cpu_devices, |
700 |
| - ) |
701 |
| - for table_name, elems in stats.max_ids_per_partition.items() |
702 |
| - } |
703 |
| - local_max_unique_ids_per_partition = { |
704 |
| - name: np.repeat( |
705 |
| - # Maximum across all partitions and previous max. |
706 |
| - np.maximum( |
707 |
| - np.max(elems), |
708 |
| - previous_max_unique_ids_per_partition[name], |
709 |
| - ), |
710 |
| - num_local_cpu_devices, |
711 |
| - ) |
712 |
| - for name, elems in stats.max_unique_ids_per_partition.items() |
713 |
| - } |
714 |
| - local_buffer_size = { |
715 |
| - table_name: np.repeat( |
716 |
| - np.maximum( |
717 |
| - np.max( |
718 |
| - # Round values up to the next multiple of 8. |
719 |
| - # Currently using this as a proxy for the actual |
720 |
| - # required buffer size. |
721 |
| - ((elems + 7) // 8) * 8 |
722 |
| - ) |
723 |
| - * global_device_count |
724 |
| - * num_sc_per_device |
725 |
| - * local_device_count |
726 |
| - * num_sc_per_device, |
727 |
| - previous_buffer_size[table_name], |
728 |
| - ), |
729 |
| - num_local_cpu_devices, |
730 |
| - ) |
731 |
| - for table_name, elems in stats.max_ids_per_partition.items() |
732 |
| - } |
| 618 | + tiled_stats = np.tile( |
| 619 | + np.array(flat_stats, dtype=np.int32), (num_local_cpu_devices, 1) |
| 620 | + ) |
733 | 621 |
|
734 | 622 | # Aggregate variables across all processes/devices.
|
735 | 623 | max_across_cpus = jax.pmap(
|
736 | 624 | lambda x: jax.lax.pmax( # type: ignore[no-untyped-call]
|
737 | 625 | x, "all_cpus"
|
738 | 626 | ),
|
739 | 627 | axis_name="all_cpus",
|
740 |
| - devices=self._cpu_layout.device_mesh.backend_mesh.devices, |
741 |
| - ) |
742 |
| - new_max_ids_per_partition = max_across_cpus( |
743 |
| - local_max_ids_per_partition |
744 |
| - ) |
745 |
| - new_max_unique_ids_per_partition = max_across_cpus( |
746 |
| - local_max_unique_ids_per_partition |
| 628 | + backend="cpu", |
747 | 629 | )
|
748 |
| - new_buffer_size = max_across_cpus(local_buffer_size) |
749 |
| - |
750 |
| - # Assign new preprocessing parameters. |
751 |
| - with self._cpu_distribution.scope(): |
752 |
| - # For each process, all max ids/buffer sizes are replicated |
753 |
| - # across all local devices. Take the value from the first |
754 |
| - # device. |
755 |
| - keras.tree.map_structure( |
756 |
| - lambda var, values: var.assign(values[0]), |
757 |
| - self._preprocessing_max_ids_per_partition, |
758 |
| - new_max_ids_per_partition, |
759 |
| - ) |
760 |
| - keras.tree.map_structure( |
761 |
| - lambda var, values: var.assign(values[0]), |
762 |
| - self._preprocessing_max_unique_ids_per_partition, |
763 |
| - new_max_unique_ids_per_partition, |
764 |
| - ) |
765 |
| - keras.tree.map_structure( |
766 |
| - lambda var, values: var.assign(values[0]), |
767 |
| - self._preprocessing_buffer_size, |
768 |
| - new_buffer_size, |
769 |
| - ) |
770 |
| - # Update parameters in the underlying feature specs. |
771 |
| - int_max_ids_per_partition = keras.tree.map_structure( |
772 |
| - lambda varray: varray.item(), new_max_ids_per_partition |
| 630 | + flat_stats = max_across_cpus(tiled_stats)[0].tolist() |
| 631 | + stats = jax.tree.unflatten(stats_treedef, stats) |
| 632 | + |
| 633 | + # Update configuration and repeat preprocessing if stats changed. |
| 634 | + if stats.values() != prev_stats.values(): |
| 635 | + embedding_utils.update_stacked_table_stats( |
| 636 | + self._config.feature_specs, stats |
773 | 637 | )
|
774 |
| - int_max_unique_ids_per_partition = keras.tree.map_structure( |
775 |
| - lambda varray: varray.item(), |
776 |
| - new_max_unique_ids_per_partition, |
| 638 | + |
| 639 | + prev_stats = embedding_utils.get_stacked_table_stats( |
| 640 | + self._config.feature_specs |
777 | 641 | )
|
778 |
| - embedding_utils.update_stacked_table_specs( |
| 642 | + |
| 643 | + # Re-execute preprocessing with consistent input statistics. |
| 644 | + preprocessed, _ = embedding_utils.stack_and_shard_samples( |
779 | 645 | self._config.feature_specs,
|
780 |
| - int_max_ids_per_partition, |
781 |
| - int_max_unique_ids_per_partition, |
| 646 | + samples, |
| 647 | + local_device_count, |
| 648 | + global_device_count, |
| 649 | + num_sc_per_device, |
782 | 650 | )
|
783 | 651 |
|
784 | 652 | return {"inputs": preprocessed}
|
|
0 commit comments