Skip to content
This repository was archived by the owner on Feb 13, 2025. It is now read-only.

Implement Deep Q-Network #617

Merged
merged 36 commits into from
Aug 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
5dc50eb
Copy code from Colab
seungjaeryanlee Jun 29, 2020
9f049af
Merge branch 'master' into dqn
seungjaeryanlee Jun 29, 2020
6268d58
Use .scalarized() to convert TF scalar to Swift
seungjaeryanlee Jun 30, 2020
51d1fad
Improve code clarity
seungjaeryanlee Jun 30, 2020
84df320
Save isDone as Tensor<Bool>
seungjaeryanlee Jun 30, 2020
2b99489
Save and use isDone for target calculation
seungjaeryanlee Jun 30, 2020
18e6294
Add commented parallelized training implementation
seungjaeryanlee Jun 30, 2020
36c1ddf
Save learning curve plot
seungjaeryanlee Jun 30, 2020
da57062
Use parallelized training with custom gatherNd
seungjaeryanlee Jul 1, 2020
2ec956c
Add minBufferSize parameter
seungjaeryanlee Jul 1, 2020
dab2a3f
Remove comments and refactor code
seungjaeryanlee Jul 1, 2020
0bc60ca
Fix bug where state was updated
seungjaeryanlee Jul 1, 2020
01074d9
Simplify code
seungjaeryanlee Jul 1, 2020
eca8a92
Save TD loss curve
seungjaeryanlee Jul 1, 2020
ae087dd
Purge uses of _Raw operations
seungjaeryanlee Jul 2, 2020
4acd6ce
Use Huber loss instead of MSE
seungjaeryanlee Jul 2, 2020
22aaf75
Simplify Tensor initialization
seungjaeryanlee Jul 2, 2020
24392f3
Set device explicitly on Tensor creation
seungjaeryanlee Jul 2, 2020
441ab35
Merge branch 'master' into dqn
seungjaeryanlee Aug 3, 2020
ccfa087
Add minBufferSize to Agent argument
seungjaeryanlee Aug 3, 2020
65de04e
Use soft target updates
seungjaeryanlee Aug 3, 2020
bcbb7e2
Fix bug where isDone was used wrong
seungjaeryanlee Aug 3, 2020
a203226
Fix bug where target net is initialized with soft update
seungjaeryanlee Aug 3, 2020
e757c0f
Follow hyperparameters in swift-rl
seungjaeryanlee Aug 3, 2020
d2be5bd
Run evaluation episode for every training episode
seungjaeryanlee Aug 3, 2020
6a118ab
Implement combined experience replay
seungjaeryanlee Aug 4, 2020
ce539e5
Implement double DQN
seungjaeryanlee Aug 4, 2020
cf7b96a
Add options to toggle CER and DDQN
seungjaeryanlee Aug 4, 2020
98b4647
Refactor code
seungjaeryanlee Aug 4, 2020
e00901a
Add updateTargetQNet to Agent class
seungjaeryanlee Aug 4, 2020
bca2614
Use TF-Agents hyperparameters
seungjaeryanlee Aug 4, 2020
45b880e
Changed ReplayBuffer to play better with GPU eager mode, restructured…
BradLarson Aug 5, 2020
356c989
Fix ReplayBuffer pass-by-value bug
seungjaeryanlee Aug 6, 2020
d774fad
Use epsilon decay for more consistent performance
seungjaeryanlee Aug 6, 2020
a10f201
Add documentation and improve names
seungjaeryanlee Aug 7, 2020
4aa9296
Document Agent and ReplayBuffer parameters
seungjaeryanlee Aug 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions Gym/DQN/Agent.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import TensorFlow

// Force unwrapping with `!` does not provide source location when unwrapping `nil`, so we instead
// make a utility function for debuggability.
extension Optional {
fileprivate func unwrapped(file: StaticString = #filePath, line: UInt = #line) -> Wrapped {
guard let unwrapped = self else {
fatalError("Value is nil", file: (file), line: line)
}
return unwrapped
}
}

/// A Deep Q-Network.
///
/// A Q-network is a neural network that receives the observation (state) as input and estimates
/// the action values (Q values) of each action. For more information, check Human-level control
/// through deep reinforcement learning (Mnih et al., 2015).
struct DeepQNetwork: Layer {
typealias Input = Tensor<Float>
typealias Output = Tensor<Float>

var l1, l2: Dense<Float>

init(observationSize: Int, hiddenSize: Int, actionCount: Int) {
l1 = Dense<Float>(inputSize: observationSize, outputSize: hiddenSize, activation: relu)
l2 = Dense<Float>(inputSize: hiddenSize, outputSize: actionCount, activation: identity)
}

@differentiable
func callAsFunction(_ input: Input) -> Output {
return input.sequenced(through: l1, l2)
}
}

