Skip to content

Commit

Permalink
Add dot-product kernel with multiplier reuse (#412)
Browse files Browse the repository at this point in the history
* Add dot-product kernel with multiplier reuse

* simplify vmul implementation

* complete design

* main module

* fix and add test

* reorg files

---------

Co-authored-by: ehg54 <[email protected]>
  • Loading branch information
rachitnigam and gabizon103 authored Feb 28, 2024
1 parent 0656bae commit 8a99546
Show file tree
Hide file tree
Showing 2 changed files with 233 additions and 0 deletions.
135 changes: 135 additions & 0 deletions apps/blas/dot-alt/dot.fil
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import "primitives/core.fil";
import "primitives/math/math.fil";
import "primitives/reshape.fil";

/// Perform vector multiplication of 16 elements using parameterized number of
/// multipliers.
comp VMul[M]<'G:C>(
go: interface['G],
left[16]: ['G, 'G+1] 32,
right[16]: ['G, 'G+1] 32,
) -> (
out[C][M]: for<j> ['G+j+Lat, 'G+j+Lat+1] 32
) with {
let C = 16 / M;
let Lat = 3;
} where 16 % M == 0, M > 0, M <= 16 {
// For each multiplier
for i in 0..M {
// Instantiate the multiplier
M := new FastMult[32] in ['G, 'G+C];
for j in 0..C {
// Shift the inputs by the appropriate amount
ls := new Shift[32, j]<'G>(left{C*i+j});
rs := new Shift[32, j]<'G>(right{C*i+j});
m := M<'G+j>(ls.out, rs.out);
out{j}{i} = m.out;
}
}
}

/// Dot-product implementation that uses exactly M multipliers
comp Dot[M]<'G:C+1>(
go: interface['G],
left[16]: ['G, 'G+1] 32,
right[16]: ['G, 'G+1] 32,
) -> (
out: ['G+TLat+C, 'G+TLat+C+1] 32
) with {
let C = 16 / M;
let Lat = 3;
let ALat = log2(M); // Latency of the reduction tree
let TLat = ALat + Lat;
} where 16 % M == 0, M > 1, M <= 16 {
// Vector multiplier that produces M values at a time
vmul := new VMul[M]<'G>(left{0..16}, right{0..16});

// Required for the reduce adder
assume ALat >= 0;
adder := new ReduceAdd[32, M] in ['G+Lat, 'G+Lat+C];

// Bundle to track output from the reduction tree
bundle add_out[C]: for<i> ['G+TLat+i, 'G+TLat+i+1] 32;

for j in 0..C {
// Reduce the M values to a single value
a := adder<'G+Lat+j>(vmul.out{j}{0..M});
add_out{j} = a.out;
}

r := new Prev[32, 1] in ['G+TLat, 'G+TLat+C+1];
ar := new Add[32];
for j in 0..C {
// Accumulate the results
// XXX(rachit): Tragic amount of duplication across branches.
if j == 0 {
zero := new Const[32, 0]<'G+TLat>();
add := ar<'G+TLat>(zero.out, add_out{j});
prev := r<'G+TLat>(add.out);
} else {
add := ar<'G+TLat+j>(prev.prev, add_out{j});
prev := r<'G+TLat+j>(add.out);
}
}

// Reset the prev to 0
final := r<'G+TLat+C>(final.prev);
out = final.prev;
}

// Flat interface for the main module.
// li and ri are left and right inputs, respectively at index i.
comp main<'G:5>(
go: interface['G],
l0: ['G, 'G+1] 32,
l1: ['G, 'G+1] 32,
l2: ['G, 'G+1] 32,
l3: ['G, 'G+1] 32,
l4: ['G, 'G+1] 32,
l5: ['G, 'G+1] 32,
l6: ['G, 'G+1] 32,
l7: ['G, 'G+1] 32,
l8: ['G, 'G+1] 32,
l9: ['G, 'G+1] 32,
l10: ['G, 'G+1] 32,
l11: ['G, 'G+1] 32,
l12: ['G, 'G+1] 32,
l13: ['G, 'G+1] 32,
l14: ['G, 'G+1] 32,
l15: ['G, 'G+1] 32,
r0: ['G, 'G+1] 32,
r1: ['G, 'G+1] 32,
r2: ['G, 'G+1] 32,
r3: ['G, 'G+1] 32,
r4: ['G, 'G+1] 32,
r5: ['G, 'G+1] 32,
r6: ['G, 'G+1] 32,
r7: ['G, 'G+1] 32,
r8: ['G, 'G+1] 32,
r9: ['G, 'G+1] 32,
r10: ['G, 'G+1] 32,
r11: ['G, 'G+1] 32,
r12: ['G, 'G+1] 32,
r13: ['G, 'G+1] 32,
r14: ['G, 'G+1] 32,
r15: ['G, 'G+1] 32,
) -> (
out: ['G+9, 'G+10] 32
) {
// Wrap inputs into bundles
bundle l[16]: ['G, 'G+1] 32;
l{0} = l0; l{1} = l1; l{2} = l2; l{3} = l3;
l{4} = l4; l{5} = l5; l{6} = l6; l{7} = l7;
l{8} = l8; l{9} = l9; l{10} = l10; l{11} = l11;
l{12} = l12; l{13} = l13; l{14} = l14; l{15} = l15;

bundle r[16]: ['G, 'G+1] 32;
r{0} = r0; r{1} = r1; r{2} = r2; r{3} = r3;
r{4} = r4; r{5} = r5; r{6} = r6; r{7} = r7;
r{8} = r8; r{9} = r9; r{10} = r10; r{11} = r11;
r{12} = r12; r{13} = r13; r{14} = r14; r{15} = r15;

// Perform the dot product
dot := new Dot[4]<'G>(l{0..16}, r{0..16});
out = dot.out;
}
98 changes: 98 additions & 0 deletions apps/blas/dot-alt/dot.fil.data
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
{
"l0": [
1
],
"l1": [
2
],
"l2": [
3
],
"l3": [
4
],
"l4": [
5
],
"l5": [
6
],
"l6": [
7
],
"l7": [
8
],
"l8": [
9
],
"l9": [
10
],
"l10": [
11
],
"l11": [
12
],
"l12": [
13
],
"l13": [
14
],
"l14": [
15
],
"l15": [
16
],
"r0": [
17
],
"r1": [
18
],
"r2": [
19
],
"r3": [
20
],
"r4": [
21
],
"r5": [
22
],
"r6": [
23
],
"r7": [
24
],
"r8": [
25
],
"r9": [
26
],
"r10": [
27
],
"r11": [
28
],
"r12": [
29
],
"r13": [
30
],
"r14": [
31
],
"r15": [
32
]
}

0 comments on commit 8a99546

Please sign in to comment.