Skip to content

Commit 44393bb

Browse files
committed
make doctests pass for pymc_model.py #323
1 parent 15b5756 commit 44393bb

File tree

1 file changed

+26
-12
lines changed

1 file changed

+26
-12
lines changed

causalpy/pymc_models.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,17 @@ class ModelBuilder(pm.Model):
7272
... }
7373
... )
7474
>>> model.fit(X, y)
75-
Inference...
75+
<BLANKLINE>
76+
<BLANKLINE>
77+
Inference data...
7678
>>> X_new = rng.normal(loc=0, scale=1, size=(20,2))
7779
>>> model.predict(X_new)
78-
Inference...
79-
>>> model.score(X, y) # doctest: +NUMBER
80-
r2 0.3
81-
r2_std 0.0
80+
<BLANKLINE>
81+
Inference data...
82+
>>> model.score(X, y)
83+
<BLANKLINE>
84+
r2 0.390344
85+
r2_std 0.081135
8286
dtype: float64
8387
"""
8488

@@ -112,10 +116,7 @@ def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
112116

113117
# Ensure random_seed is used in sample_prior_predictive() and
114118
# sample_posterior_predictive() if provided in sample_kwargs.
115-
if "random_seed" in self.sample_kwargs:
116-
random_seed = self.sample_kwargs["random_seed"]
117-
else:
118-
random_seed = None
119+
random_seed = self.sample_kwargs.get("random_seed", None)
119120

120121
self.build_model(X, y, coords)
121122
with self:
@@ -137,10 +138,17 @@ def predict(self, X):
137138
138139
"""
139140

141+
# Ensure random_seed is used in sample_prior_predictive() and
142+
# sample_posterior_predictive() if provided in sample_kwargs.
143+
random_seed = self.sample_kwargs.get("random_seed", None)
144+
140145
self._data_setter(X)
141146
with self: # sample with new input data
142147
post_pred = pm.sample_posterior_predictive(
143-
self.idata, var_names=["y_hat", "mu"], progressbar=False
148+
self.idata,
149+
var_names=["y_hat", "mu"],
150+
progressbar=False,
151+
random_seed=random_seed,
144152
)
145153
return post_pred
146154

@@ -193,7 +201,9 @@ class WeightedSumFitter(ModelBuilder):
193201
>>> y = np.asarray(sc['actual']).reshape((sc.shape[0], 1))
194202
>>> wsf = WeightedSumFitter(sample_kwargs={"progressbar": False})
195203
>>> wsf.fit(X,y)
196-
Inference ...
204+
<BLANKLINE>
205+
<BLANKLINE>
206+
Inference data...
197207
""" # noqa: W605
198208

199209
def build_model(self, X, y, coords):
@@ -249,7 +259,9 @@ class LinearRegression(ModelBuilder):
249259
... 'obs_indx': np.arange(rd.shape[0])
250260
... },
251261
... )
252-
Inference...
262+
<BLANKLINE>
263+
<BLANKLINE>
264+
Inference data...
253265
""" # noqa: W605
254266

255267
def build_model(self, X, y, coords):
@@ -301,6 +313,8 @@ class InstrumentalVariableRegression(ModelBuilder):
301313
... "eta": 2,
302314
... "lkj_sd": 2,
303315
... })
316+
<BLANKLINE>
317+
<BLANKLINE>
304318
Inference data...
305319
"""
306320

0 commit comments

Comments
 (0)