diff --git a/fastembed/common/model_description.py b/fastembed/common/model_description.py index caa17710f..0c539fd84 100644 --- a/fastembed/common/model_description.py +++ b/fastembed/common/model_description.py @@ -49,4 +49,5 @@ class SparseModelDescription(BaseModelDescription): class PoolingType(str, Enum): CLS = "CLS" MEAN = "MEAN" + LAST_TOKEN = "LAST_TOKEN" DISABLED = "DISABLED" diff --git a/fastembed/common/model_management.py b/fastembed/common/model_management.py index 5301def80..90b1e0ee2 100644 --- a/fastembed/common/model_management.py +++ b/fastembed/common/model_management.py @@ -409,7 +409,7 @@ def download_model(cls, model: T, cache_dir: str, retries: int = 3, **kwargs: An try: cache_kwargs = deepcopy(kwargs) cache_kwargs["local_files_only"] = True - return Path( + cached_dir = Path( cls.download_files_from_huggingface( hf_source, cache_dir=cache_dir, @@ -417,6 +417,13 @@ def download_model(cls, model: T, cache_dir: str, retries: int = 3, **kwargs: An **cache_kwargs, ) ) + # Verify all required files exist in cache before returning + missing = [p for p in extra_patterns if not (cached_dir / p).exists()] + if missing: + raise FileNotFoundError( + f"Cached snapshot missing files: {missing}" + ) + return cached_dir except Exception: pass finally: diff --git a/fastembed/common/preprocessor_utils.py b/fastembed/common/preprocessor_utils.py index 3b702f799..1fa0acff5 100644 --- a/fastembed/common/preprocessor_utils.py +++ b/fastembed/common/preprocessor_utils.py @@ -1,4 +1,5 @@ import json +import logging from typing import Any from pathlib import Path @@ -6,11 +7,13 @@ from fastembed.image.transform.operators import Compose +logger = logging.getLogger(__name__) + def load_special_tokens(model_dir: Path) -> dict[str, Any]: tokens_map_path = model_dir / "special_tokens_map.json" if not tokens_map_path.exists(): - raise ValueError(f"Could not find special_tokens_map.json in {model_dir}") + return {} with open(str(tokens_map_path)) as tokens_map_file: tokens_map = json.load(tokens_map_file) @@ -51,9 +54,17 @@ def load_tokenizer(model_dir: Path) -> tuple[Tokenizer, dict[str, int]]: tokenizer = Tokenizer.from_file(str(tokenizer_path)) tokenizer.enable_truncation(max_length=max_context) if not tokenizer.padding: - tokenizer.enable_padding( - pad_id=config.get("pad_token_id", 0), pad_token=tokenizer_config["pad_token"] - ) + pad_token_id = config.get("pad_token_id") + if pad_token_id is None: + logger.warning( + "pad_token_id not found in config.json for %s, defaulting to 0", + model_dir.name, + ) + pad_token_id = 0 + pad_token = tokenizer_config.get("pad_token", "") + if isinstance(pad_token, dict): + pad_token = pad_token.get("content", "") + tokenizer.enable_padding(pad_id=pad_token_id, pad_token=pad_token) for token in tokens_map.values(): if isinstance(token, str): diff --git a/fastembed/common/utils.py b/fastembed/common/utils.py index b61a8b9ce..90e3a2be6 100644 --- a/fastembed/common/utils.py +++ b/fastembed/common/utils.py @@ -15,6 +15,28 @@ T = TypeVar("T") +def last_token_pool(input_array: NumpyArray, attention_mask: NDArray[np.int64]) -> NumpyArray: + """Extract embedding from the last non-padding token position. + + Qwen3-Embedding uses last-token pooling (NOT CLS/mean pooling). + Handles both left-padding and right-padding. + + Args: + input_array: Model output, shape (batch_size, seq_len, hidden_dim). + attention_mask: Attention mask, shape (batch_size, seq_len). + + Returns: + Pooled embeddings, shape (batch_size, hidden_dim). + """ + left_padding = bool(attention_mask[:, -1].sum() == attention_mask.shape[0]) + if left_padding: + return input_array[:, -1] + + sequence_lengths = attention_mask.sum(axis=1).astype(np.int64) - 1 + batch_size = input_array.shape[0] + return input_array[np.arange(batch_size), sequence_lengths] + + def normalize(input_array: NumpyArray, p: int = 2, dim: int = 1, eps: float = 1e-12) -> NumpyArray: # Calculate the Lp norm along the specified dimension norm = np.linalg.norm(input_array, ord=p, axis=dim, keepdims=True) diff --git a/fastembed/rerank/cross_encoder/qwen3_cross_encoder.py b/fastembed/rerank/cross_encoder/qwen3_cross_encoder.py new file mode 100644 index 000000000..a6d9b744d --- /dev/null +++ b/fastembed/rerank/cross_encoder/qwen3_cross_encoder.py @@ -0,0 +1,225 @@ +"""Qwen3 reranker using causal LM with yes/no logit scoring. + +Unlike traditional cross-encoder rerankers (which concatenate query+document +as a pair, feed through a BERT-class model, and read a relevance head), the +Qwen3 reranker: + +1. Formats input as a **chat template** with system/user/assistant turns. +2. Runs a **causal language model** (Qwen3ForCausalLM). +3. Extracts the **last-token logits** for the "yes" and "no" tokens. +4. Applies **softmax** to obtain the relevance probability. + +This means the ONNX model output has shape ``(batch, seq_len, vocab_size)`` +instead of the typical ``(batch, num_labels)`` from cross-encoders. +""" + +from typing import Any + +import numpy as np + +from fastembed.common.model_description import BaseModelDescription, ModelSource +from fastembed.common.onnx_model import OnnxOutputContext +from fastembed.common.types import NumpyArray +from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import ( + OnnxTextCrossEncoder, + TextCrossEncoderWorker, +) +from fastembed.rerank.cross_encoder.onnx_text_model import TextRerankerWorker + +# --------------------------------------------------------------------------- +# Qwen3 reranker constants +# --------------------------------------------------------------------------- +# Token IDs in the Qwen3 tokenizer vocabulary +TOKEN_YES_ID = 9693 +TOKEN_NO_ID = 2132 + +SYSTEM_PROMPT = ( + "Judge whether the Document meets the requirements based on the Query " + 'and the Instruct provided. Note that the answer can only be "yes" or "no".' +) + +DEFAULT_INSTRUCTION = ( + "Given a query and a document, judge whether the document is relevant to the query." +) + +RERANK_TEMPLATE = ( + "<|im_start|>system\n{system}<|im_end|>\n" + "<|im_start|>user\n: {instruction}\n" + ": {query}\n: {document}<|im_end|>\n" + "<|im_start|>assistant\n\n\n\n\n" +) + +# --------------------------------------------------------------------------- +# Model registry +# --------------------------------------------------------------------------- +supported_qwen3_reranker_models: list[BaseModelDescription] = [ + BaseModelDescription( + model="Qwen/Qwen3-Reranker-0.6B", + description=( + "Qwen3 reranker (0.6B) using causal LM yes/no scoring. " + "INT8 dynamic quantized. Multilingual, 40960 input tokens, " + "instruction-aware, 2025 year." + ), + license="apache-2.0", + size_in_GB=0.57, + sources=ModelSource(hf="n24q02m/Qwen3-Reranker-0.6B-ONNX"), + model_file="onnx/model_quantized.onnx", + ), + BaseModelDescription( + model="Qwen/Qwen3-Reranker-0.6B-Q4F16", + description=( + "Qwen3 reranker (0.6B) using causal LM yes/no scoring. " + "INT4 weights + FP16 activations (Q4F16). Multilingual, " + "40960 input tokens, instruction-aware, 2025 year." + ), + license="apache-2.0", + size_in_GB=0.57, + sources=ModelSource(hf="n24q02m/Qwen3-Reranker-0.6B-ONNX"), + model_file="onnx/model_q4f16.onnx", + ), +] + + +# --------------------------------------------------------------------------- +# Qwen3 reranker implementation +# --------------------------------------------------------------------------- +class Qwen3CrossEncoder(OnnxTextCrossEncoder): + """Qwen3 Reranker using causal LM with yes/no logit scoring. + + Usage:: + + from fastembed import TextCrossEncoder + + reranker = TextCrossEncoder("Qwen/Qwen3-Reranker-0.6B") + scores = list(reranker.rerank("What is AI?", ["doc1", "doc2"])) + + # Custom instruction + scores = list(reranker.rerank( + "What is AI?", + ["doc1", "doc2"], + instruction="Judge document relevance for code search.", + )) + """ + + @classmethod + def _list_supported_models(cls) -> list[BaseModelDescription]: + """Return the list of supported Qwen3 reranker models.""" + return supported_qwen3_reranker_models + + # ------------------------------------------------------------------ + # Chat template formatting + # ------------------------------------------------------------------ + @staticmethod + def _format_rerank_input( + query: str, + document: str, + instruction: str = DEFAULT_INSTRUCTION, + ) -> str: + """Build the chat-template string for a single query-document pair.""" + return RERANK_TEMPLATE.format( + system=SYSTEM_PROMPT, + instruction=instruction, + query=query, + document=document, + ) + + # ------------------------------------------------------------------ + # Yes/No logit scoring + # ------------------------------------------------------------------ + @staticmethod + def _compute_yes_no_scores(model_output: NumpyArray) -> NumpyArray: + """Extract yes/no logits from causal LM output and compute scores. + + Args: + model_output: Raw model output, shape ``(batch, seq_len, vocab_size)``. + + Returns: + Relevance scores (P(yes)), shape ``(batch,)``. + """ + # Last token logits for each sample + last_logits: NumpyArray = model_output[:, -1, :] # (batch, vocab_size) + + # Stack [no, yes] logits + yes_no_logits = np.stack( + [last_logits[:, TOKEN_NO_ID], last_logits[:, TOKEN_YES_ID]], axis=1 + ) # (batch, 2) + + # Numerically stable softmax + max_logits = np.max(yes_no_logits, axis=1, keepdims=True) + exp_logits = np.exp(yes_no_logits - max_logits) + probs = exp_logits / np.sum(exp_logits, axis=1, keepdims=True) + + return probs[:, 1] # P(yes) + + # ------------------------------------------------------------------ + # Override ONNX inference to use chat-template + CausalLM scoring + # ------------------------------------------------------------------ + def onnx_embed(self, query: str, documents: list[str], **kwargs: Any) -> OnnxOutputContext: + """Score query-document pairs using the Qwen3 chat template.""" + instruction = kwargs.pop("instruction", DEFAULT_INSTRUCTION) + texts = [self._format_rerank_input(query, doc, instruction) for doc in documents] + return self._onnx_embed_texts(texts, **kwargs) + + def onnx_embed_pairs(self, pairs: list[tuple[str, str]], **kwargs: Any) -> OnnxOutputContext: + """Score pre-formed (query, document) pairs.""" + instruction = kwargs.pop("instruction", DEFAULT_INSTRUCTION) + texts = [self._format_rerank_input(query, doc, instruction) for query, doc in pairs] + return self._onnx_embed_texts(texts, **kwargs) + + def _onnx_embed_texts(self, texts: list[str], **kwargs: Any) -> OnnxOutputContext: + """Tokenise and run model one text at a time (static batch=1 ONNX graph), + then concatenate the yes/no scores.""" + assert self.tokenizer is not None, "Tokenizer not loaded. Call load_onnx_model() first." + + input_names: set[str] = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr] + all_scores: list[NumpyArray] = [] + for text in texts: + tokenized = self.tokenizer.encode_batch([text]) + onnx_input: dict[str, NumpyArray] = { + "input_ids": np.array([tokenized[0].ids], dtype=np.int64), + } + if "attention_mask" in input_names: + onnx_input["attention_mask"] = np.array( + [tokenized[0].attention_mask], dtype=np.int64 + ) + if "token_type_ids" in input_names: + onnx_input["token_type_ids"] = np.zeros_like( + onnx_input["input_ids"], dtype=np.int64 + ) + + onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs) + outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr] + model_output = outputs[0] + if model_output.dtype == np.float16: + model_output = model_output.astype(np.float32) + scores = self._compute_yes_no_scores(model_output) + all_scores.append(scores) + + concatenated = np.concatenate(all_scores).astype(np.float32) + return OnnxOutputContext(model_output=concatenated) + + # ------------------------------------------------------------------ + # Worker + # ------------------------------------------------------------------ + @classmethod + def _get_worker_class(cls) -> type[TextRerankerWorker]: + """Return the worker class for parallel processing.""" + return Qwen3CrossEncoderWorker + + +class Qwen3CrossEncoderWorker(TextCrossEncoderWorker): + """Worker for parallel Qwen3 reranker inference.""" + + def init_embedding( + self, + model_name: str, + cache_dir: str, + **kwargs: Any, + ) -> OnnxTextCrossEncoder: + """Initialise a Qwen3CrossEncoder instance for the worker.""" + return Qwen3CrossEncoder( + model_name=model_name, + cache_dir=cache_dir, + threads=1, + **kwargs, + ) diff --git a/fastembed/rerank/cross_encoder/text_cross_encoder.py b/fastembed/rerank/cross_encoder/text_cross_encoder.py index 6f98cb24a..99048f161 100644 --- a/fastembed/rerank/cross_encoder/text_cross_encoder.py +++ b/fastembed/rerank/cross_encoder/text_cross_encoder.py @@ -5,6 +5,7 @@ from fastembed.common.types import Device from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder from fastembed.rerank.cross_encoder.custom_text_cross_encoder import CustomTextCrossEncoder +from fastembed.rerank.cross_encoder.qwen3_cross_encoder import Qwen3CrossEncoder from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase from fastembed.common.model_description import ( @@ -16,6 +17,7 @@ class TextCrossEncoder(TextCrossEncoderBase): CROSS_ENCODER_REGISTRY: list[Type[TextCrossEncoderBase]] = [ OnnxTextCrossEncoder, + Qwen3CrossEncoder, CustomTextCrossEncoder, ] diff --git a/fastembed/text/onnx_text_model.py b/fastembed/text/onnx_text_model.py index c8001a917..7593c9847 100644 --- a/fastembed/text/onnx_text_model.py +++ b/fastembed/text/onnx_text_model.py @@ -94,8 +94,11 @@ def onnx_embed( onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs) model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr] + embeddings = model_output[0] + if embeddings.dtype == np.float16: + embeddings = embeddings.astype(np.float32) return OnnxOutputContext( - model_output=model_output[0], + model_output=embeddings, attention_mask=onnx_input.get("attention_mask", attention_mask), input_ids=onnx_input.get("input_ids", input_ids), ) diff --git a/fastembed/text/qwen3_embedding.py b/fastembed/text/qwen3_embedding.py new file mode 100644 index 000000000..cda8966b0 --- /dev/null +++ b/fastembed/text/qwen3_embedding.py @@ -0,0 +1,212 @@ +"""Qwen3 text embedding with last-token pooling and Matryoshka (MRL) support. + +Qwen3-Embedding uses a causal LM architecture with last-token pooling instead +of the traditional CLS or mean pooling used by BERT-family models. It also +supports Matryoshka Representation Learning (MRL), allowing truncation of +embeddings to smaller dimensions (32-1024) with graceful degradation. + +Key differences from standard text embedding models: + - Last-token pooling: embedding is extracted from the last non-padding token + - Left-padding: the tokenizer pads from the left (not right) + - Instruction-aware: queries use ``Instruct: {task}\\nQuery: {text}`` format + - MRL: pass ``dim=256`` (or any value 32-1024) to ``embed()`` / ``query_embed()`` +""" + +import logging +from collections.abc import Iterable +from typing import Any + +from fastembed.common.model_description import DenseModelDescription, ModelSource +from fastembed.common.onnx_model import OnnxOutputContext +from fastembed.common.types import NumpyArray +from fastembed.common.utils import last_token_pool, normalize +from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker + +# --------------------------------------------------------------------------- +# Model registry +# --------------------------------------------------------------------------- +supported_qwen3_models: list[DenseModelDescription] = [ + DenseModelDescription( + model="Qwen/Qwen3-Embedding-0.6B", + dim=1024, + description=( + "Qwen3 text embedding (0.6B) with last-token pooling and MRL support " + "(32-1024 dims). INT8 dynamic quantized. Multilingual, 32768 input tokens, " + "instruction-aware, 2025 year." + ), + license="apache-2.0", + size_in_GB=0.57, + sources=ModelSource(hf="n24q02m/Qwen3-Embedding-0.6B-ONNX"), + model_file="onnx/model_quantized.onnx", + ), + DenseModelDescription( + model="Qwen/Qwen3-Embedding-0.6B-Q4F16", + dim=1024, + description=( + "Qwen3 text embedding (0.6B) with last-token pooling and MRL support " + "(32-1024 dims). INT4 weights + FP16 activations (Q4F16). Multilingual, " + "32768 input tokens, instruction-aware, 2025 year." + ), + license="apache-2.0", + size_in_GB=0.57, + sources=ModelSource(hf="n24q02m/Qwen3-Embedding-0.6B-ONNX"), + model_file="onnx/model_q4f16.onnx", + ), +] + +# --------------------------------------------------------------------------- +# Instruction template +# --------------------------------------------------------------------------- +DEFAULT_TASK = "Given a query, retrieve relevant documents that answer the query" +QUERY_INSTRUCTION_TEMPLATE = "Instruct: {task}\nQuery: {text}" + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Qwen3 embedding implementation +# --------------------------------------------------------------------------- +class Qwen3TextEmbedding(OnnxTextEmbedding): + """Qwen3 Embedding model with last-token pooling and MRL support. + + Usage:: + + from fastembed import TextEmbedding + + model = TextEmbedding("Qwen/Qwen3-Embedding-0.6B") + embeddings = list(model.embed(["Hello world"])) + + # MRL: reduce dimension + embeddings_256 = list(model.embed(["Hello world"], dim=256)) + + # Query with custom task instruction + query_emb = list(model.query_embed("What is AI?", task="...")) + """ + + @classmethod + def _list_supported_models(cls) -> list[DenseModelDescription]: + """Return the list of supported Qwen3 embedding models.""" + return supported_qwen3_models + + def _post_process_onnx_output( + self, output: OnnxOutputContext, **kwargs: Any + ) -> Iterable[NumpyArray]: + """Apply last-token pooling, optional MRL truncation, and L2 normalisation.""" + if output.attention_mask is None: + raise ValueError("attention_mask must be provided for last-token pooling") + + embeddings = last_token_pool(output.model_output, output.attention_mask) + + # MRL: optionally truncate to requested dimension + dim: int | None = kwargs.get("dim") + if dim is not None: + max_dim = embeddings.shape[-1] + if not (1 <= dim <= max_dim): + raise ValueError(f"dim must be between 1 and {max_dim}, got {dim}") + embeddings = embeddings[:, :dim] + + return normalize(embeddings) + + # ------------------------------------------------------------------ + # embed / query_embed / passage_embed + # ------------------------------------------------------------------ + def embed( + self, + documents: str | Iterable[str], + batch_size: int = 1, + parallel: int | None = None, + **kwargs: Any, + ) -> Iterable[NumpyArray]: + """Encode documents into embeddings. + + Args: + documents: A single document string or an iterable of documents. + batch_size: Ignored -- always ``1`` because + the causal-LM ONNX graph has a static batch dimension. + parallel: Number of parallel workers (``None`` = single-threaded). + **kwargs: Extra arguments; ``dim`` (int) enables MRL truncation, + ``task`` (str) is used only by :meth:`query_embed`. + + Yields: + NumpyArray: L2-normalised embeddings, one per document. + """ + if batch_size != 1: + logger.warning( + "batch_size=%d ignored for Qwen3; causal-LM ONNX graph requires batch_size=1", + batch_size, + ) + yield from self._embed_documents( + model_name=self.model_name, + cache_dir=str(self.cache_dir), + documents=documents, + batch_size=1, + parallel=parallel, + providers=self.providers, + cuda=self.cuda, + device_ids=self.device_ids, + local_files_only=self._local_files_only, + specific_model_path=self._specific_model_path, + extra_session_options=self._extra_session_options, + **kwargs, + ) + + def query_embed(self, query: str | Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]: + """Embed queries with instruction prefix. + + The instruction prefix follows the Qwen3 format:: + + Instruct: {task} + Query: {query_text} + + Args: + query: A single query string or an iterable of queries. + **kwargs: ``task`` (str) overrides the default retrieval instruction. + ``dim`` (int) enables MRL truncation. + + Yields: + NumpyArray: L2-normalised query embeddings. + """ + task = kwargs.pop("task", DEFAULT_TASK) + if isinstance(query, str): + queries = [QUERY_INSTRUCTION_TEMPLATE.format(task=task, text=query)] + else: + queries = [QUERY_INSTRUCTION_TEMPLATE.format(task=task, text=q) for q in query] + yield from self.embed(queries, **kwargs) + + def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]: + """Embed passages (documents) without instruction prefix. + + Args: + texts: An iterable of passage strings. + **kwargs: ``dim`` (int) enables MRL truncation. + + Yields: + NumpyArray: L2-normalised passage embeddings. + """ + yield from self.embed(texts, **kwargs) + + # ------------------------------------------------------------------ + # Worker + # ------------------------------------------------------------------ + @classmethod + def _get_worker_class(cls) -> type[OnnxTextEmbeddingWorker]: + """Return the worker class for parallel processing.""" + return Qwen3TextEmbeddingWorker + + +class Qwen3TextEmbeddingWorker(OnnxTextEmbeddingWorker): + """Worker for parallel Qwen3 embedding inference.""" + + def init_embedding( + self, + model_name: str, + cache_dir: str, + **kwargs: Any, + ) -> OnnxTextEmbedding: + """Initialise a Qwen3TextEmbedding instance for the worker.""" + return Qwen3TextEmbedding( + model_name=model_name, + cache_dir=cache_dir, + threads=1, + **kwargs, + ) diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index a4ae48cc5..2c000f4d4 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -9,6 +9,7 @@ from fastembed.text.pooled_embedding import PooledEmbedding from fastembed.text.multitask_embedding import JinaEmbeddingV3 from fastembed.text.onnx_embedding import OnnxTextEmbedding +from fastembed.text.qwen3_embedding import Qwen3TextEmbedding from fastembed.text.text_embedding_base import TextEmbeddingBase from fastembed.common.model_description import DenseModelDescription, ModelSource, PoolingType @@ -20,6 +21,7 @@ class TextEmbedding(TextEmbeddingBase): PooledNormalizedEmbedding, PooledEmbedding, JinaEmbeddingV3, + Qwen3TextEmbedding, CustomTextEmbedding, ] diff --git a/tests/test_text_cross_encoder.py b/tests/test_text_cross_encoder.py index 4d0d5b7d6..f30dcbe4f 100644 --- a/tests/test_text_cross_encoder.py +++ b/tests/test_text_cross_encoder.py @@ -14,6 +14,7 @@ "jinaai/jina-reranker-v1-tiny-en": np.array([2.5911, 0.1122]), "jinaai/jina-reranker-v1-turbo-en": np.array([1.8295, -2.8908]), "jinaai/jina-reranker-v2-base-multilingual": np.array([1.6533, -1.6455]), + "Qwen/Qwen3-Reranker-0.6B": np.array([0.99449426, 0.01638701]), } @@ -149,3 +150,31 @@ def test_session_options(model_cache, model_name) -> None: model = TextCrossEncoder(model_name=model_name, enable_cpu_mem_arena=False) session_options = model.model.model.get_session_options() assert session_options.enable_cpu_mem_arena is False + + +def test_qwen3_reranker_left_padding_batch(model_cache) -> None: + '''Test to ensure Qwen3 causal logit cross encoder works reliably when left-padded in batch.''' + is_ci = os.getenv("CI") + is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" + model_name = "Qwen/Qwen3-Reranker-0.6B" + + for model_desc in TextCrossEncoder._list_supported_models(): + if model_desc.model != model_name: + continue + if not should_test_model(model_desc, model_name, is_ci, is_manual): + pytest.skip(f"Skipping {model_name} (not selected for this CI run)") + + query = "Testing Qwen" + short_doc = "This is a short doc." + long_doc = "This is a significantly longer string that will force the shorter string to be padded with `` tokens on the left side during the tokenization phase. The embedding pooling must ignore these left padding tokens." + + with model_cache(model_name) as model: + # Infer short string alone + single_result = next(iter(model.rerank(query, [short_doc]))) + + # Infer short string mixed in a batch with a very long string + batch_results = list(model.rerank(query, [long_doc, short_doc])) + batch_result_short = batch_results[1].score + + # Ensure the score is exactly the same, proving causal LM logit selection is precise + assert np.allclose(single_result, batch_result_short, atol=1e-4) diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index e919faf9d..1e8da07ee 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -68,6 +68,9 @@ "Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]), "thenlper/gte-base": np.array([0.0038, 0.0355, 0.0181, 0.0092, 0.0654]), "jinaai/jina-clip-v1": np.array([-0.0862, -0.0101, -0.0056, 0.0375, -0.0472]), + "Qwen/Qwen3-Embedding-0.6B": np.array( + [-0.02225659, 0.01872586, -0.01449341, -0.08536665, 0.01223033] + ), } MULTI_TASK_MODELS = ["jinaai/jina-embeddings-v3"] @@ -217,3 +220,46 @@ def test_token_count(model_cache, model_name) -> None: doc_token_count = model.token_count(documents) assert first_doc_token_count + second_doc_token_count == doc_token_count assert doc_token_count == model.token_count(documents, batch_size=1) + + + +def test_qwen3_left_padding_batch_unit() -> None: + '''Directly verify last_token_pool behavior on synthetic left-padded hidden states.''' + from fastembed.common.utils import last_token_pool + + # Simulate a batch of 2 sequences, max length 5, hidden size 4 + # Sequence 0: [pad, pad, token1, token2, token3] -> Left padded + # Sequence 1: [token1, token2, token3, token4, token5] -> Not padded + + hidden_states = np.array([ + # Seq 0 + [ + [0,0,0,0], # pad + [0,0,0,0], # pad + [1,1,1,1], # token1 + [2,2,2,2], # token2 + [3,3,3,3], # token3 (LAST TOKEN) + ], + # Seq 1 + [ + [9,9,9,9], # token1 + [8,8,8,8], # token2 + [7,7,7,7], # token3 + [6,6,6,6], # token4 + [5,5,5,5], # token5 (LAST TOKEN) + ] + ], dtype=np.float32) + + attention_mask = np.array([ + [0, 0, 1, 1, 1], # Seq 0 mask + [1, 1, 1, 1, 1] # Seq 1 mask + ], dtype=np.int64) + + pooled = last_token_pool(hidden_states, attention_mask) + + # Expected: The vector at index 4 (the last token) for both + expected_seq0 = np.array([3,3,3,3], dtype=np.float32) + expected_seq1 = np.array([5,5,5,5], dtype=np.float32) + + assert np.allclose(pooled[0], expected_seq0) + assert np.allclose(pooled[1], expected_seq1)