v0.7.0
Serialization
Serialization has been completely revamped since the last release. Modules, Optimizers, and Learning Rate Scheduler now have an associative type, allowing them to determine the type used for serializing and deserializing their state. The solution is documented in the new architecture doc.
State can be saved with any precision, regardless of the backend in use. Precision conversion is performed during serialization and deserialization, ensuring high memory efficiency since the model is not stored twice in memory with different precisions.
All saved states can be loaded from any backend. The precision of the serialized state must be set correctly, but the element types of the backend can be anything.
Multiple (de)serialization recorders are provided:
- Default (compressed gzip with named message pack format)
- Bincode
- Compressed gzip bincode
- Pretty JSON
Users can extend the current recorder using any serde implementation.
Multiple precision settings are available:
- Half (f16, i16)
- Full (f32, i32)
- Double (f64, i64)
Users can extend the current settings using any supported number type.
Optimizer
The optimizer API has undergone a complete overhaul. It now supports the new serialization paradigm with a simplified trait definition. The learning rate is now passed as a parameter to the step method, making it easier to integrate the new learning rate scheduler. The learning rate configuration is now a part of the learner API. For more information, please refer to the documentation.
Gradient Clipping
You can now clip gradients by norm or by value. An integration is done with optimizers, and gradient clipping can be configured from optimizer configs (Adam & SGD).
Learning Rate Scheduler
A new trait has been introduced for creating learning rate schedulers. This trait follows a similar pattern as the Module and Optimizer APIs, utilizing an associative type that implements the Record trait for state (de)serialization.
The following learning rate schedulers are now available:
- Noam learning scheduler
- Constant learning scheduler
Module
The module API has undergone changes. There is no longer a need to wrap modules with the Param struct; only the Tensor struct requires a parameter ID.
All modules can now be created with their configuration and state, eliminating the unnecessary tensor initializations during model deployment for inference.
Convolution
Significant improvements have been made to support all convolution configurations. The stride, dilation, and groups can now be set, with full support for both inference and training.
Transposed convolutions are available in the backend API but do not currently support the backward pass. Once they are fully supported for both training and inference, they will be exposed as modules.
Pooling
The implementation of the average pooling module is now available.
Transformer
The transformer decoder has been implemented, offering support for efficient inference and autoregressive decoding by leveraging layer norms, position-wise feed forward, self-attention, and cross-attention caching.
Tensor
The developer experience of the Tensor API has been improved, providing more consistent error messages across different backends for common operations. The Tensor struct now implements Display, allowing values, shape, backend information, and other useful details to be displayed in an easily readable format.
New operations
- The flatten operation
- The mask scatter operation
Torch Backend
The Torch backend now supports bf16.
ONNX
The burn-import
project now has the capability to generate the required Burn code and model state from an ONNX file, enabling users to easily import pre-trained models into Burn. The code generation utilizes the end user API, allowing the generated model to be fine-tuned and trained using the learner struct.
Please note that not all operations are currently supported, and assistance from the community is highly appreciated. For more details, please refer to the burn-import repository https://github.com/burn-rs/burn/tree/main/burn-import.
Bug Fixes
- Backward pass issue when there is implicit broadcasting in add #181
Thanks 🙏
Thanks to all contributors @nathanielsimard , @antimora, @agelas, @bioinformatist, @sunny-g
Thanks to current sponsors: @smallstepman