Skip to content

Commit ce433ed

Browse files
authored
Upgrade pytorch (meteofrance#86)
* upgrade pytorch and python to fit the docker image * upgrade requirements.txt * CI automatically gets the python version * upgrade lint requirements * upgrade python version in pyproject.toml * lint * add exception to bandit
1 parent ede4ba5 commit ce433ed

23 files changed

+71
-64
lines changed

.github/workflows/tests.yml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,15 @@ jobs:
1818

1919
steps:
2020
- uses: actions/checkout@v4
21-
- name: Set up Python 3.10
21+
22+
- name: Get python version from env.yaml
23+
id: python_version
24+
run: echo PYTHON_VERS=$(cat env.yaml | grep python= | sed 's/.*python=\([0-9]\.[0-9]*\.[0-9]*\).*/\1/') >> $GITHUB_OUTPUT
25+
26+
- name: Set up Python ${{ steps.python_version.outputs.PYTHON_VERS }}
2227
uses: actions/setup-python@v3
2328
with:
24-
python-version: "3.10"
29+
python-version: "${{ steps.python_version.outputs.PYTHON_VERS }}"
2530

2631
- name: Get pytorch version from env.yaml
2732
id: pytorch_version

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ ARG DOCKER_REGISTRY=docker.io
22
ARG TORCH_VERS=2.2.2
33
ARG CUDA_VERS=12.1
44

5-
FROM ${DOCKER_REGISTRY}/pytorch/pytorch:${TORCH_VERS}-cuda${CUDA_VERS}-cudnn8-devel
5+
FROM ${DOCKER_REGISTRY}/pytorch/pytorch:${TORCH_VERS}-cuda${CUDA_VERS}-cudnn9-devel
66

77
ARG INJECT_MF_CERT
88

env.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ channels:
33
- pytorch
44
- nvidia
55
dependencies:
6-
- python=3.10.13 # The reference python version for the project
6+
- python=3.11.9 # The reference python version for the project
77
- pytorch-cuda=12.1 # The reference cuda version for the project
8-
- pytorch==2.2.2 # The reference pytorch version for the project
8+
- pytorch==2.4.1 # The reference pytorch version for the project
99
- pip
1010
- pip:
1111
- -r requirements.txt

lint.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ pwd
77
flake8 --ignore E203,W503 --max-line-length 120 $1
88
isort --profile black -c $1
99
black --check $1
10-
bandit -ll -r --skip B402,B321 $1
10+
bandit -ll -r --skip B402,B321,B614 $1

py4cast/datasets/base.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Base classes defining our software components
33
and their interfaces
44
"""
5+
56
import warnings
67
from abc import ABC, abstractclassmethod, abstractmethod, abstractproperty
78
from copy import deepcopy
@@ -298,9 +299,11 @@ def index_select_dim(
298299
return NamedTensor(
299300
self.index_select_dim(dim_name, indices, bare_tensor=True),
300301
self.names,
301-
self.feature_names
302-
if dim_name != self.feature_dim_name
303-
else [self.feature_names[i] for i in indices],
302+
(
303+
self.feature_names
304+
if dim_name != self.feature_dim_name
305+
else [self.feature_names[i] for i in indices]
306+
),
304307
feature_dim_name=self.feature_dim_name,
305308
)
306309

@@ -614,7 +617,9 @@ class DatasetInfo:
614617
units: Dict[str, str] # d[shortname] = unit (str)
615618
weather_dim: int
616619
forcing_dim: int
617-
step_duration: float # Duration (in hour) of one step in the dataset. 0.25 means 15 minutes.
620+
step_duration: (
621+
float # Duration (in hour) of one step in the dataset. 0.25 means 15 minutes.
622+
)
618623
statics: Statics # A lot of static variable
619624
stats: Stats
620625
diff_stats: Stats

py4cast/datasets/dummy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
Can be used as a starting point to implement your own dataset.
44
inputs, outputs and forcing are filled with random tensors.
55
"""
6+
67
from dataclasses import dataclass
78
from functools import cached_property
89
from io import BytesIO

py4cast/datasets/poesy.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,9 @@ class PoesySettings:
241241
term: dict
242242
num_input_steps: int # = 2 # Number of input timesteps
243243
num_output_steps: int # = 1 # Number of output timesteps (= 0 for inference)
244-
num_inference_pred_steps: int = 0 # 0 in training config ; else used to provide future information about forcings
244+
num_inference_pred_steps: int = (
245+
0 # 0 in training config ; else used to provide future information about forcings
246+
)
245247
standardize: bool = False
246248
members: Tuple[int] = (0,)
247249

py4cast/datasets/smeagol.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,9 @@ class SmeagolSettings:
251251
term: dict # Terms used in this configuration. Should be present in nc files.
252252
num_input_steps: int # = 2 # Number of input timesteps
253253
num_output_steps: int # = 1 # Number of output timesteps (= 0 for inference)
254-
num_inference_pred_steps: int = 0 # 0 in training config ; else used to provide future information about forcings
254+
num_inference_pred_steps: int = (
255+
0 # 0 in training config ; else used to provide future information about forcings
256+
)
255257
standardize: bool = True # Do we need to standardize our data ?
256258
members: Tuple[int] = (
257259
0,

py4cast/ideas/minimal_leak.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
of Torch tensor on CPU.
55
Using numpy seems to work fine.
66
"""
7+
78
import gc
89
import json
910
import os

py4cast/ideas/recursive_dict_register.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55

66
class RegisterDictMixin:
7-
87
"""
98
Register dictionnaries.
109
Enable to recursively register dictionnary (or other object with a getitem method).

py4cast/lightning.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -477,9 +477,11 @@ def _common_step(
477477

478478
if scale_y:
479479
step_diff_std, step_diff_mean = self._step_diffs(
480-
self.output_feature_names
481-
if inference
482-
else batch.outputs.feature_names,
480+
(
481+
self.output_feature_names
482+
if inference
483+
else batch.outputs.feature_names
484+
),
483485
prev_states.device,
484486
)
485487

py4cast/models/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Abstract Base Class for all models
33
Contains also a few functionnality used in various model.
44
"""
5+
56
from abc import ABC, abstractproperty
67
from typing import Tuple
78

py4cast/models/nlam/create_mesh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def mk_2d_graph(xy, nx, ny):
107107

108108
# turn into directed graph
109109
dg = networkx.DiGraph(g)
110-
for (u, v) in g.edges():
110+
for u, v in g.edges():
111111
d = np.sqrt(np.sum((g.nodes[u]["pos"] - g.nodes[v]["pos"]) ** 2))
112112
dg.edges[u, v]["len"] = d
113113
dg.edges[u, v]["vdiff"] = g.nodes[u]["pos"] - g.nodes[v]["pos"]

py4cast/models/vision/conv.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Convolutional neural network models
33
for pn-ia.
44
"""
5+
56
from collections import OrderedDict
67
from dataclasses import dataclass
78
from functools import reduce

py4cast/models/vision/transformers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
and adapted to our needs (upsampler + extra settings)
55
"""
66

7-
87
from dataclasses import dataclass
98
from functools import partial
109
from math import sqrt
@@ -195,15 +194,15 @@ def forward(self, x, return_layer_outputs=False):
195194

196195
layer_outputs = []
197196
i = 0
198-
for (get_overlap_patches, overlap_embed, layers) in self.stages:
197+
for get_overlap_patches, overlap_embed, layers in self.stages:
199198
x = get_overlap_patches(x)
200199

201200
num_patches = x.shape[-1]
202201
ratio = int(sqrt((h * w) / num_patches))
203202
x = rearrange(x, "b c (h w) -> b c h w", h=h // ratio)
204203

205204
x = overlap_embed(x)
206-
for (attn, ff) in layers:
205+
for attn, ff in layers:
207206
x = attn(x) + x
208207
x = ff(x) + x
209208

py4cast/models/vision/unetrpp.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -676,9 +676,7 @@ def __init__(
676676
self.hidden_size = settings.hidden_size
677677
self.spatial_dims = settings.spatial_dims
678678
# Number of pixels after stem layer
679-
no_pixels = (input_shape[0] * input_shape[1]) // (
680-
settings.downsampling_rate**2
681-
)
679+
no_pixels = (input_shape[0] * input_shape[1]) // (settings.downsampling_rate**2)
682680
encoder_input_size = [
683681
no_pixels,
684682
no_pixels // 4,

py4cast_plugin_example.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
A simple plugin example for py4cast model with a Identity model.
33
"""
44

5-
65
from dataclasses import dataclass
76

87
import torch

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ authors = [
1010
]
1111
description = "Library to train a variety of Neural Network architectures on various weather forecasting datasets."
1212
readme = "README.md"
13-
requires-python = ">=3.10"
13+
requires-python = ">=3.11"
1414
classifiers = [
1515
"Programming Language :: Python :: 3",
1616
"License :: OSI Approved :: Apache 2",

requirements.txt

Lines changed: 24 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,30 @@
1-
numpy>=1.24.2
2-
wandb>=0.13.10
3-
matplotlib>=3.7.0
4-
scipy>=1.10.0
5-
pytorch-lightning>=2.1.2
6-
lightning==2.2.2
7-
shapely>=2.0.1
8-
networkx>=3.0
9-
Cartopy>=0.22.0
10-
pyproj>=3.4.1
11-
tueplots>=0.0.8
12-
plotly>=5.15.0
13-
tornado>=6.3.3
14-
cython>=3
15-
cfgrib==0.9.14.0
16-
dataclasses-json==0.6.4
17-
xarray==2024.7.0
1+
numpy==1.26.4
2+
matplotlib==3.9.2
3+
scipy==1.14.1
4+
lightning==2.4.0
5+
networkx==3.3
6+
Cartopy==0.24.1
7+
tueplots==0.0.17
8+
cfgrib==0.9.14.1
9+
dataclasses-json==0.6.7
10+
xarray==2024.9.0
1811
argparse-dataclass==2.0.0
19-
tensorboard==2.16.1
20-
typer==0.9.0
21-
netCDF4==1.6.5
12+
tensorboard==2.17.1
13+
typer==0.12.5
14+
netCDF4==1.7.1.post2
2215
tensorboard-plugin-profile==2.17.0
23-
torch-tb-profiler==0.4.1
24-
einops==0.7.0
16+
torch-tb-profiler==0.4.3
17+
einops==0.8.0
2518
torchinfo==1.8.0
2619
tabulate==0.9.0
27-
pytest==8.1.1
28-
coverage==7.6.1
29-
onnx==1.16.1
30-
onnxruntime==1.18.1
31-
onnxruntime-gpu==1.18.1
32-
onnxscript==0.1.0.dev20240905
33-
monai==1.3.1
20+
pytest==8.3.3
21+
coverage==7.6.3
22+
onnx==1.17.0
23+
onnxruntime-gpu==1.19.2
24+
onnxscript==0.1.0.dev20241018
25+
monai==1.4.0
3426
gif==23.3.0
3527
scikit-image==0.24.0
36-
typer==0.9.0
37-
transformers==4.44.2
38-
tensorflow==2.16.1
39-
py-spy
40-
torch-geometric==2.3.1
28+
typer==0.12.5
29+
transformers==4.45.2
30+
torch-geometric==2.6.1

requirements_lint.txt

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
pylint
2-
bandit==1.7.4
3-
flake8==4.0.1
4-
black==22.3.0
5-
isort==5.11.2
1+
bandit==1.7.10
2+
flake8==7.1.1
3+
black==24.10.0
4+
isort==5.13.2

tests/test_datasets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Unit tests for datasets and NamedTensor.
33
"""
4+
45
import datetime
56

67
import numpy as np

tests/test_io.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Unit tests for datasets and NamedTensor.
33
"""
4+
45
from dataclasses import dataclass
56
from functools import cached_property
67

tests/test_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
3. onnx exported
66
4. onnx loaded and used for inference
77
"""
8+
89
import tempfile
910
from dataclasses import dataclass
1011
from pathlib import Path

0 commit comments

Comments
 (0)