From 2fc0036e4e1d661be11755e8111911071dd690b4 Mon Sep 17 00:00:00 2001 From: cdelteil Date: Sat, 20 Apr 2024 16:12:29 +0200 Subject: [PATCH] feat: clean code, add CI, precommit --- .github/workflows/ci.yml | 18 + .idea/.gitignore | 8 + .idea/dictionaries/cdelteil.xml | 3 + .idea/image-video-colorization.iml | 12 + .idea/inspectionProfiles/Project_Default.xml | 86 +++++ .../inspectionProfiles/profiles_settings.xml | 7 + .idea/misc.xml | 7 + .idea/modules.xml | 8 + .pre-commit-config.yaml | 24 ++ "01_\360\237\223\274_Upload_Video_File.py" | 45 +-- README.md | 12 + __init__.py | 0 models/__init__.py | 0 models/deep_colorization/__init__.py | 2 + .../deep_colorization/colorizers/__init__.py | 10 +- .../colorizers/base_color.py | 30 +- models/deep_colorization/colorizers/eccv16.py | 246 ++++++++++---- .../colorizers/siggraph17.py | 310 +++++++++++++----- models/deep_colorization/colorizers/util.py | 63 ++-- ...02_\360\237\216\245_Input_Youtube_Link.py" | 53 +-- ...0\237\226\274\357\270\217_Input_Images.py" | 55 ++-- pages/__init__.py | 0 setup.cfg | 12 + utils.py | 77 ++++- 24 files changed, 802 insertions(+), 286 deletions(-) create mode 100644 .github/workflows/ci.yml create mode 100644 .idea/.gitignore create mode 100644 .idea/dictionaries/cdelteil.xml create mode 100644 .idea/image-video-colorization.iml create mode 100644 .idea/inspectionProfiles/Project_Default.xml create mode 100644 .idea/inspectionProfiles/profiles_settings.xml create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 .pre-commit-config.yaml create mode 100644 __init__.py create mode 100644 models/__init__.py create mode 100644 models/deep_colorization/__init__.py create mode 100644 pages/__init__.py create mode 100644 setup.cfg diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..7b4c05b --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,18 @@ +name: CI + +on: [push, pull_request] + +jobs: + pre-commit: + runs-on: ubuntu-latest + name: Do the code respects Python standards? + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.10.0' + - name: Install pre-commit + run: pip install pre-commit + - name: Run pre-commit + run: pre-commit run --all-files \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/dictionaries/cdelteil.xml b/.idea/dictionaries/cdelteil.xml new file mode 100644 index 0000000..e72f0f2 --- /dev/null +++ b/.idea/dictionaries/cdelteil.xml @@ -0,0 +1,3 @@ + + + \ No newline at end of file diff --git a/.idea/image-video-colorization.iml b/.idea/image-video-colorization.iml new file mode 100644 index 0000000..f5ea913 --- /dev/null +++ b/.idea/image-video-colorization.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..ebe8626 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,86 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..dd4c951 --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,7 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..8999ef5 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..7785f95 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..e5049b1 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,24 @@ +repos: + - repo: https://github.com/psf/black + rev: 24.4.0 + hooks: + - id: black + args: ['--line-length=120', '--verbose'] + exclude: '^models/' + + - repo: https://github.com/pycqa/flake8 + rev: '7.0.0' + hooks: + - id: flake8 + exclude: '^models/' + + - repo: https://github.com/pre-commit/mirrors-pylint + rev: v3.0.0a5 + hooks: + - id: pylint + name: pylint + entry: pylint + language: system + args: ['.', '--rcfile=setup.cfg', '--fail-under=8'] + exclude: '^models/' + types: [python] \ No newline at end of file diff --git "a/01_\360\237\223\274_Upload_Video_File.py" "b/01_\360\237\223\274_Upload_Video_File.py" index 22a84e3..ef014ec 100644 --- "a/01_\360\237\223\274_Upload_Video_File.py" +++ "b/01_\360\237\223\274_Upload_Video_File.py" @@ -6,44 +6,46 @@ import moviepy.editor as mp import numpy as np import streamlit as st -from streamlit_lottie import st_lottie -from tqdm import tqdm - -from models.deep_colorization.colorizers import eccv16 -from utils import load_lottieurl, format_time, colorize_frame, change_model -st.set_page_config(page_title="Image & Video Colorizer", page_icon="🎨", layout="wide") +from tqdm import tqdm -loaded_model = eccv16(pretrained=True).eval() -current_model = "None" +from utils import format_time, colorize_frame, change_model, load_model, setup_columns, set_page_config -col1, col2 = st.columns([1, 3]) -with col1: - lottie = load_lottieurl("https://assets5.lottiefiles.com/packages/lf20_RHdEuzVfEL.json") - st_lottie(lottie) +set_page_config() +loaded_model = load_model() +col2 = setup_columns() +current_model = None with col2: - st.write(""" + st.write( + """ ## B&W Videos Colorizer ##### Upload a black and white video and get a colorized version of it. ###### ➠ This space is using CPU Basic so it might take a while to colorize a video. - ###### ➠ If you want more models and GPU available please support this space by donating.""") + ###### ➠ If you want more models and GPU available please support this space by donating.""" + ) def main(): + """ + Main function to run this page + """ model = st.selectbox( - "Select Model (Both models have their pros and cons, I recommend trying both and keeping the best for your task)", - ["ECCV16", "SIGGRAPH17"], index=0) + "Select Model (Both models have their pros and cons, I recommend trying both and keeping the best for your " + "task)", + ["ECCV16", "SIGGRAPH17"], + index=0, + ) loaded_model = change_model(current_model, model) st.write(f"Model is now {model}") - uploaded_file = st.file_uploader("Upload your video here...", type=['mp4', 'mov', 'avi', 'mkv']) + uploaded_file = st.file_uploader("Upload your video here...", type=["mp4", "mov", "avi", "mkv"]) if st.button("Colorize"): if uploaded_file is not None: file_extension = os.path.splitext(uploaded_file.name)[1].lower() - if file_extension in ['.mp4', '.avi', '.mov', '.mkv']: + if file_extension in [".mp4", ".avi", ".mov", ".mkv"]: # Save the video file to a temporary location temp_file = tempfile.NamedTemporaryFile(delete=False) temp_file.write(uploaded_file.read()) @@ -73,7 +75,7 @@ def main(): start_time = time.time() time_text = st.text("Time Remaining: ") # Initialize text value - for _ in tqdm(range(total_frames), unit='frame', desc="Progress"): + for _ in tqdm(range(total_frames), unit="frame", desc="Progress"): ret, frame = video.read() if not ret: break @@ -123,7 +125,7 @@ def main(): st.download_button( label="Download Colorized Video", data=open(converted_filename, "rb").read(), - file_name="colorized_video.mp4" + file_name="colorized_video.mp4", ) # Close and delete the temporary file after processing @@ -135,4 +137,5 @@ def main(): main() st.markdown( "###### Made with :heart: by [Clément Delteil](https://www.linkedin.com/in/clementdelteil/) [![this is an " - "image link](https://i.imgur.com/thJhzOO.png)](https://www.buymeacoffee.com/clementdelteil)") + "image link](https://i.imgur.com/thJhzOO.png)](https://www.buymeacoffee.com/clementdelteil)" + ) diff --git a/README.md b/README.md index 3236a40..99caedf 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,18 @@ The following features are available: ## Interface +## Running Locally +If you want to process longer videos and you're limited by the Hugging Face space memory's limits, you can run this app locally. + +`ffmpeg.exe` is needed to run this app, you can install it using `brew install ffmpeg` and update the `IMAGEIO_FFMPEG_EXE` environment variable accordingly. + +```bash +git clone https://github.com/Wazzabeee/image-video-colorization +cd image-video-colorization +pip install -r requirements.txt +streamlit run 01_📼_Upload_Video_File.py +``` + ## Todos Other models based on GANs will probably be implemented in the future if my application for a community grant to gain access to a GPU on Hugging Face is successful. diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/deep_colorization/__init__.py b/models/deep_colorization/__init__.py new file mode 100644 index 0000000..fa36340 --- /dev/null +++ b/models/deep_colorization/__init__.py @@ -0,0 +1,2 @@ +# inside models/deep_colorization/__init__.py +from .colorizers import eccv16, siggraph17, load_img, preprocess_img, postprocess_tens diff --git a/models/deep_colorization/colorizers/__init__.py b/models/deep_colorization/colorizers/__init__.py index 058dfb3..5d2a7b5 100644 --- a/models/deep_colorization/colorizers/__init__.py +++ b/models/deep_colorization/colorizers/__init__.py @@ -1,6 +1,4 @@ - -from .base_color import * -from .eccv16 import * -from .siggraph17 import * -from .util import * - +from .base_color import BaseColor +from .eccv16 import ECCVGenerator, eccv16 +from .siggraph17 import SIGGRAPHGenerator, siggraph17 +from .util import load_img, resize_img, preprocess_img, postprocess_tens diff --git a/models/deep_colorization/colorizers/base_color.py b/models/deep_colorization/colorizers/base_color.py index 00beb39..b1f1137 100644 --- a/models/deep_colorization/colorizers/base_color.py +++ b/models/deep_colorization/colorizers/base_color.py @@ -1,24 +1,22 @@ - -import torch from torch import nn -class BaseColor(nn.Module): - def __init__(self): - super(BaseColor, self).__init__() - self.l_cent = 50. - self.l_norm = 100. - self.ab_norm = 110. +class BaseColor(nn.Module): + def __init__(self): + super(BaseColor, self).__init__() - def normalize_l(self, in_l): - return (in_l-self.l_cent)/self.l_norm + self.l_cent = 50.0 + self.l_norm = 100.0 + self.ab_norm = 110.0 - def unnormalize_l(self, in_l): - return in_l*self.l_norm + self.l_cent + def normalize_l(self, in_l): + return (in_l - self.l_cent) / self.l_norm - def normalize_ab(self, in_ab): - return in_ab/self.ab_norm + def unnormalize_l(self, in_l): + return in_l * self.l_norm + self.l_cent - def unnormalize_ab(self, in_ab): - return in_ab*self.ab_norm + def normalize_ab(self, in_ab): + return in_ab / self.ab_norm + def unnormalize_ab(self, in_ab): + return in_ab * self.ab_norm diff --git a/models/deep_colorization/colorizers/eccv16.py b/models/deep_colorization/colorizers/eccv16.py index 896ed47..cb4a35b 100644 --- a/models/deep_colorization/colorizers/eccv16.py +++ b/models/deep_colorization/colorizers/eccv16.py @@ -1,4 +1,3 @@ - import torch import torch.nn as nn import numpy as np @@ -6,70 +5,175 @@ from .base_color import * + class ECCVGenerator(BaseColor): def __init__(self, norm_layer=nn.BatchNorm2d): super(ECCVGenerator, self).__init__() - model1=[nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True),] - model1+=[nn.ReLU(True),] - model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True),] - model1+=[nn.ReLU(True),] - model1+=[norm_layer(64),] - - model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),] - model2+=[nn.ReLU(True),] - model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True),] - model2+=[nn.ReLU(True),] - model2+=[norm_layer(128),] - - model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),] - model3+=[nn.ReLU(True),] - model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] - model3+=[nn.ReLU(True),] - model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=True),] - model3+=[nn.ReLU(True),] - model3+=[norm_layer(256),] - - model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),] - model4+=[nn.ReLU(True),] - model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] - model4+=[nn.ReLU(True),] - model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] - model4+=[nn.ReLU(True),] - model4+=[norm_layer(512),] - - model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] - model5+=[nn.ReLU(True),] - model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] - model5+=[nn.ReLU(True),] - model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] - model5+=[nn.ReLU(True),] - model5+=[norm_layer(512),] - - model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] - model6+=[nn.ReLU(True),] - model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] - model6+=[nn.ReLU(True),] - model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] - model6+=[nn.ReLU(True),] - model6+=[norm_layer(512),] - - model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] - model7+=[nn.ReLU(True),] - model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] - model7+=[nn.ReLU(True),] - model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] - model7+=[nn.ReLU(True),] - model7+=[norm_layer(512),] - - model8=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True),] - model8+=[nn.ReLU(True),] - model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] - model8+=[nn.ReLU(True),] - model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] - model8+=[nn.ReLU(True),] - - model8+=[nn.Conv2d(256, 313, kernel_size=1, stride=1, padding=0, bias=True),] + model1 = [ + nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True), + ] + model1 += [ + nn.ReLU(True), + ] + model1 += [ + nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True), + ] + model1 += [ + nn.ReLU(True), + ] + model1 += [ + norm_layer(64), + ] + + model2 = [ + nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True), + ] + model2 += [ + nn.ReLU(True), + ] + model2 += [ + nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True), + ] + model2 += [ + nn.ReLU(True), + ] + model2 += [ + norm_layer(128), + ] + + model3 = [ + nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True), + ] + model3 += [ + nn.ReLU(True), + ] + model3 += [ + nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), + ] + model3 += [ + nn.ReLU(True), + ] + model3 += [ + nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=True), + ] + model3 += [ + nn.ReLU(True), + ] + model3 += [ + norm_layer(256), + ] + + model4 = [ + nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True), + ] + model4 += [ + nn.ReLU(True), + ] + model4 += [ + nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), + ] + model4 += [ + nn.ReLU(True), + ] + model4 += [ + nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), + ] + model4 += [ + nn.ReLU(True), + ] + model4 += [ + norm_layer(512), + ] + + model5 = [ + nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True), + ] + model5 += [ + nn.ReLU(True), + ] + model5 += [ + nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True), + ] + model5 += [ + nn.ReLU(True), + ] + model5 += [ + nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True), + ] + model5 += [ + nn.ReLU(True), + ] + model5 += [ + norm_layer(512), + ] + + model6 = [ + nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True), + ] + model6 += [ + nn.ReLU(True), + ] + model6 += [ + nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True), + ] + model6 += [ + nn.ReLU(True), + ] + model6 += [ + nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True), + ] + model6 += [ + nn.ReLU(True), + ] + model6 += [ + norm_layer(512), + ] + + model7 = [ + nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), + ] + model7 += [ + nn.ReLU(True), + ] + model7 += [ + nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), + ] + model7 += [ + nn.ReLU(True), + ] + model7 += [ + nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), + ] + model7 += [ + nn.ReLU(True), + ] + model7 += [ + norm_layer(512), + ] + + model8 = [ + nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True), + ] + model8 += [ + nn.ReLU(True), + ] + model8 += [ + nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), + ] + model8 += [ + nn.ReLU(True), + ] + model8 += [ + nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), + ] + model8 += [ + nn.ReLU(True), + ] + + model8 += [ + nn.Conv2d(256, 313, kernel_size=1, stride=1, padding=0, bias=True), + ] self.model1 = nn.Sequential(*model1) self.model2 = nn.Sequential(*model2) @@ -82,7 +186,7 @@ def __init__(self, norm_layer=nn.BatchNorm2d): self.softmax = nn.Softmax(dim=1) self.model_out = nn.Conv2d(313, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=False) - self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear') + self.upsample4 = nn.Upsample(scale_factor=4, mode="bilinear") def forward(self, input_l): conv1_2 = self.model1(self.normalize_l(input_l)) @@ -97,9 +201,17 @@ def forward(self, input_l): return self.unnormalize_ab(self.upsample4(out_reg)) + def eccv16(pretrained=True): - model = ECCVGenerator() - if(pretrained): - import torch.utils.model_zoo as model_zoo - model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/colorization_release_v2-9b330a0b.pth',map_location='cpu',check_hash=True)) - return model + model = ECCVGenerator() + if pretrained: + import torch.utils.model_zoo as model_zoo + + model.load_state_dict( + model_zoo.load_url( + "https://colorizers.s3.us-east-2.amazonaws.com/colorization_release_v2-9b330a0b.pth", + map_location="cpu", + check_hash=True, + ) + ) + return model diff --git a/models/deep_colorization/colorizers/siggraph17.py b/models/deep_colorization/colorizers/siggraph17.py index 775a23f..a89133d 100644 --- a/models/deep_colorization/colorizers/siggraph17.py +++ b/models/deep_colorization/colorizers/siggraph17.py @@ -3,108 +3,239 @@ from .base_color import * + class SIGGRAPHGenerator(BaseColor): def __init__(self, norm_layer=nn.BatchNorm2d, classes=529): super(SIGGRAPHGenerator, self).__init__() # Conv1 - model1=[nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1, bias=True),] - model1+=[nn.ReLU(True),] - model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True),] - model1+=[nn.ReLU(True),] - model1+=[norm_layer(64),] + model1 = [ + nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1, bias=True), + ] + model1 += [ + nn.ReLU(True), + ] + model1 += [ + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True), + ] + model1 += [ + nn.ReLU(True), + ] + model1 += [ + norm_layer(64), + ] # add a subsampling operation # Conv2 - model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),] - model2+=[nn.ReLU(True),] - model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),] - model2+=[nn.ReLU(True),] - model2+=[norm_layer(128),] + model2 = [ + nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True), + ] + model2 += [ + nn.ReLU(True), + ] + model2 += [ + nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True), + ] + model2 += [ + nn.ReLU(True), + ] + model2 += [ + norm_layer(128), + ] # add a subsampling layer operation # Conv3 - model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),] - model3+=[nn.ReLU(True),] - model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] - model3+=[nn.ReLU(True),] - model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] - model3+=[nn.ReLU(True),] - model3+=[norm_layer(256),] + model3 = [ + nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True), + ] + model3 += [ + nn.ReLU(True), + ] + model3 += [ + nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), + ] + model3 += [ + nn.ReLU(True), + ] + model3 += [ + nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), + ] + model3 += [ + nn.ReLU(True), + ] + model3 += [ + norm_layer(256), + ] # add a subsampling layer operation # Conv4 - model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),] - model4+=[nn.ReLU(True),] - model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] - model4+=[nn.ReLU(True),] - model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] - model4+=[nn.ReLU(True),] - model4+=[norm_layer(512),] + model4 = [ + nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True), + ] + model4 += [ + nn.ReLU(True), + ] + model4 += [ + nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), + ] + model4 += [ + nn.ReLU(True), + ] + model4 += [ + nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), + ] + model4 += [ + nn.ReLU(True), + ] + model4 += [ + norm_layer(512), + ] # Conv5 - model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] - model5+=[nn.ReLU(True),] - model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] - model5+=[nn.ReLU(True),] - model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] - model5+=[nn.ReLU(True),] - model5+=[norm_layer(512),] + model5 = [ + nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True), + ] + model5 += [ + nn.ReLU(True), + ] + model5 += [ + nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True), + ] + model5 += [ + nn.ReLU(True), + ] + model5 += [ + nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True), + ] + model5 += [ + nn.ReLU(True), + ] + model5 += [ + norm_layer(512), + ] # Conv6 - model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] - model6+=[nn.ReLU(True),] - model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] - model6+=[nn.ReLU(True),] - model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] - model6+=[nn.ReLU(True),] - model6+=[norm_layer(512),] + model6 = [ + nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True), + ] + model6 += [ + nn.ReLU(True), + ] + model6 += [ + nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True), + ] + model6 += [ + nn.ReLU(True), + ] + model6 += [ + nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True), + ] + model6 += [ + nn.ReLU(True), + ] + model6 += [ + norm_layer(512), + ] # Conv7 - model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] - model7+=[nn.ReLU(True),] - model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] - model7+=[nn.ReLU(True),] - model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] - model7+=[nn.ReLU(True),] - model7+=[norm_layer(512),] + model7 = [ + nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), + ] + model7 += [ + nn.ReLU(True), + ] + model7 += [ + nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), + ] + model7 += [ + nn.ReLU(True), + ] + model7 += [ + nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), + ] + model7 += [ + nn.ReLU(True), + ] + model7 += [ + norm_layer(512), + ] # Conv7 - model8up=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True)] - model3short8=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] + model8up = [nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True)] + model3short8 = [ + nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), + ] - model8=[nn.ReLU(True),] - model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] - model8+=[nn.ReLU(True),] - model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] - model8+=[nn.ReLU(True),] - model8+=[norm_layer(256),] + model8 = [ + nn.ReLU(True), + ] + model8 += [ + nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), + ] + model8 += [ + nn.ReLU(True), + ] + model8 += [ + nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), + ] + model8 += [ + nn.ReLU(True), + ] + model8 += [ + norm_layer(256), + ] # Conv9 - model9up=[nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),] - model2short9=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),] - # add the two feature maps above + model9up = [ + nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True), + ] + model2short9 = [ + nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True), + ] + # add the two feature maps above - model9=[nn.ReLU(True),] - model9+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),] - model9+=[nn.ReLU(True),] - model9+=[norm_layer(128),] + model9 = [ + nn.ReLU(True), + ] + model9 += [ + nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True), + ] + model9 += [ + nn.ReLU(True), + ] + model9 += [ + norm_layer(128), + ] # Conv10 - model10up=[nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True),] - model1short10=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),] + model10up = [ + nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True), + ] + model1short10 = [ + nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True), + ] # add the two feature maps above - model10=[nn.ReLU(True),] - model10+=[nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=True),] - model10+=[nn.LeakyReLU(negative_slope=.2),] + model10 = [ + nn.ReLU(True), + ] + model10 += [ + nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=True), + ] + model10 += [ + nn.LeakyReLU(negative_slope=0.2), + ] # classification output - model_class=[nn.Conv2d(256, classes, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),] + model_class = [ + nn.Conv2d(256, classes, kernel_size=1, padding=0, dilation=1, stride=1, bias=True), + ] # regression output - model_out=[nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),] - model_out+=[nn.Tanh()] + model_out = [ + nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=True), + ] + model_out += [nn.Tanh()] self.model1 = nn.Sequential(*model1) self.model2 = nn.Sequential(*model2) @@ -126,19 +257,27 @@ def __init__(self, norm_layer=nn.BatchNorm2d, classes=529): self.model_class = nn.Sequential(*model_class) self.model_out = nn.Sequential(*model_out) - self.upsample4 = nn.Sequential(*[nn.Upsample(scale_factor=4, mode='bilinear'),]) - self.softmax = nn.Sequential(*[nn.Softmax(dim=1),]) + self.upsample4 = nn.Sequential( + *[ + nn.Upsample(scale_factor=4, mode="bilinear"), + ] + ) + self.softmax = nn.Sequential( + *[ + nn.Softmax(dim=1), + ] + ) - def forward(self, input_A, input_B=None, mask_B=None): - if(input_B is None): - input_B = torch.cat((input_A*0, input_A*0), dim=1) - if(mask_B is None): - mask_B = input_A*0 + def forward(self, input_a, input_b=None, mask_b=None): + if input_b is None: + input_b = torch.cat((input_a * 0, input_a * 0), dim=1) + if mask_b is None: + mask_b = input_a * 0 - conv1_2 = self.model1(torch.cat((self.normalize_l(input_A),self.normalize_ab(input_B),mask_B),dim=1)) - conv2_2 = self.model2(conv1_2[:,:,::2,::2]) - conv3_3 = self.model3(conv2_2[:,:,::2,::2]) - conv4_3 = self.model4(conv3_3[:,:,::2,::2]) + conv1_2 = self.model1(torch.cat((self.normalize_l(input_a), self.normalize_ab(input_b), mask_b), dim=1)) + conv2_2 = self.model2(conv1_2[:, :, ::2, ::2]) + conv3_3 = self.model3(conv2_2[:, :, ::2, ::2]) + conv4_3 = self.model4(conv3_3[:, :, ::2, ::2]) conv5_3 = self.model5(conv4_3) conv6_3 = self.model6(conv5_3) conv7_3 = self.model7(conv6_3) @@ -159,10 +298,17 @@ def forward(self, input_A, input_B=None, mask_B=None): return self.unnormalize_ab(out_reg) + def siggraph17(pretrained=True): model = SIGGRAPHGenerator() - if(pretrained): + if pretrained: import torch.utils.model_zoo as model_zoo - model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/siggraph17-df00044c.pth',map_location='cpu',check_hash=True)) - return model + model.load_state_dict( + model_zoo.load_url( + "https://colorizers.s3.us-east-2.amazonaws.com/siggraph17-df00044c.pth", + map_location="cpu", + check_hash=True, + ) + ) + return model diff --git a/models/deep_colorization/colorizers/util.py b/models/deep_colorization/colorizers/util.py index 79968ba..e479d2d 100644 --- a/models/deep_colorization/colorizers/util.py +++ b/models/deep_colorization/colorizers/util.py @@ -1,4 +1,3 @@ - from PIL import Image import numpy as np from skimage import color @@ -6,42 +5,46 @@ import torch.nn.functional as F from IPython import embed + def load_img(img_path): - out_np = np.asarray(Image.open(img_path)) - if(out_np.ndim==2): - out_np = np.tile(out_np[:,:,None],3) - return out_np + out_np = np.asarray(Image.open(img_path)) + if out_np.ndim == 2: + out_np = np.tile(out_np[:, :, None], 3) + return out_np + + +def resize_img(img, HW=(256, 256), resample=3): + return np.asarray(Image.fromarray(img).resize((HW[1], HW[0]), resample=resample)) + + +def preprocess_img(img_rgb_orig, HW=(256, 256), resample=3): + # return original size L and resized L as torch Tensors + img_rgb_rs = resize_img(img_rgb_orig, HW=HW, resample=resample) -def resize_img(img, HW=(256,256), resample=3): - return np.asarray(Image.fromarray(img).resize((HW[1],HW[0]), resample=resample)) + img_lab_orig = color.rgb2lab(img_rgb_orig) + img_lab_rs = color.rgb2lab(img_rgb_rs) -def preprocess_img(img_rgb_orig, HW=(256,256), resample=3): - # return original size L and resized L as torch Tensors - img_rgb_rs = resize_img(img_rgb_orig, HW=HW, resample=resample) - - img_lab_orig = color.rgb2lab(img_rgb_orig) - img_lab_rs = color.rgb2lab(img_rgb_rs) + img_l_orig = img_lab_orig[:, :, 0] + img_l_rs = img_lab_rs[:, :, 0] - img_l_orig = img_lab_orig[:,:,0] - img_l_rs = img_lab_rs[:,:,0] + tens_orig_l = torch.Tensor(img_l_orig)[None, None, :, :] + tens_rs_l = torch.Tensor(img_l_rs)[None, None, :, :] - tens_orig_l = torch.Tensor(img_l_orig)[None,None,:,:] - tens_rs_l = torch.Tensor(img_l_rs)[None,None,:,:] + return tens_orig_l, tens_rs_l - return (tens_orig_l, tens_rs_l) -def postprocess_tens(tens_orig_l, out_ab, mode='bilinear'): - # tens_orig_l 1 x 1 x H_orig x W_orig - # out_ab 1 x 2 x H x W +def postprocess_tens(tens_orig_l, out_ab, mode="bilinear"): + # tens_orig_l 1 x 1 x H_orig x W_orig + # out_ab 1 x 2 x H x W - HW_orig = tens_orig_l.shape[2:] - HW = out_ab.shape[2:] + HW_orig = tens_orig_l.shape[2:] + HW = out_ab.shape[2:] - # call resize function if needed - if(HW_orig[0]!=HW[0] or HW_orig[1]!=HW[1]): - out_ab_orig = F.interpolate(out_ab, size=HW_orig, mode='bilinear') - else: - out_ab_orig = out_ab + # call resize function if needed + if HW_orig[0] != HW[0] or HW_orig[1] != HW[1]: + out_ab_orig = F.interpolate(out_ab, size=HW_orig, mode="bilinear") + else: + out_ab_orig = out_ab - out_lab_orig = torch.cat((tens_orig_l, out_ab_orig), dim=1) - return color.lab2rgb(out_lab_orig.data.cpu().numpy()[0,...].transpose((1,2,0))) + out_lab_orig = torch.cat((tens_orig_l, out_ab_orig), dim=1) + return color.lab2rgb(out_lab_orig.data.cpu().numpy()[0, ...].transpose((1, 2, 0))) diff --git "a/pages/02_\360\237\216\245_Input_Youtube_Link.py" "b/pages/02_\360\237\216\245_Input_Youtube_Link.py" index fd9ec96..1e144eb 100644 --- "a/pages/02_\360\237\216\245_Input_Youtube_Link.py" +++ "b/pages/02_\360\237\216\245_Input_Youtube_Link.py" @@ -5,44 +5,52 @@ import numpy as np import streamlit as st from pytube import YouTube -from streamlit_lottie import st_lottie from tqdm import tqdm -from models.deep_colorization.colorizers import eccv16 -from utils import colorize_frame, format_time -from utils import load_lottieurl, change_model -st.set_page_config(page_title="Image & Video Colorizer", page_icon="🎨", layout="wide") +from utils import format_time, colorize_frame, change_model, load_model, setup_columns, set_page_config - -loaded_model = eccv16(pretrained=True).eval() -current_model = "None" - - -col1, col2 = st.columns([1, 3]) -with col1: - lottie = load_lottieurl("https://assets5.lottiefiles.com/packages/lf20_RHdEuzVfEL.json") - st_lottie(lottie) +set_page_config() +loaded_model = load_model() +col2 = setup_columns() +current_model = None with col2: - st.write(""" + st.write( + """ ## B&W Videos Colorizer ##### Input a YouTube black and white video link and get a colorized version of it. ###### ➠ This space is using CPU Basic so it might take a while to colorize a video. - ###### ➠ If you want more models and GPU available please support this space by donating.""") + ###### ➠ If you want more models and GPU available please support this space by donating.""" + ) @st.cache_data() def download_video(link): + """ + Download video from YouTube + """ yt = YouTube(link) - video = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first().download(filename="video.mp4") + video = ( + yt.streams.filter(progressive=True, file_extension="mp4") + .order_by("resolution") + .desc() + .first() + .download(filename="video.mp4") + ) return video def main(): + """ + Main function + """ model = st.selectbox( - "Select Model (Both models have their pros and cons, I recommend trying both and keeping the best for you task)", - ["ECCV16", "SIGGRAPH17"], index=0) + "Select Model (Both models have their pros and cons," + "I recommend trying both and keeping the best for you task)", + ["ECCV16", "SIGGRAPH17"], + index=0, + ) loaded_model = change_model(current_model, model) st.write(f"Model is now {model}") @@ -71,7 +79,7 @@ def main(): start_time = time.time() time_text = st.text("Time Remaining: ") # Initialize text value - for _ in tqdm(range(total_frames), unit='frame', desc="Progress"): + for _ in tqdm(range(total_frames), unit="frame", desc="Progress"): ret, frame = video.read() if not ret: break @@ -121,7 +129,7 @@ def main(): st.download_button( label="Download Colorized Video", data=open(converted_filename, "rb").read(), - file_name="colorized_video.mp4" + file_name="colorized_video.mp4", ) # Close and delete the temporary file after processing @@ -132,4 +140,5 @@ def main(): main() st.markdown( "###### Made with :heart: by [Clément Delteil](https://www.linkedin.com/in/clementdelteil/) [![this is an " - "image link](https://i.imgur.com/thJhzOO.png)](https://www.buymeacoffee.com/clementdelteil)") + "image link](https://i.imgur.com/thJhzOO.png)](https://www.buymeacoffee.com/clementdelteil)" + ) diff --git "a/pages/03_\360\237\226\274\357\270\217_Input_Images.py" "b/pages/03_\360\237\226\274\357\270\217_Input_Images.py" index 30c5191..d970291 100644 --- "a/pages/03_\360\237\226\274\357\270\217_Input_Images.py" +++ "b/pages/03_\360\237\226\274\357\270\217_Input_Images.py" @@ -3,47 +3,47 @@ import streamlit as st from PIL import Image -from streamlit_lottie import st_lottie -from models.deep_colorization.colorizers import eccv16 -from utils import colorize_image, change_model, load_lottieurl +from utils import change_model, load_model, setup_columns, set_page_config, colorize_image -st.set_page_config(page_title="Image & Video Colorizer", page_icon="🎨", layout="wide") - - -loaded_model = eccv16(pretrained=True).eval() -current_model = "None" - - -col1, col2 = st.columns([1, 3]) -with col1: - lottie = load_lottieurl("https://assets5.lottiefiles.com/packages/lf20_RHdEuzVfEL.json") - st_lottie(lottie) +set_page_config() +loaded_model = load_model() +col2 = setup_columns() +current_model = None with col2: - st.write(""" + st.write( + """ ## B&W Images Colorizer ##### Input a black and white image and get a colorized version of it. ###### ➠ If you want to colorize multiple images just upload them all at once. ###### ➠ Uploading already colored images won't raise errors but images won't look good. - ###### ➠ I recommend starting with the first model and then experimenting with the second one.""") + ###### ➠ I recommend starting with the first model and then experimenting with the second one.""" + ) def main(): + """ + Main function + """ model = st.selectbox( - "Select Model (Both models have their pros and cons, I recommend trying both and keeping the best for you task)", - ["ECCV16", "SIGGRAPH17"], index=0) + "Select Model (Both models have their pros and cons, " + "I recommend trying both and keeping the best for you task)", + ["ECCV16", "SIGGRAPH17"], + index=0, + ) # Make the user select a model loaded_model = change_model(current_model, model) st.write(f"Model is now {model}") # Ask the user if he wants to see colorization - display_results = st.checkbox('Display results in real time', value=True) + display_results = st.checkbox("Display results in real time", value=True) # Input for the user to upload images - uploaded_file = st.file_uploader("Upload your images here...", type=['jpg', 'png', 'jpeg'], - accept_multiple_files=True) + uploaded_file = st.file_uploader( + "Upload your images here...", type=["jpg", "png", "jpeg"], accept_multiple_files=True + ) # If the user clicks on the button if st.button("Colorize"): @@ -56,11 +56,11 @@ def main(): with col2: st.markdown('

