This project provides a flexible template for PyTorch-based machine learning experiments. It includes configuration management, logging with Weights & Biases (wandb), hyperparameter optimization with Optuna, and a modular structure for easy customization and experimentation.
config.py
: Defines theRunConfig
andOptimizeConfig
classes for managing experiment configurations and optimization settings.main.py
: The entry point of the project, handling command-line arguments and experiment execution.model.py
: Contains the model architecture (currently an MLP).util.py
: Utility functions for data loading, device selection, training, and analysis.configs/run_template.yaml
: Template for run configuration.configs/optimize_template.yaml
: Template for optimization configuration.analyze.py
: Script for analyzing completed runs and optimizations, utilizing functions fromutil.py
.
-
Clone the repository:
git clone https://github.com/yourusername/pytorch_template.git cd pytorch_template
-
Install the required packages:
# Use pip pip install torch wandb survey polars numpy optuna matplotlib scienceplots # Or Use uv with sync requirements.txt (recommended) uv pip sync requirements.txt # Or Use uv (fresh install) uv pip install -U torch wandb survey polars numpy optuna matplotlib scienceplots
-
(Optional) Set up a Weights & Biases account for experiment tracking.
-
Configure your experiment by modifying
configs/run_template.yaml
or creating a new YAML file based on it. -
(Optional) Configure hyperparameter optimization by modifying
configs/optimize_template.yaml
or creating a new YAML file based on it. -
Run the experiment:
python main.py --run_config path/to/run_config.yaml [--optimize_config path/to/optimize_config.yaml]
If
--optimize_config
is provided, the script will perform hyperparameter optimization using Optuna. -
Analyze the results:
python analyze.py
project
: Project name for wandb loggingdevice
: Device to run on (e.g., 'cpu', 'cuda:0')net
: Model class to useoptimizer
: Optimizer classscheduler
: Learning rate scheduler classepochs
: Number of training epochsbatch_size
: Batch size for trainingseeds
: List of random seeds for multiple runsnet_config
: Model-specific configurationoptimizer_config
: Optimizer-specific configurationscheduler_config
: Scheduler-specific configuration
study_name
: Name of the optimization studytrials
: Number of optimization trialsseed
: Random seed for optimizationmetric
: Metric to optimizedirection
: Direction of optimization ('minimize' or 'maximize')sampler
: Optuna sampler configurationpruner
: (Optional) Pruner configurationsearch_space
: Definition of the hyperparameter search space
-
Custom model: Modify or add models in
model.py
. Models should accept ahparams
argument as a dictionary, with keys matching thenet_config
parameters in the run configuration YAML file. -
Custom data: Modify the
load_data
function inutil.py
. The current example uses Cosine regression. Theload_data
function should return train and validation datasets compatible with PyTorch's DataLoader. -
Custom training: Customize the
Trainer
class inutil.py
by modifyingstep
,train_epoch
,val_epoch
, andtrain
methods to suit your task. Ensure thattrain
returnsval_loss
or a custom metric for proper hyperparameter optimization.
- Configurable experiments using YAML files
- Integration with Weights & Biases for experiment tracking
- Hyperparameter optimization using Optuna
- Support for multiple random seeds
- Flexible model architecture (currently MLP)
- Device selection (CPU/CUDA)
- Learning rate scheduling
- Analysis tools for completed runs and optimizations
The analyze.py
script utilizes functions from util.py
to analyze completed runs and optimizations. Key functions include:
select_group
: Select a run group for analysisselect_seed
: Select a specific seed from a run groupselect_device
: Choose a device for analysisload_model
: Load a trained model and its configurationload_study
: Load an Optuna studyload_best_model
: Load the best model from an optimization study
These functions are defined in util.py
and used within analyze.py
.
To use the analysis tools:
-
Run the
analyze.py
script:python analyze.py
-
Follow the prompts to select the project, run group, and seed (if applicable).
-
The script will load the selected model and perform basic analysis, such as calculating the validation loss.
-
You can extend the
main()
function inanalyze.py
to add custom analysis as needed, utilizing the utility functions fromutil.py
.
Contributions are welcome! Please feel free to submit a Pull Request.
This project is provided as a template and is intended to be freely used, modified, and distributed. Users of this template are encouraged to choose a license that best suits their specific project needs.
For the template itself:
- You are free to use, modify, and distribute this template.
- No attribution is required, although it is appreciated.
- The template is provided "as is", without warranty of any kind.
When using this template for your own project, please remember to:
- Remove this license section or replace it with your chosen license.
- Ensure all dependencies and libraries used in your project comply with their respective licenses.
For more information on choosing a license, visit choosealicense.com.
PFL (Predicted Final Loss) Pruner
The PFL pruner is a custom pruner that helps optimize hyperparameter search by early stopping unpromising trials. It maintains top k trials based on validation loss and prunes trials if their predicted final loss is worse than the worst saved PFL.
- Maintains top k trials based on validation loss
- Predicts final loss using loss history
- Supports multiple random seeds
- Compatible with Optuna's pruning interface
In your optimize_template.yaml
, configure the pruner under the pruner
section:
pruner:
name: pruner.PFLPruner
kwargs:
n_startup_trials: 10 # Number of trials to run before pruning starts
n_warmup_epochs: 10 # Number of epochs to run before pruning can occur
top_k: 10 # Number of best trials to maintain
target_epoch: 50 # Target epoch for final loss prediction
n_startup_trials
: Number of trials to run before pruning startsn_warmup_epochs
: Number of epochs to wait before pruning can occur within each trialtop_k
: Number of best trials to maintain for comparisontarget_epoch
: Target epoch number used for final loss prediction
- For the first
n_startup_trials
, all trials run without pruning - Within each trial, no pruning occurs during the first
n_warmup_epochs
- After warmup:
- The pruner maintains a list of top k trials based on validation loss
- For each trial, it predicts the final loss using the loss history
- If a trial's predicted final loss is worse than all saved trials, it is pruned