Interactive online version: Open In Colab

Variationally Inferred Parameterization

Author: Madhav Kanda

Occasionally, the Hamiltonian Monte Carlo (HMC) sampler encounters challenges in effectively sampling from the posterior distribution. One illustrative case is Neal’s funnel. In these situations, the conventional centered parameterization may prove inadequate, leading us to employ non-centered parameterization. However, there are instances where even non-centered parameterization may not suffice, necessitating the utilization of Variationally Inferred Parameterization to attain the desired centeredness within the range of 0 to 1.

The purpose of this tutorial is to implement Variationally Inferred Parameterization based on Automatic Reparameterization of Probabilistic Programs using LocScaleReparam in Numpyro.

[ ]:
%pip -qq install numpyro
%pip -qq install ucimlrepo
[ ]:
importarvizasaz
importnumpyasnp
fromucimlrepoimport fetch_ucirepo
importjax
importjax.numpyasjnp
importnumpyro
importnumpyro.distributionsasdist
fromnumpyro.inferimport MCMC, NUTS, SVI, Trace_ELBO
fromnumpyro.infer.autoguideimport AutoDiagonalNormal
fromnumpyro.infer.reparamimport LocScaleReparam
rng_key = jax.random.PRNGKey(0)

1. Dataset

We will be using the German Credit Dataset for this illustration. The dataset consists of 1000 entries with 20 categorial symbolic attributes prepared by Prof. Hofmann. In this dataset, each entry represents a person who takes a credit by a bank. Each person is classified as good or bad credit risks according to the set of attributes.

[ ]:
defload_german_credit():
 statlog_german_credit_data = fetch_ucirepo(id=144)
 X = statlog_german_credit_data.data.features
 y = statlog_german_credit_data.data.targets
 return X, y
[ ]:
X, y = load_german_credit()
X
Attribute1 Attribute2 Attribute3 Attribute4 Attribute5 Attribute6 Attribute7 Attribute8 Attribute9 Attribute10 Attribute11 Attribute12 Attribute13 Attribute14 Attribute15 Attribute16 Attribute17 Attribute18 Attribute19 Attribute20
0 A11 6 A34 A43 1169 A65 A75 4 A93 A101 4 A121 67 A143 A152 2 A173 1 A192 A201
1 A12 48 A32 A43 5951 A61 A73 2 A92 A101 2 A121 22 A143 A152 1 A173 1 A191 A201
2 A14 12 A34 A46 2096 A61 A74 2 A93 A101 3 A121 49 A143 A152 1 A172 2 A191 A201
3 A11 42 A32 A42 7882 A61 A74 2 A93 A103 4 A122 45 A143 A153 1 A173 2 A191 A201
4 A11 24 A33 A40 4870 A61 A73 3 A93 A101 4 A124 53 A143 A153 2 A173 2 A191 A201
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
995 A14 12 A32 A42 1736 A61 A74 3 A92 A101 4 A121 31 A143 A152 1 A172 1 A191 A201
996 A11 30 A32 A41 3857 A61 A73 4 A91 A101 4 A122 40 A143 A152 1 A174 1 A192 A201
997 A14 12 A32 A43 804 A61 A75 4 A93 A101 4 A123 38 A143 A152 1 A173 1 A191 A201
998 A11 45 A32 A43 1845 A61 A73 4 A93 A101 4 A124 23 A143 A153 1 A173 1 A192 A201
999 A12 45 A34 A41 4576 A62 A71 3 A93 A101 4 A123 27 A143 A152 1 A173 1 A191 A201

1000 rows ×ばつ 20 columns

Here, X depicts 20 attributes and the values corresponding to these attributes for each person represented in the data entry and y is the output variable corresponding to these attributes

[ ]:
defdata_transform(X, y):
 defcategorical_to_int(x):
 d = {u: i for i, u in enumerate(np.unique(x))}
 return np.array([d[i] for i in x])
 categoricals = []
 numericals = []
 numericals.append(np.ones([len(y)]))
 for column in X:
 column = X[column]
 if column.dtype == "O":
 categoricals.append(categorical_to_int(column))
 else:
 numericals.append((column - column.mean()) / column.std())
 numericals = np.array(numericals).T
 status = np.array(y == 1, dtype=np.int32)
 status = np.squeeze(status)
 return jnp.array(numericals), jnp.array(categoricals), jnp.array(status)

