FastEHR.adapters.BEHRT

Classes

ConvertToBEHRT

Convert tokenized FastEHR patient sequences into BEHRT-compatible format.

BehrtDFBuilder

Build a BEHRT-ready DataFrame from batches of token and age tensors.

Module Contents

class FastEHR.adapters.BEHRT.ConvertToBEHRT(tokenizer, supervised=False)

Convert tokenized FastEHR patient sequences into BEHRT-compatible format.

This adapter:

  • Extends an existing FastEHR tokenizer to include BEHRT’s required special tokens.

  • Converts sequences of events grouped by visit into the token/age format expected by BEHRT, adding [CLS] at the start and [SEP] between visits.

  • Retains values (despite not being used in BEHRT).

  • Removes baseline information (e.g. ethnicity, gender) as this is not used by BEHRT.

Attributes

  • special_tokens (dict[str, int]): Mapping of BEHRT special tokens to fixed IDs: PAD=0, UNK=1, SEP=2, CLS=3, MASK=4.

  • fastehr_tokenizer (object): Original FastEHR tokenizer instance passed at init.

  • supervised (bool): Whether conversion targets a supervised task (affects final SEP).

  • tokenizer (dict[str, int]): Token to index mapping incl. BEHRT specials and original codes.

Example:

>>> converter = ConvertToBEHRT(fastehr_tokenizer)
>>> processed_list_of_patient_dicts = converter(list_of_patient_dicts)
special_tokens
create_behrt_tokenizer(tokenizer)
fastehr_tokenizer
supervised = False
tokenizer
convert_sample(data_sample: dict)
class FastEHR.adapters.BEHRT.BehrtDFBuilder(token_map: dict, pad_token: int | str = 'PAD', class_token: int | str = 'CLS', sep_token: int | str = 'SEP', id_prefix: str = 'P', zfill: int = 3, min_seq_len: int = 5)

Build a BEHRT-ready DataFrame from batches of token and age tensors.

Each batch must be shaped [batch_size, seq_len].

Parameters

token_mapdict

Mapping from token string to token id (BEHRT-modified vocab).

pad_token, class_token, sep_tokenstr or int

Special tokens (as names or ids)

id_prefixstr

Prefix for generated patient IDs.

zfillint

Zero-padding length for patient IDs.

min_seq_lenint

Minimum number of non-CLS/SEP tokens required to keep a sample. Defaults to 5 as per BEHRT paper.

class_token_id = 'CLS'
pad_token_id = 'PAD'
sep_token_id = 'SEP'
id_prefix = 'P'
zfill = 3
min_seq_len = 5
rows = []
next_id = 1
add_batch(tokens_batch, ages_batch, target_event=None, target_time=None, target_value=None)

Add a batch of sequences to the builder.

Parameters:
  • tokens_batch (torch.Tensor, shape [B, T]) – Batch of token sequences; each element is a string token (or an integer ID).

  • ages_batch (torch.Tensor, shape [B, T]) – Ages aligned with tokens_batch.

  • target_event (torch.Tensor or None, shape [B]) – Outcome event token/ID for each sequence, or None.

  • target_time (torch.Tensor or None, shape [B]) – Time-to-event measured from the last token in tokens_batch, or None.

  • target_value (torch.Tensor or None, shape [B]) – Value associated with the outcome event, or None.

flush() pandas.DataFrame

Return a DataFrame of all accumulated rows and clear the buffer. This helps manage memory when processing large datasets.