Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Only last dimension is supported for softmax #50

Open
filipposkat opened this issue Jan 17, 2024 · 8 comments
Open

RuntimeError: Only last dimension is supported for softmax #50

filipposkat opened this issue Jan 17, 2024 · 8 comments

Comments

@filipposkat
Copy link

Thank you for all your effort, OpenCL is really the only option for old AMD gpus.
I have successfully managed to build your code on windows by following the instructions. There were some minor issues I would like to note in case you want to update README accordingly:

  1. "-DCMAKE_INCLUDE_PATH=c:\deps\include\include" should be "-DCMAKE_INCLUDE_PATH=c:\deps\include"
  2. Adding sqlit3.dll to path did not work for me (file not found error), probably due to this.
    I had to add this piece of code to my script:
    os.add_dll_directory("C:\deps\lib")

This was a sidenote, the real issue I have is that I get an error when trying to train a simple UNET using CrossEntropyLoss:

  File "G:\filip\Documents\Data Science Projects\Thesis\apnea-ppg\trainer.py", line 149, in train_loop
    loss = criterion(outputs, labels.long())
  File "C:\Users\filip\anaconda3\envs\torch_113\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\filip\anaconda3\envs\torch_113\lib\site-packages\torch\nn\modules\loss.py", line 1174, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "C:\Users\filip\anaconda3\envs\torch_113\lib\site-packages\torch\nn\functional.py", line 3026, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
RuntimeError: Only last dimension is supported for softmax

where criterion is: nn.CrossEntropyLoss()
Is this normal?

Details:
Number of classes: 5
My batches have this shape: (256, 1, 512)
Output has shape: (256, 5, 512)
Labels have shape: (256, 512)

Specs:
OS: Windows 10
PyTorch version: 1.13
Python version: 3.10
GPU: RX 590

@artyom-beilis
Copy link
Owner

Can you use nll_loss - I implemeted full cross range for log loss but not for direct softmax (since it is rarely used as is)

@filipposkat
Copy link
Author

Changed to:

m = nn.LogSoftmax(dim=1)
criterion = nn.NLLLoss()
loss = criterion(m(outputs), labels.long())

and I get:

  File "G:\filip\Documents\Data Science Projects\Thesis\apnea-ppg\trainer.py", line 155, in train_loop
    loss = criterion(m(outputs), labels.long())
  File "C:\Users\filip\anaconda3\envs\torch_113\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\filip\anaconda3\envs\torch_113\lib\site-packages\torch\nn\modules\activation.py", line 1461, in forward
    return F.log_softmax(input, self.dim, _stacklevel=5)
  File "C:\Users\filip\anaconda3\envs\torch_113\lib\site-packages\torch\nn\functional.py", line 1930, in log_softmax
    ret = input.log_softmax(dim)
RuntimeError: Only last dimension is supported for softmax

@artyom-beilis
Copy link
Owner

Do you have latest version?

@artyom-beilis
Copy link
Owner

I added it around Nov 26.

@filipposkat
Copy link
Author

Yes, cloned yesterday

@artyom-beilis
Copy link
Owner

Ok let me see what I can do.

Can you give me full example of your loss function?

@filipposkat
Copy link
Author

filipposkat commented Jan 17, 2024

Code for training one epoch:

    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.utils.data import DataLoader

    device = "privateuseone:0"

    # Create Network:
    net = UNet(nclass=5, in_chans=1, max_channels=512, depth=5, layers=2, kernel_size=5, sampling_method="conv_stride")
    net = net.to(device)

    criterion = nn.CrossEntropyLoss()

    optim_kwargs = {"lr": 0.001, "momentum": 0.7}
    optimizer = optim.SGD(net.parameters(), **optim_kwargs)


    for (i, data) in enumerate(train_dataloader):
        inputs, labels = data

        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = net(inputs)

        loss = criterion(outputs, labels.long())

        loss.backward()
        optimizer.step()

Code of UNet class:

class EncoderBlock(nn.Module):
    # Consists of Conv -> LeakyReLU(0.2) -> MaxPool
    def __init__(self, in_chans, out_chans, layers=2, kernel_size=3, sampling_factor=2, sampling_method="pooling"):
        super().__init__()
        self.in_chans = in_chans
        self.out_chans = out_chans
        self.layers = layers
        self.kernel_size = kernel_size
        self.sampling_factor = sampling_factor
        self.sampling_method = sampling_method

        self.encoder = nn.ModuleList()

        for _ in range(layers):
            self.encoder.append(nn.Conv1d(in_chans, out_chans, kernel_size=kernel_size, stride=1, padding="same"))
            self.encoder.append(nn.BatchNorm1d(out_chans))
            self.encoder.append(nn.LeakyReLU(0.2))
            in_chans = out_chans

        if sampling_method == "conv_stride" and sampling_factor > 1:
            diff = (kernel_size - sampling_factor)
            if diff % 2 == 0:
                ks = kernel_size
                pad = diff // 2
            else:
                ks = kernel_size + 1
                pad = (ks - sampling_factor) // 2

            self.encoder.append(nn.Sequential(
                nn.Conv1d(out_chans, out_chans, kernel_size=ks,
                          stride=sampling_factor, padding=pad),
                nn.BatchNorm1d(out_chans),
                nn.LeakyReLU(0.2)))
        elif sampling_method == "pooling" and sampling_factor > 1:
            self.encoder.append(nn.MaxPool1d(sampling_factor))

    def forward(self, x):
        for enc in self.encoder:
            x = enc(x)
        return x


