Source code for ecgan.utils.interpolation

"""Interpolation schemes between two data points in latent space."""
from typing import List, Union

import torch


[docs]def slerp(mu, low, high): """ Spherical linear interpolation based on `White et al. 2016 <https://arxiv.org/pdf/1609.04468.pdf>`_. Originally introduced in `Shoemake, 1985 <https://www.engr.colostate.edu/ECE481A2/Readings/Rotation_Animation.pdf>`_ and additional visualizations can be found in `Huszár 2017 (Blogpost) <https://www.inference.vc/high-dimensional-gaussian-distributions-are-soap-bubble/>`_. Implementation adapted from `ptrblack, GitHub <https://github.com/ptrblck/prog_gans_pytorch_inference/blob/master/utils.py>`_ and `soumith, GitHub <https://github.com/soumith/dcgan.torch/issues/14>`_. Most probability mass in high-dimensional Gaussian latent spaces is in an annulus around the origin and points around the origin are scarce. To account for this, a meaningful interpolation should traverse the annulus and not travel through the center of the hypersphere. Args: mu: Parameter moving from 0 to 1 the closer it gets to `high`. low: Sample from latent space, origin of the interpolation process. high: Sample from latent space, target of the interpolation process. Returns: Mu-based interpolated sample between low and high. """ low_norm = torch.linalg.norm(low) high_norm = torch.linalg.norm(high) # avoid numerical instability norm is only zero if low is zero # denominator can then be set arbitrarily to avoid division by zero low_norm = torch.where(low_norm != 0, low_norm, torch.ones(low_norm.shape)) high_norm = torch.where(high_norm != 0, high_norm, torch.ones(high_norm.shape)) low_norm_scale = torch.div(low, low_norm) high_norm_scale = torch.div(high, high_norm) omega = torch.acos(torch.matmul(low_norm_scale, high_norm_scale.t())) sin_omega = torch.sin(omega) # L'Hopital's rule/LERP from https://github.com/soumith/dcgan.torch/issues/14 if sin_omega == 0: return (1.0 - mu) * low + mu * high res = (torch.sin((1.0 - mu) * omega) / sin_omega) * low + (torch.sin(mu * omega) / sin_omega) * high return res
[docs]def spherical_interpolation(start, target, num_steps) -> torch.Tensor: """ Perform the interpolation between two points in a specified amount of steps. Args: start: One point (in latent space), of dimensionality (n,). target: Point which is approached during interpolation. num_steps: Amount of steps/samples taken during interpolation. Returns: Tensor: (num_steps x samples.shape). """ interpolation_steps = torch.linspace(start=1 / num_steps, end=1, steps=num_steps) interpolated_samples = [slerp(val, start, target).numpy() for val in interpolation_steps] return torch.as_tensor(interpolated_samples)
[docs]def latent_walk( base_sample: torch.Tensor, component: torch.nn.Module, walk_range: torch.Tensor, device: torch.device, latent_dims: Union[int, List], ) -> torch.Tensor: """ Explore the latent space based on a single latent space sample. Up to 10 dims are visualized. Args: base_sample: Initial latent sample of dim [1,1,latent_dim]. component: The generative module that is used to create new samples. walk_range: The area of the latent sample investigated. device: The device of the NN module. latent_dims: Amount of dims walked through. If more dims exist than selected: Use the first n latent_dims. Only up to 10 latent_dims are allowed for this visualization. Returns: A tensor of reconstructed samples where each latent dim is altered in the direction of walk_range. """ latent_dim_range = latent_dims if isinstance(latent_dims, List) else range(0, min(latent_dims, 10)) with torch.no_grad(): samples = torch.empty(0).to(device) for dim in latent_dim_range: for offset in walk_range: base_clone = base_sample.clone() base_clone[0][0][dim] = base_clone[0][0][dim] + offset gen_samples = component(base_clone) samples = torch.cat((samples, gen_samples[:, :, :1])) return samples