"""Dataset base class extending the PyTorch Dataset class specifications."""
import random
from abc import ABC, abstractmethod
from typing import Dict
from torch import Tensor, stack
from torch.utils.data import Dataset
[docs]class BaseDataset(Dataset, ABC):
"""Extend PyTorch Dataset class with explicit sampling and __len__ function."""
[docs] def sample(self, batch_size: int) -> Dict:
"""
Sample a batch directly from the dataset.
Args:
batch_size: Amount of samples to sample.
Returns:
Dict containing the samples and its attributes.
"""
indices = random.sample(range(len(self)), batch_size)
sample_dicts = [self[idx] for idx in indices]
collated_samples: Dict = {key: [] for key in sample_dicts[0].keys()}
for sample_dict in sample_dicts:
for key, val in sample_dict.items():
collated_samples[key].append(val)
return {key: stack(val) for key, val in collated_samples.items()}
@abstractmethod
def __len__(self) -> int:
"""
Return number of samples in dataset.
Returns:
Number of samples in dataset.
"""
raise NotImplementedError("Dataset needs to implement the `__len__` method.")
[docs]class SeriesDataset(BaseDataset):
"""PyTorch Dataset class for time series that are preprocessed using :code:`ecgan-preprocess`."""
def __init__(self, data: Tensor, label: Tensor):
"""Load dataset to memory and transforms it to tensor."""
self.data = data
self.label = label
self.num_classes = len(self.label.unique())
def __len__(self) -> int:
"""
Return number of samples in dataset.
Returns:
Number of samples in dataset.
"""
return len(self.label)
def __getitem__(self, idx: int) -> dict:
"""
Given an index, return the corresponding data pair.
Args:
idx: Index of entry in dataset.
Returns:
Dict with the respective time series and its label.
"""
return {'data': self.data[idx], 'label': self.label[idx]}