diff --git a/icx360/utils/model_wrappers/huggingface.py b/icx360/utils/model_wrappers/huggingface.py index 4868536..28e116d 100644 --- a/icx360/utils/model_wrappers/huggingface.py +++ b/icx360/utils/model_wrappers/huggingface.py @@ -37,7 +37,7 @@ def __init__(self, model, tokenizer): self._tokenizer = tokenizer self._device = model.device - def convert_input(self, inputs, chat_template=False, system_prompt=None, **kwargs): + def convert_input(self, inputs, chat_template=False, system_prompt=None, unit_ranges=None, **kwargs): """ Encode input text as token IDs for HuggingFace model. @@ -48,6 +48,8 @@ def convert_input(self, inputs, chat_template=False, system_prompt=None, **kwarg Whether to apply chat template. system_prompt (str or None): System prompt to include in chat template. + unit_ranges (dict or None): + Mapping from chat template parts to ranges of input units. **kwargs (dict): Additional keyword arguments for tokenizer. @@ -60,37 +62,94 @@ def convert_input(self, inputs, chat_template=False, system_prompt=None, **kwarg # Batch of strings, enable padding and truncation kwargs["padding"] = True kwargs["truncation"] = True - if isinstance(inputs[0], list): - if chat_template: - # Join segmented strings - inputs = ["".join(inp) for inp in inputs] - else: - # Indicate to tokenizer that strings are segmented - kwargs["is_split_into_words"] = True + if isinstance(inputs[0], list) and not chat_template: + # Indicate to tokenizer that strings are segmented + kwargs["is_split_into_words"] = True if chat_template: - # Construct chat messages - if isinstance(inputs, list): - if system_prompt is not None: - messages = [[{"role": "system", "content": system_prompt}, - {"role": "user", "content": inp}] for inp in inputs] - else: - messages = [[{"role": "user", "content": inp}] for inp in inputs] + if isinstance(inputs, list) and isinstance(inputs[0], list) and unit_ranges is not None: + # Inputs are segmented into units and a mapping from chat template parts to units is given + inputs_formatted = self._construct_chat_template_from_mapping(inputs, unit_ranges) + # Encode chat + input_encoding = self._tokenizer(inputs_formatted, **kwargs).to(self._device) else: - if system_prompt is not None: - messages = [{"role": "system", "content": system_prompt}, - {"role": "user", "content": inputs}] + if isinstance(inputs, list) and isinstance(inputs[0], list): + # Inputs are segmented into units but no mapping given, just join units + inputs = ["".join(inp) for inp in inputs] + + # Construct chat messages + if isinstance(inputs, list): + if system_prompt is not None: + messages = [[{"role": "system", "content": system_prompt}, + {"role": "user", "content": inp}] for inp in inputs] + else: + messages = [[{"role": "user", "content": inp}] for inp in inputs] else: - messages = [{"role": "user", "content": inputs}] - # Encode chat - input_encoding = self._tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_dict=True, **kwargs).to(self._device) + if system_prompt is not None: + messages = [{"role": "system", "content": system_prompt}, + {"role": "user", "content": inputs}] + else: + messages = [{"role": "user", "content": inputs}] + + # Encode chat + input_encoding = self._tokenizer.apply_chat_template( + messages, add_generation_prompt=True, return_dict=True, **kwargs).to(self._device) else: # Encode text input_encoding = self._tokenizer(inputs, **kwargs).to(self._device) return input_encoding - def generate(self, inputs, chat_template=False, system_prompt=None, tokenizer_kwargs={}, text_only=True, **kwargs): + def _construct_chat_template_from_mapping(self, inputs, unit_ranges): + """ + Construct chat template given mapping from parts of the chat template to input units. + + Args: + inputs (List[List[str]]): + A list of input texts segmented into units. + unit_ranges (dict): + Mapping from chat template parts to ranges of input units. + + Returns: + inputs_formatted (List[str]): + List of inputs formatted according to chat template. + """ + inputs_formatted = [] + # Iterate over inputs + for inp in inputs: + + # Construct conversation turn by turn + conversation = [] + for turn_ranges in unit_ranges["conversation"]: + turn = {} + for key, rng in turn_ranges.items(): + # There should be only one range per turn + turn["role"] = key + turn["content"] = "".join(inp[rng[0] : rng[1]]) + conversation.append(turn) + + if "documents" in unit_ranges: + # Construct documents + documents = [] + for doc_id, doc_ranges in enumerate(unit_ranges["documents"]): + document = {"doc_id": doc_id + 1} + for key, rng in doc_ranges.items(): + # Document text and possibly a title + document[key] = "".join(inp[rng[0] : rng[1]]) + documents.append(document) + else: + documents = None + + # Construct chat template from conversation and documents + input_formatted = self._tokenizer.apply_chat_template(conversation, + documents=documents, + add_generation_prompt=True, + tokenize=False) + inputs_formatted.append(input_formatted) + + return inputs_formatted + + def generate(self, inputs, chat_template=False, system_prompt=None, unit_ranges=None, tokenizer_kwargs={}, text_only=True, **kwargs): """ Generate response from model. @@ -101,6 +160,8 @@ def generate(self, inputs, chat_template=False, system_prompt=None, tokenizer_kw Whether to apply chat template. system_prompt (str or None): System prompt to include in chat template. + unit_ranges (dict or None): + Mapping from chat template parts to ranges of input units. tokenizer_kwargs (dict): Additional keyword arguments for tokenizer. text_only (bool): @@ -117,7 +178,7 @@ def generate(self, inputs, chat_template=False, system_prompt=None, tokenizer_kw output_token_count: Maximum number of generated tokens. """ # Encode input text as token IDs - inputs = self.convert_input(inputs, chat_template, system_prompt, **tokenizer_kwargs) + inputs = self.convert_input(inputs, chat_template, system_prompt, unit_ranges, **tokenizer_kwargs) num_inputs, input_length = inputs["input_ids"].shape if num_inputs == 1 or not torch.cuda.is_available(): diff --git a/icx360/utils/model_wrappers/vllm.py b/icx360/utils/model_wrappers/vllm.py index 73ef0cc..a0050da 100644 --- a/icx360/utils/model_wrappers/vllm.py +++ b/icx360/utils/model_wrappers/vllm.py @@ -36,7 +36,7 @@ def __init__(self, model, model_name, tokenizer=None): self._model_name = model_name self._tokenizer = tokenizer - def convert_input(self, inputs, chat_template=False, system_prompt=None, **kwargs): + def convert_input(self, inputs, chat_template=False, system_prompt=None, unit_ranges=None, **kwargs): """ Convert input(s) into a list of strings. @@ -47,6 +47,8 @@ def convert_input(self, inputs, chat_template=False, system_prompt=None, **kwarg Whether to apply chat template. system_prompt (str or None): System prompt to include in chat template. + unit_ranges (dict or None): + Mapping from chat template parts to ranges of input units. Returns: inputs (List[str]): @@ -55,30 +57,87 @@ def convert_input(self, inputs, chat_template=False, system_prompt=None, **kwarg if isinstance(inputs, str): # Single input text, convert to list inputs = [inputs] - elif isinstance(inputs, list): - if isinstance(inputs[0], list): - # Join segmented texts - inputs = ["".join(inp) for inp in inputs] - else: + elif not isinstance(inputs, list): raise TypeError("Inputs must be a string or list for VLLMModel") if chat_template: if self._tokenizer is None: raise TypeError("HuggingFace tokenizer must be provided to apply chat template") - # Construct chat messages - if system_prompt is not None: - messages = [[{"role": "system", "content": system_prompt}, - {"role": "user", "content": inp}] for inp in inputs] + if isinstance(inputs, list) and isinstance(inputs[0], list) and unit_ranges is not None: + # Inputs are segmented into units and a mapping from chat template parts to units is given + inputs = self._construct_chat_template_from_mapping(inputs, unit_ranges) else: - messages = [[{"role": "user", "content": inp}] for inp in inputs] - - # Apply chat template - inputs = self._tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + if isinstance(inputs, list) and isinstance(inputs[0], list): + # Inputs are segmented into units but no mapping given, just join units + inputs = ["".join(inp) for inp in inputs] + + # Construct chat messages, placing each input into a single user message + if system_prompt is not None: + messages = [[{"role": "system", "content": system_prompt}, + {"role": "user", "content": inp}] for inp in inputs] + else: + messages = [[{"role": "user", "content": inp}] for inp in inputs] + + # Apply chat template + inputs = self._tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + else: + if isinstance(inputs, list) and isinstance(inputs[0], list): + # Join segmented units + inputs = ["".join(inp) for inp in inputs] return inputs - def generate(self, inputs, chat_template=False, system_prompt=None, text_only=True, **kwargs): + def _construct_chat_template_from_mapping(self, inputs, unit_ranges): + """ + Construct chat template given mapping from parts of the chat template to input units. + + Args: + inputs (List[List[str]]): + A list of input texts segmented into units. + unit_ranges (dict): + Mapping from chat template parts to ranges of input units. + + Returns: + inputs_formatted (List[str]): + List of inputs formatted according to chat template. + """ + inputs_formatted = [] + # Iterate over inputs + for inp in inputs: + + # Construct conversation turn by turn + conversation = [] + for turn_ranges in unit_ranges["conversation"]: + turn = {} + for key, rng in turn_ranges.items(): + # There should be only one range per turn + turn["role"] = key + turn["content"] = "".join(inp[rng[0] : rng[1]]) + conversation.append(turn) + + if "documents" in unit_ranges: + # Construct documents + documents = [] + for doc_id, doc_ranges in enumerate(unit_ranges["documents"]): + document = {"doc_id": doc_id + 1} + for key, rng in doc_ranges.items(): + # Document text and possibly a title + document[key] = "".join(inp[rng[0] : rng[1]]) + documents.append(document) + else: + documents = None + + # Construct chat template from conversation and documents + input_formatted = self._tokenizer.apply_chat_template(conversation, + documents=documents, + add_generation_prompt=True, + tokenize=False) + inputs_formatted.append(input_formatted) + + return inputs_formatted + + def generate(self, inputs, chat_template=False, system_prompt=None, text_only=True, unit_ranges=None, **kwargs): """ Generate response from model. @@ -91,6 +150,8 @@ def generate(self, inputs, chat_template=False, system_prompt=None, text_only=Tr System prompt to include in chat template. text_only (bool): Return only generated text (default) or an object containing additional outputs. + unit_ranges (dict or None): + Mapping from chat template parts to ranges of input units. **kwargs (dict): Additional keyword arguments for VLLM model. @@ -101,7 +162,7 @@ def generate(self, inputs, chat_template=False, system_prompt=None, text_only=Tr output_text: List of generated texts. """ # Convert input into list of strings if needed - inputs = self.convert_input(inputs, chat_template, system_prompt) + inputs = self.convert_input(inputs, chat_template, system_prompt, unit_ranges) # Generate output output_text = [] diff --git a/icx360/utils/scalarizers/prob.py b/icx360/utils/scalarizers/prob.py index a62e327..64d19ff 100644 --- a/icx360/utils/scalarizers/prob.py +++ b/icx360/utils/scalarizers/prob.py @@ -40,7 +40,9 @@ def __init__(self, model): if not isinstance(model, HFModel) and not isinstance(model, VLLMModel): raise TypeError("Model must be a HFModel (HuggingFace) or VLLMModel for ProbScalarizedModel") - def scalarize_output(self, inputs=None, outputs=None, ref_input=None, ref_output=None, chat_template=False, system_prompt=None, tokenizer_kwargs={}, transformation="log_prob_mean", **kwargs): + def scalarize_output(self, inputs=None, outputs=None, ref_input=None, ref_output=None, + chat_template=False, system_prompt=None, unit_ranges=None, tokenizer_kwargs={}, + transformation="log_prob_mean", **kwargs): """ Compute probability of generating reference output (or each unit thereof) conditioned on inputs. @@ -58,6 +60,8 @@ def scalarize_output(self, inputs=None, outputs=None, ref_input=None, ref_output Whether to apply chat template. system_prompt (str or None): System prompt to include in chat template. + unit_ranges (dict or None): + Mapping from chat template parts to ranges of input units. tokenizer_kwargs (dict): Additional keyword arguments for tokenizer. transformation (str, optional): @@ -77,7 +81,7 @@ def scalarize_output(self, inputs=None, outputs=None, ref_input=None, ref_output if inputs is None: raise ValueError("inputs must be provided for ProbScalarizedModel.scalarize_output()") else: - inputs = self.model.convert_input(inputs, chat_template, system_prompt, **tokenizer_kwargs) + inputs = self.model.convert_input(inputs, chat_template, system_prompt, unit_ranges, **tokenizer_kwargs) # Check for reference output if ref_output is None: raise ValueError("ref_output must be provided for ProbScalarizedModel.scalarize_output()")