class DecoderBlock(nn.Module):
    # Consists of 2x2 transposed convolution -> Conv -> LeakyReLU(0.2)
    def __init__(self, in_chans, out_chans, layers=2, kernel_size=3, skip_connection=True, sampling_factor=2,
                 dropout=0.0):
        super().__init__()
        self.skip_connection = skip_connection
        self.layers = layers
        self.kernel_size = kernel_size
        self.padding = "same"
        self.dropout = dropout

        skip_factor = 1 if skip_connection else 2
        self.decoder = nn.ModuleList()

        if kernel_size == sampling_factor:
            pad = 0
            out_pad = 0
        elif (kernel_size - sampling_factor) % 2 == 0:
            pad = (kernel_size - sampling_factor) // 2
            out_pad = 0
        else:
            out_pad = 1
            pad = (kernel_size - sampling_factor + out_pad) // 2

        assert (kernel_size - sampling_factor - 2 * pad + out_pad) == 0
        self.tconv = nn.ConvTranspose1d(in_chans, in_chans // 2, stride=sampling_factor, kernel_size=kernel_size,
                                        padding=pad, output_padding=out_pad)

        self.decoder.append(nn.Conv1d(in_chans // skip_factor, out_chans, kernel_size, 1, padding="same"))
        self.decoder.append(nn.BatchNorm1d(out_chans))
        self.decoder.append(nn.LeakyReLU(0.2))
        for _ in range(layers - 1):
            self.decoder.append(nn.Conv1d(out_chans, out_chans, kernel_size, 1, padding="same"))
            self.decoder.append(nn.BatchNorm1d(out_chans))
            self.decoder.append(nn.LeakyReLU(0.2))

        if dropout > 0.0:
            self.decoder.append(nn.Dropout(p=dropout))

    def forward(self, x, enc_features=None):
        x = self.tconv(x)
        if self.skip_connection:
            x = torch.cat((enc_features, x), dim=1)
        for dec in self.decoder:
            x = dec(x)
        return x


class UNet(nn.Module):
    def __init__(self, nclass=1, in_chans=1, max_channels=512, depth=5, layers=2, kernel_size=3, sampling_factor=2,
                 sampling_method="pooling", skip_connection=True):
        """
        :param nclass:
        :param in_chans:
        :param max_channels:
        :param depth:
        :param layers:
        :param kernel_size:
        :param sampling_factor:
        :param sampling_method: either "pooling" or "conv_stride"
        :param skip_connection:
        """
        super().__init__()
        self.nclass = nclass
        self.in_chans = in_chans
        self.max_channels = max_channels
        self.depth = depth
        self.layers = layers
        self.kernel_size = kernel_size
        self.sampling_factor = sampling_factor
        self.sampling_method = sampling_method
        self.skip_connection = skip_connection

        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()

        first_out_chans = max_channels // (2 ** (depth - 1))
        if first_out_chans % 4 == 0:
            out_chans = first_out_chans
        else:
            out_chans = 4

        # The first block should not do any down-sampling (stride = 1 and no pooling):
        self.encoder.append(EncoderBlock(in_chans, out_chans, layers, kernel_size=kernel_size,
                                         sampling_factor=1, sampling_method="no sampling"))

        for _ in range(depth - 1):
            if out_chans * 2 <= max_channels:
                in_chans, out_chans = out_chans, out_chans * 2
            else:
                in_chans, out_chans = out_chans, out_chans
            self.encoder.append(EncoderBlock(in_chans, out_chans, layers=layers, kernel_size=kernel_size,
                                             sampling_factor=sampling_factor, sampling_method=sampling_method))

        for _ in range(depth - 1):
            if out_chans // 2 >= 4:
                in_chans, out_chans = out_chans, out_chans // 2
            else:
                in_chans, out_chans = out_chans, out_chans
            self.decoder.append(DecoderBlock(in_chans, out_chans, layers=layers, kernel_size=kernel_size,
                                             sampling_factor=sampling_factor))

        # Add a 1x1 convolution to produce final classes
        self.logits = nn.Conv1d(out_chans, nclass, 1, 1)

    def forward(self, x):
        encoded = []
        for enc in self.encoder:
            x = enc(x)
            encoded.append(x)

        # Last encoder output is not used in any skip_connection:
        _ = encoded.pop()

        for dec in self.decoder:
            enc_output = encoded.pop()
            x = dec(x, enc_output)

        # Return the logits
        return self.logits(x)

train_loader yields two tensors (X, y) where

  • X has shape: (256, 1, 512) and contains float32 numbers
  • y has shape: (256, 512) and contains uint8 integers ranging from 0 to 4 (5 classes)

@artyom-beilis
Copy link
Owner

Got it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants