Part II: Trans-dimensional GMM

In this tutorial, we demonstrate how to use BayesBay to retrieve a Gaussian Mixture Model (GMM) while treating the number of components in the mixture as a free parameter, to be inferred from the data. If you haven’t already, we recommend reviewing Part I: Known Number of Mixture Components; this provides a detailed explanation of the code required to define and solve the simpler problem of deriving a GMM with a predetermined number of mixture components.

Import libraries and define constants

import bayesbay as bb
from math import sqrt, pi
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(30)
MEANS = [140, 162, 177]  # Means of the Gaussians
STDS = [12, 5, 6]  # Standard deviations of the Gaussians
WEIGHTS = [0.4, 0.3, 0.3]  # Weights of each Gaussian in the mixture
N_SAMPLES = 10_000 # Number of samples to generate
def gaussian(x, mu, sigma):
    return 1 / (sigma * sqrt(2*pi)) * np.exp(-0.5 * ((x - mu) / sigma)**2)

Observed data

assert np.isclose(sum(WEIGHTS), 1), "The weights must sum to 1."

def generate_random_samples(means, stds, weights):
    """Generate random samples for each component of the Gaussian mixture"""
    samples = []
    for mean, std, weight in zip(means, stds, weights):
        n_samples = int(N_SAMPLES * weight) # number of samples for each component
        samples.append(np.random.normal(mean, std, n_samples))
    samples = np.concatenate(samples)
    np.random.shuffle(samples)
    return samples

samples = generate_random_samples(MEANS, STDS, WEIGHTS)
data_obs, bins = np.histogram(samples, bins=50, density=True)
data_x = (bins[:-1] + bins[1:]) / 2 # Height associated with each data point
# Define the range of x-values (i.e., heights) over which to evaluate the PDF
x_min, x_max = min(MEANS) - 3 * max(STDS), max(MEANS) + 3 * max(STDS)
xs = np.linspace(x_min, x_max, 1000)

pdf_true = np.zeros_like(xs)
for mean, std, weight in zip(MEANS, STDS, WEIGHTS):
    pdf_true += weight * gaussian(xs, mean, std)
fig, ax = plt.subplots()
ax.set_title('Histogram of measured heights')
ax.hist(samples, bins=50, density=True, ec='w')
data_x = (bins[:-1] + bins[1:]) / 2 # take the midpoint of each bin
plt.plot(xs, pdf_true, label='True mixture PDF', color='r')
plt.plot(data_x, data_obs, 'ko', label='Observed data', markerfacecolor='None')

ax.set_xlabel('Height')
ax.set_ylabel('Density')
ax.grid()
ax.legend()
plt.show()
../_images/3bf2e7066808b970191d9f4b3a5e8e5cff97e03174eb3f5093a9b64fb3624667.png

Setting up the Bayesian sampling

Prior probability

mean = bb.prior.UniformPrior(name="mean", vmin=100, vmax=200, perturb_std=2)
std = bb.prior.UniformPrior(name="std", vmin=1, vmax=20, perturb_std=0.4)
weight = bb.prior.UniformPrior(name="weight", vmin=0, vmax=1, perturb_std=0.02)

Parameter space and parameterization

As explained in Part I: Known Number of Mixture Components, the above free parameters should be used in BayesBay to create what we call a ParameterSpace. ParameterSpace can be conceptualized as an \(n\)-dimensional vector space, or similarly, as a specialized container. It not only groups a number of free parameters but also determines their dimensionality.

Important

The dimensionality of ParameterSpace can be treated as unknown by setting n_dimensions=None at the initialization of the class instance, as shown in the following block. This enables the definition of trans-dimensional inference problems.

In this tutorial’s case, ParameterSpace will contain three distinct free parameters (namely, \(\boldsymbol{\omega}\), \(\boldsymbol{\mu}\), and \(\boldsymbol{\sigma}\)). Each parameter is a vector whose dimensionality will align with that of ParameterSpace throughout the inversion process. Whenever a dimension is added to or removed from ParameterSpace (via a BirthPerturbation/DeathPerturbation), the dimensionality of the free parameters linked to it is correspondingly adjusted.

param_space = bb.parameterization.ParameterSpace(
    name='my_param_space', 
    n_dimensions=None, # Trans-dimensional setting
    n_dimensions_min=1, # Minimum number of dimensions (i.e., Gaussians in the mixture)
    n_dimensions_max=7, # Maximum number of dimensions (i.e., Gaussians in the mixture)
    parameters=[mean, std, weight], 
)
parameterization = bb.parameterization.Parameterization(param_space)

Forward problem

