"""Thermodynamics-based sampling of free energies, concentrations and fluxes."""
import logging
import math
import multiprocessing
import os
import pickle
import random
import tempfile
from typing import List, Optional, Tuple, Union
import numpy as np
import pandas as pd
import scipy as sp
from scipy.stats.distributions import chi2
import _pta_python_binaries as pb
import cobra
import cvxpy as cp
from pta.sampling.convergence_manager import ConvergenceManager
from ..constants import (
default_max_psrf,
default_max_threads,
default_min_eigenvalue_tds_basis,
default_num_samples,
tfs_default_feasibility_cache_size,
tfs_default_min_rel_region_length,
)
from ..flux_space import FluxSpace
from ..pmo import PmoProblem, PmoProblemPool
from ..thermodynamic_space import ThermodynamicSpace, ThermodynamicSpaceBasis
from ..utils import apply_transform, covariance_square_root
from .commons import (
SamplerInterface,
SamplingResult,
apply_to_chains,
fill_common_sampling_settings,
sample_from_chains,
split_R,
)
from .uniform import sample_flux_space_uniform
logger = logging.getLogger(__name__)
_us_model: cobra.Model
[docs]class TFSModel(SamplerInterface):
"""Object holding the information necessary to run TFS.
Parameters
----------
network : Union[cobra.Model, FluxSpace]
Cobra model or `FluxSpace` object describing the flux space of the metabolic
network.
thermodynamic_space : ThermodynamicSpace
Description of the thermodynamic space of the metabolic network.
thermodynamic_space_basis : ThermodynamicSpaceBasis, optional
A basis for the thermodynamic space. If specified, `m` will be defined in this
basis.
objective : Callable[ [PmoProblem], cp.problems.objective.Objective], optional
A function used to set the optimization objective. By default the probability of
in thermodynamic space is maximized.
confidence_level : float, optional
Confidence level (in the range :math:`[0.0, 1.0[`) on the joint of the
thermodynamic variables, by default 0.95.
min_drg : float, optional
Minimum magnitude for the reaction energy of each reaction, by default 1e-1.
max_drg : float, optional
Maximum magnitude for the reaction energy of each reaction, by default 1000.
solver : Optional[str], optional
Name of the solver to use, this can be any of the solvers supported by CVXPY, by
default None.
solver_options : dict, optional
Dictionary specifying additional options for the solver.
"""
def __init__(
self,
network: Union[cobra.Model, FluxSpace],
thermodynamic_space: ThermodynamicSpace,
thermodynamic_space_basis: ThermodynamicSpaceBasis = None,
confidence_level: float = 0.95,
min_drg: float = 1e-1,
max_drg: float = 1000,
solver: Optional[str] = None,
solver_options: dict = None,
):
# Create the flux space and its basis if needed.
if isinstance(network, FluxSpace):
self._F = network.copy()
else:
self._F = FluxSpace.from_cobrapy_model(network)
# Create the basis of the thermodynamic space if needed.
if thermodynamic_space_basis is not None:
self._B = thermodynamic_space_basis
assert (
self.B.to_drg_transform is not None
and self.B.to_drg0_transform is None
and self.B.to_log_conc_transform is None
), "Currently TFS requires a basis that represents free energies only."
else:
self._B = ThermodynamicSpaceBasis(
thermodynamic_space, explicit_drg0=False, explicit_log_conc=False
)
self._T = thermodynamic_space
self._confidence_level = confidence_level
self._drg_epsilon = min_drg
self._pmo_args = [
self.F,
self.T,
self.B,
None,
confidence_level,
min_drg,
max_drg,
solver,
solver_options or {},
]
@property
[docs] def T(self) -> ThermodynamicSpace:
"""Gets the thermodynamic space used for sampling."""
return self._T
@property
[docs] def F(self) -> FluxSpace:
"""Gets the flux space used for sampling."""
return self._F
@property
[docs] def B(self) -> ThermodynamicSpaceBasis:
"""Gets the basis of the thermodynamic space used for sampling."""
return self._B
@property
[docs] def pmo_args(self):
"""Gets the arguments used to construct PMO problems."""
return self._pmo_args
@property
[docs] def dimensionality(self) -> int:
"""Gets the dimensionality of the basis of the thermodynamic space."""
return self.B.dimensionality
@property
[docs] def confidence_radius(self):
"""Gets the radius of the selected confidence region."""
return math.sqrt(chi2.ppf(self._confidence_level, self.B.dimensionality))
@property
[docs] def drg_epsilon(self):
"""Gets the minimum magnitude of the reaction energy of irreversible
reactions."""
return self._drg_epsilon
@property
[docs] def reversible_rxn_ids(self):
"""Gets the identifiers of the reversible reactions in the thermodynamic
space."""
return [
id
for i, id in enumerate(self.F.reaction_ids)
if self.F.lb[i] < 0 and self.F.ub[i] > 0 and id in self.T.reaction_ids
]
[docs] def get_reversible_reactions_ids_T(self) -> List:
"""
Get the ids of the reversible reactions in the thermodynamic space
Returns
----------
List
The corresponding reactions ids
"""
reaction_idxs_T = list(range(len(self.T.reaction_ids)))
reaction_idxs_F = [self.F.reaction_ids.index(id) for id in self.T.reaction_ids]
reversible_ids_T = [
i
for i in reaction_idxs_T
if self.F.lb[reaction_idxs_F[i]] < 0 and self.F.ub[reaction_idxs_F[i]] > 0
]
return reversible_ids_T
[docs] def to_drg(self, value: np.ndarray) -> np.ndarray:
"""Transform a point or matrix from the the basis to reaction energies.
Parameters
----------
value : np.ndarray
Input values in the basis.
Returns
-------
np.ndarray
The corresponding reaction energies.
"""
assert self.B.to_drg_transform is not None
return apply_transform(value, self.B.to_drg_transform)
[docs] def simulate(
self,
settings: pb.SamplerSettings,
initial_points: np.ndarray,
directions_transform: np.ndarray,
) -> pb.TFSResult:
"""Run the sampler with the given parameters.
Parameters
----------
settings : pb.SamplerSettings
Sampling settings.
initial_points : np.ndarray
The initial points for the chains.
directions_transform : np.ndarray
The transform for the directions sampler.
"""
drg_G, drg_h = _make_drg_polytope(self)
result = pb.sample_free_energies(
self.F.S,
self.F.lb,
self.F.ub,
np.array(
[[self.F.reaction_ids.index(id) for id in self.T.reaction_ids]],
dtype=np.uint32,
).T,
initial_points,
settings,
drg_G,
drg_h,
self.B.to_drg_transform[0],
self.B.to_drg_transform[1],
directions_transform,
)
return result
[docs] def compute_psrf(self, result: pb.TFSResult) -> pd.Series:
"""Compute the potential scale reduction factors for the variables of interest
on a given set of chains.
Parameters
----------
result : pb.TFSResult
The result of the sampling function.
Returns
-------
pd.Series
The computed potential scale reduction factors.
"""
logger.debug("Computing PSRFs.")
basis_var_names = ["drg_var" + str(i) for i in range(self.B.dimensionality)]
var_names = [id + "_drg" for id in self.T.reaction_ids]
psrf_var_names = basis_var_names + var_names
psrf = np.hstack(
[
split_R(result.chains),
split_R(apply_to_chains(result.chains, self.to_drg)),
]
)
return pd.Series(psrf, index=psrf_var_names)
[docs] def get_chains(self, result: pb.TFSResult) -> np.ndarray:
"""Extract the simulated chains from a given result.
Parameters
----------
result : pb.TFSResult
The result of the native sampling function.
Returns
-------
np.ndarray
The simulated chains.
"""
return result.chains
[docs] def get_initial_points(self, num_points: int) -> np.ndarray:
"""Gets initial points for sampling reaction energies.
Parameters
----------
num_points : int
Number of initial points to generate.
Returns
-------
np.ndarray
Array containing the initial points.
"""
# pylint: disable=protected-access
# Create the process pool for solving PMO problems.
pool = PmoProblemPool(None, *self._pmo_args)
# Find candidate optimization direactions.
reaction_idxs_T = list(range(len(self.T.reaction_ids)))
reaction_idxs_F = [self.F.reaction_ids.index(id) for id in self.T.reaction_ids]
only_forward_ids_T = [
i for i in reaction_idxs_T if self.F.lb[reaction_idxs_F[i]] >= 0
]
only_backward_ids_T = [
i for i in reaction_idxs_T if self.F.ub[reaction_idxs_F[i]] <= 0
]
reversible_ids_T = [
i
for i in reaction_idxs_T
if self.F.lb[reaction_idxs_F[i]] < 0 and self.F.ub[reaction_idxs_F[i]] > 0
]
reversible_dirs = [(i, -1) for i in reversible_ids_T] + [
(i, 1) for i in reversible_ids_T
]
irreversible_dirs = [(i, -1) for i in only_backward_ids_T] + [
(i, 1) for i in only_forward_ids_T
]
# Select optimization directions, giving precedence to the reversible reactions.
if num_points >= len(reversible_dirs):
directions = reversible_dirs
directions_pool = irreversible_dirs
to_sample = min(num_points - len(reversible_dirs), len(irreversible_dirs))
else:
directions = []
directions_pool = reversible_dirs
to_sample = min(num_points, len(reversible_dirs))
optimization_directions = directions + random.sample(directions_pool, to_sample)
# Run the optimizations in the pool.
initial_points = pool.map(_find_point, optimization_directions)
assert all(p is not None for p in initial_points), (
"One or more initial points could not be found. This could be due to "
"an overconstrained model or numerical inaccuracies."
)
points_array = np.hstack(initial_points)
pool.close()
return points_array
def _find_point(problem: PmoProblem, flux_objective: Tuple[int, int]) -> np.ndarray:
# pylint: disable=protected-access
# First, find an orthant with the given direction (if possible).
problem.objective = lambda p: cp.Maximize(
flux_objective[1] * p.d[flux_objective[0]]
)
result = problem.solve()
if result != "optimal":
logger.warning(
f"Initial points search failed for reaction {flux_objective[0]} in "
f"direction {flux_objective[1]} with status {result}."
)
return None
# Next, restrict PMO to the orthant found and search for a point maximizing the
# distance from all the constraints.
orthant_problem = problem.rebuild_for_directions(problem.d.value)
assert orthant_problem.B.to_drg_transform is not None
(T, s) = orthant_problem.B.to_drg_transform
CI_square = chi2.ppf(
orthant_problem._confidence_level, orthant_problem.B.dimensionality
)
distance = cp.Variable(name="distance", nonneg=True)
row_norms = np.linalg.norm(
orthant_problem.B.to_drg_transform[0], ord=2, axis=1, keepdims=True
)
orthant_problem._constraints.extend(
[
cp.multiply(orthant_problem._big_M_r, orthant_problem.d)
+ T @ orthant_problem.m
+ row_norms * distance
+ s
<= orthant_problem._big_M_r - orthant_problem._epsilon_r,
cp.multiply(orthant_problem._big_M_r, orthant_problem.d)
+ T @ orthant_problem.m
+ row_norms * distance
+ s
>= orthant_problem._epsilon_r,
cp.atoms.quad_form(
orthant_problem.m, np.identity(orthant_problem.B.dimensionality)
)
+ cp.square(distance)
<= CI_square,
]
)
orthant_problem.objective = lambda p: cp.Maximize(distance)
if orthant_problem.solver != "GUROBI":
logger.warning(
"Orthant problem does not have GUROBI as a solver, BarHomogeneous parameter specification may fail"
)
orthant_problem.solver_options["BarHomogeneous"] = 1
orthant_problem.solve()
if result != "optimal":
logger.warning(
f"Initial points search failed for reaction {flux_objective[0]} in "
f"direction {flux_objective[1]} with status {result}."
)
return None
else:
logger.info(f"Found initial point with distance {distance.value}")
return orthant_problem.m.value
[docs]class FreeEnergiesSamplingResult(SamplingResult):
"""Encapsulates the result of sampling reaction energies.
Parameters
----------
samples : pd.DataFrame
Data frame containing the free energy samples.
psrf : pd.Series
The Potential Scale Reduction Factors of each variable.
orthants : pd.DataFrame
Data frame containing the signs of the reversible reactions for each orthants.
Contains an additional column ("weight") describing the weight of the orthant.
basis_samples : pd.DataFrame, optional
The samples in the basis.
chains : np.ndarray, optional
The simulated chains.
"""
def __init__(
self,
samples: pd.DataFrame,
psrf: pd.Series,
orthants: pd.DataFrame,
basis_samples: pd.DataFrame = None,
chains: np.ndarray = None,
):
SamplingResult.__init__(self, samples, psrf, basis_samples, chains)
self._orthants = orthants
@property
[docs] def orthants(self) -> pd.DataFrame:
"""Gets a data frame containing the sampled orthants. Contains an additional
column ("weight") describing the weight of the orthant."""
return self._orthants
[docs]def sample_drg(
model: TFSModel,
num_samples: int = default_num_samples,
num_direction_samples: int = default_num_samples,
max_steps: int = -1,
max_psrf: float = default_max_psrf,
num_chains: int = -1,
initial_points: np.ndarray = None,
num_initial_steps: int = -1,
feasibility_cache_size: int = tfs_default_feasibility_cache_size,
min_rel_region_length: float = tfs_default_min_rel_region_length,
max_threads: int = default_max_threads,
convergence_manager: ConvergenceManager = None,
) -> FreeEnergiesSamplingResult:
"""Sample reaction energies under steady state flux constraints in the given model.
Parameters
----------
model : TFSModel
The model to sample.
num_samples : int, optional
Number of samples to draw.
num_direction_samples : int, optional
Number of orthant samples to collect.
max_steps : int, optional
The maximum number fo steps to simulate.
max_psrf : float, optional
Maximum value of the PSRFs for convergence.
num_chains : int, optional
The number of chains to simulate.
initial_points : np.ndarray, optional
The initial points for the chains.
num_initial_steps: int, optional
Initial chains length.
feasibility_cache_size : int, optional
Maximum size of the cache storing the feasibility of the orthants encountered
during the random walk.
min_rel_region_length : float, optional
Minimum length (relative to the length of the entire ray) of a segment in order
to consider it for sampling.
max_threads : int, optional
The maximum number of parallel threads to use.
convergence_manager : ConvergenceManager, optional
The object to use to monitor and improve convergence.
Returns
-------
FreeEnergiesSamplingResult
The sampling result.
Raises
------
SamplingException
If sampling fails.
"""
# Validate and process input settings.
assert model.B.to_drg_transform is not None
assert num_samples > 0
assert num_direction_samples > 0
max_steps = max_steps if max_steps >= 1 else 2**63
assert max_psrf > 1.0
if num_chains < 0:
if initial_points is not None:
num_chains = initial_points.shape[1]
else:
T_n_reversible_reactions = model.get_reversible_reactions_ids_T()
num_chains = max(len(T_n_reversible_reactions) * 2, 1)
num_initial_steps = (
num_initial_steps
if num_initial_steps > 0
else model.B.dimensionality**2 + num_samples
)
assert feasibility_cache_size >= 0
assert min_rel_region_length > 0
assert max_threads > 0
if convergence_manager is None:
convergence_manager = ConvergenceManager(model, num_initial_steps, True)
# Construct settings and auxiliary functions.
settings = pb.FreeEnergySamplerSettings()
settings.truncation_multiplier = model.confidence_radius
settings.feasibility_cache_size = feasibility_cache_size
settings.drg_epsilon = model.drg_epsilon
settings.flux_epsilon = 1e-4 # TODO: make this an actual setting.
settings.min_rel_region_length = min_rel_region_length
update_settings_function = lambda s, steps: _fill_settings(
s, num_samples, steps, num_chains, 2, max_threads, num_direction_samples
)
update_settings_function(settings, num_initial_steps)
make_sampling_result_function = lambda r: _make_sampling_result(
model, r, num_samples
)
# Generate the initial points if needed.
if initial_points is None:
initial_points = model.get_initial_points(settings.num_chains)
# Run the sampler.
result = convergence_manager.run(
settings,
update_settings_function,
make_sampling_result_function,
initial_points,
max_steps,
max_psrf,
)
return result
def _make_sampling_result(
model: "TFSModel", result: pb.TFSResult, num_samples: int
) -> FreeEnergiesSamplingResult:
basis_var_names = ["var" + str(i) for i in range(model.dimensionality)]
basis_samples = sample_from_chains(result.chains, num_samples)
samples = model.to_drg(basis_samples.T).T
# Convert the orthants from binary to integer description.
orthant_signs = np.unpackbits(result.directions, bitorder="little", axis=1)[
:, : len(model.reversible_rxn_ids)
].astype(np.int8)
orthant_signs[orthant_signs == 0] = -1
orthants = pd.DataFrame(
orthant_signs,
columns=model.reversible_rxn_ids,
)
orthant_weights = pd.DataFrame(result.direction_counts, columns=["weight"])
result = FreeEnergiesSamplingResult(
pd.DataFrame(samples, columns=model.T.reaction_ids),
model.compute_psrf(result),
orthant_weights.join(orthants),
pd.DataFrame(basis_samples, columns=basis_var_names),
result.chains,
)
return result
def _fill_settings(
settings,
num_samples,
steps,
num_chains,
num_warmup_steps,
max_threads,
num_direction_samples,
):
fill_common_sampling_settings(
settings, num_samples, steps, num_chains, num_warmup_steps, max_threads
)
settings.steps_thinning_directions = math.floor(
(settings.num_steps - settings.num_skipped_steps)
/ math.ceil(num_direction_samples / settings.num_chains)
)
assert settings.steps_thinning_directions > 0, (
"Unable to generate that many direction samples with the current settings. "
"Please select a larger number of steps or chains."
)
def _make_drg_polytope(model: TFSModel) -> Tuple[np.ndarray, np.ndarray]:
# Find irreversible reactions.
reaction_idxs_T = list(range(len(model.T.reaction_ids)))
reaction_idxs_F = [model.F.reaction_ids.index(id) for id in model.T.reaction_ids]
only_forward_ids_T = [
i for i in reaction_idxs_T if model.F.lb[reaction_idxs_F[i]] >= 0
]
only_backward_ids_T = [
i for i in reaction_idxs_T if model.F.ub[reaction_idxs_F[i]] <= 0
]
# Construct direction constraints on the free energies polytope, in the form
# G*drg <= h.
drg_G = np.vstack(
(
model.B.to_drg_transform[0][only_forward_ids_T, :],
-model.B.to_drg_transform[0][only_backward_ids_T, :],
)
)
drg_h = -np.vstack(
(
model.B.to_drg_transform[1][only_forward_ids_T, :],
-model.B.to_drg_transform[1][only_backward_ids_T, :],
)
)
return drg_G, drg_h
[docs]def sample_log_conc_from_drg(
thermodynamic_space: ThermodynamicSpace,
drg_samples: pd.DataFrame,
min_eigenvalue: float = default_min_eigenvalue_tds_basis,
) -> pd.DataFrame:
"""Sample the natural logarithm of the metabolite concentrations conditioned on
samples of free energies. This function draws one sample for each sample of reaction
energies.
Parameters
----------
thermodynamic_space : ThermodynamicSpace
The thermodynamic space of the network.
drg_samples : pd.DataFrame
Data frame containing the samples of reaction energies.
min_eigenvalue : float, optional
Minimum eigenvalue to keep when performing the truncated SVD of the covariance
of the conditional probability.
Returns
-------
pd.DataFrame
Data frame containing the log-concentration samples.
"""
basis = ThermodynamicSpaceBasis(
thermodynamic_space,
explicit_drg=True,
explicit_drg0=False,
explicit_log_conc=True,
min_eigenvalue=min_eigenvalue,
)
log_conc_samples = _sample_conditional_mvn(
basis.to_observables_transform[1],
basis.to_observables_transform[0],
drg_samples.to_numpy().T,
basis.observables_ranges["log_conc"],
min_eigenvalue,
)
return pd.DataFrame(log_conc_samples.T, columns=thermodynamic_space.metabolite_ids)
[docs]def sample_drg0_from_drg(
thermodynamic_space: ThermodynamicSpace,
drg_samples: pd.DataFrame,
min_eigenvalue: float = default_min_eigenvalue_tds_basis,
) -> pd.DataFrame:
"""Sample standard reaction energies conditioned on samples of reaction energies.
This function draws one sample for each sample of reaction energies.
Parameters
----------
thermodynamic_space : ThermodynamicSpace
The thermodynamic space of the network.
drg_samples : pd.DataFrame
Data frame containing the samples of reaction energies.
min_eigenvalue : float, optional
Minimum eigenvalue to keep when performing the truncated SVD of the covariance
of the conditional probability.
Returns
-------
pd.DataFrame
Data frame containing the standard reaction energy samples.
"""
basis = ThermodynamicSpaceBasis(
thermodynamic_space,
explicit_drg=True,
explicit_drg0=True,
explicit_log_conc=False,
min_eigenvalue=min_eigenvalue,
)
drg0_samples = _sample_conditional_mvn(
basis.to_observables_transform[1],
basis.to_observables_transform[0],
drg_samples.to_numpy().T,
basis.observables_ranges["drg0"],
min_eigenvalue,
)
return pd.DataFrame(drg0_samples.T, columns=thermodynamic_space.reaction_ids)
def _init_us_worker(args_filename):
global _us_model # pylint: disable=global-statement
_us_model = pickle.load(open(args_filename, "rb"))
def _sample_orthant(args: Tuple[np.ndarray, np.ndarray, int]) -> pd.DataFrame:
global _us_model # pylint: disable=global-statement
model = _us_model
for i, _ in enumerate(model.reactions):
model.reactions[i].lower_bound = args[0][i]
model.reactions[i].upper_bound = args[1][i]
return sample_flux_space_uniform(model, args[2], num_chains=1).samples
def _get_sample_weight(orthant_weights, orthant_signs, sample_signs):
result = np.where(np.all(orthant_signs == sample_signs, axis=1))[0]
if len(result) == 1:
return orthant_weights.iloc[result[0]]
elif len(result) == 0:
return 1
else:
raise Exception("Direction samples should be unique.")
[docs]def sample_fluxes_from_drg(
model: cobra.Model,
drg_samples: pd.DataFrame,
orthants: pd.DataFrame,
num_approx_samples: int = default_num_samples,
) -> pd.DataFrame:
"""Sample the flux space using the samples of orthant of reaction energies and
orthants as prior. For each unique orthant implied by the reaction energy samples,
this function draws a number of uniform flux samples proportional to the probability
of the orthant in the thermodynamic space.
Parameters
----------
model : cobra.Model
cobrapy model describing the flux space.
drg_samples : pd.DataFrame
The input reaction energy samples.
orthants : pd.DataFrame
Data frame containing the sampled orthants and their weights.
num_approx_samples : int, optional
Approximate number of samples to draw.
Returns
-------
pd.DataFrame
Data frame containing the flux samples.
"""
# Get the signs of the reversible reactions for each orthant.
reversible_rxns_ids = list(orthants.columns)
reversible_rxns_ids.remove("weight")
orthant_signs = orthants[reversible_rxns_ids].to_numpy()
# Get the signs of the reversible reactions for each sample of drg.
drg_samples = drg_samples[reversible_rxns_ids]
samples_signs = np.unique(np.where(drg_samples < 0, 1, -1), axis=0).astype(np.int8)
# Find the weight of the orthant corresponding to each sample of drg.
samples_weights = np.array(
[
_get_sample_weight(orthants["weight"], orthant_signs, samples_signs[i, :])
for i in range(samples_signs.shape[0])
]
)
# Scale the weights to obtain approximately the requested number of samples.
num_orthant_samples = samples_weights / np.sum(samples_weights) * num_approx_samples
to_round_up = np.random.rand(num_orthant_samples.size) < (
num_orthant_samples - np.floor(num_orthant_samples)
)
num_orthant_samples = num_orthant_samples.astype(np.int64) + np.where(
to_round_up, 1, 0
)
worker_args = []
original_lb = np.array([r.lower_bound for r in model.reactions])
original_ub = np.array([r.upper_bound for r in model.reactions])
rev_rxns_idxs_F = [model.reactions.index(id) for id in reversible_rxns_ids]
for i in range(num_orthant_samples.size):
if num_orthant_samples[i] > 0:
lb = original_lb.copy()
ub = original_ub.copy()
rev_lb = -np.ones(len(rev_rxns_idxs_F)) * np.inf
rev_ub = np.ones(len(rev_rxns_idxs_F)) * np.inf
rev_lb[samples_signs[i, :] > 0] = 0
rev_ub[samples_signs[i, :] < 0] = 0
lb[rev_rxns_idxs_F] = np.maximum(lb[rev_rxns_idxs_F], rev_lb)
ub[rev_rxns_idxs_F] = np.minimum(ub[rev_rxns_idxs_F], rev_ub)
worker_args.append((lb, ub, num_orthant_samples[i]))
# Create process pool and run uniform sampling on each orthant.
num_processes = multiprocessing.cpu_count()
if num_processes is None:
logger.warning(
"Cannot determine the number of processors available. Assuming 1."
)
num_processes = 1
with tempfile.NamedTemporaryFile(delete=False) as args_file:
pickle.dump(model, open(args_file.name, "wb"))
args_filename = args_file.name
pool = multiprocessing.Pool(
num_processes,
initializer=_init_us_worker,
initargs=[args_filename],
)
pool_result = pool.map(_sample_orthant, worker_args)
samples = pd.concat(pool_result, ignore_index=True)
pool.close()
pool.join()
if os.path.isfile(args_filename):
os.remove(args_filename)
return samples
def _sample_conditional_mvn(
mean: np.ndarray,
cov_sqrt: np.ndarray,
Y_samples: np.ndarray,
X_ids: List[int],
min_eigenvalue: float = default_min_eigenvalue_tds_basis,
) -> np.ndarray:
"""Samples from the probability density of X conditioned on Y."""
Y_ids = sorted(list(set(range(mean.size)) - set(X_ids)))
# https://en.wikipedia.org/wiki/Schur_complement
# [a; b] * [a; b]' = [A, B; B', C]
a = cov_sqrt[X_ids, :]
b = cov_sqrt[Y_ids, :]
A = a @ a.T
B = a @ b.T
C = b @ b.T
C_inv = sp.linalg.pinvh(C)
E_X = mean[X_ids]
E_Y = mean[Y_ids]
X_cov = A - B @ C_inv @ B.T
std_mvn_to_X_transform = (
covariance_square_root(X_cov, min_eigenvalue),
np.zeros_like(E_X),
)
X_means = E_X + B @ C_inv @ (Y_samples - E_Y)
n_samples = Y_samples.shape[1]
n_dimensions = std_mvn_to_X_transform[0].shape[1]
standard_samples = np.random.multivariate_normal(
np.zeros(n_dimensions), np.identity(n_dimensions), n_samples
).T
return apply_transform(standard_samples, std_mvn_to_X_transform) + X_means