diff --git a/CHANGELOG.md b/CHANGELOG.md index 52ea52a..00e3cfc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,11 +1,8 @@ # Changelog -## 1.2.0-alpha.73 (2026-05-22) +## 1.2.0-alpha.73 (2026-05-27) -Full Changelog: [v1.2.0-alpha.72...v1.2.0-alpha.73](https://github.com/fw-ai-external/python-sdk/compare/v1.2.0-alpha.72...v1.2.0-alpha.73) - -### Chores -* autorelease single source of truth (#90) ([fc3f8c7](https://github.com/fw-ai-external/python-sdk/commit/fc3f8c72d7a72f40d95d32658488bea44e443d18)) +Full Changelog: [v1.2.0-alpha.73...v1.2.0-alpha.73](https://github.com/fw-ai-external/python-sdk/compare/v1.2.0-alpha.73...v1.2.0-alpha.73) ## 1.2.0-alpha.72 (2026-05-21) diff --git a/src/fireworks/training/sdk/__init__.py b/src/fireworks/training/sdk/__init__.py index 5084b42..edd2964 100644 --- a/src/fireworks/training/sdk/__init__.py +++ b/src/fireworks/training/sdk/__init__.py @@ -36,6 +36,7 @@ DeploymentManager, DeploymentSampler, SampledCompletion, + FiretitanSamplingClient, FixedConcurrencyController, AdaptiveConcurrencyController, ) @@ -69,6 +70,7 @@ "DeploymentInfo", "DeploymentManager", "DeploymentSampler", + "FiretitanSamplingClient", "AdaptiveConcurrencyController", "FixedConcurrencyController", "SampledCompletion", diff --git a/src/fireworks/training/sdk/deployment.py b/src/fireworks/training/sdk/deployment.py index f0cc5b9..0aea733 100644 --- a/src/fireworks/training/sdk/deployment.py +++ b/src/fireworks/training/sdk/deployment.py @@ -11,12 +11,16 @@ import random import asyncio import logging +import warnings +import threading from typing import TYPE_CHECKING, Any, List from dataclasses import dataclass +from concurrent.futures import Future as ConcurrentFuture import httpx if TYPE_CHECKING: + from tinker import types as tinker_types from transformers import PreTrainedTokenizerBase from fireworks.training.sdk.errors import ( @@ -1052,7 +1056,6 @@ def __init__( self._recent_metrics: list[ServerMetrics] = [] if max_concurrency is not None: - import warnings warnings.warn( "max_concurrency is deprecated and will be removed in a future release. " "Use concurrency_controller=FixedConcurrencyController(max_concurrency) " @@ -1337,9 +1340,9 @@ async def sample_with_prompt_tokens( and logprob extraction are inherited from :meth:`_do_one_completion` and :meth:`_parse_completions_result`. - ``stop`` preserves its caller-provided shape: ``list[str]`` is forwarded as - string stop sequences and ``list[int]`` as integer stop token IDs. No - coercion happens here. + ``list[str]`` stops are forwarded as string stop sequences. ``list[int]`` + stops are decoded with the sampler tokenizer before forwarding because + the completions API only accepts string stop sequences. """ if max_seq_len is not None and len(prompt_token_ids) >= max_seq_len: return [] @@ -1348,7 +1351,20 @@ async def sample_with_prompt_tokens( routing_requested = kwargs.get("include_routing_matrix", False) echo_mode = kwargs.get("echo", False) if stop is not None: - kwargs["stop"] = stop + if all(type(s) is str for s in stop): + kwargs["stop"] = stop + elif all(type(s) is int for s in stop): + if self.tokenizer is None: + raise ValueError( + "Tokenizer is required to convert integer stop token IDs " + "to string stop sequences for the completions API" + ) + kwargs["stop"] = [ + self.tokenizer.decode([token_id], skip_special_tokens=False) + for token_id in stop + ] + else: + raise ValueError("stop must be list[str] or list[int]") async def _one(_idx: int) -> List[SampledCompletion]: return await self._do_one_completion( @@ -1559,6 +1575,325 @@ def _parse_completions_result( return completions +class FiretitanSamplingClient: + """Tinker-compatible sampling wrapper backed by a ``DeploymentSampler``. + + The public surface mirrors ``tinker.lib.public_interfaces.SamplingClient`` + for sampling from an already-running Fireworks deployment. Methods return + ``concurrent.futures.Future`` objects for sync callers and provide async + convenience methods for coroutine users. + """ + + def __init__( + self, + deployment_sampler: DeploymentSampler, + ): + self.deployment_sampler = deployment_sampler + self._loop: asyncio.AbstractEventLoop | None = None + self._loop_thread: threading.Thread | None = None + self._loop_lock = threading.Lock() + self._closed = False + + @classmethod + def create( + cls, + *, + inference_url: str, + model: str, + api_key: str, + tokenizer: PreTrainedTokenizerBase | None = None, + concurrency_controller: "AdaptiveConcurrencyController | FixedConcurrencyController | None" = None, + max_concurrency: int | None = None, + ) -> "FiretitanSamplingClient": + """Build a Tinker-compatible client and the underlying sampler.""" + sampler = DeploymentSampler( + inference_url=inference_url, + model=model, + api_key=api_key, + tokenizer=tokenizer, + concurrency_controller=concurrency_controller, + max_concurrency=max_concurrency, + ) + return cls(sampler) + + @staticmethod + def _tinker_types(): + try: + from tinker import types + except ImportError as e: + raise ImportError( + "FiretitanSamplingClient requires the optional 'tinker' package. " + "Install fireworks-ai[training-sdk] to use this wrapper." + ) from e + return types + + def _ensure_loop(self) -> asyncio.AbstractEventLoop: + if self._closed: + raise RuntimeError("FiretitanSamplingClient is closed") + + with self._loop_lock: + if ( + self._loop is not None + and self._loop_thread is not None + and self._loop_thread.is_alive() + ): + return self._loop + + loop = asyncio.new_event_loop() + started = threading.Event() + + def _run_loop() -> None: + asyncio.set_event_loop(loop) + started.set() + loop.run_forever() + + thread = threading.Thread( + target=_run_loop, + name="fireworks-sampling-client", + daemon=True, + ) + thread.start() + started.wait() + self._loop = loop + self._loop_thread = thread + return loop + + def _submit(self, coro) -> ConcurrentFuture[Any]: + loop = self._ensure_loop() + return asyncio.run_coroutine_threadsafe(coro, loop) + + @staticmethod + async def _await_concurrent_future(future: ConcurrentFuture[Any]) -> Any: + try: + from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture + except ImportError: + return await asyncio.wrap_future(future) + + return await AwaitableConcurrentFuture(future) + + @staticmethod + def _sampling_kwargs(sampling_params: "tinker_types.SamplingParams") -> dict[str, Any]: + kwargs: dict[str, Any] = {} + + top_p = getattr(sampling_params, "top_p", None) + if top_p is not None: + kwargs["top_p"] = top_p + + top_k = getattr(sampling_params, "top_k", None) + if top_k is not None and top_k >= 0: + kwargs["top_k"] = top_k + + seed = getattr(sampling_params, "seed", None) + if seed is not None: + kwargs["seed"] = seed + + return kwargs + + @staticmethod + def _stop_reason(finish_reason: str) -> str: + return "length" if finish_reason == "length" else "stop" + + async def _sample_async_impl( + self, + prompt: "tinker_types.ModelInput", + num_samples: int, + sampling_params: "tinker_types.SamplingParams", + include_prompt_logprobs: bool, + topk_prompt_logprobs: int = 0, + ) -> "tinker_types.SampleResponse": + if topk_prompt_logprobs: + raise NotImplementedError( + "FiretitanSamplingClient does not support topk_prompt_logprobs yet" + ) + + types = self._tinker_types() + prompt_token_ids = list(prompt.to_ints()) + max_tokens = getattr(sampling_params, "max_tokens", None) + if max_tokens is None: + max_tokens = 1024 + temperature = getattr(sampling_params, "temperature", 1.0) + stop = getattr(sampling_params, "stop", None) + + kwargs = self._sampling_kwargs(sampling_params) + kwargs["logprobs"] = True + if include_prompt_logprobs: + kwargs["echo"] = True + + completions = await self.deployment_sampler.sample_with_prompt_tokens( + prompt_token_ids, + n=num_samples, + max_tokens=max_tokens, + temperature=temperature, + stop=stop, + **kwargs, + ) + + sequences = [] + prompt_logprobs: list[float | None] | None = None + for completion in completions: + completion_tokens = completion.full_tokens[completion.prompt_len :] + completion_logprobs = completion.inference_logprobs + if completion.logprobs_echoed and completion_logprobs is not None: + prompt_logprobs = ( + [None] + completion_logprobs[: completion.prompt_len - 1] + if completion.prompt_len + else [] + ) + completion_logprobs = completion_logprobs[max(completion.prompt_len - 1, 0) :] + + sequences.append( + types.SampledSequence( + stop_reason=self._stop_reason(completion.finish_reason), + tokens=completion_tokens, + logprobs=completion_logprobs, + ) + ) + + return types.SampleResponse( + sequences=sequences, + prompt_logprobs=prompt_logprobs if include_prompt_logprobs else None, + ) + + def sample( + self, + prompt: "tinker_types.ModelInput", + num_samples: int, + sampling_params: "tinker_types.SamplingParams", + include_prompt_logprobs: bool = False, + topk_prompt_logprobs: int = 0, + ) -> ConcurrentFuture["tinker_types.SampleResponse"]: + """Generate completions with the Tinker ``SamplingClient`` signature.""" + return self._submit( + self._sample_async_impl( + prompt, + num_samples, + sampling_params, + include_prompt_logprobs, + topk_prompt_logprobs, + ) + ) + + async def sample_async( + self, + prompt: "tinker_types.ModelInput", + num_samples: int, + sampling_params: "tinker_types.SamplingParams", + include_prompt_logprobs: bool = False, + topk_prompt_logprobs: int = 0, + ) -> "tinker_types.SampleResponse": + """Async version of :meth:`sample`.""" + return await self._await_concurrent_future( + self.sample( + prompt, + num_samples, + sampling_params, + include_prompt_logprobs, + topk_prompt_logprobs, + ) + ) + + async def _compute_logprobs_async_impl( + self, prompt: "tinker_types.ModelInput" + ) -> list[float | None]: + types = self._tinker_types() + response = await self._sample_async_impl( + prompt, + num_samples=1, + sampling_params=types.SamplingParams(max_tokens=1, temperature=1.0, top_p=1.0), + include_prompt_logprobs=True, + ) + return list(response.prompt_logprobs or []) + + def compute_logprobs( + self, prompt: "tinker_types.ModelInput" + ) -> ConcurrentFuture[list[float | None]]: + """Compute prompt token logprobs with the Tinker client signature. + + Warning: we don't recommend using the sampling client to compute logprobs. + Please use ``training_client.forward()`` instead. + """ + warnings.warn( + "We don't recommend using the sampling client to compute logprobs. " + "Please use training_client.forward() instead.", + UserWarning, + stacklevel=2, + ) + return self._submit(self._compute_logprobs_async_impl(prompt)) + + async def compute_logprobs_async( + self, prompt: "tinker_types.ModelInput" + ) -> list[float | None]: + """Async version of :meth:`compute_logprobs`.""" + return await self._await_concurrent_future(self.compute_logprobs(prompt)) + + def get_tokenizer(self) -> PreTrainedTokenizerBase: + """Return the tokenizer attached to the underlying deployment sampler.""" + if self.deployment_sampler.tokenizer is None: + raise ValueError("DeploymentSampler was created without a tokenizer") + return self.deployment_sampler.tokenizer + + def get_base_model(self) -> str: + """Return the model/deployment name used by the underlying sampler.""" + return self.deployment_sampler.model + + async def get_base_model_async(self) -> str: + """Async version of :meth:`get_base_model`.""" + return self.get_base_model() + + def get_telemetry(self) -> None: + """Match Tinker's ``SamplingClient`` telemetry hook.""" + return None + + async def _aclose_sampler(self) -> None: + async_client = self.deployment_sampler._async_client + self.deployment_sampler._async_client = None + if async_client is not None and not async_client.is_closed: + await async_client.aclose() + self.deployment_sampler._sync_client.close() + + def close(self) -> None: + """Close the wrapper loop and the underlying sampler clients.""" + if self._closed: + return + self._closed = True + + loop = self._loop + thread = self._loop_thread + if loop is not None and loop.is_running(): + future = asyncio.run_coroutine_threadsafe(self._aclose_sampler(), loop) + try: + future.result(timeout=10) + except Exception: + logger.debug("Failed to close FiretitanSamplingClient cleanly", exc_info=True) + loop.call_soon_threadsafe(loop.stop) + if thread is not None and thread is not threading.current_thread(): + thread.join(timeout=10) + if thread is None or not thread.is_alive(): + loop.close() + else: + logger.debug( + "Skipped closing FiretitanSamplingClient loop because the loop thread " + "did not stop before the close timeout" + ) + else: + self.deployment_sampler.close() + + self._loop = None + self._loop_thread = None + + def __enter__(self) -> "FiretitanSamplingClient": + return self + + def __exit__(self, *args) -> None: + self.close() + + def __del__(self) -> None: + try: + self.close() + except Exception: + pass + + # ============================================================================= # FixedConcurrencyController — static semaphore # ============================================================================= diff --git a/src/fireworks/training/sdk/tests/test_deployment.py b/src/fireworks/training/sdk/tests/test_deployment.py index bd1b5d5..12a5a4c 100644 --- a/src/fireworks/training/sdk/tests/test_deployment.py +++ b/src/fireworks/training/sdk/tests/test_deployment.py @@ -2,6 +2,8 @@ from __future__ import annotations +import sys +import types as pytypes import asyncio from unittest.mock import MagicMock @@ -12,6 +14,7 @@ DeploymentConfig, DeploymentManager, DeploymentSampler, + FiretitanSamplingClient, AdaptiveConcurrencyController, _SSETruncationError, ) @@ -65,6 +68,61 @@ async def _fake(*args, **kwargs): sampler.async_completions_stream = _fake +class _FakeModelInput: + def __init__(self, tokens): + self._tokens = list(tokens) + + @classmethod + def from_ints(cls, tokens): + return cls(tokens) + + def to_ints(self): + return list(self._tokens) + + +class _FakeSamplingParams: + def __init__( + self, + max_tokens=None, + seed=None, + stop=None, + temperature=1, + top_k=-1, + top_p=1, + ): + self.max_tokens = max_tokens + self.seed = seed + self.stop = stop + self.temperature = temperature + self.top_k = top_k + self.top_p = top_p + + +class _FakeSampledSequence: + def __init__(self, stop_reason, tokens, logprobs=None): + self.stop_reason = stop_reason + self.tokens = tokens + self.logprobs = logprobs + + +class _FakeSampleResponse: + def __init__(self, sequences, prompt_logprobs=None): + self.sequences = sequences + self.prompt_logprobs = prompt_logprobs + + +@pytest.fixture +def fake_tinker(monkeypatch): + fake_types = pytypes.SimpleNamespace( + ModelInput=_FakeModelInput, + SamplingParams=_FakeSamplingParams, + SampledSequence=_FakeSampledSequence, + SampleResponse=_FakeSampleResponse, + ) + monkeypatch.setitem(sys.modules, "tinker", pytypes.SimpleNamespace(types=fake_types)) + return fake_types + + # --------------------------------------------------------------------------- # _should_verify_ssl # --------------------------------------------------------------------------- @@ -722,6 +780,195 @@ async def _fake_stream(*args, **kwargs): sampler.close() +class TestFiretitanSamplingClient: + def test_sample_returns_tinker_response(self, fake_tinker): + prompt_ids = [10, 20, 30] + completion_ids = [40, 50] + sampler = _make_sampler(tokenizer=None) + captured = {} + + async def _fake_stream(*args, **kwargs): + captured.update(kwargs) + return { + "choices": [{ + "text": "out", + "finish_reason": "stop", + "raw_output": {"completion_token_ids": completion_ids}, + "logprobs": {"content": [ + {"logprob": -0.3}, + {"logprob": -0.4}, + ]}, + }] + }, ServerMetrics() + + sampler.async_completions_stream = _fake_stream + client = FiretitanSamplingClient(sampler) + try: + response = client.sample( + prompt=fake_tinker.ModelInput.from_ints(prompt_ids), + num_samples=1, + sampling_params=fake_tinker.SamplingParams( + max_tokens=2, + stop=[99], + temperature=0.7, + top_p=0.9, + top_k=10, + seed=123, + ), + ).result(timeout=5) + finally: + client.close() + + assert len(response.sequences) == 1 + assert response.sequences[0].tokens == completion_ids + assert response.sequences[0].logprobs == [-0.3, -0.4] + assert response.sequences[0].stop_reason == "stop" + assert response.prompt_logprobs is None + assert captured["prompt"] == prompt_ids + assert captured["max_tokens"] == 2 + assert captured["temperature"] == 0.7 + assert captured["stop"] == [99] + assert captured["top_p"] == 0.9 + assert captured["top_k"] == 10 + assert captured["seed"] == 123 + assert captured["logprobs"] is True + + def test_sample_splits_echo_prompt_logprobs(self, fake_tinker): + prompt_ids = [10, 20, 30] + completion_ids = [40, 50] + sampler = _make_sampler(tokenizer=None) + captured = {} + + async def _fake_stream(*args, **kwargs): + captured.update(kwargs) + return { + "choices": [{ + "text": "out", + "finish_reason": "length", + "raw_output": {"completion_token_ids": prompt_ids + completion_ids}, + "logprobs": {"content": [ + {"logprob": 0.0}, + {"logprob": -0.1}, + {"logprob": -0.2}, + {"logprob": -0.3}, + {"logprob": -0.4}, + ]}, + }] + }, ServerMetrics() + + sampler.async_completions_stream = _fake_stream + client = FiretitanSamplingClient(sampler) + try: + response = client.sample( + prompt=fake_tinker.ModelInput.from_ints(prompt_ids), + num_samples=1, + sampling_params=fake_tinker.SamplingParams(max_tokens=2), + include_prompt_logprobs=True, + ).result(timeout=5) + finally: + client.close() + + assert captured["echo"] is True + assert response.prompt_logprobs == [None, -0.1, -0.2] + assert response.sequences[0].tokens == completion_ids + assert response.sequences[0].logprobs == [-0.3, -0.4] + assert response.sequences[0].stop_reason == "length" + + def test_compute_logprobs_uses_prompt_logprobs(self, fake_tinker): + prompt_ids = [10, 20, 30] + sampler = _make_sampler(tokenizer=None) + + async def _fake_stream(*args, **kwargs): + return { + "choices": [{ + "text": "x", + "finish_reason": "length", + "raw_output": {"completion_token_ids": prompt_ids + [40]}, + "logprobs": {"content": [ + {"logprob": 0.0}, + {"logprob": -0.1}, + {"logprob": -0.2}, + {"logprob": -0.3}, + ]}, + }] + }, ServerMetrics() + + sampler.async_completions_stream = _fake_stream + client = FiretitanSamplingClient(sampler) + try: + with pytest.warns(UserWarning, match="training_client.forward"): + logprobs = client.compute_logprobs( + fake_tinker.ModelInput.from_ints(prompt_ids) + ).result(timeout=5) + finally: + client.close() + + assert logprobs == [None, -0.1, -0.2] + + def test_topk_prompt_logprobs_is_explicitly_unsupported(self, fake_tinker): + sampler = _make_sampler(tokenizer=None) + client = FiretitanSamplingClient(sampler) + try: + future = client.sample( + prompt=fake_tinker.ModelInput.from_ints([1, 2]), + num_samples=1, + sampling_params=fake_tinker.SamplingParams(max_tokens=1), + topk_prompt_logprobs=1, + ) + with pytest.raises(NotImplementedError, match="topk_prompt_logprobs"): + future.result(timeout=5) + finally: + client.close() + + def test_close_does_not_close_loop_if_thread_join_times_out(self, monkeypatch): + sampler = _make_sampler(tokenizer=None) + client = FiretitanSamplingClient(sampler) + + class _FakeLoop: + closed = False + + def is_running(self): + return True + + def call_soon_threadsafe(self, callback, *args): + return None + + def stop(self): + return None + + def close(self): + self.closed = True + raise RuntimeError("Cannot close a running event loop") + + class _FakeThread: + joined = False + + def is_alive(self): + return True + + def join(self, timeout=None): + self.joined = True + + class _FakeFuture: + def result(self, timeout=None): + raise TimeoutError + + def _run_coroutine_threadsafe(coro, loop): + coro.close() + return _FakeFuture() + + fake_loop = _FakeLoop() + fake_thread = _FakeThread() + client._loop = fake_loop + client._loop_thread = fake_thread + monkeypatch.setattr(asyncio, "run_coroutine_threadsafe", _run_coroutine_threadsafe) + + client.close() + + assert fake_thread.joined is True + assert fake_loop.closed is False + + # --------------------------------------------------------------------------- # Regression: legacy sample_with_tokens(messages=...) still calls apply_chat_template # --------------------------------------------------------------------------- diff --git a/src/fireworks/training/sdk/tests/test_firetitan_tinker_compat.py b/src/fireworks/training/sdk/tests/test_firetitan_tinker_compat.py new file mode 100644 index 0000000..400bfde --- /dev/null +++ b/src/fireworks/training/sdk/tests/test_firetitan_tinker_compat.py @@ -0,0 +1,435 @@ +"""Compatibility tests: ``FiretitanSamplingClient`` vs ``tinker.SamplingClient``. + +The other tests in this directory drive ``FiretitanSamplingClient`` against +*fake* tinker types (see the ``fake_tinker`` fixture). That proves behavior but +says nothing about whether the wrapper is actually drop-in compatible with the +*real* tinker client. This module closes that gap with three layers: + +1. ``TestInterfaceCompat`` — static, no network, no creds. Introspects the real + ``tinker.lib.public_interfaces.SamplingClient`` and asserts the Firetitan + wrapper exposes every method of the shared sampling surface with a + signature-compatible definition. + +2. ``TestFormatCompat`` — exercises ``FiretitanSamplingClient`` against a mocked + deployment using the *real* ``tinker.types``, then asserts the returned + objects are genuine ``tinker.SampleResponse`` / ``SampledSequence`` instances + whose structural *shape* matches a hand-built tinker reference object. + +3. ``TestLiveFormatCompat`` — runs both real clients end-to-end and diffs their + shapes side by side. Skipped unless ``TINKER_API_KEY`` and + ``FIREWORKS_API_KEY`` are set (and ``FIRETITAN_LIVE=1`` to opt in). + +The reusable comparison logic lives in :class:`SamplingFormatComparator`, which +is import-safe to reuse from notebooks/scripts. + +Run just this file: + pytest src/fireworks/training/sdk/tests/test_firetitan_tinker_compat.py -v +Run the live layer too: + FIRETITAN_LIVE=1 pytest -k LiveFormatCompat -v +""" + +from __future__ import annotations + +import os +import inspect +from typing import Any + +import pytest + +from fireworks.training.sdk.deployment import ( + ServerMetrics, + DeploymentSampler, + FiretitanSamplingClient, +) + +# The whole module is meaningless without the real tinker package; skip cleanly +# (e.g. on a CI image that doesn't install the optional 'tinker' extra). +tinker = pytest.importorskip("tinker", reason="real tinker package required for compat checks") +from tinker import types as tinker_types # noqa: E402 + +# Skip the entire module unless both API credentials are present. Even the +# static/mocked layers are gated on this so the file is a no-op in environments +# without Tinker/Fireworks access (e.g. CI without secrets configured). +pytestmark = pytest.mark.skipif( + not (os.environ.get("TINKER_API_KEY") and os.environ.get("FIREWORKS_API_KEY")), + reason="TINKER_API_KEY and FIREWORKS_API_KEY required for tinker/fireworks compat checks", +) + +# The public sampling surface the wrapper promises to mirror. tinker-internal +# concerns (pickling, queue-state callbacks, retry handlers, the differing +# `create` constructor) are deliberately excluded. +SHARED_SURFACE = ( + "sample", + "sample_async", + "compute_logprobs", + "compute_logprobs_async", + "get_tokenizer", + "get_base_model", + "get_base_model_async", + "get_telemetry", +) + + +def _tinker_sampling_client_cls() -> type: + from tinker.lib.public_interfaces.sampling_client import SamplingClient + + return SamplingClient + + +# --------------------------------------------------------------------------- +# Reusable structural comparator (also handy from notebooks) +# --------------------------------------------------------------------------- + + +class SamplingFormatComparator: + """Compare the *structure* (not the values) of sampling return objects. + + ``shape_of`` reduces an arbitrary return value to a canonical fingerprint: + nested types, field names, element types and None-vs-present distinctions — + everything that matters for format compatibility, nothing that depends on + the actual sampled tokens or logprob magnitudes. + """ + + @staticmethod + def shape_of(obj: Any, _depth: int = 0) -> Any: + if obj is None: + return "None" + if isinstance(obj, bool): + return "bool" + if isinstance(obj, int): + return "int" + if isinstance(obj, float): + return "float" + if isinstance(obj, str): + return "str" + if isinstance(obj, (list, tuple)): + kind = type(obj).__name__ + if not obj: + return f"{kind}[empty]" + # Collapse element shapes; flag whether any element is None so that + # `[None, -0.1, -0.2]` (prompt_logprobs) is distinguishable from a + # pure `list[float]`. Dedup via repr because element shapes may be + # dicts (e.g. a list of SampledSequence), which are unhashable. + has_none = any(e is None for e in obj) + seen: dict[str, Any] = {} + for e in obj: + if e is None: + continue + s = SamplingFormatComparator.shape_of(e, _depth + 1) + seen.setdefault(repr(s), s) + non_none = [seen[k] for k in sorted(seen)] + if not non_none: + inner: Any = "None" + elif len(non_none) == 1: + inner = non_none[0] + else: + inner = non_none # heterogeneous list: keep all variant shapes + if has_none: + inner = {"Optional": inner} + return {"__list__": kind, "elem": inner} + # pydantic model or dataclass: descend into named fields. + fields = SamplingFormatComparator._field_names(obj) + if fields is not None: + qualname = f"{type(obj).__module__.split('.')[0]}.{type(obj).__name__}" + return { + "__type__": qualname, + **{ + f: SamplingFormatComparator.shape_of(getattr(obj, f), _depth + 1) + for f in fields + }, + } + return type(obj).__name__ + + @staticmethod + def _field_names(obj: Any) -> list[str] | None: + # pydantic v2 + mf = getattr(type(obj), "model_fields", None) + if isinstance(mf, dict): + return list(mf.keys()) + # dataclass + df = getattr(type(obj), "__dataclass_fields__", None) + if isinstance(df, dict): + return list(df.keys()) + return None + + @classmethod + def diff(cls, expected: Any, actual: Any, path: str = "") -> list[str]: + """Return human-readable mismatches between two shapes (empty == match).""" + exp_shape = cls.shape_of(expected) if not cls._is_shape(expected) else expected + act_shape = cls.shape_of(actual) if not cls._is_shape(actual) else actual + return cls._diff_shapes(exp_shape, act_shape, path) + + @staticmethod + def _is_shape(x: Any) -> bool: + return isinstance(x, str) or (isinstance(x, dict) and "__type__" in x) + + @classmethod + def _diff_shapes(cls, exp: Any, act: Any, path: str) -> list[str]: + loc = path or "" + if isinstance(exp, dict) and isinstance(act, dict): + problems: list[str] = [] + for key in sorted(set(exp) | set(act)): + if key not in act: + problems.append(f"{loc}: missing field {key!r}") + elif key not in exp: + problems.append(f"{loc}: unexpected field {key!r}") + else: + problems.extend( + cls._diff_shapes(exp[key], act[key], f"{path}.{key}" if path else key) + ) + return problems + if exp != act: + return [f"{loc}: expected {exp!r}, got {act!r}"] + return [] + + +# --------------------------------------------------------------------------- +# Helpers: build a mocked Firetitan client backed by a canned completion +# --------------------------------------------------------------------------- + + +def _firetitan_with_completion(choice: dict, captured: dict | None = None) -> FiretitanSamplingClient: + """A FiretitanSamplingClient whose deployment returns a single canned choice.""" + sampler = DeploymentSampler( + inference_url="https://api.example.com", + model="accounts/fireworks/models/gpt-oss-20b", + api_key="key", + tokenizer=None, + ) + + async def _fake_stream(*args, **kwargs): + if captured is not None: + captured.update(kwargs) + return {"choices": [choice]}, ServerMetrics() + + sampler.async_completions_stream = _fake_stream + return FiretitanSamplingClient(sampler) + + +def _completion(completion_ids, logprobs=None, finish_reason="stop"): + choice: dict[str, Any] = { + "text": "out", + "finish_reason": finish_reason, + "raw_output": {"completion_token_ids": completion_ids}, + } + if logprobs is not None: + choice["logprobs"] = {"content": [{"logprob": lp} for lp in logprobs]} + return choice + + +# --------------------------------------------------------------------------- +# 1. Interface compatibility (static) +# --------------------------------------------------------------------------- + + +class TestInterfaceCompat: + @pytest.mark.parametrize("name", SHARED_SURFACE) + def test_method_present_and_callable(self, name): + attr = getattr(FiretitanSamplingClient, name, None) + assert attr is not None, f"FiretitanSamplingClient is missing {name!r}" + assert callable(attr), f"{name!r} is not callable" + + @pytest.mark.parametrize("name", SHARED_SURFACE) + def test_signature_matches_tinker(self, name): + tinker_cls = _tinker_sampling_client_cls() + if not hasattr(tinker_cls, name): + pytest.skip(f"tinker SamplingClient has no {name!r} in this version") + + tinker_params = _public_params(getattr(tinker_cls, name)) + fire_params = _public_params(getattr(FiretitanSamplingClient, name)) + + # Same parameter names, in the same order. + assert [p.name for p in fire_params] == [p.name for p in tinker_params], ( + f"{name}: parameter names/order differ\n" + f" tinker: {[p.name for p in tinker_params]}\n" + f" firetitan: {[p.name for p in fire_params]}" + ) + # Same defaults (so callers can rely on identical optional behavior). + for tp, fp in zip(tinker_params, fire_params): + assert fp.default == tp.default, ( + f"{name}: default for {tp.name!r} differs " + f"(tinker={tp.default!r}, firetitan={fp.default!r})" + ) + + def test_async_methods_are_coroutines(self): + for name in ("sample_async", "compute_logprobs_async", "get_base_model_async"): + assert inspect.iscoroutinefunction(getattr(FiretitanSamplingClient, name)), ( + f"{name} should be a coroutine function, matching tinker" + ) + + +def _public_params(func) -> list[inspect.Parameter]: + """Positional/keyword params excluding self and any *args/**kwargs.""" + params = list(inspect.signature(func).parameters.values()) + out = [] + for p in params: + if p.name == "self": + continue + if p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD): + continue + out.append(p) + return out + + +# --------------------------------------------------------------------------- +# 2. Format compatibility (mocked deployment, real tinker.types) +# --------------------------------------------------------------------------- + + +class TestFormatCompat: + def test_sample_response_is_real_tinker_type(self): + client = _firetitan_with_completion(_completion([40, 50], logprobs=[-0.3, -0.4])) + try: + resp = client.sample( + prompt=tinker_types.ModelInput.from_ints([10, 20, 30]), + num_samples=1, + sampling_params=tinker_types.SamplingParams(max_tokens=2, temperature=0.7), + ).result(timeout=5) + finally: + client.close() + + assert isinstance(resp, tinker_types.SampleResponse) + assert isinstance(resp.sequences[0], tinker_types.SampledSequence) + + def test_sample_shape_matches_tinker_reference(self): + """Firetitan's SampleResponse must have the same shape as a hand-built + tinker one constructed from equivalent data.""" + client = _firetitan_with_completion(_completion([40, 50], logprobs=[-0.3, -0.4])) + try: + actual = client.sample( + prompt=tinker_types.ModelInput.from_ints([10, 20, 30]), + num_samples=1, + sampling_params=tinker_types.SamplingParams(max_tokens=2), + ).result(timeout=5) + finally: + client.close() + + reference = tinker_types.SampleResponse( + sequences=[ + tinker_types.SampledSequence( + stop_reason="stop", tokens=[40, 50], logprobs=[-0.3, -0.4] + ) + ], + prompt_logprobs=None, + ) + + problems = SamplingFormatComparator.diff(reference, actual) + assert not problems, "shape mismatch vs tinker reference:\n " + "\n ".join(problems) + + def test_prompt_logprobs_shape_matches(self): + """include_prompt_logprobs=True must yield the tinker + list[Optional[float]] shape with a leading None for the first token.""" + prompt_ids = [10, 20, 30] + echoed = prompt_ids + [40, 50] + logprobs = [0.0, -0.1, -0.2, -0.3, -0.4] + client = _firetitan_with_completion( + _completion(echoed, logprobs=logprobs, finish_reason="length") + ) + try: + actual = client.sample( + prompt=tinker_types.ModelInput.from_ints(prompt_ids), + num_samples=1, + sampling_params=tinker_types.SamplingParams(max_tokens=2), + include_prompt_logprobs=True, + ).result(timeout=5) + finally: + client.close() + + reference = tinker_types.SampleResponse( + sequences=[ + tinker_types.SampledSequence( + stop_reason="length", tokens=[40, 50], logprobs=[-0.3, -0.4] + ) + ], + prompt_logprobs=[None, -0.1, -0.2], + ) + problems = SamplingFormatComparator.diff(reference, actual) + assert not problems, "prompt_logprobs shape mismatch:\n " + "\n ".join(problems) + + def test_compute_logprobs_shape_matches(self): + prompt_ids = [10, 20, 30] + client = _firetitan_with_completion( + _completion(prompt_ids + [40], logprobs=[0.0, -0.1, -0.2, -0.3], finish_reason="length") + ) + try: + actual = client.compute_logprobs( + tinker_types.ModelInput.from_ints(prompt_ids) + ).result(timeout=5) + finally: + client.close() + + # tinker contract: list[float | None] + reference: list[float | None] = [None, -0.1, -0.2] + problems = SamplingFormatComparator.diff(reference, actual) + assert not problems, "compute_logprobs shape mismatch:\n " + "\n ".join(problems) + + def test_num_samples_produces_list_of_sequences(self): + client = _firetitan_with_completion(_completion([40, 50], logprobs=[-0.3, -0.4])) + try: + # The mocked stream returns one choice per call regardless of n; what + # we assert here is that sequences is a list whose elements are all + # SampledSequence (the tinker container contract). + resp = client.sample( + prompt=tinker_types.ModelInput.from_ints([10, 20, 30]), + num_samples=3, + sampling_params=tinker_types.SamplingParams(max_tokens=2), + ).result(timeout=5) + finally: + client.close() + + assert isinstance(resp.sequences, list) + assert all(isinstance(s, tinker_types.SampledSequence) for s in resp.sequences) + + +# --------------------------------------------------------------------------- +# 3. Live cross-check (both real clients) — opt-in +# --------------------------------------------------------------------------- + +_LIVE_REASON = ( + "set FIRETITAN_LIVE=1, TINKER_API_KEY and FIREWORKS_API_KEY to run live compat" +) + + +@pytest.mark.skipif( + not ( + os.environ.get("FIRETITAN_LIVE") == "1" + and os.environ.get("TINKER_API_KEY") + and os.environ.get("FIREWORKS_API_KEY") + ), + reason=_LIVE_REASON, +) +class TestLiveFormatCompat: + TINKER_MODEL = os.environ.get("TINKER_COMPAT_MODEL", "Qwen/Qwen3-4B-Instruct-2507") + FIREWORKS_MODEL = os.environ.get( + "FIREWORKS_COMPAT_MODEL", "accounts/fireworks/models/gpt-oss-20b" + ) + + def test_sample_shapes_match_live(self): + params = tinker_types.SamplingParams(max_tokens=8, temperature=0.7) + + # --- real tinker --- + sc = tinker.ServiceClient() + tinker_client = sc.create_sampling_client(base_model=self.TINKER_MODEL) + tok = tinker_client.get_tokenizer() + t_prompt = tinker_types.ModelInput.from_ints(tok.encode("The weather today is")) + t_resp = tinker_client.sample( + prompt=t_prompt, sampling_params=params, num_samples=2 + ).result() + + # --- firetitan --- + f_client = FiretitanSamplingClient.create( + inference_url="https://api.fireworks.ai", + model=self.FIREWORKS_MODEL, + api_key=os.environ["FIREWORKS_API_KEY"], + tokenizer=None, + ) + try: + f_prompt = tinker_types.ModelInput.from_ints([3923, 374, 220, 17, 10, 17, 30]) + f_resp = f_client.sample( + prompt=f_prompt, sampling_params=params, num_samples=2 + ).result(timeout=120) + finally: + f_client.close() + + problems = SamplingFormatComparator.diff(t_resp, f_resp) + assert not problems, "live shape mismatch:\n " + "\n ".join(problems) diff --git a/src/fireworks/training/sdk/tests/test_weight_syncer.py b/src/fireworks/training/sdk/tests/test_weight_syncer.py index 56782dc..9d2f200 100644 --- a/src/fireworks/training/sdk/tests/test_weight_syncer.py +++ b/src/fireworks/training/sdk/tests/test_weight_syncer.py @@ -127,6 +127,30 @@ def test_existing_snapshot_clears_base_identity(self): assert t.base_identity is None +# --------------------------------------------------------------------------- +# Sampling client helpers +# --------------------------------------------------------------------------- + + +class TestGetSamplingClient: + def test_returns_firetitan_sampling_client_for_deployment(self): + deploy = _make_deploy_mgr() + deploy.inference_url = "https://api.fireworks.ai" + deploy.api_key = "test-key" + t = _make_tracker(deploy_mgr=deploy, deployment_id="dep-123") + + from fireworks.training.sdk.deployment import FiretitanSamplingClient + + client = t.get_sampling_client() + try: + assert isinstance(client, FiretitanSamplingClient) + assert client.deployment_sampler.base_url == "https://api.fireworks.ai" + assert client.deployment_sampler.model == "accounts/test-acct/deployments/dep-123" + assert client.deployment_sampler.api_key == "test-key" + finally: + client.close() + + # --------------------------------------------------------------------------- # Warmup model resolution # --------------------------------------------------------------------------- diff --git a/src/fireworks/training/sdk/weight_syncer.py b/src/fireworks/training/sdk/weight_syncer.py index 12bb6f3..bc567d4 100644 --- a/src/fireworks/training/sdk/weight_syncer.py +++ b/src/fireworks/training/sdk/weight_syncer.py @@ -34,14 +34,17 @@ import time import logging -from typing import TYPE_CHECKING from dataclasses import field, dataclass -from fireworks.training.sdk.client import FiretitanTrainingClient -from fireworks.training.sdk.deployment import DEFAULT_CHECKSUM_FORMAT +from transformers import PreTrainedTokenizerBase -if TYPE_CHECKING: - from fireworks.training.sdk.deployment import DeploymentManager +from fireworks.training.sdk.client import FiretitanTrainingClient +from fireworks.training.sdk.deployment import ( + DEFAULT_CHECKSUM_FORMAT, + DeploymentManager, + DeploymentSampler, + FiretitanSamplingClient, +) logger = logging.getLogger(__name__) @@ -480,11 +483,17 @@ def save_and_hotload(self, name: str, checkpoint_type: str | None = None) -> str ) raise - def get_deployment_sampler(self): - """Get the deployment's current sampler""" - from fireworks.training.sdk.deployment import DeploymentSampler + def get_deployment_sampler(self, tokenizer: PreTrainedTokenizerBase | None = None) -> DeploymentSampler: + """Get the deployment's current sampler.""" + return DeploymentSampler( inference_url=self.deploy_mgr.inference_url, model=self._get_model(), api_key=self.deploy_mgr.api_key, + tokenizer=tokenizer, ) + + def get_sampling_client(self, tokenizer: PreTrainedTokenizerBase | None = None) -> FiretitanSamplingClient: + """Get a Tinker-compatible sampling client for the deployment.""" + + return FiretitanSamplingClient(self.get_deployment_sampler(tokenizer))