From cfde5bb44e4b551c88c9e5a6a5f5f154290b8d67 Mon Sep 17 00:00:00 2001 From: John Lamprou <41962910+jlamprou@users.noreply.github.com> Date: Sun, 1 Dec 2024 00:07:27 +0200 Subject: [PATCH] fix: Handle different separator token values for uint16/uint32 dtypes When processing Qwen model data with uint32 dtype, the -1 separator token overflows to 4294967295 instead of 65535. Add support for detecting both values to ensure correct prompt/response segmentation. --- data_utils/lm_datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data_utils/lm_datasets.py b/data_utils/lm_datasets.py index 65c14b8b..6aa4a2cf 100644 --- a/data_utils/lm_datasets.py +++ b/data_utils/lm_datasets.py @@ -55,8 +55,8 @@ def _process_lm(self, i, samp, model_data, no_model_data, gen_data): source_len = 1 prompt = None - if 65535 in input_ids: - source_len = np.where(input_ids==65535)[0][0] + if 65535 in input_ids or 4294967295 in input_ids: + source_len = np.where((input_ids==65535) | (input_ids==4294967295))[0][0] prompt = input_ids[:source_len] input_ids = np.concatenate([input_ids[:source_len], input_ids[source_len+1:]], axis=0) input_ids = input_ids[:self.max_length]