From ac79a4670a8003029a81c9b361d402b81b290335 Mon Sep 17 00:00:00 2001 From: inaniloquentee <3051000145@qq.com> Date: Sat, 27 Jun 2026 20:38:51 +0800 Subject: [PATCH] Add RL-Kernel linear_logp integration with TP2 benchmark config --- scripts/run-qwen3-30B-A3B.sh | 113 ++++- tests/_unit_stubs.py | 89 +--- tests/test_rl_kernel_args.py | 83 ++++ .../test_rl_kernel_linear_logp_integration.py | 423 ++++++++++++++++++ tests/test_rl_kernel_logp_integration.py | 185 ++++++++ vime-RLK.md | 312 +++++++++++++ vime/backends/megatron_utils/loss.py | 159 +++++-- vime/backends/megatron_utils/model.py | 83 +++- vime/backends/megatron_utils/rl_kernel.py | 319 +++++++++++++ vime/ray/actor_group.py | 1 + vime/utils/arguments.py | 32 ++ vime/utils/rl_kernel.py | 79 ++++ 12 files changed, 1733 insertions(+), 145 deletions(-) create mode 100644 tests/test_rl_kernel_args.py create mode 100644 tests/test_rl_kernel_linear_logp_integration.py create mode 100644 tests/test_rl_kernel_logp_integration.py create mode 100644 vime-RLK.md create mode 100644 vime/backends/megatron_utils/rl_kernel.py create mode 100644 vime/utils/rl_kernel.py diff --git a/scripts/run-qwen3-30B-A3B.sh b/scripts/run-qwen3-30B-A3B.sh index 83dc0c6d..ac694315 100644 --- a/scripts/run-qwen3-30B-A3B.sh +++ b/scripts/run-qwen3-30B-A3B.sh @@ -1,7 +1,7 @@ #!/bin/bash # for rerun the task -pkill -9 vllm +pkill -9 -f "vllm serve" sleep 3 ray stop --force pkill -9 ray @@ -24,17 +24,53 @@ else fi echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" +if command -v nvidia-smi >/dev/null 2>&1; then + DETECTED_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l | tr -d ' ') + DETECTED_GPU_NAME=$(nvidia-smi --query-gpu=name --format=csv,noheader 2>/dev/null | head -n 1) +else + DETECTED_GPUS=0 + DETECTED_GPU_NAME="unknown" +fi +NUM_GPUS=${NUM_GPUS:-8} +if [ -z "$NUM_GPUS" ] || [ "$NUM_GPUS" -le 0 ]; then + NUM_GPUS=8 +fi +if [ "$DETECTED_GPUS" -gt 0 ] && [ "$NUM_GPUS" -gt "$DETECTED_GPUS" ]; then + echo "Requested NUM_GPUS=$NUM_GPUS but only detected $DETECTED_GPUS GPUs" >&2 + exit 1 +fi +echo "BENCHMARK_GPU: ${DETECTED_GPU_NAME}" +echo "NUM_GPUS: $NUM_GPUS" + SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +VIME_ROOT="$(cd -- "${SCRIPT_DIR}/.." &>/dev/null && pwd)" source "${SCRIPT_DIR}/models/qwen3-30B-A3B.sh" +MEGATRON_TP=${MEGATRON_TP:-2} +MEGATRON_EP=${MEGATRON_EP:-${NUM_GPUS}} +MEGATRON_CP=${MEGATRON_CP:-1} +MAX_TOKENS_PER_GPU=${MAX_TOKENS_PER_GPU:-20480} +NUM_ROLLOUT=${NUM_ROLLOUT:-3000} +ROLLOUT_BATCH_SIZE=${ROLLOUT_BATCH_SIZE:-32} +N_SAMPLES_PER_PROMPT=${N_SAMPLES_PER_PROMPT:-8} +ROLLOUT_MAX_RESPONSE_LEN=${ROLLOUT_MAX_RESPONSE_LEN:-8192} +GLOBAL_BATCH_SIZE=${GLOBAL_BATCH_SIZE:-$((ROLLOUT_BATCH_SIZE * N_SAMPLES_PER_PROMPT))} +ROLLOUT_NUM_GPUS_PER_ENGINE=${ROLLOUT_NUM_GPUS_PER_ENGINE:-${NUM_GPUS}} +VLLM_GPU_MEMORY_UTILIZATION=${VLLM_GPU_MEMORY_UTILIZATION:-0.7} +VIME_CKPT_DIR=${VIME_CKPT_DIR:-/root/Qwen3-30B-A3B_vime} + CKPT_ARGS=( --hf-checkpoint /root/Qwen3-30B-A3B #--hf-checkpoint /root/Qwen3-30B-A3B-FP8 --ref-load /root/Qwen3-30B-A3B_torch_dist - --load /root/Qwen3-30B-A3B_vime/ - --save /root/Qwen3-30B-A3B_vime/ - --save-interval 20 + --load "${VIME_CKPT_DIR}/" ) +if [[ "${VIME_DISABLE_SAVE:-0}" != "1" ]]; then + CKPT_ARGS+=( + --save "${VIME_CKPT_DIR}/" + --save-interval "${VIME_SAVE_INTERVAL:-20}" + ) +fi ROLLOUT_ARGS=( --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl @@ -43,13 +79,13 @@ ROLLOUT_ARGS=( --apply-chat-template --rollout-shuffle --rm-type deepscaler - --num-rollout 3000 - --rollout-batch-size 32 - --n-samples-per-prompt 8 - --rollout-max-response-len 8192 + --num-rollout "${NUM_ROLLOUT}" + --rollout-batch-size "${ROLLOUT_BATCH_SIZE}" + --n-samples-per-prompt "${N_SAMPLES_PER_PROMPT}" + --rollout-max-response-len "${ROLLOUT_MAX_RESPONSE_LEN}" --rollout-temperature 1 - --global-batch-size 256 + --global-batch-size "${GLOBAL_BATCH_SIZE}" --balance-data ) @@ -60,13 +96,16 @@ EVAL_ARGS=( --eval-max-response-len 16384 --eval-top-p 1 ) +if [[ "${VIME_SKIP_EVAL_BEFORE_TRAIN:-0}" == "1" ]]; then + EVAL_ARGS+=(--skip-eval-before-train) +fi PERF_ARGS=( - --tensor-model-parallel-size 4 + --tensor-model-parallel-size "${MEGATRON_TP}" --sequence-parallel --pipeline-model-parallel-size 1 - --context-parallel-size 1 - --expert-model-parallel-size 8 + --context-parallel-size "${MEGATRON_CP}" + --expert-model-parallel-size "${MEGATRON_EP}" --expert-tensor-parallel-size 1 --recompute-granularity full @@ -75,8 +114,11 @@ PERF_ARGS=( # --micro-batch-size 1 --use-dynamic-batch-size - --max-tokens-per-gpu 20480 + --max-tokens-per-gpu "${MAX_TOKENS_PER_GPU}" ) +if [[ "${VIME_NO_GRAD_ACCUM_FUSION:-0}" == "1" ]]; then + PERF_ARGS+=(--no-gradient-accumulation-fusion) +fi GRPO_ARGS=( --advantage-estimator grpo @@ -108,11 +150,24 @@ WANDB_ARGS=( # --wandb-key ${WANDB_KEY} ) +TB_ARGS=() +if [[ "${VIME_TENSORBOARD:-0}" == "1" ]]; then + export TENSORBOARD_DIR="${TENSORBOARD_DIR:-${VIME_ROOT}/tensorboard_log/${TB_EXPERIMENT_NAME:-qwen3-30B-A3B}}" + TB_ARGS+=(--use-tensorboard) + TB_ARGS+=(--tb-project-name "${TB_PROJECT_NAME:-vime-rlk}") + TB_ARGS+=(--tb-experiment-name "${TB_EXPERIMENT_NAME:-qwen3-30B-A3B}") +fi + VLLM_ARGS=( - --rollout-num-gpus-per-engine 8 - --vllm-gpu-memory-utilization 0.7 - --vllm-cudagraph-capture-sizes 1 2 4 8 $(seq 16 8 256) + --rollout-num-gpus-per-engine "${ROLLOUT_NUM_GPUS_PER_ENGINE}" + --vllm-gpu-memory-utilization "${VLLM_GPU_MEMORY_UTILIZATION}" + --vllm-enable-expert-parallel ) +if [[ "${VIME_VLLM_ENFORCE_EAGER:-0}" == "1" ]]; then + VLLM_ARGS+=(--vllm-enforce-eager) +else + VLLM_ARGS+=(--vllm-cudagraph-capture-sizes 1 2 4 8 $(seq 16 8 256)) +fi MISC_ARGS=( # default dropout in megatron is 0.1 @@ -125,16 +180,30 @@ MISC_ARGS=( --attention-backend flash ) +RLK_ARGS=() +if [[ "${VIME_RL_KERNEL:-0}" == "1" ]]; then + RLK_ARGS+=(--enable-rl-kernel --rl-kernel-ops "${VIME_RL_KERNEL_OPS:-linear_logp}") + if [[ "${VIME_RL_KERNEL_STRICT:-0}" == "1" ]]; then + RLK_ARGS+=(--rl-kernel-strict) + fi +fi + # launch the master node of ray in container export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 # Build the runtime environment JSON with proper variable substitution RUNTIME_ENV_JSON="{ \"env_vars\": { - \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"PYTHONPATH\": \"${VIME_ROOT}:/root/Megatron-LM/\", + \"PATH\": \"${PATH}\", + \"CUDA_HOME\": \"${CUDA_HOME:-}\", + \"LD_LIBRARY_PATH\": \"${LD_LIBRARY_PATH:-}\", + \"CPATH\": \"${CPATH:-}\", + \"LIBRARY_PATH\": \"${LIBRARY_PATH:-}\", \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", - \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", + \"TENSORBOARD_DIR\": \"${TENSORBOARD_DIR:-}\" } }" @@ -142,7 +211,7 @@ ray job submit --address="http://127.0.0.1:8265" \ --runtime-env-json="${RUNTIME_ENV_JSON}" \ -- python3 train.py \ --actor-num-nodes 1 \ - --actor-num-gpus-per-node 8 \ + --actor-num-gpus-per-node ${NUM_GPUS} \ --colocate \ ${MODEL_ARGS[@]} \ ${CKPT_ARGS[@]} \ @@ -150,7 +219,9 @@ ray job submit --address="http://127.0.0.1:8265" \ ${OPTIMIZER_ARGS[@]} \ ${GRPO_ARGS[@]} \ ${WANDB_ARGS[@]} \ + ${TB_ARGS[@]} \ ${PERF_ARGS[@]} \ ${EVAL_ARGS[@]} \ ${VLLM_ARGS[@]} \ - ${MISC_ARGS[@]} + ${MISC_ARGS[@]} \ + ${RLK_ARGS[@]} diff --git a/tests/_unit_stubs.py b/tests/_unit_stubs.py index 2ba863de..7fe27a24 100644 --- a/tests/_unit_stubs.py +++ b/tests/_unit_stubs.py @@ -57,7 +57,8 @@ def install_rollout_optional_stubs() -> None: """Stub rollout-side optional imports when not installed.""" ensure_ray_stub() - install_vllm_router_stub() + if not real_module_available("vllm_router"): + sys.modules["vllm_router"] = types.ModuleType("vllm_router") if not real_module_available("PIL"): pil = types.ModuleType("PIL") @@ -96,48 +97,6 @@ def _raise_os_error(*args, **kwargs): sys.modules["pylatexenc"] = pylatexenc sys.modules["pylatexenc.latex2text"] = latex2text - install_wandb_stub() - - -def install_vllm_router_stub() -> None: - if real_module_available("vllm_router"): - return - - class RouterArgs: - @classmethod - def add_cli_args(cls, parser, *args, **kwargs): # noqa: ARG003 - return parser - - @classmethod - def from_cli_args(cls, args, *unused_args, **unused_kwargs): # noqa: ARG003 - return types.SimpleNamespace() - - router_mod = types.ModuleType("vllm_router") - router_mod.__path__ = [] - launch_router_mod = types.ModuleType("vllm_router.launch_router") - router_args_mod = types.ModuleType("vllm_router.router_args") - launch_router_mod.RouterArgs = RouterArgs - router_args_mod.RouterArgs = RouterArgs - router_mod.launch_router = launch_router_mod - router_mod.router_args = router_args_mod - sys.modules["vllm_router"] = router_mod - sys.modules["vllm_router.launch_router"] = launch_router_mod - sys.modules["vllm_router.router_args"] = router_args_mod - - -def install_wandb_stub() -> None: - if real_module_available("wandb"): - return - wandb_mod = types.ModuleType("wandb") - wandb_mod.run = None - wandb_mod.log = MagicMock() - wandb_mod.finish = MagicMock() - wandb_mod.login = MagicMock() - wandb_mod.init = MagicMock() - wandb_mod.Settings = MagicMock() - wandb_mod.util = types.SimpleNamespace(generate_id=lambda: "unit-test") - sys.modules["wandb"] = wandb_mod - def save_sys_modules(names: Iterable[str]) -> dict[str, Any]: return {k: sys.modules.get(k) for k in names} @@ -237,57 +196,14 @@ def add_cli_args(cls, parser): # noqa: ARG003 arg_utils.AsyncEngineArgs = AsyncEngineArgs engine_mod.arg_utils = arg_utils - system_utils_mod = types.ModuleType("vllm.utils.system_utils") - system_utils_mod.kill_process_tree = lambda pid, include_parent=True: None # noqa: ARG005 - utils_mod.system_utils = system_utils_mod - - # vllm.entrypoints stubs (used by arguments.add_vllm_arguments and vllm_engine._vllm_server_field_names) - entrypoints_mod = types.ModuleType("vllm.entrypoints") - entrypoints_mod.__path__ = [] - openai_mod = types.ModuleType("vllm.entrypoints.openai") - openai_mod.__path__ = [] - cli_args_mod = types.ModuleType("vllm.entrypoints.openai.cli_args") - - import dataclasses as _dc - - @_dc.dataclass - class FrontendArgs: - @classmethod - def add_cli_args(cls, parser): # noqa: ARG003 - return parser - - cli_args_mod.FrontendArgs = FrontendArgs - cli_args_mod.make_arg_parser = lambda parser=None: parser - cli_args_mod.validate_parsed_serve_args = lambda args: args - openai_mod.cli_args = cli_args_mod - entrypoints_mod.openai = openai_mod - vllm_mod.entrypoints = entrypoints_mod - - cli_mod = types.ModuleType("vllm.entrypoints.cli") - cli_mod.__path__ = [] - serve_mod = types.ModuleType("vllm.entrypoints.cli.serve") - - class ServeSubcommand: - pass - - serve_mod.ServeSubcommand = ServeSubcommand - cli_mod.serve = serve_mod - entrypoints_mod.cli = cli_mod - vllm_mod.engine = engine_mod vllm_mod.utils = utils_mod sys.modules["vllm"] = vllm_mod sys.modules["vllm.utils"] = utils_mod sys.modules["vllm.utils.argparse_utils"] = argparse_utils - sys.modules["vllm.utils.system_utils"] = system_utils_mod sys.modules["vllm.engine"] = engine_mod sys.modules["vllm.engine.arg_utils"] = arg_utils - sys.modules["vllm.entrypoints"] = entrypoints_mod - sys.modules["vllm.entrypoints.openai"] = openai_mod - sys.modules["vllm.entrypoints.openai.cli_args"] = cli_args_mod - sys.modules["vllm.entrypoints.cli"] = cli_mod - sys.modules["vllm.entrypoints.cli.serve"] = serve_mod def install_triton_stub() -> None: @@ -307,4 +223,5 @@ def install_triton_stub() -> None: def install_vime_distributed_utils_stub() -> None: vime_utils = types.ModuleType("vime.utils.distributed_utils") vime_utils.get_gloo_group = MagicMock(return_value="gloo") + vime_utils.distributed_masked_whiten = MagicMock(side_effect=lambda values, *args, **kwargs: values) sys.modules.setdefault("vime.utils.distributed_utils", vime_utils) diff --git a/tests/test_rl_kernel_args.py b/tests/test_rl_kernel_args.py new file mode 100644 index 00000000..cd3916c4 --- /dev/null +++ b/tests/test_rl_kernel_args.py @@ -0,0 +1,83 @@ +from argparse import Namespace + +import pytest + +from vime.utils.rl_kernel import is_rl_kernel_op_enabled, normalize_rl_kernel_args, parse_rl_kernel_ops + +NUM_GPUS = 0 + + +@pytest.mark.unit +def test_parse_rl_kernel_ops_defaults_to_linear_logp(): + assert parse_rl_kernel_ops(None) == ("linear_logp",) + assert parse_rl_kernel_ops("") == ("linear_logp",) + + +@pytest.mark.unit +def test_parse_rl_kernel_ops_deduplicates_comma_and_space_separated_values(): + assert parse_rl_kernel_ops("linear_logp, linear_logp") == ("linear_logp",) + + +@pytest.mark.unit +def test_parse_rl_kernel_ops_rejects_unknown_ops(): + with pytest.raises(ValueError, match="Unsupported RL-Kernel op"): + parse_rl_kernel_ops("linear_logp,moe") + + +@pytest.mark.unit +def test_normalize_rl_kernel_args_keeps_default_disabled(): + args = Namespace(enable_rl_kernel=False, rl_kernel_ops="linear_logp", rl_kernel_strict=False) + + normalize_rl_kernel_args(args) + + assert args.enable_rl_kernel is False + assert args.rl_kernel_ops == ("linear_logp",) + assert is_rl_kernel_op_enabled(args, "linear_logp") is False + + +@pytest.mark.unit +def test_normalize_rl_kernel_args_accepts_env_enable(monkeypatch): + monkeypatch.setenv("VIME_RL_KERNEL", "1") + args = Namespace(enable_rl_kernel=False, rl_kernel_ops="linear_logp", rl_kernel_strict=False) + + normalize_rl_kernel_args(args) + + assert args.enable_rl_kernel is True + assert args.rl_kernel_ops == ("linear_logp",) + assert is_rl_kernel_op_enabled(args, "linear_logp") is True + + +@pytest.mark.unit +def test_normalize_rl_kernel_args_rejects_non_linear_logp_ops(): + args = Namespace(enable_rl_kernel=True, rl_kernel_ops="logp", rl_kernel_strict=False) + + with pytest.raises(ValueError, match="Unsupported RL-Kernel op"): + normalize_rl_kernel_args(args) + + +@pytest.mark.unit +@pytest.mark.parametrize("op", ["ratio_kl", "grpo_loss", "sampling"]) +def test_normalize_rl_kernel_args_rejects_out_of_scope_ops(op): + args = Namespace(enable_rl_kernel=True, rl_kernel_ops=op, rl_kernel_strict=False) + + with pytest.raises(ValueError, match="Unsupported RL-Kernel op"): + normalize_rl_kernel_args(args) + + +@pytest.mark.unit +def test_normalize_rl_kernel_args_accepts_linear_logp(): + args = Namespace(enable_rl_kernel=True, rl_kernel_ops="linear_logp", rl_kernel_strict=False) + + normalize_rl_kernel_args(args) + + assert args.rl_kernel_ops == ("linear_logp",) + assert is_rl_kernel_op_enabled(args, "linear_logp") is True + + +@pytest.mark.unit +def test_normalize_rl_kernel_args_rejects_bad_env_bool(monkeypatch): + monkeypatch.setenv("VIME_RL_KERNEL", "maybe") + args = Namespace(enable_rl_kernel=False, rl_kernel_ops="linear_logp", rl_kernel_strict=False) + + with pytest.raises(ValueError, match="VIME_RL_KERNEL"): + normalize_rl_kernel_args(args) diff --git a/tests/test_rl_kernel_linear_logp_integration.py b/tests/test_rl_kernel_linear_logp_integration.py new file mode 100644 index 00000000..1fe6b5e6 --- /dev/null +++ b/tests/test_rl_kernel_linear_logp_integration.py @@ -0,0 +1,423 @@ +from __future__ import annotations + +import builtins +import sys +import types +from argparse import Namespace +from pathlib import Path + +import pytest +import torch +import torch.nn.functional as F + +_tests_root = Path(__file__).resolve().parent +if str(_tests_root) not in sys.path: + sys.path.insert(0, str(_tests_root)) + +import _unit_stubs + +_unit_stubs.install_megatron_mpu_stub() +_unit_stubs.install_vime_distributed_utils_stub() + +from megatron.core import mpu # noqa: E402 + +from vime.backends.megatron_utils import loss as loss_mod # noqa: E402 +from vime.backends.megatron_utils import rl_kernel as rlk_mod # noqa: E402 + +NUM_GPUS = 0 + + +class _FakeLinearLogpOp: + calls: list[dict] = [] + + def __call__( + self, + hidden: torch.Tensor, + weight: torch.Tensor, + target_ids: torch.Tensor, + bias: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + type(self).calls.append( + { + "hidden_shape": tuple(hidden.shape), + "weight_shape": tuple(weight.shape), + "target_shape": tuple(target_ids.shape), + "bias": bias is not None, + "kwargs": kwargs, + } + ) + logits = F.linear(hidden.float(), weight.float(), None if bias is None else bias.float()) + return torch.gather(torch.log_softmax(logits, dim=-1), -1, target_ids.long().unsqueeze(-1)).squeeze(-1) + + +class _FakeLegacyLinearLogpOp: + def __call__( + self, + hidden: torch.Tensor, + weight: torch.Tensor, + target_ids: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + logits = F.linear(hidden.float(), weight.float(), None if bias is None else bias.float()) + return torch.gather(torch.log_softmax(logits, dim=-1), -1, target_ids.long().unsqueeze(-1)).squeeze(-1) + + +def _reset_rl_kernel_state(): + rlk_mod._LOGP_OP = None + rlk_mod._LOGP_OP_LOAD_ERROR = None + rlk_mod._LINEAR_LOGP_OP = None + rlk_mod._LINEAR_LOGP_OP_LOAD_ERROR = None + rlk_mod._WARNED_FALLBACK_REASONS.clear() + rlk_mod._FALLBACK_COUNTS.clear() + rlk_mod._FALLBACK_COUNTS.update({"logp": 0, "linear_logp": 0}) + _FakeLinearLogpOp.calls.clear() + + +def _install_fake_rl_engine(monkeypatch): + rl_engine = types.ModuleType("rl_engine") + kernels = types.ModuleType("rl_engine.kernels") + registry = types.ModuleType("rl_engine.kernels.registry") + registry.kernel_registry = types.SimpleNamespace(get_op=lambda name: _FakeLinearLogpOp()) + monkeypatch.setitem(sys.modules, "rl_engine", rl_engine) + monkeypatch.setitem(sys.modules, "rl_engine.kernels", kernels) + monkeypatch.setitem(sys.modules, "rl_engine.kernels.registry", registry) + + +def _install_fake_legacy_rl_engine(monkeypatch): + rl_engine = types.ModuleType("rl_engine") + kernels = types.ModuleType("rl_engine.kernels") + registry = types.ModuleType("rl_engine.kernels.registry") + registry.kernel_registry = types.SimpleNamespace(get_op=lambda name: _FakeLegacyLinearLogpOp()) + monkeypatch.setitem(sys.modules, "rl_engine", rl_engine) + monkeypatch.setitem(sys.modules, "rl_engine.kernels", kernels) + monkeypatch.setitem(sys.modules, "rl_engine.kernels.registry", registry) + + +def _make_args(**overrides) -> Namespace: + values = { + "enable_rl_kernel": True, + "rl_kernel_ops": ("linear_logp",), + "rl_kernel_strict": False, + "allgather_cp": False, + "qkv_format": "thd", + "rollout_temperature": 1.0, + "log_probs_chunk_size": -1, + "entropy_coef": 0.0, + "sequence_parallel": False, + "padded_vocab_size": None, + "vocab_size": None, + } + values.update(overrides) + return Namespace(**values) + + +@pytest.fixture(autouse=True) +def reset_parallelism(): + _reset_rl_kernel_state() + mpu.get_tensor_model_parallel_world_size.return_value = 1 + mpu.get_tensor_model_parallel_rank.return_value = 0 + mpu.get_tensor_model_parallel_group.return_value = None + mpu.get_context_parallel_world_size.return_value = 1 + mpu.get_context_parallel_rank.return_value = 0 + mpu.get_virtual_pipeline_model_parallel_world_size.return_value = None + mpu.is_pipeline_last_stage.return_value = True + yield + _reset_rl_kernel_state() + + +def _reference_logp(hidden: torch.Tensor, weight: torch.Tensor, target: torch.Tensor, bias: torch.Tensor | None): + logits = F.linear(hidden.float(), weight.float(), None if bias is None else bias.float()) + return torch.gather(torch.log_softmax(logits, dim=-1), -1, target.long().unsqueeze(-1)).squeeze(-1) + + +def _cpu_calculate_log_probs_and_entropy( + logits: torch.Tensor, + tokens: torch.Tensor, + tp_group, + *, + with_entropy: bool, + chunk_size: int, +): + del tp_group, chunk_size + log_probs = torch.log_softmax(logits.float(), dim=-1) + selected = torch.gather(log_probs, -1, tokens.long().unsqueeze(-1)).squeeze(-1) + entropy = None + if with_entropy: + probs = torch.softmax(logits.float(), dim=-1) + entropy = -(probs * log_probs).sum(dim=-1) + return selected, entropy + + +@pytest.mark.unit +def test_maybe_compute_linear_logp_passes_tensor_parallel_metadata(monkeypatch): + _install_fake_rl_engine(monkeypatch) + args = _make_args() + torch.manual_seed(1) + hidden = torch.randn(6, 5) + weight = torch.randn(8, 5) + bias = torch.randn(8) + target = torch.randint(0, 8, (6,)) + context = rlk_mod.LinearLogpContext( + lm_head_weight=weight, + bias=bias, + tp_group="tp", + vocab_start_index=16, + global_vocab_size=32, + ) + + actual = rlk_mod.maybe_compute_linear_logp(hidden, target, context=context, args=args, with_entropy=False) + + torch.testing.assert_close(actual, _reference_logp(hidden, weight, target, bias)) + assert _FakeLinearLogpOp.calls == [ + { + "hidden_shape": (6, 5), + "weight_shape": (8, 5), + "target_shape": (6,), + "bias": True, + "kwargs": { + "tp_group": "tp", + "vocab_start_index": 16, + "global_vocab_size": 32, + }, + } + ] + + +@pytest.mark.unit +def test_linear_logp_matches_vime_response_slicing_from_hidden_states(monkeypatch): + _install_fake_rl_engine(monkeypatch) + args = _make_args() + vocab_size = 23 + hidden_size = 7 + total_lengths = [5, 4, 6] + response_lengths = [2, 3, 4] + torch.manual_seed(2) + unconcat_tokens = [torch.randint(0, vocab_size, (length,), dtype=torch.long) for length in total_lengths] + hidden = torch.randn(sum(total_lengths), 1, hidden_size) + weight = torch.randn(vocab_size, hidden_size) + bias = torch.randn(vocab_size) + context = rlk_mod.LinearLogpContext(lm_head_weight=weight, bias=bias, tp_group=None) + + _, result = loss_mod.get_log_probs_and_entropy( + hidden, + args=args, + unconcat_tokens=unconcat_tokens, + total_lengths=total_lengths, + response_lengths=response_lengths, + with_entropy=False, + rl_kernel_linear_logp_context=context, + ) + + full_tokens = torch.zeros(sum(total_lengths), dtype=torch.long) + offset = 0 + for tokens, total_length in zip(unconcat_tokens, total_lengths, strict=False): + full_tokens[offset : offset + total_length - 1] = tokens[1:total_length] + offset += total_length + full_logp = _reference_logp(hidden.squeeze(1), weight, full_tokens, bias) + + expected = [] + offset = 0 + for total_length, response_length in zip(total_lengths, response_lengths, strict=False): + end = offset + total_length + start = end - response_length + expected.append(full_logp[start - 1 : end - 1]) + offset += total_length + + assert len(_FakeLinearLogpOp.calls) == 1 + for actual, expected_item in zip(result["log_probs"], expected, strict=True): + torch.testing.assert_close(actual, expected_item, rtol=1e-6, atol=1e-6) + + +@pytest.mark.unit +def test_linear_logp_materializes_logits_fallback_when_optional_package_missing(monkeypatch): + monkeypatch.setattr(loss_mod, "calculate_log_probs_and_entropy", _cpu_calculate_log_probs_and_entropy) + for name in list(sys.modules): + if name == "rl_engine" or name.startswith("rl_engine."): + monkeypatch.delitem(sys.modules, name, raising=False) + + original_import = builtins.__import__ + + def fail_rl_engine_import(name, *args, **kwargs): + if name == "rl_engine" or name.startswith("rl_engine."): + raise ModuleNotFoundError("No module named 'rl_engine'") + return original_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", fail_rl_engine_import) + args = _make_args() + vocab_size = 19 + hidden_size = 5 + total_lengths = [4, 5] + response_lengths = [2, 3] + torch.manual_seed(3) + unconcat_tokens = [torch.randint(0, vocab_size, (length,), dtype=torch.long) for length in total_lengths] + hidden = torch.randn(1, sum(total_lengths), hidden_size) + weight = torch.randn(vocab_size, hidden_size) + bias = torch.randn(vocab_size) + context = rlk_mod.LinearLogpContext(lm_head_weight=weight, bias=bias, tp_group=None) + + _, result = loss_mod.get_log_probs_and_entropy( + hidden, + args=args, + unconcat_tokens=unconcat_tokens, + total_lengths=total_lengths, + response_lengths=response_lengths, + with_entropy=False, + rl_kernel_linear_logp_context=context, + ) + + logits = F.linear(hidden.squeeze(0).float(), weight.float(), bias.float()) + expected = [] + offset = 0 + for tokens, total_length, response_length in zip(unconcat_tokens, total_lengths, response_lengths, strict=False): + end = offset + total_length + start = end - response_length + target = tokens[-response_length:] + expected.append( + torch.gather(torch.log_softmax(logits[start - 1 : end - 1], dim=-1), -1, target.unsqueeze(-1)).squeeze(-1) + ) + offset += total_length + + assert rlk_mod.get_rl_kernel_fallback_count("linear_logp") == 1 + for actual, expected_item in zip(result["log_probs"], expected, strict=True): + torch.testing.assert_close(actual, expected_item, rtol=1e-6, atol=1e-6) + + +@pytest.mark.unit +def test_linear_logp_falls_back_when_op_lacks_tp_interface(monkeypatch): + _install_fake_legacy_rl_engine(monkeypatch) + args = _make_args() + hidden = torch.randn(3, 4) + weight = torch.randn(6, 4) + target = torch.randint(0, 6, (3,)) + context = rlk_mod.LinearLogpContext( + lm_head_weight=weight, + bias=None, + tp_group="tp_group", + vocab_start_index=6, + global_vocab_size=12, + ) + + actual = rlk_mod.maybe_compute_linear_logp(hidden, target, context=context, args=args, with_entropy=False) + + assert actual is None + assert rlk_mod.get_rl_kernel_fallback_count("linear_logp") == 1 + + +@pytest.mark.unit +def test_linear_logp_context_from_model_uses_tp_vocab_offsets(): + mpu.get_tensor_model_parallel_world_size.return_value = 4 + mpu.get_tensor_model_parallel_rank.return_value = 2 + mpu.get_tensor_model_parallel_group.return_value = "tp_group" + + output_layer = types.SimpleNamespace( + weight=torch.empty(8, 4), + bias=None, + sequence_parallel=True, + ) + model = types.SimpleNamespace(output_layer=output_layer, post_process=True) + args = _make_args(padded_vocab_size=32, sequence_parallel=False) + + context = rlk_mod.get_linear_logp_context_from_model(args, model) + + assert context is not None + assert context.lm_head_weight is output_layer.weight + assert context.tp_group == "tp_group" + assert context.vocab_start_index == 16 + assert context.global_vocab_size == 32 + assert context.sequence_parallel is True + + +@pytest.mark.unit +def test_linear_logp_context_prefers_output_layer_weight_for_untied_pp1_model(): + mpu.get_tensor_model_parallel_world_size.return_value = 1 + mpu.get_tensor_model_parallel_rank.return_value = 0 + mpu.get_tensor_model_parallel_group.return_value = None + + output_weight = torch.empty(8, 4) + embedding_weight = torch.empty(8, 4) + output_layer = types.SimpleNamespace(weight=output_weight, bias=None) + model = types.SimpleNamespace( + output_layer=output_layer, + post_process=True, + pre_process=True, + shared_embedding_or_output_weight=lambda: embedding_weight, + ) + args = _make_args() + + context = rlk_mod.get_linear_logp_context_from_model(args, model) + + assert context is not None + assert context.lm_head_weight is output_weight + + +@pytest.mark.unit +def test_linear_logp_context_uses_shared_weight_when_output_layer_weight_is_missing(): + mpu.get_tensor_model_parallel_world_size.return_value = 1 + mpu.get_tensor_model_parallel_rank.return_value = 0 + mpu.get_tensor_model_parallel_group.return_value = None + + embedding_weight = torch.empty(8, 4) + output_layer = types.SimpleNamespace(weight=None, bias=None) + model = types.SimpleNamespace( + output_layer=output_layer, + post_process=True, + pre_process=True, + shared_embedding_or_output_weight=lambda: embedding_weight, + ) + args = _make_args() + + context = rlk_mod.get_linear_logp_context_from_model(args, model) + + assert context is not None + assert context.lm_head_weight is embedding_weight + + +@pytest.mark.unit +def test_linear_logp_context_uses_covered_padded_vocab_when_padded_vocab_size_missing(): + mpu.get_tensor_model_parallel_world_size.return_value = 4 + mpu.get_tensor_model_parallel_rank.return_value = 1 + mpu.get_tensor_model_parallel_group.return_value = "tp_group" + + output_layer = types.SimpleNamespace(weight=torch.empty(8, 4), bias=None) + model = types.SimpleNamespace(output_layer=output_layer, post_process=True) + args = _make_args(padded_vocab_size=None, vocab_size=30) + + context = rlk_mod.get_linear_logp_context_from_model(args, model) + + assert context is not None + assert context.vocab_start_index == 8 + assert context.global_vocab_size == 32 + + +@pytest.mark.unit +def test_return_hidden_states_for_linear_logp_restores_post_process_flag(): + args = _make_args() + model = types.SimpleNamespace(post_process=True) + context = rlk_mod.LinearLogpContext( + lm_head_weight=torch.empty(4, 3), + bias=None, + tp_group=None, + ) + + with rlk_mod.return_hidden_states_for_linear_logp(args, model, context) as enabled: + assert enabled is True + assert model.post_process is False + + assert model.post_process is True + + +@pytest.mark.unit +def test_policy_loss_only_skips_entropy_when_linear_logp_context_is_active(): + args = _make_args(enable_rl_kernel=True, rl_kernel_ops=("linear_logp",), entropy_coef=0.0) + context = rlk_mod.LinearLogpContext( + lm_head_weight=torch.empty(4, 3), + bias=None, + tp_group=None, + ) + + assert loss_mod._policy_loss_needs_entropy(args, None) is True + assert loss_mod._policy_loss_needs_entropy(args, context) is False + + args.entropy_coef = 0.01 + assert loss_mod._policy_loss_needs_entropy(args, context) is True diff --git a/tests/test_rl_kernel_logp_integration.py b/tests/test_rl_kernel_logp_integration.py new file mode 100644 index 00000000..5d444768 --- /dev/null +++ b/tests/test_rl_kernel_logp_integration.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +import sys +import types +from argparse import Namespace +from pathlib import Path + +import pytest +import torch + +_tests_root = Path(__file__).resolve().parent +if str(_tests_root) not in sys.path: + sys.path.insert(0, str(_tests_root)) + +import _unit_stubs + +_unit_stubs.install_megatron_mpu_stub() +_unit_stubs.install_vime_distributed_utils_stub() + +from megatron.core import mpu # noqa: E402 + +from vime.backends.megatron_utils import loss as loss_mod # noqa: E402 +from vime.backends.megatron_utils import rl_kernel as rlk_mod # noqa: E402 + + +NUM_GPUS = 0 + + +class _FakeLogpOp: + calls = 0 + + def apply_fp32(self, logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor: + type(self).calls += 1 + return torch.gather(torch.log_softmax(logits.float(), dim=-1), -1, token_ids.long().unsqueeze(-1)).squeeze(-1) + + +def _reset_rl_kernel_state(): + rlk_mod._LOGP_OP = None + rlk_mod._LOGP_OP_LOAD_ERROR = None + rlk_mod._LINEAR_LOGP_OP = None + rlk_mod._LINEAR_LOGP_OP_LOAD_ERROR = None + rlk_mod._WARNED_FALLBACK_REASONS.clear() + rlk_mod._FALLBACK_COUNTS.clear() + rlk_mod._FALLBACK_COUNTS.update({"logp": 0, "linear_logp": 0}) + _FakeLogpOp.calls = 0 + + +def _install_fake_rl_engine(monkeypatch): + rl_engine = types.ModuleType("rl_engine") + kernels = types.ModuleType("rl_engine.kernels") + registry = types.ModuleType("rl_engine.kernels.registry") + registry.kernel_registry = types.SimpleNamespace(get_op=lambda name: _FakeLogpOp()) + monkeypatch.setitem(sys.modules, "rl_engine", rl_engine) + monkeypatch.setitem(sys.modules, "rl_engine.kernels", kernels) + monkeypatch.setitem(sys.modules, "rl_engine.kernels.registry", registry) + + +def _make_args(**overrides) -> Namespace: + values = { + "enable_rl_kernel": True, + "rl_kernel_ops": ("logp",), + "rl_kernel_strict": False, + "allgather_cp": False, + "qkv_format": "thd", + "rollout_temperature": 1.0, + "log_probs_chunk_size": -1, + } + values.update(overrides) + return Namespace(**values) + + +@pytest.fixture(autouse=True) +def reset_parallelism(): + _reset_rl_kernel_state() + mpu.get_tensor_model_parallel_world_size.return_value = 1 + mpu.get_tensor_model_parallel_group.return_value = None + mpu.get_context_parallel_world_size.return_value = 1 + mpu.get_context_parallel_rank.return_value = 0 + yield + _reset_rl_kernel_state() + + +@pytest.mark.unit +def test_rl_kernel_logp_matches_vime_response_slicing(monkeypatch): + _install_fake_rl_engine(monkeypatch) + args = _make_args() + vocab_size = 257 + total_lengths = [9, 7, 11] + response_lengths = [4, 3, 6] + torch.manual_seed(17) + unconcat_tokens = [torch.randint(0, vocab_size, (length,), dtype=torch.long) for length in total_lengths] + logits = torch.randn(1, sum(total_lengths), vocab_size, dtype=torch.float32) + + with torch.no_grad(): + _, result = loss_mod.get_log_probs_and_entropy( + logits, + args=args, + unconcat_tokens=unconcat_tokens, + total_lengths=total_lengths, + response_lengths=response_lengths, + with_entropy=False, + ) + + expected = [] + offset = 0 + log_probs = torch.log_softmax(logits.squeeze(0), dim=-1) + for tokens, total_length, response_length in zip(unconcat_tokens, total_lengths, response_lengths, strict=False): + start = offset + total_length - response_length + end = offset + total_length + shifted_tokens = tokens[-response_length:] + expected.append(torch.gather(log_probs[start - 1 : end - 1], -1, shifted_tokens.unsqueeze(-1)).squeeze(-1)) + offset += total_length + + assert _FakeLogpOp.calls == 1 + assert len(result["log_probs"]) == len(expected) + for actual, expected_item in zip(result["log_probs"], expected, strict=True): + torch.testing.assert_close(actual, expected_item, rtol=1e-6, atol=1e-6) + + +@pytest.mark.unit +def test_rl_kernel_logp_falls_back_when_entropy_requested(monkeypatch): + _install_fake_rl_engine(monkeypatch) + args = _make_args() + logits = torch.randn(8, 16, dtype=torch.float32) + tokens = torch.randint(0, 16, (8,), dtype=torch.long) + + with torch.no_grad(): + actual = rlk_mod.maybe_compute_logp(logits, tokens, args=args, with_entropy=True) + + assert actual is None + assert _FakeLogpOp.calls == 0 + + +@pytest.mark.unit +def test_rl_kernel_logp_falls_back_for_tensor_parallel_vocab(monkeypatch): + _install_fake_rl_engine(monkeypatch) + mpu.get_tensor_model_parallel_world_size.return_value = 2 + args = _make_args() + logits = torch.randn(8, 16, dtype=torch.float32) + tokens = torch.randint(0, 16, (8,), dtype=torch.long) + + with torch.no_grad(): + actual = rlk_mod.maybe_compute_logp(logits, tokens, args=args, with_entropy=False) + + assert actual is None + assert _FakeLogpOp.calls == 0 + + +@pytest.mark.unit +def test_rl_kernel_logp_strict_mode_raises_on_unsupported_parallelism(monkeypatch): + _install_fake_rl_engine(monkeypatch) + mpu.get_tensor_model_parallel_world_size.return_value = 2 + args = _make_args(rl_kernel_strict=True) + logits = torch.randn(8, 16, dtype=torch.float32) + tokens = torch.randint(0, 16, (8,), dtype=torch.long) + + with pytest.raises(RuntimeError, match="tensor-parallel vocab shards"): + with torch.no_grad(): + rlk_mod.maybe_compute_logp(logits, tokens, args=args, with_entropy=False) + + +@pytest.mark.unit +def test_rl_kernel_logp_falls_back_when_optional_package_missing(monkeypatch): + import builtins + + for name in list(sys.modules): + if name == "rl_engine" or name.startswith("rl_engine."): + monkeypatch.delitem(sys.modules, name, raising=False) + + original_import = builtins.__import__ + + def fail_rl_engine_import(name, *args, **kwargs): + if name == "rl_engine" or name.startswith("rl_engine."): + raise ModuleNotFoundError("No module named 'rl_engine'") + return original_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", fail_rl_engine_import) + args = _make_args() + logits = torch.randn(8, 16, dtype=torch.float32) + tokens = torch.randint(0, 16, (8,), dtype=torch.long) + + with torch.no_grad(): + actual = rlk_mod.maybe_compute_logp(logits, tokens, args=args, with_entropy=False) + + assert actual is None diff --git a/vime-RLK.md b/vime-RLK.md new file mode 100644 index 00000000..ca4ce493 --- /dev/null +++ b/vime-RLK.md @@ -0,0 +1,312 @@ +# vime + RL-Kernel linear_logp 主宣传实验 + +## 0. 我们要做什么 + +本轮只保留一个主宣传实验: + +```text +baseline: vime benchmark branch, Qwen3-30B-A3B, 8xH100 colocate, RL-Kernel off +candidate: vime benchmark branch + RL-Kernel linear_logp, Qwen3-30B-A3B, 8xH100 colocate +``` + +目标:证明 RL-Kernel 的 `linear_logp` 接入 vime 后,在同一套 Qwen3-30B-A3B MoE 训练配置下,不降低训练质量,并降低 selected-logprob 路径耗时或显存压力。 + +范围收口: + +- 只测 `linear_logp`。 +- 只跑 Qwen3-30B-A3B 主宣传实验。 +- 不跑 Qwen3-4B smoke、R3 单独对比、GLM-4.5、GB200/H200 硬件对照。 +- 不测 `logp`、`ratio_kl`、`grpo_loss`、`sampling` 的 vime 端到端收益。 +- 不做训推一致性专项 benchmark。 +- 不接 MoE expert/router 算子。 + +## 1. H100 支持结论 + +vime 支持 H100:当前 vime 文档已有 `Qwen3-30B-A3B with 8xH100` 和 `Qwen3-4B with 8xH100` 示例,代码里也有 H100 hardware mapping。 + +所以本轮主方案使用: + +```text +8xH100 +``` + +A100 不作为本轮主宣传配置。 + +## 2. 当前代码边界 + +vime candidate 只暴露一个 RL-Kernel op: + +```text +RL_KERNEL_SUPPORTED_OPS = ("linear_logp",) +RL_KERNEL_INTEGRATED_OPS = ("linear_logp",) +--rl-kernel-ops linear_logp +VIME_RL_KERNEL_OPS=linear_logp +``` + +主实验脚本: + +```text +scripts/run-qwen3-30B-A3B.sh +``` + +该脚本已经按 8 卡主宣传实验参数化: + +```text +NUM_GPUS=8 +MEGATRON_TP=2 +MEGATRON_EP=8 +MEGATRON_CP=1 +ROLLOUT_NUM_GPUS_PER_ENGINE=8 +ROLLOUT_BATCH_SIZE=32 +N_SAMPLES_PER_PROMPT=8 +GLOBAL_BATCH_SIZE=256 +MAX_TOKENS_PER_GPU=20480 +VLLM_GPU_MEMORY_UTILIZATION=0.7 +``` + +如遇 OOM,先降低: + +```text +MAX_TOKENS_PER_GPU=8192 +VLLM_GPU_MEMORY_UTILIZATION=0.55 +ROLLOUT_BATCH_SIZE=4 +GLOBAL_BATCH_SIZE=32 +``` + +## 3. 上卡准备 + +从官方仓库开始,不依赖当前本地目录: + +```bash +cd /workspace +git clone https://github.com/RL-Align/vime.git vime-main +git clone https://github.com/RL-Align/vime.git vime-benchmark +git clone https://github.com/RL-Align/vime.git vime-rlk-integration +git clone https://github.com/RL-Align/RL-Kernel.git RL-Kernel +``` + +RL-Kernel 必须使用含 TP 版 `linear_logp` 接口的版本。vime 这边会调用: + +```text +op(hidden, weight, target_ids, bias, tp_group=..., vocab_start_index=..., global_vocab_size=...) +``` + +当前使用 `RL-Align/RL-Kernel#189` 提供 TP 版 `linear_logp`。上卡后在 `/workspace/RL-Kernel` 里 checkout 该 PR 后再安装: + +```bash +cd /workspace/RL-Kernel +git checkout main +git pull origin main +gh pr checkout 189 +``` + +`vime-main` 只作为干净参考,不直接跑实验: + +```bash +cd /workspace/vime-main +git checkout main +git pull origin main +``` + +vime candidate 已经准备成 draft PR,baseline 仍然要保持 benchmark-only,避免把 baseline 和 candidate 混在一起: + +```text +vime-rlk-benchmark-8h100 +只包含 8xH100 benchmark harness,不包含 RL-Kernel 集成代码。 + +RL-Align/vime#1 +draft PR,基于 benchmark harness,再加入 RL-Kernel linear_logp 集成代码和测试。 +``` + +baseline 从干净 main 新建 benchmark 分支: + +```bash +cd /workspace/vime-benchmark +git checkout main +git pull origin main +git checkout -b vime-rlk-benchmark-8h100 +# 只应用 benchmark harness 改动,例如 scripts/run-qwen3-30B-A3B.sh 的 8xH100 参数化。 +# 不加入 --enable-rl-kernel、vime/utils/rl_kernel.py、megatron_utils/rl_kernel.py 等 RL-Kernel 集成改动。 +``` + +candidate 直接 checkout draft PR `RL-Align/vime#1`: + +```bash +cd /workspace/vime-rlk-integration +git checkout main +git pull origin main +gh pr checkout 1 +``` + +安装: + +```bash +cd /workspace/RL-Kernel +pip install -e . +python setup.py build_ext --inplace -v + +cd /workspace/vime-benchmark +pip install -e . + +cd /workspace/vime-rlk-integration +pip install -e . +``` + +下载模型和数据: + +```bash +pip install -U "huggingface_hub[cli]" +huggingface-cli login + +hf download Qwen/Qwen3-30B-A3B --local-dir /root/Qwen3-30B-A3B + +hf download --repo-type dataset zhuzilin/dapo-math-17k \ + --local-dir /root/dapo-math-17k + +hf download --repo-type dataset zhuzilin/aime-2024 \ + --local-dir /root/aime-2024 +``` + +转换 Megatron `torch_dist` checkpoint: + +```bash +cd /workspace/vime-benchmark +source scripts/models/qwen3-30B-A3B.sh + +PYTHONPATH=/root/Megatron-LM torchrun --nproc-per-node 8 \ + tools/convert_hf_to_torch_dist.py \ + ${MODEL_ARGS[@]} \ + --hf-checkpoint /root/Qwen3-30B-A3B \ + --save /root/Qwen3-30B-A3B_torch_dist +``` + +## 4. 运行主宣传实验 + +两边使用同一套 8 卡 colocate 环境变量: + +```bash +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export NUM_GPUS=8 +export MEGATRON_TP=2 +export MEGATRON_EP=8 +export MEGATRON_CP=1 +export ROLLOUT_NUM_GPUS_PER_ENGINE=8 +export ROLLOUT_BATCH_SIZE=32 +export N_SAMPLES_PER_PROMPT=8 +export GLOBAL_BATCH_SIZE=256 +export MAX_TOKENS_PER_GPU=20480 +export VLLM_GPU_MEMORY_UTILIZATION=0.7 +``` + +baseline: + +```bash +cd /workspace/vime-benchmark +unset VIME_RL_KERNEL VIME_RL_KERNEL_OPS VIME_RL_KERNEL_STRICT +bash scripts/run-qwen3-30B-A3B.sh +``` + +candidate: + +```bash +cd /workspace/vime-rlk-integration +export VIME_RL_KERNEL=1 +export VIME_RL_KERNEL_OPS=linear_logp +export VIME_RL_KERNEL_STRICT=1 +bash scripts/run-qwen3-30B-A3B.sh +``` + +每组至少跑 3 次;每次丢弃前 5-10 step warmup 后统计。 + +## 5. 必须记录 + +每个 run 保存: + +```text +hardware +gpu_name +num_gpus +model +dataset +vime_commit +rl_kernel_commit +candidate_enabled +enabled_rl_kernel_ops +selected_rl_kernel_backend +tp +ep +cp +rollout_batch_size +n_samples_per_prompt +global_batch_size +max_tokens_per_gpu +mean_step_time_s +p50_step_time_s +p90_step_time_s +mean_log_probs_time_s +p50_log_probs_time_s +p90_log_probs_time_s +peak_vram_gb +raw_reward_mean +train_rollout_logprob_abs_diff_mean +rl_kernel_fallback_count +``` + +验收线: + +```text +candidate 日志出现 RL-Kernel linear_logp backend +rl_kernel_fallback_count = 0 +candidate raw_reward 不低于 baseline 同量级 +candidate train_rollout_logprob_abs_diff 不持续高于 baseline +candidate mean_log_probs_time_s 或 peak_vram_gb 有可解释下降 +``` + +## 6. 最终图表 + +只输出主宣传图: + +1. `Qwen3-30B-A3B 8xH100 raw_reward` +2. `Qwen3-30B-A3B 8xH100 train_rollout_logprob_abs_diff` +3. `Qwen3-30B-A3B 8xH100 Step Time` +4. `Qwen3-30B-A3B 8xH100 Logprob Time / Peak VRAM` + +图表风格对齐 `vime_blog.md`:白底、虚线网格、baseline 蓝色、candidate 红色。 + +## 7. 本地验证 + +当前无 GPU 环境已完成: + +```text +# linear_logp 主路径与公共工具 +pytest tests/test_rl_kernel_args.py tests/test_rl_kernel_linear_logp_integration.py tests/test_value_temperature.py tests/test_metric_report.py -q +结果:39 passed + +# legacy logp compatibility regression,不属于本轮 benchmark 范围 +pytest tests/test_rl_kernel_logp_integration.py tests/test_rl_kernel_args.py tests/test_rl_kernel_linear_logp_integration.py -q +结果:24 passed + +pre-commit run --files <本轮 vime 相关文件> +结果:Passed +``` + +上卡后必须补跑: + +```text +8xH100 baseline: /workspace/vime-benchmark, benchmark-only branch +8xH100 candidate: /workspace/vime-rlk-integration, RL-Align/vime#1 +``` + +## 8. 宣传口径 + +英文: + +```text +RL-Kernel integrates with vime to accelerate the Qwen3-30B-A3B GRPO selected-logprob path through linear_logp. On the same 8xH100 setup, it reduces logprob-path cost while keeping reward and train-rollout logprob alignment stable. +``` + +中文: + +```text +RL-Kernel 接入 vime 后,通过 linear_logp 加速 Qwen3-30B-A3B GRPO selected-logprob 路径。在相同 8xH100 配置下,RL-Kernel 降低 logprob 路径开销,同时保持 reward 和 train-rollout logprob alignment 稳定。 +``` diff --git a/vime/backends/megatron_utils/loss.py b/vime/backends/megatron_utils/loss.py index 3f6ab29c..e205a45f 100644 --- a/vime/backends/megatron_utils/loss.py +++ b/vime/backends/megatron_utils/loss.py @@ -29,6 +29,7 @@ get_sum_of_sample_mean, slice_log_prob_with_cp, ) +from .rl_kernel import LinearLogpContext, get_rl_kernel_fallback_count, maybe_compute_linear_logp, maybe_compute_logp def get_responses( @@ -383,6 +384,65 @@ def _extract_per_sample( return log_probs_list, entropy_list +def _gather_sequence_parallel_hidden_if_needed( + hidden_states: torch.Tensor, + context: LinearLogpContext | None, +) -> torch.Tensor: + if context is None or not context.sequence_parallel: + return hidden_states + + from megatron.core import tensor_parallel + + return tensor_parallel.gather_from_sequence_parallel_region(hidden_states, tensor_parallel_output_grad=False) + + +def _flatten_logprob_model_output( + output_tensor: torch.Tensor, + *, + qkv_format: str, + max_seq_lens: list[int] | None, + linear_logp_context: LinearLogpContext | None, +) -> torch.Tensor: + if qkv_format == "thd": + assert len(output_tensor.shape) == 3, f"{output_tensor.shape}" + if output_tensor.size(0) == 1: + return output_tensor.squeeze(0) + if linear_logp_context is not None and output_tensor.size(1) == 1: + return output_tensor.squeeze(1) + assert output_tensor.size(0) == 1, f"{output_tensor.shape}" + else: + assert max_seq_lens is not None + return output_tensor.view(-1, output_tensor.size(-1)) + + raise AssertionError(f"Unsupported output tensor shape: {output_tensor.shape}") + + +def _materialize_linear_logits( + hidden_states: torch.Tensor, + *, + context: LinearLogpContext, + args: Namespace, +) -> torch.Tensor: + logits = F.linear(hidden_states, context.lm_head_weight, context.bias) + rollout_temperature = getattr(args, "rollout_temperature", 1.0) + if rollout_temperature != 1.0: + logits = logits / rollout_temperature + return logits.float() + + +def _policy_loss_needs_entropy( + args: Namespace, + rl_kernel_linear_logp_context: LinearLogpContext | None, +) -> bool: + if rl_kernel_linear_logp_context is None: + return True + return not ( + getattr(args, "enable_rl_kernel", False) + and "linear_logp" in getattr(args, "rl_kernel_ops", ()) + and getattr(args, "entropy_coef", 0.0) == 0 + ) + + def get_log_probs_and_entropy( logits: torch.Tensor, *, @@ -393,6 +453,7 @@ def get_log_probs_and_entropy( with_entropy: bool = False, non_loss_data: bool = True, max_seq_lens: list[int] | None = None, + rl_kernel_linear_logp_context: LinearLogpContext | None = None, ) -> dict[str, list[torch.Tensor]]: """Compute per-token log-probabilities (and optionally entropy) on responses. @@ -406,21 +467,28 @@ def get_log_probs_and_entropy( assert non_loss_data qkv_format = args.qkv_format - assert logits.dtype == torch.float32, f"{logits.dtype}" assert len(logits.shape) == 3, f"{logits.shape}" - if qkv_format == "thd": - assert logits.size(0) == 1, f"{logits.shape}" - logits = logits.squeeze(0) + linear_logp_context = rl_kernel_linear_logp_context + if linear_logp_context is not None: + logits = _gather_sequence_parallel_hidden_if_needed(logits, linear_logp_context) else: - assert max_seq_lens is not None - logits = logits.view(-1, logits.size(-1)) + assert logits.dtype == torch.float32, f"{logits.dtype}" + + logits = _flatten_logprob_model_output( + logits, + qkv_format=qkv_format, + max_seq_lens=max_seq_lens, + linear_logp_context=linear_logp_context, + ).contiguous() + + if linear_logp_context is None: + # Apply rollout temperature scaling to logits to match rollout-time log-probs. + rollout_temperature = getattr(args, "rollout_temperature", 1.0) + if rollout_temperature != 1.0: + logits = logits / rollout_temperature + logits = logits.contiguous() - # Apply rollout temperature scaling to logits to match rollout-time log-probs. - rollout_temperature = getattr(args, "rollout_temperature", 1.0) - if rollout_temperature != 1.0: - logits = logits / rollout_temperature - logits = logits.contiguous() T = logits.size(0) device = logits.device tp_group = mpu.get_tensor_model_parallel_group() @@ -431,14 +499,31 @@ def get_log_probs_and_entropy( T, device, unconcat_tokens, total_lengths, response_lengths, qkv_format, max_seq_lens, args.allgather_cp ) - # --- compute on full [T,V] logits at once via calculate_log_probs_and_entropy --- - log_prob_full, entropy_full = calculate_log_probs_and_entropy( - logits, - full_tokens, - tp_group, - with_entropy=with_entropy, - chunk_size=chunk_size, - ) + # --- compute on full [T,V] logits at once --- + log_prob_full = None + if linear_logp_context is not None: + log_prob_full = maybe_compute_linear_logp( + logits, + full_tokens, + context=linear_logp_context, + args=args, + with_entropy=with_entropy, + ) + if log_prob_full is None: + logits = _materialize_linear_logits(logits, context=linear_logp_context, args=args).contiguous() + else: + log_prob_full = maybe_compute_logp(logits, full_tokens, args=args, with_entropy=with_entropy) + + if log_prob_full is None: + log_prob_full, entropy_full = calculate_log_probs_and_entropy( + logits, + full_tokens, + tp_group, + with_entropy=with_entropy, + chunk_size=chunk_size, + ) + else: + entropy_full = None log_prob_full = log_prob_full.squeeze(-1) # [T, 1] -> [T] # --- extract per-sample response portions --- @@ -802,6 +887,7 @@ def policy_loss_function( batch: RolloutBatch, logits: torch.Tensor, sum_of_sample_mean: Callable[[torch.Tensor], torch.Tensor], + rl_kernel_linear_logp_context: LinearLogpContext | None = None, ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """Compute policy loss (PPO/GSPO) and metrics. @@ -834,14 +920,17 @@ def policy_loss_function( total_lengths = batch["total_lengths"] max_seq_lens = batch.get("max_seq_lens", None) + need_entropy = _policy_loss_needs_entropy(args, rl_kernel_linear_logp_context) + _, log_probs_and_entropy = get_log_probs_and_entropy( logits, args=args, unconcat_tokens=batch["unconcat_tokens"], total_lengths=total_lengths, response_lengths=response_lengths, - with_entropy=True, + with_entropy=need_entropy, max_seq_lens=max_seq_lens, + rl_kernel_linear_logp_context=rl_kernel_linear_logp_context, ) log_probs = log_probs_and_entropy["log_probs"] @@ -963,9 +1052,12 @@ def policy_loss_function( ppo_kl = sum_of_sample_mean(ppo_kl) # entropy loss - entropy = log_probs_and_entropy["entropy"] - entropy = torch.cat(entropy, dim=0) - entropy_loss = sum_of_sample_mean(entropy) + if need_entropy: + entropy = log_probs_and_entropy["entropy"] + entropy = torch.cat(entropy, dim=0) + entropy_loss = sum_of_sample_mean(entropy) + else: + entropy_loss = log_probs.new_zeros(()) loss = pg_loss - args.entropy_coef * entropy_loss @@ -1033,6 +1125,7 @@ def value_loss_function( batch: RolloutBatch, logits: torch.Tensor, sum_of_sample_mean: Callable[[torch.Tensor], torch.Tensor], + rl_kernel_linear_logp_context: LinearLogpContext | None = None, ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """Compute clipped value loss and metrics. @@ -1051,6 +1144,7 @@ def value_loss_function( Tuple of `(loss, metrics)` where `loss` is a scalar tensor and `metrics` contains detached scalars "value_loss" and "value_clipfrac". """ + del rl_kernel_linear_logp_context old_values = torch.cat(batch["values"], dim=0) _, values = get_values( @@ -1091,6 +1185,7 @@ def sft_loss_function( batch: RolloutBatch, logits: torch.Tensor, sum_of_sample_mean: Callable[[torch.Tensor], torch.Tensor], + rl_kernel_linear_logp_context: LinearLogpContext | None = None, ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """Compute supervised fine-tuning loss over response tokens. @@ -1119,6 +1214,7 @@ def sft_loss_function( response_lengths=response_lengths, with_entropy=False, max_seq_lens=batch.get("max_seq_lens", None), + rl_kernel_linear_logp_context=rl_kernel_linear_logp_context, ) log_probs = log_probs_and_entropy["log_probs"] @@ -1143,6 +1239,7 @@ def loss_function( num_microbatches: int, step_global_batch_size: int, logits: torch.Tensor, + rl_kernel_linear_logp_context: LinearLogpContext | None = None, ) -> tuple[torch.Tensor, int | torch.Tensor, dict[str, list[str] | torch.Tensor]]: """Dispatch to the configured loss and rescale for Megatron integration. @@ -1195,10 +1292,22 @@ def loss_function( case _: raise ValueError(f"Unknown loss type: {args.loss_type}") + if func in {policy_loss_function, value_loss_function, sft_loss_function}: + func_args = (args, batch, logits, sum_of_sample_mean, rl_kernel_linear_logp_context) + else: + func_args = (args, batch, logits, sum_of_sample_mean) + if args.recompute_loss_function: - loss, log = checkpoint(func, args, batch, logits, sum_of_sample_mean, use_reentrant=False) + loss, log = checkpoint(func, *func_args, use_reentrant=False) else: - loss, log = func(args, batch, logits, sum_of_sample_mean) + loss, log = func(*func_args) + + if getattr(args, "enable_rl_kernel", False): + log["rl_kernel_fallback_count"] = torch.tensor( + get_rl_kernel_fallback_count(), + device=logits.device, + dtype=torch.float32, + ) # With allgather-CP, some CP ranks may have no loss-contributing tokens (e.g., all # padding). Without this, gradient doesn't flow through their attention path, so diff --git a/vime/backends/megatron_utils/model.py b/vime/backends/megatron_utils/model.py index 6b602fc1..14f80f53 100644 --- a/vime/backends/megatron_utils/model.py +++ b/vime/backends/megatron_utils/model.py @@ -27,14 +27,22 @@ from megatron.core.pipeline_parallel.utils import unwrap_model except ImportError: from megatron.core.utils import unwrap_model + from vime.utils import logging_utils from vime.utils.memory_utils import clear_memory +from vime.utils.rl_kernel import is_rl_kernel_op_enabled from .checkpoint import load_checkpoint, save_checkpoint from .cp_utils import reduce_train_step_metrics from .data import DataIterator, get_batch -from .loss import loss_function +from .loss import get_log_probs_and_entropy, loss_function from .model_provider import get_model_provider_func +from .rl_kernel import ( + get_linear_logp_context_from_model, + return_hidden_states_for_linear_logp, + should_use_linear_logp_model_output, + warn_linear_logp_fallback, +) logger = logging.getLogger(__name__) @@ -73,6 +81,35 @@ def wrapped_forward_step(*args, **kwargs): return wrapped_forward_step +def _forward_only_should_return_hidden_for_linear_logp( + f: Callable[..., dict[str, list[torch.Tensor]]], + args: Namespace, +) -> bool: + return f is get_log_probs_and_entropy and should_use_linear_logp_model_output( + args, + with_entropy=args.use_rollout_entropy, + ) + + +def _train_should_return_hidden_for_linear_logp(args: Namespace, *, return_schedule_plan: bool) -> bool: + if not is_rl_kernel_op_enabled(args, "linear_logp"): + return False + + if args.loss_type not in {"policy_loss", "sft_loss"}: + return False + + if return_schedule_plan: + warn_linear_logp_fallback(args, "schedule-plan forward path is not supported") + return False + + if getattr(args, "enable_mtp_training", False): + warn_linear_logp_fallback(args, "MTP training path is not supported") + return False + + with_entropy = args.loss_type == "policy_loss" and getattr(args, "entropy_coef", 0.0) != 0 + return should_use_linear_logp_model_output(args, with_entropy=with_entropy) + + def _iter_critic_output_layers(model: Sequence[DDP]): for chunk_id, module in enumerate(unwrap_model(model)): output_layer = getattr(module, "output_layer", None) @@ -342,17 +379,25 @@ def forward_step( } if batch["multimodal_train_inputs"] is not None: forward_kwargs.update(batch["multimodal_train_inputs"]) - output_tensor = model(**forward_kwargs) + linear_logp_context = None + if _forward_only_should_return_hidden_for_linear_logp(f, args): + linear_logp_context = get_linear_logp_context_from_model(args, model) - return output_tensor, partial( - f, - args=args, - unconcat_tokens=unconcat_tokens, - total_lengths=total_lengths, - response_lengths=response_lengths, - with_entropy=args.use_rollout_entropy, - max_seq_lens=batch.get("max_seq_lens", None), - ) + with return_hidden_states_for_linear_logp(args, model, linear_logp_context): + output_tensor = model(**forward_kwargs) + + callback_kwargs = { + "args": args, + "unconcat_tokens": unconcat_tokens, + "total_lengths": total_lengths, + "response_lengths": response_lengths, + "with_entropy": args.use_rollout_entropy, + "max_seq_lens": batch.get("max_seq_lens", None), + } + if f is get_log_probs_and_entropy: + callback_kwargs["rl_kernel_linear_logp_context"] = linear_logp_context + + return output_tensor, partial(f, **callback_kwargs) # Turn on evaluation mode which disables dropout. for model_module in model: @@ -512,6 +557,10 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p old_stage = os.environ["ROUTING_REPLAY_STAGE"] os.environ["ROUTING_REPLAY_STAGE"] = "replay_forward" + linear_logp_context = None + if _train_should_return_hidden_for_linear_logp(args, return_schedule_plan=return_schedule_plan): + linear_logp_context = get_linear_logp_context_from_model(args, model) + if return_schedule_plan: assert not args.enable_mtp_training, "MTP training should not be enabled when using combined 1f1b" position_ids = None @@ -539,12 +588,20 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p if args.enable_mtp_training: forward_kwargs["mtp_kwargs"] = {"mtp_labels": batch["tokens"]} - output_tensor = model(**forward_kwargs) + with return_hidden_states_for_linear_logp(args, model, linear_logp_context): + output_tensor = model(**forward_kwargs) if os.environ.get("ENABLE_ROUTING_REPLAY", "0") == "1": os.environ["ROUTING_REPLAY_STAGE"] = old_stage - return output_tensor, partial(loss_function, args, batch, num_microbatches, step_global_batch_size) + return output_tensor, partial( + loss_function, + args, + batch, + num_microbatches, + step_global_batch_size, + rl_kernel_linear_logp_context=linear_logp_context, + ) # Forward pass. forward_backward_func = get_forward_backward_func() diff --git a/vime/backends/megatron_utils/rl_kernel.py b/vime/backends/megatron_utils/rl_kernel.py new file mode 100644 index 00000000..181afd40 --- /dev/null +++ b/vime/backends/megatron_utils/rl_kernel.py @@ -0,0 +1,319 @@ +from __future__ import annotations + +import logging +from argparse import Namespace +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any + +import torch +from megatron.core import mpu + +from vime.utils.rl_kernel import is_rl_kernel_op_enabled + +logger = logging.getLogger(__name__) + +_LOGP_OP = None +_LOGP_OP_LOAD_ERROR: Exception | None = None +_LINEAR_LOGP_OP = None +_LINEAR_LOGP_OP_LOAD_ERROR: Exception | None = None +_WARNED_FALLBACK_REASONS: set[str] = set() +_FALLBACK_COUNTS: dict[str, int] = {"logp": 0, "linear_logp": 0} + + +@dataclass(frozen=True) +class LinearLogpContext: + lm_head_weight: torch.Tensor + bias: torch.Tensor | None + tp_group: Any + vocab_start_index: int = 0 + global_vocab_size: int | None = None + sequence_parallel: bool = False + + +def get_rl_kernel_fallback_count(op: str | None = None) -> int: + if op is not None: + return _FALLBACK_COUNTS.get(op, 0) + return sum(_FALLBACK_COUNTS.values()) + + +def _warn_fallback(args: Namespace, op: str, reason: str) -> None: + _FALLBACK_COUNTS[op] = _FALLBACK_COUNTS.get(op, 0) + 1 + if getattr(args, "rl_kernel_strict", False): + raise RuntimeError(f"RL-Kernel {op} is enabled but unavailable: {reason}") + warning_key = f"{op}: {reason}" + if warning_key not in _WARNED_FALLBACK_REASONS: + logger.warning("Falling back to vime logprob path because RL-Kernel %s is unavailable: %s", op, reason) + _WARNED_FALLBACK_REASONS.add(warning_key) + + +def _get_logp_op(args: Namespace): + global _LOGP_OP, _LOGP_OP_LOAD_ERROR + if _LOGP_OP is not None: + return _LOGP_OP + if _LOGP_OP_LOAD_ERROR is not None: + _warn_fallback(args, "logp", str(_LOGP_OP_LOAD_ERROR)) + return None + + try: + from rl_engine.kernels.registry import kernel_registry + + _LOGP_OP = kernel_registry.get_op("logp") + logger.info("Using RL-Kernel logp op: %s", type(_LOGP_OP).__name__) + return _LOGP_OP + except Exception as exc: # pragma: no cover - exercised with missing optional package in integration envs + _LOGP_OP_LOAD_ERROR = exc + _warn_fallback(args, "logp", str(exc)) + return None + + +def _get_linear_logp_op(args: Namespace): + global _LINEAR_LOGP_OP, _LINEAR_LOGP_OP_LOAD_ERROR + if _LINEAR_LOGP_OP is not None: + return _LINEAR_LOGP_OP + if _LINEAR_LOGP_OP_LOAD_ERROR is not None: + _warn_fallback(args, "linear_logp", str(_LINEAR_LOGP_OP_LOAD_ERROR)) + return None + + try: + from rl_engine.kernels.registry import kernel_registry + + _LINEAR_LOGP_OP = kernel_registry.get_op("linear_logp") + logger.info("Using RL-Kernel linear_logp op: %s", type(_LINEAR_LOGP_OP).__name__) + return _LINEAR_LOGP_OP + except Exception as exc: # pragma: no cover - exercised with missing optional package in integration envs + _LINEAR_LOGP_OP_LOAD_ERROR = exc + _warn_fallback(args, "linear_logp", str(exc)) + return None + + +def _unwrap_model_chunk(model): + while hasattr(model, "module"): + model = model.module + return model + + +def _is_pipeline_last_stage_for_model(model) -> bool: + module = _unwrap_model_chunk(model) + vp_stage = getattr(module, "vp_stage", None) + try: + vp_world_size = mpu.get_virtual_pipeline_model_parallel_world_size() + except Exception: + vp_world_size = None + + try: + if vp_world_size is not None and vp_stage is not None: + return bool(mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage)) + return bool(mpu.is_pipeline_last_stage(ignore_virtual=True)) + except Exception: + return True + + +def _get_lm_head_weight(model, output_layer) -> torch.Tensor | None: + weight = getattr(output_layer, "weight", None) + if isinstance(weight, torch.Tensor): + return weight + + shared_weight = getattr(model, "shared_embedding_or_output_weight", None) + if callable(shared_weight): + try: + weight = shared_weight() + if isinstance(weight, torch.Tensor): + return weight + except Exception: + logger.debug("Unable to read shared embedding/output weight for RL-Kernel linear_logp.", exc_info=True) + + return None + + +def get_linear_logp_context_from_model(args: Namespace, model) -> LinearLogpContext | None: + if not is_rl_kernel_op_enabled(args, "linear_logp"): + return None + + if not _is_pipeline_last_stage_for_model(model): + return None + + module = _unwrap_model_chunk(model) + output_layer = getattr(module, "output_layer", None) + if output_layer is None: + _warn_fallback(args, "linear_logp", "model output_layer is unavailable") + return None + + weight = _get_lm_head_weight(module, output_layer) + if weight is None: + _warn_fallback(args, "linear_logp", "LM-head weight is unavailable") + return None + + bias = getattr(output_layer, "bias", None) + if not isinstance(bias, torch.Tensor): + bias = None + + tp_world_size = int(mpu.get_tensor_model_parallel_world_size()) + tp_group = mpu.get_tensor_model_parallel_group() if tp_world_size > 1 else None + vocab_start_index = 0 + global_vocab_size = None + if tp_world_size > 1: + local_vocab_size = int(weight.size(0)) + vocab_start_index = int(mpu.get_tensor_model_parallel_rank()) * local_vocab_size + global_vocab_size = getattr(args, "padded_vocab_size", None) + if global_vocab_size is None: + global_vocab_size = local_vocab_size * tp_world_size + + return LinearLogpContext( + lm_head_weight=weight, + bias=bias, + tp_group=tp_group, + vocab_start_index=vocab_start_index, + global_vocab_size=None if global_vocab_size is None else int(global_vocab_size), + sequence_parallel=bool(getattr(output_layer, "sequence_parallel", getattr(args, "sequence_parallel", False))), + ) + + +def _linear_logp_runtime_blocker(args: Namespace, *, with_entropy: bool) -> str | None: + if with_entropy: + return "entropy is requested" + if getattr(args, "qkv_format", "thd") != "thd": + return "only qkv_format=thd is supported by RL-Kernel linear_logp" + if mpu.get_context_parallel_world_size() != 1 or getattr(args, "allgather_cp", False): + return "context parallel logprob redistribution is not supported by RL-Kernel linear_logp" + if getattr(args, "rollout_temperature", 1.0) <= 0: + return "rollout_temperature must be positive" + return None + + +def should_use_linear_logp_model_output(args: Namespace, *, with_entropy: bool) -> bool: + if not is_rl_kernel_op_enabled(args, "linear_logp"): + return False + reason = _linear_logp_runtime_blocker(args, with_entropy=with_entropy) + if reason is not None: + _warn_fallback(args, "linear_logp", reason) + return False + return True + + +def warn_linear_logp_fallback(args: Namespace, reason: str) -> None: + _warn_fallback(args, "linear_logp", reason) + + +@contextmanager +def return_hidden_states_for_linear_logp(args: Namespace, model, context: LinearLogpContext | None): + if context is None: + yield False + return + + module = _unwrap_model_chunk(model) + if not hasattr(module, "post_process"): + _warn_fallback(args, "linear_logp", "model post_process flag is unavailable") + yield False + return + + old_post_process = module.post_process + module.post_process = False + try: + yield True + finally: + module.post_process = old_post_process + + +def maybe_compute_logp( + logits: torch.Tensor, + tokens: torch.Tensor, + *, + args: Namespace, + with_entropy: bool, +) -> torch.Tensor | None: + """Return selected log-probs from RL-Kernel when this runtime is safe. + + The first integration deliberately limits itself to forward-only logprob + precompute paths: no autograd, no vocab tensor parallelism, no CP + redistribution, and no entropy. Unsupported cases fall back to vime's + Megatron-aware implementation. + """ + if not is_rl_kernel_op_enabled(args, "logp"): + return None + + if with_entropy: + _warn_fallback(args, "logp", "entropy is requested") + return None + + if logits.requires_grad or torch.is_grad_enabled(): + _warn_fallback(args, "logp", "autograd is enabled") + return None + + if mpu.get_tensor_model_parallel_world_size() != 1: + _warn_fallback(args, "logp", "tensor-parallel vocab shards are not supported by RL-Kernel logp") + return None + + if mpu.get_context_parallel_world_size() != 1 or getattr(args, "allgather_cp", False): + _warn_fallback(args, "logp", "context parallel logprob redistribution is not supported by RL-Kernel logp") + return None + + if logits.size(0) == 0: + return logits.new_zeros((0,), dtype=torch.float32) + + op = _get_logp_op(args) + if op is None: + return None + + try: + if hasattr(op, "apply_fp32"): + log_prob = op.apply_fp32(logits, tokens) + else: + log_prob = op(logits, tokens).float() + except Exception as exc: + _warn_fallback(args, "logp", str(exc)) + return None + + return log_prob.reshape(-1) + + +def maybe_compute_linear_logp( + hidden_states: torch.Tensor, + target_ids: torch.Tensor, + *, + context: LinearLogpContext | None, + args: Namespace, + with_entropy: bool, +) -> torch.Tensor | None: + if not is_rl_kernel_op_enabled(args, "linear_logp"): + return None + + reason = _linear_logp_runtime_blocker(args, with_entropy=with_entropy) + if reason is not None: + _warn_fallback(args, "linear_logp", reason) + return None + + if context is None: + _warn_fallback(args, "linear_logp", "hidden-state linear_logp context is unavailable") + return None + + if target_ids.numel() == 0: + return hidden_states.new_zeros((0,), dtype=torch.float32) + + op = _get_linear_logp_op(args) + if op is None: + return None + + weight = context.lm_head_weight + bias = context.bias + rollout_temperature = float(getattr(args, "rollout_temperature", 1.0)) + if rollout_temperature != 1.0: + weight = weight / rollout_temperature + if bias is not None: + bias = bias / rollout_temperature + + try: + log_prob = op( + hidden_states, + weight, + target_ids.long(), + bias, + tp_group=context.tp_group, + vocab_start_index=context.vocab_start_index, + global_vocab_size=context.global_vocab_size, + ) + except Exception as exc: + _warn_fallback(args, "linear_logp", str(exc)) + return None + + return log_prob.float().reshape(-1) diff --git a/vime/ray/actor_group.py b/vime/ray/actor_group.py index db3500dc..0e601671 100644 --- a/vime/ray/actor_group.py +++ b/vime/ray/actor_group.py @@ -63,6 +63,7 @@ def _allocate_gpus_for_actor(self, pg, num_gpus_per_actor): import torch_memory_saver for path in [ + "torch_memory_saver_hook_mode_preload_cu13.abi3.so", "torch_memory_saver_hook_mode_preload_cu12.abi3.so", "torch_memory_saver_hook_mode_preload.abi3.so", ]: diff --git a/vime/utils/arguments.py b/vime/utils/arguments.py index 147f4d2b..a186a756 100644 --- a/vime/utils/arguments.py +++ b/vime/utils/arguments.py @@ -12,6 +12,7 @@ from vime.backends.vllm_utils.arguments import vllm_parse_args from vime.utils.eval_config import EvalDatasetConfig, build_eval_dataset_configs, ensure_dataset_list from vime.utils.logging_utils import configure_logger +from vime.utils.rl_kernel import normalize_rl_kernel_args logger = logging.getLogger(__name__) @@ -157,6 +158,24 @@ def add_train_arguments(parser): parser.add_argument( "--log-probs-chunk-size", type=int, default=-1, help="Chunk size to compute log probs to save memory" ) + parser.add_argument( + "--enable-rl-kernel", + action="store_true", + default=False, + help="Enable optional RL-Kernel acceleration for supported vime training/forward-only paths.", + ) + parser.add_argument( + "--rl-kernel-ops", + type=str, + default="linear_logp", + help="Comma-separated RL-Kernel ops to enable. Current production integration supports: linear_logp.", + ) + parser.add_argument( + "--rl-kernel-strict", + action="store_true", + default=False, + help="Raise instead of falling back when an enabled RL-Kernel op is unavailable or unsupported.", + ) parser.add_argument( "--only-train-params-name-list", type=str, @@ -1606,6 +1625,8 @@ def _resolve_eval_datasets(args) -> list[EvalDatasetConfig]: def vime_validate_args(args): + normalize_rl_kernel_args(args) + args.eval_datasets = _resolve_eval_datasets(args) if args.kl_coef != 0 or args.use_kl_loss: @@ -1849,6 +1870,17 @@ def vime_validate_args(args): if hasattr(args, k): logger.info(f"Warning: Argument {k} is already set to {getattr(args, k)}, will override with {v}.") setattr(args, k, v) + # vllm launch_server_process distinguishes "user-supplied value" from + # "argparse default" via ``args._vllm_user_provided``. YAML overrides + # bypass argparse, so we register them explicitly here — without this, + # YAML values that happen to equal the vllm-side default (e.g. + # ``vllm_gpu_memory_utilization: 0.92``) would be treated as "default" + # and silently replaced by vime's preferred value. + if isinstance(k, str) and k.startswith("vllm_"): + if not hasattr(args, "_vllm_user_provided"): + args._vllm_user_provided = set() + args._vllm_user_provided.add(k) + normalize_rl_kernel_args(args) if args.eval_max_context_len is None: logger.info( diff --git a/vime/utils/rl_kernel.py b/vime/utils/rl_kernel.py new file mode 100644 index 00000000..e0e4a0b7 --- /dev/null +++ b/vime/utils/rl_kernel.py @@ -0,0 +1,79 @@ +import os +from argparse import Namespace +from collections.abc import Iterable + + +RL_KERNEL_SUPPORTED_OPS = ("linear_logp",) +RL_KERNEL_INTEGRATED_OPS = ("linear_logp",) +_TRUE_VALUES = {"1", "true", "yes", "on"} +_FALSE_VALUES = {"0", "false", "no", "off"} + + +def parse_rl_kernel_ops(value: str | Iterable[str] | None) -> tuple[str, ...]: + """Parse a comma/space separated RL-Kernel op list.""" + if value is None: + return ("linear_logp",) + + if isinstance(value, str): + raw_items = value.replace(",", " ").split() + else: + raw_items = [] + for item in value: + raw_items.extend(str(item).replace(",", " ").split()) + + ops: list[str] = [] + for item in raw_items: + op = item.strip().lower() + if not op: + continue + if op not in RL_KERNEL_SUPPORTED_OPS: + supported = ", ".join(RL_KERNEL_SUPPORTED_OPS) + raise ValueError(f"Unsupported RL-Kernel op '{op}'. Supported ops: {supported}.") + if op not in ops: + ops.append(op) + + return tuple(ops) if ops else ("linear_logp",) + + +def _env_bool(name: str) -> bool | None: + value = os.getenv(name) + if value is None: + return None + normalized = value.strip().lower() + if normalized in _TRUE_VALUES: + return True + if normalized in _FALSE_VALUES: + return False + raise ValueError(f"{name} must be one of: 1/0, true/false, yes/no, on/off.") + + +def normalize_rl_kernel_args(args: Namespace) -> Namespace: + """Apply environment overrides and validate RL-Kernel integration args.""" + env_enabled = _env_bool("VIME_RL_KERNEL") + if env_enabled is not None: + args.enable_rl_kernel = env_enabled + + env_strict = _env_bool("VIME_RL_KERNEL_STRICT") + if env_strict is not None: + args.rl_kernel_strict = env_strict + + env_ops = os.getenv("VIME_RL_KERNEL_OPS") + if env_ops is not None: + args.rl_kernel_ops = env_ops + + args.rl_kernel_ops = parse_rl_kernel_ops(getattr(args, "rl_kernel_ops", None)) + + if getattr(args, "enable_rl_kernel", False): + unsupported = [op for op in args.rl_kernel_ops if op not in RL_KERNEL_INTEGRATED_OPS] + if unsupported: + integrated = ", ".join(RL_KERNEL_INTEGRATED_OPS) + raise ValueError( + "This vime RL-Kernel integration currently supports only " + f"{integrated}. Requested future ops: {', '.join(unsupported)}." + ) + + return args + + +def is_rl_kernel_op_enabled(args: Namespace, op: str) -> bool: + return getattr(args, "enable_rl_kernel", False) and op in getattr(args, "rl_kernel_ops", ())