Skip to content

Commit cfde6e1

Browse files
committed
Add a lint rule for torch/csrc/util/pybind.h include
We define specializations for pybind11 defined templates (in particular, PYBIND11_DECLARE_HOLDER_TYPE) and consequently it is important that these specializations *always* be #include'd when making use of pybind11 templates whose behavior depends on these specializations, otherwise we can cause an ODR violation. The easiest way to ensure that all the specializations are always loaded is to designate a header (in this case, torch/csrc/util/pybind.h) that ensures the specializations are defined, and then add a lint to ensure this header is included whenever pybind11 headers are included. The existing grep linter didn't have enough knobs to do this conveniently, so I added some features. I'm open to suggestions for how to structure the features better. The main changes: - Added an --allowlist-pattern flag, which turns off the grep lint if some other line exists. This is used to stop the grep lint from complaining about pybind11 includes if the util include already exists. - Added --match-first-only flag, which lets grep only match against the first matching line. This is because, even if there are multiple includes that are problematic, I only need to fix one of them. We don't /really/ need this, but when I was running lintrunner -a to fixup the preexisting codebase it was annoying without this, as the lintrunner overall driver fails if there are multiple edits on the same file. I excluded any files that didn't otherwise have a dependency on torch/ATen, this was mostly caffe2 and the valgrind wrapper compat bindings. Note the grep replacement is kind of crappy, but clang-tidy lint cleaned it up in most cases. See also pybind/pybind11#4099 Signed-off-by: Edward Z. Yang <[email protected]> [ghstack-poisoned]
1 parent 53f5689 commit cfde6e1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+135
-4
lines changed

.lintrunner.toml

+33
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,39 @@ command = [
387387
'@{{PATHSFILE}}'
388388
]
389389

390+
[[linter]]
391+
code = 'PYBIND11_INCLUDE'
392+
include_patterns = [
393+
'**/*.cpp',
394+
'**/*.h',
395+
]
396+
exclude_patterns = [
397+
'torch/csrc/utils/pybind.h',
398+
'torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp',
399+
'caffe2/**/*',
400+
]
401+
command = [
402+
'python3',
403+
'tools/linter/adapters/grep_linter.py',
404+
'--pattern=#include <pybind11\/',
405+
'--allowlist-pattern=#include <torch\/csrc\/utils\/pybind.h>',
406+
'--linter-name=PYBIND11_INCLUDE',
407+
'--match-first-only',
408+
'--error-name=direct include of pybind11',
409+
# https://stackoverflow.com/a/33416489/23845
410+
# NB: this won't work if the pybind11 include is on the first line;
411+
# but that's fine because it will just mean the lint will still fail
412+
# after applying the change and you will have to fix it manually
413+
'--replace-pattern=1,/(#include <pybind11\/)/ s/(#include <pybind11\/)/#include <torch\/csrc\/utils\/pybind.h>\n\1/',
414+
"""--error-description=\
415+
This #include directly includes pybind11 without also including \
416+
#include <torch/csrc/utils/pybind.h>; this means some important \
417+
specializations may not be included.\
418+
""",
419+
'--',
420+
'@{{PATHSFILE}}'
421+
]
422+
390423
[[linter]]
391424
code = 'PYPIDEP'
392425
include_patterns = ['.github/**']

test/cpp/jit/test_exception.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <torch/csrc/jit/frontend/parser.h>
1010
#include <torch/csrc/jit/frontend/resolver.h>
1111
#include <torch/csrc/jit/runtime/jit_exception.h>
12+
#include <torch/csrc/utils/pybind.h>
1213
#include <torch/jit.h>
1314
#include <iostream>
1415
#include <stdexcept>

tools/autograd/templates/python_enum_tag.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <torch/csrc/autograd/python_enum_tag.h>
2+
#include <torch/csrc/utils/pybind.h>
23
#include <pybind11/pybind11.h>
34
#include <ATen/core/enum_tag.h>
45

tools/autograd/templates/python_functions.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "torch/csrc/autograd/python_cpp_function.h"
1111
#include <torch/csrc/autograd/python_variable.h>
1212
#include <torch/csrc/autograd/saved_variable.h>
13+
#include <torch/csrc/utils/pybind.h>
1314
#include <pybind11/pybind11.h>
1415

1516
// NOTE: See [Sharded File] comment in VariableType

tools/linter/adapters/grep_linter.py