Data transformation for feeding it into the Numpyro model

[ ]:
numericals, categoricals, status = data_transform(X, y)
[ ]:
x_numeric = numericals.astype(jnp.float32)
x_categorical = [jnp.eye(c.max() + 1)[c] for c in categoricals]
all_x = jnp.concatenate([x_numeric] + x_categorical, axis=1)
num_features = all_x.shape[1]
y = status[jnp.newaxis, Ellipsis]

2. Model

We will be using a logistic regression model with hierarchical prior on coefficient scales

\begin{align} \log \tau_0 & \sim \mathcal{N}(0,10) & \log \tau_i & \sim \mathcal{N}\left(\log \tau_0, 1\right) \\ \beta_i & \sim \mathcal{N}\left(0, \tau_i\right) & y & \sim \operatorname{Bernoulli}\left(\sigma\left(\beta X^T\right)\right) \end{align}

[ ]:
defgerman_credit():
 log_tau_zero = numpyro.sample("log_tau_zero", dist.Normal(0, 10))
 log_tau_i = numpyro.sample(
 "log_tau_i", dist.Normal(log_tau_zero, jnp.ones(num_features))
 )
 beta = numpyro.sample(
 "beta", dist.Normal(jnp.zeros(num_features), jnp.exp(log_tau_i))
 )
 numpyro.sample(
 "obs",
 dist.Bernoulli(logits=jnp.einsum("nd,md->mn", all_x, beta[jnp.newaxis, :])),
 obs=y,
 )
