"""Implementations of custom distributions."""
import torch
[docs]class TruncatedNormal:
"""
Sample from a normal distribution truncated to lie within an upper and a lower limit a and b.
Inspired by https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/20.
Args:
mu: Mean of the parent normal distribution to sample from.
sigma: Standard deviation of the parent normal distribution to sample from.
lower_limit: Lower threshold of the truncated distribution.
upper_limit: Upper threshold of the truncated distribution.
"""
def __init__(
self,
mu: float = 0.0,
sigma: float = 1.0,
lower_limit: float = -2.0,
upper_limit: float = 2.0,
):
self.uniform = torch.distributions.uniform.Uniform(low=0, high=1)
self.normal = torch.distributions.normal.Normal(0, 1, validate_args=False)
self.alpha = (lower_limit - mu) / sigma
self.beta = (upper_limit - mu) / sigma
self.mu = mu
self.sigma = sigma
self.lower_limit = lower_limit
self.upper_limit = upper_limit
[docs] def sample(self, shape):
"""Generate uniform random variable and apply inverse CDF."""
# Following https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
# p
uniform = self.uniform.sample(shape)
# Φ(μ, σ^2; a)
alpha_normal_cdf = self.normal.cdf(self.alpha)
# x =Φ^{−1}(μ, σ^2; Φ(μ, σ^2; a) + p · (Φ(μ, σ^2; b) − Φ(μ, σ^2; a))) which means:
# Φ(μ, σ^2; a) + p · (Φ(μ, σ^2; b) − Φ(μ, σ^2; a))
inner_inverse = alpha_normal_cdf + (self.normal.cdf(self.beta) - alpha_normal_cdf) * uniform
epsilon = torch.finfo(inner_inverse.dtype).eps
# with x =Φ^{−1}(μ, σ^2;inner_inverse) and numerical stability:
# erf is not erf(x) but erf(x/sqrt(2)) which will accounted for below
erf = torch.clamp(2 * inner_inverse - 1, -1 + epsilon, 1 - epsilon)
# given std normal distribution: samples x = mu + sigma * xi = 0+1*xi = xi clamped to be between given limits
samples = self.mu + self.sigma * torch.sqrt(torch.tensor(2.0)) * torch.erfinv(erf)
samples = torch.clamp(samples, self.lower_limit, self.upper_limit)
return samples