-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoptimizer.py
30 lines (23 loc) · 916 Bytes
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import modules
class SGD(object):
"""Implements stochastic gradient descent"""
def __init__(self, params, lr):
if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
self.lr = lr
if isinstance(params, torch.Tensor):
raise TypeError("params argument given to the optimizer should be "
"an iterable of Tensors or dicts, but got " +
torch.typename(params))
if len(params) == 0:
raise ValueError("optimizer got an empty parameter list")
self.params = params
def step(self):
"""Performs a single optimization step."""
for (p, grad) in self.params:
p.data -= (self.lr * grad)
def zero_grad(self):
"""Clears the gradients of all Tensors."""
for (p, grad) in self.params:
grad.zero_()