Skip to content

Commit

Permalink
[feature] ONNX export support for DDColor model (#51)
Browse files Browse the repository at this point in the history
* onnx export support and demo notebook

* update docs

* update readme

* add usage example back
  • Loading branch information
shubhamgupto authored Oct 25, 2024
1 parent 7552368 commit 1c7861d
Show file tree
Hide file tree
Showing 5 changed files with 464 additions and 1 deletion.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ tmp/*
.vscode
.github

.onnx

# ignored files
version.py

Expand Down Expand Up @@ -133,3 +135,5 @@ venv.bak/

# meta file
data_list/*.txt

weights/
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,22 @@ python data_list/get_meta_file.py
sh scripts/train.sh
```

## ONNX export
Support for ONNX model exports is now available
### Additional dependencies
```
pip install onnx==1.16.1 onnxruntime==1.19.2 onnxsim==0.4.36
```

### Usage example
```
python export.py
usage: export.py [-h] [--input_size INPUT_SIZE] [--batch_size BATCH_SIZE] --model_path MODEL_PATH [--model_size MODEL_SIZE]
[--decoder_type DECODER_TYPE] [--export_path EXPORT_PATH] [--opset OPSET]
```

Demo of ONNX export using a `ddcolor_paper_tiny` model is available [here](notebooks/colorization_pipeline_onnxruntime.ipynb).

## Citation

If our work is helpful for your research, please consider citing:
Expand Down
165 changes: 165 additions & 0 deletions export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import types

import argparse
import torch
import torch.nn.functional as F
import numpy as np
import onnx
import onnxsim

from basicsr.archs.ddcolor_arch import DDColor

from onnx import load_model, save_model, shape_inference
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference


def parse_args():
parser = argparse.ArgumentParser(description="Export DDColor model to ONNX.")
parser.add_argument(
"--input_size",
type=int,
default=512,
help="Input image dimension.",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="Input batch size.",
)
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Path to export ONNX model.",
)
parser.add_argument(
"--model_size",
type=str,
default="tiny",
help="Path to export ONNX model.",
)
parser.add_argument(
"--decoder_type",
type=str,
default="MultiScaleColorDecoder",
help="Path to export ONNX model.",
)
parser.add_argument(
"--export_path",
type=str,
default="./model.onnx",
help="Path to export ONNX model.",
)
parser.add_argument(
"--opset",
type=int,
default=12,
help="ONNX opset version.",
)


return parser.parse_args()


def create_onnx_export(args):
input_size = args.input_size
device = torch.device('cpu')
if args.model_size == 'tiny':
encoder_name = 'convnext-t'
else:
encoder_name = 'convnext-l'

# hardcoded in inference/colorization_pipeline.py
# decoder_type = "MultiScaleColorDecoder"

if args.decoder_type == 'MultiScaleColorDecoder':
model = DDColor(
encoder_name=encoder_name,
decoder_name='MultiScaleColorDecoder',
input_size=[input_size, input_size],
num_output_channels=2,
last_norm='Spectral',
do_normalize=False,
num_queries=100,
num_scales=3,
dec_layers=9,
).to(device)
elif args.decoder_type == 'SingleColorDecoder':
model = DDColor(
encoder_name=encoder_name,
decoder_name='SingleColorDecoder',
input_size=[input_size, input_size],
num_output_channels=2,
last_norm='Spectral',
do_normalize=False,
num_queries=256,
).to(device)
else:
raise("decoder_type not implemented.")

model.load_state_dict(
torch.load(args.model_path, map_location=device)['params'],
strict=False)
model.eval()

channels = 3 # RGB image has 3 channels

random_input = torch.rand((args.batch_size, channels, input_size, input_size), dtype=torch.float32)

dynamic_axes = {}
if args.batch_size == 0:
dynamic_axes[0] = "batch"
if input_size == 0:
dynamic_axes[2] = "height"
dynamic_axes[3] = "width"

torch.onnx.export(
model,
random_input,
args.export_path,
opset_version=args.opset,
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": dynamic_axes,
"output": dynamic_axes
},
)

def check_onnx_export(export_path):
save_model(
shape_inference.infer_shapes(
load_model(export_path),
check_type=True,
strict_mode=True,
data_prop=True

),
export_path
)

save_model(
SymbolicShapeInference.infer_shapes(load_model(export_path),
auto_merge=True,
guess_output_rank=True
),
export_path,
)

model_onnx = onnx.load(export_path) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model

model_onnx, check = onnxsim.simplify(model_onnx)
assert check, "assert check failed"
onnx.save(model_onnx, export_path)


if __name__ == '__main__':
args = parse_args()

create_onnx_export(args)
print(f'ONNX file successfully created at {args.export_path}')
check_onnx_export(args.export_path)
print(f'ONNX file at {args.export_path} verifed shapes and simplified')

276 changes: 276 additions & 0 deletions notebooks/colorization_pipeline_onnxruntime.ipynb

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ tqdm==4.65.0
wandb==0.15.5
scikit-image==0.22.0
tensorboard
huggingface_hub
huggingface_hub
ipykernel
matplotlib

0 comments on commit 1c7861d

Please sign in to comment.