[ ]:
nuts_kernel = NUTS(german_credit)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
mcmc.run(rng_key, extra_fields=("num_steps",))
sample: 100%|██████████| 2000/2000 [00:21<00:00, 94.07it/s, 63 steps of size 6.31e-02. acc. prob=0.87]
[ ]:
mcmc.print_summary()
 mean std median 5.0% 95.0% n_eff r_hat
 beta[0] 0.13 0.38 0.05 -0.36 0.74 284.06 1.00
 beta[1] -0.34 0.12 -0.34 -0.52 -0.15 621.55 1.00
 beta[2] -0.27 0.13 -0.27 -0.45 -0.03 542.13 1.00
 beta[3] -0.30 0.10 -0.30 -0.44 -0.11 566.55 1.00
 beta[4] -0.00 0.07 -0.00 -0.12 0.11 782.35 1.00
 beta[5] 0.12 0.09 0.11 -0.02 0.27 728.28 1.01
 beta[6] -0.08 0.08 -0.07 -0.22 0.05 822.89 1.00
 beta[7] -0.05 0.07 -0.04 -0.19 0.05 752.66 1.00
 beta[8] -0.42 0.32 -0.39 -0.87 0.05 198.00 1.00
 beta[9] -0.07 0.26 -0.02 -0.50 0.31 220.27 1.00
 beta[10] 0.26 0.31 0.18 -0.15 0.78 404.97 1.00
 beta[11] 1.23 0.34 1.25 0.68 1.79 227.34 1.01
 beta[12] -0.26 0.34 -0.17 -0.81 0.22 349.10 1.00
 beta[13] -0.30 0.34 -0.21 -0.86 0.13 387.72 1.00
 beta[14] 0.07 0.20 0.04 -0.26 0.38 240.45 1.03
 beta[15] 0.10 0.22 0.05 -0.18 0.50 287.41 1.02
 beta[16] 0.76 0.30 0.76 0.22 1.24 364.73 1.03
 beta[17] -0.53 0.28 -0.55 -0.94 -0.05 269.95 1.00
 beta[18] 0.70 0.42 0.70 -0.02 1.29 367.28 1.00
 beta[19] 0.17 0.40 0.06 -0.43 0.77 333.54 1.00
 beta[20] 0.03 0.19 0.01 -0.23 0.39 381.57 1.00
 beta[21] 0.18 0.22 0.13 -0.14 0.53 335.48 1.00
 beta[22] -0.05 0.32 -0.01 -0.56 0.46 439.54 1.00
 beta[23] -0.10 0.30 -0.04 -0.63 0.30 508.20 1.00
 beta[24] -0.34 0.36 -0.25 -0.94 0.12 283.15 1.00
 beta[25] 0.14 0.40 0.04 -0.46 0.71 433.69 1.00
 beta[26] -0.01 0.19 -0.00 -0.34 0.28 438.64 1.00
 beta[27] -0.36 0.27 -0.33 -0.78 0.04 377.33 1.01
 beta[28] -0.07 0.22 -0.03 -0.43 0.26 493.09 1.00
 beta[29] 0.01 0.22 0.00 -0.32 0.34 448.21 1.00
 beta[30] 0.35 0.43 0.22 -0.18 1.08 314.69 1.00
 beta[31] 0.41 0.33 0.40 -0.10 0.90 402.62 1.00
 beta[32] -0.03 0.21 -0.01 -0.39 0.30 525.23 1.00
 beta[33] -0.12 0.18 -0.09 -0.41 0.16 334.94 1.00
 beta[34] -0.02 0.16 -0.01 -0.24 0.26 318.25 1.00
 beta[35] 0.42 0.27 0.42 -0.04 0.81 455.99 1.00
 beta[36] 0.05 0.17 0.03 -0.18 0.35 506.34 1.00
 beta[37] -0.12 0.25 -0.06 -0.57 0.21 470.11 1.00
 beta[38] -0.07 0.20 -0.04 -0.39 0.24 410.71 1.00
 beta[39] 0.36 0.24 0.35 -0.04 0.71 359.55 1.00
 beta[40] 0.05 0.20 0.02 -0.29 0.35 441.70 1.00
 beta[41] -0.00 0.21 0.00 -0.34 0.37 513.67 1.00
 beta[42] -0.13 0.27 -0.08 -0.59 0.23 402.64 1.00
 beta[43] 0.55 0.46 0.49 -0.11 1.28 570.74 1.00
 beta[44] 0.19 0.21 0.15 -0.14 0.50 379.76 1.00
 beta[45] -0.00 0.16 0.00 -0.25 0.26 352.19 1.00
 beta[46] 0.01 0.16 0.01 -0.25 0.25 411.05 1.00
 beta[47] -0.16 0.24 -0.11 -0.55 0.18 455.59 1.00
 beta[48] -0.12 0.24 -0.07 -0.55 0.21 322.67 1.04
 beta[49] -0.04 0.23 -0.02 -0.45 0.30 437.47 1.02
 beta[50] 0.38 0.28 0.37 -0.03 0.82 266.19 1.04
 beta[51] -0.14 0.22 -0.09 -0.52 0.16 406.31 1.00
 beta[52] 0.19 0.23 0.14 -0.14 0.55 338.97 1.00
 beta[53] 0.04 0.22 0.02 -0.23 0.43 438.03 1.00
 beta[54] 0.05 0.24 0.02 -0.32 0.41 522.43 1.00
 beta[55] 0.02 0.14 0.01 -0.22 0.23 562.00 1.00
 beta[56] -0.01 0.13 -0.01 -0.24 0.21 638.20 1.00
 beta[57] 0.01 0.17 0.00 -0.25 0.34 590.99 1.00
 beta[58] -0.07 0.18 -0.04 -0.34 0.23 481.37 1.00
 beta[59] 0.13 0.19 0.09 -0.12 0.47 507.56 1.00
 beta[60] -0.14 0.33 -0.06 -0.64 0.37 303.00 1.00
 beta[61] 0.48 0.56 0.32 -0.18 1.41 438.86 1.00
 log_tau_i[0] -1.51 0.95 -1.52 -3.03 0.11 290.78 1.00
 log_tau_i[1] -1.07 0.67 -1.11 -2.12 0.03 641.04 1.00
 log_tau_i[2] -1.24 0.76 -1.26 -2.47 0.03 666.31 1.00
 log_tau_i[3] -1.16 0.65 -1.19 -2.20 -0.10 821.60 1.00
 log_tau_i[4] -2.11 0.88 -2.13 -3.50 -0.61 806.15 1.00
 log_tau_i[5] -1.71 0.86 -1.68 -3.28 -0.44 697.00 1.00
 log_tau_i[6] -1.88 0.84 -1.91 -3.30 -0.58 623.56 1.00
 log_tau_i[7] -1.99 0.90 -1.98 -3.51 -0.65 710.21 1.00
 log_tau_i[8] -1.00 0.86 -0.96 -2.23 0.52 445.30 1.00
 log_tau_i[9] -1.69 0.93 -1.63 -3.17 -0.14 326.33 1.00
 log_tau_i[10] -1.41 0.95 -1.35 -2.93 0.19 441.60 1.01
 log_tau_i[11] -0.11 0.57 -0.12 -0.97 0.80 539.60 1.00
 log_tau_i[12] -1.36 0.96 -1.31 -3.16 0.01 336.11 1.00
 log_tau_i[13] -1.30 0.95 -1.26 -2.85 0.28 335.04 1.00
 log_tau_i[14] -1.72 0.89 -1.70 -3.05 -0.25 584.38 1.00
 log_tau_i[15] -1.65 0.92 -1.63 -3.07 -0.10 345.77 1.03
 log_tau_i[16] -0.51 0.65 -0.49 -1.42 0.59 676.64 1.00
 log_tau_i[17] -0.84 0.76 -0.76 -2.09 0.34 303.14 1.00
 log_tau_i[18] -0.69 0.82 -0.59 -2.03 0.61 359.35 1.00
 log_tau_i[19] -1.45 0.99 -1.42 -2.97 0.25 397.18 1.00
 log_tau_i[20] -1.75 0.94 -1.73 -3.39 -0.40 617.54 1.00
 log_tau_i[21] -1.51 0.88 -1.49 -3.16 -0.27 488.52 1.00
 log_tau_i[22] -1.56 0.93 -1.56 -3.06 -0.10 348.20 1.00
 log_tau_i[23] -1.58 0.94 -1.57 -3.05 0.02 278.69 1.00
 log_tau_i[24] -1.26 1.00 -1.12 -2.91 0.29 205.38 1.00
 log_tau_i[25] -1.53 0.95 -1.56 -3.11 0.02 351.09 1.00
 log_tau_i[26] -1.73 0.91 -1.74 -3.17 -0.22 492.18 1.00
 log_tau_i[27] -1.15 0.89 -1.08 -2.66 0.17 485.34 1.00
 log_tau_i[28] -1.69 0.92 -1.65 -3.15 -0.19 425.75 1.00
 log_tau_i[29] -1.71 0.99 -1.71 -3.19 0.01 374.58 1.00
 log_tau_i[30] -1.24 0.99 -1.20 -2.74 0.50 327.47 1.00
 log_tau_i[31] -1.02 0.89 -0.91 -2.40 0.51 587.85 1.00
 log_tau_i[32] -1.71 0.94 -1.70 -3.22 -0.11 511.74 1.00
 log_tau_i[33] -1.69 0.90 -1.68 -3.13 -0.28 538.65 1.00
 log_tau_i[34] -1.82 0.92 -1.81 -3.35 -0.35 423.01 1.00
 log_tau_i[35] -1.06 0.82 -1.00 -2.30 0.34 470.50 1.00
 log_tau_i[36] -1.79 0.87 -1.76 -3.15 -0.34 527.47 1.00
 log_tau_i[37] -1.58 0.95 -1.54 -3.11 0.04 485.52 1.00
 log_tau_i[38] -1.71 0.87 -1.65 -3.18 -0.34 482.67 1.00
 log_tau_i[39] -1.12 0.85 -1.01 -2.44 0.33 337.59 1.00
 log_tau_i[40] -1.76 0.96 -1.73 -3.58 -0.36 533.15 1.00
 log_tau_i[41] -1.74 0.94 -1.70 -3.26 -0.22 500.91 1.00
 log_tau_i[42] -1.57 0.95 -1.54 -3.04 0.01 499.44 1.00
 log_tau_i[43] -0.87 0.93 -0.74 -2.28 0.58 445.98 1.00
 log_tau_i[44] -1.52 0.89 -1.45 -2.95 -0.13 442.63 1.00
 log_tau_i[45] -1.84 0.94 -1.79 -3.21 -0.09 673.31 1.00
 log_tau_i[46] -1.82 0.85 -1.83 -3.26 -0.56 579.66 1.00
 log_tau_i[47] -1.54 0.90 -1.51 -3.35 -0.30 428.50 1.00
 log_tau_i[48] -1.62 0.89 -1.60 -3.00 -0.15 413.30 1.01
 log_tau_i[49] -1.71 0.95 -1.68 -3.23 -0.13 514.04 1.00
 log_tau_i[50] -1.12 0.92 -0.99 -2.67 0.38 206.76 1.03
 log_tau_i[51] -1.61 0.92 -1.58 -3.07 -0.03 477.41 1.00
 log_tau_i[52] -1.54 0.90 -1.49 -2.96 -0.09 459.83 1.00
 log_tau_i[53] -1.74 0.92 -1.69 -3.13 -0.14 509.51 1.00
 log_tau_i[54] -1.68 0.95 -1.67 -3.07 0.10 477.21 1.00
 log_tau_i[55] -1.87 0.97 -1.88 -3.49 -0.35 514.38 1.00
 log_tau_i[56] -1.87 0.96 -1.84 -3.23 -0.12 574.80 1.00
 log_tau_i[57] -1.77 0.86 -1.72 -3.26 -0.36 646.10 1.00
 log_tau_i[58] -1.78 0.92 -1.77 -3.18 -0.15 617.59 1.00
 log_tau_i[59] -1.67 0.93 -1.61 -3.19 -0.21 510.74 1.00
 log_tau_i[60] -1.50 0.99 -1.44 -3.09 0.08 386.86 1.00
 log_tau_i[61] -1.09 1.06 -1.00 -2.79 0.52 421.27 1.00
 log_tau_zero -1.49 0.26 -1.49 -1.90 -1.05 169.88 1.00
