FastEHR.dataloader.foundational_loader

Classes

FoundationalDataModule

PyTorch-Lightning datamodule for foundational models

FoundationalDataset

Initialize the FoundationalDataset, a dataset class for foundational

Collator

supervised:

Module Contents

class FastEHR.dataloader.foundational_loader.FoundationalDataModule(path_to_db: str, path_to_ds: str, load: bool, tokenizer: str = 'tabular', batch_size: int = 64, min_workers: int = 1, overwrite_practice_ids: str | None = None, overwrite_meta_information: str | None = None, supervised: bool = False, supervised_time_scale: float = 1.0, subsample_training: int | None = None, seed: int = 42, adapter: str | None = None, **kwargs)

PyTorch-Lightning datamodule for foundational models

ARGS:
practice_patient_id (list[str])

List of practice patient identifiers which satisfy study criteria.

KWARGS:

batch_size (int):

unk_freq_threshold (float).

Initialize the FoundationalDataModule, a PyTorch Lightning DataModule for foundational models.

This module handles data loading, preprocessing, and batching for training, validation, and testing. It supports both supervised and unsupervised learning tasks and integrates tokenization, meta-information loading, and dataset preparation.

Parameters:
  • path_to_db (str) – Full path to the SQLite database folder.

  • path_to_ds (str) – Full path to the preprocessed dataset folder, either for loading or saving data.

  • load (bool) – If True, loads preprocessed dataset from Parquet files. If False, processes raw SQLite data and saves it as Parquet files.

  • tokenizer (str, optional) – Which tokenizer to use. "tabular" uses the Tabular tokenizer; any other string defaults to NonTabular.

  • batch_size (int, optional) – Number of samples per batch for the DataLoader (default 64).

  • min_workers (int, optional) – Minimum number of workers used for data loading (default 1).

  • overwrite_practice_ids (str | None, optional) – Path to a file containing new practice-ID allocations for train/validation/test splits. Prevents data leakage when creating fine-tuning datasets.

  • overwrite_meta_information (str | None, optional) – Path to an existing meta-information pickle file. If provided, prevents redundant preprocessing of meta-information.

  • supervised (bool, optional) – If True, enables supervised training mode (default False).

  • supervised_time_scale (float, optional) – Scaling factor applied to supervised target times produced in the collator (default: 10 years). This multiplies the time_scale in FoundationalDataset.

  • subsample_training (int | None, optional) – If specified, reduces the training dataset to a random subset of this size.

  • seed (int | None, optional) – Random seed used when subsampling the training dataset.

  • kwargs (Any) – Additional keyword arguments passed to PolarsDataset.fit().

Notes: - Loads tokenizer vocabulary from meta-information and initializes dataset splits. - Supports dataset reprocessing or direct loading from Parquet files. - Uses Collator for batch collation (supports self-supervised and supervised tasks). - Handles training, validation, and test dataset creation with tokenization.

property is_supervised
property context_time_scale
property supervised_time_scale
batch_size = 64
min_workers = 1
tokenizer
train_set
test_set
val_set
adapter
collate_fn
standardise(event, value)
unstandardise(event, value)
encode(sequence: list[str])
decode(sequence: list[int])
train_dataloader()
val_dataloader()
test_dataloader()
class FastEHR.dataloader.foundational_loader.FoundationalDataset(parquet_path: str, split: str, tokenizer: FastEHR.dataloader.tokenizers_local.TokenizerBase, meta_information: dict, file_row_count_dict: dict, max_seq_length: int = 256, standardise_values: bool = True, global_diagnoses: bool = False, repeating_events: bool = True, random_context_window: bool = False, time_scale: float = 1825.0, subsample: int | None = None, seed: int = 42, **kwargs)

Initialize the FoundationalDataset, a dataset class for foundational model training.

This dataset is constructed from preprocessed Parquet files and provides tokenized and structured sequences of events, including dynamic and static features.

Parameters

parquet_pathstr

Path to the directory containing Parquet dataset files.

splitstr

Dataset split type (“train”, “val”, or “test”).

tokenizerTokenizerBase

Tokenizer used for encoding event sequences.

meta_informationdict

Dictionary containing meta-information about the dataset, including measurement statistics.

