Skip to content

Commit

Permalink
Merge pull request #1223 from sebffischer/feat/ignite
Browse files Browse the repository at this point in the history
Feat/ignite
  • Loading branch information
dfalbel authored Jan 14, 2025
2 parents e93430f + 5629cd7 commit 6feaf60
Show file tree
Hide file tree
Showing 42 changed files with 4,090 additions and 307 deletions.
15 changes: 8 additions & 7 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ Collate:
'gen-namespace.R'
'generator.R'
'help.R'
'utils-data.R'
'optim.R'
'optim-adam.R'
'optim-sgd.R'
'optim-rmsprop.R'
'optim-adagrad.R'
'optim-adamw.R'
'ignite.R'
'indexing.R'
'install.R'
'ivalue.R'
Expand All @@ -107,7 +115,6 @@ Collate:
'layout.R'
'linalg.R'
'memory_format.R'
'utils-data.R'
'nn.R'
'nn-activation.R'
'nn-batchnorm.R'
Expand Down Expand Up @@ -145,17 +152,11 @@ Collate:
'nnf-upsampling.R'
'nnf-vision.R'
'operators.R'
'optim.R'
'optim-adadelta.R'
'optim-adagrad.R'
'optim-adam.R'
'optim-adamw.R'
'optim-asgd.R'
'optim-lbfgs.R'
'optim-lr_scheduler.R'
'optim-rmsprop.R'
'optim-rprop.R'
'optim-sgd.R'
'positron.R'
'package.R'
'qscheme.R'
Expand Down
8 changes: 8 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ S3method(torch_save,nn_module)
S3method(torch_save,torch_tensor)
S3method(trunc,torch_tensor)
export("%>%")
export(OptimizerIgnite)
export(as_array)
export(as_iterator)
export(autograd_backward)
Expand Down Expand Up @@ -467,11 +468,18 @@ export(optim_adagrad)
export(optim_adam)
export(optim_adamw)
export(optim_asgd)
export(optim_ignite_adagrad)
export(optim_ignite_adam)
export(optim_ignite_adamw)
export(optim_ignite_rmsprop)
export(optim_ignite_sgd)
export(optim_lbfgs)
export(optim_required)
export(optim_rmsprop)
export(optim_rprop)
export(optim_sgd)
export(optimizer)
export(optimizer_ignite)
export(sampler)
export(set_autocast)
export(slc)
Expand Down
9 changes: 5 additions & 4 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

## Bug fixes

