From 8a99546308196a8227902a79d94dae193a110c8e Mon Sep 17 00:00:00 2001 From: Rachit Nigam Date: Wed, 28 Feb 2024 01:35:59 -0500 Subject: [PATCH] Add dot-product kernel with multiplier reuse (#412) * Add dot-product kernel with multiplier reuse * simplify vmul implementation * complete design * main module * fix and add test * reorg files --------- Co-authored-by: ehg54 --- apps/blas/dot-alt/dot.fil | 135 +++++++++++++++++++++++++++++++++ apps/blas/dot-alt/dot.fil.data | 98 ++++++++++++++++++++++++ 2 files changed, 233 insertions(+) create mode 100644 apps/blas/dot-alt/dot.fil create mode 100644 apps/blas/dot-alt/dot.fil.data diff --git a/apps/blas/dot-alt/dot.fil b/apps/blas/dot-alt/dot.fil new file mode 100644 index 00000000..9d359ac9 --- /dev/null +++ b/apps/blas/dot-alt/dot.fil @@ -0,0 +1,135 @@ +import "primitives/core.fil"; +import "primitives/math/math.fil"; +import "primitives/reshape.fil"; + +/// Perform vector multiplication of 16 elements using parameterized number of +/// multipliers. +comp VMul[M]<'G:C>( + go: interface['G], + left[16]: ['G, 'G+1] 32, + right[16]: ['G, 'G+1] 32, +) -> ( + out[C][M]: for ['G+j+Lat, 'G+j+Lat+1] 32 +) with { + let C = 16 / M; + let Lat = 3; +} where 16 % M == 0, M > 0, M <= 16 { + // For each multiplier + for i in 0..M { + // Instantiate the multiplier + M := new FastMult[32] in ['G, 'G+C]; + for j in 0..C { + // Shift the inputs by the appropriate amount + ls := new Shift[32, j]<'G>(left{C*i+j}); + rs := new Shift[32, j]<'G>(right{C*i+j}); + m := M<'G+j>(ls.out, rs.out); + out{j}{i} = m.out; + } + } +} + +/// Dot-product implementation that uses exactly M multipliers +comp Dot[M]<'G:C+1>( + go: interface['G], + left[16]: ['G, 'G+1] 32, + right[16]: ['G, 'G+1] 32, +) -> ( + out: ['G+TLat+C, 'G+TLat+C+1] 32 +) with { + let C = 16 / M; + let Lat = 3; + let ALat = log2(M); // Latency of the reduction tree + let TLat = ALat + Lat; +} where 16 % M == 0, M > 1, M <= 16 { + // Vector multiplier that produces M values at a time + vmul := new VMul[M]<'G>(left{0..16}, right{0..16}); + + // Required for the reduce adder + assume ALat >= 0; + adder := new ReduceAdd[32, M] in ['G+Lat, 'G+Lat+C]; + + // Bundle to track output from the reduction tree + bundle add_out[C]: for ['G+TLat+i, 'G+TLat+i+1] 32; + + for j in 0..C { + // Reduce the M values to a single value + a := adder<'G+Lat+j>(vmul.out{j}{0..M}); + add_out{j} = a.out; + } + + r := new Prev[32, 1] in ['G+TLat, 'G+TLat+C+1]; + ar := new Add[32]; + for j in 0..C { + // Accumulate the results + // XXX(rachit): Tragic amount of duplication across branches. + if j == 0 { + zero := new Const[32, 0]<'G+TLat>(); + add := ar<'G+TLat>(zero.out, add_out{j}); + prev := r<'G+TLat>(add.out); + } else { + add := ar<'G+TLat+j>(prev.prev, add_out{j}); + prev := r<'G+TLat+j>(add.out); + } + } + + // Reset the prev to 0 + final := r<'G+TLat+C>(final.prev); + out = final.prev; +} + +// Flat interface for the main module. +// li and ri are left and right inputs, respectively at index i. +comp main<'G:5>( + go: interface['G], + l0: ['G, 'G+1] 32, + l1: ['G, 'G+1] 32, + l2: ['G, 'G+1] 32, + l3: ['G, 'G+1] 32, + l4: ['G, 'G+1] 32, + l5: ['G, 'G+1] 32, + l6: ['G, 'G+1] 32, + l7: ['G, 'G+1] 32, + l8: ['G, 'G+1] 32, + l9: ['G, 'G+1] 32, + l10: ['G, 'G+1] 32, + l11: ['G, 'G+1] 32, + l12: ['G, 'G+1] 32, + l13: ['G, 'G+1] 32, + l14: ['G, 'G+1] 32, + l15: ['G, 'G+1] 32, + r0: ['G, 'G+1] 32, + r1: ['G, 'G+1] 32, + r2: ['G, 'G+1] 32, + r3: ['G, 'G+1] 32, + r4: ['G, 'G+1] 32, + r5: ['G, 'G+1] 32, + r6: ['G, 'G+1] 32, + r7: ['G, 'G+1] 32, + r8: ['G, 'G+1] 32, + r9: ['G, 'G+1] 32, + r10: ['G, 'G+1] 32, + r11: ['G, 'G+1] 32, + r12: ['G, 'G+1] 32, + r13: ['G, 'G+1] 32, + r14: ['G, 'G+1] 32, + r15: ['G, 'G+1] 32, +) -> ( + out: ['G+9, 'G+10] 32 +) { + // Wrap inputs into bundles + bundle l[16]: ['G, 'G+1] 32; + l{0} = l0; l{1} = l1; l{2} = l2; l{3} = l3; + l{4} = l4; l{5} = l5; l{6} = l6; l{7} = l7; + l{8} = l8; l{9} = l9; l{10} = l10; l{11} = l11; + l{12} = l12; l{13} = l13; l{14} = l14; l{15} = l15; + + bundle r[16]: ['G, 'G+1] 32; + r{0} = r0; r{1} = r1; r{2} = r2; r{3} = r3; + r{4} = r4; r{5} = r5; r{6} = r6; r{7} = r7; + r{8} = r8; r{9} = r9; r{10} = r10; r{11} = r11; + r{12} = r12; r{13} = r13; r{14} = r14; r{15} = r15; + + // Perform the dot product + dot := new Dot[4]<'G>(l{0..16}, r{0..16}); + out = dot.out; +} \ No newline at end of file diff --git a/apps/blas/dot-alt/dot.fil.data b/apps/blas/dot-alt/dot.fil.data new file mode 100644 index 00000000..9be8901c --- /dev/null +++ b/apps/blas/dot-alt/dot.fil.data @@ -0,0 +1,98 @@ +{ + "l0": [ + 1 + ], + "l1": [ + 2 + ], + "l2": [ + 3 + ], + "l3": [ + 4 + ], + "l4": [ + 5 + ], + "l5": [ + 6 + ], + "l6": [ + 7 + ], + "l7": [ + 8 + ], + "l8": [ + 9 + ], + "l9": [ + 10 + ], + "l10": [ + 11 + ], + "l11": [ + 12 + ], + "l12": [ + 13 + ], + "l13": [ + 14 + ], + "l14": [ + 15 + ], + "l15": [ + 16 + ], + "r0": [ + 17 + ], + "r1": [ + 18 + ], + "r2": [ + 19 + ], + "r3": [ + 20 + ], + "r4": [ + 21 + ], + "r5": [ + 22 + ], + "r6": [ + 23 + ], + "r7": [ + 24 + ], + "r8": [ + 25 + ], + "r9": [ + 26 + ], + "r10": [ + 27 + ], + "r11": [ + 28 + ], + "r12": [ + 29 + ], + "r13": [ + 30 + ], + "r14": [ + 31 + ], + "r15": [ + 32 + ] +} \ No newline at end of file