file_row_count_dictdict

Dictionary mapping Parquet filenames to the number of rows they contain.

max_seq_lengthint, optional, default=256

The maximum number of tokens in a sequence. Longer sequences are truncated.

standardise_valuesbool, optional, default=True

Whether to standardize event values based on dataset statistics.

global_diagnosesbool, optional, default=False

If True, ensures all historical diagnoses are included in each sequence’s context window.

repeating_eventsbool, optional, default=True

Whether to allow repeated events within a sequence. - True: Retains all occurrences of an event. - False: Keeps only the latest record of each event. These may still fall outside of a context window.

random_context_windowbool, optional, default=False
Whether to randomly sample context windows or use the latest

events.

time_scalefloat, optional, default=1825.0
The scaling factor applied to age values (default: 5 years so a

model using unit interval looks 5 years ahead)

subsampleint, optional

If specified, limits the dataset to a random subsample of this size.

Notes

  • The dataset loads preprocessed patient event sequences and encodes

    them using the tokenizer.

  • Supports both fixed-length and randomly sampled context windows.

  • Static covariates are one-hot encoded and stored separately from

    dynamic sequences.

  • The dataset length is computed based on the sum of row counts across

    Parquet files.

view_sample(idx, max_dynamic_events=None, report_time=False)

Displays a sample in a readable format.

property meta_static
property meta_measurement
parquet_path
sub_dir = 'split=Uninferable/'
tokenizer
max_seq_length = 256
standardise_values = True
global_diagnoses = False
repeating_events = True
random_context_window = False
meta_information
time_scale = 1825.0
subsample = None
seed = 42
warnings_raised = []
file_row_count_dict
total_samples
static_1hot
getitem(idx)
class FastEHR.dataloader.foundational_loader.Collator(supervised=False, supervised_time_scale=2.0, adapter=None)
supervised:

Whether to take the last time point as the target

supervised_time_scale: float, optional, default 2.0

The scaling factor applied to any supervised target times produced in the Collator (default: 2 imes 5 years).

Note: The collator receives the times from FoundationalDataset() which are already scaled. The way this is coded leads to a multiplicative effect (TODO: put all scaling in the Collator to simplify code and API)

supervised = False
supervised_time_scale = 2.0
adapter = None
collate_fn(data: list[dict])

Collect and collate separate dictionaries.

During this operation, pad the sequence lengths to the maximum length seen within the batch and tokenize

static convert_to_supervised(batch, supervised_time_scale)

Convert a batch to a supervised format for non-causal tasks.

This method is used in conjunction with FoundationalDataModule for non-causal tasks, where the last non-padding token in each sequence is removed and used as a supervised target.

Specifically, this function: - Replaces the last non-padding token in each row with a padding token (0). - Creates new target vectors containing the removed tokens, values, and age deltas. - Ensures alignment of token sequences, masking, and corresponding values.

Parameters

batchdict

A dictionary containing the following keys:

  • tokens (torch.Tensor): The input tensor containing tokenized sequences with padding.

  • ages (torch.Tensor): The tensor containing ages corresponding to each token in the matrix.

  • values (torch.Tensor): The tensor containing values corresponding to each token in the matrix.

  • attention_mask (torch.Tensor): The tensor containing masks indicating valid tokens.

Returns

dict

A modified batch dictionary containing the following keys:

  • tokens (torch.Tensor): The input sequence with the last non-padding token replaced by a padding token (0).

  • ages (torch.Tensor): The modified age matrix with the last non-padding age replaced by 0.

  • values (torch.Tensor): The modified value matrix with the last non-padding value replaced with np.nan.

  • attention_mask (torch.Tensor): The modified mask matrix with the last non-padding entry set to 0.

  • target_token (torch.Tensor): A vector containing the removed tokens.

  • target_age_delta (torch.Tensor): A vector containing the difference in age between the last two non-padding tokens.

  • target_value (torch.Tensor): A vector containing the values that were removed from values.

Notes

  • If a sample does not contain at least two non-padding events, a warning is logged, and the sample is removed.

  • This function prevents information leakage by ensuring that the last event in a sequence is not used as an input.

  • If convert_to_supervised() is applied multiple times to the same batch, it will be skipped to prevent redundant modifications.