"""Custom torch layers for neural architectures."""
from functools import partial
from typing import Optional, cast
import torch
from torch import Tensor, nn
from torch.distributions.normal import Normal
from torch.nn import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from ecgan.config import NormalInitializationConfig, UniformInitializationConfig, WeightInitializationConfig
from ecgan.utils.custom_types import WeightInitialization
[docs]class MinibatchDiscrimination(nn.Module):
    """Minibatch discrimination layer based on https://gist.github.com/t-ae/732f78671643de97bbe2c46519972491."""
    def __init__(
        self,
        in_features: int,
        out_features: int,
        kernel_dims: int = 16,
        calc_mean: bool = False,
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.kernel_dims = kernel_dims
        self.calc_mean = calc_mean
        self.t_mat = nn.Parameter(Normal(0, 1).sample((in_features, out_features, kernel_dims)))
[docs]    def forward(self, x: Tensor) -> Tensor:
        """Forward pass of the Minibatch Discriminator."""
        # x is NxA
        # T is AxBxC
        out = x.mm(self.t_mat.view(self.in_features, -1))
        out = out.view(-1, self.out_features, self.kernel_dims)
        out = out.unsqueeze(0)  # 1xNxBxC
        out_perm = out.permute(1, 0, 2, 3)  # Nx1xBxC
        norm = torch.abs(out - out_perm).sum(3)  # NxNxB
        exp_norm = torch.exp(-norm)
        o_b = exp_norm.sum(0) - 1  # NxB, subtract self distance
        if self.calc_mean:
            o_b /= x.shape[0] - 1
        x = torch.cat([x, o_b], dim=1)
        return x  
[docs]class MinibatchDiscriminationSimple(nn.Module):
    """From `Karras et al. 2018 <https://arxiv.org/pdf/1710.10196.pdf>`_."""
[docs]    @staticmethod
    def forward(x: Tensor) -> Tensor:
        """Forward pass of the Minibatch Discriminator."""
        out_stds = torch.std(x, dim=0)
        out = torch.mean(out_stds).unsqueeze(0)
        return out.expand(x.shape[0], 1).detach()  
[docs]def initialize_weights(
    network: nn.Module,
    init_config: WeightInitializationConfig,
) -> None:
    """
    Initialize weights of a Torch architecture.
    Currently supported are:
        - 'normal': Sampling from a normal distribution. Parameters: mean, std
        - 'uniform': Sampling from a uniform distribution. Parameters: upper_bound,
           lower_bound
        - 'he': He initialization . He, K. et al. (2015)
        - 'glorot': Glorot, X. & Bengio, Y. (2010)
    Biases and BatchNorm are not initialized with this function as different strategies are applicable for these
    tensors/layers. Therefore the standard initialization of PyTorch when creating the layers is taken in these cases.
    """
    weight_init = partial(_initialize_weights, init_cfg=init_config)
    network.apply(weight_init) 
def _initialize_weights(layer: nn.Module, init_cfg: WeightInitializationConfig) -> None:
    """
    Initialize the weights of a given layer.
    Args:
        layer: Layer to initialize weights in.
        init_cfg: Configuration for weight initialization.
    """
    if is_normalization_layer(layer):
        return
    if init_cfg.weight_init_type == WeightInitialization.NORMAL:
        normal_cfg = cast(NormalInitializationConfig, init_cfg)
        _init_normal(layer, mean=normal_cfg.MEAN, std=normal_cfg.STD)
    elif init_cfg.weight_init_type == WeightInitialization.UNIFORM:
        uniform_cfg = cast(UniformInitializationConfig, init_cfg)
        _init_uniform(layer, lower_bound=uniform_cfg.LOWER_BOUND, upper_bound=uniform_cfg.UPPER_BOUND)
    elif init_cfg.weight_init_type == WeightInitialization.HE.value:
        _init_he(layer)
    elif init_cfg.weight_init_type == WeightInitialization.GLOROT_UNIFORM:
        _init_glorot_uniform(layer)
    elif init_cfg.weight_init_type == WeightInitialization.GLOROT_NORMAL:
        _init_glorot_normal(layer)
    else:
        raise ValueError('Initialization "{}" is not known.'.format(init_cfg.NAME))
[docs]def initialize_batchnorm(module: nn.Module, **kwargs):
    """Explicitly initialize batchnorm layers with a normal distribution."""
    for layer in module.modules():
        if isinstance(layer, _BatchNorm):
            layer.weight.data.normal_(kwargs.get('mean', 1.0), kwargs.get('std', 0.02))
            layer.bias.data.fill_(kwargs.get('bias', 0)) 
def _init_normal(module: nn.Module, mean: Optional[float] = None, std: Optional[float] = None) -> None:
    """Initialize a nn.Module by sampling from a normal distribution."""
    mean = 0.0 if mean is None else mean
    std = 0.02 if std is None else std
    if hasattr(module, 'weight'):
        nn.init.normal_(module.weight, mean, std)  # type: ignore
def _init_uniform(
    module: nn.Module,
    lower_bound: Optional[float] = None,
    upper_bound: Optional[float] = None,
) -> None:
    """Initialize a nn.Module by sampling from a uniform distribution."""
    lower_bound = 0.0 if lower_bound is None else lower_bound
    upper_bound = 1.0 if upper_bound is None else upper_bound
    if hasattr(module, 'weight'):
        nn.init.uniform_(module.weight, lower_bound, upper_bound)  # type: ignore
def _init_he(module: nn.Module) -> None:
    """Initialize a nn.Module with He initialization."""
    if hasattr(module, 'weight'):
        nn.init.kaiming_normal_(module.weight)
def _init_glorot_uniform(module: nn.Module) -> None:
    """Initialize a nn.Module with Glorot initialization."""
    if hasattr(module, 'weight'):
        nn.init.xavier_uniform_(module.weight)  # type: ignore
def _init_glorot_normal(layer: nn.Module) -> None:
    """Initialize a nn.Module with Glorot initialization."""
    if hasattr(layer, 'weight'):
        nn.init.xavier_normal_(layer.weight)  # type: ignore
[docs]def is_normalization_layer(module: nn.Module):
    """Check if a module is a input normalization layer."""
    if isinstance(module, (_BatchNorm, GroupNorm)):
        return True
    return False