diff --git a/src/flyte/__init__.py b/src/flyte/__init__.py index 8cb31e442..cc43b765d 100644 --- a/src/flyte/__init__.py +++ b/src/flyte/__init__.py @@ -34,7 +34,7 @@ from ._resources import AMD_GPU, GPU, HABANA_GAUDI, TPU, Device, DeviceClass, Neuron, Resources from ._retry import Backoff, RetryStrategy from ._reusable_environment import ReusePolicy -from ._run import run, with_runcontext +from ._run import rerun, run, with_runcontext from ._run_python_script import run_python_script from ._secret import Secret, SecretRequest from ._serve import AppHandle, serve, with_servecontext @@ -106,6 +106,7 @@ def version() -> str: "logger", "map", "new_condition", + "rerun", "run", "run_python_script", "serve", diff --git a/src/flyte/_run.py b/src/flyte/_run.py index a2c19305b..02c02a6ea 100644 --- a/src/flyte/_run.py +++ b/src/flyte/_run.py @@ -74,6 +74,20 @@ def _get_main_run_mode() -> Mode | None: return _run_mode_var.get() +def _to_cache_lookup_scope(scope: CacheLookupScope | None = None): + """Map the SDK cache-lookup-scope literal onto its RunSpec enum value.""" + from flyteidl2.task import run_pb2 + + if scope == "global": + return run_pb2.CacheLookupScope.CACHE_LOOKUP_SCOPE_GLOBAL + elif scope == "project-domain": + return run_pb2.CacheLookupScope.CACHE_LOOKUP_SCOPE_PROJECT_DOMAIN + elif scope is None: + return run_pb2.CacheLookupScope.CACHE_LOOKUP_SCOPE_UNSPECIFIED + else: + raise ValueError(f"Unknown cache lookup scope: {scope}") + + class _Runner: def __init__( self, @@ -108,6 +122,7 @@ def __init__( cache_lookup_scope: CacheLookupScope = "global", preserve_original_types: bool | None = None, debug: bool = False, + recover: bool | str | None = False, _tracker: Any = None, _bundle_relative_paths: tuple[str, ...] | None = None, _bundle_from_dir: pathlib.Path | None = None, @@ -156,155 +171,151 @@ def __init__( preserve_original_types if preserve_original_types is not None else self._interactive_mode ) self._debug = debug + # Recover (reuse a prior run's succeeded actions). `True` = recover from the run being rerun; + # a run-name string = recover from that named run (the only form valid on a plain run()). + # Carried on RunSpec.recover; remote-only; gated in _apply_overrides until the flyteidl2 field + # + backend ship. See _resolve_recover_ref. + self._recover = recover - @requires_initialization - async def _run_remote(self, obj: TaskTemplate[P, R, F] | LazyEntity, *args: P.args, **kwargs: P.kwargs) -> Run: - from connectrpc.code import Code - from connectrpc.errors import ConnectError - from flyteidl2.common import identifier_pb2 - from flyteidl2.core import literals_pb2, security_pb2 - from flyteidl2.task import run_pb2 - from flyteidl2.workflow import run_definition_pb2, run_service_pb2 - from google.protobuf import wrappers_pb2 + def _resolve_recover_ref(self, rerun_run_name: str | None) -> str | None: + """Resolve `self._recover` to the reference run name to recover from (or None). + + `False`/`None` -> no recover. `True` -> the run being rerun (`rerun_run_name`); invalid on a + plain `run()` where there is no rerun target. A string -> that named run. + """ + r = self._recover + if not r: + return None + if r is True: + if rerun_run_name is None: + raise ValueError( + "recover=True is only valid with rerun() (it recovers from the run being rerun). " + "To recover a fresh run() from a prior run, pass its name: " + "with_runcontext(recover='').run(...)" + ) + return rerun_run_name + return r # explicit run-name string + + async def _build_task_spec_from_template(self, obj: TaskTemplate[P, R, F]) -> Tuple[Any, Any, str]: + """Build ``(task_spec, code_bundle, version)`` from a local ``TaskTemplate``. + Shared by ``_run_remote`` (local-task branch) and ``rerun`` with substitute code, so both + get identical fidelity (copy_files / dry_run / interactive_mode / include-files). Heavy + imports stay function-local to keep ``import flyte`` cheap. The built ``image_cache`` is + folded into the returned ``task_spec`` via the serialization context, so it is not returned. + """ import flyte.report - from flyte.remote import Run - from flyte.remote._task import LazyEntity, TaskDetails + from flyte._image import Image, resolve_code_bundle_layer from ._code_bundle import build_code_bundle, build_code_bundle_from_relative_paths, build_pkl_bundle from ._code_bundle._includes import collect_env_include_files - from ._deploy import build_images - from ._internal.runtime.convert import convert_from_native_to_inputs + from ._deploy import build_images, plan_deploy from ._internal.runtime.task_serde import translate_task_to_wire cfg = get_init_config() project = self._project or cfg.project domain = self._domain or cfg.domain - task: TaskTemplate[P, R, F] | TaskDetails - task_id = None - if isinstance(obj, (LazyEntity, TaskDetails)): - if isinstance(obj, LazyEntity): - task = await obj.fetch.aio() - else: - task = obj - task_spec = task.pb2.spec - # A fetched task is normally run by reference (task_id only). But if it was modified via - # `.override(...)`, the local spec no longer matches the registered task, so we must send - # the full spec instead. Setting task_id to None routes every downstream branch to the - # spec path. - task_id = None if task.overridden else task.pb2.task_id - inputs = await convert_from_native_to_inputs( - task.interface, *args, custom_context=self._custom_context, **kwargs - ) - version = task.pb2.task_id.version - code_bundle = None - elif isinstance(obj, TaskTemplate): - task = cast(TaskTemplate[P, R, F], obj) - if obj.parent_env is None: - raise ValueError("Task is not attached to an environment. Please attach the task to an environment") - - # Resolve any CodeBundleLayer layers before building images. - # Must cover the parent env AND all depends_on envs (recursively) - # so that _build_images can compute the content hash for every image. - parent_env = cast(Environment, obj.parent_env()) - from flyte._image import Image, resolve_code_bundle_layer - - from ._deploy import plan_deploy + if obj.parent_env is None: + raise ValueError("Task is not attached to an environment. Please attach the task to an environment") - plan_envs = list(plan_deploy(parent_env)[0].envs.values()) - for _env in plan_envs: - if isinstance(_env.image, Image): - _env.image = resolve_code_bundle_layer(_env.image, self._copy_files, pathlib.Path(cfg.root_dir)) + # Resolve any CodeBundleLayer layers before building images. + # Must cover the parent env AND all depends_on envs (recursively) + # so that _build_images can compute the content hash for every image. + parent_env = cast(Environment, obj.parent_env()) + plan_envs = list(plan_deploy(parent_env)[0].envs.values()) + for _env in plan_envs: + if isinstance(_env.image, Image): + _env.image = resolve_code_bundle_layer(_env.image, self._copy_files, pathlib.Path(cfg.root_dir)) - if not self._dry_run: - image_cache = await build_images.aio(parent_env) - else: - image_cache = None + if not self._dry_run: + image_cache = await build_images.aio(parent_env) + else: + image_cache = None - include_files = collect_env_include_files(plan_envs) - skip_cache = self._disable_run_cache + include_files = collect_env_include_files(plan_envs) + skip_cache = self._disable_run_cache - if self._interactive_mode: - if include_files: - raise ValueError( - "Environment.include is not supported in interactive/pkl runs. " - "Run from a file or remove `include` from the environment." - ) - code_bundle = await build_pkl_bundle( - obj, - upload_to_controlplane=not self._dry_run, - copy_bundle_to=self._copy_bundle_to, + if self._interactive_mode: + if include_files: + raise ValueError( + "Environment.include is not supported in interactive/pkl runs. " + "Run from a file or remove `include` from the environment." ) - elif self._copy_files == "custom": - if not self._bundle_relative_paths or not self._bundle_from_dir: - raise ValueError("copy_style='custom' requires _bundle_relative_paths and _bundle_from_dir") - merged_paths = tuple(self._bundle_relative_paths) + include_files - code_bundle = await build_code_bundle_from_relative_paths( - merged_paths, - from_dir=self._bundle_from_dir, - dryrun=self._dry_run, - copy_bundle_to=self._copy_bundle_to, - skip_cache=skip_cache, - ) - elif self._copy_files != "none": - code_bundle = await build_code_bundle( - from_dir=cfg.root_dir, - dryrun=self._dry_run, - copy_bundle_to=self._copy_bundle_to, - copy_style=self._copy_files, - additional_files=include_files, - skip_cache=skip_cache, - ) - elif include_files: - code_bundle = await build_code_bundle_from_relative_paths( - include_files, - from_dir=pathlib.Path(cfg.root_dir), - dryrun=self._dry_run, - copy_bundle_to=self._copy_bundle_to, - skip_cache=skip_cache, - ) - else: - code_bundle = None - - version = self._version or ( - code_bundle.computed_version if code_bundle and code_bundle.computed_version else None + code_bundle = await build_pkl_bundle( + obj, + upload_to_controlplane=not self._dry_run, + copy_bundle_to=self._copy_bundle_to, ) - if not version: - raise ValueError("Version is required when running a task") - s_ctx = SerializationContext( - code_bundle=code_bundle, - version=version, - image_cache=image_cache, - root_dir=cfg.root_dir, + elif self._copy_files == "custom": + if not self._bundle_relative_paths or not self._bundle_from_dir: + raise ValueError("copy_style='custom' requires _bundle_relative_paths and _bundle_from_dir") + merged_paths = tuple(self._bundle_relative_paths) + include_files + code_bundle = await build_code_bundle_from_relative_paths( + merged_paths, + from_dir=self._bundle_from_dir, + dryrun=self._dry_run, + copy_bundle_to=self._copy_bundle_to, + skip_cache=skip_cache, ) - action = ActionID( - name="{{.actionName}}", run_name="{{.runName}}", project=project, domain=domain, org=cfg.org + elif self._copy_files != "none": + code_bundle = await build_code_bundle( + from_dir=cfg.root_dir, + dryrun=self._dry_run, + copy_bundle_to=self._copy_bundle_to, + copy_style=self._copy_files, + additional_files=include_files, + skip_cache=skip_cache, ) - tctx = TaskContext( - action=action, - code_bundle=code_bundle, - output_path="", - version=version or "na", - raw_data_path=RawDataPath(path=""), - compiled_image_cache=image_cache, - run_base_dir="", - report=flyte.report.Report(name=action.name), - custom_context=self._custom_context, - ) - task_spec = translate_task_to_wire(obj, s_ctx, default_inputs=None, task_context=tctx) - inputs = await convert_from_native_to_inputs( - obj.native_interface, *args, custom_context=self._custom_context, **kwargs + elif include_files: + code_bundle = await build_code_bundle_from_relative_paths( + include_files, + from_dir=pathlib.Path(cfg.root_dir), + dryrun=self._dry_run, + copy_bundle_to=self._copy_bundle_to, + skip_cache=skip_cache, ) else: - raise ValueError(f"Not supported Task Type: {type(task)}") + code_bundle = None - env = self._env_vars or {} + version = self._version or ( + code_bundle.computed_version if code_bundle and code_bundle.computed_version else None + ) + if not version: + raise ValueError("Version is required when running a task") + s_ctx = SerializationContext( + code_bundle=code_bundle, + version=version, + image_cache=image_cache, + root_dir=cfg.root_dir, + ) + action = ActionID(name="{{.actionName}}", run_name="{{.runName}}", project=project, domain=domain, org=cfg.org) + tctx = TaskContext( + action=action, + code_bundle=code_bundle, + output_path="", + version=version or "na", + raw_data_path=RawDataPath(path=""), + compiled_image_cache=image_cache, + run_base_dir="", + report=flyte.report.Report(name=action.name), + custom_context=self._custom_context, + ) + task_spec = translate_task_to_wire(obj, s_ctx, default_inputs=None, task_context=tctx) + return task_spec, code_bundle, version + + def _build_env_dict(self) -> Dict[str, str]: + """Assemble the runtime env dict from runner config. + + User-supplied ``env_vars`` plus the always-injected LOG_* / debug / rust-controller / + sys-path keys. Shared by the fresh-build and inherited (rerun) RunSpec paths so debug's + ssh-env injection and the log settings apply identically. Returns a fresh dict (never + mutates ``self._env_vars``). + """ + cfg = get_init_config() + env: Dict[str, str] = dict(self._env_vars or {}) if env.get("LOG_LEVEL") is None: - if self._log_level: - env["LOG_LEVEL"] = str(self._log_level) - else: - env["LOG_LEVEL"] = str(logger.getEffectiveLevel()) + env["LOG_LEVEL"] = str(self._log_level) if self._log_level else str(logger.getEffectiveLevel()) env["LOG_FORMAT"] = self._log_format if self._user_log_level is not None: env["USER_LOG_LEVEL"] = str(self._user_log_level) @@ -320,16 +331,245 @@ async def _run_remote(self, obj: TaskTemplate[P, R, F] | LazyEntity, *args: P.ar # These paths will be appended to sys.path at runtime. if cfg.sync_local_sys_paths: root_dir_abs = pathlib.Path(cfg.root_dir).resolve() - added_paths = [ + env[FLYTE_SYS_PATH] = ":".join( f"./{pathlib.Path(p).relative_to(root_dir_abs)}" for p in sys.path if pathlib.Path(p).is_relative_to(root_dir_abs) - ] - env[FLYTE_SYS_PATH] = ":".join(added_paths) + ) # TODO: Remove once the actions service is the default and this env var is no longer needed. if os.getenv("_U_USE_ACTIONS") == "1": env["_U_USE_ACTIONS"] = "1" + return env + + def _resolve_run_target(self, project: str | None, domain: str | None, org: str | None): + """Resolve the create-run target: a RunIdentifier when a name is set, else a ProjectIdentifier.""" + from flyteidl2.common import identifier_pb2 + + if self._name: + return ( + identifier_pb2.RunIdentifier(project=project, domain=domain, org=org, name=self._name or None), + None, + ) + return None, identifier_pb2.ProjectIdentifier(name=project, domain=domain, organization=org) + + def _apply_overrides(self, base: Any, *, task: Any = None, recover_ref: str | None = None) -> Any: + """Build the ``RunSpec`` for ``create_run``. + + ``base is None`` -> a fresh spec from runner config (the run / recover path). + ``base`` set -> deep-copy a prior run's ``RunSpec`` and merge runner overrides by key + (the rerun path: env merge + explicitly-set field overrides). Pure proto assembly, no I/O. + This is the single place runner config maps onto a ``RunSpec``. ``recover_ref`` is the already- + resolved reference run to recover from (see ``_resolve_recover_ref``), or None. + """ + from flyteidl2.core import literals_pb2, security_pb2 + from flyteidl2.task import run_pb2 + from google.protobuf import wrappers_pb2 + + env = self._build_env_dict() + if base is not None: + # Inherit the prior run's env as the floor; runner overrides win. + merged = {kv.key: kv.value for kv in base.envs.values} + merged.update(env) + env = merged + + kv_pairs: List[literals_pb2.KeyValuePair] = [] + for k, v in env.items(): + if not isinstance(v, str): + raise ValueError(f"Environment variable {k} must be a string, got {type(v)}") + kv_pairs.append(literals_pb2.KeyValuePair(key=k, value=v)) + env_kv = run_pb2.Envs(values=kv_pairs) + + notification_rule_name = None + notification_rules = None + if self._notifications: + from flyte._internal.runtime.notifications_serde import resolve_notification_settings + + notification_rule_name, notification_rules = resolve_notification_settings(self._notifications) + + if base is None: + raw_data_storage = ( + run_pb2.RawDataStorage(raw_data_prefix=self._raw_data_path) if self._raw_data_path else None + ) + security_context = ( + security_pb2.SecurityContext(run_as=security_pb2.Identity(k8s_service_account=self._service_account)) + if self._service_account + else None + ) + run_spec = run_pb2.RunSpec( + overwrite_cache=self._overwrite_cache, + interruptible=wrappers_pb2.BoolValue(value=self._interruptible) + if self._interruptible is not None + else None, + annotations=run_pb2.Annotations(values=self._annotations), + labels=run_pb2.Labels(values=self._labels), + envs=env_kv, + cluster=self._queue or (task.queue if task is not None else ""), + max_action_concurrency=self._max_action_concurrency or 0, + raw_data_storage=raw_data_storage, + security_context=security_context, + cache_config=run_pb2.CacheConfig( + overwrite_cache=self._overwrite_cache, + cache_lookup_scope=_to_cache_lookup_scope(self._cache_lookup_scope) + if self._cache_lookup_scope + else None, + ), + notification_rule_name=notification_rule_name, + notification_rules=notification_rules, + ) + else: + # Deep-copy the fetched spec (it is shared/cached on the RunDetails); never mutate in place. + run_spec = run_pb2.RunSpec() + run_spec.CopyFrom(base) + run_spec.envs.CopyFrom(env_kv) + if self._interruptible is not None: + run_spec.interruptible.CopyFrom(wrappers_pb2.BoolValue(value=self._interruptible)) + if self._overwrite_cache: + run_spec.overwrite_cache = True + run_spec.cache_config.overwrite_cache = True + if self._labels: + for k, v in self._labels.items(): + run_spec.labels.values[k] = v + if self._annotations: + for k, v in self._annotations.items(): + run_spec.annotations.values[k] = v + if self._cache_lookup_scope: + run_spec.cache_config.cache_lookup_scope = _to_cache_lookup_scope(self._cache_lookup_scope) + if self._max_action_concurrency: + run_spec.max_action_concurrency = self._max_action_concurrency + if self._queue: + # TODO: cluster is being renamed to queue + run_spec.cluster = self._queue + if self._service_account: + run_spec.security_context.CopyFrom( + security_pb2.SecurityContext( + run_as=security_pb2.Identity(k8s_service_account=self._service_account) + ) + ) + if notification_rule_name: + run_spec.notification_rule_name = notification_rule_name + if notification_rules: + run_spec.notification_rules.CopyFrom(notification_rules) + + # recover: gated until flyteidl2 ships RunSpec.recover (+ backend support). One-line set then. + if recover_ref: + if "recover" not in run_pb2.RunSpec.DESCRIPTOR.fields_by_name: + raise NotImplementedError( + "recover is not yet supported by this backend " + "(RunSpec.recover is unavailable in this flyteidl2 build)." + ) + from flyteidl2.common import identifier_pb2 + + run_spec.recover.CopyFrom(run_pb2.Recover(run_id=identifier_pb2.RunIdentifier(name=recover_ref))) + + return run_spec + + async def _submit_remote( + self, *, task_spec: Any, task_id: Any, proto_inputs: Any, run_spec: Any, run_id: Any, project_id: Any + ) -> Run: + """Upload inputs and create the run. The single network call site for remote submission. + + Consumes an already-built ``run_spec`` (see ``_apply_overrides``), raw proto ``inputs`` + (``flyteidl2.task.Inputs``), and a task by reference (``task_id``) or by value + (``task_spec``); shared by ``_run_remote`` and ``rerun``. + """ + from connectrpc.code import Code + from connectrpc.errors import ConnectError + from flyteidl2.dataproxy import dataproxy_service_pb2 + from flyteidl2.workflow import run_service_pb2 + + import flyte.errors + from flyte.remote import Run + + try: + upload_req = dataproxy_service_pb2.UploadInputsRequest(inputs=proto_inputs) + # Reference an already-registered task by id; otherwise upload the full spec. + if task_id is not None: + upload_req.task_id.CopyFrom(task_id) + else: + upload_req.task_spec.CopyFrom(task_spec) + if run_id is not None: + upload_req.run_id.CopyFrom(run_id) + else: + upload_req.project_id.CopyFrom(project_id) + + upload_resp = await get_client().dataproxy_service.upload_inputs(upload_req) + + create_req = run_service_pb2.CreateRunRequest( + run_id=run_id, + project_id=project_id, + offloaded_input_data=upload_resp.offloaded_input_data, + run_spec=run_spec, + ) + # Reference an already-registered task by id; otherwise send the full spec. + if task_id is not None: + create_req.task_id.CopyFrom(task_id) + else: + create_req.task_spec.CopyFrom(task_spec) + + resp = await get_client().run_service.create_run(create_req) + return Run(pb2=resp.run, _preserve_original_types=self._preserve_original_types) + except ConnectError as e: + if e.code == Code.UNAVAILABLE: + raise flyte.errors.RuntimeSystemError( + "SystemUnavailableError", + "Flyte system is currently unavailable. check your configuration, or the service status.", + ) from e + elif e.code == Code.INVALID_ARGUMENT: + raise flyte.errors.RuntimeUserError("InvalidArgumentError", e.message) + elif e.code == Code.ALREADY_EXISTS: + # TODO maybe this should be a pass and return existing run? + raise flyte.errors.RuntimeUserError( + "RunAlreadyExistsError", + f"A run with the name '{self._name}' already exists. Please choose a different name.", + ) + else: + raise flyte.errors.RuntimeSystemError( + "RunCreationError", + f"Failed to create run: {e.message}", + ) from e + + @requires_initialization + async def _run_remote(self, obj: TaskTemplate[P, R, F] | LazyEntity, *args: P.args, **kwargs: P.kwargs) -> Run: + from flyteidl2.common import identifier_pb2 + from flyteidl2.workflow import run_definition_pb2 + + import flyte.errors + from flyte.remote import Run + from flyte.remote._task import LazyEntity, TaskDetails + + from ._internal.runtime.convert import convert_from_native_to_inputs + + cfg = get_init_config() + project = self._project or cfg.project + domain = self._domain or cfg.domain + + task: TaskTemplate[P, R, F] | TaskDetails + task_id = None + if isinstance(obj, (LazyEntity, TaskDetails)): + if isinstance(obj, LazyEntity): + task = await obj.fetch.aio() + else: + task = obj + task_spec = task.pb2.spec + # A fetched task is normally run by reference (task_id only). But if it was modified via + # `.override(...)`, the local spec no longer matches the registered task, so we must send + # the full spec instead. Setting task_id to None routes every downstream branch to the + # spec path. + task_id = None if task.overridden else task.pb2.task_id + inputs = await convert_from_native_to_inputs( + task.interface, *args, custom_context=self._custom_context, **kwargs + ) + version = task.pb2.task_id.version + code_bundle = None + elif isinstance(obj, TaskTemplate): + task = cast(TaskTemplate[P, R, F], obj) + task_spec, code_bundle, version = await self._build_task_spec_from_template(obj) + inputs = await convert_from_native_to_inputs( + obj.native_interface, *args, custom_context=self._custom_context, **kwargs + ) + else: + raise ValueError(f"Not supported Task Type: {type(task)}") if not self._dry_run: if get_client() is None: @@ -341,21 +581,7 @@ async def _run_remote(self, obj: TaskTemplate[P, R, F] | LazyEntity, *args: P.ar "Call flyte.init() with a valid endpoint/api-key before using this function" "or Call flyte.init_from_config() with a valid path to the config file", ) - run_id = None - project_id = None - if self._name: - run_id = identifier_pb2.RunIdentifier( - project=project, - domain=domain, - org=cfg.org, - name=self._name or None, - ) - else: - project_id = identifier_pb2.ProjectIdentifier( - name=project, - domain=domain, - organization=cfg.org, - ) + run_id, project_id = self._resolve_run_target(project, domain, cfg.org) # Fill in task id inside the task template if it's not provided. # Maybe this should be done here, or the backend. # Only needed for locally-defined tasks; a fetched task sent by reference (task_id set) @@ -371,111 +597,16 @@ async def _run_remote(self, obj: TaskTemplate[P, R, F] | LazyEntity, *args: P.ar if task_spec.task_template.id.version == "": task_spec.task_template.id.version = version - kv_pairs: List[literals_pb2.KeyValuePair] = [] - for k, v in env.items(): - if not isinstance(v, str): - raise ValueError(f"Environment variable {k} must be a string, got {type(v)}") - kv_pairs.append(literals_pb2.KeyValuePair(key=k, value=v)) - - env_kv = run_pb2.Envs(values=kv_pairs) - annotations = run_pb2.Annotations(values=self._annotations) - labels = run_pb2.Labels(values=self._labels) - raw_data_storage = ( - run_pb2.RawDataStorage(raw_data_prefix=self._raw_data_path) if self._raw_data_path else None - ) - security_context = ( - security_pb2.SecurityContext(run_as=security_pb2.Identity(k8s_service_account=self._service_account)) - if self._service_account - else None + run_spec = self._apply_overrides(None, task=task, recover_ref=self._resolve_recover_ref(None)) + return await self._submit_remote( + task_spec=task_spec, + task_id=task_id, + proto_inputs=inputs.proto_inputs, + run_spec=run_spec, + run_id=run_id, + project_id=project_id, ) - def _to_cache_lookup_scope(scope: CacheLookupScope | None = None) -> run_pb2.CacheLookupScope: - if scope == "global": - return run_pb2.CacheLookupScope.CACHE_LOOKUP_SCOPE_GLOBAL - elif scope == "project-domain": - return run_pb2.CacheLookupScope.CACHE_LOOKUP_SCOPE_PROJECT_DOMAIN - elif scope is None: - return run_pb2.CacheLookupScope.CACHE_LOOKUP_SCOPE_UNSPECIFIED - else: - raise ValueError(f"Unknown cache lookup scope: {scope}") - - notification_rule_name = None - notification_rules = None - if self._notifications: - from flyte._internal.runtime.notifications_serde import resolve_notification_settings - - notification_rule_name, notification_rules = resolve_notification_settings(self._notifications) - - try: - from flyteidl2.dataproxy import dataproxy_service_pb2 - - upload_req = dataproxy_service_pb2.UploadInputsRequest(inputs=inputs.proto_inputs) - # Reference an already-registered task by id; otherwise upload the full spec. - if task_id is not None: - upload_req.task_id.CopyFrom(task_id) - else: - upload_req.task_spec.CopyFrom(task_spec) - if run_id is not None: - upload_req.run_id.CopyFrom(run_id) - else: - upload_req.project_id.CopyFrom(project_id) - - upload_resp = await get_client().dataproxy_service.upload_inputs(upload_req) - - create_req = run_service_pb2.CreateRunRequest( - run_id=run_id, - project_id=project_id, - offloaded_input_data=upload_resp.offloaded_input_data, - run_spec=run_pb2.RunSpec( - overwrite_cache=self._overwrite_cache, - interruptible=wrappers_pb2.BoolValue(value=self._interruptible) - if self._interruptible is not None - else None, - annotations=annotations, - labels=labels, - envs=env_kv, - cluster=self._queue or task.queue, - max_action_concurrency=self._max_action_concurrency or 0, - raw_data_storage=raw_data_storage, - security_context=security_context, - cache_config=run_pb2.CacheConfig( - overwrite_cache=self._overwrite_cache, - cache_lookup_scope=_to_cache_lookup_scope(self._cache_lookup_scope) - if self._cache_lookup_scope - else None, - ), - notification_rule_name=notification_rule_name, - notification_rules=notification_rules, - ), - ) - # Reference an already-registered task by id; otherwise send the full spec. - if task_id is not None: - create_req.task_id.CopyFrom(task_id) - else: - create_req.task_spec.CopyFrom(task_spec) - - resp = await get_client().run_service.create_run(create_req) - return Run(pb2=resp.run, _preserve_original_types=self._preserve_original_types) - except ConnectError as e: - if e.code == Code.UNAVAILABLE: - raise flyte.errors.RuntimeSystemError( - "SystemUnavailableError", - "Flyte system is currently unavailable. check your configuration, or the service status.", - ) from e - elif e.code == Code.INVALID_ARGUMENT: - raise flyte.errors.RuntimeUserError("InvalidArgumentError", e.message) - elif e.code == Code.ALREADY_EXISTS: - # TODO maybe this should be a pass and return existing run? - raise flyte.errors.RuntimeUserError( - "RunAlreadyExistsError", - f"A run with the name '{self._name}' already exists. Please choose a different name.", - ) - else: - raise flyte.errors.RuntimeSystemError( - "RunCreationError", - f"Failed to create run: {e.message}", - ) from e - class DryRun(Run): def __init__(self, _task_spec, _inputs, _code_bundle): super().__init__( @@ -810,6 +941,11 @@ async def example_task(x: int, y: str) -> str: if not isinstance(task, TaskTemplate) and not isinstance(task, (LazyEntity, TaskDetails)): raise TypeError(f"On Flyte tasks can be run, not generic functions or methods '{type(task)}'.") + # recover is an actions-service / RunSpec concern — remote-only. Fail fast rather than silently + # ignoring it in local/hybrid mode. + if self._recover and self._mode != "remote": + raise ValueError("recover is only supported in remote mode") + # Set the run mode in the context variable so that offloaded types (files, directories, dataframes) # can check the mode for controlling auto-uploading behavior (only enabled in remote mode). _run_mode_var.set(self._mode) @@ -829,6 +965,103 @@ async def example_task(x: int, y: str) -> str: finally: _run_mode_var.set(None) + @syncify # type: ignore[arg-type] + async def rerun( + self, + run_name: str, + action_name: str = "a0", + task_template: TaskTemplate[P, R, F] | None = None, + inputs: Dict[str, Any] | None = None, + ) -> Run: + """Re-run a prior run, returning a new `Run`. + + - `rerun("r1")` re-runs with the prior run's exact inputs, fetching its task spec from the + platform (no local code needed). + - `rerun("r1", inputs={"x": 2})` changes input parameters (converted against the fetched + task interface). + - `rerun("r1", task_template=fixed)` substitutes new code, validated against the original + inputs (or `inputs` if given). + + The prior run's `RunSpec` is inherited and merged with this context's overrides + (`with_runcontext(env_vars=..., interruptible=..., recover=...)` etc.), so debug/recover + compose with rerun. Currently remote-only. + + :param run_name: Name of the prior run to re-run. + :param action_name: Action within the prior run to source the task + inputs from (default `a0`). + :param task_template: Optional task to substitute for the prior run's code. + :param inputs: Optional native kwargs to change input parameters; omit to reuse prior inputs. + :return: the new Run. + """ + if self._mode != "remote": + raise NotImplementedError(f"rerun is only supported in remote mode, got mode={self._mode!r}") + + from flyteidl2.dataproxy import dataproxy_service_pb2 + + from flyte.remote._action import ActionDetails + from flyte.remote._run import RunDetails + + from ._internal.runtime.convert import convert_from_native_to_inputs + + cfg = get_init_config() + project = self._project or cfg.project + domain = self._domain or cfg.domain + + run_details = await RunDetails.get.aio(name=run_name) + base_run_spec = run_details.pb2.run_spec + if action_name == "a0": + action_details = run_details.action_details + else: + action_details = await ActionDetails.get.aio(run_name=run_name, name=action_name) + + # Task source: substitute a freshly-built local spec, or reuse the prior action's spec. + if task_template is not None: + task_spec, _code_bundle, version = await self._build_task_spec_from_template(task_template) + else: + if not action_details.pb2.HasField("task"): + raise ValueError(f"Action {run_name}/{action_name} has no task spec to rerun.") + task_spec = action_details.pb2.task + version = task_spec.task_template.id.version + + # Inputs: reuse the prior raw proto inputs, or convert new native kwargs against the interface. + if inputs: + if task_template is not None: + iface = task_template.native_interface + else: + from flyte.types._interface import guess_interface + + iface = guess_interface(task_spec.task_template.interface) + converted = await convert_from_native_to_inputs(iface, custom_context=self._custom_context, **inputs) + proto_inputs = converted.proto_inputs + else: + resp = await get_client().dataproxy_service.get_action_data( + request=dataproxy_service_pb2.GetActionDataRequest(action_id=action_details.pb2.id) + ) + proto_inputs = resp.inputs + + run_id, project_id = self._resolve_run_target(project, domain, cfg.org) + + # A freshly-built substitute spec may carry empty ids; fill them like _run_remote does. + if task_template is not None: + tt_id = task_spec.task_template.id + if tt_id.project == "": + tt_id.project = project or "" + if tt_id.domain == "": + tt_id.domain = domain or "" + if tt_id.org == "": + tt_id.org = cfg.org or "" + if tt_id.version == "": + tt_id.version = version + + run_spec = self._apply_overrides(base_run_spec, recover_ref=self._resolve_recover_ref(run_name)) + return await self._submit_remote( + task_spec=task_spec, + task_id=None, + proto_inputs=proto_inputs, + run_spec=run_spec, + run_id=run_id, + project_id=project_id, + ) + def with_runcontext( mode: Mode | None = None, @@ -863,6 +1096,7 @@ def with_runcontext( cache_lookup_scope: CacheLookupScope = "global", preserve_original_types: bool = False, debug: bool = False, + recover: bool | str | None = False, _tracker: Any = None, ) -> _Runner: """ @@ -943,6 +1177,11 @@ async def example_task(x: int, y: str) -> str: explicitly by this parameter. :param debug: Optional If true, the task will be run as a VSCode debug task, starting a code-server in the container so users can connect via the UI to interactively debug/run the task. + :param recover: Recover (reuse a prior run's succeeded actions, re-running only what failed or + changed). ``True`` recovers from the run being rerun — only valid with ``.rerun(...)``; a + run-name string recovers from that named run and is the only form valid on ``.run(...)``. + Remote-only. Not yet supported by the backend (raises NotImplementedError at submit until + flyteidl2 RunSpec.recover ships). :param _tracker: This is an internal only parameter used by the CLI to render the TUI. :return: runner @@ -992,6 +1231,7 @@ async def example_task(x: int, y: str) -> str: cache_lookup_scope=cache_lookup_scope, preserve_original_types=preserve_original_types, debug=debug, + recover=recover, _tracker=_tracker, ) @@ -1007,3 +1247,25 @@ async def run(task: TaskTemplate[P, R, F], *args: P.args, **kwargs: P.kwargs) -> """ # using syncer causes problems return await _Runner().run.aio(task, *args, **kwargs) # type: ignore + + +@syncify +async def rerun( + run_name: str, + action_name: str = "a0", + task_template: TaskTemplate[P, R, F] | None = None, + **inputs: Any, +) -> Run: + """Re-run a prior run, returning a new `Run`. + + `rerun("r1")` reuses the prior run's exact inputs (fetching its code from the platform); + pass keyword inputs to change parameters (`rerun("r1", x=2)`), or `task_template=` to substitute + code. Use `with_runcontext(...).rerun(...)` to apply run-context overrides (env_vars, recover, …). + + :param run_name: Name of the prior run to re-run. + :param action_name: Action within the prior run to source the task + inputs from (default `a0`). + :param task_template: Optional task to substitute for the prior run's code. + :param inputs: Optional native keyword inputs to change parameters; omit to reuse prior inputs. + :return: the new Run. + """ + return await _Runner().rerun.aio(run_name, action_name, task_template, inputs=inputs or None) diff --git a/src/flyte/cli/_rerun.py b/src/flyte/cli/_rerun.py new file mode 100644 index 000000000..875fba9f5 --- /dev/null +++ b/src/flyte/cli/_rerun.py @@ -0,0 +1,119 @@ +"""``flyte rerun `` — re-run an existing run with its own code + exact inputs. + +Counterpart to ``flyte run``: where ``run`` launches *local* code (and can recover from a prior +run via ``--recover-from``), ``rerun`` re-launches an *existing* run — fetching its task + inputs +from the platform, no local code needed. ``--recover`` reuses that run's succeeded actions. To +re-run with *new* local code (reusing the prior run's inputs), use ``flyte run +--rerun-from ``. + +v1 reuses the prior run's exact inputs; changing inputs from the CLI is a follow-up +(`flyte.rerun(run, x=2)` covers it programmatically today). +""" + +from __future__ import annotations + +import asyncio +from typing import Dict, Optional, Tuple + +import rich_click as click + +from . import _common as common + + +def _parse_kv(items: Tuple[str, ...], flag: str) -> Optional[Dict[str, str]]: + """Parse repeated ``KEY=VALUE`` flag values into a dict (None if none given).""" + if not items: + return None + parsed: Dict[str, str] = {} + for item in items: + if "=" not in item: + raise click.BadParameter(f"Invalid {flag} value {item!r}: expected KEY=VALUE.") + key, value = item.split("=", 1) + if not key: + raise click.BadParameter(f"Invalid {flag} value {item!r}: key must not be empty.") + parsed[key] = value + return parsed + + +@click.command("rerun", cls=click.RichCommand) +@click.argument("run_name", required=True) +@click.option("-p", "--project", default=None, help="Project for the new run (defaults to config).") +@click.option("-d", "--domain", default=None, help="Domain for the new run (defaults to config).") +@click.option("--name", default=None, help="Name for the new run (a random name is generated if unset).") +@click.option("-e", "--env", "env", multiple=True, help="Env var KEY=VALUE for the new run. Repeatable.") +@click.option("--label", "label", multiple=True, help="Label KEY=VALUE for the new run. Repeatable.") +@click.option("--follow", "-f", is_flag=True, default=False, help="Stream the parent action logs after launch.") +@click.option( + "--recover", + is_flag=True, + default=False, + help="Recover from this run: reuse its succeeded actions, re-run only what failed or changed.", +) +@click.pass_context +def rerun( + ctx: click.Context, + run_name: str, + project: Optional[str], + domain: Optional[str], + name: Optional[str], + env: Tuple[str, ...], + label: Tuple[str, ...], + follow: bool, + recover: bool, +) -> None: + """Re-run an existing run RUN_NAME with its original code and inputs. + + Fetches the prior run's task + inputs from the platform (no local code needed) and launches a + new run that returns the same way ``flyte run`` does. ``--recover`` reuses the prior run's + succeeded actions (re-running only what failed or changed). To re-run with *new* local code + (reusing the prior run's inputs), use ``flyte run --rerun-from ``. + + Examples: + + $ flyte rerun ul56wcvgqrb9vzhzz5l2 + $ flyte rerun ul56wcvgqrb9vzhzz5l2 --name retry-1 --follow + $ flyte rerun ul56wcvgqrb9vzhzz5l2 --recover + """ + config = common.initialize_config(ctx, project=project, domain=domain) + asyncio.run(_execute(run_name, name, env, label, follow, recover, config)) + + +async def _execute( + run_name: str, + name: Optional[str], + env: Tuple[str, ...], + label: Tuple[str, ...], + follow: bool, + recover: bool, + config: common.CLIConfig, +) -> None: + import flyte + from flyte._status import status + + console = common.get_console() + try: + status.step(f"Re-running {run_name}...") + runner = flyte.with_runcontext( + mode="remote", + name=name, + env_vars=_parse_kv(env, "--env"), + labels=_parse_kv(label, "--label"), + recover=recover, + ) + result = await runner.rerun.aio(run_name) + except Exception as e: + console.print(f"[red]✕ Re-run failed:[/red] {e}") + return + + if config.output_format in ("json", "table-simple"): + run_info = f"Created Run: {result.name}\nURL: {result.url}" + else: + run_info = ( + f"[green bold]Created Run: {result.name}[/green bold]\n" + f"➡️ [blue bold][link={result.url}]{result.url}[/link][/blue bold]" + ) + console.print(common.get_panel("Rerun", run_info, config.output_format)) + + if follow: + status.step("Waiting for log stream...") + await result.show_logs.aio(max_lines=30, show_ts=True, raw=False) diff --git a/src/flyte/cli/_run.py b/src/flyte/cli/_run.py index 5f062b2ee..62a4ea1ec 100644 --- a/src/flyte/cli/_run.py +++ b/src/flyte/cli/_run.py @@ -289,6 +289,30 @@ class RunArguments: ) }, ) + recover_from: str | None = field( + default=None, + metadata={ + "click.option": click.Option( + ["--recover-from"], + type=str, + default=None, + help="Recover a fresh run from a prior run: reuse its succeeded actions and re-run " + "only what failed or changed. Remote-only.", + ) + }, + ) + rerun_from: str | None = field( + default=None, + metadata={ + "click.option": click.Option( + ["--rerun-from"], + type=str, + default=None, + help="Re-run an existing run with THIS local code, reusing that run's inputs " + "(no per-task input flags are needed). Remote-only.", + ) + }, + ) @classmethod def from_dict(cls, d: Dict[str, Any]) -> RunArguments: @@ -401,8 +425,13 @@ async def _execute_and_render(self, ctx: click.Context, config: common.CLIConfig env_vars=self.run_args.parsed_env_vars(), max_action_concurrency=self.run_args.max_action_concurrency, labels=self.run_args.parsed_labels(), + recover=self.run_args.recover_from, ) - result = await execution_context.run.aio(self.obj, **ctx.params) + if self.run_args.rerun_from: + # Re-run a prior run with THIS local code, reusing the prior run's inputs. + result = await execution_context.rerun.aio(self.run_args.rerun_from, task_template=self.obj) + else: + result = await execution_context.run.aio(self.obj, **ctx.params) except Exception as e: if isinstance(e, RuntimeSystemError): capture_exception(e) @@ -480,6 +509,8 @@ def invoke(self, ctx: click.Context): tuple(self.run_args.image) or None, not self.run_args.no_sync_local_sys_paths, ) + if self.run_args.rerun_from and self.run_args.local: + raise click.UsageError("--rerun-from requires remote mode (it cannot be combined with --local)") self._validate_required_params(ctx) if self.run_args.tui: if not self.run_args.local: @@ -491,6 +522,11 @@ def invoke(self, ctx: click.Context): def get_params(self, ctx: click.Context) -> List[click.Parameter]: # Note this function may be called multiple times by click. + # With --rerun-from, inputs come from the prior run, so don't expose (or require) per-task + # input options. (Overriding specific inputs alongside --rerun-from is a follow-up.) + if self.run_args.rerun_from: + return super().get_params(ctx) + task = self.obj from .._internal.runtime.types_serde import transform_native_to_typed_interface @@ -626,6 +662,7 @@ async def _execute_and_render(self, ctx: click.Context, config: common.CLIConfig env_vars=self.run_args.parsed_env_vars(), max_action_concurrency=self.run_args.max_action_concurrency, labels=self.run_args.parsed_labels(), + recover=self.run_args.recover_from, ) result = await execution_context.run.aio(task, **ctx.params) except Exception as e: diff --git a/src/flyte/cli/main.py b/src/flyte/cli/main.py index 94853be7e..502d6d1b4 100644 --- a/src/flyte/cli/main.py +++ b/src/flyte/cli/main.py @@ -16,6 +16,7 @@ from ._get import get from ._plugins import discover_and_register_plugins from ._prefetch import prefetch +from ._rerun import rerun from ._run import run from ._serve import serve from ._signal import signal @@ -31,7 +32,7 @@ "flyte": [ { "name": "Run and stop tasks", - "commands": ["run", "abort", "signal"], + "commands": ["run", "rerun", "abort", "signal"], }, { "name": "Serve Apps", @@ -283,6 +284,7 @@ def main( main.add_command(run) +main.add_command(rerun) main.add_command(deploy) main.add_command(get) # type: ignore main.add_command(create) # type: ignore diff --git a/tests/cli/test_rerun.py b/tests/cli/test_rerun.py new file mode 100644 index 000000000..2c8a7de8e --- /dev/null +++ b/tests/cli/test_rerun.py @@ -0,0 +1,63 @@ +"""Tests for the `flyte rerun ` CLI command.""" + +from unittest import mock + +from click.testing import CliRunner +from mock.mock import AsyncMock + +from flyte.cli._rerun import _parse_kv, rerun +from flyte.cli.main import main + + +def test_rerun_registered_on_main(): + assert "rerun" in main.commands + + +def test_rerun_has_recover_flag(): + opts = {o for p in rerun.params for o in p.opts} + assert "--recover" in opts + # Takes the run name as a positional argument. + assert any(p.name == "run_name" for p in rerun.params) + + +def test_parse_kv(): + assert _parse_kv((), "--env") is None + assert _parse_kv(("A=1", "B=2"), "--env") == {"A": "1", "B": "2"} + + +def test_rerun_delegates_to_runner_rerun(): + """`flyte rerun --name n -e K=V` builds the run context and calls runner.rerun(run).""" + runner_obj = mock.MagicMock() + runner_obj.rerun = mock.MagicMock(return_value=mock.MagicMock()) + runner_obj.rerun.aio = AsyncMock(return_value=mock.MagicMock(name="new", url="http://x")) + + with ( + mock.patch("flyte.cli._common.initialize_config") as init_cfg, + mock.patch("flyte.with_runcontext", return_value=runner_obj) as wrc, + ): + init_cfg.return_value = mock.MagicMock(output_format="table") + result = CliRunner().invoke(rerun, ["my-run", "--name", "n", "-e", "K=V"]) + + assert result.exit_code == 0, result.output + # recover flag default False, env parsed, name forwarded. + _, kwargs = wrc.call_args + assert kwargs["recover"] is False + assert kwargs["name"] == "n" + assert kwargs["env_vars"] == {"K": "V"} + assert kwargs["mode"] == "remote" + runner_obj.rerun.aio.assert_awaited_once_with("my-run") + + +def test_rerun_recover_flag_passed_through(): + runner_obj = mock.MagicMock() + runner_obj.rerun.aio = AsyncMock(return_value=mock.MagicMock(name="new", url="http://x")) + + with ( + mock.patch("flyte.cli._common.initialize_config") as init_cfg, + mock.patch("flyte.with_runcontext", return_value=runner_obj) as wrc, + ): + init_cfg.return_value = mock.MagicMock(output_format="table") + result = CliRunner().invoke(rerun, ["my-run", "--recover"]) + + assert result.exit_code == 0, result.output + assert wrc.call_args.kwargs["recover"] is True diff --git a/tests/cli/test_run.py b/tests/cli/test_run.py index ab2df20da..c8214213a 100644 --- a/tests/cli/test_run.py +++ b/tests/cli/test_run.py @@ -52,6 +52,18 @@ def test_run_arguments_max_action_concurrency_from_dict(): assert RunArguments.from_dict({}).max_action_concurrency is None +def test_run_command_has_recover_from_option(): + option_names = {decl for p in run.params for decl in p.opts} + assert "--recover-from" in option_names + + +def test_run_arguments_recover_from_from_dict(): + from flyte.cli._run import RunArguments + + assert RunArguments.from_dict({"recover_from": "r1"}).recover_from == "r1" + assert RunArguments.from_dict({}).recover_from is None + + def test_run_max_action_concurrency_rejects_negative(runner): result = runner.invoke(run, ["--max-action-concurrency", "-1", str(HELLO_WORLD_PY), "say_hello"]) assert result.exit_code != 0 @@ -85,6 +97,57 @@ def test_run_hello_world(runner): raise ve +def test_run_command_has_rerun_from_option(): + """--rerun-from is a visible option on `flyte run` (not hidden — rerun works today).""" + opt_names = {decl for p in run.params for decl in p.opts} + assert "--rerun-from" in opt_names + rerun_opt = next(p for p in run.params if "--rerun-from" in p.opts) + assert rerun_opt.hidden is False + + +def test_run_rerun_from_routes_to_rerun(runner): + """`flyte run --rerun-from r` routes to runner.rerun(r, task_template=task). + + The required `name` input is NOT demanded — inputs come from the prior run. + """ + from unittest import mock + + from mock.mock import AsyncMock + + runner_obj = mock.MagicMock() + runner_obj.rerun.aio = AsyncMock(return_value=mock.MagicMock()) + runner_obj.run.aio = AsyncMock() + + with mock.patch("flyte.with_runcontext", return_value=runner_obj): + cmd = ["--rerun-from", "r1", "--project", "p", "--domain", "d", str(HELLO_WORLD_PY), "say_hello"] + try: + result = runner.invoke(run, cmd) + except ValueError as ve: + if "I/O operation on closed file" in str(ve): + return + raise + + assert result.exit_code == 0, result.output + runner_obj.rerun.aio.assert_awaited_once() + args, kwargs = runner_obj.rerun.aio.call_args + assert args[0] == "r1" + assert "task_template" in kwargs # this local say_hello task is passed as the substitute code + runner_obj.run.aio.assert_not_awaited() + + +def test_run_rerun_from_rejects_local(runner): + """--rerun-from cannot be combined with --local (rerun is remote-only).""" + cmd = ["--local", "--rerun-from", "r1", str(HELLO_WORLD_PY), "say_hello"] + try: + result = runner.invoke(run, cmd) + except ValueError as ve: + if "I/O operation on closed file" in str(ve): + return + raise + assert result.exit_code != 0 + assert "requires remote" in result.output.lower() + + @pytest.mark.integration def test_run_complex_inputs(runner): result = runner.invoke( diff --git a/tests/flyte/test_rerun.py b/tests/flyte/test_rerun.py new file mode 100644 index 000000000..8e482f37c --- /dev/null +++ b/tests/flyte/test_rerun.py @@ -0,0 +1,136 @@ +"""Unit tests for flyte.rerun (folded into _Runner): re-run a prior run by fetching its +RunSpec + task spec + inputs and resubmitting via the shared _submit_remote path.""" + +from types import SimpleNamespace + +import mock +import pytest +from flyteidl2.common import run_pb2 as common_run_pb2 +from flyteidl2.core import literals_pb2 +from flyteidl2.dataproxy import dataproxy_service_pb2 +from flyteidl2.task import common_pb2 as task_common_pb2 +from flyteidl2.task import run_pb2 +from flyteidl2.workflow import run_definition_pb2, run_service_pb2 +from mock.mock import AsyncMock, MagicMock + +import flyte +from flyte._initialize import _init_for_testing + + +def _mock_client_with_run(): + """Mock client whose create_run captures the request and get_action_data returns prior inputs.""" + mock_client = MagicMock() + mock_run_service = AsyncMock() + mock_client.run_service = mock_run_service + + mock_dataproxy = AsyncMock() + mock_dataproxy.upload_inputs.return_value = dataproxy_service_pb2.UploadInputsResponse( + offloaded_input_data=common_run_pb2.OffloadedInputData(uri="s3://b/inputs", inputs_hash="h"), + ) + # Prior run's raw proto inputs (what get_action_data returns). + prior_inputs = task_common_pb2.Inputs( + literals=[ + task_common_pb2.NamedLiteral( + name="v", + value=literals_pb2.Literal( + scalar=literals_pb2.Scalar(primitive=literals_pb2.Primitive(string_value="prior")) + ), + ) + ] + ) + mock_dataproxy.get_action_data.return_value = dataproxy_service_pb2.GetActionDataResponse(inputs=prior_inputs) + mock_client.dataproxy_service = mock_dataproxy + return mock_client, mock_run_service, mock_dataproxy, prior_inputs + + +def _fake_prior_run(base_envs=None): + """A stand-in RunDetails: prior RunSpec + a root action carrying a task spec.""" + base_run_spec = run_pb2.RunSpec( + envs=run_pb2.Envs(values=[literals_pb2.KeyValuePair(key="KEEP", value="1")] + (base_envs or [])), + cluster="orig", + ) + task_spec = run_definition_pb2.ActionDetails( + id=run_definition_pb2.ActionDetails().id, + task=_task_spec_with_string_input(), + ) + action_details = SimpleNamespace(pb2=task_spec) + run_details = SimpleNamespace( + pb2=SimpleNamespace(run_spec=base_run_spec), + action_details=action_details, + ) + return run_details + + +def _task_spec_with_string_input(): + """A minimal TaskSpec with one string input `v` and a version, for fetch + guess_interface.""" + from flyteidl2.core import identifier_pb2, interface_pb2, tasks_pb2, types_pb2 + from flyteidl2.task import task_definition_pb2 + + iface = interface_pb2.TypedInterface( + inputs=interface_pb2.VariableMap( + variables=[ + interface_pb2.VariableEntry( + key="v", + value=interface_pb2.Variable(type=types_pb2.LiteralType(simple=types_pb2.SimpleType.STRING)), + ) + ] + ) + ) + tmpl = tasks_pb2.TaskTemplate( + id=identifier_pb2.Identifier(name="test.task1", version="v1"), + interface=iface, + ) + return task_definition_pb2.TaskSpec(task_template=tmpl) + + +@pytest.mark.asyncio +async def test_rerun_same_inputs_inherits_runspec_and_reuses_prior_inputs(): + mock_client, mock_run_service, mock_dataproxy, prior_inputs = _mock_client_with_run() + await _init_for_testing(client=mock_client, project="test", domain="test") + + with mock.patch("flyte.remote._run.RunDetails") as RD: + RD.get.aio = AsyncMock(return_value=_fake_prior_run()) + run = await flyte.with_runcontext(mode="remote", env_vars={"X": "1"}).rerun.aio("r1") + + assert run + # Prior inputs reused verbatim (no conversion). + mock_dataproxy.get_action_data.assert_called_once() + upload_req = mock_dataproxy.upload_inputs.call_args[0][0] + assert upload_req.inputs == prior_inputs + + req: run_service_pb2.CreateRunRequest = mock_run_service.create_run.call_args[0][0] + envs = {kv.key: kv.value for kv in req.run_spec.envs.values} + assert envs["KEEP"] == "1" # inherited from prior run + assert envs["X"] == "1" # runner override merged in + assert req.run_spec.cluster == "orig" # inherited (queue not overridden) + assert req.WhichOneof("task") == "task_spec" + assert req.task_spec.task_template.id.name == "test.task1" + + +@pytest.mark.asyncio +async def test_rerun_changed_inputs_converts_against_fetched_interface(): + mock_client, _mock_run_service, mock_dataproxy, _ = _mock_client_with_run() + await _init_for_testing(client=mock_client, project="test", domain="test") + + with mock.patch("flyte.remote._run.RunDetails") as RD: + RD.get.aio = AsyncMock(return_value=_fake_prior_run()) + run = await flyte.with_runcontext(mode="remote").rerun.aio("r1", inputs={"v": "changed"}) + + assert run + # Changed inputs => no prior-input fetch; converted against the fetched interface. + mock_dataproxy.get_action_data.assert_not_called() + upload_req = mock_dataproxy.upload_inputs.call_args[0][0] + assert upload_req.inputs.literals[0].name == "v" + assert upload_req.inputs.literals[0].value.scalar.primitive.string_value == "changed" + + +@pytest.mark.asyncio +async def test_rerun_rejects_non_remote_mode(): + await flyte.init.aio() + with pytest.raises(NotImplementedError, match="remote mode"): + await flyte.with_runcontext(mode="local").rerun.aio("r1") + + +def test_replay_is_removed(): + """flyte.replay was deleted in favor of flyte.rerun.""" + assert not hasattr(flyte, "replay") diff --git a/tests/flyte/test_run_runspec_chars.py b/tests/flyte/test_run_runspec_chars.py new file mode 100644 index 000000000..4ed03045d --- /dev/null +++ b/tests/flyte/test_run_runspec_chars.py @@ -0,0 +1,378 @@ +"""Characterization tests pinning the exact ``RunSpec`` / ``CreateRunRequest`` the +remote run path builds from ``with_runcontext(...)``. + +These are a safety net for the run/rerun/recover/debug unification refactor: every +field that ``_Runner._run_remote`` serializes is asserted here so the extraction of +``_build_task_spec_from_template`` / ``_submit_remote`` / ``_apply_overrides`` cannot +silently change the wire. The combined ``test_runspec_all_fields_snapshot`` is the +byte-for-byte oracle; the per-field tests localize any regression. +""" + +import mock +import pytest +from flyteidl2.common import run_pb2 as common_run_pb2 +from flyteidl2.dataproxy import dataproxy_service_pb2 +from flyteidl2.task import run_pb2 +from flyteidl2.workflow import run_service_pb2 +from mock.mock import AsyncMock, MagicMock + +import flyte +from flyte._initialize import _init_for_testing +from flyte.models import CodeBundle + +env = flyte.TaskEnvironment(name="test") + + +@env.task +async def task1(v: str) -> str: + return f"Hello, world {v}!" + + +def _make_mock_client(): + """Mocked ClientSet with run + dataproxy services wired for create_run capture.""" + mock_client = MagicMock() + mock_run_service = AsyncMock() + mock_client.run_service = mock_run_service + + mock_dataproxy_service = AsyncMock() + mock_offloaded = common_run_pb2.OffloadedInputData(uri="s3://bucket/inputs", inputs_hash="abc123") + mock_dataproxy_service.upload_inputs.return_value = dataproxy_service_pb2.UploadInputsResponse( + offloaded_input_data=mock_offloaded, + ) + mock_client.dataproxy_service = mock_dataproxy_service + return mock_client, mock_run_service + + +def _patch_build(fn): + """Stack the image-build + code-bundle mocks shared by every remote-path test. + + The patch applied first (closest to the function) is injected as the first arg, so + build_code_bundle must wrap first to match the ``(mock_code_bundler, mock_build_image_bg)`` + signature used below. + """ + fn = mock.patch("flyte._code_bundle.build_code_bundle", new_callable=AsyncMock)(fn) + fn = mock.patch("flyte._deploy._build_image_bg", new_callable=AsyncMock)(fn) + return fn + + +async def _run_and_capture(mock_build_image_bg, mock_code_bundler, **runcontext_kwargs): + """Run task1 in remote mode with the given runcontext kwargs; return the CreateRunRequest.""" + mock_client, mock_run_service = _make_mock_client() + mock_code_bundler.return_value = CodeBundle(computed_version="v1", tgz="test.tgz") + mock_build_image_bg.return_value = (env.name, "image_name", None) + + await _init_for_testing(client=mock_client, project="test", domain="test") + run = await flyte.with_runcontext(mode="remote", **runcontext_kwargs).run.aio(task1, "hello") + assert run + req: run_service_pb2.CreateRunRequest = mock_run_service.create_run.call_args[0][0] + return req + + +def _envs_dict(req): + return {kv.key: kv.value for kv in req.run_spec.envs.values} + + +@pytest.mark.asyncio +@_patch_build +async def test_runspec_env_vars(mock_code_bundler, mock_build_image_bg): + """User env_vars land on RunSpec.envs, alongside the always-injected LOG_* keys.""" + req = await _run_and_capture(mock_build_image_bg, mock_code_bundler, env_vars={"FOO": "bar"}) + envs = _envs_dict(req) + assert envs["FOO"] == "bar" + # Always-injected keys (see _run.py:302-308). + assert "LOG_LEVEL" in envs + assert "LOG_FORMAT" in envs + + +@pytest.mark.asyncio +@_patch_build +async def test_runspec_debug_injects_f_e_vs(mock_code_bundler, mock_build_image_bg): + """debug=True injects the _F_E_VS env flag.""" + req = await _run_and_capture(mock_build_image_bg, mock_code_bundler, debug=True) + assert _envs_dict(req)["_F_E_VS"] == "1" + + +@pytest.mark.asyncio +@_patch_build +async def test_runspec_labels_and_annotations(mock_code_bundler, mock_build_image_bg): + req = await _run_and_capture( + mock_build_image_bg, + mock_code_bundler, + labels={"team": "ml"}, + annotations={"note": "exp"}, + ) + assert req.run_spec.labels.values["team"] == "ml" + assert req.run_spec.annotations.values["note"] == "exp" + + +@pytest.mark.asyncio +@_patch_build +async def test_runspec_queue_to_cluster(mock_code_bundler, mock_build_image_bg): + """queue= maps to RunSpec.cluster.""" + req = await _run_and_capture(mock_build_image_bg, mock_code_bundler, queue="gpu-queue") + assert req.run_spec.cluster == "gpu-queue" + + +@pytest.mark.asyncio +@_patch_build +async def test_runspec_interruptible(mock_code_bundler, mock_build_image_bg): + """interruptible is a BoolValue (set only when not None).""" + req = await _run_and_capture(mock_build_image_bg, mock_code_bundler, interruptible=True) + assert req.run_spec.HasField("interruptible") + assert req.run_spec.interruptible.value is True + + req2 = await _run_and_capture(mock_build_image_bg, mock_code_bundler) + assert not req2.run_spec.HasField("interruptible") + + +@pytest.mark.asyncio +@_patch_build +async def test_runspec_overwrite_cache(mock_code_bundler, mock_build_image_bg): + """overwrite_cache sets both the top-level field and cache_config.""" + req = await _run_and_capture(mock_build_image_bg, mock_code_bundler, overwrite_cache=True) + assert req.run_spec.overwrite_cache is True + assert req.run_spec.cache_config.overwrite_cache is True + + +@pytest.mark.asyncio +@_patch_build +async def test_runspec_cache_lookup_scope(mock_code_bundler, mock_build_image_bg): + """Default scope is global; project-domain maps to its enum.""" + req_global = await _run_and_capture(mock_build_image_bg, mock_code_bundler) + assert req_global.run_spec.cache_config.cache_lookup_scope == run_pb2.CacheLookupScope.CACHE_LOOKUP_SCOPE_GLOBAL + + req_pd = await _run_and_capture(mock_build_image_bg, mock_code_bundler, cache_lookup_scope="project-domain") + assert req_pd.run_spec.cache_config.cache_lookup_scope == run_pb2.CacheLookupScope.CACHE_LOOKUP_SCOPE_PROJECT_DOMAIN + + +@pytest.mark.asyncio +@_patch_build +async def test_runspec_service_account(mock_code_bundler, mock_build_image_bg): + """service_account maps to security_context.run_as.k8s_service_account.""" + req = await _run_and_capture(mock_build_image_bg, mock_code_bundler, service_account="my-sa") + assert req.run_spec.security_context.run_as.k8s_service_account == "my-sa" + + req_none = await _run_and_capture(mock_build_image_bg, mock_code_bundler) + assert not req_none.run_spec.HasField("security_context") + + +@pytest.mark.asyncio +@_patch_build +async def test_runspec_all_fields_snapshot(mock_code_bundler, mock_build_image_bg): + """Combined oracle: pin the full RunSpec for a fully-populated with_runcontext config. + + The unification refactor must reproduce this RunSpec byte-for-byte; treat any diff + as a regression, not a re-baseline. + """ + req = await _run_and_capture( + mock_build_image_bg, + mock_code_bundler, + env_vars={"FOO": "bar"}, + labels={"team": "ml"}, + annotations={"note": "exp"}, + queue="gpu-queue", + interruptible=True, + overwrite_cache=True, + cache_lookup_scope="project-domain", + service_account="my-sa", + max_action_concurrency=4, + ) + rs = req.run_spec + assert rs.overwrite_cache is True + assert rs.interruptible.value is True + assert rs.labels.values["team"] == "ml" + assert rs.annotations.values["note"] == "exp" + assert _envs_dict(req)["FOO"] == "bar" + assert rs.cluster == "gpu-queue" + assert rs.max_action_concurrency == 4 + assert rs.security_context.run_as.k8s_service_account == "my-sa" + assert rs.cache_config.overwrite_cache is True + assert rs.cache_config.cache_lookup_scope == run_pb2.CacheLookupScope.CACHE_LOOKUP_SCOPE_PROJECT_DOMAIN + # task_spec is sent inline (not by reference) for a locally-defined task. + assert req.WhichOneof("task") == "task_spec" + assert req.offloaded_input_data.uri == "s3://bucket/inputs" + assert not req.HasField("inputs") + + +# --- mode dispatch + hybrid validation ------------------------------------------------- +# Hybrid mode runs the parent locally and enqueues children via a controller; it never +# builds a CreateRunRequest, so the only refactor-relevant invariants are (a) run() routes +# to the right _run_* method per mode, and (b) the hybrid guardrails fire. Both are cheap +# and robust to pin without a live controller/storage. + + +def test_with_runcontext_hybrid_requires_name_and_run_base_dir(): + with pytest.raises(ValueError, match="hybrid"): + flyte.with_runcontext(mode="hybrid") + + +@pytest.mark.asyncio +async def test_run_dispatches_per_mode(): + """`run()` routes to _run_remote / _run_local / _run_hybrid based on the resolved mode.""" + from flyte._run import _Runner + + for mode, target in (("remote", "_run_remote"), ("local", "_run_local"), ("hybrid", "_run_hybrid")): + runner = _Runner(force_mode=mode, name="r", run_base_dir="s3://b/md") + with mock.patch.object(_Runner, target, new_callable=AsyncMock) as m: + m.return_value = object() + await runner.run.aio(task1, "hello") + m.assert_called_once() + + +# --- notifications wiring (moves into _apply_overrides) -------------------------------- + + +@pytest.mark.asyncio +@_patch_build +async def test_runspec_notifications(mock_code_bundler, mock_build_image_bg): + """Notifications resolve into notification_rule_name / notification_rules on RunSpec.""" + import flyte.notify + + req = await _run_and_capture( + mock_build_image_bg, + mock_code_bundler, + notifications=flyte.notify.Email(on_phase="failed", recipients=("a@b.com",)), + ) + # Exactly one of the two notification carriers is populated (depends on resolve output). + assert req.run_spec.notification_rule_name or len(req.run_spec.notification_rules.rules) > 0 + + +# --- ConnectError mapping (moves into _submit_remote) ---------------------------------- + + +@pytest.mark.asyncio +@_patch_build +async def test_create_run_already_exists_maps_to_user_error(mock_code_bundler, mock_build_image_bg): + """create_run ALREADY_EXISTS → RuntimeUserError (RunAlreadyExistsError).""" + from connectrpc.code import Code + from connectrpc.errors import ConnectError + + import flyte.errors + + mock_client, mock_run_service = _make_mock_client() + mock_code_bundler.return_value = CodeBundle(computed_version="v1", tgz="test.tgz") + mock_build_image_bg.return_value = (env.name, "image_name", None) + mock_run_service.create_run.side_effect = ConnectError(Code.ALREADY_EXISTS, "dup") + + await _init_for_testing(client=mock_client, project="test", domain="test") + with pytest.raises(flyte.errors.RuntimeUserError, match="already exists"): + await flyte.with_runcontext(mode="remote", name="dup-run").run.aio(task1, "hello") + + +@pytest.mark.asyncio +@_patch_build +async def test_create_run_unavailable_maps_to_system_error(mock_code_bundler, mock_build_image_bg): + """create_run UNAVAILABLE → RuntimeSystemError (SystemUnavailableError).""" + from connectrpc.code import Code + from connectrpc.errors import ConnectError + + import flyte.errors + + mock_client, mock_run_service = _make_mock_client() + mock_code_bundler.return_value = CodeBundle(computed_version="v1", tgz="test.tgz") + mock_build_image_bg.return_value = (env.name, "image_name", None) + mock_run_service.create_run.side_effect = ConnectError(Code.UNAVAILABLE, "down") + + await _init_for_testing(client=mock_client, project="test", domain="test") + with pytest.raises(flyte.errors.RuntimeSystemError): + await flyte.with_runcontext(mode="remote").run.aio(task1, "hello") + + +# --- dry-run path (stays in _run_remote) ----------------------------------------------- + + +@pytest.mark.asyncio +@_patch_build +async def test_dry_run_returns_dryrun_without_create_run(mock_code_bundler, mock_build_image_bg): + """dry_run=True returns a DryRun carrying the task_spec and never calls create_run.""" + mock_client, mock_run_service = _make_mock_client() + mock_code_bundler.return_value = CodeBundle(computed_version="v1", tgz="test.tgz") + mock_build_image_bg.return_value = (env.name, "image_name", None) + + await _init_for_testing(client=mock_client, project="test", domain="test") + run = await flyte.with_runcontext(mode="remote", dry_run=True).run.aio(task1, "hello") + + assert run is not None + assert run.task_spec is not None + mock_run_service.create_run.assert_not_called() + + +# --- _apply_overrides inherited path (the rerun seam, base != None) -------------------- + + +@pytest.mark.asyncio +async def test_apply_overrides_inherited_merges_env_and_keys(): + """base != None: deep-copy the prior RunSpec, overlay env by key, apply only set overrides.""" + from flyteidl2.core import literals_pb2 + from flyteidl2.task import run_pb2 + + from flyte._run import _Runner + + mock_client, _ = _make_mock_client() + await _init_for_testing(client=mock_client, project="test", domain="test") + + base = run_pb2.RunSpec( + envs=run_pb2.Envs( + values=[ + literals_pb2.KeyValuePair(key="KEEP", value="1"), + literals_pb2.KeyValuePair(key="FOO", value="old"), + ] + ), + labels=run_pb2.Labels(values={"base": "yes"}), + cluster="orig-cluster", + ) + + runner = _Runner(force_mode="remote", env_vars={"FOO": "new", "BAR": "2"}, labels={"team": "ml"}) + out = runner._apply_overrides(base) + + envs = {kv.key: kv.value for kv in out.envs.values} + assert envs["KEEP"] == "1" # prior key preserved + assert envs["FOO"] == "new" # runner override wins + assert envs["BAR"] == "2" # new key added + assert out.labels.values["base"] == "yes" # prior label preserved + assert out.labels.values["team"] == "ml" # runner label merged + assert out.cluster == "orig-cluster" # queue not set on runner -> inherited cluster kept + + # base is not mutated (deep copy). + assert {kv.key: kv.value for kv in base.envs.values} == {"KEEP": "1", "FOO": "old"} + + +@pytest.mark.asyncio +async def test_apply_overrides_recover_gated(): + """recover raises until flyteidl2 RunSpec.recover ships (field absent today).""" + from flyteidl2.task import run_pb2 + + from flyte._run import _Runner + + mock_client, _ = _make_mock_client() + await _init_for_testing(client=mock_client, project="test", domain="test") + + runner = _Runner(force_mode="remote") + + if "recover" in run_pb2.RunSpec.DESCRIPTOR.fields_by_name: + pytest.skip("RunSpec.recover is available; gating no longer applies") + with pytest.raises(NotImplementedError, match="recover is not yet supported"): + runner._apply_overrides(None, recover_ref="some-run") + + +def test_resolve_recover_ref_semantics(): + """recover=False/True/str resolve to the right reference (or raise on run()).""" + from flyte._run import _Runner + + # default False -> no recover + assert _Runner()._resolve_recover_ref("r1") is None + # True -> the run being rerun + assert _Runner(recover=True)._resolve_recover_ref("r1") == "r1" + # True with no rerun target (a plain run()) -> error + with pytest.raises(ValueError, match="recover=True is only valid with rerun"): + _Runner(recover=True)._resolve_recover_ref(None) + # explicit name -> that name (works on run()) + assert _Runner(recover="other")._resolve_recover_ref(None) == "other" + + +@pytest.mark.asyncio +async def test_recover_rejected_in_local_mode(): + """recover is remote-only; a truthy recover in local mode fails fast on run().""" + await flyte.init.aio() + with pytest.raises(ValueError, match="recover is only supported in remote mode"): + await flyte.with_runcontext(mode="local", recover="r1").run.aio(task1, "hello")