/// Agent that uses the Deep Q-Network.
///
/// Deep Q-Network is an algorithm that trains a Q-network that estimates the action values of
/// each action given an observation (state). The Q-network is trained iteratively using the
/// Bellman equation. For more information, check Human-level control through deep reinforcement
/// learning (Mnih et al., 2015).
class DeepQNetworkAgent {
/// The Q-network uses to estimate the action values.
var qNet: DeepQNetwork
/// The copy of the Q-network updated less frequently to stabilize the
/// training process.
var targetQNet: DeepQNetwork
/// The optimizer used to train the Q-network.
let optimizer: Adam<DeepQNetwork>
/// The replay buffer that stores experiences of the interactions between the
/// agent and the environment. The Q-network is trained from experiences
/// sampled from the replay buffer.
let replayBuffer: ReplayBuffer
/// The discount factor that measures how much to weight to give to future
/// rewards when calculating the action value.
let discount: Float
/// The minimum replay buffer size before the training starts.
let minBufferSize: Int
/// If enabled, uses the Double DQN update equation instead of the original
/// DQN equation. This mitigates the overestimation problem of DQN. For more
/// information about Double DQN, check Deep Reinforcement Learning with
/// Double Q-learning (Hasselt, Guez, and Silver, 2015).
let doubleDQN: Bool
let device: Device

init(
qNet: DeepQNetwork,
targetQNet: DeepQNetwork,
optimizer: Adam<DeepQNetwork>,
replayBuffer: ReplayBuffer,
discount: Float,
minBufferSize: Int,
doubleDQN: Bool,
device: Device
) {
self.qNet = qNet
self.targetQNet = targetQNet
self.optimizer = optimizer
self.replayBuffer = replayBuffer
self.discount = discount
self.minBufferSize = minBufferSize
self.doubleDQN = doubleDQN
self.device = device

// Copy Q-network to Target Q-network before training
updateTargetQNet(tau: 1)
}

func getAction(state: Tensor<Float>, epsilon: Float) -> Tensor<Int32> {
if Float(np.random.uniform()).unwrapped() < epsilon {
return Tensor<Int32>(numpy: np.array(np.random.randint(0, 2), dtype: np.int32))!
} else {
// Neural network input needs to be 2D
let tfState = Tensor<Float>(numpy: np.expand_dims(state.makeNumpyArray(), axis: 0))!
let qValues = qNet(tfState)[0]
return Tensor<Int32>(qValues[1].scalarized() > qValues[0].scalarized() ? 1 : 0, on: device)
}
}

func train(batchSize: Int) -> Float {
// Don't train if replay buffer is too small
if replayBuffer.count >= minBufferSize {
let (tfStateBatch, tfActionBatch, tfRewardBatch, tfNextStateBatch, tfIsDoneBatch) =
replayBuffer.sample(batchSize: batchSize)

let (loss, gradients) = valueWithGradient(at: qNet) { qNet -> Tensor<Float> in
// Compute prediction batch
let npActionBatch = tfActionBatch.makeNumpyArray()
let npFullIndices = np.stack(
[np.arange(batchSize, dtype: np.int32), npActionBatch], axis: 1)
let tfFullIndices = Tensor<Int32>(numpy: npFullIndices)!
let stateQValueBatch = qNet(tfStateBatch)
let predictionBatch = stateQValueBatch.dimensionGathering(atIndices: tfFullIndices)

// Compute target batch
let nextStateQValueBatch: Tensor<Float>
if self.doubleDQN == true {
// Double DQN
let npNextStateActionBatch = self.qNet(tfNextStateBatch).argmax(squeezingAxis: 1)
.makeNumpyArray()
let npNextStateFullIndices = np.stack(
[np.arange(batchSize, dtype: np.int32), npNextStateActionBatch], axis: 1)
let tfNextStateFullIndices = Tensor<Int32>(numpy: npNextStateFullIndices)!
nextStateQValueBatch = self.targetQNet(tfNextStateBatch).dimensionGathering(
atIndices: tfNextStateFullIndices)
} else {
// DQN
nextStateQValueBatch = self.targetQNet(tfNextStateBatch).max(squeezingAxes: 1)
}
let targetBatch: Tensor<Float> =
tfRewardBatch + self.discount * (1 - Tensor<Float>(tfIsDoneBatch)) * nextStateQValueBatch

return huberLoss(
predicted: predictionBatch,
expected: targetBatch,
delta: 1
)
}
optimizer.update(&qNet, along: gradients)

return loss.scalarized()
}
return 0
}

func updateTargetQNet(tau: Float) {
self.targetQNet.l1.weight =
tau * Tensor<Float>(self.qNet.l1.weight) + (1 - tau) * self.targetQNet.l1.weight
self.targetQNet.l1.bias =
tau * Tensor<Float>(self.qNet.l1.bias) + (1 - tau) * self.targetQNet.l1.bias
self.targetQNet.l2.weight =
tau * Tensor<Float>(self.qNet.l2.weight) + (1 - tau) * self.targetQNet.l2.weight
self.targetQNet.l2.bias =
tau * Tensor<Float>(self.qNet.l2.bias) + (1 - tau) * self.targetQNet.l2.bias
}
}
45 changes: 45 additions & 0 deletions Gym/DQN/Gathering.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import TensorFlow

extension Tensor where Scalar: TensorFlowFloatingPoint {
@inlinable
@differentiable(wrt: self)
public func dimensionGathering<Index: TensorFlowIndex>(
atIndices indices: Tensor<Index>
) -> Tensor {
return _Raw.gatherNd(params: self, indices: indices)
}

/// Derivative of `_Raw.gatherNd`.
///
/// Ported from TensorFlow Python reference implementation:
/// https://github.com/tensorflow/tensorflow/blob/r2.2/tensorflow/python/ops/array_grad.py#L691-L701
@inlinable
@derivative(of: dimensionGathering)
func _vjpDimensionGathering<Index: TensorFlowIndex>(
atIndices indices: Tensor<Index>
) -> (value: Tensor, pullback: (Tensor) -> Tensor) {
let shapeTensor = Tensor<Index>(self.shapeTensor)
let value = _Raw.gatherNd(params: self, indices: indices)
return (
value,
{ v in
let dparams = _Raw.scatterNd(indices: indices, updates: v, shape: shapeTensor)
return dparams
}
)
}
}
105 changes: 105 additions & 0 deletions Gym/DQN/ReplayBuffer.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import TensorFlow

/// Replay buffer to store the agent's experiences.
///
/// Vanilla Q-learning only trains on the latest experience. Deep Q-network uses
/// a technique called "experience replay", where all experience is stored into
/// a replay buffer. By storing experience, the agent can reuse the experiences
/// and also train in batches. For more information, check Human-level control
/// through deep reinforcement learning (Mnih et al., 2015).
class ReplayBuffer {
/// The maximum size of the replay buffer. When the replay buffer is full,
/// new elements replace the oldest element in the replay buffer.
let capacity: Int
/// If enabled, uses Combined Experience Replay (CER) sampling instead of the
/// uniform random sampling in the original DQN paper. Original DQN samples
/// batch uniformly randomly in the replay buffer. CER always includes the
/// most recent element and samples the rest of the batch uniformly randomly.
/// This makes the agent more robust to different replay buffer capacities.
/// For more information about Combined Experience Replay, check A Deeper Look
/// at Experience Replay (Zhang and Sutton, 2017).
let combined: Bool

/// The states that the agent observed.
@noDerivative var states: [Tensor<Float>] = []
/// The actions that the agent took.
@noDerivative var actions: [Tensor<Int32>] = []
/// The rewards that the agent received from the environment after taking
/// an action.
@noDerivative var rewards: [Tensor<Float>] = []
/// The next states that the agent received from the environment after taking
/// an action.
@noDerivative var nextStates: [Tensor<Float>] = []
/// The episode-terminal flag that the agent received after taking an action.
@noDerivative var isDones: [Tensor<Bool>] = []
/// The current size of the replay buffer.
var count: Int { return states.count }

init(capacity: Int, combined: Bool) {
self.capacity = capacity
self.combined = combined
}

func append(
state: Tensor<Float>,
action: Tensor<Int32>,
reward: Tensor<Float>,
nextState: Tensor<Float>,
isDone: Tensor<Bool>
) {
if count >= capacity {
// Erase oldest SARS if the replay buffer is full
states.removeFirst()
actions.removeFirst()
rewards.removeFirst()
nextStates.removeFirst()
isDones.removeFirst()
}
states.append(state)
actions.append(action)
rewards.append(reward)
nextStates.append(nextState)
isDones.append(isDone)
}

func sample(batchSize: Int) -> (
stateBatch: Tensor<Float>,
actionBatch: Tensor<Int32>,
rewardBatch: Tensor<Float>,
nextStateBatch: Tensor<Float>,
isDoneBatch: Tensor<Bool>
) {
let indices: Tensor<Int32>
if self.combined == true {
// Combined Experience Replay
let sampledIndices = (0..<batchSize - 1).map { _ in Int32.random(in: 0..<Int32(count)) }
indices = Tensor<Int32>(shape: [batchSize], scalars: sampledIndices + [Int32(count) - 1])
} else {
// Vanilla Experience Replay
let sampledIndices = (0..<batchSize).map { _ in Int32.random(in: 0..<Int32(count)) }
indices = Tensor<Int32>(shape: [batchSize], scalars: sampledIndices)
}

let stateBatch = Tensor(stacking: states).gathering(atIndices: indices, alongAxis: 0)
let actionBatch = Tensor(stacking: actions).gathering(atIndices: indices, alongAxis: 0)
let rewardBatch = Tensor(stacking: rewards).gathering(atIndices: indices, alongAxis: 0)
let nextStateBatch = Tensor(stacking: nextStates).gathering(atIndices: indices, alongAxis: 0)
let isDoneBatch = Tensor(stacking: isDones).gathering(atIndices: indices, alongAxis: 0)

return (stateBatch, actionBatch, rewardBatch, nextStateBatch, isDoneBatch)
}
}
Loading