Variationally Inferred Parameterization
Author: Madhav Kanda
Occasionally, the Hamiltonian Monte Carlo (HMC) sampler encounters challenges in effectively sampling from the posterior distribution. One illustrative case is Neal’s funnel. In these situations, the conventional centered parameterization may prove inadequate, leading us to employ non-centered parameterization. However, there are instances where even non-centered parameterization may not suffice, necessitating the utilization of Variationally Inferred Parameterization to attain the desired centeredness within the range of 0 to 1.
The purpose of this tutorial is to implement Variationally Inferred Parameterization based on Automatic Reparameterization of Probabilistic Programs using LocScaleReparam in Numpyro.
[ ]:
%pip -qq install numpyro %pip -qq install ucimlrepo
[ ]:
importarvizasaz importnumpyasnp fromucimlrepoimport fetch_ucirepo importjax importjax.numpyasjnp importnumpyro importnumpyro.distributionsasdist fromnumpyro.inferimport MCMC, NUTS, SVI, Trace_ELBO fromnumpyro.infer.autoguideimport AutoDiagonalNormal fromnumpyro.infer.reparamimport LocScaleReparam rng_key = jax.random.PRNGKey(0)
1. Dataset
We will be using the German Credit Dataset for this illustration. The dataset consists of 1000 entries with 20 categorial symbolic attributes prepared by Prof. Hofmann. In this dataset, each entry represents a person who takes a credit by a bank. Each person is classified as good or bad credit risks according to the set of attributes.
[ ]:
defload_german_credit(): statlog_german_credit_data = fetch_ucirepo(id=144) X = statlog_german_credit_data.data.features y = statlog_german_credit_data.data.targets return X, y
[ ]:
X, y = load_german_credit() X
| Attribute1 | Attribute2 | Attribute3 | Attribute4 | Attribute5 | Attribute6 | Attribute7 | Attribute8 | Attribute9 | Attribute10 | Attribute11 | Attribute12 | Attribute13 | Attribute14 | Attribute15 | Attribute16 | Attribute17 | Attribute18 | Attribute19 | Attribute20 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | A11 | 6 | A34 | A43 | 1169 | A65 | A75 | 4 | A93 | A101 | 4 | A121 | 67 | A143 | A152 | 2 | A173 | 1 | A192 | A201 |
| 1 | A12 | 48 | A32 | A43 | 5951 | A61 | A73 | 2 | A92 | A101 | 2 | A121 | 22 | A143 | A152 | 1 | A173 | 1 | A191 | A201 |
| 2 | A14 | 12 | A34 | A46 | 2096 | A61 | A74 | 2 | A93 | A101 | 3 | A121 | 49 | A143 | A152 | 1 | A172 | 2 | A191 | A201 |
| 3 | A11 | 42 | A32 | A42 | 7882 | A61 | A74 | 2 | A93 | A103 | 4 | A122 | 45 | A143 | A153 | 1 | A173 | 2 | A191 | A201 |
| 4 | A11 | 24 | A33 | A40 | 4870 | A61 | A73 | 3 | A93 | A101 | 4 | A124 | 53 | A143 | A153 | 2 | A173 | 2 | A191 | A201 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 995 | A14 | 12 | A32 | A42 | 1736 | A61 | A74 | 3 | A92 | A101 | 4 | A121 | 31 | A143 | A152 | 1 | A172 | 1 | A191 | A201 |
| 996 | A11 | 30 | A32 | A41 | 3857 | A61 | A73 | 4 | A91 | A101 | 4 | A122 | 40 | A143 | A152 | 1 | A174 | 1 | A192 | A201 |
| 997 | A14 | 12 | A32 | A43 | 804 | A61 | A75 | 4 | A93 | A101 | 4 | A123 | 38 | A143 | A152 | 1 | A173 | 1 | A191 | A201 |
| 998 | A11 | 45 | A32 | A43 | 1845 | A61 | A73 | 4 | A93 | A101 | 4 | A124 | 23 | A143 | A153 | 1 | A173 | 1 | A192 | A201 |
| 999 | A12 | 45 | A34 | A41 | 4576 | A62 | A71 | 3 | A93 | A101 | 4 | A123 | 27 | A143 | A152 | 1 | A173 | 1 | A191 | A201 |
1000 rows ×ばつ 20 columns
Here, X depicts 20 attributes and the values corresponding to these attributes for each person represented in the data entry and y is the output variable corresponding to these attributes
[ ]:
defdata_transform(X, y): defcategorical_to_int(x): d = {u: i for i, u in enumerate(np.unique(x))} return np.array([d[i] for i in x]) categoricals = [] numericals = [] numericals.append(np.ones([len(y)])) for column in X: column = X[column] if column.dtype == "O": categoricals.append(categorical_to_int(column)) else: numericals.append((column - column.mean()) / column.std()) numericals = np.array(numericals).T status = np.array(y == 1, dtype=np.int32) status = np.squeeze(status) return jnp.array(numericals), jnp.array(categoricals), jnp.array(status)
Data transformation for feeding it into the Numpyro model
[ ]:
numericals, categoricals, status = data_transform(X, y)
[ ]:
x_numeric = numericals.astype(jnp.float32) x_categorical = [jnp.eye(c.max() + 1)[c] for c in categoricals] all_x = jnp.concatenate([x_numeric] + x_categorical, axis=1) num_features = all_x.shape[1] y = status[jnp.newaxis, Ellipsis]
2. Model
We will be using a logistic regression model with hierarchical prior on coefficient scales
\begin{align} \log \tau_0 & \sim \mathcal{N}(0,10) & \log \tau_i & \sim \mathcal{N}\left(\log \tau_0, 1\right) \\ \beta_i & \sim \mathcal{N}\left(0, \tau_i\right) & y & \sim \operatorname{Bernoulli}\left(\sigma\left(\beta X^T\right)\right) \end{align}
[ ]:
defgerman_credit(): log_tau_zero = numpyro.sample("log_tau_zero", dist.Normal(0, 10)) log_tau_i = numpyro.sample( "log_tau_i", dist.Normal(log_tau_zero, jnp.ones(num_features)) ) beta = numpyro.sample( "beta", dist.Normal(jnp.zeros(num_features), jnp.exp(log_tau_i)) ) numpyro.sample( "obs", dist.Bernoulli(logits=jnp.einsum("nd,md->mn", all_x, beta[jnp.newaxis, :])), obs=y, )
[ ]:
nuts_kernel = NUTS(german_credit) mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000) mcmc.run(rng_key, extra_fields=("num_steps",))
sample: 100%|██████████| 2000/2000 [00:21<00:00, 94.07it/s, 63 steps of size 6.31e-02. acc. prob=0.87]
[ ]:
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat beta[0] 0.13 0.38 0.05 -0.36 0.74 284.06 1.00 beta[1] -0.34 0.12 -0.34 -0.52 -0.15 621.55 1.00 beta[2] -0.27 0.13 -0.27 -0.45 -0.03 542.13 1.00 beta[3] -0.30 0.10 -0.30 -0.44 -0.11 566.55 1.00 beta[4] -0.00 0.07 -0.00 -0.12 0.11 782.35 1.00 beta[5] 0.12 0.09 0.11 -0.02 0.27 728.28 1.01 beta[6] -0.08 0.08 -0.07 -0.22 0.05 822.89 1.00 beta[7] -0.05 0.07 -0.04 -0.19 0.05 752.66 1.00 beta[8] -0.42 0.32 -0.39 -0.87 0.05 198.00 1.00 beta[9] -0.07 0.26 -0.02 -0.50 0.31 220.27 1.00 beta[10] 0.26 0.31 0.18 -0.15 0.78 404.97 1.00 beta[11] 1.23 0.34 1.25 0.68 1.79 227.34 1.01 beta[12] -0.26 0.34 -0.17 -0.81 0.22 349.10 1.00 beta[13] -0.30 0.34 -0.21 -0.86 0.13 387.72 1.00 beta[14] 0.07 0.20 0.04 -0.26 0.38 240.45 1.03 beta[15] 0.10 0.22 0.05 -0.18 0.50 287.41 1.02 beta[16] 0.76 0.30 0.76 0.22 1.24 364.73 1.03 beta[17] -0.53 0.28 -0.55 -0.94 -0.05 269.95 1.00 beta[18] 0.70 0.42 0.70 -0.02 1.29 367.28 1.00 beta[19] 0.17 0.40 0.06 -0.43 0.77 333.54 1.00 beta[20] 0.03 0.19 0.01 -0.23 0.39 381.57 1.00 beta[21] 0.18 0.22 0.13 -0.14 0.53 335.48 1.00 beta[22] -0.05 0.32 -0.01 -0.56 0.46 439.54 1.00 beta[23] -0.10 0.30 -0.04 -0.63 0.30 508.20 1.00 beta[24] -0.34 0.36 -0.25 -0.94 0.12 283.15 1.00 beta[25] 0.14 0.40 0.04 -0.46 0.71 433.69 1.00 beta[26] -0.01 0.19 -0.00 -0.34 0.28 438.64 1.00 beta[27] -0.36 0.27 -0.33 -0.78 0.04 377.33 1.01 beta[28] -0.07 0.22 -0.03 -0.43 0.26 493.09 1.00 beta[29] 0.01 0.22 0.00 -0.32 0.34 448.21 1.00 beta[30] 0.35 0.43 0.22 -0.18 1.08 314.69 1.00 beta[31] 0.41 0.33 0.40 -0.10 0.90 402.62 1.00 beta[32] -0.03 0.21 -0.01 -0.39 0.30 525.23 1.00 beta[33] -0.12 0.18 -0.09 -0.41 0.16 334.94 1.00 beta[34] -0.02 0.16 -0.01 -0.24 0.26 318.25 1.00 beta[35] 0.42 0.27 0.42 -0.04 0.81 455.99 1.00 beta[36] 0.05 0.17 0.03 -0.18 0.35 506.34 1.00 beta[37] -0.12 0.25 -0.06 -0.57 0.21 470.11 1.00 beta[38] -0.07 0.20 -0.04 -0.39 0.24 410.71 1.00 beta[39] 0.36 0.24 0.35 -0.04 0.71 359.55 1.00 beta[40] 0.05 0.20 0.02 -0.29 0.35 441.70 1.00 beta[41] -0.00 0.21 0.00 -0.34 0.37 513.67 1.00 beta[42] -0.13 0.27 -0.08 -0.59 0.23 402.64 1.00 beta[43] 0.55 0.46 0.49 -0.11 1.28 570.74 1.00 beta[44] 0.19 0.21 0.15 -0.14 0.50 379.76 1.00 beta[45] -0.00 0.16 0.00 -0.25 0.26 352.19 1.00 beta[46] 0.01 0.16 0.01 -0.25 0.25 411.05 1.00 beta[47] -0.16 0.24 -0.11 -0.55 0.18 455.59 1.00 beta[48] -0.12 0.24 -0.07 -0.55 0.21 322.67 1.04 beta[49] -0.04 0.23 -0.02 -0.45 0.30 437.47 1.02 beta[50] 0.38 0.28 0.37 -0.03 0.82 266.19 1.04 beta[51] -0.14 0.22 -0.09 -0.52 0.16 406.31 1.00 beta[52] 0.19 0.23 0.14 -0.14 0.55 338.97 1.00 beta[53] 0.04 0.22 0.02 -0.23 0.43 438.03 1.00 beta[54] 0.05 0.24 0.02 -0.32 0.41 522.43 1.00 beta[55] 0.02 0.14 0.01 -0.22 0.23 562.00 1.00 beta[56] -0.01 0.13 -0.01 -0.24 0.21 638.20 1.00 beta[57] 0.01 0.17 0.00 -0.25 0.34 590.99 1.00 beta[58] -0.07 0.18 -0.04 -0.34 0.23 481.37 1.00 beta[59] 0.13 0.19 0.09 -0.12 0.47 507.56 1.00 beta[60] -0.14 0.33 -0.06 -0.64 0.37 303.00 1.00 beta[61] 0.48 0.56 0.32 -0.18 1.41 438.86 1.00 log_tau_i[0] -1.51 0.95 -1.52 -3.03 0.11 290.78 1.00 log_tau_i[1] -1.07 0.67 -1.11 -2.12 0.03 641.04 1.00 log_tau_i[2] -1.24 0.76 -1.26 -2.47 0.03 666.31 1.00 log_tau_i[3] -1.16 0.65 -1.19 -2.20 -0.10 821.60 1.00 log_tau_i[4] -2.11 0.88 -2.13 -3.50 -0.61 806.15 1.00 log_tau_i[5] -1.71 0.86 -1.68 -3.28 -0.44 697.00 1.00 log_tau_i[6] -1.88 0.84 -1.91 -3.30 -0.58 623.56 1.00 log_tau_i[7] -1.99 0.90 -1.98 -3.51 -0.65 710.21 1.00 log_tau_i[8] -1.00 0.86 -0.96 -2.23 0.52 445.30 1.00 log_tau_i[9] -1.69 0.93 -1.63 -3.17 -0.14 326.33 1.00 log_tau_i[10] -1.41 0.95 -1.35 -2.93 0.19 441.60 1.01 log_tau_i[11] -0.11 0.57 -0.12 -0.97 0.80 539.60 1.00 log_tau_i[12] -1.36 0.96 -1.31 -3.16 0.01 336.11 1.00 log_tau_i[13] -1.30 0.95 -1.26 -2.85 0.28 335.04 1.00 log_tau_i[14] -1.72 0.89 -1.70 -3.05 -0.25 584.38 1.00 log_tau_i[15] -1.65 0.92 -1.63 -3.07 -0.10 345.77 1.03 log_tau_i[16] -0.51 0.65 -0.49 -1.42 0.59 676.64 1.00 log_tau_i[17] -0.84 0.76 -0.76 -2.09 0.34 303.14 1.00 log_tau_i[18] -0.69 0.82 -0.59 -2.03 0.61 359.35 1.00 log_tau_i[19] -1.45 0.99 -1.42 -2.97 0.25 397.18 1.00 log_tau_i[20] -1.75 0.94 -1.73 -3.39 -0.40 617.54 1.00 log_tau_i[21] -1.51 0.88 -1.49 -3.16 -0.27 488.52 1.00 log_tau_i[22] -1.56 0.93 -1.56 -3.06 -0.10 348.20 1.00 log_tau_i[23] -1.58 0.94 -1.57 -3.05 0.02 278.69 1.00 log_tau_i[24] -1.26 1.00 -1.12 -2.91 0.29 205.38 1.00 log_tau_i[25] -1.53 0.95 -1.56 -3.11 0.02 351.09 1.00 log_tau_i[26] -1.73 0.91 -1.74 -3.17 -0.22 492.18 1.00 log_tau_i[27] -1.15 0.89 -1.08 -2.66 0.17 485.34 1.00 log_tau_i[28] -1.69 0.92 -1.65 -3.15 -0.19 425.75 1.00 log_tau_i[29] -1.71 0.99 -1.71 -3.19 0.01 374.58 1.00 log_tau_i[30] -1.24 0.99 -1.20 -2.74 0.50 327.47 1.00 log_tau_i[31] -1.02 0.89 -0.91 -2.40 0.51 587.85 1.00 log_tau_i[32] -1.71 0.94 -1.70 -3.22 -0.11 511.74 1.00 log_tau_i[33] -1.69 0.90 -1.68 -3.13 -0.28 538.65 1.00 log_tau_i[34] -1.82 0.92 -1.81 -3.35 -0.35 423.01 1.00 log_tau_i[35] -1.06 0.82 -1.00 -2.30 0.34 470.50 1.00 log_tau_i[36] -1.79 0.87 -1.76 -3.15 -0.34 527.47 1.00 log_tau_i[37] -1.58 0.95 -1.54 -3.11 0.04 485.52 1.00 log_tau_i[38] -1.71 0.87 -1.65 -3.18 -0.34 482.67 1.00 log_tau_i[39] -1.12 0.85 -1.01 -2.44 0.33 337.59 1.00 log_tau_i[40] -1.76 0.96 -1.73 -3.58 -0.36 533.15 1.00 log_tau_i[41] -1.74 0.94 -1.70 -3.26 -0.22 500.91 1.00 log_tau_i[42] -1.57 0.95 -1.54 -3.04 0.01 499.44 1.00 log_tau_i[43] -0.87 0.93 -0.74 -2.28 0.58 445.98 1.00 log_tau_i[44] -1.52 0.89 -1.45 -2.95 -0.13 442.63 1.00 log_tau_i[45] -1.84 0.94 -1.79 -3.21 -0.09 673.31 1.00 log_tau_i[46] -1.82 0.85 -1.83 -3.26 -0.56 579.66 1.00 log_tau_i[47] -1.54 0.90 -1.51 -3.35 -0.30 428.50 1.00 log_tau_i[48] -1.62 0.89 -1.60 -3.00 -0.15 413.30 1.01 log_tau_i[49] -1.71 0.95 -1.68 -3.23 -0.13 514.04 1.00 log_tau_i[50] -1.12 0.92 -0.99 -2.67 0.38 206.76 1.03 log_tau_i[51] -1.61 0.92 -1.58 -3.07 -0.03 477.41 1.00 log_tau_i[52] -1.54 0.90 -1.49 -2.96 -0.09 459.83 1.00 log_tau_i[53] -1.74 0.92 -1.69 -3.13 -0.14 509.51 1.00 log_tau_i[54] -1.68 0.95 -1.67 -3.07 0.10 477.21 1.00 log_tau_i[55] -1.87 0.97 -1.88 -3.49 -0.35 514.38 1.00 log_tau_i[56] -1.87 0.96 -1.84 -3.23 -0.12 574.80 1.00 log_tau_i[57] -1.77 0.86 -1.72 -3.26 -0.36 646.10 1.00 log_tau_i[58] -1.78 0.92 -1.77 -3.18 -0.15 617.59 1.00 log_tau_i[59] -1.67 0.93 -1.61 -3.19 -0.21 510.74 1.00 log_tau_i[60] -1.50 0.99 -1.44 -3.09 0.08 386.86 1.00 log_tau_i[61] -1.09 1.06 -1.00 -2.79 0.52 421.27 1.00 log_tau_zero -1.49 0.26 -1.49 -1.90 -1.05 169.88 1.00 Number of divergences: 37
From mcmc.print_summary it is evident that there are 37 divergences. Thus, we will use Variationally Inferred Parameterization (VIP) to reduce these divergences
[ ]:
data = az.from_numpyro(mcmc) az.plot_trace(data, compact=True);
3. Reparameterization
We introduce a parameterization parameters \(\lambda \in [0,1]\) for any variable \(z\), and transform:
=> \(z\) ~ \(N (z | μ, σ)\)
=> by defining \(z\) ~ \(N(λμ, σ^λ)\)
=> \(z\) = \(μ + σ^{1-λ}(z - λμ)\).
Thus, using the above transformation the joint density can be transformed as follows: \begin{align} p(\theta, \hat{\mu}, \mathbf{y}) & =\mathcal{N}(\theta \mid 0,1) \times \mathcal{N}\left(\mu \mid \theta, \sigma_\mu\right) \times \mathcal{N}(\mathbf{y} \mid \mu, \sigma) \end{align}
\begin{align} p(\theta, \hat{\mu}, \mathbf{y}) & =\mathcal{N}(\theta \mid 0,1) \times \mathcal{N}\left(\hat{\mu} \mid \lambda \theta, \sigma_\mu^\lambda\right) \times \mathcal{N}\left(\mathbf{y} \mid \theta+\sigma_\mu^{1-\lambda}(\hat{\mu}-\lambda \theta), \sigma\right) \end{align}
[ ]:
defgerman_credit_reparam(beta_centeredness=None): defmodel(): log_tau_zero = numpyro.sample("log_tau_zero", dist.Normal(0, 10)) log_tau_i = numpyro.sample( "log_tau_i", dist.Normal(log_tau_zero, jnp.ones(num_features)) ) with numpyro.handlers.reparam( config={"beta": LocScaleReparam(beta_centeredness)} ): beta = numpyro.sample( "beta", dist.Normal(jnp.zeros(num_features), jnp.exp(log_tau_i)) ) numpyro.sample( "obs", dist.Bernoulli(logits=jnp.einsum("nd,md->mn", all_x, beta[jnp.newaxis, :])), obs=y, ) return model
Now, using SVI we optimize \(\lambda\).
[ ]:
model = german_credit_reparam() guide = AutoDiagonalNormal(model) svi = SVI(model, guide, numpyro.optim.Adam(3e-4), Trace_ELBO(10)) svi_results = svi.run(rng_key, 10000)
100%|██████████| 10000/10000 [00:16<00:00, 588.87it/s, init loss: 2165.2424, avg. loss [9501-10000]: 576.7846]
[ ]:
reparam_model = german_credit_reparam( beta_centeredness=svi_results.params["beta_centered"] )
[ ]:
nuts_kernel = NUTS(reparam_model) mcmc_reparam = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000) mcmc_reparam.run(rng_key, extra_fields=("num_steps",))
sample: 100%|██████████| 2000/2000 [00:07<00:00, 285.41it/s, 31 steps of size 1.28e-01. acc. prob=0.89]
[ ]:
mcmc_reparam.print_summary()
mean std median 5.0% 95.0% n_eff r_hat beta_decentered[0] 0.12 0.40 0.06 -0.48 0.80 338.70 1.00 beta_decentered[1] -0.45 0.15 -0.45 -0.70 -0.21 791.23 1.00 beta_decentered[2] -0.38 0.17 -0.38 -0.65 -0.09 691.79 1.00 beta_decentered[3] -0.41 0.13 -0.41 -0.61 -0.19 1022.79 1.00 beta_decentered[4] -0.01 0.11 -0.01 -0.18 0.20 1176.84 1.00 beta_decentered[5] 0.19 0.14 0.19 -0.04 0.41 1194.41 1.00 beta_decentered[6] -0.13 0.14 -0.13 -0.36 0.09 1227.24 1.00 beta_decentered[7] -0.07 0.12 -0.06 -0.24 0.14 1096.31 1.00 beta_decentered[8] -0.46 0.34 -0.46 -0.99 0.08 330.30 1.00 beta_decentered[9] -0.03 0.32 -0.02 -0.57 0.49 310.35 1.00 beta_decentered[10] 0.35 0.39 0.30 -0.26 1.00 426.11 1.00 beta_decentered[11] 1.29 0.31 1.30 0.81 1.82 433.16 1.00 beta_decentered[12] -0.32 0.39 -0.25 -0.96 0.24 521.05 1.00 beta_decentered[13] -0.38 0.40 -0.32 -1.00 0.24 410.05 1.00 beta_decentered[14] 0.08 0.28 0.06 -0.37 0.57 457.72 1.00 beta_decentered[15] 0.14 0.30 0.10 -0.28 0.66 612.31 1.00 beta_decentered[16] 0.85 0.31 0.86 0.41 1.45 432.14 1.00 beta_decentered[17] -0.64 0.28 -0.65 -1.05 -0.14 523.15 1.00 beta_decentered[18] 0.78 0.42 0.78 0.07 1.46 545.52 1.00 beta_decentered[19] 0.15 0.39 0.08 -0.50 0.80 662.60 1.00 beta_decentered[20] 0.04 0.25 0.03 -0.39 0.40 445.85 1.00 beta_decentered[21] 0.24 0.27 0.21 -0.20 0.65 477.68 1.00 beta_decentered[22] -0.03 0.38 -0.01 -0.64 0.60 984.59 1.00 beta_decentered[23] -0.13 0.34 -0.08 -0.72 0.35 702.87 1.00 beta_decentered[24] -0.41 0.39 -0.37 -1.08 0.13 603.13 1.00 beta_decentered[25] 0.19 0.47 0.09 -0.48 0.92 529.68 1.00 beta_decentered[26] 0.00 0.25 0.01 -0.47 0.35 690.54 1.00 beta_decentered[27] -0.46 0.31 -0.46 -0.95 0.04 464.44 1.00 beta_decentered[28] -0.09 0.30 -0.06 -0.56 0.41 464.65 1.00 beta_decentered[29] 0.02 0.30 0.01 -0.47 0.52 747.44 1.00 beta_decentered[30] 0.38 0.44 0.31 -0.30 1.05 717.12 1.00 beta_decentered[31] 0.47 0.36 0.47 -0.09 1.03 564.18 1.00 beta_decentered[32] -0.03 0.26 -0.02 -0.44 0.44 572.03 1.00 beta_decentered[33] -0.17 0.25 -0.15 -0.63 0.19 713.40 1.00 beta_decentered[34] -0.02 0.21 -0.01 -0.39 0.32 620.45 1.01 beta_decentered[35] 0.53 0.31 0.55 -0.03 1.00 681.60 1.00 beta_decentered[36] 0.09 0.24 0.06 -0.27 0.49 610.06 1.00 beta_decentered[37] -0.14 0.31 -0.10 -0.74 0.28 826.87 1.00 beta_decentered[38] -0.12 0.25 -0.11 -0.53 0.30 493.49 1.00 beta_decentered[39] 0.44 0.28 0.44 -0.01 0.89 542.71 1.00 beta_decentered[40] 0.05 0.26 0.03 -0.39 0.45 709.78 1.00 beta_decentered[41] 0.02 0.30 0.01 -0.52 0.46 389.41 1.00 beta_decentered[42] -0.15 0.33 -0.10 -0.74 0.32 607.05 1.00 beta_decentered[43] 0.66 0.47 0.65 -0.11 1.38 539.31 1.01 beta_decentered[44] 0.25 0.27 0.23 -0.19 0.66 686.63 1.00 beta_decentered[45] -0.01 0.22 -0.00 -0.36 0.35 909.70 1.00 beta_decentered[46] 0.02 0.22 0.01 -0.33 0.39 741.33 1.00 beta_decentered[47] -0.25 0.31 -0.21 -0.72 0.25 487.26 1.00 beta_decentered[48] -0.18 0.30 -0.15 -0.67 0.30 400.57 1.00 beta_decentered[49] -0.07 0.30 -0.05 -0.56 0.41 594.38 1.00 beta_decentered[50] 0.46 0.31 0.46 -0.09 0.88 384.33 1.00 beta_decentered[51] -0.23 0.27 -0.19 -0.69 0.16 561.35 1.00 beta_decentered[52] 0.22 0.25 0.21 -0.19 0.60 456.41 1.00 beta_decentered[53] 0.07 0.29 0.04 -0.41 0.53 500.99 1.00 beta_decentered[54] 0.05 0.30 0.02 -0.49 0.51 896.08 1.00 beta_decentered[55] 0.01 0.23 0.01 -0.34 0.43 1033.13 1.00 beta_decentered[56] -0.02 0.19 -0.01 -0.33 0.28 717.87 1.00 beta_decentered[57] 0.01 0.23 0.01 -0.33 0.41 684.61 1.00 beta_decentered[58] -0.09 0.26 -0.08 -0.55 0.30 455.66 1.01 beta_decentered[59] 0.20 0.27 0.18 -0.20 0.63 413.57 1.01 beta_decentered[60] -0.14 0.37 -0.09 -0.76 0.46 411.83 1.00 beta_decentered[61] 0.58 0.57 0.50 -0.23 1.50 495.86 1.00 log_tau_i[0] -1.56 0.91 -1.53 -2.94 -0.01 480.69 1.00 log_tau_i[1] -1.07 0.64 -1.08 -1.99 0.06 950.02 1.00 log_tau_i[2] -1.25 0.74 -1.23 -2.36 -0.06 796.37 1.00 log_tau_i[3] -1.15 0.65 -1.18 -2.25 -0.14 838.40 1.00 log_tau_i[4] -2.09 0.90 -2.12 -3.51 -0.49 1120.31 1.00 log_tau_i[5] -1.70 0.84 -1.69 -3.10 -0.35 1291.74 1.00 log_tau_i[6] -1.87 0.92 -1.84 -3.67 -0.58 1066.38 1.00 log_tau_i[7] -2.06 0.89 -2.07 -3.42 -0.42 641.48 1.00 log_tau_i[8] -1.11 0.91 -1.03 -2.79 0.19 482.69 1.00 log_tau_i[9] -1.65 0.91 -1.62 -3.05 -0.04 772.73 1.00 log_tau_i[10] -1.31 0.94 -1.27 -2.77 0.24 578.34 1.00 log_tau_i[11] -0.04 0.49 -0.08 -0.81 0.77 736.20 1.00 log_tau_i[12] -1.37 0.98 -1.32 -2.86 0.31 623.14 1.00 log_tau_i[13] -1.28 0.97 -1.21 -2.78 0.30 653.51 1.00 log_tau_i[14] -1.70 0.95 -1.71 -3.20 -0.15 831.88 1.00 log_tau_i[15] -1.67 0.97 -1.65 -3.15 0.04 726.23 1.00 log_tau_i[16] -0.53 0.65 -0.49 -1.55 0.54 558.56 1.00 log_tau_i[17] -0.81 0.68 -0.80 -1.91 0.27 655.36 1.00 log_tau_i[18] -0.71 0.83 -0.60 -1.89 0.75 536.55 1.00 log_tau_i[19] -1.55 0.98 -1.53 -3.17 -0.02 675.27 1.00 log_tau_i[20] -1.81 0.90 -1.79 -3.36 -0.41 823.93 1.00 log_tau_i[21] -1.53 0.93 -1.51 -3.03 -0.04 767.64 1.00 log_tau_i[22] -1.60 0.99 -1.58 -3.25 -0.01 840.71 1.00 log_tau_i[23] -1.64 0.97 -1.60 -3.15 0.10 615.62 1.00 log_tau_i[24] -1.21 0.96 -1.11 -2.61 0.47 674.58 1.00 log_tau_i[25] -1.54 1.04 -1.49 -3.39 -0.01 509.69 1.00 log_tau_i[26] -1.77 0.95 -1.74 -3.27 -0.13 915.91 1.00 log_tau_i[27] -1.14 0.89 -1.07 -2.53 0.36 569.10 1.01 log_tau_i[28] -1.72 0.93 -1.66 -3.25 -0.20 911.70 1.01 log_tau_i[29] -1.70 0.91 -1.71 -3.16 -0.14 648.29 1.00 log_tau_i[30] -1.28 1.02 -1.28 -2.99 0.34 575.90 1.00 log_tau_i[31] -1.10 0.86 -1.04 -2.39 0.45 730.21 1.00 log_tau_i[32] -1.79 0.95 -1.82 -3.32 -0.19 667.95 1.00 log_tau_i[33] -1.64 0.86 -1.62 -3.04 -0.32 948.64 1.00 log_tau_i[34] -1.87 0.92 -1.88 -3.29 -0.25 909.49 1.00 log_tau_i[35] -1.00 0.85 -0.95 -2.44 0.25 672.42 1.00 log_tau_i[36] -1.76 0.92 -1.73 -3.19 -0.22 889.56 1.00 log_tau_i[37] -1.63 1.00 -1.58 -3.11 0.12 973.29 1.00 log_tau_i[38] -1.72 0.85 -1.70 -3.13 -0.30 837.73 1.00 log_tau_i[39] -1.15 0.80 -1.13 -2.32 0.18 627.15 1.00 log_tau_i[40] -1.76 0.93 -1.70 -3.36 -0.34 686.88 1.00 log_tau_i[41] -1.75 0.96 -1.74 -3.20 -0.09 612.83 1.00 log_tau_i[42] -1.62 0.91 -1.62 -3.21 -0.25 596.18 1.00 log_tau_i[43] -0.79 0.93 -0.74 -2.20 0.83 560.52 1.01 log_tau_i[44] -1.52 0.94 -1.48 -3.08 0.01 888.27 1.00 log_tau_i[45] -1.83 0.90 -1.84 -3.27 -0.44 1122.53 1.00 log_tau_i[46] -1.83 0.89 -1.78 -3.21 -0.31 990.03 1.00 log_tau_i[47] -1.56 0.93 -1.48 -3.20 -0.21 736.01 1.00 log_tau_i[48] -1.59 0.94 -1.54 -3.20 -0.01 588.77 1.00 log_tau_i[49] -1.71 0.93 -1.68 -3.17 -0.17 813.94 1.00 log_tau_i[50] -1.15 0.83 -1.11 -2.51 0.14 514.68 1.00 log_tau_i[51] -1.54 0.89 -1.51 -2.97 -0.13 780.67 1.00 log_tau_i[52] -1.59 0.93 -1.56 -3.21 -0.28 807.47 1.00 log_tau_i[53] -1.74 0.92 -1.72 -3.10 -0.15 657.89 1.00 log_tau_i[54] -1.74 0.93 -1.74 -3.09 -0.10 867.29 1.00 log_tau_i[55] -1.89 0.94 -1.87 -3.34 -0.27 1132.90 1.00 log_tau_i[56] -1.89 0.92 -1.89 -3.27 -0.27 980.72 1.00 log_tau_i[57] -1.84 0.91 -1.84 -3.33 -0.38 813.32 1.00 log_tau_i[58] -1.75 0.93 -1.74 -3.25 -0.34 583.52 1.00 log_tau_i[59] -1.59 0.96 -1.53 -3.11 -0.06 754.39 1.01 log_tau_i[60] -1.58 0.98 -1.52 -3.23 -0.08 561.05 1.00 log_tau_i[61] -0.98 1.03 -0.85 -2.55 0.79 502.63 1.00 log_tau_zero -1.49 0.26 -1.49 -1.91 -1.10 220.32 1.00 Number of divergences: 1
The number of divergences have significantly reduced from 37 to 1.
[ ]:
data = az.from_numpyro(mcmc_reparam) az.plot_trace(data, compact=True, figsize=(15, 25));