Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
HolyWu committed Dec 2, 2023
1 parent 79ef316 commit fdcbce9
Show file tree
Hide file tree
Showing 15 changed files with 1,726 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Auto detect text files and perform LF normalization
* text=auto
160 changes: 160 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
26 changes: 24 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,24 @@
# vs-ddcolor
DDColor function for VapourSynth
# DDColor
Towards Photo-Realistic Image Colorization via Dual Decoders, based on https://github.com/piddnad/DDColor.


## Dependencies
- [PyTorch](https://pytorch.org/get-started) 2.1.1 or later
- [VapourSynth](http://www.vapoursynth.com/) R62 or later


## Installation
```
pip install -U vsddcolor
python -m vsddcolor
```


## Usage
```python
from vsddcolor import ddcolor

ret = ddcolor(clip)
```

See `__init__.py` for the description of the parameters.
30 changes: 30 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "vsddcolor"
version = "1.0.0"
description = "DDColor function for VapourSynth"
readme = "README.md"
requires-python = ">=3.11"
authors = [{name = "HolyWu", email = "[email protected]"}]
keywords = ["DDColor", "VapourSynth"]
classifiers = [
"Operating System :: OS Independent",
"Programming Language :: Python :: 3 :: Only",
"Topic :: Multimedia :: Video"
]
dependencies = [
"kornia>=0.7.0",
"numpy>=1.26.2",
"requests>=2.31.0",
"timm>=0.9.12",
"torch>=2.1.1",
"tqdm>=4.66.1",
"VapourSynth>=62"
]

[project.urls]
"Homepage" = "https://github.com/HolyWu/vs-ddcolor"
"Bug Tracker" = "https://github.com/HolyWu/vs-ddcolor/issues"
119 changes: 119 additions & 0 deletions vsddcolor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from __future__ import annotations

import os
from threading import Lock

import kornia
import numpy as np
import torch
import torch.nn.functional as F
import vapoursynth as vs

from .ddcolor_arch import DDColor

__version__ = "1.0.0"

os.environ["CUDA_MODULE_LOADING"] = "LAZY"

model_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")


@torch.inference_mode()
def ddcolor(
clip: vs.VideoNode, device_index: int | None = None, num_streams: int = 1, model: int = 1, input_size: int = 512
) -> vs.VideoNode:
"""Towards Photo-Realistic Image Colorization via Dual Decoders
:param clip: Clip to process. Only RGBH and RGBS formats are supported. RGBH uses the bfloat16
data type for inference while RGBS uses the float32 data type.
:param device_index: Device ordinal of the GPU.
:param num_streams: Number of CUDA streams to enqueue the kernels.
:param model: Model to use.
0 = ddcolor_modelscope
1 = ddcolor_artistic
:param input_size: Input size for model.
"""
if not isinstance(clip, vs.VideoNode):
raise vs.Error("ddcolor: this is not a clip")

if clip.format.id not in [vs.RGBH, vs.RGBS]:
raise vs.Error("ddcolor: only RGBH and RGBS formats are supported")

if not torch.cuda.is_available():
raise vs.Error("ddcolor: CUDA is not available")

if num_streams < 1:
raise vs.Error("ddcolor: num_streams must be at least 1")

if model not in range(2):
raise vs.Error("ddcolor: model must be 0 or 1")

if os.path.getsize(os.path.join(model_dir, "ddcolor_artistic.pth")) == 0:
raise vs.Error("ddcolor: model files have not been downloaded. run 'python -m vsddcolor' first")

torch.set_float32_matmul_precision("high")

device = torch.device("cuda", device_index)

stream = [torch.cuda.Stream(device=device) for _ in range(num_streams)]
stream_lock = [Lock() for _ in range(num_streams)]

match model:
case 0:
model_name = "ddcolor_modelscope.pth"
case 1:
model_name = "ddcolor_artistic.pth"

state_dict = torch.load(os.path.join(model_dir, model_name), map_location="cpu")["params"]

module = DDColor(input_size=(input_size, input_size), num_output_channels=2, last_norm="Spectral", num_queries=100)
module.load_state_dict(state_dict, strict=False)
module.eval().to(device, memory_format=torch.channels_last)
if clip.format.bits_per_sample == 16:
module.bfloat16()

index = -1
index_lock = Lock()

@torch.inference_mode()
def inference(n: int, f: vs.VideoFrame) -> vs.VideoFrame:
nonlocal index
with index_lock:
index = (index + 1) % num_streams
local_index = index

with stream_lock[local_index], torch.cuda.stream(stream[local_index]):
img = frame_to_tensor(f, device)
orig_l = kornia.color.rgb_to_lab(img)[:, :1, :, :]

img = F.interpolate(img, (input_size, input_size), mode="bilinear")
img_l = kornia.color.rgb_to_lab(img)[:, :1, :, :]
img_gray_lab = torch.cat([img_l, torch.zeros_like(img_l), torch.zeros_like(img_l)], dim=1)
img_gray_rgb = kornia.color.lab_to_rgb(img_gray_lab)

output_ab = module(img_gray_rgb)

output_ab_resize = F.interpolate(output_ab, (clip.height, clip.width), mode="bilinear")
output_lab = torch.cat([orig_l, output_ab_resize], dim=1)
output = kornia.color.lab_to_rgb(output_lab)

return tensor_to_frame(output, f.copy())

return clip.std.FrameEval(lambda n: clip.std.ModifyFrame(clip, inference), clip_src=clip)


def frame_to_tensor(frame: vs.VideoFrame, device: torch.device) -> torch.Tensor:
array = np.stack([np.asarray(frame[plane]) for plane in range(frame.format.num_planes)])
tensor = torch.from_numpy(array).unsqueeze(0).to(device, memory_format=torch.channels_last)
if tensor.dtype == torch.half:
tensor = tensor.bfloat16()
return tensor.clamp(0.0, 1.0)


def tensor_to_frame(tensor: torch.Tensor, frame: vs.VideoFrame) -> vs.VideoFrame:
if tensor.dtype == torch.bfloat16:
tensor = tensor.half()
array = tensor.squeeze(0).detach().cpu().numpy()
for plane in range(frame.format.num_planes):
np.copyto(np.asarray(frame[plane]), array[plane, :, :])
return frame
28 changes: 28 additions & 0 deletions vsddcolor/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os

import requests
from tqdm import tqdm


def download_model(url: str) -> None:
filename = url.split("/")[-1]
r = requests.get(url, stream=True)
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "models", filename), "wb") as f:
with tqdm(
unit="B",
unit_scale=True,
unit_divisor=1024,
miniters=1,
desc=filename,
total=int(r.headers.get("content-length", 0)),
) as pbar:
for chunk in r.iter_content(chunk_size=4096):
f.write(chunk)
pbar.update(len(chunk))


if __name__ == "__main__":
url = "https://github.com/HolyWu/vs-ddcolor/releases/download/model/"
models = ["ddcolor_artistic", "ddcolor_modelscope"]
for model in models:
download_model(url + model + ".pth")
Loading

0 comments on commit fdcbce9

Please sign in to comment.