Number of divergences: 37

From mcmc.print_summary it is evident that there are 37 divergences. Thus, we will use Variationally Inferred Parameterization (VIP) to reduce these divergences

[ ]:
data = az.from_numpyro(mcmc)
az.plot_trace(data, compact=True);
../_images/tutorials_variationally_inferred_parameterization_16_0.png

3. Reparameterization

We introduce a parameterization parameters \(\lambda \in [0,1]\) for any variable \(z\), and transform:

=> \(z\) ~ \(N (z | μ, σ)\)

=> by defining \(z\) ~ \(N(λμ, σ^λ)\)

=> \(z\) = \(μ + σ^{1-λ}(z - λμ)\).

Thus, using the above transformation the joint density can be transformed as follows: \begin{align} p(\theta, \hat{\mu}, \mathbf{y}) & =\mathcal{N}(\theta \mid 0,1) \times \mathcal{N}\left(\mu \mid \theta, \sigma_\mu\right) \times \mathcal{N}(\mathbf{y} \mid \mu, \sigma) \end{align}

\begin{align} p(\theta, \hat{\mu}, \mathbf{y}) & =\mathcal{N}(\theta \mid 0,1) \times \mathcal{N}\left(\hat{\mu} \mid \lambda \theta, \sigma_\mu^\lambda\right) \times \mathcal{N}\left(\mathbf{y} \mid \theta+\sigma_\mu^{1-\lambda}(\hat{\mu}-\lambda \theta), \sigma\right) \end{align}

