ChituDiffusion is a high-performance diffusion inference framework focused on video generation workloads. It provides a compact runtime for distributed inference, multiple attention backends, FlexCache and DiTango acceleration, and optional evaluation utilities.
The project is under active development. The current public interface is the
repository launcher in run.sh plus the configuration in system_config.yaml.
- Distributed diffusion inference with context parallelism and CFG parallelism.
- Attention backend selection for FlashAttention, SageAttention, and SpargeAttention.
- FlexCache strategies including TeaCache and PAB, plus independent DiTango planner/runtime acceleration.
- Optional low-memory execution modes.
- Built-in timing, logging, output naming, and evaluation helpers.
- Initial model support for Wan text-to-video models.
- Python 3.12 or newer
- CUDA-capable NVIDIA GPU
- CUDA and PyTorch versions compatible with the selected
pyproject.tomlpackage index uvis recommended for dependency management
Clone the repository and initialize optional submodules as needed:
git clone <repo-url>
cd ChituDiffusion
git submodule update --init --recursiveInstall the base environment:
uv syncOptional extras are available for acceleration and evaluation:
uv sync --extra sage
uv sync --extra sparge
uv sync --extra eval
uv sync --extra vbenchFor manual environments:
pip install -r requirements.txt
pip install -e .Edit system_config.yaml before running:
model:
name: Wan2.1-T2V-1.3B
ckpt_dir: /path/to/Wan2.1-T2V-1.3B
launch:
num_nodes: 1
gpus_per_node: 8
parallel:
cfp: 2
infer:
attn_type: flash_attn
enable_flexcache: true # required for TeaCache/PAB; DiTango is independentThe most important field is model.ckpt_dir; it must point to a local model
checkpoint directory.
Run generation through the single repository entry point:
bash run.sh system_config.yamlCommon overrides:
bash run.sh system_config.yaml --gpus-per-node 8 --cfp 2run.sh reads system_config.yaml, builds dotlist overrides, and launches
the configured Python entry through the runtime script.
The current configuration set includes:
Wan2.1-T2V-1.3BWan2.1-T2V-14BWan2.2-T2V-A14BFLUX.2-klein-4B
Model availability depends on the local checkpoint path and the corresponding configuration under the project config directory.
chitu_diffusion/core/ Configuration, schemas, distributed utilities, registry
chitu_diffusion/runtime/ Backend, generator, scheduler, task, main runtime API
chitu_diffusion/modules/ Model-specific and reusable diffusion modules
chitu_diffusion/flexcache/ Curvature FlexCache strategies and baselines
chitu_diffusion/ditango/ DiTango planner, runtime attention, visualization
chitu_diffusion/evaluation/ Evaluation manager, strategies, metric helpers
chitu_diffusion/observability/ Timing and magnitude logging helpers
script/ Launch helpers for local and Slurm execution
test/ Generation and acceleration test entry points
system_config.yaml Default runtime configuration
run.sh Main launch entry point
This layout will continue to be simplified as the project is prepared for public release.
Evaluation can be enabled from system_config.yaml:
eval:
eval_type: [psnr, lpips]
reference_path: /path/to/reference/videosAdditional metric dependencies are installed with:
uv sync --extra evalEach run writes to:
outputs/<tag>-<YYYYMMDD_HHMMSS>-<taskid>/
request_params.json
system_params.json
run_config.yaml
results/
<task_id>/
*.mp4
*.json
metrics/
timing/
summary.json
<task_id>.json
memory/
rank<N>.json
quality/
summary.json
logs/
command.log
run.log
run.rank<N>.log
<task_id>/
*.ppm
results/<task_id>/ contains generated media and sidecar metadata. metrics/
contains JSON-only timing, memory, and quality files in separate subdirectories.
Timing JSON includes aggregate timer stats; timers.dit_forward.total_ms is the
overall DiT forward time, and records.dit_forward_step stores per-timestep DiT
forward times. Memory JSON is grouped by rank, so model_loaded,
task_complete, and final events for the same rank live in one file.
output.memory toggles memory metrics. output.log_ranks controls which ranks
write memory metrics and Python logs. Quality JSON includes by_task_id groups
for multi-request runs. logs/ contains process logs and per-task debug
visualizations. command.log captures the full launch command output,
including run.sh, srun, wrapper output, and Python stdout/stderr.
Run a lightweight import check:
python - <<'PY'
import chitu_diffusion.core
from chitu_diffusion.runtime.task import DiffusionUserParams
from chitu_diffusion.observability import Timer
print("imports ok")
PYRun tests with:
pytest testSome tests require CUDA, local checkpoints, and distributed launch settings.
This project is licensed under the Apache License 2.0. See LICENSE for
details.