Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experiment with using out_axes=-1 inside nmap's vmap #99

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions penzai/core/named_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def recursive_vectorize_step(current_views, remaining_names):

# Otherwise, we still have names to vectorize over. Pop one name off the
# stack and vectorize over it as needed.
vmap_name = remaining_names[0]
vmap_name = remaining_names[-1]
reduced_views = []
vmap_axes = []
for view in current_views:
Expand Down Expand Up @@ -298,10 +298,10 @@ def _shift_axis(other_axis):
return jax.vmap(
functools.partial(
recursive_vectorize_step,
remaining_names=remaining_names[1:],
remaining_names=remaining_names[:-1],
),
in_axes=(vmap_axes,),
out_axes=0,
out_axes=-1,
axis_name=vmap_name,
)(reduced_views)

Expand All @@ -324,8 +324,11 @@ def handle_result(leaf):
return NamedArrayView(
data_array=leaf,
data_shape=leaf.shape,
data_axis_for_name={name: i for i, name in enumerate(all_names)},
data_axis_for_logical_axis=tuple(range(len(all_names), leaf.ndim)),
data_axis_for_name={
name: leaf.ndim - i - 1
for i, name in enumerate(reversed(all_names))
},
data_axis_for_logical_axis=tuple(range(leaf.ndim - len(all_names))),
)

return jax.tree_util.tree_map(handle_result, result_data)
Expand Down