- Fix french translation (#1176 @cregouby)

## Bug fixes

- `torch_iinfo()` now support all integer dtypes (#1190 @cregouby)
- Fixed float key_padding_mask in `nnf_multi_head_attention_forward()` (#1205)
- Updated to LibTorch v2.5.1 (#1204)
- Fix french translation (#1176 @cregouby)
- Feature: Faster optimizers (`optim_ignite_<name>()`) are available: Adam, AdamW, Adagrad, RMSprop,SGD.
These can be used as drop-in replacements for `optim_<name>` but are considerably
faster as they wrap the LibTorch implementation of the optimizer.
The biggest speed differences can be observed for complex optimizers such as `AdamW`.

# torch 0.13.0

Expand Down
144 changes: 144 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -14889,6 +14889,150 @@ cpp_torch_cuda_set_rng_state <- function(device, state) {
invisible(.Call(`_torch_cpp_torch_cuda_set_rng_state`, device, state))
}

rcpp_ignite_optim_get_param_groups <- function(opt) {
.Call(`_torch_rcpp_ignite_optim_get_param_groups`, opt)
}

rcpp_ignite_optim_param_groups_size <- function(groups) {
.Call(`_torch_rcpp_ignite_optim_param_groups_size`, groups)
}

rcpp_ignite_optim_get_param_group_params <- function(groups, i) {
.Call(`_torch_rcpp_ignite_optim_get_param_group_params`, groups, i)
}

rcpp_ignite_optim_step <- function(opt) {
invisible(.Call(`_torch_rcpp_ignite_optim_step`, opt))
}

rcpp_ignite_optim_zero_grad <- function(opt) {
invisible(.Call(`_torch_rcpp_ignite_optim_zero_grad`, opt))
}

rcpp_ignite_optim_parameters_with_state <- function(opt) {
.Call(`_torch_rcpp_ignite_optim_parameters_with_state`, opt)
}

rcpp_ignite_adamw <- function(params, lr, betas, eps, weight_decay, amsgrad) {
.Call(`_torch_rcpp_ignite_adamw`, params, lr, betas, eps, weight_decay, amsgrad)
}

rcpp_ignite_adamw_get_states <- function(opt) {
.Call(`_torch_rcpp_ignite_adamw_get_states`, opt)
}

rcpp_ignite_adamw_set_states <- function(opt, params, states) {
invisible(.Call(`_torch_rcpp_ignite_adamw_set_states`, opt, params, states))
}

rcpp_ignite_adamw_add_param_group <- function(opt, params, lr, betas, eps, weight_decay, amsgrad) {
invisible(.Call(`_torch_rcpp_ignite_adamw_add_param_group`, opt, params, lr, betas, eps, weight_decay, amsgrad))
}

rcpp_as_list_adamw_param_groups <- function(groups) {
.Call(`_torch_rcpp_as_list_adamw_param_groups`, groups)
}

rcpp_ignite_adamw_set_param_group_options <- function(opt, list) {
invisible(.Call(`_torch_rcpp_ignite_adamw_set_param_group_options`, opt, list))
}

rcpp_ignite_adam <- function(params, lr, betas, eps, weight_decay, amsgrad) {
.Call(`_torch_rcpp_ignite_adam`, params, lr, betas, eps, weight_decay, amsgrad)
}

rcpp_ignite_adam_get_states <- function(opt) {
.Call(`_torch_rcpp_ignite_adam_get_states`, opt)
}

rcpp_ignite_adam_set_states <- function(opt, params, states) {
invisible(.Call(`_torch_rcpp_ignite_adam_set_states`, opt, params, states))
}

rcpp_ignite_adam_add_param_group <- function(opt, params, lr, betas, eps, weight_decay, amsgrad) {
invisible(.Call(`_torch_rcpp_ignite_adam_add_param_group`, opt, params, lr, betas, eps, weight_decay, amsgrad))
}

rcpp_as_list_adam_param_groups <- function(groups) {
.Call(`_torch_rcpp_as_list_adam_param_groups`, groups)
}

rcpp_ignite_adam_set_param_group_options <- function(opt, list) {
invisible(.Call(`_torch_rcpp_ignite_adam_set_param_group_options`, opt, list))
}

rcpp_ignite_sgd <- function(params, lr, momentum, dampening, weight_decay, nesterov) {
.Call(`_torch_rcpp_ignite_sgd`, params, lr, momentum, dampening, weight_decay, nesterov)
}

rcpp_ignite_sgd_get_states <- function(opt) {
.Call(`_torch_rcpp_ignite_sgd_get_states`, opt)
}

rcpp_ignite_sgd_set_states <- function(opt, params, states) {
invisible(.Call(`_torch_rcpp_ignite_sgd_set_states`, opt, params, states))
}

rcpp_ignite_sgd_add_param_group <- function(opt, params, lr, momentum, dampening, weight_decay, nesterov) {
invisible(.Call(`_torch_rcpp_ignite_sgd_add_param_group`, opt, params, lr, momentum, dampening, weight_decay, nesterov))
}

rcpp_as_list_sgd_param_groups <- function(groups) {
.Call(`_torch_rcpp_as_list_sgd_param_groups`, groups)
}

rcpp_ignite_sgd_set_param_group_options <- function(opt, list) {
invisible(.Call(`_torch_rcpp_ignite_sgd_set_param_group_options`, opt, list))
}

rcpp_ignite_rmsprop <- function(params, lr, alpha, eps, weight_decay, momentum, centered) {
.Call(`_torch_rcpp_ignite_rmsprop`, params, lr, alpha, eps, weight_decay, momentum, centered)
}

rcpp_ignite_rmsprop_get_states <- function(opt) {
.Call(`_torch_rcpp_ignite_rmsprop_get_states`, opt)
}

rcpp_ignite_rmsprop_set_states <- function(opt, params, states) {
invisible(.Call(`_torch_rcpp_ignite_rmsprop_set_states`, opt, params, states))
}

rcpp_ignite_rmsprop_add_param_group <- function(opt, params, lr, alpha, eps, weight_decay, momentum, centered) {
invisible(.Call(`_torch_rcpp_ignite_rmsprop_add_param_group`, opt, params, lr, alpha, eps, weight_decay, momentum, centered))
}

rcpp_as_list_rmsprop_param_groups <- function(groups) {
.Call(`_torch_rcpp_as_list_rmsprop_param_groups`, groups)
}

rcpp_ignite_rmsprop_set_param_group_options <- function(opt, list) {
invisible(.Call(`_torch_rcpp_ignite_rmsprop_set_param_group_options`, opt, list))
}

rcpp_ignite_adagrad <- function(params, lr, lr_decay, weight_decay, eps, initial_accumulator_value) {
.Call(`_torch_rcpp_ignite_adagrad`, params, lr, lr_decay, weight_decay, eps, initial_accumulator_value)
}

rcpp_ignite_adagrad_get_states <- function(opt) {
.Call(`_torch_rcpp_ignite_adagrad_get_states`, opt)
}

rcpp_ignite_adagrad_set_states <- function(opt, params, states) {
invisible(.Call(`_torch_rcpp_ignite_adagrad_set_states`, opt, params, states))
}

rcpp_ignite_adagrad_add_param_group <- function(opt, params, lr, lr_decay, weight_decay, eps, initial_accumulator_value) {
invisible(.Call(`_torch_rcpp_ignite_adagrad_add_param_group`, opt, params, lr, lr_decay, weight_decay, eps, initial_accumulator_value))
}

rcpp_as_list_adagrad_param_groups <- function(groups) {
.Call(`_torch_rcpp_as_list_adagrad_param_groups`, groups)
}

rcpp_ignite_adagrad_set_param_group_options <- function(opt, list) {
invisible(.Call(`_torch_rcpp_ignite_adagrad_set_param_group_options`, opt, list))
}

enquos0 <- function(env) {
.Call(`_torch_enquos0`, env)
}
Expand Down
Loading

0 comments on commit 6feaf60

Please sign in to comment.