-
Notifications
You must be signed in to change notification settings - Fork 0
/
semirings.py
381 lines (300 loc) · 12.3 KB
/
semirings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
from collections.abc import Sequence
import dataclasses
import functools
from typing import Any, Callable, Generic, Optional, Tuple, TypeVar
import torch
import torch.nn.functional as F
# for documentation
DType = Any
PyTree = Any
# semiring values
T = TypeVar("T")
S = TypeVar("S")
def _tree_leaves(tree: PyTree) -> list[Any]:
if isinstance(tree, (list, tuple)):
leaves = []
for item in tree:
leaves.extend(_tree_leaves(item))
return leaves
elif isinstance(tree, dict):
leaves = []
for key, value in tree.items():
leaves.extend(_tree_leaves(value))
return leaves
else:
return [tree]
def value_shape(x: PyTree) -> tuple[int, ...]:
if x is None or len(x) == 0:
raise ValueError(f"No common shape can be derived for an empty PyTree: {x!r}")
else:
shapes = [i.shape for i in _tree_leaves(x)]
result = shapes[0]
for i in shapes[1:]:
if i != result:
raise ValueError(
"A semiring value must consist of ndarrays of a common shape. "
f"Got inconsistent shapes {result} vs {i} for PyTree: {x!r}"
)
return result
def value_dtype(x: list[T]) -> DType:
"""Obtains the dtypes of a semiring value.
Different leaves of a semiring value may have different dtypes. Methods
such as Semiring.{zeros,ones} can take a PyTree of dtypes in the same
structure as the corresponding semiring values. This function can be used
to extract such a dtype PyTree from a semiring value.
Args:
x: Some semiring value.
Returns:
dtypes in the same structure as x.
"""
return map(lambda x_: x_.dtype, x)
class Semiring(Generic[T]):
"""Base Semiring interface.
See https://en.wikipedia.org/wiki/Semiring for what a semiring is. A Semiring
object holds methods that implement the semiring operations. To simplify
non-semiring operations on the semiring values, the semiring values are not
typed: for most basic semirings, each value is a single ndarray; for some more
complex semirings (e.g. Expectation or Cartesian), the values can be a tuple
of ndarrays.
In general, a semiring value under some semiring is represented as a PyTree
of identically shaped ndarrays, with possibly different dtypes. The shape
and dtypes of a semiring value can be obtained with methods
`last.semirings.value_shape()` and `last.semirings.value_dtype()`.
Semiring is not an abstract base class because we allow operations to be
unimplemented (e.g. `prod`, is not commonly used).
* Reductions (sum) can be tricky to implement correctly, here are
two important things to watch out for:
* `dim` can be in the range [-rank, rank).
* The input can have 0-sized dimensions.
"""
def zeros(self, shape: Sequence[int], dtype: Optional[DType] = None) -> T:
"""Semiring zeros in the given shape and dtype.
Args:
shape: Desired output shape.
dtype: Optional PyTree of dtypes.
Returns:
If dtype is None, semiring zero values in the specified shape with
reasonable default dtypes. Otherwise, semiring zero values in the
specified shape with the specified dtypes.
"""
raise NotImplementedError
def ones(self, shape: Sequence[int], dtype: Optional[DType] = None) -> T:
"""Semiring ones in the given shape and dtype.
Args:
shape: Desired output shape.
dtype: Optional PyTree of dtypes.
Returns:
If dtype is None, semiring one values in the specified shape with
reasonable default dtypes. Otherwise, semiring one values in the
specified shape with the specified dtypes.
"""
raise NotImplementedError
def times(self, a: T, b: T) -> T:
"""Semiring multiplication between two values."""
raise NotImplementedError
def plus(self, a: T, b: T) -> T:
"""Semiring addition between two values."""
raise NotImplementedError
def prod(self, a: T, dim: int) -> T:
"""Semiring multiplication along a single dim."""
raise NotImplementedError
def sum(self, a: T, dim: int) -> T:
"""Semiring addition along a single dim."""
raise NotImplementedError
class _Real(Semiring[torch.Tensor]):
@staticmethod
def zeros(shape: Sequence[int], dtype: Optional[DType] = None) -> torch.Tensor:
return torch.zeros(shape, dtype=dtype)
@staticmethod
def ones(shape: Sequence[int], dtype: Optional[DType] = None) -> torch.Tensor:
return torch.ones(shape, dtype=dtype)
@staticmethod
def times(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return a * b
@staticmethod
def plus(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return a + b
@staticmethod
def prod(a: torch.Tensor, dim: int) -> torch.Tensor:
return a.prod(dim=dim)
@staticmethod
def sum(a: torch.Tensor, dim: int) -> torch.Tensor:
return a.sum(dim=dim)
Real = _Real()
def _check_dim(a: torch.Tensor, dim: int) -> None:
if not isinstance(dim, int):
raise ValueError(f"Only int dimension is supported, got dim={dim!r}")
if not -a.dim() <= dim < a.dim():
raise ValueError(f"Invalid reduction dim={dim!r} for input shape {a.shape}")
class _Log(Semiring[torch.Tensor]):
@staticmethod
def zeros(shape: Sequence[int], dtype: Optional[DType] = None) -> torch.Tensor:
return torch.full(shape, float("-inf"), dtype=dtype)
@staticmethod
def ones(shape: Sequence[int], dtype: Optional[DType] = None) -> torch.Tensor:
return torch.zeros(shape, dtype=dtype)
@staticmethod
def times(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return a + b
@staticmethod
def plus(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return LogPlus.apply(a, b)
@staticmethod
def prod(a: torch.Tensor, dim: int) -> torch.Tensor:
return a.sum(dim=dim)
@classmethod
def sum(cls, a: torch.Tensor, dim: int) -> torch.Tensor:
_check_dim(a, dim)
if a.numel() > 0:
return LogSum.apply(a, dim)
if dim < 0:
dim += a.dim()
result_shape = a.shape[:dim] + a.shape[dim + 1 :]
return cls.zeros(result_shape, a.dtype)
# Specialized log{add,sum}exp with safe gradients.
#
# Scenarios:
# - All operands are finite: As expected.
# - All operands are -inf: Sum should be -inf. Gradient should be 0.
# - All operands are +inf: Sum should be +inf. Gradient should be NaN.
# - Mixed finite & -inf operands: Sum as expected. Gradient should be 0 for
# -inf; non-0 for others.
# - Mixed finite & +inf operands: Sum should +inf. Gradient should be NaN for
# +inf; 0 for others.
# - Mixed -inf & +inf operands: Sum should be +inf. Gradient should be NaN for
# +inf; 0 for -inf.
# - Mixed finite, -inf & +inf operands: Sum should be +inf. Gradient should be
# NaN for +inf; 0 for others.
#
# The different treatment of -inf & +inf comes from their different sources.
# - +inf is an indicator of a true error, e.g. an overflow somewhere. It's
# thus desirabled to not silence such issues.
# - -inf often arises from perfectly legitimate computations such as
# `logaddexp(-inf, -inf + x)`, where `x` should not receive a NaN gradient.
class LogPlus(torch.autograd.Function):
@staticmethod
def forward(ctx, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
c = torch.maximum(a, b)
safe = torch.isfinite(c)
c = torch.where(safe, c, 0)
ea = torch.exp(a - c)
eb = torch.exp(b - c)
z = ea + eb
ctx.save_for_backward(ea, eb, z)
return c + torch.log(z)
@staticmethod
def backward(ctx, grad_out):
ea, eb, z = ctx.saved_tensors
safe = z != 0
z = torch.where(safe, z, 1)
scale = grad_out / z
return scale * ea, scale * eb
class LogSum(torch.autograd.Function):
@staticmethod
def forward(ctx, a: torch.Tensor, dim: int) -> torch.Tensor:
c = torch.max(a, dim=dim, keepdim=True)[0]
safe = torch.isfinite(c)
c = torch.where(safe, c, 0)
e = torch.exp(a - c)
z = torch.sum(e, dim=dim, keepdim=True)
result = torch.squeeze(c, dim=dim) + torch.log(torch.squeeze(z, dim=dim))
ctx.save_for_backward(e, z)
ctx.constant = dim
return result
@staticmethod
def backward(ctx, grad_out):
e, z = ctx.saved_tensors
dim = ctx.constant
safe = z != 0
z = torch.where(safe, z, 1)
grad_out = grad_out.unsqueeze(dim)
return ((grad_out / z) * e, None)
Log = _Log()
# TODO: MaxTropical
class _MaxTropical(Semiring):
@staticmethod
def zeros(shape: Sequence[int], dtype: Optional[DType] = None) -> torch.Tensor:
return torch.full(shape, float("-inf"), dtype=dtype)
@staticmethod
def ones(shape: Sequence[int], dtype: Optional[DType] = None) -> torch.Tensor:
return torch.zeros(shape, dtype=dtype)
@staticmethod
def times(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return a + b
@staticmethod
def plus(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
a, b = torch.broadcast_tensors(a, b)
return TropicMaximum.apply(a, b)
@staticmethod
def prod(a: torch.Tensor, dim: int) -> torch.Tensor:
return a.sum(dim=dim)
@classmethod
def sum(cls, a: torch.Tensor, dim: int) -> torch.Tensor:
_check_dim(a, dim)
if a.numel() > 0:
return TropicMax.apply(a, dim)
if dim < 0:
dim += a.dim()
result_shape = a.shape[:dim] + a.shape[dim + 1 :]
return cls.zeros(result_shape, a.dtype)
MaxTropical = _MaxTropical()
class TropicMaximum(torch.autograd.Function):
@staticmethod
def forward(ctx, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(a >= b)
return torch.maximum(a, b)
@staticmethod
def backward(ctx, grad_out):
greater_a = ctx.saved_tensors
return grad_out * greater_a[0], grad_out * (1 - greater_a[0].long())
class TropicMax(torch.autograd.Function):
@staticmethod
def forward(ctx, a: torch.Tensor, dim: int) -> torch.Tensor:
argmax = torch.argmax(a, dim=dim)
width = a.shape[dim]
ctx.constant = width, dim
ctx.save_for_backward(argmax)
return torch.max(a.squeeze(-1), dim, keepdim=True)[0]
@staticmethod
def backward(ctx, grad_out):
# def _one_hot(input, num_classes):
# output = torch.zeros(input.size(0), num_classes)
# output[torch.arange(input.size(0)), input] = 1
# return output
# def _one_hot(input, num_classes, dim, dtype=grad_out.dtype):
# init = F.one_hot(input, num_classes)
# print(f"\n\ninit shape:{init.shape}")
# shape = list(init.shape)
# shape.insert(dim, shape.pop(-1))
# return torch.reshape(init, tuple(shape))
# def _one_hot(input, num_classes, dim):
# init = F.one_hot(input, num_classes)
# shape = list(input.shape)
# shape.insert(dim + 1, num_classes)
# shape = tuple(shape)
# return torch.reshape(init, shape)
argmax = ctx.saved_tensors[0]
width, dim = ctx.constant
mask = _one_hot(argmax, width, dim)
grad_out = torch.unsqueeze(grad_out, dim)
# print(
# f"width:{width}\ndim:{dim}\nargmax shape:{argmax.shape}\ngrad_out shape:{grad_out.shape}\nmask shape:{mask.shape}"
# )
return (grad_out * mask, None)
def _one_hot(input, num_classes, dtype=torch.float32):
if isinstance(input, torch.Tensor):
one_hot_vectors = torch.zeros((input.size(0), num_classes), dtype=dtype)
one_hot_vectors[torch.arange(input.size(0)), input] = 1
return one_hot_vectors
def _one_hot(input, num_classes, dim, dtype=torch.float32):
"""
Creates a one-hot encoding of the given input.
Implementation of jax.nn.one_hot()
"""
if not isinstance(input, torch.Tensor):
raise ValueError("input must be a PyTorch tensor.")
if num_classes <= 0:
raise ValueError("Number of classes must be positive.")
eye_matrix = torch.eye(num_classes, dtype=dtype)
one_hot_tensor = eye_matrix[input]
return one_hot_tensor