After

', unsafe_allow_html=True) else: - col1, col2, col3 = st.columns(3) + col1, col2, _ = st.columns(3) for i, file in enumerate(uploaded_file): file_extension = os.path.splitext(file.name)[1].lower() - if file_extension in ['.jpg', '.png', '.jpeg']: + if file_extension in [".jpg", ".png", ".jpeg"]: image = Image.open(file) if display_results: with col1: @@ -68,12 +68,12 @@ def main(): with col2: with st.spinner("Colorizing image..."): out_img, new_img = colorize_image(file, loaded_model) - new_img.save("IMG_" + str(i+1) + ".jpg") + new_img.save("IMG_" + str(i + 1) + ".jpg") st.image(out_img, use_column_width="always") else: out_img, new_img = colorize_image(file, loaded_model) - new_img.save("IMG_" + str(i+1) + ".jpg") + new_img.save("IMG_" + str(i + 1) + ".jpg") if len(uploaded_file) > 1: # Create a zip file @@ -98,11 +98,12 @@ def main(): ) else: - st.warning('Upload a file', icon="⚠️") + st.warning("Upload a file", icon="⚠️") if __name__ == "__main__": main() st.markdown( "###### Made with :heart: by [Clément Delteil](https://www.linkedin.com/in/clementdelteil/) [![this is an " - "image link](https://i.imgur.com/thJhzOO.png)](https://www.buymeacoffee.com/clementdelteil)") + "image link](https://i.imgur.com/thJhzOO.png)](https://www.buymeacoffee.com/clementdelteil)" + ) diff --git a/pages/__init__.py b/pages/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..8285bb6 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,12 @@ +[pylint] +disable=C0103,C0114,R0913,R0914,C0200,C0301,E1101 + +[pep8] +max-line-length=120 +ignore=E121,E123,E126,E226,E24,E704,E203,W503 +exclude=venv,test_env,test_venv + +[flake8] +max-line-length=120 +ignore=E121,E123,E126,E226,E24,E704,E203,W503 +exclude=venv,test_env,test_venv \ No newline at end of file diff --git a/utils.py b/utils.py index 4274c2f..b5e4666 100644 --- a/utils.py +++ b/utils.py @@ -2,21 +2,62 @@ import requests import streamlit as st from PIL import Image +from streamlit_lottie import st_lottie -from models.deep_colorization.colorizers import postprocess_tens, preprocess_img, load_img, eccv16, siggraph17 +from models.deep_colorization import eccv16 +from models.deep_colorization import siggraph17 +from models.deep_colorization import postprocess_tens, preprocess_img, load_img + + +class SameModelException(ValueError): + """Exception raised when the same model is attempted to be reloaded.""" + + +def set_page_config(): + """ + Sets up the page config. + """ + st.set_page_config(page_title="Image & Video Colorizer", page_icon="🎨", layout="wide") + + +def load_model(): + """ + Loads the default model. + """ + return eccv16(pretrained=True).eval() + + +def setup_columns(): + """ + Sets up the columns. + """ + col1, col2 = st.columns([1, 3]) + lottie = load_lottieurl("https://assets5.lottiefiles.com/packages/lf20_RHdEuzVfEL.json") + with col1: + st_lottie(lottie) + return col2 # Define a function that we can use to load lottie files from a link. @st.cache_data() def load_lottieurl(url: str): - r = requests.get(url) - if r.status_code != 200: + """ + Load lottieurl image + """ + try: + r = requests.get(url, timeout=10) # Timeout set to 10 seconds + r.raise_for_status() # This will raise an exception for HTTP errors + return r.json() + except requests.RequestException as e: + print(f"Request failed: {e}") return None - return r.json() @st.cache_resource() def change_model(current_model, model): + """ + Change model + """ loaded_model = "None" if current_model != model: @@ -25,38 +66,44 @@ def change_model(current_model, model): elif model == "SIGGRAPH17": loaded_model = siggraph17(pretrained=True).eval() return loaded_model - else: - raise Exception("Model is the same as the current one.") + + raise SameModelException("Model is the same as the current one.") def format_time(seconds: float) -> str: """Formats time in seconds to a human readable format""" if seconds < 60: return f"{int(seconds)} seconds" - elif seconds < 3600: + if seconds < 3600: minutes = seconds // 60 seconds %= 60 return f"{minutes} minutes and {int(seconds)} seconds" - elif seconds < 86400: + if seconds < 86400: hours = seconds // 3600 minutes = (seconds % 3600) // 60 seconds %= 60 return f"{hours} hours, {minutes} minutes, and {int(seconds)} seconds" - else: - days = seconds // 86400 - hours = (seconds % 86400) // 3600 - minutes = (seconds % 3600) // 60 - seconds %= 60 - return f"{days} days, {hours} hours, {minutes} minutes, and {int(seconds)} seconds" + + days = seconds // 86400 + hours = (seconds % 86400) // 3600 + minutes = (seconds % 3600) // 60 + seconds %= 60 + return f"{days} days, {hours} hours, {minutes} minutes, and {int(seconds)} seconds" # Function to colorize video frames def colorize_frame(frame, colorizer) -> np.ndarray: + """ + Colorize frame + """ tens_l_orig, tens_l_rs = preprocess_img(frame, HW=(256, 256)) return postprocess_tens(tens_l_orig, colorizer(tens_l_rs).cpu()) def colorize_image(file, loaded_model): + """ + Colorize image + """ img = load_img(file) # If user input a colored image with 4 channels, discard the fourth channel if img.shape[2] == 4: @@ -66,4 +113,4 @@ def colorize_image(file, loaded_model): out_img = postprocess_tens(tens_l_orig, loaded_model(tens_l_rs).cpu()) new_img = Image.fromarray((out_img * 255).astype(np.uint8)) - return out_img, new_img \ No newline at end of file + return out_img, new_img