Skip to content

Commit 3c21c9d

Browse files
authored
Wrap Array1<T> as torch::Tensor. (k2-fsa#173)
* Wrap Array1<T> as torch::Tensor. Fix k2host test cases. * interpret arc.weight from a float to an int. * update the comment for torch.h/torch.cu * fix linker errors for release build.
1 parent c0be6f5 commit 3c21c9d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

86 files changed

+2226
-1631
lines changed

.flake8

+8
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,11 @@
22
show-source=true
33
statistics=true
44
max-line-length=80
5+
exclude =
6+
.git,
7+
build,
8+
k2/python/host
9+
10+
ignore =
11+
# E127 continuation line over-indented for visual indent
12+
E127,

.github/workflows/style_check.yml

+3-4
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
runs-on: ubuntu-latest
1818
strategy:
1919
matrix:
20-
python-version: [3.5, 3.6, 3.7, 3.8]
20+
python-version: [3.7, 3.8]
2121

2222
steps:
2323
- uses: actions/checkout@v2
@@ -32,16 +32,15 @@ jobs:
3232
- name: Install Python dependencies
3333
run: |
3434
python3 -m pip install --upgrade pip
35-
python3 -m pip install --upgrade flake8
35+
python3 -m pip install --upgrade flake8==3.8.3
3636
3737
- name: Run flake8
3838
shell: bash
3939
working-directory: ${{github.workspace}}
4040
run: |
4141
# stop the build if there are Python syntax errors or undefined names
4242
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
43-
# exit-zero treats all errors as warnings.
44-
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=79 --statistics
43+
flake8 .
4544
4645
# TODO(fangjun): build a docker for style check
4746
# - name: Install cppcheck

CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,10 @@ enable_testing()
9797
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
9898
include(pybind11)
9999
if(USE_PYTORCH)
100+
add_definitions(-DK2_USE_PYTORCH)
100101
include(torch)
101102
endif()
102103
include(cub)
103104
include(googletest)
104105

105-
106106
add_subdirectory(k2)

cmake/pybind11.cmake

+6
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,15 @@ function(download_pybind11)
1111
set(pybind11_URL "https://github.com/pybind/pybind11/archive/v2.5.0.tar.gz")
1212
set(pybind11_HASH "SHA256=97504db65640570f32d3fdf701c25a340c8643037c3b69aec469c10c93dc8504")
1313

14+
set(double_quotes "\"")
15+
set(dollar "\$")
16+
set(semicolon "\;")
1417
FetchContent_Declare(pybind11
1518
URL ${pybind11_URL}
1619
URL_HASH ${pybind11_HASH}
20+
PATCH_COMMAND
21+
sed -i s/\\${double_quotes}-flto\\\\${dollar}/\\${double_quotes}-Xcompiler=-flto${dollar}/g "tools/pybind11Tools.cmake" &&
22+
sed -i s/${seimcolon}-fno-fat-lto-objects/${seimcolon}-Xcompiler=-fno-fat-lto-objects/g "tools/pybind11Tools.cmake"
1723
)
1824

1925
FetchContent_GetProperties(pybind11)

k2/csrc/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ else()
2828
endif()
2929

3030
# the target
31-
add_library(context STATIC ${context_srcs})
31+
add_library(context SHARED ${context_srcs})
3232
set_target_properties(context PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
3333

3434
# lib deps

k2/csrc/default_context.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class CpuContext : public Context {
3939
int32_t ret = posix_memalign(&p, kAlignment, bytes);
4040
K2_CHECK_EQ(ret, 0);
4141
}
42-
if (deleter_context) *deleter_context = nullptr;
42+
if (deleter_context != nullptr) *deleter_context = nullptr;
4343
return p;
4444
}
4545

@@ -75,7 +75,7 @@ class CudaContext : public Context {
7575
auto ret = cudaMalloc(&p, bytes);
7676
K2_CHECK_CUDA_ERROR(ret);
7777
}
78-
if (deleter_context) *deleter_context = nullptr;
78+
if (deleter_context != nullptr) *deleter_context = nullptr;
7979
return p;
8080
}
8181

