FastEHR.dataloader.foundational_loader ====================================== .. py:module:: FastEHR.dataloader.foundational_loader Classes ------- .. autoapisummary:: FastEHR.dataloader.foundational_loader.FoundationalDataModule FastEHR.dataloader.foundational_loader.FoundationalDataset FastEHR.dataloader.foundational_loader.Collator Module Contents --------------- .. py:class:: FoundationalDataModule(path_to_db: str, path_to_ds: str, load: bool, tokenizer: str = 'tabular', batch_size: int = 64, min_workers: int = 1, overwrite_practice_ids: Optional[str] = None, overwrite_meta_information: Optional[str] = None, supervised: bool = False, supervised_time_scale: float = 1.0, subsample_training: Optional[int] = None, seed: int = 42, adapter: Optional[str] = None, **kwargs) PyTorch-Lightning datamodule for foundational models ARGS: practice_patient_id (list[str]) List of practice patient identifiers which satisfy study criteria. KWARGS: batch_size (int): unk_freq_threshold (float). Initialize the `FoundationalDataModule`, a PyTorch Lightning DataModule for foundational models. This module handles data loading, preprocessing, and batching for training, validation, and testing. It supports both supervised and unsupervised learning tasks and integrates tokenization, meta-information loading, and dataset preparation. :param path_to_db: Full path to the SQLite database folder. :type path_to_db: str :param path_to_ds: Full path to the preprocessed dataset folder, either for loading or saving data. :type path_to_ds: str :param load: If ``True``, loads preprocessed dataset from Parquet files. If ``False``, processes raw SQLite data and saves it as Parquet files. :type load: bool :param tokenizer: Which tokenizer to use. ``"tabular"`` uses the Tabular tokenizer; any other string defaults to NonTabular. :type tokenizer: str, optional :param batch_size: Number of samples per batch for the DataLoader (default ``64``). :type batch_size: int, optional :param min_workers: Minimum number of workers used for data loading (default ``1``). :type min_workers: int, optional :param overwrite_practice_ids: Path to a file containing new practice-ID allocations for train/validation/test splits. Prevents data leakage when creating fine-tuning datasets. :type overwrite_practice_ids: str | None, optional :param overwrite_meta_information: Path to an existing meta-information pickle file. If provided, prevents redundant preprocessing of meta-information. :type overwrite_meta_information: str | None, optional :param supervised: If ``True``, enables supervised training mode (default ``False``). :type supervised: bool, optional :param supervised_time_scale: Scaling factor applied to supervised target times produced in the collator (default: 10 years). This multiplies the ``time_scale`` in ``FoundationalDataset``. :type supervised_time_scale: float, optional :param subsample_training: If specified, reduces the training dataset to a random subset of this size. :type subsample_training: int | None, optional :param seed: Random seed used when subsampling the training dataset. :type seed: int | None, optional :param kwargs: Additional keyword arguments passed to ``PolarsDataset.fit()``. :type kwargs: Any Notes: - Loads tokenizer vocabulary from meta-information and initializes dataset splits. - Supports dataset reprocessing or direct loading from Parquet files. - Uses ``Collator`` for batch collation (supports self-supervised and supervised tasks). - Handles training, validation, and test dataset creation with tokenization. .. py:property:: is_supervised .. py:property:: context_time_scale .. py:property:: supervised_time_scale .. py:attribute:: batch_size :value: 64 .. py:attribute:: min_workers :value: 1 .. py:attribute:: tokenizer .. py:attribute:: train_set .. py:attribute:: test_set .. py:attribute:: val_set .. py:attribute:: adapter .. py:attribute:: collate_fn .. py:method:: standardise(event, value) .. py:method:: unstandardise(event, value) .. py:method:: encode(sequence: list[str]) .. py:method:: decode(sequence: list[int]) .. py:method:: train_dataloader() .. py:method:: val_dataloader() .. py:method:: test_dataloader() .. py:class:: FoundationalDataset(parquet_path: str, split: str, tokenizer: FastEHR.dataloader.tokenizers_local.TokenizerBase, meta_information: dict, file_row_count_dict: dict, max_seq_length: int = 256, standardise_values: bool = True, global_diagnoses: bool = False, repeating_events: bool = True, random_context_window: bool = False, time_scale: float = 1825.0, subsample: Optional[int] = None, seed: int = 42, **kwargs) Initialize the `FoundationalDataset`, a dataset class for foundational model training. This dataset is constructed from preprocessed Parquet files and provides tokenized and structured sequences of events, including dynamic and static features. Parameters ---------- parquet_path : str Path to the directory containing Parquet dataset files. split : str Dataset split type (`"train"`, `"val"`, or `"test"`). tokenizer : TokenizerBase Tokenizer used for encoding event sequences. meta_information : dict Dictionary containing meta-information about the dataset, including measurement statistics. file_row_count_dict : dict Dictionary mapping Parquet filenames to the number of rows they contain. max_seq_length : int, optional, default=256 The maximum number of tokens in a sequence. Longer sequences are truncated. standardise_values : bool, optional, default=True Whether to standardize event values based on dataset statistics. global_diagnoses : bool, optional, default=False If True, ensures all historical diagnoses are included in each sequence's context window. repeating_events : bool, optional, default=True Whether to allow repeated events within a sequence. - `True`: Retains all occurrences of an event. - `False`: Keeps only the latest record of each event. These may still fall outside of a context window. random_context_window : bool, optional, default=False Whether to randomly sample context windows or use the latest events. time_scale : float, optional, default=1825.0 The scaling factor applied to age values (default: 5 years so a model using unit interval looks 5 years ahead) subsample : int, optional If specified, limits the dataset to a random subsample of this size. Notes ----- - The dataset loads preprocessed patient event sequences and encodes them using the tokenizer. - Supports both fixed-length and randomly sampled context windows. - Static covariates are one-hot encoded and stored separately from dynamic sequences. - The dataset length is computed based on the sum of row counts across Parquet files. .. py:method:: view_sample(idx, max_dynamic_events=None, report_time=False) Displays a sample in a readable format. .. py:property:: meta_static .. py:property:: meta_measurement .. py:attribute:: parquet_path .. py:attribute:: sub_dir :value: 'split=Uninferable/' .. py:attribute:: tokenizer .. py:attribute:: max_seq_length :value: 256 .. py:attribute:: standardise_values :value: True .. py:attribute:: global_diagnoses :value: False .. py:attribute:: repeating_events :value: True .. py:attribute:: random_context_window :value: False .. py:attribute:: meta_information .. py:attribute:: time_scale :value: 1825.0 .. py:attribute:: subsample :value: None .. py:attribute:: seed :value: 42 .. py:attribute:: warnings_raised :value: [] .. py:attribute:: file_row_count_dict .. py:attribute:: total_samples .. py:attribute:: static_1hot .. py:method:: getitem(idx) .. py:class:: Collator(supervised=False, supervised_time_scale=2.0, adapter=None) supervised: Whether to take the last time point as the target supervised_time_scale: float, optional, default 2.0 The scaling factor applied to any supervised target times produced in the Collator (default: 2 imes 5 years). Note: The collator receives the times from `FoundationalDataset()` which are already scaled. The way this is coded leads to a multiplicative effect (TODO: put all scaling in the Collator to simplify code and API) .. py:attribute:: supervised :value: False .. py:attribute:: supervised_time_scale :value: 2.0 .. py:attribute:: adapter :value: None .. py:method:: collate_fn(data: list[dict]) Collect and collate separate dictionaries. During this operation, pad the sequence lengths to the maximum length seen within the batch and tokenize .. py:method:: convert_to_supervised(batch, supervised_time_scale) :staticmethod: Convert a batch to a supervised format for non-causal tasks. This method is used in conjunction with `FoundationalDataModule` for non-causal tasks, where the last non-padding token in each sequence is removed and used as a supervised target. Specifically, this function: - Replaces the last non-padding token in each row with a padding token (0). - Creates new target vectors containing the removed tokens, values, and age deltas. - Ensures alignment of token sequences, masking, and corresponding values. Parameters ---------- batch : dict A dictionary containing the following keys: - **tokens** (*torch.Tensor*): The input tensor containing tokenized sequences with padding. - **ages** (*torch.Tensor*): The tensor containing ages corresponding to each token in the matrix. - **values** (*torch.Tensor*): The tensor containing values corresponding to each token in the matrix. - **attention_mask** (*torch.Tensor*): The tensor containing masks indicating valid tokens. Returns ------- dict A modified batch dictionary containing the following keys: - **tokens** (*torch.Tensor*): The input sequence with the last non-padding token replaced by a padding token (0). - **ages** (*torch.Tensor*): The modified age matrix with the last non-padding age replaced by 0. - **values** (*torch.Tensor*): The modified value matrix with the last non-padding value replaced with `np.nan`. - **attention_mask** (*torch.Tensor*): The modified mask matrix with the last non-padding entry set to 0. - **target_token** (*torch.Tensor*): A vector containing the removed tokens. - **target_age_delta** (*torch.Tensor*): A vector containing the difference in age between the last two non-padding tokens. - **target_value** (*torch.Tensor*): A vector containing the values that were removed from `values`. Notes ----- - If a sample does not contain at least two non-padding events, a warning is logged, and the sample is removed. - This function prevents information leakage by ensuring that the last event in a sequence is not used as an input. - If `convert_to_supervised()` is applied multiple times to the same batch, it will be skipped to prevent redundant modifications.