[ ]:
defgerman_credit_reparam(beta_centeredness=None):
 defmodel():
 log_tau_zero = numpyro.sample("log_tau_zero", dist.Normal(0, 10))
 log_tau_i = numpyro.sample(
 "log_tau_i", dist.Normal(log_tau_zero, jnp.ones(num_features))
 )
 with numpyro.handlers.reparam(
 config={"beta": LocScaleReparam(beta_centeredness)}
 ):
 beta = numpyro.sample(
 "beta", dist.Normal(jnp.zeros(num_features), jnp.exp(log_tau_i))
 )
 numpyro.sample(
 "obs",
 dist.Bernoulli(logits=jnp.einsum("nd,md->mn", all_x, beta[jnp.newaxis, :])),
 obs=y,
 )
 return model

Now, using SVI we optimize \(\lambda\).

[ ]:
model = german_credit_reparam()
guide = AutoDiagonalNormal(model)
svi = SVI(model, guide, numpyro.optim.Adam(3e-4), Trace_ELBO(10))
svi_results = svi.run(rng_key, 10000)
100%|██████████| 10000/10000 [00:16<00:00, 588.87it/s, init loss: 2165.2424, avg. loss [9501-10000]: 576.7846]
[ ]:
reparam_model = german_credit_reparam(
 beta_centeredness=svi_results.params["beta_centered"]
)
[ ]:
nuts_kernel = NUTS(reparam_model)
mcmc_reparam = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
mcmc_reparam.run(rng_key, extra_fields=("num_steps",))
sample: 100%|██████████| 2000/2000 [00:07<00:00, 285.41it/s, 31 steps of size 1.28e-01. acc. prob=0.89]
[ ]:
mcmc_reparam.print_summary()
 mean std median 5.0% 95.0% n_eff r_hat
 beta_decentered[0] 0.12 0.40 0.06 -0.48 0.80 338.70 1.00
 beta_decentered[1] -0.45 0.15 -0.45 -0.70 -0.21 791.23 1.00
 beta_decentered[2] -0.38 0.17 -0.38 -0.65 -0.09 691.79 1.00
 beta_decentered[3] -0.41 0.13 -0.41 -0.61 -0.19 1022.79 1.00
 beta_decentered[4] -0.01 0.11 -0.01 -0.18 0.20 1176.84 1.00
 beta_decentered[5] 0.19 0.14 0.19 -0.04 0.41 1194.41 1.00
 beta_decentered[6] -0.13 0.14 -0.13 -0.36 0.09 1227.24 1.00
 beta_decentered[7] -0.07 0.12 -0.06 -0.24 0.14 1096.31 1.00
 beta_decentered[8] -0.46 0.34 -0.46 -0.99 0.08 330.30 1.00
 beta_decentered[9] -0.03 0.32 -0.02 -0.57 0.49 310.35 1.00
