-
Notifications
You must be signed in to change notification settings - Fork 149
/
Copy pathDecoder.swift
196 lines (176 loc) · 6.92 KB
/
Decoder.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Foundation
import TensorFlow
// This whole struct should probably be merged into the PersonLab model struct when we no longer
// need to do CPUTensor wrapping when SwiftRT fixes the GPU->CPU copy issue.
struct PoseDecoder {
let heatmap: CPUTensor<Float>
let offsets: CPUTensor<Float>
let displacementsFwd: CPUTensor<Float>
let displacementsBwd: CPUTensor<Float>
let config: Config
init(for results: PersonlabHeadsResults, with config: Config) {
// Hardcoded to batch size == 1 at the moment
self.heatmap = CPUTensor<Float>(results.heatmap[0])
self.offsets = CPUTensor<Float>(results.offsets[0])
self.displacementsFwd = CPUTensor<Float>(results.displacementsFwd[0])
self.displacementsBwd = CPUTensor<Float>(results.displacementsBwd[0])
self.config = config
}
func decode() -> [Pose] {
var poses = [Pose]()
var sortedLocallyMaximumKeypoints = getSortedLocallyMaximumKeypoints()
while sortedLocallyMaximumKeypoints.count > 0 {
let rootKeypoint = sortedLocallyMaximumKeypoints.removeFirst()
if rootKeypoint.isWithinRadiusOfCorrespondingKeypoints(in: poses, radius: config.nmsRadius) {
continue
}
var pose = Pose(resolution: self.config.inputImageSize)
pose.add(rootKeypoint)
// Recursivelly parse keypoint tree going in both forwards & backwards directions optimally
recursivellyAddNextKeypoint(
after: rootKeypoint,
into: &pose
)
if getPoseScore(for: pose, considering: poses) > config.poseScoreThreshold {
poses.append(pose)
}
}
return poses
}
func recursivellyAddNextKeypoint(after previousKeypoint: Keypoint, into pose: inout Pose) {
for (nextKeypointIndex, direction) in getNextKeypointIndexAndDirection(previousKeypoint.index) {
if pose.getKeypoint(nextKeypointIndex) == nil {
let nextKeypoint = followDisplacement(
from: previousKeypoint,
to: nextKeypointIndex,
using: direction == .fwd ? displacementsFwd : displacementsBwd
)
pose.add(nextKeypoint)
recursivellyAddNextKeypoint(after: nextKeypoint, into: &pose)
}
}
}
func followDisplacement(
from previousKeypoint: Keypoint, to nextKeypointIndex: KeypointIndex,
using displacements: CPUTensor<Float>
) -> Keypoint {
let displacementKeypointIndexY = keypointPairToDisplacementIndexMap[
Set([previousKeypoint.index, nextKeypointIndex])]!
let displacementKeypointIndexX = displacementKeypointIndexY + displacements.shape[2] / 2
let displacementYIndex = getUnstridedIndex(y: previousKeypoint.y)
let displacementXIndex = getUnstridedIndex(x: previousKeypoint.x)
let displacementY = displacements[
displacementYIndex,
displacementXIndex,
displacementKeypointIndexY
]
let displacementX = displacements[
displacementYIndex,
displacementXIndex,
displacementKeypointIndexX
]
let displacedY = getUnstridedIndex(y: previousKeypoint.y + displacementY)
let displacedX = getUnstridedIndex(x: previousKeypoint.x + displacementX)
let yOffset = offsets[
displacedY,
displacedX,
nextKeypointIndex.rawValue
]
let xOffset = offsets[
displacedY,
displacedX,
nextKeypointIndex.rawValue + KeypointIndex.allCases.count
]
// If we are getting the offset from an exact point in the heatmap, we should add this
// offset parting from that exact point in the heatmap, so we just nearest neighbour
// interpolate it back, then re strech using output stride, and then add said offset.
let nextY = Float(displacedY * config.outputStride) + yOffset
let nextX = Float(displacedX * config.outputStride) + xOffset
return Keypoint(
y: nextY,
x: nextX,
index: nextKeypointIndex,
score: heatmap[
displacedY, displacedX, nextKeypointIndex.rawValue
]
)
}
func scoreIsMaximumInLocalWindow(heatmapY: Int, heatmapX: Int, score: Float, keypointIndex: Int)
-> Bool
{
let yStart = max(heatmapY - config.keypointLocalMaximumRadius, 0)
let yEnd = min(heatmapY + config.keypointLocalMaximumRadius, heatmap.shape[0] - 1)
for windowY in yStart...yEnd {
let xStart = max(heatmapX - config.keypointLocalMaximumRadius, 0)
let xEnd = min(heatmapX + config.keypointLocalMaximumRadius, heatmap.shape[1] - 1)
for windowX in xStart...xEnd {
if heatmap[windowY, windowX, keypointIndex] > score {
return false
}
}
}
return true
}
func getUnstridedIndex(y: Float) -> Int {
let downScaled = y / Float(config.outputStride)
let clamped = min(max(0, downScaled.rounded()), Float(heatmap.shape[0] - 1))
return Int(clamped)
}
func getUnstridedIndex(x: Float) -> Int {
let downScaled = x / Float(config.outputStride)
let clamped = min(max(0, downScaled.rounded()), Float(heatmap.shape[1] - 1))
return Int(clamped)
}
func getSortedLocallyMaximumKeypoints() -> [Keypoint] {
var sortedLocallyMaximumKeypoints = [Keypoint]()
for heatmapY in 0..<heatmap.shape[0] {
for heatmapX in 0..<heatmap.shape[1] {
for keypointIndex in 0..<heatmap.shape[2] {
let score = heatmap[heatmapY, heatmapX, keypointIndex]
if score < config.keypointScoreThreshold { continue }
if scoreIsMaximumInLocalWindow(
heatmapY: heatmapY,
heatmapX: heatmapX,
score: score,
keypointIndex: keypointIndex
) {
sortedLocallyMaximumKeypoints.append(
Keypoint(
heatmapY: heatmapY,
heatmapX: heatmapX,
index: keypointIndex,
score: score,
offsets: offsets,
outputStride: config.outputStride
)
)
}
}
}
}
sortedLocallyMaximumKeypoints.sort { $0.score > $1.score }
return sortedLocallyMaximumKeypoints
}
func getPoseScore(for pose: Pose, considering poses: [Pose]) -> Float {
var notOverlappedKeypointScoreAccumulator: Float = 0
for keypoint in pose.keypoints {
if !keypoint!.isWithinRadiusOfCorrespondingKeypoints(in: poses, radius: config.nmsRadius) {
notOverlappedKeypointScoreAccumulator += keypoint!.score
}
}
return notOverlappedKeypointScoreAccumulator / Float(KeypointIndex.allCases.count)
}
}