Skip to content

Commit

Permalink
chore(accelerated_op): use correct Python Ctype for pybind11 function…
Browse files Browse the repository at this point in the history
… prototype (#52)
  • Loading branch information
XuehaiPan authored Aug 7, 2022
1 parent 5b5b21d commit 84d1c3d
Show file tree
Hide file tree
Showing 13 changed files with 162 additions and 156 deletions.
14 changes: 6 additions & 8 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,9 @@ jobs:
run: |
python -m pip install --upgrade pip setuptools
- name: Install dependencies
run: |
python -m pip install --extra-index-url "${TORCH_INDEX_URL}" \
-r tests/requirements.txt
python -m pip install --extra-index-url "${TORCH_INDEX_URL}" \
-r docs/requirements.txt
- name: Install TorchOpt
run: |
python -m pip install -e .
python -m pip install -vvv -e '.[lint]'
- name: pre-commit
run: |
Expand Down Expand Up @@ -97,6 +90,11 @@ jobs:
run: |
make mypy
- name: Install dependencies
run: |
python -m pip install --extra-index-url "${TORCH_INDEX_URL}" \
-r docs/requirements.txt
- name: docstyle
run: |
make docstyle
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ jobs:
- name: Install TorchOpt
run: |
python -m pip install -e .
python -m pip install -vvv -e .
- name: Test with pytest
run: |
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Use [`cibuildwheel`](https://github.com/pypa/cibuildwheel) to build wheels by [@XuehaiPan](https://github.com/XuehaiPan) in [#45](https://github.com/metaopt/TorchOpt/pull/45).
- Use dynamic process number in CPU kernels by [@JieRen98](https://github.com/JieRen98) in [#42](https://github.com/metaopt/TorchOpt/pull/42).

### Changed

- Use correct Python Ctype for pybind11 function prototype [@XuehaiPan](https://github.com/XuehaiPan) in [#52](https://github.com/metaopt/TorchOpt/pull/52).

------

## [0.4.2] - 2022-07-26
Expand Down
2 changes: 1 addition & 1 deletion conda-recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ dependencies:
- mypy
- flake8
- flake8-bugbear
- doc8
- doc8 < 1.0.0a0
- pydocstyle
- clang-format
- clang-tools # clang-tidy
Expand Down
27 changes: 15 additions & 12 deletions include/adam_op/adam_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,35 @@
namespace torchopt {
TensorArray<3> adamForwardInplace(const torch::Tensor& updates,
const torch::Tensor& mu,
const torch::Tensor& nu, const float b1,
const float b2, const float eps,
const float eps_root, const int count);
const torch::Tensor& nu, const pyfloat_t b1,
const pyfloat_t b2, const pyfloat_t eps,
const pyfloat_t eps_root,
const pyuint_t count);

torch::Tensor adamForwardMu(const torch::Tensor& updates,
const torch::Tensor& mu, const float b1);
const torch::Tensor& mu, const pyfloat_t b1);

torch::Tensor adamForwardNu(const torch::Tensor& updates,
const torch::Tensor& nu, const float b2);
const torch::Tensor& nu, const pyfloat_t b2);

torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu,
const torch::Tensor& new_nu, const float b1,
const float b2, const float eps,
const float eps_root, const int count);
const torch::Tensor& new_nu,
const pyfloat_t b1, const pyfloat_t b2,
const pyfloat_t eps, const pyfloat_t eps_root,
const pyuint_t count);

TensorArray<2> adamBackwardMu(const torch::Tensor& dmu,
const torch::Tensor& updates,
const torch::Tensor& mu, const float b1);
const torch::Tensor& mu, const pyfloat_t b1);

TensorArray<2> adamBackwardNu(const torch::Tensor& dnu,
const torch::Tensor& updates,
const torch::Tensor& nu, const float b2);
const torch::Tensor& nu, const pyfloat_t b2);

TensorArray<2> adamBackwardUpdates(const torch::Tensor& dupdates,
const torch::Tensor& updates,
const torch::Tensor& new_mu,
const torch::Tensor& new_nu, const float b1,
const float b2, const int count);
const torch::Tensor& new_nu,
const pyfloat_t b1, const pyfloat_t b2,
const pyuint_t count);
} // namespace torchopt
29 changes: 15 additions & 14 deletions include/adam_op/adam_op_impl_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,35 +21,36 @@
#include "include/common.h"