beta_decentered[10] 0.35 0.39 0.30 -0.26 1.00 426.11 1.00
beta_decentered[11] 1.29 0.31 1.30 0.81 1.82 433.16 1.00
beta_decentered[12] -0.32 0.39 -0.25 -0.96 0.24 521.05 1.00
beta_decentered[13] -0.38 0.40 -0.32 -1.00 0.24 410.05 1.00
beta_decentered[14] 0.08 0.28 0.06 -0.37 0.57 457.72 1.00
beta_decentered[15] 0.14 0.30 0.10 -0.28 0.66 612.31 1.00
beta_decentered[16] 0.85 0.31 0.86 0.41 1.45 432.14 1.00
beta_decentered[17] -0.64 0.28 -0.65 -1.05 -0.14 523.15 1.00
beta_decentered[18] 0.78 0.42 0.78 0.07 1.46 545.52 1.00
beta_decentered[19] 0.15 0.39 0.08 -0.50 0.80 662.60 1.00
beta_decentered[20] 0.04 0.25 0.03 -0.39 0.40 445.85 1.00
beta_decentered[21] 0.24 0.27 0.21 -0.20 0.65 477.68 1.00
beta_decentered[22] -0.03 0.38 -0.01 -0.64 0.60 984.59 1.00
beta_decentered[23] -0.13 0.34 -0.08 -0.72 0.35 702.87 1.00
beta_decentered[24] -0.41 0.39 -0.37 -1.08 0.13 603.13 1.00
beta_decentered[25] 0.19 0.47 0.09 -0.48 0.92 529.68 1.00
beta_decentered[26] 0.00 0.25 0.01 -0.47 0.35 690.54 1.00
beta_decentered[27] -0.46 0.31 -0.46 -0.95 0.04 464.44 1.00
beta_decentered[28] -0.09 0.30 -0.06 -0.56 0.41 464.65 1.00
beta_decentered[29] 0.02 0.30 0.01 -0.47 0.52 747.44 1.00
beta_decentered[30] 0.38 0.44 0.31 -0.30 1.05 717.12 1.00
beta_decentered[31] 0.47 0.36 0.47 -0.09 1.03 564.18 1.00
beta_decentered[32] -0.03 0.26 -0.02 -0.44 0.44 572.03 1.00
beta_decentered[33] -0.17 0.25 -0.15 -0.63 0.19 713.40 1.00
beta_decentered[34] -0.02 0.21 -0.01 -0.39 0.32 620.45 1.01
beta_decentered[35] 0.53 0.31 0.55 -0.03 1.00 681.60 1.00
beta_decentered[36] 0.09 0.24 0.06 -0.27 0.49 610.06 1.00
beta_decentered[37] -0.14 0.31 -0.10 -0.74 0.28 826.87 1.00
beta_decentered[38] -0.12 0.25 -0.11 -0.53 0.30 493.49 1.00
beta_decentered[39] 0.44 0.28 0.44 -0.01 0.89 542.71 1.00
beta_decentered[40] 0.05 0.26 0.03 -0.39 0.45 709.78 1.00
beta_decentered[41] 0.02 0.30 0.01 -0.52 0.46 389.41 1.00
beta_decentered[42] -0.15 0.33 -0.10 -0.74 0.32 607.05 1.00
beta_decentered[43] 0.66 0.47 0.65 -0.11 1.38 539.31 1.01
beta_decentered[44] 0.25 0.27 0.23 -0.19 0.66 686.63 1.00
beta_decentered[45] -0.01 0.22 -0.00 -0.36 0.35 909.70 1.00
beta_decentered[46] 0.02 0.22 0.01 -0.33 0.39 741.33 1.00
beta_decentered[47] -0.25 0.31 -0.21 -0.72 0.25 487.26 1.00
beta_decentered[48] -0.18 0.30 -0.15 -0.67 0.30 400.57 1.00
beta_decentered[49] -0.07 0.30 -0.05 -0.56 0.41 594.38 1.00
beta_decentered[50] 0.46 0.31 0.46 -0.09 0.88 384.33 1.00
beta_decentered[51] -0.23 0.27 -0.19 -0.69 0.16 561.35 1.00
beta_decentered[52] 0.22 0.25 0.21 -0.19 0.60 456.41 1.00
beta_decentered[53] 0.07 0.29 0.04 -0.41 0.53 500.99 1.00
beta_decentered[54] 0.05 0.30 0.02 -0.49 0.51 896.08 1.00
beta_decentered[55] 0.01 0.23 0.01 -0.34 0.43 1033.13 1.00
beta_decentered[56] -0.02 0.19 -0.01 -0.33 0.28 717.87 1.00
beta_decentered[57] 0.01 0.23 0.01 -0.33 0.41 684.61 1.00
beta_decentered[58] -0.09 0.26 -0.08 -0.55 0.30 455.66 1.01
beta_decentered[59] 0.20 0.27 0.18 -0.20 0.63 413.57 1.01
beta_decentered[60] -0.14 0.37 -0.09 -0.76 0.46 411.83 1.00
beta_decentered[61] 0.58 0.57 0.50 -0.23 1.50 495.86 1.00
 log_tau_i[0] -1.56 0.91 -1.53 -2.94 -0.01 480.69 1.00
 log_tau_i[1] -1.07 0.64 -1.08 -1.99 0.06 950.02 1.00
 log_tau_i[2] -1.25 0.74 -1.23 -2.36 -0.06 796.37 1.00
 log_tau_i[3] -1.15 0.65 -1.18 -2.25 -0.14 838.40 1.00
 log_tau_i[4] -2.09 0.90 -2.12 -3.51 -0.49 1120.31 1.00
 log_tau_i[5] -1.70 0.84 -1.69 -3.10 -0.35 1291.74 1.00
 log_tau_i[6] -1.87 0.92 -1.84 -3.67 -0.58 1066.38 1.00
 log_tau_i[7] -2.06 0.89 -2.07 -3.42 -0.42 641.48 1.00
 log_tau_i[8] -1.11 0.91 -1.03 -2.79 0.19 482.69 1.00
 log_tau_i[9] -1.65 0.91 -1.62 -3.05 -0.04 772.73 1.00
 log_tau_i[10] -1.31 0.94 -1.27 -2.77 0.24 578.34 1.00
 log_tau_i[11] -0.04 0.49 -0.08 -0.81 0.77 736.20 1.00
 log_tau_i[12] -1.37 0.98 -1.32 -2.86 0.31 623.14 1.00
 log_tau_i[13] -1.28 0.97 -1.21 -2.78 0.30 653.51 1.00
 log_tau_i[14] -1.70 0.95 -1.71 -3.20 -0.15 831.88 1.00
 log_tau_i[15] -1.67 0.97 -1.65 -3.15 0.04 726.23 1.00
 log_tau_i[16] -0.53 0.65 -0.49 -1.55 0.54 558.56 1.00
 log_tau_i[17] -0.81 0.68 -0.80 -1.91 0.27 655.36 1.00
 log_tau_i[18] -0.71 0.83 -0.60 -1.89 0.75 536.55 1.00
 log_tau_i[19] -1.55 0.98 -1.53 -3.17 -0.02 675.27 1.00
 log_tau_i[20] -1.81 0.90 -1.79 -3.36 -0.41 823.93 1.00
 log_tau_i[21] -1.53 0.93 -1.51 -3.03 -0.04 767.64 1.00
 log_tau_i[22] -1.60 0.99 -1.58 -3.25 -0.01 840.71 1.00
 log_tau_i[23] -1.64 0.97 -1.60 -3.15 0.10 615.62 1.00
 log_tau_i[24] -1.21 0.96 -1.11 -2.61 0.47 674.58 1.00
 log_tau_i[25] -1.54 1.04 -1.49 -3.39 -0.01 509.69 1.00
 log_tau_i[26] -1.77 0.95 -1.74 -3.27 -0.13 915.91 1.00
 log_tau_i[27] -1.14 0.89 -1.07 -2.53 0.36 569.10 1.01
 log_tau_i[28] -1.72 0.93 -1.66 -3.25 -0.20 911.70 1.01
 log_tau_i[29] -1.70 0.91 -1.71 -3.16 -0.14 648.29 1.00
 log_tau_i[30] -1.28 1.02 -1.28 -2.99 0.34 575.90 1.00
 log_tau_i[31] -1.10 0.86 -1.04 -2.39 0.45 730.21 1.00
 log_tau_i[32] -1.79 0.95 -1.82 -3.32 -0.19 667.95 1.00
 log_tau_i[33] -1.64 0.86 -1.62 -3.04 -0.32 948.64 1.00
 log_tau_i[34] -1.87 0.92 -1.88 -3.29 -0.25 909.49 1.00
 log_tau_i[35] -1.00 0.85 -0.95 -2.44 0.25 672.42 1.00
 log_tau_i[36] -1.76 0.92 -1.73 -3.19 -0.22 889.56 1.00
 log_tau_i[37] -1.63 1.00 -1.58 -3.11 0.12 973.29 1.00
 log_tau_i[38] -1.72 0.85 -1.70 -3.13 -0.30 837.73 1.00
 log_tau_i[39] -1.15 0.80 -1.13 -2.32 0.18 627.15 1.00
 log_tau_i[40] -1.76 0.93 -1.70 -3.36 -0.34 686.88 1.00
 log_tau_i[41] -1.75 0.96 -1.74 -3.20 -0.09 612.83 1.00
 log_tau_i[42] -1.62 0.91 -1.62 -3.21 -0.25 596.18 1.00
 log_tau_i[43] -0.79 0.93 -0.74 -2.20 0.83 560.52 1.01
 log_tau_i[44] -1.52 0.94 -1.48 -3.08 0.01 888.27 1.00
 log_tau_i[45] -1.83 0.90 -1.84 -3.27 -0.44 1122.53 1.00
 log_tau_i[46] -1.83 0.89 -1.78 -3.21 -0.31 990.03 1.00
 log_tau_i[47] -1.56 0.93 -1.48 -3.20 -0.21 736.01 1.00
 log_tau_i[48] -1.59 0.94 -1.54 -3.20 -0.01 588.77 1.00
 log_tau_i[49] -1.71 0.93 -1.68 -3.17 -0.17 813.94 1.00
 log_tau_i[50] -1.15 0.83 -1.11 -2.51 0.14 514.68 1.00
 log_tau_i[51] -1.54 0.89 -1.51 -2.97 -0.13 780.67 1.00
 log_tau_i[52] -1.59 0.93 -1.56 -3.21 -0.28 807.47 1.00
 log_tau_i[53] -1.74 0.92 -1.72 -3.10 -0.15 657.89 1.00
 log_tau_i[54] -1.74 0.93 -1.74 -3.09 -0.10 867.29 1.00
 log_tau_i[55] -1.89 0.94 -1.87 -3.34 -0.27 1132.90 1.00
 log_tau_i[56] -1.89 0.92 -1.89 -3.27 -0.27 980.72 1.00
 log_tau_i[57] -1.84 0.91 -1.84 -3.33 -0.38 813.32 1.00
 log_tau_i[58] -1.75 0.93 -1.74 -3.25 -0.34 583.52 1.00
 log_tau_i[59] -1.59 0.96 -1.53 -3.11 -0.06 754.39 1.01
 log_tau_i[60] -1.58 0.98 -1.52 -3.23 -0.08 561.05 1.00
 log_tau_i[61] -0.98 1.03 -0.85 -2.55 0.79 502.63 1.00
 log_tau_zero -1.49 0.26 -1.49 -1.91 -1.10 220.32 1.00
Number of divergences: 1

The number of divergences have significantly reduced from 37 to 1.

[ ]:
data = az.from_numpyro(mcmc_reparam)
az.plot_trace(data, compact=True, figsize=(15, 25));
../_images/tutorials_variationally_inferred_parameterization_26_0.png

4. References:

  1. https://arxiv.org/abs/1906.03028

  2. https://github.com/mgorinova/autoreparam/tree/master