Skip to content

Refactor/simplify training module #221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
dimkart opened this issue Feb 25, 2025 · 0 comments
Open

Refactor/simplify training module #221

dimkart opened this issue Feb 25, 2025 · 0 comments
Labels
enhancement New feature or request up-for-grabs Available to external contributors

Comments

@dimkart
Copy link
Contributor

dimkart commented Feb 25, 2025

Introduction

We would like to refactor the training module to make it simpler and more intuitive to use. This can be done by focusing on the concept of unified training backends rather than focusing on individual trainers and models, or on classical/quantum models. These distinctions create unnecessary confusion.

Based on the current status of the package, we can identify the following training backends:

  • PyTorch
  • Numpy (with Jax as an option)
  • Tket
  • PennyLane

The distinction between classical/quantum models can be boiled down to the desired generated output type: TNs or quantum circuits. In the revised interface this can be addressed by a special flag (e.g. tn_output, "whether outputs should be TNs instead of circuits"), which will be False by default. This will put the necessary focus on the quantum models, and will reduce the importance of the current "classical models" to a special optional case, as it should be.

Unified interface

Instead of completely rewriting the model and trainer hierarchy in lambeq, we can introduce a new structure that combines both models and trainers in an intuitive way. The new interface will be based on a single unified structure, Pipeline, which will get the required backend as a parameter. The call may have the following form:

pipeline = Pipeline(
    backend: TrainingBackend,        # One of the above backends,
    backend_params: dict[str, Any],  # Parameters specific to the selected backend
    optimiser: Optimiser,
    optim_params: dict[str, Any],
    epochs: int,
    loss: LossFunction,
    eval_functions: list[Callable],
    tn_output: bool                  # Whether outputs should be TNs instead of circuits,
                                     # raise a NonImplemented error for cases that do not
                                     # support TNs
    ...                                    
    ...                              # Other training parameters (as in the current trainer)
)

Outline of use

The high-level training process is given below.

from lambeq.training import Pipeline, Dataset

 pipeline = Pipeline(...)
 train_data, dev_data = Dataset(...), Dataset(...)
 pipeline.fit(train_data, dev_data, ...)

Approach

Internally, the Pipeline object can keep separate attributes for lambeq Models and Trainers, the use of which will be opaque to the user. This way, we keep backwards compatibility, and give to the user the option to continue using the old interface for more flexibility if required.

Depending on the backend selection and its tn_output parameter, Pipeline.model and Pipeline.trainer will be created and assigned automatically using existing lambeq models and trainers. In that way, e.g., Pipeline.fit() method can call Pipeline.trainer.fit() in the background.

References

@dimkart dimkart added enhancement New feature or request up-for-grabs Available to external contributors labels Feb 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request up-for-grabs Available to external contributors
Projects
None yet
Development

No branches or pull requests

1 participant