Skip to content

Commit 9eaf92a

Browse files
committed
Remove test uses of deprecated useless transform=None in favor of default_transform=None
1 parent 76799fa commit 9eaf92a

14 files changed

+34
-26
lines changed

pymc/model/transform/conditioning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def change_value_transforms(
249249
from pymc.model.transform.conditioning import change_value_transforms
250250
251251
with pm.Model() as base_m:
252-
p = pm.Uniform("p", 0, 1, transform=None)
252+
p = pm.Uniform("p", 0, 1, default_transform=None)
253253
w = pm.Binomial("w", n=9, p=p, observed=6)
254254
255255
with change_value_transforms(base_m, {"p": logodds}) as transformed_p:

pymc/testing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def build_model(distfam, valuedomain, vardomains, extra_args=None):
242242
distfam(
243243
"value",
244244
**param_vars,
245-
transform=None,
245+
default_transform=None,
246246
)
247247
return m, param_vars
248248

tests/distributions/test_continuous.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def test_wald_logp_custom_points(self, value, mu, lam, phi, alpha, logp):
370370
# See e.g., doi: 10.1111/j.1467-9876.2005.00510.x, or
371371
# http://www.gamlss.org/.
372372
with pm.Model() as model:
373-
pm.Wald("wald", mu=mu, lam=lam, phi=phi, alpha=alpha, transform=None)
373+
pm.Wald("wald", mu=mu, lam=lam, phi=phi, alpha=alpha, default_transform=None)
374374
point = {"wald": value}
375375
decimals = select_by_precision(float64=6, float32=1)
376376
npt.assert_almost_equal(

tests/distributions/test_distribution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def test_issue_4499(self):
8383
# Test for bug in Uniform and DiscreteUniform logp when setting check_bounds = False
8484
# https://github.com/pymc-devs/pymc/issues/4499
8585
with pm.Model(check_bounds=False) as m:
86-
x = pm.Uniform("x", 0, 2, size=10, transform=None)
86+
x = pm.Uniform("x", 0, 2, size=10, default_transform=None)
8787
npt.assert_almost_equal(m.compile_logp()({"x": np.ones(10)}), -np.log(2) * 10)
8888

8989
with pm.Model(check_bounds=False) as m:

tests/distributions/test_mixture.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def test_list_mvnormals_logp(self):
433433
cov2 = np.diag([2.5, 3.5])
434434
obs = np.asarray([[0.5, 0.5], mu1, mu2])
435435
with Model() as model:
436-
w = Dirichlet("w", floatX(np.ones(2)), transform=None, shape=(2,))
436+
w = Dirichlet("w", floatX(np.ones(2)), default_transform=None, shape=(2,))
437437
mvncomp1 = MvNormal.dist(mu=mu1, cov=cov1)
438438
mvncomp2 = MvNormal.dist(mu=mu2, cov=cov2)
439439
y = Mixture("x_obs", w, [mvncomp1, mvncomp2], observed=obs)
@@ -630,19 +630,27 @@ def test_nested_mixture(self):
630630
with Model() as model:
631631
# mixtures components
632632
g_comp = Normal.dist(
633-
mu=Exponential("mu_g", lam=1.0, shape=nbr, transform=None), sigma=1, shape=nbr
633+
mu=Exponential("mu_g", lam=1.0, shape=nbr, default_transform=None),
634+
sigma=1,
635+
shape=nbr,
634636
)
635637
l_comp = LogNormal.dist(
636-
mu=Exponential("mu_l", lam=1.0, shape=nbr, transform=None), sigma=1, shape=nbr
638+
mu=Exponential("mu_l", lam=1.0, shape=nbr, default_transform=None),
639+
sigma=1,
640+
shape=nbr,
637641
)
638642
# weight vector for the mixtures
639-
g_w = Dirichlet("g_w", a=floatX(np.ones(nbr) * 0.0000001), transform=None, shape=(nbr,))
640-
l_w = Dirichlet("l_w", a=floatX(np.ones(nbr) * 0.0000001), transform=None, shape=(nbr,))
643+
g_w = Dirichlet(
644+
"g_w", a=floatX(np.ones(nbr) * 0.0000001), default_transform=None, shape=(nbr,)
645+
)
646+
l_w = Dirichlet(
647+
"l_w", a=floatX(np.ones(nbr) * 0.0000001), default_transform=None, shape=(nbr,)
648+
)
641649
# mixture components
642650
g_mix = Mixture.dist(w=g_w, comp_dists=g_comp)
643651
l_mix = Mixture.dist(w=l_w, comp_dists=l_comp)
644652
# mixture of mixtures
645-
mix_w = Dirichlet("mix_w", a=floatX(np.ones(2)), transform=None, shape=(2,))
653+
mix_w = Dirichlet("mix_w", a=floatX(np.ones(2)), default_transform=None, shape=(2,))
646654
mix = Mixture("mix", w=mix_w, comp_dists=[g_mix, l_mix], observed=np.exp(norm_x))
647655

648656
test_point = model.initial_point()
@@ -1306,9 +1314,9 @@ def test_hierarchical_interval_transform(self):
13061314
with Model() as model:
13071315
lower = Normal("lower", 0.5)
13081316
upper = Uniform("upper", 0, 1)
1309-
uniform = Uniform("uniform", -pt.abs(lower), pt.abs(upper), transform=None)
1317+
uniform = Uniform("uniform", -pt.abs(lower), pt.abs(upper), default_transform=None)
13101318
triangular = Triangular(
1311-
"triangular", -pt.abs(lower), pt.abs(upper), c=0.25, transform=None
1319+
"triangular", -pt.abs(lower), pt.abs(upper), c=0.25, default_transform=None
13121320
)
13131321
comp_dists = [
13141322
Uniform.dist(-pt.abs(lower), pt.abs(upper)),
@@ -1334,7 +1342,7 @@ def test_logp(self):
13341342
halfnorm = HalfNormal("halfnorm")
13351343
comp_dists = [HalfNormal.dist(), HalfNormal.dist()]
13361344
mix_transf = Mixture("mix_transf", w=[0.5, 0.5], comp_dists=comp_dists)
1337-
mix = Mixture("mix", w=[0.5, 0.5], comp_dists=comp_dists, transform=None)
1345+
mix = Mixture("mix", w=[0.5, 0.5], comp_dists=comp_dists, default_transform=None)
13381346

13391347
logp_fn = m.compile_logp(vars=[halfnorm, mix_transf, mix], sum=False)
13401348
test_point = {"halfnorm_log__": 1, "mix_transf_log__": 1, "mix": np.exp(1)}

tests/distributions/test_multivariate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ def test_wishart(self, n):
559559
@pytest.mark.parametrize("x,eta,n,lp", LKJ_CASES)
560560
def test_lkjcorr(self, x, eta, n, lp):
561561
with pm.Model() as model:
562-
pm.LKJCorr("lkj", eta=eta, n=n, transform=None, return_matrix=False)
562+
pm.LKJCorr("lkj", eta=eta, n=n, default_transform=None, return_matrix=False)
563563

564564
point = {"lkj": x}
565565
decimals = select_by_precision(float64=6, float32=4)
@@ -790,7 +790,7 @@ def test_dirichlet_multinomial_vectorized(self, n, a, extra_size):
790790
)
791791
def test_stickbreakingweights_logp(self, value, alpha, K, logp):
792792
with pm.Model() as model:
793-
sbw = pm.StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None)
793+
sbw = pm.StickBreakingWeights("sbw", alpha=alpha, K=K, default_transform=None)
794794
point = {"sbw": value}
795795
npt.assert_almost_equal(
796796
pm.logp(sbw, value).eval(),
@@ -817,7 +817,7 @@ def test_stickbreakingweights_invalid(self):
817817
def test_stickbreakingweights_vectorized(self, alpha, K, stickbreakingweights_logpdf):
818818
value = pm.StickBreakingWeights.dist(alpha, K).eval()
819819
with pm.Model():
820-
sbw = pm.StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None)
820+
sbw = pm.StickBreakingWeights("sbw", alpha=alpha, K=K, default_transform=None)
821821
point = {"sbw": value}
822822
npt.assert_almost_equal(
823823
pm.logp(sbw, value).eval(),

tests/distributions/test_truncated.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def test_truncated_default_transform():
386386
def test_truncated_transform_logp():
387387
with Model() as m:
388388
base_dist = rejection_normal(0, 1)
389-
x = Truncated("x", base_dist, lower=0, upper=None, transform=None)
389+
x = Truncated("x", base_dist, lower=0, upper=None, default_transform=None)
390390
y = Truncated("y", base_dist, lower=0, upper=None)
391391
logp_eval = m.compile_logp(sum=False)({"x": -1, "y_interval__": -1})
392392
assert logp_eval[0] == -np.inf

tests/model/test_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1671,7 +1671,7 @@ def test_invalid_parameter_cant_be_evaluated(self, fn, verbose, capfd):
16711671
def test_invalid_value(self, capfd):
16721672
with pm.Model() as m:
16731673
x = pm.Normal("x", [1, -1, 1])
1674-
y = pm.HalfNormal("y", tau=pm.math.abs(x), initval=[-1, 1, -1], transform=None)
1674+
y = pm.HalfNormal("y", tau=pm.math.abs(x), initval=[-1, 1, -1], default_transform=None)
16751675
m.debug()
16761676

16771677
out, _ = capfd.readouterr()

tests/model/transform/test_conditioning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def test_do_self_reference():
255255

256256
def test_change_value_transforms():
257257
with pm.Model() as base_m:
258-
p = pm.Uniform("p", 0, 1, transform=None)
258+
p = pm.Uniform("p", 0, 1, default_transform=None)
259259
w = pm.Binomial("w", n=9, p=p, observed=6)
260260
assert base_m.rvs_to_transforms[p] is None
261261
assert base_m.rvs_to_values[p].name == "p"

tests/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,13 @@ def mv_simple_discrete():
161161

162162
def non_normal(n=2):
163163
with pm.Model() as model:
164-
pm.Beta("x", 3, 3, size=n, transform=None)
164+
pm.Beta("x", 3, 3, size=n, default_transform=None)
165165
return model.initial_point(), model, (np.tile([0.5], n), None)
166166

167167

168168
def beta_bernoulli(n=2):
169169
with pm.Model() as model:
170-
pm.Beta("x", 3, 1, size=n, transform=None)
170+
pm.Beta("x", 3, 1, size=n, default_transform=None)
171171
pm.Bernoulli("y", 0.5)
172172
return model.initial_point(), model, None
173173

tests/sampling/test_jax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ def test_sample_var_names():
504504
def test_convergence_warnings(caplog, nuts_sampler):
505505
with pm.Model() as m:
506506
# Model that should diverge
507-
sigma = pm.Normal("sigma", initval=3, transform=None)
507+
sigma = pm.Normal("sigma", initval=3, default_transform=None)
508508
pm.Normal("obs", mu=0, sigma=sigma, observed=[0.99, 1.0, 1.01])
509509

510510
with caplog.at_level(logging.WARNING, logger="pymc"):

tests/sampling/test_mcmc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,7 @@ def test_exec_nuts_init(method):
627627
)
628628
def test_init_jitter(initval, jitter_max_retries, expectation):
629629
with pm.Model() as m:
630-
pm.HalfNormal("x", transform=None, initval=initval)
630+
pm.HalfNormal("x", default_transform=None, initval=initval)
631631

632632
with expectation:
633633
# Starting value is negative (invalid) when np.random.rand returns 0 (jitter = -1)

tests/step_methods/hmc/test_nuts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,14 @@ def test_multiple_samplers(self, caplog):
108108

109109
def test_bad_init_nonparallel(self):
110110
with pm.Model():
111-
pm.HalfNormal("a", sigma=1, initval=-1, transform=None)
111+
pm.HalfNormal("a", sigma=1, initval=-1, default_transform=None)
112112
with pytest.raises(SamplingError) as error:
113113
pm.sample(chains=1, random_seed=1)
114114
error.match("Initial evaluation")
115115

116116
def test_bad_init_parallel(self):
117117
with pm.Model():
118-
pm.HalfNormal("a", sigma=1, initval=-1, transform=None)
118+
pm.HalfNormal("a", sigma=1, initval=-1, default_transform=None)
119119
with pytest.raises(SamplingError) as error:
120120
pm.sample(cores=2, random_seed=1)
121121
error.match("Initial evaluation")

tests/test_initial_point.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class TestInitvalEvaluation:
6666
def test_make_initial_point_fns_per_chain_checks_kwargs(self):
6767
with pm.Model() as pmodel:
6868
A = pm.Uniform("A", 0, 1, initval=0.5)
69-
B = pm.Uniform("B", lower=A, upper=1.5, transform=None, initval="support_point")
69+
B = pm.Uniform("B", lower=A, upper=1.5, default_transform=None, initval="support_point")
7070
with pytest.raises(ValueError, match="Number of initval dicts"):
7171
make_initial_point_fns_per_chain(
7272
model=pmodel,

0 commit comments

Comments
 (0)