Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

ValueError: Input 0 of layer "generator" is incompatible with the layer #231

Open
loftusa opened this issue May 3, 2022 · 1 comment
Open

Comments

@loftusa
Copy link

loftusa commented May 3, 2022

Hi, I'm getting this error when I try to predict after training:

ValueError: Input 0 of layer "generator" is incompatible with the layer: expected shape=(None, 40, 40, 3), found shape=(None, 64, 64, 3)

I trained on a set of 64x64 images, with 512x512 upscaled versions. I split the original full set of 64x64 images into a training set and a validation set, and tried to predict with the validation set. That's when I got this error. I'm not sure why the generator is expecting a 40x40 image as input, given that these weights were trained on 64x64 images.

Here is the full code for training / running:

from ISR.train import Trainer
from ISR.models import RRDN, Cut_VGG19, Discriminator
import os
from PIL import Image
import numpy as np

loss_weights = {'generator': 0.0, 'feature_extractor': 0.0833, 'discriminator': 0.01}
losses = {'generator': 'mae', 'feature_extractor': 'mse', 'discriminator': 'binary_crossentropy'}
log_dirs = {'logs': '/workspace/image-super-resolution/logs', 'weights': '/workspace/image-super-resolution/weights'}
learning_rate = {'initial_value': 0.0004, 'decay_factor': 0.5, 'decay_frequency': 30}
flatness = {'min': 0.0, 'max': 0.15, 'increase': 0.01, 'increase_frequency': 5}

# model hyperparams
lr_train_patch_size = 40
layers_to_extract = [5, 9]
scale = 4
hr_train_patch_size = lr_train_patch_size * scale
# pretrained_weights_loc = "/workspace/rrdn-C4-D3-G32-G032-T10-x4_epoch299.hdf5"

arch_params = {'C': 4, 'D': 3, 'G': 32, 'G0':32, 'T': 10, 'x': scale}
rrdn = RRDN(arch_params=arch_params, patch_size=lr_train_patch_size, weights='gans')
f_ext = Cut_VGG19(patch_size=hr_train_patch_size, layers_to_extract=layers_to_extract)
discr = Discriminator(patch_size=hr_train_patch_size, kernel_size=3)


trainer = Trainer(generator=rrdn, 
discriminator=discr, 
feature_extractor=f_ext,
lr_train_dir="/workspace/IR_preprocessed/train/lr_64", 
hr_train_dir="/workspace/IR_preprocessed/train/hr_512",
loss_weights=loss_weights, 
learning_rate=learning_rate, 
flatness=flatness,
dataname="IR_dataset", 
log_dirs=log_dirs, 
weights_generator=None,
weights_discriminator=None,
n_validation=40, 
lr_valid_dir="/workspace/IR_preprocessed/val/lr_64",
hr_valid_dir="/workspace/IR_preprocessed/val/hr_512", 
)

trainer.train(epochs=3, steps_per_epoch=10, batch_size=16, monitored_metrics={"val_generator_PSNR_Y": "max"})

run validation
saved_weights = "/workspace/image-super-resolution/weights/rrdn-C4-D3-G32-G032-T10-x4/2022-05-03_1809/rrdn-C4-D3-G32-G032-T10-x4_best-val_generator_PSNR_Y_epoch003.hdf5"
rrdn.model.load_weights(saved_weights)
lr_valid_dir = "/workspace/IR_preprocessed/val/lr_64"
for imgfile in os.listdir(lr_valid_dir):
    if imgfile.endswith(".png"):
        imgfile = os.path.join(lr_valid_dir, imgfile)
        print(f"processing {imgfile}...")
        img = Image.open(imgfile)
        lr_img = np.array(img)
        sr_img = rrdn.predict(lr_img)

@md-rifatkhan
Copy link

md-rifatkhan commented Oct 19, 2023

I have a sane problem, now fixed with this: if lr_img.shape[2] == 4:
# If it has four channels, remove the alpha channel (assumed to be the fourth channel)
lr_img = lr_img[:, :, :3]

full code:

 import numpy as np
from PIL import Image
from ISR.models import RRDN, rdn

# Load the image
img = Image.open('dd.png')
# Resize the image to match the expected input shape (e.g., 3 color channels)

# Convert the image to a NumPy array
lr_img = np.array(img)
if lr_img.shape[2] == 4:
    # If it has four channels, remove the alpha channel (assumed to be the fourth channel)
    lr_img = lr_img[:, :, :3]
# Load the pre-trained RDN model
# rdn = RDN(weights='noise-cancel')
rrdn = RRDN(weights='gans')



sr_img = rrdn.predict(lr_img)

# Convert the NumPy array back to an Image object
sr_img = Image.fromarray(sr_img)

# Save the super-resolved image
sr_img.save('output.png')

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants