-
Notifications
You must be signed in to change notification settings - Fork 373
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5224ee5
commit 8c88593
Showing
3 changed files
with
191 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
build: | ||
gpu: true | ||
cuda: "11.3" | ||
python_version: 3.8 | ||
system_packages: | ||
- libgl1-mesa-glx | ||
- libglib2.0-0 | ||
|
||
python_packages: | ||
- imageio==2.22.2 | ||
- matplotlib==3.1.3 | ||
- spacy==3.3.1 | ||
- smplx==0.1.28 | ||
- chumpy==0.70 | ||
- blis==0.7.8 | ||
- click==8.1.3 | ||
- confection==0.0.2 | ||
- ftfy==6.1.1 | ||
- importlib-metadata==5.0.0 | ||
- lxml==4.9.1 | ||
- murmurhash==1.0.8 | ||
- preshed==3.0.7 | ||
- pycryptodomex==3.15.0 | ||
- regex==2022.9.13 | ||
- srsly==2.4.4 | ||
- thinc==8.0.17 | ||
- typing-extensions==4.1.1 | ||
- urllib3==1.26.12 | ||
- wasabi==0.10.1 | ||
- wcwidth==0.2.5 | ||
|
||
run: | ||
- apt update -y && apt-get install ffmpeg -y | ||
# - python -m spacy download en_core_web_sm | ||
- git clone https://github.com/openai/CLIP.git sub_modules/CLIP | ||
- pip install -e sub_modules/CLIP | ||
|
||
predict: "sample/predict.py:Predictor" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import os | ||
import subprocess | ||
import typing | ||
from argparse import Namespace | ||
|
||
import torch | ||
from cog import BasePredictor, Input, Path | ||
|
||
import data_loaders.humanml.utils.paramUtil as paramUtil | ||
from data_loaders.get_data import get_dataset_loader | ||
from data_loaders.humanml.scripts.motion_process import recover_from_ric | ||
from data_loaders.humanml.utils.plot_script import plot_3d_motion | ||
from data_loaders.tensors import collate | ||
from model.cfg_sampler import ClassifierFreeSampleModel | ||
from utils import dist_util | ||
from utils.model_util import create_model_and_diffusion, load_model_wo_clip | ||
from sample.generate import construct_template_variables | ||
|
||
""" | ||
In case of matplot lib issues it may be needed to delete model/data_loaders/humanml/utils/plot_script.py" in lines 89~92 as | ||
suggested in https://github.com/GuyTevet/motion-diffusion-model/issues/6 | ||
""" | ||
|
||
|
||
def get_args(): | ||
args = Namespace() | ||
args.fps = 20 | ||
args.model_path = './save/humanml_trans_enc_512/model000200000.pt' | ||
args.guidance_param = 2.5 | ||
args.unconstrained = False | ||
args.dataset = 'humanml' | ||
|
||
args.cond_mask_prob = 1 | ||
args.emb_trans_dec = False | ||
args.latent_dim = 512 | ||
args.layers = 8 | ||
args.arch = 'trans_enc' | ||
|
||
args.noise_schedule = 'cosine' | ||
args.sigma_small = True | ||
args.lambda_vel = 0.0 | ||
args.lambda_rcxyz = 0.0 | ||
args.lambda_fc = 0.0 | ||
return args | ||
|
||
|
||
class Predictor(BasePredictor): | ||
def setup(self): | ||
subprocess.run(["mkdir", "/root/.cache/clip"]) | ||
subprocess.run(["cp", "-r", "ViT-B-32.pt", "/root/.cache/clip"]) | ||
|
||
self.args = get_args() | ||
self.num_frames = self.args.fps * 6 | ||
print('Loading dataset...') | ||
|
||
# temporary data | ||
self.data = get_dataset_loader(name=self.args.dataset, | ||
batch_size=1, | ||
num_frames=196, | ||
split='test', | ||
hml_mode='text_only') | ||
|
||
self.data.fixed_length = float(self.num_frames) | ||
|
||
print("Creating model and diffusion...") | ||
self.model, self.diffusion = create_model_and_diffusion(self.args, self.data) | ||
|
||
print(f"Loading checkpoints from...") | ||
state_dict = torch.load(self.args.model_path, map_location='cpu') | ||
load_model_wo_clip(self.model, state_dict) | ||
|
||
if self.args.guidance_param != 1: | ||
self.model = ClassifierFreeSampleModel(self.model) # wrapping model with the classifier-free sampler | ||
self.model.to(dist_util.dev()) | ||
self.model.eval() # disable random masking | ||
|
||
def predict( | ||
self, | ||
prompt: str = Input(default="the person walked forward and is picking up his toolbox."), | ||
num_repetitions: int = Input(default=3, description="How many"), | ||
|
||
) -> typing.List[Path]: | ||
args = self.args | ||
args.num_repetitions = int(num_repetitions) | ||
|
||
self.data = get_dataset_loader(name=self.args.dataset, | ||
batch_size=args.num_repetitions, | ||
num_frames=self.num_frames, | ||
split='test', | ||
hml_mode='text_only') | ||
|
||
collate_args = [{'inp': torch.zeros(self.num_frames), 'tokens': None, 'lengths': self.num_frames, 'text': str(prompt)}] | ||
_, model_kwargs = collate(collate_args) | ||
|
||
# add CFG scale to batch | ||
if args.guidance_param != 1: | ||
model_kwargs['y']['scale'] = torch.ones(args.num_repetitions, device=dist_util.dev()) * args.guidance_param | ||
|
||
sample_fn = self.diffusion.p_sample_loop | ||
sample = sample_fn( | ||
self.model, | ||
(args.num_repetitions, self.model.njoints, self.model.nfeats, self.num_frames), | ||
clip_denoised=False, | ||
model_kwargs=model_kwargs, | ||
skip_timesteps=0, # 0 is the default value - i.e. don't skip any step | ||
init_image=None, | ||
progress=True, | ||
dump_steps=None, | ||
noise=None, | ||
const_noise=False, | ||
) | ||
|
||
# Recover XYZ *positions* from HumanML3D vector representation | ||
if self.model.data_rep == 'hml_vec': | ||
n_joints = 22 if sample.shape[1] == 263 else 21 | ||
sample = self.data.dataset.t2m_dataset.inv_transform(sample.cpu().permute(0, 2, 3, 1)).float() | ||
sample = recover_from_ric(sample, n_joints) | ||
sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1) | ||
|
||
rot2xyz_pose_rep = 'xyz' if self.model.data_rep in ['xyz', 'hml_vec'] else self.model.data_rep | ||
rot2xyz_mask = None if rot2xyz_pose_rep == 'xyz' else model_kwargs['y']['mask'].reshape(args.num_repetitions, | ||
self.num_frames).bool() | ||
sample = self.model.rot2xyz(x=sample, mask=rot2xyz_mask, pose_rep=rot2xyz_pose_rep, glob=True, translation=True, | ||
jointstype='smpl', vertstrans=True, betas=None, beta=0, glob_rot=None, | ||
get_rotations_back=False) | ||
|
||
all_motions = sample.cpu().numpy() | ||
|
||
caption = str(prompt) | ||
|
||
skeleton = paramUtil.t2m_kinematic_chain | ||
|
||
|
||
sample_print_template, row_print_template, all_print_template, \ | ||
sample_file_template, row_file_template, all_file_template = construct_template_variables( | ||
args.unconstrained) | ||
|
||
rep_files = [] | ||
replicate_fnames = [] | ||
for rep_i in range(args.num_repetitions): | ||
motion = all_motions[rep_i].transpose(2, 0, 1)[:self.num_frames] | ||
save_file = sample_file_template.format(1, rep_i) | ||
print(sample_print_template.format(caption, 1, rep_i, save_file)) | ||
plot_3d_motion(save_file, skeleton, motion, dataset=args.dataset, title=caption, fps=args.fps) | ||
# Credit for visualization: https://github.com/EricGuo5513/text-to-motion | ||
rep_files.append(save_file) | ||
|
||
replicate_fnames.append(Path(save_file)) | ||
|
||
return replicate_fnames | ||
|