Source code for dsipts.data_structure.time_series_d2

"""
Time Series D2 Layer Module
This module provides the D2 layer for time series data processing:
- TSDataModule: LightningDataModule for time series data with support for training, validation, and testing
- TimeSeriesSubset: Subset implementation for train/val/test splits
- custom_collate_fn: Custom collate function for handling mixed data types
Key Features:
- Creates sliding windows from time series data
- Handles train/validation/test splits (percentage-based or group-based)
- Validates data points based on minimum valid requirements
- Creates DataLoaders for PyTorch Lightning integration
- Efficiently manages memory with caching mechanisms
"""

import random
import torch
from torch.utils.data import Dataset, DataLoader, Sampler
import pytorch_lightning as pl
import pandas as pd
import numpy as np
from typing import Dict, List, Optional, Union, Tuple
import logging

# Import the D1 layer
from dsipts.data_structure.time_series_d1 import MultiSourceTSDataSet

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)


[docs] class TSDataModule(pl.LightningDataModule): """D2 Layer - Processes time series data for model consumption. This module: 1. Creates sliding windows from time series data 2. Handles train/validation/test splits 3. Creates DataLoaders for PyTorch Lightning """
[docs] def __init__( self, d1_dataset: MultiSourceTSDataSet, past_len: int, future_len: int, batch_size: int = 32, min_valid_length: Optional[int] = None, split_method: str = 'percentage', split_config: Optional[tuple] = None, num_workers: int = 0, sampler: Optional[Sampler] = None, memory_efficient: bool = False, known_cols: Optional[List[str]] = None, unknown_cols: Optional[List[str]] = None, precompute: bool = True ): """ Initialize the TSDataModule. Args: d1_dataset: The D1 dataset instance (MultiSourceTSDataSet) past_len: Number of past time steps for input future_len: Number of future time steps for prediction batch_size: Batch size for DataLoaders min_valid_length: Minimum number of valid points required in a window split_method: Method for splitting data ('percentage' or 'group') split_config: Configuration for the split num_workers: Number of workers for DataLoader sampler: Optional custom sampler for the DataLoader memory_efficient: Whether to use memory-efficient mode known_cols: Columns that are known at prediction time (overrides D1 dataset settings) unknown_cols: Columns that are unknown at prediction time (overrides D1 dataset settings) precompute: Whether to precompute valid indices and create datasets """ super().__init__() self.d1_dataset = d1_dataset self.past_len = past_len self.future_len = future_len self.batch_size = batch_size self.min_valid_length = min_valid_length or past_len self.split_method = split_method self.split_config = split_config self.num_workers = num_workers self.sampler = sampler self.memory_efficient = memory_efficient # Store reference to D1 dataset metadata self.metadata = d1_dataset.metadata.copy() # Set feature and target columns from D1 dataset self.feature_cols = d1_dataset.feature_cols self.target_cols = d1_dataset.target_cols self.static_cols = d1_dataset.static_cols # Override known/unknown columns if provided self.known_cols = known_cols if known_cols is not None else d1_dataset.known_cols self.unknown_cols = unknown_cols if unknown_cols is not None else d1_dataset.unknown_cols # Update metadata with known/unknown columns self.metadata['known_cols'] = self.known_cols self.metadata['unknown_cols'] = self.unknown_cols # Default split configuration based on method if split_config is None: if split_method == 'percentage': self.split_config = (0.7, 0.15, 0.15) # Default: 70% train, 15% val, 15% test else: raise ValueError("For 'group' split method, split_config must be provided") # Whether to precompute valid indices and create datasets self.precompute = precompute # Initialize dataset attributes self.train_dataset = None self.val_dataset = None self.test_dataset = None # Initialize the module self._initialize() # Add max classes information to metadata self._add_max_classes_to_metadata()
def _initialize(self): """Initialize the dataset by computing valid indices and mappings.""" print("Computing valid indices and mappings...") self.valid_indices = self._compute_valid_indices() self.mapping = self._create_global_mapping() self.length = len(self.mapping) # Create splits if not already done if not hasattr(self, 'train_indices') or not self.train_indices: if self.split_config is not None: print(f"Creating {self.split_method} splits with config: {self.split_config}") # Create splits with the new config self.train_indices, self.val_indices, self.test_indices = self._create_splits(self.split_config) print(f"Split statistics: Train: {len(self.train_indices)}, Validation: {len(self.val_indices)}, Test: {len(self.test_indices)}") else: # Default to all indices as training self.train_indices = list(range(self.length)) self.val_indices = [] self.test_indices = [] # Create subset datasets for train/val/test using the indices if precompute is enabled # Otherwise, the datasets will be created on-demand during setup if self.precompute: print("Precomputing datasets for train/val/test splits...") if hasattr(self, 'train_indices') and self.train_indices: self.train_dataset = TimeSeriesSubset(self, self.train_indices) if hasattr(self, 'val_indices') and self.val_indices: self.val_dataset = TimeSeriesSubset(self, self.val_indices) if hasattr(self, 'test_indices') and self.test_indices: self.test_dataset = TimeSeriesSubset(self, self.test_indices) else: print("Datasets will be created on-demand during setup...") def _compute_valid_indices(self): """ Compute valid indices for all groups in the dataset. This method ensures that windows are valid based on: 1. Having enough valid points in the past window 2. Having at least one valid point in the future window Optimizations: - Vectorized operations for NaN checks - Early termination for invalid regions - Efficient masking for bulk validation Returns: Dictionary mapping group indices to lists of valid indices """ valid_indices = {} device = torch.device('cpu') # Use CPU for consistency across environments for i in range(len(self.d1_dataset)): # Load group data from D1 dataset group_data = self.d1_dataset[i] # Fetch feature and target tensors features = group_data.get('x', torch.tensor([])) targets = group_data.get('y', torch.tensor([])) # Get the time series length ts_length = len(features) # Skip if time series is too short if ts_length < (self.past_len + self.future_len): valid_indices[i] = [] continue # Pre-compute all possible window start indices all_indices = list(range(ts_length - (self.past_len + self.future_len) + 1)) # Create validity masks for features and targets feature_is_valid = ~torch.isnan(features) if features.numel() > 0 else torch.ones((ts_length, 1), dtype=torch.bool) target_is_valid = ~torch.isnan(targets) if targets.numel() > 0 else torch.ones((ts_length, 1), dtype=torch.bool) # If features or targets are multi-dimensional, check if any dimension has NaN if feature_is_valid.dim() > 1: feature_is_valid = feature_is_valid.all(dim=1) if target_is_valid.dim() > 1: target_is_valid = target_is_valid.all(dim=1) # Combined mask for both features and targets combined_mask = feature_is_valid & target_is_valid # Vectorized approach to find valid windows group_valid_indices = [] # Process each potential window t = 0 while t < len(all_indices): idx = all_indices[t] end_past = idx + self.past_len end_future = end_past + self.future_len # Check past validity - count valid points in past window past_valid_count = combined_mask[idx:end_past].sum().item() past_valid = past_valid_count >= self.min_valid_length # Early termination - if past isn't valid, skip this window if not past_valid: # Find the next valid point after the current position next_valid_idx = None for j in range(idx + 1, min(ts_length, idx + self.past_len * 2)): if combined_mask[j]: next_valid_idx = j break if next_valid_idx is not None: # Skip to a position where this valid point would be in the window # but not at the end (to maximize valid points in window) skip_to = max(t + 1, next_valid_idx - self.past_len + 1) t = skip_to else: # No valid points ahead, skip to the end t = len(all_indices) continue # Check future validity - at least one point must be valid future_valid = combined_mask[end_past:end_future].any().item() if past_valid and future_valid: group_valid_indices.append(idx) t += 1 valid_indices[i] = group_valid_indices # Print statistics total_valid = sum(len(indices) for indices in valid_indices.values()) groups_with_valid = sum(1 for indices in valid_indices.values() if len(indices) > 0) print(f"Found {total_valid} valid windows across {groups_with_valid} groups") return valid_indices def _create_global_mapping(self): """ Create a global mapping from index to (group_idx, local_idx). This allows O(1) lookup of samples by global index. Returns: List of (group_idx, local_idx) tuples """ mapping = [] # Track statistics for reporting total_valid = 0 groups_with_valid = 0 for group_idx, local_indices in self.valid_indices.items(): if local_indices: # Only add groups with valid indices for start_idx in local_indices: mapping.append((group_idx, start_idx)) total_valid += len(local_indices) groups_with_valid += 1 print(f"Created global mapping with {len(mapping)} windows from {groups_with_valid} groups") return mapping def _get_group_data(self, group_idx): """ Get data for a specific group, using cache if available. This method: 1. Checks if the group data is already in the cache 2. If not, loads it from the D1 dataset Args: group_idx: Index of the group to retrieve Returns: Dictionary containing the group data """ # Load the group data from D1 dataset group_data = self.d1_dataset[group_idx] return group_data def _add_max_classes_to_metadata(self): """ Add information about maximum number of classes for categorical features. This is useful for model architecture decisions. """ # Copy max classes information from D1 dataset metadata if 'max_classes' in self.d1_dataset.metadata: self.metadata['max_classes'] = self.d1_dataset.metadata['max_classes'].copy() # Ensure known/unknown column categorization is properly reflected in metadata self.metadata['known_cols'] = self.known_cols self.metadata['unknown_cols'] = self.unknown_cols # Add categorical/numerical classification for known/unknown columns self.metadata['known_cat_cols'] = [col for col in self.known_cols if col in self.d1_dataset.cat_cols] self.metadata['known_num_cols'] = [col for col in self.known_cols if col not in self.d1_dataset.cat_cols] self.metadata['unknown_cat_cols'] = [col for col in self.unknown_cols if col in self.d1_dataset.cat_cols] self.metadata['unknown_num_cols'] = [col for col in self.unknown_cols if col not in self.d1_dataset.cat_cols] def _create_splits(self, split_config): """ Create train/validation/test splits based on the specified configuration. Args: split_config: Configuration for splits: - For 'percentage' method: (train%, val%, test%) - For 'group' method: (train_groups, val_groups, test_groups) Returns: Tuple of (train_indices, val_indices, test_indices) """ if self.split_method == 'percentage': # Percentage-based split (temporal or random) train_pct, val_pct, test_pct = split_config total_samples = len(self.mapping) # Normalize percentages if they don't sum to 1 total_pct = train_pct + val_pct + test_pct if abs(total_pct - 1.0) > 1e-6: train_pct /= total_pct val_pct /= total_pct test_pct /= total_pct # Calculate number of samples for each split train_size = int(total_samples * train_pct) val_size = int(total_samples * val_pct) test_size = total_samples - train_size - val_size # For temporal splits, sort by time all_indices = list(range(total_samples)) # Try to sort by time if available try: # Collect time values for each window time_values = [] for idx in all_indices: group_idx, local_idx = self.mapping[idx] group_data = self.d1_dataset[group_idx] # Use the first time point of each window time_point = group_data['t'][local_idx] # Convert to numeric value for sorting if isinstance(time_point, (str, np.datetime64)): # Convert to timestamp (seconds since epoch) time_numeric = pd.to_datetime(time_point).timestamp() else: time_numeric = float(time_point) time_values.append(time_numeric) # Sort indices by time sorted_indices = [idx for _, idx in sorted(zip(time_values, all_indices))] # Split into train/val/test train_indices = sorted_indices[:train_size] val_indices = sorted_indices[train_size:train_size + val_size] test_indices = sorted_indices[train_size + val_size:] print(f"Temporal percentage-based split - Train: {len(train_indices)}, Val: {len(val_indices)}, Test: {len(test_indices)} samples") except (TypeError, ValueError) as e: # Fallback to random split if time-based sorting fails print(f"Warning: Could not sort by time ({str(e)}), using random split instead") random.shuffle(all_indices) train_indices = all_indices[:train_size] val_indices = all_indices[train_size:train_size + val_size] test_indices = all_indices[train_size + val_size:] print(f"Random percentage-based split - Train: {len(train_indices)}, Val: {len(val_indices)}, Test: {len(test_indices)} samples") elif self.split_method == 'group': # Group-based split train_groups, val_groups, test_groups = split_config # Convert to sets for faster lookup train_groups_set = set(train_groups) val_groups_set = set(val_groups) test_groups_set = set(test_groups) # Assign indices to splits based on group membership train_indices = [] val_indices = [] test_indices = [] for idx, (group_idx, _) in enumerate(self.mapping): group_id = self.d1_dataset[group_idx]['group_id'] if group_id in train_groups_set: train_indices.append(idx) elif group_id in val_groups_set: val_indices.append(idx) elif group_id in test_groups_set: test_indices.append(idx) else: # Default to train if not specified train_indices.append(idx) print(f"Group-based split - Train: {len(train_indices)} (from {len(train_groups)} groups), " f"Val: {len(val_indices)} (from {len(val_groups)} groups), " f"Test: {len(test_indices)} (from {len(test_groups)} groups) samples") else: raise ValueError(f"Unknown split method: {self.split_method}") return train_indices, val_indices, test_indices
[docs] def __len__(self): """Return the number of valid samples in the dataset.""" if self.length is not None: return self.length else: # Calculate length on first access if not precomputed self.valid_indices = self._compute_valid_indices() self.mapping = self._create_global_mapping() self.length = len(self.mapping) return self.length
[docs] def __getitem__(self, idx): """ Get a time series window by global index. This method: 1. Maps the global index to a specific group and local index 2. Extracts the window from the group data 3. Returns the window in a format suitable for model training Args: idx: Global index of the window to retrieve Returns: Dictionary containing: - past_features: Tensor of past features - past_time: Array of past time points - future_targets: Tensor of future targets - future_time: Array of future time points - group_id: Group identifier - static: Static features tensor """ # Map global index to group and local index group_idx, local_idx = self.mapping[idx] # Get the group data group_data = self._get_group_data(group_idx) # Get the start and end indices for the window start_idx = local_idx past_end_idx = start_idx + self.past_len future_end_idx = past_end_idx + self.future_len # Extract past and future windows past_features = group_data['x'][start_idx:past_end_idx] past_time = group_data['t'][start_idx:past_end_idx] future_targets = group_data['y'][past_end_idx:future_end_idx] future_time = group_data['t'][past_end_idx:future_end_idx] # Get static features static = group_data.get('st', torch.tensor([])) # Return the window as a dictionary return { 'past_features': past_features, 'past_time': past_time, 'future_targets': future_targets, 'future_time': future_time, 'group_id': group_data['group_id'], 'static': static }
[docs] def setup(self, stage=None): """ Prepare data for the given stage. Args: stage: Either 'fit' or 'test' """ # If we haven't precomputed valid indices yet, do it now if not hasattr(self, 'valid_indices') or self.valid_indices is None: print("Computing valid indices...") self.valid_indices = self._compute_valid_indices() self.mapping = self._create_global_mapping() self.length = len(self.mapping) # Create splits if not already done if not hasattr(self, 'train_indices') or not self.train_indices: if self.split_config is not None: print(f"Creating {self.split_method} splits with config: {self.split_config}") # Create splits with the new config self.train_indices, self.val_indices, self.test_indices = self._create_splits(self.split_config) print(f"Split statistics: Train: {len(self.train_indices)}, Validation: {len(self.val_indices)}, Test: {len(self.test_indices)}") else: # Default to all indices as training self.train_indices = list(range(self.length)) self.val_indices = [] self.test_indices = [] # If precompute is False, create datasets on-demand during setup # Otherwise, datasets were already created during initialization if not self.precompute: if stage == 'fit' or stage is None: self.train_dataset = TimeSeriesSubset(self, self.train_indices) self.val_dataset = TimeSeriesSubset(self, self.val_indices) if stage == 'test' or stage is None: self.test_dataset = TimeSeriesSubset(self, self.test_indices)
[docs] def train_dataloader(self): """Return a DataLoader for training.""" if self.sampler is not None: return DataLoader( self.train_dataset, batch_size=self.batch_size, sampler=self.sampler(self.train_dataset), num_workers=self.num_workers, collate_fn=custom_collate_fn ) else: return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, collate_fn=custom_collate_fn )
[docs] def val_dataloader(self): """Return a DataLoader for validation.""" if len(self.val_dataset) == 0: return None return DataLoader( self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, collate_fn=custom_collate_fn )
[docs] def test_dataloader(self): """Return a DataLoader for testing.""" if len(self.test_dataset) == 0: return None return DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, collate_fn=custom_collate_fn )
[docs] class TimeSeriesSubset(Dataset): """Subset of a D2 processor dataset that implements the Dataset interface."""
[docs] def __init__(self, data_module, indices): """ Initialize the TimeSeriesSubset. Args: data_module: The TSDataModule instance (stored as reference, not copy) indices: List of indices to include in this subset """ # In Python, this assignment creates a reference to the original data_module object # No copying occurs, so all subsets share the same data_module instance self.data_module = data_module self.indices = indices
[docs] def __len__(self): """Return the number of samples in this subset.""" return len(self.indices)
[docs] def __getitem__(self, idx): """Get a sample from the data module using the mapped index.""" return self.data_module[self.indices[idx]]
[docs] def custom_collate_fn(batch): """ Custom collate function for the DataLoader to handle mixed data types. Handles static features that may be objects or other non-tensor types. """ elem = batch[0] result = {} # Process each key in the batch for key in elem: if key in ['st', 'group_id', 't', 'v']: # Special handling for non-tensor data # Store as lists result[key] = [sample[key] for sample in batch] else: # Default handling for tensors # For tensors, we can stack them try: result[key] = torch.stack([sample[key] for sample in batch]) except: # If stacking fails, just store as a list result[key] = [sample[key] for sample in batch] return result