namespace torchopt {
TensorArray<3> adamForwardInplaceCPU(const torch::Tensor& updates,
const torch::Tensor& mu,
const torch::Tensor& nu, const float b1,
const float b2, const float eps,
const float eps_root, const int count);
TensorArray<3> adamForwardInplaceCPU(
const torch::Tensor& updates, const torch::Tensor& mu,
const torch::Tensor& nu, const pyfloat_t b1, const pyfloat_t b2,
const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count);

torch::Tensor adamForwardMuCPU(const torch::Tensor& updates,
const torch::Tensor& mu, const float b1);
const torch::Tensor& mu, const pyfloat_t b1);

torch::Tensor adamForwardNuCPU(const torch::Tensor& updates,
const torch::Tensor& nu, const float b2);
const torch::Tensor& nu, const pyfloat_t b2);

torch::Tensor adamForwardUpdatesCPU(const torch::Tensor& new_mu,
const torch::Tensor& new_nu, const float b1,
const float b2, const float eps,
const float eps_root, const int count);
const torch::Tensor& new_nu,
const pyfloat_t b1, const pyfloat_t b2,
const pyfloat_t eps,
const pyfloat_t eps_root,
const pyuint_t count);

TensorArray<2> adamBackwardMuCPU(const torch::Tensor& dmu,
const torch::Tensor& updates,
const torch::Tensor& mu, const float b1);
const torch::Tensor& mu, const pyfloat_t b1);

TensorArray<2> adamBackwardNuCPU(const torch::Tensor& dnu,
const torch::Tensor& updates,
const torch::Tensor& nu, const float b2);
const torch::Tensor& nu, const pyfloat_t b2);

TensorArray<2> adamBackwardUpdatesCPU(const torch::Tensor& dupdates,
const torch::Tensor& updates,
const torch::Tensor& new_mu,
const torch::Tensor& new_nu,
const float b1, const float b2,
const int count);
const pyfloat_t b1, const pyfloat_t b2,
const pyuint_t count);
} // namespace torchopt
28 changes: 14 additions & 14 deletions include/adam_op/adam_op_impl_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,36 +21,36 @@
#include "include/common.h"

namespace torchopt {
TensorArray<3> adamForwardInplaceCUDA(const torch::Tensor &updates,
const torch::Tensor &mu,
const torch::Tensor &nu, const float b1,
const float b2, const float eps,
const float eps_root, const int count);
TensorArray<3> adamForwardInplaceCUDA(
const torch::Tensor &updates, const torch::Tensor &mu,
const torch::Tensor &nu, const pyfloat_t b1, const pyfloat_t b2,
const pyfloat_t eps, const pyfloat_t eps_root, const pyuint_t count);

torch::Tensor adamForwardMuCUDA(const torch::Tensor &updates,
const torch::Tensor &mu, const float b1);
const torch::Tensor &mu, const pyfloat_t b1);

torch::Tensor adamForwardNuCUDA(const torch::Tensor &updates,
const torch::Tensor &nu, const float b2);
const torch::Tensor &nu, const pyfloat_t b2);

torch::Tensor adamForwardUpdatesCUDA(const torch::Tensor &new_mu,
const torch::Tensor &new_nu,
const float b1, const float b2,
const float eps, const float eps_root,
const int count);
const pyfloat_t b1, const pyfloat_t b2,
const pyfloat_t eps,
const pyfloat_t eps_root,
const pyuint_t count);

TensorArray<2> adamBackwardMuCUDA(const torch::Tensor &dmu,
const torch::Tensor &updates,
const torch::Tensor &mu, const float b1);
const torch::Tensor &mu, const pyfloat_t b1);

TensorArray<2> adamBackwardNuCUDA(const torch::Tensor &dnu,
const torch::Tensor &updates,
const torch::Tensor &nu, const float b2);
const torch::Tensor &nu, const pyfloat_t b2);

TensorArray<2> adamBackwardUpdatesCUDA(const torch::Tensor &dupdates,
const torch::Tensor &updates,
const torch::Tensor &new_mu,
const torch::Tensor &new_nu,
const float b1, const float b2,
const int count);
const pyfloat_t b1, const pyfloat_t b2,
const pyuint_t count);
} // namespace torchopt
4 changes: 4 additions & 0 deletions include/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
#include <torch/extension.h>

