Source code for pta.sampling.convergence_manager

"""Management of MCMC simulation to achieve desired convergence criteria."""
import logging
import time
from typing import Any, Callable

import _pta_python_binaries as pb
import numpy as np
from pta.sampling.commons import SamplerInterface, SamplingException, SamplingResult

logger = logging.getLogger(__name__)


[docs]class ConvergenceManager: """Class managing MCMC simulations to achieve desired convergence criteria. Parameters ---------- sampler : Sampler Object implementing a specific MCMC sampler. num_initial_steps : int, optional Initial number of steps for simulations, by default -1 (automatic). samples_based_rounding : bool, optional If true, the direction sampling distribution will be adjusted after each iteration based on the distribution of samples so far, by default False """ def __init__( self, sampler: SamplerInterface, num_initial_steps: int = -1, samples_based_rounding: bool = False, ): self._sampler = sampler self._num_initial_steps = ( num_initial_steps if num_initial_steps > 0 else sampler.dimensionality ** 2 ) self._round = samples_based_rounding
[docs] def run( self, settings: pb.SamplerSettings, update_settings_function: Callable[[pb.SamplerSettings, int], None], make_sampling_result_function: Callable[[Any], SamplingResult], initial_points: np.ndarray, max_steps: int, max_psrf: float, ) -> SamplingResult: """Run the sampler using this manager until the given PSRF or maximum number of steps is reached. Parameters ---------- settings : pb.SamplerSettings The initial settings for the sampler. update_settings_function : Callable[[pb.SamplerSettings, int], None] Function for updating the settings with different numbers of steps. make_sampling_result_function: Callable[[Any], SamplingResult] Function for packing the result of the native sampler in a SamplingResult object. initial_points : np.ndarray The initial points for the simulation. max_steps : int Maximum number of steps to simulate. max_psrf : float Maximum PSRF to declare convergence. Returns ------- SamplingResult The result of the sampler. Raises ------ SamplingException If the native sampler fails. """ assert ( initial_points.shape[1] == settings.num_chains ), "Sampling requires the same number of initial points and chains." logger.info( f"Starting sampler with {settings.num_chains} chains and " f"{settings.max_threads} threads." ) num_steps = self._num_initial_steps total_steps = 0 iteration = 0 converged = False directions_transform = np.identity(self._sampler.dimensionality) sampling_start_time = time.time() # Sample the space increasing the number of steps and optionally improving the # direction sampling distribution until one of the stopping conditions is # reached. while (not converged) and total_steps < max_steps: iteration += 1 iteration_start_time = time.time() try: logger.info( f"Running sampling iteration {iteration} with {num_steps} steps..." ) update_settings_function(settings, num_steps) result = self._sampler.simulate( settings, initial_points, directions_transform ) total_steps += num_steps except Exception as e: logger.error(f"Sampling failed. {e}") raise SamplingException("Sampling failed.") from e else: iteration_time = time.time() - iteration_start_time step_rate = num_steps / iteration_time # Decide whether convergence was reached or we need to simulate the # chains for longer. psrf = self._sampler.compute_psrf(result) if np.all(psrf <= max_psrf): converged = True else: # Update number of steps and initial points. num_steps = min(2 * num_steps, max_steps - total_steps) chains = self._sampler.get_chains(result) initial_points = chains[-1, :, :] # If needed update the directions transform as well. if self._round: _, S, Vh = np.linalg.svd( np.hstack(chains).T, full_matrices=False ) directions_transform = ( Vh.T @ np.diag(S / np.min(S)) @ np.identity(self._sampler.dimensionality) ) logger.info( f"Sampling iteration {iteration} completed in {iteration_time:.2f}" f"s ({step_rate:.1f} steps/s). Max PSRF = {np.max(psrf):.2f}." ) sampling_time = time.time() - sampling_start_time logger.info( f"Sampling completed in {sampling_time:.2f}s after a total of {total_steps}" f" steps. Convergence criteria were {'' if converged else 'not '}satisfied " f"(Max PSRF = {np.max(psrf):.3f}, PSRF threshold = {max_psrf:.3f})." ) return make_sampling_result_function(result)