diff --git a/.github/workflows/torchao_experimental_test.yml b/.github/workflows/torchao_experimental_test.yml index 9fb52fb760..4f3fe41a86 100644 --- a/.github/workflows/torchao_experimental_test.yml +++ b/.github/workflows/torchao_experimental_test.yml @@ -54,28 +54,29 @@ jobs: python torchao/experimental/tests/test_embedding_xbit_quantizer.py python torchao/experimental/tests/test_quant_passes.py pytest -s test/prototype/test_dynamic_activation_lut.py - - name: Run kernels/cpu/aarch64/tests + - name: torchao/csrc/cpu - build and run C++ tests if: runner.os == 'macOS' run: | conda activate venv - pushd torchao/experimental/kernels/cpu/aarch64/tests + pushd torchao/csrc/cpu sh build_and_run_tests.sh - rm -rf /tmp/cmake-out + rm -rf cmake-out popd - - name: Run torchao/experimental/ops/tests + - name: torchao/csrc/cpu - build benchmarks if: runner.os == 'macOS' run: | conda activate venv - pushd torchao/experimental/ops/tests - sh build_and_run_tests.sh - rm -rf /tmp/cmake-out + pushd torchao/csrc/cpu + sh build_and_run_benchmarks.sh build_only + rm -rf cmake-out popd - - name: ET ops build + - name: torchao/csrc/cpu - build shared_kernels with ExecuTorch if: runner.os == 'macOS' run: | conda activate venv - pushd torchao/experimental - sh build_torchao_ops.sh executorch + pushd torchao/csrc/cpu + sh build_shared_kernels.sh executorch + rm -rf cmake-out popd # test-mps-ops: diff --git a/setup.py b/setup.py index d9d7e80506..668c72d4cc 100644 --- a/setup.py +++ b/setup.py @@ -73,7 +73,7 @@ def read_version(file_path="version.txt"): # ├── USE_CPU_KERNELS="1" + Linux → Include optimized CPU kernels (AVX512, etc.) # └── ARM64 + macOS → Auto-enable experimental builds (build_macos_arm_auto) # -# Level 3: Experimental builds (cmake-based) +# Level 3: Shared CPU kernel builds (cmake-based) # ├── BUILD_TORCHAO_EXPERIMENTAL="1" → Force experimental builds # ├── build_macos_arm_auto → Auto-enable on ARM64 macOS # └── When enabled, provides access to: @@ -322,6 +322,19 @@ def build_cmake(self, ext): ext_filename = os.path.basename(self.get_ext_filename(ext.name)) ext_basename = os.path.splitext(ext_filename)[0] + print( + "CMAKE COMMANG", + [ + "cmake", + ext.cmake_lists_dir, + ] + + ext.cmake_args + + [ + "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, + "-DTORCHAO_CMAKE_EXT_SO_NAME=" + ext_basename, + ], + ) + subprocess.check_call( [ "cmake", @@ -472,10 +485,22 @@ def get_extensions(): # Collect C++ source files sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True)) + + # Exclude C++ CPU sources that are built by CMake + cpu_cmake_sources = glob.glob( + os.path.join(extensions_dir, "cpu", "torch_free_kernels", "**", "*.cpp"), + recursive=True, + ) + cpu_cmake_sources += glob.glob( + os.path.join(extensions_dir, "cpu", "shared_kernels", "**", "*.cpp"), + recursive=True, + ) + sources = [s for s in sources if s not in cpu_cmake_sources] + if not use_cpu_kernels or not is_linux: # Remove csrc/cpu/*.cpp excluded_sources = list( - glob.glob(os.path.join(extensions_dir, "cpu/*.cpp"), recursive=True) + glob.glob(os.path.join(extensions_dir, "cpu/*.cpp"), recursive=False) ) sources = [s for s in sources if s not in excluded_sources] @@ -616,6 +641,7 @@ def get_extensions(): ext_modules = [] if len(sources) > 0: + print("SOURCES", sources) # Double-check to ensure mx_fp_cutlass_kernels.cu is not in sources sources = [ s for s in sources if os.path.basename(s) != "mx_fp_cutlass_kernels.cu" @@ -703,7 +729,7 @@ def get_extensions(): ) ) - # Build CMakeLists from /torchao/experimental - additional options become available : TORCHAO_BUILD_CPU_AARCH64, TORCHAO_BUILD_KLEIDIAI, TORCHAO_BUILD_MPS_OPS, TORCHAO_PARALLEL_BACKEND + # Build CMakeLists from /torchao/csrc/cpu - additional options become available : TORCHAO_BUILD_CPU_AARCH64, TORCHAO_BUILD_KLEIDIAI, TORCHAO_BUILD_MPS_OPS, TORCHAO_PARALLEL_BACKEND if build_macos_arm_auto or os.getenv("BUILD_TORCHAO_EXPERIMENTAL") == "1": build_options = BuildOptions() @@ -716,24 +742,20 @@ def bool_to_on_off(value): ext_modules.append( CMakeExtension( - "torchao._experimental_aten_ops", - cmake_lists_dir="torchao/experimental", + "torchao._C_cpu_shared_kernels_aten", + cmake_lists_dir="torchao/csrc/cpu", cmake_args=( [ f"-DCMAKE_BUILD_TYPE={'Debug' if use_debug_mode() else 'Release'}", f"-DTORCHAO_BUILD_CPU_AARCH64={bool_to_on_off(build_options.build_cpu_aarch64)}", f"-DTORCHAO_BUILD_KLEIDIAI={bool_to_on_off(build_options.build_kleidi_ai)}", - f"-DTORCHAO_BUILD_MPS_OPS={bool_to_on_off(build_options.build_experimental_mps)}", f"-DTORCHAO_ENABLE_ARM_NEON_DOT={bool_to_on_off(build_options.enable_arm_neon_dot)}", f"-DTORCHAO_ENABLE_ARM_I8MM={bool_to_on_off(build_options.enable_arm_i8mm)}", f"-DTORCHAO_PARALLEL_BACKEND={build_options.parallel_backend}", + "-DTORCHAO_BUILD_TESTS=OFF", + "-DTORCHAO_BUILD_BENCHMARKS=OFF", "-DTorch_DIR=" + torch_dir, ] - + ( - ["-DCMAKE_INSTALL_PREFIX=cmake-out"] - if build_options.build_experimental_mps - else [] - ) ), ) ) diff --git a/torchao/__init__.py b/torchao/__init__.py index c6b7f92f50..032fcd861c 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -31,11 +31,7 @@ torch.ops.load_library(str(file)) from . import ops - # The following library contains CPU kernels from torchao/experimental - # They are built automatically by ao/setup.py if on an ARM machine. - # They can also be built outside of the torchao install process by - # running the script `torchao/experimental/build_torchao_ops.sh ` - # For more information, see https://github.com/pytorch/ao/blob/main/torchao/experimental/docs/readme.md + # The following import registers meta kernels for experimental ops from torchao.experimental.op_lib import * # noqa: F403 except Exception as e: logger.debug(f"Skipping import of cpp extensions: {e}") diff --git a/torchao/csrc/cpu/CMakeLists.txt b/torchao/csrc/cpu/CMakeLists.txt new file mode 100644 index 0000000000..aaea27ec74 --- /dev/null +++ b/torchao/csrc/cpu/CMakeLists.txt @@ -0,0 +1,232 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) +include(CMakeDependentOption) + +project(torchao) + +set(CMAKE_CXX_STANDARD 17) + +if (NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + +# Platform options +option(TORCHAO_BUILD_ATEN_OPS "Building torchao ops for ATen." ON) +option(TORCHAO_BUILD_EXECUTORCH_OPS "Building torchao ops for ExecuTorch." OFF) +option(TORCHAO_BUILD_CPU_AARCH64 "Build torchao's CPU aarch64 kernels" OFF) +option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" OFF) +option(TORCHAO_ENABLE_ARM_NEON_DOT "Enable ARM Neon Dot Product extension" OFF) +option(TORCHAO_ENABLE_ARM_I8MM "Enable ARM 8-bit Integer Matrix Multiply instructions" OFF) +option(TORCHAO_BUILD_TESTS "Build tests" OFF) +option(TORCHAO_BUILD_BENCHMARKS "Build tests" OFF) + +# Set default compiler options +add_compile_options("-fPIC" "-Wall" "-Werror" "-Wno-deprecated") +if (CMAKE_SYSTEM_NAME STREQUAL "Linux") + add_compile_options( + "-Wno-error=unknown-pragmas" + "-Wno-array-parameter" + "-Wno-maybe-uninitialized" + "-Wno-sign-compare" + ) +elseif (APPLE) + add_compile_options("-Wno-shorten-64-to-32") +endif() + + + +if (NOT TARGET cpuinfo) + cmake_policy(PUSH) + cmake_policy(VERSION 3.5) # cpuinfo requires CMake 3.5 + + # For some reason cpuinfo package has unused functions/variables + # TODO (T215533422): fix upstream + add_compile_options(-Wno-unused-function -Wno-unused-variable) + + # set(CMAKE_POLICY_VERSION_MINIMUM 3.5) + include(FetchContent) + set(CPUINFO_BUILD_UNIT_TESTS OFF CACHE BOOL "" FORCE) + set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE BOOL "" FORCE) + set(CPUINFO_BUILD_BENCHMARKS OFF CACHE BOOL "" FORCE) + FetchContent_Declare(cpuinfo + GIT_REPOSITORY https://github.com/pytorch/cpuinfo.git + GIT_TAG c61fe919607bbc534d7a5a5707bdd7041e72c5ff + ) + FetchContent_MakeAvailable( + cpuinfo) + + cmake_policy(POP) +endif() + +if (TORCHAO_BUILD_TESTS) + include(FetchContent) + FetchContent_Declare( + googletest + URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip + ) + FetchContent_MakeAvailable(googletest) +endif() + +if (TORCHAO_BUILD_BENCHMARKS) + include(FetchContent) + FetchContent_Declare(googlebenchmark + GIT_REPOSITORY https://github.com/google/benchmark.git + GIT_TAG main) # need main for benchmark::benchmark + + set(BENCHMARK_ENABLE_TESTING OFF) + FetchContent_MakeAvailable( + googlebenchmark) +endif() + +if(NOT TORCHAO_INCLUDE_DIRS) + set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../..) +endif() + +if(NOT DEFINED TORCHAO_PARALLEL_BACKEND) + set(TORCHAO_PARALLEL_BACKEND aten_openmp) +endif() + +# Set default compiler options + +include(CMakePrintHelpers) +include(${CMAKE_CURRENT_SOURCE_DIR}/shared_kernels/Utils.cmake) + +message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}") +include_directories(${TORCHAO_INCLUDE_DIRS}) + + +# Build fallback kernels +add_subdirectory(torch_free_kernels/fallback) + +# Build cpu/aarch64 kernels +if(TORCHAO_BUILD_CPU_AARCH64) + message(STATUS "Building with cpu/aarch64") + add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64) + + if(TORCHAO_ENABLE_ARM_NEON_DOT) + message(STATUS "Building with ARM NEON dot product support") + add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT) + add_compile_options("-march=armv8.4-a+dotprod") + endif() + + if(TORCHAO_ENABLE_ARM_I8MM) + message(STATUS "Building with ARM I8MM support") + add_compile_definitions(TORCHAO_ENABLE_ARM_I8MM) + add_compile_options("-march=armv8.6-a") + endif() + + if(TORCHAO_BUILD_KLEIDIAI) + message(STATUS "Building with Arm KleidiAI library") + add_compile_definitions(TORCHAO_ENABLE_KLEIDI) + if (NOT TARGET kleidiai) + include(FetchContent) + # KleidiAI is an open-source library that provides optimized + # performance-critical routines, also known as micro-kernels, for artificial + # intelligence (AI) workloads tailored for Arm® CPUs. + set(KLEIDIAI_BUILD_TESTS OFF CACHE BOOL "" FORCE) + set(KLEIDIAI_BUILD_BENCHMARKS OFF CACHE BOOL "" FORCE) + FetchContent_Declare(kleidiai + GIT_REPOSITORY https://git.gitlab.arm.com/kleidi/kleidiai.git + GIT_TAG v1.12.0 + ) + FetchContent_MakeAvailable(kleidiai) + endif() + endif() + + # Defines torchao_kernels_aarch64 + add_subdirectory(torch_free_kernels/aarch64) +endif() + +# Build ATen ops +if(TORCHAO_BUILD_ATEN_OPS) + find_package(Torch REQUIRED) + set(_torchao_op_srcs_aten) + list(APPEND _torchao_op_srcs_aten + shared_kernels/embedding_xbit/op_embedding_xbit_aten.cpp + shared_kernels/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp + shared_kernels/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp + shared_kernels/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp + shared_kernels/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_aten.cpp + ) + list(TRANSFORM _torchao_op_srcs_aten PREPEND "${CMAKE_CURRENT_SOURCE_DIR}/") + + # Use the Python extension name if provided + add_library(torchao_ops_aten SHARED ${_torchao_op_srcs_aten}) + if(DEFINED TORCHAO_CMAKE_EXT_SO_NAME) + message(STATUS "Setting output name to: ${TORCHAO_CMAKE_EXT_SO_NAME}.so") + set_target_properties(torchao_ops_aten PROPERTIES + OUTPUT_NAME ${TORCHAO_CMAKE_EXT_SO_NAME} + PREFIX "" # Remove "lib" prefix for Python extensions + SUFFIX ".so" # Add ".so" suffix for Python extensions + ) + endif() + + target_link_torchao_parallel_backend(torchao_ops_aten "${TORCHAO_PARALLEL_BACKEND}") + if (TORCHAO_BUILD_CPU_AARCH64) + target_link_libraries(torchao_ops_aten PRIVATE torchao_kernels_aarch64) + if (TORCHAO_BUILD_KLEIDIAI) + target_link_libraries(torchao_ops_aten PRIVATE kleidiai) + endif() + endif() + target_link_libraries(torchao_ops_aten PRIVATE cpuinfo) + target_include_directories(torchao_ops_aten PRIVATE "${TORCH_INCLUDE_DIRS}") + target_link_libraries(torchao_ops_aten PRIVATE "${TORCH_LIBRARIES}") + target_compile_definitions(torchao_ops_aten PRIVATE TORCHAO_SHARED_KERNELS_BUILD_ATEN=1) + + if (TORCHAO_BUILD_TESTS) + add_subdirectory(shared_kernels/tests) + endif() + + if (TORCHAO_BUILD_BENCHMARKS) + add_subdirectory(shared_kernels/benchmarks) + endif() + + # Install ATen targets + install( + TARGETS torchao_ops_aten + EXPORT _targets + DESTINATION lib + ) +endif() + + +# Build ExecuTorch ops +if(TORCHAO_BUILD_EXECUTORCH_OPS) + # ExecuTorch package is not required, but EXECUTORCH_INCLUDE_DIRS and EXECUTORCH_LIBRARIES must + # be defined and EXECUTORCH_LIBRARIES must include the following libraries installed by ExecuTorch: + # libexecutorch.a + # libextension_threadpool.a + # libcpuinfo.a + # libpthreadpool.a + if(NOT DEFINED EXECUTORCH_INCLUDE_DIRS AND NOT DEFINED EXECUTORCH_LIBRARIES) + message(WARNING "EXECUTORCH_INCLUDE_DIRS and EXECUTORCH_LIBRARIES are not defined. Looking for ExecuTorch.") + find_package(ExecuTorch HINTS ${CMAKE_PREFIX_PATH}/executorch/share/cmake) + endif() + set(_torchao_op_srcs_executorch) + list(APPEND _torchao_op_srcs_executorch + shared_kernels/embedding_xbit/op_embedding_xbit_executorch.cpp + shared_kernels/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp + shared_kernels/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch.cpp + shared_kernels/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp + shared_kernels/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_executorch.cpp) + + list(TRANSFORM _torchao_op_srcs_executorch PREPEND "${CMAKE_CURRENT_SOURCE_DIR}/") + add_library(torchao_ops_executorch STATIC ${_torchao_op_srcs_executorch}) + + target_compile_definitions(torchao_ops_executorch PRIVATE TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH=1) + + # This links to ExecuTorch + target_link_torchao_parallel_backend(torchao_ops_executorch executorch) + if (TORCHAO_BUILD_CPU_AARCH64) + target_link_libraries(torchao_ops_executorch PRIVATE torchao_kernels_aarch64) + if (TORCHAO_BUILD_KLEIDIAI) + target_link_libraries(torchao_ops_executorch PRIVATE kleidiai) + endif() + endif() + target_link_libraries(torchao_ops_executorch PRIVATE cpuinfo) +endif() diff --git a/torchao/csrc/cpu/build_and_run_benchmarks.sh b/torchao/csrc/cpu/build_and_run_benchmarks.sh new file mode 100644 index 0000000000..964fe9e5bf --- /dev/null +++ b/torchao/csrc/cpu/build_and_run_benchmarks.sh @@ -0,0 +1,38 @@ +set -eu + +if [[ $# -ne 1 ]]; then + echo "Usage: $0 "; + exit 1; +fi + +BENCHMARK_TYPE="${1}" + +export CMAKE_OUT=cmake-out + +export CMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())') +echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}" + +# Build +cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ + -DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \ + -DTORCHAO_BUILD_EXECUTORCH_OPS=OFF \ + -DTORCHAO_BUILD_CPU_AARCH64=ON \ + -DTORCHAO_ENABLE_ARM_NEON_DOT=ON \ + -DCMAKE_BUILD_TYPE=Release \ + -DTORCHAO_BUILD_TESTS=OFF \ + -DTORCHAO_BUILD_BENCHMARKS=ON \ + -DOpenMP_ROOT=$(brew --prefix libomp) \ + -S . \ + -B ${CMAKE_OUT} +cmake --build ${CMAKE_OUT} -j 16 --config Release + + +# Run +TARGET_PREFIX="${CMAKE_OUT}/torch_free_kernels/aarch64/benchmarks/torchao_benchmarks_torch_free_kernels_aarch64_" +case "${BENCHMARK_TYPE}" in + build_only) echo "Build only"; exit 0; ;; + quantization) ${TARGET_PREFIX}benchmark_quantization; ;; + bitpacking) ${TARGET_PREFIX}benchmark_bitpacking; ;; + linear) ${TARGET_PREFIX}benchmark_linear; ;; + *) echo "Unknown benchmark: $1. Please specify quantization, bitpacking, or linear."; exit 1; ;; +esac diff --git a/torchao/csrc/cpu/build_and_run_tests.sh b/torchao/csrc/cpu/build_and_run_tests.sh new file mode 100644 index 0000000000..6d92a81d98 --- /dev/null +++ b/torchao/csrc/cpu/build_and_run_tests.sh @@ -0,0 +1,87 @@ +#!/bin/bash -eu +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +set -eu + + +target=${1:-"native"} +export CMAKE_OUT=cmake-out + +EXTRA_ARGS="" +if [[ "${target}" == "android" ]]; then + if [[ -z ${ANDROID_NDK} ]]; then + echo "Need to set ANDROID_NDK env variable to build for Android"; + exit 1; + fi + android_abi=arm64-v8a + android_platform=28 # must be >=28 for aligned_alloc + IS_ARM64=1 + BUILD_ARM_I8MM=1 # Hardcoded for now + CMAKE_OUT=${CMAKE_OUT/cmake-out/cmake-out-android} + toolchain_file="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" + if [[ -z ${toolchain_file} ]]; then + echo "Unable to find toolchain file at ANDROID_NDK location, looking for ${toolchain_file}" + exit 1; + fi + EXTRA_ARGS="\ + -DCMAKE_TOOLCHAIN_FILE=${toolchain_file} \ + -DANDROID_ABI=${android_abi} \ + -DANDROID_PLATFORM=${android_platform} + " + echo "Building tests for Android (${android_abi}) @ ${CMAKE_OUT}" +fi + + + + +export CMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())') +echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}" + + +cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ + -DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \ + -DTORCHAO_BUILD_EXECUTORCH_OPS=OFF \ + -DTORCHAO_BUILD_CPU_AARCH64=ON \ + -DTORCHAO_ENABLE_ARM_NEON_DOT=ON \ + -DTORCHAO_BUILD_KLEIDIAI=ON \ + -DCMAKE_BUILD_TYPE=Debug \ + -DTORCHAO_BUILD_TESTS=ON \ + -S . \ + -B ${CMAKE_OUT} +cmake --build ${CMAKE_OUT} -j 16 --config Debug + + + +echo "Successfully built tests." + +if [[ "${target}" != "native" ]]; then + echo "Skip running tests when cross compiling."; + exit 0; +fi + +# Torch-free aarch64 +TEST_TARGET_PREFIX="${CMAKE_OUT}/torch_free_kernels/aarch64/tests/torchao_tests_torch_free_kernels_aarch64_" +${TEST_TARGET_PREFIX}test_quantization +${TEST_TARGET_PREFIX}test_reduction +${TEST_TARGET_PREFIX}test_reduction +${TEST_TARGET_PREFIX}test_bitpacking +${TEST_TARGET_PREFIX}test_linear +${TEST_TARGET_PREFIX}test_embedding +${TEST_TARGET_PREFIX}test_weight_packing +${TEST_TARGET_PREFIX}test_qmatmul +${TEST_TARGET_PREFIX}test_lut +${TEST_TARGET_PREFIX}test_bitpack_fallback_compatibility +${TEST_TARGET_PREFIX}test_embedding_lut + +# Torch-free fallback +TEST_TARGET_PREFIX="${CMAKE_OUT}/torch_free_kernels/fallback/tests/torchao_tests_torch_free_kernels_fallback_" +${TEST_TARGET_PREFIX}test_bitpacking + +# Shared kernels +TEST_TARGET_PREFIX="${CMAKE_OUT}/shared_kernels/tests/torchao_tests_shared_kernels_" +${TEST_TARGET_PREFIX}test_linear_8bit_act_xbit_weight +${TEST_TARGET_PREFIX}test_groupwise_lowbit_weight_lut diff --git a/torchao/experimental/build_torchao_ops.sh b/torchao/csrc/cpu/build_shared_kernels.sh similarity index 93% rename from torchao/experimental/build_torchao_ops.sh rename to torchao/csrc/cpu/build_shared_kernels.sh index 1bcc1a9658..bfa9a55eef 100644 --- a/torchao/experimental/build_torchao_ops.sh +++ b/torchao/csrc/cpu/build_shared_kernels.sh @@ -23,6 +23,8 @@ cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ -DTORCHAO_BUILD_EXECUTORCH_OPS="${TORCHAO_BUILD_EXECUTORCH_OPS}" \ -DTORCHAO_BUILD_CPU_AARCH64=ON \ -DTORCHAO_ENABLE_ARM_NEON_DOT=ON \ + -DTORCHAO_BUILD_TESTS=OFF \ + -DTORCHAO_BUILD_BENCHMARKS=OFF \ -S . \ -B ${CMAKE_OUT} cmake --build ${CMAKE_OUT} -j 16 --target install --config Release diff --git a/torchao/csrc/cpu/shared_kernels/README.md b/torchao/csrc/cpu/shared_kernels/README.md new file mode 100644 index 0000000000..37b4be6c7c --- /dev/null +++ b/torchao/csrc/cpu/shared_kernels/README.md @@ -0,0 +1,5 @@ +# Shared kernels + +This directory is for kernels that are shared between PyTorch/ATen and Executorch. +Shared kernels are written with abstractions in internal/library.h. +These are compiled to either an ATen or ExecuTorch kernel based on compile flags. diff --git a/torchao/experimental/Utils.cmake b/torchao/csrc/cpu/shared_kernels/Utils.cmake similarity index 100% rename from torchao/experimental/Utils.cmake rename to torchao/csrc/cpu/shared_kernels/Utils.cmake diff --git a/torchao/csrc/cpu/shared_kernels/benchmarks/CMakeLists.txt b/torchao/csrc/cpu/shared_kernels/benchmarks/CMakeLists.txt new file mode 100644 index 0000000000..b5fd251a1f --- /dev/null +++ b/torchao/csrc/cpu/shared_kernels/benchmarks/CMakeLists.txt @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +project(torchao_benchmarks) +set(CMAKE_BUILD_TYPE Release) + +set(TARGET_PREFIX "torchao_benchmarks_shared_kernels_") + + +# TODO: fix benchmark. Got broken from refactor + +# add_executable(${TARGET_PREFIX}benchmark_linear_8bit_act_xbit_weight +# benchmark_linear_8bit_act_xbit_weight.cpp +# ${TORCHAO_INCLUDE_DIRS}/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp +# ) + +# target_link_torchao_parallel_backend(${TARGET_PREFIX}benchmark_linear_8bit_act_xbit_weight openmp) +# target_link_libraries( +# ${TARGET_PREFIX}benchmark_linear_8bit_act_xbit_weight +# PRIVATE +# benchmark::benchmark +# torchao_kernels_aarch64 +# ) diff --git a/torchao/experimental/ops/benchmarks/benchmark_linear_8bit_act_xbit_weight.cpp b/torchao/csrc/cpu/shared_kernels/benchmarks/benchmark_linear_8bit_act_xbit_weight.cpp similarity index 92% rename from torchao/experimental/ops/benchmarks/benchmark_linear_8bit_act_xbit_weight.cpp rename to torchao/csrc/cpu/shared_kernels/benchmarks/benchmark_linear_8bit_act_xbit_weight.cpp index 2efd425175..caf03acf21 100644 --- a/torchao/experimental/ops/benchmarks/benchmark_linear_8bit_act_xbit_weight.cpp +++ b/torchao/csrc/cpu/shared_kernels/benchmarks/benchmark_linear_8bit_act_xbit_weight.cpp @@ -5,11 +5,11 @@ // LICENSE file in the root directory of this source tree. #include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include #include using namespace torchao::ops::linear_8bit_act_xbit_weight; diff --git a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h b/torchao/csrc/cpu/shared_kernels/embedding_xbit/op_embedding_xbit-impl.h similarity index 87% rename from torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h rename to torchao/csrc/cpu/shared_kernels/embedding_xbit/op_embedding_xbit-impl.h index 8113a0566b..6c1181873b 100644 --- a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h +++ b/torchao/csrc/cpu/shared_kernels/embedding_xbit/op_embedding_xbit-impl.h @@ -7,14 +7,14 @@ #pragma once #if defined(TORCHAO_BUILD_CPU_AARCH64) -#include +#include #endif // TORCHAO_BUILD_CPU_AARCH64 -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include template void check_embedding_inputs( @@ -27,11 +27,11 @@ void check_embedding_inputs( int& group_size) { TORCHAO_CHECK( packed_weight_qvals.dim() == 1, "packed_weight_qvals must be 1D"); -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK( packed_weight_qvals.dtype() == torch::kInt8, "packed_weight_qvals must be byte"); -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK( (embedding_dim * weight_nbit) % 8 == 0, "embedding_dim * weight_nbit must be a multiple of 8"); @@ -53,11 +53,11 @@ void check_embedding_inputs( /*max_value_chunk_size=*/128), "packed_weights are not compatible with the kernel"); -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK( weight_scales.dtype() == torch::kFloat32, "weight_scales must be float32"); -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK(weight_scales.dim() == 2, "weight_scales must be 2D"); TORCHAO_CHECK( weight_scales.size(0) == num_embeddings, @@ -71,10 +71,10 @@ void check_embedding_inputs( group_size = embedding_dim / num_groups; TORCHAO_CHECK(group_size % 32 == 0, "group_size must be a multiple of 32"); -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK( weight_zeros.dtype() == torch::kInt8, "weight_zeros must be int8"); -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK(weight_zeros.dim() == 2, "weight_zeros must be 2D"); TORCHAO_CHECK( weight_zeros.size(0) == weight_scales.size(0) && @@ -88,7 +88,7 @@ void check_embedding_inputs( "indices must be int32 or int64"); } -#if defined(USE_ATEN) || defined(USE_EXECUTORCH) +#if defined(TORCHAO_SHARED_KERNELS_BUILD_ATEN) || defined(TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH) template Tensor embedding_out_cpu( const Tensor& packed_weight_qvals, @@ -149,9 +149,9 @@ Tensor embedding_out_cpu( return out; } -#endif // defined(USE_ATEN) || defined(USE_EXECUTORCH) +#endif // defined(TORCHAO_SHARED_KERNELS_BUILD_ATEN) || defined(TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH) -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN template Tensor embedding_cpu( const Tensor& packed_weight_qvals, @@ -171,9 +171,9 @@ Tensor embedding_cpu( output_tensor); return output_tensor; } -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN template Tensor pack_embedding_cpu(const Tensor& weight_qvals) { TORCHAO_CHECK(weight_qvals.dim() == 2, "weight_qvals must be 2D"); @@ -213,9 +213,9 @@ Tensor pack_embedding_cpu(const Tensor& weight_qvals) { return out; } -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN template Tensor pack_embedding_meta(const Tensor& weight_qvals) { TORCHAO_CHECK(weight_qvals.dim() == 2, "weight_qvals must be 2D"); @@ -229,9 +229,9 @@ Tensor pack_embedding_meta(const Tensor& weight_qvals) { torchao::ops::PackedWeightsHeader::size() + (num_embeddings * packed_embedding_dim), options); } -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN -#if defined(USE_ATEN) || defined(USE_EXECUTORCH) +#if defined(TORCHAO_SHARED_KERNELS_BUILD_ATEN) || defined(TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH) template Tensor shared_embedding_out_cpu( const Tensor& packed_weights, @@ -242,10 +242,10 @@ Tensor shared_embedding_out_cpu( Tensor& out) { // Check packed_weights are from linear op TORCHAO_CHECK(packed_weights.dim() == 1, "packed_weights must be 1D"); -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK( packed_weights.dtype() == torch::kInt8, "packed_weights must be int8"); -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK( packed_weights.size(0) >= torchao::ops::PackedWeightsHeader::size(), "packed_weights is not big enough to read the header."); @@ -308,7 +308,7 @@ Tensor shared_embedding_out_cpu( return out; } -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN template Tensor shared_embedding_cpu( const Tensor& packed_weights, @@ -321,6 +321,6 @@ Tensor shared_embedding_cpu( packed_weights, group_size, n, k, indices, output_tensor); return output_tensor; } -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN -#endif // defined(USE_ATEN) || defined(USE_EXECUTORCH) +#endif // defined(TORCHAO_SHARED_KERNELS_BUILD_ATEN) || defined(TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH) diff --git a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_aten.cpp b/torchao/csrc/cpu/shared_kernels/embedding_xbit/op_embedding_xbit_aten.cpp similarity index 98% rename from torchao/experimental/ops/embedding_xbit/op_embedding_xbit_aten.cpp rename to torchao/csrc/cpu/shared_kernels/embedding_xbit/op_embedding_xbit_aten.cpp index 318e648977..7129cd61c3 100644 --- a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_aten.cpp +++ b/torchao/csrc/cpu/shared_kernels/embedding_xbit/op_embedding_xbit_aten.cpp @@ -4,7 +4,7 @@ // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. -#include +#include #define DEFINE_OP(weight_nbit) \ m.def("_pack_embedding_" #weight_nbit "bit(Tensor weight_qvals) -> Tensor"); \ diff --git a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_executorch.cpp b/torchao/csrc/cpu/shared_kernels/embedding_xbit/op_embedding_xbit_executorch.cpp similarity index 96% rename from torchao/experimental/ops/embedding_xbit/op_embedding_xbit_executorch.cpp rename to torchao/csrc/cpu/shared_kernels/embedding_xbit/op_embedding_xbit_executorch.cpp index 2ffcba7e6b..0227f23327 100644 --- a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit_executorch.cpp +++ b/torchao/csrc/cpu/shared_kernels/embedding_xbit/op_embedding_xbit_executorch.cpp @@ -4,7 +4,7 @@ // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. -#include +#include #define DEFINE_OP(weight_nbit) \ Tensor _op_out_##weight_nbit( \ diff --git a/torchao/experimental/ops/embedding_xbit/packed_weights_header.h b/torchao/csrc/cpu/shared_kernels/embedding_xbit/packed_weights_header.h similarity index 85% rename from torchao/experimental/ops/embedding_xbit/packed_weights_header.h rename to torchao/csrc/cpu/shared_kernels/embedding_xbit/packed_weights_header.h index 8e47c2d1c0..addcd4181e 100644 --- a/torchao/experimental/ops/embedding_xbit/packed_weights_header.h +++ b/torchao/csrc/cpu/shared_kernels/embedding_xbit/packed_weights_header.h @@ -5,8 +5,8 @@ // LICENSE file in the root directory of this source tree. #pragma once -#include -#include +#include +#include namespace torchao::ops::embedding_xbit { diff --git a/torchao/experimental/ops/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp similarity index 96% rename from torchao/experimental/ops/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp rename to torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp index c0d452c95b..d6ffbc79e1 100644 --- a/torchao/experimental/ops/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp +++ b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp @@ -4,11 +4,11 @@ // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. -#include +#include -#include -#include -#include +#include +#include +#include #include #include #include diff --git a/torchao/experimental/ops/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.h b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.h similarity index 98% rename from torchao/experimental/ops/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.h rename to torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.h index f5293a3fc1..bb5624033b 100644 --- a/torchao/experimental/ops/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.h +++ b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.h @@ -5,7 +5,7 @@ // LICENSE file in the root directory of this source tree. #pragma once -#include +#include #include #include diff --git a/torchao/experimental/ops/groupwise_lowbit_weight_lut/kernel_config.h b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/kernel_config.h similarity index 99% rename from torchao/experimental/ops/groupwise_lowbit_weight_lut/kernel_config.h rename to torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/kernel_config.h index 6b3ab28310..1110e740e2 100644 --- a/torchao/experimental/ops/groupwise_lowbit_weight_lut/kernel_config.h +++ b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/kernel_config.h @@ -5,7 +5,7 @@ // LICENSE file in the root directory of this source tree. #pragma once -#include +#include #include #include #include diff --git a/torchao/experimental/ops/groupwise_lowbit_weight_lut/kernel_selector.h b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/kernel_selector.h similarity index 96% rename from torchao/experimental/ops/groupwise_lowbit_weight_lut/kernel_selector.h rename to torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/kernel_selector.h index e898ba5af4..f8bdc4cafb 100644 --- a/torchao/experimental/ops/groupwise_lowbit_weight_lut/kernel_selector.h +++ b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/kernel_selector.h @@ -6,14 +6,14 @@ #pragma once #include -#include -#include +#include +#include #include #include #include #if defined(TORCHAO_BUILD_CPU_AARCH64) -#include +#include #endif // TORCHAO_BUILD_CPU_AARCH64 namespace torchao::ops::groupwise_lowbit_weight_lut { diff --git a/torchao/experimental/ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut-impl.h b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut-impl.h similarity index 87% rename from torchao/experimental/ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut-impl.h rename to torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut-impl.h index f4e36870df..e3aca77844 100644 --- a/torchao/experimental/ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut-impl.h +++ b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut-impl.h @@ -6,16 +6,16 @@ #pragma once -#include -#include -#include -#include +#include +#include +#include +#include #include #include namespace { -#if defined(USE_ATEN) || defined(USE_EXECUTORCH) +#if defined(TORCHAO_SHARED_KERNELS_BUILD_ATEN) || defined(TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH) template Tensor linear_out_cpu( const Tensor& activations, @@ -29,10 +29,10 @@ Tensor linear_out_cpu( TORCHAO_CHECK(k >= 1, "k must be >= 1"); TORCHAO_CHECK(lut_group_size >= 1, "lut_group_size must be >= 1"); -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK( activations.dtype() == torch::kFloat32, "activations must be float32"); -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK(activations.dim() == 2, "activations must be 2D"); int m = activations.size(0); @@ -40,18 +40,18 @@ Tensor linear_out_cpu( TORCHAO_CHECK( k == k_, "activation shape is incompatible with packed weights."); -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK(out.dtype() == torch::kFloat32, "out must be float32"); -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN // Explicit cast from int64_t to int is required for Executorch TORCHAO_RESIZE_TENSOR(out, {(int)m, (int)n}); TORCHAO_CHECK(packed_weights.dim() == 1, "packed_weights must be 1D"); -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK( packed_weights.dtype() == torch::kInt8, "packed_weights must be int8"); -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK( packed_weights.size(0) >= torchao::ops::PackedWeightsHeader::size(), "packed_weights is not big enough to read the header."); @@ -80,9 +80,9 @@ Tensor linear_out_cpu( return out; } -#endif // defined(USE_ATEN) || defined(USE_EXECUTORCH) +#endif // defined(TORCHAO_SHARED_KERNELS_BUILD_ATEN) || defined(TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH) -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN template Tensor linear_cpu( const Tensor& activations, @@ -102,9 +102,9 @@ Tensor linear_cpu( output_tensor); return output_tensor; } -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN template Tensor pack_weights_with_lut_cpu( const Tensor& weight_qval_idxs, @@ -195,9 +195,9 @@ Tensor pack_weights_with_lut_cpu( return packed_weights; } -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN template Tensor pack_weights_with_lut_meta( const Tensor& weight_qval_idxs, @@ -235,6 +235,6 @@ Tensor pack_weights_with_lut_meta( torch::TensorOptions().device(c10::DeviceType::Meta).dtype(torch::kInt8); return torch::empty({static_cast(packed_weight_data_size)}, options); } -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN } // namespace diff --git a/torchao/experimental/ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_aten.cpp b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_aten.cpp similarity index 96% rename from torchao/experimental/ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_aten.cpp rename to torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_aten.cpp index 612ed4d656..c9b65f2152 100644 --- a/torchao/experimental/ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_aten.cpp +++ b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_aten.cpp @@ -4,7 +4,7 @@ // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. -#include +#include #define DEFINE_PACK_OP(weight_nbit) \ m.def( \ diff --git a/torchao/experimental/ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_executorch.cpp b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_executorch.cpp similarity index 93% rename from torchao/experimental/ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_executorch.cpp rename to torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_executorch.cpp index 42ae795fb9..d3e06dd538 100644 --- a/torchao/experimental/ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_executorch.cpp +++ b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_executorch.cpp @@ -1,4 +1,4 @@ -#include +#include #define DEFINE_OP(weight_nbit) \ Tensor _op_out_##weight_nbit( \ diff --git a/torchao/experimental/ops/groupwise_lowbit_weight_lut/packed_weights_format.h b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/packed_weights_format.h similarity index 97% rename from torchao/experimental/ops/groupwise_lowbit_weight_lut/packed_weights_format.h rename to torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/packed_weights_format.h index 9ea50425b7..d7c64fbebd 100644 --- a/torchao/experimental/ops/groupwise_lowbit_weight_lut/packed_weights_format.h +++ b/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/packed_weights_format.h @@ -6,7 +6,7 @@ #pragma once -#include +#include #include namespace torchao::ops::groupwise_lowbit_weight_lut { diff --git a/torchao/experimental/ops/library.h b/torchao/csrc/cpu/shared_kernels/internal/library.h similarity index 67% rename from torchao/experimental/ops/library.h rename to torchao/csrc/cpu/shared_kernels/internal/library.h index c518b31aee..204d97f5a7 100644 --- a/torchao/experimental/ops/library.h +++ b/torchao/csrc/cpu/shared_kernels/internal/library.h @@ -4,8 +4,8 @@ // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. -#if defined(USE_ATEN) && !defined(USE_EXECUTORCH) -#pragma message("USE_ATEN") +#if defined(TORCHAO_SHARED_KERNELS_BUILD_ATEN) && !defined(TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH) +#pragma message("TORCHAO_SHARED_KERNELS_BUILD_ATEN") #include #include #include @@ -15,8 +15,8 @@ using Tensor = at::Tensor; #define TORCHAO_CHECK(cond, msg) TORCH_CHECK(cond, msg) #define TORCHAO_RESIZE_TENSOR(tensor, ...) tensor.resize_({__VA_ARGS__}) -#elif defined(USE_EXECUTORCH) && !defined(USE_ATEN) -#pragma message("USE_EXECUTORCH") +#elif defined(TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH) && !defined(TORCHAO_SHARED_KERNELS_BUILD_ATEN) +#pragma message("TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH") #include #include #include @@ -28,8 +28,8 @@ using RuntimeContext = torch::executor::KernelRuntimeContext; #define TORCHAO_RESIZE_TENSOR(tensor, ...) \ ET_CHECK_MSG(torch::executor::resize_tensor(tensor, {__VA_ARGS__}) == torch::executor::Error::Ok, "resize failed") -#elif !defined(USE_EXECUTORCH) && !defined(USE_ATEN) -#pragma message("Neither USE_ATEN or USE_EXECUTORCH defined") +#elif !defined(TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH) && !defined(TORCHAO_SHARED_KERNELS_BUILD_ATEN) +#pragma message("Neither TORCHAO_SHARED_KERNELS_BUILD_ATEN or TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH defined") #include #define TORCHAO_CHECK(cond, message) \ @@ -38,5 +38,5 @@ using RuntimeContext = torch::executor::KernelRuntimeContext; } #else -#error "Cannot define both USE_ATEN or USE_EXECUTORCH" +#error "Cannot define both TORCHAO_SHARED_KERNELS_BUILD_ATEN or TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH" #endif diff --git a/torchao/experimental/ops/memory.h b/torchao/csrc/cpu/shared_kernels/internal/memory.h similarity index 100% rename from torchao/experimental/ops/memory.h rename to torchao/csrc/cpu/shared_kernels/internal/memory.h diff --git a/torchao/experimental/ops/packed_weights_header.h b/torchao/csrc/cpu/shared_kernels/internal/packed_weights_header.h similarity index 100% rename from torchao/experimental/ops/packed_weights_header.h rename to torchao/csrc/cpu/shared_kernels/internal/packed_weights_header.h diff --git a/torchao/experimental/ops/parallel-aten-impl.h b/torchao/csrc/cpu/shared_kernels/internal/parallel-aten-impl.h similarity index 87% rename from torchao/experimental/ops/parallel-aten-impl.h rename to torchao/csrc/cpu/shared_kernels/internal/parallel-aten-impl.h index c2eb0b8498..9c825e48e5 100644 --- a/torchao/experimental/ops/parallel-aten-impl.h +++ b/torchao/csrc/cpu/shared_kernels/internal/parallel-aten-impl.h @@ -19,10 +19,6 @@ void torchao::parallel_1d(const int64_t begin, const int64_t end, const F& f) { }); } -inline void torchao::set_num_threads(int num_threads) { - torch::set_num_threads(num_threads); -} - inline int torchao::get_num_threads() { return torch::get_num_threads(); } diff --git a/torchao/experimental/ops/parallel-executorch-impl.h b/torchao/csrc/cpu/shared_kernels/internal/parallel-executorch-impl.h similarity index 80% rename from torchao/experimental/ops/parallel-executorch-impl.h rename to torchao/csrc/cpu/shared_kernels/internal/parallel-executorch-impl.h index 233f7250d4..01c8eb766f 100644 --- a/torchao/experimental/ops/parallel-executorch-impl.h +++ b/torchao/csrc/cpu/shared_kernels/internal/parallel-executorch-impl.h @@ -18,11 +18,6 @@ void torchao::parallel_1d(const int64_t begin, const int64_t end, const F& f) { end - begin); } -inline void torchao::set_num_threads(int num_threads) { - torch::executorch::threadpool::get_threadpool()->_unsafe_reset_threadpool( - num_threads); -} - inline int torchao::get_num_threads() { return torch::executorch::threadpool::get_threadpool()->get_thread_count(); } diff --git a/torchao/experimental/ops/parallel-openmp-impl.h b/torchao/csrc/cpu/shared_kernels/internal/parallel-openmp-impl.h similarity index 87% rename from torchao/experimental/ops/parallel-openmp-impl.h rename to torchao/csrc/cpu/shared_kernels/internal/parallel-openmp-impl.h index 236bb4e25f..e9b43653d2 100644 --- a/torchao/experimental/ops/parallel-openmp-impl.h +++ b/torchao/csrc/cpu/shared_kernels/internal/parallel-openmp-impl.h @@ -18,9 +18,6 @@ void torchao::parallel_1d(const int64_t begin, const int64_t end, const F& f) { } } -inline void torchao::set_num_threads(int num_threads) { - omp_set_num_threads(num_threads); -} inline int torchao::get_num_threads() { // omp_get_num_threads returns the number of threads // in the current code section, which will be 1 in the routines diff --git a/torchao/experimental/ops/parallel-pthreadpool-impl.h b/torchao/csrc/cpu/shared_kernels/internal/parallel-pthreadpool-impl.h similarity index 83% rename from torchao/experimental/ops/parallel-pthreadpool-impl.h rename to torchao/csrc/cpu/shared_kernels/internal/parallel-pthreadpool-impl.h index 9906cf4f3a..704349b59d 100644 --- a/torchao/experimental/ops/parallel-pthreadpool-impl.h +++ b/torchao/csrc/cpu/shared_kernels/internal/parallel-pthreadpool-impl.h @@ -33,13 +33,6 @@ class Threadpool { } return pthreadpool_get_threads_count(pthreadpool_); } - void set_num_threads(size_t num_threads) { - if (num_threads == get_num_threads()) { - return; - } - pthreadpool_destroy(pthreadpool_); - pthreadpool_ = pthreadpool_create(num_threads); - } }; template @@ -62,10 +55,6 @@ inline int torchao::get_num_threads() { return torchao::parallel::internal::threadpool.get_num_threads(); } -inline void torchao::set_num_threads(int num_threads) { - torchao::parallel::internal::threadpool.set_num_threads(num_threads); -} - template void torchao::parallel_1d(const int64_t begin, const int64_t end, const F& f) { auto context = torchao::parallel::internal::Context(f, begin); diff --git a/torchao/experimental/ops/parallel-single_threaded-impl.h b/torchao/csrc/cpu/shared_kernels/internal/parallel-single_threaded-impl.h similarity index 88% rename from torchao/experimental/ops/parallel-single_threaded-impl.h rename to torchao/csrc/cpu/shared_kernels/internal/parallel-single_threaded-impl.h index d9706829c2..74f067e39a 100644 --- a/torchao/experimental/ops/parallel-single_threaded-impl.h +++ b/torchao/csrc/cpu/shared_kernels/internal/parallel-single_threaded-impl.h @@ -13,7 +13,6 @@ void torchao::parallel_1d(const int64_t begin, const int64_t end, const F& f) { } } -inline void torchao::set_num_threads(int num_threads) {} inline int torchao::get_num_threads() { return 1; } diff --git a/torchao/experimental/ops/parallel-test_dummy-impl.h b/torchao/csrc/cpu/shared_kernels/internal/parallel-test_dummy-impl.h similarity index 86% rename from torchao/experimental/ops/parallel-test_dummy-impl.h rename to torchao/csrc/cpu/shared_kernels/internal/parallel-test_dummy-impl.h index de5a5f63ad..4a82cbd504 100644 --- a/torchao/experimental/ops/parallel-test_dummy-impl.h +++ b/torchao/csrc/cpu/shared_kernels/internal/parallel-test_dummy-impl.h @@ -15,9 +15,13 @@ void torchao::parallel_1d(const int64_t begin, const int64_t end, const F& f) { } } -inline void torchao::set_num_threads(int num_threads) { - torchao::parallel::internal::num_threads_test_dummy_ = num_threads; -} inline int torchao::get_num_threads() { return torchao::parallel::internal::num_threads_test_dummy_; } + + +namespace torchao::parallel { +inline void set_num_threads_in_test_dummy(int num_threads) { + torchao::parallel::internal::num_threads_test_dummy_ = num_threads; +} +} diff --git a/torchao/experimental/ops/parallel.h b/torchao/csrc/cpu/shared_kernels/internal/parallel.h similarity index 80% rename from torchao/experimental/ops/parallel.h rename to torchao/csrc/cpu/shared_kernels/internal/parallel.h index 5372c5a2dd..81f98b92c7 100644 --- a/torchao/experimental/ops/parallel.h +++ b/torchao/csrc/cpu/shared_kernels/internal/parallel.h @@ -12,8 +12,6 @@ namespace torchao { template void parallel_1d(const int64_t begin, const int64_t end, const F& f); -void set_num_threads(int num_threads); - int get_num_threads(); } // namespace torchao @@ -28,37 +26,37 @@ int get_num_threads(); #pragma message( \ "AT_PARALLEL_OPENMP is not set; TORCHAO_PARALLEL_ATEN may be single-threaded.") #endif -#include +#include #else #ifdef TORCHAO_PARALLEL_EXECUTORCH #pragma message( \ "TORCHAO_PARALLEL_EXECUTORCH is set. Using ExecuTorch parallel backend.") -#include +#include #else #ifdef TORCHAO_PARALLEL_PTHREADPOOL #pragma message( \ "TORCHAO_PARALLEL_PTHREADPOOL is set. Using pthreadpool parallel backend.") -#include +#include #else #ifdef TORCHAO_PARALLEL_OPENMP #pragma message( \ "TORCHAO_PARALLEL_OPENMP is set. Using OPENMP parallel backend.") -#include +#include #else #if defined TORCHAO_PARALLEL_SINGLE_THREADED #pragma message( \ "TORCHAO_PARALLEL_SINGLE_THREADED is set. Using single-threaded parallel backend.") -#include +#include #else #if defined TORCHAO_PARALLEL_TEST_DUMMY #pragma message( \ "TORCHAO_PARALLEL_TEST_DUMMY is set. Using test dummy parallel backend.") -#include +#include #else #error \ diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_config.h b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/kernel_config.h similarity index 98% rename from torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_config.h rename to torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/kernel_config.h index b699bdd3d3..c54b8af090 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_config.h +++ b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/kernel_config.h @@ -5,8 +5,8 @@ // LICENSE file in the root directory of this source tree. #pragma once -#include -#include +#include +#include #include #include diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/kernel_selector.h similarity index 97% rename from torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h rename to torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/kernel_selector.h index 2633920a51..d1bf056a43 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h +++ b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/kernel_selector.h @@ -6,19 +6,19 @@ #pragma once #include -#include -#include +#include +#include #include #include #include #if defined(TORCHAO_BUILD_CPU_AARCH64) #if defined(TORCHAO_ENABLE_ARM_NEON_DOT) -#include +#include #endif // TORCHAO_ENABLE_ARM_NEON_DOT #if defined(TORCHAO_ENABLE_KLEIDI) -#include +#include #endif // TORCHAO_ENABLE_KLEIDI #endif // TORCHAO_BUILD_CPU_AARCH64 diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp similarity index 96% rename from torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp rename to torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp index 8caffe4342..e95191d925 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp +++ b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp @@ -5,10 +5,10 @@ // LICENSE file in the root directory of this source tree. #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include #include diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h similarity index 91% rename from torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h rename to torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h index 95e1640ad9..a148d3aa31 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h +++ b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h @@ -7,8 +7,8 @@ #pragma once #include #include -#include -#include +#include +#include #include #include diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h similarity index 90% rename from torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h rename to torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h index 08fa5c6d42..94df29d669 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h +++ b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h @@ -6,16 +6,16 @@ #pragma once -#include -#include -#include -#include +#include +#include +#include +#include #include #include namespace { -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN template Tensor pack_weights_cpu( const Tensor& weight_qvals, @@ -106,9 +106,9 @@ Tensor pack_weights_cpu( return packed_weights; } -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN template Tensor pack_weights_meta( const Tensor& weight_qvals, @@ -146,9 +146,9 @@ Tensor pack_weights_meta( torch::TensorOptions().device(c10::DeviceType::Meta).dtype(torch::kInt8); return torch::empty({static_cast(packed_weight_data_size)}, options); } -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN -#if defined(USE_ATEN) || defined(USE_EXECUTORCH) +#if defined(TORCHAO_SHARED_KERNELS_BUILD_ATEN) || defined(TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH) template Tensor linear_out_cpu( const Tensor& activations, @@ -161,10 +161,10 @@ Tensor linear_out_cpu( TORCHAO_CHECK(k >= 1, "k must be >= 1"); TORCHAO_CHECK(group_size >= 1, "group_size must be >= 1"); -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK( activations.dtype() == torch::kFloat32, "activations must be float32"); -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK(activations.dim() == 2, "activations must be 2D"); int m = activations.size(0); @@ -172,18 +172,18 @@ Tensor linear_out_cpu( TORCHAO_CHECK( k == k_, "activation shape is incompatible with packed weights."); -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK(out.dtype() == torch::kFloat32, "out must be float32"); -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN // Explicit cast from int64_t to int is required for Executorch TORCHAO_RESIZE_TENSOR(out, {(int)m, (int)n}); TORCHAO_CHECK(packed_weights.dim() == 1, "packed_weights must be 1D"); -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK( packed_weights.dtype() == torch::kInt8, "packed_weights must be int8"); -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN TORCHAO_CHECK( packed_weights.size(0) >= torchao::ops::PackedWeightsHeader::size(), "packed_weights is not big enough to read the header."); @@ -210,9 +210,9 @@ Tensor linear_out_cpu( return out; } -#endif // defined(USE_ATEN) || defined(USE_EXECUTORCH) +#endif // defined(TORCHAO_SHARED_KERNELS_BUILD_ATEN) || defined(TORCHAO_SHARED_KERNELS_BUILD_EXECUTORCH) -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN template Tensor linear_cpu( const Tensor& activations, @@ -225,9 +225,9 @@ Tensor linear_cpu( activations, packed_weights, group_size, n, k, output_tensor); return output_tensor; } -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN template Tensor pack_weights_with_lut_cpu( const Tensor& weight_qval_idxs, @@ -324,9 +324,9 @@ Tensor pack_weights_with_lut_cpu( return packed_weights; } -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN -#ifdef USE_ATEN +#ifdef TORCHAO_SHARED_KERNELS_BUILD_ATEN template Tensor pack_weights_with_lut_meta( const Tensor& weight_qval_idxs, @@ -361,6 +361,6 @@ Tensor pack_weights_with_lut_meta( torch::TensorOptions().device(c10::DeviceType::Meta).dtype(torch::kInt8); return torch::empty({static_cast(packed_weight_data_size)}, options); } -#endif // USE_ATEN +#endif // TORCHAO_SHARED_KERNELS_BUILD_ATEN } // namespace diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp similarity index 97% rename from torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp rename to torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp index 7e5799b5fd..466fd2567f 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp +++ b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp @@ -4,7 +4,7 @@ // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. -#include +#include #define DEFINE_OP(weight_nbit) \ m.def( \ diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch.cpp b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch.cpp similarity index 91% rename from torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch.cpp rename to torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch.cpp index 1275accbaa..78ccefecb7 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch.cpp +++ b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch.cpp @@ -1,4 +1,4 @@ -#include +#include #define DEFINE_OP(weight_nbit) \ Tensor _op_out_##weight_nbit( \ diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_format.h b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/packed_weights_format.h similarity index 96% rename from torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_format.h rename to torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/packed_weights_format.h index e22082f9f1..e95593c13b 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_format.h +++ b/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/packed_weights_format.h @@ -6,7 +6,7 @@ #pragma once -#include +#include namespace torchao::ops::linear_8bit_act_xbit_weight { diff --git a/torchao/csrc/cpu/shared_kernels/tests/CMakeLists.txt b/torchao/csrc/cpu/shared_kernels/tests/CMakeLists.txt new file mode 100644 index 0000000000..28bda6a1b8 --- /dev/null +++ b/torchao/csrc/cpu/shared_kernels/tests/CMakeLists.txt @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +project(torchao_tests) + +set(CMAKE_GTEST_DISCOVER_TESTS_DISCOVERY_MODE PRE_TEST) + +include_directories(${TORCHAO_INCLUDE_DIRS}) + +set(TEST_TARGET_PREFIX "torchao_tests_shared_kernels_") + +add_executable( + ${TEST_TARGET_PREFIX}test_linear_8bit_act_xbit_weight + test_linear_8bit_act_xbit_weight.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/csrc/cpu/shared_kernels/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp +) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_linear_8bit_act_xbit_weight + PRIVATE + GTest::gtest_main +) +if (TORCHAO_BUILD_CPU_AARCH64) + target_link_libraries( + ${TEST_TARGET_PREFIX}test_linear_8bit_act_xbit_weight + PRIVATE + torchao_kernels_aarch64 + ) +endif() +if (TORCHAO_BUILD_KLEIDIAI) + target_link_libraries( + ${TEST_TARGET_PREFIX}test_linear_8bit_act_xbit_weight + PRIVATE + kleidiai + ) +endif() +target_link_torchao_parallel_backend( ${TEST_TARGET_PREFIX}test_linear_8bit_act_xbit_weight test_dummy) + +add_executable( + ${TEST_TARGET_PREFIX}test_groupwise_lowbit_weight_lut + test_groupwise_lowbit_weight_lut.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/csrc/cpu/shared_kernels/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp +) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_groupwise_lowbit_weight_lut + PRIVATE + GTest::gtest_main +) +if (TORCHAO_BUILD_CPU_AARCH64) + target_link_libraries( + ${TEST_TARGET_PREFIX}test_groupwise_lowbit_weight_lut + PRIVATE + torchao_kernels_aarch64 + ) +endif() +target_link_torchao_parallel_backend(${TEST_TARGET_PREFIX}test_groupwise_lowbit_weight_lut test_dummy) + +include(GoogleTest) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_groupwise_lowbit_weight_lut) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_linear_8bit_act_xbit_weight) diff --git a/torchao/experimental/ops/tests/generate_tests.py b/torchao/csrc/cpu/shared_kernels/tests/generate_tests.py similarity index 100% rename from torchao/experimental/ops/tests/generate_tests.py rename to torchao/csrc/cpu/shared_kernels/tests/generate_tests.py diff --git a/torchao/experimental/ops/tests/test_groupwise_lowbit_weight_lut.cpp b/torchao/csrc/cpu/shared_kernels/tests/test_groupwise_lowbit_weight_lut.cpp similarity index 94% rename from torchao/experimental/ops/tests/test_groupwise_lowbit_weight_lut.cpp rename to torchao/csrc/cpu/shared_kernels/tests/test_groupwise_lowbit_weight_lut.cpp index a2a790a30b..10bf9bcd3c 100644 --- a/torchao/experimental/ops/tests/test_groupwise_lowbit_weight_lut.cpp +++ b/torchao/csrc/cpu/shared_kernels/tests/test_groupwise_lowbit_weight_lut.cpp @@ -6,12 +6,12 @@ #include #if defined(TORCHAO_BUILD_CPU_AARCH64) -#include +#include #endif // TORCHAO_BUILD_CPU_AARCH64 -#include -#include -#include -#include +#include +#include +#include +#include const float kTol = 1.0e-5; using namespace torchao::ops::groupwise_lowbit_weight_lut; @@ -86,7 +86,7 @@ void test_groupwise_lowbit_weight_lut( auto output = std::vector(m * n); for (auto num_threads : {1, 4, 500}) { - torchao::set_num_threads(num_threads); + torchao::parallel::set_num_threads_in_test_dummy(num_threads); EXPECT_EQ(torchao::get_num_threads(), num_threads); auto packed_weight_data_size = ukernel_config.packed_weights_size( n, diff --git a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp b/torchao/csrc/cpu/shared_kernels/tests/test_linear_8bit_act_xbit_weight.cpp similarity index 99% rename from torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp rename to torchao/csrc/cpu/shared_kernels/tests/test_linear_8bit_act_xbit_weight.cpp index 16c38aa8d3..7631d34a03 100644 --- a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp +++ b/torchao/csrc/cpu/shared_kernels/tests/test_linear_8bit_act_xbit_weight.cpp @@ -7,15 +7,15 @@ #include // TODO: move test_utils.h out of aarch64 #if defined(TORCHAO_BUILD_CPU_AARCH64) -#include +#include #endif // TORCHAO_BUILD_CPU_AARCH64 -#include -#include -#include -#include +#include +#include +#include +#include #if defined(TORCHAO_ENABLE_KLEIDI) -#include +#include using namespace torchao::kernels::cpu::aarch64::kleidi:: kai_matmul_clamp_f32_qai8dxp_qsi4c32p; #endif // TORCHAO_ENABLE_KLEIDI @@ -111,7 +111,7 @@ void test_linear_8bit_act_xbit_weight( auto output = std::vector(m * n); for (auto num_threads : {1, 4, 500}) { - torchao::set_num_threads(num_threads); + torchao::parallel::set_num_threads_in_test_dummy(num_threads); EXPECT_EQ(torchao::get_num_threads(), num_threads); // Pack weights diff --git a/torchao/csrc/cpu/torch_free_kernels/README.md b/torchao/csrc/cpu/torch_free_kernels/README.md new file mode 100644 index 0000000000..e1787bd980 --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/README.md @@ -0,0 +1,8 @@ +# Torch free kernels + +Kernels in this directory do not depend on Torch. Rather than use Tensor, they are written with raw pointers. These raw kernels are used by ATen/ExecuTorch kernels in torchao/csrc/cpu/shared_kernels. + +Code is organized into subdirectories by CPU architecture: +* aarch64 (Arm) +* fallback (architecture-independent / generic C++) +* interface (high-level interface for fallback and architecture-specific code) diff --git a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt b/torchao/csrc/cpu/torch_free_kernels/aarch64/CMakeLists.txt similarity index 73% rename from torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt rename to torchao/csrc/cpu/torch_free_kernels/aarch64/CMakeLists.txt index dad1c91995..42f9cc82b7 100644 --- a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/CMakeLists.txt @@ -13,3 +13,11 @@ if (TORCHAO_BUILD_CPU_AARCH64) ${CMAKE_CURRENT_SOURCE_DIR}/valpacking/interleave.cpp ) endif() + +if (TORCHAO_BUILD_TESTS) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tests) +endif() + +if (TORCHAO_BUILD_BENCHMARKS) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/benchmarks) +endif() diff --git a/torchao/csrc/cpu/torch_free_kernels/aarch64/benchmarks/CMakeLists.txt b/torchao/csrc/cpu/torch_free_kernels/aarch64/benchmarks/CMakeLists.txt new file mode 100644 index 0000000000..d9d0480dfb --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/benchmarks/CMakeLists.txt @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +project(torchao_benchmarks) +set(CMAKE_BUILD_TYPE Release) + +set(TARGET_PREFIX "torchao_benchmarks_torch_free_kernels_aarch64_") + +add_library( + ${TARGET_PREFIX}dep + ${TORCHAO_INCLUDE_DIRS}/torchao/csrc/cpu/torch_free_kernels/aarch64/reduction/find_min_and_max.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/csrc/cpu/torch_free_kernels/aarch64/reduction/compute_sum.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/csrc/cpu/torch_free_kernels/aarch64/quantization/quantize.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/csrc/cpu/torch_free_kernels/aarch64/valpacking/interleave.cpp +) + +add_executable(${TARGET_PREFIX}benchmark_quantization benchmark_quantization.cpp) +target_link_libraries( + ${TARGET_PREFIX}benchmark_quantization + PRIVATE + benchmark::benchmark + ${TARGET_PREFIX}dep +) + +add_executable(${TARGET_PREFIX}benchmark_bitpacking benchmark_bitpacking.cpp) +target_link_libraries( + ${TARGET_PREFIX}benchmark_bitpacking + PRIVATE + benchmark::benchmark + ${TARGET_PREFIX}dep +) + +# TODO: fix this, it's not working right now because of code refactors +# add_executable(${TARGET_PREFIX}benchmark_linear benchmark_linear.cpp) +# target_link_libraries( +# ${TARGET_PREFIX}benchmark_linear +# PRIVATE +# benchmark::benchmark +# ${TARGET_PREFIX}dep +# ) diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/benchmarks/benchmark_bitpacking.cpp similarity index 96% rename from torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/benchmarks/benchmark_bitpacking.cpp index a6bb8b478f..d31233b09b 100644 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/benchmarks/benchmark_bitpacking.cpp @@ -9,15 +9,15 @@ #include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include namespace { diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/benchmarks/benchmark_linear.cpp similarity index 95% rename from torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/benchmarks/benchmark_linear.cpp index 4e9759ab2e..26abe6918a 100644 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/benchmarks/benchmark_linear.cpp @@ -5,9 +5,9 @@ // LICENSE file in the root directory of this source tree. #include -#include -#include -#include +#include +#include +#include #include template @@ -92,7 +92,7 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot( int group_size = state.range(3); using namespace torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot; + channelwise_8bit_activation_groupwise_lowbit_weight; auto test_case = torchao:: channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( @@ -164,7 +164,7 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot( int group_size = state.range(3); using namespace torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + channelwise_8bit_activation_groupwise_lowbit_weight; auto test_case = torchao:: channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_quantization.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/benchmarks/benchmark_quantization.cpp similarity index 84% rename from torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_quantization.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/benchmarks/benchmark_quantization.cpp index 7c81b963dc..d877b905d0 100644 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_quantization.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/benchmarks/benchmark_quantization.cpp @@ -7,9 +7,9 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include -#include -#include +#include +#include +#include static void benchmark_quantize(benchmark::State& state) { int nbit = state.range(0); diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/bitpack.h similarity index 98% rename from torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/bitpack.h index ca5af62f33..01e8b85e1d 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/bitpack.h @@ -9,14 +9,14 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include #include namespace torchao { diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint1.h similarity index 98% rename from torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint1.h index de999a53d6..d24425745e 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint1.h @@ -8,7 +8,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include +#include // This file contains bitpacking and unpacking methods for uint1. // These are not inteded to be used outside of bitpacking directory. diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint2.h similarity index 98% rename from torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint2.h index 630bc22798..b4874154e1 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint2.h @@ -9,7 +9,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include +#include // This file contains bitpacking and unpacking methods for uint4. // These are not inteded to be used outside of bitpacking directory. diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint3.h similarity index 99% rename from torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint3.h index a808ee3a27..6063c12008 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint3.h @@ -9,7 +9,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include +#include // This file contains bitpacking and unpacking methods for uint3. // These are not inteded to be used outside of bitpacking directory. diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint4.h similarity index 97% rename from torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint4.h index fba626ea57..2a36f3c429 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint4.h @@ -9,7 +9,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include +#include // This file contains bitpacking and unpacking methods for uint4. // These are not inteded to be used outside of bitpacking directory. diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint5.h similarity index 99% rename from torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint5.h index 456706b76a..4771bab584 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint5.h @@ -9,7 +9,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include +#include // This file contains bitpacking and unpacking methods for uint5. // These are not inteded to be used outside of bitpacking directory. diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint6.h similarity index 99% rename from torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint6.h index d15094ddfb..3ae83fab09 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint6.h @@ -9,7 +9,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include +#include // This file contains bitpacking and unpacking methods for uint5. // These are not inteded to be used outside of bitpacking directory. diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint7.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint7.h similarity index 99% rename from torchao/experimental/kernels/cpu/aarch64/bitpacking/uint7.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint7.h index 1fc2a8d5cb..f1130c89bd 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint7.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/bitpacking/uint7.h @@ -9,7 +9,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include +#include // This file contains bitpacking and unpacking methods for uint7. // These are not inteded to be used outside of bitpacking directory. diff --git a/torchao/experimental/kernels/cpu/aarch64/embedding/embedding.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/embedding/embedding.h similarity index 97% rename from torchao/experimental/kernels/cpu/aarch64/embedding/embedding.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/embedding/embedding.h index c750b6d534..0f6d8a2339 100644 --- a/torchao/experimental/kernels/cpu/aarch64/embedding/embedding.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/embedding/embedding.h @@ -9,9 +9,9 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include -#include -#include +#include +#include +#include #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/embedding/embedding_lut.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/embedding/embedding_lut.h similarity index 98% rename from torchao/experimental/kernels/cpu/aarch64/embedding/embedding_lut.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/embedding/embedding_lut.h index 1d551f9d2b..573fc8020d 100644 --- a/torchao/experimental/kernels/cpu/aarch64/embedding/embedding_lut.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/embedding/embedding_lut.h @@ -8,9 +8,9 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include -#include -#include +#include +#include +#include #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h similarity index 99% rename from torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h index aa338fc165..777d73cebc 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h @@ -28,7 +28,7 @@ #include #endif // TORCHAO_ENABLE_ARM_I8MM -#include +#include namespace torchao::kernels::cpu::aarch64::kleidi { diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/kleidi/pack.h similarity index 100% rename from torchao/experimental/kernels/cpu/aarch64/kleidi/pack.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/kleidi/pack.h diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h similarity index 90% rename from torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h index ce0ac804c9..849d99cb8a 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/channelwise_8bit_activation_groupwise_lowbit_weight.h @@ -10,12 +10,12 @@ #include #include -#include -#include +#include +#include -#include -#include -#include +#include +#include +#include namespace torchao::kernels::cpu::aarch64::linear:: channelwise_8bit_activation_groupwise_lowbit_weight { diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x1x32_f32_neondot-impl.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x1x32_f32_neondot-impl.h similarity index 98% rename from torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x1x32_f32_neondot-impl.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x1x32_f32_neondot-impl.h index 1d48f6f2b0..535bf7a084 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x1x32_f32_neondot-impl.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x1x32_f32_neondot-impl.h @@ -8,7 +8,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) -#include +#include #include namespace torchao::kernels::cpu::aarch64::linear:: diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x4x16_f32_neondot-impl.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x4x16_f32_neondot-impl.h similarity index 99% rename from torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x4x16_f32_neondot-impl.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x4x16_f32_neondot-impl.h index e2bb78d385..40be2c5231 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x4x16_f32_neondot-impl.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x4x16_f32_neondot-impl.h @@ -8,7 +8,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) -#include +#include #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h similarity index 99% rename from torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h index 7a53c7302c..78246e211d 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/kernel_1x8x16_f32_neondot-impl.h @@ -8,7 +8,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) -#include +#include #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_activations.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_activations.h similarity index 95% rename from torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_activations.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_activations.h index 5967c5b14e..d7558dd4ce 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_activations.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_activations.h @@ -8,8 +8,8 @@ #if defined(__aarch64__) || defined(__ARM_NEON) -#include -#include +#include +#include #include namespace torchao::kernels::cpu::aarch64::linear::channelwise_8bit_activation_groupwise_lowbit_weight::activation_packing { diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_weights.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_weights.h similarity index 98% rename from torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_weights.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_weights.h index 7412b795e7..133c4a7f25 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_weights.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_weights.h @@ -2,10 +2,10 @@ #if defined(__aarch64__) || defined(__ARM_NEON) -#include -#include -#include -#include +#include +#include +#include +#include #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/groupwise_lowbit_weight_lut.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/groupwise_lowbit_weight/groupwise_lowbit_weight_lut.h similarity index 96% rename from torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/groupwise_lowbit_weight_lut.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/linear/groupwise_lowbit_weight/groupwise_lowbit_weight_lut.h index 897ec44549..b0fea65afb 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/groupwise_lowbit_weight_lut.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/groupwise_lowbit_weight/groupwise_lowbit_weight_lut.h @@ -7,9 +7,9 @@ #include #include -#include -#include -#include +#include +#include +#include namespace torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut { diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/kernel_f32-impl.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/groupwise_lowbit_weight/kernel_f32-impl.h similarity index 98% rename from torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/kernel_f32-impl.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/linear/groupwise_lowbit_weight/kernel_f32-impl.h index 3b97e54730..b50c886d11 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/kernel_f32-impl.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/groupwise_lowbit_weight/kernel_f32-impl.h @@ -7,8 +7,8 @@ #if defined(aarch64) || defined(__ARM_NEON) #include -#include -#include +#include +#include #include #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/pack_activations.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/groupwise_lowbit_weight/pack_activations.h similarity index 100% rename from torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/pack_activations.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/linear/groupwise_lowbit_weight/pack_activations.h diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/pack_weights.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/groupwise_lowbit_weight/pack_weights.h similarity index 96% rename from torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/pack_weights.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/linear/groupwise_lowbit_weight/pack_weights.h index a219bcdfde..021693caec 100644 --- a/torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/pack_weights.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/linear/groupwise_lowbit_weight/pack_weights.h @@ -1,10 +1,10 @@ #pragma once #if defined(aarch64) || defined(__ARM_NEON) -#include -#include -#include -#include +#include +#include +#include +#include #include #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/lut/lut.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/lut/lut.h similarity index 98% rename from torchao/experimental/kernels/cpu/aarch64/lut/lut.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/lut/lut.h index 6935412110..c8b76d979f 100644 --- a/torchao/experimental/kernels/cpu/aarch64/lut/lut.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/lut/lut.h @@ -9,7 +9,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include +#include namespace torchao::lut { diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h similarity index 99% rename from torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h index 5ed3b686fd..925bbbb4bd 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h @@ -13,8 +13,8 @@ #include #include -#include -#include +#include +#include namespace torchao::kernels::cpu::aarch64::quantized_matmul { namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal::internal { diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h similarity index 98% rename from torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h index c976be39f5..2c34cebc3c 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h @@ -13,8 +13,8 @@ #include #include -#include -#include +#include +#include namespace torchao::kernels::cpu::aarch64::quantized_matmul { namespace channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot::internal { diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot-impl.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot-impl.h similarity index 99% rename from torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot-impl.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot-impl.h index 19bde9dad9..80417f37e4 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot-impl.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot-impl.h @@ -13,8 +13,8 @@ #include #include -#include -#include +#include +#include namespace torchao::kernels::cpu::aarch64::quantized_matmul { namespace channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot::internal { diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h similarity index 98% rename from torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h index 4fc393fcaf..28f173e9bc 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h @@ -13,8 +13,8 @@ #include #include -#include -#include +#include +#include namespace torchao::kernels::cpu::aarch64::quantized_matmul { namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32::internal { diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_4x16x4_f32_impl.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/fp32_a_input_channelwise_8bit_b_4x16x4_f32_impl.h similarity index 98% rename from torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_4x16x4_f32_impl.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/fp32_a_input_channelwise_8bit_b_4x16x4_f32_impl.h index a3dd44a10b..ffcd0a1f1d 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_4x16x4_f32_impl.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/fp32_a_input_channelwise_8bit_b_4x16x4_f32_impl.h @@ -13,8 +13,8 @@ #include #include -#include -#include +#include +#include namespace torchao::kernels::cpu::aarch64::quantized_matmul { namespace fp32_a_input_channelwise_8bit_b_4x16x4_f32::internal { diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/matmul.h similarity index 91% rename from torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/matmul.h index 86b14a52aa..371dc55666 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/matmul.h @@ -5,7 +5,7 @@ // LICENSE file in the root directory of this source tree. // TODO: this file will be deleted and replaced by -// torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/include.h +// torchao/csrc/cpu/torch_free_kernels/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/include.h // It exists now to prevent breaking existing code in the interim. #pragma once @@ -309,10 +309,10 @@ void kernel( } // namespace fp32_a_input_channelwise_8bit_b_f32 } // namespace torchao::kernels::cpu::aarch64::quantized_matmul -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include #endif // defined(__aarch64__) && defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul_utils.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/matmul_utils.h similarity index 98% rename from torchao/experimental/kernels/cpu/aarch64/matmul/matmul_utils.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/matmul_utils.h index 0a3c8463a8..db577c39a8 100644 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul_utils.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/matmul/matmul_utils.h @@ -9,7 +9,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include +#include #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/packing/utils.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/packing/utils.h similarity index 100% rename from torchao/experimental/kernels/cpu/aarch64/packing/utils.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/packing/utils.h diff --git a/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/quantization/quantize.cpp similarity index 97% rename from torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/quantization/quantize.cpp index 3460d67fba..42301dc2fa 100644 --- a/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/quantization/quantize.cpp @@ -6,7 +6,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) -#include +#include #include #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/quantization/quantize.h similarity index 100% rename from torchao/experimental/kernels/cpu/aarch64/quantization/quantize.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/quantization/quantize.h diff --git a/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/reduction/compute_sum.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/reduction/compute_sum.cpp index 3a41307cb3..1b9d2aa97b 100644 --- a/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/reduction/compute_sum.cpp @@ -6,7 +6,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) -#include +#include #include int32_t torchao::kernels::cpu::aarch64::reduction::compute_sum( diff --git a/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/reduction/find_min_and_max.cpp similarity index 93% rename from torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/reduction/find_min_and_max.cpp index 89707eb0ac..ea4efcf1cc 100644 --- a/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/reduction/find_min_and_max.cpp @@ -6,7 +6,7 @@ #if defined(__aarch64__) || defined(__ARM_NEON) -#include +#include #include void torchao::kernels::cpu::aarch64::reduction::find_min_and_max( diff --git a/torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/reduction/reduction.h similarity index 100% rename from torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/reduction/reduction.h diff --git a/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/CMakeLists.txt b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/CMakeLists.txt new file mode 100644 index 0000000000..8d214b2e61 --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/CMakeLists.txt @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +project(torchao_tests) + + # Delay test discovery till runtime. Useful for cross-compiling. +set(CMAKE_GTEST_DISCOVER_TESTS_DISCOVERY_MODE PRE_TEST) + +set(TEST_TARGET_PREFIX "torchao_tests_torch_free_kernels_aarch64_") + +add_library( + ${TEST_TARGET_PREFIX}dep + ${TORCHAO_INCLUDE_DIRS}/torchao/csrc/cpu/torch_free_kernels/aarch64/reduction/find_min_and_max.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/csrc/cpu/torch_free_kernels/aarch64/reduction/compute_sum.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/csrc/cpu/torch_free_kernels/aarch64/quantization/quantize.cpp +) + +enable_testing() + +add_executable(${TEST_TARGET_PREFIX}test_quantization test_quantization.cpp) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_quantization + PRIVATE + GTest::gtest_main + ${TEST_TARGET_PREFIX}dep +) + +add_executable(${TEST_TARGET_PREFIX}test_reduction test_reduction.cpp) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_reduction + PRIVATE + GTest::gtest_main + ${TEST_TARGET_PREFIX}dep +) + +add_executable(${TEST_TARGET_PREFIX}test_bitpacking test_bitpacking.cpp) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_bitpacking + PRIVATE + GTest::gtest_main + ${TEST_TARGET_PREFIX}dep +) + +add_executable(${TEST_TARGET_PREFIX}test_linear test_linear.cpp) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_linear + PRIVATE + GTest::gtest_main + ${TEST_TARGET_PREFIX}dep + torchao_kernels_aarch64 +) + +add_executable(${TEST_TARGET_PREFIX}test_embedding_lut test_embedding_lut.cpp) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_embedding_lut + PRIVATE + GTest::gtest_main + ${TEST_TARGET_PREFIX}dep +) + +add_executable(${TEST_TARGET_PREFIX}test_embedding test_embedding.cpp) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_embedding + PRIVATE + GTest::gtest_main + ${TEST_TARGET_PREFIX}dep +) + +add_executable(${TEST_TARGET_PREFIX}test_weight_packing test_weight_packing.cpp) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_weight_packing + PRIVATE + GTest::gtest_main + ${TEST_TARGET_PREFIX}dep +) + +add_executable(${TEST_TARGET_PREFIX}test_qmatmul test_qmatmul.cpp) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_qmatmul + PRIVATE + GTest::gtest_main + ${TEST_TARGET_PREFIX}dep +) + +add_executable(${TEST_TARGET_PREFIX}test_lut test_lut.cpp) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_lut + PRIVATE + GTest::gtest_main + ${TEST_TARGET_PREFIX}dep +) + +add_executable(${TEST_TARGET_PREFIX}test_bitpack_fallback_compatibility test_bitpack_fallback_compatibility.cpp) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_bitpack_fallback_compatibility + PRIVATE + GTest::gtest_main + ${TEST_TARGET_PREFIX}dep +) + +include(GoogleTest) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_quantization) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_reduction) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_bitpacking) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_linear) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_embedding) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_embedding_lut) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_weight_packing) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_qmatmul) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_lut) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_bitpack_fallback_compatibility) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpack_fallback_compatibility.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_bitpack_fallback_compatibility.cpp similarity index 99% rename from torchao/experimental/kernels/cpu/aarch64/tests/test_bitpack_fallback_compatibility.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_bitpack_fallback_compatibility.cpp index d0a8622b36..ccae74cbcd 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpack_fallback_compatibility.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_bitpack_fallback_compatibility.cpp @@ -8,9 +8,9 @@ #include #include -#include -#include -#include +#include +#include +#include // --- Compatibility Tests for uint1 --- diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_bitpacking.cpp similarity index 97% rename from torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_bitpacking.cpp index 93e68eb86c..d052ae1d47 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_bitpacking.cpp @@ -8,15 +8,15 @@ #include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include TEST(test_bitpacking_8_uint1_values, PackUnpackAreSame) { diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_embedding.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_embedding.cpp similarity index 96% rename from torchao/experimental/kernels/cpu/aarch64/tests/test_embedding.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_embedding.cpp index 8fe7e69574..e5cdfb0a1b 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_embedding.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_embedding.cpp @@ -7,9 +7,9 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include -#include -#include +#include +#include +#include #include float kTol = 0.0001; diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_embedding_lut.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_embedding_lut.cpp similarity index 96% rename from torchao/experimental/kernels/cpu/aarch64/tests/test_embedding_lut.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_embedding_lut.cpp index 23ef66b9e8..5802a179d0 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_embedding_lut.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_embedding_lut.cpp @@ -7,8 +7,8 @@ #if defined(__aarch64__) || defined(__ARM_NEON) #include -#include -#include +#include +#include #include float kTol = 0.0001; diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_linear.cpp similarity index 97% rename from torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_linear.cpp index 6d6101e3cf..bf99823052 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_linear.cpp @@ -10,9 +10,9 @@ #include #include -#include -#include -#include +#include +#include +#include float kTol = 0.0001; diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_lut.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_lut.cpp similarity index 98% rename from torchao/experimental/kernels/cpu/aarch64/tests/test_lut.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_lut.cpp index 6cd9ee8dfa..6d9214eeba 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_lut.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_lut.cpp @@ -8,9 +8,9 @@ #include #include -#include -#include -#include +#include +#include +#include #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_qmatmul.cpp similarity index 98% rename from torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_qmatmul.cpp index 18c9986393..5d46937ccf 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_qmatmul.cpp @@ -10,9 +10,9 @@ #include #include -#include -#include -#include +#include +#include +#include float kTol = 0.0001; diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_quantization.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_quantization.cpp similarity index 92% rename from torchao/experimental/kernels/cpu/aarch64/tests/test_quantization.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_quantization.cpp index bb19528de7..ebe3fbdfa8 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_quantization.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_quantization.cpp @@ -8,8 +8,8 @@ #include #include -#include -#include +#include +#include #include // Demonstrate some basic assertions. diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_reduction.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_reduction.cpp similarity index 93% rename from torchao/experimental/kernels/cpu/aarch64/tests/test_reduction.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_reduction.cpp index 0720f2dcf8..44dbafafa5 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_reduction.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_reduction.cpp @@ -8,8 +8,8 @@ #include #include -#include -#include +#include +#include #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_utils.h similarity index 95% rename from torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_utils.h index 3bdf5df8c0..e5742d3f56 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_utils.h @@ -8,61 +8,15 @@ #if defined(__aarch64__) || defined(__ARM_NEON) -#include -#include +#include +#include +#include #include #include #include #include namespace torchao { -inline std::vector -get_random_vector(int size, float min = -1.0, float max = 1.0) { - assert(min < max); - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto dist = std::bind(std::uniform_real_distribution(min, max), rng); - std::vector res(size); - std::generate(res.begin(), res.end(), std::ref(dist)); - return res; -} - -inline std::vector get_random_lowbit_vector(int size, int nbit) { - assert(nbit >= 1); - assert(nbit <= 8); - - int min = 0; - int max = (1 << nbit) - 1; - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto dist = std::bind(std::uniform_int_distribution<>(min, max), rng); - - std::vector res(size); - std::generate(res.begin(), res.end(), std::ref(dist)); - return res; -} - -inline std::vector get_random_signed_lowbit_vector(int size, int nbit) { - assert(nbit >= 1); - assert(nbit <= 8); - - int min = 0; - int max = (1 << nbit) - 1; - int offset = (1 << (nbit - 1)); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto dist = std::bind(std::uniform_int_distribution<>(min, max), rng); - - std::vector res(size); - std::vector tmp(size); - std::generate(tmp.begin(), tmp.end(), std::ref(dist)); - for (int i = 0; i < size; i++) { - res[i] = tmp[i] - offset; - } - return res; -} // TODO move these to a common utils inline uint16_t get_bf16_from_float(float f) { @@ -612,13 +566,11 @@ struct lut_embedding_test_case { weight_scales(weight_scales_), weight_luts(weight_luts_), expected_outputs(expected_outputs_) { - const int total_weights = num_embeddings * embedding_dim; - assert(total_weights % lut_group_size == 0); + assert((num_embeddings * embedding_dim) % lut_group_size == 0); assert(embedding_dim % scale_group_size == 0); assert(this->weight_qval_idxs.size() == num_embeddings * embedding_dim); - const int scales_per_row = embedding_dim / scale_group_size; if (has_scales) { - assert(this->weight_scales.size() == num_embeddings * scales_per_row); + assert(this->weight_scales.size() == num_embeddings * (embedding_dim / scale_group_size)); } assert(this->expected_outputs.size() == num_embeddings * embedding_dim); } diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_utils_quantized_attention.h similarity index 98% rename from torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_utils_quantized_attention.h index 52fb0851bc..ba6fb83069 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_utils_quantized_attention.h @@ -8,9 +8,9 @@ #if defined(__aarch64__) || defined(__ARM_NEON) -#include -#include -#include +#include +#include +#include #include #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_weight_packing.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_weight_packing.cpp similarity index 95% rename from torchao/experimental/kernels/cpu/aarch64/tests/test_weight_packing.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_weight_packing.cpp index fba4fba391..b64d4b2754 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_weight_packing.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/tests/test_weight_packing.cpp @@ -5,8 +5,8 @@ // LICENSE file in the root directory of this source tree. #include -#include -#include +#include +#include template void test_weight_packing( diff --git a/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp b/torchao/csrc/cpu/torch_free_kernels/aarch64/valpacking/interleave.cpp similarity index 97% rename from torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp rename to torchao/csrc/cpu/torch_free_kernels/aarch64/valpacking/interleave.cpp index 0274b0889e..3818fac2d0 100644 --- a/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/aarch64/valpacking/interleave.cpp @@ -4,7 +4,7 @@ // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. -#include +#include #include #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/valpacking/valpack.h b/torchao/csrc/cpu/torch_free_kernels/aarch64/valpacking/valpack.h similarity index 100% rename from torchao/experimental/kernels/cpu/aarch64/valpacking/valpack.h rename to torchao/csrc/cpu/torch_free_kernels/aarch64/valpacking/valpack.h diff --git a/torchao/experimental/kernels/cpu/fallback/CMakeLists.txt b/torchao/csrc/cpu/torch_free_kernels/fallback/CMakeLists.txt similarity index 69% rename from torchao/experimental/kernels/cpu/fallback/CMakeLists.txt rename to torchao/csrc/cpu/torch_free_kernels/fallback/CMakeLists.txt index 0952fcc3f5..bf488ffab5 100644 --- a/torchao/experimental/kernels/cpu/fallback/CMakeLists.txt +++ b/torchao/csrc/cpu/torch_free_kernels/fallback/CMakeLists.txt @@ -3,3 +3,7 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. + +if (TORCHAO_BUILD_TESTS) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tests) +endif() diff --git a/torchao/experimental/kernels/cpu/fallback/bitpacking/bitpack.h b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/bitpack.h similarity index 91% rename from torchao/experimental/kernels/cpu/fallback/bitpacking/bitpack.h rename to torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/bitpack.h index 1a558d27ac..c28c6ec90d 100644 --- a/torchao/experimental/kernels/cpu/fallback/bitpacking/bitpack.h +++ b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/bitpack.h @@ -6,14 +6,14 @@ #pragma once -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include #include namespace torchao::kernels::cpu::fallback::bitpacking { diff --git a/torchao/experimental/kernels/cpu/fallback/bitpacking/uint1.h b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint1.h similarity index 98% rename from torchao/experimental/kernels/cpu/fallback/bitpacking/uint1.h rename to torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint1.h index 67d4512a2c..08e231716b 100644 --- a/torchao/experimental/kernels/cpu/fallback/bitpacking/uint1.h +++ b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint1.h @@ -6,7 +6,7 @@ #pragma once -#include +#include #include namespace torchao::kernels::cpu::fallback::bitpacking { diff --git a/torchao/experimental/kernels/cpu/fallback/bitpacking/uint2.h b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint2.h similarity index 98% rename from torchao/experimental/kernels/cpu/fallback/bitpacking/uint2.h rename to torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint2.h index 2681110348..9dc1cce463 100644 --- a/torchao/experimental/kernels/cpu/fallback/bitpacking/uint2.h +++ b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint2.h @@ -6,7 +6,7 @@ #pragma once -#include +#include #include namespace torchao::kernels::cpu::fallback::bitpacking { namespace internal { diff --git a/torchao/experimental/kernels/cpu/fallback/bitpacking/uint3.h b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint3.h similarity index 99% rename from torchao/experimental/kernels/cpu/fallback/bitpacking/uint3.h rename to torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint3.h index 635e1bca6c..277317d5a2 100644 --- a/torchao/experimental/kernels/cpu/fallback/bitpacking/uint3.h +++ b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint3.h @@ -6,7 +6,7 @@ #pragma once -#include +#include #include namespace torchao::kernels::cpu::fallback::bitpacking { diff --git a/torchao/experimental/kernels/cpu/fallback/bitpacking/uint4.h b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint4.h similarity index 98% rename from torchao/experimental/kernels/cpu/fallback/bitpacking/uint4.h rename to torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint4.h index 27be9488d7..4b98a47143 100644 --- a/torchao/experimental/kernels/cpu/fallback/bitpacking/uint4.h +++ b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint4.h @@ -6,7 +6,7 @@ #pragma once -#include +#include #include namespace torchao::kernels::cpu::fallback::bitpacking { diff --git a/torchao/experimental/kernels/cpu/fallback/bitpacking/uint5.h b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint5.h similarity index 99% rename from torchao/experimental/kernels/cpu/fallback/bitpacking/uint5.h rename to torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint5.h index 2ad408a75a..3de577e05f 100644 --- a/torchao/experimental/kernels/cpu/fallback/bitpacking/uint5.h +++ b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint5.h @@ -6,7 +6,7 @@ #pragma once -#include +#include #include namespace torchao::kernels::cpu::fallback::bitpacking { diff --git a/torchao/experimental/kernels/cpu/fallback/bitpacking/uint6.h b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint6.h similarity index 98% rename from torchao/experimental/kernels/cpu/fallback/bitpacking/uint6.h rename to torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint6.h index 65325b030d..2fcd9334ec 100644 --- a/torchao/experimental/kernels/cpu/fallback/bitpacking/uint6.h +++ b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint6.h @@ -6,7 +6,7 @@ #pragma once -#include +#include #include namespace torchao::kernels::cpu::fallback::bitpacking { diff --git a/torchao/experimental/kernels/cpu/fallback/bitpacking/uint7.h b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint7.h similarity index 98% rename from torchao/experimental/kernels/cpu/fallback/bitpacking/uint7.h rename to torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint7.h index ee4d501324..60493a20b2 100644 --- a/torchao/experimental/kernels/cpu/fallback/bitpacking/uint7.h +++ b/torchao/csrc/cpu/torch_free_kernels/fallback/bitpacking/uint7.h @@ -6,7 +6,7 @@ #pragma once -#include +#include #include namespace torchao::kernels::cpu::fallback::bitpacking { diff --git a/torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h b/torchao/csrc/cpu/torch_free_kernels/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h similarity index 100% rename from torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h rename to torchao/csrc/cpu/torch_free_kernels/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h diff --git a/torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h b/torchao/csrc/cpu/torch_free_kernels/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h similarity index 100% rename from torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h rename to torchao/csrc/cpu/torch_free_kernels/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h diff --git a/torchao/csrc/cpu/torch_free_kernels/fallback/tests/CMakeLists.txt b/torchao/csrc/cpu/torch_free_kernels/fallback/tests/CMakeLists.txt new file mode 100644 index 0000000000..eab4f9e54b --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/fallback/tests/CMakeLists.txt @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +project(torchao_tests) + +set(TEST_TARGET_PREFIX "torchao_tests_torch_free_kernels_fallback_") + +enable_testing() + +add_executable(${TEST_TARGET_PREFIX}test_bitpacking test_bitpacking.cpp) +target_link_libraries( + ${TEST_TARGET_PREFIX}test_bitpacking + PRIVATE + GTest::gtest_main +) + +include(GoogleTest) +gtest_discover_tests(${TEST_TARGET_PREFIX}test_bitpacking) diff --git a/torchao/experimental/kernels/cpu/fallback/tests/test_bitpacking.cpp b/torchao/csrc/cpu/torch_free_kernels/fallback/tests/test_bitpacking.cpp similarity index 92% rename from torchao/experimental/kernels/cpu/fallback/tests/test_bitpacking.cpp rename to torchao/csrc/cpu/torch_free_kernels/fallback/tests/test_bitpacking.cpp index 980f1a1cbe..32177e63da 100644 --- a/torchao/experimental/kernels/cpu/fallback/tests/test_bitpacking.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/fallback/tests/test_bitpacking.cpp @@ -5,15 +5,15 @@ // LICENSE file in the root directory of this source tree. // test pack with cpp unpack with arm_neon #include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include TEST(FallbackBitpackingTest, PackUnpack8_uint1) { diff --git a/torchao/experimental/kernels/cpu/interface/quantized_matmul.h b/torchao/csrc/cpu/torch_free_kernels/interface/quantized_matmul.h similarity index 94% rename from torchao/experimental/kernels/cpu/interface/quantized_matmul.h rename to torchao/csrc/cpu/torch_free_kernels/interface/quantized_matmul.h index 826fe9e85b..da3fd32747 100644 --- a/torchao/experimental/kernels/cpu/interface/quantized_matmul.h +++ b/torchao/csrc/cpu/torch_free_kernels/interface/quantized_matmul.h @@ -8,11 +8,11 @@ #include -#include -#include +#include +#include #if defined(__aarch64__) && defined(__ARM_NEON) -#include +#include #endif // defined(__aarch64__) && defined(__ARM_NEON) namespace torchao::kernels::cpu::quantized_matmul { diff --git a/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp b/torchao/csrc/cpu/torch_free_kernels/interface/test_qmatmul_interface.cpp similarity index 99% rename from torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp rename to torchao/csrc/cpu/torch_free_kernels/interface/test_qmatmul_interface.cpp index 0fbe33ccdc..5ce1593732 100644 --- a/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp +++ b/torchao/csrc/cpu/torch_free_kernels/interface/test_qmatmul_interface.cpp @@ -11,7 +11,7 @@ #include #include -#include +#include float kTol = 0.0001; diff --git a/torchao/experimental/kernels/cpu/aarch64/macro.h b/torchao/csrc/cpu/torch_free_kernels/macro.h similarity index 100% rename from torchao/experimental/kernels/cpu/aarch64/macro.h rename to torchao/csrc/cpu/torch_free_kernels/macro.h diff --git a/torchao/csrc/cpu/torch_free_kernels/test_utils.h b/torchao/csrc/cpu/torch_free_kernels/test_utils.h new file mode 100644 index 0000000000..29b72b51c0 --- /dev/null +++ b/torchao/csrc/cpu/torch_free_kernels/test_utils.h @@ -0,0 +1,62 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include +#include + +namespace torchao { +inline std::vector +get_random_vector(int size, float min = -1.0, float max = 1.0) { + assert(min < max); + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto dist = std::bind(std::uniform_real_distribution(min, max), rng); + std::vector res(size); + std::generate(res.begin(), res.end(), std::ref(dist)); + return res; +} + +inline std::vector get_random_lowbit_vector(int size, int nbit) { + assert(nbit >= 1); + assert(nbit <= 8); + + int min = 0; + int max = (1 << nbit) - 1; + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto dist = std::bind(std::uniform_int_distribution<>(min, max), rng); + + std::vector res(size); + std::generate(res.begin(), res.end(), std::ref(dist)); + return res; +} + +inline std::vector get_random_signed_lowbit_vector(int size, int nbit) { + assert(nbit >= 1); + assert(nbit <= 8); + + int min = 0; + int max = (1 << nbit) - 1; + int offset = (1 << (nbit - 1)); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto dist = std::bind(std::uniform_int_distribution<>(min, max), rng); + + std::vector res(size); + std::vector tmp(size); + std::generate(tmp.begin(), tmp.end(), std::ref(dist)); + for (int i = 0; i < size; i++) { + res[i] = tmp[i] - offset; + } + return res; +} +} // namespace torchao diff --git a/torchao/experimental/CMakeLists.txt b/torchao/experimental/CMakeLists.txt index 317b35643b..84582f704e 100644 --- a/torchao/experimental/CMakeLists.txt +++ b/torchao/experimental/CMakeLists.txt @@ -17,12 +17,7 @@ endif() # Platform options option(TORCHAO_BUILD_ATEN_OPS "Building torchao ops for ATen." ON) -option(TORCHAO_BUILD_EXECUTORCH_OPS "Building torchao ops for ExecuTorch." OFF) option(TORCHAO_BUILD_MPS_OPS "Building torchao MPS ops" OFF) -option(TORCHAO_BUILD_CPU_AARCH64 "Build torchao's CPU aarch64 kernels" OFF) -option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" OFF) -option(TORCHAO_ENABLE_ARM_NEON_DOT "Enable ARM Neon Dot Product extension" OFF) -option(TORCHAO_ENABLE_ARM_I8MM "Enable ARM 8-bit Integer Matrix Multiply instructions" OFF) if(NOT TORCHAO_INCLUDE_DIRS) set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../..) @@ -36,129 +31,17 @@ endif() add_compile_options("-Wall" "-Werror" "-Wno-deprecated" "-Wno-shorten-64-to-32") include(CMakePrintHelpers) -include(${CMAKE_CURRENT_SOURCE_DIR}/Utils.cmake) message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}") include_directories(${TORCHAO_INCLUDE_DIRS}) -# Build cpu/aarch64 kernels -if(TORCHAO_BUILD_CPU_AARCH64) - message(STATUS "Building with cpu/aarch64") - add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64) - - # Set aarch64 compiler options - if (CMAKE_SYSTEM_NAME STREQUAL "Linux") - message(STATUS "Add aarch64 linux compiler options") - add_compile_options( - "-fPIC" - "-Wno-error=unknown-pragmas" - "-Wno-array-parameter" - "-Wno-maybe-uninitialized" - "-Wno-sign-compare" - ) - - # Since versions are hierarchical (each includes features from prior versions): - # - dotprod is included by default in armv8.4-a and later - # - i8mm is included by default in armv8.6-a and later - if(TORCHAO_ENABLE_ARM_I8MM) - message(STATUS "Using armv8.6-a (includes 'i8mm' and 'dotprod' flags)") - add_compile_options("-march=armv8.6-a") - elseif(TORCHAO_ENABLE_ARM_NEON_DOT) - message(STATUS "Using armv8.4-a (includes '+dotprod' flag)") - add_compile_options("-march=armv8.4-a") - endif() - endif() - - if(TORCHAO_ENABLE_ARM_NEON_DOT) - message(STATUS "Building with ARM NEON dot product support") - add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT) - add_compile_options("-march=armv8.4-a+dotprod") - endif() - - if(TORCHAO_ENABLE_ARM_I8MM) - message(STATUS "Building with ARM I8MM support") - add_compile_definitions(TORCHAO_ENABLE_ARM_I8MM) - endif() - - if(TORCHAO_BUILD_KLEIDIAI) - message(STATUS "Building with Arm KleidiAI library") - add_compile_definitions(TORCHAO_ENABLE_KLEIDI) - endif() - - # Defines torchao_kernels_aarch64 - add_subdirectory(kernels/cpu/aarch64) -endif() - -if (NOT TARGET cpuinfo) - # For some reason cpuinfo package has unused functions/variables - # TODO (T215533422): fix upstream - add_compile_options(-Wno-unused-function -Wno-unused-variable) - set(CMAKE_POLICY_VERSION_MINIMUM 3.5) - include(FetchContent) - set(CPUINFO_BUILD_UNIT_TESTS OFF CACHE BOOL "" FORCE) - set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE BOOL "" FORCE) - set(CPUINFO_BUILD_BENCHMARKS OFF CACHE BOOL "" FORCE) - FetchContent_Declare(cpuinfo - GIT_REPOSITORY https://github.com/pytorch/cpuinfo.git - GIT_TAG c61fe919607bbc534d7a5a5707bdd7041e72c5ff - ) - FetchContent_MakeAvailable( - cpuinfo) -endif() - -if (TORCHAO_BUILD_KLEIDIAI) - if (NOT TARGET kleidiai) - include(FetchContent) - # KleidiAI is an open-source library that provides optimized - # performance-critical routines, also known as micro-kernels, for artificial - # intelligence (AI) workloads tailored for Arm® CPUs. - set(KLEIDIAI_BUILD_TESTS OFF CACHE BOOL "" FORCE) - set(KLEIDIAI_BUILD_BENCHMARKS OFF CACHE BOOL "" FORCE) - FetchContent_Declare(kleidiai - GIT_REPOSITORY https://git.gitlab.arm.com/kleidi/kleidiai.git - GIT_TAG v1.12.0 - ) - FetchContent_MakeAvailable(kleidiai) - endif() -endif() - - # Build ATen ops if(TORCHAO_BUILD_ATEN_OPS) find_package(Torch REQUIRED) - set(_torchao_op_srcs_aten) - list(APPEND _torchao_op_srcs_aten - ops/embedding_xbit/op_embedding_xbit_aten.cpp - ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp - ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp - ops/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp - ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_aten.cpp - ) - list(TRANSFORM _torchao_op_srcs_aten PREPEND "${CMAKE_CURRENT_SOURCE_DIR}/") # Use the Python extension name if provided - add_library(torchao_ops_aten SHARED ${_torchao_op_srcs_aten}) - if(DEFINED TORCHAO_CMAKE_EXT_SO_NAME) - message(STATUS "Setting output name to: ${TORCHAO_CMAKE_EXT_SO_NAME}.so") - set_target_properties(torchao_ops_aten PROPERTIES - OUTPUT_NAME ${TORCHAO_CMAKE_EXT_SO_NAME} - PREFIX "" # Remove "lib" prefix for Python extensions - SUFFIX ".so" # Add ".so" suffix for Python extensions - ) - endif() - - target_link_torchao_parallel_backend(torchao_ops_aten "${TORCHAO_PARALLEL_BACKEND}") - if (TORCHAO_BUILD_CPU_AARCH64) - target_link_libraries(torchao_ops_aten PRIVATE torchao_kernels_aarch64) - if (TORCHAO_BUILD_KLEIDIAI) - target_link_libraries(torchao_ops_aten PRIVATE kleidiai) - endif() - endif() - target_link_libraries(torchao_ops_aten PRIVATE cpuinfo) - target_include_directories(torchao_ops_aten PRIVATE "${TORCH_INCLUDE_DIRS}") - target_link_libraries(torchao_ops_aten PRIVATE "${TORCH_LIBRARIES}") - target_compile_definitions(torchao_ops_aten PRIVATE USE_ATEN=1) + add_library(torchao_ops_aten SHARED) # Add MPS support if enabled if (TORCHAO_BUILD_MPS_OPS) @@ -174,40 +57,3 @@ if(TORCHAO_BUILD_ATEN_OPS) DESTINATION lib ) endif() - - -# Build ExecuTorch ops -if(TORCHAO_BUILD_EXECUTORCH_OPS) - # ExecuTorch package is not required, but EXECUTORCH_INCLUDE_DIRS and EXECUTORCH_LIBRARIES must - # be defined and EXECUTORCH_LIBRARIES must include the following libraries installed by ExecuTorch: - # libexecutorch.a - # libextension_threadpool.a - # libcpuinfo.a - # libpthreadpool.a - if(NOT DEFINED EXECUTORCH_INCLUDE_DIRS AND NOT DEFINED EXECUTORCH_LIBRARIES) - message(WARNING "EXECUTORCH_INCLUDE_DIRS and EXECUTORCH_LIBRARIES are not defined. Looking for ExecuTorch.") - find_package(ExecuTorch HINTS ${CMAKE_PREFIX_PATH}/executorch/share/cmake) - endif() - set(_torchao_op_srcs_executorch) - list(APPEND _torchao_op_srcs_executorch - ops/embedding_xbit/op_embedding_xbit_executorch.cpp - ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp - ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch.cpp - ops/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp - ops/groupwise_lowbit_weight_lut/op_groupwise_lowbit_weight_lut_executorch.cpp) - - list(TRANSFORM _torchao_op_srcs_executorch PREPEND "${CMAKE_CURRENT_SOURCE_DIR}/") - add_library(torchao_ops_executorch STATIC ${_torchao_op_srcs_executorch}) - - target_compile_definitions(torchao_ops_executorch PRIVATE USE_EXECUTORCH=1) - - # This links to ExecuTorch - target_link_torchao_parallel_backend(torchao_ops_executorch executorch) - if (TORCHAO_BUILD_CPU_AARCH64) - target_link_libraries(torchao_ops_executorch PRIVATE torchao_kernels_aarch64) - if (TORCHAO_BUILD_KLEIDIAI) - target_link_libraries(torchao_ops_executorch PRIVATE kleidiai) - endif() - endif() - target_link_libraries(torchao_ops_executorch PRIVATE cpuinfo) -endif() diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/benchmarks/CMakeLists.txt deleted file mode 100644 index 5227ff1090..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/CMakeLists.txt +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -cmake_minimum_required(VERSION 3.19) -project(benchmarks) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_BUILD_TYPE Release) - -include(FetchContent) -FetchContent_Declare(googlebenchmark - GIT_REPOSITORY https://github.com/google/benchmark.git - GIT_TAG main) # need main for benchmark::benchmark - -set(BENCHMARK_ENABLE_TESTING OFF) -FetchContent_MakeAvailable( - googlebenchmark) - -add_compile_options("-Wall" "-Werror") - -include(CMakePrintHelpers) -message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}") -include_directories(${TORCHAO_LIBRARIES}) - -add_library( - dep - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp -) - -add_executable(benchmark_quantization benchmark_quantization.cpp) -target_link_libraries( - benchmark_quantization - PRIVATE - benchmark::benchmark - dep -) - -add_executable(benchmark_bitpacking benchmark_bitpacking.cpp) -target_link_libraries( - benchmark_bitpacking - PRIVATE - benchmark::benchmark - dep -) - -add_executable(benchmark_linear benchmark_linear.cpp) -target_link_libraries( - benchmark_linear - PRIVATE - benchmark::benchmark - dep -) diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/build_and_run_benchmarks.sh b/torchao/experimental/kernels/cpu/aarch64/benchmarks/build_and_run_benchmarks.sh deleted file mode 100644 index e7fa9402e2..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/build_and_run_benchmarks.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash -eu -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -set -eu - -if [[ $# -ne 1 ]]; then - echo "Usage: $0 "; - exit 1; -fi - -BENCHMARK_TYPE="${1}" -SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) - -export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../.. -export CMAKE_OUT=/tmp/cmake-out/torch_ao/benchmarks - -# Build -cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ - -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/benchmarks \ - -B ${CMAKE_OUT} - -cmake --build ${CMAKE_OUT} - -# Run -case "${BENCHMARK_TYPE}" in - quantization) ${CMAKE_OUT}/benchmark_quantization; ;; - bitpacking) ${CMAKE_OUT}/benchmark_bitpacking; ;; - linear) ${CMAKE_OUT}/benchmark_linear; ;; - *) echo "Unknown benchmark: $1. Please specify quantization, bitpacking, or linear."; exit 1; ;; -esac diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt deleted file mode 100644 index c89141ac07..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -cmake_minimum_required(VERSION 3.19) -project(tests) -set(CMAKE_CXX_STANDARD 17) - -include(FetchContent) -FetchContent_Declare( - googletest - URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip -) -FetchContent_MakeAvailable(googletest) - -if (ANDROID_ABI) - # We are cross compiling, delay test discovery till runtime - set(CMAKE_GTEST_DISCOVER_TESTS_DISCOVERY_MODE PRE_TEST) -endif() - -add_compile_options("-Wall" "-Werror") - -include(CMakePrintHelpers) -message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}") -include_directories(${TORCHAO_LIBRARIES}) - -add_library( - dep - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp -) - -if(NOT TORCHAO_INCLUDE_DIRS) - set(TORCHAO_INCLUDE_DIRS ${TORCHAO_LIBRARIES}) -endif() - -add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) - -if(TORCHAO_BUILD_KLEIDIAI) - add_compile_definitions(TORCHAO_ENABLE_KLEIDI) - add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT) -endif() - -if(TORCHAO_BUILD_ARM_I8MM) - add_compile_definitions(TORCHAO_ENABLE_ARM_I8MM) -endif() - -enable_testing() - -if (ANDROID_ABI) - # Given where we are today this is sufficent. But needs to be revisited. - # This is also needed for native builds, but keeping it only for cross builds - # for now given the hacky nature. - file(GLOB DOTPROD_SRC_FILES test*.cpp) - message(SRC_FILES: ${DOTPROD_SRC_FILES}) - set_property(SOURCE - ${DOTPROD_SRC_FILES} - APPEND_STRING PROPERTY - COMPILE_FLAGS " -march=armv8.2-a+dotprod ") -endif() - -add_executable(test_quantization test_quantization.cpp) -target_link_libraries( - test_quantization - PRIVATE - GTest::gtest_main - dep -) - -add_executable(test_reduction test_reduction.cpp) -target_link_libraries( - test_reduction - PRIVATE - GTest::gtest_main - dep -) - -add_executable(test_bitpacking test_bitpacking.cpp) -target_link_libraries( - test_bitpacking - PRIVATE - GTest::gtest_main - dep -) - -add_executable(test_linear test_linear.cpp) -target_link_libraries( - test_linear - PRIVATE - GTest::gtest_main - dep - torchao_kernels_aarch64 -) - -add_executable(test_embedding_lut test_embedding_lut.cpp) -target_link_libraries( - test_embedding_lut - PRIVATE - GTest::gtest_main - dep -) - -add_executable(test_embedding test_embedding.cpp) -target_link_libraries( - test_embedding - PRIVATE - GTest::gtest_main - dep -) - -add_executable(test_weight_packing test_weight_packing.cpp) -target_link_libraries( - test_weight_packing - PRIVATE - GTest::gtest_main - dep -) - -add_executable(test_qmatmul test_qmatmul.cpp) -target_link_libraries( - test_qmatmul - PRIVATE - GTest::gtest_main - dep -) - -add_executable(test_lut test_lut.cpp) -target_link_libraries( - test_lut - PRIVATE - GTest::gtest_main - dep -) - -add_executable(test_bitpack_fallback_compatibility test_bitpack_fallback_compatibility.cpp) -target_link_libraries( - test_bitpack_fallback_compatibility - PRIVATE - GTest::gtest_main - dep -) - -include(GoogleTest) -gtest_discover_tests(test_quantization) -gtest_discover_tests(test_reduction) -gtest_discover_tests(test_bitpacking) -gtest_discover_tests(test_linear) -gtest_discover_tests(test_embedding) -gtest_discover_tests(test_embedding_lut) -gtest_discover_tests(test_weight_packing) -gtest_discover_tests(test_qmatmul) -gtest_discover_tests(test_lut) -gtest_discover_tests(test_bitpack_fallback_compatibility) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh deleted file mode 100644 index 768b5db5f3..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh +++ /dev/null @@ -1,66 +0,0 @@ -#!/bin/bash -eu -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -set -eu -SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) -export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../.. -export CMAKE_OUT=/tmp/cmake-out/torch_ao/kernel_tests - -target=${1:-"native"} - -EXTRA_ARGS="" -if [[ "${target}" == "android" ]]; then - if [[ -z ${ANDROID_NDK} ]]; then - echo "Need to set ANDROID_NDK env variable to build for Android"; - exit 1; - fi - android_abi=arm64-v8a - android_platform=28 # must be >=28 for aligned_alloc - IS_ARM64=1 - BUILD_ARM_I8MM=1 # Hardcoded for now - CMAKE_OUT=${CMAKE_OUT/cmake-out/cmake-out-android} - toolchain_file="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" - if [[ -z ${toolchain_file} ]]; then - echo "Unable to find toolchain file at ANDROID_NDK location, looking for ${toolchain_file}" - exit 1; - fi - EXTRA_ARGS="\ - -DCMAKE_TOOLCHAIN_FILE=${toolchain_file} \ - -DANDROID_ABI=${android_abi} \ - -DANDROID_PLATFORM=${android_platform} - " - echo "Building tests for Android (${android_abi}) @ ${CMAKE_OUT}" -fi - -cmake \ - ${EXTRA_ARGS} \ - -DCMAKE_BUILD_TYPE=Debug \ - -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ - -DTORCHAO_BUILD_CPU_AARCH64=ON \ - -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/tests \ - -B ${CMAKE_OUT} - -cmake --build ${CMAKE_OUT} - -echo "Successfully built tests." - -if [[ "${target}" != "native" ]]; then - echo "Skip running tests when cross compiling."; - exit 0; -fi - -# Run -${CMAKE_OUT}/test_quantization -${CMAKE_OUT}/test_reduction -${CMAKE_OUT}/test_bitpacking -${CMAKE_OUT}/test_linear -${CMAKE_OUT}/test_embedding -${CMAKE_OUT}/test_weight_packing -${CMAKE_OUT}/test_qmatmul -${CMAKE_OUT}/test_lut -${CMAKE_OUT}/test_bitpack_fallback_compatibility -${CMAKE_OUT}/test_embedding_lut diff --git a/torchao/experimental/kernels/cpu/fallback/tests/CMakeLists.txt b/torchao/experimental/kernels/cpu/fallback/tests/CMakeLists.txt deleted file mode 100644 index 652475766b..0000000000 --- a/torchao/experimental/kernels/cpu/fallback/tests/CMakeLists.txt +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -cmake_minimum_required(VERSION 3.19) -project(tests) -set(CMAKE_CXX_STANDARD 17) - -include(FetchContent) -FetchContent_Declare( - googletest - URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip -) -FetchContent_MakeAvailable(googletest) - -add_compile_options("-Wall" "-Werror") - -include(CMakePrintHelpers) -message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}") -include_directories(${TORCHAO_LIBRARIES}) -add_library( - dep - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp -) -if(NOT TORCHAO_INCLUDE_DIRS) - set(TORCHAO_INCLUDE_DIRS ${TORCHAO_LIBRARIES}) -endif() - -add_subdirectory( -${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/fallback -${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_cpu_fallback -) - -enable_testing() - -add_executable(test_bitpacking test_bitpacking.cpp) -target_link_libraries( - test_bitpacking - PRIVATE - GTest::gtest_main - dep -) - -include(GoogleTest) -gtest_discover_tests(test_bitpacking) diff --git a/torchao/experimental/kernels/cpu/fallback/tests/build_and_run_tests.sh b/torchao/experimental/kernels/cpu/fallback/tests/build_and_run_tests.sh deleted file mode 100644 index 69590512ec..0000000000 --- a/torchao/experimental/kernels/cpu/fallback/tests/build_and_run_tests.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash -eu -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -set -eu -SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) -export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../.. -export CMAKE_OUT=/tmp/cmake-out/torch_ao/kernel_fallback_tests - -target=${1:-"native"} - -EXTRA_ARGS="" - -cmake \ - ${EXTRA_ARGS} \ - -DCMAKE_BUILD_TYPE=Debug \ - -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ - -DTORCHAO_BUILD_CPU_AARCH64=ON \ - -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/fallback/tests \ - -B ${CMAKE_OUT} - -cmake --build ${CMAKE_OUT} - -echo "Successfully built tests." - -if [[ "${target}" != "native" ]]; then - echo "Skip running tests when cross compiling."; - exit 0; -fi - -# Run -${CMAKE_OUT}/test_bitpacking diff --git a/torchao/experimental/op_lib.py b/torchao/experimental/op_lib.py index e895858d55..771bbfc4ce 100644 --- a/torchao/experimental/op_lib.py +++ b/torchao/experimental/op_lib.py @@ -4,54 +4,10 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from pathlib import Path - import torch from torch import Tensor from torch.library import impl -# Load C++ ops - use multiple potential paths -potential_paths = [ - # Standard path from the module location - Path(__file__).parent.parent, - # Site-packages installation path - Path(torch.__file__).parent.parent / "torchao", - # For editable installs - Path(__file__).parent.parent.parent / "torchao", -] - - -def find_and_load_libtorchao_ops(potential_paths): - """ - Finds and loads torchao._experimental_aten_ops from one of the provided paths - """ - - for lib_path in potential_paths: - libs = list(lib_path.glob("_experimental_aten_ops.*")) - - if not libs: - continue - - assert len(libs) == 1, ( - f"Expected to find one _experimental_aten_ops.* library at {lib_path}, but found {len(libs)}" - ) - - target_lib = libs[0] - print(f"Found library at: {target_lib}") - - try: - torch.ops.load_library(str(target_lib)) - return - except Exception as e: - print(f"Error loading library from {target_lib}: {e}") - - raise FileNotFoundError( - "Could not find libtorchao_ops_aten library in any of the provided paths" - ) - - -find_and_load_libtorchao_ops(potential_paths) - # Define meta ops. To support dynamic shapes, some meta ops need to # be defined in python instead of C++. torchao_lib = torch.library.Library("torchao", "IMPL") diff --git a/torchao/experimental/ops/benchmarks/CMakeLists.txt b/torchao/experimental/ops/benchmarks/CMakeLists.txt deleted file mode 100644 index d06526cf84..0000000000 --- a/torchao/experimental/ops/benchmarks/CMakeLists.txt +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -cmake_minimum_required(VERSION 3.19) -project(benchmarks) - -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_BUILD_TYPE Release) -add_compile_options("-Wall" "-Werror") - -set(TORCHAO_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) -set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) - -include(FetchContent) -FetchContent_Declare(googlebenchmark - GIT_REPOSITORY https://github.com/google/benchmark.git - GIT_TAG main) # need main for benchmark::benchmark - -set(BENCHMARK_ENABLE_TESTING OFF) -FetchContent_MakeAvailable( - googlebenchmark) - -include_directories(${TORCHAO_INCLUDE_DIRS}) - -set(TORCHAO_PARALLEL_BACKEND "openmp") - -include(${TORCHAO_ROOT}/Utils.cmake) - -add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) - -add_executable(benchmark_linear_8bit_act_xbit_weight - benchmark_linear_8bit_act_xbit_weight.cpp - ${TORCHAO_ROOT}/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp -) -target_link_torchao_parallel_backend(benchmark_linear_8bit_act_xbit_weight "${TORCHAO_PARALLEL_BACKEND}") -target_link_libraries( - benchmark_linear_8bit_act_xbit_weight - PRIVATE - benchmark::benchmark - torchao_kernels_aarch64 -) diff --git a/torchao/experimental/ops/benchmarks/build_and_run_benchmarks.sh b/torchao/experimental/ops/benchmarks/build_and_run_benchmarks.sh deleted file mode 100644 index b837b36fe4..0000000000 --- a/torchao/experimental/ops/benchmarks/build_and_run_benchmarks.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -# Call script with sh build_and_run_benchmarks.sh {BENCHAMRK} - -export CMAKE_OUT=/tmp/cmake-out/torchao/benchmarks -cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ - -S . \ - -B ${CMAKE_OUT} \ - -DOpenMP_ROOT=$(brew --prefix libomp) \ - -DTORCHAO_PARALLEL_OMP=ON - -cmake --build ${CMAKE_OUT} - -# Run -${CMAKE_OUT}/benchmark_linear_8bit_act_xbit_weight diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/CMakeLists.txt b/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/CMakeLists.txt deleted file mode 100644 index 7ba8d20c6d..0000000000 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/CMakeLists.txt +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -project(examples) - -cmake_minimum_required(VERSION 3.19) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_BUILD_TYPE Release) - -include(CMakePrintHelpers) - -set(TORCHAO_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) -set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../../..) - -include_directories(${TORCHAO_INCLUDE_DIRS}) - -set(TORCHAO_PARALLEL_BACKEND "openmp") -add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) - -include(${TORCHAO_ROOT}/Utils.cmake) - -add_executable(separate_function_wrappers - separate_function_wrappers.cpp - ${TORCHAO_ROOT}/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp -) -target_link_libraries( - separate_function_wrappers - PRIVATE - torchao_kernels_aarch64 -) -target_link_torchao_parallel_backend(separate_function_wrappers "${TORCHAO_PARALLEL_BACKEND}") - -add_executable(stateful_class_wrapper - stateful_class_wrapper.cpp - ${TORCHAO_ROOT}/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp -) -target_link_libraries( - stateful_class_wrapper - PRIVATE - torchao_kernels_aarch64 -) -target_link_torchao_parallel_backend(stateful_class_wrapper "${TORCHAO_PARALLEL_BACKEND}") diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/Linear8BitActXBitWeightOperator.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/Linear8BitActXBitWeightOperator.h deleted file mode 100644 index 2250a60706..0000000000 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/Linear8BitActXBitWeightOperator.h +++ /dev/null @@ -1,197 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once -#include -#include -#include -#include -#include - -namespace torchao::ops::linear_8bit_act_xbit_weight { - -class Linear8BitActXBitWeightOperator { - private: - torchao::aligned_byte_ptr packed_weight_data_{nullptr, nullptr}; - int packed_weight_data_size_{0}; - int preferred_packed_weight_data_alignment_{0}; - - torchao::aligned_byte_ptr activation_data_buffer_{nullptr, nullptr}; - - int m_{0}; - int n_{0}; - int k_{0}; - int group_size_{0}; - - // The class does not own this data - const int8_t* weight_qvals_{nullptr}; - const float* weight_scales_{nullptr}; - const int8_t* weight_zeros_{nullptr}; - - bool initialized_{false}; - - UKernelConfig ukernel_config_; - PackWeightDataTilingParams pack_weight_tiling_params_; - LinearTilingParams linear_tiling_params_; - LinearTileSchedulingPolicy linear_scheduling_policy_; - - public: - Linear8BitActXBitWeightOperator( - UKernelConfig ukernel_config, - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - int initial_m = 1, - std::optional pack_weight_tiling_params = {}, - std::optional linear_tiling_params = {}, - std::optional linear_scheduling_policy = {}) - : m_{initial_m}, - n_{n}, - k_{k}, - group_size_(group_size), - weight_qvals_{weight_qvals}, - weight_scales_{weight_scales}, - weight_zeros_{weight_zeros} { - TORCHAO_CHECK(n_ >= 1, "n must be >= 1"); - TORCHAO_CHECK(k_ >= 1, "k must be >= 1"); - TORCHAO_CHECK(group_size_ >= 1, "group_size must be >= 1"); - TORCHAO_CHECK(m_ >= 1, "initial_m must be >= 1"); - - ukernel_config_ = ukernel_config; - if (pack_weight_tiling_params.has_value()) { - pack_weight_tiling_params_ = pack_weight_tiling_params.value(); - } else { - pack_weight_tiling_params_ = get_default_pack_weight_data_tiling_params( - ukernel_config_, n_, /*target_panels_per_thread=*/1); - } - - if (linear_tiling_params.has_value()) { - linear_tiling_params_ = linear_tiling_params.value(); - } else { - linear_tiling_params_ = get_default_linear_tiling_params( - ukernel_config_, m_, n_, /*target_tiles_per_thread=*/5); - } - - if (linear_scheduling_policy.has_value()) { - linear_scheduling_policy_ = linear_scheduling_policy.value(); - } else { - linear_scheduling_policy_ = - LinearTileSchedulingPolicy::single_mc_parallel_nc; - } - } - - int get_m() { - return m_; - } - int get_n() { - return n_; - } - int get_k() { - return k_; - } - int get_group_size() { - return group_size_; - } - - void initialize() { - if (initialized_) { - return; - } - - // Pack weight data - auto packed_weight_data_size = - get_packed_weight_data_size(ukernel_config_, n_, k_, group_size_); - auto preferred_packed_weight_data_alignment = - get_preferred_packed_weight_data_alignment(ukernel_config_); - - packed_weight_data_size_ = packed_weight_data_size; - preferred_packed_weight_data_alignment_ = preferred_packed_weight_data_alignment; - packed_weight_data_ = torchao::make_aligned_byte_ptr( - preferred_packed_weight_data_alignment, packed_weight_data_size); - - pack_weight_data_operator( - ukernel_config_, - pack_weight_tiling_params_, - packed_weight_data_.get(), - n_, - k_, - group_size_, - weight_qvals_, - weight_scales_, - weight_zeros_); - - // Pre-allocate space for quantized/packed activations - // This buffer may be resized when calling the operator if m is changed - auto activation_data_buffer_size = get_activation_data_buffer_size( - ukernel_config_, - linear_tiling_params_, - linear_scheduling_policy_, - m_, - k_, - group_size_); - auto activation_data_buffer_alignment = - get_preferred_activation_data_buffer_alignment(ukernel_config_); - activation_data_buffer_ = torchao::make_aligned_byte_ptr( - activation_data_buffer_alignment, activation_data_buffer_size); - - // Mark as initialized - initialized_ = true; - } - - void operator()( - float* output, - const float* activations, - int m, - int k, - const float* bias, - float clamp_min, - float clamp_max) { - TORCHAO_CHECK(initialized_, "kernel is not initialized."); - TORCHAO_CHECK( - k == this->k_, - "activations have incompatible size with initialized kernel."); - - // Resize activation buffer if needed - if (m > m_) { - m_ = m; - auto activation_data_buffer_size = get_activation_data_buffer_size( - ukernel_config_, - linear_tiling_params_, - linear_scheduling_policy_, - m_, - k_, - group_size_); - auto activation_data_buffer_alignment = - get_preferred_activation_data_buffer_alignment(ukernel_config_); - activation_data_buffer_ = torchao::make_aligned_byte_ptr( - activation_data_buffer_alignment, activation_data_buffer_size); - } - - // Run linear operator - linear_operator( - ukernel_config_, - linear_tiling_params_, - linear_scheduling_policy_, - activation_data_buffer_.get(), - output, - // To support dynamic shapes, we use m from args, not m_ - // Note m_ can be larger than m - m, - n_, - k_, - group_size_, - packed_weight_data_.get(), - activations, - bias, - clamp_min, - clamp_max); - } -}; -} // namespace - // torchao::ops::linear_8bit_act_xbit_weight diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/build_and_run_examples.sh b/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/build_and_run_examples.sh deleted file mode 100644 index 01185fdd3f..0000000000 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/build_and_run_examples.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" -echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}" -export CMAKE_OUT=/tmp/cmake-out/torchao/examples -cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ - -S . \ - -B ${CMAKE_OUT} \ - -DOpenMP_ROOT=$(brew --prefix libomp) -cmake --build ${CMAKE_OUT} - -# Run -case "$1" in - separate_function_wrappers) ${CMAKE_OUT}/separate_function_wrappers; ;; - stateful_class_wrapper) ${CMAKE_OUT}/stateful_class_wrapper; ;; - *) echo "Unknown example: $1. Please specify one of: separate_function_wrappers, stateful_class_wrapper."; exit 1; ;; -esac diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/separate_function_wrappers.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/separate_function_wrappers.cpp deleted file mode 100644 index 961c03e985..0000000000 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/separate_function_wrappers.cpp +++ /dev/null @@ -1,223 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#include -#include -#include -#include -#include -#include -// This file contains an example of wrapping the torchao weight packing and -// linear operators into two operators: one for weight packing and another -// for running the linear operator. Each surface (PyTorch custom class, PyTorch -// operator, ExecuTorch operator, ExecuTorch delegate) will need to write its -// own wrapper). In the example here, std::vector is used for storage, but in -// PyTorch a PyTorch Tensor would be used and in ExecuTorch, an ExecuTorch -// Tensor would be used. -// -// It is more efficient to combine weight-packing and the linear operator into -// one stateful class, but not all surfaces support this (see -// examples/stateful_class_wrapper.cpp for an example of this). - -namespace torchao::ops::linear_8bit_act_xbit_weight { - -template -UKernelConfig get_ukernel_config() { - UKernelConfig config; - - namespace ukernel = torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; - config.mr = 1; - config.nr = 8; - config.activation_data_size_fn = - &ukernel::activation_data_size; - config.preferred_activation_data_alignment = 16; // size of neon register - config.prepare_activation_data_fn = - &ukernel::prepare_activation_data; - config.weight_data_size_fn = - &ukernel::weight_data_size; - config.preferred_weight_data_alignment = 16; // size of neon register - config.prepare_weight_data_fn = - &ukernel::prepare_weight_data; - config.kernel_fn = - &ukernel::kernel; - - return config; -} - -torchao::aligned_byte_ptr pack_weight_data_operator( - UKernelConfig ukernel_config, - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - std::optional tiling_params = {}) { - PackWeightDataTilingParams tiling_params_; - if (tiling_params.has_value()) { - tiling_params_ = tiling_params.value(); - } else { - tiling_params_ = get_default_pack_weight_data_tiling_params( - ukernel_config, n, /*target_panels_per_thread=*/1); - } - - auto packed_weight_data_size = - get_packed_weight_data_size(ukernel_config, n, k, group_size); - auto preferred_packed_weight_data_alignment = - get_preferred_packed_weight_data_alignment(ukernel_config); - auto packed_weight_data = torchao::make_aligned_byte_ptr( - preferred_packed_weight_data_alignment, packed_weight_data_size); - - pack_weight_data_operator( - ukernel_config, - tiling_params_, - packed_weight_data.get(), - n, - k, - group_size, - weight_qvals, - weight_scales, - weight_zeros); - - return packed_weight_data; -} - -void linear_operator( - UKernelConfig ukernel_config, - float* output, - int m, - int n, - int k, - int group_size, - void* packed_weight_data, - float* activations, - const float* bias, - float clamp_min, - float clamp_max, - std::optional tiling_params = {}, - std::optional scheduling_policy = {}) { - LinearTilingParams tiling_params_; - if (tiling_params.has_value()) { - tiling_params_ = tiling_params.value(); - } else { - tiling_params_ = get_default_linear_tiling_params( - ukernel_config, m, n, /*target_tiles_per_thread=*/5); - } - - LinearTileSchedulingPolicy scheduling_policy_; - if (scheduling_policy.has_value()) { - scheduling_policy_ = scheduling_policy.value(); - } else { - scheduling_policy_ = LinearTileSchedulingPolicy::single_mc_parallel_nc; - } - - auto activation_data_buffer_size = get_activation_data_buffer_size( - ukernel_config, tiling_params_, scheduling_policy_, m, k, group_size); - auto activation_data_buffer_alignment = - get_preferred_activation_data_buffer_alignment(ukernel_config); - auto activation_data_buffer = torchao::make_aligned_byte_ptr( - activation_data_buffer_alignment, activation_data_buffer_size); - - linear_operator( - ukernel_config, - tiling_params_, - scheduling_policy_, - activation_data_buffer.get(), - output, - m, - n, - k, - group_size, - packed_weight_data, - activations, - bias, - clamp_min, - clamp_max); -} - -} // namespace - // torchao::ops::linear_8bit_act_xbit_weight - -int main() { - using namespace torchao::ops::linear_8bit_act_xbit_weight; - - torchao::set_num_threads(8); - std::cout << "Using " << torchao::get_num_threads() << " threads." - << std::endl; - - constexpr int weight_nbit = 3; - constexpr bool has_weight_zeros = false; - constexpr bool has_bias = false; - constexpr bool has_clamp = false; - - int m = 1; - int n = 4096 + 1; - int k = 4096; - int group_size = 16; - - std::cout << "Generating random test case." << std::endl; - auto test_case = torchao:: - channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( - m, - k, - n, - group_size, - weight_nbit, - has_weight_zeros, - has_bias, - has_clamp); - - auto output = std::vector(m * n); - - auto ukernel_config = - get_ukernel_config(); - - std::cout << "Running pack_weight_data_operator." << std::endl; - auto packed_weight_data = pack_weight_data_operator( - ukernel_config, - n, - k, - group_size, - test_case.weight_qvals.data(), - test_case.weight_scales.data(), - test_case.weight_zeros.data()); - - std::cout << "Running linear_operator." << std::endl; - linear_operator( - ukernel_config, - output.data(), - m, - n, - k, - group_size, - packed_weight_data.get(), - test_case.activations.data(), - test_case.bias.data(), - test_case.clamp_min, - test_case.clamp_max); - - std::cout << "Checking results." << std::endl; - - bool passed = true; - float tol = 0.001; - for (int i = 0; i < output.size(); i++) { - if (std::abs(test_case.expected_output[i] - output[i]) > tol) { - std::cout << "Bad result at index " << i << "."; - std::cout << " Output: " << output[i] - << ". Expected: " << test_case.expected_output[i] << "." - << std::endl; - passed = false; - } - } - if (passed) { - std::cout << "Test passed." << std::endl; - } else { - std::cout << "Test failed." << std::endl; - } - - return 0; -} diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/stateful_class_wrapper.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/stateful_class_wrapper.cpp deleted file mode 100644 index a45c32811b..0000000000 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/examples/stateful_class_wrapper.cpp +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#include -#include -#include -#include -#include -#include - -// This file contains an example of wrapping the torchao weight packing and -// linear operators into one stateful LinearOperator class. Each surface -// (PyTorch custom class, PyTorch operator, ExecuTorch operator, ExecuTorch -// delegate) will need to write its own wrapper. In the example here, -// std::vector is used for storage, but in PyTorch a PyTorch Tensor would be -// used and in ExecuTorch, an ExecuTorch Tensor would be used. -// -// Although more efficient, not all surfaces support stateful operators. See -// examples/separate_function_wrappers.cpp for an example of how to split the -// operations into two steps. - -using namespace torchao::ops::linear_8bit_act_xbit_weight; - -template -UKernelConfig get_ukernel_config() { - UKernelConfig config; - - namespace ukernel = torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; - config.mr = 1; - config.nr = 8; - config.activation_data_size_fn = - &ukernel::activation_data_size; - config.preferred_activation_data_alignment = 16; // size of neon register - config.prepare_activation_data_fn = - &ukernel::prepare_activation_data; - config.weight_data_size_fn = - &ukernel::weight_data_size; - config.preferred_weight_data_alignment = 16; // size of neon register - config.prepare_weight_data_fn = - &ukernel::prepare_weight_data; - config.kernel_fn = - &ukernel::kernel; - - return config; -} - -int main() { - int m = 13; - int n = 4096 + 1; - int k = 4096; - int group_size = 16; - - constexpr int weight_nbit = 4; - constexpr bool has_weight_zeros = false; - constexpr bool has_bias = false; - constexpr bool has_clamp = false; - - std::cout << "Generating random test case." << std::endl; - auto test_case = torchao:: - channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( - m, - k, - n, - group_size, - weight_nbit, - has_weight_zeros, - has_bias, - has_clamp); - - torchao::set_num_threads(8); - std::cout << "Using " << torchao::get_num_threads() << " threads." - << std::endl; - - std::cout << "Initializing linear_operator." << std::endl; - auto ukernel_config = - get_ukernel_config(); - - auto linear_operator = - Linear8BitActXBitWeightOperator( - ukernel_config, - n, - k, - group_size, - test_case.weight_qvals.data(), - test_case.weight_scales.data(), - test_case.weight_zeros.data(), - // m may be resized during call to support dynamic shapes - /*initial_m=*/1); - - linear_operator.initialize(); - - std::cout << "Calling linear_operator." << std::endl; - auto output = std::vector(m * n); - linear_operator( - output.data(), - test_case.activations.data(), - m, - k, - test_case.bias.data(), - test_case.clamp_min, - test_case.clamp_max); - - std::cout << "Checking results." << std::endl; - - bool passed = true; - float tol = 0.001; - for (int i = 0; i < output.size(); i++) { - if (std::abs(test_case.expected_output[i] - output[i]) > tol) { - std::cout << "Bad result at index " << i << "."; - std::cout << " Output: " << output[i] - << ". Expected: " << test_case.expected_output[i] << "." - << std::endl; - passed = false; - break; - } - } - if (passed) { - std::cout << "Test passed." << std::endl; - } else { - std::cout << "Test failed." << std::endl; - } - - return 0; -} diff --git a/torchao/experimental/ops/tests/CMakeLists.txt b/torchao/experimental/ops/tests/CMakeLists.txt deleted file mode 100644 index 1d0e40ba21..0000000000 --- a/torchao/experimental/ops/tests/CMakeLists.txt +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -cmake_minimum_required(VERSION 3.19) -project(tests) - -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_BUILD_TYPE Debug) -add_compile_options("-Wall" "-Werror") - -set(TORCHAO_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) -set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) - -include(FetchContent) -FetchContent_Declare( - googletest - URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip -) -FetchContent_MakeAvailable(googletest) -enable_testing() - -if(TORCHAO_BUILD_CPU_AARCH64) - add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64=1) - add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT) -endif() - -if(TORCHAO_BUILD_KLEIDIAI) - add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1) - # TODO: build tests at top-level so we can use same KleidiAI version - if (NOT TARGET kleidiai) - include(FetchContent) - # KleidiAI is an open-source library that provides optimized - # performance-critical routines, also known as micro-kernels, for artificial - # intelligence (AI) workloads tailored for Arm® CPUs. - set(KLEIDIAI_BUILD_TESTS OFF CACHE BOOL "" FORCE) - set(KLEIDIAI_BUILD_BENCHMARKS OFF CACHE BOOL "" FORCE) - FetchContent_Declare(kleidiai - GIT_REPOSITORY https://git.gitlab.arm.com/kleidi/kleidiai.git - GIT_TAG v1.12.0 - ) - FetchContent_MakeAvailable(kleidiai) - endif() -endif() - -if(TORCHAO_BUILD_ARM_I8MM) - add_compile_definitions(TORCHAO_ENABLE_ARM_I8MM) -endif() - -if (ANDROID_ABI) - # We are cross compiling, delay test discovery till runtime - set(CMAKE_GTEST_DISCOVER_TESTS_DISCOVERY_MODE PRE_TEST) -endif() - -include_directories(${TORCHAO_INCLUDE_DIRS}) - -set(TORCHAO_PARALLEL_BACKEND "test_dummy") - -if (TORCHAO_BUILD_CPU_AARCH64) - add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) - add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64) -endif() - -include(${TORCHAO_ROOT}/Utils.cmake) - -if (ANDROID_ABI) - # Given where we are today this is sufficent. But needs to be revisited. - # This is also needed for native builds, but keeping it only for cross builds - # for now given the hacky nature. - file(GLOB DOTPROD_SRC_FILES test*.cpp) - message(SRC_FILES: ${DOTPROD_SRC_FILES}) - set_property(SOURCE - ${DOTPROD_SRC_FILES} - APPEND_STRING PROPERTY - COMPILE_FLAGS " -march=armv8.2-a+dotprod ") -endif() - -add_executable( - test_linear_8bit_act_xbit_weight - test_linear_8bit_act_xbit_weight.cpp - ${TORCHAO_ROOT}/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp -) -target_link_libraries( - test_linear_8bit_act_xbit_weight - PRIVATE - GTest::gtest_main -) -if (TORCHAO_BUILD_CPU_AARCH64) - target_link_libraries( - test_linear_8bit_act_xbit_weight - PRIVATE - torchao_kernels_aarch64 - ) -endif() -if (TORCHAO_BUILD_KLEIDIAI) - target_link_libraries( - test_linear_8bit_act_xbit_weight - PRIVATE - kleidiai - ) -endif() -target_link_torchao_parallel_backend(test_linear_8bit_act_xbit_weight "${TORCHAO_PARALLEL_BACKEND}") - -add_executable( - test_groupwise_lowbit_weight_lut - test_groupwise_lowbit_weight_lut.cpp - ${TORCHAO_ROOT}/ops/groupwise_lowbit_weight_lut/groupwise_lowbit_weight_lut.cpp -) -target_link_libraries( - test_groupwise_lowbit_weight_lut - PRIVATE - GTest::gtest_main -) -if (TORCHAO_BUILD_CPU_AARCH64) - target_link_libraries( - test_groupwise_lowbit_weight_lut - PRIVATE - torchao_kernels_aarch64 - ) -endif() -target_link_torchao_parallel_backend(test_groupwise_lowbit_weight_lut "${TORCHAO_PARALLEL_BACKEND}") - -include(GoogleTest) -gtest_discover_tests(test_groupwise_lowbit_weight_lut) -gtest_discover_tests(test_linear_8bit_act_xbit_weight) diff --git a/torchao/experimental/ops/tests/build_and_run_tests.sh b/torchao/experimental/ops/tests/build_and_run_tests.sh deleted file mode 100644 index 4e6fef8ce1..0000000000 --- a/torchao/experimental/ops/tests/build_and_run_tests.sh +++ /dev/null @@ -1,66 +0,0 @@ -#!/bin/bash -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -target=${1:-"native"} -SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) -export CMAKE_OUT=/tmp/cmake-out/torch_ao/tests - -export TORCH_DIR=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib() + '/torch/share/cmake/Torch')") - -IS_ARM64=0 -BUILD_ARM_I8MM=0 -EXTRA_ARGS="" -if [[ "${target}" == "android" ]]; then - if [[ -z ${ANDROID_NDK} ]]; then - echo "Need to set ANDROID_NDK env variable to build for Android"; - exit 1; - fi - android_abi=arm64-v8a - android_platform=28 # must be >=28 for aligned_alloc - IS_ARM64=1 - BUILD_ARM_I8MM=1 # Hardcoded for now - CMAKE_OUT=${CMAKE_OUT/cmake-out/cmake-out-android} - toolchain_file="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" - if [[ -z ${toolchain_file} ]]; then - echo "Unable to find toolchain file at ANDROID_NDK location, looking for ${toolchain_file}" - exit 1; - fi - EXTRA_ARGS="\ - -DCMAKE_TOOLCHAIN_FILE=${toolchain_file} \ - -DANDROID_ABI=${android_abi} \ - -DANDROID_PLATFORM=${android_platform} - " - echo "Building tests for Android (${android_abi}) @ ${CMAKE_OUT}" -fi - -hash arch; retval=$? -if [[ ${retval} -eq 0 && $(arch) == "arm64" ]]; then - IS_ARM64=1 -fi - -cmake \ - ${EXTRA_ARGS} \ - -DCMAKE_BUILD_TYPE=Debug \ - -DTORCHAO_BUILD_CPU_AARCH64=${IS_ARM64} \ - -DTORCHAO_BUILD_KLEIDIAI=${IS_ARM64} \ - -DTORCHAO_BUILD_ARM_I8MM=${BUILD_ARM_I8MM} \ - -DTorch_DIR=${TORCH_DIR} \ - -S . \ - -B ${CMAKE_OUT} - -cmake --build ${CMAKE_OUT} - -echo "Successfully built tests." - -if [[ "${target}" != "native" ]]; then - echo "Skip running tests when cross compiling."; - exit 0; -fi - -# Run -${CMAKE_OUT}/test_linear_8bit_act_xbit_weight -${CMAKE_OUT}/test_groupwise_lowbit_weight_lut diff --git a/torchao/experimental/temp_build.py b/torchao/experimental/temp_build.py deleted file mode 100644 index 3195e24581..0000000000 --- a/torchao/experimental/temp_build.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import glob -import subprocess -import tempfile - -import torch - - -def cmake_build_torchao_ops(cmake_lists_path, temp_build_dir): - from distutils.sysconfig import get_python_lib - - print("Building torchao ops for ATen target") - cmake_prefix_path = get_python_lib() - subprocess.run( - [ - "cmake", - "-DCMAKE_PREFIX_PATH=" + cmake_prefix_path, - "-DCMAKE_INSTALL_PREFIX=" + temp_build_dir.name, - "-S " + cmake_lists_path, - "-B " + temp_build_dir.name, - ] - ) - subprocess.run( - [ - "cmake", - "--build", - temp_build_dir.name, - "-j 16", - "--target install", - "--config Release", - ] - ) - - -def temp_build_and_load_torchao_ops(cmake_lists_path): - temp_build_dir = tempfile.TemporaryDirectory() - cmake_build_torchao_ops(cmake_lists_path, temp_build_dir) - libs = glob.glob(f"{temp_build_dir.name}/lib/libtorchao_ops_aten.*") - libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) - assert len(libs) == 1 - torch.ops.load_library(libs[0]) - print(f"TorchAO ops are loaded from {libs[0]}")