From fad6ca6fe162970f5cde235d183c8270b10825a8 Mon Sep 17 00:00:00 2001 From: n24q02m Date: Sat, 14 Feb 2026 02:13:57 +0700 Subject: [PATCH 1/6] feat: add Qwen3-Embedding-0.6B and Qwen3-Reranker-0.6B support Add native support for Qwen3 embedding and reranker models: - Qwen3TextEmbedding: last-token pooling, MRL (32-1024 dims), instruction-aware - Qwen3CrossEncoder: causal LM yes/no logit scoring, chat-template formatting - last_token_pool() utility for causal embedding models - LAST_TOKEN pooling type in PoolingType enum - Graceful handling of missing special_tokens_map.json in preprocessor_utils - Fix pad_token_id=null and dict pad_token in tokenizer config ONNX models hosted at: - n24q02m/Qwen3-Embedding-0.6B-ONNX - n24q02m/Qwen3-Reranker-0.6B-ONNX Closes #528 Closes #529 Related to #530 --- fastembed/common/model_description.py | 1 + fastembed/common/preprocessor_utils.py | 12 +- fastembed/common/utils.py | 22 ++ .../cross_encoder/qwen3_cross_encoder.py | 205 ++++++++++++++++++ .../cross_encoder/text_cross_encoder.py | 2 + fastembed/text/qwen3_embedding.py | 182 ++++++++++++++++ fastembed/text/text_embedding.py | 2 + tests/test_text_cross_encoder.py | 1 + tests/test_text_onnx_embeddings.py | 3 + 9 files changed, 426 insertions(+), 4 deletions(-) create mode 100644 fastembed/rerank/cross_encoder/qwen3_cross_encoder.py create mode 100644 fastembed/text/qwen3_embedding.py diff --git a/fastembed/common/model_description.py b/fastembed/common/model_description.py index caa17710f..0c539fd84 100644 --- a/fastembed/common/model_description.py +++ b/fastembed/common/model_description.py @@ -49,4 +49,5 @@ class SparseModelDescription(BaseModelDescription): class PoolingType(str, Enum): CLS = "CLS" MEAN = "MEAN" + LAST_TOKEN = "LAST_TOKEN" DISABLED = "DISABLED" diff --git a/fastembed/common/preprocessor_utils.py b/fastembed/common/preprocessor_utils.py index 3b702f799..fe0d75f11 100644 --- a/fastembed/common/preprocessor_utils.py +++ b/fastembed/common/preprocessor_utils.py @@ -10,7 +10,7 @@ def load_special_tokens(model_dir: Path) -> dict[str, Any]: tokens_map_path = model_dir / "special_tokens_map.json" if not tokens_map_path.exists(): - raise ValueError(f"Could not find special_tokens_map.json in {model_dir}") + return {} with open(str(tokens_map_path)) as tokens_map_file: tokens_map = json.load(tokens_map_file) @@ -51,9 +51,13 @@ def load_tokenizer(model_dir: Path) -> tuple[Tokenizer, dict[str, int]]: tokenizer = Tokenizer.from_file(str(tokenizer_path)) tokenizer.enable_truncation(max_length=max_context) if not tokenizer.padding: - tokenizer.enable_padding( - pad_id=config.get("pad_token_id", 0), pad_token=tokenizer_config["pad_token"] - ) + pad_token_id = config.get("pad_token_id") + if pad_token_id is None: + pad_token_id = 0 + pad_token = tokenizer_config.get("pad_token", "") + if isinstance(pad_token, dict): + pad_token = pad_token.get("content", "") + tokenizer.enable_padding(pad_id=pad_token_id, pad_token=pad_token) for token in tokens_map.values(): if isinstance(token, str): diff --git a/fastembed/common/utils.py b/fastembed/common/utils.py index b61a8b9ce..90e3a2be6 100644 --- a/fastembed/common/utils.py +++ b/fastembed/common/utils.py @@ -15,6 +15,28 @@ T = TypeVar("T") +def last_token_pool(input_array: NumpyArray, attention_mask: NDArray[np.int64]) -> NumpyArray: + """Extract embedding from the last non-padding token position. + + Qwen3-Embedding uses last-token pooling (NOT CLS/mean pooling). + Handles both left-padding and right-padding. + + Args: + input_array: Model output, shape (batch_size, seq_len, hidden_dim). + attention_mask: Attention mask, shape (batch_size, seq_len). + + Returns: + Pooled embeddings, shape (batch_size, hidden_dim). + """ + left_padding = bool(attention_mask[:, -1].sum() == attention_mask.shape[0]) + if left_padding: + return input_array[:, -1] + + sequence_lengths = attention_mask.sum(axis=1).astype(np.int64) - 1 + batch_size = input_array.shape[0] + return input_array[np.arange(batch_size), sequence_lengths] + + def normalize(input_array: NumpyArray, p: int = 2, dim: int = 1, eps: float = 1e-12) -> NumpyArray: # Calculate the Lp norm along the specified dimension norm = np.linalg.norm(input_array, ord=p, axis=dim, keepdims=True) diff --git a/fastembed/rerank/cross_encoder/qwen3_cross_encoder.py b/fastembed/rerank/cross_encoder/qwen3_cross_encoder.py new file mode 100644 index 000000000..f511064a7 --- /dev/null +++ b/fastembed/rerank/cross_encoder/qwen3_cross_encoder.py @@ -0,0 +1,205 @@ +"""Qwen3 reranker using causal LM with yes/no logit scoring. + +Unlike traditional cross-encoder rerankers (which concatenate query+document +as a pair, feed through a BERT-class model, and read a relevance head), the +Qwen3 reranker: + +1. Formats input as a **chat template** with system/user/assistant turns. +2. Runs a **causal language model** (Qwen3ForCausalLM). +3. Extracts the **last-token logits** for the "yes" and "no" tokens. +4. Applies **softmax** to obtain the relevance probability. + +This means the ONNX model output has shape ``(batch, seq_len, vocab_size)`` +instead of the typical ``(batch, num_labels)`` from cross-encoders. +""" + +from typing import Any + +import numpy as np + +from fastembed.common.model_description import BaseModelDescription, ModelSource +from fastembed.common.onnx_model import OnnxOutputContext +from fastembed.common.types import NumpyArray +from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import ( + OnnxTextCrossEncoder, + TextCrossEncoderWorker, +) +from fastembed.rerank.cross_encoder.onnx_text_model import TextRerankerWorker + +# --------------------------------------------------------------------------- +# Qwen3 reranker constants +# --------------------------------------------------------------------------- +# Token IDs in the Qwen3 tokenizer vocabulary +TOKEN_YES_ID = 9693 +TOKEN_NO_ID = 2132 + +SYSTEM_PROMPT = ( + "Judge whether the Document meets the requirements based on the Query " + 'and the Instruct provided. Note that the answer can only be "yes" or "no".' +) + +DEFAULT_INSTRUCTION = ( + "Given a query and a document, judge whether the document is relevant to the query." +) + +RERANK_TEMPLATE = ( + "<|im_start|>system\n{system}<|im_end|>\n" + "<|im_start|>user\n: {instruction}\n" + ": {query}\n: {document}<|im_end|>\n" + "<|im_start|>assistant\n\n\n\n\n" +) + +# --------------------------------------------------------------------------- +# Model registry +# --------------------------------------------------------------------------- +supported_qwen3_reranker_models: list[BaseModelDescription] = [ + BaseModelDescription( + model="Qwen/Qwen3-Reranker-0.6B", + description=( + "Qwen3 reranker (0.6B) using causal LM yes/no scoring. " + "Multilingual, 40960 input tokens, instruction-aware, 2025 year." + ), + license="apache-2.0", + size_in_GB=0.57, + sources=ModelSource(hf="n24q02m/Qwen3-Reranker-0.6B-ONNX"), + model_file="onnx/model.onnx", + ), +] + + +# --------------------------------------------------------------------------- +# Qwen3 reranker implementation +# --------------------------------------------------------------------------- +class Qwen3CrossEncoder(OnnxTextCrossEncoder): + """Qwen3 Reranker using causal LM with yes/no logit scoring. + + Usage:: + + from fastembed import TextCrossEncoder + + reranker = TextCrossEncoder("Qwen/Qwen3-Reranker-0.6B") + scores = list(reranker.rerank("What is AI?", ["doc1", "doc2"])) + + # Custom instruction + scores = list(reranker.rerank( + "What is AI?", + ["doc1", "doc2"], + instruction="Judge document relevance for code search.", + )) + """ + + @classmethod + def _list_supported_models(cls) -> list[BaseModelDescription]: + return supported_qwen3_reranker_models + + # ------------------------------------------------------------------ + # Chat template formatting + # ------------------------------------------------------------------ + @staticmethod + def _format_rerank_input( + query: str, + document: str, + instruction: str = DEFAULT_INSTRUCTION, + ) -> str: + """Build the chat-template string for a single query-document pair.""" + return RERANK_TEMPLATE.format( + system=SYSTEM_PROMPT, + instruction=instruction, + query=query, + document=document, + ) + + # ------------------------------------------------------------------ + # Yes/No logit scoring + # ------------------------------------------------------------------ + @staticmethod + def _compute_yes_no_scores(model_output: NumpyArray) -> NumpyArray: + """Extract yes/no logits from causal LM output and compute scores. + + Args: + model_output: Raw model output, shape ``(batch, seq_len, vocab_size)``. + + Returns: + Relevance scores (P(yes)), shape ``(batch,)``. + """ + # Last token logits for each sample + last_logits: NumpyArray = model_output[:, -1, :] # (batch, vocab_size) + + # Stack [no, yes] logits + yes_no_logits = np.stack( + [last_logits[:, TOKEN_NO_ID], last_logits[:, TOKEN_YES_ID]], axis=1 + ) # (batch, 2) + + # Numerically stable softmax + max_logits = np.max(yes_no_logits, axis=1, keepdims=True) + exp_logits = np.exp(yes_no_logits - max_logits) + probs = exp_logits / np.sum(exp_logits, axis=1, keepdims=True) + + return probs[:, 1] # P(yes) + + # ------------------------------------------------------------------ + # Override ONNX inference to use chat-template + CausalLM scoring + # ------------------------------------------------------------------ + def onnx_embed(self, query: str, documents: list[str], **kwargs: Any) -> OnnxOutputContext: + """Score query-document pairs using the Qwen3 chat template.""" + instruction = kwargs.pop("instruction", DEFAULT_INSTRUCTION) + texts = [self._format_rerank_input(query, doc, instruction) for doc in documents] + return self._onnx_embed_texts(texts, **kwargs) + + def onnx_embed_pairs(self, pairs: list[tuple[str, str]], **kwargs: Any) -> OnnxOutputContext: + """Score pre-formed (query, document) pairs.""" + instruction = kwargs.pop("instruction", DEFAULT_INSTRUCTION) + texts = [self._format_rerank_input(query, doc, instruction) for query, doc in pairs] + return self._onnx_embed_texts(texts, **kwargs) + + def _onnx_embed_texts(self, texts: list[str], **kwargs: Any) -> OnnxOutputContext: + """Tokenise and run model one text at a time (static batch=1 ONNX graph), + then concatenate the yes/no scores.""" + assert self.tokenizer is not None, "Tokenizer not loaded. Call load_onnx_model() first." + + all_scores: list[NumpyArray] = [] + for text in texts: + tokenized = self.tokenizer.encode_batch([text]) + + input_names: set[str] = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr] + onnx_input: dict[str, NumpyArray] = { + "input_ids": np.array([tokenized[0].ids], dtype=np.int64), + } + if "attention_mask" in input_names: + onnx_input["attention_mask"] = np.array( + [tokenized[0].attention_mask], dtype=np.int64 + ) + if "token_type_ids" in input_names: + onnx_input["token_type_ids"] = np.zeros_like( + onnx_input["input_ids"], dtype=np.int64 + ) + + onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs) + outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr] + scores = self._compute_yes_no_scores(outputs[0]) + all_scores.append(scores) + + concatenated = np.concatenate(all_scores).astype(np.float32) + return OnnxOutputContext(model_output=concatenated) + + # ------------------------------------------------------------------ + # Worker + # ------------------------------------------------------------------ + @classmethod + def _get_worker_class(cls) -> type[TextRerankerWorker]: + return Qwen3CrossEncoderWorker + + +class Qwen3CrossEncoderWorker(TextCrossEncoderWorker): + def init_embedding( + self, + model_name: str, + cache_dir: str, + **kwargs: Any, + ) -> OnnxTextCrossEncoder: + return Qwen3CrossEncoder( + model_name=model_name, + cache_dir=cache_dir, + threads=1, + **kwargs, + ) diff --git a/fastembed/rerank/cross_encoder/text_cross_encoder.py b/fastembed/rerank/cross_encoder/text_cross_encoder.py index 6f98cb24a..99048f161 100644 --- a/fastembed/rerank/cross_encoder/text_cross_encoder.py +++ b/fastembed/rerank/cross_encoder/text_cross_encoder.py @@ -5,6 +5,7 @@ from fastembed.common.types import Device from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder from fastembed.rerank.cross_encoder.custom_text_cross_encoder import CustomTextCrossEncoder +from fastembed.rerank.cross_encoder.qwen3_cross_encoder import Qwen3CrossEncoder from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase from fastembed.common.model_description import ( @@ -16,6 +17,7 @@ class TextCrossEncoder(TextCrossEncoderBase): CROSS_ENCODER_REGISTRY: list[Type[TextCrossEncoderBase]] = [ OnnxTextCrossEncoder, + Qwen3CrossEncoder, CustomTextCrossEncoder, ] diff --git a/fastembed/text/qwen3_embedding.py b/fastembed/text/qwen3_embedding.py new file mode 100644 index 000000000..38408c3d5 --- /dev/null +++ b/fastembed/text/qwen3_embedding.py @@ -0,0 +1,182 @@ +"""Qwen3 text embedding with last-token pooling and Matryoshka (MRL) support. + +Qwen3-Embedding uses a causal LM architecture with last-token pooling instead +of the traditional CLS or mean pooling used by BERT-family models. It also +supports Matryoshka Representation Learning (MRL), allowing truncation of +embeddings to smaller dimensions (32-1024) with graceful degradation. + +Key differences from standard text embedding models: + - Last-token pooling: embedding is extracted from the last non-padding token + - Left-padding: the tokenizer pads from the left (not right) + - Instruction-aware: queries use ``Instruct: {task}\\nQuery: {text}`` format + - MRL: pass ``dim=256`` (or any value 32-1024) to ``embed()`` / ``query_embed()`` +""" + +from collections.abc import Iterable +from typing import Any + +from fastembed.common.model_description import DenseModelDescription, ModelSource +from fastembed.common.onnx_model import OnnxOutputContext +from fastembed.common.types import NumpyArray +from fastembed.common.utils import last_token_pool, normalize +from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker + +# --------------------------------------------------------------------------- +# Model registry +# --------------------------------------------------------------------------- +supported_qwen3_models: list[DenseModelDescription] = [ + DenseModelDescription( + model="Qwen/Qwen3-Embedding-0.6B", + dim=1024, + description=( + "Qwen3 text embedding (0.6B) with last-token pooling and MRL support " + "(32-1024 dims). Multilingual, 32768 input tokens, instruction-aware, " + "2025 year." + ), + license="apache-2.0", + size_in_GB=0.57, + sources=ModelSource(hf="n24q02m/Qwen3-Embedding-0.6B-ONNX"), + model_file="onnx/model.onnx", + ), +] + +# --------------------------------------------------------------------------- +# Instruction template +# --------------------------------------------------------------------------- +DEFAULT_TASK = "Given a query, retrieve relevant documents that answer the query" +QUERY_INSTRUCTION_TEMPLATE = "Instruct: {task}\nQuery: {text}" + + +# --------------------------------------------------------------------------- +# Qwen3 embedding implementation +# --------------------------------------------------------------------------- +class Qwen3TextEmbedding(OnnxTextEmbedding): + """Qwen3 Embedding model with last-token pooling and MRL support. + + Usage:: + + from fastembed import TextEmbedding + + model = TextEmbedding("Qwen/Qwen3-Embedding-0.6B") + embeddings = list(model.embed(["Hello world"])) + + # MRL: reduce dimension + embeddings_256 = list(model.embed(["Hello world"], dim=256)) + + # Query with custom task instruction + query_emb = list(model.query_embed("What is AI?", task="...")) + """ + + @classmethod + def _list_supported_models(cls) -> list[DenseModelDescription]: + return supported_qwen3_models + + def _post_process_onnx_output( + self, output: OnnxOutputContext, **kwargs: Any + ) -> Iterable[NumpyArray]: + if output.attention_mask is None: + raise ValueError("attention_mask must be provided for last-token pooling") + + embeddings = last_token_pool(output.model_output, output.attention_mask) + + # MRL: optionally truncate to requested dimension + dim: int | None = kwargs.get("dim") + if dim is not None: + embeddings = embeddings[:, :dim] + + return normalize(embeddings) + + # ------------------------------------------------------------------ + # embed / query_embed / passage_embed + # ------------------------------------------------------------------ + def embed( + self, + documents: str | Iterable[str], + batch_size: int = 1, + parallel: int | None = None, + **kwargs: Any, + ) -> Iterable[NumpyArray]: + """Encode documents into embeddings. + + Args: + documents: A single document string or an iterable of documents. + batch_size: Ignored -- always ``1`` because + the causal-LM ONNX graph has a static batch dimension. + parallel: Number of parallel workers (``None`` = single-threaded). + **kwargs: Extra arguments; ``dim`` (int) enables MRL truncation, + ``task`` (str) is used only by :meth:`query_embed`. + + Yields: + NumpyArray: L2-normalised embeddings, one per document. + """ + yield from self._embed_documents( + model_name=self.model_name, + cache_dir=str(self.cache_dir), + documents=documents, + batch_size=1, + parallel=parallel, + providers=self.providers, + cuda=self.cuda, + device_ids=self.device_ids, + local_files_only=self._local_files_only, + specific_model_path=self._specific_model_path, + extra_session_options=self._extra_session_options, + **kwargs, + ) + + def query_embed(self, query: str | Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]: + """Embed queries with instruction prefix. + + The instruction prefix follows the Qwen3 format:: + + Instruct: {task} + Query: {query_text} + + Args: + query: A single query string or an iterable of queries. + **kwargs: ``task`` (str) overrides the default retrieval instruction. + ``dim`` (int) enables MRL truncation. + + Yields: + NumpyArray: L2-normalised query embeddings. + """ + task = kwargs.pop("task", DEFAULT_TASK) + if isinstance(query, str): + queries = [QUERY_INSTRUCTION_TEMPLATE.format(task=task, text=query)] + else: + queries = [QUERY_INSTRUCTION_TEMPLATE.format(task=task, text=q) for q in query] + yield from self.embed(queries, **kwargs) + + def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]: + """Embed passages (documents) without instruction prefix. + + Args: + texts: An iterable of passage strings. + **kwargs: ``dim`` (int) enables MRL truncation. + + Yields: + NumpyArray: L2-normalised passage embeddings. + """ + yield from self.embed(texts, **kwargs) + + # ------------------------------------------------------------------ + # Worker + # ------------------------------------------------------------------ + @classmethod + def _get_worker_class(cls) -> type[OnnxTextEmbeddingWorker]: + return Qwen3TextEmbeddingWorker + + +class Qwen3TextEmbeddingWorker(OnnxTextEmbeddingWorker): + def init_embedding( + self, + model_name: str, + cache_dir: str, + **kwargs: Any, + ) -> OnnxTextEmbedding: + return Qwen3TextEmbedding( + model_name=model_name, + cache_dir=cache_dir, + threads=1, + **kwargs, + ) diff --git a/fastembed/text/text_embedding.py b/fastembed/text/text_embedding.py index a4ae48cc5..2c000f4d4 100644 --- a/fastembed/text/text_embedding.py +++ b/fastembed/text/text_embedding.py @@ -9,6 +9,7 @@ from fastembed.text.pooled_embedding import PooledEmbedding from fastembed.text.multitask_embedding import JinaEmbeddingV3 from fastembed.text.onnx_embedding import OnnxTextEmbedding +from fastembed.text.qwen3_embedding import Qwen3TextEmbedding from fastembed.text.text_embedding_base import TextEmbeddingBase from fastembed.common.model_description import DenseModelDescription, ModelSource, PoolingType @@ -20,6 +21,7 @@ class TextEmbedding(TextEmbeddingBase): PooledNormalizedEmbedding, PooledEmbedding, JinaEmbeddingV3, + Qwen3TextEmbedding, CustomTextEmbedding, ] diff --git a/tests/test_text_cross_encoder.py b/tests/test_text_cross_encoder.py index 4d0d5b7d6..53cf6d144 100644 --- a/tests/test_text_cross_encoder.py +++ b/tests/test_text_cross_encoder.py @@ -14,6 +14,7 @@ "jinaai/jina-reranker-v1-tiny-en": np.array([2.5911, 0.1122]), "jinaai/jina-reranker-v1-turbo-en": np.array([1.8295, -2.8908]), "jinaai/jina-reranker-v2-base-multilingual": np.array([1.6533, -1.6455]), + "Qwen/Qwen3-Reranker-0.6B": np.array([0.99449426, 0.01638701]), } diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index e919faf9d..92d9a4783 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -68,6 +68,9 @@ "Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]), "thenlper/gte-base": np.array([0.0038, 0.0355, 0.0181, 0.0092, 0.0654]), "jinaai/jina-clip-v1": np.array([-0.0862, -0.0101, -0.0056, 0.0375, -0.0472]), + "Qwen/Qwen3-Embedding-0.6B": np.array( + [-0.02225659, 0.01872586, -0.01449341, -0.08536665, 0.01223033] + ), } MULTI_TASK_MODELS = ["jinaai/jina-embeddings-v3"] From 80d726c4cc6489a34a294bf043e7dbd36d3b1aac Mon Sep 17 00:00:00 2001 From: n24q02m Date: Sat, 14 Feb 2026 09:36:48 +0700 Subject: [PATCH 2/6] fix: address code review feedback - Add logging warning when pad_token_id defaults to 0 - Hoist input_names computation out of per-text loop - Add dim parameter validation (1 <= dim <= max_dim) - Add batch_size warning when non-1 value is ignored - Add docstrings to all public/internal methods for coverage --- fastembed/common/preprocessor_utils.py | 7 +++++++ .../rerank/cross_encoder/qwen3_cross_encoder.py | 8 ++++++-- fastembed/text/qwen3_embedding.py | 17 +++++++++++++++++ 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/fastembed/common/preprocessor_utils.py b/fastembed/common/preprocessor_utils.py index fe0d75f11..1fa0acff5 100644 --- a/fastembed/common/preprocessor_utils.py +++ b/fastembed/common/preprocessor_utils.py @@ -1,4 +1,5 @@ import json +import logging from typing import Any from pathlib import Path @@ -6,6 +7,8 @@ from fastembed.image.transform.operators import Compose +logger = logging.getLogger(__name__) + def load_special_tokens(model_dir: Path) -> dict[str, Any]: tokens_map_path = model_dir / "special_tokens_map.json" @@ -53,6 +56,10 @@ def load_tokenizer(model_dir: Path) -> tuple[Tokenizer, dict[str, int]]: if not tokenizer.padding: pad_token_id = config.get("pad_token_id") if pad_token_id is None: + logger.warning( + "pad_token_id not found in config.json for %s, defaulting to 0", + model_dir.name, + ) pad_token_id = 0 pad_token = tokenizer_config.get("pad_token", "") if isinstance(pad_token, dict): diff --git a/fastembed/rerank/cross_encoder/qwen3_cross_encoder.py b/fastembed/rerank/cross_encoder/qwen3_cross_encoder.py index f511064a7..18b814f2d 100644 --- a/fastembed/rerank/cross_encoder/qwen3_cross_encoder.py +++ b/fastembed/rerank/cross_encoder/qwen3_cross_encoder.py @@ -90,6 +90,7 @@ class Qwen3CrossEncoder(OnnxTextCrossEncoder): @classmethod def _list_supported_models(cls) -> list[BaseModelDescription]: + """Return the list of supported Qwen3 reranker models.""" return supported_qwen3_reranker_models # ------------------------------------------------------------------ @@ -157,11 +158,10 @@ def _onnx_embed_texts(self, texts: list[str], **kwargs: Any) -> OnnxOutputContex then concatenate the yes/no scores.""" assert self.tokenizer is not None, "Tokenizer not loaded. Call load_onnx_model() first." + input_names: set[str] = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr] all_scores: list[NumpyArray] = [] for text in texts: tokenized = self.tokenizer.encode_batch([text]) - - input_names: set[str] = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr] onnx_input: dict[str, NumpyArray] = { "input_ids": np.array([tokenized[0].ids], dtype=np.int64), } @@ -187,16 +187,20 @@ def _onnx_embed_texts(self, texts: list[str], **kwargs: Any) -> OnnxOutputContex # ------------------------------------------------------------------ @classmethod def _get_worker_class(cls) -> type[TextRerankerWorker]: + """Return the worker class for parallel processing.""" return Qwen3CrossEncoderWorker class Qwen3CrossEncoderWorker(TextCrossEncoderWorker): + """Worker for parallel Qwen3 reranker inference.""" + def init_embedding( self, model_name: str, cache_dir: str, **kwargs: Any, ) -> OnnxTextCrossEncoder: + """Initialise a Qwen3CrossEncoder instance for the worker.""" return Qwen3CrossEncoder( model_name=model_name, cache_dir=cache_dir, diff --git a/fastembed/text/qwen3_embedding.py b/fastembed/text/qwen3_embedding.py index 38408c3d5..da95c8122 100644 --- a/fastembed/text/qwen3_embedding.py +++ b/fastembed/text/qwen3_embedding.py @@ -12,6 +12,7 @@ - MRL: pass ``dim=256`` (or any value 32-1024) to ``embed()`` / ``query_embed()`` """ +import logging from collections.abc import Iterable from typing import Any @@ -46,6 +47,8 @@ DEFAULT_TASK = "Given a query, retrieve relevant documents that answer the query" QUERY_INSTRUCTION_TEMPLATE = "Instruct: {task}\nQuery: {text}" +logger = logging.getLogger(__name__) + # --------------------------------------------------------------------------- # Qwen3 embedding implementation @@ -69,11 +72,13 @@ class Qwen3TextEmbedding(OnnxTextEmbedding): @classmethod def _list_supported_models(cls) -> list[DenseModelDescription]: + """Return the list of supported Qwen3 embedding models.""" return supported_qwen3_models def _post_process_onnx_output( self, output: OnnxOutputContext, **kwargs: Any ) -> Iterable[NumpyArray]: + """Apply last-token pooling, optional MRL truncation, and L2 normalisation.""" if output.attention_mask is None: raise ValueError("attention_mask must be provided for last-token pooling") @@ -82,6 +87,9 @@ def _post_process_onnx_output( # MRL: optionally truncate to requested dimension dim: int | None = kwargs.get("dim") if dim is not None: + max_dim = embeddings.shape[-1] + if not (1 <= dim <= max_dim): + raise ValueError(f"dim must be between 1 and {max_dim}, got {dim}") embeddings = embeddings[:, :dim] return normalize(embeddings) @@ -109,6 +117,11 @@ def embed( Yields: NumpyArray: L2-normalised embeddings, one per document. """ + if batch_size != 1: + logger.warning( + "batch_size=%d ignored for Qwen3; causal-LM ONNX graph requires batch_size=1", + batch_size, + ) yield from self._embed_documents( model_name=self.model_name, cache_dir=str(self.cache_dir), @@ -164,16 +177,20 @@ def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyAr # ------------------------------------------------------------------ @classmethod def _get_worker_class(cls) -> type[OnnxTextEmbeddingWorker]: + """Return the worker class for parallel processing.""" return Qwen3TextEmbeddingWorker class Qwen3TextEmbeddingWorker(OnnxTextEmbeddingWorker): + """Worker for parallel Qwen3 embedding inference.""" + def init_embedding( self, model_name: str, cache_dir: str, **kwargs: Any, ) -> OnnxTextEmbedding: + """Initialise a Qwen3TextEmbedding instance for the worker.""" return Qwen3TextEmbedding( model_name=model_name, cache_dir=cache_dir, From 4c5c99752e34b29de6a8dad62a4d97fc49c1b723 Mon Sep 17 00:00:00 2001 From: n24q02m Date: Sat, 14 Feb 2026 14:18:54 +0700 Subject: [PATCH 3/6] feat: add Q4F16 model variant support for Qwen3 - Register Q4F16 variants for Qwen3-Embedding and Qwen3-Reranker - Add float16-to-float32 cast after ONNX inference for Q4F16 outputs - Fix snapshot_download cache bug: verify model_file exists in cached snapshot before returning (prevents stale cache hit when multiple variants share the same HF repo) --- fastembed/common/model_management.py | 9 +++++++- .../cross_encoder/qwen3_cross_encoder.py | 22 ++++++++++++++++--- fastembed/text/onnx_text_model.py | 5 ++++- fastembed/text/qwen3_embedding.py | 19 +++++++++++++--- 4 files changed, 47 insertions(+), 8 deletions(-) diff --git a/fastembed/common/model_management.py b/fastembed/common/model_management.py index 5301def80..90b1e0ee2 100644 --- a/fastembed/common/model_management.py +++ b/fastembed/common/model_management.py @@ -409,7 +409,7 @@ def download_model(cls, model: T, cache_dir: str, retries: int = 3, **kwargs: An try: cache_kwargs = deepcopy(kwargs) cache_kwargs["local_files_only"] = True - return Path( + cached_dir = Path( cls.download_files_from_huggingface( hf_source, cache_dir=cache_dir, @@ -417,6 +417,13 @@ def download_model(cls, model: T, cache_dir: str, retries: int = 3, **kwargs: An **cache_kwargs, ) ) + # Verify all required files exist in cache before returning + missing = [p for p in extra_patterns if not (cached_dir / p).exists()] + if missing: + raise FileNotFoundError( + f"Cached snapshot missing files: {missing}" + ) + return cached_dir except Exception: pass finally: diff --git a/fastembed/rerank/cross_encoder/qwen3_cross_encoder.py b/fastembed/rerank/cross_encoder/qwen3_cross_encoder.py index 18b814f2d..a6d9b744d 100644 --- a/fastembed/rerank/cross_encoder/qwen3_cross_encoder.py +++ b/fastembed/rerank/cross_encoder/qwen3_cross_encoder.py @@ -57,12 +57,25 @@ model="Qwen/Qwen3-Reranker-0.6B", description=( "Qwen3 reranker (0.6B) using causal LM yes/no scoring. " - "Multilingual, 40960 input tokens, instruction-aware, 2025 year." + "INT8 dynamic quantized. Multilingual, 40960 input tokens, " + "instruction-aware, 2025 year." ), license="apache-2.0", size_in_GB=0.57, sources=ModelSource(hf="n24q02m/Qwen3-Reranker-0.6B-ONNX"), - model_file="onnx/model.onnx", + model_file="onnx/model_quantized.onnx", + ), + BaseModelDescription( + model="Qwen/Qwen3-Reranker-0.6B-Q4F16", + description=( + "Qwen3 reranker (0.6B) using causal LM yes/no scoring. " + "INT4 weights + FP16 activations (Q4F16). Multilingual, " + "40960 input tokens, instruction-aware, 2025 year." + ), + license="apache-2.0", + size_in_GB=0.57, + sources=ModelSource(hf="n24q02m/Qwen3-Reranker-0.6B-ONNX"), + model_file="onnx/model_q4f16.onnx", ), ] @@ -176,7 +189,10 @@ def _onnx_embed_texts(self, texts: list[str], **kwargs: Any) -> OnnxOutputContex onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs) outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr] - scores = self._compute_yes_no_scores(outputs[0]) + model_output = outputs[0] + if model_output.dtype == np.float16: + model_output = model_output.astype(np.float32) + scores = self._compute_yes_no_scores(model_output) all_scores.append(scores) concatenated = np.concatenate(all_scores).astype(np.float32) diff --git a/fastembed/text/onnx_text_model.py b/fastembed/text/onnx_text_model.py index c8001a917..7593c9847 100644 --- a/fastembed/text/onnx_text_model.py +++ b/fastembed/text/onnx_text_model.py @@ -94,8 +94,11 @@ def onnx_embed( onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs) model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr] + embeddings = model_output[0] + if embeddings.dtype == np.float16: + embeddings = embeddings.astype(np.float32) return OnnxOutputContext( - model_output=model_output[0], + model_output=embeddings, attention_mask=onnx_input.get("attention_mask", attention_mask), input_ids=onnx_input.get("input_ids", input_ids), ) diff --git a/fastembed/text/qwen3_embedding.py b/fastembed/text/qwen3_embedding.py index da95c8122..cda8966b0 100644 --- a/fastembed/text/qwen3_embedding.py +++ b/fastembed/text/qwen3_embedding.py @@ -31,13 +31,26 @@ dim=1024, description=( "Qwen3 text embedding (0.6B) with last-token pooling and MRL support " - "(32-1024 dims). Multilingual, 32768 input tokens, instruction-aware, " - "2025 year." + "(32-1024 dims). INT8 dynamic quantized. Multilingual, 32768 input tokens, " + "instruction-aware, 2025 year." ), license="apache-2.0", size_in_GB=0.57, sources=ModelSource(hf="n24q02m/Qwen3-Embedding-0.6B-ONNX"), - model_file="onnx/model.onnx", + model_file="onnx/model_quantized.onnx", + ), + DenseModelDescription( + model="Qwen/Qwen3-Embedding-0.6B-Q4F16", + dim=1024, + description=( + "Qwen3 text embedding (0.6B) with last-token pooling and MRL support " + "(32-1024 dims). INT4 weights + FP16 activations (Q4F16). Multilingual, " + "32768 input tokens, instruction-aware, 2025 year." + ), + license="apache-2.0", + size_in_GB=0.57, + sources=ModelSource(hf="n24q02m/Qwen3-Embedding-0.6B-ONNX"), + model_file="onnx/model_q4f16.onnx", ), ] From 00e932f6b6e1ee7ef06e062c0fc17dd7060f31ce Mon Sep 17 00:00:00 2001 From: n24q02m Date: Fri, 20 Feb 2026 21:25:57 +0700 Subject: [PATCH 4/6] test: add left-padding batching tests for Qwen3 Causal LM Adds tests to ensure that left-padding and last-token pooling correctly handles batch inference without losing positional context for short strings. --- tests/test_text_cross_encoder.py | 20 ++++++++++++++++++++ tests/test_text_onnx_embeddings.py | 19 +++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/tests/test_text_cross_encoder.py b/tests/test_text_cross_encoder.py index 53cf6d144..d6ed4debf 100644 --- a/tests/test_text_cross_encoder.py +++ b/tests/test_text_cross_encoder.py @@ -150,3 +150,23 @@ def test_session_options(model_cache, model_name) -> None: model = TextCrossEncoder(model_name=model_name, enable_cpu_mem_arena=False) session_options = model.model.model.get_session_options() assert session_options.enable_cpu_mem_arena is False + + +def test_qwen3_reranker_left_padding_batch(model_cache) -> None: + '''Test to ensure Qwen3 causal logit cross encoder works reliably when left-padded in batch.''' + model_name = "Qwen/Qwen3-Reranker-0.6B" + query = "Testing Qwen" + short_doc = "This is a short doc." + long_doc = "This is a significantly longer string that will force the shorter string to be padded with `` tokens on the left side during the tokenization phase. The embedding pooling must ignore these left padding tokens." + + with model_cache(model_name) as model: + # Infer short string alone + single_result = list(model.rerank(query, [short_doc]))[0].score + + # Infer short string mixed in a batch with a very long string + batch_results = list(model.rerank(query, [long_doc, short_doc])) + batch_result_short = batch_results[1].score + + # Ensure the score is exactly the same, proving causal LM logit selection is precise + import numpy as np + assert np.allclose(single_result, batch_result_short, atol=1e-4) diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index 92d9a4783..e6eabdd5a 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -220,3 +220,22 @@ def test_token_count(model_cache, model_name) -> None: doc_token_count = model.token_count(documents) assert first_doc_token_count + second_doc_token_count == doc_token_count assert doc_token_count == model.token_count(documents, batch_size=1) + + +def test_qwen3_left_padding_batch(model_cache) -> None: + '''Test to ensure causal LMs like Qwen3 properly pool from the last actual token when using left padding in a batch''' + model_name = "Qwen/Qwen3-Embedding-0.6B" + short_text = "Hello." + long_text = "This is a significantly longer string that will force the shorter string to be padded with `` tokens on the left side during the tokenization phase. The embedding pooling must ignore these left padding tokens." + + with model_cache(model_name) as model: + # Infer short string alone + single_result = list(model.embed([short_text]))[0] + + # Infer short string mixed in a batch with a very long string + batch_results = list(model.embed([long_text, short_text])) + batch_result_short = batch_results[1] + + # Ensure the vector is exactly the same, proving left-padding last-token pooling is precise + import numpy as np + assert np.allclose(single_result, batch_result_short, atol=1e-4) From deaf66404683fd72821f072d7a2f1642a6b3c2be Mon Sep 17 00:00:00 2001 From: n24q02m Date: Fri, 20 Feb 2026 21:45:58 +0700 Subject: [PATCH 5/6] test: address CodeRabbit review feedback on left-padding tests --- tests/test_text_cross_encoder.py | 3 +- tests/test_text_onnx_embeddings.py | 57 +++++++++++++++++++++--------- 2 files changed, 42 insertions(+), 18 deletions(-) diff --git a/tests/test_text_cross_encoder.py b/tests/test_text_cross_encoder.py index d6ed4debf..655e6a30c 100644 --- a/tests/test_text_cross_encoder.py +++ b/tests/test_text_cross_encoder.py @@ -161,12 +161,11 @@ def test_qwen3_reranker_left_padding_batch(model_cache) -> None: with model_cache(model_name) as model: # Infer short string alone - single_result = list(model.rerank(query, [short_doc]))[0].score + single_result = next(iter(model.rerank(query, [short_doc]))) # Infer short string mixed in a batch with a very long string batch_results = list(model.rerank(query, [long_doc, short_doc])) batch_result_short = batch_results[1].score # Ensure the score is exactly the same, proving causal LM logit selection is precise - import numpy as np assert np.allclose(single_result, batch_result_short, atol=1e-4) diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index e6eabdd5a..57eed33b2 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -222,20 +222,45 @@ def test_token_count(model_cache, model_name) -> None: assert doc_token_count == model.token_count(documents, batch_size=1) -def test_qwen3_left_padding_batch(model_cache) -> None: - '''Test to ensure causal LMs like Qwen3 properly pool from the last actual token when using left padding in a batch''' - model_name = "Qwen/Qwen3-Embedding-0.6B" - short_text = "Hello." - long_text = "This is a significantly longer string that will force the shorter string to be padded with `` tokens on the left side during the tokenization phase. The embedding pooling must ignore these left padding tokens." + +def test_qwen3_left_padding_batch_unit() -> None: + '''Directly verify last_token_pool behavior on synthetic left-padded hidden states.''' + from fastembed.common.utils import last_token_pool + import numpy as np - with model_cache(model_name) as model: - # Infer short string alone - single_result = list(model.embed([short_text]))[0] - - # Infer short string mixed in a batch with a very long string - batch_results = list(model.embed([long_text, short_text])) - batch_result_short = batch_results[1] - - # Ensure the vector is exactly the same, proving left-padding last-token pooling is precise - import numpy as np - assert np.allclose(single_result, batch_result_short, atol=1e-4) + # Simulate a batch of 2 sequences, max length 5, hidden size 4 + # Sequence 0: [pad, pad, token1, token2, token3] -> Left padded + # Sequence 1: [token1, token2, token3, token4, token5] -> Not padded + + hidden_states = np.array([ + # Seq 0 + [ + [0,0,0,0], # pad + [0,0,0,0], # pad + [1,1,1,1], # token1 + [2,2,2,2], # token2 + [3,3,3,3], # token3 (LAST TOKEN) + ], + # Seq 1 + [ + [9,9,9,9], # token1 + [8,8,8,8], # token2 + [7,7,7,7], # token3 + [6,6,6,6], # token4 + [5,5,5,5], # token5 (LAST TOKEN) + ] + ], dtype=np.float32) + + attention_mask = np.array([ + [0, 0, 1, 1, 1], # Seq 0 mask + [1, 1, 1, 1, 1] # Seq 1 mask + ], dtype=np.int64) + + pooled = last_token_pool(hidden_states, attention_mask) + + # Expected: The vector at index 4 (the last token) for both + expected_seq0 = np.array([3,3,3,3], dtype=np.float32) + expected_seq1 = np.array([5,5,5,5], dtype=np.float32) + + assert np.allclose(pooled[0], expected_seq0) + assert np.allclose(pooled[1], expected_seq1) From 637d04f53f7b3d85a6113f315dbf5a7cd2ae5a2e Mon Sep 17 00:00:00 2001 From: n24q02m Date: Fri, 6 Mar 2026 20:11:03 +0700 Subject: [PATCH 6/6] fix: address CodeRabbit review feedback - Add should_test_model guard to test_qwen3_reranker_left_padding_batch to prevent unconditional model download on every CI run - Remove redundant `import numpy as np` in test_qwen3_left_padding_batch_unit (already imported at module level) Co-Authored-By: Claude Opus 4.6 --- tests/test_text_cross_encoder.py | 9 +++++++++ tests/test_text_onnx_embeddings.py | 3 +-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_text_cross_encoder.py b/tests/test_text_cross_encoder.py index 655e6a30c..f30dcbe4f 100644 --- a/tests/test_text_cross_encoder.py +++ b/tests/test_text_cross_encoder.py @@ -154,7 +154,16 @@ def test_session_options(model_cache, model_name) -> None: def test_qwen3_reranker_left_padding_batch(model_cache) -> None: '''Test to ensure Qwen3 causal logit cross encoder works reliably when left-padded in batch.''' + is_ci = os.getenv("CI") + is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" model_name = "Qwen/Qwen3-Reranker-0.6B" + + for model_desc in TextCrossEncoder._list_supported_models(): + if model_desc.model != model_name: + continue + if not should_test_model(model_desc, model_name, is_ci, is_manual): + pytest.skip(f"Skipping {model_name} (not selected for this CI run)") + query = "Testing Qwen" short_doc = "This is a short doc." long_doc = "This is a significantly longer string that will force the shorter string to be padded with `` tokens on the left side during the tokenization phase. The embedding pooling must ignore these left padding tokens." diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index 57eed33b2..1e8da07ee 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -226,8 +226,7 @@ def test_token_count(model_cache, model_name) -> None: def test_qwen3_left_padding_batch_unit() -> None: '''Directly verify last_token_pool behavior on synthetic left-padded hidden states.''' from fastembed.common.utils import last_token_pool - import numpy as np - + # Simulate a batch of 2 sequences, max length 5, hidden size 4 # Sequence 0: [pad, pad, token1, token2, token3] -> Left padded # Sequence 1: [token1, token2, token3, token4, token5] -> Not padded