Skip to content

Commit

Permalink
Export take along dim (#1248)
Browse files Browse the repository at this point in the history
  • Loading branch information
dfalbel authored Jan 21, 2025
1 parent 94437bb commit 990cc85
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 0 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,7 @@ export(torch_sum)
export(torch_svd)
export(torch_t)
export(torch_take)
export(torch_take_along_dim)
export(torch_tan)
export(torch_tanh)
export(torch_tensor)
Expand Down
24 changes: 24 additions & 0 deletions R/gen-namespace-docs.R
Original file line number Diff line number Diff line change
Expand Up @@ -7479,3 +7479,27 @@ NULL
#' @name torch_kron
#' @export
NULL


#' Selects values from input at the 1-dimensional indices from indices along the given dim.
#'
#' @note If dim is `NULL`, the input array is treated as if it has been flattened to 1d.
#'
#' Functions that return indices along a dimension, like [torch_argmax()] and [torch_argsort()],
#' are designed to work with this function. See the examples below.
#'
#' @param input the input tensor.
#' @param indices the indices into input. Must have long dtype.
#' @param dim the dimension to select along. Default is `NULL`.
#'
#' @name torch_take_along_dim
#' @examples
#' t <- torch_tensor(matrix(c(10, 30, 20, 60, 40, 50), nrow = 2))
#' max_idx <- torch_argmax(t)
#' torch_take_along_dim(t, max_idx)
#'
#' sorted_idx <- torch_argsort(t, dim=2)
#' torch_take_along_dim(t, sorted_idx, dim=2)
#'
#' @export
NULL
35 changes: 35 additions & 0 deletions man/torch_take_along_dim.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 990cc85

Please sign in to comment.