-
Notifications
You must be signed in to change notification settings - Fork 6
elegantly prevent data re-put within DP #50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5ff95b8
6da262c
d2d6927
a513570
9989e83
ce78ce4
8381033
e58fdd6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -18,7 +18,10 @@ | |||||||||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||||||||||
| import threading | ||||||||||||||||||||||||||||||||||||
| from functools import wraps | ||||||||||||||||||||||||||||||||||||
| from typing import Any, Callable | ||||||||||||||||||||||||||||||||||||
| from typing import TYPE_CHECKING, Any, Callable | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if TYPE_CHECKING: | ||||||||||||||||||||||||||||||||||||
| from verl.single_controller.base.decorator import Dispatch | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| from tensordict import TensorDict | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
@@ -144,7 +147,84 @@ def _update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta", fun | |||||||||||||||||||||||||||||||||||
| return updated_batch_meta | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def tqbridge(put_data: bool = True): | ||||||||||||||||||||||||||||||||||||
| def _compute_need_collect(dispatch_mode: dict | "Dispatch", args: list) -> bool: | ||||||||||||||||||||||||||||||||||||
| """Compute whether data collection is needed for the current worker. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| This function determines whether the current worker should collect data based on | ||||||||||||||||||||||||||||||||||||
| the dispatch mode configuration and worker parameters. It's used to optimize | ||||||||||||||||||||||||||||||||||||
| distributed data collection by ensuring only the appropriate rank collects data. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||
| dispatch_mode: Controls data collection logic for the current worker. Can be None, | ||||||||||||||||||||||||||||||||||||
| a Dispatch instance, or a dict with 'collect_fn' key. If None or Dispatch, | ||||||||||||||||||||||||||||||||||||
| always returns True (current worker should collect). If dict, checks | ||||||||||||||||||||||||||||||||||||
| collect_fn for lazy compute optimization. | ||||||||||||||||||||||||||||||||||||
| args: List of arguments passed to the function. Should contain a Worker instance | ||||||||||||||||||||||||||||||||||||
| as the first argument when using lazy compute mode. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||
| bool: True if data collection is needed, False otherwise. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Note: | ||||||||||||||||||||||||||||||||||||
| Only checks worker attributes when dispatch_mode is a dict with 'collect_fn', | ||||||||||||||||||||||||||||||||||||
| the collect_fn is 'collect_lazy_compute_data_proto', and args[0] is a Worker. | ||||||||||||||||||||||||||||||||||||
| Otherwise, returns True. For the lazy compute case, checks the worker's | ||||||||||||||||||||||||||||||||||||
| data parallel rank for the mesh specified in collect_fn.args[0] to determine | ||||||||||||||||||||||||||||||||||||
| if this worker should collect data. | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
| from verl.single_controller.base.decorator import Dispatch | ||||||||||||||||||||||||||||||||||||
| from verl.single_controller.base.worker import Worker | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if dispatch_mode is None or isinstance(dispatch_mode, Dispatch): | ||||||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| assert "collect_fn" in dispatch_mode.keys(), "collect_fn should be in dispatch_mode." | ||||||||||||||||||||||||||||||||||||
| collect_fn_name = dispatch_mode["collect_fn"].func.__name__ | ||||||||||||||||||||||||||||||||||||
| if collect_fn_name != "collect_lazy_compute_data_proto" or len(args) < 1 or not isinstance(args[0], Worker): | ||||||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| collect_mesh_name = dispatch_mode["collect_fn"].args[0] | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+182
to
+186
|
||||||||||||||||||||||||||||||||||||
| collect_fn_name = dispatch_mode["collect_fn"].func.__name__ | |
| if collect_fn_name != "collect_lazy_compute_data_proto" or len(args) < 1 or not isinstance(args[0], Worker): | |
| return True | |
| collect_mesh_name = dispatch_mode["collect_fn"].args[0] | |
| collect_fn = dispatch_mode["collect_fn"] | |
| base_fn = getattr(collect_fn, "func", collect_fn) | |
| # Prefer an explicit attribute on the collect function, fall back to name-based check. | |
| is_lazy_collect = getattr( | |
| base_fn, | |
| "is_lazy_compute_data_proto", | |
| base_fn.__name__ == "collect_lazy_compute_data_proto", | |
| ) | |
| if not is_lazy_collect or len(args) < 1 or not isinstance(args[0], Worker): | |
| return True | |
| collect_mesh_name = collect_fn.args[0] |
Uh oh!
There was an error while loading. Please reload this page.