Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jaxnav docs & improvements #112

Merged
merged 10 commits into from
Aug 29, 2024
Merged

Jaxnav docs & improvements #112

merged 10 commits into from
Aug 29, 2024

Conversation

amacrutherford
Copy link
Collaborator

No description provided.

@amacrutherford amacrutherford marked this pull request as draft August 27, 2024 21:10
@amacrutherford amacrutherford marked this pull request as ready for review August 28, 2024 14:39
Copy link
Contributor

@benellis3 benellis3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me -- left some typos and suggestions but feel free to ignore them.

The default map is square robots of width 0.5m moving within a world with grid based obstacled, with cells of size 1m x 1m. Map cell size can be varied to produce obstacles of higher fidelty or robot strucutre can be changed into any polygon or a circle.

We also include a map which uses polygon obstacles, but note we have not used this code is a while so there may well be issues with it.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: "is a while"


### Observation space
By default, each robot recieves 200 range readings from a 360-degree arc centered on their forward axis. These range readings have a max range of 6m but no minimum range and are discritised with a resultion of 0.05 m. Alongside these range readings, each robot recieves their current linear and angular velocities along with the direction to their goal. Their goal direction is given by a vector in polar form where the distance is either the max lidar range if the goal is beyond their "line of sight" or the actual distance if the goal is within their lidar range. There is no communication between agents.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"discritised" -> "discretised"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also "recieves" -> "receives"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the dyslexia do hit hard

map_collisions = self._check_map_collisions(new_pos, new_theta, agent_states.map_data)*(1-agent_states.done).astype(bool)
agent_collisions = self._check_agent_collisions(jnp.arange(agent_states.pos.shape[0]), new_pos, agent_states.done)*(1- agent_states.done).astype(bool)
map_collisions = jax.vmap(self._map_obj.check_agent_map_collision, in_axes=(0, 0, None))(new_pos, new_theta, agent_states.map_data)*(1-agent_states.done).astype(bool)
# map_collisions = self._check_map_collisions(new_pos, new_theta, agent_states.map_data)*(1-agent_states.done).astype(bool)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need this code?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope :)

@@ -708,11 +708,6 @@ def make_jaxnav_singleton_collection(collection_id: str, **env_kwargs) -> Tuple[
"NarrowChicane2b",
"Chicane4",
],
"new": [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you remove these tests? Ideally tests should be fixed unless you are removing a feature

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh these were just used during development so tidying them into a clearer set

@@ -330,6 +342,75 @@ def _passable(grid, posa, posb):

def scale_coords(self, x):
return x / self.cell_size

def dikstra_path(self, map_data, pos1, pos2):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you implement dijkstra's algorithm in Jax 😆

absolutely nuts good work

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to describe how this works in a high level comment

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at 11pm on a sunday yes haahah, Reviewer 4 asked for environment metrics..!

@@ -569,6 +651,44 @@ def _checkPointRect(x, y, rx, ry):
inside = _checkPolyWithinRect(sides, grid_idx[1], grid_idx[0])
return inside & map_grid[grid_idx[0], grid_idx[1]] & valid_idx

@partial(jax.jit, static_argnums=[0])
def check_all_agent_agent_collisions(self, agent_positions: chex.Array, agent_theta: chex.Array, agent_coords=None) -> chex.Array:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also would be good to provide more detail for people (like) me who have no idea what the separating axis theorem is.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea very fair point :)

@amacrutherford amacrutherford merged commit ded3239 into main Aug 29, 2024
2 checks passed
@amacrutherford amacrutherford deleted the jaxnav-docs branch August 29, 2024 13:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants