[WIP]add hp as fsdp engine to verl#6835
Conversation
|
zhang_xue_tong seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account. You have signed the CLA already but the status is still pending? Let us recheck it. |
There was a problem hiding this comment.
Code Review
This pull request introduces HyperParallel engine support to verl, adding new configuration files, the HyperParallelEngine and HyperParallelEngineWithLMHead implementations, and associated utility functions. It also improves the GSM8K reward scoring with a more flexible answer extraction method. The code review highlights several critical issues: a potential OOM error from calling state_dict() on all ranks, a missing Ulysses SP padding/slicing implementation causing a KeyError, a missing SkipDTensorDispatch context manager in the training loop, a potential TypeError from passing arbitrary keyword arguments in gsm8k_flexible.py, a copy-paste configuration error in hyperparallel_critic.yaml, and a missing divisibility check for fsdp_size.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| "reshard_after_forward": self.engine_config.reshard_after_forward, | ||
| } | ||
| # if not self.engine_config.forward_only: | ||
| full_state = module.state_dict() |
There was a problem hiding this comment.
Calling module.state_dict() on all ranks before the model is sharded will cause a massive CPU/GPU memory overhead and likely lead to Out-Of-Memory (OOM) errors on non-zero ranks for large models. Since fsdp2_load_full_state_dict broadcasts the state dict from rank 0, non-zero ranks only need an empty dictionary.
| full_state = module.state_dict() | |
| full_state = module.state_dict() if self.rank == 0 else {} |
| # for compute the log_prob | ||
| input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) | ||
| input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) | ||
| temperature_rmpad = temperature_rmpad.squeeze(0) | ||
| output_args["input_ids_rmpad_rolled"] = input_ids_rmpad_rolled | ||
| output_args["temperature_rmpad"] = temperature_rmpad |
There was a problem hiding this comment.
HyperParallelEngineWithLMHead.prepare_model_inputs is missing the Ulysses SP padding and slicing logic. When ulysses_sequence_parallel_size > 1, this will cause a KeyError: 'pad_size' in prepare_model_outputs because pad_size is never set in output_args. We should add the padding and slicing logic here.
# for compute the log_prob
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz)
if self.use_ulysses_sp:
from verl.utils.ulysses import ulysses_pad, ulysses_pad_and_slice_inputs
is_vlm_model = hasattr(getattr(self.module, "module", self.module).config, "vision_config")
if is_vlm_model:
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad(
input_ids_rmpad,
position_ids_rmpad=position_ids_rmpad,
sp_size=self.ulysses_sequence_parallel_size,
)
else:
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad,
position_ids_rmpad=position_ids_rmpad,
sp_size=self.ulysses_sequence_parallel_size,
)
input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(
input_ids_rmpad_rolled,
position_ids_rmpad=None,
sp_size=self.ulysses_sequence_parallel_size,
)
temperature_rmpad, _, _ = ulysses_pad_and_slice_inputs(
temperature_rmpad, position_ids_rmpad=None, sp_size=self.ulysses_sequence_parallel_size, pad_value=1
)
output_args["pad_size"] = pad_size
input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0)
temperature_rmpad = temperature_rmpad.squeeze(0)
output_args["input_ids_rmpad_rolled"] = input_ids_rmpad_rolled
output_args["temperature_rmpad"] = temperature_rmpad| for micro_batch in micro_batches: | ||
| with ctx: | ||
| loss, meta_info = self.forward_step( | ||
| micro_batch, loss_function=loss_function, forward_only=forward_only, | ||
| ) | ||
| if not forward_only: | ||
| loss.backward() | ||
| output_lst.append(meta_info) |
There was a problem hiding this comment.
The comment explicitly mentions that the training loop is run inside SkipDTensorDispatch to avoid PyTorch's view-safety check error, but the context manager is missing in the actual code. Wrapping the micro-batch loop in SkipDTensorDispatch is necessary to prevent runtime crashes during training.
| for micro_batch in micro_batches: | |
| with ctx: | |
| loss, meta_info = self.forward_step( | |
| micro_batch, loss_function=loss_function, forward_only=forward_only, | |
| ) | |
| if not forward_only: | |
| loss.backward() | |
| output_lst.append(meta_info) | |
| with SkipDTensorDispatch(): | |
| for micro_batch in micro_batches: | |
| with ctx: | |
| loss, meta_info = self.forward_step( | |
| micro_batch, loss_function=loss_function, forward_only=forward_only, | |
| ) | |
| if not forward_only: | |
| loss.backward() | |
| output_lst.append(meta_info) |
| return gsm8k_compute_score( | ||
| solution_str=solution_str, | ||
| ground_truth=actual_ground_truth, | ||
| method="flexible", | ||
| **kwargs | ||
| ) |
There was a problem hiding this comment.
gsm8k_compute_score does not accept arbitrary **kwargs in its signature. Passing **kwargs directly can cause a TypeError if unexpected keyword arguments are present in the dictionary. We should explicitly extract and pass only the supported arguments (format_score and score).
| return gsm8k_compute_score( | |
| solution_str=solution_str, | |
| ground_truth=actual_ground_truth, | |
| method="flexible", | |
| **kwargs | |
| ) | |
| return gsm8k_compute_score( | |
| solution_str=solution_str, | |
| ground_truth=actual_ground_truth, | |
| method="flexible", | |
| format_score=kwargs.get("format_score", 0.0), | |
| score=kwargs.get("score", 1.0), | |
| ) |
| _target_: verl.workers.config.FSDPCriticConfig | ||
|
|
||
| strategy: fsdp |
There was a problem hiding this comment.
This configuration incorrectly targets FSDPCriticConfig and sets strategy: fsdp. This is a copy-paste error and will cause Hydra initialization errors because FSDPCriticConfig does not have a hyperparallel field, or it will fail to use the hyperparallel engine. It should target HyperParallelCriticConfig (or the appropriate config class) and use strategy: hyperparallel.
| if fsdp_size < 0 or fsdp_size >= world_size: | ||
| self.device_mesh = init_device_mesh(device_name, mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) | ||
| else: | ||
| self.device_mesh = init_device_mesh( | ||
| device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"] | ||
| ) |
There was a problem hiding this comment.
When fsdp_size is specified, we should assert that world_size is divisible by fsdp_size to prevent shape mismatch errors in init_device_mesh.
| if fsdp_size < 0 or fsdp_size >= world_size: | |
| self.device_mesh = init_device_mesh(device_name, mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) | |
| else: | |
| self.device_mesh = init_device_mesh( | |
| device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"] | |
| ) | |
| if fsdp_size < 0 or fsdp_size >= world_size: | |
| self.device_mesh = init_device_mesh(device_name, mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) | |
| else: | |
| assert world_size % fsdp_size == 0, f"world_size ({world_size}) must be divisible by fsdp_size ({fsdp_size})" | |
| self.device_mesh = init_device_mesh( | |
| device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"] | |
| ) |
|
@Alan-zr Thanks for your contribution, we're not going to maintain more training backend in source tree. Instead, we provide hook mechanism to dynamically plugin external training backend. |
What does this PR do?
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,fully_async,one_step_off,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.