Skip to content

Commit

Permalink
Add Replicate demo
Browse files Browse the repository at this point in the history
  • Loading branch information
ariel415el committed Dec 24, 2022
1 parent 5224ee5 commit 8c88593
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 0 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/human-motion-diffusion-model/motion-synthesis-on-humanml3d)](https://paperswithcode.com/sota/motion-synthesis-on-humanml3d?p=human-motion-diffusion-model)
[![arXiv](https://img.shields.io/badge/arXiv-<2209.14916>-<COLOR>.svg)](https://arxiv.org/abs/2209.14916)

<a href="https://replicate.com/arielreplicate/motion_diffusion_model"><img src="https://replicate.com/arielreplicate/motion_diffusion_model/badge"></a>

The official PyTorch implementation of the paper [**"Human Motion Diffusion Model"**](https://arxiv.org/abs/2209.14916).

Please visit our [**webpage**](https://guytevet.github.io/mdm-page/) for more details.
Expand Down
38 changes: 38 additions & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
build:
gpu: true
cuda: "11.3"
python_version: 3.8
system_packages:
- libgl1-mesa-glx
- libglib2.0-0

python_packages:
- imageio==2.22.2
- matplotlib==3.1.3
- spacy==3.3.1
- smplx==0.1.28
- chumpy==0.70
- blis==0.7.8
- click==8.1.3
- confection==0.0.2
- ftfy==6.1.1
- importlib-metadata==5.0.0
- lxml==4.9.1
- murmurhash==1.0.8
- preshed==3.0.7
- pycryptodomex==3.15.0
- regex==2022.9.13
- srsly==2.4.4
- thinc==8.0.17
- typing-extensions==4.1.1
- urllib3==1.26.12
- wasabi==0.10.1
- wcwidth==0.2.5

run:
- apt update -y && apt-get install ffmpeg -y
# - python -m spacy download en_core_web_sm
- git clone https://github.com/openai/CLIP.git sub_modules/CLIP
- pip install -e sub_modules/CLIP

predict: "sample/predict.py:Predictor"
151 changes: 151 additions & 0 deletions sample/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import os
import subprocess
import typing
from argparse import Namespace

import torch
from cog import BasePredictor, Input, Path

import data_loaders.humanml.utils.paramUtil as paramUtil
from data_loaders.get_data import get_dataset_loader
from data_loaders.humanml.scripts.motion_process import recover_from_ric
from data_loaders.humanml.utils.plot_script import plot_3d_motion
from data_loaders.tensors import collate
from model.cfg_sampler import ClassifierFreeSampleModel
from utils import dist_util
from utils.model_util import create_model_and_diffusion, load_model_wo_clip
from sample.generate import construct_template_variables

"""
In case of matplot lib issues it may be needed to delete model/data_loaders/humanml/utils/plot_script.py" in lines 89~92 as
suggested in https://github.com/GuyTevet/motion-diffusion-model/issues/6
"""


def get_args():
args = Namespace()
args.fps = 20
args.model_path = './save/humanml_trans_enc_512/model000200000.pt'
args.guidance_param = 2.5
args.unconstrained = False
args.dataset = 'humanml'

args.cond_mask_prob = 1
args.emb_trans_dec = False
args.latent_dim = 512
args.layers = 8
args.arch = 'trans_enc'

args.noise_schedule = 'cosine'
args.sigma_small = True
args.lambda_vel = 0.0
args.lambda_rcxyz = 0.0
args.lambda_fc = 0.0
return args


class Predictor(BasePredictor):
def setup(self):
subprocess.run(["mkdir", "/root/.cache/clip"])
subprocess.run(["cp", "-r", "ViT-B-32.pt", "/root/.cache/clip"])

self.args = get_args()
self.num_frames = self.args.fps * 6
print('Loading dataset...')

# temporary data
self.data = get_dataset_loader(name=self.args.dataset,
batch_size=1,
num_frames=196,
split='test',
hml_mode='text_only')

self.data.fixed_length = float(self.num_frames)

print("Creating model and diffusion...")
self.model, self.diffusion = create_model_and_diffusion(self.args, self.data)

print(f"Loading checkpoints from...")
state_dict = torch.load(self.args.model_path, map_location='cpu')
load_model_wo_clip(self.model, state_dict)

if self.args.guidance_param != 1:
self.model = ClassifierFreeSampleModel(self.model) # wrapping model with the classifier-free sampler
self.model.to(dist_util.dev())
self.model.eval() # disable random masking

def predict(
self,
prompt: str = Input(default="the person walked forward and is picking up his toolbox."),
num_repetitions: int = Input(default=3, description="How many"),

) -> typing.List[Path]:
args = self.args
args.num_repetitions = int(num_repetitions)

self.data = get_dataset_loader(name=self.args.dataset,
batch_size=args.num_repetitions,
num_frames=self.num_frames,
split='test',
hml_mode='text_only')

collate_args = [{'inp': torch.zeros(self.num_frames), 'tokens': None, 'lengths': self.num_frames, 'text': str(prompt)}]
_, model_kwargs = collate(collate_args)

# add CFG scale to batch
if args.guidance_param != 1:
model_kwargs['y']['scale'] = torch.ones(args.num_repetitions, device=dist_util.dev()) * args.guidance_param

sample_fn = self.diffusion.p_sample_loop
sample = sample_fn(
self.model,
(args.num_repetitions, self.model.njoints, self.model.nfeats, self.num_frames),
clip_denoised=False,
model_kwargs=model_kwargs,
skip_timesteps=0, # 0 is the default value - i.e. don't skip any step
init_image=None,
progress=True,
dump_steps=None,
noise=None,
const_noise=False,
)

# Recover XYZ *positions* from HumanML3D vector representation
if self.model.data_rep == 'hml_vec':
n_joints = 22 if sample.shape[1] == 263 else 21
sample = self.data.dataset.t2m_dataset.inv_transform(sample.cpu().permute(0, 2, 3, 1)).float()
sample = recover_from_ric(sample, n_joints)
sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1)

rot2xyz_pose_rep = 'xyz' if self.model.data_rep in ['xyz', 'hml_vec'] else self.model.data_rep
rot2xyz_mask = None if rot2xyz_pose_rep == 'xyz' else model_kwargs['y']['mask'].reshape(args.num_repetitions,
self.num_frames).bool()
sample = self.model.rot2xyz(x=sample, mask=rot2xyz_mask, pose_rep=rot2xyz_pose_rep, glob=True, translation=True,
jointstype='smpl', vertstrans=True, betas=None, beta=0, glob_rot=None,
get_rotations_back=False)

all_motions = sample.cpu().numpy()

caption = str(prompt)

skeleton = paramUtil.t2m_kinematic_chain


sample_print_template, row_print_template, all_print_template, \
sample_file_template, row_file_template, all_file_template = construct_template_variables(
args.unconstrained)

rep_files = []
replicate_fnames = []
for rep_i in range(args.num_repetitions):
motion = all_motions[rep_i].transpose(2, 0, 1)[:self.num_frames]
save_file = sample_file_template.format(1, rep_i)
print(sample_print_template.format(caption, 1, rep_i, save_file))
plot_3d_motion(save_file, skeleton, motion, dataset=args.dataset, title=caption, fps=args.fps)
# Credit for visualization: https://github.com/EricGuo5513/text-to-motion
rep_files.append(save_file)

replicate_fnames.append(Path(save_file))

return replicate_fnames

0 comments on commit 8c88593

Please sign in to comment.