diff --git a/3rdparty/cmake/FindPaddle.cmake b/3rdparty/cmake/FindPaddle.cmake new file mode 100644 index 00000000000..a564e0061f9 --- /dev/null +++ b/3rdparty/cmake/FindPaddle.cmake @@ -0,0 +1,120 @@ +# Find the Paddle root and use the provided cmake module +# The following variables will be set: +# - Paddle_FOUND +# - Paddle_VERSION +# - Paddle_ROOT +# - Paddle_DEFINITIONS +# +# - PADDLE_FOUND +# - PADDLE_INCLUDE_DIRS +# - PADDLE_LIBRARY_DIRS +# - PADDLE_LIBRARIES +# - PADDLE_CXX_FLAGS +# +# and import the target 'paddle'. + +if(NOT Paddle_FOUND) + # Searching for Paddle requires the python executable + if (NOT Python3_EXECUTABLE) + message(FATAL_ERROR "Python 3 not found in top level file") + endif() + + if(BUILD_CUDA_MODULE) + find_package(CUDAToolkit REQUIRED) + string(SUBSTRING ${CUDAToolkit_VERSION} 0 4 CUDA_VERSION) + endif() + + message(STATUS "Getting Paddle properties ...") + + set(Paddle_FETCH_PROPERTIES + "import os" + "import paddle" + "import sysconfig" + "print(paddle.__version__, end=';')" + "print(os.path.dirname(paddle.__file__), end=';')" + "print(sysconfig.get_path('include', scheme='posix_prefix'), end=';')" + ) + execute_process( + COMMAND ${Python3_EXECUTABLE} "-c" "${Paddle_FETCH_PROPERTIES}" + OUTPUT_VARIABLE Paddle_PROPERTIES + ) + + + list(GET Paddle_PROPERTIES 0 Paddle_VERSION) + list(GET Paddle_PROPERTIES 1 Paddle_ROOT) + list(GET Paddle_PROPERTIES 2 Python_INCLUDE) + + set(Paddle_CXX11_ABI True) + + unset(Paddle_FETCH_PROPERTIES) + unset(Paddle_PROPERTIES) + + add_library(paddle STATIC IMPORTED) + + # handle include directories + set(PADDLE_INCLUDE_DIRS) + list(APPEND PADDLE_INCLUDE_DIRS "${Paddle_ROOT}/include") + list(APPEND PADDLE_INCLUDE_DIRS "${Paddle_ROOT}/include/third_party") + list(APPEND PADDLE_INCLUDE_DIRS "${Python_INCLUDE}") + + if(BUILD_CUDA_MODULE) + list(APPEND PADDLE_INCLUDE_DIRS "${CUDAToolkit_INCLUDE_DIRS}") + endif() + + # handle library directories + set(PADDLE_LIBRARY_DIRS) + list(APPEND PADDLE_LIBRARY_DIRS "${Paddle_ROOT}/libs") + list(APPEND PADDLE_LIBRARY_DIRS "${Paddle_ROOT}/base") + + if(BUILD_CUDA_MODULE) + list(APPEND PADDLE_LIBRARY_DIRS "${CUDAToolkit_LIBRARY_DIR}") + endif() + + # handle libraries + set(PADDLE_LIBRARIES) + find_library(PADDLE_LIB NAMES paddle PATHS "${Paddle_ROOT}/base") + list(APPEND PADDLE_LIBRARY_DIRS "${PADDLE_LIB}") + + if(BUILD_CUDA_MODULE) + find_library(CUDART_LIB NAMES cudart PATHS "${CUDAToolkit_LIBRARY_DIR}") + list(APPEND PADDLE_LIBRARY_DIRS "${CUDART_LIB}") + endif() + + # handle compile flags + set(PADDLE_CXX_FLAGS) + if(BUILD_CUDA_MODULE) + set(PADDLE_CXX_FLAGS "-DPADDLE_WITH_CUDA ${PADDLE_CXX_FLAGS}") + endif() + + set_target_properties(paddle PROPERTIES + IMPORTED_LOCATION "${PADDLE_LIB}" + ) + set_target_properties(paddle PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${PADDLE_INCLUDE_DIRS}" + ) + set_property(TARGET paddle PROPERTY INTERFACE_COMPILE_OPTIONS "${PADDLE_CXX_FLAGS}") + + set(PADDLE_FOUND True) +endif() + +if(PRINT_ONCE) + message(STATUS "Paddle version: ${Paddle_VERSION}") + message(STATUS " root dir: ${Paddle_ROOT}") + message(STATUS " compile flags: ${PADDLE_CXX_FLAGS}") + if (UNIX AND NOT APPLE) + message(STATUS " use cxx11 abi: ${Paddle_CXX11_ABI}") + endif() + foreach(idir ${PADDLE_INCLUDE_DIRS}) + message(STATUS " include dirs: ${idir}") + endforeach(idir) + foreach(ldir ${PADDLE_LIBRARY_DIRS}) + message(STATUS " library dirs: ${ldir}") + endforeach(ldir) + foreach(lib ${PADDLE_LIBRARIES}) + message(STATUS " libraries: ${lib}") + endforeach(lib) +endif() + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(Paddle DEFAULT_MSG Paddle_VERSION + Paddle_ROOT) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7b047cd0b82..1b2c4777947 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -131,6 +131,7 @@ option(BUILD_AZURE_KINECT "Build support for Azure Kinect sensor" OFF # ML library options option(BUILD_TENSORFLOW_OPS "Build ops for TensorFlow" OFF) option(BUILD_PYTORCH_OPS "Build ops for PyTorch" OFF) +option(BUILD_PADDLE_OPS "Build ops for Paddle" OFF) option(BUNDLE_OPEN3D_ML "Includes the Open3D-ML repo in the wheel" OFF) # Release build options @@ -288,6 +289,9 @@ endif() if(BUILD_SYCL_MODULE AND BUILD_PYTORCH_OPS) message(FATAL_ERROR "BUILD_SYCL_MODULE=ON requires BUILD_PYTORCH_OPS=OFF") endif() +if(BUILD_SYCL_MODULE AND BUILD_PADDLE_OPS) + message(FATAL_ERROR "BUILD_SYCL_MODULE=ON requires BUILD_PADDLE_OPS=OFF") +endif() if(BUILD_SYCL_MODULE AND BUILD_CUDA_MODULE) message(FATAL_ERROR "BUILD_SYCL_MODULE and BUILD_SYCL_MODULE cannot be on at the same time for now.") endif() diff --git a/cpp/open3d/ml/CMakeLists.txt b/cpp/open3d/ml/CMakeLists.txt index 35f0b65112a..be445a9f238 100644 --- a/cpp/open3d/ml/CMakeLists.txt +++ b/cpp/open3d/ml/CMakeLists.txt @@ -3,6 +3,10 @@ if (BUILD_TENSORFLOW_OPS AND WIN32) # see https://github.com/tensorflow/custom-op/issues/24 endif() +if (BUILD_PADDLE_OPS AND (WIN32 OR APPLE)) + message(FATAL_ERROR "Building Paddle ops on Windows or MacOS is currently not supported.") +endif() + if (BUILD_TENSORFLOW_OPS) add_subdirectory(tensorflow) endif() @@ -11,4 +15,9 @@ if (BUILD_PYTORCH_OPS) add_subdirectory(pytorch) endif() +if (BUILD_PADDLE_OPS) + add_subdirectory(paddle) +endif() + + add_subdirectory(contrib) diff --git a/cpp/open3d/ml/paddle/CMakeLists.txt b/cpp/open3d/ml/paddle/CMakeLists.txt new file mode 100644 index 00000000000..c6d0b78393e --- /dev/null +++ b/cpp/open3d/ml/paddle/CMakeLists.txt @@ -0,0 +1,119 @@ +if(BUILD_CUDA_MODULE) + message(STATUS "Building Paddle ops with CUDA") +else() + message(STATUS "Building Paddle ops") +endif() + +set(PRINT_ONCE ON) +find_package(Paddle REQUIRED) + +if (Paddle_VERSION VERSION_LESS 3.0.0) + message(FATAL_ERROR "Please update to Paddle 3.0.0+ to build Paddle Ops.") +endif() + +add_library(open3d_paddle_ops SHARED) + +target_sources(open3d_paddle_ops PRIVATE + PaddleHelper.cpp + misc/BuildSpatialHashTableOpKernel.cpp + misc/BuildSpatialHashTableOps.cpp + misc/FixedRadiusSearchOps.cpp + misc/FixedRadiusSearchOpKernel.cpp + misc/RadiusSearchOps.cpp + misc/RadiusSearchOpKernel.cpp + misc/InvertNeighborsListOps.cpp + misc/InvertNeighborsListOpKernel.cpp + misc/KnnSearchOps.cpp + misc/KnnSearchOpKernel.cpp + misc/RaggedToDenseOpKernel.cpp + misc/RaggedToDenseOps.cpp +) + +if (BUILD_CUDA_MODULE) + target_sources(open3d_paddle_ops PRIVATE + misc/BuildSpatialHashTableOpKernel.cu + misc/FixedRadiusSearchOpKernel.cu + misc/InvertNeighborsListOpKernel.cu + misc/RaggedToDenseOpKernel.cu + ) + +endif() + +open3d_show_and_abort_on_warning(open3d_paddle_ops) +open3d_set_global_properties(open3d_paddle_ops) + +# Set output directory according to architecture (cpu/cuda) +get_target_property(PADDLE_OPS_DIR open3d_paddle_ops LIBRARY_OUTPUT_DIRECTORY) +set(PADDLE_OPS_ARCH_DIR + "${PADDLE_OPS_DIR}/$,cuda,cpu>") +set_target_properties(open3d_paddle_ops PROPERTIES + LIBRARY_OUTPUT_DIRECTORY "${PADDLE_OPS_ARCH_DIR}" + ARCHIVE_OUTPUT_DIRECTORY "${PADDLE_OPS_ARCH_DIR}") + +# Do not add "lib" prefix +set_target_properties(open3d_paddle_ops PROPERTIES PREFIX "") +set_target_properties(open3d_paddle_ops PROPERTIES DEBUG_POSTFIX "_debug") + +target_include_directories(open3d_paddle_ops SYSTEM PRIVATE + ${PROJECT_SOURCE_DIR}/cpp + ${PADDLE_INCLUDE_DIRS} +) + +target_link_libraries(open3d_paddle_ops PRIVATE + paddle + Open3D::Open3D + Open3D::3rdparty_eigen3 + Open3D::3rdparty_fmt + Open3D::3rdparty_nanoflann + TBB::tbb +) +if (TARGET Open3D::3rdparty_parallelstl) + target_link_libraries(open3d_paddle_ops PRIVATE + Open3D::3rdparty_parallelstl + ) +endif() +if (TARGET Open3D::3rdparty_onedpl) + target_link_libraries(open3d_paddle_ops PRIVATE + Open3D::3rdparty_onedpl + ) +endif() + +if (BUILD_CUDA_MODULE) + target_link_libraries(open3d_paddle_ops PRIVATE + Open3D::3rdparty_cutlass + ${PADDLE_LIBRARIES} + CUDA::cuda_driver + ) + + if (TARGET Open3D::3rdparty_cub) + target_link_libraries(open3d_paddle_ops PRIVATE + Open3D::3rdparty_cub + ) + endif() +endif() + +install(TARGETS open3d_paddle_ops EXPORT Open3DPaddleOps + LIBRARY DESTINATION ${Open3D_INSTALL_LIB_DIR} +) +install(EXPORT Open3DPaddleOps NAMESPACE ${PROJECT_NAME}:: DESTINATION ${Open3D_INSTALL_CMAKE_DIR}) + +if (BUILD_SHARED_LIBS AND UNIX) +file(CONFIGURE OUTPUT open3d_paddle_ops.pc.in + CONTENT [=[ +prefix=${pcfiledir}/../.. +libdir=${prefix}/lib +includedir=${prefix}/include/ + +Name: Open3D Paddle Ops +Description: @PROJECT_DESCRIPTION@ This library contains 3D ML Ops for use with Paddle. +URL: @PROJECT_HOMEPAGE_URL@ +Version: @PROJECT_VERSION@ +Requires: Open3D = @PROJECT_VERSION@ +Cflags: +Libs: -lopen3d_paddle_ops]=] @ONLY NEWLINE_STYLE LF) + file(GENERATE OUTPUT open3d_paddle_ops.pc INPUT + "${CMAKE_CURRENT_BINARY_DIR}/open3d_paddle_ops.pc.in" + TARGET open3d_paddle_ops) + install(FILES "${CMAKE_CURRENT_BINARY_DIR}/open3d_paddle_ops.pc" + DESTINATION "${Open3D_INSTALL_LIB_DIR}/pkgconfig") +endif() diff --git a/cpp/open3d/ml/paddle/PaddleHelper.cpp b/cpp/open3d/ml/paddle/PaddleHelper.cpp new file mode 100644 index 00000000000..02ab96ff27d --- /dev/null +++ b/cpp/open3d/ml/paddle/PaddleHelper.cpp @@ -0,0 +1,42 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- + +#include "PaddleHelper.h" + +paddle::Tensor InitializedEmptyTensor(const phi::DataType dtype, + const phi::IntArray& shape, + const phi::Place& place) { + switch (dtype) { + case phi::DataType::INT8: + return InitializedEmptyTensor(shape, place); + break; + case phi::DataType::UINT8: + return InitializedEmptyTensor(shape, place); + break; + case phi::DataType::INT16: + return InitializedEmptyTensor(shape, place); + break; + case phi::DataType::FLOAT32: + return InitializedEmptyTensor(shape, place); + break; + case phi::DataType::INT32: + return InitializedEmptyTensor(shape, place); + break; + case phi::DataType::FLOAT64: + return InitializedEmptyTensor(shape, place); + break; + case phi::DataType::INT64: + return InitializedEmptyTensor(shape, place); + break; + default: + PD_CHECK(false, + "Only support phi::DataType as `INT8`, `UINT8`, `INT16`, " + "`FLOAT32`, `FLOAT64`, " + "`INT32` and `INT64` but got %s.", + phi::DataTypeToString(dtype)); + } +} diff --git a/cpp/open3d/ml/paddle/PaddleHelper.h b/cpp/open3d/ml/paddle/PaddleHelper.h new file mode 100644 index 00000000000..d7918ed8524 --- /dev/null +++ b/cpp/open3d/ml/paddle/PaddleHelper.h @@ -0,0 +1,286 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- + +#pragma once + +#include +#include + +#include +#include + +#include "open3d/ml/ShapeChecking.h" +#include "paddle/extension.h" +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/allocator.h" + +// Macros for checking tensor properties +#define CHECK_CUDA(x) \ + do { \ + PD_CHECK(x.is_gpu(), #x " must be a CUDA tensor"); \ + } while (0) + +// NOTE: The input Tensor will be preprocessed into a contiguous Tensor within +// the execution function of the custom operator, so CHECK_CONTIGUOUS will be +// always True as there is no need for an explicit conversion in Open3D. For +// reference, please see: +// https://github.com/PaddlePaddle/Paddle/blob/65126f558a5c0fbb0cd1aa0a42844a73632ff9e9/paddle/fluid/eager/custom_operator/custom_operator_utils.cc#L803-L810 +#define CHECK_CONTIGUOUS(x) \ + do { \ + } while (0) + +#define CHECK_TYPE(x, type) \ + do { \ + PD_CHECK(x.dtype() == type, #x " must have type " #type); \ + } while (0) + +#define CHECK_SAME_DEVICE_TYPE(...) \ + do { \ + if (!SameDeviceType({__VA_ARGS__})) { \ + PD_CHECK(false, \ + #__VA_ARGS__ \ + " must all have the same device type but got " + \ + TensorInfoStr({__VA_ARGS__})); \ + } \ + } while (0) + +#define CHECK_SAME_DTYPE(...) \ + do { \ + if (!SameDtype({__VA_ARGS__})) { \ + PD_CHECK(false, #__VA_ARGS__ \ + " must all have the same dtype but got " + \ + TensorInfoStr({__VA_ARGS__})); \ + } \ + } while (0) +// Conversion from standard types to paddle types +typedef std::remove_const::type + PaddleDtype_t; +template +inline PaddleDtype_t ToPaddleDtype() { + PD_CHECK(false, "Unsupported type"); +} +template <> +inline PaddleDtype_t ToPaddleDtype() { + return paddle::DataType::UINT8; +} +template <> +inline PaddleDtype_t ToPaddleDtype() { + return paddle::DataType::INT8; +} +template <> +inline PaddleDtype_t ToPaddleDtype() { + return paddle::DataType::INT16; +} +template <> +inline PaddleDtype_t ToPaddleDtype() { + return paddle::DataType::INT32; +} +template <> +inline PaddleDtype_t ToPaddleDtype() { + return paddle::DataType::INT64; +} +template <> +inline PaddleDtype_t ToPaddleDtype() { + return paddle::DataType::FLOAT32; +} +template <> +inline PaddleDtype_t ToPaddleDtype() { + return paddle::DataType::FLOAT64; +} + +// convenience function for comparing standard types with paddle types +template +inline bool ComparePaddleDtype(const TDtype& t) { + return ToPaddleDtype() == t; +} + +// convenience function to check if all tensors have the same device type +inline bool SameDeviceType(std::initializer_list tensors) { + if (tensors.size()) { + auto device_type = tensors.begin()->place(); + for (const auto& t : tensors) { + if (device_type != t.place()) { + return false; + } + } + } + return true; +} + +// convenience function to check if all tensors have the same dtype +inline bool SameDtype(std::initializer_list tensors) { + if (tensors.size()) { + auto dtype = tensors.begin()->dtype(); + for (const auto& t : tensors) { + if (dtype != t.dtype()) { + return false; + } + } + } + return true; +} + +inline std::string TensorInfoStr( + std::initializer_list tensors) { + std::stringstream sstr; + size_t count = 0; + for (const auto& t : tensors) { + sstr << "Tensor(" << t.size() << ", " << t.place() << ")"; + ++count; + if (count < tensors.size()) sstr << ", "; + } + return sstr.str(); +} + +// convenience function for creating a tensor for temp memory +inline paddle::Tensor CreateTempTensor(const int64_t size, + const paddle::Place& device, + void** ptr = nullptr) { + paddle::Tensor tensor = + paddle::empty({size}, ToPaddleDtype(), device); + if (ptr) { + *ptr = tensor.data(); + } + return tensor; +} + +inline std::vector GetShapeVector( + paddle::Tensor tensor) { + using namespace open3d::ml::op_util; + const auto old_shape = tensor.shape(); + std::vector shape; + for (auto i = 0u; i < old_shape.size(); ++i) { + shape.push_back(old_shape[i]); + } + return shape; +} + +template +std::tuple CheckShape(paddle::Tensor tensor, + TDimX&& dimex, + TArgs&&... args) { + return open3d::ml::op_util::CheckShape(GetShapeVector(tensor), + std::forward(dimex), + std::forward(args)...); +} + +// +// Macros for checking the shape of Tensors. +// Usage: +// { +// using namespace open3d::ml::op_util; +// Dim w("w"); +// Dim h("h"); +// CHECK_SHAPE(tensor1, 10, w, h); // checks if the first dim is 10 +// // and assigns w and h based on +// // the shape of tensor1 +// +// CHECK_SHAPE(tensor2, 10, 20, h); // this checks if the the last dim +// // of tensor2 matches the last dim +// // of tensor1. The first two dims +// // must match 10, 20. +// } +// +// +// See "../ShapeChecking.h" for more info and limitations. +// +#define CHECK_SHAPE(tensor, ...) \ + do { \ + bool cs_success_; \ + std::string cs_errstr_; \ + std::tie(cs_success_, cs_errstr_) = CheckShape(tensor, __VA_ARGS__); \ + PD_CHECK(cs_success_, \ + "invalid shape for '" #tensor "', " + cs_errstr_); \ + } while (0) + +#define CHECK_SHAPE_COMBINE_FIRST_DIMS(tensor, ...) \ + do { \ + bool cs_success_; \ + std::string cs_errstr_; \ + std::tie(cs_success_, cs_errstr_) = \ + CheckShape(tensor, __VA_ARGS__); \ + PD_CHECK(cs_success_, \ + "invalid shape for '" #tensor "', " + cs_errstr_); \ + } while (0) + +#define CHECK_SHAPE_IGNORE_FIRST_DIMS(tensor, ...) \ + do { \ + bool cs_success_; \ + std::string cs_errstr_; \ + std::tie(cs_success_, cs_errstr_) = \ + CheckShape(tensor, __VA_ARGS__); \ + PD_CHECK(cs_success_, \ + "invalid shape for '" #tensor "', " + cs_errstr_); \ + } while (0) + +#define CHECK_SHAPE_COMBINE_LAST_DIMS(tensor, ...) \ + do { \ + bool cs_success_; \ + std::string cs_errstr_; \ + std::tie(cs_success_, cs_errstr_) = \ + CheckShape(tensor, __VA_ARGS__); \ + PD_CHECK(cs_success_, \ + "invalid shape for '" #tensor "', " + cs_errstr_); \ + } while (0) + +#define CHECK_SHAPE_IGNORE_LAST_DIMS(tensor, ...) \ + do { \ + bool cs_success_; \ + std::string cs_errstr_; \ + std::tie(cs_success_, cs_errstr_) = \ + CheckShape(tensor, __VA_ARGS__); \ + PD_CHECK(cs_success_, \ + "invalid shape for '" #tensor "', " + cs_errstr_); \ + } while (0) + +#ifdef BUILD_CUDA_MODULE +static void cudaFreeWrapper(void* ptr) { + phi::gpuError_t result = cudaFree(ptr); + PADDLE_ENFORCE_GPU_SUCCESS(result); +} +#endif + +// NOTE: Hack to support empty tensor, like Tensor(shape=[0], []) +template +paddle::Tensor InitializedEmptyTensor(const phi::IntArray& shape, + const phi::Place& place) { + int64_t size = 1; + for (auto v : shape.GetData()) { + size *= v; + } + PD_CHECK(size == 0, "The numel of empty tensor is not equal to 0."); + + paddle::Deleter deleter; + T* ptr = nullptr; + if (phi::is_gpu_place(place)) { +#ifdef BUILD_CUDA_MODULE + phi::gpuError_t result = cudaMalloc(&ptr, sizeof(T) * 1); + PADDLE_ENFORCE_GPU_SUCCESS(result); + deleter = std::function(cudaFreeWrapper); +#else + PD_CHECK(false, + "InitializedEmptyTensor was not compiled with CUDA support"); +#endif + } else if (phi::is_cpu_place(place)) { + ptr = (T*)malloc(sizeof(T) * 1); + deleter = std::function(free); + } else { + PD_CHECK(false, "Not supported backend!"); + } + + // NOTE: In Paddle, the stride of an empty (0-size) tensor can be the same + // as its shape. + return paddle::from_blob(static_cast(ptr), shape, shape, + paddle::DataType(ToPaddleDtype()), + phi::DataLayout::NCHW, place, deleter); +} + +paddle::Tensor InitializedEmptyTensor(const phi::DataType dtype, + const phi::IntArray& shape, + const phi::Place& place); diff --git a/cpp/open3d/ml/paddle/misc/BuildSpatialHashTableOpKernel.cpp b/cpp/open3d/ml/paddle/misc/BuildSpatialHashTableOpKernel.cpp new file mode 100644 index 00000000000..bcb6d42a41e --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/BuildSpatialHashTableOpKernel.cpp @@ -0,0 +1,34 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/core/nns/FixedRadiusSearchImpl.h" +#include "open3d/ml/paddle/PaddleHelper.h" + +template +void BuildSpatialHashTableCPU(const paddle::Tensor& points, + double radius, + const paddle::Tensor& points_row_splits, + const std::vector& hash_table_splits, + paddle::Tensor& hash_table_index, + paddle::Tensor& hash_table_cell_splits) { + open3d::core::nns::impl::BuildSpatialHashTableCPU( + points.shape()[0], points.data(), T(radius), + points_row_splits.shape()[0], points_row_splits.data(), + hash_table_splits.data(), hash_table_cell_splits.shape()[0], + reinterpret_cast(const_cast( + hash_table_cell_splits.data())), + reinterpret_cast( + const_cast(hash_table_index.data()))); +} +#define INSTANTIATE(T) \ + template void BuildSpatialHashTableCPU( \ + const paddle::Tensor&, double, const paddle::Tensor&, \ + const std::vector&, paddle::Tensor&, paddle::Tensor&); + +INSTANTIATE(float) +INSTANTIATE(double) diff --git a/cpp/open3d/ml/paddle/misc/BuildSpatialHashTableOpKernel.cu b/cpp/open3d/ml/paddle/misc/BuildSpatialHashTableOpKernel.cu new file mode 100644 index 00000000000..061fcc68b1e --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/BuildSpatialHashTableOpKernel.cu @@ -0,0 +1,59 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/core/nns/FixedRadiusSearchImpl.cuh" +#include "open3d/ml/paddle/PaddleHelper.h" + +using namespace open3d::core::nns; + +template +void BuildSpatialHashTableCUDA(const paddle::Tensor& points, + double radius, + const paddle::Tensor& points_row_splits, + const std::vector& hash_table_splits, + paddle::Tensor& hash_table_index, + paddle::Tensor& hash_table_cell_splits) { + auto stream = points.stream(); + // -1 means current global place + auto cuda_place_props = phi::backends::gpu::GetDeviceProperties(-1); + const int texture_alignment = cuda_place_props.textureAlignment; + + void* temp_ptr = nullptr; + size_t temp_size = 0; + + // determine temp_size + impl::BuildSpatialHashTableCUDA( + stream, temp_ptr, temp_size, texture_alignment, points.shape()[0], + points.data(), T(radius), points_row_splits.shape()[0], + points_row_splits.data(), hash_table_splits.data(), + hash_table_cell_splits.shape()[0], + reinterpret_cast(const_cast( + hash_table_cell_splits.data())), + reinterpret_cast( + const_cast(hash_table_index.data()))); + auto place = points.place(); + auto temp_tensor = CreateTempTensor(temp_size, place, &temp_ptr); + + // actually build the table + impl::BuildSpatialHashTableCUDA( + stream, temp_ptr, temp_size, texture_alignment, points.shape()[0], + points.data(), T(radius), points_row_splits.shape()[0], + points_row_splits.data(), hash_table_splits.data(), + hash_table_cell_splits.shape()[0], + reinterpret_cast(const_cast( + hash_table_cell_splits.data())), + reinterpret_cast( + const_cast(hash_table_index.data()))); +} + +#define INSTANTIATE(T) \ + template void BuildSpatialHashTableCUDA( \ + const paddle::Tensor&, double, const paddle::Tensor&, \ + const std::vector&, paddle::Tensor&, paddle::Tensor&); + +INSTANTIATE(float) diff --git a/cpp/open3d/ml/paddle/misc/BuildSpatialHashTableOps.cpp b/cpp/open3d/ml/paddle/misc/BuildSpatialHashTableOps.cpp new file mode 100644 index 00000000000..a0f49e6a442 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/BuildSpatialHashTableOps.cpp @@ -0,0 +1,121 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include +#include + +#include + +#include "open3d/ml/paddle/PaddleHelper.h" + +template +void BuildSpatialHashTableCPU(const paddle::Tensor& points, + double radius, + const paddle::Tensor& points_row_splits, + const std::vector& hash_table_splits, + paddle::Tensor& hash_table_index, + paddle::Tensor& hash_table_cell_splits); +#ifdef BUILD_CUDA_MODULE +template +void BuildSpatialHashTableCUDA(const paddle::Tensor& points, + double radius, + const paddle::Tensor& points_row_splits, + const std::vector& hash_table_splits, + paddle::Tensor& hash_table_index, + paddle::Tensor& hash_table_cell_splits); +#endif + +std::vector BuildSpatialHashTable( + paddle::Tensor& points, + paddle::Tensor& points_row_splits, + double radius, + double hash_table_size_factor, + int64_t max_hash_table_size) { + points_row_splits = points_row_splits.copy_to(phi::CPUPlace(), false); + CHECK_TYPE(points_row_splits, paddle::DataType::INT64); + + // check input shapes + using namespace open3d::ml::op_util; + Dim num_points("num_points"); + Dim batch_size("batch_size"); + + CHECK_SHAPE(points, num_points, 3); + CHECK_SHAPE(points_row_splits, batch_size + 1); + + const auto& point_type = points.dtype(); + + std::vector hash_table_splits(batch_size.value() + 1, 0); + for (int i = 0; i < batch_size.value(); ++i) { + int64_t num_points_i = points_row_splits.data()[i + 1] - + points_row_splits.data()[i]; + int64_t hash_table_size = std::min( + std::max(hash_table_size_factor * num_points_i, 1), + max_hash_table_size); + hash_table_splits[i + 1] = hash_table_splits[i] + hash_table_size; + } + + auto place = points.place(); + paddle::Tensor hash_table_index; + if (points.shape()[0] != 0) { + hash_table_index = + paddle::empty({points.shape()[0]}, + paddle::DataType(ToPaddleDtype()), place); + } else { + hash_table_index = InitializedEmptyTensor({0}, place); + } + paddle::Tensor hash_table_cell_splits = + paddle::empty({hash_table_splits.back() + 1}, + paddle::DataType(ToPaddleDtype()), place); + paddle::Tensor out_hash_table_splits = paddle::empty( + {batch_size.value() + 1}, + paddle::DataType(ToPaddleDtype()), phi::CPUPlace()); + for (size_t i = 0; i < hash_table_splits.size(); ++i) { + out_hash_table_splits.data()[i] = hash_table_splits[i]; + } +#define FN_PARAMETERS \ + points, radius, points_row_splits, hash_table_splits, hash_table_index, \ + hash_table_cell_splits +#define CALL(type, fn) \ + if (ComparePaddleDtype(point_type)) { \ + fn(FN_PARAMETERS); \ + return {hash_table_index, hash_table_cell_splits, \ + out_hash_table_splits}; \ + } + if (points.is_gpu()) { +#ifdef BUILD_CUDA_MODULE + // pass to cuda function + CALL(float, BuildSpatialHashTableCUDA) +#else + PD_CHECK(false, + "BuildSpatialHashTable was not compiled with CUDA support"); +#endif + } else { + CALL(float, BuildSpatialHashTableCPU) + CALL(double, BuildSpatialHashTableCPU) + } + PD_CHECK(false, "BuildSpatialHashTable does not support " + + phi::DataTypeToString(points.dtype()) + + " as input for " + "points"); + + return std::vector(); +} + +std::vector BuildSpatialHashTableInferDtype() { + auto dtype = paddle::DataType::INT32; + return {dtype, dtype, dtype}; +} + +PD_BUILD_OP(open3d_build_spatial_hash_table) + .Inputs({"points", "points_row_splits"}) + .Outputs({"hash_table_index", "hash_table_cell_splits", + "hash_table_splits"}) + .Attrs({"radius: double", "hash_table_size_factor: double", + "max_hash_table_size: int64_t"}) + .SetKernelFn(PD_KERNEL(BuildSpatialHashTable)) + .SetInferDtypeFn(PD_INFER_DTYPE(BuildSpatialHashTableInferDtype)); diff --git a/cpp/open3d/ml/paddle/misc/FixedRadiusSearchOpKernel.cpp b/cpp/open3d/ml/paddle/misc/FixedRadiusSearchOpKernel.cpp new file mode 100644 index 00000000000..c48003d7891 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/FixedRadiusSearchOpKernel.cpp @@ -0,0 +1,67 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/core/nns/FixedRadiusSearchImpl.h" +#include "open3d/core/nns/NeighborSearchCommon.h" +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/ml/paddle/misc/NeighborSearchAllocator.h" + +using namespace open3d::core::nns; + +template +void FixedRadiusSearchCPU(const paddle::Tensor& points, + const paddle::Tensor& queries, + double radius, + const paddle::Tensor& points_row_splits, + const paddle::Tensor& queries_row_splits, + const paddle::Tensor& hash_table_splits, + const paddle::Tensor& hash_table_index, + const paddle::Tensor& hash_table_cell_splits, + const Metric metric, + const bool ignore_query_point, + const bool return_distances, + paddle::Tensor& neighbors_index, + paddle::Tensor& neighbors_row_splits, + paddle::Tensor& neighbors_distance) { + NeighborSearchAllocator output_allocator(points.place()); + + impl::FixedRadiusSearchCPU( + neighbors_row_splits.data(), points.shape()[0], + points.data(), queries.shape()[0], queries.data(), T(radius), + points_row_splits.shape()[0], points_row_splits.data(), + queries_row_splits.shape()[0], queries_row_splits.data(), + reinterpret_cast( + const_cast(hash_table_splits.data())), + hash_table_cell_splits.shape()[0], + reinterpret_cast(const_cast( + hash_table_cell_splits.data())), + reinterpret_cast( + const_cast(hash_table_index.data())), + metric, ignore_query_point, return_distances, output_allocator); + + neighbors_index = output_allocator.NeighborsIndex(); + neighbors_distance = output_allocator.NeighborsDistance(); +} + +#define INSTANTIATE(T, TIndex) \ + template void FixedRadiusSearchCPU( \ + const paddle::Tensor& points, const paddle::Tensor& queries, \ + double radius, const paddle::Tensor& points_row_splits, \ + const paddle::Tensor& queries_row_splits, \ + const paddle::Tensor& hash_table_splits, \ + const paddle::Tensor& hash_table_index, \ + const paddle::Tensor& hash_table_cell_splits, const Metric metric, \ + const bool ignore_query_point, const bool return_distances, \ + paddle::Tensor& neighbors_index, \ + paddle::Tensor& neighbors_row_splits, \ + paddle::Tensor& neighbors_distance); + +INSTANTIATE(float, int32_t) +INSTANTIATE(float, int64_t) +INSTANTIATE(double, int32_t) +INSTANTIATE(double, int64_t) diff --git a/cpp/open3d/ml/paddle/misc/FixedRadiusSearchOpKernel.cu b/cpp/open3d/ml/paddle/misc/FixedRadiusSearchOpKernel.cu new file mode 100644 index 00000000000..3519d4783f9 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/FixedRadiusSearchOpKernel.cu @@ -0,0 +1,93 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/core/nns/FixedRadiusSearchImpl.cuh" +#include "open3d/core/nns/NeighborSearchCommon.h" +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/ml/paddle/misc/NeighborSearchAllocator.h" + +using namespace open3d::core::nns; + +template +void FixedRadiusSearchCUDA(const paddle::Tensor& points, + const paddle::Tensor& queries, + double radius, + const paddle::Tensor& points_row_splits, + const paddle::Tensor& queries_row_splits, + const paddle::Tensor& hash_table_splits, + const paddle::Tensor& hash_table_index, + const paddle::Tensor& hash_table_cell_splits, + const Metric metric, + const bool ignore_query_point, + const bool return_distances, + paddle::Tensor& neighbors_index, + paddle::Tensor& neighbors_row_splits, + paddle::Tensor& neighbors_distance) { + auto stream = points.stream(); + auto cuda_place_props = phi::backends::gpu::GetDeviceProperties(-1); + const int texture_alignment = cuda_place_props.textureAlignment; + + auto place = points.place(); + + NeighborSearchAllocator output_allocator(place); + void* temp_ptr = nullptr; + size_t temp_size = 0; + + // determine temp_size + impl::FixedRadiusSearchCUDA( + stream, temp_ptr, temp_size, texture_alignment, + neighbors_row_splits.data(), points.shape()[0], + points.data(), queries.shape()[0], queries.data(), T(radius), + points_row_splits.shape()[0], points_row_splits.data(), + queries_row_splits.shape()[0], queries_row_splits.data(), + reinterpret_cast( + const_cast(hash_table_splits.data())), + hash_table_cell_splits.shape()[0], + reinterpret_cast(const_cast( + hash_table_cell_splits.data())), + reinterpret_cast( + const_cast(hash_table_index.data())), + metric, ignore_query_point, return_distances, output_allocator); + + auto temp_tensor = CreateTempTensor(temp_size, points.place(), &temp_ptr); + + // actually run the search + impl::FixedRadiusSearchCUDA( + stream, temp_ptr, temp_size, texture_alignment, + neighbors_row_splits.data(), points.shape()[0], + points.data(), queries.shape()[0], queries.data(), T(radius), + points_row_splits.shape()[0], points_row_splits.data(), + queries_row_splits.shape()[0], queries_row_splits.data(), + reinterpret_cast( + const_cast(hash_table_splits.data())), + hash_table_cell_splits.shape()[0], + reinterpret_cast(const_cast( + hash_table_cell_splits.data())), + reinterpret_cast( + const_cast(hash_table_index.data())), + metric, ignore_query_point, return_distances, output_allocator); + + neighbors_index = output_allocator.NeighborsIndex(); + neighbors_distance = output_allocator.NeighborsDistance(); +} + +#define INSTANTIATE(T, TIndex) \ + template void FixedRadiusSearchCUDA( \ + const paddle::Tensor& points, const paddle::Tensor& queries, \ + double radius, const paddle::Tensor& points_row_splits, \ + const paddle::Tensor& queries_row_splits, \ + const paddle::Tensor& hash_table_splits, \ + const paddle::Tensor& hash_table_index, \ + const paddle::Tensor& hash_table_cell_splits, const Metric metric, \ + const bool ignore_query_point, const bool return_distances, \ + paddle::Tensor& neighbors_index, \ + paddle::Tensor& neighbors_row_splits, \ + paddle::Tensor& neighbors_distance); + +INSTANTIATE(float, int32_t) +INSTANTIATE(float, int64_t) diff --git a/cpp/open3d/ml/paddle/misc/FixedRadiusSearchOps.cpp b/cpp/open3d/ml/paddle/misc/FixedRadiusSearchOps.cpp new file mode 100644 index 00000000000..26feb76bd4e --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/FixedRadiusSearchOps.cpp @@ -0,0 +1,190 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/core/Dtype.h" +#include "open3d/core/nns/NeighborSearchCommon.h" +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/utility/Helper.h" + +using namespace open3d::core::nns; + +template +void FixedRadiusSearchCPU(const paddle::Tensor& points, + const paddle::Tensor& queries, + double radius, + const paddle::Tensor& points_row_splits, + const paddle::Tensor& queries_row_splits, + const paddle::Tensor& hash_table_splits, + const paddle::Tensor& hash_table_index, + const paddle::Tensor& hash_table_cell_splits, + const Metric metric, + const bool ignore_query_point, + const bool return_distances, + paddle::Tensor& neighbors_index, + paddle::Tensor& neighbors_row_splits, + paddle::Tensor& neighbors_distance); +#ifdef BUILD_CUDA_MODULE +template +void FixedRadiusSearchCUDA(const paddle::Tensor& points, + const paddle::Tensor& queries, + double radius, + const paddle::Tensor& points_row_splits, + const paddle::Tensor& queries_row_splits, + const paddle::Tensor& hash_table_splits, + const paddle::Tensor& hash_table_index, + const paddle::Tensor& hash_table_cell_splits, + const Metric metric, + const bool ignore_query_point, + const bool return_distances, + paddle::Tensor& neighbors_index, + paddle::Tensor& neighbors_row_splits, + paddle::Tensor& neighbors_distance); +#endif + +std::vector FixedRadiusSearch( + paddle::Tensor& points, + paddle::Tensor& queries, + paddle::Tensor& points_row_splits, + paddle::Tensor& queries_row_splits, + paddle::Tensor& hash_table_splits, + paddle::Tensor& hash_table_index, + paddle::Tensor& hash_table_cell_splits, + double radius, + const std::string& index_dtype, + const std::string& metric_str, + const bool ignore_query_point, + const bool return_distances) { + Metric metric = L2; + if (metric_str == "L1") { + metric = L1; + } else if (metric_str == "L2") { + metric = L2; + } else if (metric_str == "Linf") { + metric = Linf; + } else { + PD_CHECK(false, + "metric must be one of (L1, L2, Linf) but got " + metric_str); + } + CHECK_TYPE(points_row_splits, paddle::DataType::INT64); + CHECK_TYPE(queries_row_splits, paddle::DataType::INT64); + CHECK_TYPE(hash_table_splits, paddle::DataType::INT32); + CHECK_TYPE(hash_table_index, paddle::DataType::INT32); + CHECK_TYPE(hash_table_cell_splits, paddle::DataType::INT32); + CHECK_SAME_DTYPE(points, queries); + CHECK_SAME_DEVICE_TYPE(points, queries); + // PD_CHECK(index_dtype == paddle::DataType::INT32 || index_dtype == + // paddle::DataType::INT64, + PD_CHECK(index_dtype == "int32" || index_dtype == "int64", + "index_dtype must be int32 or int64"); + // ensure that these are on the cpu + points_row_splits = points_row_splits.copy_to(paddle::CPUPlace(), false); + queries_row_splits = queries_row_splits.copy_to(paddle::CPUPlace(), false); + hash_table_splits = hash_table_splits.copy_to(paddle::CPUPlace(), false); + + // check input shapes + using namespace open3d::ml::op_util; + Dim num_points("num_points"); + Dim num_queries("num_queries"); + Dim batch_size("batch_size"); + Dim num_cells("num_cells"); + CHECK_SHAPE(points, num_points, 3); + CHECK_SHAPE(hash_table_index, num_points); + CHECK_SHAPE(queries, num_queries, 3); + CHECK_SHAPE(points_row_splits, batch_size + 1); + CHECK_SHAPE(queries_row_splits, batch_size + 1); + CHECK_SHAPE(hash_table_splits, batch_size + 1); + CHECK_SHAPE(hash_table_cell_splits, num_cells + 1); + + const auto& point_type = points.dtype(); + + auto place = points.place(); + + paddle::Tensor neighbors_index; + paddle::Tensor neighbors_row_splits = + paddle::empty({queries.shape()[0] + 1}, + paddle::DataType(ToPaddleDtype()), place); + paddle::Tensor neighbors_distance; + +#define FN_PARAMETERS \ + points, queries, radius, points_row_splits, queries_row_splits, \ + hash_table_splits, hash_table_index, hash_table_cell_splits, \ + metric, ignore_query_point, return_distances, neighbors_index, \ + neighbors_row_splits, neighbors_distance + + if (points.is_gpu()) { +#ifdef BUILD_CUDA_MODULE + // pass to cuda function + if (ComparePaddleDtype(point_type)) { + if (index_dtype == "int32") { + FixedRadiusSearchCUDA(FN_PARAMETERS); + } else { + FixedRadiusSearchCUDA(FN_PARAMETERS); + } + return {neighbors_index, neighbors_row_splits, neighbors_distance}; + } +#else + PD_CHECK(false, "FixedRadiusSearch was not compiled with CUDA support"); +#endif + } else { + if (ComparePaddleDtype(point_type)) { + if (index_dtype == "int32") { + FixedRadiusSearchCPU(FN_PARAMETERS); + } else { + FixedRadiusSearchCPU(FN_PARAMETERS); + } + } else { + if (index_dtype == "int32") { + FixedRadiusSearchCPU(FN_PARAMETERS); + } else { + FixedRadiusSearchCPU(FN_PARAMETERS); + } + } + return {neighbors_index, neighbors_row_splits, neighbors_distance}; + } + + // in torch the name is ToString, but paddle not have this function + PD_CHECK(false, "FixedRadiusSearch does not support " + + phi::DataTypeToString(points.dtype()) + + " as input for points"); + return std::vector(); +} + +std::vector FixedRadiusSearchInferDtype( + const std::string& index_dtype) { + paddle::DataType dtype = index_dtype == "int32" ? paddle::DataType::INT32 + : paddle::DataType::INT64; + return {dtype, paddle::DataType::INT64, dtype}; +} + +std::vector> FixedRadiusSearchInferShape( + std::vector queries_shape, const bool return_distances) { + // this just a temp impl , all return is fake data + // TODO(woodman3): impl real data + int64_t neighbors_row_splits_shape = queries_shape[0] + 1; + int64_t neighbors_distance_shape = return_distances ? 1 : 0; + return {{neighbors_row_splits_shape}, + {neighbors_row_splits_shape}, + {neighbors_distance_shape}}; +} + +PD_BUILD_OP(open3d_fixed_radius_search) + .Inputs({"points", "queries", "points_row_splits", "queries_row_splits", + "hash_table_splits", "hash_table_index", + "hash_table_cell_splits"}) + .Outputs({"neighbors_index", "neighbors_row_splits", + "neighbors_distance"}) + .Attrs({ + "radius: double", + "index_dtype:std::string", + "metric_str: std::string", + "ignore_query_point: bool", + "return_distances: bool", + }) + .SetKernelFn(PD_KERNEL(FixedRadiusSearch)) + .SetInferShapeFn(PD_INFER_SHAPE(FixedRadiusSearchInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(FixedRadiusSearchInferDtype)); diff --git a/cpp/open3d/ml/paddle/misc/InvertNeighborsListOpKernel.cpp b/cpp/open3d/ml/paddle/misc/InvertNeighborsListOpKernel.cpp new file mode 100644 index 00000000000..597167f0cb6 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/InvertNeighborsListOpKernel.cpp @@ -0,0 +1,67 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/ml/paddle/misc/InvertNeighborsListOpKernel.h" + +#include "open3d/ml/impl/misc/InvertNeighborsList.h" +#include "open3d/ml/paddle/PaddleHelper.h" + +template +std::vector InvertNeighborsListCPU( + int64_t num_points, + const paddle::Tensor& inp_neighbors_index, + const paddle::Tensor& inp_neighbors_row_splits, + const paddle::Tensor& inp_neighbors_attributes) { + paddle::Tensor neighbors_index = + paddle::empty(inp_neighbors_index.shape(), + paddle::DataType(ToPaddleDtype())); + paddle::Tensor neighbors_row_splits = paddle::empty( + {num_points + 1}, paddle::DataType(paddle::DataType::INT64)); + + paddle::Tensor neighbors_attributes = + paddle::empty_like(inp_neighbors_attributes); + + int num_attributes; + if (inp_neighbors_attributes.shape()[0] == 0) { + num_attributes = 0; + neighbors_attributes = + InitializedEmptyTensor(inp_neighbors_attributes.dtype(), + inp_neighbors_attributes.shape(), + inp_neighbors_attributes.place()); + + } else { + num_attributes = 1; + for (size_t i = 1; i < inp_neighbors_attributes.shape().size(); ++i) + num_attributes *= inp_neighbors_attributes.shape()[i]; + } + + open3d::ml::impl::InvertNeighborsListCPU( + inp_neighbors_index.data(), + num_attributes ? inp_neighbors_attributes.data() : nullptr, + num_attributes, inp_neighbors_row_splits.data(), + inp_neighbors_row_splits.shape()[0] - 1, + neighbors_index.data(), + num_attributes ? neighbors_attributes.data() : nullptr, + neighbors_index.shape()[0], neighbors_row_splits.data(), + neighbors_row_splits.shape()[0] - 1); + + return {neighbors_index, neighbors_row_splits, neighbors_attributes}; +} +#define INSTANTIATE(TIndex, TAttr) \ + template std::vector \ + InvertNeighborsListCPU(int64_t, const paddle::Tensor&, \ + const paddle::Tensor&, \ + const paddle::Tensor&); + +INSTANTIATE(int32_t, uint8_t) +INSTANTIATE(int32_t, int8_t) +INSTANTIATE(int32_t, int16_t) +INSTANTIATE(int32_t, int32_t) +INSTANTIATE(int32_t, int64_t) +INSTANTIATE(int32_t, float) +INSTANTIATE(int32_t, double) diff --git a/cpp/open3d/ml/paddle/misc/InvertNeighborsListOpKernel.cu b/cpp/open3d/ml/paddle/misc/InvertNeighborsListOpKernel.cu new file mode 100644 index 00000000000..276797a79e7 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/InvertNeighborsListOpKernel.cu @@ -0,0 +1,91 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/ml/impl/misc/InvertNeighborsList.cuh" +#include "open3d/ml/paddle/PaddleHelper.h" + +template +std::vector InvertNeighborsListCUDA( + int64_t num_points, + const paddle::Tensor& inp_neighbors_index, + const paddle::Tensor& inp_neighbors_row_splits, + const paddle::Tensor& inp_neighbors_attributes) { + auto place = inp_neighbors_index.place(); + paddle::Tensor neighbors_index = + paddle::empty(inp_neighbors_index.shape(), + paddle::DataType(ToPaddleDtype()), place); + paddle::Tensor neighbors_row_splits = paddle::empty( + {num_points + 1}, paddle::DataType(paddle::DataType::INT64), place); + paddle::Tensor neighbors_attributes = + paddle::empty_like(inp_neighbors_attributes); + + // maybe this can use torch's impl way ? + auto stream = inp_neighbors_index.stream(); + auto cuda_place_props = phi::backends::gpu::GetDeviceProperties(-1); + const int texture_alignment = cuda_place_props.textureAlignment; + + int num_attributes; + if (inp_neighbors_attributes.shape()[0] == 0) { + std::cout << inp_neighbors_attributes.dtype() << std::endl; + num_attributes = 0; + neighbors_attributes = + InitializedEmptyTensor(inp_neighbors_attributes.dtype(), + inp_neighbors_attributes.shape(), + inp_neighbors_attributes.place()); + } else { + num_attributes = 1; + for (int i = 1; i < inp_neighbors_attributes.dims().size(); ++i) + num_attributes *= inp_neighbors_attributes.shape()[i]; + } + + void* temp_ptr = nullptr; + size_t temp_size = 0; + + // determine temp_size + open3d::ml::impl::InvertNeighborsListCUDA( + stream, temp_ptr, temp_size, texture_alignment, + inp_neighbors_index.data(), + num_attributes ? inp_neighbors_attributes.data() : nullptr, + num_attributes, inp_neighbors_row_splits.data(), + inp_neighbors_row_splits.shape()[0] - 1, + neighbors_index.data(), + num_attributes ? neighbors_attributes.data() : nullptr, + neighbors_index.shape()[0], + neighbors_row_splits.data(), // NOLINT + neighbors_row_splits.shape()[0] - 1); + + auto temp_tensor = CreateTempTensor(temp_size, place, &temp_ptr); + + // actually invert the list + open3d::ml::impl::InvertNeighborsListCUDA( + stream, temp_ptr, temp_size, texture_alignment, + inp_neighbors_index.data(), + num_attributes ? inp_neighbors_attributes.data() : nullptr, + num_attributes, inp_neighbors_row_splits.data(), + inp_neighbors_row_splits.shape()[0] - 1, + neighbors_index.data(), + num_attributes ? neighbors_attributes.data() : nullptr, + neighbors_index.shape()[0], + neighbors_row_splits.data(), // NOLINT + neighbors_row_splits.shape()[0] - 1); + + return {neighbors_index, neighbors_row_splits, neighbors_attributes}; +} +#define INSTANTIATE(TIndex, TAttr) \ + template std::vector \ + InvertNeighborsListCUDA(int64_t, const paddle::Tensor&, \ + const paddle::Tensor&, \ + const paddle::Tensor&); + +INSTANTIATE(int32_t, uint8_t) +INSTANTIATE(int32_t, int8_t) +INSTANTIATE(int32_t, int16_t) +INSTANTIATE(int32_t, int32_t) +INSTANTIATE(int32_t, int64_t) +INSTANTIATE(int32_t, float) +INSTANTIATE(int32_t, double) diff --git a/cpp/open3d/ml/paddle/misc/InvertNeighborsListOpKernel.h b/cpp/open3d/ml/paddle/misc/InvertNeighborsListOpKernel.h new file mode 100644 index 00000000000..f97500abc55 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/InvertNeighborsListOpKernel.h @@ -0,0 +1,26 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// +#pragma once + +#include "open3d/ml/paddle/PaddleHelper.h" + +template +std::vector InvertNeighborsListCPU( + int64_t num_points, + const paddle::Tensor& inp_neighbors_index, + const paddle::Tensor& inp_neighbors_row_splits, + const paddle::Tensor& inp_neighbors_attributes); + +#ifdef BUILD_CUDA_MODULE +template +std::vector InvertNeighborsListCUDA( + int64_t num_points, + const paddle::Tensor& inp_neighbors_index, + const paddle::Tensor& inp_neighbors_row_splits, + const paddle::Tensor& inp_neighbors_attributes); +#endif diff --git a/cpp/open3d/ml/paddle/misc/InvertNeighborsListOps.cpp b/cpp/open3d/ml/paddle/misc/InvertNeighborsListOps.cpp new file mode 100644 index 00000000000..f9717b86e91 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/InvertNeighborsListOps.cpp @@ -0,0 +1,102 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include + +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/ml/paddle/misc/InvertNeighborsListOpKernel.h" + +std::vector InvertNeighborsList( + paddle::Tensor& inp_neighbors_index, + paddle::Tensor& inp_neighbors_row_splits, + paddle::Tensor& inp_neighbors_attributes, + int64_t num_points) { + CHECK_TYPE(inp_neighbors_row_splits, paddle::DataType::INT64); + + // check input shapes + { + using namespace open3d::ml::op_util; + Dim num_neighbors("num_neighbors"); + + CHECK_SHAPE(inp_neighbors_index, num_neighbors); + CHECK_SHAPE_IGNORE_LAST_DIMS(inp_neighbors_attributes, + num_neighbors || 0); + CHECK_SHAPE(inp_neighbors_row_splits, Dim()); + } + + const auto& index_type = inp_neighbors_index.dtype(); + const auto& attr_type = inp_neighbors_attributes.dtype(); + +#define FN_PARAMETERS \ + num_points, inp_neighbors_index, inp_neighbors_row_splits, \ + inp_neighbors_attributes + +#define CALL(idx_t, attr_t, fn) \ + if (ComparePaddleDtype(index_type) && \ + ComparePaddleDtype(attr_type)) { \ + return fn(FN_PARAMETERS); \ + } + + CHECK_SAME_DEVICE_TYPE(inp_neighbors_index, inp_neighbors_row_splits, + inp_neighbors_attributes); + if (inp_neighbors_index.is_gpu()) { +#ifdef BUILD_CUDA_MODULE + // pass to cuda function + CALL(int32_t, uint8_t, InvertNeighborsListCUDA) + CALL(int32_t, int8_t, InvertNeighborsListCUDA) + CALL(int32_t, int16_t, InvertNeighborsListCUDA) + CALL(int32_t, int32_t, InvertNeighborsListCUDA) + CALL(int32_t, int64_t, InvertNeighborsListCUDA) + CALL(int32_t, float, InvertNeighborsListCUDA) + CALL(int32_t, double, InvertNeighborsListCUDA) +#else + PD_CHECK(false, + "InvertNeighborsList was not compiled with CUDA support"); +#endif + } else { + CALL(int32_t, uint8_t, InvertNeighborsListCPU) + CALL(int32_t, int8_t, InvertNeighborsListCPU) + CALL(int32_t, int16_t, InvertNeighborsListCPU) + CALL(int32_t, int32_t, InvertNeighborsListCPU) + CALL(int32_t, int64_t, InvertNeighborsListCPU) + CALL(int32_t, float, InvertNeighborsListCPU) + CALL(int32_t, double, InvertNeighborsListCPU) + } + + PD_CHECK(false, + "InvertNeighborsList does not support " + + phi::DataTypeToString(inp_neighbors_index.dtype()) + + " as input for inp_neighbors_index and " + + phi::DataTypeToString(inp_neighbors_attributes.dtype()) + + " as input for inp_neighbors_attributes"); + return {}; +} + +std::vector InvertNeighborsListInferDtype( + const paddle::DataType inp_neighbors_attributes_dtype) { + return {paddle::DataType::INT32, paddle::DataType::INT64, + inp_neighbors_attributes_dtype}; +} + +std::vector> InvertNeighborsListInferShape( + int64_t num_points, + std::vector inp_neighbors_index_shape, + std::vector inp_neighbors_attributes_shape) { + return {inp_neighbors_index_shape, + {num_points + 1}, + inp_neighbors_attributes_shape}; +} +PD_BUILD_OP(open3d_invert_neighbors_list) + .Inputs({"inp_neighbors_index", "inp_neighbors_row_splits", + "inp_neighbors_attributes"}) + .Outputs({"neighbors_index", "neighbors_row_splits", + "neighbors_attributes"}) + .Attrs({"num_points: int64_t"}) + .SetKernelFn(PD_KERNEL(InvertNeighborsList)) + .SetInferShapeFn(PD_INFER_SHAPE(InvertNeighborsListInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(InvertNeighborsListInferDtype)); diff --git a/cpp/open3d/ml/paddle/misc/InvertNeighborsListOps.h b/cpp/open3d/ml/paddle/misc/InvertNeighborsListOps.h new file mode 100644 index 00000000000..f0d836bbf88 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/InvertNeighborsListOps.h @@ -0,0 +1,18 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +// this file seem not use +#pragma once + +#include "open3d/ml/paddle/PaddleHelper.h" + +std::vector InvertNeighborsList( + int64_t num_points, + const paddle::Tensor& inp_neighbors_index, + const paddle::Tensor& inp_neighbors_row_splits, + const paddle::Tensor& inp_neighbors_attributes); diff --git a/cpp/open3d/ml/paddle/misc/KnnSearchOpKernel.cpp b/cpp/open3d/ml/paddle/misc/KnnSearchOpKernel.cpp new file mode 100644 index 00000000000..ebc90a07a5e --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/KnnSearchOpKernel.cpp @@ -0,0 +1,120 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/core/nns/NanoFlannImpl.h" +#include "open3d/core/nns/NeighborSearchCommon.h" +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/ml/paddle/misc/NeighborSearchAllocator.h" + +using namespace open3d::core::nns; + +template +void KnnSearchCPU(const paddle::Tensor& points, + const paddle::Tensor& queries, + const int64_t k, + const paddle::Tensor& points_row_splits, + const paddle::Tensor& queries_row_splits, + const Metric metric, + const bool ignore_query_point, + const bool return_distances, + paddle::Tensor& neighbors_index, + paddle::Tensor& neighbors_row_splits, + paddle::Tensor& neighbors_distance) { + const int batch_size = points_row_splits.shape()[0] - 1; + + // run radius search for each batch item + std::vector> batch_output_allocators( + batch_size, NeighborSearchAllocator(points.place()) + + ); + int64_t last_neighbors_count = 0; + for (int i = 0; i < batch_size; ++i) { + const T* const points_i = + points.data() + 3 * points_row_splits.data()[i]; + const T* const queries_i = + queries.data() + 3 * queries_row_splits.data()[i]; + size_t num_points_i = points_row_splits.data()[i + 1] - + points_row_splits.data()[i]; + size_t num_queries_i = queries_row_splits.data()[i + 1] - + queries_row_splits.data()[i]; + + int64_t* neighbors_row_splits_i = neighbors_row_splits.data() + + queries_row_splits.data()[i]; + + std::unique_ptr holder = + impl::BuildKdTree(num_points_i, points_i, 3, metric); + + impl::KnnSearchCPU( + holder.get(), neighbors_row_splits_i, num_points_i, points_i, + num_queries_i, queries_i, 3, k, metric, ignore_query_point, + return_distances, batch_output_allocators[i]); + + if (i > 0) { + for (size_t j = 0; j <= num_queries_i; ++j) + neighbors_row_splits_i[j] += last_neighbors_count; + } + last_neighbors_count = neighbors_row_splits_i[num_queries_i]; + } + + if (batch_size == 1) { + // no need to combine just return the results from the first batch item + neighbors_index = batch_output_allocators[0].NeighborsIndex(); + neighbors_distance = batch_output_allocators[0].NeighborsDistance(); + return; + } + + NeighborSearchAllocator output_allocator(points.place()); + + // combine results + int64_t neighbors_index_size = 0; + int64_t neighbors_distance_size = 0; + for (const auto& a : batch_output_allocators) { + neighbors_index_size += a.NeighborsIndex().shape()[0]; + neighbors_distance_size += a.NeighborsDistance().shape()[0]; + } + TIndex* neighbors_index_data_ptr; + T* neighbors_distance_data_ptr; + output_allocator.AllocIndices(&neighbors_index_data_ptr, + neighbors_index_size); + output_allocator.AllocDistances(&neighbors_distance_data_ptr, + neighbors_distance_size); + + for (int i = 0; i < batch_size; ++i) { + auto& a = batch_output_allocators[i]; + if (a.NeighborsIndex().shape()[0]) { + for (int64_t j = 0; j < a.NeighborsIndex().shape()[0]; ++j) { + neighbors_index_data_ptr[0] = + a.IndicesPtr()[j] + + points_row_splits.data()[i]; + ++neighbors_index_data_ptr; + } + } + if (a.NeighborsDistance().shape()[0]) { + memcpy(neighbors_distance_data_ptr, a.DistancesPtr(), + a.NeighborsDistance().shape()[0] * sizeof(T)); + neighbors_distance_data_ptr += a.NeighborsDistance().shape()[0]; + } + } + neighbors_index = output_allocator.NeighborsIndex(); + neighbors_distance = output_allocator.NeighborsDistance(); +} + +#define INSTANTIATE(T, TIndex) \ + template void KnnSearchCPU( \ + const paddle::Tensor& points, const paddle::Tensor& queries, \ + const int64_t k, const paddle::Tensor& points_row_splits, \ + const paddle::Tensor& queries_row_splits, const Metric metric, \ + const bool ignore_query_point, const bool return_distances, \ + paddle::Tensor& neighbors_index, \ + paddle::Tensor& neighbors_row_splits, \ + paddle::Tensor& neighbors_distance); + +INSTANTIATE(float, int32_t) +INSTANTIATE(float, int64_t) +INSTANTIATE(double, int32_t) +INSTANTIATE(double, int64_t) diff --git a/cpp/open3d/ml/paddle/misc/KnnSearchOps.cpp b/cpp/open3d/ml/paddle/misc/KnnSearchOps.cpp new file mode 100644 index 00000000000..e1d9f143253 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/KnnSearchOps.cpp @@ -0,0 +1,140 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include + +#include "open3d/core/Dtype.h" +#include "open3d/core/nns/NeighborSearchCommon.h" +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/utility/Helper.h" +#include "paddle/extension.h" + +using namespace open3d::core::nns; + +template +void KnnSearchCPU(const paddle::Tensor& points, + const paddle::Tensor& queries, + const int64_t k, + const paddle::Tensor& points_row_splits, + const paddle::Tensor& queries_row_splits, + const Metric metric, + const bool ignore_query_point, + const bool return_distances, + paddle::Tensor& neighbors_index, + paddle::Tensor& neighbors_row_splits, + paddle::Tensor& neighbors_distance); + +std::vector KnnSearch(paddle::Tensor& points, + paddle::Tensor& queries, + paddle::Tensor& points_row_splits, + paddle::Tensor& queries_row_splits, + const int64_t k, + const std::string& index_dtype, + const std::string& metric_str, + const bool ignore_query_point, + const bool return_distances) { + Metric metric = L2; + if (metric_str == "L1") { + metric = L1; + } else if (metric_str == "L2") { + metric = L2; + } else { + PD_CHECK(false, "metric must be one of (L1, L2) but got " + metric_str); + } + PD_CHECK(k > 0, "k must be greater than zero"); + CHECK_TYPE(points_row_splits, phi::DataType::INT64); + CHECK_TYPE(queries_row_splits, phi::DataType::INT64); + CHECK_SAME_DTYPE(points, queries); + CHECK_SAME_DEVICE_TYPE(points, queries); + PD_CHECK(index_dtype == "int32" || index_dtype == "int64", + "index_dtype must be int32 or int64"); + // ensure that these are on the cpu + points_row_splits = points_row_splits.copy_to(paddle::CPUPlace(), false); + queries_row_splits = queries_row_splits.copy_to(paddle::CPUPlace(), false); + + // check input shapes + using namespace open3d::ml::op_util; + Dim num_points("num_points"); + Dim num_queries("num_queries"); + Dim batch_size("batch_size"); + Dim num_cells("num_cells"); + CHECK_SHAPE(points, num_points, 3); + CHECK_SHAPE(queries, num_queries, 3); + CHECK_SHAPE(points_row_splits, batch_size + 1); + CHECK_SHAPE(queries_row_splits, batch_size + 1); + + const auto& point_type = points.dtype(); + + auto place = points.place(); + + paddle::Tensor neighbors_index; + paddle::Tensor neighbors_row_splits = + paddle::empty({queries.shape()[0] + 1}, + paddle::DataType(ToPaddleDtype()), place); + paddle::Tensor neighbors_distance; + +#define FN_PARAMETERS \ + points, queries, k, points_row_splits, queries_row_splits, metric, \ + ignore_query_point, return_distances, neighbors_index, \ + neighbors_row_splits, neighbors_distance + + if (points.is_gpu()) { + PD_CHECK(false, "KnnSearch does not support CUDA"); + } else { + if (ComparePaddleDtype(point_type)) { + if (index_dtype == "int32") { + KnnSearchCPU(FN_PARAMETERS); + } else { + KnnSearchCPU(FN_PARAMETERS); + } + } else { + if (index_dtype == "int32") { + KnnSearchCPU(FN_PARAMETERS); + } else { + KnnSearchCPU(FN_PARAMETERS); + } + } + return {neighbors_index, neighbors_row_splits, neighbors_distance}; + } + PD_CHECK(false, "KnnSearch does not support " + + phi::DataTypeToString(points.dtype()) + + " as input for points"); + return std::vector(); +} + +std::vector KnnSearchInferDtype( + const std::string& index_dtype) { + paddle::DataType dtype = index_dtype == "int32" ? paddle::DataType::INT32 + : paddle::DataType::INT64; + return {dtype, paddle::DataType::INT64, dtype}; +} + +std::vector> KnnSearchInferShape( + std::vector queries_shape, const bool return_distances) { + int64_t neighbors_row_splits_shape = queries_shape[0] + 1; + int64_t neighbors_distance_shape = return_distances ? 1 : 0; + return {{neighbors_row_splits_shape}, + {neighbors_row_splits_shape}, + {neighbors_distance_shape}}; +} + +PD_BUILD_OP(open3d_knn_search) + .Inputs({"points", "queries", "points_row_splits", + "queries_row_splits"}) + .Outputs({"neighbors_index", "neighbors_row_splits", + "neighbors_distance"}) + .Attrs({ + "k: int64_t", + "index_dtype:std::string", + "metric_str: std::string", + "ignore_query_point: bool", + "return_distances: bool", + }) + .SetKernelFn(PD_KERNEL(KnnSearch)) + .SetInferShapeFn(PD_INFER_SHAPE(KnnSearchInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(KnnSearchInferDtype)); diff --git a/cpp/open3d/ml/paddle/misc/NeighborSearchAllocator.h b/cpp/open3d/ml/paddle/misc/NeighborSearchAllocator.h new file mode 100644 index 00000000000..05f631ba56e --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/NeighborSearchAllocator.h @@ -0,0 +1,54 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/ml/paddle/PaddleHelper.h" + +// These classes implement functors that can be passed to the neighbor search +// functions. + +template +class NeighborSearchAllocator { +public: + NeighborSearchAllocator(paddle::Place place) : place(place) {} + + void AllocIndices(TIndex** ptr, size_t num) { + if (num == 0) { + neighbors_index = InitializedEmptyTensor({0}, place); + } else { + neighbors_index = paddle::empty( + {int64_t(num)}, paddle::DataType(ToPaddleDtype()), + place); + } + *ptr = neighbors_index.data(); + } + + void AllocDistances(T** ptr, size_t num) { + if (num == 0) { + neighbors_distance = InitializedEmptyTensor({0}, place); + } else { + neighbors_distance = + paddle::empty({int64_t(num)}, + paddle::DataType(ToPaddleDtype()), place); + } + *ptr = neighbors_distance.data(); + } + + const TIndex* IndicesPtr() const { return neighbors_index.data(); } + + const T* DistancesPtr() const { return neighbors_distance.data(); } + + const paddle::Tensor& NeighborsIndex() const { return neighbors_index; } + const paddle::Tensor& NeighborsDistance() const { + return neighbors_distance; + } + +private: + paddle::Tensor neighbors_index; + paddle::Tensor neighbors_distance; + paddle::Place place; +}; diff --git a/cpp/open3d/ml/paddle/misc/RadiusSearchOpKernel.cpp b/cpp/open3d/ml/paddle/misc/RadiusSearchOpKernel.cpp new file mode 100644 index 00000000000..be25cb4c847 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/RadiusSearchOpKernel.cpp @@ -0,0 +1,125 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/core/nns/NanoFlannImpl.h" +#include "open3d/core/nns/NeighborSearchCommon.h" +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/ml/paddle/misc/NeighborSearchAllocator.h" + +using namespace open3d::core::nns; + +template +void RadiusSearchCPU(const paddle::Tensor& points, + const paddle::Tensor& queries, + const paddle::Tensor& radii, + const paddle::Tensor& points_row_splits, + const paddle::Tensor& queries_row_splits, + const Metric metric, + const bool ignore_query_point, + const bool return_distances, + const bool normalize_distances, + paddle::Tensor& neighbors_index, + paddle::Tensor& neighbors_row_splits, + paddle::Tensor& neighbors_distance) { + const int batch_size = points_row_splits.shape()[0] - 1; + // run radius search for each batch item + std::vector> batch_output_allocators( + batch_size, NeighborSearchAllocator(points.place()) + + ); + int64_t last_neighbors_count = 0; + for (int i = 0; i < batch_size; ++i) { + const T* const points_i = + points.data() + 3 * points_row_splits.data()[i]; + const T* const queries_i = + queries.data() + 3 * queries_row_splits.data()[i]; + const T* const radius_i = + radii.data() + queries_row_splits.data()[i]; + size_t num_points_i = points_row_splits.data()[i + 1] - + points_row_splits.data()[i]; + size_t num_queries_i = queries_row_splits.data()[i + 1] - + queries_row_splits.data()[i]; + + int64_t* neighbors_row_splits_i = neighbors_row_splits.data() + + queries_row_splits.data()[i]; + + std::unique_ptr holder = + impl::BuildKdTree(num_points_i, points_i, 3, metric); + + impl::RadiusSearchCPU( + holder.get(), neighbors_row_splits_i, num_points_i, points_i, + num_queries_i, queries_i, 3, radius_i, metric, + ignore_query_point, return_distances, normalize_distances, + /* sort */ false, batch_output_allocators[i]); + + if (i > 0) { + for (size_t j = 0; j <= num_queries_i; ++j) + neighbors_row_splits_i[j] += last_neighbors_count; + } + last_neighbors_count = neighbors_row_splits_i[num_queries_i]; + } + + if (batch_size == 1) { + // no need to combine just return the results from the first batch + // item + neighbors_index = batch_output_allocators[0].NeighborsIndex(); + neighbors_distance = batch_output_allocators[0].NeighborsDistance(); + return; + } + + NeighborSearchAllocator output_allocator(points.place()); + + // combine results + int64_t neighbors_index_size = 0; + int64_t neighbors_distance_size = 0; + for (const auto& a : batch_output_allocators) { + neighbors_index_size += a.NeighborsIndex().shape()[0]; + neighbors_distance_size += a.NeighborsDistance().shape()[0]; + } + TIndex* neighbors_index_data_ptr; + T* neighbors_distance_data_ptr; + output_allocator.AllocIndices(&neighbors_index_data_ptr, + neighbors_index_size); + output_allocator.AllocDistances(&neighbors_distance_data_ptr, + neighbors_distance_size); + + for (int i = 0; i < batch_size; ++i) { + const auto& a = batch_output_allocators[i]; + if (a.NeighborsIndex().shape()[0]) { + for (int64_t j = 0; j < a.NeighborsIndex().shape()[0]; ++j) { + neighbors_index_data_ptr[0] = + a.IndicesPtr()[j] + + points_row_splits.data()[i]; + ++neighbors_index_data_ptr; + } + } + if (a.NeighborsDistance().shape()[0]) { + memcpy(neighbors_distance_data_ptr, a.DistancesPtr(), + a.NeighborsDistance().shape()[0] * sizeof(T)); + neighbors_distance_data_ptr += a.NeighborsDistance().shape()[0]; + } + } + neighbors_index = output_allocator.NeighborsIndex(); + neighbors_distance = output_allocator.NeighborsDistance(); +} + +#define INSTANTIATE(T, TIndex) \ + template void RadiusSearchCPU( \ + const paddle::Tensor& points, const paddle::Tensor& queries, \ + const paddle::Tensor& radii, \ + const paddle::Tensor& points_row_splits, \ + const paddle::Tensor& queries_row_splits, const Metric metric, \ + const bool ignore_query_point, const bool return_distances, \ + const bool normalize_distances, paddle::Tensor& neighbors_index, \ + paddle::Tensor& neighbors_row_splits, \ + paddle::Tensor& neighbors_distance); + +INSTANTIATE(float, int32_t) +INSTANTIATE(float, int64_t) +INSTANTIATE(double, int32_t) +INSTANTIATE(double, int64_t) diff --git a/cpp/open3d/ml/paddle/misc/RadiusSearchOps.cpp b/cpp/open3d/ml/paddle/misc/RadiusSearchOps.cpp new file mode 100644 index 00000000000..4f48057422c --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/RadiusSearchOps.cpp @@ -0,0 +1,141 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include + +#include "open3d/core/Dtype.h" +#include "open3d/core/nns/NeighborSearchCommon.h" +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/utility/Helper.h" + +using namespace open3d::core::nns; + +template +void RadiusSearchCPU(const paddle::Tensor& points, + const paddle::Tensor& queries, + const paddle::Tensor& radii, + const paddle::Tensor& points_row_splits, + const paddle::Tensor& queries_row_splits, + const Metric metric, + const bool ignore_query_point, + const bool return_distances, + const bool normalize_distances, + paddle::Tensor& neighbors_index, + paddle::Tensor& neighbors_row_splits, + paddle::Tensor& neighbors_distance); + +std::vector MultiRadiusSearch( + paddle::Tensor& points, + paddle::Tensor& queries, + paddle::Tensor& radii, + paddle::Tensor& points_row_splits, + paddle::Tensor& queries_row_splits, + const std::string& index_dtype, + const std::string& metric_str, + const bool ignore_query_point, + const bool return_distances, + const bool normalize_distances) { + Metric metric = L2; + if (metric_str == "L1") { + metric = L1; + } else if (metric_str == "L2") { + metric = L2; + } else { + PD_CHECK(false, "metric must be one of (L1, L2) but got " + metric_str); + } + CHECK_TYPE(points_row_splits, paddle::DataType::INT64); + CHECK_TYPE(queries_row_splits, paddle::DataType::INT64); + CHECK_SAME_DTYPE(points, queries, radii); + CHECK_SAME_DEVICE_TYPE(points, queries, radii); + PD_CHECK(index_dtype == "int32" || index_dtype == "int64", + "index_dtype must be int32 or int64"); + // ensure that these are on the cpu + points_row_splits = points_row_splits.copy_to(paddle::CPUPlace(), false); + queries_row_splits = queries_row_splits.copy_to(paddle::CPUPlace(), false); + + // check input shapes + using namespace open3d::ml::op_util; + Dim num_points("num_points"); + Dim num_queries("num_queries"); + Dim batch_size("batch_size"); + Dim num_cells("num_cells"); + CHECK_SHAPE(points, num_points, 3); + CHECK_SHAPE(queries, num_queries, 3); + CHECK_SHAPE(radii, num_queries); + CHECK_SHAPE(points_row_splits, batch_size + 1); + CHECK_SHAPE(queries_row_splits, batch_size + 1); + + const auto& point_type = points.dtype(); + + auto place = points.place(); + + paddle::Tensor neighbors_index; + paddle::Tensor neighbors_row_splits = + paddle::empty({queries.shape()[0] + 1}, + paddle::DataType(ToPaddleDtype()), place); + paddle::Tensor neighbors_distance; + +#define FN_PARAMETERS \ + points, queries, radii, points_row_splits, queries_row_splits, metric, \ + ignore_query_point, return_distances, normalize_distances, \ + neighbors_index, neighbors_row_splits, neighbors_distance + + if (points.is_gpu()) { + PD_CHECK(false, "MultiRadiusSearch does not support CUDA"); + } else { + if (ComparePaddleDtype(point_type)) { + if (index_dtype == "int32") { + RadiusSearchCPU(FN_PARAMETERS); + } else { + RadiusSearchCPU(FN_PARAMETERS); + } + } else { + if (index_dtype == "int32") { + RadiusSearchCPU(FN_PARAMETERS); + } else { + RadiusSearchCPU(FN_PARAMETERS); + } + } + return {neighbors_index, neighbors_row_splits, neighbors_distance}; + } + // same question of fixed_radius_search + PD_CHECK(false, "MultiRadiusSearch does not support " + + phi::DataTypeToString(points.dtype()) + + " as input for points"); + return {neighbors_index, neighbors_row_splits, neighbors_distance}; +} + +std::vector MultiRadiusSearchInferDtype( + const std::string& index_dtype) { + paddle::DataType dtype = index_dtype == "int32" ? paddle::DataType::INT32 + : paddle::DataType::INT64; + return {dtype, paddle::DataType::INT64, dtype}; +} + +std::vector> MultiRadiusSearchInferShape( + std::vector queries_shape, const bool return_distances) { + // this just a temp impl , all return is fake data + // TODO(woodman3): impl real data + int64_t neighbors_row_splits_shape = queries_shape[0] + 1; + int64_t neighbors_distance_shape = return_distances ? 1 : 0; + return {{neighbors_row_splits_shape}, + {neighbors_row_splits_shape}, + {neighbors_distance_shape}}; +} + +PD_BUILD_OP(open3d_radius_search) + .Inputs({"points", "queries", "radii", "points_row_splits", + "queries_row_splits"}) + .Outputs({"neighbors_index", "neighbors_row_splits", + "neighbors_distance"}) + .Attrs({"index_dtype: std::string", "metric_str: std::string", + "ignore_query_point: bool", "return_distances: bool", + "normalize_distances: bool"}) + .SetKernelFn(PD_KERNEL(MultiRadiusSearch)) + .SetInferShapeFn(PD_INFER_SHAPE(MultiRadiusSearchInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(MultiRadiusSearchInferDtype)); diff --git a/cpp/open3d/ml/paddle/misc/RaggedToDenseOpKernel.cpp b/cpp/open3d/ml/paddle/misc/RaggedToDenseOpKernel.cpp new file mode 100644 index 00000000000..64fec24034a --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/RaggedToDenseOpKernel.cpp @@ -0,0 +1,45 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/ml/paddle/misc/RaggedToDenseOpKernel.h" + +#include "open3d/ml/impl/misc/RaggedToDense.h" +#include "open3d/ml/paddle/PaddleHelper.h" + +template +paddle::Tensor RaggedToDenseCPU(const paddle::Tensor& values, + const paddle::Tensor& row_splits, + const int64_t out_col_size, + const paddle::Tensor& default_value) { + auto out_shape = values.shape(); + out_shape.erase(out_shape.begin()); + out_shape.insert(out_shape.begin(), + {row_splits.shape()[0] - 1, out_col_size}); + paddle::Tensor out = + paddle::empty(out_shape, paddle::DataType(ToPaddleDtype())); + + open3d::ml::impl::RaggedToDenseCPU( + values.data(), row_splits.data(), row_splits.shape()[0], + out_col_size, default_value.data(), default_value.numel(), + out.data()); + + return out; +} + +#define INSTANTIATE(T) \ + template paddle::Tensor RaggedToDenseCPU( \ + const paddle::Tensor&, const paddle::Tensor&, const int64_t, \ + const paddle::Tensor&); + +INSTANTIATE(uint8_t) +INSTANTIATE(int8_t) +INSTANTIATE(int16_t) +INSTANTIATE(int32_t) +INSTANTIATE(int64_t) +INSTANTIATE(float) +INSTANTIATE(double) diff --git a/cpp/open3d/ml/paddle/misc/RaggedToDenseOpKernel.cu b/cpp/open3d/ml/paddle/misc/RaggedToDenseOpKernel.cu new file mode 100644 index 00000000000..ae3677bf035 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/RaggedToDenseOpKernel.cu @@ -0,0 +1,48 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include "open3d/ml/impl/misc/RaggedToDense.cuh" +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/ml/paddle/misc/RaggedToDenseOpKernel.h" +#include "paddle/extension.h" + +template +paddle::Tensor RaggedToDenseCUDA(const paddle::Tensor& values, + const paddle::Tensor& row_splits, + const int64_t out_col_size, + const paddle::Tensor& default_value) { + auto out_shape = values.shape(); + out_shape.erase(out_shape.begin()); + out_shape.insert(out_shape.begin(), + {row_splits.shape()[0] - 1, out_col_size}); + auto place = values.place(); + paddle::Tensor out = paddle::empty( + out_shape, paddle::DataType(ToPaddleDtype()), place); + + auto stream = values.stream(); + + open3d::ml::impl::RaggedToDenseCUDA( + stream, values.data(), row_splits.data(), + row_splits.shape()[0], out_col_size, default_value.data(), + default_value.numel(), out.data()); + + return out; +} + +#define INSTANTIATE(T) \ + template paddle::Tensor RaggedToDenseCUDA( \ + const paddle::Tensor&, const paddle::Tensor&, const int64_t, \ + const paddle::Tensor&); + +INSTANTIATE(uint8_t) +INSTANTIATE(int8_t) +INSTANTIATE(int16_t) +INSTANTIATE(int32_t) +INSTANTIATE(int64_t) +INSTANTIATE(float) +INSTANTIATE(double) diff --git a/cpp/open3d/ml/paddle/misc/RaggedToDenseOpKernel.h b/cpp/open3d/ml/paddle/misc/RaggedToDenseOpKernel.h new file mode 100644 index 00000000000..1834c710979 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/RaggedToDenseOpKernel.h @@ -0,0 +1,24 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// +#pragma once + +#include "paddle/extension.h" + +template +paddle::Tensor RaggedToDenseCPU(const paddle::Tensor& values, + const paddle::Tensor& row_splits, + const int64_t out_col_size, + const paddle::Tensor& default_value); + +#ifdef BUILD_CUDA_MODULE +template +paddle::Tensor RaggedToDenseCUDA(const paddle::Tensor& values, + const paddle::Tensor& row_splits, + const int64_t out_col_size, + const paddle::Tensor& default_value); +#endif diff --git a/cpp/open3d/ml/paddle/misc/RaggedToDenseOps.cpp b/cpp/open3d/ml/paddle/misc/RaggedToDenseOps.cpp new file mode 100644 index 00000000000..60ea9885dd2 --- /dev/null +++ b/cpp/open3d/ml/paddle/misc/RaggedToDenseOps.cpp @@ -0,0 +1,111 @@ +// ---------------------------------------------------------------------------- +// - Open3D: www.open3d.org - +// ---------------------------------------------------------------------------- +// Copyright (c) 2018-2024 www.open3d.org +// SPDX-License-Identifier: MIT +// ---------------------------------------------------------------------------- +// + +#include + +#include "open3d/ml/paddle/PaddleHelper.h" +#include "open3d/ml/paddle/misc/RaggedToDenseOpKernel.h" +#include "paddle/extension.h" + +std::vector RaggedToDense(paddle::Tensor& values, + paddle::Tensor& row_splits, + paddle::Tensor& default_value, + const int64_t out_col_size) { + CHECK_TYPE(row_splits, phi::DataType::INT64); + CHECK_SAME_DTYPE(values, default_value); + + // check input shapes + { + using namespace open3d::ml::op_util; + Dim num_rows("num_rows"); + CHECK_SHAPE(row_splits, num_rows + 1); + if (default_value.shape().size()) { + Dim item_size("item_size"); + CHECK_SHAPE_COMBINE_LAST_DIMS(default_value, item_size); + CHECK_SHAPE_COMBINE_LAST_DIMS(values, Dim(), item_size); + auto value_shape = values.shape(); + + // check shape tail + std::vector item_shape(value_shape.begin() + 1, + value_shape.end()); + auto default_value_shape = default_value.shape(); + PD_CHECK(default_value_shape == item_shape, + "default_value " + + phi::DataTypeToString(default_value.dtype()) + + "has incompatible with the shape of items in " + "values" + + TensorInfoStr({values})); + } else // scalar default_value + { + Dim num_values("num_values"); + CHECK_SHAPE_COMBINE_LAST_DIMS(values, num_values); + } + } + + // make sure everything is on the same place as 'values' + auto place = values.place(); + row_splits = row_splits.copy_to(place, false); + default_value = default_value.copy_to(place, false); + + const auto& value_type = values.dtype(); + +#define CALL(value_t, fn) \ + if (ComparePaddleDtype(value_type)) { \ + return {fn(values, row_splits, out_col_size, default_value)}; \ + } + + if (values.is_gpu()) { +#ifdef BUILD_CUDA_MODULE + // pass to cuda function + CALL(uint8_t, RaggedToDenseCUDA) + CALL(int8_t, RaggedToDenseCUDA) + CALL(int16_t, RaggedToDenseCUDA) + CALL(int32_t, RaggedToDenseCUDA) + CALL(int64_t, RaggedToDenseCUDA) + CALL(float, RaggedToDenseCUDA) + CALL(double, RaggedToDenseCUDA) +#else + PD_CHECK(false, "RaggedToDense was not compiled with CUDA support"); +#endif + } else { + CALL(uint8_t, RaggedToDenseCPU) + CALL(int8_t, RaggedToDenseCPU) + CALL(int16_t, RaggedToDenseCPU) + CALL(int32_t, RaggedToDenseCPU) + CALL(int64_t, RaggedToDenseCPU) + CALL(float, RaggedToDenseCPU) + CALL(double, RaggedToDenseCPU) + } + PD_CHECK(false, "RaggedToDense does not support " + + phi::DataTypeToString(values.dtype()) + + " as input for values"); +} + +std::vector RaggedToDenseInferDtype( + const paddle::DataType values_dtype) { + return {values_dtype}; +} + +std::vector> RaggedToDenseInferShape( + std::vector values_shape, + std::vector row_splits_shape, + const int64_t out_col_size) { + auto out_shape = values_shape; + out_shape.erase(out_shape.begin()); + out_shape.insert(out_shape.begin(), + {row_splits_shape[0] - 1, out_col_size}); + return {out_shape}; +} + +PD_BUILD_OP(open3d_ragged_to_dense) + .Inputs({"values", "row_splits", "default_value"}) + .Attrs({"out_col_size: int64_t"}) + .Outputs({"out"}) + .SetKernelFn(PD_KERNEL(RaggedToDense)) + .SetInferShapeFn(PD_INFER_SHAPE(RaggedToDenseInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(RaggedToDenseInferDtype)); diff --git a/cpp/pybind/CMakeLists.txt b/cpp/pybind/CMakeLists.txt index c79bbd96719..3fb90789c02 100644 --- a/cpp/pybind/CMakeLists.txt +++ b/cpp/pybind/CMakeLists.txt @@ -184,6 +184,25 @@ if (BUILD_PYTORCH_OPS) OUTPUT_VARIABLE Pytorch_VERSION) endif() +# add additional optional compiled modules +if (BUILD_PADDLE_OPS) + list( APPEND COMPILED_MODULE_PATH_LIST $ ) + add_custom_command( OUTPUT "${CMAKE_BINARY_DIR}/lib/ml/paddle/python/ops.py" "${CMAKE_BINARY_DIR}/lib/ml/paddle/python/return_types.py" + COMMAND ${Python3_EXECUTABLE} generate_paddle_ops_wrapper.py + --input_return_types_py_in "${PYTHON_PACKAGE_SRC_DIR}/open3d/ml/paddle/python/return_types.py.in" + --output_dir "${CMAKE_BINARY_DIR}/lib/ml/paddle/python/" + --lib $ + DEPENDS open3d_paddle_ops + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + COMMENT "Generating python ops.py and return_types.py" ) + + list(APPEND GENERATED_OUTPUTS "${CMAKE_BINARY_DIR}/lib/ml/paddle/python/ops.py" "${CMAKE_BINARY_DIR}/lib/ml/paddle/python/return_types.py") + # find paddle to get some info for the _build_config.py + set(PRINT_ONCE OFF) + find_package(Paddle) +endif() + + if (BUNDLE_OPEN3D_ML) find_path( OPEN3D_ML_ROOT @@ -229,6 +248,7 @@ add_custom_target(python-package -DBUILD_JUPYTER_EXTENSION=${BUILD_JUPYTER_EXTENSION} -DBUILD_TENSORFLOW_OPS=${BUILD_TENSORFLOW_OPS} -DBUILD_PYTORCH_OPS=${BUILD_PYTORCH_OPS} + -DBUILD_PADDLE_OPS=${BUILD_PADDLE_OPS} -DBUNDLE_OPEN3D_ML=${BUNDLE_OPEN3D_ML} -DOPEN3D_ML_ROOT=${OPEN3D_ML_ROOT} -DBUILD_GUI=${BUILD_GUI} diff --git a/cpp/pybind/_build_config.py.in b/cpp/pybind/_build_config.py.in index 6c32224de02..7ea14135ff2 100644 --- a/cpp/pybind/_build_config.py.in +++ b/cpp/pybind/_build_config.py.in @@ -1,6 +1,7 @@ _build_config = { "BUILD_TENSORFLOW_OPS" : $,True,False>, "BUILD_PYTORCH_OPS" : $,True,False>, + "BUILD_PADDLE_OPS" : $,True,False>, "BUILD_CUDA_MODULE" : $,True,False>, "BUILD_SYCL_MODULE" : $,True,False>, "BUILD_AZURE_KINECT" : $,True,False>, @@ -16,5 +17,6 @@ _build_config = { "CUDA_GENCODES" : "@CUDA_GENCODES@", "Tensorflow_VERSION" : "@Tensorflow_VERSION@", "Pytorch_VERSION" : "@Pytorch_VERSION@", + "Paddle_VERSION" : "@Paddle_VERSION@", "WITH_OPENMP" : $,True,False> } diff --git a/cpp/pybind/generate_paddle_ops_wrapper.py b/cpp/pybind/generate_paddle_ops_wrapper.py new file mode 100644 index 00000000000..40038e044e7 --- /dev/null +++ b/cpp/pybind/generate_paddle_ops_wrapper.py @@ -0,0 +1,197 @@ +import argparse +import textwrap +import sys +import os +from yapf.yapflib.yapf_api import FormatFile + + +from paddle.utils.cpp_extension.extension_utils import ( + load_op_meta_info_and_register_op, + _get_api_inputs_str, + _gen_output_content +) + + +def remove_op_name_prefix(op_name): + PADDLE_OPS_PREFIX = "open3d_" + + assert op_name.startswith(PADDLE_OPS_PREFIX), "Paddle operators should be start with `open3d_`." + func_name = op_name[len(PADDLE_OPS_PREFIX):] + + return func_name + + +def custom_api_header(): + HEADER = textwrap.dedent( + """ + # ---------------------------------------------------------------------------- + # - Open3D: www.open3d.org - + # ---------------------------------------------------------------------------- + # The MIT License (MIT) + # + # Copyright (c) 2018-2024 www.open3d.org + # + # Permission is hereby granted, free of charge, to any person obtaining a copy + # of this software and associated documentation files (the "Software"), to deal + # in the Software without restriction, including without limitation the rights + # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + # copies of the Software, and to permit persons to whom the Software is + # furnished to do so, subject to the following conditions: + # + # The above copyright notice and this permission notice shall be included in + # all copies or substantial portions of the Software. + # + # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + # IN THE SOFTWARE. + # ---------------------------------------------------------------------------- + + # This file is machine generated. Do not modify. + from paddle import _C_ops + from paddle.framework import in_dynamic_or_pir_mode + from paddle.base.layer_helper import LayerHelper + from . import return_types + """ + ).lstrip() + + return HEADER + + +def custom_api_footer(custom_ops): + FOOTER = textwrap.dedent( + """ + __all__ = [ + {export_func_name_strs} + ] + """ + ).lstrip() + + export_func_name_strs = "" + for op_name in custom_ops: + export_func_name_strs += f"'{remove_op_name_prefix(op_name)}', \n" + + return FOOTER.format( + export_func_name_strs = export_func_name_strs + ) + + +def custom_api_content(op_name): + ( + params_list, + ins_map, + attrs_map, + outs_list, + in_names, + _, + out_names, + inplace_reverse_idx, + ) = _get_api_inputs_str(op_name) + dynamic_content, static_content = _gen_output_content( + op_name, + in_names, + out_names, + ins_map, + attrs_map, + outs_list, + inplace_reverse_idx, + ) + API_TEMPLATE = textwrap.dedent( + """ + def {func_name}({params_list}): + # The output variable's dtype use default value 'float32', + # and the actual dtype of output variable will be inferred in runtime. + if in_dynamic_or_pir_mode(): + outs = _C_ops._run_custom_op("{op_name}", {params_list}) + {dynamic_content} + else: + {static_content} + """ + ).lstrip() + + # NOTE: Hack return express to wrapper multi return value by return_types + if len(out_names) > 1: + RETURN_NAMEDTUPLE_TEMPLATE = textwrap.dedent("""return return_types.{op_name}(*res)""").lstrip() + REPLACED_RETURN_TEMPLATE = textwrap.dedent("""return res[0] if len(res)==1 else res""").lstrip() + dynamic_content = dynamic_content.replace(REPLACED_RETURN_TEMPLATE, RETURN_NAMEDTUPLE_TEMPLATE.format(op_name=op_name)) + static_content = static_content.replace(REPLACED_RETURN_TEMPLATE, RETURN_NAMEDTUPLE_TEMPLATE.format(op_name=op_name)) + + func_name = remove_op_name_prefix(op_name) + + # generate python api file + api_content = API_TEMPLATE.format( + func_name=func_name, + op_name=op_name, + params_list=params_list, + dynamic_content=dynamic_content, + static_content=static_content, + ) + + NAMEDTUPLE_TEMPLATE= textwrap.dedent("""{op_name} = _namedtuple('{op_name}', '{out_names}')""").lstrip() + out_names = ' '.join([out_name for out_name in out_names]) + api_namedtuple = NAMEDTUPLE_TEMPLATE.format( + op_name=op_name, out_names=out_names) + + + return api_content, api_namedtuple + + +def main(): + parser = argparse.ArgumentParser( + description="Creates the ops.py and return_types.py files") + parser.add_argument("--input_return_types_py_in", + type=str, + required=True, + help="input file with header") + parser.add_argument("--lib", + type=str, + required=True, + help="path to open3d_paddle_ops.so") + parser.add_argument("--output_dir", + type=str, + required=True, + help="output directory") + args = parser.parse_args() + + generated_fuction_strs = "" + generated_namedtuple_strs = "" + custom_ops = load_op_meta_info_and_register_op(args.lib) + for _custom_op in custom_ops: + generated_fuction_str, generated_namedtuple_str = custom_api_content(_custom_op) + generated_fuction_strs += generated_fuction_str + "\n" + generated_namedtuple_strs += generated_namedtuple_str + "\n" + + CUSTOM_API_TEMPLATE = textwrap.dedent(""" + {custom_api_header} + + {custom_api_content} + + {custom_api_footer} + """).lstrip() + generated_ops_strs = CUSTOM_API_TEMPLATE.format( + custom_api_header = custom_api_header(), + custom_api_content = generated_fuction_strs, + custom_api_footer = custom_api_footer(custom_ops) + ) + + os.makedirs(args.output_dir, exist_ok=True) + output_ops_py_path = os.path.join(args.output_dir, 'ops.py') + with open(output_ops_py_path,'w') as f: + f.write(generated_ops_strs) + FormatFile(output_ops_py_path, in_place=True) + + output_return_types_py_path = os.path.join(args.output_dir, + 'return_types.py') + with open(args.input_return_types_py_in, 'r') as f: + input_header = f.read() + with open(output_return_types_py_path, 'w') as f: + f.write(input_header + generated_namedtuple_strs) + + return 0 + + +if __name__ == '__main__': + sys.exit(main()) \ No newline at end of file diff --git a/cpp/pybind/make_python_package.cmake b/cpp/pybind/make_python_package.cmake index 01e0d5663f5..9648f1743b0 100644 --- a/cpp/pybind/make_python_package.cmake +++ b/cpp/pybind/make_python_package.cmake @@ -65,7 +65,7 @@ configure_file("${PYTHON_PACKAGE_SRC_DIR}/../cpp/open3d/visualization/webrtc_ser file(COPY "${PYTHON_COMPILED_MODULE_DIR}/_build_config.py" DESTINATION "${PYTHON_PACKAGE_DST_DIR}/open3d/") -if (BUILD_TENSORFLOW_OPS OR BUILD_PYTORCH_OPS) +if (BUILD_TENSORFLOW_OPS OR BUILD_PYTORCH_OPS OR BUILD_PADDLE_OPS) # copy generated files file(COPY "${PYTHON_PACKAGE_DST_DIR}/../ml" DESTINATION "${PYTHON_PACKAGE_DST_DIR}/open3d/" ) diff --git a/python/open3d/ml/paddle/__init__.py b/python/open3d/ml/paddle/__init__.py new file mode 100644 index 00000000000..a53d37447f6 --- /dev/null +++ b/python/open3d/ml/paddle/__init__.py @@ -0,0 +1,27 @@ +# ---------------------------------------------------------------------------- +# - Open3D: www.open3d.org - +# ---------------------------------------------------------------------------- +# Copyright (c) 2018-2024 www.open3d.org +# SPDX-License-Identifier: MIT +# ---------------------------------------------------------------------------- +from packaging.version import parse as _verp +import paddle as _paddle +from open3d import _build_config + +if not _build_config["Paddle_VERSION"]: + raise Exception('Open3D was not built with Paddle support!') +_o3d_paddle_version = _verp(_build_config["Paddle_VERSION"]) +# Check match with Paddle version, any patch level is OK +if _verp(_paddle.__version__).release[:2] != _o3d_paddle_version.release[:2]: + match_paddle_ver = '.'.join( + str(v) for v in _o3d_paddle_version.release[:2] + ('*',)) + raise Exception('Version mismatch: Open3D needs Paddle version {}, but ' + 'version {} is installed!'.format(match_paddle_ver, + _paddle.__version__)) + +from . import layers +from . import ops +from . import classes + +# put contrib at the same level +from open3d.ml import contrib diff --git a/python/open3d/ml/paddle/classes/__init__.py b/python/open3d/ml/paddle/classes/__init__.py new file mode 100644 index 00000000000..a321fa69f87 --- /dev/null +++ b/python/open3d/ml/paddle/classes/__init__.py @@ -0,0 +1,22 @@ +# ---------------------------------------------------------------------------- +# - Open3D: www.open3d.org - +# ---------------------------------------------------------------------------- +# Copyright (c) 2018-2024 www.open3d.org +# SPDX-License-Identifier: MIT +# ---------------------------------------------------------------------------- +"""Paddle specific machine learning classes.""" +import paddle + +from .ragged_tensor import RaggedTensor + +DTYPE_MAP = { + paddle.bool: 'bool', + paddle.float16: 'float16', + paddle.float32: 'float32', + paddle.float64: 'float64', + paddle.int8: 'int8', + paddle.int16: 'int16', + paddle.int32: 'int32', + paddle.int64: 'int64', + paddle.bfloat16: 'uint16', +} diff --git a/python/open3d/ml/paddle/classes/ragged_tensor.py b/python/open3d/ml/paddle/classes/ragged_tensor.py new file mode 100644 index 00000000000..ecc6ad5b443 --- /dev/null +++ b/python/open3d/ml/paddle/classes/ragged_tensor.py @@ -0,0 +1,199 @@ +# ---------------------------------------------------------------------------- +# - Open3D: www.open3d.org - +# ---------------------------------------------------------------------------- +# Copyright (c) 2018-2024 www.open3d.org +# SPDX-License-Identifier: MIT +# ---------------------------------------------------------------------------- + +import paddle +import numpy as np + +__all__ = ['RaggedTensor'] + + +class RaggedTensor: + + def __init__(self, values, row_splits, internal=False): + if not internal: + raise ValueError( + "RaggedTensor constructor is private, please use one of the factory method instead(e.g. RaggedTensor.from_row_splits())" + ) + self._values = values + self._row_splits = row_splits + + @classmethod + def _from_row_splits(cls, values, row_splits, validate=True): + if row_splits.dtype != paddle.int64: + raise ValueError("row_splits must have type paddle.int64") + + values = values.contiguous() + row_splits = row_splits.contiguous() + + if validate: + if len(row_splits.shape) != 1: + raise ValueError("row_splits must be of rank 1") + if row_splits[0] != 0: + raise ValueError( + f"Arguments to from_row_splits do not form a valid RaggedTensor. Expect row_splits[0] == 0 but received row_splits[0] == {row_splits[0]}." + ) + for i in range(0, row_splits.shape[0] - 1): + if row_splits[i] > row_splits[i + 1]: + raise ValueError( + "row_splits must be monotonically increasing") + + row_splits = row_splits.to(values.place) + + return values, row_splits + + @classmethod + def from_row_splits(cls, values, row_splits, validate=True, copy=True): + + if isinstance(values, list): + values = paddle.to_tensor(values, dtype=paddle.float64) + elif isinstance(values, np.ndarray): + values = paddle.to_tensor(values) + elif isinstance(values, paddle.Tensor) and copy: + values = values.clone() + + if isinstance(row_splits, list): + row_splits = paddle.to_tensor(row_splits, dtype=paddle.int64) + elif isinstance(row_splits, np.ndarray): + row_splits = paddle.to_tensor(row_splits) + elif isinstance(row_splits, paddle.Tensor) and copy: + row_splits = row_splits.clone() + + values, row_splits = cls._from_row_splits(values, row_splits, validate) + + return cls(values, row_splits, internal=True) + + @property + def values(self): + """The concatenated rows for this ragged tensor.""" + return self._values + + @property + def row_splits(self): + """The row-split indices for this ragged tensor's `values`.""" + return self._row_splits + + @property + def dtype(self): + """The `DType` of values in this ragged tensor.""" + return self._values.dtype + + @property + def device(self): + """The device of values in this ragged tensor.""" + return self._values.place + + @property + def shape(self): + """The statically known shape of this ragged tensor.""" + return [ + len(self._row_splits.shape[0] - 1), None, *self._values.shape[1:] + ] + + @property + def requires_grad(self): + """Read/writeble `requires_grad` for values.""" + return not self._values.stop_gradient + + @requires_grad.setter + def requires_grad(self, value): + # NOTE: stop_gradient=True means not requires grad + self._values.stop_gradient = not value + + def clone(self): + """Returns a clone of object.""" + return self.__class__(self._values.clone(), self._row_splits.clone(), + True) + + def to_list(self): + """Returns a list of tensors""" + return [tensor for tensor in self._values] + + def __getitem__(self, idx): + return self._values.slice([ + 0, + ], [ + self._row_splits[idx], + ], [ + self._row_splits[idx + 1], + ]) + + def __repr__(self): + return f"RaggedTensor(values={self._values}, row_splits={self._row_splits})" + + def __len__(self): + return len(self._row_splits.shape[0] - 1) + + def __add__(self, other): + values, row_splits = self.__class__._from_row_splits( + self._values + self.__convert_to_tensor(other), self._row_splits, + False) + return RaggedTensor(values, row_splits, True) + + def __iadd__(self, other): + paddle.assign(self._values + self.__convert_to_tensor(other), + self._values) + return self + + def __sub__(self, other): + values, row_splits = self.__class__._from_row_splits( + self._values - self.__convert_to_tensor(other), self._row_splits, + False) + return RaggedTensor(values.clone(), row_splits.clone(), True) + + def __isub__(self, other): + paddle.assign(self._values - self.__convert_to_tensor(other), + self._values) + return self + + def __mul__(self, other): + values, row_splits = self.__class__._from_row_splits( + self._values * self.__convert_to_tensor(other), self._row_splits, + False) + return RaggedTensor(values.clone(), row_splits.clone(), True) + + def __imul__(self, other): + paddle.assign(self._values * self.__convert_to_tensor(other), + self._values) + return self + + def __truediv__(self, other): + values, row_splits = self.__class__._from_row_splits( + self._values / self.__convert_to_tensor(other), self._row_splits, + False) + return RaggedTensor(values.clone(), row_splits.clone(), True) + + def __itruediv__(self, other): + paddle.assign(self._values / self.__convert_to_tensor(other), + self._values) + return self + + def __floordiv__(self, other): + values, row_splits = self.__class__._from_row_splits( + self._values // self.__convert_to_tensor(other), self._row_splits, + False) + return RaggedTensor(values.clone(), row_splits.clone(), True) + + def __ifloordiv__(self, other): + paddle.assign(self._values // self.__convert_to_tensor(other), + self._values) + return self + + def __convert_to_tensor(self, value): + """Converts scalar/tensor/RaggedTensor to paddle.Tensor""" + if isinstance(value, RaggedTensor): + if self._row_splits.shape != value.row_splits.shape or paddle.any( + self._row_splits != value.row_splits).item(): + raise ValueError( + f"Incompatible shape : {self._row_splits} and {value.row_splits}" + ) + return value.values + elif isinstance(value, paddle.Tensor): + return value + elif isinstance(value, (int, float, bool)): + return paddle.to_tensor([value], dtype=type(value)) + else: + raise ValueError(f"Unknown type : {type(value)}") diff --git a/python/open3d/ml/paddle/layers/__init__.py b/python/open3d/ml/paddle/layers/__init__.py new file mode 100644 index 00000000000..5a09a34e775 --- /dev/null +++ b/python/open3d/ml/paddle/layers/__init__.py @@ -0,0 +1,12 @@ +# ---------------------------------------------------------------------------- +# - Open3D: www.open3d.org - +# ---------------------------------------------------------------------------- +# Copyright (c) 2018-2024 www.open3d.org +# SPDX-License-Identifier: MIT +# ---------------------------------------------------------------------------- +"""High level layer API for building networks. + +This module contains layers for processing 3D data. +All layers subclass paddle.nn.Layer +""" +from ..python.layers.neighbor_search import * diff --git a/python/open3d/ml/paddle/ops/__init__.py b/python/open3d/ml/paddle/ops/__init__.py new file mode 100644 index 00000000000..dd97bf61ed0 --- /dev/null +++ b/python/open3d/ml/paddle/ops/__init__.py @@ -0,0 +1,62 @@ +# ---------------------------------------------------------------------------- +# - Open3D: www.open3d.org - +# ---------------------------------------------------------------------------- +# Copyright (c) 2018-2024 www.open3d.org +# SPDX-License-Identifier: MIT +# ---------------------------------------------------------------------------- +"""Functional API with operators. + +These are the building blocks for the layers. See The layer API for an easy to +use high level interface. +""" +import os as _os +import sys as _sys +import types as _types +import importlib as _importlib +import importlib.abc as _importlib_abc +import importlib.util as _importlib_util +import paddle as _paddle +from open3d import _build_config + +from ..python.ops import * + +_lib_path = [] +# allow overriding the path to the op library with an env var. +if 'OPEN3D_PADDLE_OP_LIB' in _os.environ: + _lib_path.append(_os.environ['OPEN3D_PADDLE_OP_LIB']) + +_this_dir = _os.path.dirname(__file__) +_package_root = _os.path.join(_this_dir, '..', '..', '..') +_lib_ext = {'linux': '.so', 'darwin': '.dylib', 'win32': '.dll'}[_sys.platform] +_lib_suffix = '_debug' if _build_config['CMAKE_BUILD_TYPE'] == 'Debug' else '' +_lib_arch = ('cpu',) +if _build_config["BUILD_CUDA_MODULE"] and _paddle.device.cuda.device_count( +) >= 1: + if _paddle.version.cuda() == _build_config["CUDA_VERSION"]: + _lib_arch = ('cuda', 'cpu') + else: + print("Warning: Open3D was built with CUDA {} but" + "Paddle was built with CUDA {}. Falling back to CPU for now." + "Otherwise, install Paddle with CUDA {}.".format( + _build_config["CUDA_VERSION"], _paddle.version.cuda(), + _build_config["CUDA_VERSION"])) +_lib_path.extend([ + _os.path.join(_package_root, la, + 'open3d_paddle_ops' + _lib_suffix + _lib_ext) + for la in _lib_arch +]) + +# only load first lib +_load_lib_path = _lib_path[0] +# load custom op shared library with abs path +_custom_ops = _paddle.utils.cpp_extension.load_op_meta_info_and_register_op( + _load_lib_path) + +try: + _spec = _importlib_util.spec_from_file_location(__name__, _load_lib_path) + assert _spec is not None + _mod = _importlib_util.module_from_spec(_spec) + assert isinstance(_spec.loader, _importlib_abc.Loader) + _spec.loader.exec_module(_mod) +except ImportError: + _mod = _types.ModuleType(__name__) diff --git a/python/open3d/ml/paddle/python/layers/neighbor_search.py b/python/open3d/ml/paddle/python/layers/neighbor_search.py new file mode 100644 index 00000000000..2600cf03b54 --- /dev/null +++ b/python/open3d/ml/paddle/python/layers/neighbor_search.py @@ -0,0 +1,376 @@ +# ---------------------------------------------------------------------------- +# - Open3D: www.open3d.org - +# ---------------------------------------------------------------------------- +# Copyright (c) 2018-2024 www.open3d.org +# SPDX-License-Identifier: MIT +# ---------------------------------------------------------------------------- + +from ...python import ops +from ....paddle import classes +from ...classes import DTYPE_MAP +import paddle + +__all__ = ['FixedRadiusSearch', 'RadiusSearch', 'KNNSearch'] + + +class FixedRadiusSearch(paddle.nn.Layer): + """Fixed radius search for 3D point clouds. + + This layer computes the neighbors for a fixed radius on a point cloud. + + Example: + This example shows a neighbor search that returns the indices to the + found neighbors and the distances.:: + + import paddle + import open3d.ml.paddle as ml3d + + points = paddle.randn([20, 3]) + queries = paddle.randn([10, 3]) + radius = 0.8 + + nsearch = ml3d.layers.FixedRadiusSearch(return_distances=True) + ans = nsearch(points, queries, radius) + # returns a tuple of neighbors_index, neighbors_row_splits, and neighbors_distance + + + Arguments: + metric: Either L1, L2 or Linf. Default is L2. + + ignore_query_point: If True the points that coincide with the center of + the search window will be ignored. This excludes the query point if + 'queries' and 'points' are the same point cloud. + + return_distances: If True the distances for each neighbor will be returned. + If False a zero length Tensor will be returned instead. + """ + + def __init__(self, + metric='L2', + ignore_query_point=False, + return_distances=False, + max_hash_table_size=32 * 2**20, + index_dtype=paddle.int32, + **kwargs): + super().__init__() + self.metric = metric + self.ignore_query_point = ignore_query_point + self.return_distances = return_distances + self.max_hash_table_size = max_hash_table_size + assert index_dtype in [paddle.int32, paddle.int64] + self.index_dtype = DTYPE_MAP[index_dtype] + + def forward(self, + points, + queries, + radius, + points_row_splits=None, + queries_row_splits=None, + hash_table_size_factor=1 / 64, + hash_table=None): + """This function computes the neighbors within a fixed radius for each query point. + + Arguments: + + points: The 3D positions of the input points. It can be a RaggedTensor. + + queries: The 3D positions of the query points. It can be a RaggedTensor. + + radius: A scalar with the neighborhood radius + + points_row_splits: Optional 1D vector with the row splits information + if points is batched. This vector is [0, num_points] if there is + only 1 batch item. + + queries_row_splits: Optional 1D vector with the row splits information + if queries is batched. This vector is [0, num_queries] if there is + only 1 batch item. + + hash_table_size_factor: Scalar. The size of the hash table as fraction + of points. + + hash_table: A precomputed hash table generated with build_spatial_hash_table(). + This input can be used to explicitly force the reuse of a hash table in special + cases and is usually not needed. + Note that the hash table must have been generated with the same 'points' array. + + Returns: + 3 Tensors in the following order + + neighbors_index + The compact list of indices of the neighbors. The corresponding query point + can be inferred from the 'neighbor_count_row_splits' vector. + + neighbors_row_splits + The exclusive prefix sum of the neighbor count for the query points including + the total neighbor count as the last element. The size of this array is the + number of queries + 1. + + neighbors_distance + Stores the distance to each neighbor if 'return_distances' is True. + Note that the distances are squared if metric is L2. + This is a zero length Tensor if 'return_distances' is False. + """ + + if isinstance(points, classes.RaggedTensor): + points_row_splits = points.row_splits + points = points.values + if isinstance(queries, classes.RaggedTensor): + queries_row_splits = queries.row_splits + queries = queries.values + + if points_row_splits is None: + points_row_splits = paddle.to_tensor([0, points.shape[0]], + dtype="int64") + if queries_row_splits is None: + queries_row_splits = paddle.to_tensor([0, queries.shape[0]], + dtype="int64") + + if hash_table is None: + table = ops.build_spatial_hash_table( + max_hash_table_size=self.max_hash_table_size, + points=points, + radius=radius, + points_row_splits=points_row_splits, + hash_table_size_factor=hash_table_size_factor) + else: + table = hash_table + + result = ops.fixed_radius_search( + ignore_query_point=self.ignore_query_point, + return_distances=self.return_distances, + metric_str=self.metric, + points=points, + queries=queries, + radius=radius, + points_row_splits=points_row_splits, + queries_row_splits=queries_row_splits, + hash_table_splits=table.hash_table_splits, + hash_table_index=table.hash_table_index, + hash_table_cell_splits=table.hash_table_cell_splits, + index_dtype=self.index_dtype) + + return result + + +class RadiusSearch(paddle.nn.Layer): + """Radius search for 3D point clouds. + + This layer computes the neighbors for each query point with each query + having an individual radius. + + Example: + This example shows a neighbor search that returns the indices to the + found neighbors and the distances.:: + + import paddle + import open3d.ml.paddle as ml3d + + points = paddle.randn([20, 3]) + queries = paddle.randn([10, 3]) + radii = paddle.randn([10]) + 1.0 + + nsearch = ml3d.layers.RadiusSearch(return_distances=True) + ans = nsearch(points, queries, radii) + # returns a tuple of neighbors_index, neighbors_row_splits, and neighbors_distance + + + Arguments: + metric: Either L1, L2 or Linf. Default is L2. + + ignore_query_point: If True the points that coincide with the center of the + search window will be ignored. This excludes the query point if 'queries' + and 'points' are the same point cloud. + + return_distances: If True the distances for each neighbor will be returned. + If False a zero length Tensor will be returned instead. + + normalize_distances: If True the returned distances will be normalized with + the radii. + """ + + def __init__(self, + metric='L2', + ignore_query_point=False, + return_distances=False, + normalize_distances=False, + index_dtype=paddle.int32, + **kwargs): + super().__init__() + self.metric = metric + self.ignore_query_point = ignore_query_point + self.return_distances = return_distances + self.normalize_distances = normalize_distances + assert index_dtype in [paddle.int32, paddle.int64] + self.index_dtype = DTYPE_MAP[index_dtype] + + def forward(self, + points, + queries, + radii, + points_row_splits=None, + queries_row_splits=None): + """This function computes the neighbors within a radius for each query point. + + Arguments: + + points: The 3D positions of the input points. + + queries: The 3D positions of the query points. + + radii: A radius for each query point. + + points_row_splits: Optional 1D vector with the row splits information + if points is batched. This vector is [0, num_points] if there is + only 1 batch item. + + queries_row_splits: Optional 1D vector with the row splits information + if queries is batched. This vector is [0, num_queries] if there is + only 1 batch item. + + Returns: + 3 Tensors in the following order + + neighbors_index + The compact list of indices of the neighbors. The corresponding query point + can be inferred from the 'neighbor_count_row_splits' vector. + + neighbors_row_splits + The exclusive prefix sum of the neighbor count for the query points including + the total neighbor count as the last element. The size of this array is the + number of queries + 1. + + neighbors_distance + Stores the distance to each neighbor if 'return_distances' is True. + Note that the distances are squared if metric is L2. + This is a zero length Tensor if 'return_distances' is False. + """ + if points_row_splits is None: + points_row_splits = paddle.to_tensor([0, points.shape[0]], + dtype="int64") + if queries_row_splits is None: + queries_row_splits = paddle.to_tensor([0, queries.shape[0]], + dtype="int64") + + result = ops.radius_search(ignore_query_point=self.ignore_query_point, + return_distances=self.return_distances, + normalize_distances=self.normalize_distances, + metric_str=self.metric, + points=points, + queries=queries, + radii=radii, + points_row_splits=points_row_splits, + queries_row_splits=queries_row_splits, + index_dtype=self.index_dtype) + + return result + + +class KNNSearch(paddle.nn.Layer): + """KNN search for 3D point clouds. + + This layer computes the k nearest neighbors for each query point. + + Example: + This example shows a neighbor search that returns the indices to the + found neighbors and the distances.:: + + import paddle + import open3d.ml.paddle as ml3d + + points = paddle.randn([20, 3]) + queries = paddle.randn([10, 3]) + k = 8 + + nsearch = ml3d.layers.KNNSearch(return_distances=True) + ans = nsearch(points, queries, k) + # returns a tuple of neighbors_index, neighbors_row_splits, and neighbors_distance + # Since there are more than k points and we do not ignore any points we can + # reshape the output to [num_queries, k] with + neighbors_index = ans.neighbors_index.reshape(10,k) + neighbors_distance = ans.neighbors_distance.reshape(10,k) + + + Arguments: + metric: Either L1, L2 or Linf. Default is L2. + + ignore_query_point: If True the points that coincide with the center of the + search window will be ignored. This excludes the query point if 'queries' + and 'points' are the same point cloud. + + return_distances: If True the distances for each neighbor will be returned. + If False a zero length Tensor will be returned instead. + """ + + def __init__(self, + metric='L2', + ignore_query_point=False, + return_distances=False, + index_dtype=paddle.int32, + **kwargs): + super().__init__() + self.metric = metric + self.ignore_query_point = ignore_query_point + self.return_distances = return_distances + assert index_dtype in [paddle.int32, paddle.int64] + self.index_dtype = DTYPE_MAP[index_dtype] + + def forward(self, + points, + queries, + k, + points_row_splits=None, + queries_row_splits=None): + """This function computes the k nearest neighbors for each query point. + + Arguments: + points: The 3D positions of the input points. *This argument must be + given as a positional argument!* + + queries: The 3D positions of the query points. + + k: The number of nearest neighbors to search. + + points_row_splits: Optional 1D vector with the row splits information + if points is batched. + This vector is [0, num_points] if there is only 1 batch item. + + queries_row_splits: Optional 1D vector with the row splits information + if queries is batched. + This vector is [0, num_queries] if there is only 1 batch item. + + Returns: 3 Tensors in the following order + + neighbors_index + The compact list of indices of the neighbors. The corresponding query point + can be inferred from the 'neighbor_count_row_splits' vector. + + neighbors_row_splits + The exclusive prefix sum of the neighbor count for the query points including + the total neighbor count as the last element. The size of this array is the + number of queries + 1. + + neighbors_distance + Stores the distance to each neighbor if 'return_distances' is True. + Note that the distances are squared if metric is L2. + This is a zero length Tensor if 'return_distances' is False. + """ + + if points_row_splits is None: + points_row_splits = paddle.to_tensor([0, points.shape[0]], + dtype=paddle.int64) + if queries_row_splits is None: + queries_row_splits = paddle.to_tensor([0, queries.shape[0]], + dtype=paddle.int64) + + result = ops.knn_search(ignore_query_point=self.ignore_query_point, + return_distances=self.return_distances, + metric_str=self.metric, + points=points, + queries=queries, + k=k, + points_row_splits=points_row_splits, + queries_row_splits=queries_row_splits, + index_dtype=self.index_dtype) + + return result diff --git a/python/open3d/ml/paddle/python/return_types.py.in b/python/open3d/ml/paddle/python/return_types.py.in new file mode 100644 index 00000000000..a73a0765ed9 --- /dev/null +++ b/python/open3d/ml/paddle/python/return_types.py.in @@ -0,0 +1,29 @@ +# ---------------------------------------------------------------------------- +# - Open3D: www.open3d.org - +# ---------------------------------------------------------------------------- +# The MIT License (MIT) +# +# Copyright (c) 2018-2024 www.open3d.org +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# ---------------------------------------------------------------------------- + +# This file is machine generated. Do not modify. +from collections import namedtuple as _namedtuple + diff --git a/python/test/ml_ops/mltest.py b/python/test/ml_ops/mltest.py index 0f25f5d46f6..c8dc8e60f1e 100644 --- a/python/test/ml_ops/mltest.py +++ b/python/test/ml_ops/mltest.py @@ -17,7 +17,8 @@ # skip all tests if the ml ops were not built default_marks = [ pytest.mark.skipif(not (o3d._build_config['BUILD_TENSORFLOW_OPS'] or - o3d._build_config['BUILD_PYTORCH_OPS']), + o3d._build_config['BUILD_PYTORCH_OPS'] or + o3d._build_config['BUILD_PADDLE_OPS']), reason='ml ops not built'), ] @@ -62,9 +63,24 @@ except ImportError: pass +try: + paddle = importlib.import_module('paddle') + ml3d_ops = importlib.import_module('open3d.ml.paddle.ops') + ml3d_layers = importlib.import_module('open3d.ml.paddle.layers') + ml3d_classes = importlib.import_module('open3d.ml.paddle.classes') + _ml_modules['paddle'] = MLModules(paddle, ml3d_ops, ml3d_layers, + ml3d_classes, 'cpu', 'cpu', False) + if paddle.device.is_compiled_with_cuda( + ) and o3d._build_config['BUILD_CUDA_MODULE']: + _ml_modules['paddle_cuda'] = MLModules(paddle, ml3d_ops, ml3d_layers, + ml3d_classes, 'cuda', 'cpu', + True) +except ImportError: + pass + def is_gpu_device_name(name): - return name in ('GPU:0', 'cuda') + return name in ('GPU:0', 'cuda', 'gpu:0', 'gpu') def to_numpy(tensor): @@ -75,6 +91,8 @@ def to_numpy(tensor): if tensor.device.type == 'cuda': tensor = tensor.cpu() + return tensor.numpy() + elif 'paddle' in _ml_modules and isinstance(tensor, paddle.Tensor): return tensor.numpy() else: return tensor.numpy() @@ -88,6 +106,25 @@ def to_torch(x, device): return x +def to_paddle(x, device): + """Converts x such that it can be used as input to a paddle op.""" + if isinstance(x, np.ndarray): + return paddle.to_tensor( + x, place='gpu') if device == 'cuda' else paddle.to_tensor( + x, place='cpu') + else: + return x + + +def paddle_cmp_device(x, device): + if device == 'cuda' and x.place.is_gpu_place(): + return True + elif device == 'cpu' and x.place.is_cpu_place(): + return True + else: + return False + + def run_op(ml, device_name, check_device, fn, *args, **kwargs): """Runs an op using an ml framework""" if ml.module.__name__ == 'tensorflow': @@ -126,7 +163,25 @@ def run_op(ml, device_name, check_device, fn, *args, **kwargs): x, torch.Tensor) and device_name == x.device.type: tensor_on_device = True assert tensor_on_device + elif ml.module.__name__ == 'paddle': + _args = [to_paddle(x, device_name) for x in args] + _kwargs = {k: to_paddle(v, device_name) for k, v in kwargs.items()} + + ans = fn(*_args, **_kwargs) + if check_device: + # not all returned tensor have to use the device. + # check if there is at least one tensor using device memory + tensor_on_device = False + if isinstance(ans, paddle.Tensor): + if paddle_cmp_device(ans, device_name): + tensor_on_device = True + else: + for x in ans: + if isinstance(x, paddle.Tensor) and paddle_cmp_device( + x, device_name): + tensor_on_device = True + assert tensor_on_device else: raise ValueError('unsupported ml framework {}'.format(ml.module)) @@ -189,6 +244,30 @@ def run_op_grad(ml, device_name, check_device, fn, x, y_attr_name, torch.Tensor) and device_name == dy_dx.device.type: tensor_on_device = True assert tensor_on_device + elif ml.module.__name__ == 'paddle': + x_var = to_paddle(x, device_name) + x_var.stop_gradient = False + _args = [x_var if a is x else to_paddle(a, device_name) for a in args] + _kwargs = { + k: x_var if a is x else to_paddle(a, device_name) + for k, a in kwargs.items() + } + + ans = fn(*_args, **_kwargs) + if y_attr_name: + y = getattr(ans, y_attr_name) + else: + y = ans + y.backward(to_paddle(backprop_values, device_name)) + dy_dx = x_var.grad + + if check_device: + # check if the gradient is using device memory + tensor_on_device = False + if isinstance(dy_dx, paddle.Tensor) and paddle_cmp_device( + dy_dx, device_name): + tensor_on_device = True + assert tensor_on_device else: raise ValueError('unsupported ml framework {}'.format(ml.module)) @@ -213,6 +292,8 @@ def set_seed(self, seed): self.module.random.set_seed(seed) elif self.module.__name__ == 'torch': self.module.manual_seed(seed) + elif self.module.__name__ == 'paddle': + self.module.seed(seed) else: raise Exception('Unsupported ml framework') @@ -221,19 +302,27 @@ def set_deterministic(self, deterministic): pass elif self.module.__name__ == 'torch': self.module.set_deterministic(deterministic) + elif self.module.__name__ == 'paddle': + paddle.set_flags({ + "FLAGS_cudnn_deterministic": "1", + "FLAGS_cpu_deterministic": "1" + }) else: raise Exception('Unsupported ml framework') - def random_uniform(self, size, dtype, minval=0, maxval=1): + def random_uniform(self, shape, dtype, minval=0, maxval=1): if isinstance(dtype, str): dtype = self.get_dtype(dtype) if self.module.__name__ == 'tensorflow': - return self.module.random.uniform(shape=size, + return self.module.random.uniform(shape=shape, dtype=dtype, minval=minval, maxval=maxval) elif self.module.__name__ == 'torch': - ans = self.module.empty(size=size, dtype=dtype) + ans = self.module.empty(size=shape, dtype=dtype) + return ans.uniform_(minval, maxval) + elif self.module.__name__ == 'paddle': + ans = self.module.empty(shape=shape, dtype=dtype) return ans.uniform_(minval, maxval) else: raise Exception('Unsupported ml framework') @@ -245,6 +334,8 @@ def empty(self, shape, dtype): return self.module.zeros(shape=shape, dtype=dtype) elif self.module.__name__ == 'torch': return self.module.empty(size=shape, dtype=dtype) + elif self.module.__name__ == 'paddle': + return self.module.empty(shape=shape, dtype=dtype) else: raise Exception('Unsupported ml framework') @@ -255,6 +346,8 @@ def zeros(self, shape, dtype): return self.module.zeros(shape=shape, dtype=dtype) elif self.module.__name__ == 'torch': return self.module.zeros(size=shape, dtype=dtype) + elif self.module.__name__ == 'paddle': + return self.module.zeros(shape=shape, dtype=dtype) else: raise Exception('Unsupported ml framework') @@ -272,6 +365,13 @@ def zeros(self, shape, dtype): ml_tf_only=pytest.mark.parametrize('ml', [ v for k, v in _ml_modules.items() if v.module.__name__ == 'tensorflow' ]), + ml_paddle_only=pytest.mark.parametrize( + 'ml', + [v for k, v in _ml_modules.items() if v.module.__name__ == 'paddle']), + ml_torch_and_paddle_only=pytest.mark.parametrize('ml', [ + v for k, v in _ml_modules.items() + if v.module.__name__ == 'paddle' or v.module.__name__ == 'torch' + ]), ) diff --git a/python/test/ml_ops/test_fixed_radius_search.py b/python/test/ml_ops/test_fixed_radius_search.py index b0a8f720129..c3da38894ae 100644 --- a/python/test/ml_ops/test_fixed_radius_search.py +++ b/python/test/ml_ops/test_fixed_radius_search.py @@ -14,6 +14,8 @@ import torch if o3d._build_config['BUILD_TENSORFLOW_OPS']: import tensorflow as tf +if o3d._build_config['BUILD_PADDLE_OPS']: + import paddle # skip all tests if the ml ops were not built pytestmark = mltest.default_marks @@ -62,6 +64,11 @@ def test_fixed_radius_search(dtype, ml, num_points_queries, radius, index_dtype_ = {'int32': tf.int32, 'int64': tf.int64}[index_dtype] elif ml.module.__name__ == 'torch': index_dtype_ = {'int32': torch.int32, 'int64': torch.int64}[index_dtype] + elif ml.module.__name__ == 'paddle': + index_dtype_ = { + 'int32': paddle.int32, + 'int64': paddle.int64 + }[index_dtype] else: raise Exception('Unsupported ml framework') @@ -209,6 +216,11 @@ def test_fixed_radius_search_batches(dtype, ml, batch_size, radius, index_dtype_ = {'int32': tf.int32, 'int64': tf.int64}[index_dtype] elif ml.module.__name__ == 'torch': index_dtype_ = {'int32': torch.int32, 'int64': torch.int64}[index_dtype] + elif ml.module.__name__ == 'paddle': + index_dtype_ = { + 'int32': paddle.int32, + 'int64': paddle.int64 + }[index_dtype] else: raise Exception('Unsupported ml framework') diff --git a/python/test/ml_ops/test_knn_search.py b/python/test/ml_ops/test_knn_search.py index 38d0e11d644..8deb5d453f5 100644 --- a/python/test/ml_ops/test_knn_search.py +++ b/python/test/ml_ops/test_knn_search.py @@ -14,6 +14,8 @@ import torch if o3d._build_config['BUILD_TENSORFLOW_OPS']: import tensorflow as tf +if o3d._build_config['BUILD_PADDLE_OPS']: + import paddle # skip all tests if the ml ops were not built pytestmark = mltest.default_marks @@ -59,6 +61,11 @@ def test_knn_search(dtype, ml, num_points_queries, metric, ignore_query_point, index_dtype_ = {'int32': tf.int32, 'int64': tf.int64}[index_dtype] elif ml.module.__name__ == 'torch': index_dtype_ = {'int32': torch.int32, 'int64': torch.int64}[index_dtype] + elif ml.module.__name__ == 'paddle': + index_dtype_ = { + 'int32': paddle.int32, + 'int64': paddle.int64 + }[index_dtype] else: raise Exception('Unsupported ml framework') diff --git a/python/test/ml_ops/test_radius_search.py b/python/test/ml_ops/test_radius_search.py index c23590c5596..a9d7a6cce1c 100644 --- a/python/test/ml_ops/test_radius_search.py +++ b/python/test/ml_ops/test_radius_search.py @@ -14,6 +14,8 @@ import torch if o3d._build_config['BUILD_TENSORFLOW_OPS']: import tensorflow as tf +if o3d._build_config['BUILD_PADDLE_OPS']: + import paddle # skip all tests if the tf ops were not built and disable warnings caused by # tensorflow @@ -58,6 +60,11 @@ def test_radius_search(dtype, ml, num_points_queries, metric, index_dtype_ = {'int32': tf.int32, 'int64': tf.int64}[index_dtype] elif ml.module.__name__ == 'torch': index_dtype_ = {'int32': torch.int32, 'int64': torch.int64}[index_dtype] + elif ml.module.__name__ == 'paddle': + index_dtype_ = { + 'int32': paddle.int32, + 'int64': paddle.int64 + }[index_dtype] else: raise Exception('Unsupported ml framework') diff --git a/python/test/ml_ops/test_ragged_tensor_paddle.py b/python/test/ml_ops/test_ragged_tensor_paddle.py new file mode 100644 index 00000000000..eaafd72dd03 --- /dev/null +++ b/python/test/ml_ops/test_ragged_tensor_paddle.py @@ -0,0 +1,209 @@ +# ---------------------------------------------------------------------------- +# - Open3D: www.open3d.org - +# ---------------------------------------------------------------------------- +# Copyright (c) 2018-2024 www.open3d.org +# SPDX-License-Identifier: MIT +# ---------------------------------------------------------------------------- + +# noqa # pylint: disable=unused-import +import open3d as o3d +import numpy as np +import pytest +import mltest +import paddle + +# skip all tests if the tf ops were not built and disable warnings caused by +# tensorflow +pytestmark = mltest.default_marks + +# the supported dtypes for the values +dtypes = pytest.mark.parametrize('dtype', + [np.int32, np.int64, np.float32, np.float64]) + +# this class is only available for torch + + +@dtypes +@mltest.parametrize.ml_paddle_only +def test_creation(dtype, ml): + values = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=dtype) + row_splits = np.array([0, 2, 4, 4, 5, 12, 13], dtype=np.int64) + + # From numpy arrays + r_tensor = ml.classes.RaggedTensor.from_row_splits(values, row_splits) + for i, tensor in enumerate(r_tensor): + np.testing.assert_equal(mltest.to_numpy(tensor), + values[row_splits[i]:row_splits[i + 1]]) + + # From List + r_tensor = ml.classes.RaggedTensor.from_row_splits(list(values), + list(row_splits)) + for i, tensor in enumerate(r_tensor): + np.testing.assert_equal(mltest.to_numpy(tensor), + values[row_splits[i]:row_splits[i + 1]]) + + # Incompatible tensors. + # Non zero first element. + row_splits = np.array([1, 2, 4, 4, 5, 12, 13], dtype=np.int64) + + context = np.testing.assert_raises(ValueError) + + with context: + ml.classes.RaggedTensor.from_row_splits(values, row_splits) + + # Rank > 1. + row_splits = np.array([[0, 2, 4, 4, 5, 12, 13]], dtype=np.int64) + with context: + ml.classes.RaggedTensor.from_row_splits(values, row_splits) + + # Not increasing monotonically. + row_splits = np.array([[0, 2, 4, 6, 5, 12, 13]], dtype=np.int64) + with context: + ml.classes.RaggedTensor.from_row_splits(values, row_splits) + + # Wrong dtype. + row_splits = np.array([0, 2, 4, 4, 5, 12, 13], dtype=np.float32) + with context: + ml.classes.RaggedTensor.from_row_splits(values, row_splits) + + +# test with more dimensions +@dtypes +@mltest.parametrize.ml_paddle_only +def test_creation_more_dims(dtype, ml): + values = np.array([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], + [7, 7], [8, 8], [9, 9], [10, 10], [11, 11], [12, 12]], + dtype=dtype) + row_splits = np.array([0, 2, 4, 4, 5, 12, 13], dtype=np.int64) + + # From numpy arrays + r_tensor = ml.classes.RaggedTensor.from_row_splits(values, row_splits) + for i, tensor in enumerate(r_tensor): + np.testing.assert_equal(mltest.to_numpy(tensor), + values[row_splits[i]:row_splits[i + 1]]) + + # From List + r_tensor = ml.classes.RaggedTensor.from_row_splits(list(values), + list(row_splits)) + for i, tensor in enumerate(r_tensor): + np.testing.assert_equal(mltest.to_numpy(tensor), + values[row_splits[i]:row_splits[i + 1]]) + + +@mltest.parametrize.ml_paddle_only +def test_backprop(ml): + # Create 3 different RaggedTensors and torch.tensor + t_1 = paddle.randn([10, 3]) + t_1.stop_gradient = False + + t_2 = paddle.randn([10, 3]) + t_2.stop_gradient = False + + t_3 = paddle.randn([10, 3]) + t_3.stop_gradient = False + + row_splits = paddle.to_tensor([0, 4, 6, 6, 8, 10]) + + r_1 = ml.classes.RaggedTensor.from_row_splits(t_1.detach().numpy(), + row_splits) + r_1.requires_grad = True + r_2 = ml.classes.RaggedTensor.from_row_splits(t_2.detach().numpy(), + row_splits) + r_2.requires_grad = True + r_3 = ml.classes.RaggedTensor.from_row_splits(t_3.detach().numpy(), + row_splits) + r_3.requires_grad = True + + r_ans = (r_1 + r_2) * r_3 + t_ans = (t_1 + t_2) * t_3 + + np.testing.assert_equal(mltest.to_numpy(t_ans), + mltest.to_numpy(r_ans.values)) + + # Compute gradients + t_ans.sum().backward() + r_ans.values.sum().backward() + + np.testing.assert_equal(mltest.to_numpy(t_1.grad), + mltest.to_numpy(r_1.values.grad)) + + +@dtypes +@mltest.parametrize.ml_paddle_only +def test_binary_ew_ops(dtype, ml): + # Binary Ops. + device = 'gpu' if ml.device == 'cuda' else 'cpu' + + t_1 = paddle.to_tensor( + np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + dtype=dtype)).to(device) + t_2 = paddle.to_tensor( + np.array([2, 3, 6, 3, 11, 3, 43, 12, 8, 15, 12, 87, 45], + dtype=dtype)).to(device) + + row_splits = paddle.to_tensor( + np.array([0, 2, 4, 4, 5, 12, 13], dtype=np.int64)).to(device) + + a = ml.classes.RaggedTensor.from_row_splits(t_1, row_splits) + b = ml.classes.RaggedTensor.from_row_splits(t_2, row_splits) + + np.testing.assert_equal( + (a + b).values.cpu().numpy(), + np.array([2, 4, 8, 6, 15, 8, 49, 19, 16, 24, 22, 98, 57])) + np.testing.assert_equal( + (a - b).values.cpu().numpy(), + np.array([-2, -2, -4, 0, -7, 2, -37, -5, 0, -6, -2, -76, -33])) + np.testing.assert_equal( + (a * b).values.cpu().numpy(), + np.array([0, 3, 12, 9, 44, 15, 258, 84, 64, 135, 120, 957, 540])) + np.testing.assert_equal((a / b).values.cpu().numpy(), + (t_1 / t_2).cpu().numpy()) + np.testing.assert_equal((a // b).values.cpu().numpy(), + np.array([0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0])) + + # Assignment Ops. + a = ml.classes.RaggedTensor.from_row_splits(t_1, row_splits) + a += b + np.testing.assert_equal( + a.values.cpu().numpy(), + np.array([2, 4, 8, 6, 15, 8, 49, 19, 16, 24, 22, 98, 57])) + + a = ml.classes.RaggedTensor.from_row_splits(t_1, row_splits) + a -= b + np.testing.assert_equal( + a.values.cpu().numpy(), + np.array([-2, -2, -4, 0, -7, 2, -37, -5, 0, -6, -2, -76, -33])) + + a = ml.classes.RaggedTensor.from_row_splits(t_1, row_splits) + a *= b + np.testing.assert_equal( + a.values.cpu().numpy(), + np.array([0, 3, 12, 9, 44, 15, 258, 84, 64, 135, 120, 957, 540])) + + a = ml.classes.RaggedTensor.from_row_splits(t_1, row_splits) + a //= b + np.testing.assert_equal(a.values.cpu().numpy(), + np.array([0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0])) + + # Failure cases with incompatible shape. + # Different row_splits. + row_splits = [0, 4, 5, 13] + a = ml.classes.RaggedTensor.from_row_splits(t_1, row_splits) + row_splits = [0, 4, 6, 13] + b = ml.classes.RaggedTensor.from_row_splits(t_2, row_splits) + + with np.testing.assert_raises(ValueError): + a + b # noqa # pylint: disable=pointless-statement + with np.testing.assert_raises(ValueError): + a += b # noqa # pylint: disable=pointless-statement + + # Different length + row_splits = [0, 4, 5, 13] + a = ml.classes.RaggedTensor.from_row_splits(t_1, row_splits) + row_splits = [0, 4, 13] + b = ml.classes.RaggedTensor.from_row_splits(t_2, row_splits) + + with np.testing.assert_raises(ValueError): + a + b # noqa # pylint: disable=pointless-statement + with np.testing.assert_raises(ValueError): + a += b diff --git a/python/test/ml_ops/test_ragged_to_dense.py b/python/test/ml_ops/test_ragged_to_dense.py index 9121faa1b46..289320032ff 100644 --- a/python/test/ml_ops/test_ragged_to_dense.py +++ b/python/test/ml_ops/test_ragged_to_dense.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: MIT # ---------------------------------------------------------------------------- +# noqa # pylint: disable=unused-import import open3d as o3d import numpy as np import pytest @@ -22,7 +23,7 @@ @dtypes -@mltest.parametrize.ml_torch_only +@mltest.parametrize.ml_torch_and_paddle_only def test_ragged_to_dense(dtype, ml): values = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=dtype) @@ -30,8 +31,14 @@ def test_ragged_to_dense(dtype, ml): out_col_size = 4 default_value = np.array(-1, dtype=dtype) - ans = mltest.run_op(ml, ml.device, True, ml.ops.ragged_to_dense, values, - row_splits, out_col_size, default_value) + ans = mltest.run_op(ml, + ml.device, + True, + ml.ops.ragged_to_dense, + values=values, + row_splits=row_splits, + out_col_size=out_col_size, + default_value=default_value) expected = np.full((row_splits.shape[0] - 1, out_col_size), default_value) for i in range(row_splits.shape[0] - 1): @@ -44,7 +51,7 @@ def test_ragged_to_dense(dtype, ml): # test with more dimensions @dtypes -@mltest.parametrize.ml_torch_only +@mltest.parametrize.ml_torch_and_paddle_only def test_ragged_to_dense_more_dims(dtype, ml): values = np.array([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], @@ -54,8 +61,14 @@ def test_ragged_to_dense_more_dims(dtype, ml): out_col_size = 4 default_value = np.array([-1, -1], dtype=dtype) - ans = mltest.run_op(ml, ml.device, True, ml.ops.ragged_to_dense, values, - row_splits, out_col_size, default_value) + ans = mltest.run_op(ml, + ml.device, + True, + ml.ops.ragged_to_dense, + values=values, + row_splits=row_splits, + out_col_size=out_col_size, + default_value=default_value) expected = np.full(( row_splits.shape[0] - 1, @@ -71,7 +84,7 @@ def test_ragged_to_dense_more_dims(dtype, ml): # test with larger random data @dtypes -@mltest.parametrize.ml_torch_only +@mltest.parametrize.ml_torch_and_paddle_only @pytest.mark.parametrize('seed', [123, 456]) def test_ragged_to_dense_random(dtype, ml, seed): @@ -87,8 +100,14 @@ def test_ragged_to_dense_random(dtype, ml, seed): default_value = np.array(-1, dtype=dtype) - ans = mltest.run_op(ml, ml.device, True, ml.ops.ragged_to_dense, values, - row_splits, out_col_size, default_value) + ans = mltest.run_op(ml, + ml.device, + True, + ml.ops.ragged_to_dense, + values=values, + row_splits=row_splits, + out_col_size=out_col_size, + default_value=default_value) expected = np.full((row_splits.shape[0] - 1, out_col_size), default_value) for i in range(row_splits.shape[0] - 1):