Skip to content

[WIP]add hp as fsdp engine to verl#6835

Draft
Alan-zr wants to merge 1 commit into
verl-project:mainfrom
Alan-zr:hp
Draft

[WIP]add hp as fsdp engine to verl#6835
Alan-zr wants to merge 1 commit into
verl-project:mainfrom
Alan-zr:hp

Conversation

@Alan-zr

@Alan-zr Alan-zr commented Jun 24, 2026

Copy link
Copy Markdown

What does this PR do?

Add concise overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, veomni, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data, cfg, reward, fully_async, one_step_off
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

@CLAassistant

Copy link
Copy Markdown

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.


zhang_xue_tong seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces HyperParallel engine support to verl, adding new configuration files, the HyperParallelEngine and HyperParallelEngineWithLMHead implementations, and associated utility functions. It also improves the GSM8K reward scoring with a more flexible answer extraction method. The code review highlights several critical issues: a potential OOM error from calling state_dict() on all ranks, a missing Ulysses SP padding/slicing implementation causing a KeyError, a missing SkipDTensorDispatch context manager in the training loop, a potential TypeError from passing arbitrary keyword arguments in gsm8k_flexible.py, a copy-paste configuration error in hyperparallel_critic.yaml, and a missing divisibility check for fsdp_size.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

"reshard_after_forward": self.engine_config.reshard_after_forward,
}
# if not self.engine_config.forward_only:
full_state = module.state_dict()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Calling module.state_dict() on all ranks before the model is sharded will cause a massive CPU/GPU memory overhead and likely lead to Out-Of-Memory (OOM) errors on non-zero ranks for large models. Since fsdp2_load_full_state_dict broadcasts the state dict from rank 0, non-zero ranks only need an empty dictionary.

Suggested change
full_state = module.state_dict()
full_state = module.state_dict() if self.rank == 0 else {}

Comment on lines +440 to +445
# for compute the log_prob
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz)
input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad)
temperature_rmpad = temperature_rmpad.squeeze(0)
output_args["input_ids_rmpad_rolled"] = input_ids_rmpad_rolled
output_args["temperature_rmpad"] = temperature_rmpad

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

HyperParallelEngineWithLMHead.prepare_model_inputs is missing the Ulysses SP padding and slicing logic. When ulysses_sequence_parallel_size > 1, this will cause a KeyError: 'pad_size' in prepare_model_outputs because pad_size is never set in output_args. We should add the padding and slicing logic here.

            # for compute the log_prob
            input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1)  # (1, total_nnz)

            if self.use_ulysses_sp:
                from verl.utils.ulysses import ulysses_pad, ulysses_pad_and_slice_inputs
                is_vlm_model = hasattr(getattr(self.module, "module", self.module).config, "vision_config")
                if is_vlm_model:
                    input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad(
                        input_ids_rmpad,
                        position_ids_rmpad=position_ids_rmpad,
                        sp_size=self.ulysses_sequence_parallel_size,
                    )
                else:
                    input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
                        input_ids_rmpad,
                        position_ids_rmpad=position_ids_rmpad,
                        sp_size=self.ulysses_sequence_parallel_size,
                    )
                input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(
                    input_ids_rmpad_rolled,
                    position_ids_rmpad=None,
                    sp_size=self.ulysses_sequence_parallel_size,
                )
                temperature_rmpad, _, _ = ulysses_pad_and_slice_inputs(
                    temperature_rmpad, position_ids_rmpad=None, sp_size=self.ulysses_sequence_parallel_size, pad_value=1
                )
                output_args["pad_size"] = pad_size

            input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0)
            temperature_rmpad = temperature_rmpad.squeeze(0)
            output_args["input_ids_rmpad_rolled"] = input_ids_rmpad_rolled
            output_args["temperature_rmpad"] = temperature_rmpad

