diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000..dfe0770
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,2 @@
+# Auto detect text files and perform LF normalization
+* text=auto
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..68bc17f
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,160 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+#   For a library or package, you might want to ignore these files since the code is
+#   intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+#   However, in case of collaboration, if having platform-specific dependencies or dependencies
+#   having no cross-platform support, pipenv may install dependencies that don't work, or not
+#   install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+#   This is especially recommended for binary packages to ensure reproducibility, and is more
+#   commonly ignored for libraries.
+#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+#   in version control.
+#   https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+#  and can be added to the global gitignore or merged into this file.  For a more nuclear
+#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..44bf750
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,35 @@
+S-Lab License 1.0
+
+Copyright 2022 S-Lab
+
+Redistribution and use for non-commercial purpose in source and 
+binary forms, with or without modification, are permitted provided 
+that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright 
+   notice, this list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright 
+   notice, this list of conditions and the following disclaimer in 
+   the documentation and/or other materials provided with the 
+   distribution.
+
+3. Neither the name of the copyright holder nor the names of its 
+   contributors may be used to endorse or promote products derived 
+   from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 
+HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+In the event that redistribution and/or use for commercial purpose in 
+source or binary forms, with or without modification is required, 
+please contact the contributor(s) of the work.
\ No newline at end of file
diff --git a/README.md b/README.md
index 7e1b777..021bf8e 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,26 @@
-# vs-codeformer
-CodeFormer function for VapourSynth
+# CodeFormer
+Towards Robust Blind Face Restoration with Codebook Lookup TransFormer, based on https://github.com/sczhou/CodeFormer.
+
+
+## Dependencies
+- [NumPy](https://numpy.org/install)
+- [OpenCV-Python](https://github.com/opencv/opencv-python)
+- [PyTorch](https://pytorch.org/get-started) 1.13.1
+- [VapourSynth](http://www.vapoursynth.com/) R55+
+
+
+## Installation
+```
+pip install -U vscodeformer
+python -m vscodeformer
+```
+
+
+## Usage
+```python
+from vscodeformer import codeformer
+
+ret = codeformer(clip)
+```
+
+See `__init__.py` for the description of the parameters.
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..fdf37a8
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,31 @@
+[build-system]
+requires = ["hatchling"]
+build-backend = "hatchling.build"
+
+[project]
+name = "vscodeformer"
+version = "1.0.0"
+description = "CodeFormer function for VapourSynth"
+readme = "README.md"
+requires-python = ">=3.10"
+license = {file = "LICENSE"}
+authors = [{name = "HolyWu", email = "holywu@gmail.com"}]
+keywords = ["CodeFormer", "VapourSynth"]
+classifiers = [
+  "Operating System :: OS Independent",
+  "Programming Language :: Python :: 3.10",
+  "Topic :: Multimedia :: Video"
+]
+dependencies = [
+  "numpy>=1.24.2",
+  "opencv-python>=4.7.0.72",
+  "requests>=2.28.2",
+  "torch>=1.13.1",
+  "torchvision>=0.14.1",
+  "tqdm>=4.65.0",
+  "VapourSynth>=55"
+]
+
+[project.urls]
+"Homepage" = "https://github.com/HolyWu/vs-codeformer"
+"Bug Tracker" = "https://github.com/HolyWu/vs-codeformer/issues"
diff --git a/vscodeformer/__init__.py b/vscodeformer/__init__.py
new file mode 100644
index 0000000..de0fe4e
--- /dev/null
+++ b/vscodeformer/__init__.py
@@ -0,0 +1,180 @@
+from __future__ import annotations
+
+import os
+from threading import Lock
+
+import cv2
+import numpy as np
+import torch
+import vapoursynth as vs
+from torchvision.transforms.functional import normalize
+
+from .codeformer_arch import CodeFormer
+from .face_restoration_helper import FaceRestoreHelper
+from .img_util import img2tensor, tensor2img
+
+__version__ = "1.0.0"
+
+os.environ["CUDA_MODULE_LOADING"] = "LAZY"
+
+model_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
+
+
+@torch.inference_mode()
+def codeformer(
+    clip: vs.VideoNode,
+    device_index: int | None = None,
+    num_streams: int = 1,
+    upscale: int = 2,
+    detector: int = 0,
+    only_center_face: bool = False,
+    weight: float = 0.5,
+    bg_clip: vs.VideoNode | None = None,
+) -> vs.VideoNode:
+    """Towards Robust Blind Face Restoration with Codebook Lookup TransFormer
+
+    :param clip:                Clip to process. Only RGB24 format is supported.
+    :param device_index:        Device ordinal of the GPU.
+    :param num_streams:         Number of CUDA streams to enqueue the kernels.
+    :param upscale:             Final upsampling scale.
+    :param detector:            Face detector.
+                                0 = retinaface_resnet50
+                                1 = dlib
+    :param only_center_face:    Only restore the center face.
+    :param weight:              Balance the quality and fidelity. Generally, smaller weight tends to produce a
+                                higher-quality result, while larger weight yields a higher-fidelity result.
+    :param bg_clip:             Background clip that has been upsampled to final scale. If None, bilinear will be used.
+    """
+    if not isinstance(clip, vs.VideoNode):
+        raise vs.Error("codeformer: this is not a clip")
+
+    if clip.format.id != vs.RGB24:
+        raise vs.Error("codeformer: only RGB24 format is supported")
+
+    if not torch.cuda.is_available():
+        raise vs.Error("codeformer: CUDA is not available")
+
+    if num_streams < 1:
+        raise vs.Error("codeformer: num_streams must be at least 1")
+
+    if num_streams > vs.core.num_threads:
+        raise vs.Error("codeformer: setting num_streams greater than `core.num_threads` is useless")
+
+    if upscale < 1:
+        raise vs.Error("codeformer: upscale must be at least 1")
+
+    if detector not in range(2):
+        raise vs.Error("codeformer: detector must be 0 or 1")
+
+    if weight < 0 or weight > 1:
+        raise vs.Error("codeformer: weight must be between 0.0 and 1.0 (inclusive)")
+
+    if bg_clip is not None:
+        if not isinstance(bg_clip, vs.VideoNode):
+            raise vs.Error("codeformer: bg_clip is not a clip")
+
+        if bg_clip.format.id != vs.RGB24:
+            raise vs.Error("codeformer: bg_clip must be of RGB24 format")
+
+        if bg_clip.width != clip.width * upscale or bg_clip.height != clip.height * upscale:
+            raise vs.Error("codeformer: dimensions of bg_clip must match final upsampling scale")
+
+        if bg_clip.num_frames != clip.num_frames:
+            raise vs.Error("codeformer: bg_clip must have the same number of frames as main clip")
+
+    if os.path.getsize(os.path.join(model_dir, "codeformer.pth")) == 0:
+        raise vs.Error("codeformer: model files have not been downloaded. run 'python -m vscodeformer' first")
+
+    torch.set_float32_matmul_precision("high")
+
+    device = torch.device("cuda", device_index)
+
+    stream = [torch.cuda.Stream(device=device) for _ in range(num_streams)]
+    stream_lock = [Lock() for _ in range(num_streams)]
+
+    model_path = os.path.join(model_dir, "codeformer.pth")
+
+    module = CodeFormer()
+    module.load_state_dict(torch.load(model_path, map_location="cpu")["params_ema"])
+    module.eval().to(device)
+
+    detection_model = "retinaface_resnet50" if detector == 0 else "dlib"
+    face_helper = [
+        FaceRestoreHelper(upscale, det_model=detection_model, use_parse=True, device=device) for _ in range(num_streams)
+    ]
+
+    index = -1
+    index_lock = Lock()
+
+    @torch.inference_mode()
+    def inference(n: int, f: list[vs.VideoFrame]) -> vs.VideoFrame:
+        nonlocal index
+        with index_lock:
+            index = (index + 1) % num_streams
+            local_index = index
+
+        with stream_lock[local_index], torch.cuda.stream(stream[local_index]):
+            img = frame_to_ndarray(f[0])
+            bg_img = frame_to_ndarray(f[2]) if bg_clip is not None else None
+
+            face_helper[local_index].clean_all()
+            face_helper[local_index].read_image(img)
+            face_helper[local_index].get_face_landmarks_5(
+                only_center_face=only_center_face, resize=640, eye_dist_threshold=5
+            )
+            face_helper[local_index].align_warp_face()
+
+            for cropped_face in face_helper[local_index].cropped_faces:
+                cropped_face_t = img2tensor(cropped_face / 255.0)
+                normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+                cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
+                output = module(cropped_face_t, w=weight, adain=True)[0]
+                restored_face = tensor2img(output, min_max=(-1, 1))
+                face_helper[local_index].add_restored_face(restored_face, cropped_face)
+
+            face_helper[local_index].get_inverse_affine()
+            restored_img = face_helper[local_index].paste_faces_to_input_image(upsample_img=bg_img)
+            restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
+            return ndarray_to_frame(restored_img, f[1].copy())
+
+    pad_w = 512 - clip.width if clip.width < 512 else 0
+    pad_h = 512 - clip.height if clip.height < 512 else 0
+
+    if pad_w > 0 or pad_h > 0:
+        clip = clip.resize.Point(
+            clip.width + pad_w, clip.height + pad_h, src_width=clip.width + pad_w, src_height=clip.height + pad_h
+        )
+
+    new_clip = clip.std.BlankClip(width=clip.width * upscale, height=clip.height * upscale, keep=True)
+
+    if bg_clip is None:
+        ret = new_clip.std.FrameEval(
+            lambda n: new_clip.std.ModifyFrame([clip, new_clip], inference), clip_src=[clip, new_clip]
+        )
+    else:
+        bg_pad_w = new_clip.width - bg_clip.width
+        bg_pad_h = new_clip.height - bg_clip.height
+
+        if bg_pad_w > 0 or bg_pad_h > 0:
+            bg_clip = bg_clip.resize.Point(
+                bg_clip.width + bg_pad_w,
+                bg_clip.height + bg_pad_h,
+                src_width=bg_clip.width + bg_pad_w,
+                src_height=bg_clip.height + bg_pad_h,
+            )
+
+        ret = new_clip.std.FrameEval(
+            lambda n: new_clip.std.ModifyFrame([clip, new_clip, bg_clip], inference), clip_src=[clip, new_clip, bg_clip]
+        )
+
+    return ret.std.Crop(right=pad_w * upscale, bottom=pad_h * upscale)
+
+
+def frame_to_ndarray(frame: vs.VideoFrame) -> np.ndarray:
+    return np.stack([np.asarray(frame[plane]) for plane in range(frame.format.num_planes - 1, -1, -1)], axis=2)
+
+
+def ndarray_to_frame(array: np.ndarray, frame: vs.VideoFrame) -> vs.VideoFrame:
+    for plane in range(frame.format.num_planes):
+        np.copyto(np.asarray(frame[plane]), array[:, :, plane])
+    return frame
diff --git a/vscodeformer/__main__.py b/vscodeformer/__main__.py
new file mode 100644
index 0000000..3aec703
--- /dev/null
+++ b/vscodeformer/__main__.py
@@ -0,0 +1,34 @@
+import os
+
+import requests
+from tqdm import tqdm
+
+
+def download_model(url: str) -> None:
+    filename = url.split("/")[-1]
+    r = requests.get(url, stream=True)
+    with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "models", filename), "wb") as f:
+        with tqdm(
+            unit="B",
+            unit_scale=True,
+            unit_divisor=1024,
+            miniters=1,
+            desc=filename,
+            total=int(r.headers.get("content-length", 0)),
+        ) as pbar:
+            for chunk in r.iter_content(chunk_size=4096):
+                f.write(chunk)
+                pbar.update(len(chunk))
+
+
+if __name__ == "__main__":
+    url = "https://github.com/HolyWu/vs-codeformer/releases/download/model/"
+    models = [
+        "codeformer",
+        "detection_Resnet50_Final",
+        "mmod_human_face_detector-4cb19393",
+        "parsing_parsenet",
+        "shape_predictor_5_face_landmarks-c4b1e980",
+    ]
+    for model in models:
+        download_model(url + model + ".pth")
diff --git a/vscodeformer/align_trans.py b/vscodeformer/align_trans.py
new file mode 100644
index 0000000..07f1eb3
--- /dev/null
+++ b/vscodeformer/align_trans.py
@@ -0,0 +1,219 @@
+import cv2
+import numpy as np
+
+from .matlab_cp2tform import get_similarity_transform_for_cv2
+
+# reference facial points, a list of coordinates (x,y)
+REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], [65.53179932, 51.50139999], [48.02519989, 71.73660278],
+                           [33.54930115, 92.3655014], [62.72990036, 92.20410156]]
+
+DEFAULT_CROP_SIZE = (96, 112)
+
+
+class FaceWarpException(Exception):
+
+    def __str__(self):
+        return 'In File {}:{}'.format(__file__, super.__str__(self))
+
+
+def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False):
+    """
+    Function:
+    ----------
+        get reference 5 key points according to crop settings:
+        0. Set default crop_size:
+            if default_square:
+                crop_size = (112, 112)
+            else:
+                crop_size = (96, 112)
+        1. Pad the crop_size by inner_padding_factor in each side;
+        2. Resize crop_size into (output_size - outer_padding*2),
+            pad into output_size with outer_padding;
+        3. Output reference_5point;
+    Parameters:
+    ----------
+        @output_size: (w, h) or None
+            size of aligned face image
+        @inner_padding_factor: (w_factor, h_factor)
+            padding factor for inner (w, h)
+        @outer_padding: (w_pad, h_pad)
+            each row is a pair of coordinates (x, y)
+        @default_square: True or False
+            if True:
+                default crop_size = (112, 112)
+            else:
+                default crop_size = (96, 112);
+        !!! make sure, if output_size is not None:
+                (output_size - outer_padding)
+                = some_scale * (default crop_size * (1.0 +
+                inner_padding_factor))
+    Returns:
+    ----------
+        @reference_5point: 5x2 np.array
+            each row is a pair of transformed coordinates (x, y)
+    """
+
+    tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
+    tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
+
+    # 0) make the inner region a square
+    if default_square:
+        size_diff = max(tmp_crop_size) - tmp_crop_size
+        tmp_5pts += size_diff / 2
+        tmp_crop_size += size_diff
+
+    if (output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]):
+
+        return tmp_5pts
+
+    if (inner_padding_factor == 0 and outer_padding == (0, 0)):
+        if output_size is None:
+            return tmp_5pts
+        else:
+            raise FaceWarpException('No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
+
+    # check output size
+    if not (0 <= inner_padding_factor <= 1.0):
+        raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
+
+    if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None):
+        output_size = tmp_crop_size * \
+            (1 + inner_padding_factor * 2).astype(np.int32)
+        output_size += np.array(outer_padding)
+    if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]):
+        raise FaceWarpException('Not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1])')
+
+    # 1) pad the inner region according inner_padding_factor
+    if inner_padding_factor > 0:
+        size_diff = tmp_crop_size * inner_padding_factor * 2
+        tmp_5pts += size_diff / 2
+        tmp_crop_size += np.round(size_diff).astype(np.int32)
+
+    # 2) resize the padded inner region
+    size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
+
+    if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
+        raise FaceWarpException('Must have (output_size - outer_padding)'
+                                '= some_scale * (crop_size * (1.0 + inner_padding_factor)')
+
+    scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
+    tmp_5pts = tmp_5pts * scale_factor
+    #    size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
+    #    tmp_5pts = tmp_5pts + size_diff / 2
+    tmp_crop_size = size_bf_outer_pad
+
+    # 3) add outer_padding to make output_size
+    reference_5point = tmp_5pts + np.array(outer_padding)
+    tmp_crop_size = output_size
+
+    return reference_5point
+
+
+def get_affine_transform_matrix(src_pts, dst_pts):
+    """
+    Function:
+    ----------
+        get affine transform matrix 'tfm' from src_pts to dst_pts
+    Parameters:
+    ----------
+        @src_pts: Kx2 np.array
+            source points matrix, each row is a pair of coordinates (x, y)
+        @dst_pts: Kx2 np.array
+            destination points matrix, each row is a pair of coordinates (x, y)
+    Returns:
+    ----------
+        @tfm: 2x3 np.array
+            transform matrix from src_pts to dst_pts
+    """
+
+    tfm = np.float32([[1, 0, 0], [0, 1, 0]])
+    n_pts = src_pts.shape[0]
+    ones = np.ones((n_pts, 1), src_pts.dtype)
+    src_pts_ = np.hstack([src_pts, ones])
+    dst_pts_ = np.hstack([dst_pts, ones])
+
+    A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
+
+    if rank == 3:
+        tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]])
+    elif rank == 2:
+        tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]])
+
+    return tfm
+
+
+def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type='smilarity'):
+    """
+    Function:
+    ----------
+        apply affine transform 'trans' to uv
+    Parameters:
+    ----------
+        @src_img: 3x3 np.array
+            input image
+        @facial_pts: could be
+            1)a list of K coordinates (x,y)
+        or
+            2) Kx2 or 2xK np.array
+            each row or col is a pair of coordinates (x, y)
+        @reference_pts: could be
+            1) a list of K coordinates (x,y)
+        or
+            2) Kx2 or 2xK np.array
+            each row or col is a pair of coordinates (x, y)
+        or
+            3) None
+            if None, use default reference facial points
+        @crop_size: (w, h)
+            output face image size
+        @align_type: transform type, could be one of
+            1) 'similarity': use similarity transform
+            2) 'cv2_affine': use the first 3 points to do affine transform,
+                    by calling cv2.getAffineTransform()
+            3) 'affine': use all points to do affine transform
+    Returns:
+    ----------
+        @face_img: output face image with size (w, h) = @crop_size
+    """
+
+    if reference_pts is None:
+        if crop_size[0] == 96 and crop_size[1] == 112:
+            reference_pts = REFERENCE_FACIAL_POINTS
+        else:
+            default_square = False
+            inner_padding_factor = 0
+            outer_padding = (0, 0)
+            output_size = crop_size
+
+            reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding,
+                                                        default_square)
+
+    ref_pts = np.float32(reference_pts)
+    ref_pts_shp = ref_pts.shape
+    if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
+        raise FaceWarpException('reference_pts.shape must be (K,2) or (2,K) and K>2')
+
+    if ref_pts_shp[0] == 2:
+        ref_pts = ref_pts.T
+
+    src_pts = np.float32(facial_pts)
+    src_pts_shp = src_pts.shape
+    if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
+        raise FaceWarpException('facial_pts.shape must be (K,2) or (2,K) and K>2')
+
+    if src_pts_shp[0] == 2:
+        src_pts = src_pts.T
+
+    if src_pts.shape != ref_pts.shape:
+        raise FaceWarpException('facial_pts and reference_pts must have the same shape')
+
+    if align_type == 'cv2_affine':
+        tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
+    elif align_type == 'affine':
+        tfm = get_affine_transform_matrix(src_pts, ref_pts)
+    else:
+        tfm = get_similarity_transform_for_cv2(src_pts, ref_pts)
+
+    face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]))
+
+    return face_img
diff --git a/vscodeformer/codeformer_arch.py b/vscodeformer/codeformer_arch.py
new file mode 100644
index 0000000..bc1c8ce
--- /dev/null
+++ b/vscodeformer/codeformer_arch.py
@@ -0,0 +1,274 @@
+import math
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+from .vqgan_arch import *
+
+
+def calc_mean_std(feat, eps=1e-5):
+    """Calculate mean and std for adaptive_instance_normalization.
+
+    Args:
+        feat (Tensor): 4D tensor.
+        eps (float): A small value added to the variance to avoid
+            divide-by-zero. Default: 1e-5.
+    """
+    size = feat.size()
+    assert len(size) == 4, 'The input feature should be 4D tensor.'
+    b, c = size[:2]
+    feat_var = feat.view(b, c, -1).var(dim=2) + eps
+    feat_std = feat_var.sqrt().view(b, c, 1, 1)
+    feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
+    return feat_mean, feat_std
+
+
+def adaptive_instance_normalization(content_feat, style_feat):
+    """Adaptive instance normalization.
+
+    Adjust the reference features to have the similar color and illuminations
+    as those in the degradate features.
+
+    Args:
+        content_feat (Tensor): The reference feature.
+        style_feat (Tensor): The degradate features.
+    """
+    size = content_feat.size()
+    style_mean, style_std = calc_mean_std(style_feat)
+    content_mean, content_std = calc_mean_std(content_feat)
+    normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
+    return normalized_feat * style_std.expand(size) + style_mean.expand(size)
+
+
+class PositionEmbeddingSine(nn.Module):
+    """
+    This is a more standard version of the position embedding, very similar to the one
+    used by the Attention is all you need paper, generalized to work on images.
+    """
+
+    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+        super().__init__()
+        self.num_pos_feats = num_pos_feats
+        self.temperature = temperature
+        self.normalize = normalize
+        if scale is not None and normalize is False:
+            raise ValueError("normalize should be True if scale is passed")
+        if scale is None:
+            scale = 2 * math.pi
+        self.scale = scale
+
+    def forward(self, x, mask=None):
+        if mask is None:
+            mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
+        not_mask = ~mask
+        y_embed = not_mask.cumsum(1, dtype=torch.float32)
+        x_embed = not_mask.cumsum(2, dtype=torch.float32)
+        if self.normalize:
+            eps = 1e-6
+            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+        pos_x = x_embed[:, :, :, None] / dim_t
+        pos_y = y_embed[:, :, :, None] / dim_t
+        pos_x = torch.stack(
+            (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+        ).flatten(3)
+        pos_y = torch.stack(
+            (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+        ).flatten(3)
+        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+        return pos
+
+def _get_activation_fn(activation):
+    """Return an activation function given a string"""
+    if activation == "relu":
+        return F.relu
+    if activation == "gelu":
+        return F.gelu
+    if activation == "glu":
+        return F.glu
+    raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
+
+
+class TransformerSALayer(nn.Module):
+    def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
+        super().__init__()
+        self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
+        # Implementation of Feedforward model - MLP
+        self.linear1 = nn.Linear(embed_dim, dim_mlp)
+        self.dropout = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(dim_mlp, embed_dim)
+
+        self.norm1 = nn.LayerNorm(embed_dim)
+        self.norm2 = nn.LayerNorm(embed_dim)
+        self.dropout1 = nn.Dropout(dropout)
+        self.dropout2 = nn.Dropout(dropout)
+
+        self.activation = _get_activation_fn(activation)
+
+    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+        return tensor if pos is None else tensor + pos
+
+    def forward(self, tgt,
+                tgt_mask: Optional[Tensor] = None,
+                tgt_key_padding_mask: Optional[Tensor] = None,
+                query_pos: Optional[Tensor] = None):
+
+        # self attention
+        tgt2 = self.norm1(tgt)
+        q = k = self.with_pos_embed(tgt2, query_pos)
+        tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
+                              key_padding_mask=tgt_key_padding_mask)[0]
+        tgt = tgt + self.dropout1(tgt2)
+
+        # ffn
+        tgt2 = self.norm2(tgt)
+        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+        tgt = tgt + self.dropout2(tgt2)
+        return tgt
+
+class Fuse_sft_block(nn.Module):
+    def __init__(self, in_ch, out_ch):
+        super().__init__()
+        self.encode_enc = ResBlock(2*in_ch, out_ch)
+
+        self.scale = nn.Sequential(
+                    nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
+                    nn.LeakyReLU(0.2, True),
+                    nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
+
+        self.shift = nn.Sequential(
+                    nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
+                    nn.LeakyReLU(0.2, True),
+                    nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
+
+    def forward(self, enc_feat, dec_feat, w=1):
+        enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
+        scale = self.scale(enc_feat)
+        shift = self.shift(enc_feat)
+        residual = w * (dec_feat * scale + shift)
+        out = dec_feat + residual
+        return out
+
+
+class CodeFormer(VQAutoEncoder):
+    def __init__(self, dim_embd=512, n_head=8, n_layers=9,
+                codebook_size=1024, latent_size=256,
+                connect_list=['32', '64', '128', '256'],
+                fix_modules=['quantize','generator']):
+        super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
+
+        if fix_modules is not None:
+            for module in fix_modules:
+                for param in getattr(self, module).parameters():
+                    param.requires_grad = False
+
+        self.connect_list = connect_list
+        self.n_layers = n_layers
+        self.dim_embd = dim_embd
+        self.dim_mlp = dim_embd*2
+
+        self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
+        self.feat_emb = nn.Linear(256, self.dim_embd)
+
+        # transformer
+        self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
+                                    for _ in range(self.n_layers)])
+
+        # logits_predict head
+        self.idx_pred_layer = nn.Sequential(
+            nn.LayerNorm(dim_embd),
+            nn.Linear(dim_embd, codebook_size, bias=False))
+
+        self.channels = {
+            '16': 512,
+            '32': 256,
+            '64': 256,
+            '128': 128,
+            '256': 128,
+            '512': 64,
+        }
+
+        # after second residual block for > 16, before attn layer for ==16
+        self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
+        # after first residual block for > 16, before attn layer for ==16
+        self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
+
+        # fuse_convs_dict
+        self.fuse_convs_dict = nn.ModuleDict()
+        for f_size in self.connect_list:
+            in_ch = self.channels[f_size]
+            self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
+
+    def _init_weights(self, module):
+        if isinstance(module, (nn.Linear, nn.Embedding)):
+            module.weight.data.normal_(mean=0.0, std=0.02)
+            if isinstance(module, nn.Linear) and module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+
+    def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
+        # ################### Encoder #####################
+        enc_feat_dict = {}
+        out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
+        for i, block in enumerate(self.encoder.blocks):
+            x = block(x)
+            if i in out_list:
+                enc_feat_dict[str(x.shape[-1])] = x.clone()
+
+        lq_feat = x
+        # ################# Transformer ###################
+        # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
+        pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
+        # BCHW -> BC(HW) -> (HW)BC
+        feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
+        query_emb = feat_emb
+        # Transformer encoder
+        for layer in self.ft_layers:
+            query_emb = layer(query_emb, query_pos=pos_emb)
+
+        # output logits
+        logits = self.idx_pred_layer(query_emb) # (hw)bn
+        logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
+
+        if code_only: # for training stage II
+          # logits doesn't need softmax before cross_entropy loss
+            return logits, lq_feat
+
+        # ################# Quantization ###################
+        # if self.training:
+        #     quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
+        #     # b(hw)c -> bc(hw) -> bchw
+        #     quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
+        # ------------
+        soft_one_hot = F.softmax(logits, dim=2)
+        _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
+        quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
+        # preserve gradients
+        # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
+
+        if detach_16:
+            quant_feat = quant_feat.detach() # for training stage III
+        if adain:
+            quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
+
+        # ################## Generator ####################
+        x = quant_feat
+        fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
+
+        for i, block in enumerate(self.generator.blocks):
+            x = block(x)
+            if i in fuse_list: # fuse after i-th block
+                f_size = str(x.shape[-1])
+                if w>0:
+                    x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
+        out = x
+        # logits doesn't need softmax before cross_entropy loss
+        return out, logits, lq_feat
diff --git a/vscodeformer/face_restoration_helper.py b/vscodeformer/face_restoration_helper.py
new file mode 100644
index 0000000..284c0a3
--- /dev/null
+++ b/vscodeformer/face_restoration_helper.py
@@ -0,0 +1,538 @@
+import os
+from copy import deepcopy
+
+import cv2
+import numpy as np
+import torch
+from torchvision.transforms.functional import normalize
+
+from .misc import adain_npy, bgr2gray, img2tensor, imwrite, is_gray
+from .parsenet import ParseNet
+from .retinaface import RetinaFace
+
+model_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models')
+
+
+def get_largest_face(det_faces, h, w):
+
+    def get_location(val, length):
+        if val < 0:
+            return 0
+        elif val > length:
+            return length
+        else:
+            return val
+
+    face_areas = []
+    for det_face in det_faces:
+        left = get_location(det_face[0], w)
+        right = get_location(det_face[2], w)
+        top = get_location(det_face[1], h)
+        bottom = get_location(det_face[3], h)
+        face_area = (right - left) * (bottom - top)
+        face_areas.append(face_area)
+    largest_idx = face_areas.index(max(face_areas))
+    return det_faces[largest_idx], largest_idx
+
+
+def get_center_face(det_faces, h=0, w=0, center=None):
+    if center is not None:
+        center = np.array(center)
+    else:
+        center = np.array([w / 2, h / 2])
+    center_dist = []
+    for det_face in det_faces:
+        face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
+        dist = np.linalg.norm(face_center - center)
+        center_dist.append(dist)
+    center_idx = center_dist.index(min(center_dist))
+    return det_faces[center_idx], center_idx
+
+
+class FaceRestoreHelper(object):
+    """Helper for the face restoration pipeline (base class)."""
+
+    def __init__(self,
+                 upscale_factor,
+                 face_size=512,
+                 crop_ratio=(1, 1),
+                 det_model='retinaface_resnet50',
+                 save_ext='png',
+                 template_3points=False,
+                 pad_blur=False,
+                 use_parse=False,
+                 device=None):
+        self.template_3points = template_3points  # improve robustness
+        self.upscale_factor = int(upscale_factor)
+        # the cropped face ratio based on the square face
+        self.crop_ratio = crop_ratio  # (h, w)
+        assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
+        self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
+        self.det_model = det_model
+
+        if self.det_model == 'dlib':
+            # standard 5 landmarks for FFHQ faces with 1024 x 1024
+            self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941],
+                                        [337.91089109, 488.38613861], [437.95049505, 493.51485149],
+                                        [513.58415842, 678.5049505]])
+            self.face_template = self.face_template / (1024 // face_size)
+        elif self.template_3points:
+            self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
+        else:
+            # standard 5 landmarks for FFHQ faces with 512 x 512
+            # facexlib
+            self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
+                                           [201.26117, 371.41043], [313.08905, 371.15118]])
+
+            # dlib: left_eye: 36:41  right_eye: 42:47  nose: 30,32,33,34  left mouth corner: 48  right mouth corner: 54
+            # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
+            #                                 [198.22603, 372.82502], [313.91018, 372.75659]])
+
+        self.face_template = self.face_template * (face_size / 512.0)
+        if self.crop_ratio[0] > 1:
+            self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
+        if self.crop_ratio[1] > 1:
+            self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
+        self.save_ext = save_ext
+        self.pad_blur = pad_blur
+        if self.pad_blur is True:
+            self.template_3points = False
+
+        self.all_landmarks_5 = []
+        self.det_faces = []
+        self.affine_matrices = []
+        self.inverse_affine_matrices = []
+        self.cropped_faces = []
+        self.restored_faces = []
+        self.pad_input_imgs = []
+
+        self.device = device
+
+        # init face detection model
+        if self.det_model == 'dlib':
+            self.face_detector, self.shape_predictor_5 = self.init_dlib()
+        else:
+            self.face_detector = self.init_detection_model()
+
+        # init face parsing model
+        self.use_parse = use_parse
+        self.face_parse = self.init_parsing_model()
+
+    def init_detection_model(self):
+        model = RetinaFace(self.device)
+        model_path = os.path.join(model_dir, 'detection_Resnet50_Final.pth')
+        load_net = torch.load(model_path, map_location='cpu')
+        # remove unnecessary 'module.'
+        for k, v in deepcopy(load_net).items():
+            if k.startswith('module.'):
+                load_net[k[7:]] = v
+                load_net.pop(k)
+        model.load_state_dict(load_net)
+        model.eval().to(self.device)
+        return model
+
+    def init_parsing_model(self):
+        model = ParseNet(in_size=512, out_size=512)
+        model_path = os.path.join(model_dir, 'parsing_parsenet.pth')
+        load_net = torch.load(model_path, map_location='cpu')
+        model.load_state_dict(load_net)
+        model.eval().to(self.device)
+        return model
+
+    def set_upscale_factor(self, upscale_factor):
+        self.upscale_factor = upscale_factor
+
+    def read_image(self, img):
+        """img can be image path or cv2 loaded image."""
+        # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
+        if isinstance(img, str):
+            img = cv2.imread(img)
+
+        if np.max(img) > 256:  # 16-bit image
+            img = img / 65535 * 255
+        if len(img.shape) == 2:  # gray image
+            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+        elif img.shape[2] == 4:  # BGRA image with alpha channel
+            img = img[:, :, 0:3]
+
+        self.input_img = img
+        self.is_gray = is_gray(img, threshold=10)
+
+        if min(self.input_img.shape[:2])<512:
+            f = 512.0/min(self.input_img.shape[:2])
+            self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
+
+    def init_dlib(self):
+        """Initialize the dlib detectors and predictors."""
+        try:
+            import dlib
+        except ImportError:
+            print('Please install dlib')
+        detection_path = os.path.join(model_dir, 'mmod_human_face_detector-4cb19393.dat')
+        landmark5_path = os.path.join(model_dir, 'shape_predictor_5_face_landmarks-c4b1e980.dat')
+        face_detector = dlib.cnn_face_detection_model_v1(detection_path)
+        shape_predictor_5 = dlib.shape_predictor(landmark5_path)
+        return face_detector, shape_predictor_5
+
+    def get_face_landmarks_5_dlib(self,
+                                only_keep_largest=False,
+                                scale=1):
+        det_faces = self.face_detector(self.input_img, scale)
+
+        if len(det_faces) == 0:
+            print('No face detected. Try to increase upsample_num_times.')
+            return 0
+        else:
+            if only_keep_largest:
+                print('Detect several faces and only keep the largest.')
+                face_areas = []
+                for i in range(len(det_faces)):
+                    face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * (
+                        det_faces[i].rect.bottom() - det_faces[i].rect.top())
+                    face_areas.append(face_area)
+                largest_idx = face_areas.index(max(face_areas))
+                self.det_faces = [det_faces[largest_idx]]
+            else:
+                self.det_faces = det_faces
+
+        if len(self.det_faces) == 0:
+            return 0
+
+        for face in self.det_faces:
+            shape = self.shape_predictor_5(self.input_img, face.rect)
+            landmark = np.array([[part.x, part.y] for part in shape.parts()])
+            self.all_landmarks_5.append(landmark)
+
+        return len(self.all_landmarks_5)
+
+
+    def get_face_landmarks_5(self,
+                             only_keep_largest=False,
+                             only_center_face=False,
+                             resize=None,
+                             blur_ratio=0.01,
+                             eye_dist_threshold=None):
+        if self.det_model == 'dlib':
+            return self.get_face_landmarks_5_dlib(only_keep_largest)
+
+        if resize is None:
+            scale = 1
+            input_img = self.input_img
+        else:
+            h, w = self.input_img.shape[0:2]
+            scale = resize / min(h, w)
+            scale = max(1, scale) # always scale up
+            h, w = int(h * scale), int(w * scale)
+            interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
+            input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
+
+        with torch.no_grad():
+            bboxes = self.face_detector.detect_faces(input_img)
+
+        if bboxes is None or bboxes.shape[0] == 0:
+            return 0
+        else:
+            bboxes = bboxes / scale
+
+        for bbox in bboxes:
+            # remove faces with too small eye distance: side faces or too small faces
+            eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
+            if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
+                continue
+
+            if self.template_3points:
+                landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
+            else:
+                landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
+            self.all_landmarks_5.append(landmark)
+            self.det_faces.append(bbox[0:5])
+
+        if len(self.det_faces) == 0:
+            return 0
+        if only_keep_largest:
+            h, w, _ = self.input_img.shape
+            self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
+            self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
+        elif only_center_face:
+            h, w, _ = self.input_img.shape
+            self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
+            self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
+
+        # pad blurry images
+        if self.pad_blur:
+            self.pad_input_imgs = []
+            for landmarks in self.all_landmarks_5:
+                # get landmarks
+                eye_left = landmarks[0, :]
+                eye_right = landmarks[1, :]
+                eye_avg = (eye_left + eye_right) * 0.5
+                mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
+                eye_to_eye = eye_right - eye_left
+                eye_to_mouth = mouth_avg - eye_avg
+
+                # Get the oriented crop rectangle
+                # x: half width of the oriented crop rectangle
+                x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+                #  - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
+                # norm with the hypotenuse: get the direction
+                x /= np.hypot(*x)  # get the hypotenuse of a right triangle
+                rect_scale = 1.5
+                x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
+                # y: half height of the oriented crop rectangle
+                y = np.flipud(x) * [-1, 1]
+
+                # c: center
+                c = eye_avg + eye_to_mouth * 0.1
+                # quad: (left_top, left_bottom, right_bottom, right_top)
+                quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+                # qsize: side length of the square
+                qsize = np.hypot(*x) * 2
+                border = max(int(np.rint(qsize * 0.1)), 3)
+
+                # get pad
+                # pad: (width_left, height_top, width_right, height_bottom)
+                pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+                       int(np.ceil(max(quad[:, 1]))))
+                pad = [
+                    max(-pad[0] + border, 1),
+                    max(-pad[1] + border, 1),
+                    max(pad[2] - self.input_img.shape[0] + border, 1),
+                    max(pad[3] - self.input_img.shape[1] + border, 1)
+                ]
+
+                if max(pad) > 1:
+                    # pad image
+                    pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
+                    # modify landmark coords
+                    landmarks[:, 0] += pad[0]
+                    landmarks[:, 1] += pad[1]
+                    # blur pad images
+                    h, w, _ = pad_img.shape
+                    y, x, _ = np.ogrid[:h, :w, :1]
+                    mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
+                                                       np.float32(w - 1 - x) / pad[2]),
+                                      1.0 - np.minimum(np.float32(y) / pad[1],
+                                                       np.float32(h - 1 - y) / pad[3]))
+                    blur = int(qsize * blur_ratio)
+                    if blur % 2 == 0:
+                        blur += 1
+                    blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
+                    # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
+
+                    pad_img = pad_img.astype('float32')
+                    pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+                    pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
+                    pad_img = np.clip(pad_img, 0, 255)  # float32, [0, 255]
+                    self.pad_input_imgs.append(pad_img)
+                else:
+                    self.pad_input_imgs.append(np.copy(self.input_img))
+
+        return len(self.all_landmarks_5)
+
+    def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
+        """Align and warp faces with face template.
+        """
+        if self.pad_blur:
+            assert len(self.pad_input_imgs) == len(
+                self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
+        for idx, landmark in enumerate(self.all_landmarks_5):
+            # use 5 landmarks to get affine matrix
+            # use cv2.LMEDS method for the equivalence to skimage transform
+            # ref: https://blog.csdn.net/yichxi/article/details/115827338
+            affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
+            self.affine_matrices.append(affine_matrix)
+            # warp and crop faces
+            if border_mode == 'constant':
+                border_mode = cv2.BORDER_CONSTANT
+            elif border_mode == 'reflect101':
+                border_mode = cv2.BORDER_REFLECT101
+            elif border_mode == 'reflect':
+                border_mode = cv2.BORDER_REFLECT
+            if self.pad_blur:
+                input_img = self.pad_input_imgs[idx]
+            else:
+                input_img = self.input_img
+            cropped_face = cv2.warpAffine(
+                input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132))  # gray
+            self.cropped_faces.append(cropped_face)
+            # save the cropped face
+            if save_cropped_path is not None:
+                path = os.path.splitext(save_cropped_path)[0]
+                save_path = f'{path}_{idx:02d}.{self.save_ext}'
+                imwrite(cropped_face, save_path)
+
+    def get_inverse_affine(self, save_inverse_affine_path=None):
+        """Get inverse affine matrix."""
+        for idx, affine_matrix in enumerate(self.affine_matrices):
+            inverse_affine = cv2.invertAffineTransform(affine_matrix)
+            inverse_affine *= self.upscale_factor
+            self.inverse_affine_matrices.append(inverse_affine)
+            # save inverse affine matrices
+            if save_inverse_affine_path is not None:
+                path, _ = os.path.splitext(save_inverse_affine_path)
+                save_path = f'{path}_{idx:02d}.pth'
+                torch.save(inverse_affine, save_path)
+
+
+    def add_restored_face(self, restored_face, input_face=None):
+        if self.is_gray:
+            restored_face = bgr2gray(restored_face) # convert img into grayscale
+            if input_face is not None:
+                restored_face = adain_npy(restored_face, input_face) # transfer the color
+        self.restored_faces.append(restored_face)
+
+
+    def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
+        h, w, _ = self.input_img.shape
+        h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
+
+        if upsample_img is None:
+            # simply resize the background
+            # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
+            upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR)
+        else:
+            upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
+
+        assert len(self.restored_faces) == len(
+            self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
+
+        inv_mask_borders = []
+        for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
+            if face_upsampler is not None:
+                restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0]
+                inverse_affine /= self.upscale_factor
+                inverse_affine[:, 2] *= self.upscale_factor
+                face_size = (self.face_size[0]*self.upscale_factor, self.face_size[1]*self.upscale_factor)
+            else:
+                # Add an offset to inverse affine matrix, for more precise back alignment
+                if self.upscale_factor > 1:
+                    extra_offset = 0.5 * self.upscale_factor
+                else:
+                    extra_offset = 0
+                inverse_affine[:, 2] += extra_offset
+                face_size = self.face_size
+            inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
+
+            # if draw_box or not self.use_parse:  # use square parse maps
+            #     mask = np.ones(face_size, dtype=np.float32)
+            #     inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
+            #     # remove the black borders
+            #     inv_mask_erosion = cv2.erode(
+            #         inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
+            #     pasted_face = inv_mask_erosion[:, :, None] * inv_restored
+            #     total_face_area = np.sum(inv_mask_erosion)  # // 3
+            #     # add border
+            #     if draw_box:
+            #         h, w = face_size
+            #         mask_border = np.ones((h, w, 3), dtype=np.float32)
+            #         border = int(1400/np.sqrt(total_face_area))
+            #         mask_border[border:h-border, border:w-border,:] = 0
+            #         inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
+            #         inv_mask_borders.append(inv_mask_border)
+            #     if not self.use_parse:
+            #         # compute the fusion edge based on the area of face
+            #         w_edge = int(total_face_area**0.5) // 20
+            #         erosion_radius = w_edge * 2
+            #         inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
+            #         blur_size = w_edge * 2
+            #         inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
+            #         if len(upsample_img.shape) == 2:  # upsample_img is gray image
+            #             upsample_img = upsample_img[:, :, None]
+            #         inv_soft_mask = inv_soft_mask[:, :, None]
+
+            # always use square mask
+            mask = np.ones(face_size, dtype=np.float32)
+            inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
+            # remove the black borders
+            inv_mask_erosion = cv2.erode(
+                inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
+            pasted_face = inv_mask_erosion[:, :, None] * inv_restored
+            total_face_area = np.sum(inv_mask_erosion)  # // 3
+            # add border
+            if draw_box:
+                h, w = face_size
+                mask_border = np.ones((h, w, 3), dtype=np.float32)
+                border = int(1400/np.sqrt(total_face_area))
+                mask_border[border:h-border, border:w-border,:] = 0
+                inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
+                inv_mask_borders.append(inv_mask_border)
+            # compute the fusion edge based on the area of face
+            w_edge = int(total_face_area**0.5) // 20
+            erosion_radius = w_edge * 2
+            inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
+            blur_size = w_edge * 2
+            inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
+            if len(upsample_img.shape) == 2:  # upsample_img is gray image
+                upsample_img = upsample_img[:, :, None]
+            inv_soft_mask = inv_soft_mask[:, :, None]
+
+            # parse mask
+            if self.use_parse:
+                # inference
+                face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
+                face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
+                normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+                face_input = torch.unsqueeze(face_input, 0).to(self.device)
+                with torch.no_grad():
+                    out = self.face_parse(face_input)[0]
+                out = out.argmax(dim=1).squeeze().cpu().numpy()
+
+                parse_mask = np.zeros(out.shape)
+                MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
+                for idx, color in enumerate(MASK_COLORMAP):
+                    parse_mask[out == idx] = color
+                #  blur the mask
+                parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
+                parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
+                # remove the black borders
+                thres = 10
+                parse_mask[:thres, :] = 0
+                parse_mask[-thres:, :] = 0
+                parse_mask[:, :thres] = 0
+                parse_mask[:, -thres:] = 0
+                parse_mask = parse_mask / 255.
+
+                parse_mask = cv2.resize(parse_mask, face_size)
+                parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3)
+                inv_soft_parse_mask = parse_mask[:, :, None]
+                # pasted_face = inv_restored
+                fuse_mask = (inv_soft_parse_mask<inv_soft_mask).astype('int')
+                inv_soft_mask = inv_soft_parse_mask*fuse_mask + inv_soft_mask*(1-fuse_mask)
+
+            if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4:  # alpha channel
+                alpha = upsample_img[:, :, 3:]
+                upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
+                upsample_img = np.concatenate((upsample_img, alpha), axis=2)
+            else:
+                upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
+
+        if np.max(upsample_img) > 256:  # 16-bit image
+            upsample_img = upsample_img.astype(np.uint16)
+        else:
+            upsample_img = upsample_img.astype(np.uint8)
+
+        # draw bounding box
+        if draw_box:
+            # upsample_input_img = cv2.resize(input_img, (w_up, h_up))
+            img_color = np.ones([*upsample_img.shape], dtype=np.float32)
+            img_color[:,:,0] = 0
+            img_color[:,:,1] = 255
+            img_color[:,:,2] = 0
+            for inv_mask_border in inv_mask_borders:
+                upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img
+                # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
+
+        if save_path is not None:
+            path = os.path.splitext(save_path)[0]
+            save_path = f'{path}.{self.save_ext}'
+            imwrite(upsample_img, save_path)
+        return upsample_img
+
+    def clean_all(self):
+        self.all_landmarks_5 = []
+        self.restored_faces = []
+        self.affine_matrices = []
+        self.cropped_faces = []
+        self.inverse_affine_matrices = []
+        self.det_faces = []
+        self.pad_input_imgs = []
diff --git a/vscodeformer/img_util.py b/vscodeformer/img_util.py
new file mode 100644
index 0000000..3e73ac3
--- /dev/null
+++ b/vscodeformer/img_util.py
@@ -0,0 +1,171 @@
+import math
+import os
+
+import cv2
+import numpy as np
+import torch
+from torchvision.utils import make_grid
+
+
+def img2tensor(imgs, bgr2rgb=True, float32=True):
+    """Numpy array to tensor.
+
+    Args:
+        imgs (list[ndarray] | ndarray): Input images.
+        bgr2rgb (bool): Whether to change bgr to rgb.
+        float32 (bool): Whether to change to float32.
+
+    Returns:
+        list[tensor] | tensor: Tensor images. If returned results only have
+            one element, just return tensor.
+    """
+
+    def _totensor(img, bgr2rgb, float32):
+        if img.shape[2] == 3 and bgr2rgb:
+            if img.dtype == 'float64':
+                img = img.astype('float32')
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+        img = torch.from_numpy(img.transpose(2, 0, 1))
+        if float32:
+            img = img.float()
+        return img
+
+    if isinstance(imgs, list):
+        return [_totensor(img, bgr2rgb, float32) for img in imgs]
+    else:
+        return _totensor(imgs, bgr2rgb, float32)
+
+
+def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
+    """Convert torch Tensors into image numpy arrays.
+
+    After clamping to [min, max], values will be normalized to [0, 1].
+
+    Args:
+        tensor (Tensor or list[Tensor]): Accept shapes:
+            1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
+            2) 3D Tensor of shape (3/1 x H x W);
+            3) 2D Tensor of shape (H x W).
+            Tensor channel should be in RGB order.
+        rgb2bgr (bool): Whether to change rgb to bgr.
+        out_type (numpy type): output types. If ``np.uint8``, transform outputs
+            to uint8 type with range [0, 255]; otherwise, float type with
+            range [0, 1]. Default: ``np.uint8``.
+        min_max (tuple[int]): min and max values for clamp.
+
+    Returns:
+        (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
+        shape (H x W). The channel order is BGR.
+    """
+    if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
+        raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
+
+    if torch.is_tensor(tensor):
+        tensor = [tensor]
+    result = []
+    for _tensor in tensor:
+        _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
+        _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
+
+        n_dim = _tensor.dim()
+        if n_dim == 4:
+            img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
+            img_np = img_np.transpose(1, 2, 0)
+            if rgb2bgr:
+                img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+        elif n_dim == 3:
+            img_np = _tensor.numpy()
+            img_np = img_np.transpose(1, 2, 0)
+            if img_np.shape[2] == 1:  # gray image
+                img_np = np.squeeze(img_np, axis=2)
+            else:
+                if rgb2bgr:
+                    img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+        elif n_dim == 2:
+            img_np = _tensor.numpy()
+        else:
+            raise TypeError('Only support 4D, 3D or 2D tensor. ' f'But received with dimension: {n_dim}')
+        if out_type == np.uint8:
+            # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
+            img_np = (img_np * 255.0).round()
+        img_np = img_np.astype(out_type)
+        result.append(img_np)
+    if len(result) == 1:
+        result = result[0]
+    return result
+
+
+def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
+    """This implementation is slightly faster than tensor2img.
+    It now only supports torch tensor with shape (1, c, h, w).
+
+    Args:
+        tensor (Tensor): Now only support torch tensor with (1, c, h, w).
+        rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
+        min_max (tuple[int]): min and max values for clamp.
+    """
+    output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
+    output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
+    output = output.type(torch.uint8).cpu().numpy()
+    if rgb2bgr:
+        output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
+    return output
+
+
+def imfrombytes(content, flag='color', float32=False):
+    """Read an image from bytes.
+
+    Args:
+        content (bytes): Image bytes got from files or other streams.
+        flag (str): Flags specifying the color type of a loaded image,
+            candidates are `color`, `grayscale` and `unchanged`.
+        float32 (bool): Whether to change to float32., If True, will also norm
+            to [0, 1]. Default: False.
+
+    Returns:
+        ndarray: Loaded image array.
+    """
+    img_np = np.frombuffer(content, np.uint8)
+    imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
+    img = cv2.imdecode(img_np, imread_flags[flag])
+    if float32:
+        img = img.astype(np.float32) / 255.
+    return img
+
+
+def imwrite(img, file_path, params=None, auto_mkdir=True):
+    """Write image to file.
+
+    Args:
+        img (ndarray): Image array to be written.
+        file_path (str): Image file path.
+        params (None or list): Same as opencv's :func:`imwrite` interface.
+        auto_mkdir (bool): If the parent folder of `file_path` does not exist,
+            whether to create it automatically.
+
+    Returns:
+        bool: Successful or not.
+    """
+    if auto_mkdir:
+        dir_name = os.path.abspath(os.path.dirname(file_path))
+        os.makedirs(dir_name, exist_ok=True)
+    return cv2.imwrite(file_path, img, params)
+
+
+def crop_border(imgs, crop_border):
+    """Crop borders of images.
+
+    Args:
+        imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
+        crop_border (int): Crop border for each end of height and weight.
+
+    Returns:
+        list[ndarray]: Cropped images.
+    """
+    if crop_border == 0:
+        return imgs
+    else:
+        if isinstance(imgs, list):
+            return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
+        else:
+            return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
diff --git a/vscodeformer/matlab_cp2tform.py b/vscodeformer/matlab_cp2tform.py
new file mode 100644
index 0000000..b2a8b54
--- /dev/null
+++ b/vscodeformer/matlab_cp2tform.py
@@ -0,0 +1,317 @@
+import numpy as np
+from numpy.linalg import inv, lstsq
+from numpy.linalg import matrix_rank as rank
+from numpy.linalg import norm
+
+
+class MatlabCp2tormException(Exception):
+
+    def __str__(self):
+        return 'In File {}:{}'.format(__file__, super.__str__(self))
+
+
+def tformfwd(trans, uv):
+    """
+    Function:
+    ----------
+        apply affine transform 'trans' to uv
+
+    Parameters:
+    ----------
+        @trans: 3x3 np.array
+            transform matrix
+        @uv: Kx2 np.array
+            each row is a pair of coordinates (x, y)
+
+    Returns:
+    ----------
+        @xy: Kx2 np.array
+            each row is a pair of transformed coordinates (x, y)
+    """
+    uv = np.hstack((uv, np.ones((uv.shape[0], 1))))
+    xy = np.dot(uv, trans)
+    xy = xy[:, 0:-1]
+    return xy
+
+
+def tforminv(trans, uv):
+    """
+    Function:
+    ----------
+        apply the inverse of affine transform 'trans' to uv
+
+    Parameters:
+    ----------
+        @trans: 3x3 np.array
+            transform matrix
+        @uv: Kx2 np.array
+            each row is a pair of coordinates (x, y)
+
+    Returns:
+    ----------
+        @xy: Kx2 np.array
+            each row is a pair of inverse-transformed coordinates (x, y)
+    """
+    Tinv = inv(trans)
+    xy = tformfwd(Tinv, uv)
+    return xy
+
+
+def findNonreflectiveSimilarity(uv, xy, options=None):
+    options = {'K': 2}
+
+    K = options['K']
+    M = xy.shape[0]
+    x = xy[:, 0].reshape((-1, 1))  # use reshape to keep a column vector
+    y = xy[:, 1].reshape((-1, 1))  # use reshape to keep a column vector
+
+    tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
+    tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
+    X = np.vstack((tmp1, tmp2))
+
+    u = uv[:, 0].reshape((-1, 1))  # use reshape to keep a column vector
+    v = uv[:, 1].reshape((-1, 1))  # use reshape to keep a column vector
+    U = np.vstack((u, v))
+
+    # We know that X * r = U
+    if rank(X) >= 2 * K:
+        r, _, _, _ = lstsq(X, U, rcond=-1)
+        r = np.squeeze(r)
+    else:
+        raise Exception('cp2tform:twoUniquePointsReq')
+    sc = r[0]
+    ss = r[1]
+    tx = r[2]
+    ty = r[3]
+
+    Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]])
+    T = inv(Tinv)
+    T[:, 2] = np.array([0, 0, 1])
+
+    return T, Tinv
+
+
+def findSimilarity(uv, xy, options=None):
+    options = {'K': 2}
+
+    #    uv = np.array(uv)
+    #    xy = np.array(xy)
+
+    # Solve for trans1
+    trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options)
+
+    # Solve for trans2
+
+    # manually reflect the xy data across the Y-axis
+    xyR = xy
+    xyR[:, 0] = -1 * xyR[:, 0]
+
+    trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options)
+
+    # manually reflect the tform to undo the reflection done on xyR
+    TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
+
+    trans2 = np.dot(trans2r, TreflectY)
+
+    # Figure out if trans1 or trans2 is better
+    xy1 = tformfwd(trans1, uv)
+    norm1 = norm(xy1 - xy)
+
+    xy2 = tformfwd(trans2, uv)
+    norm2 = norm(xy2 - xy)
+
+    if norm1 <= norm2:
+        return trans1, trans1_inv
+    else:
+        trans2_inv = inv(trans2)
+        return trans2, trans2_inv
+
+
+def get_similarity_transform(src_pts, dst_pts, reflective=True):
+    """
+    Function:
+    ----------
+        Find Similarity Transform Matrix 'trans':
+            u = src_pts[:, 0]
+            v = src_pts[:, 1]
+            x = dst_pts[:, 0]
+            y = dst_pts[:, 1]
+            [x, y, 1] = [u, v, 1] * trans
+
+    Parameters:
+    ----------
+        @src_pts: Kx2 np.array
+            source points, each row is a pair of coordinates (x, y)
+        @dst_pts: Kx2 np.array
+            destination points, each row is a pair of transformed
+            coordinates (x, y)
+        @reflective: True or False
+            if True:
+                use reflective similarity transform
+            else:
+                use non-reflective similarity transform
+
+    Returns:
+    ----------
+       @trans: 3x3 np.array
+            transform matrix from uv to xy
+        trans_inv: 3x3 np.array
+            inverse of trans, transform matrix from xy to uv
+    """
+
+    if reflective:
+        trans, trans_inv = findSimilarity(src_pts, dst_pts)
+    else:
+        trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts)
+
+    return trans, trans_inv
+
+
+def cvt_tform_mat_for_cv2(trans):
+    """
+    Function:
+    ----------
+        Convert Transform Matrix 'trans' into 'cv2_trans' which could be
+        directly used by cv2.warpAffine():
+            u = src_pts[:, 0]
+            v = src_pts[:, 1]
+            x = dst_pts[:, 0]
+            y = dst_pts[:, 1]
+            [x, y].T = cv_trans * [u, v, 1].T
+
+    Parameters:
+    ----------
+        @trans: 3x3 np.array
+            transform matrix from uv to xy
+
+    Returns:
+    ----------
+        @cv2_trans: 2x3 np.array
+            transform matrix from src_pts to dst_pts, could be directly used
+            for cv2.warpAffine()
+    """
+    cv2_trans = trans[:, 0:2].T
+
+    return cv2_trans
+
+
+def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
+    """
+    Function:
+    ----------
+        Find Similarity Transform Matrix 'cv2_trans' which could be
+        directly used by cv2.warpAffine():
+            u = src_pts[:, 0]
+            v = src_pts[:, 1]
+            x = dst_pts[:, 0]
+            y = dst_pts[:, 1]
+            [x, y].T = cv_trans * [u, v, 1].T
+
+    Parameters:
+    ----------
+        @src_pts: Kx2 np.array
+            source points, each row is a pair of coordinates (x, y)
+        @dst_pts: Kx2 np.array
+            destination points, each row is a pair of transformed
+            coordinates (x, y)
+        reflective: True or False
+            if True:
+                use reflective similarity transform
+            else:
+                use non-reflective similarity transform
+
+    Returns:
+    ----------
+        @cv2_trans: 2x3 np.array
+            transform matrix from src_pts to dst_pts, could be directly used
+            for cv2.warpAffine()
+    """
+    trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective)
+    cv2_trans = cvt_tform_mat_for_cv2(trans)
+
+    return cv2_trans
+
+
+if __name__ == '__main__':
+    """
+    u = [0, 6, -2]
+    v = [0, 3, 5]
+    x = [-1, 0, 4]
+    y = [-1, -10, 4]
+
+    # In Matlab, run:
+    #
+    #   uv = [u'; v'];
+    #   xy = [x'; y'];
+    #   tform_sim=cp2tform(uv,xy,'similarity');
+    #
+    #   trans = tform_sim.tdata.T
+    #   ans =
+    #       -0.0764   -1.6190         0
+    #        1.6190   -0.0764         0
+    #       -3.2156    0.0290    1.0000
+    #   trans_inv = tform_sim.tdata.Tinv
+    #    ans =
+    #
+    #       -0.0291    0.6163         0
+    #       -0.6163   -0.0291         0
+    #       -0.0756    1.9826    1.0000
+    #    xy_m=tformfwd(tform_sim, u,v)
+    #
+    #    xy_m =
+    #
+    #       -3.2156    0.0290
+    #        1.1833   -9.9143
+    #        5.0323    2.8853
+    #    uv_m=tforminv(tform_sim, x,y)
+    #
+    #    uv_m =
+    #
+    #        0.5698    1.3953
+    #        6.0872    2.2733
+    #       -2.6570    4.3314
+    """
+    u = [0, 6, -2]
+    v = [0, 3, 5]
+    x = [-1, 0, 4]
+    y = [-1, -10, 4]
+
+    uv = np.array((u, v)).T
+    xy = np.array((x, y)).T
+
+    print('\n--->uv:')
+    print(uv)
+    print('\n--->xy:')
+    print(xy)
+
+    trans, trans_inv = get_similarity_transform(uv, xy)
+
+    print('\n--->trans matrix:')
+    print(trans)
+
+    print('\n--->trans_inv matrix:')
+    print(trans_inv)
+
+    print('\n---> apply transform to uv')
+    print('\nxy_m = uv_augmented * trans')
+    uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1))))
+    xy_m = np.dot(uv_aug, trans)
+    print(xy_m)
+
+    print('\nxy_m = tformfwd(trans, uv)')
+    xy_m = tformfwd(trans, uv)
+    print(xy_m)
+
+    print('\n---> apply inverse transform to xy')
+    print('\nuv_m = xy_augmented * trans_inv')
+    xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1))))
+    uv_m = np.dot(xy_aug, trans_inv)
+    print(uv_m)
+
+    print('\nuv_m = tformfwd(trans_inv, xy)')
+    uv_m = tformfwd(trans_inv, xy)
+    print(uv_m)
+
+    uv_m = tforminv(trans, xy)
+    print('\nuv_m = tforminv(trans, xy)')
+    print(uv_m)
diff --git a/vscodeformer/misc.py b/vscodeformer/misc.py
new file mode 100644
index 0000000..037d5ef
--- /dev/null
+++ b/vscodeformer/misc.py
@@ -0,0 +1,204 @@
+import os
+import os.path as osp
+from urllib.parse import urlparse
+
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+from torch.hub import download_url_to_file, get_dir
+
+# from basicsr.utils.download_util import download_file_from_google_drive
+
+ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+
+def download_pretrained_models(file_ids, save_path_root):
+    import gdown
+
+    os.makedirs(save_path_root, exist_ok=True)
+
+    for file_name, file_id in file_ids.items():
+        file_url = 'https://drive.google.com/uc?id='+file_id
+        save_path = osp.abspath(osp.join(save_path_root, file_name))
+        if osp.exists(save_path):
+            user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
+            if user_response.lower() == 'y':
+                print(f'Covering {file_name} to {save_path}')
+                gdown.download(file_url, save_path, quiet=False)
+                # download_file_from_google_drive(file_id, save_path)
+            elif user_response.lower() == 'n':
+                print(f'Skipping {file_name}')
+            else:
+                raise ValueError('Wrong input. Only accepts Y/N.')
+        else:
+            print(f'Downloading {file_name} to {save_path}')
+            gdown.download(file_url, save_path, quiet=False)
+            # download_file_from_google_drive(file_id, save_path)
+
+
+def imwrite(img, file_path, params=None, auto_mkdir=True):
+    """Write image to file.
+
+    Args:
+        img (ndarray): Image array to be written.
+        file_path (str): Image file path.
+        params (None or list): Same as opencv's :func:`imwrite` interface.
+        auto_mkdir (bool): If the parent folder of `file_path` does not exist,
+            whether to create it automatically.
+
+    Returns:
+        bool: Successful or not.
+    """
+    if auto_mkdir:
+        dir_name = os.path.abspath(os.path.dirname(file_path))
+        os.makedirs(dir_name, exist_ok=True)
+    return cv2.imwrite(file_path, img, params)
+
+
+def img2tensor(imgs, bgr2rgb=True, float32=True):
+    """Numpy array to tensor.
+
+    Args:
+        imgs (list[ndarray] | ndarray): Input images.
+        bgr2rgb (bool): Whether to change bgr to rgb.
+        float32 (bool): Whether to change to float32.
+
+    Returns:
+        list[tensor] | tensor: Tensor images. If returned results only have
+            one element, just return tensor.
+    """
+
+    def _totensor(img, bgr2rgb, float32):
+        if img.shape[2] == 3 and bgr2rgb:
+            if img.dtype == 'float64':
+                img = img.astype('float32')
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+        img = torch.from_numpy(img.transpose(2, 0, 1))
+        if float32:
+            img = img.float()
+        return img
+
+    if isinstance(imgs, list):
+        return [_totensor(img, bgr2rgb, float32) for img in imgs]
+    else:
+        return _totensor(imgs, bgr2rgb, float32)
+
+
+def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
+    """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
+    """
+    if model_dir is None:
+        hub_dir = get_dir()
+        model_dir = os.path.join(hub_dir, 'checkpoints')
+
+    os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
+
+    parts = urlparse(url)
+    filename = os.path.basename(parts.path)
+    if file_name is not None:
+        filename = file_name
+    cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
+    if not os.path.exists(cached_file):
+        print(f'Downloading: "{url}" to {cached_file}\n')
+        download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
+    return cached_file
+
+
+def scandir(dir_path, suffix=None, recursive=False, full_path=False):
+    """Scan a directory to find the interested files.
+    Args:
+        dir_path (str): Path of the directory.
+        suffix (str | tuple(str), optional): File suffix that we are
+            interested in. Default: None.
+        recursive (bool, optional): If set to True, recursively scan the
+            directory. Default: False.
+        full_path (bool, optional): If set to True, include the dir_path.
+            Default: False.
+    Returns:
+        A generator for all the interested files with relative paths.
+    """
+
+    if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+        raise TypeError('"suffix" must be a string or tuple of strings')
+
+    root = dir_path
+
+    def _scandir(dir_path, suffix, recursive):
+        for entry in os.scandir(dir_path):
+            if not entry.name.startswith('.') and entry.is_file():
+                if full_path:
+                    return_path = entry.path
+                else:
+                    return_path = osp.relpath(entry.path, root)
+
+                if suffix is None:
+                    yield return_path
+                elif return_path.endswith(suffix):
+                    yield return_path
+            else:
+                if recursive:
+                    yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
+                else:
+                    continue
+
+    return _scandir(dir_path, suffix=suffix, recursive=recursive)
+
+
+def is_gray(img, threshold=10):
+    img = Image.fromarray(img)
+    if len(img.getbands()) == 1:
+        return True
+    img1 = np.asarray(img.getchannel(channel=0), dtype=np.int16)
+    img2 = np.asarray(img.getchannel(channel=1), dtype=np.int16)
+    img3 = np.asarray(img.getchannel(channel=2), dtype=np.int16)
+    diff1 = (img1 - img2).var()
+    diff2 = (img2 - img3).var()
+    diff3 = (img3 - img1).var()
+    diff_sum = (diff1 + diff2 + diff3) / 3.0
+    if diff_sum <= threshold:
+        return True
+    else:
+        return False
+
+def rgb2gray(img, out_channel=3):
+    r, g, b = img[:,:,0], img[:,:,1], img[:,:,2]
+    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
+    if out_channel == 3:
+        gray = gray[:,:,np.newaxis].repeat(3, axis=2)
+    return gray
+
+def bgr2gray(img, out_channel=3):
+    b, g, r = img[:,:,0], img[:,:,1], img[:,:,2]
+    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
+    if out_channel == 3:
+        gray = gray[:,:,np.newaxis].repeat(3, axis=2)
+    return gray
+
+
+def calc_mean_std(feat, eps=1e-5):
+    """
+    Args:
+        feat (numpy): 3D [w h c]s
+    """
+    size = feat.shape
+    assert len(size) == 3, 'The input feature should be 3D tensor.'
+    c = size[2]
+    feat_var = feat.reshape(-1, c).var(axis=0) + eps
+    feat_std = np.sqrt(feat_var).reshape(1, 1, c)
+    feat_mean = feat.reshape(-1, c).mean(axis=0).reshape(1, 1, c)
+    return feat_mean, feat_std
+
+
+def adain_npy(content_feat, style_feat):
+    """Adaptive instance normalization for numpy.
+
+    Args:
+        content_feat (numpy): The input feature.
+        style_feat (numpy): The reference feature.
+    """
+    size = content_feat.shape
+    style_mean, style_std = calc_mean_std(style_feat)
+    content_mean, content_std = calc_mean_std(content_feat)
+    normalized_feat = (content_feat - np.broadcast_to(content_mean, size)) / np.broadcast_to(content_std, size)
+    return normalized_feat * np.broadcast_to(style_std, size) + np.broadcast_to(style_mean, size)
diff --git a/vscodeformer/models/codeformer.pth b/vscodeformer/models/codeformer.pth
new file mode 100644
index 0000000..e69de29
diff --git a/vscodeformer/models/detection_Resnet50_Final.pth b/vscodeformer/models/detection_Resnet50_Final.pth
new file mode 100644
index 0000000..e69de29
diff --git a/vscodeformer/models/mmod_human_face_detector-4cb19393.dat b/vscodeformer/models/mmod_human_face_detector-4cb19393.dat
new file mode 100644
index 0000000..e69de29
diff --git a/vscodeformer/models/parsing_parsenet.pth b/vscodeformer/models/parsing_parsenet.pth
new file mode 100644
index 0000000..e69de29
diff --git a/vscodeformer/models/shape_predictor_5_face_landmarks-c4b1e980.dat b/vscodeformer/models/shape_predictor_5_face_landmarks-c4b1e980.dat
new file mode 100644
index 0000000..e69de29
diff --git a/vscodeformer/parsenet.py b/vscodeformer/parsenet.py
new file mode 100644
index 0000000..e178ebe
--- /dev/null
+++ b/vscodeformer/parsenet.py
@@ -0,0 +1,194 @@
+"""Modified from https://github.com/chaofengc/PSFRGAN
+"""
+import numpy as np
+import torch.nn as nn
+from torch.nn import functional as F
+
+
+class NormLayer(nn.Module):
+    """Normalization Layers.
+
+    Args:
+        channels: input channels, for batch norm and instance norm.
+        input_size: input shape without batch size, for layer norm.
+    """
+
+    def __init__(self, channels, normalize_shape=None, norm_type='bn'):
+        super(NormLayer, self).__init__()
+        norm_type = norm_type.lower()
+        self.norm_type = norm_type
+        if norm_type == 'bn':
+            self.norm = nn.BatchNorm2d(channels, affine=True)
+        elif norm_type == 'in':
+            self.norm = nn.InstanceNorm2d(channels, affine=False)
+        elif norm_type == 'gn':
+            self.norm = nn.GroupNorm(32, channels, affine=True)
+        elif norm_type == 'pixel':
+            self.norm = lambda x: F.normalize(x, p=2, dim=1)
+        elif norm_type == 'layer':
+            self.norm = nn.LayerNorm(normalize_shape)
+        elif norm_type == 'none':
+            self.norm = lambda x: x * 1.0
+        else:
+            assert 1 == 0, f'Norm type {norm_type} not support.'
+
+    def forward(self, x, ref=None):
+        if self.norm_type == 'spade':
+            return self.norm(x, ref)
+        else:
+            return self.norm(x)
+
+
+class ReluLayer(nn.Module):
+    """Relu Layer.
+
+    Args:
+        relu type: type of relu layer, candidates are
+            - ReLU
+            - LeakyReLU: default relu slope 0.2
+            - PRelu
+            - SELU
+            - none: direct pass
+    """
+
+    def __init__(self, channels, relu_type='relu'):
+        super(ReluLayer, self).__init__()
+        relu_type = relu_type.lower()
+        if relu_type == 'relu':
+            self.func = nn.ReLU(True)
+        elif relu_type == 'leakyrelu':
+            self.func = nn.LeakyReLU(0.2, inplace=True)
+        elif relu_type == 'prelu':
+            self.func = nn.PReLU(channels)
+        elif relu_type == 'selu':
+            self.func = nn.SELU(True)
+        elif relu_type == 'none':
+            self.func = lambda x: x * 1.0
+        else:
+            assert 1 == 0, f'Relu type {relu_type} not support.'
+
+    def forward(self, x):
+        return self.func(x)
+
+
+class ConvLayer(nn.Module):
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size=3,
+                 scale='none',
+                 norm_type='none',
+                 relu_type='none',
+                 use_pad=True,
+                 bias=True):
+        super(ConvLayer, self).__init__()
+        self.use_pad = use_pad
+        self.norm_type = norm_type
+        if norm_type in ['bn']:
+            bias = False
+
+        stride = 2 if scale == 'down' else 1
+
+        self.scale_func = lambda x: x
+        if scale == 'up':
+            self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest')
+
+        self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.) / 2)))
+        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias)
+
+        self.relu = ReluLayer(out_channels, relu_type)
+        self.norm = NormLayer(out_channels, norm_type=norm_type)
+
+    def forward(self, x):
+        out = self.scale_func(x)
+        if self.use_pad:
+            out = self.reflection_pad(out)
+        out = self.conv2d(out)
+        out = self.norm(out)
+        out = self.relu(out)
+        return out
+
+
+class ResidualBlock(nn.Module):
+    """
+    Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html
+    """
+
+    def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'):
+        super(ResidualBlock, self).__init__()
+
+        if scale == 'none' and c_in == c_out:
+            self.shortcut_func = lambda x: x
+        else:
+            self.shortcut_func = ConvLayer(c_in, c_out, 3, scale)
+
+        scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']}
+        scale_conf = scale_config_dict[scale]
+
+        self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type)
+        self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none')
+
+    def forward(self, x):
+        identity = self.shortcut_func(x)
+
+        res = self.conv1(x)
+        res = self.conv2(res)
+        return identity + res
+
+
+class ParseNet(nn.Module):
+
+    def __init__(self,
+                 in_size=128,
+                 out_size=128,
+                 min_feat_size=32,
+                 base_ch=64,
+                 parsing_ch=19,
+                 res_depth=10,
+                 relu_type='LeakyReLU',
+                 norm_type='bn',
+                 ch_range=[32, 256]):
+        super().__init__()
+        self.res_depth = res_depth
+        act_args = {'norm_type': norm_type, 'relu_type': relu_type}
+        min_ch, max_ch = ch_range
+
+        ch_clip = lambda x: max(min_ch, min(x, max_ch))  # noqa: E731
+        min_feat_size = min(in_size, min_feat_size)
+
+        down_steps = int(np.log2(in_size // min_feat_size))
+        up_steps = int(np.log2(out_size // min_feat_size))
+
+        # =============== define encoder-body-decoder ====================
+        self.encoder = []
+        self.encoder.append(ConvLayer(3, base_ch, 3, 1))
+        head_ch = base_ch
+        for i in range(down_steps):
+            cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2)
+            self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args))
+            head_ch = head_ch * 2
+
+        self.body = []
+        for i in range(res_depth):
+            self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args))
+
+        self.decoder = []
+        for i in range(up_steps):
+            cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2)
+            self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args))
+            head_ch = head_ch // 2
+
+        self.encoder = nn.Sequential(*self.encoder)
+        self.body = nn.Sequential(*self.body)
+        self.decoder = nn.Sequential(*self.decoder)
+        self.out_img_conv = ConvLayer(ch_clip(head_ch), 3)
+        self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch)
+
+    def forward(self, x):
+        feat = self.encoder(x)
+        x = feat + self.body(feat)
+        x = self.decoder(x)
+        out_img = self.out_img_conv(x)
+        out_mask = self.out_mask_conv(x)
+        return out_mask, out_img
diff --git a/vscodeformer/retinaface.py b/vscodeformer/retinaface.py
new file mode 100644
index 0000000..b9ba567
--- /dev/null
+++ b/vscodeformer/retinaface.py
@@ -0,0 +1,363 @@
+import cv2
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from PIL import Image
+from torchvision.models._utils import IntermediateLayerGetter as IntermediateLayerGetter
+
+from .align_trans import get_reference_facial_points, warp_and_crop_face
+from .retinaface_net import FPN, SSH, MobileNetV1, make_bbox_head, make_class_head, make_landmark_head
+from .retinaface_utils import PriorBox, batched_decode, batched_decode_landm, decode, decode_landm, py_cpu_nms
+
+
+def generate_config(network_name):
+
+    cfg_mnet = {
+        'name': 'mobilenet0.25',
+        'min_sizes': [[16, 32], [64, 128], [256, 512]],
+        'steps': [8, 16, 32],
+        'variance': [0.1, 0.2],
+        'clip': False,
+        'loc_weight': 2.0,
+        'gpu_train': True,
+        'batch_size': 32,
+        'ngpu': 1,
+        'epoch': 250,
+        'decay1': 190,
+        'decay2': 220,
+        'image_size': 640,
+        'return_layers': {
+            'stage1': 1,
+            'stage2': 2,
+            'stage3': 3
+        },
+        'in_channel': 32,
+        'out_channel': 64
+    }
+
+    cfg_re50 = {
+        'name': 'Resnet50',
+        'min_sizes': [[16, 32], [64, 128], [256, 512]],
+        'steps': [8, 16, 32],
+        'variance': [0.1, 0.2],
+        'clip': False,
+        'loc_weight': 2.0,
+        'gpu_train': True,
+        'batch_size': 24,
+        'ngpu': 4,
+        'epoch': 100,
+        'decay1': 70,
+        'decay2': 90,
+        'image_size': 840,
+        'return_layers': {
+            'layer2': 1,
+            'layer3': 2,
+            'layer4': 3
+        },
+        'in_channel': 256,
+        'out_channel': 256
+    }
+
+    if network_name == 'mobile0.25':
+        return cfg_mnet
+    elif network_name == 'resnet50':
+        return cfg_re50
+    else:
+        raise NotImplementedError(f'network_name={network_name}')
+
+
+class RetinaFace(nn.Module):
+
+    def __init__(self, device, network_name='resnet50', half=False, phase='test'):
+        super(RetinaFace, self).__init__()
+        self.device = device
+        self.half_inference = half
+        cfg = generate_config(network_name)
+        self.backbone = cfg['name']
+
+        self.model_name = f'retinaface_{network_name}'
+        self.cfg = cfg
+        self.phase = phase
+        self.target_size, self.max_size = 1600, 2150
+        self.resize, self.scale, self.scale1 = 1., None, None
+        self.mean_tensor = torch.tensor([[[[104.]], [[117.]], [[123.]]]], device=device)
+        self.reference = get_reference_facial_points(default_square=True)
+        # Build network.
+        backbone = None
+        if cfg['name'] == 'mobilenet0.25':
+            backbone = MobileNetV1()
+            self.body = IntermediateLayerGetter(backbone, cfg['return_layers'])
+        elif cfg['name'] == 'Resnet50':
+            import torchvision.models as models
+            backbone = models.resnet50(pretrained=False)
+            self.body = IntermediateLayerGetter(backbone, cfg['return_layers'])
+
+        in_channels_stage2 = cfg['in_channel']
+        in_channels_list = [
+            in_channels_stage2 * 2,
+            in_channels_stage2 * 4,
+            in_channels_stage2 * 8,
+        ]
+
+        out_channels = cfg['out_channel']
+        self.fpn = FPN(in_channels_list, out_channels)
+        self.ssh1 = SSH(out_channels, out_channels)
+        self.ssh2 = SSH(out_channels, out_channels)
+        self.ssh3 = SSH(out_channels, out_channels)
+
+        self.ClassHead = make_class_head(fpn_num=3, inchannels=cfg['out_channel'])
+        self.BboxHead = make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])
+        self.LandmarkHead = make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])
+
+    def forward(self, inputs):
+        out = self.body(inputs)
+
+        if self.backbone == 'mobilenet0.25' or self.backbone == 'Resnet50':
+            out = list(out.values())
+        # FPN
+        fpn = self.fpn(out)
+
+        # SSH
+        feature1 = self.ssh1(fpn[0])
+        feature2 = self.ssh2(fpn[1])
+        feature3 = self.ssh3(fpn[2])
+        features = [feature1, feature2, feature3]
+
+        bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)
+        classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1)
+        tmp = [self.LandmarkHead[i](feature) for i, feature in enumerate(features)]
+        ldm_regressions = (torch.cat(tmp, dim=1))
+
+        if self.phase == 'train':
+            output = (bbox_regressions, classifications, ldm_regressions)
+        else:
+            output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions)
+        return output
+
+    def __detect_faces(self, inputs):
+        # get scale
+        height, width = inputs.shape[2:]
+        self.scale = torch.tensor([width, height, width, height], dtype=torch.float32, device=self.device)
+        tmp = [width, height, width, height, width, height, width, height, width, height]
+        self.scale1 = torch.tensor(tmp, dtype=torch.float32, device=self.device)
+
+        # forawrd
+        inputs = inputs.to(self.device)
+        if self.half_inference:
+            inputs = inputs.half()
+        loc, conf, landmarks = self(inputs)
+
+        # get priorbox
+        priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:])
+        priors = priorbox.forward().to(self.device)
+
+        return loc, conf, landmarks, priors
+
+    # single image detection
+    def transform(self, image, use_origin_size):
+        # convert to opencv format
+        if isinstance(image, Image.Image):
+            image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
+        image = image.astype(np.float32)
+
+        # testing scale
+        im_size_min = np.min(image.shape[0:2])
+        im_size_max = np.max(image.shape[0:2])
+        resize = float(self.target_size) / float(im_size_min)
+
+        # prevent bigger axis from being more than max_size
+        if np.round(resize * im_size_max) > self.max_size:
+            resize = float(self.max_size) / float(im_size_max)
+        resize = 1 if use_origin_size else resize
+
+        # resize
+        if resize != 1:
+            image = cv2.resize(image, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
+
+        # convert to torch.tensor format
+        # image -= (104, 117, 123)
+        image = image.transpose(2, 0, 1)
+        image = torch.from_numpy(image).unsqueeze(0)
+
+        return image, resize
+
+    def detect_faces(
+        self,
+        image,
+        conf_threshold=0.8,
+        nms_threshold=0.4,
+        use_origin_size=True,
+    ):
+        """
+        Params:
+            imgs: BGR image
+        """
+        image, self.resize = self.transform(image, use_origin_size)
+        image = image.to(self.device)
+        if self.half_inference:
+            image = image.half()
+        image = image - self.mean_tensor
+
+        loc, conf, landmarks, priors = self.__detect_faces(image)
+
+        boxes = decode(loc.data.squeeze(0), priors.data, self.cfg['variance'])
+        boxes = boxes * self.scale / self.resize
+        boxes = boxes.cpu().numpy()
+
+        scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
+
+        landmarks = decode_landm(landmarks.squeeze(0), priors, self.cfg['variance'])
+        landmarks = landmarks * self.scale1 / self.resize
+        landmarks = landmarks.cpu().numpy()
+
+        # ignore low scores
+        inds = np.where(scores > conf_threshold)[0]
+        boxes, landmarks, scores = boxes[inds], landmarks[inds], scores[inds]
+
+        # sort
+        order = scores.argsort()[::-1]
+        boxes, landmarks, scores = boxes[order], landmarks[order], scores[order]
+
+        # do NMS
+        bounding_boxes = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
+        keep = py_cpu_nms(bounding_boxes, nms_threshold)
+        bounding_boxes, landmarks = bounding_boxes[keep, :], landmarks[keep]
+        # self.t['forward_pass'].toc()
+        # print(self.t['forward_pass'].average_time)
+        # import sys
+        # sys.stdout.flush()
+        return np.concatenate((bounding_boxes, landmarks), axis=1)
+
+    def __align_multi(self, image, boxes, landmarks, limit=None):
+
+        if len(boxes) < 1:
+            return [], []
+
+        if limit:
+            boxes = boxes[:limit]
+            landmarks = landmarks[:limit]
+
+        faces = []
+        for landmark in landmarks:
+            facial5points = [[landmark[2 * j], landmark[2 * j + 1]] for j in range(5)]
+
+            warped_face = warp_and_crop_face(np.array(image), facial5points, self.reference, crop_size=(112, 112))
+            faces.append(warped_face)
+
+        return np.concatenate((boxes, landmarks), axis=1), faces
+
+    def align_multi(self, img, conf_threshold=0.8, limit=None):
+
+        rlt = self.detect_faces(img, conf_threshold=conf_threshold)
+        boxes, landmarks = rlt[:, 0:5], rlt[:, 5:]
+
+        return self.__align_multi(img, boxes, landmarks, limit)
+
+    # batched detection
+    def batched_transform(self, frames, use_origin_size):
+        """
+        Arguments:
+            frames: a list of PIL.Image, or torch.Tensor(shape=[n, h, w, c],
+                type=np.float32, BGR format).
+            use_origin_size: whether to use origin size.
+        """
+        from_PIL = True if isinstance(frames[0], Image.Image) else False
+
+        # convert to opencv format
+        if from_PIL:
+            frames = [cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR) for frame in frames]
+            frames = np.asarray(frames, dtype=np.float32)
+
+        # testing scale
+        im_size_min = np.min(frames[0].shape[0:2])
+        im_size_max = np.max(frames[0].shape[0:2])
+        resize = float(self.target_size) / float(im_size_min)
+
+        # prevent bigger axis from being more than max_size
+        if np.round(resize * im_size_max) > self.max_size:
+            resize = float(self.max_size) / float(im_size_max)
+        resize = 1 if use_origin_size else resize
+
+        # resize
+        if resize != 1:
+            if not from_PIL:
+                frames = F.interpolate(frames, scale_factor=resize)
+            else:
+                frames = [
+                    cv2.resize(frame, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
+                    for frame in frames
+                ]
+
+        # convert to torch.tensor format
+        if not from_PIL:
+            frames = frames.transpose(1, 2).transpose(1, 3).contiguous()
+        else:
+            frames = frames.transpose((0, 3, 1, 2))
+            frames = torch.from_numpy(frames)
+
+        return frames, resize
+
+    def batched_detect_faces(self, frames, conf_threshold=0.8, nms_threshold=0.4, use_origin_size=True):
+        """
+        Arguments:
+            frames: a list of PIL.Image, or np.array(shape=[n, h, w, c],
+                type=np.uint8, BGR format).
+            conf_threshold: confidence threshold.
+            nms_threshold: nms threshold.
+            use_origin_size: whether to use origin size.
+        Returns:
+            final_bounding_boxes: list of np.array ([n_boxes, 5],
+                type=np.float32).
+            final_landmarks: list of np.array ([n_boxes, 10], type=np.float32).
+        """
+        # self.t['forward_pass'].tic()
+        frames, self.resize = self.batched_transform(frames, use_origin_size)
+        frames = frames.to(self.device)
+        frames = frames - self.mean_tensor
+
+        b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames)
+
+        final_bounding_boxes, final_landmarks = [], []
+
+        # decode
+        priors = priors.unsqueeze(0)
+        b_loc = batched_decode(b_loc, priors, self.cfg['variance']) * self.scale / self.resize
+        b_landmarks = batched_decode_landm(b_landmarks, priors, self.cfg['variance']) * self.scale1 / self.resize
+        b_conf = b_conf[:, :, 1]
+
+        # index for selection
+        b_indice = b_conf > conf_threshold
+
+        # concat
+        b_loc_and_conf = torch.cat((b_loc, b_conf.unsqueeze(-1)), dim=2).float()
+
+        for pred, landm, inds in zip(b_loc_and_conf, b_landmarks, b_indice):
+
+            # ignore low scores
+            pred, landm = pred[inds, :], landm[inds, :]
+            if pred.shape[0] == 0:
+                final_bounding_boxes.append(np.array([], dtype=np.float32))
+                final_landmarks.append(np.array([], dtype=np.float32))
+                continue
+
+            # sort
+            # order = score.argsort(descending=True)
+            # box, landm, score = box[order], landm[order], score[order]
+
+            # to CPU
+            bounding_boxes, landm = pred.cpu().numpy(), landm.cpu().numpy()
+
+            # NMS
+            keep = py_cpu_nms(bounding_boxes, nms_threshold)
+            bounding_boxes, landmarks = bounding_boxes[keep, :], landm[keep]
+
+            # append
+            final_bounding_boxes.append(bounding_boxes)
+            final_landmarks.append(landmarks)
+        # self.t['forward_pass'].toc(average=True)
+        # self.batch_time += self.t['forward_pass'].diff
+        # self.total_frame += len(frames)
+        # print(self.batch_time / self.total_frame)
+
+        return final_bounding_boxes, final_landmarks
diff --git a/vscodeformer/retinaface_net.py b/vscodeformer/retinaface_net.py
new file mode 100644
index 0000000..ab6aa82
--- /dev/null
+++ b/vscodeformer/retinaface_net.py
@@ -0,0 +1,196 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def conv_bn(inp, oup, stride=1, leaky=0):
+    return nn.Sequential(
+        nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup),
+        nn.LeakyReLU(negative_slope=leaky, inplace=True))
+
+
+def conv_bn_no_relu(inp, oup, stride):
+    return nn.Sequential(
+        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
+        nn.BatchNorm2d(oup),
+    )
+
+
+def conv_bn1X1(inp, oup, stride, leaky=0):
+    return nn.Sequential(
+        nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), nn.BatchNorm2d(oup),
+        nn.LeakyReLU(negative_slope=leaky, inplace=True))
+
+
+def conv_dw(inp, oup, stride, leaky=0.1):
+    return nn.Sequential(
+        nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
+        nn.BatchNorm2d(inp),
+        nn.LeakyReLU(negative_slope=leaky, inplace=True),
+        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
+        nn.BatchNorm2d(oup),
+        nn.LeakyReLU(negative_slope=leaky, inplace=True),
+    )
+
+
+class SSH(nn.Module):
+
+    def __init__(self, in_channel, out_channel):
+        super(SSH, self).__init__()
+        assert out_channel % 4 == 0
+        leaky = 0
+        if (out_channel <= 64):
+            leaky = 0.1
+        self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1)
+
+        self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky)
+        self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
+
+        self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky)
+        self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
+
+    def forward(self, input):
+        conv3X3 = self.conv3X3(input)
+
+        conv5X5_1 = self.conv5X5_1(input)
+        conv5X5 = self.conv5X5_2(conv5X5_1)
+
+        conv7X7_2 = self.conv7X7_2(conv5X5_1)
+        conv7X7 = self.conv7x7_3(conv7X7_2)
+
+        out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
+        out = F.relu(out)
+        return out
+
+
+class FPN(nn.Module):
+
+    def __init__(self, in_channels_list, out_channels):
+        super(FPN, self).__init__()
+        leaky = 0
+        if (out_channels <= 64):
+            leaky = 0.1
+        self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky)
+        self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky)
+        self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky)
+
+        self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky)
+        self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky)
+
+    def forward(self, input):
+        # names = list(input.keys())
+        # input = list(input.values())
+
+        output1 = self.output1(input[0])
+        output2 = self.output2(input[1])
+        output3 = self.output3(input[2])
+
+        up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode='nearest')
+        output2 = output2 + up3
+        output2 = self.merge2(output2)
+
+        up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode='nearest')
+        output1 = output1 + up2
+        output1 = self.merge1(output1)
+
+        out = [output1, output2, output3]
+        return out
+
+
+class MobileNetV1(nn.Module):
+
+    def __init__(self):
+        super(MobileNetV1, self).__init__()
+        self.stage1 = nn.Sequential(
+            conv_bn(3, 8, 2, leaky=0.1),  # 3
+            conv_dw(8, 16, 1),  # 7
+            conv_dw(16, 32, 2),  # 11
+            conv_dw(32, 32, 1),  # 19
+            conv_dw(32, 64, 2),  # 27
+            conv_dw(64, 64, 1),  # 43
+        )
+        self.stage2 = nn.Sequential(
+            conv_dw(64, 128, 2),  # 43 + 16 = 59
+            conv_dw(128, 128, 1),  # 59 + 32 = 91
+            conv_dw(128, 128, 1),  # 91 + 32 = 123
+            conv_dw(128, 128, 1),  # 123 + 32 = 155
+            conv_dw(128, 128, 1),  # 155 + 32 = 187
+            conv_dw(128, 128, 1),  # 187 + 32 = 219
+        )
+        self.stage3 = nn.Sequential(
+            conv_dw(128, 256, 2),  # 219 +3 2 = 241
+            conv_dw(256, 256, 1),  # 241 + 64 = 301
+        )
+        self.avg = nn.AdaptiveAvgPool2d((1, 1))
+        self.fc = nn.Linear(256, 1000)
+
+    def forward(self, x):
+        x = self.stage1(x)
+        x = self.stage2(x)
+        x = self.stage3(x)
+        x = self.avg(x)
+        # x = self.model(x)
+        x = x.view(-1, 256)
+        x = self.fc(x)
+        return x
+
+
+class ClassHead(nn.Module):
+
+    def __init__(self, inchannels=512, num_anchors=3):
+        super(ClassHead, self).__init__()
+        self.num_anchors = num_anchors
+        self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0)
+
+    def forward(self, x):
+        out = self.conv1x1(x)
+        out = out.permute(0, 2, 3, 1).contiguous()
+
+        return out.view(out.shape[0], -1, 2)
+
+
+class BboxHead(nn.Module):
+
+    def __init__(self, inchannels=512, num_anchors=3):
+        super(BboxHead, self).__init__()
+        self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0)
+
+    def forward(self, x):
+        out = self.conv1x1(x)
+        out = out.permute(0, 2, 3, 1).contiguous()
+
+        return out.view(out.shape[0], -1, 4)
+
+
+class LandmarkHead(nn.Module):
+
+    def __init__(self, inchannels=512, num_anchors=3):
+        super(LandmarkHead, self).__init__()
+        self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0)
+
+    def forward(self, x):
+        out = self.conv1x1(x)
+        out = out.permute(0, 2, 3, 1).contiguous()
+
+        return out.view(out.shape[0], -1, 10)
+
+
+def make_class_head(fpn_num=3, inchannels=64, anchor_num=2):
+    classhead = nn.ModuleList()
+    for i in range(fpn_num):
+        classhead.append(ClassHead(inchannels, anchor_num))
+    return classhead
+
+
+def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2):
+    bboxhead = nn.ModuleList()
+    for i in range(fpn_num):
+        bboxhead.append(BboxHead(inchannels, anchor_num))
+    return bboxhead
+
+
+def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2):
+    landmarkhead = nn.ModuleList()
+    for i in range(fpn_num):
+        landmarkhead.append(LandmarkHead(inchannels, anchor_num))
+    return landmarkhead
diff --git a/vscodeformer/retinaface_utils.py b/vscodeformer/retinaface_utils.py
new file mode 100644
index 0000000..1324af3
--- /dev/null
+++ b/vscodeformer/retinaface_utils.py
@@ -0,0 +1,422 @@
+from itertools import product
+from math import ceil
+
+import numpy as np
+import torch
+import torchvision
+
+
+class PriorBox(object):
+
+    def __init__(self, cfg, image_size=None, phase='train'):
+        super(PriorBox, self).__init__()
+        self.min_sizes = cfg['min_sizes']
+        self.steps = cfg['steps']
+        self.clip = cfg['clip']
+        self.image_size = image_size
+        self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps]
+        self.name = 's'
+
+    def forward(self):
+        anchors = []
+        for k, f in enumerate(self.feature_maps):
+            min_sizes = self.min_sizes[k]
+            for i, j in product(range(f[0]), range(f[1])):
+                for min_size in min_sizes:
+                    s_kx = min_size / self.image_size[1]
+                    s_ky = min_size / self.image_size[0]
+                    dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
+                    dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
+                    for cy, cx in product(dense_cy, dense_cx):
+                        anchors += [cx, cy, s_kx, s_ky]
+
+        # back to torch land
+        output = torch.Tensor(anchors).view(-1, 4)
+        if self.clip:
+            output.clamp_(max=1, min=0)
+        return output
+
+
+def py_cpu_nms(dets, thresh):
+    """Pure Python NMS baseline."""
+    keep = torchvision.ops.nms(
+        boxes=torch.Tensor(dets[:, :4]),
+        scores=torch.Tensor(dets[:, 4]),
+        iou_threshold=thresh,
+    )
+
+    return list(keep)
+
+
+def point_form(boxes):
+    """ Convert prior_boxes to (xmin, ymin, xmax, ymax)
+    representation for comparison to point form ground truth data.
+    Args:
+        boxes: (tensor) center-size default boxes from priorbox layers.
+    Return:
+        boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
+    """
+    return torch.cat(
+        (
+            boxes[:, :2] - boxes[:, 2:] / 2,  # xmin, ymin
+            boxes[:, :2] + boxes[:, 2:] / 2),
+        1)  # xmax, ymax
+
+
+def center_size(boxes):
+    """ Convert prior_boxes to (cx, cy, w, h)
+    representation for comparison to center-size form ground truth data.
+    Args:
+        boxes: (tensor) point_form boxes
+    Return:
+        boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
+    """
+    return torch.cat(
+        (boxes[:, 2:] + boxes[:, :2]) / 2,  # cx, cy
+        boxes[:, 2:] - boxes[:, :2],
+        1)  # w, h
+
+
+def intersect(box_a, box_b):
+    """ We resize both tensors to [A,B,2] without new malloc:
+    [A,2] -> [A,1,2] -> [A,B,2]
+    [B,2] -> [1,B,2] -> [A,B,2]
+    Then we compute the area of intersect between box_a and box_b.
+    Args:
+      box_a: (tensor) bounding boxes, Shape: [A,4].
+      box_b: (tensor) bounding boxes, Shape: [B,4].
+    Return:
+      (tensor) intersection area, Shape: [A,B].
+    """
+    A = box_a.size(0)
+    B = box_b.size(0)
+    max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
+    min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2))
+    inter = torch.clamp((max_xy - min_xy), min=0)
+    return inter[:, :, 0] * inter[:, :, 1]
+
+
+def jaccard(box_a, box_b):
+    """Compute the jaccard overlap of two sets of boxes.  The jaccard overlap
+    is simply the intersection over union of two boxes.  Here we operate on
+    ground truth boxes and default boxes.
+    E.g.:
+        A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
+    Args:
+        box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
+        box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
+    Return:
+        jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
+    """
+    inter = intersect(box_a, box_b)
+    area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter)  # [A,B]
+    area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter)  # [A,B]
+    union = area_a + area_b - inter
+    return inter / union  # [A,B]
+
+
+def matrix_iou(a, b):
+    """
+    return iou of a and b, numpy version for data augenmentation
+    """
+    lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
+    rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
+
+    area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
+    area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
+    area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
+    return area_i / (area_a[:, np.newaxis] + area_b - area_i)
+
+
+def matrix_iof(a, b):
+    """
+    return iof of a and b, numpy version for data augenmentation
+    """
+    lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
+    rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
+
+    area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
+    area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
+    return area_i / np.maximum(area_a[:, np.newaxis], 1)
+
+
+def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx):
+    """Match each prior box with the ground truth box of the highest jaccard
+    overlap, encode the bounding boxes, then return the matched indices
+    corresponding to both confidence and location preds.
+    Args:
+        threshold: (float) The overlap threshold used when matching boxes.
+        truths: (tensor) Ground truth boxes, Shape: [num_obj, 4].
+        priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
+        variances: (tensor) Variances corresponding to each prior coord,
+            Shape: [num_priors, 4].
+        labels: (tensor) All the class labels for the image, Shape: [num_obj].
+        landms: (tensor) Ground truth landms, Shape [num_obj, 10].
+        loc_t: (tensor) Tensor to be filled w/ encoded location targets.
+        conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
+        landm_t: (tensor) Tensor to be filled w/ encoded landm targets.
+        idx: (int) current batch index
+    Return:
+        The matched indices corresponding to 1)location 2)confidence
+        3)landm preds.
+    """
+    # jaccard index
+    overlaps = jaccard(truths, point_form(priors))
+    # (Bipartite Matching)
+    # [1,num_objects] best prior for each ground truth
+    best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
+
+    # ignore hard gt
+    valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
+    best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
+    if best_prior_idx_filter.shape[0] <= 0:
+        loc_t[idx] = 0
+        conf_t[idx] = 0
+        return
+
+    # [1,num_priors] best ground truth for each prior
+    best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
+    best_truth_idx.squeeze_(0)
+    best_truth_overlap.squeeze_(0)
+    best_prior_idx.squeeze_(1)
+    best_prior_idx_filter.squeeze_(1)
+    best_prior_overlap.squeeze_(1)
+    best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2)  # ensure best prior
+    # TODO refactor: index  best_prior_idx with long tensor
+    # ensure every gt matches with its prior of max overlap
+    for j in range(best_prior_idx.size(0)):  # 判别此anchor是预测哪一个boxes
+        best_truth_idx[best_prior_idx[j]] = j
+    matches = truths[best_truth_idx]  # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来
+    conf = labels[best_truth_idx]  # Shape: [num_priors]      此处为每一个anchor对应的label取出来
+    conf[best_truth_overlap < threshold] = 0  # label as background   overlap<0.35的全部作为负样本
+    loc = encode(matches, priors, variances)
+
+    matches_landm = landms[best_truth_idx]
+    landm = encode_landm(matches_landm, priors, variances)
+    loc_t[idx] = loc  # [num_priors,4] encoded offsets to learn
+    conf_t[idx] = conf  # [num_priors] top class label for each prior
+    landm_t[idx] = landm
+
+
+def encode(matched, priors, variances):
+    """Encode the variances from the priorbox layers into the ground truth boxes
+    we have matched (based on jaccard overlap) with the prior boxes.
+    Args:
+        matched: (tensor) Coords of ground truth for each prior in point-form
+            Shape: [num_priors, 4].
+        priors: (tensor) Prior boxes in center-offset form
+            Shape: [num_priors,4].
+        variances: (list[float]) Variances of priorboxes
+    Return:
+        encoded boxes (tensor), Shape: [num_priors, 4]
+    """
+
+    # dist b/t match center and prior's center
+    g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
+    # encode variance
+    g_cxcy /= (variances[0] * priors[:, 2:])
+    # match wh / prior wh
+    g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
+    g_wh = torch.log(g_wh) / variances[1]
+    # return target for smooth_l1_loss
+    return torch.cat([g_cxcy, g_wh], 1)  # [num_priors,4]
+
+
+def encode_landm(matched, priors, variances):
+    """Encode the variances from the priorbox layers into the ground truth boxes
+    we have matched (based on jaccard overlap) with the prior boxes.
+    Args:
+        matched: (tensor) Coords of ground truth for each prior in point-form
+            Shape: [num_priors, 10].
+        priors: (tensor) Prior boxes in center-offset form
+            Shape: [num_priors,4].
+        variances: (list[float]) Variances of priorboxes
+    Return:
+        encoded landm (tensor), Shape: [num_priors, 10]
+    """
+
+    # dist b/t match center and prior's center
+    matched = torch.reshape(matched, (matched.size(0), 5, 2))
+    priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+    priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+    priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+    priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+    priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2)
+    g_cxcy = matched[:, :, :2] - priors[:, :, :2]
+    # encode variance
+    g_cxcy /= (variances[0] * priors[:, :, 2:])
+    # g_cxcy /= priors[:, :, 2:]
+    g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1)
+    # return target for smooth_l1_loss
+    return g_cxcy
+
+
+# Adapted from https://github.com/Hakuyume/chainer-ssd
+def decode(loc, priors, variances):
+    """Decode locations from predictions using priors to undo
+    the encoding we did for offset regression at train time.
+    Args:
+        loc (tensor): location predictions for loc layers,
+            Shape: [num_priors,4]
+        priors (tensor): Prior boxes in center-offset form.
+            Shape: [num_priors,4].
+        variances: (list[float]) Variances of priorboxes
+    Return:
+        decoded bounding box predictions
+    """
+
+    boxes = torch.cat((priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
+                       priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
+    boxes[:, :2] -= boxes[:, 2:] / 2
+    boxes[:, 2:] += boxes[:, :2]
+    return boxes
+
+
+def decode_landm(pre, priors, variances):
+    """Decode landm from predictions using priors to undo
+    the encoding we did for offset regression at train time.
+    Args:
+        pre (tensor): landm predictions for loc layers,
+            Shape: [num_priors,10]
+        priors (tensor): Prior boxes in center-offset form.
+            Shape: [num_priors,4].
+        variances: (list[float]) Variances of priorboxes
+    Return:
+        decoded landm predictions
+    """
+    tmp = (
+        priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
+        priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
+        priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
+        priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
+        priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
+    )
+    landms = torch.cat(tmp, dim=1)
+    return landms
+
+
+def batched_decode(b_loc, priors, variances):
+    """Decode locations from predictions using priors to undo
+    the encoding we did for offset regression at train time.
+    Args:
+        b_loc (tensor): location predictions for loc layers,
+            Shape: [num_batches,num_priors,4]
+        priors (tensor): Prior boxes in center-offset form.
+            Shape: [1,num_priors,4].
+        variances: (list[float]) Variances of priorboxes
+    Return:
+        decoded bounding box predictions
+    """
+    boxes = (
+        priors[:, :, :2] + b_loc[:, :, :2] * variances[0] * priors[:, :, 2:],
+        priors[:, :, 2:] * torch.exp(b_loc[:, :, 2:] * variances[1]),
+    )
+    boxes = torch.cat(boxes, dim=2)
+
+    boxes[:, :, :2] -= boxes[:, :, 2:] / 2
+    boxes[:, :, 2:] += boxes[:, :, :2]
+    return boxes
+
+
+def batched_decode_landm(pre, priors, variances):
+    """Decode landm from predictions using priors to undo
+    the encoding we did for offset regression at train time.
+    Args:
+        pre (tensor): landm predictions for loc layers,
+            Shape: [num_batches,num_priors,10]
+        priors (tensor): Prior boxes in center-offset form.
+            Shape: [1,num_priors,4].
+        variances: (list[float]) Variances of priorboxes
+    Return:
+        decoded landm predictions
+    """
+    landms = (
+        priors[:, :, :2] + pre[:, :, :2] * variances[0] * priors[:, :, 2:],
+        priors[:, :, :2] + pre[:, :, 2:4] * variances[0] * priors[:, :, 2:],
+        priors[:, :, :2] + pre[:, :, 4:6] * variances[0] * priors[:, :, 2:],
+        priors[:, :, :2] + pre[:, :, 6:8] * variances[0] * priors[:, :, 2:],
+        priors[:, :, :2] + pre[:, :, 8:10] * variances[0] * priors[:, :, 2:],
+    )
+    landms = torch.cat(landms, dim=2)
+    return landms
+
+
+def log_sum_exp(x):
+    """Utility function for computing log_sum_exp while determining
+    This will be used to determine unaveraged confidence loss across
+    all examples in a batch.
+    Args:
+        x (Variable(tensor)): conf_preds from conf layers
+    """
+    x_max = x.data.max()
+    return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max
+
+
+# Original author: Francisco Massa:
+# https://github.com/fmassa/object-detection.torch
+# Ported to PyTorch by Max deGroot (02/01/2017)
+def nms(boxes, scores, overlap=0.5, top_k=200):
+    """Apply non-maximum suppression at test time to avoid detecting too many
+    overlapping bounding boxes for a given object.
+    Args:
+        boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
+        scores: (tensor) The class predscores for the img, Shape:[num_priors].
+        overlap: (float) The overlap thresh for suppressing unnecessary boxes.
+        top_k: (int) The Maximum number of box preds to consider.
+    Return:
+        The indices of the kept boxes with respect to num_priors.
+    """
+
+    keep = torch.Tensor(scores.size(0)).fill_(0).long()
+    if boxes.numel() == 0:
+        return keep
+    x1 = boxes[:, 0]
+    y1 = boxes[:, 1]
+    x2 = boxes[:, 2]
+    y2 = boxes[:, 3]
+    area = torch.mul(x2 - x1, y2 - y1)
+    v, idx = scores.sort(0)  # sort in ascending order
+    # I = I[v >= 0.01]
+    idx = idx[-top_k:]  # indices of the top-k largest vals
+    xx1 = boxes.new()
+    yy1 = boxes.new()
+    xx2 = boxes.new()
+    yy2 = boxes.new()
+    w = boxes.new()
+    h = boxes.new()
+
+    # keep = torch.Tensor()
+    count = 0
+    while idx.numel() > 0:
+        i = idx[-1]  # index of current largest val
+        # keep.append(i)
+        keep[count] = i
+        count += 1
+        if idx.size(0) == 1:
+            break
+        idx = idx[:-1]  # remove kept element from view
+        # load bboxes of next highest vals
+        torch.index_select(x1, 0, idx, out=xx1)
+        torch.index_select(y1, 0, idx, out=yy1)
+        torch.index_select(x2, 0, idx, out=xx2)
+        torch.index_select(y2, 0, idx, out=yy2)
+        # store element-wise max with next highest score
+        xx1 = torch.clamp(xx1, min=x1[i])
+        yy1 = torch.clamp(yy1, min=y1[i])
+        xx2 = torch.clamp(xx2, max=x2[i])
+        yy2 = torch.clamp(yy2, max=y2[i])
+        w.resize_as_(xx2)
+        h.resize_as_(yy2)
+        w = xx2 - xx1
+        h = yy2 - yy1
+        # check sizes of xx1 and xx2.. after each iteration
+        w = torch.clamp(w, min=0.0)
+        h = torch.clamp(h, min=0.0)
+        inter = w * h
+        # IoU = i / (area(a) + area(b) - i)
+        rem_areas = torch.index_select(area, 0, idx)  # load remaining areas)
+        union = (rem_areas - inter) + area[i]
+        IoU = inter / union  # store result in iou
+        # keep only elements with an IoU <= overlap
+        idx = idx[IoU.le(overlap)]
+    return keep, count
diff --git a/vscodeformer/vqgan_arch.py b/vscodeformer/vqgan_arch.py
new file mode 100644
index 0000000..cb429ee
--- /dev/null
+++ b/vscodeformer/vqgan_arch.py
@@ -0,0 +1,426 @@
+'''
+VQGAN code, adapted from the original created by the Unleashing Transformers authors:
+https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
+
+'''
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def normalize(in_channels):
+    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+@torch.jit.script
+def swish(x):
+    return x*torch.sigmoid(x)
+
+
+#  Define VQVAE classes
+class VectorQuantizer(nn.Module):
+    def __init__(self, codebook_size, emb_dim, beta):
+        super(VectorQuantizer, self).__init__()
+        self.codebook_size = codebook_size  # number of embeddings
+        self.emb_dim = emb_dim  # dimension of embedding
+        self.beta = beta  # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
+        self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
+        self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
+
+    def forward(self, z):
+        # reshape z -> (batch, height, width, channel) and flatten
+        z = z.permute(0, 2, 3, 1).contiguous()
+        z_flattened = z.view(-1, self.emb_dim)
+
+        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+        d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
+            2 * torch.matmul(z_flattened, self.embedding.weight.t())
+
+        mean_distance = torch.mean(d)
+        # find closest encodings
+        min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
+        # min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
+        # [0-1], higher score, higher confidence
+        # min_encoding_scores = torch.exp(-min_encoding_scores/10)
+
+        min_encodings = z.new_zeros(min_encoding_indices.shape[0], self.codebook_size)
+        min_encodings.scatter_(1, min_encoding_indices, 1)
+
+        # get quantized latent vectors
+        z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
+        # compute loss for embedding
+        loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
+        # preserve gradients
+        z_q = z + (z_q - z).detach()
+
+        # perplexity
+        e_mean = torch.mean(min_encodings, dim=0)
+        perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
+        # reshape back to match original input shape
+        z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+        return z_q, loss, {
+            "perplexity": perplexity,
+            "min_encodings": min_encodings,
+            "min_encoding_indices": min_encoding_indices,
+            "mean_distance": mean_distance
+            }
+
+    def get_codebook_feat(self, indices, shape):
+        # input indices: batch*token_num -> (batch*token_num)*1
+        # shape: batch, height, width, channel
+        indices = indices.view(-1,1)
+        min_encodings = indices.new_zeros(indices.shape[0], self.codebook_size)
+        min_encodings.scatter_(1, indices, 1)
+        # get quantized latent vectors
+        z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
+
+        if shape is not None:  # reshape back to match original input shape
+            z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
+
+        return z_q
+
+
+class GumbelQuantizer(nn.Module):
+    def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
+        super().__init__()
+        self.codebook_size = codebook_size  # number of embeddings
+        self.emb_dim = emb_dim  # dimension of embedding
+        self.straight_through = straight_through
+        self.temperature = temp_init
+        self.kl_weight = kl_weight
+        self.proj = nn.Conv2d(num_hiddens, codebook_size, 1)  # projects last encoder layer to quantized logits
+        self.embed = nn.Embedding(codebook_size, emb_dim)
+
+    def forward(self, z):
+        hard = self.straight_through if self.training else True
+
+        logits = self.proj(z)
+
+        soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
+
+        z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
+
+        # + kl divergence to the prior loss
+        qy = F.softmax(logits, dim=1)
+        diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
+        min_encoding_indices = soft_one_hot.argmax(dim=1)
+
+        return z_q, diff, {
+            "min_encoding_indices": min_encoding_indices
+        }
+
+
+class Downsample(nn.Module):
+    def __init__(self, in_channels):
+        super().__init__()
+        self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+    def forward(self, x):
+        pad = (0, 1, 0, 1)
+        x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+        x = self.conv(x)
+        return x
+
+
+class Upsample(nn.Module):
+    def __init__(self, in_channels):
+        super().__init__()
+        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+    def forward(self, x):
+        x = F.interpolate(x, scale_factor=2.0, mode="nearest")
+        x = self.conv(x)
+
+        return x
+
+
+class ResBlock(nn.Module):
+    def __init__(self, in_channels, out_channels=None):
+        super(ResBlock, self).__init__()
+        self.in_channels = in_channels
+        self.out_channels = in_channels if out_channels is None else out_channels
+        self.norm1 = normalize(in_channels)
+        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+        self.norm2 = normalize(out_channels)
+        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+        if self.in_channels != self.out_channels:
+            self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+    def forward(self, x_in):
+        x = x_in
+        x = self.norm1(x)
+        x = swish(x)
+        x = self.conv1(x)
+        x = self.norm2(x)
+        x = swish(x)
+        x = self.conv2(x)
+        if self.in_channels != self.out_channels:
+            x_in = self.conv_out(x_in)
+
+        return x + x_in
+
+
+class AttnBlock(nn.Module):
+    def __init__(self, in_channels):
+        super().__init__()
+        self.in_channels = in_channels
+
+        self.norm = normalize(in_channels)
+        self.q = torch.nn.Conv2d(
+            in_channels,
+            in_channels,
+            kernel_size=1,
+            stride=1,
+            padding=0
+        )
+        self.k = torch.nn.Conv2d(
+            in_channels,
+            in_channels,
+            kernel_size=1,
+            stride=1,
+            padding=0
+        )
+        self.v = torch.nn.Conv2d(
+            in_channels,
+            in_channels,
+            kernel_size=1,
+            stride=1,
+            padding=0
+        )
+        self.proj_out = torch.nn.Conv2d(
+            in_channels,
+            in_channels,
+            kernel_size=1,
+            stride=1,
+            padding=0
+        )
+
+    def forward(self, x):
+        h_ = x
+        h_ = self.norm(h_)
+        q = self.q(h_)
+        k = self.k(h_)
+        v = self.v(h_)
+
+        # compute attention
+        b, c, h, w = q.shape
+        q = q.reshape(b, c, h*w)
+        q = q.permute(0, 2, 1)
+        k = k.reshape(b, c, h*w)
+        w_ = torch.bmm(q, k)
+        w_ = w_ * (int(c)**(-0.5))
+        w_ = F.softmax(w_, dim=2)
+
+        # attend to values
+        v = v.reshape(b, c, h*w)
+        w_ = w_.permute(0, 2, 1)
+        h_ = torch.bmm(v, w_)
+        h_ = h_.reshape(b, c, h, w)
+
+        h_ = self.proj_out(h_)
+
+        return x+h_
+
+
+class Encoder(nn.Module):
+    def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
+        super().__init__()
+        self.nf = nf
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        self.attn_resolutions = attn_resolutions
+
+        curr_res = self.resolution
+        in_ch_mult = (1,)+tuple(ch_mult)
+
+        blocks = []
+        # initial convultion
+        blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
+
+        # residual and downsampling blocks, with attention on smaller res (16x16)
+        for i in range(self.num_resolutions):
+            block_in_ch = nf * in_ch_mult[i]
+            block_out_ch = nf * ch_mult[i]
+            for _ in range(self.num_res_blocks):
+                blocks.append(ResBlock(block_in_ch, block_out_ch))
+                block_in_ch = block_out_ch
+                if curr_res in attn_resolutions:
+                    blocks.append(AttnBlock(block_in_ch))
+
+            if i != self.num_resolutions - 1:
+                blocks.append(Downsample(block_in_ch))
+                curr_res = curr_res // 2
+
+        # non-local attention block
+        blocks.append(ResBlock(block_in_ch, block_in_ch))
+        blocks.append(AttnBlock(block_in_ch))
+        blocks.append(ResBlock(block_in_ch, block_in_ch))
+
+        # normalise and convert to latent size
+        blocks.append(normalize(block_in_ch))
+        blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
+        self.blocks = nn.ModuleList(blocks)
+
+    def forward(self, x):
+        for block in self.blocks:
+            x = block(x)
+
+        return x
+
+
+class Generator(nn.Module):
+    def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
+        super().__init__()
+        self.nf = nf
+        self.ch_mult = ch_mult
+        self.num_resolutions = len(self.ch_mult)
+        self.num_res_blocks = res_blocks
+        self.resolution = img_size
+        self.attn_resolutions = attn_resolutions
+        self.in_channels = emb_dim
+        self.out_channels = 3
+        block_in_ch = self.nf * self.ch_mult[-1]
+        curr_res = self.resolution // 2 ** (self.num_resolutions-1)
+
+        blocks = []
+        # initial conv
+        blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
+
+        # non-local attention block
+        blocks.append(ResBlock(block_in_ch, block_in_ch))
+        blocks.append(AttnBlock(block_in_ch))
+        blocks.append(ResBlock(block_in_ch, block_in_ch))
+
+        for i in reversed(range(self.num_resolutions)):
+            block_out_ch = self.nf * self.ch_mult[i]
+
+            for _ in range(self.num_res_blocks):
+                blocks.append(ResBlock(block_in_ch, block_out_ch))
+                block_in_ch = block_out_ch
+
+                if curr_res in self.attn_resolutions:
+                    blocks.append(AttnBlock(block_in_ch))
+
+            if i != 0:
+                blocks.append(Upsample(block_in_ch))
+                curr_res = curr_res * 2
+
+        blocks.append(normalize(block_in_ch))
+        blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
+
+        self.blocks = nn.ModuleList(blocks)
+
+
+    def forward(self, x):
+        for block in self.blocks:
+            x = block(x)
+
+        return x
+
+
+class VQAutoEncoder(nn.Module):
+    def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
+                beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
+        super().__init__()
+        self.in_channels = 3
+        self.nf = nf
+        self.n_blocks = res_blocks
+        self.codebook_size = codebook_size
+        self.embed_dim = emb_dim
+        self.ch_mult = ch_mult
+        self.resolution = img_size
+        self.attn_resolutions = attn_resolutions
+        self.quantizer_type = quantizer
+        self.encoder = Encoder(
+            self.in_channels,
+            self.nf,
+            self.embed_dim,
+            self.ch_mult,
+            self.n_blocks,
+            self.resolution,
+            self.attn_resolutions
+        )
+        if self.quantizer_type == "nearest":
+            self.beta = beta #0.25
+            self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
+        elif self.quantizer_type == "gumbel":
+            self.gumbel_num_hiddens = emb_dim
+            self.straight_through = gumbel_straight_through
+            self.kl_weight = gumbel_kl_weight
+            self.quantize = GumbelQuantizer(
+                self.codebook_size,
+                self.embed_dim,
+                self.gumbel_num_hiddens,
+                self.straight_through,
+                self.kl_weight
+            )
+        self.generator = Generator(
+            self.nf,
+            self.embed_dim,
+            self.ch_mult,
+            self.n_blocks,
+            self.resolution,
+            self.attn_resolutions
+        )
+
+        if model_path is not None:
+            chkpt = torch.load(model_path, map_location='cpu')
+            if 'params_ema' in chkpt:
+                self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
+            elif 'params' in chkpt:
+                self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
+            else:
+                raise ValueError(f'Wrong params!')
+
+
+    def forward(self, x):
+        x = self.encoder(x)
+        quant, codebook_loss, quant_stats = self.quantize(x)
+        x = self.generator(quant)
+        return x, codebook_loss, quant_stats
+
+
+
+# patch based discriminator
+class VQGANDiscriminator(nn.Module):
+    def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
+        super().__init__()
+
+        layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
+        ndf_mult = 1
+        ndf_mult_prev = 1
+        for n in range(1, n_layers):  # gradually increase the number of filters
+            ndf_mult_prev = ndf_mult
+            ndf_mult = min(2 ** n, 8)
+            layers += [
+                nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
+                nn.BatchNorm2d(ndf * ndf_mult),
+                nn.LeakyReLU(0.2, True)
+            ]
+
+        ndf_mult_prev = ndf_mult
+        ndf_mult = min(2 ** n_layers, 8)
+
+        layers += [
+            nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
+            nn.BatchNorm2d(ndf * ndf_mult),
+            nn.LeakyReLU(0.2, True)
+        ]
+
+        layers += [
+            nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)]  # output 1 channel prediction map
+        self.main = nn.Sequential(*layers)
+
+        if model_path is not None:
+            chkpt = torch.load(model_path, map_location='cpu')
+            if 'params_d' in chkpt:
+                self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
+            elif 'params' in chkpt:
+                self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
+            else:
+                raise ValueError(f'Wrong params!')
+
+    def forward(self, x):
+        return self.main(x)