diff --git a/verl/single_controller/base/decorator.py b/verl/single_controller/base/decorator.py index 2727ee07f99..f764718ce04 100644 --- a/verl/single_controller/base/decorator.py +++ b/verl/single_controller/base/decorator.py @@ -217,7 +217,7 @@ def collect_dp_compute_data_proto(worker_group, output): return _concat_data_proto_or_future(output) -def dispatch_nd_compute(dp_rank_mapping: list[int], dp_size, worker_group, collect_mask, *args, **kwargs): +def dispatch_nd_compute(dp_rank_mapping: list[int], dp_size, worker_group, *args, **kwargs): import os from verl.single_controller.base.worker_group import WorkerGroup @@ -248,10 +248,6 @@ def dispatch_nd_compute(dp_rank_mapping: list[int], dp_size, worker_group, colle local_dp_rank = dp_rank_mapping[i] transformed_v.append(v[local_dp_rank]) all_kwargs[k] = transformed_v - - # add kwargs determing whether to collect from this rank - all_kwargs["collect_from_rank"] = [collect_mask[i] for i in range(worker_group.world_size)] - return all_args, all_kwargs @@ -269,9 +265,9 @@ def collect_nd_compute(collect_mask: list[bool], worker_group, output): return output_in_dp -def dispatch_nd_compute_dataproto(dp_rank_mapping: list[int], dp_size, worker_group, collect_mask, *args, **kwargs): +def dispatch_nd_compute_dataproto(dp_rank_mapping: list[int], dp_size, worker_group, *args, **kwargs): splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(dp_size, *args, **kwargs) - return dispatch_nd_compute(dp_rank_mapping, dp_size, worker_group, collect_mask, *splitted_args, **splitted_kwargs) + return dispatch_nd_compute(dp_rank_mapping, dp_size, worker_group, *splitted_args, **splitted_kwargs) def collect_nd_compute_dataproto(collect_mask: list[bool], worker_group, output): @@ -297,20 +293,10 @@ def dispatch_lazy_compute_data_proto(mesh_name, worker_group, *args, **kwargs): worker_group._dispatch_info[mesh_name] = worker_group._query_dispatch_info(mesh_name) assert len(worker_group._dispatch_info[mesh_name]) == worker_group.world_size - # the dispatch info is stored in the worker group - assert mesh_name in worker_group._dispatch_info - if mesh_name not in worker_group._collect_info: - worker_group._collect_info[mesh_name] = worker_group._query_collect_info(mesh_name) - assert len(worker_group._collect_info[mesh_name]) == worker_group.world_size - dp_rank_mapping = worker_group._dispatch_info[mesh_name] - - # a boolean of whether the dp_rank is used for collect - collect_mask = worker_group._collect_info[mesh_name] - # perform dispatch dp_size = max(dp_rank_mapping) + 1 - return dispatch_nd_compute_dataproto(dp_rank_mapping, dp_size, worker_group, collect_mask, *args, **kwargs) + return dispatch_nd_compute_dataproto(dp_rank_mapping, dp_size, worker_group, *args, **kwargs) def collect_lazy_compute_data_proto(mesh_name, worker_group, *args, **kwargs): @@ -456,7 +442,7 @@ def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocki _check_execute_mode(execute_mode=execute_mode) def decorator(func): - func = tqbridge()(func) + func = tqbridge(dispatch_mode=dispatch_mode)(func) @wraps(func) def inner(*args, **kwargs): diff --git a/verl/utils/transferqueue_utils.py b/verl/utils/transferqueue_utils.py index 0e9f8f30873..4a35cbfb56c 100644 --- a/verl/utils/transferqueue_utils.py +++ b/verl/utils/transferqueue_utils.py @@ -18,7 +18,10 @@ import os import threading from functools import wraps -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable + +if TYPE_CHECKING: + from verl.single_controller.base.decorator import Dispatch from tensordict import TensorDict @@ -144,7 +147,84 @@ def _update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta", fun return updated_batch_meta -def tqbridge(put_data: bool = True): +def _compute_need_collect(dispatch_mode: dict | "Dispatch", args: list) -> bool: + """Compute whether data collection is needed for the current worker. + + This function determines whether the current worker should collect data based on + the dispatch mode configuration and worker parameters. It's used to optimize + distributed data collection by ensuring only the appropriate rank collects data. + + Args: + dispatch_mode: Controls data collection logic for the current worker. Can be None, + a Dispatch instance, or a dict with 'collect_fn' key. If None or Dispatch, + always returns True (current worker should collect). If dict, checks + collect_fn for lazy compute optimization. + args: List of arguments passed to the function. Should contain a Worker instance + as the first argument when using lazy compute mode. + + Returns: + bool: True if data collection is needed, False otherwise. + + Note: + Only checks worker attributes when dispatch_mode is a dict with 'collect_fn', + the collect_fn is 'collect_lazy_compute_data_proto', and args[0] is a Worker. + Otherwise, returns True. For the lazy compute case, checks the worker's + data parallel rank for the mesh specified in collect_fn.args[0] to determine + if this worker should collect data. + """ + from verl.single_controller.base.decorator import Dispatch + from verl.single_controller.base.worker import Worker + + if dispatch_mode is None or isinstance(dispatch_mode, Dispatch): + return True + + assert "collect_fn" in dispatch_mode.keys(), "collect_fn should be in dispatch_mode." + collect_fn_name = dispatch_mode["collect_fn"].func.__name__ + if collect_fn_name != "collect_lazy_compute_data_proto" or len(args) < 1 or not isinstance(args[0], Worker): + return True + + collect_mesh_name = dispatch_mode["collect_fn"].args[0] + return args[0]._Worker__collect_dp_rank[collect_mesh_name] + + +def _postprocess_common(output, put_data, need_collect): + """Common post-processing logic for function outputs in TransferQueue bridge. + + This function handles the final return value based on whether data should be + put into storage (put_data) and whether collection is needed (need_collect). + It ensures proper return types based on the execution context. + + Args: + output: The original output from the decorated function. Can be any type, + typically DataProto when working with transfer queues. + put_data: bool, indicating whether the output should be stored in TransferQueue. + If True, output will be converted to BatchMeta; if False, returned as-is + or converted to DataProto. + need_collect: bool, indicating whether this process needs to collect data. + If False and put_data is True, returns empty BatchMeta to avoid + redundant storage. + + Returns: + - BatchMeta.empty(): When put_data=True but need_collect=False, indicating + no data should be stored but BatchMeta structure is expected. + - DataProto(): When put_data=False, need_collect=False, and output is DataProto, + returning an empty DataProto. + - output: In all other cases, returns the original output unchanged. + + Note: + This function is used in the tqbridge decorator to normalize return values + across different execution paths and avoid redundant data operations in + distributed scenarios. + """ + if put_data and not need_collect: + return BatchMeta.empty() + elif not put_data and not need_collect and isinstance(output, DataProto): + return DataProto() + else: + return output + + +def tqbridge(dispatch_mode: dict | "Dispatch" = None, put_data: bool = True): """Creates a decorator for bridging BatchMeta and DataProto. This decorator automatically handles conversions between `BatchMeta` and @@ -155,6 +235,9 @@ def tqbridge(put_data: bool = True): simply calls the original function as-is). Args: + dispatch_mode: Controls data collection behavior for the current worker. Passed to + _compute_need_collect to determine if current worker should collect data. + If None, _compute_need_collect returns True (current worker collects). put_data: Whether put the DataProto into Storage after func return. If True, after function execution, the output result will be updated to `BatchMeta` and `BatchMeta` will be returned; @@ -178,24 +261,15 @@ def inner(*args, **kwargs): f"Task {func.__name__} (pid={pid}) is getting len_samples={batchmeta.size}, " f"global_idx={batchmeta.global_indexes}" ) - - if "collect_from_rank" in kwargs: - collect_from_rank = kwargs["collect_from_rank"] - kwargs.pop("collect_from_rank") - else: - collect_from_rank = None - args = [_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args] kwargs = {k: _batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v for k, v in kwargs.items()} output = func(*args, **kwargs) - - if put_data and collect_from_rank: + need_collect = _compute_need_collect(dispatch_mode, args) + updated_batch_meta = None + if put_data and need_collect: updated_batch_meta = _update_batchmeta_with_output(output, batchmeta, func.__name__) return updated_batch_meta - elif collect_from_rank == False: - return BatchMeta() - else: - return output + return _postprocess_common(output, put_data, need_collect, updated_batch_meta) @wraps(func) async def async_inner(*args, **kwargs): @@ -207,39 +281,34 @@ async def async_inner(*args, **kwargs): f"Task {func.__name__} (pid={pid}) is getting len_samples={batchmeta.size}, " f"global_idx={batchmeta.global_indexes}" ) - - if "collect_from_rank" in kwargs: - collect_from_rank = kwargs["collect_from_rank"] - print(f"{func.__name__} with TQ put={kwargs['collect_from_rank']}") - kwargs.pop("collect_from_rank") - else: - collect_from_rank = None - args = [await _async_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args] kwargs = { k: await _async_batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v for k, v in kwargs.items() } output = await func(*args, **kwargs) - - if put_data and collect_from_rank: + need_collect = _compute_need_collect(dispatch_mode, args) + updated_batchmeta = None + if put_data and need_collect: updated_batchmeta = await _async_update_batchmeta_with_output(output, batchmeta, func.__name__) return updated_batchmeta - elif collect_from_rank == False: - return BatchMeta() - return output + return _postprocess_common(output, put_data, need_collect) @wraps(func) def dummy_inner(*args, **kwargs): - if "collect_from_rank" in kwargs: - kwargs.pop("collect_from_rank") - return func(*args, **kwargs) + output = func(*args, **kwargs) + need_collect = _compute_need_collect(dispatch_mode, args) + if not need_collect: + return DataProto() + return output @wraps(func) async def dummy_async_inner(*args, **kwargs): - if "collect_from_rank" in kwargs: - kwargs.pop("collect_from_rank") - return await func(*args, **kwargs) + output = await func(*args, **kwargs) + need_collect = _compute_need_collect(dispatch_mode, args) + if not need_collect: + return DataProto() + return output wrapper_inner = inner if is_transferqueue_enabled else dummy_inner wrapper_async_inner = async_inner if is_transferqueue_enabled else dummy_async_inner