Skip to content

Commit 6f61614

Browse files
authored
Add accelerate API support for Super Resolution example (#1358)
Update super_resolution example to support accelerate API
1 parent 2944a9d commit 6f61614

File tree

4 files changed

+15
-23
lines changed

4 files changed

+15
-23
lines changed

run_python_examples.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ function fx() {
136136
}
137137

138138
function super_resolution() {
139-
uv run main.py --upscale_factor 3 --batchSize 4 --testBatchSize 100 --nEpochs 1 --lr 0.001 --mps || error "super resolution failed"
140-
uv run super_resolve.py --input_image dataset/BSDS300/images/test/16077.jpg --model model_epoch_1.pth --output_filename out.png || error "super resolution upscaling failed"
139+
uv run main.py --upscale_factor 3 --batchSize 4 --testBatchSize 100 --nEpochs 1 --lr 0.001 $ACCEL_FLAG || error "super resolution failed"
140+
uv run super_resolve.py --input_image dataset/BSDS300/images/test/16077.jpg --model model_epoch_1.pth --output_filename out.png $ACCEL_FLAG || error "super resolution upscaling failed"
141141
}
142142

143143
function time_sequence_prediction() {

super_resolution/README.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ This example illustrates how to use the efficient sub-pixel convolution layer de
55
```
66
usage: main.py [-h] --upscale_factor UPSCALE_FACTOR [--batchSize BATCHSIZE]
77
[--testBatchSize TESTBATCHSIZE] [--nEpochs NEPOCHS] [--lr LR]
8-
[--cuda] [--threads THREADS] [--seed SEED]
8+
[--accel] [--threads THREADS] [--seed SEED]
99
1010
PyTorch Super Res Example
1111
@@ -16,8 +16,7 @@ optional arguments:
1616
--testBatchSize testing batch size
1717
--nEpochs number of epochs to train for
1818
--lr Learning Rate. Default=0.01
19-
--cuda use cuda
20-
--mps enable GPU on macOS
19+
--accel use accelerator
2120
--threads number of threads for data loader to use Default=4
2221
--seed random seed to use. Default=123
2322
```
@@ -29,11 +28,11 @@ This example trains a super-resolution network on the [BSD300 dataset](https://w
2928
### Train
3029

3130
```bash
32-
python main.py --upscale_factor 3 --batchSize 4 --testBatchSize 100 --nEpochs 30 --lr 0.001
31+
python main.py --upscale_factor 3 --batchSize 4 --testBatchSize 100 --nEpochs 30 --lr 0.001 --accel
3332
```
3433

3534
### Super Resolve
3635

3736
```bash
38-
python super_resolve.py --input_image dataset/BSDS300/images/test/16077.jpg --model model_epoch_30.pth --output_filename out.png
37+
python super_resolve.py --input_image dataset/BSDS300/images/test/16077.jpg --model model_epoch_30.pth --output_filename out.png --accel
3938
```

super_resolution/main.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,18 @@
1616
parser.add_argument('--testBatchSize', type=int, default=10, help='testing batch size')
1717
parser.add_argument('--nEpochs', type=int, default=2, help='number of epochs to train for')
1818
parser.add_argument('--lr', type=float, default=0.01, help='Learning Rate. Default=0.01')
19-
parser.add_argument('--cuda', action='store_true', help='use cuda?')
20-
parser.add_argument('--mps', action='store_true', default=False, help='enables macOS GPU training')
19+
parser.add_argument('--accel', action='store_true', help='Enables acceleration for training, if available')
2120
parser.add_argument('--threads', type=int, default=4, help='number of threads for data loader to use')
2221
parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123')
2322
opt = parser.parse_args()
2423

2524
print(opt)
2625

27-
if opt.cuda and not torch.cuda.is_available():
28-
raise Exception("No GPU found, please run without --cuda")
29-
if not opt.mps and torch.backends.mps.is_available():
30-
raise Exception("Found mps device, please run with --mps to enable macOS GPU")
31-
3226
torch.manual_seed(opt.seed)
33-
use_mps = opt.mps and torch.backends.mps.is_available()
3427

35-
if opt.cuda:
36-
device = torch.device("cuda")
37-
elif use_mps:
38-
device = torch.device("mps")
28+
29+
if opt.accel and torch.accelerator.is_available():
30+
device = torch.accelerator.current_accelerator()
3931
else:
4032
device = torch.device("cpu")
4133

super_resolution/super_resolve.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
parser.add_argument('--input_image', type=str, required=True, help='input image to use')
1313
parser.add_argument('--model', type=str, required=True, help='model file to use')
1414
parser.add_argument('--output_filename', type=str, help='where to save the output image')
15-
parser.add_argument('--cuda', action='store_true', help='use cuda')
15+
parser.add_argument('--accel', action='store_true', help='Enables acceleration device, if available')
1616
opt = parser.parse_args()
1717

1818
print(opt)
@@ -32,9 +32,10 @@
3232
img_to_tensor = ToTensor()
3333
input = img_to_tensor(y).view(1, -1, y.size[1], y.size[0])
3434

35-
if opt.cuda:
36-
model = model.cuda()
37-
input = input.cuda()
35+
if opt.accel:
36+
device = torch.accelerator.current_accelerator()
37+
model = model.to(device)
38+
input = input.to(device)
3839

3940
out = model(input)
4041
out = out.cpu()

0 commit comments

Comments
 (0)