+56-4
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,51 @@ def run_command(
6161

6262
def lint_file(
6363
matching_line: str,
64+
allowlist_pattern: str,
6465
replace_pattern: str,
6566
linter_name: str,
6667
error_name: str,
6768
error_description: str,
68-
) -> LintMessage:
69+
) -> Optional[LintMessage]:
6970
# matching_line looks like:
7071
# tools/linter/clangtidy_linter.py:13:import foo.bar.baz
7172
split = matching_line.split(":")
7273
filename = split[0]
7374

75+
if allowlist_pattern:
76+
try:
77+
proc = run_command(["grep", "-nEHI", allowlist_pattern, filename])
78+
except Exception as err:
79+
return LintMessage(
80+
path=None,
81+
line=None,
82+
char=None,
83+
code=linter_name,
84+
severity=LintSeverity.ERROR,
85+
name="command-failed",
86+
original=None,
87+
replacement=None,
88+
description=(
89+
f"Failed due to {err.__class__.__name__}:\n{err}"
90+
if not isinstance(err, subprocess.CalledProcessError)
91+
else (
92+
"COMMAND (exit code {returncode})\n"
93+
"{command}\n\n"
94+
"STDERR\n{stderr}\n\n"
95+
"STDOUT\n{stdout}"
96+
).format(
97+
returncode=err.returncode,
98+
command=" ".join(as_posix(x) for x in err.cmd),
99+
stderr=err.stderr.decode("utf-8").strip() or "(empty)",
100+
stdout=err.stdout.decode("utf-8").strip() or "(empty)",
101+
)
102+
),
103+
)
104+
105+
# allowlist pattern was found, abort lint
106+
if proc.returncode == 0:
107+
return None
108+
74109
original = None
75110
replacement = None
76111
if replace_pattern:
@@ -109,7 +144,7 @@ def lint_file(
109144

110145
return LintMessage(
111146
path=split[0],
112-
line=int(split[1]),
147+
line=int(split[1]) if len(split) > 1 else None,
113148
char=None,
114149
code=linter_name,
115150
severity=LintSeverity.ERROR,
@@ -130,11 +165,20 @@ def main() -> None:
130165
required=True,
131166
help="pattern to grep for",
132167
)
168+
parser.add_argument(
169+
"--allowlist-pattern",
170+
help="if this pattern is true in the file, we don't grep for pattern",
171+
)
133172
parser.add_argument(
134173
"--linter-name",
135174
required=True,
136175
help="name of the linter",
137176
)
177+
parser.add_argument(
178+
"--match-first-only",
179+
action="store_true",
180+
help="only match the first hit in the file",
181+
)
138182
parser.add_argument(
139183
"--error-name",
140184
required=True,
@@ -174,8 +218,14 @@ def main() -> None:
174218
stream=sys.stderr,
175219
)
176220

221+
files_with_matches = []
222+
if args.match_first_only:
223+
files_with_matches = ["--files-with-matches"]
224+
177225
try:
178-
proc = run_command(["grep", "-nEHI", args.pattern, *args.filenames])
226+
proc = run_command(
227+
["grep", "-nEHI", *files_with_matches, args.pattern, *args.filenames]
228+
)
179229
except Exception as err:
180230
err_msg = LintMessage(
181231
path=None,
@@ -209,12 +259,14 @@ def main() -> None:
209259
for line in lines:
210260
lint_message = lint_file(
211261
line,
262+
args.allowlist_pattern,
212263
args.replace_pattern,
213264
args.linter_name,
214265
args.error_name,
215266
args.error_description,
216267
)
217-
print(json.dumps(lint_message._asdict()), flush=True)
268+
if lint_message is not None:
269+
print(json.dumps(lint_message._asdict()), flush=True)
218270

219271

220272
if __name__ == "__main__":

torch/csrc/Exceptions.h

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <torch/csrc/jit/runtime/jit_exception.h>
1616
#include <torch/csrc/utils/auto_gil.h>
1717
#include <torch/csrc/utils/cpp_stacktraces.h>
18+
#include <torch/csrc/utils/pybind.h>
1819

1920
#if defined(USE_DISTRIBUTED) && defined(USE_C10D)
2021
#include <torch/csrc/distributed/c10d/exception.h>

torch/csrc/Module.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <pybind11/pybind11.h>
2222
#include <pybind11/stl.h>
2323
#include <torch/csrc/THConcat.h>
24+
#include <torch/csrc/utils/pybind.h>
2425
#include <cstdlib>
2526
#include <unordered_map>
2627

torch/csrc/Size.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <c10/util/irange.h>
22
#include <pybind11/pytypes.h>
33
#include <torch/csrc/Size.h>
4+
#include <torch/csrc/utils/pybind.h>
45

56
#include <torch/csrc/utils/object_ptr.h>
67
#include <torch/csrc/utils/python_arg_parser.h>

torch/csrc/Stream.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <pybind11/pybind11.h>
22
#include <torch/csrc/Device.h>
33
#include <torch/csrc/THP.h>
4+
#include <torch/csrc/utils/pybind.h>
45
#include <torch/csrc/utils/python_arg_parser.h>
56

67
#include <structmember.h>

torch/csrc/autograd/functions/pybind.h

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <pybind11/pybind11.h>
44
#include <pybind11/stl.h>
55
#include <torch/csrc/python_headers.h>
6+
#include <torch/csrc/utils/pybind.h>
67

78
#include <torch/csrc/autograd/python_cpp_function.h>
89
#include <torch/csrc/autograd/python_function.h>

torch/csrc/autograd/python_anomaly_mode.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <torch/csrc/python_headers.h>
77
#include <torch/csrc/utils/auto_gil.h>
88
#include <torch/csrc/utils/object_ptr.h>
9+
#include <torch/csrc/utils/pybind.h>
910
#include <torch/csrc/utils/python_strings.h>
1011

1112
#include <iostream>

torch/csrc/autograd/python_anomaly_mode.h

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <torch/csrc/autograd/anomaly_mode.h>
55
#include <torch/csrc/python_headers.h>
66
#include <torch/csrc/utils/auto_gil.h>
7+
#include <torch/csrc/utils/pybind.h>
78

89
namespace torch {
910
namespace autograd {

torch/csrc/autograd/python_cpp_function.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <torch/csrc/autograd/python_function.h>
1515
#include <torch/csrc/autograd/python_hook.h>
1616
#include <torch/csrc/autograd/python_variable.h>
17+
#include <torch/csrc/utils/pybind.h>
1718
#include <torch/csrc/utils/python_numbers.h>
1819
#include <torch/csrc/utils/python_strings.h>
1920

torch/csrc/autograd/python_engine.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <torch/csrc/autograd/python_anomaly_mode.h>
1414
#include <torch/csrc/autograd/python_function.h>
1515
#include <torch/csrc/autograd/python_saved_variable_hooks.h>
16+
#include <torch/csrc/utils/pybind.h>
1617
#include <torch/csrc/utils/pycfunction_helpers.h>
1718

1819
#ifndef _WIN32

torch/csrc/autograd/python_function.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <pybind11/pybind11.h>
77
#include <structmember.h>
88
#include <torch/csrc/python_headers.h>
9+
#include <torch/csrc/utils/pybind.h>
910

1011
#include <ATen/FuncTorchTLS.h>
1112
#include <torch/csrc/DynamicTypes.h>

torch/csrc/autograd/python_hook.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <torch/csrc/THP.h>
77
#include <torch/csrc/autograd/python_variable.h>
88
#include <torch/csrc/utils/object_ptr.h>
9+
#include <torch/csrc/utils/pybind.h>
910
#include <torch/csrc/utils/python_strings.h>
1011

1112
#include <sstream>

torch/csrc/autograd/python_saved_variable_hooks.h

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <torch/csrc/autograd/python_variable.h>
77
#include <torch/csrc/autograd/saved_variable_hooks.h>
88
#include <torch/csrc/python_headers.h>
9+
#include <torch/csrc/utils/pybind.h>
910

1011
namespace py = pybind11;
1112

torch/csrc/autograd/python_variable.h

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <torch/csrc/Exceptions.h>
1010
#include <torch/csrc/Export.h>
1111
#include <torch/csrc/autograd/variable.h>
12+
#include <torch/csrc/utils/pybind.h>
1213

1314
namespace py = pybind11;
1415

torch/csrc/cuda/Event.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <torch/csrc/cuda/Event.h>
55
#include <torch/csrc/cuda/Module.h>
66
#include <torch/csrc/cuda/Stream.h>
7+
#include <torch/csrc/utils/pybind.h>
78
#include <torch/csrc/utils/pycfunction_helpers.h>
89
#include <torch/csrc/utils/python_arg_parser.h>
910

torch/csrc/cuda/Stream.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <torch/csrc/THP.h>
44
#include <torch/csrc/cuda/Module.h>
55
#include <torch/csrc/cuda/Stream.h>
6+
#include <torch/csrc/utils/pybind.h>
67
#include <torch/csrc/utils/python_numbers.h>
78

89
#include <c10/cuda/CUDAGuard.h>

torch/csrc/cuda/python_nccl.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <torch/csrc/Types.h>
99
#include <torch/csrc/cuda/THCP.h>
1010
#include <torch/csrc/cuda/nccl.h>
11+
#include <torch/csrc/utils/pybind.h>
1112

1213
#include <c10/cuda/CUDAGuard.h>
1314
#include <c10/util/irange.h>

torch/csrc/deploy/interpreter/interpreter_impl.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <torch/csrc/autograd/generated/variable_factories.h>
1212
#include <torch/csrc/deploy/Exception.h>
1313
#include <torch/csrc/jit/python/pybind_utils.h>
14+
#include <torch/csrc/utils/pybind.h>
1415

1516
#include <cassert>
1617
#include <cstdio>

torch/csrc/deploy/test_deploy_lib.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <pybind11/pybind11.h>
2+
#include <torch/csrc/utils/pybind.h>
23
#include <cstdint>
34
#include <cstdio>
45
#include <iostream>

torch/csrc/deploy/test_deploy_python_ext.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <caffe2/torch/csrc/deploy/deploy.h>
22
#include <pybind11/pybind11.h>
3+
#include <torch/csrc/utils/pybind.h>
34
#include <cstdint>
45
#include <cstdio>
56
#include <iostream>

torch/csrc/init_flatbuffer_module.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <pybind11/pytypes.h>
1010
#include <pybind11/stl.h>
1111
#include <pybind11/stl_bind.h>
12+
#include <torch/csrc/utils/pybind.h>
1213

1314
#include <Python.h> // NOLINT
1415
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>

torch/csrc/jit/backends/backend_init.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <torch/csrc/jit/backends/backend_resolver.h>
66
#include <torch/csrc/jit/python/module_python.h>
77
#include <torch/csrc/jit/python/pybind_utils.h>
8+
#include <torch/csrc/utils/pybind.h>
89

910
namespace torch {
1011
namespace jit {

torch/csrc/jit/backends/coreml/cpp/preprocess.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <torch/csrc/jit/backends/backend.h>
44
#include <torch/csrc/jit/backends/backend_preprocess.h>
55
#include <torch/csrc/jit/python/pybind_utils.h>
6+
#include <torch/csrc/utils/pybind.h>
67
#include <torch/script.h>
78

89
namespace py = pybind11;

torch/csrc/jit/backends/nnapi/nnapi_backend_preprocess.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <torch/csrc/jit/backends/backend.h>
33
#include <torch/csrc/jit/backends/backend_preprocess.h>
44
#include <torch/csrc/jit/python/pybind_utils.h>
5+
#include <torch/csrc/utils/pybind.h>
56

67
namespace py = pybind11;
78

torch/csrc/jit/python/module_python.h

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <pybind11/pybind11.h>
33
#include <pybind11/stl.h>
44
#include <torch/csrc/jit/api/module.h>
5+
#include <torch/csrc/utils/pybind.h>
56

67
namespace py = pybind11;
78

torch/csrc/jit/python/python_dict.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <torch/csrc/jit/python/pybind_utils.h>
55
#include <torch/csrc/jit/python/python_dict.h>
66
#include <torch/csrc/jit/runtime/jit_exception.h>
7+
#include <torch/csrc/utils/pybind.h>
78
#include <sstream>
89
#include <stdexcept>
910

torch/csrc/jit/python/python_interpreter.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <torch/csrc/autograd/python_engine.h>
2020
#include <torch/csrc/autograd/python_variable.h>
2121
#include <torch/csrc/jit/python/pybind.h>
22+
#include <torch/csrc/utils/pybind.h>
2223

2324
namespace py = pybind11;
2425

torch/csrc/jit/python/python_ivalue.h

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <pybind11/pybind11.h>
44
#include <torch/csrc/jit/python/pybind_utils.h>
55
#include <torch/csrc/python_headers.h>
6+
#include <torch/csrc/utils/pybind.h>
67

78
namespace py = pybind11;
89

torch/csrc/jit/python/python_list.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <pybind11/pytypes.h>
55
#include <torch/csrc/jit/python/pybind_utils.h>
66
#include <torch/csrc/jit/python/python_list.h>
7+
#include <torch/csrc/utils/pybind.h>
78
#include <stdexcept>
89

910
namespace torch {

0 commit comments

Comments
 (0)