Comment on lines +303 to +310
for micro_batch in micro_batches:
with ctx:
loss, meta_info = self.forward_step(
micro_batch, loss_function=loss_function, forward_only=forward_only,
)
if not forward_only:
loss.backward()
output_lst.append(meta_info)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The comment explicitly mentions that the training loop is run inside SkipDTensorDispatch to avoid PyTorch's view-safety check error, but the context manager is missing in the actual code. Wrapping the micro-batch loop in SkipDTensorDispatch is necessary to prevent runtime crashes during training.

Suggested change
for micro_batch in micro_batches:
with ctx:
loss, meta_info = self.forward_step(
micro_batch, loss_function=loss_function, forward_only=forward_only,
)
if not forward_only:
loss.backward()
output_lst.append(meta_info)
with SkipDTensorDispatch():
for micro_batch in micro_batches:
with ctx:
loss, meta_info = self.forward_step(
micro_batch, loss_function=loss_function, forward_only=forward_only,
)
if not forward_only:
loss.backward()
output_lst.append(meta_info)

Comment on lines +47 to +52
return gsm8k_compute_score(
solution_str=solution_str,
ground_truth=actual_ground_truth,
method="flexible",
**kwargs
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

gsm8k_compute_score does not accept arbitrary **kwargs in its signature. Passing **kwargs directly can cause a TypeError if unexpected keyword arguments are present in the dictionary. We should explicitly extract and pass only the supported arguments (format_score and score).

Suggested change
return gsm8k_compute_score(
solution_str=solution_str,
ground_truth=actual_ground_truth,
method="flexible",
**kwargs
)
return gsm8k_compute_score(
solution_str=solution_str,
ground_truth=actual_ground_truth,
method="flexible",
format_score=kwargs.get("format_score", 0.0),
score=kwargs.get("score", 1.0),
)

Comment on lines +15 to +17
_target_: verl.workers.config.FSDPCriticConfig

strategy: fsdp

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This configuration incorrectly targets FSDPCriticConfig and sets strategy: fsdp. This is a copy-paste error and will cause Hydra initialization errors because FSDPCriticConfig does not have a hyperparallel field, or it will fail to use the hyperparallel engine. It should target HyperParallelCriticConfig (or the appropriate config class) and use strategy: hyperparallel.

Comment on lines +117 to +122
if fsdp_size < 0 or fsdp_size >= world_size:
self.device_mesh = init_device_mesh(device_name, mesh_shape=(world_size,), mesh_dim_names=["fsdp"])
else:
self.device_mesh = init_device_mesh(
device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"]
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When fsdp_size is specified, we should assert that world_size is divisible by fsdp_size to prevent shape mismatch errors in init_device_mesh.

Suggested change
if fsdp_size < 0 or fsdp_size >= world_size:
self.device_mesh = init_device_mesh(device_name, mesh_shape=(world_size,), mesh_dim_names=["fsdp"])
else:
self.device_mesh = init_device_mesh(
device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"]
)
if fsdp_size < 0 or fsdp_size >= world_size:
self.device_mesh = init_device_mesh(device_name, mesh_shape=(world_size,), mesh_dim_names=["fsdp"])
else:
assert world_size % fsdp_size == 0, f"world_size ({world_size}) must be divisible by fsdp_size ({fsdp_size})"
self.device_mesh = init_device_mesh(
device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"]
)

@wuxibin89

Copy link
Copy Markdown
Collaborator

@Alan-zr Thanks for your contribution, we're not going to maintain more training backend in source tree. Instead, we provide hook mechanism to dynamically plugin external training backend.

https://verl.readthedocs.io/en/latest/extend_guide.html#i-m-a-training-framework-developer-how-do-i-extend-verl-to-support-my-own-training-framework

@Alan-zr Alan-zr changed the title add hyperparallel as fsdp engine to verl [WIP]add hyperparallel as fsdp engine to verl Jun 24, 2026
@Alan-zr Alan-zr changed the title [WIP]add hyperparallel as fsdp engine to verl [WIP]add hp as fsdp engine to verl Jun 24, 2026
@Alan-zr Alan-zr marked this pull request as draft June 24, 2026 08:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants