forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest_dispatch.py
381 lines (354 loc) · 17 KB
/
test_dispatch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
import torch._C as C
from torch.testing._internal.common_utils import TestCase, run_tests
import itertools
import unittest
# TODO: Expand the dispatcher API to be a generic API for interfacing with
# the dispatcher from Python!
#
# These are exhaustive tests for commutativity of dispatch behavior. If you're
# looking for more usage-info style tests, check op_registration_test.cpp
#
# Things not tested here:
# - Listeners
# - Top level namespace registrations
# - Fallback
# - Exotic overloads of CppFunction/schema
#
# Things not directly tested here:
# - Internal state of Dispatcher makes sense. This is indirectly
# tested by the invariant testing
class TestDispatch(TestCase):
namespace_index = 0
def test_all_invariants(self):
# Check that the regular stuff is OK!
C._dispatch_check_all_invariants()
# You probably don't want to call this directly; if your constructors
# don't commute, you can still run commute with a fixed ctor_order
# so that you can test that the destructors still commute
def run_ops(self, name, ops, ctor_order=None, dtor_order=None,
results=None, expect_raises=False):
"""
Given a list of operator registrations, run the registrations in the
order specified by ctor_order, and then run the deregistrations in
dtor_order.
If results is specified, intermediate results are checked for consistency
with results stored in results (and stored in results if this is the
first time we've seen them). Results are expected to be equivalent
modulo commutativity and inverses (thus, results is keyed on a frozenset
of in effect registrations from ops). Results stores Tuple[str, provenance],
where provenance is a string that describes how exactly we got this
string.
If expect_raises is True, it is not an error to raise an exception. Instead,
we'll store the exception string (instead of the dispatcher state)
in results. In principle we should flag these differently, but it's
very obvious when you get an error in one case but not another.
"""
# By allocating every test into a fresh namespace, this makes it less
# likely that a bug in the testing framework will result in tests
# interfering with each other
self.__class__.namespace_index += 1
if results is None:
results = {}
if ctor_order is None:
ctor_order = list(range(len(ops)))
if dtor_order is None:
dtor_order = list(reversed(ctor_order))
# Refs which retain the c10::Module object so we can explicitly control
# when each deregistration happens (deregistration occurs when the
# object gets deallocated).
refs = [None] * len(ops)
# Keep track of the set "in effect" registrations
active_ops = set()
# double underscore to make it less likely we conflict with something
# else
test_namespace = "__test{}__".format(self.namespace_index)
def check_invariants(actual_provenance):
C._dispatch_check_invariants(name)
# Normalize the test namespace so that expected outputs are stable
actual = C._dispatch_dump(
"{}::{}".format(test_namespace, name)).replace(test_namespace, "test")
expected, expected_provenance = results.setdefault(
frozenset(active_ops),
(actual, actual_provenance)
)
self.assertMultiLineEqual(
expected, actual,
"expected from {}; actual from {}"
.format(expected_provenance, actual_provenance)
)
results.setdefault(frozenset(), ("", "hardcoded initial state"))
check_invariants("initial state")
# In the order specified by ctor_order, run registrations
set_to_report = frozenset(range(len(ops)))
for i, op_ix in enumerate(ctor_order):
# It would be better to DEF here, but because we manage
# lifetime of multiple registrations with multiple Library
# references (refs), we can't deal with the strict checking
# from DEF.
refs[op_ix] = C._dispatch_library("FRAGMENT", test_namespace, "")
active_ops.add(op_ix)
try:
ops[op_ix](refs[op_ix])
check_invariants("running ctors {}".format(ctor_order[:i + 1]))
except RuntimeError as e:
if not expect_raises:
raise
actual = str(e).replace(test_namespace, "test")
expected, expected_provenance = results.setdefault(
frozenset(active_ops),
(actual, "error after running ctors {}".format(ctor_order[:i + 1]))
)
self.assertMultiLineEqual(expected, actual, expected_provenance)
set_to_report = frozenset(active_ops)
active_ops.remove(op_ix)
# NB: this finally test asserts that if a registrations fails,
# the dispatcher is left in the same state *that it was before*!
check_invariants(
"running ctors {} and then failing to run ctor {} "
"(did this failure leave the dispatcher in a wedged state? "
"it shouldn't!)"
.format(ctor_order[:i], op_ix))
break
last_ctor = i
if expect_raises and len(active_ops) == len(ops):
# Destroy references first, as some test frameworks (like pytest)
# will retain references in the exception raised by assertTrue! EW!
refs = None
self.assertTrue(
False,
"expected exception to be raised, but nothing was raised "
"(after running ctors {})".format(ctor_order))
# In the order specified by dtor_order, run deregistrations
for i, op_ix in enumerate(dtor_order):
# Trigger a destruction
refs[op_ix] = None
# discard not remove, since we may not have actually deregistered
# anything if there was an error raised
if expect_raises:
active_ops.discard(op_ix)
else:
active_ops.remove(op_ix)
check_invariants(
"running ctors {}, then running dtors {}"
.format(ctor_order[:last_ctor + 1], dtor_order[:i + 1])
)
return results[set_to_report][0]
# Operator registrations are commutative (as static initializers can
# run in any order) and invertible (by deregistration). (Subject
# to some caveats: some legacy behavior in the system are not commutative--
# we want to get rid of these!)
#
# So while in principle we could simply test a set of operations
# by just running them one by one in the order specified by the user,
# we can get more assurance about these extra properties by doing
# more work:
#
# 1. Don't run the registrations once in a fixed order: run every possible
# permutation. Similarly, run every permutation of deregistration order.
#
# 2. Don't just check the end state of the dispatcher: for every
# subset of operator registrations, ensure that the computed
# intermediate state is path independent. One thing to note:
# in this function, we assume each operation is unique. In general,
# there may be duplicated registrations, but these are usually
# idempotent or legacy. We test for behavior here separately.
#
# NB: checking all permutations means this function is exponential in
# the length of ops! So don't pass too many ops to this function!
def commute(self, name, ops, ctor_order=None, expect_raises=False):
results = {}
def go(ctor_order):
for dtor_order in itertools.permutations(range(len(ops))):
self.run_ops(
name, ops, ctor_order, dtor_order,
results=results, expect_raises=expect_raises)
if ctor_order is not None:
go(ctor_order)
else:
for ctor_order in itertools.permutations(range(len(ops))):
go(ctor_order)
# Return the "full" state after all operations are run.
# If this KeyErrors, that means that there did not exist any
# ordering of ctors which got us to the "end". That's an
# error in test construction: it means you could have
# factored the test into two smaller ones.
return results[frozenset(range(len(ops)))][0]
def test_def(self):
r = self.commute("foo", [
# m.def("foo(Tensor x) -> Tensor")
lambda m: m.def_("foo(Tensor x) -> Tensor"),
# m.impl("test_def", [](const Tensor& x) { return x })
lambda m: m.impl_t_t("foo"),
# m.impl("test_def", kAutograd, [](const Tensor& x) { return x })
lambda m: m.impl_t_t("foo", dispatch="autograd")
])
self.assertExpectedInline(r, '''\
name: test::foo
schema: test::foo(Tensor x) -> (Tensor)
debug: registered at /dev/null:0
alias analysis kind: FROM_SCHEMA
Autograd: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
catchall: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
''')
def test_def_impl_schema_mismatch(self):
# NB: an impl-impl mismatch is not reported eagerly; you'll find out
# about it because one of them won't match with def
r = self.commute("foo", [
# m.def("foo(Tensor x, Tensor y) -> Tensor")
lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"),
# m.impl("foo", [](const Tensor & x) { return x })
lambda m: m.impl_t_t("foo"),
], expect_raises=True)
self.assertExpectedInline(r, '''In registration for test::foo: expected schema of operator to be "test::foo(Tensor x, Tensor y) -> (Tensor)" (registered at /dev/null:0), but got inferred schema "(Tensor _0) -> (Tensor _0)" (impl_t_t). The number of arguments is different. 2 vs 1.''') # noqa
def test_def_with_inference(self):
r = self.commute("foo", [
# m.def("foo", [](const Tensor & x) { return x })
lambda m: m.def_name_t_t("foo"),
# m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x })
lambda m: m.impl_t_t("foo", "autograd")
])
self.assertExpectedInline(r, '''\
name: test::foo
schema: test::foo(Tensor _0) -> (Tensor _0)
debug: registered at /dev/null:0
alias analysis kind: CONSERVATIVE
Autograd: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
''')
def test_def_only(self):
r = self.commute("foo", [
# m.def("foo(Tensor x, Tensor y) -> Tensor")
lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"),
])
self.assertExpectedInline(r, '''\
name: test::foo
schema: test::foo(Tensor x, Tensor y) -> (Tensor)
debug: registered at /dev/null:0
alias analysis kind: FROM_SCHEMA
''')
def test_impl_only(self):
r = self.commute("foo", [
# m.impl("foo", [](const Tensor& x) { return x })
lambda m: m.impl_t_t("foo"),
# m.impl("foo", torch::kAutograd, [](const Tensor& x) { return x })
lambda m: m.impl_t_t("foo", "autograd")
])
self.assertExpectedInline(r, '''\
name: test::foo
schema: (none)
Autograd: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
catchall: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
''')
# Can't do this yet for BC reasons
@unittest.expectedFailure
def test_multiple_def_error(self):
r = self.commute("foo", [
# m.def("foo(Tensor x, Tensor y) -> Tensor")
lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"),
# m.def("foo(Tensor x, Tensor y) -> Tensor")
lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"),
], expect_raises=True)
# TODO: fill in the error message here
# self.assertExpectedInline(r, '''''')
def test_def_with_explicit_alias(self):
r = self.commute("foo", [
# m.def(torch::schema(
# "foo(Tensor x, Tensor y) -> Tensor",
# AliasAnalysisKind::PURE))
lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor",
alias="PURE_FUNCTION")
])
self.assertExpectedInline(r, '''\
name: test::foo
schema: test::foo(Tensor x, Tensor y) -> (Tensor)
debug: registered at /dev/null:0
alias analysis kind: PURE_FUNCTION
''')
# TODO: get rid of this test when multiple defs are wrong
def test_multiple_def_schema_mismatch(self):
# error message is order dependent
ops = [
# m.def("foo(Tensor x, Tensor y) -> Tensor")
lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"),
# m.def("foo(Tensor x) -> Tensor")
lambda m: m.def_("foo(Tensor x) -> Tensor"),
]
self.assertExpectedInline(
self.commute("foo", ops, ctor_order=(0, 1), expect_raises=True),
'''Tried to register multiple operators with the same name and the same overload name but different schemas: test::foo(Tensor x) -> (Tensor) (registered at /dev/null:0) vs test::foo(Tensor x, Tensor y) -> (Tensor) (registered at /dev/null:0)''' # noqa
)
self.assertExpectedInline(
self.commute("foo", ops, ctor_order=(1, 0), expect_raises=True),
'''Tried to register multiple operators with the same name and the same overload name but different schemas: test::foo(Tensor x, Tensor y) -> (Tensor) (registered at /dev/null:0) vs test::foo(Tensor x) -> (Tensor) (registered at /dev/null:0)''' # noqa
)
def test_multiple_def_alias_defaulting(self):
# TODO: should be an error in both directions soon
ops = [
# m.def(torch::schema("foo(Tensor x) -> Tensor",
# c10::AliasAnalysisKind::PURE_FUNCTION))
lambda m: m.def_("foo(Tensor x) -> Tensor", alias="PURE_FUNCTION"),
# RegisterOperators().op("foo(Tensor x) -> Tensor")
lambda m: m.def_legacy("foo(Tensor x) -> Tensor"),
]
self.assertExpectedInline(
self.commute("foo", ops, ctor_order=(0, 1)),
'''\
name: test::foo
schema: test::foo(Tensor x) -> (Tensor)
debug: registered at /dev/null:0
alias analysis kind: PURE_FUNCTION
'''
)
# NB: When run with ctor order (1, 0), the destructors are NOT
# COMMUTATIVE. THIS IS A BUG, however we are purposely leaving the bug
# in as it is very benign (only leaves us in a bad state during
# destruction, when no useful work is being done), will be fixed when we
# make alias defaulting a hard error, and is very nontrivial to fix
# prior to that.
def test_multiple_def_alias_mismatch(self):
# error message is order dependent
ops = [
# m.def(torch::schema("foo(Tensor x) -> Tensor",
# c10::AliasAnalysisKind::PURE_FUNCTION))
lambda m: m.def_("foo(Tensor x) -> Tensor", alias="PURE_FUNCTION"),
# m.def(torch::schema("foo(Tensor x) -> Tensor",
# c10::AliasAnalysisKind::CONSERVATIVE))
lambda m: m.def_("foo(Tensor x) -> Tensor", alias="CONSERVATIVE"),
]
self.assertExpectedInline(
self.commute("foo", ops, ctor_order=(0, 1), expect_raises=True),
'''Tried to define the schema for test::foo with different alias analysis kinds: PURE_FUNCTION (registered at /dev/null:0) vs CONSERVATIVE (registered at /dev/null:0)''' # noqa
)
self.assertExpectedInline(
self.commute("foo", ops, ctor_order=(1, 0), expect_raises=True),
'''Tried to define the schema for test::foo with different alias analysis kinds: CONSERVATIVE (registered at /dev/null:0) vs PURE_FUNCTION (registered at /dev/null:0)''' # noqa
)
def test_multiple_fallback(self):
global_m = C._dispatch_library("IMPL", "_", "xla")
global_m.fallback_fallthrough(),
try:
global_m.fallback_fallthrough(),
except RuntimeError as e:
self.assertExpectedInline(
str(e),
'''Tried to register multiple backend fallbacks for the same dispatch key XLA; previous registration registered at /dev/null:0, new registration registered at /dev/null:0''' # noqa
)
else:
self.assertTrue(False)
def test_overwrite_catchall(self):
ops = [
lambda m: m.impl_t_t("foo", debug="fn1"),
lambda m: m.impl_t_t("foo", debug="fn2"),
]
# Not commutative
self.assertExpectedInline(
self.commute("foo", ops, ctor_order=(0, 1)),
'''\
name: test::foo
schema: (none)
catchall: fn2 :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
catchall (inactive): fn1 :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ]
'''
)
if __name__ == '__main__':
run_tests()