This project implements a Deep Q-Network to solve the classic CartPole control problem. The agent learns to balance a pole attached to a cart by applying horizontal forces to the cart. This implementation includes visualization tools, model persistence, and interactive gameplay modes.
The CartPole system consists of a cart that can move horizontally and a pole that can rotate around a pivot point on the cart. The system's state is described by four variables:
$x$ : Cart position -
$\dot{x}$ : Cart velocity -
$\theta$ : Pole angle -
$\dot{\theta}$ : Pole angular velocity
The equations of motion for the system are:
$g$ : Gravity constant -
$F$ : Applied force -
$m$ : Pole mass -
$M$ : Cart mass -
$l$ : Pole length
The DQN uses a neural network to approximate the Q-function:
Network structure:
Input Layer (4) → Hidden Layer (64) → ReLU → Hidden Layer (64) → ReLU → Output Layer (2)
├── # Interactive game environment
├── # DQN training implementation
├── # Training visualization tools
├── models/ # Saved model checkpoints
└── logs/ # Training logs and metrics
- Normalized state vector:
$[x/W, \dot{x}/5, \theta/(\pi/2), \dot{\theta}/2]$ - Where W is screen width
- Binary action space: {left force (-0.2), right force (0.2)}
- +1 for each timestep the pole remains upright
- 0 on episode termination
MEMORY_SIZE = 100000 # Experience replay buffer size
BATCH_SIZE = 64 # Training batch size
GAMMA = 0.99 # Discount factor
EPSILON_START = 1.0 # Initial exploration rate
EPSILON_END = 0.01 # Final exploration rate
EPSILON_DECAY = 0.995 # Exploration decay rate
poetry install
This will:
- Initialize the DQN agent
- Train for specified episodes
- Save model checkpoints and logs
- Switch between AI and human control with 'M' key
- Use arrow keys for manual control
- Watch trained agent perform
Generates plots for:
- Training rewards
- Episode lengths
- Learning curves
- Q-value distributions
The agent typically achieves:
- Convergence within 500-1000 episodes
- Average episode length >200 steps after training
- Stable pole balancing for extended periods
Uses two networks to reduce overestimation:
- Policy network: Action selection
- Target network: Value estimation
Update rule:
target = reward + GAMMA * target_net(next_state).max()
loss = MSE(policy_net(state), target)
Stores transitions
self.memory.append((state, action, reward, next_state, done))
Epsilon-greedy with decay:
- Prioritized Experience Replay
- Dueling DQN architecture
- Noisy Networks for exploration
- Multi-step returns
- Using wandb