Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Parallel(+, f)(x, y, z) to work like broadcasting, and enable Chain(identity, Parallel(+, f))(x, y, z) #2393

Merged
merged 7 commits into from
Nov 5, 2024

Conversation

mcabbott
Copy link
Member

At present Parallel allows multiple layers and one input, but not the reverse. This PR extends it to allow both ways... much like broadcasting in connection((inputs .|> layers)...).

julia> Parallel(+, inv)(1, 2, 3)  # was an error
1.8333333333333333

julia> (1,2,3) .|> (inv,)
(1.0, 0.5, 0.3333333333333333)

Does this have any unintended side-effects?

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

Copy link

codecov bot commented Mar 10, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 71.49%. Comparing base (0a36651) to head (288f8d5).
Report is 3 commits behind head on master.

Additional details and impacted files
@@             Coverage Diff             @@
##           master    #2393       +/-   ##
===========================================
+ Coverage   34.58%   71.49%   +36.91%     
===========================================
  Files          31       31               
  Lines        1897     1961       +64     
===========================================
+ Hits          656     1402      +746     
+ Misses       1241      559      -682     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

src/layers/basic.jl Outdated Show resolved Hide resolved
@mcabbott
Copy link
Member Author

mcabbott commented Mar 13, 2024

Here's the complete run-down on where Flux does & doesn't splat at present:

julia> using Flux

julia> pr(x) = begin println("arg: ", x); x end;

julia> pr(x...) = begin println(length(x), " args: ", join(x, " & "), " -> tuple"); x end;

julia> c1 = Chain(pr, pr); ########## simple chain

julia> c1(1)
arg: 1
arg: 1
1

julia> c1((1, 2))
arg: (1, 2)
arg: (1, 2)
(1, 2)

julia> c1(1, 2)
ERROR: MethodError:
Closest candidates are:
  (::Chain)(::Any)

julia> p1 = Parallel(pr, a=pr);  ########## combiner + one layer

julia> p1(1)
arg: 1
arg: 1
1

julia> p1((1, 2))  # one 2-Tuple is NOT accepted, always splatted  --> changed by PR
ERROR: ArgumentError: Parallel with 1 sub-layers can take one input or 1 inputs, but got 2 inputs

julia> p1(1, 2)  # more obvious error  --> changed by PR
ERROR: ArgumentError: Parallel with 1 sub-layers can take one input or 1 inputs, but got 2 inputs

julia> p1((a=1, b=2))  # one NamedTuple is ok
arg: (a = 1, b = 2)
arg: (a = 1, b = 2)
(a = 1, b = 2)

julia> p1((((1,),),))  # splatted many times
arg: 1
arg: 1
1

julia> p2 = Parallel(pr, a=pr, b=pr);  ########## combiner + two layers

julia> p2(1)  # one non-tuple arg is broadcasted
arg: 1
arg: 1
2 args: 1 & 1 -> tuple
(1, 1)

julia> p2(1, 2)  # 2 args sent to 2 layers
arg: 1
arg: 2
2 args: 1 & 2 -> tuple
(1, 2)

julia> p2((1, 2))  # one tuple splatted
arg: 1
arg: 2
2 args: 1 & 2 -> tuple
(1, 2)

julia> p2((a=1, b=2))  # one NamedTuple sent to both
arg: (a = 1, b = 2)
arg: (a = 1, b = 2)
2 args: (a = 1, b = 2) & (a = 1, b = 2) -> tuple
((a = 1, b = 2), (a = 1, b = 2))

julia> p2(((1,2), ((3,4),)))  # only splatted once
arg: (1, 2)
arg: ((3, 4),)
2 args: (1, 2) & ((3, 4),) -> tuple
((1, 2), ((3, 4),))

julia> Chain(pr, p2, pr)((1, 2))  # here earlier layers cannot pass p2 two arguments
arg: (1, 2)
arg: 1
arg: 2
2 args: 1 & 2 -> tuple
arg: (1, 2)
(1, 2)

This PR changes the two error cases above:

julia> p1((1, 2))  # changed by PR
arg: 1
arg: 2
2 args: 1 & 2 -> tuple
(1, 2)

julia> p1(1, 2)  # changed by PR
arg: 1
arg: 2
2 args: 1 & 2 -> tuple
(1, 2)

You could argue that p1((1, 2)) already has a plausible meaning, apply one layer to one input Tuple. But this use of Parallel is really just Chain (or in this order, ). And it's an error at present.

I think p1(1, 2) has no other plausible meaning.

The rule after this PR is:

  1. (p::Paralel)(input::Tuple) always splats to p(input...)
  2. return combine((inputs .|> layers)...)

Step 1 is unchanged, but step 2 previously allowed only broadcasting of the input. And today, I have a use where I want to broadcast the layer instead (easier than sharing it). That's in fact the 3rd case mentioned here: #1685 (comment) but I think it never worked.

@mcabbott
Copy link
Member Author

Reading old threads... around here #2101 (comment) it was agreed that adding (c::Chain)(xs...) = c(xs) would make sense, but there was never a PR.

That's the first MethodError in my list above. I would like this too, and perhaps should just add it here.

@mcabbott mcabbott changed the title Allow Parallel(+, f)(x, y, z) to work like broadcasting Allow Parallel(+, f)(x, y, z) to work like broadcasting, and enable Chain(identity, Parallel(+, f))(x, y, z) Mar 13, 2024
@mcabbott
Copy link
Member Author

mcabbott commented Mar 13, 2024

Anyone remember why we allow Parallel(hcat)? You can write Returns(hcat()) if you really want that...

julia> Parallel(hcat)()
Any[]

julia> Parallel(hcat)(NaN)  # ignores input, but this case is tested
Any[]

julia> Parallel(hcat)(1,2,3)
ERROR: ArgumentError: Parallel with 0 sub-layers can take one input or 0 inputs, but got 3 inputs

Can we just make this an error on construction? I think that's basically what was agreed in #1685

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it was linked here, can you quickly comment on the relationship between this and FluxML/Functors.jl#80?

@mcabbott mcabbott added this to the v0.15 milestone Oct 16, 2024
@mcabbott
Copy link
Member Author

I put this on 0.15 milestone... I still think it's the right thing to do, but perhaps a breaking change is the right time to merge it.

@mcabbott mcabbott merged commit 7be1ca7 into FluxML:master Nov 5, 2024
16 of 19 checks passed
@mcabbott mcabbott deleted the parallel_bc branch November 5, 2024 03:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants