Skip to content

Commit

Permalink
Remove globals (#43)
Browse files Browse the repository at this point in the history
* squash local branch commits--begin refactoring to remove globals

* working fit without globals via new STAC class

* updated viz function (renamed to viz_stac) to use the new STAC class

* updated STAC.transform

* update rodent CLI script

* update tests due to function type signature changes

* add xla flags to demo notebook, fix bug in rodent cli

* fix typo

* fix vmap args

* fix bug, add to unit tests

* fix bug in package_data, add test for STAC init

* fix package data bug

* update run_rodent so hydra_entry has no return values

* add docstrings

* more docstrings

* update README

* address comments, rerun demos

* add docstring

* period.

* fix dataloader test bug

* remove unused lines, add permalink
  • Loading branch information
charles-zhng authored Aug 26, 2024
1 parent 04568a1 commit f022418
Show file tree
Hide file tree
Showing 23 changed files with 946 additions and 737 deletions.
43 changes: 28 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,53 +35,66 @@ Our rendering functions support multiple backends: `egl`, `glfw`, and `osmesa`.
## Usage
1. Update the .yaml files in `config/` with the proper information (details WIP).

2. Run stac-mjx with its basic api: `load_configs` for loading configs and `run_stac` for the keypoint registration. Below is an example script, found in `demos/use_api.ipynb`.
2. Run stac-mjx with its basic api: `load_configs` for loading configs and `run_stac` for the keypoint registration. Below is an example script, found in `demos/use_api.ipynb`. A CLI script using the rodent model is also provided at `run_rodent.py`

```python
from stac_mjx import main
from stac_mjx import utils
from pathlib import Path
import os
# XLA flags for Nvidia GPU
if xla_bridge.get_backend().platform == "gpu":
os.environ["XLA_FLAGS"] = (
"--xla_gpu_enable_triton_softmax_fusion=true "
"--xla_gpu_triton_gemm_any=True "
)

# Set base path to the parent directory of your config files
base_path = Path.cwd()
stac_config_path = base_path / "demos/demo_stac.yaml"
model_config_path = base_path / "configs/rodent.yaml"

# Load configs
cfg = main.load_configs(stac_config_path, model_config_path)
stac_cfg, model_cfg = main.load_configs(stac_config_path, model_config_path)

# Load data
data_path = base_path / cfg.paths.data_path
kp_data = utils.load_data(data_path, utils.params)
kp_data, sorted_kp_names = utils.load_data(data_path, model_cfg)

# Run stac
fit_path, transform_path = main.run_stac(cfg, kp_data, base_path)
fit_path, transform_path = main.run_stac(
stac_cfg,
model_cfg,
kp_data,
sorted_kp_names,
base_path=base_path
)
```

3. Render the resulting data using `mujoco_viz()` (example notebook found in `demos/viz_usage.ipynb`):
```python
import os
import mediapy as media

from stac_mjx.viz import mujoco_viz
from stac_mjx.viz import viz_stac
from stac_mjx import main
from stac_mjx import utils
from pathlib import Path

stac_config_path = "../configs/stac.yaml"
model_config_path = "../configs/rodent.yaml"
base_path = Path.cwd()
stac_config_path = base_path / "demos/demo_stac.yaml"
model_config_path = base_path / "configs/rodent.yaml"

cfg = main.load_configs(stac_config_path, model_config_path)
stac_cfg, model_cfg = main.load_configs(stac_config_path, model_config_path)

xml_path = "../models/rodent.xml"
data_path = "../output.p"
n_frames=250
save_path="../videos/direct_render.mp4"
data_path = base_path / "demo_fit.p"
n_frames = 250
save_path = base_path / "videos/direct_render.mp4"

# Call mujoco_viz
frames = mujoco_viz(data_path, xml_path, n_frames, save_path, start_frame=0)
frames = viz_stac(data_path, stac_cfg, model_cfg, n_frames, save_path, start_frame=0, camera="close_profile", base_path=Path.cwd().parent)

# Show the video in the notebook (it is also saved to the save_path)
media.show_video(frames, fps=utils.params["RENDER_FPS"])
media.show_video(frames, fps=model_cfg["RENDER_FPS"])
```

4. If the rendering is poor, it's likely that some hyperparameter tuning is necessary. (details WIP)
12 changes: 9 additions & 3 deletions configs/rodent.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
MJCF_PATH: "models/rodent.xml"

# Frames per clip for transform.
N_FRAMES_PER_CLIP: 250

Expand Down Expand Up @@ -88,9 +90,13 @@ KEYPOINT_INITIAL_OFFSETS:
WristR: 0. 0. 0.0

TRUNK_OPTIMIZATION_KEYPOINTS:
- "Spine"
- "Hip"
- "Shoulder"
- "SpineF"
- "SpineL"
- "SpineM"
- "HipL"
- "HipR"
- "ShoulderL"
- "ShoulderR"
- "TailBase"

INDIVIDUAL_PART_OPTIMIZATION:
Expand Down
9 changes: 3 additions & 6 deletions configs/stac.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
paths:
model_config: "rodent"
xml: "models/rodent.xml"
fit_path: "fit.p"
transform_path: "transform.p"
data_path: "tests/data/test_rodent_mocap_1000_frames.nwb"
fit_path: "fit.p"
transform_path: "transform.p"
data_path: "tests/data/test_rodent_mocap_1000_frames.nwb"

n_fit_frames: 1000
skip_fit: False
Expand Down
244 changes: 125 additions & 119 deletions demos/api_usage.ipynb

Large diffs are not rendered by default.

9 changes: 3 additions & 6 deletions demos/demo_stac.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
paths:
model_config: "rodent"
xml: "models/rodent.xml"
fit_path: "demo_fit.p"
transform_path: "demo_transform.p"
data_path: "tests/data/test_rodent_mocap_1000_frames.mat"
fit_path: "demo_fit.p"
transform_path: "demo_transform.p"
data_path: "tests/data/test_rodent_mocap_1000_frames.mat"

n_fit_frames: 10
skip_fit: False
Expand Down
67 changes: 25 additions & 42 deletions demos/viz_usage.ipynb

Large diffs are not rendered by default.

57 changes: 26 additions & 31 deletions run_rodent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import jax
from jax import numpy as jnp
"""CLI script for running rodent skeletal registration"""

from jax.lib import xla_bridge
import numpy as np

import os
import logging
Expand All @@ -10,45 +9,41 @@

from stac_mjx import main
from stac_mjx import utils
from pathlib import Path


def load_and_run_stac(stac_cfg, model_cfg):
base_path = Path.cwd()

data_path = base_path / stac_cfg.data_path
kp_data, sorted_kp_names = utils.load_data(data_path, model_cfg)

fit_path, transform_path = main.run_stac(
stac_cfg, model_cfg, kp_data, sorted_kp_names, base_path=base_path
)

logging.info(
f"Run complete. \n fit path: {fit_path} \n transform path: {transform_path}"
)


@hydra.main(config_path="./configs", config_name="stac", version_base=None)
def hydra_entry(cfg: DictConfig):
# Initialize configs and convert to dictionaries
global_cfg = hydra.compose(config_name=cfg.paths.model_config)
logging.info(f"cfg: {OmegaConf.to_yaml(cfg)}")
logging.info(f"global_cfg: {OmegaConf.to_yaml(global_cfg)}")
utils.init_params(OmegaConf.to_container(global_cfg, resolve=True))
def hydra_entry(stac_cfg: DictConfig):
# Initialize configs
model_cfg = hydra.compose(config_name="rodent")
logging.info(f"cfg: {OmegaConf.to_yaml(stac_cfg)}")
logging.info(f"model_cfg: {OmegaConf.to_yaml(model_cfg)}")
model_cfg = OmegaConf.to_container(model_cfg, resolve=True)

# XLA flags for Nvidia GPU
if xla_bridge.get_backend().platform == "gpu":
os.environ["XLA_FLAGS"] = (
"--xla_gpu_enable_triton_softmax_fusion=true "
"--xla_gpu_triton_gemm_any=True "
)
# Set N_GPUS
utils.params["N_GPUS"] = jax.local_device_count("gpu")

# Set up mocap data
kp_names = utils.params["KP_NAMES"]
# argsort returns the indices that sort the array to match the order of marker sites
stac_keypoint_order = np.argsort(kp_names)
data_path = cfg.paths.data_path

# Load kp_data, /1000 to scale data (from mm to meters)
kp_data = utils.loadmat(data_path)["pred"][:] / 1000

# Preparing DANNCE data by reordering and reshaping
# Resulting kp_data is of shape (n_frames, n_keypoints)
kp_data = jnp.array(kp_data[:, :, stac_keypoint_order])
kp_data = jnp.transpose(kp_data, (0, 2, 1))
kp_data = jnp.reshape(kp_data, (kp_data.shape[0], -1))

return main.run_stac(cfg, kp_data)
load_and_run_stac(stac_cfg, model_cfg)


if __name__ == "__main__":
fit_path, transform_path = hydra_entry()
logging.info(
f"Run complete. \n fit path: {fit_path} \n transform path: {transform_path}"
)
hydra_entry()
4 changes: 4 additions & 0 deletions stac_mjx/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
"""This module exposes all high level APIs for stac-mjx."""

from stac_mjx.utils import enable_xla_flags, load_data
from stac_mjx.main import load_configs, run_stac
from stac_mjx.viz import viz_stac
Loading

0 comments on commit f022418

Please sign in to comment.