From 589b734fb3ebf5bc3eb3ce2a7d9b7958274bfc1e Mon Sep 17 00:00:00 2001 From: bghira Date: Wed, 27 May 2026 13:50:18 +0000 Subject: [PATCH] Support loading AnyFlow embedder and weights as PEFT LoRA --- demo.py | 29 ++- far/utils/lora_adapter.py | 111 +++++++++ scripts/extract_adapter_common.py | 331 +++++++++++++++++++++++++++ scripts/extract_anyflow_peft_lora.py | 94 ++++++++ scripts/extract_peft_lora.py | 174 ++++++++++++++ 5 files changed, 735 insertions(+), 4 deletions(-) create mode 100644 far/utils/lora_adapter.py create mode 100644 scripts/extract_adapter_common.py create mode 100755 scripts/extract_anyflow_peft_lora.py create mode 100755 scripts/extract_peft_lora.py diff --git a/demo.py b/demo.py index 1f478f8..c23236d 100644 --- a/demo.py +++ b/demo.py @@ -27,6 +27,7 @@ from far.models.transformer_far_wan_model import FAR_Wan_Transformer3DModel from far.pipelines.pipeline_far_wan_anyflow import FARWanAnyFlowPipeline from far.pipelines.pipeline_wan_anyflow import WanAnyFlowPipeline +from far.utils.lora_adapter import load_transformer_lora from far.utils.video_util import select_frame_indices from far.utils.vis_util import draw_rectangle @@ -43,10 +44,15 @@ class DemoConfig: task_type: str = 't2v' # Where to write demo_*.mp4. save_dir: str = MISSING + # Optional PEFT LoRA safetensors path. AnyFlow sidecar tensors are loaded when present. + lora_path: str | None = None + lora_adapter_name: str = 'anyflow' -def inference_causal_demo(model_path, task_type, save_dir): +def inference_causal_demo(model_path, task_type, save_dir, lora_path=None, lora_adapter_name='anyflow'): transformer = FAR_Wan_Transformer3DModel.from_pretrained(model_path, subfolder='transformer') + if lora_path: + load_transformer_lora(transformer, lora_path, adapter_name=lora_adapter_name) pipeline = FARWanAnyFlowPipeline.from_pretrained(model_path, transformer=transformer).to('cuda', dtype=torch.bfloat16) os.makedirs(save_dir, exist_ok=True) @@ -110,8 +116,10 @@ def inference_causal_demo(model_path, task_type, save_dir): raise NotImplementedError -def inference_bidirectional_demo(model_path, task_type, save_dir): +def inference_bidirectional_demo(model_path, task_type, save_dir, lora_path=None, lora_adapter_name='anyflow'): transformer = FAR_Wan_Transformer3DModel.from_pretrained(model_path, subfolder='transformer') + if lora_path: + load_transformer_lora(transformer, lora_path, adapter_name=lora_adapter_name) pipeline = WanAnyFlowPipeline.from_pretrained(model_path, transformer=transformer).to('cuda', dtype=torch.bfloat16) os.makedirs(save_dir, exist_ok=True) @@ -137,9 +145,21 @@ def inference_bidirectional_demo(model_path, task_type, save_dir): OmegaConf.from_cli(), ) if 'AnyFlow-FAR' in cfg.model_path: - inference_causal_demo(cfg.model_path, task_type=cfg.task_type, save_dir=cfg.save_dir) + inference_causal_demo( + cfg.model_path, + task_type=cfg.task_type, + save_dir=cfg.save_dir, + lora_path=cfg.lora_path, + lora_adapter_name=cfg.lora_adapter_name, + ) elif 'AnyFlow-Wan' in cfg.model_path: - inference_bidirectional_demo(cfg.model_path, task_type=cfg.task_type, save_dir=cfg.save_dir) + inference_bidirectional_demo( + cfg.model_path, + task_type=cfg.task_type, + save_dir=cfg.save_dir, + lora_path=cfg.lora_path, + lora_adapter_name=cfg.lora_adapter_name, + ) else: raise NotImplementedError @@ -150,6 +170,7 @@ def inference_bidirectional_demo(model_path, task_type, save_dir): model_path — Diffusers folder path. task_type — t2v | ti2v | tv2v (default: t2v). save_dir — Output directory for demo_*.mp4. + lora_path — Optional PEFT LoRA safetensors path. Example (from repository root): python demo.py \ diff --git a/far/utils/lora_adapter.py b/far/utils/lora_adapter.py new file mode 100644 index 0000000..f4fd469 --- /dev/null +++ b/far/utils/lora_adapter.py @@ -0,0 +1,111 @@ +# Copyright 2026 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections import OrderedDict +from pathlib import Path +from typing import Optional + +import torch +from safetensors import safe_open + + +ANYFLOW_SIDECAR_PREFIXES = ("condition_embedder.delta_embedder.",) + + +def enable_anyflow_time_conditioning(transformer, *, gate_value: float = 0.25, deltatime_type: str = "r") -> None: + if deltatime_type not in {"r", "t-r"}: + raise ValueError("AnyFlow deltatime_type must be 'r' or 't-r'.") + condition_embedder = getattr(transformer, "condition_embedder", None) + if condition_embedder is not None and hasattr(condition_embedder, "delta_embedder"): + return + if not hasattr(transformer, "setup_flowmap_model"): + raise ValueError("Transformer does not support AnyFlow time conditioning.") + transformer.register_to_config(gate_value=float(gate_value), deltatime_type=deltatime_type) + transformer.setup_flowmap_model() + + +def _copy_sidecar_tensor(transformer, key: str, tensor: torch.Tensor, *, prefix: str) -> bool: + sidecar_prefix = f"{prefix}." + if not key.startswith(sidecar_prefix): + return False + model_key = key.removeprefix(sidecar_prefix) + if not model_key.startswith(ANYFLOW_SIDECAR_PREFIXES): + return False + + model_state = transformer.state_dict() + try: + destination = model_state[model_key] + except KeyError as exc: + raise ValueError( + f"LoRA file contains AnyFlow sidecar tensor `{key}`, but the transformer does not have `{model_key}`. " + "Call enable_anyflow_time_conditioning(...) before loading this adapter." + ) from exc + if destination.shape != tensor.shape: + raise ValueError( + f"Shape mismatch for AnyFlow sidecar tensor `{key}`: " + f"model {tuple(destination.shape)} vs file {tuple(tensor.shape)}." + ) + destination.copy_(tensor.to(device=destination.device, dtype=destination.dtype)) + return True + + +def load_transformer_lora( + transformer, + lora_path: str | Path, + *, + adapter_name: str = "anyflow", + prefix: str = "transformer", + gate_value: Optional[float] = None, + deltatime_type: Optional[str] = None, +) -> None: + lora_path = Path(lora_path).expanduser() + lora_state = OrderedDict() + + with safe_open(lora_path, framework="pt", device="cpu") as handle: + metadata = handle.metadata() or {} + sidecar_gate = float(gate_value if gate_value is not None else metadata.get("anyflow_gate_value", 0.25)) + sidecar_deltatime = str(deltatime_type or metadata.get("anyflow_deltatime_type", "r")) + has_sidecar = metadata.get("simpletuner_anyflow_sidecar") == "true" or any( + "condition_embedder.delta_embedder." in key for key in handle.keys() + ) + if has_sidecar: + enable_anyflow_time_conditioning( + transformer, + gate_value=sidecar_gate, + deltatime_type=sidecar_deltatime, + ) + + for key in handle.keys(): + tensor = handle.get_tensor(key) + if has_sidecar and _copy_sidecar_tensor(transformer, key, tensor, prefix=prefix): + continue + if ".lora_A." in key or ".lora_B." in key or ".alpha" in key or ".lora_alpha" in key: + lora_state[key] = tensor + + ordered_lora_state = OrderedDict() + for key, tensor in lora_state.items(): + if ".lora_A." in key or ".lora_B." in key: + ordered_lora_state[key] = tensor + for key, tensor in lora_state.items(): + if key not in ordered_lora_state: + ordered_lora_state[key] = tensor + + if not ordered_lora_state: + raise ValueError(f"No PEFT LoRA tensors found in {lora_path}.") + transformer.load_lora_adapter(ordered_lora_state, prefix=prefix, adapter_name=adapter_name) + transformer.set_adapter(adapter_name) diff --git a/scripts/extract_adapter_common.py b/scripts/extract_adapter_common.py new file mode 100644 index 0000000..babed92 --- /dev/null +++ b/scripts/extract_adapter_common.py @@ -0,0 +1,331 @@ +#!/usr/bin/env python3 +"""Shared helpers for adapter extraction scripts.""" + +from __future__ import annotations + +import json +import re +from contextlib import ExitStack +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Iterable, Optional + +import torch +from safetensors.torch import safe_open, save_file + +WEIGHT_FILENAMES = ("diffusion_pytorch_model.safetensors", "model.safetensors") +INDEX_FILENAMES = ( + "diffusion_pytorch_model.safetensors.index.json", + "model.safetensors.index.json", +) + + +def normalize_subfolder(value: Optional[str]) -> Optional[str]: + if value is None: + return None + value = str(value).strip().strip("/") + if value.lower() in {"", ".", "none", "null"}: + return None + return value + + +def parse_csv(value: Optional[str]) -> list[str]: + if value in (None, "", "none", "None"): + return [] + return [part.strip() for part in str(value).split(",") if part.strip()] + + +def dtype_from_name(name: str) -> torch.dtype: + mapping = { + "float32": torch.float32, + "fp32": torch.float32, + "float16": torch.float16, + "fp16": torch.float16, + "bfloat16": torch.bfloat16, + "bf16": torch.bfloat16, + } + try: + return mapping[name.lower()] + except KeyError as exc: + raise ValueError(f"Unsupported dtype `{name}`. Use float32, float16, or bfloat16.") from exc + + +@dataclass +class TensorSource: + label: str + files_by_key: dict[str, Path] + _exit_stack: Optional[ExitStack] = field(default=None, init=False, repr=False) + _handles: Optional[dict[Path, Any]] = field(default=None, init=False, repr=False) + + def __enter__(self) -> "TensorSource": + self._ensure_open() + return self + + def __exit__(self, exc_type, exc, traceback) -> None: + self.close() + + @property + def keys(self) -> set[str]: + return set(self.files_by_key.keys()) + + def close(self) -> None: + if self._exit_stack is not None: + self._exit_stack.close() + self._exit_stack = None + self._handles = None + + def _ensure_open(self) -> None: + if self._exit_stack is None: + self._exit_stack = ExitStack() + self._handles = {} + + def _handle_for_path(self, path: Path): + self._ensure_open() + assert self._exit_stack is not None + assert self._handles is not None + handle = self._handles.get(path) + if handle is None: + handle = self._exit_stack.enter_context(safe_open(path, framework="pt", device="cpu")) + self._handles[path] = handle + return handle + + def get_tensor(self, key: str) -> torch.Tensor: + try: + path = self.files_by_key[key] + except KeyError as exc: + raise KeyError(f"{self.label} does not contain tensor `{key}`.") from exc + return self._handle_for_path(path).get_tensor(key) + + +def _index_from_safetensor_files(label: str, files: Iterable[Path]) -> TensorSource: + files_by_key: dict[str, Path] = {} + for file_path in files: + with safe_open(file_path, framework="pt", device="cpu") as handle: + for key in handle.keys(): + if key in files_by_key: + raise ValueError(f"Duplicate tensor key `{key}` found while reading {label}.") + files_by_key[key] = file_path + if not files_by_key: + raise ValueError(f"No tensors found in {label}.") + return TensorSource(label=label, files_by_key=files_by_key) + + +def _source_from_index_file(label: str, index_path: Path, component_dir: Path) -> TensorSource: + with index_path.open("r", encoding="utf-8") as handle: + payload = json.load(handle) + weight_map = payload.get("weight_map") + if not isinstance(weight_map, dict): + raise ValueError(f"Index file {index_path} is missing a `weight_map` object.") + + files_by_key: dict[str, Path] = {} + for key, rel_file in weight_map.items(): + rel_path = Path(rel_file) + candidate = component_dir / rel_path + if not candidate.exists(): + candidate = component_dir / rel_path.name + if not candidate.exists(): + raise FileNotFoundError(f"Index {index_path} references missing safetensors shard {rel_file}.") + files_by_key[key] = candidate + return TensorSource(label=label, files_by_key=files_by_key) + + +def _source_from_local_dir(label: str, root: Path, subfolder: Optional[str]) -> TensorSource: + component_dir = root / subfolder if subfolder else root + if not component_dir.is_dir(): + raise FileNotFoundError(f"Component directory not found: {component_dir}") + + for index_name in INDEX_FILENAMES: + index_path = component_dir / index_name + if index_path.exists(): + return _source_from_index_file(label, index_path, component_dir) + + for filename in WEIGHT_FILENAMES: + candidate = component_dir / filename + if candidate.exists(): + return _index_from_safetensor_files(label, [candidate]) + + safetensors_files = sorted(component_dir.glob("*.safetensors")) + if len(safetensors_files) == 1: + return _index_from_safetensor_files(label, safetensors_files) + if len(safetensors_files) > 1: + raise ValueError( + f"Multiple safetensors files found in {component_dir}, but no supported index file exists. " + f"Expected one of: {', '.join(INDEX_FILENAMES)}." + ) + raise FileNotFoundError(f"No supported safetensors weights found in {component_dir}.") + + +def _source_from_hub_repo( + label: str, + repo_id: str, + subfolder: Optional[str], + revision: Optional[str], + cache_dir: Optional[str], +) -> TensorSource: + try: + from huggingface_hub import hf_hub_download, list_repo_files + except ImportError as exc: + raise ImportError("huggingface_hub is required to read remote model repositories.") from exc + + repo_files = set(list_repo_files(repo_id, revision=revision)) + prefix = f"{subfolder}/" if subfolder else "" + + for index_name in INDEX_FILENAMES: + index_repo_path = f"{prefix}{index_name}" + if index_repo_path not in repo_files: + continue + index_path = Path(hf_hub_download(repo_id, index_repo_path, revision=revision, cache_dir=cache_dir)) + with index_path.open("r", encoding="utf-8") as handle: + payload = json.load(handle) + weight_map = payload.get("weight_map") + if not isinstance(weight_map, dict): + raise ValueError(f"Remote index {index_repo_path} is missing a `weight_map` object.") + + downloaded: dict[str, Path] = {} + files_by_key: dict[str, Path] = {} + for key, rel_file in weight_map.items(): + repo_file = rel_file if rel_file in repo_files else f"{prefix}{Path(rel_file).name}" + if repo_file not in repo_files: + raise FileNotFoundError(f"Remote index {index_repo_path} references missing shard {rel_file}.") + if repo_file not in downloaded: + downloaded[repo_file] = Path(hf_hub_download(repo_id, repo_file, revision=revision, cache_dir=cache_dir)) + files_by_key[key] = downloaded[repo_file] + return TensorSource(label=label, files_by_key=files_by_key) + + for filename in WEIGHT_FILENAMES: + repo_file = f"{prefix}{filename}" + if repo_file in repo_files: + path = Path(hf_hub_download(repo_id, repo_file, revision=revision, cache_dir=cache_dir)) + return _index_from_safetensor_files(label, [path]) + + candidates = sorted(file for file in repo_files if file.startswith(prefix) and file.endswith(".safetensors")) + if len(candidates) == 1: + path = Path(hf_hub_download(repo_id, candidates[0], revision=revision, cache_dir=cache_dir)) + return _index_from_safetensor_files(label, [path]) + if len(candidates) > 1: + raise ValueError( + f"Remote repository {repo_id} has multiple safetensors files under `{prefix}` but no supported index file." + ) + raise FileNotFoundError(f"No supported safetensors weights found in {repo_id} under `{prefix}`.") + + +def resolve_tensor_source( + model_ref: str, + *, + label: str, + subfolder: Optional[str], + revision: Optional[str] = None, + cache_dir: Optional[str] = None, +) -> TensorSource: + expanded = Path(model_ref).expanduser() + subfolder = normalize_subfolder(subfolder) + if expanded.is_file(): + if expanded.suffix != ".safetensors": + raise ValueError(f"Only .safetensors files are supported as direct file inputs: {expanded}") + return _index_from_safetensor_files(label, [expanded]) + if expanded.is_dir(): + return _source_from_local_dir(label, expanded, subfolder) + if str(model_ref).endswith(".safetensors"): + raise FileNotFoundError(f"Safetensors file not found: {expanded}") + return _source_from_hub_repo(label, model_ref, subfolder, revision, cache_dir) + + +def key_matches_module(key: str, target_modules: list[str]) -> bool: + if not target_modules: + return True + module_name = key.removesuffix(".weight") + return any(module_name == target or module_name.endswith(f".{target}") for target in target_modules) + + +def normalize_target_modules(raw: str) -> list[str]: + value = raw.strip() + if value == "all-linear": + return [] + if value == "default": + return ["to_q", "to_k", "to_v", "to_out.0"] + return parse_csv(value) + + +def should_extract_key( + key: str, + tensor: torch.Tensor, + *, + target_modules: list[str], + include: Optional[re.Pattern[str]], + exclude: Optional[re.Pattern[str]], + include_conv: bool = False, +) -> bool: + if not key.endswith(".weight"): + return False + if tensor.ndim == 2: + pass + elif include_conv and tensor.ndim in {3, 4, 5}: + pass + else: + return False + if not key_matches_module(key, target_modules): + return False + if include is not None and include.search(key) is None: + return False + if exclude is not None and exclude.search(key) is not None: + return False + return True + + +def compile_optional_regex(value: Optional[str]) -> Optional[re.Pattern[str]]: + if value in (None, "", "none", "None"): + return None + return re.compile(str(value)) + + +def svd_low_rank( + delta: torch.Tensor, + *, + rank: int, + alpha: float, + device: str, +) -> tuple[torch.Tensor, torch.Tensor]: + if rank <= 0: + raise ValueError("rank must be greater than zero.") + if alpha <= 0: + raise ValueError("alpha must be greater than zero.") + + original_shape = tuple(delta.shape) + if delta.ndim == 2: + out_dim, in_dim = delta.shape + flat = delta + down_shape = (rank, in_dim) + up_shape = (out_dim, rank) + elif delta.ndim in {3, 4, 5}: + out_dim = delta.shape[0] + in_dim = delta.shape[1] + kernel_shape = tuple(delta.shape[2:]) + flat = delta.reshape(out_dim, -1) + down_shape = (rank, in_dim, *kernel_shape) + up_shape = (out_dim, rank, *([1] * len(kernel_shape))) + else: + raise ValueError(f"Cannot decompose tensor with shape {original_shape}.") + + matrix = flat.to(device=device, dtype=torch.float32) + max_rank = min(matrix.shape) + effective_rank = min(rank, max_rank) + u, s, vh = torch.linalg.svd(matrix, full_matrices=False) + scale_correction = float(rank) / float(alpha) + + down_flat = torch.zeros((rank, matrix.shape[1]), device=device, dtype=torch.float32) + up = torch.zeros((matrix.shape[0], rank), device=device, dtype=torch.float32) + down_flat[:effective_rank, :] = vh[:effective_rank, :] + up[:, :effective_rank] = u[:, :effective_rank] * s[:effective_rank].unsqueeze(0) * scale_correction + + return down_flat.reshape(down_shape).cpu(), up.reshape(up_shape).cpu() + + +def save_safetensors_with_metadata(state_dict: dict[str, torch.Tensor], output: str, metadata: dict[str, str]) -> Path: + output_path = Path(output).expanduser() + if output_path.suffix != ".safetensors": + output_path.mkdir(parents=True, exist_ok=True) + output_path = output_path / "pytorch_lora_weights.safetensors" + else: + output_path.parent.mkdir(parents=True, exist_ok=True) + save_file(state_dict, str(output_path), metadata=metadata) + return output_path diff --git a/scripts/extract_anyflow_peft_lora.py b/scripts/extract_anyflow_peft_lora.py new file mode 100755 index 0000000..e2f7c62 --- /dev/null +++ b/scripts/extract_anyflow_peft_lora.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +"""Extract a PEFT LoRA and embed AnyFlow-only conditioning weights.""" + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + +from safetensors import safe_open +from safetensors.torch import save_file + +from extract_adapter_common import resolve_tensor_source +from extract_peft_lora import build_parser as build_peft_parser +from extract_peft_lora import extract as extract_peft + + +def build_parser() -> argparse.ArgumentParser: + parser = build_peft_parser() + parser.description = ( + "Extract a PEFT LoRA from a base Wan Diffusers model to an AnyFlow target and store the " + "AnyFlow-only delta_embedder tensors in the same safetensors file." + ) + parser.add_argument( + "--anyflow-sidecar", + action="store_true", + help="Append target condition_embedder.delta_embedder tensors for AnyFlow time conditioning.", + ) + parser.add_argument("--anyflow-gate-value", type=float, default=0.25, help="AnyFlow delta embedding gate value.") + parser.add_argument( + "--anyflow-deltatime-type", + choices=("r", "t-r"), + default="r", + help="AnyFlow delta timestep mode.", + ) + return parser + + +def append_anyflow_sidecar(args: argparse.Namespace, output_path: Path) -> None: + state = {} + with safe_open(output_path, framework="pt", device="cpu") as handle: + metadata = dict(handle.metadata() or {}) + for key in handle.keys(): + state[key] = handle.get_tensor(key) + + target_subfolder = args.target_subfolder or args.component_subfolder + target = resolve_tensor_source( + args.target_model, + label="target", + subfolder=target_subfolder, + revision=args.target_revision, + cache_dir=args.cache_dir, + ) + + sidecar_keys = [] + prefix = f"{args.prefix}." if args.prefix else "" + with target: + for key in sorted(target.keys): + if not key.startswith("condition_embedder.delta_embedder."): + continue + out_key = f"{prefix}{key}" + state[out_key] = target.get_tensor(key) + sidecar_keys.append(out_key) + + if not sidecar_keys: + raise ValueError( + f"No condition_embedder.delta_embedder tensors found in target model `{args.target_model}`. " + "Use an AnyFlow target checkpoint when --anyflow-sidecar is set." + ) + + metadata.update( + { + "simpletuner_anyflow_sidecar": "true", + "anyflow_gate_value": str(args.anyflow_gate_value), + "anyflow_deltatime_type": args.anyflow_deltatime_type, + "anyflow_sidecar_keys": ",".join(sidecar_keys), + "anyflow_sidecar_source": args.target_model, + } + ) + save_file(state, str(output_path), metadata=metadata) + print(f"Appended {len(sidecar_keys)} AnyFlow sidecar tensors to {output_path}.") + + +def main() -> None: + parser = build_parser() + args = parser.parse_args() + output_path = extract_peft(args) + if args.anyflow_sidecar: + append_anyflow_sidecar(args, output_path) + + +if __name__ == "__main__": + sys.path.insert(0, str(Path(__file__).resolve().parent)) + main() diff --git a/scripts/extract_peft_lora.py b/scripts/extract_peft_lora.py new file mode 100755 index 0000000..0d40204 --- /dev/null +++ b/scripts/extract_peft_lora.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 +"""Extract a PEFT LoRA approximation from two safetensors model components.""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +import torch +from extract_adapter_common import ( + compile_optional_regex, + dtype_from_name, + normalize_subfolder, + normalize_target_modules, + resolve_tensor_source, + save_safetensors_with_metadata, + should_extract_key, + svd_low_rank, +) + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Extract a PEFT LoRA from the weight delta between a base model component and a target model component. " + "Inputs may be local .safetensors files, local Diffusers folders, or remote Hugging Face Diffusers repos." + ) + ) + parser.add_argument("base_model", help="Base model component: .safetensors, Diffusers folder, or HF repo id.") + parser.add_argument("target_model", help="Target model component: .safetensors, Diffusers folder, or HF repo id.") + parser.add_argument("output", help="Output .safetensors path or directory.") + parser.add_argument("--rank", type=int, required=True, help="LoRA rank to extract.") + parser.add_argument("--alpha", type=float, default=None, help="LoRA alpha. Defaults to rank.") + parser.add_argument("--algorithm", choices=("svd",), default="svd", help="Low-rank extraction algorithm.") + parser.add_argument( + "--component-subfolder", + default="transformer", + help="Diffusers component subfolder for both models. Use 'none' for direct component folders.", + ) + parser.add_argument("--base-subfolder", default=None, help="Override component subfolder for the base model.") + parser.add_argument("--target-subfolder", default=None, help="Override component subfolder for the target model.") + parser.add_argument("--base-revision", default=None, help="Optional HF revision for the base model.") + parser.add_argument("--target-revision", default=None, help="Optional HF revision for the target model.") + parser.add_argument("--cache-dir", default=None, help="Optional Hugging Face cache directory.") + parser.add_argument( + "--prefix", + default="transformer", + help="State-dict prefix expected by SimpleTuner's init_lora loader.", + ) + parser.add_argument( + "--target-modules", + default="default", + help=( + "Comma-separated module suffixes to extract, 'default' for to_q,to_k,to_v,to_out.0, " + "or 'all-linear' for every linear weight." + ), + ) + parser.add_argument("--include", default=None, help="Optional regex that tensor keys must match.") + parser.add_argument("--exclude", default=None, help="Optional regex for tensor keys to skip.") + parser.add_argument("--device", default="cpu", help="Device used for SVD, e.g. cpu, cuda, mps.") + parser.add_argument("--dtype", default="float16", help="Output dtype: float32, float16, or bfloat16.") + parser.add_argument( + "--min-delta-norm", + type=float, + default=0.0, + help="Skip tensors whose delta L2 norm is less than or equal to this value.", + ) + parser.add_argument( + "--skip-mismatched", + action="store_true", + help="Skip common tensor keys whose shapes differ instead of raising an error.", + ) + return parser + + +def extract(args: argparse.Namespace) -> Path: + alpha = float(args.rank if args.alpha is None else args.alpha) + dtype = dtype_from_name(args.dtype) + component_subfolder = normalize_subfolder(args.component_subfolder) + base_subfolder = normalize_subfolder(args.base_subfolder) or component_subfolder + target_subfolder = normalize_subfolder(args.target_subfolder) or component_subfolder + target_modules = normalize_target_modules(args.target_modules) + include = compile_optional_regex(args.include) + exclude = compile_optional_regex(args.exclude) + + base = resolve_tensor_source( + args.base_model, + label="base", + subfolder=base_subfolder, + revision=args.base_revision, + cache_dir=args.cache_dir, + ) + target = resolve_tensor_source( + args.target_model, + label="target", + subfolder=target_subfolder, + revision=args.target_revision, + cache_dir=args.cache_dir, + ) + + state_dict: dict[str, torch.Tensor] = {} + skipped_shape = 0 + skipped_filter = 0 + skipped_zero = 0 + with base, target: + common_keys = sorted(base.keys & target.keys) + if not common_keys: + raise ValueError("The base and target sources do not share any tensor keys.") + + for key in common_keys: + base_tensor = base.get_tensor(key) + target_tensor = target.get_tensor(key) + if base_tensor.shape != target_tensor.shape: + if args.skip_mismatched: + skipped_shape += 1 + continue + raise ValueError( + f"Shape mismatch for `{key}`: base {tuple(base_tensor.shape)} " f"vs target {tuple(target_tensor.shape)}" + ) + if not should_extract_key( + key, + base_tensor, + target_modules=target_modules, + include=include, + exclude=exclude, + include_conv=False, + ): + skipped_filter += 1 + continue + + delta = target_tensor.to(dtype=torch.float32) - base_tensor.to(dtype=torch.float32) + if args.min_delta_norm > 0: + delta_norm = torch.linalg.vector_norm(delta).item() + if delta_norm <= args.min_delta_norm: + skipped_zero += 1 + continue + + down, up = svd_low_rank(delta, rank=args.rank, alpha=alpha, device=args.device) + module_name = key.removesuffix(".weight") + lora_prefix = f"{args.prefix}.{module_name}" if args.prefix else module_name + state_dict[f"{lora_prefix}.lora_A.weight"] = down.to(dtype=dtype).contiguous() + state_dict[f"{lora_prefix}.lora_B.weight"] = up.to(dtype=dtype).contiguous() + state_dict[f"{lora_prefix}.alpha"] = torch.tensor(alpha, dtype=dtype) + + if not state_dict: + raise ValueError("No tensors were extracted. Check --target-modules, --include/--exclude, and --min-delta-norm.") + + metadata = { + "format": "simpletuner-peft-lora-extract", + "algorithm": args.algorithm, + "rank": str(args.rank), + "alpha": str(alpha), + "base_model": args.base_model, + "target_model": args.target_model, + "prefix": args.prefix, + "target_modules": json.dumps(target_modules or "all-linear"), + "component_subfolder": component_subfolder or "", + } + output_path = save_safetensors_with_metadata(state_dict, args.output, metadata) + print( + f"Saved {len(state_dict) // 3} PEFT LoRA modules to {output_path} " + f"(skipped: filtered={skipped_filter}, zero={skipped_zero}, shape={skipped_shape})." + ) + return output_path + + +def main() -> None: + args = build_parser().parse_args() + extract(args) + + +if __name__ == "__main__": + main()