From 89e09c161e1b8b6fb8511cb2bfe99b58bb9b4a95 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 25 Nov 2024 11:55:09 -0800 Subject: [PATCH] Experiment with using out_axes=-1 inside nmap's vmap --- penzai/core/named_axes.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/penzai/core/named_axes.py b/penzai/core/named_axes.py index af10bb0..cac9140 100644 --- a/penzai/core/named_axes.py +++ b/penzai/core/named_axes.py @@ -251,7 +251,7 @@ def recursive_vectorize_step(current_views, remaining_names): # Otherwise, we still have names to vectorize over. Pop one name off the # stack and vectorize over it as needed. - vmap_name = remaining_names[0] + vmap_name = remaining_names[-1] reduced_views = [] vmap_axes = [] for view in current_views: @@ -298,10 +298,10 @@ def _shift_axis(other_axis): return jax.vmap( functools.partial( recursive_vectorize_step, - remaining_names=remaining_names[1:], + remaining_names=remaining_names[:-1], ), in_axes=(vmap_axes,), - out_axes=0, + out_axes=-1, axis_name=vmap_name, )(reduced_views) @@ -324,8 +324,11 @@ def handle_result(leaf): return NamedArrayView( data_array=leaf, data_shape=leaf.shape, - data_axis_for_name={name: i for i, name in enumerate(all_names)}, - data_axis_for_logical_axis=tuple(range(len(all_names), leaf.ndim)), + data_axis_for_name={ + name: leaf.ndim - i - 1 + for i, name in enumerate(reversed(all_names)) + }, + data_axis_for_logical_axis=tuple(range(leaf.ndim - len(all_names))), ) return jax.tree_util.tree_map(handle_result, result_data)