Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Synthetic data smoke test. #75

Merged
merged 29 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
bbde0a4
Model runs + draws in notebook, no data output
jf514 Oct 8, 2024
4661907
Configs etc - model not yet working.
jf514 Oct 25, 2024
d5cdd97
Disable energy
jf514 Oct 25, 2024
9baa1c7
Offset shape bugfix (#73)
charles-zhng Oct 20, 2024
a6b983f
Smoke test (#74)
jf514 Oct 25, 2024
7ed9f25
IT'S WORKING
jf514 Oct 25, 2024
d4145c5
Offset shape bugfix (#73)
charles-zhng Oct 20, 2024
8248778
Configs etc - model not yet working.
jf514 Oct 25, 2024
be2f1aa
Offset shape bugfix (#73)
charles-zhng Oct 20, 2024
935c922
Configs etc - model not yet working.
jf514 Oct 25, 2024
2e6e01c
Merge remote-tracking branch 'origin/main' into synthetic-data
jf514 Oct 26, 2024
45774d9
Fix weird merge.
jf514 Oct 26, 2024
e9c1a4f
Clean up synth_model config file.
jf514 Oct 28, 2024
f14910b
Remove TIME_BINS (which was a merge accident.)
jf514 Oct 28, 2024
f6df7e3
Fix smoke test.
jf514 Oct 28, 2024
7b16b5a
Fix smoke test.
jf514 Oct 28, 2024
b75b354
Clean up.
jf514 Oct 28, 2024
6f7da79
Fixed root optimization, but still some debug code.
jf514 Oct 28, 2024
5433344
Add root_kp_index
jf514 Oct 29, 2024
2405815
Forgot model yaml.
jf514 Oct 29, 2024
9a2a599
Reset rodent configs, enable synth config.
jf514 Oct 29, 2024
86d333e
Add synth_data smoke test.
jf514 Oct 29, 2024
84793e9
Missed data file.
jf514 Oct 29, 2024
82c1887
Clean up.
jf514 Oct 30, 2024
d9fe782
Add root opt keypoint to model configs + clean up.
jf514 Oct 30, 2024
418a1e9
Clean up.
jf514 Oct 30, 2024
19a30e0
CR feedback.
jf514 Nov 1, 2024
dc26558
Add synth data generation program.
jf514 Nov 1, 2024
dc881c3
Add comments.
jf514 Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
594 changes: 594 additions & 0 deletions Mat-to-Nwb-Synth-Data.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion configs/config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defaults:
- stac: demo
- model: rodent
- model: synth_model
- _self_

##FLY_MODEL
Expand Down
2 changes: 0 additions & 2 deletions configs/model/fly_tethered.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -176,5 +176,3 @@ N_SAMPLE_FRAMES: 100
# If you have reason to believe that the initial offsets are correct for particular keypoints,
# you can regularize those sites using _SITES_TO_REGULARIZE.
M_REG_COEF: 1

TIME_BINS: 0.02
2 changes: 0 additions & 2 deletions configs/model/fly_treadmill.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,3 @@ N_SAMPLE_FRAMES: 100
# If you have reason to believe that the initial offsets are correct for particular keypoints,
# you can regularize those sites using _SITES_TO_REGULARIZE.
M_REG_COEF: 1

TIME_BINS: 0.02
8 changes: 5 additions & 3 deletions configs/model/rodent.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
MJCF_PATH: "models/rodent.xml"

# Frames per clip for ik_only.
N_FRAMES_PER_CLIP: 250
N_FRAMES_PER_CLIP: 25

# Tolerance for the optimizations of the full model, limb, and root.
# TODO: Re-implement optimizer loops to use these tolerances
Expand All @@ -10,7 +10,7 @@ ROOT_FTOL: 1.0e-05
LIMB_FTOL: 1.0e-06

# Number of alternating pose and offset optimization rounds.
N_ITERS: 6
N_ITERS: 1

KP_NAMES:
- 'Snout'
Expand Down Expand Up @@ -89,6 +89,8 @@ KEYPOINT_INITIAL_OFFSETS:
WristL: 0. 0. 0.0
WristR: 0. 0. 0.0

ROOT_OPTIMIZATION_KEYPOINT: SpineL

TRUNK_OPTIMIZATION_KEYPOINTS:
- "SpineF"
- "SpineL"
Expand Down Expand Up @@ -189,7 +191,7 @@ SITES_TO_REGULARIZE:

RENDER_FPS: 50

N_SAMPLE_FRAMES: 100
N_SAMPLE_FRAMES: 1

# If you have reason to believe that the initial offsets are correct for particular keypoints,
# you can regularize those sites using _SITES_TO_REGULARIZE.
Expand Down
58 changes: 58 additions & 0 deletions configs/model/synth_model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@

MJCF_PATH: 'models/synth_model.xml'
jf514 marked this conversation as resolved.
Show resolved Hide resolved

# Frames per clip for transform.
N_FRAMES_PER_CLIP: 1
jf514 marked this conversation as resolved.
Show resolved Hide resolved

# Tolerance for the optimizations of the full model, limb, and root.
# TODO: Re-implement optimizer loops to use these tolerances
FTOL: 5.0e-03
ROOT_FTOL: 1.0e-05
LIMB_FTOL: 1.0e-06

# Number of alternating pose and offset optimization rounds.
N_ITERS: 1

KP_NAMES:
- part_0

ROOT_OPTIMIZATION_KEYPOINT: part_0

# The model sites used to register the keypoints.
KEYPOINT_MODEL_PAIRS:
part_0: base

# The initial offsets for each keypoint in meters.
KEYPOINT_INITIAL_OFFSETS:
part_0: 0 0 0.01

TRUNK_OPTIMIZATION_KEYPOINTS:
- part_0

INDIVIDUAL_PART_OPTIMIZATION:
model_base: [base]

# Color to use for each keypoint when visualizing the results
KEYPOINT_COLOR_PAIRS:
base: 0 .5 1 1
jf514 marked this conversation as resolved.
Show resolved Hide resolved

# What is the size of the animal you'd like to register, relative to the model?
SCALE_FACTOR: 1

# Multiplier to put the mocap data into the same scale as the data. Eg, if the
# mocap data is known to be in millimeters and the model is in meters, this is
# .001
MOCAP_SCALE_FACTOR: 1

# If you have reason to believe that the initial offsets are correct for particular keypoints,
# you can regularize those sites using this with M_REG_COEF.
SITES_TO_REGULARIZE:
- part_0

RENDER_FPS: 200

N_SAMPLE_FRAMES: 1

# If you have reason to believe that the initial offsets are correct for particular keypoints,
# you can regularize those sites using _SITES_TO_REGULARIZE.
M_REG_COEF: 1
jf514 marked this conversation as resolved.
Show resolved Hide resolved
10 changes: 5 additions & 5 deletions configs/stac/demo.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
fit_offsets_path: "demo_fit.p"
ik_only_path: "demo_ik_only.p"
data_path: "tests/data/test_rodent_mocap_1000_frames.mat"
fit_offsets_path: "synth_fit.p"
ik_only_path: "synth_ik_only.p"
data_path: "tests/data/save_data_AVG_2.nwb"

n_fit_frames: 1
skip_fit_offsets: False
jf514 marked this conversation as resolved.
Show resolved Hide resolved
skip_ik_only: True
skip_ik_only: False

mujoco:
solver: "newton"
solver: newton
iterations: 1
ls_iterations: 4
2 changes: 1 addition & 1 deletion configs/stac/stac.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ skip_ik_only: True
#skip_ik_only: False

mujoco:
solver: "newton"
solver: newton
iterations: 1
ls_iterations: 4
259 changes: 259 additions & 0 deletions demos/create_synth_data.ipynb
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How was the data saved from this notebook? And were there any offsets applied to the synthetic keypoint? I wanted to test whether it finds the ground truth offset when given different initial offsets in the model.yaml file, but it always returns the initial offset without any changes.

This notebook doesn't run as is; the rendering is cell 3 is different for me from the one shared, and the last cell throw an error: KeyError: "Invalid name '0'. Valid names: ['base', 'world']"

Copy link
Contributor Author

@jf514 jf514 Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's just not ready for that. I changed to the title to reflect that it's just a smoke test... ie it runs with out crashing.

This PR isn't to provide that level of functionality, sadly. I've created #81 to track the next steps. Please feel free to add any specific requests to that.

As for the data generation, I created a single frame of fake data (not even collected from an actual Mujoco run) just to give an output.

Also, if you think there are any comments that need to be added to the code to make this clear, also lmk.

Large diffs are not rendered by default.

23 changes: 23 additions & 0 deletions models/synth_model.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
<mujoco>
<option timestep=".001">
<!-- <flag energy="enable" contact="disable"/> -->
</option>

<default>
<joint type="hinge" axis="0 -1 0"/>
<geom type="capsule" size=".02"/>
</default>

jf514 marked this conversation as resolved.
Show resolved Hide resolved
<worldbody>
<light pos="0 -.4 1"/>
<camera name="fixed" pos="0 -1 0" xyaxes="1 0 0 0 0 1"/>
<body name="base" pos="0 0 .2">
<joint name="root"/>
<geom fromto="0 0 0 0 0 -.25" rgba="1 1 0 1"/>
<!-- <body name="3" pos="0 0 -.25">
<joint/>
<geom fromto="0 0 0 0 0 -.2" rgba="0 0 1 1"/>
</body> -->
</body>
</worldbody>
</mujoco>
3 changes: 2 additions & 1 deletion stac_mjx/compute_stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def root_optimization(
mjx_model,
mjx_data,
kp_data: jp.ndarray,
root_kp_idx: int,
lb: jp.ndarray,
ub: jp.ndarray,
site_idxs: jp.ndarray,
Expand Down Expand Up @@ -50,7 +51,7 @@ def root_optimization(
# necessarily exactly so. The value of 3*18 is chosen for the
# rodent.xml, corresponding to the index of 'SpineL' keypoint.
# For the mouse model this should be 3*5, corresponding 'Trunk'
root_kp_idx = 3 * 18
# root_kp_idx = 3 * 18
# FLY_MODEL:
# root_kp_idx = 0
q0.at[:3].set(kp_data[frame, :][root_kp_idx : root_kp_idx + 3])
Expand Down
1 change: 1 addition & 0 deletions stac_mjx/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def load_data(cfg: DictConfig, base_path: Union[Path, None] = None):
f"Number of keypoint names ({len(kp_names)}) is not the same as the number of keypoints in data ({data.shape[1]})"
)

print("kpnames ", kp_names)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Consider using proper logging instead of print statements.

Debug print statements should be avoided in production code. Consider using Python's logging framework for better control over debug output.

Replace the print statement with a proper debug log:

-    print("kpnames ", kp_names)
+    import logging
+    logging.debug("Keypoint names: %s", kp_names)

This change offers several benefits:

  • Log level control (can be disabled in production)
  • Structured output format
  • Better integration with monitoring systems

Committable suggestion was skipped due to low confidence.


💡 Codebase verification

Print statement appears to be inconsistent with codebase's logging pattern

Looking at the context and print statements across the codebase:

  1. Most print statements in the codebase are used for meaningful progress updates (iterations, timings, errors) or warnings
  2. This particular print statement is a debug-style output without clear purpose
  3. The codebase doesn't have a consistent debug/verbose mode pattern
  4. The print statement appears in the middle of keypoint mapping logic without providing significant value

Recommendations:

  • Remove this print statement as it doesn't align with the codebase's logging pattern
  • If debugging is needed, consider adding proper logging with debug levels
  • For test visibility, use pytest's built-in capture system instead
🔗 Analysis chain

Verify if this print statement is required for E2E tests.

Given this PR's focus on synthetic E2E testing, please clarify if this print statement is intentionally added for test debugging purposes. If so, consider:

  1. Moving it behind a debug flag
  2. Adding it only in test configurations
  3. Using pytest's built-in capture system for test output

Let's check if this print statement is used in tests:


The previous search was too narrow. Let's:

  1. Get more context around the print statement
  2. Check for any debug configuration patterns
  3. Look for similar print statements to understand the logging pattern
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for test files that might rely on this output
rg -l "kpnames" tests/

Length of output: 22


Script:

#!/bin/bash
# Let's check the broader context of this print statement
rg -B 3 -A 3 "kpnames" .

# Also check if there are any test configurations or debug flags
rg -l "DEBUG|VERBOSE|debug_mode|verbose_mode" .

# Look for similar print statements in the codebase
rg "print\(" stac_mjx/

Length of output: 2462

model_inds = [
kp_names.index(src) for src, dst in cfg.model.KEYPOINT_MODEL_PAIRS.items()
]
Expand Down
32 changes: 27 additions & 5 deletions stac_mjx/stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,27 @@
self._mj_model.body(i).name for i in range(self._mj_model.nbody)
]

joint_names = [self._mj_model.joint(i).name for i in range(self._mj_model.njnt)]
if "ROOT_OPTIMIZATION_KEYPOINT" in self.cfg.model:
print("fouind root opt key")
self._root_kp_idx = self._kp_names.index(

Check warning on line 85 in stac_mjx/stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/stac.py#L84-L85

Added lines #L84 - L85 were not covered by tests
self.cfg.model.ROOT_OPTIMIZATION_KEYPOINT
)
else:
print("NOOOO ROOT KP")
self._root_kp_idx = -1

print(self.cfg.model.keys())
print("Using root_kp_idx = ", self._root_kp_idx)
jf514 marked this conversation as resolved.
Show resolved Hide resolved

# Set up bounds and part_names based on joint ranges, taking into account the dimensionality of parameters
joint_names = [self._mj_model.joint(i).name for i in range(self._mj_model.njnt)]
self._lb, self._ub, self._part_names = _align_joint_dims(
self._mj_model.jnt_type, self._mj_model.jnt_range, joint_names
)

self._indiv_parts = self.part_opt_setup()

# Generate boolean flags for keypoints included in trunk optimization.
self._trunk_kps = jp.array(
[n in self.cfg.model.TRUNK_OPTIMIZATION_KEYPOINTS for n in kp_names],
)
Expand All @@ -113,7 +125,7 @@
[any(part in name for part in parts) for name in self._part_names]
)

if self.cfg.model.INDIVIDUAL_PART_OPTIMIZATION is None:
if "INDIVIDUAL_PART_OPTIMIZATION" not in self.cfg.model:
indiv_parts = []
else:
indiv_parts = jp.array(
Expand Down Expand Up @@ -224,11 +236,16 @@

# Begin optimization steps
# Skip root optimization if model is fixed (no free joint at root)
if self._mj_model.jnt_type[0] == mujoco.mjtJoint.mjJNT_FREE:
if self._root_kp_idx == -1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense 👍

print(

Check warning on line 240 in stac_mjx/stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/stac.py#L239-L240

Added lines #L239 - L240 were not covered by tests
"Missing or invalid ROOT_OPTIMIZATION_KEYPOINT, skipping root_optimization()"
)
elif self._mj_model.jnt_type[0] == mujoco.mjtJoint.mjJNT_FREE:

Check warning on line 243 in stac_mjx/stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/stac.py#L243

Added line #L243 was not covered by tests
jf514 marked this conversation as resolved.
Show resolved Hide resolved
mjx_data = compute_stac.root_optimization(
mjx_model,
mjx_data,
kp_data,
self._root_kp_idx,
self._lb,
self._ub,
self._body_site_idxs,
Expand Down Expand Up @@ -339,15 +356,20 @@
)

# q_phase - root
if self._mj_model.jnt_type[0] == mujoco.mjtJoint.mjJNT_FREE:
if self._root_kp_idx == -1:
print(

Check warning on line 360 in stac_mjx/stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/stac.py#L359-L360

Added lines #L359 - L360 were not covered by tests
"Missing or invalid ROOT_OPTIMIZATION_KEYPOINT, skipping root_optimization()"
)
elif self._mj_model.jnt_type[0] == mujoco.mjtJoint.mjJNT_FREE:

Check warning on line 363 in stac_mjx/stac.py

View check run for this annotation

Codecov / codecov/patch

stac_mjx/stac.py#L363

Added line #L363 was not covered by tests
jf514 marked this conversation as resolved.
Show resolved Hide resolved
vmap_root_opt = jax.vmap(
compute_stac.root_optimization,
in_axes=(0, 0, 0, None, None, None, None),
in_axes=(0, 0, 0, None, None, None, None, None),
jf514 marked this conversation as resolved.
Show resolved Hide resolved
)
mjx_data = vmap_root_opt(
mjx_model,
mjx_data,
batched_kp_data,
self._root_kp_idx,
self._lb,
self._ub,
self._body_site_idxs,
Expand Down
Binary file added tests/data/save_data_AVG_2.nwb
Binary file not shown.
Loading