-
Notifications
You must be signed in to change notification settings - Fork 149
/
Copy pathCellRule.swift
127 lines (111 loc) · 4.52 KB
/
CellRule.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import TensorFlow
struct CellRule: Layer {
@noDerivative var perceptionFilter: Tensor<Float>
@noDerivative let fireRate: Float
var conv1: Conv2D<Float>
var conv2: Conv2D<Float>
init(stateChannels: Int, fireRate: Float, useBias: Bool) {
self.fireRate = fireRate
let horizontalSobelKernel =
Tensor<Float>(
shape: [3, 3, 1, 1], scalars: [-1.0, 0.0, 1.0, -2.0, 0.0, 2.0, -1.0, 0.0, 1.0]) / 8.0
let horizontalSobelFilter = horizontalSobelKernel.broadcasted(to: [3, 3, stateChannels, 1])
let verticalSobelKernel =
Tensor<Float>(
shape: [3, 3, 1, 1], scalars: [-1.0, -2.0, -1.0, 0.0, 0.0, 0.0, 1.0, 2.0, 1.0]) / 8.0
let verticalSobelFilter = verticalSobelKernel.broadcasted(to: [3, 3, stateChannels, 1])
let identityKernel = Tensor<Float>(
shape: [3, 3, 1, 1], scalars: [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0])
let identityFilter = identityKernel.broadcasted(to: [3, 3, stateChannels, 1])
perceptionFilter = Tensor(
concatenating: [horizontalSobelFilter, verticalSobelFilter, identityFilter], alongAxis: 3)
conv1 = Conv2D<Float>(filterShape: (1, 1, stateChannels * 3, 128))
conv2 = Conv2D<Float>(
filterShape: (1, 1, 128, stateChannels), useBias: useBias, filterInitializer: zeros())
}
@differentiable
func livingMask(_ input: Tensor<Float>) -> Tensor<Float> {
let alphaChannel = input.slice(
lowerBounds: [0, 0, 0, 3], sizes: [input.shape[0], input.shape[1], input.shape[2], 1])
let localMaximum =
maxPool2D(alphaChannel, filterSize: (1, 3, 3, 1), strides: (1, 1, 1, 1), padding: .same)
return withoutDerivative(at: input) { _ in localMaximum.mask { $0 .> 0.1 } }
}
@differentiable
func perceive(_ input: Tensor<Float>) -> Tensor<Float> {
return depthwiseConv2D(
input, filter: perceptionFilter, strides: (1, 1, 1, 1), padding: .same)
}
@differentiable
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
let perception = perceive(input)
let dx = conv2(relu(conv1(perception)))
let updateDistribution = Tensor<Float>(
randomUniform: [input.shape[0], input.shape[1], input.shape[2], 1], on: input.device)
let updateMask = withoutDerivative(at: input) { _ in
updateDistribution.mask { $0 .< fireRate }
}
let updatedState = input + (dx * updateMask)
let combinedLivingMask = livingMask(input) * livingMask(updatedState)
return updatedState * combinedLivingMask
}
}
func normalizeGradient(_ gradient: CellRule.TangentVector) -> CellRule.TangentVector {
var outputGradient = gradient
for kp in gradient.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
let norm = sqrt(gradient[keyPath: kp].squared().sum())
outputGradient[keyPath: kp] = gradient[keyPath: kp] / (norm + 1e-8)
}
return outputGradient
}
extension Tensor where Scalar: Numeric {
@differentiable(where Scalar: TensorFlowFloatingPoint)
var colorComponents: Tensor {
precondition(self.rank == 3 || self.rank == 4)
if self.rank == 3 {
return self.slice(
lowerBounds: [0, 0, 0], sizes: [self.shape[0], self.shape[1], 4])
} else {
return self.slice(
lowerBounds: [0, 0, 0, 0], sizes: [self.shape[0], self.shape[1], self.shape[2], 4])
}
}
func mask(condition: (Tensor) -> Tensor<Bool>) -> Tensor {
let satisfied = condition(self)
return Tensor(zerosLike: self)
.replacing(with: Tensor(onesLike: self), where: satisfied)
}
}
// Note: the following is an identity function that serves to cut the backward trace into
// smaller identical traces, to improve X10 performance.
@inlinable
@differentiable
func clipBackwardsTrace(_ input: Tensor<Float>) -> Tensor<Float> {
return input
}
@inlinable
@derivative(of: clipBackwardsTrace)
func _vjpClipBackwardsTrace(
_ input: Tensor<Float>
) -> (value: Tensor<Float>, pullback: (Tensor<Float>) -> Tensor<Float>) {
return (
input,
{
LazyTensorBarrier()
return $0
}
)
}