Skip to content

Commit

Permalink
added particle filter, starting to learn from tracking rew
Browse files Browse the repository at this point in the history
  • Loading branch information
mttga committed Nov 7, 2024
1 parent 56da077 commit d098dff
Show file tree
Hide file tree
Showing 6 changed files with 600 additions and 166 deletions.
28 changes: 16 additions & 12 deletions baselines/MAPPO/config/mappo_homogenous_rnn_utracking.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"LR": 0.0005
"NUM_ENVS": 128
"NUM_ENVS": 512
"NUM_STEPS": 128
"TOTAL_TIMESTEPS": 1e9
"TOTAL_TIMESTEPS": 2e9
"FC_DIM_SIZE": 128
"GRU_HIDDEN_DIM": 128
"UPDATE_EPOCHS": 4
Expand All @@ -14,31 +14,35 @@
"VF_COEF": 0.5
"MAX_GRAD_NORM": 0.5
"ACTIVATION": "relu"
"ANNEAL_LR": True
"ANNEAL_LR": False

# ENV
"ENV_NAME": "utracking"
"ENV_KWARGS": {
"num_agents": 2,
"num_landmarks": 2,
"max_steps": 128,
"num_agents": 1,
"num_landmarks": 1,
"max_steps": 256,
"dt": 30,
"prop_range_landmark": [0, 5, 10], # possible propulsor velocities of the landmarks
"prop_range_landmark": [5, 10, 15], # possible propulsor velocities of the landmarks
"rew_type": "tracking",
"rew_pred_thr": 3,
"min_valid_distance": 5,
"min_init_distance": 30,
"max_init_distance": 100,
"max_init_distance": 150,
"pre_init_pos_len": 1000000,
"max_range_dist": 800,
"max_range_dist": 500,
"tracking_method": "pf",
"pf_num_particles": 500,
}

# EXP
"SEED": 0
"NUM_SEEDS": 1
"TUNE": False
"SAVE_PATH": "models"
"ALG_NAME": "mappo_rnn_crashing_penalty"
"ANIMATION_LOG_INTERVAL": 0.25 # percentage of total update steps. animating will slow down training and use more memory
"ANIMATION_MAX_STEPS": 128 # should be the same of the env
"ALG_NAME": "mappo_rnn_pf_minimize_tracking_error"
"ANIMATION_LOG_INTERVAL": 0.1 # percentage of total update steps. animating will slow down training and use more memory
"ANIMATION_MAX_STEPS": 256 # should be the same of the env

# WANDB
"WANDB_MODE": "online"
Expand Down
4 changes: 2 additions & 2 deletions baselines/MAPPO/mappo_rnn_utracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def _env_step(runner_state, unused):
batchify(avail_actions, env.agents, config["NUM_ACTORS"])
)
obs_batch = batchify(last_obs, env.agents, config["NUM_ACTORS"])
print(obs_batch.shape)
print("obs shape:", obs_batch.shape)
ac_in = (
obs_batch[np.newaxis, :],
last_done[np.newaxis, :],
Expand Down Expand Up @@ -643,7 +643,7 @@ def env_step(carry, _):
# TODO: check which dimension is squeezed, gives problem with 1 agent
new_last_done = batchify(
done, env.agents, env.num_agents
).squeeze()
).squeeze(-1)

