Source code for pnpl.preprocessing.serialization

"""
Serialization utilities for converting MNE data to H5 format.

Provides functions to convert:
- Continuous Raw data to H5 (LibriBrain-style)
- Epoched data to H5 (MegNIST-style)
"""

import os
from typing import Optional, TYPE_CHECKING
import numpy as np
import h5py

if TYPE_CHECKING:
    import mne


[docs] def fif_to_h5( raw: "mne.io.Raw", output_path: str, dtype: np.dtype = np.float32, chunk_size: int = 50, compression: Optional[str] = None, compression_opts: int = 4, ) -> str: """ Convert MNE Raw data to H5 format. Creates an H5 file with structure: - data: (channels, time) - MEG data - times: (time,) - Time vector - Attributes: sample_frequency, highpass_cutoff, lowpass_cutoff, channel_names, channel_types Args: raw: MNE Raw object (preprocessed) output_path: Output H5 file path dtype: Data type for storage (default: float32) chunk_size: Chunk size for H5 datasets compression: Compression algorithm ('gzip' or None) compression_opts: Compression level (1-9) Returns: Path to created H5 file """ import mne # Check for bad channels if len(raw.info['bads']) > 0: raise ValueError(f"Raw data contains bad channels: {raw.info['bads']}") # Extract data times = raw.times.astype(dtype) meg_picks = mne.pick_types(raw.info, meg=True, eeg=False, eog=False) data = raw.get_data(picks=meg_picks).astype(dtype) # Get channel info channel_names = [raw.ch_names[idx] for idx in meg_picks] # MNE's channel type helper moved between versions; prefer the stable APIs. try: channel_types = raw.get_channel_types(picks=meg_picks) except Exception: channel_types = [mne.channel_type(raw.info, idx) for idx in meg_picks] # Create output directory os.makedirs(os.path.dirname(output_path), exist_ok=True) # Write H5 file with h5py.File(output_path, "w") as f: # Data datasets if compression: f.create_dataset( "data", data=data, compression=compression, compression_opts=compression_opts, chunks=(data.shape[0], chunk_size), ) f.create_dataset( "times", data=times, compression=compression, compression_opts=compression_opts, chunks=(chunk_size,), ) else: f.create_dataset( "data", data=data, chunks=(data.shape[0], chunk_size), ) f.create_dataset( "times", data=times, chunks=(chunk_size,), ) # Metadata attributes f.attrs["sample_frequency"] = raw.info["sfreq"] f.attrs["highpass_cutoff"] = raw.info["highpass"] f.attrs["lowpass_cutoff"] = raw.info["lowpass"] f.attrs["channel_names"] = ", ".join(channel_names) f.attrs["channel_types"] = ", ".join(channel_types) return output_path
[docs] def epochs_to_h5( epochs: "mne.Epochs", output_path: str, dtype: np.dtype = np.float32, compression: Optional[str] = None, compression_opts: int = 4, ) -> str: """ Convert MNE Epochs to H5 format. Creates an H5 file with structure: - data: (trials, channels, time) - Epoched MEG data - labels: (trials,) - Event labels - times: (time,) - Time vector - channel_names: (channels,) - Channel names - channel_types: (channels,) - Channel types - sensor_xyz: (channels, 3) - Sensor positions Args: epochs: MNE Epochs object output_path: Output H5 file path dtype: Data type for storage (default: float32) compression: Compression algorithm ('gzip' or None) compression_opts: Compression level (1-9) Returns: Path to created H5 file """ import mne # Extract data data = epochs.get_data().astype(dtype) # (trials, channels, time) times = epochs.times.astype(dtype) # Get labels from events (third column is event ID) labels = epochs.events[:, 2].astype(np.int32) # Get channel info meg_picks = mne.pick_types(epochs.info, meg=True, eeg=False, eog=False) channel_names = [epochs.ch_names[idx] for idx in meg_picks] try: channel_types = epochs.get_channel_types(picks=meg_picks) except Exception: channel_types = [mne.channel_type(epochs.info, idx) for idx in meg_picks] # Get sensor positions locs = np.array([epochs.info['chs'][idx]['loc'][:3] for idx in meg_picks]) # Create output directory os.makedirs(os.path.dirname(output_path), exist_ok=True) # Write H5 file with h5py.File(output_path, "w") as f: # Main datasets if compression: f.create_dataset("data", data=data, compression=compression, compression_opts=compression_opts) f.create_dataset("times", data=times, compression=compression, compression_opts=compression_opts) else: f.create_dataset("data", data=data) f.create_dataset("times", data=times) f.create_dataset("labels", data=labels) # Channel info (stored as bytes for compatibility) f.create_dataset("channel_names", data=np.array(channel_names, dtype='S10')) f.create_dataset("channel_types", data=np.array(channel_types, dtype='S15')) f.create_dataset("sensor_xyz", data=locs.astype(dtype)) return output_path