forked from hjimce/Depth-Map-Prediction
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy paththeano_test_value_size.patch
58 lines (51 loc) · 2.52 KB
/
theano_test_value_size.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
diff --git a/theano/configdefaults.py b/theano/configdefaults.py
index 58ed2e9..97d6564 100644
--- a/theano/configdefaults.py
+++ b/theano/configdefaults.py
@@ -470,6 +470,14 @@ AddConfigVar('compute_test_value',
EnumStr('off', 'ignore', 'warn', 'raise', 'pdb'),
in_c_key=False)
+AddConfigVar('store_test_value_maxsize',
+ ("Maximum size for test values that are kept. If compute_test_value "
+ "is enabled, keeps test values smaller than the given size (in "
+ "number of entries). Beyond that, only the shape is stored; a "
+ "an array with the same shape and type is created on demand, filled "
+ "with a single random entry from the array."),
+ IntParam(sys.maxint),
+ in_c_key=False)
AddConfigVar('compute_test_value_opt',
("For debugging Theano optimization only."
diff --git a/theano/gof/op.py b/theano/gof/op.py
index ac85eec..a306077 100644
--- a/theano/gof/op.py
+++ b/theano/gof/op.py
@@ -18,6 +18,7 @@ import numpy
import os
import sys
import warnings
+import numpy
import theano
from theano import config
@@ -461,6 +462,10 @@ class PureOp(object):
elif isinstance(v, graph.Variable) and hasattr(v.tag, 'test_value'):
# ensure that the test value is correct
return v.type.filter(v.tag.test_value)
+ elif isinstance(v, graph.Variable) and hasattr(v.tag, 'test_shape'):
+ test_value = numpy.empty(v.tag.test_shape, dtype=v.type.dtype)
+ test_value.fill(v.tag.test_value_fill)
+ return v.type.filter(test_value, strict=False, allow_downcast=True)
raise AttributeError('%s has no test value' % v)
@@ -552,7 +557,14 @@ class PureOp(object):
# add 'test_value' to output tag, so that downstream ops can use these
# numerical values as inputs to their perform method.
- output.tag.test_value = storage_map[output][0]
+ test_value = storage_map[output][0]
+ if not hasattr(test_value, 'size') or \
+ test_value.size < config.store_test_value_maxsize:
+ output.tag.test_value = test_value
+ elif hasattr(test_value, 'shape'):
+ test_value = numpy.asarray(test_value)
+ output.tag.test_shape = test_value.shape
+ output.tag.test_value_fill = test_value.flat[0]
if self.default_output is not None:
rval = node.outputs[self.default_output]