-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathzip.go
119 lines (100 loc) · 3.74 KB
/
zip.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
package nunela
import (
"sync"
"github.com/vorduin/nune"
"github.com/vorduin/slices"
)
// TryZip computes element-wise operation on two tensors.
func TryZip[T, U, V Number](a *nune.Tensor[T], b *nune.Tensor[U], f func(T, U) V) (*nune.Tensor[V], error) {
if !slices.Equal(a.Shape(), b.Shape()) {
return nil, NewErrDifferentShapesTwo(a, b)
}
out := nune.Zeros[V](a.Shape()...)
handleZip(a, b, &out, f, configCPU(a.Numel()))
return &out, nil
}
// TryZipAssign computes element-wise operation on two tensors and assigns the result to the first tensor.
func TryZipAssign[T, U Number](a *nune.Tensor[T], b *nune.Tensor[U], f func(T, U) T) error {
if !slices.Equal(a.Shape(), b.Shape()) {
return NewErrDifferentShapesTwo(a, b)
}
handleZip(a, b, a, f, configCPU(a.Numel()))
return nil
}
// Zip computes element-wise operation on two tensors.
func Zip[T, U, V Number](a *nune.Tensor[T], b *nune.Tensor[U], f func(T, U) V) *nune.Tensor[V] {
t, err := TryZip(a, b, f)
if err != nil {
panic(err)
}
return t
}
// ZipAssign computes element-wise operation on two tensors and assigns the result to the first tensor.
func ZipAssign[T, U Number](a *nune.Tensor[T], b *nune.Tensor[U], f func(T, U) T) {
err := TryZipAssign(a, b, f)
if err != nil {
panic(err)
}
}
// Add computes element-wise addition of two tensors.
func Add[T Number](a, b *nune.Tensor[T]) *nune.Tensor[T] {
return Zip(a, b, func(x, y T) T { return x + y })
}
// Sub computes element-wise subtraction of two tensors.
func Sub[T Number](a, b *nune.Tensor[T]) *nune.Tensor[T] {
return Zip(a, b, func(x, y T) T { return x - y })
}
// Mul computes element-wise multiplication of two tensors.
func Mul[T Number](a, b *nune.Tensor[T]) *nune.Tensor[T] {
return Zip(a, b, func(x, y T) T { return x * y })
}
// Div computes element-wise division of two tensors.
func Div[T Number](a, b *nune.Tensor[T]) *nune.Tensor[T] {
return Zip(a, b, func(x, y T) T { return x / y })
}
// Rem computes element-wise remainder of two tensors.
func Rem[T Integer](a, b *nune.Tensor[T]) *nune.Tensor[T] {
return Zip(a, b, func(x, y T) T { return x % y })
}
// AddAssign computes element-wise addition of two tensors and assigns the result to the first tensor.
func AddAssign[T Number](a, b *nune.Tensor[T]) {
ZipAssign(a, b, func(x, y T) T { return x + y })
}
// SubAssign computes element-wise subtraction of two tensors and assigns the result to the first tensor.
func SubAssign[T Number](a, b *nune.Tensor[T]) {
ZipAssign(a, b, func(x, y T) T { return x - y })
}
// MulAssign computes element-wise multiplication of two tensors and assigns the result to the first tensor.
func MulAssign[T Number](a, b *nune.Tensor[T]) {
ZipAssign(a, b, func(x, y T) T { return x * y })
}
// DivAssign computes element-wise division of two tensors and assigns the result to the first tensor.
func DivAssign[T Number](a, b *nune.Tensor[T]) {
ZipAssign(a, b, func(x, y T) T { return x / y })
}
// RemAssign computes element-wise remainder of two tensors and assigns the result to the first tensor.
func RemAssign[T Integer](a, b *nune.Tensor[T]) {
ZipAssign(a, b, func(x, y T) T { return x % y })
}
// handleZip is a helper function for Zip and ZipAssign.
// Copyright © The Nune Author. All rights reserved.
func handleZip[T, U, V Number](a *nune.Tensor[T], b *nune.Tensor[U], out *nune.Tensor[V], f func(T, U) V, nCPU int) {
if a.Rank() == 0 {
out.Ravel()[0] = f(a.Ravel()[0], b.Ravel()[0])
return
}
var wg sync.WaitGroup
for i := 0; i < nCPU; i++ {
min := (i * a.Numel() / nCPU)
max := ((i + 1) * a.Numel()) / nCPU
wg.Add(1)
go func(aBuf []T, bBuf []U, outBuf []V) {
for j := 0; j < len(aBuf); j++ {
outBuf[j] = f(aBuf[j], bBuf[j])
}
wg.Done()
}(a.Ravel()[min:max], b.Ravel()[min:max], out.Ravel()[min:max])
}
wg.Wait()
return
}