-
-
Notifications
You must be signed in to change notification settings - Fork 50
/
api_matop.go
196 lines (170 loc) · 6.43 KB
/
api_matop.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
package tensor
import (
"github.com/pkg/errors"
)
// this file handles matops. While by default most of these matops should already have been defined as part of the
// Tensor interface, not all are possible(for example, concatenating a sparse tensor), hence the need for the following functions
// Narrow narrows the tensor.
func Narrow(t Tensor, dim, start, length int) (View, error) {
dim = resolveAxis(dim, t.Dims())
slices := make([]Slice, MinInt(dim+1, t.Dims()))
slices[dim] = S(start, start+length, 1)
return t.Slice(slices...)
}
// Repeat repeats a Tensor along the axis and given the number of repeats.
func Repeat(t Tensor, axis int, repeats ...int) (retVal Tensor, err error) {
if r, ok := t.Engine().(Repeater); ok {
return r.Repeat(t, axis, repeats...)
}
return nil, errors.New("Engine does not support Repeat")
}
// RepeatReuse repeats a Tensor along the axis and the given number of repeats, and puts the results in the provided reuse tensor. If the reuse tensor is not correctly sized, then an error will be given, but the results will still be valid.
func RepeatReuse(t, reuse Tensor, axis int, repeats ...int) (retval Tensor, err error) {
if r, ok := t.Engine().(Repeater); ok {
return r.RepeatReuse(t, reuse, axis, repeats...)
}
return nil, errors.New("Engine does not support Repeat")
}
// T safely transposes a Tensor. It returns a tensor that is not a view of the input tensor - rather, the data is all copied.
func T(t Tensor, axes ...int) (retVal Tensor, err error) {
switch tt := t.(type) {
case *Dense:
return tt.SafeT(axes...)
}
panic("Unreachable")
}
// Transpose performs transposition of a tensor according to its axes.
func Transpose(t Tensor, axes ...int) (retVal Tensor, err error) {
switch tt := t.(type) {
case *Dense:
var ret *Dense
if ret, err = tt.SafeT(axes...); err != nil {
return
}
ret.Transpose()
retVal = ret
return
}
panic("Unreachable")
}
// Concat concatenates a list of Tensors. At the moment the operation only supports Tensors of the same type
// (*Dense can only be concatenated with a bunch of *Dense, CSCs can only be concatenated with a bunch of CSC, etc)
func Concat(axis int, t Tensor, others ...Tensor) (retVal Tensor, err error) {
if len(others) == 0 {
return t, nil
}
switch T := t.(type) {
case *Dense:
ts := make([]*Dense, len(others))
for i, o := range others {
if ot, ok := o.(*Dense); ok {
ts[i] = ot
continue
}
return nil, errors.Errorf("Expected all Tensors to be *Dense")
}
return T.Concat(axis, ts...)
}
panic("Unreachable")
}
// Copy copies a tensor to another. For *Dense views, only the relevant slots are copied.
func Copy(dst, src Tensor) error {
switch st := src.(type) {
case DenseTensor:
dt, ok := dst.(DenseTensor)
if !ok {
return errors.Errorf("Cannot copy from DenseTensor to %T", dst)
}
if st.RequiresIterator() || dt.RequiresIterator() {
siter := st.Iterator()
diter := dt.Iterator()
_, err := copyDenseIter(dt, st, diter, siter)
return err
}
copyDense(dt, st)
return nil
default:
return errors.Errorf("NYI for Copy %T", src)
}
panic("Unreachable")
}
// Stack stacks a list of other Tensors. At the moment the operation only supports Tensors of the same type.
// (*Dense can only be stacked with *Dense... etc)
func Stack(axis int, t Tensor, others ...Tensor) (retVal Tensor, err error) {
if len(others) == 0 {
return t, nil
}
switch T := t.(type) {
case DenseTensor:
var dts []DenseTensor
if dts, err = tensorsToDenseTensors(others); err != nil {
return nil, errors.Wrap(err, "Cannot convert others into a slice of DenseTensors")
}
return T.stackDense(axis, dts...)
}
panic("Unreachable")
}
// Materialize takes a View and copies out the data into a new allocation.
func Materialize(t Tensor) Tensor {
switch tt := t.(type) {
case View:
return tt.Materialize()
default:
return t
}
}
func Diag(t Tensor) (retVal Tensor, err error) {
if d, ok := t.Engine().(Diager); ok {
return d.Diag(t)
}
return nil, errors.Errorf("Unable to perform diagonalization of tensor ")
}
// ByIndices allows for selection of value of `a` byt the indices listed in the `indices` tensor.
// The `indices` tensor has to be a vector-like tensor of ints.
func ByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
if axis >= a.Shape().Dims() {
return nil, errors.Errorf("Cannot select by indices on axis %d. Input only has %d dims", axis, a.Shape().Dims())
}
if sbi, ok := a.Engine().(ByIndiceser); ok {
return sbi.SelectByIndices(a, indices, axis, opts...)
}
return nil, errors.Errorf("Unable to select by indices. Engine %T does not support that.", a.Engine())
}
// ByIndicesB is the backpropagation of ByIndices.
func ByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
if axis >= a.Shape().Dims() {
return nil, errors.Errorf("Cannot select by indices on axis %d. Input only has %d dims", axis, a.Shape().Dims())
}
if sbi, ok := a.Engine().(ByIndiceser); ok {
return sbi.SelectByIndicesB(a, b, indices, axis, opts...)
}
return nil, errors.Errorf("Unable to select by indices. Engine %T does not support that.", a.Engine())
}
// LogSoftMax applies log softmax to the given tensor.
func LogSoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
if sm, ok := x.Engine().(SoftMaxer); ok {
return sm.LogSoftMax(x, axis, opts...)
}
return nil, errors.Errorf("Unable to apply LogSoftMax. Engine %T does not support that.", x.Engine())
}
// SoftMax applies softmax to the given tensor.
func SoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
if sm, ok := x.Engine().(SoftMaxer); ok {
return sm.SoftMax(x, axis, opts...)
}
return nil, errors.Errorf("Unable to apply SoftMax. Engine %T does not support that.", x.Engine())
}
// SoftMaxB applies softmax backwards operation
func SoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
if sm, ok := output.Engine().(SoftMaxer); ok {
return sm.SoftMaxB(output, grad, axis, opts...)
}
return nil, errors.Errorf("Unable to apply SoftMaxB. Engine %T does not support that.", output.Engine())
}
// LogSoftMaxB applies softmax backwards operation
func LogSoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
if sm, ok := output.Engine().(SoftMaxer); ok {
return sm.LogSoftMaxB(output, grad, axis, opts...)
}
return nil, errors.Errorf("Unable to apply SoftMaxB. Engine %T does not support that.", output.Engine())
}