Skip to content

Commit a69b771

Browse files
authoredMar 21, 2025··
Ume datamodule - allow downloads of HF datasets (#50)
* config * train script * train script * train.sh * remove train.sh * download * download=true * env vars * slurm * data dir * data dir * data dir
1 parent 969afe2 commit a69b771

File tree

6 files changed

+74
-8
lines changed

6 files changed

+74
-8
lines changed
 

‎.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,5 @@ dev
3636
outputs
3737
wandb
3838
lightning_logs
39+
40+
.env

‎last.ckpt

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:3305e5c9d1885db6c97e86372aeaeb02d7fec4d465bef5779a0d74bd0a7417a9
3+
size 137159028

‎slurm/README.md

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Running LBSTER Jobs with SLURM
2+
3+
This guide explains how to run a `lobster` training job using SLURM on a GPU-enabled system. It also describes which environment variables need to be exported for the job to run properly.
4+
5+
# SLURM Job Script
6+
The provided example job script `scripts/train_ume.sh` is configured up for training the `Ume` model on a GPU-enabled SLURM cluster.
7+
8+
You will need to set specific environment variables to run the job. These will be read by the `Ume` hydra configuration file, which is located at `src/lobster/hydra_config/experiment/train_ume.yaml`.
9+
10+
Variables:
11+
12+
* `LOBSTER_DATA_DIR`: Path to the directory containing your training data. Datasets will be downloaded and cached to this directory (if `data.download` is set to `True` in the hydra configuration file).
13+
* `LOBSTER_RUNS_DIR`: Path to the directory where training results (model checkpoints, logs, etc.) will be stored.
14+
* `LOBSTER_USER`: The user entity for the logger (usually your wandb username).
15+
* `WANDB_BASE_URL`: The base URL for the Weights & Biases service. Optional - only needed if you wandb account is not on the default wandb server.
16+
17+
Example:
18+
```bash
19+
export LOBSTER_DATA_DIR="/data/lobster/ume/data"
20+
export LOBSTER_RUNS_DIR="/data/lobster/ume/runs"
21+
export LOBSTER_USER=$(whoami)
22+
export WANDB_BASE_URL=https://your_org.wandb.io/
23+
```
24+
25+

‎slurm/scripts/train_ume.sh

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#!/usr/bin/env bash
2+
3+
#SBATCH --partition gpu2
4+
#SBATCH --nodes 1
5+
#SBATCH --ntasks-per-node 2
6+
#SBATCH --gpus-per-node 2
7+
#SBATCH --cpus-per-task 8
8+
#SBATCH -o slurm/logs/%J.out
9+
# srun hostname
10+
11+
nvidia-smi
12+
13+
source .venv/bin/activate
14+
echo "SLURM_JOB_ID = ${SLURM_JOB_ID}"
15+
16+
export WANDB_INSECURE_DISABLE_SSL=true
17+
export HYDRA_FULL_ERROR=1
18+
export PYTHONUNBUFFERED=1
19+
20+
export TOKENIZERS_PARALLELISM=true
21+
22+
srun -u --cpus-per-task 8 --cpu-bind=cores,verbose \
23+
lobster_train experiment=train_ume \
24+
logger.entity="$(whoami)"
25+
26+

‎src/lobster/data/_ume_datamodule.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __post_init__(self):
5050
supported_splits={Split.TRAIN, Split.TEST},
5151
train_size=19_400_000,
5252
test_size=1_000_000,
53-
kwargs={"download": False, "keys": ["smiles"]},
53+
kwargs={"keys": ["smiles"]},
5454
),
5555
DatasetInfo(
5656
name="Calm",
@@ -85,6 +85,7 @@ def __init__(
8585
tokenizer_max_length: int,
8686
*,
8787
datasets: None | Sequence[str] = None,
88+
download: bool = False,
8889
root: Path | str | None = None,
8990
seed: int = 0,
9091
batch_size: int = 1,
@@ -104,6 +105,10 @@ def __init__(
104105
datasets : None | Sequence[str], optional
105106
List of dataset names to use. If None, all supported datasets will be used.
106107
Example: ["M320M", "Calm", "AMPLIFY", "Pinder"]
108+
download: bool, optional
109+
If True, will download the datasets first and stream locally.
110+
Otherwise, streams directly from Hugging Face.
111+
Downloaded datasets are cached in the `root` directory.
107112
root : Path | str | None, optional
108113
Root directory where the datasets are stored. If None, the default directory will be used.
109114
seed : int, optional
@@ -156,6 +161,7 @@ def __init__(
156161
self._stopping_condition = stopping_condition
157162
self._sample = sample
158163
self._weights = weights
164+
self._download = download
159165

160166
# Initialize tokenizer transforms for each modality
161167
tokenizer_instances = {
@@ -195,6 +201,7 @@ def _get_dataset(self, dataset_info: DatasetInfo, split: Split) -> Dataset:
195201

196202
return dataset_class(
197203
root=self._root,
204+
download=self._download,
198205
transform=transform,
199206
split=split.value,
200207
shuffle=(split == Split.TRAIN),

‎src/lobster/hydra_config/experiment/train_ume.yaml

+10-7
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,20 @@ compile: false
1313

1414
data:
1515
_target_: lobster.data.UmeLightningDataModule
16-
root: ${paths.root_dir}/data
16+
root: ${oc.env:LOBSTER_DATA_DIR}
1717
datasets: ["M320M", "Calm", "AMPLIFY", "Pinder"]
1818
batch_size: 128
1919
tokenizer_max_length: ${model.max_length}
2020
pin_memory: true
2121
shuffle_buffer_size: 1000
2222
num_workers: 4
2323
seed: 0
24-
sample: false # if false, uses RoundRobinConcatIterableDataset, else MultiplexedSamplingDataset
25-
stopping_condition: min # min or max, used only if sample is false
26-
weights: null # used only if sample is true, if null and sample is true, samples with weights based on dataset sizes
24+
download: true
25+
sample: true
26+
weights: null
2727

2828
paths:
29-
root_dir: ./runs
29+
root_dir: ${oc.env:LOBSTER_RUNS_DIR}
3030

3131
trainer:
3232
max_steps: 50_000
@@ -39,7 +39,7 @@ trainer:
3939
devices: auto
4040

4141
model:
42-
model_name: UME_mini
42+
model_name: UME_small
4343
vocab_size: 1472
4444
pad_token_id: 1
4545
cls_token_id: 0
@@ -59,10 +59,13 @@ model:
5959
callbacks:
6060
moleculeace_linear_probe:
6161
max_length: ${model.max_length}
62+
run_every_n_epochs: 1
6263
calm_linear_probe:
6364
max_length: ${model.max_length}
65+
run_every_n_epochs: 1
6466

6567
logger:
6668
name: ume_${model.model_name}_${now:%Y-%m-%d_%H-%M-%S}
6769
project: lobster
68-
group: ume-dev-${now:%Y-%m-%d-%H-%M-%S}
70+
group: ume-dev-${now:%Y-%m-%d-%H-%M-%S}
71+
entity: ${oc.env:LOBSTER_USER}

0 commit comments

Comments
 (0)
Please sign in to comment.