k2/csrc/host/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
# the target
66
# please sort the source files alphabetically
7-
add_library(fsa
7+
add_library(fsa SHARED
88
arcsort.cc
99
aux_labels.cc
1010
connect.cc

k2/csrc/host/fsa.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ inline std::size_t AlignTo(std::size_t b, std::size_t alignment) {
3030
namespace k2host {
3131

3232
std::ostream &operator<<(std::ostream &os, const Arc &arc) {
33-
os << arc.src_state << " " << arc.dest_state << " " << arc.label;
33+
os << arc.src_state << " " << arc.dest_state << " " << arc.label << " "
34+
<< arc.weight;
3435
return os;
3536
}
3637

k2/csrc/pytorch_context.cu

+23
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,27 @@ ContextPtr GetCudaContext(int32_t gpu_id /*= -1*/) {
2323
return std::make_shared<PytorchCudaContext>(gpu_id);
2424
}
2525

26+
RegionPtr NewRegion(torch::Tensor &tensor) {
27+
auto ans = std::make_shared<Region>();
28+
if (tensor.device().type() == torch::kCPU) {
29+
ans->context = GetCpuContext();
30+
} else if (tensor.is_cuda()) {
31+
ans->context = GetCudaContext(tensor.device().index());
32+
} else {
33+
K2_LOG(FATAL) << "Unsupported device: " << tensor.device()
34+
<< "\nOnly CPU and CUDA are supported";
35+
}
36+
37+
// NOTE: the tensor is passed from Python and we have
38+
// to retain it to avoid potential segmentation fault.
39+
//
40+
// It will be freed in `Context::Deallocate`.
41+
auto *managed_tensor = new ManagedTensor(tensor);
42+
ans->data = tensor.data_ptr();
43+
ans->deleter_context = managed_tensor;
44+
ans->num_bytes = tensor.nbytes();
45+
ans->bytes_used = ans->num_bytes;
46+
return ans;
47+
}
48+
2649
} // namespace k2

k2/csrc/pytorch_context.h

+33-6
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,21 @@
1616
#include <memory>
1717

1818
#include "c10/cuda/CUDACachingAllocator.h"
19+
#include "c10/cuda/CUDAFunctions.h"
1920
#include "k2/csrc/context.h"
2021
#include "k2/csrc/log.h"
2122
#include "torch/torch.h"
2223

2324
namespace k2 {
2425

26+
class ManagedTensor {
27+
public:
28+
explicit ManagedTensor(torch::Tensor &tensor) : handle_(tensor) {}
29+
30+
private:
31+
torch::Tensor handle_; // retain a copy of the tensor passed from Python
32+
};
33+
2534
class PytorchCpuContext : public Context {
2635
private:
2736
PytorchCpuContext() {
@@ -46,12 +55,18 @@ class PytorchCpuContext : public Context {
4655

4756
void *Allocate(std::size_t bytes, void **deleter_context) override {
4857
void *p = allocator_->raw_allocate(bytes);
49-
if (deleter_context) *deleter_context = nullptr;
58+
if (deleter_context != nullptr) *deleter_context = nullptr;
5059
return p;
5160
}
5261

53-
void Deallocate(void *data, void * /*deleter_context*/) override {
54-
allocator_->raw_deallocate(data);
62+
void Deallocate(void *data, void *deleter_context) override {
63+
if (deleter_context != nullptr) {
64+
// a non-empty `deleter_context` indicates that
65+
// the memory is passed from a `torch::Tensor`
66+
delete reinterpret_cast<ManagedTensor *>(deleter_context);
67+
} else {
68+
allocator_->raw_deallocate(data);
69+
}
5570
}
5671

5772
bool IsCompatible(const Context &other) const override {
@@ -94,12 +109,18 @@ class PytorchCudaContext : public Context {
94109

95110
void *Allocate(std::size_t bytes, void **deleter_context) override {
96111
void *p = allocator_->raw_allocate(bytes);
97-
if (deleter_context) *deleter_context = nullptr;
112+
if (deleter_context != nullptr) *deleter_context = nullptr;
98113
return p;
99114
}
100115

101-
void Deallocate(void *data, void * /*deleter_context*/) override {
102-
allocator_->raw_deallocate(data);
116+
void Deallocate(void *data, void *deleter_context) override {
117+
if (deleter_context != nullptr) {
118+
// a non-empty `deleter_context` indicates that
119+
// the memory is passed from a `torch::Tensor`
120+
delete reinterpret_cast<ManagedTensor *>(deleter_context);
121+
} else {
122+
allocator_->raw_deallocate(data);
123+
}
103124
}
104125

105126
bool IsCompatible(const Context &other) const override {
@@ -116,6 +137,12 @@ class PytorchCudaContext : public Context {
116137
int32_t gpu_id_;
117138
};
118139

140+
// Construct a region from a `torch::Tensor`.
141+
//
142+
// The resulting region shares the underlying memory with
143+
// the given tensor.
144+
RegionPtr NewRegion(torch::Tensor &tensor);
145+
119146
} // namespace k2
120147

121148
#endif // K2_CSRC_PYTORCH_CONTEXT_H_

k2/python/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
add_subdirectory(csrc)
22
add_subdirectory(tests)
3+
add_subdirectory(host)

k2/python/csrc/CMakeLists.txt

+20-13
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
1-
# please sort the files alphabetically
2-
pybind11_add_module(_k2
3-
array.cc
4-
aux_labels.cc
5-
fsa.cc
6-
fsa_algo.cc
7-
fsa_equivalent.cc
8-
fsa_util.cc
9-
k2.cc
10-
properties.cc
11-
tensor.cc
12-
weights.cc
1+
# please keep the list sorted
2+
set(k2_srcs
3+
k2.cu
4+
torch.cu
135
)
146

15-
target_include_directories(_k2 PRIVATE ${CMAKE_SOURCE_DIR})
7+
if(USE_PYTORCH)
8+
add_definitions(-DTORCH_API_INCLUDE_EXTENSION_H)
9+
add_subdirectory(torch)
10+
set(k2_srcs ${k2_srcs} ${torch_srcs})
11+
set(k2_deps
12+
${TORCH_LIBRARIES}
13+
${TORCH_DIR}/lib/libtorch_python.so
14+
)
15+
else()
16+
message(FATAL_ERROR "Please select a framework.")
17+
endif()
18+
19+
pybind11_add_module(_k2 ${k2_srcs})
20+
target_link_libraries(_k2 PRIVATE ${k2_deps})
21+
target_link_libraries(_k2 PRIVATE context)
1622
target_link_libraries(_k2 PRIVATE fsa)
23+
target_include_directories(_k2 PRIVATE ${CMAKE_SOURCE_DIR})

k2/python/csrc/aux_labels.h

-14
This file was deleted.

k2/python/csrc/fsa_algo.h

-14
This file was deleted.

k2/python/csrc/fsa_equivalent.h

-14
This file was deleted.

k2/python/csrc/fsa_util.h

-14
This file was deleted.

k2/python/csrc/k2.cc

-30
This file was deleted.

k2/python/csrc/k2.cu

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
/**
2+
* @brief python wrappers for k2.
3+
*
4+
* @copyright
5+
* Copyright (c) 2020 Mobvoi AI Lab, Beijing, China (authors: Fangjun Kuang)
6+
*
7+
* @copyright
8+
* See LICENSE for clarification regarding multiple authors
9+
*/
10+
11+
#include "k2/python/csrc/k2.h"
12+
13+
#include "k2/python/csrc/torch.h"
14+
15+
PYBIND11_MODULE(_k2, m) {
16+
m.doc() = "pybind11 binding of k2";
17+
PybindTorch(m);
18+
}

k2/python/csrc/k2.h

+9-6
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
1-
// k2/python/csrc/k2.h
2-
3-
// Copyright (c) 2020 Fangjun Kuang ([email protected])
4-
5-
// See ../../../LICENSE for clarification regarding multiple authors
1+
/**
2+
* @brief python wrappers for k2.
3+
*
4+
* @copyright
5+
* Copyright (c) 2020 Mobvoi AI Lab, Beijing, China (authors: Fangjun Kuang)
6+
*
7+
* @copyright
8+
* See LICENSE for clarification regarding multiple authors
9+
*/
610

711
#ifndef K2_PYTHON_CSRC_K2_H_
812
#define K2_PYTHON_CSRC_K2_H_
913

1014
#include "pybind11/pybind11.h"
1115
#include "pybind11/stl.h"
12-
#include "k2/csrc/log.h"
1316

1417
namespace py = pybind11;
1518

0 commit comments

Comments
 (0)