Skip to content

Commit

Permalink
partial fix for the edit script
Browse files Browse the repository at this point in the history
  • Loading branch information
GuyTevet committed Feb 12, 2025
1 parent 020d9b4 commit c825854
Showing 1 changed file with 23 additions and 10 deletions.
33 changes: 23 additions & 10 deletions sample/edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import numpy as np
import torch
from utils.parser_util import edit_args
from utils.model_util import create_model_and_diffusion, load_model_wo_clip
from sample.generate import save_multiple_samples, construct_template_variables
from utils.model_util import create_model_and_diffusion, load_saved_model
from utils import dist_util
from utils.sampler_util import ClassifierFreeSampleModel
from data_loaders.get_data import get_dataset_loader
Expand All @@ -27,6 +28,8 @@ def main():
niter = os.path.basename(args.model_path).replace('model', '').replace('.pt', '')
max_frames = 196 if args.dataset in ['kit', 'humanml'] else 60
fps = 12.5 if args.dataset == 'kit' else 20
n_frames = 120 # min(max_frames, int(args.motion_length*fps))

dist_util.setup_dist(args.device)
if out_path == '':
out_path = os.path.join(os.path.dirname(args.model_path),
Expand Down Expand Up @@ -54,8 +57,7 @@ def main():
model, diffusion = create_model_and_diffusion(args, data)

print(f"Loading checkpoints from [{args.model_path}]...")
state_dict = torch.load(args.model_path, map_location='cpu')
load_model_wo_clip(model, state_dict)
load_saved_model(model, args.model_path, use_avg=args.use_ema)

model = ClassifierFreeSampleModel(model) # wrapping model with the classifier-free sampler
model.to(dist_util.dev())
Expand Down Expand Up @@ -156,17 +158,25 @@ def main():
input_motions = input_motions.view(-1, *input_motions.shape[2:]).permute(0, 2, 3, 1).cpu().numpy()


sample_print_template, row_print_template, all_print_template, \
sample_file_template, row_file_template, all_file_template = construct_template_variables(args.unconstrained)
max_vis_samples = 6
num_vis_samples = min(args.num_samples, max_vis_samples)
animations = np.empty(shape=(args.num_samples, args.num_repetitions), dtype=object)
max_length = max(all_lengths)

for sample_i in range(args.num_samples):
caption = 'Input Motion'
length = model_kwargs['y']['lengths'][sample_i]
motion = input_motions[sample_i].transpose(2, 0, 1)[:length]
save_file = 'input_motion{:02d}.mp4'.format(sample_i)
animation_save_path = os.path.join(out_path, save_file)
rep_files = [animation_save_path]
print(f'[({sample_i}) "{caption}" | -> {save_file}]')
plot_3d_motion(animation_save_path, skeleton, motion, title=caption,
dataset=args.dataset, fps=fps, vis_mode='gt',
gt_frames=gt_frames_per_sample.get(sample_i, []))
# FIXME - fix and bring back the following:
# print(f'[({sample_i}) "{caption}" | -> {save_file}]')
# plot_3d_motion(animation_save_path, skeleton, motion, title=caption,
# dataset=args.dataset, fps=fps, vis_mode='gt',
# gt_frames=gt_frames_per_sample.get(sample_i, []))
for rep_i in range(args.num_repetitions):
caption = all_text[rep_i*args.batch_size + sample_i]
if caption == '':
Expand All @@ -178,10 +188,11 @@ def main():
save_file = 'sample{:02d}_rep{:02d}.mp4'.format(sample_i, rep_i)
animation_save_path = os.path.join(out_path, save_file)
rep_files.append(animation_save_path)
gt_frames = gt_frames_per_sample.get(sample_i, [])
print(f'[({sample_i}) "{caption}" | Rep #{rep_i} | -> {save_file}]')
plot_3d_motion(animation_save_path, skeleton, motion, title=caption,
dataset=args.dataset, fps=fps, vis_mode=args.edit_mode,
gt_frames=gt_frames_per_sample.get(sample_i, []))
animations[sample_i, rep_i] = plot_3d_motion(animation_save_path,
skeleton, motion, dataset=args.dataset, title=caption,
fps=fps, gt_frames=gt_frames)
# Credit for visualization: https://github.com/EricGuo5513/text-to-motion

all_rep_save_file = os.path.join(out_path, 'sample{:02d}.mp4'.format(sample_i))
Expand All @@ -190,6 +201,8 @@ def main():
ffmpeg_rep_cmd = f'ffmpeg -y -loglevel warning ' + ''.join(ffmpeg_rep_files) + f'{hstack_args} {all_rep_save_file}'
os.system(ffmpeg_rep_cmd)
print(f'[({sample_i}) "{caption}" | all repetitions | -> {all_rep_save_file}]')

save_multiple_samples(out_path, {'all': all_file_template}, animations, fps, max(list(all_lengths) + [n_frames]))

abs_path = os.path.abspath(out_path)
print(f'[Done] Results are at [{abs_path}]')
Expand Down

0 comments on commit c825854

Please sign in to comment.