bayesbay.BaseBayesianInversion
- class bayesbay.BaseBayesianInversion(walkers_starting_states, perturbation_funcs, perturbation_weights=None, log_like_ratio_func=None, log_like_func=None, n_chains=10, save_dpred=True)
A low-level class for Bayesian sampling based on Markov chain Monte Carlo (McMC).
This class provides the basic structure for setting up and running McMC sampling, given user-provided definition of perturbations, likelihood functions and the initialization of walkers. At each iteration of the inference process, the current model \(\bf m\) is perturbed to produce \(\bf m'\), and the new model is accepted with probability
\[\begin{split}\alpha({\bf m' \mid m}) = \mbox{min} \Bigg[1, \underbrace{\frac{p\left({\bf m'}\right)}{p\left({\bf m}\right)}}_{\text{Prior ratio}} \underbrace{\frac{p\left({\bf d}_{obs} \mid {\bf m'}\right)}{p\left({\bf d}_{obs} \mid {\bf m}\right)}}_{\text{Likelihood ratio}} \underbrace{\frac{q\left({\bf m} \mid {\bf m'}\right)}{q\left({\bf m'} \mid {\bf m}\right)}}_{\text{Proposal ratio}} \underbrace{\lvert \mathbf{J} \rvert}_{\begin{array}{c} \text{Jacobian} \\ \text{determinant} \end{array}} \Bigg],\end{split}\]where \({\bf d}_{obs}\) denotes the observed data and \(\mathbf{J}\) the Jacobian of the transformation.
- Parameters:
walkers_starting_states (List[Any]) – a list of starting states for each chain. The states can be of any type so long as they are consistent with what is accepted as arguments in the perturbation functions and probability functions. The length of this list must be equal to the number of chains, i.e.
n_chains
perturbation_funcs (List[Callable[[Any], Tuple[Any, Number]]]) – a list of perturbation functions. Each perturbation function should take in a model \(\mathbf{m}\) (any type is allowed, as long as it is consistent with
walkers_starting_states
and the below functions) and perturb it to produce the new model \(\bf m'\). Each perturbation function should return \(\bf m'\) along with \(\log( \frac{p({\bf m'})}{p({\bf m})} \frac{q\left({\bf m} \mid {\bf m'}\right)}{q\left({\bf m'} \mid {\bf m}\right)} \lvert \mathbf{J} \rvert)\), which is used in the calculation of the acceptance probability.perturbation_weights (List[Number], optional) – a list of weights corresponding to each element of
perturbation_funcs
. If this is set to (the default value)None
, then each perturbation function will have equal probability of being selected on each iteration.log_like_ratio_func (Union[LogLikelihood, Callable[[Any, Any], Number]], optional) – the log likelihood ratio function \(\log(\frac{p(\mathbf{d}_{obs} \mid \mathbf{m'})} {p(\mathbf{d}_{obs} \mid \mathbf{m})})\). It takes the current and proposed models, \(\mathbf{m}\) and \(\mathbf{m'}\), whose type should be consistent with the other arguments of this class, and returns a scalar corresponding to the log likelihood ratio. This is utilised in the calculation of the acceptance probability. If None,
log_like_func
gets used instead. Default to Nonelog_like_func (Callable[[Any], Number], optional) – the log likelihood function \(\log(p(\mathbf{d}_{obs} \mid \mathbf{m}))\). It takes in a model \(\mathbf{m}\) (any type is allowed, as long as it is consistent with the other arguments of this class) and returns the log of the likelihood function. This function is only used when
log_like_ratio_func
is None. Default to Nonen_chains (int, optional) – the number of chains in the McMC sampling. Default is 10
Reference Details
- chains
The
MarkovChain
instances of the current Bayesian inversion
- get_results(keys=None, concatenate_chains=True)
To get the saved states from current inversion
- Parameters:
keys (Union[str, List[str]]) – key(s) to retrieve from the saved states. This will be ignored when states are not of type
State
or dictconcatenate_chains (bool, optional) – whether to aggregate samples from all the Markov chains or to keep them seperate, by default True
- Returns:
a dictionary from name of the attribute stored to the values, or a list of saved states (if the base level API is used, and states are not of type
State
or dict)- Return type:
Union[Dict[str, list], list]
- static get_results_from_chains(chains, keys=None, concatenate_chains=True)
To get the saved states from a list of given Markov chains
- Parameters:
chains (Union[BaseMarkovChain, List[BaseMarkovChain]]) – Markov chain(s) that the results are going to be extracted from
keys (Union[str, List[str]]) – key(s) to retrieve from the saved states. This will be ignored when states are not of type
State
or dictconcatenate_chains (bool, optional) – whether to aggregate samples from all the Markov chains or to keep them seperate, by default True
- Returns:
a dictionary from name of the attribute stored to the values, or a list of saved states (if the base level API is used, and states are not of type
State
or dict)- Return type:
Union[Dict[str, list], list]
- run(sampler=None, n_iterations=1000, burnin_iterations=0, save_every=100, verbose=True, print_every=100, parallel_config=None)
To run the inversion
- Parameters:
sampler (bayesbay.samplers.Sampler, optional) – a sampler instance describing how chains intereact or modify their properties during sampling. This could be a sampler from the module
bayesbay.samplers
such asbayesbay.samplers.VanillaSampler
(default),bayesbay.samplers.ParallelTempering
, orbayesbay.samplers.SimulatedAnnealing
, or a customised sampler instance.n_iterations (int, optional) – total number of iterations to run, by default 1000
burnin_iterations (int, optional) – the iteration number from which we start to save samples, by default 0
save_every (int, optional) – the frequency with which we save the samples. By default a state is saved every after 100 iterations after the burn-in phase
verbose (bool, optional) – whether to print the progress during sampling or not, by default True
print_every (int, optional) – the frequency with which we print the progress and information during the sampling, by default 100 iterations
parallel_config (dict, optional) – keyword arguments passed to
joblib.Parallel
. Ignored whenlen(self.chains)
is 1
- set_perturbation_funcs(perturbation_funcs, perturbation_weights=None)