FastEHR.dataloader.foundational_loader¶
Classes¶
PyTorch-Lightning datamodule for foundational models |
|
Initialize the FoundationalDataset, a dataset class for foundational |
|
supervised: |
Module Contents¶
- class FastEHR.dataloader.foundational_loader.FoundationalDataModule(path_to_db: str, path_to_ds: str, load: bool, tokenizer: str = 'tabular', batch_size: int = 64, min_workers: int = 1, overwrite_practice_ids: str | None = None, overwrite_meta_information: str | None = None, supervised: bool = False, supervised_time_scale: float = 1.0, subsample_training: int | None = None, seed: int = 42, adapter: str | None = None, **kwargs)¶
PyTorch-Lightning datamodule for foundational models
- ARGS:
- practice_patient_id (list[str])
List of practice patient identifiers which satisfy study criteria.
- KWARGS:
batch_size (int):
unk_freq_threshold (float).
Initialize the FoundationalDataModule, a PyTorch Lightning DataModule for foundational models.
This module handles data loading, preprocessing, and batching for training, validation, and testing. It supports both supervised and unsupervised learning tasks and integrates tokenization, meta-information loading, and dataset preparation.
- Parameters:
path_to_db (str) – Full path to the SQLite database folder.
path_to_ds (str) – Full path to the preprocessed dataset folder, either for loading or saving data.
load (bool) – If
True, loads preprocessed dataset from Parquet files. IfFalse, processes raw SQLite data and saves it as Parquet files.tokenizer (str, optional) – Which tokenizer to use.
"tabular"uses the Tabular tokenizer; any other string defaults to NonTabular.batch_size (int, optional) – Number of samples per batch for the DataLoader (default
64).min_workers (int, optional) – Minimum number of workers used for data loading (default
1).overwrite_practice_ids (str | None, optional) – Path to a file containing new practice-ID allocations for train/validation/test splits. Prevents data leakage when creating fine-tuning datasets.
overwrite_meta_information (str | None, optional) – Path to an existing meta-information pickle file. If provided, prevents redundant preprocessing of meta-information.
supervised (bool, optional) – If
True, enables supervised training mode (defaultFalse).supervised_time_scale (float, optional) – Scaling factor applied to supervised target times produced in the collator (default: 10 years). This multiplies the
time_scaleinFoundationalDataset.subsample_training (int | None, optional) – If specified, reduces the training dataset to a random subset of this size.
seed (int | None, optional) – Random seed used when subsampling the training dataset.
kwargs (Any) – Additional keyword arguments passed to
PolarsDataset.fit().
Notes: - Loads tokenizer vocabulary from meta-information and initializes dataset splits. - Supports dataset reprocessing or direct loading from Parquet files. - Uses
Collatorfor batch collation (supports self-supervised and supervised tasks). - Handles training, validation, and test dataset creation with tokenization.- property is_supervised¶
- property context_time_scale¶
- property supervised_time_scale¶
- batch_size = 64¶
- min_workers = 1¶
- tokenizer¶
- train_set¶
- test_set¶
- val_set¶
- adapter¶
- collate_fn¶
- standardise(event, value)¶
- unstandardise(event, value)¶
- encode(sequence: list[str])¶
- decode(sequence: list[int])¶
- train_dataloader()¶
- val_dataloader()¶
- test_dataloader()¶
- class FastEHR.dataloader.foundational_loader.FoundationalDataset(parquet_path: str, split: str, tokenizer: FastEHR.dataloader.tokenizers_local.TokenizerBase, meta_information: dict, file_row_count_dict: dict, max_seq_length: int = 256, standardise_values: bool = True, global_diagnoses: bool = False, repeating_events: bool = True, random_context_window: bool = False, time_scale: float = 1825.0, subsample: int | None = None, seed: int = 42, **kwargs)¶
Initialize the FoundationalDataset, a dataset class for foundational model training.
This dataset is constructed from preprocessed Parquet files and provides tokenized and structured sequences of events, including dynamic and static features.
Parameters¶
- parquet_pathstr
Path to the directory containing Parquet dataset files.
- splitstr
Dataset split type (“train”, “val”, or “test”).
- tokenizerTokenizerBase
Tokenizer used for encoding event sequences.
- meta_informationdict
Dictionary containing meta-information about the dataset, including measurement statistics.
- file_row_count_dictdict
Dictionary mapping Parquet filenames to the number of rows they contain.
- max_seq_lengthint, optional, default=256
The maximum number of tokens in a sequence. Longer sequences are truncated.
- standardise_valuesbool, optional, default=True
Whether to standardize event values based on dataset statistics.
- global_diagnosesbool, optional, default=False
If True, ensures all historical diagnoses are included in each sequence’s context window.
- repeating_eventsbool, optional, default=True
Whether to allow repeated events within a sequence. - True: Retains all occurrences of an event. - False: Keeps only the latest record of each event. These may still fall outside of a context window.
- random_context_windowbool, optional, default=False
- Whether to randomly sample context windows or use the latest
events.
- time_scalefloat, optional, default=1825.0
- The scaling factor applied to age values (default: 5 years so a
model using unit interval looks 5 years ahead)
- subsampleint, optional
If specified, limits the dataset to a random subsample of this size.
Notes¶
- The dataset loads preprocessed patient event sequences and encodes
them using the tokenizer.
Supports both fixed-length and randomly sampled context windows.
- Static covariates are one-hot encoded and stored separately from
dynamic sequences.
- The dataset length is computed based on the sum of row counts across
Parquet files.
- view_sample(idx, max_dynamic_events=None, report_time=False)¶
Displays a sample in a readable format.
- property meta_static¶
- property meta_measurement¶
- parquet_path¶
- sub_dir = 'split=Uninferable/'¶
- tokenizer¶
- max_seq_length = 256¶
- standardise_values = True¶
- global_diagnoses = False¶
- repeating_events = True¶
- random_context_window = False¶
- meta_information¶
- time_scale = 1825.0¶
- subsample = None¶
- seed = 42¶
- warnings_raised = []¶
- file_row_count_dict¶
- total_samples¶
- static_1hot¶
- getitem(idx)¶
- class FastEHR.dataloader.foundational_loader.Collator(supervised=False, supervised_time_scale=2.0, adapter=None)¶
- supervised:
Whether to take the last time point as the target
- supervised_time_scale: float, optional, default 2.0
The scaling factor applied to any supervised target times produced in the Collator (default: 2 imes 5 years).
Note: The collator receives the times from FoundationalDataset() which are already scaled. The way this is coded leads to a multiplicative effect (TODO: put all scaling in the Collator to simplify code and API)
- supervised = False¶
- supervised_time_scale = 2.0¶
- adapter = None¶
- collate_fn(data: list[dict])¶
Collect and collate separate dictionaries.
During this operation, pad the sequence lengths to the maximum length seen within the batch and tokenize
- static convert_to_supervised(batch, supervised_time_scale)¶
Convert a batch to a supervised format for non-causal tasks.
This method is used in conjunction with FoundationalDataModule for non-causal tasks, where the last non-padding token in each sequence is removed and used as a supervised target.
Specifically, this function: - Replaces the last non-padding token in each row with a padding token (0). - Creates new target vectors containing the removed tokens, values, and age deltas. - Ensures alignment of token sequences, masking, and corresponding values.
Parameters¶
- batchdict
A dictionary containing the following keys:
tokens (torch.Tensor): The input tensor containing tokenized sequences with padding.
ages (torch.Tensor): The tensor containing ages corresponding to each token in the matrix.
values (torch.Tensor): The tensor containing values corresponding to each token in the matrix.
attention_mask (torch.Tensor): The tensor containing masks indicating valid tokens.
Returns¶
- dict
A modified batch dictionary containing the following keys:
tokens (torch.Tensor): The input sequence with the last non-padding token replaced by a padding token (0).
ages (torch.Tensor): The modified age matrix with the last non-padding age replaced by 0.
values (torch.Tensor): The modified value matrix with the last non-padding value replaced with np.nan.
attention_mask (torch.Tensor): The modified mask matrix with the last non-padding entry set to 0.
target_token (torch.Tensor): A vector containing the removed tokens.
target_age_delta (torch.Tensor): A vector containing the difference in age between the last two non-padding tokens.
target_value (torch.Tensor): A vector containing the values that were removed from values.
Notes¶
If a sample does not contain at least two non-padding events, a warning is logged, and the sample is removed.
This function prevents information leakage by ensuring that the last event in a sequence is not used as an input.
If convert_to_supervised() is applied multiple times to the same batch, it will be skipped to prevent redundant modifications.