return (
rng,
Expand Down
39 changes: 39 additions & 0 deletions baselines/QLearning/config/alg/pqn_vdn_rnn_utracking.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# pqn_vdn_ff is suggested for MPE (faster and good enough), but pqn_vdn_rnn is also available
"TOTAL_TIMESTEPS": 1e7
"NUM_ENVS": 128
"MEMORY_WINDOW": 4
"NUM_STEPS": 128
"HIDDEN_SIZE": 128
"NUM_LAYERS": 2
"NORM_INPUT": False
"NORM_TYPE": "layer_norm"
"EPS_START": 1.0
"EPS_FINISH": 0.01
"EPS_DECAY": 0.1
"MAX_GRAD_NORM": 1
"NUM_MINIBATCHES": 8
"NUM_EPOCHS": 4
"LR": 0.00025
"LR_LINEAR_DECAY": True
"GAMMA": 0.99
"LAMBDA": 0.85
#"REW_SCALE": 10. # scale the reward to the original scale of SMAC

# ENV
"ENV_NAME": "utracking"
"ENV_KWARGS": {
"num_agents": 3,
"num_landmarks": 3,
"prop_range_landmark": [0],
"dt": 30,
"min_init_distance": 30,
"max_init_distance": 100,
}

# evaluate
"TEST_DURING_TRAINING": True
"TEST_INTERVAL": 0.05 # as a fraction of updates, i.e. log every 5% of training process
"TEST_NUM_STEPS": 128
"TEST_NUM_ENVS": 512 # number of episodes to average over, can affect performance

"ALG_NAME": "pqn_vdn_rnn_3v3_nomoving" # if you want to change the name of the algo in the metrics
1 change: 1 addition & 0 deletions jaxmarl/environments/hanabi/decks_test.json

Large diffs are not rendered by default.

218 changes: 218 additions & 0 deletions jaxmarl/environments/utracking/particle_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
import jax
from jax import numpy as jnp
from flax import struct
import chex
import numpy as np
from functools import partial


@struct.dataclass
class OneParticleState:
x: float
y: float
theta: float
vel_x: float
vel_y: float


@struct.dataclass
class ParticlesState:
particles: OneParticleState
weights: chex.Array


class ParticleFilter:

def __init__(
self,
num_particles,
std_range=10, # m (standard deviation error of the range measurements)
mu_init_vel=0.1, # m/s
std_init_vel=0.1, # m/s
turn_noise=0.1, # rad
vel_noise=0.05, # m/s
min_weight=0.01,
):
self.num_particles = num_particles
self.std_range = std_range
self.mu_init_vel = mu_init_vel
self.std_init_vel = std_init_vel
self.turn_noise = turn_noise
self.vel_noise = vel_noise
self.min_weight = min_weight

@partial(jax.jit, static_argnums=0)
def reset(self, key, position, range_obs):
"""
Resets particles from a single observation.
- key: rng key
- position: position of the observer
- range_obs: range of the observer
"""

def init_particle(rng):
# Randomly sample the initial position and velocity in the range around observer
rng_a, rng_r, rng_v, rng_o = jax.random.split(rng, 4)
angle = jax.random.uniform(rng_a, minval=0.0, maxval=2 * jnp.pi)
r = jax.random.uniform(rng_r) * self.std_range - self.std_range + range_obs
vel = jax.random.normal(rng_v) * self.std_init_vel + self.mu_init_vel
orientation = jax.random.uniform(rng_o, minval=0, maxval=2 * jnp.pi)
return OneParticleState(
x=position[0] + r * jnp.cos(angle),
y=position[1] + r * jnp.sin(angle),
theta=orientation,
vel_x=vel * jnp.cos(orientation),
vel_y=vel * jnp.sin(orientation),
)

particles = jax.vmap(init_particle)(jax.random.split(key, self.num_particles))
weights = jnp.ones(self.num_particles)

return ParticlesState(particles=particles, weights=weights)

@partial(jax.jit, static_argnums=0)
def step_and_predict(self, rng, state, pos, obs, mask):
"""
Step of the particle filter.
- state: ParticlesState
- pos: positions of the observers (num_observers, x, y)
- obs: observations (num_observers, range)
- mask: mask for the observations (num_observers,)
"""

key_update, key_resample = jax.random.split(rng, 2)

# Update particles
state = self.update_particles(key_update, state)

# Update weights
state = self.update_weights(state, pos, obs, mask)

# Resample or reinit particles if the weights are too low
reinit_cond = jnp.isnan(state.weights).any() # reinit if the max weight is too low or if there are NaNs
state = jax.lax.cond(
reinit_cond,
lambda _: self.reset(
key_update, pos[0], obs[0]
), # reset with the first observation (TODO: take into account masking here)
lambda _: self.resample(key_resample, state),
operand=None,
)

# Estimate position
pos_est = self.estimate_pos(state)

return state, pos_est

@partial(jax.jit, static_argnums=0)
def update_particles(self, key, state, dt=30.0):
"""
Updates the particles with a simple model.
- key: rng key
- state: ParticlesState
- dt: time step in seconds
"""

def update_particle(rng, particle):
# Update particle position and velocity with noise and a simple model
rng_t, rng_v = jax.random.split(rng, 2)
turn = jnp.arctan2(particle.vel_y, particle.vel_x)
orientation = (
turn + jax.random.uniform(rng_t) * self.turn_noise * 2 - self.turn_noise
)
velocity = jnp.sqrt(particle.vel_x**2 + particle.vel_y**2)
velocity = (
velocity
+ jax.random.uniform(rng_v) * self.vel_noise * 2
- self.vel_noise
).clip(0)
forward = velocity * dt
particle = OneParticleState(
x=particle.x + jnp.cos(orientation) * forward,
y=particle.y + jnp.sin(orientation) * forward,
theta=orientation,
vel_x=jnp.cos(orientation) * velocity,
vel_y=jnp.sin(orientation) * velocity,
)
return particle

particles = jax.vmap(update_particle)(
jax.random.split(key, self.num_particles), state.particles
)
return state.replace(particles=particles)

@partial(jax.jit, static_argnums=0)
def update_weights(self, state, pos, obs, mask):
"""
Updates the weights of the particles based on the observations.
- key: rng key
- state: ParticlesState
- pos: positions of the observers (num_observers, x, y)
- obs: observations (num_observers, range)
- mask: mask for the observations (num_observers,)
"""

def gaussian_pdf(x, mu, sigma):
return jnp.exp(-((mu - x) ** 2) / (sigma**2) / 2.0) / jnp.sqrt(
2.0 * jnp.pi * (sigma**2)
)

def get_prob(single_particle, single_pos, single_obs):
# Compute the weight of the particle based on one observation
dist = jnp.sqrt(
(single_particle.x - single_pos[0]) ** 2
+ (single_particle.y - single_pos[1]) ** 2
)
return gaussian_pdf(dist, single_obs, self.std_range)

def get_probs(particle):
# Compute the weight of the particle based on all the observations
return jax.vmap(get_prob, in_axes=(None, 0, 0))(
particle, pos, obs
) # (num_observers,)

probs = jax.vmap(get_probs)(state.particles) # (num_particles, num_observers)
probs = jnp.where(
mask[np.newaxis], probs, 1.0
) # don't use the masks (num_particles, num_observers)
weights = probs.prod(axis=1) # (num_particles,)

return state.replace(weights=weights)

@partial(jax.jit, static_argnums=0)
def resample(self, key, state):
"""
Resampling of particles based on weights using systematic resampling.
- key: rng key
- state: ParticlesState
"""

# Normalize weights
weights = state.weights / state.weights.sum()

cumulative_sum = jnp.cumsum(weights)

# Create evenly spaced positions with a random start point
positions = (
jax.random.uniform(key) + jnp.arange(self.num_particles)
) / self.num_particles

# Use searchsorted to find the index for each position
indexes = jnp.searchsorted(cumulative_sum, positions, side="right")

# Gather resampled particles based on the computed indexes
resampled_particles = jax.tree_map(lambda x: x[indexes], state.particles)

return ParticlesState(particles=resampled_particles, weights=state.weights)

@partial(jax.jit, static_argnums=0)
def estimate_pos(self, state):
"""
Estimates the position of the observer based on the particles.
- state: ParticlesState
"""
p = jax.tree_map(
lambda x: (x * state.weights).sum() / state.weights.sum(), state.particles
)
p = p.replace(theta=jnp.arctan2(p.vel_y, p.vel_x))
return jnp.array([p.x, p.y])
Loading

0 comments on commit d098dff

Please sign in to comment.