def _forward(means, stds, weights):
    weights /= np.sum(weights)
    data_pred = np.zeros_like(data_x)
    for i in range(len(means)):
        data_pred += weights[i] * gaussian(data_x, means[i], stds[i])
    return data_pred

def fwd_function(state: bb.State) -> np.ndarray:
    means = state['my_param_space']['mean']
    stds = state['my_param_space']['std']
    weights = state['my_param_space']['weight']
    return _forward(means, stds, weights)

Observed data: the Target

target = bb.Target('my_data', 
                   data_obs, 
                   std_min=0, 
                   std_max=0.01, 
                   std_perturb_std=0.001,
                   noise_is_correlated=False)

5. Log Likelihood

log_likelihood = bb.LogLikelihood(targets=target, fwd_functions=fwd_function)

Run the Bayesian sampling

Given the trans-dimensional nature of the inverse problem we are addressing, we will employ 20 Markov chains running in parallel for 400,000 iterations, discarding the initial 100,000 as part of the burn-in phase. Additionally, we set the temperature of each chain to 5 using our SimulatedAnnealing class (for details on simulated annealing, see Kirkpatrick et al. 1983). Temperatures higher than one downweight the impact of the likelihood ratio \(\left[\frac{p\left(\mathbf{d}_{obs} \mid \mathbf{m'}\right)}{p\left(\mathbf{d}_{obs} \mid \mathbf{m}\right)}\right]^{\frac{1}{T}}\) on the acceptance probability of the perturbed model \(\mathbf{m}'\), thereby enhancing each chain’s exploration freedom. During the burn-in phase, the temperature of each chain is gradually reduced according to an exponential decay, reaching a value of 1 by the end of this phase.

See also

Through the module bayesbay.samplers you can implement arbitrary sampling criteria or use our built-in class ParallelTempering (e.g., Sambridge 2014).

inversion = bb.BayesianInversion(
    log_likelihood=log_likelihood, 
    parameterization=parameterization, 
    n_chains=20
)
inversion.run(
    sampler=bb.samplers.SimulatedAnnealing(temperature_start=5),
    n_iterations=450_000, 
    burnin_iterations=150_000, 
    save_every=100, 
    verbose=False, 
)
for chain in inversion.chains:
    chain.print_statistics()
Chain ID: 0
TEMPERATURE: 1
EXPLORED MODELS: 450000
ACCEPTANCE RATE: 65197/450000 (14.49 %)
PARTIAL ACCEPTANCE RATES:
	BirthPerturbation(my_param_space): 2825/75051 (3.76%)
	DeathPerturbation(my_param_space): 2822/73838 (3.82%)
	NoisePerturbation(my_data): 15740/74627 (21.09%)
	ParamPerturbation(['my_param_space.mean', 'my_param_space.std', 'my_param_space.weight']): 43810/226484 (19.34%)
Chain ID: 1
TEMPERATURE: 1
EXPLORED MODELS: 450000
ACCEPTANCE RATE: 66472/450000 (14.77 %)
PARTIAL ACCEPTANCE RATES:
	BirthPerturbation(my_param_space): 2908/75077 (3.87%)
	DeathPerturbation(my_param_space): 2906/74040 (3.92%)
	NoisePerturbation(my_data): 15748/75349 (20.90%)
	ParamPerturbation(['my_param_space.mean', 'my_param_space.std', 'my_param_space.weight']): 44910/225534 (19.91%)
Chain ID: 2
TEMPERATURE: 1
EXPLORED MODELS: 450000
ACCEPTANCE RATE: 59628/450000 (13.25 %)
PARTIAL ACCEPTANCE RATES:
	BirthPerturbation(my_param_space): 2613/75539 (3.46%)
	DeathPerturbation(my_param_space): 2611/74536 (3.50%)
	NoisePerturbation(my_data): 15146/74962 (20.20%)
	ParamPerturbation(['my_param_space.mean', 'my_param_space.std', 'my_param_space.weight']): 39258/224963 (17.45%)
Chain ID: 3
TEMPERATURE: 1
EXPLORED MODELS: 450000
ACCEPTANCE RATE: 60787/450000 (13.51 %)
PARTIAL ACCEPTANCE RATES:
	BirthPerturbation(my_param_space): 2439/74394 (3.28%)
	DeathPerturbation(my_param_space): 2438/74830 (3.26%)
	NoisePerturbation(my_data): 14820/75220 (19.70%)
	ParamPerturbation(['my_param_space.mean', 'my_param_space.std', 'my_param_space.weight']): 41090/225556 (18.22%)
Chain ID: 4
TEMPERATURE: 1
EXPLORED MODELS: 450000
ACCEPTANCE RATE: 59737/450000 (13.27 %)
PARTIAL ACCEPTANCE RATES:
	BirthPerturbation(my_param_space): 2559/75227 (3.40%)
	DeathPerturbation(my_param_space): 2557/74293 (3.44%)
	NoisePerturbation(my_data): 15341/75074 (20.43%)
	ParamPerturbation(['my_param_space.mean', 'my_param_space.std', 'my_param_space.weight']): 39280/225406 (17.43%)
Chain ID: 5
TEMPERATURE: 1
EXPLORED MODELS: 450000
ACCEPTANCE RATE: 58972/450000 (13.10 %)
PARTIAL ACCEPTANCE RATES:
	BirthPerturbation(my_param_space): 2478/74763 (3.31%)
	DeathPerturbation(my_param_space): 2475/74234 (3.33%)
	NoisePerturbation(my_data): 15324/75107 (20.40%)
	ParamPerturbation(['my_param_space.mean', 'my_param_space.std', 'my_param_space.weight']): 38695/225896 (17.13%)
Chain ID: 6
TEMPERATURE: 1
EXPLORED MODELS: 450000
ACCEPTANCE RATE: 66545/450000 (14.79 %)
PARTIAL ACCEPTANCE RATES:
	BirthPerturbation(my_param_space): 2615/75104 (3.48%)
	DeathPerturbation(my_param_space): 2614/74131 (3.53%)
	NoisePerturbation(my_data): 15354/75098 (20.45%)
	ParamPerturbation(['my_param_space.mean', 'my_param_space.std', 'my_param_space.weight']): 45962/225667 (20.37%)
Chain ID: 7
TEMPERATURE: 1
EXPLORED MODELS: 450000
ACCEPTANCE RATE: 64728/450000 (14.38 %)
PARTIAL ACCEPTANCE RATES:
	BirthPerturbation(my_param_space): 2779/74939 (3.71%)
	DeathPerturbation(my_param_space): 2778/73787 (3.76%)
	NoisePerturbation(my_data): 15585/75336 (20.69%)
	ParamPerturbation(['my_param_space.mean', 'my_param_space.std', 'my_param_space.weight']): 43586/225938 (19.29%)
Chain ID: 8
TEMPERATURE: 1
EXPLORED MODELS: 450000
ACCEPTANCE RATE: 58536/450000 (13.01 %)
PARTIAL ACCEPTANCE RATES:
	BirthPerturbation(my_param_space): 2444/74991 (3.26%)
	DeathPerturbation(my_param_space): 2442/74370 (3.28%)
	NoisePerturbation(my_data): 14831/75300 (19.70%)
	ParamPerturbation(['my_param_space.mean', 'my_param_space.std', 'my_param_space.weight']): 38819/225339 (17.23%)
Chain ID: 9
TEMPERATURE: 1
EXPLORED MODELS: 450000
ACCEPTANCE RATE: 66436/450000 (14.76 %)
PARTIAL ACCEPTANCE RATES:
	BirthPerturbation(my_param_space): 2838/75008 (3.78%)
	DeathPerturbation(my_param_space): 2837/73446 (3.86%)
	NoisePerturbation(my_data): 15757/75467 (20.88%)
	ParamPerturbation(['my_param_space.mean', 'my_param_space.std', 'my_param_space.weight']): 45004/226079 (19.91%)
Chain ID: 10
TEMPERATURE: 1
EXPLORED MODELS: 450000
ACCEPTANCE RATE: 67444/450000 (14.99 %)
PARTIAL ACCEPTANCE RATES:
	BirthPerturbation(my_param_space): 2963/75357 (3.93%)
	DeathPerturbation(my_param_space): 2961/73788 (4.01%)
	NoisePerturbation(my_data): 16026/75147 (21.33%)
	ParamPerturbation(['my_param_space.mean', 'my_param_space.std', 'my_param_space.weight']): 45494/225708 (20.16%)
Chain ID: 11
TEMPERATURE: 1
EXPLORED MODELS: 450000
ACCEPTANCE RATE: 63388/450000 (14.09 %)
PARTIAL ACCEPTANCE RATES:
	BirthPerturbation(my_param_space): 2743/75399 (3.64%)
	DeathPerturbation(my_param_space): 2741/74206 (3.69%)
	NoisePerturbation(my_data): 15733/75140 (20.94%)
	ParamPerturbation(['my_param_space.mean', 'my_param_space.std', 'my_param_space.weight']): 42171/225255 (18.72%)
Chain ID: 12
TEMPERATURE: 1
EXPLORED MODELS: 450000
ACCEPTANCE RATE: 64964/450000 (14.44 %)
PARTIAL ACCEPTANCE RATES:
	BirthPerturbation(my_param_space): 2789/75015 (3.72%)
	DeathPerturbation(my_param_space): 2787/73897 (3.77%)
	NoisePerturbation(my_data): 15793/75206 (21.00%)
	ParamPerturbation(['my_param_space.mean', 'my_param_space.std', 'my_param_space.weight']): 43595/225882 (19.30%)
Chain ID: 13
TEMPERATURE: 1
EXPLORED MODELS: 450000
ACCEPTANCE RATE: 57073/450000 (12.68 %)
PARTIAL ACCEPTANCE RATES:
	BirthPerturbation(my_param_space): 2414/74991 (3.22%)
	DeathPerturbation(my_param_space): 2412/74684 (3.23%)
	NoisePerturbation(my_data): 14846/75391 (19.69%)
	ParamPerturbation(['my_param_space.mean', 'my_param_space.std', 'my_param_space.weight']): 37401/224934 (16.63%)
Chain ID: 14
TEMPERATURE: 1
EXPLORED MODELS: 450000
ACCEPTANCE RATE: 64725/450000 (14.38 %)
PARTIAL ACCEPTANCE RATES:
	BirthPerturbation(my_param_space): 2709/75246 (3.60%)
	DeathPerturbation(my_param_space): 2707/74144 (3.65%)
	NoisePerturbation(my_data): 15454/75230 (20.54%)
	ParamPerturbation(['my_param_space.mean', 'my_param_space.std', 'my_param_space.weight']): 43855/225380 (19.46%)
Chain ID: 15
TEMPERATURE: 1
EXPLORED MODELS: 450000
ACCEPTANCE RATE: 65151/450000 (14.48 %)
PARTIAL ACCEPTANCE RATES:
	BirthPerturbation(my_param_space): 2751/74235 (3.71%)
	DeathPerturbation(my_param_space): 2750/74155 (3.71%)
	NoisePerturbation(my_data): 15144/75199 (20.14%)
	ParamPerturbation(['my_param_space.mean', 'my_param_space.std', 'my_param_space.weight']): 44506/226411 (19.66%)
Chain ID: 16
TEMPERATURE: 1
EXPLORED MODELS: 450000
ACCEPTANCE RATE: 62163/450000 (13.81 %)
PARTIAL ACCEPTANCE RATES:
	BirthPerturbation(my_param_space): 2665/74744 (3.57%)
	DeathPerturbation(my_param_space): 2664/74141 (3.59%)
	NoisePerturbation(my_data): 15172/75503 (20.09%)
	ParamPerturbation(['my_param_space.mean', 'my_param_space.std', 'my_param_space.weight']): 41662/225612 (18.47%)
Chain ID: 17
TEMPERATURE: 1
EXPLORED MODELS: 450000
ACCEPTANCE RATE: 67334/450000 (14.96 %)
PARTIAL ACCEPTANCE RATES:
	BirthPerturbation(my_param_space): 2791/75149 (3.71%)
	DeathPerturbation(my_param_space): 2790/73966 (3.77%)
	NoisePerturbation(my_data): 15962/75506 (21.14%)
	ParamPerturbation(['my_param_space.mean', 'my_param_space.std', 'my_param_space.weight']): 45791/225379 (20.32%)
Chain ID: 18
TEMPERATURE: 1
EXPLORED MODELS: 450000
ACCEPTANCE RATE: 64308/450000 (14.29 %)
PARTIAL ACCEPTANCE RATES:
	BirthPerturbation(my_param_space): 2697/75415 (3.58%)
	DeathPerturbation(my_param_space): 2695/74081 (3.64%)
	NoisePerturbation(my_data): 15194/75000 (20.26%)
	ParamPerturbation(['my_param_space.mean', 'my_param_space.std', 'my_param_space.weight']): 43722/225504 (19.39%)
Chain ID: 19
TEMPERATURE: 1
EXPLORED MODELS: 450000
ACCEPTANCE RATE: 58912/450000 (13.09 %)
PARTIAL ACCEPTANCE RATES:
	BirthPerturbation(my_param_space): 2530/75217 (3.36%)
	DeathPerturbation(my_param_space): 2529/73960 (3.42%)
	NoisePerturbation(my_data): 15135/75167 (20.14%)
	ParamPerturbation(['my_param_space.mean', 'my_param_space.std', 'my_param_space.weight']): 38718/225656 (17.16%)

Retrieve the results and plot

In the following blocks, we first plot an histogram of the ParameterSpace dimensionality of the sampled models, to verify that the majority of them corresponds to three Gaussians in the mixture. We then select all sampled models with three dimensions and plot the retrieved PDF and model parameters, as in Part I: Known Number of Mixture Components

results = inversion.get_results()
n_gaussians = results['my_param_space.n_dimensions']
fig, ax = plt.subplots()
ax.hist(n_gaussians, bins=np.arange(1.5, 10.5), ec='w')
ax.set_xlabel('No. Dimensions (Gaussians in the mixture)')
ax.set_ylabel('Sampled models')
plt.show()
../_images/d55fcf4877413c94f4624a1dc741737db9087dd350741a117ba56a5d2747f265.png
idx = [i for i, n_comp in enumerate(n_gaussians) if n_comp==3]

def sort_mixture(means, stds, weights):
    indexes = [np.argsort(row) for row in means]
    for i, idx in enumerate(indexes):
        means[i] = means[i][idx]
        stds[i] = stds[i][idx]
        weights[i] = weights[i][idx]
    return means, stds, weights

means, stds, weights = sort_mixture(np.array([results['my_param_space.mean'][i] for i in idx]), 
                                    np.array([results['my_param_space.std'][i] for i in idx]), 
                                    np.array([results['my_param_space.weight'][i] for i in idx]))
# Estimate true data noise
datasets = [generate_random_samples(MEANS, STDS, WEIGHTS) for _ in range(10000)]
histograms = np.array([np.histogram(dataset, bins=50, density=True)[0] for dataset in datasets])
true_std = np.median(np.std(histograms, axis=0))
pdf_pred = _forward(np.median(means, axis=0), 
                    np.median(stds, axis=0), 
                    np.median(weights, axis=0)
                    )

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
ax1.hist(samples, bins=50, density=True, ec='w', fc='gray', alpha=0.8)
ax1.plot(xs, pdf_true, label='True mixture PDF', color='r')
ax1.plot(data_x, pdf_pred, label='Inferred', color='b')
ax1.set_xlabel('Height [cm]')
ax1.set_ylabel('Density')
ax1.legend()

ax2.axvline(x=true_std, color='r', lw=3, alpha=0.3, label='"True" (i.e., estimated) noise')
pdf, bins, _ = ax2.hist(results['my_data.std'], density=True, bins=4, ec='w', zorder=100, label='Posterior')
ax2.fill_between([target.std_min, target.std_max], 1 / (target.std_max - target.std_min), alpha=0.2, label='Prior')
ax2.set_xlabel('Noise standard deviation')
ax2.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
ax2.legend(framealpha=0.9)
plt.show()
../_images/d4671ed165165a473f43974b36a6bdbbb9753e43470e601ef11aaaa0d81c8b03.png
import arviz as az

for key, inferred_value, true_value in zip(
    ['mean', 'std', 'weight'], 
    [means, stds, weights],
    [MEANS, STDS, WEIGHTS]
                     ):
    fig, axes = plt.subplots(3, 3, figsize=(10, 8))
    _ = az.plot_pair(
        {f'{key}_1': inferred_value[:,0], 
         f'{key}_2': inferred_value[:,1], 
         f'{key}_3': inferred_value[:,2]},
        marginals=True,
        reference_values={f'{key}_1': true_value[0], 
                          f'{key}_2': true_value[1], 
                          f'{key}_3': true_value[2]},
        reference_values_kwargs={'color': 'yellow',
                                 'ms': 10},
        kind='kde',
        kde_kwargs={
            'hdi_probs': [0.3, 0.6, 0.9],  # Plot 30%, 60% and 90% HDI contours
            'contourf_kwargs': {'cmap': 'Blues'},
            },
        ax=axes,
        textsize=10
        )
    fig.suptitle(key.upper())
    plt.tight_layout()
    plt.show()
/home/fabrizio/mambaforge/envs/seislib/lib/python3.9/site-packages/arviz/plots/pairplot.py:232: FutureWarning: The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.
  gridsize = int(dataset.dims["draw"] ** 0.35)
../_images/48495fe9339835a9ebe1dfbac105bf1f56bf5ae70eb11458828f1eb6a9be9341.png ../_images/28b9e19581119389c13bad4367bac104efa223214e8d67b6b40e9ab118e4fa48.png ../_images/99d04ed8dbfd534a613c576a7834cc487f6f346a8b545d1b078b623b4d02b73e.png

References

[1] Kirkpatrick et al. (1983), Optimization by simulated annealing. Science

[2] Sambridge (2014), A parallel tempering algorithm for probabilistic sampling and multimodal optimization. Geophysical Journal International