#include <array>
#include <cstddef>

using pyfloat_t = double;
using pyuint_t = std::size_t;

namespace torchopt {
template <size_t _Nm>
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ lint = [
"mypy",
"flake8",
"flake8-bugbear",
"doc8",
"doc8 < 1.0.0a0",
"pydocstyle",
"pyenchant",
"cpplint",
Expand Down
27 changes: 15 additions & 12 deletions src/adam_op/adam_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
namespace torchopt {
TensorArray<3> adamForwardInplace(const torch::Tensor& updates,
const torch::Tensor& mu,
const torch::Tensor& nu, const float b1,
const float b2, const float eps,
const float eps_root, const int count) {
const torch::Tensor& nu, const pyfloat_t b1,
const pyfloat_t b2, const pyfloat_t eps,
const pyfloat_t eps_root,
const pyuint_t count) {
#if defined(__CUDACC__)
if (updates.device().is_cuda()) {
return adamForwardInplaceCUDA(updates, mu, nu, b1, b2, eps, eps_root,
Expand All @@ -42,7 +43,7 @@ TensorArray<3> adamForwardInplace(const torch::Tensor& updates,
}
}
torch::Tensor adamForwardMu(const torch::Tensor& updates,
const torch::Tensor& mu, const float b1) {
const torch::Tensor& mu, const pyfloat_t b1) {
#if defined(__CUDACC__)
if (updates.device().is_cuda()) {
return adamForwardMuCUDA(updates, mu, b1);
Expand All @@ -56,7 +57,7 @@ torch::Tensor adamForwardMu(const torch::Tensor& updates,
}

torch::Tensor adamForwardNu(const torch::Tensor& updates,
const torch::Tensor& nu, const float b2) {
const torch::Tensor& nu, const pyfloat_t b2) {
#if defined(__CUDACC__)
if (updates.device().is_cuda()) {
return adamForwardNuCUDA(updates, nu, b2);
Expand All @@ -70,9 +71,10 @@ torch::Tensor adamForwardNu(const torch::Tensor& updates,
}

torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu,
const torch::Tensor& new_nu, const float b1,
const float b2, const float eps,
const float eps_root, const int count) {
const torch::Tensor& new_nu,
const pyfloat_t b1, const pyfloat_t b2,
const pyfloat_t eps, const pyfloat_t eps_root,
const pyuint_t count) {
#if defined(__CUDACC__)
if (new_mu.device().is_cuda()) {
return adamForwardUpdatesCUDA(new_mu, new_nu, b1, b2, eps, eps_root, count);
Expand All @@ -87,7 +89,7 @@ torch::Tensor adamForwardUpdates(const torch::Tensor& new_mu,

TensorArray<2> adamBackwardMu(const torch::Tensor& dmu,
const torch::Tensor& updates,
const torch::Tensor& mu, const float b1) {
const torch::Tensor& mu, const pyfloat_t b1) {
#if defined(__CUDACC__)
if (dmu.device().is_cuda()) {
return adamBackwardMuCUDA(dmu, updates, mu, b1);
Expand All @@ -102,7 +104,7 @@ TensorArray<2> adamBackwardMu(const torch::Tensor& dmu,

TensorArray<2> adamBackwardNu(const torch::Tensor& dnu,
const torch::Tensor& updates,
const torch::Tensor& nu, const float b2) {
const torch::Tensor& nu, const pyfloat_t b2) {
#if defined(__CUDACC__)
if (dnu.device().is_cuda()) {
return adamBackwardNuCUDA(dnu, updates, nu, b2);
Expand All @@ -118,8 +120,9 @@ TensorArray<2> adamBackwardNu(const torch::Tensor& dnu,
TensorArray<2> adamBackwardUpdates(const torch::Tensor& dupdates,
const torch::Tensor& updates,
const torch::Tensor& new_mu,
const torch::Tensor& new_nu, const float b1,
const float b2, const int count) {
const torch::Tensor& new_nu,
const pyfloat_t b1, const pyfloat_t b2,
const pyuint_t count) {
#if defined(__CUDACC__)
if (dupdates.device().is_cuda()) {
return adamBackwardUpdatesCUDA(dupdates, updates, new_mu, new_nu, b1, b2,
Expand Down
Loading

0 comments on commit 84d1c3d

Please sign in to comment.