diff --git a/.gitmodules b/.gitmodules index 51d8eac03..6ebcea592 100644 --- a/.gitmodules +++ b/.gitmodules @@ -6,3 +6,6 @@ path = text_to_image/torchtitan url = https://github.com/pytorch/torchtitan.git branch = mlperf-training-flux.1 +[submodule "recommendation_v4/cutlass"] + path = recommendation_v4/generative_recommenders/ops/cpp/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/recommendation_v4/.gitignore b/recommendation_v4/.gitignore new file mode 100644 index 000000000..5edddc5b3 --- /dev/null +++ b/recommendation_v4/.gitignore @@ -0,0 +1,159 @@ +# Don't check in parsed data files and other temporary files +tmp/ +exps/ +ckpts/ +results/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ diff --git a/recommendation_v4/Dockerfile b/recommendation_v4/Dockerfile new file mode 100644 index 000000000..450a5ab55 --- /dev/null +++ b/recommendation_v4/Dockerfile @@ -0,0 +1,86 @@ +# MI350X path — implements docs/training_recipe.md §"MI350X". + +FROM rocm/primus:v26.3 + +ENV PYTHONUNBUFFERED=1 \ + PIP_NO_CACHE_DIR=1 \ + PIP_DISABLE_PIP_VERSION_CHECK=1 + +WORKDIR /workspace/recommendation_v4 + +# torch / torchvision / torchaudio — training_recipe.md:38-40. +RUN pip install --upgrade --no-deps \ + --index-url https://download.pytorch.org/whl/rocm7.2 \ + torch==2.12.0+rocm7.2 \ + torchvision==0.27.0+rocm7.2 \ + torchaudio==2.11.0+rocm7.2 + +# torchrec — training_recipe.md:43. +RUN pip install --force-reinstall --no-deps \ + "git+https://github.com/pytorch/torchrec.git@v2026.06.01.00" + +# fbgemm_gpu — training_recipe.md:42. Build from FBGEMM commit 10b77573 for +# gfx950 against the replaced torch. ~30-60 min. +RUN apt-get update && apt-get install -y --no-install-recommends git build-essential && \ + rm -rf /var/lib/apt/lists/* && \ + git clone --recursive https://github.com/pytorch/FBGEMM.git /tmp/FBGEMM && \ + cd /tmp/FBGEMM && \ + git checkout 10b775730212923f65f7b78f79b6a01d80cf3c29 && \ + git submodule update --init --recursive && \ + cd fbgemm_gpu && \ + # Filter `fairscale` and the torch family from fbgemm's requirements.txt: + # fairscale pulls a CPU torch that would clobber the +rocm7.2 wheel installed + # above. fairscale is a distributed-training lib used by fbgemm tests, not + # by the build itself. + grep -v -E '^(fairscale|torch|torchvision|torchaudio)([<>=!]|$)' requirements.txt > /tmp/req.txt && \ + pip install -r /tmp/req.txt && \ + python setup.py -j 32 bdist_wheel \ + --build-target=default \ + --build-variant=rocm \ + -DHIP_ROOT_DIR=/opt/rocm \ + -DAMDGPU_TARGETS=gfx950 && \ + pip install --force-reinstall --no-deps dist/fbgemm_gpu_nightly_rocm*.whl && \ + cd / && rm -rf /tmp/FBGEMM + +# polars-u64-idx — training_recipe.md:44 (mandatory; yambda-5b > 4.29 B rows). +# Remaining packages — training_recipe.md:156-159 ("Additional Python deps") plus +# `datasets` + `huggingface_hub`, which the recipe does not list but +# preprocess_public_data.py:278 imports to download yambda from HuggingFace. +RUN pip install \ + polars-u64-idx==1.33.1 \ + gin-config \ + absl-py \ + datasets \ + huggingface_hub \ + pyre-extensions \ + iopath \ + typing-inspect \ + psutil \ + tqdm \ + pyyaml \ + lightning-utilities && \ + # torchmetrics and tensordict declare `torch` as a dep; without --no-deps + # pip pulls torch==2.12.0+cu130 from PyPI which clobbers the +rocm7.2 wheel + # we installed above (libtorch_hip.so disappears, fbgemm_gpu fails to load). + pip install --no-deps \ + torchmetrics==1.0.3 \ + tensordict + +# mlperf_logging — required by train/mlperf_logging_utils.py for MLPerf +# compliance logs. Pinned to the Training 6.0 tag for reproducibility; --no-deps +# so pip does not resolve requirements.txt's torch/fbgemm_gpu/torchrec pins and +# clobber the +rocm7.2 wheels above. +RUN pip install --no-deps "git+https://github.com/mlcommons/logging.git@6.0.0-rc6" + +# Smoke-test the 6 imports the launch script checks at +# scripts/launch_smoke_8gpu.sh:26. +RUN python -c "import torch, fbgemm_gpu, torchrec, polars, xxhash, gin; \ +print('torch', torch.__version__, '| hip', getattr(torch.version, 'hip', None))" + +COPY . /workspace/recommendation_v4 + +ENV PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \ + HSTU_HAMMER_KERNEL=TRITON \ + DLRM_DATA_PATH=/data/mlperf_dlrm_v4 + +CMD ["bash"] diff --git a/recommendation_v4/LICENSE b/recommendation_v4/LICENSE new file mode 100644 index 000000000..d64569567 --- /dev/null +++ b/recommendation_v4/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/recommendation_v4/README.MD b/recommendation_v4/README.MD new file mode 100644 index 000000000..e078bcf0a --- /dev/null +++ b/recommendation_v4/README.MD @@ -0,0 +1,380 @@ +# Recommendation v4 — HSTU sequential recommendation (Yambda-5b) + +MLPerf Training reference benchmark. This is a fork of +[meta-recsys/generative-recommenders](https://github.com/meta-recsys/generative-recommenders) +extended to train an HSTU (Hierarchical Sequential Transduction Units) ranking +model on the [Yambda-5b](https://huggingface.co/datasets/yandex/yambda) +music-recommendation dataset, sized as an MLPerf-style training benchmark inside +the `mlcommons/training` tree. + +## 1. Summary + +This benchmark trains a model that predicts what a person will listen to next. +Given the history of songs a user has played, liked, or skipped, the model +learns to rank which song the user is most likely to genuinely listen to (rather +than skip) next. This is the same kind of "what should we recommend next?" +problem that powers music and video streaming feeds. The model is trained on a +large public dataset of anonymized music-listening events and is scored on how +well it predicts future listens it has never seen. + +## 2. Benchmark overview (technical) + +The model is a **sequential recommender**: instead of treating each interaction +independently (as classic click-through-rate models like DLRM-DCNv2 do), it +consumes a user's chronologically ordered interaction history as a sequence and +applies a Transformer-style attention stack (HSTU) over it. Each training +example is one "anchor" listen event together with that user's prior history +(user interaction history, or UIH) and a set of contextual/cross features. The +supervised target is a binary `listen_plus` label (a real listen: played for at +least 50% of the track) versus a skip. + +Training is **streaming / temporal-order**: the timeline is sliced into +fixed-duration windows and the model trains on window `T` then evaluates on the +strictly-future window `T+1`, so every reported metric is genuine +next-period generalization with no future leakage. The quality metric is +**AUC** on the held-out future window, and the convergence target is +**AUC >= 0.80275** (matching the DLRM-DCNv2-style target). + +The reference runs on 8 GPUs (validated on AMD Instinct MI350X / MI355X and +NVIDIA B200; see [docs/training_recipe.md](docs/training_recipe.md)) and scales +to multi-node via SLURM. + +## 3. Directions — steps to run + +The benchmark follows the standard MLPerf reference script flow: + +```bash +# 0. build/enter the container (canonical frozen environment) +docker build -t recommendation_v4 . +docker run --rm -it --device=/dev/kfd --device=/dev/dri \ + -v /path/to/dlrm_data:/data/mlperf_dlrm_v4 recommendation_v4 + +# 1. download + preprocess the dataset +DLRM_DATA_PATH=/data/mlperf_dlrm_v4 ./download_dataset.sh + +# 2. verify the preprocessed dataset +DLRM_DATA_PATH=/data/mlperf_dlrm_v4 ./verify_dataset.sh + +# 3. run the benchmark to the quality target and report wall-clock time +DLRM_DATA_PATH=/data/mlperf_dlrm_v4 ./run_and_time.sh +``` + +- [`download_dataset.sh`](download_dataset.sh) wraps the preprocessing pipeline + in `generative_recommenders.dlrm_v3.preprocess_public_data` (HuggingFace + download + temporal split + session segmentation + item-popularity counts). +- [`verify_dataset.sh`](verify_dataset.sh) checks the preprocessed files against + [`md5sums_yambda_5b_processed.txt`](md5sums_yambda_5b_processed.txt) (falls + back to a layout check until the canonical checksums are pinned). +- [`run_and_time.sh`](run_and_time.sh) runs the full-reference streaming + train+eval sweep on a single 8-GPU host with `AUC_THRESHOLD=0.80275` and MLPerf + compliance logging, printing the elapsed time of the timed region. + +### 3.1 Multi-node (SLURM) + +For N >= 1 nodes use [`scripts/launch_slurm.sh`](scripts/launch_slurm.sh), which +provisions the container on each node and launches the same trainer. A bare +submit runs a small functional smoke run; set the run-shape knobs for the full +sweep: + +```bash +# smoke (fast functional check) +sbatch --nodes=1 scripts/launch_slurm.sh + +# full reference sweep +START_TS=0 NUM_TRAIN_TS=299 \ +NUM_TRAIN_BATCHES=0 NUM_EVAL_BATCHES=0 \ +EVAL_EVERY_N_WINDOWS=0 EVAL_EVERY_DATA_PCT=0.005 \ +AUC_THRESHOLD=0.80275 \ +sbatch --nodes=2 scripts/launch_slurm.sh +``` + +Multi-node uses real RDMA (RoCEv2); the fabric/NCCL setup is documented in +[docs/multi_node_config.md](docs/multi_node_config.md). Keep all run outputs +(log, checkpoints, mllog, TensorBoard) under a writable scratch path you own — +the dataset mount is read-only. + +## 4. Model + +The model is **HSTU** (Hierarchical Sequential Transduction Units), the +generative-recommender architecture from Meta's ICML'24 paper *Actions Speak +Louder than Words: Trillion-Parameter Sequential Transducers for Generative +Recommendations* ([arXiv:2402.17152](https://arxiv.org/abs/2402.17152)). + +HSTU replaces the feature-interaction stack of a classic DLRM with a stack of +pointwise-attention "transducer" layers operating over the user's interaction +sequence. In this benchmark (the `dlrm_v3` path): + +- **Embeddings**: sparse tables for `item_id`, `artist_id`, `album_id`, `uid`, + and 7 cross-feature hashes (e.g. `user_x_artist`, `item_x_hour`), sharded + across GPUs with TorchRec `DistributedModelParallel`. +- **Sequence model**: an HSTU attention stack (`HSTU_NUM_LAYERS`, default 3) + over the interleaved UIH, computed with a fused jagged-attention Triton kernel + in bf16. +- **Supervision**: a single `listen_plus` binary task. The candidate event's + `action_weight` carries the supervision bit, and BCE loss is masked to + listen_plus candidates. + +See `generative_recommenders/dlrm_v3/configs.py` +(`get_hstu_configs`, `get_embedding_table_config`) for the exact architecture +and table specs, and the upstream README for the original modeling code. + +## 5. Dataset + +[Yambda-5b](https://huggingface.co/datasets/yandex/yambda) is a public +anonymized music-recommendation dataset from Yandex. The `5b` variant is used +for the reference. Statistics after preprocessing: + +| | | +|---|---| +| Total interaction events | **4.76 B** | +| Unique users | **1.00 M** | +| Max events per user | 27,738 | +| Median events per user | 2,695 | +| Mean events per user | 4,763 | +| Train events (300d) | 4.76 B | +| Test events (1d) | 22.4 M | +| Item catalog size | 9.39 M | + +### 5.1 Per-event-type distribution (across the full 4.76 B corpus) + +| Pool | Definition | Count | Share | +|---|---|---|---| +| **listen_plus (lp)** | `is_listen AND played_ratio >= 50%` | 2.92 B | **61.3%** | +| **skip** | `is_listen AND played_ratio < 50%` | 1.71 B | **35.9%** | +| **like** | explicit thumbs-up action | 89 M | **1.9%** | +| other | dislike / unlike / undislike | 47 M | 1.0% | + +The `like` pool is roughly **30x rarer** than `lp` — important context for the +gather strategy in §6. + +### 5.2 Preprocessing & download + +`./download_dataset.sh` (which calls +`python3 -m generative_recommenders.dlrm_v3.preprocess_public_data --dataset +yambda-5b --data-path `) downloads the 5b variant from HuggingFace, then: + +1. **Encodes** the raw `event_type` string into a uint8 lookup (listen=0, + like=1, dislike=2, unlike=3, undislike=4). +2. **Splits** events temporally — 300 train days, 30-min gap, 1 test day — by + Global Temporal Split (GTS). +3. **Segments** per-user event timelines into sessions on a 30-min inactivity + gap. +4. **Computes** per-item popularity for downstream metric weighting. +5. **Writes** the layout `DLRMv3YambdaDataset` expects: + +``` +/ +├── raw/5b/multi_event.parquet 50 GB (downloaded) +├── shared_metadata/ +│ ├── artist_item_mapping.parquet 60 MB +│ ├── album_item_mapping.parquet 76 MB +│ └── embeddings.parquet 18 GB (unused by HSTU training) +└── processed_5b/ + ├── train_sessions.parquet 47 GB ← main training input + ├── test_events.parquet 152 MB + ├── session_index.parquet 600 MB + ├── item_popularity.npy 75 MB + └── split_meta.json anchor + boundary stats +``` + +For smaller variants (`yambda-50m` / `yambda-500m`) substitute the dataset name +(`DATASET=yambda-50m ./download_dataset.sh`). Preprocessing takes ~2 min for 50m +and ~53 min for 5b end-to-end. + +Integrity is verified with `./verify_dataset.sh` against +[`md5sums_yambda_5b_processed.txt`](md5sums_yambda_5b_processed.txt). + +## 6. How data is fed to HSTU + +For every training anchor (a LISTEN event with >= `min_history` prior events), +the dataset builds a `(uih_kjt, candidate_kjt)` pair: + +``` +UIH (User Interaction History): + ┌─ Sequence features (chronologically interleaved across 3 pools) + │ item_id, artist_id, album_id ← per-position + │ action_weight ← per-position (LP_BIT/LIKE_BIT/SKIP_BIT) + │ action_timestamp, dummy_watch_time ← per-position + └─ Contextual features (length 1 each) + uid + 7 cross-feature hashes (user_x_artist, item_x_hour, …) + = 8 contextual entries + +CANDIDATE (the LISTEN event at the anchor): + item_id, artist_id, album_id, item_query_time, + item_action_weight (LP_BIT if listen_plus, else 0), + item_dummy_watchtime +``` + +The candidate's `action_weight` is **the supervision label**: HSTU's +`_get_supervision_labels_and_weights` masks BCE training to +`(supervision_bitmask & task_weight) > 0`, with `task_weight = 1` (LP bit) for +the single `listen_plus` task — so only listen_plus candidates supervise. + +### 6.1 Per-pool gather (the cap = L // 3 strategy) + +The UIH is built by `DLRMv3YambdaDataset._gather_interleaved_history`. For each +anchor it: + +1. Scans the most recent `scan_window` (default 20,000) events of any type + before the anchor, **clipped to user_start**. +2. From those, takes **the last `L // 3` events** from each of the three pools + (lp, like, skip) independently. +3. Concatenates and **re-sorts chronologically** to produce an interleaved + sequence. +4. Tags each event's pool identity into `action_weight` via OR'd bitmask + (LP=1, LIKE=2, SKIP=4). + +With `history_length = 4086` and `max_seq_len = 4096`: per-pool cap = `4086 // +3 = 1362`, and `3 × 1362 + 8 contextual + 1 candidate = 4095 <= 4096` (no +truncation). Because the `like` pool is rare (1.9%) it under-fills (~105 events +per anchor on average); the Triton jagged-attention backend skips unfilled +slots, so the under-fill costs sequence budget but not GPU compute. + +## 7. Optimizer + +Two optimizers, configured in +[`yambda_5b.gin`](generative_recommenders/dlrm_v3/train/gin/yambda_5b.gin): + +| component | optimizer | gin binding | key settings | +|---|---|---|---| +| Dense params (HSTU blocks, MLPs) | **Adam** | `dense_optimizer_factory_and_class.*` | lr `DENSE_LR`, betas (0.95, 0.999), eps 1e-8, weight_decay 0 | +| Sparse embedding tables | **RowWiseAdagrad** (fused FBGEMM TBE) | `sparse_optimizer_factory_and_class.*` | lr `SPARSE_LR`, eps 1e-8, weight_decay 0 | + +Gradient clipping (`GRAD_CLIP_NORM`, default `max_norm=1.0`) is applied to the +dense parameters on the streaming path; the fused sparse optimizer is +unaffected. Training is bf16 mixed precision (`make_model.bf16_training=True`). + +## 8. Hyperparameters + +All tunable hyperparameters live in +[`yambda_5b.gin`](generative_recommenders/dlrm_v3/train/gin/yambda_5b.gin) +(the config-file source of truth) and are **overridable via environment +variables** (the env value takes precedence over the gin default, per MLPerf +CONTRIBUTING rule 4d). The gin macros (`@env_int`, `@env_float`, `@env_str`) +enforce the correct type for each parameter. + +| hyperparameter | env var | gin binding | type | default | tuning rule | +|---|---|---|---|---|---| +| Per-rank batch size | `BATCH_SIZE` | `batch_size` | int | 1024 | positive integer (global batch = `BATCH_SIZE × world_size`) | +| Dense learning rate | `DENSE_LR` | `dense_optimizer_factory_and_class.learning_rate` | float | 1e-7 | positive float | +| Sparse learning rate | `SPARSE_LR` | `sparse_optimizer_factory_and_class.learning_rate` | float | 1e-7 | positive float | +| Grad clip max-norm | `GRAD_CLIP_NORM` | `streaming_train_eval_loop.grad_clip_norm` | float | 1.0 | float >= 0 (0 disables) | +| RNG seed | `SEED` | `seed_everything.seed` | int | 1 | any integer (-1 = random per run) | +| HSTU attention layers | `HSTU_NUM_LAYERS` | `get_hstu_configs.hstu_attn_num_layers` | int | 3 | positive integer | +| UIH history length | `HISTORY_LENGTH` | `get_dataset.history_length` | int | 4086 | positive integer (per-pool cap = L//3) | +| Max sequence length | `MAX_SEQ_LEN` | `get_hstu_configs.max_seq_len` | int | 4096 | positive integer (>= `history_length + 9`) | +| History strategy | `HISTORY_STRATEGY` | `get_dataset.history_strategy` | str | `interleaved` | one of `interleaved` \| `last_n` | +| Min history (anchor floor) | `MIN_HISTORY` | `get_dataset.min_history` | int | 4086 | integer >= 0 | +| Train user split | `TRAIN_SPLIT_PERCENTAGE` | `*.train_split_percentage` | float | 1.0 | float in (0, 1] | +| Streaming shuffle fraction | `STREAMING_SHUFFLE_FRACTION` | `get_dataset.streaming_shuffle_fraction` | float | 0.0 | float in [0, 1] | +| Streaming shuffle seed | `STREAMING_SHUFFLE_SEED` | `get_dataset.streaming_shuffle_seed` | int | 0 | any integer | +| Split salt | `SPLIT_SALT` | `get_dataset.split_salt` | int | 0 | any integer | +| Start window | `START_TS` | `streaming_train_eval_loop.start_ts` | int | 150 | integer >= 0 | +| Number of train windows | `NUM_TRAIN_TS` | `streaming_train_eval_loop.num_train_ts` | int | 149 | positive integer (clamped to available) | +| Sparse A2A fwd precision | `SPARSE_A2A_FWD` | `make_optimizer_and_shard.sparse_a2a_forward_precision` | str | `fp32` | one of `fp32` \| `bf16` \| `fp16` | +| Sparse A2A bwd precision | `SPARSE_A2A_BWD` | `make_optimizer_and_shard.sparse_a2a_backward_precision` | str | `fp32` | one of `fp32` \| `bf16` \| `fp16` | + +Non-tunable / fixed reference values (optimizer betas (0.95, 0.999), eps 1e-8, +weight_decay 0, bf16 training, streaming window = 86400 s) are pinned in the gin +file. Submitters tuning hyperparameters must follow the allowed values above and +the +[MLPerf training rules](https://github.com/mlcommons/training_policies/blob/master/training_rules.adoc#hyperparameters). + +## 9. Quality target & evaluation + +- **Metric**: AUC on the held-out future evaluation window (`window_auc` for the + `listen_plus` task), computed by `MetricsLogger` in + `generative_recommenders/dlrm_v3/utils.py`. +- **Target**: **eval AUC >= 0.80275**. Set via `AUC_THRESHOLD=0.80275` + (`MetricsLogger.auc_threshold`); the run logs `RUN_STOP` with `SUCCESS` and + stops once the target is reached. The gin default of `1.0` is unreachable + (trains all windows with no early stop) and is overridden by the reference + scripts. +- **Evaluation frequency**: the full-reference run uses + `EVAL_EVERY_DATA_PCT=0.005` — evaluate every 0.5% of the training stream + (~200 evenly-data-spaced eval points), independent of node count. The + alternative per-window cadence (`EVAL_EVERY_N_WINDOWS`) is mutually exclusive. +- **Evaluation set**: a fixed held-out future window (`eval_holdout_ts`, default + `start_ts + num_train_ts`); with `TRAIN_SPLIT_PERCENTAGE < 1.0` the held-out + users' anchors over that window form the eval set. The temporal one-window + lead guarantees no future leakage (see §11). + +Evaluation is always one window ahead of training, so reported AUC is genuine +next-period generalization. + +## 10. Reference Convergence Points (RCP) + +*Placeholder — to be generated.* + +RCPs have **not yet been generated** for this benchmark. Per the MLPerf +[CONTRIBUTING guidance](https://github.com/mlcommons/training_policies/blob/master/CONTRIBUTING.md), +RCPs must be generated for at least 3 reasonable batch sizes using at least 2N +seeds (N = number of submission runs), in FP32 or BF16, with the exact precision +recorded in the RCP JSON. The convergence curves (steps/samples to reach +AUC >= 0.80275) will be added under [`rcp/`](rcp/) once the convergence runs are +complete. This section is intentionally left blank for now. + +## 11. Streaming (temporal-order) training + +`scripts/launch_slurm.sh` and `run_and_time.sh` default to +`--mode streaming-train-eval`, which trains Yambda in strict wall-clock order +instead of shuffling the whole corpus. The timeline is sliced into +fixed-duration **windows** (default 1 day, +`get_dataset.streaming_window_seconds = 86400`), and the loop walks them forward: + +``` +window T: train window T+1: eval (then train) window T+2: eval (then train) ... + └─ train window T ─┐ + └─ eval window T+1 ─┐ + └─ train window T+1 ─┐ + └─ eval window T+2 ... +``` + +i.e. for each step it **trains window T, then evaluates window T+1** before +advancing — always predicting the immediate future from the past. + +### 11.1 Temporal guarantee + +The streaming path enforces **no future leakage** at two levels: + +1. **Across windows** — a window is the set of anchors whose target/candidate + timestamp falls in `[t_min + T·W, t_min + (T+1)·W)`. Training only ever sees + windows `<= T`; the evaluation window `T+1` is strictly in the future of every + training anchor it is scored against. +2. **Within an anchor** — history is gathered **causally**: the UIH scan is + `scan_start:flat_pos` (events strictly before the anchor), so no event at or + after the anchor's timestamp can enter its features. + +This is a *temporal* split on the training stream — distinct from the +preprocessing GTS split (§5) that carves off the final test day. Windows are +indexed off the per-anchor target timestamp via a lazily-built, mmap'd +`anchor_ts` cache keyed by `(history_length, min_history)`. + +### 11.2 Streaming knobs + +All configurable via +[`yambda_5b.gin`](generative_recommenders/dlrm_v3/train/gin/yambda_5b.gin) with +env overrides: + +| env | gin | default | meaning | +|---|---|---|---| +| `START_TS` | `streaming_train_eval_loop.start_ts` | 150 | first window (early windows are near-empty warm-up) | +| `NUM_TRAIN_TS` | `streaming_train_eval_loop.num_train_ts` | 149 | number of train windows (clamped to available) | +| `PERSISTENT_LOADER` | `streaming_train_eval_loop.persistent_loader` | 1 | reuse one worker pool across windows | +| `DOUBLE_BUFFER` | `streaming_train_eval_loop.double_buffer` | 1 | prepare the next window in a background thread | +| `EVAL_EVERY_N_WINDOWS` | `streaming_train_eval_loop.eval_every_n_windows` | 1 | eval cadence by window count (0 to use data-pct) | +| `EVAL_EVERY_DATA_PCT` | `streaming_train_eval_loop.eval_every_data_pct` | 0.0 | eval cadence by fraction of train data (full ref: 0.005) | +| `MIN_HISTORY` | `get_dataset.min_history` | 4086 | anchor-eligibility floor (0 = ~all users incl. cold-start) | + +### 11.3 Checkpointing & resume + +The streaming loop is resume-aware: set `CKPT_PATH` to enable DMP checkpoint +save/load (auto-resolves to the highest-numbered subdir), with retention via +`KEEP_LAST_N` and cadences `IN_WINDOW_CKPT_FREQ` / `CKPT_STEP_FREQ` / +`CKPT_TIME_INTERVAL_S`. The MLPerf run state (run-started flag, global sample +count) is persisted across resume so compliance logging is continuous. See +`generative_recommenders/dlrm_v3/checkpoint.py`. + +## 12. License + +Apache 2.0 (inherited from upstream). diff --git a/recommendation_v4/docs/multi_node_config.md b/recommendation_v4/docs/multi_node_config.md new file mode 100644 index 000000000..52fdbbb69 --- /dev/null +++ b/recommendation_v4/docs/multi_node_config.md @@ -0,0 +1,230 @@ +# Multi-Node Training Enablement (yambda-5b, MI350X / Broadcom bnxt_re RoCE) + +How N-node (N×8-GPU) distributed training was brought up for the yambda-5b HSTU +ranker on the `meta64` cv350 cluster, the hard problems solved, and **exactly +which settings are cluster/fabric-specific** so this can be reused or re-tuned +when the underlying network changes. + +Companion to [`perf_opt.md`](./perf_opt.md) and [`training_recipe.md`](./training_recipe.md). +The single entry point is [`scripts/launch_slurm.sh`](../scripts/launch_slurm.sh); +the Python side is `generative_recommenders/dlrm_v3/train/{train_ranker,utils}.py`. + +--- + +## TL;DR + +- Multi-node works over **real RDMA** (RoCEv2 on 8× Broadcom bnxt_re HCAs). + 2-node = `world_size=16`, clean `rc=0`, ~7.7–8.0k `global_sps` (≈1.28× of + 1-node 6.2k; weak scaling, per-rank batch fixed). +- The one non-obvious blocker was a **userspace RDMA provider ABI mismatch** + inside the container, fixed with an `LD_PRELOAD`/`LD_LIBRARY_PATH` **overlay** + of the host's matched `rdma-core` (no container lib surgery). +- Everything is one script with three auto-detected phases + (`orchestrate` → `provision` → `worker`) plus small Python changes for global + ranks. All cluster-specific knobs are env-overridable and tagged + `[CLUSTER-SPECIFIC]` in the script. + +--- + +## Architecture: one script, three phases + +`launch_slurm.sh` self-dispatches by context (`LAUNCH_SLURM_PHASE`, else +auto-detected via `/.dockerenv`): + +| Phase | Runs on | Does | +|---|---|---| +| `orchestrate` | SLURM batch host | Resolve rendezvous (`MASTER_ADDR/PORT`), ensure container on every node (calls `provision`), then `docker exec` the `worker` phase on every node (one srun task per node). | +| `provision` | each compute node (host) | Ensure the `yambda_primus` container is up (baked image if present, else base image + pip), stage the host RDMA overlay on NFS. | +| `worker` | inside the container | Derive topology, set NCCL/RDMA env, apply the RDMA overlay, spawn this node's 8 GPU ranks via `train_ranker`. `NNODES==1` => legacy single-node path unchanged. | + +Why one script: multi-node enablement is then a single committable file. The +worker phase is also what the streaming-e2e supervisor invokes directly +(single-node, already inside the container), so the production path is unchanged. + +``` +sbatch --nodes=N launch_slurm.sh + │ (batch host: orchestrate) + ├─ srun: provision ──> docker container up + RDMA overlay staged (×N nodes) + └─ srun: docker exec launch_slurm.sh (worker) (×N nodes) + │ in container: topology + NCCL/RDMA env + LD overlay + └─ python train_ranker ──> 8 local ranks ──> RCCL rendezvous over RDMA +``` + +--- + +## The hard problems (lessons learned) + +### 1. RDMA provider ABI mismatch — the core blocker + +**Symptom:** multi-node RCCL died at init with +`ibv_create_qp ... Bad address`. + +**Root cause:** the container image (`rocm/primus:v26.3`) ships an **older** +userspace `rdma-core` (v34, `libbnxt_re-rdmav34.so`) than the **host kernel** +bnxt_re driver's uapi (host `rdma-core` v61 / `libbnxt_re-rdmav59.so`). The v34 +provider enumerates the HCAs and creates *shallow* QPs fine, but **faults when +creating a deep send queue** — RCCL uses `max_send_wr=256`. Verified with a +parameterized verbs probe: v34 `create_qp` is OK at depth ≤16 and faults at ≥64; +the host v59 provider works at **every** depth. So it is purely the **userspace +provider**, not the kernel or the fabric (a 2-node RoCEv2 RDMA-write test passes +on the stock stack, and bare-metal RCCL benchmarks run fine with the host libs). + +**Fix (no container surgery):** the `provision` phase stages the host's matched +`rdma-core` on shared NFS (`$OVERLAY`): + +``` +$OVERLAY/lib/libibverbs.so.1 # host libibverbs v61 +$OVERLAY/lib/libibverbs.so -> .so.1 # UNVERSIONED symlink (critical, see below) +$OVERLAY/lib/libnl-3.so.200, libnl-route-3.so.200 +$OVERLAY/lib/libibverbs/.so # incl. libbnxt_re-rdmav59.so +``` + +The `worker` phase makes RCCL load it at runtime: + +```bash +export LD_LIBRARY_PATH="$OVERLAY/lib:$OVERLAY/lib/libibverbs:$LD_LIBRARY_PATH" +export LD_PRELOAD="$OVERLAY/lib/libibverbs.so.1:$LD_PRELOAD" +``` + +We do **not** modify the container's system libs — only this process tree's +`LD_*`. Single-node and other users keep the stock stack. + +### 2. The UNVERSIONED `libibverbs.so` symlink is mandatory + +An earlier overlay attempt set `LD_LIBRARY_PATH` but still failed with +`Bad address`. Reason: at `import torch` the ROCm stack pulls in the +**unversioned** soname `libibverbs.so` (not `libibverbs.so.1`). If the overlay +only has `libibverbs.so.1`, that unversioned lookup misses the overlay, falls +through to the **container's** old lib, which then occupies the `libibverbs.so.1` +slot — so RCCL's later `dlopen("libibverbs.so.1")` binds the v34 stack and +`create_qp(256)` faults again. The overlay **must** expose +`libibverbs.so -> libibverbs.so.1`. With it (verified via `/proc//maps`), +the process maps **only** the host lib. `LD_PRELOAD` is belt-and-braces so the +host lib claims the soname slot first. + +### 3. Two network planes — pin TCP bootstrap, RDMA for data + +The container is `--network=host`, so RCCL sees **all** host interfaces and, left +to auto-detect, picks the wrong one. These nodes expose: +- `benic1p1..benic8p1` — per-GPU point-to-point RoCE links on `192.168.{1..8}.x/31`. + These are **not node-routable** for plain TCP; the very first bring-up **hung** + in `init_process_group` because RCCL tried the TCP bootstrap over a + non-routable `192.168.x` backend addr. +- `fenic0` — the routable front-end (`10.190.x`). + +So we split the planes explicitly: +- `NCCL_SOCKET_IFNAME=fenic0` → TCP bootstrap/rendezvous over the routable NIC. +- `NCCL_IB_HCA=bnxt_re0..7` → RDMA **data** over the 8 RoCE HCAs (the RoCEv2 + fabric *is* reachable rail-to-rail at the RDMA layer even though plain IP is not). + +### 4. Minimal proven bnxt_re NCCL config + +The minimal set proven on these nodes (matches cmcknigh's bare-metal RCCL +benchmarks): `NCCL_IB_GID_INDEX=3` (RoCEv2 IPv4 GID), `NCCL_IB_TC=104` (RoCE +lossless / PFC traffic class). **Do not** add the heavy +`QPS_PER_CONNECTION / ECE / DMABUF` block — that belongs to a different +(ionic AINIC) fabric and is counterproductive on bnxt_re. GPU-Direct RDMA +(`NCCL_NET_GDR_LEVEL`) is left **off**: it needs DMABUF/peermem, unavailable +in-container here, so RCCL stages through host memory (still real RDMA). + +### 5. Rendezvous must be resolved on the host + +The container image has **no SLURM client** (`scontrol` absent). So the +`orchestrate` phase resolves `MASTER_ADDR` (first host of the allocation) and a +deterministic `MASTER_PORT` (`20000 + job_id % 20000`, same on all nodes) **on +the host** and forwards them into the container via `docker exec -e`. + +### 6. Global rank derivation (Python) + +`mp.start_processes` hands out a node-local `local_rank` (0..7). Every downstream +consumer (data sharding, checkpoint I/O, metrics) needs the **global** rank: + +```python +rank = node_rank * gpus_per_node + local_rank # train_ranker._main_func +device = torch.device(f"cuda:{local_rank}") # CUDA device stays node-local +``` + +Also: `make_optimizer_and_shard(local_world_size=gpus_per_node)` so the TorchRec +planner respects the intra-node GPU count, and `MetricsLogger(world_size=...)` +gets the live world size (the gin default of 8 would mis-normalize multi-node). +`NNODES==1` makes `rank == local_rank` — identical to the old single-node path. + +### 7. `$0` is the staged `slurm_script`, not the repo path + +For an sbatch batch script, `$0` = +`/var/spool/slurmd/job/slurm_script` (node-local), so deriving the script / +repo path from `$0` gives a path that **doesn't exist on other nodes** (`bash +$SELF` → "No such file", and the worker's `cd $REPO` → exit 127). The +`orchestrate` phase instead resolves the real shared-NFS path from SLURM: + +```bash +SCRIPT_PATH=$(scontrol show job "$SLURM_JOB_ID" | grep -oP 'Command=\K\S+') +# fallbacks: $SLURM_SUBMIT_DIR/scripts/launch_slurm.sh, then $SELF +REPO=$(cd "$(dirname "$SCRIPT_PATH")/.." && pwd) +``` + +### 8. `srun ... bash -c "…"` host-vs-remote expansion + +Inside the double-quoted srun command string, **plain `$VAR` expands now on the +batch host** (values computed in orchestrate: `$MASTER_ADDR`, `$SCRIPT_PATH`, …) +while **`\$VAR` is deferred to each compute node** (`\$SLURM_NODEID`, +`\$(hostname)`) where the per-node SLURM env lives. Mixing these up sends every +rank the wrong node id. + +### 9. `memlock` ulimit for QP registration + +`docker run --ulimit memlock=-1:-1` is **required** — RDMA QP memory +registration needs unlimited locked memory. A container started with the default +8 MB memlock fails QP creation regardless of the overlay. + +### 10. Provisioning & the image-bake caveat + +Fresh nodes otherwise re-download a **6.1 GB** ROCm torch wheel + pip + build +torchrec-from-git every time. The script supports a pre-baked image +(`docker commit` → NFS tar → `docker load` offline). **Caveat:** the committed +image is **~127 GB** (ROCm base is huge), so the full-image NFS tar is impractical +(loading it can be slower than re-downloading 6 GB). For true download-avoidance +prefer a **local pip wheelhouse** (`pip install --no-index --find-links` from +~8 GB of NFS wheels) or a **local registry** (ships only the ~35 GB delta layer). +The bake hook is left in (`BAKE_IMAGE=1`) but defaults off; provisioning falls +back to base-image + pip. + +### Debunked theory (do not re-introduce) + +An earlier claim that the container's rdma-core was "too old → 0 devices / +Bad address" and needed an **in-place lib copy** was a red herring: the "0 +devices" came from a *broken in-place copy* of the host EL9 libs (mixing v34 +tooling that links `IBVERBS_PRIVATE_34` with host v61 libs breaks symbol-version +lookup). The stock container enumerates all 8 HCAs fine. The real issue is only +the deep-QP create path; the fix is the **LD overlay**, never in-place surgery. + +--- + +## Cluster-specific settings — change these when the fabric/hardware changes + +All are env-overridable and tagged `[CLUSTER-SPECIFIC]` in `launch_slurm.sh` +(`grep '\[CLUSTER-SPECIFIC\]' scripts/launch_slurm.sh`). + +| Setting | Default (meta64) | What it is | How to find the right value | +|---|---|---|---| +| `#SBATCH --partition` | `meta64` | scheduler partition | `sinfo` | +| bind mounts + default paths | `/home/chcai`, `/apps/chcai` | repo + scratch, **must be shared/NFS on all nodes** | `df -h`, cluster docs | +| `IMAGE` | `rocm/primus:v26.3` | base container (GPU arch + ROCm version) | vendor image registry | +| docker `--device` | `/dev/kfd /dev/dri` (AMD) | GPU passthrough | NVIDIA: `--gpus all` / nvidia runtime | +| `--ulimit memlock` | `-1` | locked mem for RDMA QP | keep `-1` for any RDMA fabric | +| `TORCH_IDX` / torch,vision,audio | `rocm7.2`, `2.12.0+rocm7.2` … | ROCm-version'd wheels | `download.pytorch.org/whl/` | +| `FBGEMM_WHL` | gfx950 wheel on NFS | GPU-arch fbgemm | build/stage per arch | +| `NCCL_SOCKET_IFNAME` | `fenic0` | **routable** host NIC for TCP bootstrap | `ip -br addr` (pick the routable one; NOT the per-GPU RDMA NICs) | +| `NCCL_IB_HCA` | `bnxt_re0..7` | RDMA HCA device names | `ibv_devices` (vendor: `mlx5_*`, `ionic_*`, …) | +| `NCCL_IB_GID_INDEX` | `3` | RoCEv2 IPv4 GID index | `show_gids` (v1/v2 & IPv4/IPv6 differ per port) | +| `NCCL_IB_TC` | `104` | RoCE lossless / PFC traffic class | fabric/switch admin | +| `RDMA_OVERLAY` (+ provider .so) | `/apps/chcai/rdma_host_el9_new` | host rdma-core overlay | only needed if container rdma-core < host kernel uapi; else set `RDMA_OVERLAY=` to disable. Stage the host's matching `/usr/lib64/libibverbs/.so` | + +**Different NIC vendor (e.g. Mellanox `mlx5`)** typically means: change +`NCCL_IB_HCA` names, re-check `NCCL_IB_GID_INDEX`/`NCCL_IB_TC`, and the RDMA +overlay is often **unnecessary** (Mellanox userspace in the image usually matches +the host) — set `RDMA_OVERLAY=` to skip it. + +**Emergency fallback:** `NCCL_NET_TRANSPORT=socket` disables IB and runs +allreduce over TCP (`fenic0`). Functional but ~100–200× slower; use only to +isolate a fabric problem. diff --git a/recommendation_v4/docs/perf_opt.md b/recommendation_v4/docs/perf_opt.md new file mode 100644 index 000000000..7799848ad --- /dev/null +++ b/recommendation_v4/docs/perf_opt.md @@ -0,0 +1,73 @@ +# Performance Optimizations — MI350X HSTU / OneTrans (yambda-5b, bs=1024, TRITON) + +Performance work for the 8× MI350X HSTU ranker on `yambda-5b` at `batch_size=1024` +with the **TRITON** HSTU kernel and bf16 training. Companion to +[`training_recipe.md`](./training_recipe.md) (environment + reproduction). + +Throughput numbers are global samples/sec across 8 GPUs (`global_sps`), measured +at steady state (instantaneous, computed from consecutive logged steps). + +--- + +## LN-dropout: multi-row, separated-RNG path on MI350 + +### What + +`_ln_mul_dropout_*` has two kernel variants: + +- **legacy** — single program per row, RNG fused inline (`_ln_mul_dropout_fwd`). +- **separated-RNG** — multiple rows per program, dropout mask precomputed once + and reused by the backward (`_ln_mul_dropout_fwd_rng` / + `_ln_mul_dropout_bwd_dx_du_rng`). + +The separated path was previously gated to Blackwell only (`is_sm100_plus()`). +MI350X (`gfx950`) benefits from the same structure, so the gate now also enables +it on MI350. + +### Where + +| file | change | +|---|---| +| `ops/utils.py` | `is_amd_mi350()` (gfx950 detect) + `use_separated_rng_ln_mul_dropout()` gate | +| `ops/triton/triton_hstu_linear.py` | dispatch LN-dropout fwd to the separated-RNG path when the gate is true | + +```python +# ops/utils.py +def use_separated_rng_ln_mul_dropout() -> bool: + return is_sm100_plus() or is_amd_mi350() +``` + +### Perf + +**+5.6% end-to-end → 14,222 global sps** (separated-RNG vs legacy fused, identical +config, full boost clocks — see the caveat below). + +--- + +## Caveat — GPU clock lock can mask all perf changes + +A node-level GPU clock lock will silently invalidate any benchmark on this +machine, so check it before trusting numbers. + +During this work all 8 GPUs were stuck in **`perf_determinism`** performance +level at **sclk 1093 MHz** (DPM level 1) while the real max is **2200 MHz** +(level 2) — despite 100% utilization, ~370 W of power headroom (629 / 1000 W), +and low temps (~50 °C). This was **not** thermal/power throttling; it was +leftover node state from a prior job. + +Effect: a **uniform ~1.87× slowdown of every Triton compute kernel** +(`2200 / 1093 ≈ 2.0×`), including kernels unrelated to any code change. It made +the LN-dropout fix above look like a regression until the clock state was found. + +### Detect + fix + +```bash +rocm-smi --showperflevel # expect "auto", not perf_determinism/manual/low +rocm-smi -d 0 --showclocks # expect sclk ~2000+ MHz under load +rocm-smi --setperflevel auto # restore boost +``` + +`scripts/launch_slurm.sh` (worker phase) now logs the perf level + a live `sclk` sample on +every launch, auto-restores `auto` if it finds a `perf_determinism`/`manual`/`low` +lock, and warns (to reset from the host) if it lacks permission inside the +container. **Always sanity-check `sclk ≈ 2000+ MHz` before trusting a benchmark.** diff --git a/recommendation_v4/docs/training_recipe.md b/recommendation_v4/docs/training_recipe.md new file mode 100644 index 000000000..cf31a9fff --- /dev/null +++ b/recommendation_v4/docs/training_recipe.md @@ -0,0 +1,203 @@ +# Training Recipe + +Reproducible environment + configuration for training HSTU / DLRM-v3 on the +`yambda-5b` dataset. + +--- + +## MI350X + +Single-node, 8× AMD **Instinct MI350X** (`gfx950`, ~288 GiB HBM3e each), HSTU +ranker on `yambda-5b` with the **TRITON** HSTU kernel and **bf16** +mixed-precision training. + +### Hardware / host + +| item | value | +|---|---| +| GPUs | 8× AMD Instinct MI350X (`gfx950`, ROCm 7.2.1) | +| Host CPU | AMD EPYC 9655 96-Core (192 cores × 2 threads) | + +### Container image + +``` +rocm/primus:v26.3 +``` + +### Dependency versions + +Aligned with the B200 path: same torch major.minor, same torchrec commit, +same fbgemm SHA. The image's native torch / torchvision / torchaudio / +torchrec / fbgemm_gpu are all replaced; only the image's triton stays. + +| package | version | install | +|---|---|---| +| **torch** | `2.12.0+rocm7.2` | `pip install --upgrade --no-deps --index-url https://download.pytorch.org/whl/rocm7.2 torch==2.12.0+rocm7.2` | +| **torchvision** | `0.27.0+rocm7.2` | `pip install --upgrade --no-deps --index-url https://download.pytorch.org/whl/rocm7.2 torchvision` — ABI must match torch 2.12 | +| **torchaudio** | `2.11.0+rocm7.2` | `pip install --upgrade --no-deps --index-url https://download.pytorch.org/whl/rocm7.2 torchaudio` — ABI must match torch 2.12 | +| **triton** | `3.6.0` | image native, unchanged | +| **fbgemm_gpu** | `fbgemm_gpu_nightly_rocm-2026.6.2` (built from FBGEMM commit `10b77573`, same SHA as the B200 path) for `gfx950` | rebuild from source against the replaced torch. Build command: `python setup.py -j 32 bdist_wheel --build-target=default --build-variant=rocm -DHIP_ROOT_DIR=/opt/rocm -DAMDGPU_TARGETS=gfx950` | +| **torchrec** | `1.7.0a0+bf55480` (git tag `v2026.06.01.00`) | `pip install --force-reinstall --no-deps "git+https://github.com/pytorch/torchrec.git@v2026.06.01.00"` | +| **polars-u64-idx** | `1.33.1` | 64-bit row index — `yambda-5b` has > 4.29 B rows. Installed from a pre-staged local tarball by `scripts/launch_smoke_8gpu.sh` | + +### Training configuration + +From `generative_recommenders/dlrm_v3/train/gin/yambda_5b.gin`: + +| parameter | value | gin binding | +|---|---|---| +| num_workers (dataloader) | 4 | `make_train_test_dataloaders.num_workers` | +| prefetch_factor | 8 | `make_train_test_dataloaders.prefetch_factor` | +| num_blocks | 1 | `make_train_test_dataloaders.num_blocks` | +| train_split_percentage | 0.90 | `make_train_test_dataloaders.train_split_percentage` | +| history_length (per-sample UIH budget) | 2039 | `get_dataset.history_length` | +| max_seq_len (attention budget) | 2048 | `get_hstu_configs.max_seq_len` | +| bf16 training | True | `make_model.bf16_training` | +| HBM cap (per GPU) | 260 GiB | `make_optimizer_and_shard.hbm_cap_gb` (env `HBM_CAP_GB`) | +| **triton autotune pinning** | **False (pinned)** | `apply_env_bootstrap.TRITON_FULL_AUTOTUNE` | +| dense optimizer | Adam, lr 1e-3, betas (0.95, 0.999), eps 1e-8 | `dense_optimizer_factory_and_class.*` | +| sparse optimizer | RowWiseAdagrad, lr 1e-3, betas (0.95, 0.999), eps 1e-8 | `sparse_optimizer_factory_and_class.*` | +| world_size | 8 | `MetricsLogger.world_size` | + +Effective global batch = `batch_size × world_size = 32 × 8 = 256` samples/step. + +### Environment variables + +| var | value | purpose | +|---|---|---| +| `HSTU_HAMMER_KERNEL` | `TRITON` | fast HSTU kernel (vs `PYTORCH` fallback) | +| `DLRM_DATA_PATH` | dataset root | overrides gin default `/apps/chcai/dlrm_data` | +| `HBM_CAP_GB` | (optional) | embedding planner HBM budget per GPU | +| `RUN_NAME` | run id | results dir → `results//` | +| `PYTORCH_CUDA_ALLOC_CONF` | `expandable_segments:True` | allocator headroom | +| `HIP_VISIBLE_DEVICES` / `CUDA_VISIBLE_DEVICES` | `0,1,2,3,4,5,6,7` | rank visibility | + +`TRITON_FULL_AUTOTUNE` is set automatically by the gin-driven bootstrap +(`generative_recommenders.dlrm_v3.train._env_bootstrap.apply_env_bootstrap`), +which runs in `train_ranker._main_func` BEFORE the triton kernel modules +import — so the gin file is the source of truth. + +### Measured performance + +| variant | steady-state ms/step | global sps | epoch ETA (3.23B anchors) | +|---|---|---|---| +| nightly + fp32 + PYTORCH attn (baseline) | ~190 | ~1340 | ~28 d | +| nightly + bf16 + TRITON attn | ~93 | ~2787 | ~13.4 d | +| primus + bf16 + TRITON attn | ~67.5 | ~3793 | ~9.9 d | +| primus + fbgemm HEAD + bf16 + TRITON, autotune drift | ~53 fast / ~70 slow | 3700–4860 | 7.7–10.2 d | +| **primus + fbgemm HEAD + bf16 + TRITON + pinning (default)** | **~52** | **~4970** | **~7.6 d** | + +The "pinning" line is the deterministic per-cold-start equilibrium — +three layer-norm / jagged triton kernels have two stable autotune winners +and the pin forces the fast one every run. + +### Known pitfalls + +- The image ships `fbgemm_gpu==2026.5.14`. The wheel built from FBGEMM HEAD + (`2026.6.1`) is required for the 70 → 52 ms step. Build inside the + container so the wheel links against the image's native torch. +- Stock `polars` silently overflows on `yambda-5b` (> 4.29 B rows); always + use `polars-u64-idx`. +- When changing shape (batch size, history length), GPU, or triton/torch + version, flip `apply_env_bootstrap.TRITON_FULL_AUTOTUNE = True` and run + with `TRITON_PRINT_AUTOTUNING=1` to re-capture winners, then update the + pinned configs at the `pinned_or_full(...)` call sites in + `generative_recommenders/ops/triton/`. +- Do not run with bf16 on the `PYTORCH` HSTU attention backend at our + sequence length — `pt_hstu_attention`'s QK einsum backward overflows in + bf16 at N > 1k and produces NaN at step 1. bf16 is only safe with TRITON. + +--- + +## B200 + +Single-node, 8× NVIDIA **B200** (Blackwell, `sm_100`, ~183 GiB HBM each), HSTU +ranker on `yambda-5b` with the **TRITON** HSTU kernel and **bf16** mixed-precision +training. + +### Hardware / host + +| item | value | +|---|---| +| GPUs | 8× NVIDIA B200 (`sm_100`, compute capability 10.0) | +| Host driver | 580.159.03 (reports CUDA 13.0) | +| Forward-compat userspace driver | `libcuda.so.595.58.03` (CUDA 13.2.1; engaged automatically by the NGC image) | + +### Container image + +``` +nvcr.io/nvidia/pytorch:26.04-py3 +``` + +Digest: `sha256:192d749b4d773610ec9e01c0443a9df545d196c412b7b8fd33bfa3da362a49e7` + +The image's native PyTorch is kept as-is and must not be reinstalled (so CUPTI +stays matched to the driver and `sm_100` support is preserved). + +`nvcr.io/nvidia/pytorch:26.01-py3` (torch `2.10.0a0` / CUDA 13.1, digest +`sha256:38ed2ecb2c16d10677006d73fb0a150855d6ec81db8fc66e800b5ae92741007e`) is +also validated and performance-equivalent — rebuild `fbgemm_gpu` against +whichever image's torch you run. + +### Dependency versions + +| package | version | notes | +|---|---|---| +| **torch** | `2.12.0a0+0291f960b6.nv26.04.48445190` (CUDA 13.2) | native to the image; not reinstalled | +| **triton** | `3.6.0` | native to the image; provides `triton.language.make_tensor_descriptor` (required by the TRITON HSTU path) | +| **fbgemm_gpu** | FBGEMM commit `10b775730212923f65f7b78f79b6a01d80cf3c29` (2026-06-01 `main`, CUDA 13.2, `sm_100`) | built from source against the native torch; public wheels are ABI-incompatible with the NGC torch. The built wheel is named `fbgemm_gpu_nightly-2026.6.1` — that version is the build date, not the source date, so always identify the build by the commit above. Build command: `TORCH_CUDA_ARCH_LIST=10.0 python setup.py bdist_wheel --build-target default --build-variant cuda --package_channel nightly --nvml_lib_path /usr/lib/x86_64-linux-gnu/libnvidia-ml.so` (~55 min — the `sm_100` TBE-forward kernels dominate via `ptxas`) | +| **torchrec** | `1.7.0.dev20260601+cu130` (nightly, tested) | installed `--no-deps` from `https://download.pytorch.org/whl/nightly/cu130`. Perf-neutral vs stable `1.4.0`; use `1.4.0` (latest stable) if you prefer a non-pre-release | +| **polars-u64-idx** | `1.33.1` | 64-bit row index — `yambda-5b` has > 4.29 B rows (overflows stock polars' 32-bit index) | +| CUPTI (for `torch.profiler`) | 13.2 (native) | matches the driver; the `+cu128` stack's CUPTI 12.8 fails on B200 (`CUPTI_ERROR_INVALID_DEVICE`) | + +Additional Python deps: +`xxhash`, `gin-config`, `absl-py`, `pandas`, `tensorboard`, `pyarrow`, `pyyaml`, +`tqdm`, `psutil`, `torchmetrics==1.0.3`, `tensordict`, `pyre-extensions`, +`iopath`, `typing-inspect`. + +### Training configuration + +From `generative_recommenders/dlrm_v3/train/gin/yambda_5b.gin`: + +| parameter | value | gin binding | +|---|---|---| +| batch_size (train) | 32 | `make_train_test_dataloaders.batch_size` | +| eval_batch_size | 32 | `make_train_test_dataloaders.eval_batch_size` | +| num_workers (dataloader) | 4 | `make_train_test_dataloaders.num_workers` | +| prefetch_factor | 8 | `make_train_test_dataloaders.prefetch_factor` | +| num_blocks | 1 | `make_train_test_dataloaders.num_blocks` | +| train_split_percentage | 0.90 | `make_train_test_dataloaders.train_split_percentage` | +| history_length (per-sample UIH budget) | 2039 | `get_dataset.history_length` | +| max_seq_len (attention budget) | 2048 | `get_hstu_configs.max_seq_len` | +| bf16 training | True | `make_model.bf16_training` | +| HBM cap (per GPU) | 150 GiB | `make_optimizer_and_shard.hbm_cap_gb` (env `HBM_CAP_GB`) | +| **triton autotune pinning** | **True (full autotune)** | `apply_env_bootstrap.TRITON_FULL_AUTOTUNE` — the pinned configs are MI350X-specific, so B200 runs full autotune to find its own `sm_100` winners | +| dense optimizer | Adam, lr 1e-3, betas (0.95, 0.999), eps 1e-8 | `dense_optimizer_factory_and_class.*` | +| sparse optimizer | RowWiseAdagrad, lr 1e-3, betas (0.95, 0.999), eps 1e-8 | `sparse_optimizer_factory_and_class.*` | +| world_size | 8 | `MetricsLogger.world_size` | + +Effective global batch = `batch_size × world_size = 32 × 8 = 256` samples/step. + +### Environment variables + +| var | value | purpose | +|---|---|---| +| `HSTU_HAMMER_KERNEL` | `TRITON` | fast HSTU kernel (vs `PYTORCH` fallback) | +| `TORCH_CUDA_ARCH_LIST` | `10.0` | target `sm_100` for JIT / Triton compilation | +| `DLRM_DATA_PATH` | dataset root | overrides gin default `/apps/chcai/dlrm_data` | +| `HBM_CAP_GB` | `150` | embedding planner HBM budget per GPU | +| `RUN_NAME` | run id | results dir → `results//` | +| `PYTORCH_CUDA_ALLOC_CONF` | `expandable_segments:True` | allocator headroom | +| `TRITON_CACHE_DIR` | cache path | persist compiled Triton kernels across runs | +| `WORLD_SIZE` / `LOCAL_WORLD_SIZE` | `8` | mp.spawn rank count | + +### Known pitfalls + +- Never reinstall torch in this image — a cu12x wheel breaks CUPTI and may drop + `sm_100`. +- The `+cu128` stack (`torch==2.7.1+cu128` + `fbgemm-gpu==1.2.0+cu128` + + `torchrec==1.2.0+cu128`) runs on B200 but cannot profile GPU activity (CUPTI + 12.8 vs the 13.2 driver). +- Stock `polars` silently overflows on `yambda-5b` (> 4.29 B rows); always use + `polars-u64-idx`. +- `EmbeddingBoundsCheck ... Setting idx to zero` warnings are benign data clamps. diff --git a/recommendation_v4/download_dataset.sh b/recommendation_v4/download_dataset.sh new file mode 100755 index 000000000..d02f382dc --- /dev/null +++ b/recommendation_v4/download_dataset.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash +# MLPerf Training reference script: download + preprocess the dataset. +# +# Downloads the Yambda dataset from HuggingFace (yandex/yambda) and runs the +# preprocessing pipeline (event-type encoding, temporal GTS split, session +# segmentation, item-popularity counts) into the on-disk layout that +# DLRMv3YambdaDataset consumes. This is a thin wrapper over +# generative_recommenders.dlrm_v3.preprocess_public_data +# so the full reference data pipeline lives in one place. +# +# Usage: +# DLRM_DATA_PATH=/path/to/dlrm_data ./download_dataset.sh +# DATASET=yambda-50m DLRM_DATA_PATH=/path/to/dlrm_data ./download_dataset.sh +# +# Env: +# DATASET dataset variant (default: yambda-5b). One of +# kuairand-1k | kuairand-27k | yambda-50m | yambda-500m | yambda-5b +# DLRM_DATA_PATH destination data root (required). +set -euo pipefail + +DATASET="${DATASET:-yambda-5b}" +: "${DLRM_DATA_PATH:?Set DLRM_DATA_PATH to the destination data root}" + +REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "${REPO_ROOT}" + +echo "[download_dataset] dataset=${DATASET} data-path=${DLRM_DATA_PATH}" +mkdir -p "${DLRM_DATA_PATH}" + +python3 -m generative_recommenders.dlrm_v3.preprocess_public_data \ + --dataset "${DATASET}" \ + --data-path "${DLRM_DATA_PATH}" + +echo "[download_dataset] done. Preprocessed layout under ${DLRM_DATA_PATH}:" +echo " raw/5b/multi_event.parquet" +echo " shared_metadata/{artist,album}_item_mapping.parquet, embeddings.parquet" +echo " processed_5b/{train_sessions,test_events,session_index}.parquet" +echo " processed_5b/item_popularity.npy, processed_5b/split_meta.json" +echo "[download_dataset] verify integrity with: ./verify_dataset.sh" diff --git a/recommendation_v4/generative_recommenders/README.md b/recommendation_v4/generative_recommenders/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/recommendation_v4/generative_recommenders/common.py b/recommendation_v4/generative_recommenders/common.py new file mode 100644 index 000000000..2ff8edf80 --- /dev/null +++ b/recommendation_v4/generative_recommenders/common.py @@ -0,0 +1,513 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import abc +import copy +import os +from enum import Enum, unique +from typing import Any, Callable, List, Optional, Tuple + +import torch + +# @manual=//triton:triton +import triton +from generative_recommenders.ops.utils import is_sm100_plus, is_sm90_plus +from torch.fx._symbolic_trace import is_fx_tracing +from torch.utils._python_dispatch import _get_current_dispatch_mode_stack + +# @manual=//triton:triton +from triton.runtime.autotuner import Autotuner + +try: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") +except OSError: + pass + +try: + # @manual=//triton:triton + import triton.language.extra.tlx # type: ignore + + HAS_TLX = True +except ImportError: + HAS_TLX = False + +try: + from generative_recommenders.fb.triton_cc.utils import triton_cc + from hammer.ops.triton.utils import triton_autotune + from hammer.utils import is_dev_mode, set_dev_mode, set_verbose_level +except ImportError: + # pyre-ignore + def triton_cc(annotations): + # pyre-ignore + def decorator(fn): + return fn + + return decorator + + # pyre-ignore + def triton_autotune( + configs: List[triton.Config], + key: List[str], + # pyre-ignore + prune_configs_by=None, + # pyre-ignore + reset_to_zero=None, + # pyre-ignore + restore_value=None, + warmup: int = 25, + rep: int = 100, + ): + # pyre-ignore + def decorator(fn): + return Autotuner( + fn, + fn.arg_names, + configs, + key, + reset_to_zero, + restore_value, + pre_hook=None, + post_hook=None, + prune_configs_by=prune_configs_by, + warmup=warmup, + rep=rep, + ) + + return decorator + + DEV_MODE: bool = False + VERBOSE_LEVEL: int = 0 + + def set_dev_mode(val: bool) -> None: + global DEV_MODE + DEV_MODE = val + + def is_dev_mode() -> bool: + global DEV_MODE # noqa: F824 + return DEV_MODE + + def set_verbose_level(level: int) -> None: + global VERBOSE_LEVEL + VERBOSE_LEVEL = level + + def get_verbose_level() -> int: + global VERBOSE_LEVEL # noqa: F824 + return VERBOSE_LEVEL + + +@unique +class HammerKernel(Enum): + TRITON = "TRITON" + TLX = "TLX" + PYTORCH = "PYTORCH" + CUDA = "CUDA" + TRITON_CC = "TRITON_CC" + TRITON_INFERENCE = "TRITON_INFERENCE" + CUTEDSL = "CUTEDSL" + + +class HammerModule(torch.nn.Module, abc.ABC): + _is_inference: bool = False + _use_triton_cc: bool = True + _training_dtype: torch.dtype = torch.float32 + _hammer_kernel: Optional[HammerKernel] = None + + def __init__( + self, + is_inference: bool, + training_dytpe: torch.dtype = torch.float32, + use_triton_cc: bool = _use_triton_cc, + hammer_kernel: Optional[HammerKernel] = None, + ) -> None: + super().__init__() + self._is_inference = is_inference + self._training_dtype = training_dytpe + self._hammer_kernel = hammer_kernel + self._use_triton_cc = use_triton_cc + + def hammer_kernel(self) -> HammerKernel: + kernel = self._hammer_kernel + if kernel is not None: + return kernel + if self._is_inference and self._use_triton_cc: + return HammerKernel.TRITON_CC + else: + return HammerKernel.TRITON + + # pyre-ignore[2] + def recursive_setattr(self, name: str, value: Any) -> None: + for _, module in self.named_modules(): + if hasattr(module, name): + setattr(module, name, value) + + def set_use_triton_cc(self, use_triton_cc: bool) -> None: + self._use_triton_cc = use_triton_cc + self.recursive_setattr("_use_triton_cc", use_triton_cc) + + def set_is_inference(self, is_inference: bool) -> None: + self._is_inference = is_inference + self.recursive_setattr("_is_inference", is_inference) + + def set_training_dtype(self, training_dtype: torch.dtype) -> None: + self._training_dtype = training_dtype + self.recursive_setattr("_training_dtype", training_dtype) + + def set_hammer_kernel(self, hammer_kernel: HammerKernel) -> None: + self._hammer_kernel = hammer_kernel + self.recursive_setattr("_hammer_kernel", hammer_kernel) + + @property + def is_inference(self) -> bool: + return self._is_inference + + @property + def is_eval(self) -> bool: + return (not self._is_inference) and (not self.training) + + @property + def is_train(self) -> bool: + return (not self._is_inference) and self.training + + +def generate_sparse_seq_len( + size: int, + max_seq_len: int, + sparsity: float, + device: torch.device, +) -> torch.Tensor: + if sparsity == 0.0: + return torch.zeros(size=(size,), device=device, dtype=torch.int) + elif sparsity == 1.0: + return torch.ones(size=(size,), device=device, dtype=torch.int) * max_seq_len + elif sparsity >= 0.5: + min_seq_len: int = int((2 * sparsity - 1.0) * max_seq_len) + return torch.randint( + low=min_seq_len, + high=max_seq_len, + size=(size,), + device=device, + dtype=torch.int, + ) + else: + min_seq_len: int = 0 + max_seq_len: int = int(2 * sparsity * max_seq_len) + return torch.randint( + low=min_seq_len, + high=max_seq_len, + size=(size,), + device=device, + dtype=torch.int, + ) + + +def apply_sampling( + lengths: torch.Tensor, + alpha: float, + max_seq_len: int, +) -> torch.Tensor: + threshold = int(max_seq_len ** (alpha / 2)) + no_sample_prob = (max_seq_len**alpha) / torch.pow(lengths, 2) + users_to_sample = torch.logical_and( + lengths > threshold, + torch.rand_like(no_sample_prob) < 1 - no_sample_prob, + ) + lengths = torch.where(users_to_sample, threshold, lengths) + return lengths + + +nv_gpu_unavailable: Tuple[bool, str] = ( + not torch.cuda.is_available() or torch.cuda.device_count() == 0, + "CUDA is not available or no GPUs detected", +) +nv_gpu_available: bool = not nv_gpu_unavailable[0] + + +amd_gpu_unavailable: Tuple[bool, str] = ( + not torch.version.hip, + "AMD HIP not available or no GPUs detected", +) +amd_gpu_available: bool = not amd_gpu_unavailable[0] + +gpu_unavailable: Tuple[bool, str] = ( + not nv_gpu_available and not amd_gpu_available, + "CUDA/HIP is not available or no GPUs detected", +) + +gpu_available: bool = not gpu_unavailable[0] + +blackwell_tlx_unavailable: Tuple[bool, str] = ( + not is_sm100_plus() or not HAS_TLX, + "Skip TLX and blackwell only tests", +) + +tma_unavailable: Tuple[bool, str] = ( + not is_sm90_plus(), # noqa + "Skip TMA only tests", +) + + +def switch_to_contiguous_if_needed(x: torch.Tensor) -> torch.Tensor: + if torch.jit.is_scripting(): + if x.stride(-1) == 1: + return x + return x.contiguous() + if torch.compiler.is_compiling(): + # Tell Dynamo this data-dependent value is in the range (0, 10**9) + torch._check(x.size(0) > 0) + torch._check(x.size(0) < 10**9) + # FX cannot trace Python control flow over symbolic stride checks + # (`x.stride(-1) == 1`). For AOT-T lowering, conservatively emit the + # contiguous op instead of branching on a symbolic value. + if is_fx_tracing(): + return x.contiguous() + if x.stride(-1) == 1: + return x + return x.contiguous() + + +def cdiv(x: int, y: int) -> int: + return (x + y - 1) // y + + +def backend_allow_tf32() -> bool: + return True + + +BACKEND_ALLOW_TF32: bool = backend_allow_tf32() + + +def next_power_of_2(n: int) -> int: + """Return the smallest power of 2 greater than or equal to n""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + n += 1 + return n + + +def _prev_power_of_2_bitwise(x: int) -> int: + """Return the largest power of 2 less than or equal to x.""" + x |= x >> 1 + x |= x >> 2 + x |= x >> 4 + x |= x >> 8 + x |= x >> 16 + x |= x >> 32 + return (x >> 1) + 1 + + +@torch.fx.wrap +def _prev_power_of_2_legacy(x: int) -> int: + if torch.compiler.is_compiling(): + # Re-write to make Dynamo happy + x_tensor = torch.scalar_tensor(x, dtype=torch.int64) # type: ignore[arg-type] + x_tensor_orig = x_tensor.clone() + out_val = next_power_of_2(int(x_tensor.item())) # type: ignore[arg-type] + out = torch.scalar_tensor(out_val, dtype=torch.int64) + return int(torch.where(torch.lt(x_tensor_orig, out), out // 2, out).item()) # type: ignore[return-value] + else: + out = next_power_of_2(x) + return out // 2 if out > x else out + + +prev_power_of_2: Callable[[int], int] = ( + _prev_power_of_2_legacy + if os.environ.get("PREV_POWER_OF_2_IMPL", "legacy") == "legacy" + else _prev_power_of_2_bitwise +) + + +STATIC_MAX_SEQ_LENS: List[int] = [] +USE_RUNTIME_MAX_SEQ_LEN: bool = False + + +def set_static_max_seq_lens(max_seq_lens: List[int]) -> None: + global STATIC_MAX_SEQ_LENS + STATIC_MAX_SEQ_LENS = copy.deepcopy(max_seq_lens) + STATIC_MAX_SEQ_LENS.sort() + + +def set_use_runtime_max_seq_len(use_runtime_max_seq_len: bool) -> None: + global USE_RUNTIME_MAX_SEQ_LEN + USE_RUNTIME_MAX_SEQ_LEN = use_runtime_max_seq_len + + +def autotune_max_seq_len(runtime_max_seq_len: int) -> int: + global USE_RUNTIME_MAX_SEQ_LEN # noqa: F824 + + if USE_RUNTIME_MAX_SEQ_LEN: + return prev_power_of_2(runtime_max_seq_len) + else: + if STATIC_MAX_SEQ_LENS == []: + return 1 + for max_len in STATIC_MAX_SEQ_LENS: + if max_len >= runtime_max_seq_len: + return max_len + return STATIC_MAX_SEQ_LENS[-1] + + +def fine_grained_autotune_max_seq_len(runtime_max_seq_len: int) -> int: + global USE_RUNTIME_MAX_SEQ_LEN # noqa: F824 + + if USE_RUNTIME_MAX_SEQ_LEN: + return _fine_grained_bucket_size(runtime_max_seq_len) + else: + if STATIC_MAX_SEQ_LENS == []: + return 1 + for max_len in STATIC_MAX_SEQ_LENS: + if max_len >= runtime_max_seq_len: + return max_len + return STATIC_MAX_SEQ_LENS[-1] + + +def _generate_fine_grained_buckets() -> List[int]: + buckets = [ + 1024, + 2048, + 4096, + 8192, + 12288, + 16384, + 24576, + 32768, + 40960, + 49152, + 65536, + 81920, + 98304, + ] + return buckets + + +@torch.fx.wrap +def _fine_grained_bucket_size(x: int) -> int: + if torch.compiler.is_compiling(): + x_tensor = torch.scalar_tensor(x, dtype=torch.int64) + buckets = torch.tensor(_generate_fine_grained_buckets(), dtype=torch.int64) + + mask = buckets >= x_tensor + valid_buckets = torch.where( + mask, buckets, torch.tensor(2**31 - 1, dtype=torch.int64) + ) + + result = torch.where(mask.any(), valid_buckets.min(), buckets[-1]) + + return int(result.item()) + else: + buckets = _generate_fine_grained_buckets() + + for bucket in buckets: + if x <= bucket: + return bucket + + return buckets[-1] + + +@torch.fx.wrap +def fx_unwrap_optional_tensor(optional: Optional[torch.Tensor]) -> torch.Tensor: + assert optional is not None, "Expected optional to be non-None Tensor" + return optional + + +@torch.fx.wrap +def fx_arange(len: int, device: torch.device) -> torch.Tensor: + return torch.arange(len, device=device) + + +@torch.fx.wrap +def fx_infer_max_len( + lengths: torch.Tensor, +) -> int: + # Do not call ".item()" to avoid unbacked symint problems for lowering + max_len = int(lengths.max()) + if not torch.jit.is_scripting() and torch.compiler.is_compiling(): + # Tell Dynamo this data-dependent value is in the range [0, 10**9) + torch._check_is_size(max_len) + torch._check(max_len < 10**9) + torch._check(max_len > 0) + return max_len + + +@torch.fx.wrap +def fx_mark_length_features(tensor: torch.Tensor) -> torch.Tensor: + return tensor + + +@torch.fx.wrap +def fx_torch_ones( + shape: List[int], + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + return torch.ones(shape, device=device, dtype=dtype) + + +@torch.fx.wrap +def fx_torch_zeros(shape: List[int], device: torch.device) -> torch.Tensor: + return torch.zeros(shape, device=device) + + +def _is_in_dispatch_modes(mode_names: List[str]) -> bool: + modes = _get_current_dispatch_mode_stack() + return any(mode.__class__.__name__ in mode_names for mode in modes) + + +def should_trigger_eager_impl() -> bool: + if torch.jit.is_scripting(): + return True + if torch.compiler.is_compiling(): + return False + return _is_in_dispatch_modes(["SplitDispatchMode", "FakeTensorMode"]) + + +@torch.fx.wrap +def jagged_to_padded_dense( + values: torch.Tensor, + offsets: List[torch.Tensor], + max_lengths: List[int], + padding_value: float, +) -> torch.Tensor: + return torch.ops.fbgemm.jagged_to_padded_dense( + values=values, + offsets=offsets, + max_lengths=max_lengths, + padding_value=padding_value, + ) + + +@torch.fx.wrap +def dense_to_jagged( + dense: torch.Tensor, + x_offsets: List[torch.Tensor], +) -> torch.Tensor: + return torch.ops.fbgemm.dense_to_jagged( + dense=dense, + x_offsets=x_offsets, + )[0] + + +def init_mlp_weights_optional_bias(m: torch.nn.Module) -> None: + if isinstance(m, torch.nn.Linear): + torch.nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + m.bias.data.fill_(0.0) diff --git a/recommendation_v4/generative_recommenders/dlrm_v3/checkpoint.py b/recommendation_v4/generative_recommenders/dlrm_v3/checkpoint.py new file mode 100644 index 000000000..46cc10e2e --- /dev/null +++ b/recommendation_v4/generative_recommenders/dlrm_v3/checkpoint.py @@ -0,0 +1,680 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict +""" +Checkpoint utilities for saving and loading DLRMv3 model checkpoints. + +This module provides functions for saving and loading distributed model checkpoints, +including both sparse (embedding) and dense (non-embedding) components. +""" + +import gc +import logging +import os +import random +import shutil +import time +from datetime import datetime +from typing import Any, Dict, Optional, Set, Tuple + +import gin +import numpy as np +import torch +from generative_recommenders.dlrm_v3.utils import ( + BinnedCumulativeAUC, + LifetimeAUCMetricComputation, + MetricsLogger, +) +from torch.distributed.checkpoint.stateful import Stateful +from torch.optim.optimizer import Optimizer +from torchrec.distributed.types import ShardedTensor + +logger: logging.Logger = logging.getLogger(__name__) + +# Sentinel meaning "the saved window completed in full" — when the loop reads +# this back it advances start_ts past the saved train_ts. Anything >=0 means the +# saved checkpoint stopped mid-window after K batches; resume continues that +# window at batch K. +WINDOW_COMPLETE: int = -1 + +# Filename (per-rank) holding the lifetime-AUC trailing buffers, mirroring the +# rng_rank{rank}.pt pattern. The buffers are per-rank-local, so a single +# rank-0 copy in non_sparse.ckpt would (wrongly) restore 1/world_size of the +# true history to every rank — hence a dedicated per-rank artifact. +METRICBUF_FILE_FMT: str = "metricbuf_rank{rank}.pt" + + +def _metric_blob_state_dict(m: torch.nn.Module) -> Dict[str, Any]: + """State dict for the shared (rank-0) non_sparse.ckpt metric blob. + + Both lifetime-AUC backends carry per-rank-local state that is persisted + authoritatively per-rank in ``metricbuf_rank{rank}.pt``; we must keep it out + of the shared blob so a rank's load doesn't inherit rank-0's counts: + + - ``LifetimeAUCMetricComputation``: drop the explicitly-serialized trailing + buffer keys (the rest of the blob keys are the parent's persistent state). + - ``BinnedCumulativeAUC``: zero the histogram buffers (they are persistent so + the keys must remain for a strict load, but the values are neutralized). + + All other metrics serialize normally. In both cases the per-rank file is + loaded afterward and is authoritative. + """ + sd = m.state_dict() + if isinstance(m, LifetimeAUCMetricComputation): + prefix = LifetimeAUCMetricComputation._LIFETIME_KEY_PREFIX + sd = {k: v for k, v in sd.items() if not k.startswith(prefix)} + elif isinstance(m, BinnedCumulativeAUC): + sd = { + k: (torch.zeros_like(v) if torch.is_tensor(v) else v) + for k, v in sd.items() + } + return sd + + +def _collect_perrank_metric_state( + metric_logger: "MetricsLogger", +) -> Dict[str, Dict[str, Any]]: + """Map "||" -> state_dict for every metric whose + cumulative state is per-rank-local and must be restored per-rank: + + - lifetime-AUC instances (`LifetimeAUCMetricComputation` trailing buffer, or + `BinnedCumulativeAUC` histograms) in class_metrics train/eval. Covers the + train lifetime AUC and, in legacy single-set eval, the eval lifetime AUC, + under either configured backend. + - the ENTIRE cumulative eval set (`eval_cum`, both class + regression) used + by the streaming dual-set eval: the lifetime-AUC backend state plus the + persistent cumulative scalar sums of NE/Accuracy/GAUC/MSE/MAE. + + Selected by structure/isinstance (not a hard index) since metric positions + depend on the configured tasks/mode. + """ + out: Dict[str, Dict[str, Any]] = {} + for mode in ("train", "eval"): + for idx, m in enumerate(metric_logger.class_metrics.get(mode, [])): + if isinstance(m, (LifetimeAUCMetricComputation, BinnedCumulativeAUC)): + out[f"class_metrics|{mode}|{idx}"] = m.state_dict() + for coll in ("class_metrics", "regression_metrics"): + for idx, m in enumerate(getattr(metric_logger, coll).get("eval_cum", [])): + out[f"{coll}|eval_cum|{idx}"] = m.state_dict() + return out + + +def _restore_perrank_metric_state( + metric_logger: "MetricsLogger", state: Dict[str, Dict[str, Any]] +) -> None: + for key, sd in state.items(): + coll, mode, idx_str = key.split("|") + getattr(metric_logger, coll)[mode][int(idx_str)].load_state_dict(sd) + + +def _perrank_sample_counts(metric_logger: "MetricsLogger") -> Dict[str, int]: + out: Dict[str, int] = {} + + def _count(m: torch.nn.Module) -> Optional[int]: + if isinstance(m, LifetimeAUCMetricComputation): + return m.lifetime_sample_count() + if isinstance(m, BinnedCumulativeAUC): + return m.cumulative_sample_count() + return None + + for mode in ("train", "eval", "eval_cum"): + for idx, m in enumerate(metric_logger.class_metrics.get(mode, [])): + n = _count(m) + if n is not None: + out[f"class|{mode}|{idx}"] = n + return out + + +class SparseState(Stateful): + """ + Stateful wrapper for sparse (embedding) tensors in a model. + + This class implements the Stateful interface for distributed checkpointing, + allowing sparse tensors to be saved and loaded separately from dense tensors. + + Args: + model: The PyTorch model containing sparse tensors. + sparse_tensor_keys: Set of keys identifying sparse tensors in the model's state dict. + """ + + def __init__(self, model: torch.nn.Module, sparse_tensor_keys: Set[str]) -> None: + self.model = model + self.sparse_tensor_keys = sparse_tensor_keys + + def state_dict(self) -> Dict[str, torch.Tensor]: + out_dict: Dict[str, torch.Tensor] = {} + is_sharded_tensor: Optional[bool] = None + for k, v in self.model.state_dict().items(): + if k in self.sparse_tensor_keys: + if is_sharded_tensor is None: + is_sharded_tensor = isinstance(v, ShardedTensor) + assert is_sharded_tensor == isinstance(v, ShardedTensor) + out_dict[k] = v + return out_dict + + def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: + incompatible_keys = self.model.load_state_dict(state_dict, strict=False) + assert not incompatible_keys.unexpected_keys + + +def is_sparse_key(k: str, v: torch.Tensor) -> bool: + return isinstance(v, ShardedTensor) or "embedding_collection" in k + + +def load_dense_state_dict(model: torch.nn.Module, state_dict: Dict[str, Any]) -> None: + own_state = model.state_dict() + own_state_dense_keys = {k for k, v in own_state.items() if not is_sparse_key(k, v)} + state_dict_dense_keys = { + k for k, v in state_dict.items() if not is_sparse_key(k, v) + } + assert own_state_dense_keys == state_dict_dense_keys, ( + f"expects {own_state_dense_keys} but gets {state_dict_dense_keys}" + ) + for name in state_dict_dense_keys: + param = state_dict[name] + if isinstance(param, torch.nn.Parameter): + # backwards compatibility for serialized parameters + param = param.data + own_state[name].copy_(param) + + +def _rng_state(device: torch.device) -> Dict[str, Any]: + """Snapshot every RNG source bit-equal training depends on. + + HSTU has stochastic dropout (input_dropout=0.2, linear_dropout_rate=0.1) + consuming the per-device CUDA RNG cycle each step. Without round-tripping + these, a resumed run draws different dropout masks and the resumed AUC + trajectory diverges from the uninterrupted run within a few steps. + """ + return { + "cpu": torch.get_rng_state(), + "cuda": torch.cuda.get_rng_state(device), + "numpy": np.random.get_state(), + "random": random.getstate(), + } + + +def _restore_rng_state(state: Dict[str, Any], device: torch.device) -> None: + torch.set_rng_state(state["cpu"]) + torch.cuda.set_rng_state(state["cuda"], device) + np.random.set_state(state["numpy"]) + random.setstate(state["random"]) + + +def _list_numeric_subdirs(base_path: str) -> list[str]: + """Return subdir names of `base_path` that look like an int, sorted ascending. + + Filters out `*.tmp` (orphaned in-progress saves), `*.sparse/` and any other + non-numeric entries. + """ + if not os.path.isdir(base_path): + return [] + out: list[str] = [] + for name in os.listdir(base_path): + if name.isdigit(): + out.append(name) + return sorted(out, key=int) + + +def _resolve_latest_subdir(path: str) -> str: + """Map a base ckpt dir → its highest-numbered numeric subdir. + + Used so users can set `load_dmp_checkpoint.path = ""` (or + `CKPT_PATH=`) and automatically pick up the most recent save without + needing to know which step number to point at. If `path` already names a leaf save (numeric basename) it's returned + unchanged. If the base dir has no numeric subdirs yet — the cold-start case + where ``CKPT_PATH`` is configured but nothing has been saved (e.g. the + interrupt phase of the resume test starts from a freshly-cleaned dir) — we + return ``""`` so ``load_*_checkpoint`` no-ops instead of asserting on a + missing ``sparse/.metadata``. + """ + if not path: + return path + base = path.rstrip("/") + leaf = os.path.basename(base) + if leaf.isdigit(): + return base # already a leaf, caller knows what it wants + subs = _list_numeric_subdirs(base) + if not subs: + logger.info("No checkpoint subdirs under %s — cold start (no load).", base) + return "" # nothing to load → load_*_checkpoint short-circuits + resolved = os.path.join(base, subs[-1]) + logger.info("Auto-latest checkpoint: %s → %s", base, resolved) + return resolved + + +def _prune_old_checkpoints(base_path: str, keep_last_n: int, just_saved_subdir: str) -> None: + """Delete numeric subdirs older than the keep_last_n most recent. + + Defensive: never prune `just_saved_subdir` even if it would be evicted by + the keep_last_n window (shouldn't happen since we just wrote it, but + catches off-by-one bugs). Skipped entirely when keep_last_n<=0. + """ + if keep_last_n <= 0: + return + subs = _list_numeric_subdirs(base_path) + if len(subs) <= keep_last_n: + return + to_prune = subs[:-keep_last_n] + for name in to_prune: + full = os.path.join(base_path, name) + if os.path.realpath(full) == os.path.realpath(just_saved_subdir): + continue + try: + shutil.rmtree(full) + logger.info("Pruned old checkpoint: %s", full) + except OSError as e: + logger.warning("Failed to prune %s: %s", full, e) + + +def _cleanup_stale_tmps(base_path: str) -> None: + """Remove `*.tmp`/`*.old` subdirs left by a crashed prior save attempt. + + `*.tmp` = an interrupted write; `*.old` = an interrupted atomic-overwrite + swap (see the promotion step in save_dmp_checkpoint). Both are non-numeric + so `_resolve_latest_subdir` already ignores them; this just reclaims disk. + """ + if not os.path.isdir(base_path): + return + for name in os.listdir(base_path): + if name.endswith(".tmp") or name.endswith(".old"): + full = os.path.join(base_path, name) + try: + shutil.rmtree(full) + logger.warning("Removed stale checkpoint dir: %s", full) + except OSError as e: + logger.warning("Failed to remove stale dir %s: %s", full, e) + + +@gin.configurable +def save_dmp_checkpoint( + model: torch.nn.Module, + optimizer: Optimizer, + metric_logger: MetricsLogger, + rank: int, + batch_idx: int, + path: str = "", + keep_last_n: int = 1, + train_ts: Optional[int] = None, + batch_idx_in_window: int = WINDOW_COMPLETE, + device: Optional[torch.device] = None, + split_contract: Optional[Dict[str, Any]] = None, +) -> None: + """ + Save a distributed model checkpoint including sparse and dense components. + + Writes into a per-rank-coordinated atomic layout: + /.tmp/ ← directory written into during save + // ← atomically renamed from .tmp on success + + A crash mid-save leaves the `.tmp/` orphan, which `_cleanup_stale_tmps` + sweeps on the next save attempt and which `_resolve_latest_subdir` ignores + (non-numeric basename). The previous successful `/` remains valid. + + Args: + model: The model to checkpoint. + optimizer: The optimizer whose state should be saved. + metric_logger: The metrics logger containing training/eval metrics. + rank: The current process rank in distributed training. + batch_idx: Subdir name (for streaming we set this == train_ts so the + on-disk layout monotonically increases). + path: Base path for saving the checkpoint. If empty, no checkpoint is saved. + keep_last_n: Number of most-recent numeric subdirs to retain after a + successful save. Set 1 (default) for disk-bounded long runs; + <=0 disables pruning. + train_ts: For streaming-train-eval, the current train timestamp. + Stored in non_sparse.ckpt so resume knows which window to enter. + batch_idx_in_window: For streaming-train-eval, batches completed within + train_ts. WINDOW_COMPLETE (-1) means the window finished; resume + advances to train_ts+1. >=0 means crash happened mid-window; resume + re-enters train_ts at batch_idx_in_window. + device: CUDA device for the per-rank RNG snapshot. Required for + bit-equal trajectories across resume (HSTU dropout consumes the + per-device RNG cycle). + """ + if path == "": + return + # Exclude checkpoint wall-time from the train step-time window so step_ms + # reports canonical compute latency; the duration is surfaced separately + # (window_ckpt_time_ms + the per-save log below). pause/resume are no-ops if + # metric_logger is None. Not wrapped in try/finally: a save that raises + # crashes the process (supervisor restarts fresh), so a dangling pause on + # the soon-dead logger is irrelevant. + _t_ckpt_start = time.perf_counter() + if metric_logger is not None: + metric_logger.pause_perf("ckpt") + base_path = path + # Atomic-save layout: write to .tmp, rename to final, prune older. + tmp_subdir = f"{base_path}/{batch_idx}.tmp" + final_subdir = f"{base_path}/{batch_idx}" + + if rank == 0: + _cleanup_stale_tmps(base_path) + # Always (re)write into a fresh .tmp. An existing `final_subdir` with the + # same batch_idx (e.g. a later in-window save for the same train_ts, or a + # deterministic re-run at the same step) is overwritten atomically at the + # promotion step below — NOT skipped here. Skipping would desync ranks: + # the collective barrier/checkpoint.save calls below run on *every* rank, + # so a rank-0-only early return deadlocks ranks 1..N on the next barrier. + shutil.rmtree(tmp_subdir, ignore_errors=True) + os.makedirs(tmp_subdir, exist_ok=True) + os.makedirs(f"{tmp_subdir}/sparse/", exist_ok=True) + torch.distributed.barrier() + sparse_path = f"{tmp_subdir}/sparse/" + non_sparse_ckpt = f"{tmp_subdir}/non_sparse.ckpt" + + sparse_tensor_keys = { + k for k, v in model.state_dict().items() if isinstance(v, ShardedTensor) + } + if rank == 0: + dense_state_dict = { + k: v + for k, v in model.state_dict().items() + if not isinstance(v, ShardedTensor) + } + class_metric_state_dict = { + "train": [ + _metric_blob_state_dict(m) + for m in metric_logger.class_metrics["train"] + ], + "eval": [ + _metric_blob_state_dict(m) + for m in metric_logger.class_metrics["eval"] + ], + } + regression_metric_state_dict = { + "train": [ + m.state_dict() for m in metric_logger.regression_metrics["train"] + ], + "eval": [m.state_dict() for m in metric_logger.regression_metrics["eval"]], + } + torch.save( + { + "dense_dict": dense_state_dict, + "optimizer_dict": optimizer.state_dict(), + "class_metrics": class_metric_state_dict, + "reg_metrics": regression_metric_state_dict, + "global_step": metric_logger.global_step, + # MLPerf progress counter (global trained samples). Defaulted on + # load so pre-existing checkpoints restore as 0 and resume the + # count from there. + "cumulative_train_samples": metric_logger.cumulative_train_samples, + # MLPerf run-marker state: lets a resume relaunch continue the + # SAME run's event stream without re-emitting INIT_START/RUN_START. + "mlperf_run_started": metric_logger.mlperf_run_started, + "sparse_tensor_keys": sparse_tensor_keys, + # Streaming resume fields. Defaulted on load so old checkpoints + # (pre-streaming-resume) still load as a normal restart. + "train_ts": train_ts, + "batch_idx_in_window": batch_idx_in_window, + # Immutable train:eval split + resume-determinism contract + # (train_split_percentage, split_salt, eval holdout window, + # batch_size, world_size). Validated on resume so a relaunch + # cannot silently change the split (which would desync the skip + # offset and/or train on held-out eval users). None for + # non-holdout / legacy runs. + "split_contract": split_contract, + }, + non_sparse_ckpt, + ) + + # Per-rank RNG snapshot. Written even on a single rank because dropout's + # randomness comes from the CUDA generator which differs across devices. + if device is not None: + rng_path = f"{tmp_subdir}/rng_rank{rank}.pt" + torch.save(_rng_state(device), rng_path) + + # Per-rank cumulative metric state (lifetime-AUC buffers + cumulative-eval + # histograms/scalar sums). Written by EVERY rank (outside the rank-0 block) + # because this state is per-rank-local; restoring rank-0's copy to all ranks + # would lose (world_size-1)/world_size of the history. + if metric_logger is not None: + perrank_state = _collect_perrank_metric_state(metric_logger) + if perrank_state: + torch.save( + perrank_state, + f"{tmp_subdir}/{METRICBUF_FILE_FMT.format(rank=rank)}", + ) + logger.info( + "checkpoint save: cumulative metric state rank=%d samples=%s", + rank, + _perrank_sample_counts(metric_logger), + ) + + torch.distributed.barrier() + sparse_dict = {"sparse_dict": SparseState(model, sparse_tensor_keys)} + torch.distributed.checkpoint.save( + sparse_dict, + storage_writer=torch.distributed.checkpoint.FileSystemWriter(sparse_path), + ) + torch.distributed.barrier() + # Promote .tmp → final, then prune. Done on rank 0 only since the directory + # operations are global filesystem state. + if rank == 0: + if os.path.exists(final_subdir): + # POSIX rename() refuses to replace a non-empty directory, so we + # can't os.replace(tmp, final) directly. Swap the old snapshot aside + # (instant rename), move the new one into place, then delete the old. + # The `.old` name is non-numeric → ignored by _resolve_latest_subdir + # and swept by _cleanup_stale_tmps on the next save if we crash mid-swap. + old_aside = f"{final_subdir}.old" + shutil.rmtree(old_aside, ignore_errors=True) + os.replace(final_subdir, old_aside) + os.replace(tmp_subdir, final_subdir) + shutil.rmtree(old_aside, ignore_errors=True) + else: + os.replace(tmp_subdir, final_subdir) + _prune_old_checkpoints(base_path, keep_last_n, final_subdir) + logger.info( + "checkpoint successfully saved → %s (wall-time %.2fs)", + final_subdir, + time.perf_counter() - _t_ckpt_start, + ) + torch.distributed.barrier() + if metric_logger is not None: + metric_logger.resume_perf("ckpt") + + +@gin.configurable +def load_sparse_checkpoint( + model: torch.nn.Module, + path: str = "", +) -> None: + if path == "": + return + sparse_path = f"{path}/sparse/" + + sparse_tensor_keys = { + k for k, v in model.state_dict().items() if is_sparse_key(k, v) + } + sparse_dict = {"sparse_dict": SparseState(model, sparse_tensor_keys)} + gc.collect() + torch.distributed.checkpoint.load( + sparse_dict, + storage_reader=torch.distributed.checkpoint.FileSystemReader(sparse_path), + ) + gc.collect() + print("sparse checkpoint successfully loaded") + + +@gin.configurable +def load_nonsparse_checkpoint( + model: torch.nn.Module, + device: torch.device, + optimizer: Optional[Optimizer] = None, + metric_logger: Optional[MetricsLogger] = None, + path: str = "", + rank: int = 0, +) -> Tuple[Optional[int], int, Optional[Dict[str, Any]]]: + """ + Load non-sparse (dense) components from a checkpoint. + + Loads dense model parameters, and optionally optimizer state and metrics. + Also restores per-rank RNG state if a matching `rng_rank{rank}.pt` is found + next to `non_sparse.ckpt`. + + Returns: + (train_ts, batch_idx_in_window, split_contract) — the streaming resume + hint and the saved train:eval split contract (None for legacy / non- + holdout checkpoints). `(None, WINDOW_COMPLETE, None)` if not a streaming + checkpoint or no path supplied. + """ + if path == "": + return None, WINDOW_COMPLETE, None + non_sparse_ckpt = f"{path}/non_sparse.ckpt" + + # weights_only=False: these are our own trusted checkpoints, and they hold + # non-tensor objects (optimizer/metric state dicts, numpy-backed RNG state) + # that PyTorch>=2.6's weights_only=True default refuses to unpickle. + non_sparse_state_dict = torch.load( + non_sparse_ckpt, map_location=device, weights_only=False + ) + load_dense_state_dict(model, non_sparse_state_dict["dense_dict"]) + print("dense checkpoint successfully loaded") + if optimizer is not None: + optimizer.load_state_dict(non_sparse_state_dict["optimizer_dict"]) + print("optimizer checkpoint successfully loaded") + if metric_logger is not None: + metric_logger.global_step = non_sparse_state_dict["global_step"] + # Defaulted for legacy checkpoints written before the counter existed. + metric_logger.cumulative_train_samples = non_sparse_state_dict.get( + "cumulative_train_samples", 0 + ) + # Defaulted False for legacy/cold checkpoints: a resume that loads a + # checkpoint where the run was already open continues without re-emitting + # the run markers. + metric_logger.mlperf_run_started = non_sparse_state_dict.get( + "mlperf_run_started", False + ) + class_metric_state_dict = non_sparse_state_dict["class_metrics"] + regression_metric_state_dict = non_sparse_state_dict["reg_metrics"] + # Length-safe positional restore: if a checkpoint was written with a + # different metric set (e.g. tasks added/removed since), restore the + # overlap instead of crashing with an IndexError at run end. + def _restore_metric_list( + live: list, saved: Optional[list], label: str + ) -> None: + saved = saved or [] + if len(live) != len(saved): + logger.warning( + "metric count mismatch for %s: live=%d saved=%d; " + "restoring overlapping %d", + label, + len(live), + len(saved), + min(len(live), len(saved)), + ) + for i in range(min(len(live), len(saved))): + live[i].load_state_dict(saved[i]) + + _restore_metric_list( + metric_logger.class_metrics["train"], + class_metric_state_dict.get("train"), + "class/train", + ) + _restore_metric_list( + metric_logger.class_metrics["eval"], + class_metric_state_dict.get("eval"), + "class/eval", + ) + _restore_metric_list( + metric_logger.regression_metrics["train"], + regression_metric_state_dict.get("train"), + "reg/train", + ) + _restore_metric_list( + metric_logger.regression_metrics["eval"], + regression_metric_state_dict.get("eval"), + "reg/eval", + ) + + # Per-rank cumulative metric state restore. This runs AFTER the generic + # load above so it is authoritative: the shared blob carries no lifetime + # buffers (stripped at save) nor any eval_cum state, and each rank + # restores its OWN cumulative state here. Missing file = legacy/pre-fix + # checkpoint; cumulative metrics self-heal (lifetime AUC refills; the + # binned-AUC histograms / scalar sums restart from zero). + mb_path = f"{path}/{METRICBUF_FILE_FMT.format(rank=rank)}" + if os.path.exists(mb_path): + perrank_state = torch.load( + mb_path, map_location=device, weights_only=False + ) + _restore_perrank_metric_state(metric_logger, perrank_state) + logger.info( + "checkpoint load: cumulative metric state rank=%d samples=%s", + rank, + _perrank_sample_counts(metric_logger), + ) + else: + logger.info( + "checkpoint load: no per-rank cumulative metric state at %s " + "(legacy/pre-fix checkpoint); cumulative metrics will refill", + mb_path, + ) + + # Per-rank RNG restore. Missing file = bit-equal trajectory not requested at + # save time; we silently continue (the test harness checks for both). + rng_path = f"{path}/rng_rank{rank}.pt" + if os.path.exists(rng_path): + # weights_only=False: RNG state is numpy/Python tuples, not tensors. + rng_state = torch.load(rng_path, map_location="cpu", weights_only=False) + _restore_rng_state(rng_state, device) + logger.info("RNG state restored from %s", rng_path) + + train_ts = non_sparse_state_dict.get("train_ts") + batch_idx_in_window = non_sparse_state_dict.get( + "batch_idx_in_window", WINDOW_COMPLETE + ) + split_contract = non_sparse_state_dict.get("split_contract") + return train_ts, batch_idx_in_window, split_contract + + +@gin.configurable +def load_dmp_checkpoint( + model: torch.nn.Module, + optimizer: Optimizer, + metric_logger: MetricsLogger, + device: torch.device, + path: str = "", + rank: int = 0, +) -> Tuple[Optional[int], int, Optional[Dict[str, Any]], bool]: + """ + Load a complete distributed model checkpoint (both sparse and dense components). + + `path` is auto-resolved: if it points at a directory containing numeric + subdirs (e.g. CKPT_PATH=/), the highest-numbered subdir is used. If it + already names a leaf save (e.g. /300), it's used as-is. Empty string = + no load. + + Returns: + (train_ts, batch_idx_in_window, split_contract, cold_start) — streaming + resume hint plus the saved split contract, and `cold_start` which is True + iff there was nothing to load (no checkpoint resolved). `cold_start` + distinguishes a genuine fresh run (no weights loaded) from a resume that + merely lacks a split contract (e.g. a legacy/non-streaming checkpoint), + which the caller's split-contract guard must still reject. + """ + resolved = _resolve_latest_subdir(path) + cold_start = resolved == "" + load_sparse_checkpoint(model=model, path=resolved) + train_ts, batch_idx_in_window, split_contract = load_nonsparse_checkpoint( + model=model, + optimizer=optimizer, + metric_logger=metric_logger, + path=resolved, + device=device, + rank=rank, + ) + return train_ts, batch_idx_in_window, split_contract, cold_start diff --git a/recommendation_v4/generative_recommenders/dlrm_v3/configs.py b/recommendation_v4/generative_recommenders/dlrm_v3/configs.py new file mode 100644 index 000000000..387fb4900 --- /dev/null +++ b/recommendation_v4/generative_recommenders/dlrm_v3/configs.py @@ -0,0 +1,830 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict +""" +Configuration module for DLRMv3 model. + +This module provides configuration functions for the HSTU model architecture and embedding table configurations. +""" + +import hashlib +import math +import os +from typing import Callable, Dict, Optional, Tuple + +import gin +import torch + +from generative_recommenders.modules.dlrm_hstu import DlrmHSTUConfig +from generative_recommenders.modules.multitask_module import ( + MultitaskTaskType, + TaskConfig, +) +from torchrec.modules.embedding_configs import DataType, EmbeddingConfig + +HSTU_EMBEDDING_DIM = 512 # final DLRMv3 model +HASH_SIZE = 10_000_000 +HASH_SIZE_1B = 1_000_000_000 + +# (name, keys, num_embeddings, salt) — single source of truth for both +# get_embedding_table_config("yambda-5b") and the dataset's cross-hash inputs. +# Sizes mirror Primus-DLRM/configs/bench_onetrans_large_5b_cross_feat_shampoo.yaml. +YAMBDA_5B_CROSS_SPECS = [ + ("user_x_artist", ("uid", "artist_id"), 100_000_000, 0), + ("user_x_album", ("uid", "album_id"), 40_000_000, 0), + ("user_x_hour", ("uid", "hour_of_day"), 24_000_000, 0), + ("item_x_hour", ("item_id", "hour_of_day"), 40_000_000, 0), + ("artist_x_hour", ("artist_id", "hour_of_day"), 32_000_000, 0), + ("user_x_is_organic", ("uid", "is_organic"), 2_000_000, 0), + ("user_x_artist_x_hour", ("uid", "artist_id", "hour_of_day"), 40_000_000, 0), +] + + +@gin.configurable +def get_hstu_configs( + dataset: str = "debug", + max_seq_len: Optional[int] = None, + max_num_candidates: Optional[int] = None, + hstu_embedding_table_dim: Optional[int] = None, + hstu_transducer_embedding_dim: Optional[int] = None, + hstu_num_heads: Optional[int] = None, + hstu_attn_num_layers: Optional[int] = None, + hstu_attn_linear_dim: Optional[int] = None, + hstu_attn_qk_dim: Optional[int] = None, + hstu_input_dropout_ratio: Optional[float] = None, + hstu_linear_dropout_rate: Optional[float] = None, +) -> DlrmHSTUConfig: + """ + Create and return HSTU model configuration. + + Builds a complete DlrmHSTUConfig with default hyperparameters for the HSTU + architecture including attention settings, embedding dimensions, dropout rates, + and feature name mappings. + + Args: + dataset: Dataset identifier (currently unused, reserved for dataset-specific configs). + + Returns: + DlrmHSTUConfig: Complete configuration object for the HSTU model. + """ + hstu_config = DlrmHSTUConfig( + hstu_num_heads=4, + hstu_attn_linear_dim=128, + hstu_attn_qk_dim=128, + hstu_attn_num_layers=5, + hstu_embedding_table_dim=HSTU_EMBEDDING_DIM, + hstu_preprocessor_hidden_dim=256, + hstu_transducer_embedding_dim=512, + hstu_group_norm=False, + hstu_input_dropout_ratio=0.2, + hstu_linear_dropout_rate=0.1, + causal_multitask_weights=0.2, + ) + if "movielens" in dataset: + assert dataset in [ + "movielens-1m", + "movielens-20m", + "movielens-13b", + "movielens-18b", + ] + hstu_config.user_embedding_feature_names = ( + [ + "movie_id", + "user_id", + "sex", + "age_group", + "occupation", + "zip_code", + ] + if dataset == "movielens-1m" + else [ + "movie_id", + "user_id", + ] + ) + hstu_config.item_embedding_feature_names = [ + "item_movie_id", + ] + hstu_config.uih_post_id_feature_name = "movie_id" + hstu_config.uih_action_time_feature_name = "action_timestamp" + hstu_config.candidates_querytime_feature_name = "item_query_time" + hstu_config.candidates_weight_feature_name = "item_action_weights" + hstu_config.uih_weight_feature_name = "item_weights" + hstu_config.candidates_watchtime_feature_name = "item_movie_rating" + hstu_config.action_weights = [1, 2, 4, 8, 16] + hstu_config.contextual_feature_to_max_length = ( + { + "user_id": 1, + "sex": 1, + "age_group": 1, + "occupation": 1, + "zip_code": 1, + } + if dataset == "movielens-1m" + else { + "user_id": 1, + } + ) + hstu_config.contextual_feature_to_min_uih_length = ( + { + "user_id": 20, + "sex": 20, + "age_group": 20, + "occupation": 20, + "zip_code": 20, + } + if dataset == "movielens-1m" + else { + "user_id": 20, + } + ) + hstu_config.merge_uih_candidate_feature_mapping = [ + ("movie_id", "item_movie_id"), + ("movie_rating", "item_movie_rating"), + ("action_timestamp", "item_query_time"), + ("item_weights", "item_action_weights"), + ("dummy_watch_time", "item_dummy_watchtime"), + ] + hstu_config.hstu_uih_feature_names = ( + [ + "user_id", + "sex", + "age_group", + "occupation", + "zip_code", + "movie_id", + "movie_rating", + "action_timestamp", + "item_weights", + "dummy_watch_time", + ] + if dataset == "movielens-1m" + else [ + "user_id", + "movie_id", + "movie_rating", + "action_timestamp", + "item_weights", + "dummy_watch_time", + ] + ) + hstu_config.hstu_candidate_feature_names = [ + "item_movie_id", + "item_movie_rating", + "item_query_time", + "item_action_weights", + "item_dummy_watchtime", + ] + hstu_config.max_num_candidates = 10 + hstu_config.max_num_candidates_inference = ( + 5 if dataset not in ["movielens-13b", "movielens-18b"] else 2048 + ) + hstu_config.multitask_configs = [ + TaskConfig( + task_name="rating", + task_weight=1, + task_type=MultitaskTaskType.BINARY_CLASSIFICATION, + ) + ] + elif "streaming" in dataset: + hstu_config.user_embedding_feature_names = [ + "item_id", + "user_id", + "item_category_id", + ] + hstu_config.item_embedding_feature_names = [ + "item_candidate_id", + "item_candidate_category_id", + ] + hstu_config.uih_post_id_feature_name = "item_id" + hstu_config.uih_action_time_feature_name = "action_timestamp" + hstu_config.candidates_querytime_feature_name = "item_query_time" + hstu_config.candidates_weight_feature_name = "item_action_weights" + hstu_config.uih_weight_feature_name = "item_weights" + hstu_config.candidates_watchtime_feature_name = "item_rating" + hstu_config.action_weights = [1, 2, 4, 8, 16] + hstu_config.action_embedding_init_std = 5.0 + hstu_config.contextual_feature_to_max_length = {"user_id": 1} + hstu_config.contextual_feature_to_min_uih_length = {"user_id": 20} + hstu_config.merge_uih_candidate_feature_mapping = [ + ("item_id", "item_candidate_id"), + ("item_rating", "item_candidate_rating"), + ("action_timestamp", "item_query_time"), + ("item_weights", "item_action_weights"), + ("dummy_watch_time", "item_dummy_watchtime"), + ("item_category_id", "item_candidate_category_id"), + ] + hstu_config.hstu_uih_feature_names = [ + "user_id", + "item_id", + "item_rating", + "action_timestamp", + "item_weights", + "dummy_watch_time", + "item_category_id", + ] + hstu_config.hstu_candidate_feature_names = [ + "item_candidate_id", + "item_candidate_rating", + "item_query_time", + "item_action_weights", + "item_dummy_watchtime", + "item_candidate_category_id", + ] + hstu_config.max_num_candidates = 32 + hstu_config.max_num_candidates_inference = 2048 + hstu_config.multitask_configs = [ + TaskConfig( + task_name="rating", + task_weight=1, + task_type=MultitaskTaskType.BINARY_CLASSIFICATION, + ) + ] + elif "kuairand" in dataset: + hstu_config.user_embedding_feature_names = [ + "video_id", + "user_id", + "user_active_degree", + "follow_user_num_range", + "fans_user_num_range", + "friend_user_num_range", + "register_days_range", + ] + hstu_config.item_embedding_feature_names = [ + "item_video_id", + ] + hstu_config.uih_post_id_feature_name = "video_id" + hstu_config.uih_action_time_feature_name = "action_timestamp" + hstu_config.candidates_querytime_feature_name = "item_query_time" + hstu_config.uih_weight_feature_name = "action_weight" + hstu_config.candidates_weight_feature_name = "item_action_weight" + hstu_config.candidates_watchtime_feature_name = "item_target_watchtime" + # There are more contextual features in the dataset, see https://kuairand.com/ for details + hstu_config.contextual_feature_to_max_length = { + "user_id": 1, + "user_active_degree": 1, + "follow_user_num_range": 1, + "fans_user_num_range": 1, + "friend_user_num_range": 1, + "register_days_range": 1, + } + hstu_config.merge_uih_candidate_feature_mapping = [ + ("video_id", "item_video_id"), + ("action_timestamp", "item_query_time"), + ("action_weight", "item_action_weight"), + ("watch_time", "item_target_watchtime"), + ] + hstu_config.hstu_uih_feature_names = [ + "user_id", + "user_active_degree", + "follow_user_num_range", + "fans_user_num_range", + "friend_user_num_range", + "register_days_range", + "video_id", + "action_timestamp", + "action_weight", + "watch_time", + ] + hstu_config.hstu_candidate_feature_names = [ + "item_video_id", + "item_action_weight", + "item_target_watchtime", + "item_query_time", + ] + hstu_config.multitask_configs = [ + TaskConfig( + task_name="is_click", + task_weight=1, + task_type=MultitaskTaskType.BINARY_CLASSIFICATION, + ), + TaskConfig( + task_name="is_like", + task_weight=2, + task_type=MultitaskTaskType.BINARY_CLASSIFICATION, + ), + TaskConfig( + task_name="is_follow", + task_weight=4, + task_type=MultitaskTaskType.BINARY_CLASSIFICATION, + ), + TaskConfig( + task_name="is_comment", + task_weight=8, + task_type=MultitaskTaskType.BINARY_CLASSIFICATION, + ), + TaskConfig( + task_name="is_forward", + task_weight=16, + task_type=MultitaskTaskType.BINARY_CLASSIFICATION, + ), + TaskConfig( + task_name="is_hate", + task_weight=32, + task_type=MultitaskTaskType.BINARY_CLASSIFICATION, + ), + TaskConfig( + task_name="long_view", + task_weight=64, + task_type=MultitaskTaskType.BINARY_CLASSIFICATION, + ), + TaskConfig( + task_name="is_profile_enter", + task_weight=128, + task_type=MultitaskTaskType.BINARY_CLASSIFICATION, + ), + ] + hstu_config.action_weights = [1, 2, 4, 8, 16, 32, 64, 128] + elif "yambda" in dataset: + assert dataset in ["yambda-5b"] + cross_names = [name for (name, _k, _n, _s) in YAMBDA_5B_CROSS_SPECS] + # Per-table dim defaults to HSTU_EMBEDDING_DIM (512); override via the + # `get_hstu_configs.hstu_embedding_table_dim = N` gin binding if needed. + # Note: the embedding tables in get_embedding_table_config also use + # HSTU_EMBEDDING_DIM and must stay aligned with this value. + hstu_config.hstu_embedding_table_dim = HSTU_EMBEDDING_DIM + hstu_config.hstu_transducer_embedding_dim = 512 + hstu_config.max_seq_len = 8192 + hstu_config.max_num_candidates = 1 + hstu_config.max_num_candidates_inference = 1 + # Per dlrm_hstu convention (see streaming-100b/movielens): + # - user_embedding_feature_names = UIH-side post-id features + contextual features. + # After main_forward merges UIH + candidate, only these entries hold the merged + # sequence (used by user-side transducer). + # - item_embedding_feature_names = candidate-side names only. _item_forward + # concats these along dim=-1 to feed the item MLP (per-candidate, not per-position). + hstu_config.user_embedding_feature_names = ( + ["uid"] + + cross_names + + ["item_id", "artist_id", "album_id"] + ) + hstu_config.item_embedding_feature_names = [ + "item_candidate_id", + "item_candidate_artist_id", + "item_candidate_album_id", + ] + hstu_config.uih_post_id_feature_name = "item_id" + hstu_config.uih_action_time_feature_name = "action_timestamp" + hstu_config.uih_weight_feature_name = "action_weight" + hstu_config.candidates_querytime_feature_name = "item_query_time" + hstu_config.candidates_weight_feature_name = "item_action_weight" + hstu_config.candidates_watchtime_feature_name = "item_dummy_watchtime" + hstu_config.action_weights = [1, 2, 4] # lp, like, skip bits + hstu_config.contextual_feature_to_max_length = { + "uid": 1, + **{name: 1 for name in cross_names}, + } + hstu_config.contextual_feature_to_min_uih_length = { + "uid": 0, + **{name: 0 for name in cross_names}, + } + # uih names map to candidate names (no name collisions allowed): + # item_id/artist_id/album_id appear with prefix "item_" on candidate side. + hstu_config.merge_uih_candidate_feature_mapping = [ + ("item_id", "item_candidate_id"), + ("artist_id", "item_candidate_artist_id"), + ("album_id", "item_candidate_album_id"), + ("action_weight", "item_action_weight"), + ("action_timestamp", "item_query_time"), + ("dummy_watch_time", "item_dummy_watchtime"), + ] + hstu_config.hstu_uih_feature_names = ( + ["uid"] + + cross_names + + [ + "item_id", + "artist_id", + "album_id", + "action_weight", + "action_timestamp", + "dummy_watch_time", + ] + ) + hstu_config.hstu_candidate_feature_names = [ + "item_candidate_id", + "item_candidate_artist_id", + "item_candidate_album_id", + "item_query_time", + "item_action_weight", + "item_dummy_watchtime", + ] + hstu_config.multitask_configs = [ + TaskConfig( + task_name="listen_plus", + task_weight=1, # matches action_weights[0] (lp bit) + task_type=MultitaskTaskType.BINARY_CLASSIFICATION, + ) + ] + else: + hstu_config.user_embedding_feature_names = [ + "uih_post_id", + "uih_owner_id", + "viewer_id", + "dummy_contexual", + ] + hstu_config.item_embedding_feature_names = [ + "item_post_id", + "item_owner_id", + ] + hstu_config.uih_post_id_feature_name = "uih_post_id" + hstu_config.uih_action_time_feature_name = "uih_action_time" + hstu_config.candidates_querytime_feature_name = "item_query_time" + hstu_config.candidates_weight_feature_name = "item_action_weight" + hstu_config.candidates_watchtime_feature_name = "item_target_watchtime" + hstu_config.contextual_feature_to_max_length = { + "viewer_id": 1, + "dummy_contexual": 1, + } + hstu_config.contextual_feature_to_min_uih_length = { + "viewer_id": 128, + "dummy_contexual": 128, + } + hstu_config.merge_uih_candidate_feature_mapping = [ + ("uih_post_id", "item_post_id"), + ("uih_owner_id", "item_owner_id"), + ("uih_action_time", "item_query_time"), + ("uih_weight", "item_action_weight"), + ("uih_watchtime", "item_target_watchtime"), + ("uih_video_length", "item_video_length"), + ("uih_surface_type", "item_surface_type"), + ] + hstu_config.hstu_uih_feature_names = [ + "uih_post_id", + "uih_action_time", + "uih_weight", + "uih_owner_id", + "uih_watchtime", + "uih_surface_type", + "uih_video_length", + "viewer_id", + "dummy_contexual", + ] + hstu_config.hstu_candidate_feature_names = [ + "item_post_id", + "item_owner_id", + "item_surface_type", + "item_video_length", + "item_action_weight", + "item_target_watchtime", + "item_query_time", + ] + hstu_config.multitask_configs = [ + TaskConfig( + task_name="vvp100", + task_weight=1, + task_type=MultitaskTaskType.BINARY_CLASSIFICATION, + ) + ] + + # Apply gin overrides last so a value set in the gin file wins over the + # per-dataset defaults above. Anything left as None inherits the default + # the dataset branch (or DlrmHSTUConfig) chose. Example in a gin file: + # get_hstu_configs.max_seq_len = 4096 + # get_hstu_configs.hstu_embedding_table_dim = 256 + _gin_overrides = { + "max_seq_len": max_seq_len, + "max_num_candidates": max_num_candidates, + "max_num_candidates_inference": max_num_candidates, + "hstu_embedding_table_dim": hstu_embedding_table_dim, + "hstu_transducer_embedding_dim": hstu_transducer_embedding_dim, + "hstu_num_heads": hstu_num_heads, + "hstu_attn_num_layers": hstu_attn_num_layers, + "hstu_attn_linear_dim": hstu_attn_linear_dim, + "hstu_attn_qk_dim": hstu_attn_qk_dim, + "hstu_input_dropout_ratio": hstu_input_dropout_ratio, + "hstu_linear_dropout_rate": hstu_linear_dropout_rate, + } + for _name, _val in _gin_overrides.items(): + if _val is not None: + setattr(hstu_config, _name, _val) + + return hstu_config + + +def _stable_table_seed(init_seed: int, table_name: str) -> int: + """Deterministic 63-bit seed from (init_seed, table_name). + + Uses sha256 (not Python's salted built-in ``hash()``) so the per-table seed + is identical across processes/ranks/runs for a given ``$SEED`` + table name. + """ + digest = hashlib.sha256(f"{init_seed}:{table_name}".encode("utf-8")).digest() + return int.from_bytes(digest[:8], "big") & 0x7FFF_FFFF_FFFF_FFFF + + +def _uniform_init_bounds(cfg: EmbeddingConfig) -> Tuple[float, float]: + """Mirror TorchREC's default per-table init bounds. + + TorchREC falls back to ``uniform_(-1/sqrt(N), +1/sqrt(N))`` when a table does + not set ``weight_init_min/max``; honor any explicit bounds the config carries. + """ + bound = math.sqrt(1.0 / cfg.num_embeddings) + lo = -bound if cfg.weight_init_min is None else cfg.weight_init_min + hi = bound if cfg.weight_init_max is None else cfg.weight_init_max + return lo, hi + + +def _make_seeded_uniform_init( + table_seed: int, lo: float, hi: float +) -> Callable[[torch.Tensor], torch.Tensor]: + """Build a seeded in-place uniform initializer for one table's weight. + + TorchREC/FBGEMM calls ``init_fn`` with the (per-rank) local shard tensor on + its compute device, so we seed a generator on that same device. For a fixed + sharding plan (world size + plan unchanged) this makes embedding init + byte-reproducible run-to-run. + """ + + def _init(weight: torch.Tensor) -> torch.Tensor: + # TorchREC builds the unsharded EmbeddingCollection on the META device + # first (DMP materializes real storage on the compute device later). + # Meta tensors have no storage and torch.Generator(device="meta") is + # invalid ("META device type not an accelerator"), so skip them: the + # seeded init for the sharded/fused TBE path is provided by the RNG + # re-seed right before DMP in make_optimizer_and_shard. On a real + # device (eager/non-meta path) we still apply the per-table seeded fill. + if weight.device.type == "meta": + return weight + gen = torch.Generator(device=weight.device) + gen.manual_seed(table_seed) + with torch.no_grad(): + weight.uniform_(lo, hi, generator=gen) + return weight + + return _init + + +@gin.configurable +def get_embedding_table_config( + dataset: str = "debug", + embedding_dim: Optional[int] = None, + init_seed: Optional[int] = None, +) -> Dict[str, EmbeddingConfig]: + """ + Create and return embedding table configurations. + + Defines the embedding table configurations for item IDs, category IDs, and user IDs + with their respective dimensions and data types. + + Args: + dataset: Dataset identifier (currently unused, reserved for dataset-specific configs). + embedding_dim: Per-table embedding width override. When set via gin + (e.g. `get_embedding_table_config.embedding_dim = 256`), wins over + `HSTU_EMBEDDING_DIM`. Keep in sync with the matching gin override on + `get_hstu_configs.hstu_embedding_table_dim` — the model and the + tables must agree on dim or sharding will reject the plan. + init_seed: Base seed for the per-table seeded `init_fn` (Tier 1 + reproducible embedding init). When None, falls back to `$SEED` + (default 1), matching `seed_everything`. Each table draws from a + generator seeded by `sha256(init_seed, table_name)` so init is + reproducible run-to-run for a fixed sharding plan. + + Returns: + Dict mapping table names to their EmbeddingConfig objects. + """ + tables = _build_embedding_table_config(dataset=dataset, embedding_dim=embedding_dim) + + if init_seed is None: + init_seed = int(os.environ.get("SEED", "1")) + for name, cfg in tables.items(): + lo, hi = _uniform_init_bounds(cfg) + cfg.init_fn = _make_seeded_uniform_init( + _stable_table_seed(init_seed, name), lo, hi + ) + return tables + + +def _build_embedding_table_config( + dataset: str = "debug", + embedding_dim: Optional[int] = None, +) -> Dict[str, EmbeddingConfig]: + DIM = embedding_dim if embedding_dim is not None else HSTU_EMBEDDING_DIM + if "movielens" in dataset: + assert dataset in [ + "movielens-1m", + "movielens-20m", + "movielens-13b", + "movielens-18b", + ] + return ( + { + "movie_id": EmbeddingConfig( + num_embeddings=HASH_SIZE, + embedding_dim=DIM, + name="movie_id", + data_type=DataType.FP16, + feature_names=["movie_id", "item_movie_id"], + ), + "user_id": EmbeddingConfig( + num_embeddings=HASH_SIZE, + embedding_dim=DIM, + name="user_id", + data_type=DataType.FP16, + feature_names=["user_id"], + ), + "sex": EmbeddingConfig( + num_embeddings=HASH_SIZE, + embedding_dim=DIM, + name="sex", + data_type=DataType.FP16, + feature_names=["sex"], + ), + "age_group": EmbeddingConfig( + num_embeddings=HASH_SIZE, + embedding_dim=DIM, + name="age_group", + data_type=DataType.FP16, + feature_names=["age_group"], + ), + "occupation": EmbeddingConfig( + num_embeddings=HASH_SIZE, + embedding_dim=DIM, + name="occupation", + data_type=DataType.FP16, + feature_names=["occupation"], + ), + "zip_code": EmbeddingConfig( + num_embeddings=HASH_SIZE, + embedding_dim=DIM, + name="zip_code", + data_type=DataType.FP16, + feature_names=["zip_code"], + ), + } + if dataset == "movielens-1m" + else { + "movie_id": EmbeddingConfig( + num_embeddings=HASH_SIZE_1B, + embedding_dim=DIM, + name="movie_id", + data_type=DataType.FP16, + feature_names=["movie_id", "item_movie_id"], + ), + "user_id": EmbeddingConfig( + num_embeddings=3_000_000, + embedding_dim=DIM, + name="user_id", + data_type=DataType.FP16, + feature_names=["user_id"], + ), + } + ) + elif "streaming" in dataset: + return { + "item_id": EmbeddingConfig( + num_embeddings=HASH_SIZE_1B, + embedding_dim=DIM, + name="item_id", + data_type=DataType.FP16, + feature_names=["item_id", "item_candidate_id"], + ), + "item_category_id": EmbeddingConfig( + num_embeddings=128, + embedding_dim=DIM, + name="item_category_id", + data_type=DataType.FP16, + weight_init_max=1.0, + weight_init_min=-1.0, + feature_names=["item_category_id", "item_candidate_category_id"], + ), + "user_id": EmbeddingConfig( + num_embeddings=10_000_000, + embedding_dim=DIM, + name="user_id", + data_type=DataType.FP16, + feature_names=["user_id"], + ), + } + elif "kuairand" in dataset: + return { + "video_id": EmbeddingConfig( + num_embeddings=HASH_SIZE, + embedding_dim=DIM, + name="video_id", + data_type=DataType.FP16, + feature_names=["video_id", "item_video_id"], + ), + "user_id": EmbeddingConfig( + num_embeddings=HASH_SIZE, + embedding_dim=DIM, + name="user_id", + data_type=DataType.FP16, + feature_names=["user_id"], + ), + "user_active_degree": EmbeddingConfig( + num_embeddings=8, + embedding_dim=DIM, + name="user_active_degree", + data_type=DataType.FP16, + feature_names=["user_active_degree"], + ), + "follow_user_num_range": EmbeddingConfig( + num_embeddings=9, + embedding_dim=DIM, + name="follow_user_num_range", + data_type=DataType.FP16, + feature_names=["follow_user_num_range"], + ), + "fans_user_num_range": EmbeddingConfig( + num_embeddings=9, + embedding_dim=DIM, + name="fans_user_num_range", + data_type=DataType.FP16, + feature_names=["fans_user_num_range"], + ), + "friend_user_num_range": EmbeddingConfig( + num_embeddings=8, + embedding_dim=DIM, + name="friend_user_num_range", + data_type=DataType.FP16, + feature_names=["friend_user_num_range"], + ), + "register_days_range": EmbeddingConfig( + num_embeddings=8, + embedding_dim=DIM, + name="register_days_range", + data_type=DataType.FP16, + feature_names=["register_days_range"], + ), + } + elif "yambda" in dataset: + assert dataset in ["yambda-5b"] + tables: Dict[str, EmbeddingConfig] = { + "item_id": EmbeddingConfig( + num_embeddings=9_390_624, + embedding_dim=DIM, + name="item_id", + data_type=DataType.FP32, + feature_names=["item_id", "item_candidate_id"], + ), + "artist_id": EmbeddingConfig( + num_embeddings=1_293_395, + embedding_dim=DIM, + name="artist_id", + data_type=DataType.FP32, + feature_names=["artist_id", "item_candidate_artist_id"], + ), + "album_id": EmbeddingConfig( + num_embeddings=3_367_692, + embedding_dim=DIM, + name="album_id", + data_type=DataType.FP32, + feature_names=["album_id", "item_candidate_album_id"], + ), + "uid": EmbeddingConfig( + num_embeddings=1_000_001, + embedding_dim=DIM, + name="uid", + data_type=DataType.FP32, + feature_names=["uid"], + ), + } + for name, _keys, num_embeddings, _salt in YAMBDA_5B_CROSS_SPECS: + tables[name] = EmbeddingConfig( + num_embeddings=num_embeddings, + embedding_dim=DIM, + name=name, + data_type=DataType.FP32, + feature_names=[name], + ) + return tables + else: + return { + "post_id": EmbeddingConfig( + num_embeddings=HASH_SIZE, + embedding_dim=DIM, + name="post_id", + data_type=DataType.FP16, + feature_names=[ + "uih_post_id", + "item_post_id", + "uih_owner_id", + "item_owner_id", + ], + ), + "viewer_id": EmbeddingConfig( + num_embeddings=HASH_SIZE, + embedding_dim=DIM, + name="viewer_id", + data_type=DataType.FP16, + feature_names=["viewer_id"], + ), + "dummy_contexual": EmbeddingConfig( + num_embeddings=HASH_SIZE, + embedding_dim=DIM, + name="dummy_contexual", + data_type=DataType.FP16, + feature_names=["dummy_contexual"], + ), + } diff --git a/recommendation_v4/generative_recommenders/dlrm_v3/datasets/dataset.py b/recommendation_v4/generative_recommenders/dlrm_v3/datasets/dataset.py new file mode 100644 index 000000000..204c06df1 --- /dev/null +++ b/recommendation_v4/generative_recommenders/dlrm_v3/datasets/dataset.py @@ -0,0 +1,461 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-unsafe +""" +Dataset implementations for DLRMv3. + +This module provides dataset classes for loading and processing recommendation +data, including sample containers, collation functions, and random data generation. +""" + +import logging +import time +from dataclasses import dataclass +from typing import Dict, List, Tuple + +import torch +from generative_recommenders.modules.dlrm_hstu import DlrmHSTUConfig +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.streamable import Pipelineable + + +logging.basicConfig(level=logging.INFO) +logger: logging.Logger = logging.getLogger("dlrmv3_dataset") + + +@dataclass +class Samples(Pipelineable): + """ + Container for batched samples with user interaction history and candidate features. + + Attributes: + uih_features_kjt: User interaction history features as KeyedJaggedTensor. + candidates_features_kjt: Candidate item features as KeyedJaggedTensor. + """ + + uih_features_kjt: KeyedJaggedTensor + candidates_features_kjt: KeyedJaggedTensor + # UIH + candidate features concatenated into the single KJT that the model's + # sharded EmbeddingCollection consumes. Pre-built here (dataloader/CPU) rather + # than inside DlrmHSTU.forward so the embedding lookup's input is a plain + # attribute of the batch — which lets TorchRec's TrainPipelineSparseDist hoist + # its input_dist into the prefetch stage (otherwise the runtime cat + + # from_lengths_sync counts as an "input modification" and the embedding + # collection is left un-pipelined). + merged_sparse_features: KeyedJaggedTensor + + def to(self, device: torch.device, non_blocking: bool = False) -> "Samples": + """ + Move all tensors to the specified device (in place) and return self. + + Returning ``self`` (rather than ``None``) and accepting ``non_blocking`` + makes ``Samples`` conform to TorchRec's ``Pipelineable`` protocol so it + can be driven by ``TrainPipelineSparseDist``. Existing call sites that + use ``sample.to(device)`` for its side effect continue to work unchanged. + """ + for attr in vars(self): + setattr( + self, + attr, + getattr(self, attr).to(device=device, non_blocking=non_blocking), + ) + return self + + def record_stream(self, stream: torch.Stream) -> None: + """Record the contained KJTs on ``stream`` (Pipelineable protocol). + + Required by ``TrainPipelineSparseDist`` so the prefetched batch's H2D + copy on the side stream is not freed before compute consumes it. + """ + self.uih_features_kjt.record_stream(stream) + self.candidates_features_kjt.record_stream(stream) + self.merged_sparse_features.record_stream(stream) + + def pin_memory(self) -> "Samples": + """Pin the contained KJTs' host memory (Pipelineable protocol).""" + self.uih_features_kjt = self.uih_features_kjt.pin_memory() + self.candidates_features_kjt = self.candidates_features_kjt.pin_memory() + self.merged_sparse_features = self.merged_sparse_features.pin_memory() + return self + + def batch_size(self) -> int: + """ + Get the batch size of the samples. + + Returns: + Number of samples in the batch. + """ + return self.uih_features_kjt.stride() + + +def merge_uih_candidate_kjts( + uih_features: KeyedJaggedTensor, + candidates_features: KeyedJaggedTensor, +) -> KeyedJaggedTensor: + """Concatenate the UIH and candidate KJTs into the single KJT consumed by the + model's ``EmbeddingCollection``. + + Must mirror ``DlrmHSTU.preprocess`` exactly (key order = uih + candidates, + values/lengths concatenated in that order). Built on the dataloader side so + the model can read it straight off the batch and TorchRec can pipeline the + embedding ``input_dist``. + """ + return KeyedJaggedTensor.from_lengths_sync( + keys=uih_features.keys() + candidates_features.keys(), + values=torch.cat( + [uih_features.values(), candidates_features.values()], + dim=0, + ), + lengths=torch.cat( + [uih_features.lengths(), candidates_features.lengths()], + dim=0, + ), + ) + + +def collate_fn( + samples: List[Tuple[KeyedJaggedTensor, KeyedJaggedTensor]], +) -> Samples: + """ + Collate multiple samples into a batched Samples object. + + Args: + samples: List of (uih_features, candidates_features) tuples. + + Returns: + Batched Samples object with concatenated features. + """ + ( + uih_features_kjt_list, + candidates_features_kjt_list, + ) = list(zip(*samples)) + + uih_features_kjt = kjt_batch_func(uih_features_kjt_list) + candidates_features_kjt = kjt_batch_func(candidates_features_kjt_list) + return Samples( + uih_features_kjt=uih_features_kjt, + candidates_features_kjt=candidates_features_kjt, + merged_sparse_features=merge_uih_candidate_kjts( + uih_features_kjt, candidates_features_kjt + ), + ) + + +class Dataset: + """ + Base dataset class for DLRMv3. + + Provides the interface for loading, accessing, and managing samples + for recommendation model training and inference. + + Args: + hstu_config: HSTU model configuration. + **args: Additional arguments (unused in base class). + """ + + def __init__(self, hstu_config: DlrmHSTUConfig, **args): + self.arrival = None + self.image_list = [] + self.label_list = [] + self.image_list_inmemory = {} + self.last_loaded = -1.0 + + def preprocess(self, use_cache=True): + """ + Preprocess the dataset. + + Args: + use_cache: Whether to use cached preprocessed data. + + Raises: + NotImplementedError: Subclasses must implement this method. + """ + raise NotImplementedError("Dataset:preprocess") + + def get_item_count(self): + """ + Get the total number of items in the dataset. + + Returns: + Number of items. + """ + return len(self.image_list) + + def load_query_samples(self, sample_list): + """ + Load specified samples into memory. + + Args: + sample_list: List of sample indices to load. + + Raises: + NotImplementedError: Subclasses must implement this method. + """ + raise NotImplementedError("Dataset:load_query_samples") + + def unload_query_samples(self, sample_list): + """ + Unload specified samples from memory. + + Args: + sample_list: List of sample indices to unload. + + Raises: + NotImplementedError: Subclasses must implement this method. + """ + raise NotImplementedError("Dataset:unload_query_samples") + + def get_sample(self, id: int): + """ + Get a single sample by ID. + + Args: + id: Sample identifier. + + Raises: + NotImplementedError: Subclasses must implement this method. + """ + raise NotImplementedError("Dataset:get_sample") + + def get_samples(self, id_list: List[int]) -> Samples: + """ + Get multiple samples and collate them into a batch. + + Args: + id_list: List of sample identifiers. + + Returns: + Collated Samples object containing the batch. + """ + list_samples = [self.get_sample(ix) for ix in id_list] + return collate_fn(list_samples) + + +@torch.jit.script +def kjt_batch_func( + kjt_list: List[KeyedJaggedTensor], +) -> KeyedJaggedTensor: + """ + Batch multiple KeyedJaggedTensors into a single tensor. + + Uses FBGEMM operations for efficient batching and reordering of + jagged tensor data. + + Args: + kjt_list: List of KeyedJaggedTensors to batch. + + Returns: + Batched KeyedJaggedTensor with reordered indices and lengths. + """ + bs_list = [kjt.stride() for kjt in kjt_list] + bs = sum(bs_list) + batched_length = torch.cat([kjt.lengths() for kjt in kjt_list], dim=0) + batched_indices = torch.cat([kjt.values() for kjt in kjt_list], dim=0) + bs_offset = torch.ops.fbgemm.asynchronous_complete_cumsum( + torch.tensor(bs_list) + ).int() + batched_offset = torch.ops.fbgemm.asynchronous_complete_cumsum(batched_length) + reorder_length = torch.ops.fbgemm.reorder_batched_ad_lengths( + batched_length, bs_offset, bs + ) + reorder_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(reorder_length) + reorder_indices = torch.ops.fbgemm.reorder_batched_ad_indices( + batched_offset, batched_indices, reorder_offsets, bs_offset, bs + ) + out = KeyedJaggedTensor( + keys=kjt_list[0].keys(), + lengths=reorder_length.long(), + values=reorder_indices.long(), + ) + return out + + +def get_random_data( + contexual_features: List[str], + hstu_uih_keys: List[str], + hstu_candidates_keys: List[str], + uih_max_seq_len: int, + max_num_candidates: int, + value_bound: int = 1000, +): + """ + Generate random sample data for testing and debugging. + + Creates synthetic user interaction history and candidate features + with random values. + + Args: + contexual_features: List of contextual feature names. + hstu_uih_keys: List of UIH feature keys. + hstu_candidates_keys: List of candidate feature keys. + uih_max_seq_len: Maximum sequence length for UIH. + max_num_candidates: Maximum number of candidates. + value_bound: Upper bound for random values. + + Returns: + Tuple of (uih_features_kjt, candidates_features_kjt). + """ + uih_non_seq_feature_keys = contexual_features + uih_seq_feature_keys = [ + k for k in hstu_uih_keys if k not in uih_non_seq_feature_keys + ] + uih_seq_len = torch.randint( + int(uih_max_seq_len * 0.8), + uih_max_seq_len + 1, + (1,), + ).item() + uih_lengths = torch.tensor( + [1 for _ in uih_non_seq_feature_keys] + + [uih_seq_len for _ in uih_seq_feature_keys] + ) + # logging.info(f"uih_lengths: {uih_lengths}") + uih_values = torch.randint( + 1, + value_bound, + # pyre-ignore[6] + (uih_seq_len * len(uih_seq_feature_keys) + len(uih_non_seq_feature_keys),), + ) + uih_features_kjt = KeyedJaggedTensor( + keys=uih_non_seq_feature_keys + uih_seq_feature_keys, + lengths=uih_lengths.long(), + values=uih_values.long(), + ) + num_candidates = torch.randint( + 1, + max_num_candidates + 1, + (1,), + ).item() + candidates_lengths = num_candidates * torch.ones(len(hstu_candidates_keys)) + candidates_values = torch.randint( + 1, + value_bound, + (num_candidates * len(hstu_candidates_keys),), # pyre-ignore[6] + ) + candidates_features_kjt = KeyedJaggedTensor( + keys=hstu_candidates_keys, + lengths=candidates_lengths.long(), + values=candidates_values.long(), + ) + return uih_features_kjt, candidates_features_kjt + + +class DLRMv3RandomDataset(Dataset): + """ + Dataset that generates random synthetic data for DLRMv3. + + Useful for testing and benchmarking without real data dependencies. + + Args: + hstu_config: HSTU model configuration. + num_aggregated_samples: Total number of samples to generate. + is_inference: Whether the dataset is used for inference mode. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + """ + + def __init__( + self, + hstu_config: DlrmHSTUConfig, + num_aggregated_samples: int = 10000, + is_inference: bool = False, + *args, + **kwargs, + ): + super().__init__( + hstu_config=hstu_config, + ) + self.hstu_config: DlrmHSTUConfig = hstu_config + self._max_num_candidates: int = hstu_config.max_num_candidates + self._max_num_candidates_inference: int = ( + hstu_config.max_num_candidates_inference + ) + self._max_seq_len: int = hstu_config.max_seq_len + self._uih_keys: List[str] = hstu_config.hstu_uih_feature_names + self._candidates_keys: List[str] = hstu_config.hstu_candidate_feature_names + self._contextual_feature_to_max_length: Dict[str, int] = ( + hstu_config.contextual_feature_to_max_length + ) + self._max_uih_len: int = ( + self._max_seq_len + - self._max_num_candidates + - ( + len(self._contextual_feature_to_max_length) + if self._contextual_feature_to_max_length + else 0 + ) + ) + self._is_inference = is_inference + + self.contexual_features = [] + if hstu_config.contextual_feature_to_max_length is not None: + self.contexual_features = [ + p[0] for p in hstu_config.contextual_feature_to_max_length + ] + + self.num_aggregated_samples = num_aggregated_samples + self.items_in_memory = {} + + def get_sample(self, id: int) -> Tuple[KeyedJaggedTensor, KeyedJaggedTensor]: + """ + Get a sample by ID from in-memory storage. + + Args: + id: Sample identifier. + + Returns: + Tuple of (uih_features_kjt, candidates_features_kjt). + """ + return self.items_in_memory[id] + + def get_item_count(self): + """ + Get the total number of samples in the dataset. + + Returns: + Number of aggregated samples. + """ + return self.num_aggregated_samples + + def unload_query_samples(self, sample_list): + """ + Clear all samples from memory. + + Args: + sample_list: Ignored; clears all samples. + """ + self.items_in_memory = {} + + def load_query_samples(self, sample_list): + """ + Generate and load random samples into memory. + + Args: + sample_list: List of sample IDs to generate. + """ + max_num_candidates = ( + self._max_num_candidates_inference + if self._is_inference + else self._max_num_candidates + ) + self.items_in_memory = {} + for sample in sample_list: + self.items_in_memory[sample] = get_random_data( + contexual_features=self.contexual_features, + hstu_uih_keys=self.hstu_config.hstu_uih_feature_names, + hstu_candidates_keys=self.hstu_config.hstu_candidate_feature_names, + uih_max_seq_len=self._max_uih_len, + max_num_candidates=max_num_candidates, + ) + self.last_loaded = time.time() diff --git a/recommendation_v4/generative_recommenders/dlrm_v3/datasets/kuairand.py b/recommendation_v4/generative_recommenders/dlrm_v3/datasets/kuairand.py new file mode 100644 index 000000000..f6cd9e672 --- /dev/null +++ b/recommendation_v4/generative_recommenders/dlrm_v3/datasets/kuairand.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-unsafe +import json +import time +from functools import partial +from typing import Any, Dict, List + +import pandas as pd +import torch +from generative_recommenders.dlrm_v3.datasets.dataset import DLRMv3RandomDataset +from generative_recommenders.dlrm_v3.datasets.utils import ( + maybe_truncate_seq, + separate_uih_candidates, +) +from generative_recommenders.modules.dlrm_hstu import DlrmHSTUConfig +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +def process_and_hash_x(x: Any, hash_size: int) -> Any: + if isinstance(x, str): + x = json.loads(x) + if isinstance(x, list): + return [x_i % hash_size for x_i in x] + else: + return x % hash_size + + +class DLRMv3KuaiRandDataset(DLRMv3RandomDataset): + def __init__( + self, + hstu_config: DlrmHSTUConfig, + embedding_config: Dict[str, Any], + seq_logs_file: str, + is_inference: bool, + **kwargs, + ) -> None: + super().__init__(hstu_config=hstu_config, is_inference=is_inference) + self.seq_logs_frame: pd.DataFrame = pd.read_csv(seq_logs_file, delimiter=",") + # apply hashing from embedding table config + for key, table in embedding_config.items(): + assert key in self.seq_logs_frame.columns, ( + "Rename key in embedding table configs!" + ) + hash_size = table.num_embeddings + self.seq_logs_frame[key] = self.seq_logs_frame[key].apply( + partial(process_and_hash_x, hash_size=hash_size) + ) + + def get_item_count(self): + return len(self.seq_logs_frame) + + def unload_query_samples(self, sample_list): + self.items_in_memory = {} + + def load_query_samples(self, sample_list): + max_num_candidates = ( + self._max_num_candidates_inference + if self._is_inference + else self._max_num_candidates + ) + self.items_in_memory = {} + for idx in sample_list: + data = self.seq_logs_frame.iloc[idx] + if len(data.video_id) <= max_num_candidates: + continue + sample = self.load_item(data, max_num_candidates) + self.items_in_memory[idx] = sample + + self.last_loaded = time.time() + + def load_item(self, data, max_num_candidates): + with torch.profiler.record_function("load_item"): + video_history_uih, video_history_candidates = separate_uih_candidates( + data.video_id, + candidates_max_seq_len=max_num_candidates, + ) + action_weights_uih, action_weights_candidates = separate_uih_candidates( + data.action_weights, + candidates_max_seq_len=max_num_candidates, + ) + timestamps_uih, _ = separate_uih_candidates( + data.time_ms, + candidates_max_seq_len=max_num_candidates, + ) + watch_time_uih, watch_time_candidates = separate_uih_candidates( + data.play_time_ms, + candidates_max_seq_len=max_num_candidates, + ) + + video_history_uih = maybe_truncate_seq(video_history_uih, self._max_uih_len) + action_weights_uih = maybe_truncate_seq( + action_weights_uih, self._max_uih_len + ) + timestamps_uih = maybe_truncate_seq(timestamps_uih, self._max_uih_len) + watch_time_uih = maybe_truncate_seq(watch_time_uih, self._max_uih_len) + + uih_seq_len = len(video_history_uih) + assert uih_seq_len == len(timestamps_uih), ( + "history len differs from timestamp len." + ) + assert uih_seq_len == len(action_weights_uih), ( + "history len differs from weights len." + ) + assert uih_seq_len == len(watch_time_uih), ( + "history len differs from watch time len." + ) + + uih_kjt_values: List[torch.Tensor] = [] + uih_kjt_lengths: List[torch.Tensor] = [] + for name, length in self._contextual_feature_to_max_length.items(): + uih_kjt_values.append(data[name]) + uih_kjt_lengths.append(length) + + uih_kjt_values.extend( + video_history_uih + timestamps_uih + action_weights_uih + watch_time_uih + ) + + uih_kjt_lengths.extend( + [ + uih_seq_len + for _ in range( + len(self._uih_keys) + - len(self._contextual_feature_to_max_length) + ) + ] + ) + + dummy_query_time = max(timestamps_uih) + uih_features_kjt = KeyedJaggedTensor( + keys=self._uih_keys, + lengths=torch.tensor(uih_kjt_lengths).long(), + values=torch.tensor(uih_kjt_values).long(), + ) + + candidates_kjt_lengths = max_num_candidates * torch.ones( + len(self._candidates_keys) + ) + candidates_kjt_values = ( + video_history_candidates + + action_weights_candidates + + watch_time_candidates + + [dummy_query_time] * max_num_candidates + ) + candidates_features_kjt = KeyedJaggedTensor( + keys=self._candidates_keys, + lengths=torch.tensor(candidates_kjt_lengths).long(), + values=torch.tensor(candidates_kjt_values).long(), + ) + + return uih_features_kjt, candidates_features_kjt diff --git a/recommendation_v4/generative_recommenders/dlrm_v3/datasets/movie_lens.py b/recommendation_v4/generative_recommenders/dlrm_v3/datasets/movie_lens.py new file mode 100644 index 000000000..d74fb575b --- /dev/null +++ b/recommendation_v4/generative_recommenders/dlrm_v3/datasets/movie_lens.py @@ -0,0 +1,177 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-unsafe +import logging +import time +from typing import List, Optional + +import pandas as pd +import torch +from generative_recommenders.dlrm_v3.datasets.dataset import DLRMv3RandomDataset +from generative_recommenders.dlrm_v3.datasets.utils import ( + maybe_truncate_seq, + separate_uih_candidates, +) +from generative_recommenders.modules.dlrm_hstu import DlrmHSTUConfig +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + +logger = logging.getLogger(__name__) + + +class DLRMv3MovieLensDataset(DLRMv3RandomDataset): + def __init__( + self, + hstu_config: DlrmHSTUConfig, + ratings_file: str, + is_inference: bool, + *args, + **kwargs, + ): + super().__init__(hstu_config=hstu_config, is_inference=is_inference) + self.ratings_frame: Optional[pd.DataFrame] = None + if ratings_file != "": + self.ratings_frame = pd.read_csv( + ratings_file, + delimiter=",", + ) + assert hstu_config.action_weights is not None + self.action_weights: List[int] = hstu_config.action_weights + + def get_item_count(self): + assert self.ratings_frame is not None + return len(self.ratings_frame) + + def unload_query_samples(self, sample_list): + self.items_in_memory = {} + + def iloc(self, idx): + assert self.ratings_frame is not None + return self.ratings_frame.iloc[idx] + + def load_query_samples(self, sample_list): + max_num_candidates = ( + self._max_num_candidates_inference + if self._is_inference + else self._max_num_candidates + ) + self.items_in_memory = {} + for idx in sample_list: + data = self.iloc(idx) + if len(data.sequence_item_ids) <= max_num_candidates: + continue + sample = self.load_item(data, max_num_candidates) + self.items_in_memory[idx] = sample + + self.last_loaded = time.time() + + def get_timestamp_uih(self, data, max_num_candidates, size): + movie_timestamps_uih, _ = separate_uih_candidates( + data.sequence_timestamps, + candidates_max_seq_len=max_num_candidates, + ) + return movie_timestamps_uih + + def load_item(self, data, max_num_candidates): + movie_history_uih, movie_history_candidates = separate_uih_candidates( + data.sequence_item_ids, + candidates_max_seq_len=max_num_candidates, + ) + movie_history_ratings_uih, movie_history_ratings_candidates = ( + separate_uih_candidates( + data.sequence_ratings, + candidates_max_seq_len=max_num_candidates, + ) + ) + movie_timestamps_uih = self.get_timestamp_uih( + data=data, + max_num_candidates=max_num_candidates, + size=len(movie_history_uih), + ) + + assert len(movie_history_uih) == len(movie_timestamps_uih), ( + "history len differs from timestamp len." + ) + assert len(movie_history_uih) == len(movie_history_ratings_uih), ( + "history len differs from ratings len." + ) + + movie_history_uih = maybe_truncate_seq(movie_history_uih, self._max_uih_len) + movie_history_ratings_uih = maybe_truncate_seq( + movie_history_ratings_uih, self._max_uih_len + ) + movie_timestamps_uih = maybe_truncate_seq( + movie_timestamps_uih, self._max_uih_len + ) + + uih_kjt_values: List[torch.Tensor] = [] + uih_kjt_lengths: List[torch.Tensor] = [] + for name, length in self._contextual_feature_to_max_length.items(): + uih_kjt_values.append(data[name]) + uih_kjt_lengths.append(length) + + uih_seq_len = len(movie_history_uih) + movie_dummy_watch_times_uih = [0 for _ in range(uih_seq_len)] + action_weights_uih = [ + self.action_weights[int(rating) - 1] for rating in movie_history_ratings_uih + ] + uih_kjt_values.extend( + movie_history_uih + + movie_history_ratings_uih + + movie_timestamps_uih + + action_weights_uih + + movie_dummy_watch_times_uih + ) + uih_kjt_lengths.extend( + [ + uih_seq_len + for _ in range( + len(self._uih_keys) - len(self._contextual_feature_to_max_length) + ) + ] + ) + + dummy_query_time = ( + 0 if movie_timestamps_uih == [] else max(movie_timestamps_uih) + ) + uih_kjt_values.append(dummy_query_time) + uih_kjt_lengths.append(1) + uih_features_kjt = KeyedJaggedTensor( + keys=self._uih_keys + ["dummy_query_time"], + lengths=torch.tensor(uih_kjt_lengths).long(), + values=torch.tensor(uih_kjt_values).long(), + ) + + candidates_kjt_lengths = max_num_candidates * torch.ones( + len(self._candidates_keys) + ) + action_weights_candidates = [ + int(rating >= 3.5) for rating in movie_history_ratings_candidates + ] + candidates_kjt_values = ( + movie_history_candidates + + movie_history_ratings_candidates + + [dummy_query_time] * max_num_candidates # item_query_time + + action_weights_candidates + + [1] * max_num_candidates # item_dummy_watchtime + ) + candidates_features_kjt = KeyedJaggedTensor( + keys=self._candidates_keys, + lengths=candidates_kjt_lengths.detach().clone().long(), + values=torch.tensor(candidates_kjt_values).long(), + ) + return ( + uih_features_kjt, + candidates_features_kjt, + ) diff --git a/recommendation_v4/generative_recommenders/dlrm_v3/datasets/synthetic_movie_lens.py b/recommendation_v4/generative_recommenders/dlrm_v3/datasets/synthetic_movie_lens.py new file mode 100644 index 000000000..6cf8a5f56 --- /dev/null +++ b/recommendation_v4/generative_recommenders/dlrm_v3/datasets/synthetic_movie_lens.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-unsafe +import csv +import linecache +import logging +import sys +from typing import List + +import numpy as np +import pandas as pd +from generative_recommenders.dlrm_v3.datasets.movie_lens import DLRMv3MovieLensDataset +from generative_recommenders.modules.dlrm_hstu import DlrmHSTUConfig + +csv.field_size_limit(sys.maxsize) +logger = logging.getLogger(__name__) + + +class DLRMv3SyntheticMovieLensDataset(DLRMv3MovieLensDataset): + def __init__( + self, + hstu_config: DlrmHSTUConfig, + ratings_file_prefix: str, + is_inference: bool, + *args, + **kwargs, + ): + super().__init__( + hstu_config=hstu_config, is_inference=is_inference, ratings_file="" + ) + self.ratings_file_prefix = ratings_file_prefix + with open(f"{self.ratings_file_prefix}_users.csv", "r") as file: + reader = csv.reader(file) + self.users_cumsum: List[int] = np.cumsum( + [int(row[1]) for row in reader] + ).tolist() + + def get_item_count(self): + return self.users_cumsum[-1] + + def _process_line(self, line: str) -> pd.Series: + reader = csv.reader([line]) + parsed_line = next(reader) + user_id = int(parsed_line[0]) + sequence_item_ids = parsed_line[1] + sequence_ratings = parsed_line[2] + return pd.Series( + data={ + "user_id": user_id, + "sequence_item_ids": sequence_item_ids, + "sequence_ratings": sequence_ratings, + } + ) + + def iloc(self, idx) -> pd.Series: + assert idx < self.users_cumsum[-1] + file_idx: int = 0 + while self.users_cumsum[file_idx] <= idx: + file_idx += 1 + if file_idx == 0: + local_idx = idx + else: + local_idx = idx - self.users_cumsum[file_idx - 1] + line = linecache.getline( + f"{self.ratings_file_prefix}_{file_idx}.csv", local_idx + 1 + ) + data = self._process_line(line) + return data + + def get_timestamp_uih(self, data, max_num_candidates, size): + return [1] * size diff --git a/recommendation_v4/generative_recommenders/dlrm_v3/datasets/synthetic_streaming.py b/recommendation_v4/generative_recommenders/dlrm_v3/datasets/synthetic_streaming.py new file mode 100644 index 000000000..6e38fe334 --- /dev/null +++ b/recommendation_v4/generative_recommenders/dlrm_v3/datasets/synthetic_streaming.py @@ -0,0 +1,403 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict +""" +Synthetic streaming dataset for DLRMv3 inference benchmarking. + +This module provides a streaming dataset implementation that loads user interaction +data from pre-generated CSV files with temporal (timestamp) organization, suitable +for simulating real-time recommendation scenarios. +""" + +import csv +import logging +import sys +import time +from typing import Any, Dict, List, Set, Tuple + +import pandas as pd +import torch +from generative_recommenders.dlrm_v3.datasets.dataset import ( + collate_fn, + DLRMv3RandomDataset, + Samples, +) +from generative_recommenders.dlrm_v3.datasets.utils import ( + json_loads, + maybe_truncate_seq, +) +from generative_recommenders.modules.dlrm_hstu import DlrmHSTUConfig +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + +csv.field_size_limit(sys.maxsize) +logger: logging.Logger = logging.getLogger(__name__) + + +class DLRMv3SyntheticStreamingDataset(DLRMv3RandomDataset): + """ + Streaming dataset that loads pre-generated synthetic recommendation data. + + Supports timestamp-based data organization for simulating streaming scenarios + where user interaction histories evolve over time. + + Args: + hstu_config: HSTU model configuration. + ratings_file_prefix: Path prefix for rating data files. + is_inference: Whether dataset is used for inference. + train_ts: Number of timestamps used for training. + total_ts: Total number of timestamps in the data. + num_files: Number of data files (for parallelization). + num_users: Total number of users in the dataset. + num_items: Total number of items in the catalog. + num_categories: Number of item categories. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + """ + + def __init__( + self, + hstu_config: DlrmHSTUConfig, + ratings_file_prefix: str, + is_inference: bool, + train_ts: int, + total_ts: int, + num_files: int, + num_users: int, + num_items: int, + num_categories: int, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(hstu_config=hstu_config, is_inference=is_inference) + self.ratings_file_prefix = ratings_file_prefix + self.file_to_offsets: Dict[int, List[int]] = {} + with open(f"{self.ratings_file_prefix}offset.csv", "r") as file: + reader = csv.reader(file) + for size in range(num_files): + row = next(reader) + assert len(row) == 1 + offset = json_loads(row[0]) + assert len(offset) == num_users // num_files + self.file_to_offsets[size] = offset + self.ts_requests_offsets: List[int] = [] + with open(f"{self.ratings_file_prefix}requests_per_ts_offset.csv", "r") as file: + reader = csv.reader(file) + row = next(reader) + assert len(row) == 1 + self.ts_requests_offsets = json_loads(row[0]) + assert len(self.ts_requests_offsets) == total_ts + self.requests: List[int] = [] + self.ts_to_users_cumsum: Dict[int, List[int]] = {} + with open( + f"{self.ratings_file_prefix}users_cumsum_per_ts.csv", "r" + ) as cumsum_file: + reader = csv.reader(cumsum_file) + ts = 0 + for row in reader: + assert len(row) == 1 + cumsum = json_loads(row[0]) + self.ts_to_users_cumsum[ts] = cumsum + ts += 1 + self.train_ts = train_ts + self.total_ts = total_ts + self.num_files = num_files + self.ts: int = -1 + self.is_inference: bool = False + self.is_eval: bool = False + self.users_per_file: int = num_users // num_files + self.cached_files: Set[str] = set() + self.items_per_category: int = num_items // num_categories + assert hstu_config.action_weights is not None + self.action_weights: List[int] = hstu_config.action_weights + self.items_in_memory: Dict[ + int, Dict[int, Tuple[KeyedJaggedTensor, KeyedJaggedTensor]] + ] = {} + + def get_item_count(self) -> int: + return len(self.requests) + + def load_query_samples(self, sample_list: List[int]) -> None: + max_num_candidates = ( + self._max_num_candidates_inference + if self._is_inference + else self._max_num_candidates + ) + for idx in sample_list: + data = self.iloc(idx) + sample = self.load_item(data, max_num_candidates) + if self.ts not in self.items_in_memory: + self.items_in_memory[self.ts] = {} + self.items_in_memory[self.ts][idx] = sample + + self.last_loaded = time.time() + + def unload_query_samples(self, sample_list: List[int]) -> None: + self.items_in_memory = {} + + def get_sample(self, id: int) -> Tuple[KeyedJaggedTensor, KeyedJaggedTensor]: + return self.items_in_memory[self.ts][id] + + def get_sample_with_ts( + self, id: int, ts: int + ) -> Tuple[KeyedJaggedTensor, KeyedJaggedTensor]: + """ + Get a sample for a specific timestamp. + + Args: + id: Sample identifier. + ts: Timestamp index. + + Returns: + Tuple of (uih_features_kjt, candidates_features_kjt). + """ + return self.items_in_memory[ts][id] + + def get_samples_with_ts(self, id_list: List[int], ts: int) -> Samples: + """ + Get and collate multiple samples for a specific timestamp. + + Args: + id_list: List of sample identifiers. + ts: Timestamp index. + + Returns: + Collated Samples object. + """ + list_samples = [self.get_sample_with_ts(ix, ts) for ix in id_list] + return collate_fn(list_samples) + + def _process_line(self, line: str, user_id: int) -> pd.Series: + """ + Parse a CSV line into a pandas Series with user interaction data. + + Args: + line: CSV line containing user data. + user_id: User identifier. + + Returns: + pd.Series with parsed user interaction history and candidates. + """ + reader = csv.reader([line]) + parsed_line = next(reader) + # total ts + one more eval ts + one base ts so that uih won't be zero + # for each ts, ordered as candidate_ids, candidate_ratings, uih_ids, uih_ratings + assert len(parsed_line) == 4 * (self.total_ts + 2) + uih_item_ids_list = [] + uih_ratings_list = [] + candidate_item_ids = "" + candidate_ratings = "" + if (not self.is_eval) and (not self.is_inference): + assert self.ts < self.train_ts + for i in range(self.ts + 1): + if parsed_line[4 * i]: + uih_item_ids_list.append(parsed_line[2 + 4 * i]) + uih_ratings_list.append(parsed_line[3 + 4 * i]) + candidate_item_ids = parsed_line[4 * (self.ts + 1)] + candidate_ratings = parsed_line[1 + 4 * (self.ts + 1)] + elif self.is_eval: + for i in range(self.ts + 1): + if parsed_line[4 * i]: + uih_item_ids_list.append(parsed_line[2 + 4 * i]) + uih_ratings_list.append(parsed_line[3 + 4 * i]) + candidate_item_ids = parsed_line[4 * (self.ts + 1)] + candidate_ratings = parsed_line[1 + 4 * (self.ts + 1)] + else: + assert self.is_inference is True + assert self.ts >= self.train_ts + for i in range(self.train_ts + 1): + if parsed_line[4 * i]: + uih_item_ids_list.append(parsed_line[2 + 4 * i]) + uih_ratings_list.append(parsed_line[3 + 4 * i]) + for i in range(self.train_ts + 2, self.ts + 2): + if parsed_line[4 * i]: + uih_item_ids_list.append(parsed_line[2 + 4 * i]) + uih_ratings_list.append(parsed_line[3 + 4 * i]) + candidate_item_ids = parsed_line[4 * (self.ts + 2)] + candidate_ratings = parsed_line[1 + 4 * (self.ts + 2)] + uih_item_ids = ",".join(uih_item_ids_list) + uih_ratings = ",".join(uih_ratings_list) + assert candidate_item_ids != "" and candidate_ratings != "" + return pd.Series( + data={ + "user_id": user_id, + "uih_item_ids": uih_item_ids, + "uih_ratings": uih_ratings, + "candidate_item_ids": candidate_item_ids, + "candidate_ratings": candidate_ratings, + } + ) + + def iloc(self, idx: int) -> pd.Series: + """ + Get user data by request index using file offsets for efficient access. + + Args: + idx: Request index within the current timestamp. + + Returns: + pd.Series with parsed user interaction data. + """ + cumsum: List[int] = self.ts_to_users_cumsum[self.ts] + assert cumsum != [] + assert idx < cumsum[-1] + file_idx: int = 0 + while cumsum[file_idx] <= idx: + file_idx += 1 + user_idx = self.requests[idx] + filename = f"{self.ratings_file_prefix}{file_idx}.csv" + with open(filename, "r") as file: + idx = user_idx % self.users_per_file + file.seek(self.file_to_offsets[file_idx][idx]) + line = file.readline() + data = self._process_line(line=line, user_id=user_idx) + return data + + def get_timestamp_uih( + self, data: pd.Series, max_num_candidates: int, size: int + ) -> List[int]: + return [1] * size + + def set_ts(self, ts: int, train_only: bool = False) -> None: + """ + Set the current timestamp and load associated request data. + + Args: + ts: Timestamp index to set. + train_only: Accepted for API parity with the yambda dataset (which + supports a user-level train:eval holdout). This synthetic + dataset has no holdout, so the flag is ignored. + """ + logger.warning(f"Streaming dataset ts set to {ts}") + if ts == self.ts: + return + self.ts = ts + with open( + f"{self.ratings_file_prefix}requests_per_ts.csv", "r" + ) as request_file: + request_file.seek(self.ts_requests_offsets[self.ts]) + line = request_file.readline() + reader = csv.reader([line]) + row = next(reader) + assert len(row) == 1 + requests = json_loads(row[0]) + self.requests = requests + logger.warning(f"DLRMv3SyntheticStreamingDataset: ts={ts} requests loaded") + assert self.ts_to_users_cumsum[self.ts][-1] == len(self.requests) + logger.warning( + f"DLRMv3SyntheticStreamingDataset: ts={ts} users_cumsum={self.ts_to_users_cumsum[self.ts]}" + ) + + def load_item( + self, data: pd.Series, max_num_candidates: int + ) -> Tuple[KeyedJaggedTensor, KeyedJaggedTensor]: + """ + Load and process a single user's data into KeyedJaggedTensors. + + Converts parsed user data into feature tensors suitable for model input, + including truncation to maximum sequence lengths. + + Args: + data: pd.Series with user interaction history and candidates. + max_num_candidates: Maximum number of candidates to include. + + Returns: + Tuple of (uih_features_kjt, candidates_features_kjt). + """ + ids_uih = json_loads(data.uih_item_ids) + ids_candidates = json_loads(data.candidate_item_ids) + ratings_uih = json_loads(data.uih_ratings) + ratings_candidates = json_loads(data.candidate_ratings) + timestamps_uih = self.get_timestamp_uih( + data=data, + max_num_candidates=max_num_candidates, + size=len(ids_uih), + ) + assert len(ids_uih) == len(timestamps_uih), ( + "history len differs from timestamp len." + ) + assert len(ids_uih) == len(ratings_uih), ( + f"history len {len(ids_uih)} differs from ratings len {len(ratings_uih)}." + ) + assert len(ids_candidates) == len(ratings_candidates), ( + f"candidates len {len(ids_candidates)} differs from ratings len {len(ratings_candidates)}." + ) + + ids_uih = maybe_truncate_seq(ids_uih, self._max_uih_len) + ratings_uih = maybe_truncate_seq(ratings_uih, self._max_uih_len) + timestamps_uih = maybe_truncate_seq(timestamps_uih, self._max_uih_len) + ids_candidates = maybe_truncate_seq(ids_candidates, max_num_candidates) + num_candidates = len(ids_candidates) + ratings_candidates = maybe_truncate_seq(ratings_candidates, max_num_candidates) + action_weights_uih = [ + self.action_weights[int(rating) - 1] for rating in ratings_uih + ] + action_weights_candidates = [ + int(rating >= 3.5) for rating in ratings_candidates + ] + + uih_kjt_values: List[int] = [] + uih_kjt_lengths: List[int] = [] + for name, length in self._contextual_feature_to_max_length.items(): + uih_kjt_values.append(data[name]) + uih_kjt_lengths.append(length) + + uih_seq_len = len(ids_uih) + dummy_watch_times_uih = [0 for _ in range(uih_seq_len)] + item_category_ids = [id // self.items_per_category for id in ids_uih] + extend_uih_kjt_values: List[int] = ( + ids_uih + + ratings_uih + + timestamps_uih + + action_weights_uih + + dummy_watch_times_uih + + item_category_ids + ) + uih_kjt_values.extend(extend_uih_kjt_values) + uih_kjt_lengths.extend( + [ + uih_seq_len + for _ in range( + len(self._uih_keys) - len(self._contextual_feature_to_max_length) + ) + ] + ) + + dummy_query_time = 0 if timestamps_uih == [] else max(timestamps_uih) + uih_kjt_values.append(dummy_query_time) + uih_kjt_lengths.append(1) + uih_features_kjt: KeyedJaggedTensor = KeyedJaggedTensor( + keys=self._uih_keys + ["dummy_query_time"], + lengths=torch.tensor(uih_kjt_lengths).long(), + values=torch.tensor(uih_kjt_values).long(), + ) + + candidates_kjt_lengths = num_candidates * torch.ones(len(self._candidates_keys)) + item_candidate_category_ids = [ + id // self.items_per_category for id in ids_candidates + ] + candidates_kjt_values = ( + ids_candidates + + ratings_candidates + + [dummy_query_time] * num_candidates # item_query_time + + action_weights_candidates + + [1] * num_candidates # item_dummy_watchtime + + item_candidate_category_ids + ) + candidates_features_kjt: KeyedJaggedTensor = KeyedJaggedTensor( + keys=self._candidates_keys, + lengths=candidates_kjt_lengths.detach().clone().long(), + values=torch.tensor(candidates_kjt_values).long(), + ) + return uih_features_kjt, candidates_features_kjt diff --git a/recommendation_v4/generative_recommenders/dlrm_v3/datasets/utils.py b/recommendation_v4/generative_recommenders/dlrm_v3/datasets/utils.py new file mode 100644 index 000000000..aeca75d41 --- /dev/null +++ b/recommendation_v4/generative_recommenders/dlrm_v3/datasets/utils.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-unsafe +""" +Utility functions for dataset processing. + +This module provides helper functions for parsing and processing data +in the DLRMv3 dataset pipeline. +""" + +import json +import struct +from typing import Dict, List, Sequence, Tuple + +import numpy as np +import xxhash + + +def json_loads( + x: str | int | List[int], +) -> List[int]: + """ + Parse a JSON-like string into a list of integers. + + Handles multiple input formats including JSON arrays, comma-separated + strings, and single values. + + Args: + x: Input that can be a JSON array string, a single integer, + or already a list of integers. + + Returns: + List of integers parsed from the input. + """ + if isinstance(x, str): + if x[0] != "[" and x[-1] != "]": + x = "[" + x + "]" + y = json.loads(x) + else: + y = x + y_list = [y] if type(y) == int else list(y) + return y_list + + +def separate_uih_candidates( + x: str | int | List[int], + candidates_max_seq_len: int, +) -> Tuple[List[int], List[int]]: + """ + Separate a sequence into user interaction history (UIH) and candidates. + + Splits the input sequence such that the last `candidates_max_seq_len` + elements become candidates and the rest become UIH. + + Args: + x: Input sequence as JSON string, single int, or list of ints. + candidates_max_seq_len: Number of items at the end to use as candidates. + + Returns: + Tuple of (uih, candidates) where both are lists of integers. + """ + if isinstance(x, str): + if x[0] != "[" and x[-1] != "]": + x = "[" + x + "]" + y = json.loads(x) + else: + y = x + y_list = [y] if type(y) == int else list(y) + candidates, uih = ( + y_list[-candidates_max_seq_len:], + y_list[:-candidates_max_seq_len], + ) + return uih, candidates + + +def maybe_truncate_seq( + y: List[int], + max_seq_len: int, +) -> List[int]: + """ + Truncate a sequence if it exceeds the maximum length. + + Args: + y: Input sequence to potentially truncate. + max_seq_len: Maximum allowed sequence length. + + Returns: + The input sequence, truncated to max_seq_len if necessary. + """ + y_len = len(y) + if y_len > max_seq_len: + y = y[:max_seq_len] + return y + + +def xxhash_cross( + anchor: Dict[str, int], + keys: Sequence[str], + table_size: int, + salt: int = 0, +) -> int: + """xxhash64(seed=salt) over little-endian int64 concat(anchor[k] for k in keys), mod table_size. + + Bit-identical to primus_dlrm.data.hashing.cross_hash_nway — embedding rows + are interchangeable with Primus-trained ones. + """ + n = len(keys) + assert n >= 2, f"xxhash_cross needs >=2 keys, got {n}" + digest = xxhash.xxh64(seed=salt) + digest.update(struct.Struct(f"<{n}q").pack(*(int(anchor[k]) for k in keys))) + return digest.intdigest() % table_size + + +def xxhash_cross_batch( + arr_by_key: Dict[str, np.ndarray], + keys: Sequence[str], + table_size: int, + salt: int = 0, +) -> np.ndarray: + """Vectorised xxhash_cross over equal-length int64 arrays (one per key).""" + n = len(keys) + assert n >= 2 + cols = [np.asarray(arr_by_key[k], dtype=np.int64).ravel() for k in keys] + length = cols[0].shape[0] + for c in cols: + assert c.shape[0] == length + pack = struct.Struct(f"<{n}q").pack + digest_cls = xxhash.xxh64 + out = np.empty(length, dtype=np.int64) + for i in range(length): + d = digest_cls(seed=salt) + d.update(pack(*(int(c[i]) for c in cols))) + out[i] = d.intdigest() % table_size + return out diff --git a/recommendation_v4/generative_recommenders/dlrm_v3/datasets/yambda.py b/recommendation_v4/generative_recommenders/dlrm_v3/datasets/yambda.py new file mode 100644 index 000000000..00b22cff9 --- /dev/null +++ b/recommendation_v4/generative_recommenders/dlrm_v3/datasets/yambda.py @@ -0,0 +1,1066 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +# pyre-unsafe +""" +Yambda dataset for the DLRMv3 HSTU `modules/` path. + +Reads the parquets produced by `dlrm_v3/preprocess_public_data.py +--dataset yambda-`. Each sample is one anchor LISTEN event with: + * label = (played_ratio >= LISTEN_PLUS_THRESHOLD) — the listen_plus bit + * a chronologically interleaved 3-pool history (listen+/like/skip), with + pool identity tagged per-position in `action_weight` (bits 1/2/4) + * 7 pre-hashed cross-feature ids exposed as length-1 contextual entries +""" + +import logging +import mmap as _mmap_mod +import os +import time +from pathlib import Path +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np +import polars as pl +import torch +from generative_recommenders.dlrm_v3.datasets.dataset import DLRMv3RandomDataset +from generative_recommenders.dlrm_v3.datasets.utils import xxhash_cross +from generative_recommenders.modules.dlrm_hstu import DlrmHSTUConfig +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + +logger = logging.getLogger(__name__) + + +def _load_npy_readonly(path: Union[str, Path]) -> np.ndarray: + # MAP_SHARED + PROT_READ so the kernel does not charge the mapping against + # vm.overcommit_memory=2 limits. numpy's mmap_mode='r' uses MAP_PRIVATE and + # reserves per-process commit; at 8 ranks × ~190 GB store, that OOMs. + path = Path(path) + with open(path, "rb") as f: + version = np.lib.format.read_magic(f) + if version[0] == 1: + shape, _, dtype = np.lib.format.read_array_header_1_0(f) + else: + shape, _, dtype = np.lib.format.read_array_header_2_0(f) + offset = f.tell() + fd = os.open(str(path), os.O_RDONLY) + try: + buf = _mmap_mod.mmap(fd, 0, access=_mmap_mod.ACCESS_READ) + finally: + os.close(fd) + arr = np.ndarray(shape, dtype=dtype, buffer=buf, offset=offset) + arr.flags.writeable = False + return arr + + +def _uid_unit_hash(uids: np.ndarray, salt: int) -> np.ndarray: + """Deterministic uniform-in-[0,1) hash of user ids (splitmix64 finalizer). + + Pure function of (uid, salt): the same uid always maps to the same value, + so the train/eval user split is identical across processes, ranks, and + crash/resume — the property the no-leakage holdout relies on. Vectorized + uint64 arithmetic wraps mod 2**64 (defined for unsigned), so we silence the + benign overflow warnings. + """ + GOLDEN = np.uint64(0x9E3779B97F4A7C15) + M1 = np.uint64(0xBF58476D1CE4E5B9) + M2 = np.uint64(0x94D049BB133111EB) + s30, s27, s31 = np.uint64(30), np.uint64(27), np.uint64(31) + with np.errstate(over="ignore"): + z = uids.astype(np.uint64) + GOLDEN + np.uint64(salt & 0xFFFFFFFFFFFFFFFF) + z = (z ^ (z >> s30)) * M1 + z = (z ^ (z >> s27)) * M2 + z = z ^ (z >> s31) + # Top 53 bits -> uniform [0, 1) double (same trick numpy uses for randoms). + return (z >> np.uint64(11)).astype(np.float64) * (1.0 / 9007199254740992.0) + +# Yambda event-type encoding written by preprocess_public_data.py. +LISTEN_TYPE = 0 +LIKE_TYPE = 1 +LISTEN_PLUS_THRESHOLD = 50 + +# Action-weight bits (must match hstu_config.action_weights = [1, 2, 4]). +LP_BIT = 1 +LIKE_BIT = 2 +SKIP_BIT = 4 + + +class _FlatEventStore: + """Per-user flat event index built from the preprocessed sessions parquet. + + Reads `train_sessions.parquet` and explodes per-session arrays into flat + numpy columns + per-user `(start, end)` index arrays. Cache-compatible + layout, but writes nothing (rebuilds from parquet each construction). + """ + + # On-disk column layout. + _MMAP_COLS = ( + "flat_uid", "flat_item_ids", "flat_timestamps", + "flat_event_types", "flat_played_ratio", + "flat_is_listen_plus", "flat_is_like", "flat_is_skip", + "flat_is_organic", + "user_start", "user_end", "unique_uids", + ) + + def __init__(self, sessions_df: pl.DataFrame) -> None: + logger.info("Building flat event store from sessions...") + sorted_sessions = sessions_df.sort(["uid", "session_id"]) + exploded = sorted_sessions.explode( + ["item_ids", "timestamps", "event_types", "is_organic", "played_ratio_pct"] + ) + + self.flat_uid: np.ndarray = exploded["uid"].to_numpy().astype(np.int64) + self.flat_item_ids: np.ndarray = exploded["item_ids"].to_numpy().astype(np.int64) + self.flat_timestamps: np.ndarray = exploded["timestamps"].to_numpy().astype(np.int64) + self.flat_event_types: np.ndarray = exploded["event_types"].to_numpy().astype(np.int64) + self.flat_played_ratio: np.ndarray = exploded["played_ratio_pct"].to_numpy().astype(np.float32) + self.flat_is_organic: np.ndarray = exploded["is_organic"].to_numpy().astype(np.int8) + np.nan_to_num(self.flat_played_ratio, copy=False, nan=0.0) + + is_listen = self.flat_event_types == LISTEN_TYPE + self.flat_is_listen_plus: np.ndarray = is_listen & ( + self.flat_played_ratio >= LISTEN_PLUS_THRESHOLD + ) + self.flat_is_like: np.ndarray = self.flat_event_types == LIKE_TYPE + self.flat_is_skip: np.ndarray = is_listen & ( + self.flat_played_ratio < LISTEN_PLUS_THRESHOLD + ) + + uid_changes = np.where(np.diff(self.flat_uid) != 0)[0] + 1 + starts = np.concatenate([[0], uid_changes]) + ends = np.concatenate([uid_changes, [len(self.flat_uid)]]) + uid_vals = self.flat_uid[starts] + max_uid = int(uid_vals.max()) + 1 + self.user_start: np.ndarray = np.full(max_uid, -1, dtype=np.int64) + self.user_end: np.ndarray = np.full(max_uid, -1, dtype=np.int64) + self.user_start[uid_vals] = starts + self.user_end[uid_vals] = ends + self.unique_uids: np.ndarray = uid_vals + self.num_users: int = len(uid_vals) + self.total_events: int = len(self.flat_item_ids) + logger.info( + f"FlatEventStore: {self.total_events:,} events, {self.num_users:,} users" + ) + + @classmethod + def load_mmap(cls, cache_dir: Union[str, Path]) -> "_FlatEventStore": + """Load flat columns by MAP_SHARED+PROT_READ from a prebuilt cache. + All ranks on a node share the same physical pages.""" + import json as _json + cache_dir = Path(cache_dir) + with open(cache_dir / "store_meta.json") as f: + meta = _json.load(f) + store = object.__new__(cls) + for name in cls._MMAP_COLS: + setattr(store, name, _load_npy_readonly(cache_dir / f"{name}.npy")) + store.num_users = int(meta["num_users"]) + store.total_events = int(meta["total_events"]) + logger.info( + f"FlatEventStore mmap from {cache_dir}: " + f"{store.total_events:,} events, {store.num_users:,} users" + ) + return store + + def save_mmap(self, cache_dir: Union[str, Path]) -> None: + """Persist flat columns to disk as .npy, then write a sentinel. + Subsequent runs (any rank, any node sharing the FS) load via mmap.""" + import json as _json + cache_dir = Path(cache_dir) + cache_dir.mkdir(parents=True, exist_ok=True) + for name in self._MMAP_COLS: + np.save(cache_dir / f"{name}.npy", getattr(self, name)) + with open(cache_dir / "store_meta.json", "w") as f: + _json.dump( + {"num_users": self.num_users, "total_events": self.total_events}, f + ) + # Sentinel — readers check this before mmap'ing to avoid partial files. + (cache_dir / "_READY").touch() + logger.info(f"FlatEventStore saved to {cache_dir}") + + +class DLRMv3YambdaDataset(DLRMv3RandomDataset): + """Yambda-5b dataset for the DLRMv3 HSTU modules/ path. + + Args: + hstu_config: DlrmHSTUConfig (must come from `get_hstu_configs("yambda-5b")`). + processed_dir: directory with `train_sessions.parquet` + `item_popularity.npy`. + metadata_dir: directory with `{artist,album}_item_mapping.parquet`. + history_length: UIH cap. Under "interleaved" it is the per-pool cap + (total ≤ 3 * history_length // 3); under "last_n" it is the literal + total number of pooled events kept. + scan_window: how far back to scan when filling each pool. + history_strategy: "interleaved" (equal per-pool L//3 cap, re-interleaved) + or "last_n" (last history_length pooled events, no per-pool split). + cross_specs: list of (name, keys, num_embeddings, salt). Source of truth + in `dlrm_v3/configs.py:YAMBDA_5B_CROSS_SPECS`. + is_inference: passed through to base class. + """ + + def __init__( + self, + hstu_config: DlrmHSTUConfig, + processed_dir: str, + metadata_dir: str, + history_length: int = 2048, + scan_window: int = 20000, + min_history: Optional[int] = None, + history_strategy: str = "interleaved", + cross_specs: Optional[Sequence[Tuple[str, Sequence[str], int, int]]] = None, + cache_dir: Optional[str] = None, + is_inference: bool = False, + streaming_window_seconds: int = 86400, + streaming_sort_within_window: bool = False, + streaming_shuffle_fraction: float = 0.0, + streaming_shuffle_seed: int = 0, + train_split_percentage: float = 1.0, + split_salt: int = 0, + *args, + **kwargs, + ) -> None: + super().__init__(hstu_config=hstu_config, is_inference=is_inference) + self._processed_dir: str = processed_dir + self._metadata_dir: str = metadata_dir + self._history_length: int = history_length + self._scan_window: int = scan_window + # UIH construction strategy: + # "interleaved" (default) — equal history_length//3 cap per behavior + # pool (listen+/like/skip), re-interleaved chronologically. Likes are + # ~1.9% of the corpus so the like pool over-fills relative to its + # natural frequency while the sequence under-fills overall. + # "last_n" — take the last history_length events of ANY pool type with + # no per-pool split. Fills the sequence to ~history_length (higher + # effective length) and lets the like share fall to its natural rate. + # Both exclude dislike/unlike/undislike (no action-weight bit). + if history_strategy not in ("interleaved", "last_n"): + raise ValueError( + f"history_strategy must be 'interleaved' or 'last_n', got " + f"{history_strategy!r}" + ) + self._history_strategy: str = history_strategy + # Minimum prior-event count for a LISTEN event to qualify as an anchor. + # Decoupled from history_length (which is only the gather/truncation cap): + # jagged attention handles short UIH, so we no longer require a full + # history_length of context to include a sample. Default None preserves + # the legacy "need a full history_length of prior events" behavior (which + # dropped ~60% of users); set small (e.g. 1) to include ~all users. + self._min_history: int = ( + history_length if min_history is None else int(min_history) + ) + # Streaming/temporal-order state. Everything here is LAZY: nothing is + # built or read until the first set_ts()/num_windows() call (only the + # streaming-train-eval loop does that), so the default train-eval path + # is byte-for-byte unaffected. + self._streaming_window_seconds: int = streaming_window_seconds + self._streaming_sort_within_window: bool = streaming_sort_within_window + # In-window shuffle dial in [0, 1] to break user-major batching (default + # 0.0 = off, user-major order preserved for page-local mmap scans). Maps to + # a within-segment shuffle with K = round(fraction * per-window train-anchor + # count): 1.0 = full per-element shuffle (max user diversity per batch), + # intermediate = interpolation. Computed from the global anchor count BEFORE + # round-robin striding, so a given fraction yields the same diversity + # regardless of world_size / #nodes / batch_size (config-invariant). The + # permutation is a pure function of (seed, ts) so the per-rank round-robin + # slice + mid-window resume skip stay deterministic across restarts. + self._streaming_shuffle_fraction: float = streaming_shuffle_fraction + self._streaming_shuffle_seed: int = streaming_shuffle_seed + # User-level train:eval split. `train_split_percentage >= 1.0` means no + # holdout (legacy behavior: every anchor is trainable). Otherwise the + # top `1 - train_split_percentage` fraction of users (by a deterministic + # hash of `uid + split_salt`) are held out: NEVER trained, used only to + # build the fixed eval set. The split is a pure function of (uid, salt), + # so it is identical across crash/resume (no leakage on failover). + self._train_split_percentage: float = train_split_percentage + self._split_salt: int = split_salt + # Cache only the (small) fixed eval-holdout index list; the per-window + # train filter is computed on the fly to avoid a full-length mask. + self._eval_holdout_cache: Optional[np.ndarray] = None + self._eval_holdout_cache_key: Optional[Tuple[int, int]] = None + self._active: Optional[np.ndarray] = None + self.is_eval: bool = False + self._anchor_ts: Optional[np.ndarray] = None + self._t_min: Optional[int] = None + self._t_max: Optional[int] = None + self._cache_dir: Optional[str] = cache_dir + self._cross_specs: List[Tuple[str, Tuple[str, ...], int, int]] = [ + (name, tuple(keys), n, s) for (name, keys, n, s) in (cross_specs or []) + ] + assert hstu_config.action_weights is not None + self._action_weights: List[int] = hstu_config.action_weights + + self._load_metadata(metadata_dir) + # Build-once-mmap-many: first rank to arrive acquires the build lock + # and explodes the parquet (one ~190 GB in-memory pass), then writes + # flat .npy columns + _READY sentinel. All ranks (including the + # builder, after dropping its in-memory copy) reload via MAP_SHARED+ + # PROT_READ — kernel shares physical pages across ranks so the steady- + # state per-rank RSS for the dataset is ~0. + if cache_dir is None: + cache_dir = os.path.join(processed_dir, f"hstu_cache_L{history_length}") + self._cache_dir = cache_dir + self._ensure_cache_built(cache_dir, processed_dir, history_length) + self.store: _FlatEventStore = _FlatEventStore.load_mmap(cache_dir) + # Anchor positions depend on min_history (the eligibility floor), not + # just history_length (the gather cap), so they live in a + # min_history-versioned file that shares the flat store. Built + # independently of the _READY sentinel so changing the floor rebuilds + # only this (cheap) array, not the whole 150 GB cache. + self._positions_name: str = self._positions_filename( + history_length, self._min_history + ) + self._ensure_positions_built( + cache_dir, self._positions_name, self._min_history + ) + self._positions: np.ndarray = _load_npy_readonly( + os.path.join(cache_dir, self._positions_name) + ) + logger.info( + f"Yambda dataset ready: {self.store.total_events:,} events, " + f"{len(self._positions):,} training positions" + ) + + @staticmethod + def _positions_filename(history_length: int, min_history: int) -> str: + """Anchor-positions filename. Uses the legacy name when the floor equals + the gather cap (the historical "full history required" behavior) so + existing caches are reused as-is; otherwise a min_history-tagged name.""" + if min_history == history_length: + return f"positions_L{history_length}.npy" + return f"positions_L{history_length}_m{min_history}.npy" + + @staticmethod + def _ensure_positions_built( + cache_dir: str, positions_name: str, min_history: int + ) -> None: + """Build the anchor-positions array for ``min_history`` if absent. + + Anchors are LISTEN events whose user-local offset is >= ``min_history`` + (i.e. the user already has that many prior events). This is decoupled + from the _READY-gated flat-store build so a new floor only rebuilds this + (cheap, ~one int64 scan) array rather than the whole 150 GB cache. + Multi-rank safe via an exclusive lock + atomic rename; all ranks then + mmap the result read-only. + """ + import fcntl + + positions_path = os.path.join(cache_dir, positions_name) + if os.path.exists(positions_path): + return + lock_path = os.path.join(cache_dir, "_positions_lock") + with open(lock_path, "w") as lf: + logger.info(f"Acquiring positions build lock for {positions_path}...") + fcntl.flock(lf, fcntl.LOCK_EX) + try: + if os.path.exists(positions_path): + return + flat_uid = _load_npy_readonly( + os.path.join(cache_dir, "flat_uid.npy") + ) + event_types = _load_npy_readonly( + os.path.join(cache_dir, "flat_event_types.npy") + ) + user_start = _load_npy_readonly( + os.path.join(cache_dir, "user_start.npy") + ) + idx = np.arange(len(flat_uid), dtype=np.int64) + keep = (idx - user_start[flat_uid] >= min_history) & ( + event_types == LISTEN_TYPE + ) + positions = np.where(keep)[0].astype(np.int64) + tmp = positions_path + ".tmp.npy" + np.save(tmp, positions) + os.replace(tmp, positions_path) + logger.info( + f"Wrote {positions_name}: {len(positions):,} anchors " + f"(min_history={min_history})" + ) + finally: + fcntl.flock(lf, fcntl.LOCK_UN) + + @staticmethod + def _ensure_cache_built( + cache_dir: str, processed_dir: str, history_length: int + ) -> None: + """File-locked one-shot build with column-at-a-time explode. + + A naive `pl.read_parquet(...).explode([5 list cols])` peaks at ~1.6 TB + on the 5b dataset (polars holds input list-columns + dense output + + parallel-worker scratch all together). Instead we: + 1) Read parquet + sort once (sorted list-column DF, ~80 GB). + 2) For each output column: select that single list, explode, write + .npy, drop. Bounds incremental peak to one column (~38 GB). + 3) Derive bool flags and indices from the on-disk mmaps. + + Peak RAM: ~150 GB. Steady state across all ranks afterward: ~0 + incremental thanks to MAP_SHARED in load_mmap. + """ + import fcntl + import gc + import json as _json + + ready = os.path.join(cache_dir, "_READY") + if os.path.exists(ready): + return + os.makedirs(cache_dir, exist_ok=True) + lock_path = os.path.join(cache_dir, "_lock") + with open(lock_path, "w") as lf: + logger.info(f"Acquiring build lock for {cache_dir}...") + fcntl.flock(lf, fcntl.LOCK_EX) + try: + if os.path.exists(ready): + return + parquet_path = os.path.join(processed_dir, "train_sessions.parquet") + logger.info( + f"Building flat-event cache from {parquet_path} " + f"(column-at-a-time, ~150 GB peak RAM)" + ) + + # Step 1: read + sort. List columns stay nested at this stage. + sessions = pl.read_parquet(parquet_path).sort(["uid", "session_id"]) + logger.info(f"Sessions sorted: {sessions.shape}") + + # Per-session lengths + uids — used to derive flat_uid via + # np.repeat (cheap) without exploding the whole DF at once. + lengths = ( + sessions.select(pl.col("item_ids").list.len()) + .to_numpy() + .reshape(-1) + .astype(np.int64) + ) + session_uids = sessions["uid"].to_numpy().astype(np.int64) + N = int(lengths.sum()) + num_users = int(np.unique(session_uids).shape[0]) + logger.info(f"Total events: {N:,}, users: {num_users:,}") + + # Step 2: column-at-a-time explode → save → drop. + # uid is per-session scalar; expand via np.repeat. + flat_uid = np.repeat(session_uids, lengths).astype(np.int64) + np.save(os.path.join(cache_dir, "flat_uid.npy"), flat_uid) + del flat_uid, session_uids, lengths + gc.collect() + logger.info("Wrote flat_uid.npy") + + # Derived columns flat_is_listen_plus/like/skip depend on + # event_types + played_ratio. Save those two first, then + # derive the bools from the mmaps. + _list_cols = [ + ("item_ids", "flat_item_ids", np.int64), + ("timestamps", "flat_timestamps", np.int64), + ("event_types", "flat_event_types", np.int64), + ("is_organic", "flat_is_organic", np.int8), + ("played_ratio_pct", "flat_played_ratio", np.float32), + ] + for src_col, dst_name, dtype in _list_cols: + exploded = sessions.select(pl.col(src_col).explode()) + arr = exploded[src_col].to_numpy().astype(dtype, copy=False) + if dtype == np.float32: + np.nan_to_num(arr, copy=False, nan=0.0) + np.save(os.path.join(cache_dir, f"{dst_name}.npy"), arr) + del exploded, arr + gc.collect() + logger.info(f"Wrote {dst_name}.npy") + + # Drop the sessions DF now that all source columns are on disk. + del sessions + gc.collect() + + # Step 3: derive bool flags from the just-written mmaps. + event_types = _load_npy_readonly( + os.path.join(cache_dir, "flat_event_types.npy") + ) + played_ratio = _load_npy_readonly( + os.path.join(cache_dir, "flat_played_ratio.npy") + ) + is_listen = event_types == LISTEN_TYPE + np.save( + os.path.join(cache_dir, "flat_is_listen_plus.npy"), + is_listen & (played_ratio >= LISTEN_PLUS_THRESHOLD), + ) + np.save( + os.path.join(cache_dir, "flat_is_like.npy"), + event_types == LIKE_TYPE, + ) + np.save( + os.path.join(cache_dir, "flat_is_skip.npy"), + is_listen & (played_ratio < LISTEN_PLUS_THRESHOLD), + ) + del is_listen, played_ratio + gc.collect() + logger.info("Wrote flat_is_listen_plus/like/skip.npy") + + # user_start / user_end / unique_uids from flat_uid mmap. + flat_uid = _load_npy_readonly( + os.path.join(cache_dir, "flat_uid.npy") + ) + uid_changes = np.where(np.diff(flat_uid) != 0)[0] + 1 + starts = np.concatenate([[0], uid_changes]) + ends = np.concatenate([uid_changes, [len(flat_uid)]]) + uid_vals = flat_uid[starts] + max_uid = int(uid_vals.max()) + 1 + user_start = np.full(max_uid, -1, dtype=np.int64) + user_end = np.full(max_uid, -1, dtype=np.int64) + user_start[uid_vals] = starts + user_end[uid_vals] = ends + np.save(os.path.join(cache_dir, "user_start.npy"), user_start) + np.save(os.path.join(cache_dir, "user_end.npy"), user_end) + np.save(os.path.join(cache_dir, "unique_uids.npy"), uid_vals) + logger.info("Wrote user_start/end/unique_uids.npy") + + # Positions: LISTEN events with ≥history_length prior history. + # Done now (before dropping user_start) so all sibling ranks + # just mmap the result instead of each running a 75 GB build. + user_start_per_event = user_start[flat_uid] + idx = np.arange(len(flat_uid), dtype=np.int64) + keep = (idx - user_start_per_event >= history_length) & ( + event_types == LISTEN_TYPE + ) + positions = np.where(keep)[0].astype(np.int64) + np.save( + os.path.join(cache_dir, f"positions_L{history_length}.npy"), + positions, + ) + logger.info( + f"Wrote positions_L{history_length}.npy: {len(positions):,}" + ) + del ( + flat_uid, event_types, user_start, user_end, uid_vals, + starts, ends, uid_changes, idx, user_start_per_event, + keep, positions, + ) + gc.collect() + + # Meta + sentinel — written last; readers gate on _READY. + with open(os.path.join(cache_dir, "store_meta.json"), "w") as f: + _json.dump( + {"num_users": num_users, "total_events": N}, f + ) + open(os.path.join(cache_dir, "_READY"), "w").close() + logger.info(f"Cache build complete: {cache_dir}") + finally: + fcntl.flock(lf, fcntl.LOCK_UN) + + def _load_metadata(self, metadata_dir: str) -> None: + item_pop_path = os.path.join(metadata_dir, "item_popularity.npy") + if os.path.exists(item_pop_path): + item_popularity = np.load(item_pop_path) + else: + # Fallback: derive vocab size from the artist+album maps. + item_popularity = None + + artist_map = pl.read_parquet(os.path.join(metadata_dir, "artist_item_mapping.parquet")) + album_map = pl.read_parquet(os.path.join(metadata_dir, "album_item_mapping.parquet")) + n_items = int( + max( + int(artist_map["item_id"].max()) + 1, + int(album_map["item_id"].max()) + 1, + len(item_popularity) if item_popularity is not None else 0, + ) + ) + self.item_to_artist: np.ndarray = np.zeros(n_items, dtype=np.int64) + valid = artist_map.filter(pl.col("item_id") < n_items) + self.item_to_artist[valid["item_id"].to_numpy()] = valid["artist_id"].to_numpy() + self.item_to_album: np.ndarray = np.zeros(n_items, dtype=np.int64) + valid = album_map.filter(pl.col("item_id") < n_items) + self.item_to_album[valid["item_id"].to_numpy()] = valid["album_id"].to_numpy() + self.num_items: int = n_items + + def get_item_count(self) -> int: + # Streaming mode restricts the active set to the current time window; + # otherwise the full (user-major) anchor list is used (train-eval). + if self._active is not None: + return int(len(self._active)) + return int(len(self._positions)) + + def iloc(self, idx: int) -> int: + if self._active is not None: + return int(self._positions[self._active[idx]]) + return int(self._positions[idx]) + + def _ensure_streaming_index(self) -> None: + """Lazily build + mmap the per-anchor target-timestamp array used for + time-windowed streaming. + + Built only on the first ``set_ts()``/``num_windows()`` call, so the + default train-eval path never reads timestamps or writes a new file. + Multi-rank safe via an exclusive file lock + atomic rename; all ranks + then mmap the result read-only (shared physical pages, ~0 anon). + """ + if self._anchor_ts is not None: + return + import fcntl + + assert self._cache_dir is not None + # Target-ts array is per-anchor, so it must track the same min_history + # versioning as the positions file it indexes into. + anchor_path = os.path.join( + self._cache_dir, + self._positions_name.replace("positions_", "anchor_ts_", 1), + ) + if not os.path.exists(anchor_path): + lock_path = os.path.join(self._cache_dir, "_anchor_ts_lock") + with open(lock_path, "w") as lf: + logger.info(f"Acquiring anchor-ts build lock for {anchor_path}...") + fcntl.flock(lf, fcntl.LOCK_EX) + if not os.path.exists(anchor_path): + logger.info( + f"Building {anchor_path}: target ts for " + f"{len(self._positions):,} anchors" + ) + anchor_ts = self.store.flat_timestamps[self._positions] + tmp = anchor_path + ".tmp.npy" + np.save(tmp, anchor_ts) + os.replace(tmp, anchor_path) + del anchor_ts + self._anchor_ts = _load_npy_readonly(anchor_path) + self._t_min = int(self._anchor_ts.min()) + self._t_max = int(self._anchor_ts.max()) + + def num_windows(self) -> int: + """Number of fixed-duration windows spanning [t_min, t_max].""" + self._ensure_streaming_index() + assert self._t_min is not None and self._t_max is not None + span = self._t_max - self._t_min + 1 + w = self._streaming_window_seconds + return int((span + w - 1) // w) + + def window_indices( + self, ts: int, sort_by_time: Optional[bool] = None + ) -> np.ndarray: + """Global anchor indices (into ``_positions``) whose target timestamp is + in window ``ts``: ``[t_min + ts*W, t_min + (ts+1)*W)``. + + Returned in ascending global-index order (user-major), which keeps the + per-sample history scans page-local in the mmap'd event arrays. Used by + the per-window path (via ``set_ts``) and the persistent path (shipped to + workers through the sampler). ``sort_by_time`` defaults to + ``streaming_sort_within_window``. + + Note: an O(log N) variant using a cached argsort of the timestamps was + evaluated but rejected — it doubles resident mmap (sorted-ts + order + permutation, ~52 GB) and that extra residency evicts the event-array + page cache, stalling dataloader workers (NCCL watchdog timeouts). The + O(N) mask here keeps only one ~26 GB array resident and is robust. + """ + self._ensure_streaming_index() + assert self._anchor_ts is not None and self._t_min is not None + w = self._streaming_window_seconds + lo = self._t_min + ts * w + hi = lo + w + idx = np.where((self._anchor_ts >= lo) & (self._anchor_ts < hi))[0] + do_sort = ( + self._streaming_sort_within_window if sort_by_time is None else sort_by_time + ) + if do_sort and idx.size > 0: + idx = idx[np.argsort(self._anchor_ts[idx], kind="stable")] + logger.warning(f"window_indices({ts}): [{lo}, {hi}) -> {idx.size:,} anchors") + return idx.astype(np.int64) + + def _eval_anchor_mask(self, anchor_idx: np.ndarray) -> np.ndarray: + """Bool mask (aligned to ``anchor_idx``) marking held-out eval users. + + Computed on the fly for just this slice of anchors (a window is ~tens of + millions, not the full ~3B ``_positions``), so we never materialize a + full-length mask. ``uid``-hash >= ``train_split_percentage`` -> eval. + """ + uids = self.store.flat_uid[self._positions[anchor_idx]] + return _uid_unit_hash(uids, self._split_salt) >= self._train_split_percentage + + def _shuffle_window(self, idx: np.ndarray, ts: int) -> np.ndarray: + """Optionally break user-major ordering within a train window. + + ``streaming_shuffle_fraction`` (0..1) is the single diversity dial. It + maps to a within-segment shuffle with ``K = round(fraction * N)`` where + ``N`` is this window's train-anchor count: + + - 0.0 -> off: return ``idx`` unchanged (user-major, page-local scans). + - 1.0 -> full per-element shuffle (max user diversity per batch). + - else -> permute WITHIN each contiguous size-K segment (segment order + preserved). A per-rank batch then draws across a bounded user-major + region, so diversity scales with the fraction while the concurrently + touched mmap working set stays within ~one K-segment (page locality). + + Because ``N`` is a property of the dataset/window (not the compute layout) + and the permutation is applied BEFORE the per-rank round-robin striding, a + given fraction yields the same diversity across world_size / #nodes / + batch_size (config-invariant). + + The permutation is a pure function of ``(seed, ts)`` via + ``np.random.default_rng(seed + ts)``, so every (re)run of this window + yields the IDENTICAL order. This keeps the per-rank round-robin slice and + the mid-window resume ``skip_samples`` offset consistent across restarts, + exactly like the unshuffled path. + """ + frac = self._streaming_shuffle_fraction + if idx.size <= 1 or not frac or frac <= 0.0: + return idx + rng = np.random.default_rng(self._streaming_shuffle_seed + ts) + if frac >= 1.0: + return idx[rng.permutation(idx.size)] + # Within-segment shuffle (K = round(fraction * N)): a single vectorized + # lexsort over per-element random keys, stable within each size-K segment + # so elements never cross a segment boundary (bounds the working set). O(N + # log K), run once per window in the background prep thread. + n = idx.size + k = max(1, int(round(frac * n))) + seg = np.arange(n, dtype=np.int64) // k + keys = rng.random(n) + order = np.lexsort((keys, seg)) + return idx[order] + + def train_window_indices(self, ts: int) -> np.ndarray: + """Global anchor indices for TRAIN in window ``ts``: ``window_indices`` + with held-out eval users removed. Identical across resume because + ``window_indices``, the uid hash, and the (seed,ts)-keyed in-window + shuffle are all pure functions, so the per-rank round-robin slice (and + the mid-window skip offset) stay consistent.""" + idx = self.window_indices(ts) + if self._train_split_percentage >= 1.0: + return self._shuffle_window(idx, ts) + kept = idx[~self._eval_anchor_mask(idx)] + logger.warning( + f"train_window_indices({ts}): {idx.size:,} -> {kept.size:,} anchors " + f"(holdout tsp={self._train_split_percentage}, salt={self._split_salt})" + ) + return self._shuffle_window(kept, ts) + + def eval_holdout_indices(self, start_ts: int, num_windows: int = 1) -> np.ndarray: + """Fixed eval set: held-out users' anchors over windows + ``[start_ts, start_ts + num_windows)``. Computed once and cached, so the + SAME anchors are evaluated at every eval step (stable, comparable curve). + With no holdout (tsp>=1.0) this falls back to the full window(s).""" + key = (int(start_ts), int(num_windows)) + if self._eval_holdout_cache is not None and self._eval_holdout_cache_key == key: + return self._eval_holdout_cache + parts: List[np.ndarray] = [] + for ts in range(start_ts, start_ts + max(1, num_windows)): + idx = self.window_indices(ts) + if self._train_split_percentage < 1.0: + idx = idx[self._eval_anchor_mask(idx)] + parts.append(idx) + holdout = ( + np.concatenate(parts).astype(np.int64) + if parts + else np.empty(0, dtype=np.int64) + ) + logger.warning( + f"eval_holdout_indices(start_ts={start_ts}, num_windows={num_windows}): " + f"{holdout.size:,} held-out anchors (tsp={self._train_split_percentage})" + ) + self._eval_holdout_cache = holdout + self._eval_holdout_cache_key = key + return holdout + + def total_train_anchors(self, start_ts: int, num_ts: int) -> int: + """Total TRAIN anchors across windows ``[start_ts, start_ts + num_ts)``. + + A single O(N) pass over the cached ``_anchor_ts`` array (NOT per-window + ``train_window_indices`` scans). Used to convert a "fraction of training + data" eval cadence into a global train-step interval. With a user holdout + (``train_split_percentage`` < 1.0) the held-out eval users are excluded + via the SAME uid hash as ``train_window_indices``, so the count matches + what is actually trained. + + NOTE: this is an UPPER BOUND on the realized train STEP count — the + per-window samplers truncate each window to a multiple of ``world_size`` + and drop the last partial per-rank batch (``drop_last=True``). The small + overcount is acceptable for a cadence knob (it only shifts the eval grid + by a fraction of a window). + """ + self._ensure_streaming_index() + assert self._anchor_ts is not None and self._t_min is not None + if num_ts <= 0: + return 0 + w = self._streaming_window_seconds + lo = self._t_min + start_ts * w + hi = self._t_min + (start_ts + num_ts) * w + in_range = (self._anchor_ts >= lo) & (self._anchor_ts < hi) + if self._train_split_percentage >= 1.0: + total = int(np.count_nonzero(in_range)) + else: + sel = np.where(in_range)[0] + total = int(np.count_nonzero(~self._eval_anchor_mask(sel))) + logger.warning( + f"total_train_anchors(start_ts={start_ts}, num_ts={num_ts}): " + f"{total:,} train anchors (tsp={self._train_split_percentage})" + ) + return total + + def set_ts(self, ts: int, train_only: bool = False) -> None: + """Restrict the active sample set to anchors in window ``ts`` (used by + the per-window-DataLoader path, where ``iloc``/``get_item_count`` index + through ``_active``). + + ``train_only=True`` removes held-out eval users so the non-persistent + TRAIN loader never sees them (closes the leakage path). Forward-only + temporal slicing for streaming train/eval. History for any anchor is + still gathered causally (``scan_start:flat_pos``) and may span earlier + windows, so there is no feature leakage from future events. + """ + self._active = ( + self.train_window_indices(ts) if train_only else self.window_indices(ts) + ) + + def set_active_indices(self, indices: np.ndarray) -> None: + """Restrict the active sample set to an explicit array of global anchor + indices (into ``_positions``). Used by the non-persistent eval path to + iterate the fixed user-holdout set (which spans a window range, not a + single ``ts``).""" + self._active = np.asarray(indices, dtype=np.int64) + + def load_query_samples(self, sample_list) -> None: + max_num_candidates = ( + self._max_num_candidates_inference + if self._is_inference + else self._max_num_candidates + ) + self.items_in_memory = {} + for idx in sample_list: + flat_pos = self.iloc(idx) + self.items_in_memory[idx] = self._build_sample(flat_pos, max_num_candidates) + self.last_loaded = time.time() + + def get_sample(self, idx: int) -> Tuple[KeyedJaggedTensor, KeyedJaggedTensor]: + if idx in self.items_in_memory: + return self.items_in_memory[idx] + max_num_candidates = ( + self._max_num_candidates_inference + if self._is_inference + else self._max_num_candidates + ) + flat_pos = self.iloc(idx) + return self._build_sample(flat_pos, max_num_candidates) + + @staticmethod + def _empty_history() -> Tuple[ + np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray + ]: + empty = np.empty(0, dtype=np.int64) + return empty, empty, empty, empty, empty + + def _read_scan_window( + self, flat_pos: int, user_start: int + ) -> Optional[ + Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray] + ]: + """Read the causal scan window [scan_start, flat_pos) for an anchor. + Returns (item_ids, timestamps, is_lp, is_like, is_skip) views, or None + if the window is empty.""" + scan_start = max(int(user_start), int(flat_pos) - self._scan_window) + scan_end = int(flat_pos) + if scan_end <= scan_start: + return None + return ( + self.store.flat_item_ids[scan_start:scan_end], + self.store.flat_timestamps[scan_start:scan_end], + self.store.flat_is_listen_plus[scan_start:scan_end], + self.store.flat_is_like[scan_start:scan_end], + self.store.flat_is_skip[scan_start:scan_end], + ) + + def _materialize_history( + self, + keep_local: np.ndarray, + item_ids: np.ndarray, + timestamps: np.ndarray, + is_lp: np.ndarray, + is_like: np.ndarray, + is_skip: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Gather item/artist/album/ts + pool-bitmask `weight` for the kept + (chronologically-ordered) local indices.""" + items = item_ids[keep_local] + ts = timestamps[keep_local] + artists = self.item_to_artist[np.clip(items, 0, self.item_to_artist.shape[0] - 1)] + albums = self.item_to_album[np.clip(items, 0, self.item_to_album.shape[0] - 1)] + # Pool bitmask per kept event (LP/LIKE/SKIP are mutually exclusive in + # the source data, but OR is safe and forward-compatible). + weight = np.zeros(keep_local.shape[0], dtype=np.int64) + weight[is_lp[keep_local]] |= LP_BIT + weight[is_like[keep_local]] |= LIKE_BIT + weight[is_skip[keep_local]] |= SKIP_BIT + return items, artists, albums, ts, weight + + def _gather_history( + self, flat_pos: int, user_start: int + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Dispatch UIH construction to the configured strategy.""" + if self._history_strategy == "last_n": + return self._gather_last_n_history(flat_pos, user_start) + return self._gather_interleaved_history(flat_pos, user_start) + + def _gather_interleaved_history( + self, flat_pos: int, user_start: int + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Build a single chronologically-ordered history sequence from the 3 + behavior pools. Each event's `action_weight` carries the pool bitmask + (LP_BIT/LIKE_BIT/SKIP_BIT). Per-pool cap = history_length // 3.""" + L = self._history_length + per_pool = max(1, L // 3) + scan = self._read_scan_window(flat_pos, user_start) + if scan is None: + return self._empty_history() + item_ids, timestamps, is_lp, is_like, is_skip = scan + + # Local indices into the scan window — preserves chronological order + # within each pool and lets us interleave by re-sorting. + idx_all = np.arange(item_ids.shape[0], dtype=np.int64) + lp_idx = idx_all[is_lp][-per_pool:] + like_idx = idx_all[is_like][-per_pool:] + skip_idx = idx_all[is_skip][-per_pool:] + + keep_local = np.concatenate([lp_idx, like_idx, skip_idx]) + if keep_local.size == 0: + return self._empty_history() + + order = np.argsort(keep_local, kind="stable") + keep_local = keep_local[order] + + return self._materialize_history( + keep_local, item_ids, timestamps, is_lp, is_like, is_skip + ) + + def _gather_last_n_history( + self, flat_pos: int, user_start: int + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Build the UIH from the last `history_length` events of ANY pool type + (listen+/like/skip) with no per-pool split. Vs the interleaved strategy + this fills the sequence to ~history_length (higher effective length) and + lets the like share fall to its natural corpus rate (~1.9%). Events + outside the 3 pools (dislike/unlike/undislike) are excluded as before.""" + L = self._history_length + scan = self._read_scan_window(flat_pos, user_start) + if scan is None: + return self._empty_history() + item_ids, timestamps, is_lp, is_like, is_skip = scan + + member = is_lp | is_like | is_skip + # Last L pooled events, in chronological order (already position-sorted + # within the scan window, so no re-sort is needed). + keep_local = np.arange(item_ids.shape[0], dtype=np.int64)[member][-L:] + if keep_local.size == 0: + return self._empty_history() + + return self._materialize_history( + keep_local, item_ids, timestamps, is_lp, is_like, is_skip + ) + + def _build_sample( + self, flat_pos: int, max_num_candidates: int + ) -> Tuple[KeyedJaggedTensor, KeyedJaggedTensor]: + uid = int(self.store.flat_uid[flat_pos]) + user_start = int(self.store.user_start[uid]) + + items, artists, albums, ts, weight = self._gather_history( + flat_pos, user_start + ) + + target_item = int(self.store.flat_item_ids[flat_pos]) + target_artist = int( + self.item_to_artist[target_item] + if target_item < self.item_to_artist.shape[0] + else 0 + ) + target_album = int( + self.item_to_album[target_item] + if target_item < self.item_to_album.shape[0] + else 0 + ) + target_ts = int(self.store.flat_timestamps[flat_pos]) + + played_ratio = float(self.store.flat_played_ratio[flat_pos]) + is_lp = ( + int(self.store.flat_event_types[flat_pos]) == LISTEN_TYPE + and played_ratio >= LISTEN_PLUS_THRESHOLD + ) + # Label encoded into the candidate's action_weight via the LP bit, so + # _get_supervision_labels_and_weights sees the right supervision. + candidate_action_weight = LP_BIT if is_lp else 0 + + cross_id_anchor: Dict[str, int] = { + "uid": uid, + "item_id": target_item, + "artist_id": target_artist, + "album_id": target_album, + "hour_of_day": int((target_ts // 3600) % 24), + "is_organic": int(self.store.flat_is_organic[flat_pos]), + } + cross_ids: Dict[str, int] = { + name: xxhash_cross(cross_id_anchor, list(keys), n, salt) + for (name, keys, n, salt) in self._cross_specs + } + + # ---- Truncate UIH to fit max_seq_len budget ---- + uih_seq_len_budget = ( + self._max_seq_len + - max_num_candidates + - len(self._contextual_feature_to_max_length or {}) + ) + if items.shape[0] > uih_seq_len_budget: + items = items[-uih_seq_len_budget:] + artists = artists[-uih_seq_len_budget:] + albums = albums[-uih_seq_len_budget:] + ts = ts[-uih_seq_len_budget:] + weight = weight[-uih_seq_len_budget:] + uih_seq_len = int(items.shape[0]) + dummy_watch_time = np.zeros(uih_seq_len, dtype=np.int64) + + # ---- Build UIH KJT ---- + # Contextual features (length-1 each) iterated in the same order as + # `_contextual_feature_to_max_length` (matches movielens reference). + uih_kjt_values: List[int] = [] + uih_kjt_lengths: List[int] = [] + for name, length in (self._contextual_feature_to_max_length or {}).items(): + assert length == 1, f"yambda contextuals are length-1, got {name}={length}" + if name == "uid": + uih_kjt_values.append(uid) + else: + uih_kjt_values.append(int(cross_ids[name])) + uih_kjt_lengths.append(1) + + # Sequential features — order must match the trailing entries of + # hstu_uih_feature_names in configs.py: + # item_id, artist_id, album_id, action_weight, action_timestamp, dummy_watch_time + uih_kjt_values.extend(items.tolist()) + uih_kjt_values.extend(artists.tolist()) + uih_kjt_values.extend(albums.tolist()) + uih_kjt_values.extend(weight.tolist()) + uih_kjt_values.extend(ts.tolist()) + uih_kjt_values.extend(dummy_watch_time.tolist()) + n_sequential = len(self._uih_keys) - len(self._contextual_feature_to_max_length or {}) + uih_kjt_lengths.extend([uih_seq_len] * n_sequential) + + dummy_query_time = int(ts[-1]) if uih_seq_len > 0 else target_ts + uih_kjt_values.append(dummy_query_time) + uih_kjt_lengths.append(1) + + uih_features_kjt = KeyedJaggedTensor( + keys=self._uih_keys + ["dummy_query_time"], + lengths=torch.tensor(uih_kjt_lengths, dtype=torch.long), + values=torch.tensor(uih_kjt_values, dtype=torch.long), + ) + + # ---- Build candidates KJT ---- + # Order must match configs.py:hstu_candidate_feature_names exactly: + # item_candidate_id, item_candidate_artist_id, item_candidate_album_id, + # item_query_time, item_action_weight, item_dummy_watchtime + candidates_kjt_lengths = max_num_candidates * torch.ones( + len(self._candidates_keys), dtype=torch.long + ) + candidates_kjt_values: List[int] = ( + [target_item] * max_num_candidates + + [target_artist] * max_num_candidates + + [target_album] * max_num_candidates + + [dummy_query_time] * max_num_candidates + + [candidate_action_weight] * max_num_candidates + + [1] * max_num_candidates # item_dummy_watchtime + ) + candidates_features_kjt = KeyedJaggedTensor( + keys=self._candidates_keys, + lengths=candidates_kjt_lengths, + values=torch.tensor(candidates_kjt_values, dtype=torch.long), + ) + return uih_features_kjt, candidates_features_kjt diff --git a/recommendation_v4/generative_recommenders/dlrm_v3/preprocess_public_data.py b/recommendation_v4/generative_recommenders/dlrm_v3/preprocess_public_data.py new file mode 100644 index 000000000..06b43aa48 --- /dev/null +++ b/recommendation_v4/generative_recommenders/dlrm_v3/preprocess_public_data.py @@ -0,0 +1,539 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-unsafe +import argparse +import json +import logging +import os +import tarfile +from pathlib import Path +from typing import Dict, List, Optional, Tuple +from urllib.request import urlretrieve + +import numpy as np +import pandas as pd + + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger("main") + +""" +Usage: + mkdir -p data/ && python3 preprocess_public_data.py --dataset kuairand-1k + python3 preprocess_public_data.py --dataset yambda-5b --data-path + python3 preprocess_public_data.py --dataset yambda-500m --data-path + python3 preprocess_public_data.py --dataset yambda-50m --data-path +""" + +SUPPORTED_DATASETS = [ + "kuairand-1k", + "kuairand-27k", + "yambda-50m", + "yambda-500m", + "yambda-5b", +] + + +def get_feature_merge_weights(dataset: str = "debug") -> Dict[str, int]: + if "kuairand" in dataset: + return { + "is_click": 1, + "is_like": 2, + "is_follow": 4, + "is_comment": 8, + "is_forward": 16, + "is_hate": 32, + "long_view": 64, + "is_profile_enter": 128, + } + else: + return {"dummy": 1} + + +class DataProcessor: + def __init__( + self, + download_url: str, + data_path: str, + file_name: str, + prefix: str, + ) -> None: + self._download_url = download_url + self._data_path = data_path + self._file_name = file_name + self._prefix = prefix + + def download(self) -> None: + return + + def preprocess(self) -> None: + return + + def file_exists(self, name: str) -> bool: + return os.path.isfile("%s/%s" % (os.getcwd(), name)) + + +class DLRMKuaiRandProcessor(DataProcessor): + def __init__( + self, + download_url: str, + data_path: str, + file_name: str, + prefix: str, + ) -> None: + super().__init__(download_url, data_path, file_name, prefix) + if prefix == "KuaiRand-1K": + self._log_files: List[str] = [ + f"{data_path}{prefix}/data/log_standard_4_08_to_4_21_1k.csv", + f"{data_path}{prefix}/data/log_standard_4_22_to_5_08_1k.csv", + ] + self._user_features_file: str = ( + f"{data_path}{prefix}/data/user_features_1k.csv" + ) + elif prefix == "KuaiRand-27K": + self._log_files: List[str] = [ + f"{data_path}{prefix}/data/log_standard_4_08_to_4_21_27k_part1.csv", + f"{data_path}{prefix}/data/log_standard_4_08_to_4_21_27k_part2.csv", + f"{data_path}{prefix}/data/log_standard_4_22_to_5_08_27k_part1.csv", + f"{data_path}{prefix}/data/log_standard_4_22_to_5_08_27k_part2.csv", + ] + self._user_features_file: str = ( + f"{data_path}{prefix}/data/user_features_27k.csv" + ) + self._output_file: str = f"{data_path}{prefix}/data/processed_seqs.csv" + self._event_merge_weight: Dict[str, int] = get_feature_merge_weights( + prefix.lower() + ) + + def download(self) -> None: + file_path = f"{self._data_path}{self._file_name}" + if not self.file_exists(file_path): + log.info(f"Downloading {self._download_url}") + urlretrieve(self._download_url, file_path) + log.info(f"Downloaded to {file_path}") + with tarfile.open(file_path, "r:*") as tar_ref: + tar_ref.extractall(path=self._data_path) + log.info("Data files extracted") + os.remove(file_path) + log.info("Tar file removed") + + def preprocess(self) -> None: + self.download() + log.info("Preprocessing data...") + seq_cols = [ + "video_id", + "time_ms", + "action_weights", + "play_time_ms", + "duration_ms", + ] + df = None + for idx, log_file in enumerate(self._log_files): + log.info(f"Processing {log_file}...") + log_df = pd.read_csv( + log_file, + delimiter=",", + ) + df_grouped_by_user = log_df.groupby("user_id").agg(list).reset_index() + + for event, weight in self._event_merge_weight.items(): + df_grouped_by_user[event] = df_grouped_by_user[event].apply( + lambda seq: np.where(np.array(seq) == 0, 0, weight) + ) + + events = list(self._event_merge_weight.keys()) + df_grouped_by_user["action_weights"] = df_grouped_by_user.apply( + lambda row: [int(sum(x)) for x in zip(*[row[col] for col in events])], + axis=1, + ) + df_grouped_by_user = df_grouped_by_user[["user_id"] + seq_cols] + + if idx == 0: + df = df_grouped_by_user + else: + df = df.merge(df_grouped_by_user, on="user_id", suffixes=("_x", "_y")) + for col in seq_cols: + df[col] = df.apply( + lambda row: row[col + "_x"] + row[col + "_y"], axis=1 + ) + df = df.drop(columns=[col + "_x", col + "_y"]) + + max_seq_len = df["video_id"].apply(len).max() + min_seq_len = df["video_id"].apply(len).min() + average_seq_len = df["video_id"].apply(len).mean() + log.info(f"{max_seq_len=}, {min_seq_len=}, {average_seq_len=}") + + log.info("Merging user features...") + user_features_df = pd.read_csv(self._user_features_file, delimiter=",") + + def _one_hot_encode(row): + mapping = {category: i + 1 for i, category in enumerate(row.unique())} + row = row.map(mapping) + return row + + for col in [ + "user_active_degree", + "follow_user_num_range", + "fans_user_num_range", + "friend_user_num_range", + "register_days_range", + ]: + user_features_df[col] = _one_hot_encode(user_features_df[col]) + + final_df = pd.merge(df, user_features_df, on="user_id") + final_df.to_csv(self._output_file, index=False, sep=",") + log.info(f"Processed file saved to {self._output_file}") + + +# ---------------------------------------------------------------------------- +# Yambda processor +# ---------------------------------------------------------------------------- +# +# Yambda is hosted on HuggingFace at `yandex/yambda` and comes in three sizes: +# 50m, 500m, 5b. Each size shares the same catalog metadata (embeddings, +# artist/album mappings); only the interaction stream differs. +# +# This processor: +# 1) Downloads `multi_event.parquet` for the chosen size + the catalog +# metadata files via the `datasets` library. +# 2) Encodes event_type strings into uint8. +# 3) Splits temporally into train + test (Global Temporal Split, GTS). +# 4) Builds per-user sessions by inactivity gap. +# 5) Computes item popularity counts. +# 6) Writes the layout expected by `DLRMv3YambdaDataset`: +# +# /processed_/ +# train_sessions.parquet +# test_events.parquet +# session_index.parquet +# item_popularity.npy +# split_meta.json +# +# /shared_metadata/ +# artist_item_mapping.parquet +# album_item_mapping.parquet +# embeddings.parquet (optional; not used by HSTU training) +# +# The HSTU training path then auto-builds an `hstu_cache_L/` mmap under +# `processed_/` on first use. +# ---------------------------------------------------------------------------- + +YAMBDA_HF_REPO = "yandex/yambda" +YAMBDA_SIZES = {"yambda-50m": "50m", "yambda-500m": "500m", "yambda-5b": "5b"} +YAMBDA_METADATA_FILES = ( + "artist_item_mapping", + "album_item_mapping", + "embeddings", +) + +# Yambda timestamps are seconds (rounded to 5s boundaries). +SECONDS_PER_DAY = 86400 +# Polars chunk size for streaming the 5b parquet (~150 GB on disk). +YAMBDA_CHUNK_SIZE = 10_000_000 +EVENT_TYPE_MAP = {"listen": 0, "like": 1, "dislike": 2, "unlike": 3, "undislike": 4} + + +class DLRMYambdaProcessor(DataProcessor): + """Download + preprocess Yambda (50m / 500m / 5b) for DLRMv3YambdaDataset.""" + + def __init__( + self, + data_path: str, + size: str, + session_gap_seconds: int = 1800, + train_days: int = 300, + gap_minutes: int = 30, + test_days: int = 1, + ) -> None: + assert size in {"50m", "500m", "5b"}, f"unknown yambda size {size}" + super().__init__( + download_url="", # download is via HuggingFace `datasets` lib + data_path=data_path.rstrip("/") + "/", + file_name=f"{size}/multi_event.parquet", + prefix=f"yambda-{size}", + ) + self._size: str = size + self._raw_dir: Path = Path(self._data_path) / "raw" + self._processed_dir: Path = Path(self._data_path) / f"processed_{size}" + self._shared_dir: Path = Path(self._data_path) / "shared_metadata" + self._session_gap_seconds: int = session_gap_seconds + self._train_days: int = train_days + self._gap_minutes: int = gap_minutes + self._test_days: int = test_days + + def download(self) -> None: + try: + from datasets import DatasetDict, load_dataset + except ImportError as e: + raise ImportError( + "Downloading Yambda requires the `datasets` package " + "(`pip install datasets`)." + ) from e + + self._raw_dir.mkdir(parents=True, exist_ok=True) + self._shared_dir.mkdir(parents=True, exist_ok=True) + + # Size-specific interaction stream. + event_path = self._raw_dir / self._size / "multi_event.parquet" + if not event_path.exists(): + event_path.parent.mkdir(parents=True, exist_ok=True) + log.info( + f"Downloading multi_event.parquet for {self._size} " + f"from {YAMBDA_HF_REPO} ..." + ) + ds = load_dataset( + YAMBDA_HF_REPO, + data_dir=f"flat/{self._size}", + data_files="multi_event.parquet", + ) + assert isinstance(ds, DatasetDict) + ds["train"].to_parquet(str(event_path)) + log.info(f"Saved {event_path}") + else: + log.info(f"Already exists: {event_path}") + + # Catalog metadata files (shared across sizes). + for name in YAMBDA_METADATA_FILES: + shared_path = self._shared_dir / f"{name}.parquet" + if shared_path.exists(): + log.info(f"Already exists: {shared_path}") + continue + log.info(f"Downloading {name}.parquet from {YAMBDA_HF_REPO} ...") + ds = load_dataset(YAMBDA_HF_REPO, data_files=f"{name}.parquet") + assert isinstance(ds, DatasetDict) + ds["train"].to_parquet(str(shared_path)) + log.info(f"Saved {shared_path}") + + def preprocess(self) -> None: + self.download() + try: + import polars as pl + except ImportError as e: + raise ImportError( + "Yambda preprocessing requires polars " + "(`pip install polars-u64-idx` is recommended for the 5b " + "variant — stock polars overflows its 32-bit row index)." + ) from e + + self._processed_dir.mkdir(parents=True, exist_ok=True) + event_path = self._raw_dir / self._size / "multi_event.parquet" + + log.info(f"Loading multi_event from {event_path} ...") + events = self._load_events(pl, event_path) + log.info(f"Loaded {len(events):,} events") + + events = self._encode_event_types(pl, events) + t_min = int(events["timestamp"].min()) + t_max = int(events["timestamp"].max()) + log.info( + f"Timestamp range: {t_min}..{t_max} " + f"({(t_max - t_min) / SECONDS_PER_DAY:.1f} days)" + ) + + train_start, train_end, test_start, test_end = self._split_boundaries(t_max) + log.info( + f"GTS train=[{train_start},{train_end}) gap=[{train_end},{test_start}) " + f"test=[{test_start},{test_end})" + ) + train_events, test_events = self._temporal_split( + pl, events, train_start, train_end, test_start, test_end + ) + log.info( + f"Train: {len(train_events):,} events, Test: {len(test_events):,} events" + ) + + gap_units = self._session_gap_seconds # 1 unit = 1 second + sessions = self._build_sessions(pl, train_events, gap_units) + log.info(f"Built {len(sessions):,} sessions") + + session_index = self._build_session_index(pl, sessions) + log.info(f"Session index covers {len(session_index):,} users") + + item_popularity = self._compute_item_popularity(train_events) + + sessions.write_parquet(str(self._processed_dir / "train_sessions.parquet")) + test_events.write_parquet(str(self._processed_dir / "test_events.parquet")) + session_index.write_parquet(str(self._processed_dir / "session_index.parquet")) + np.save(self._processed_dir / "item_popularity.npy", item_popularity) + + with open(self._processed_dir / "split_meta.json", "w") as f: + json.dump( + { + "size": self._size, + "t_min": t_min, + "t_max": t_max, + "train_start": train_start, + "train_end": train_end, + "test_start": test_start, + "test_end": test_end, + "train_days": self._train_days, + "gap_minutes": self._gap_minutes, + "test_days": self._test_days, + "session_gap_seconds": self._session_gap_seconds, + "num_train_events": int(len(train_events)), + "num_test_events": int(len(test_events)), + "num_sessions": int(len(sessions)), + "num_users": int(len(session_index)), + }, + f, + indent=2, + ) + log.info(f"Preprocessing complete: {self._processed_dir}") + + # ------- helpers -------- + + def _load_events(self, pl, parquet_path: Path): + # 5b is too large to load in one polars pass on most boxes (~150 GB + # peak in-RAM with eager read). Stream in 10M-row chunks for safety. + if self._size == "5b": + log.info(f"Streaming load (chunk_size={YAMBDA_CHUNK_SIZE:,})...") + lf = pl.scan_parquet(parquet_path) + n = lf.select(pl.len()).collect().item() + log.info(f"Total rows: {n:,}") + chunks = [] + for off in range(0, n, YAMBDA_CHUNK_SIZE): + chunk = lf.slice(off, YAMBDA_CHUNK_SIZE).collect() + chunks.append(chunk) + log.info(f" loaded {off:,}..{off + len(chunk):,}") + return pl.concat(chunks) + return pl.read_parquet(parquet_path) + + def _encode_event_types(self, pl, events): + dt = events["event_type"].dtype + if dt == pl.Utf8 or isinstance(dt, (pl.Categorical, pl.Enum)): + events = events.with_columns( + pl.col("event_type") + .cast(pl.Utf8) + .replace_strict(EVENT_TYPE_MAP) + .cast(pl.UInt8) + .alias("event_type") + ) + return events + + def _split_boundaries(self, t_max: int) -> Tuple[int, int, int, int]: + test_end = t_max + test_start = test_end - self._test_days * SECONDS_PER_DAY + train_end = test_start - self._gap_minutes * 60 + train_start = train_end - self._train_days * SECONDS_PER_DAY + return train_start, train_end, test_start, test_end + + def _temporal_split(self, pl, events, train_start, train_end, test_start, test_end): + train = events.filter( + (pl.col("timestamp") >= train_start) & (pl.col("timestamp") < train_end) + ) + test_all = events.filter( + (pl.col("timestamp") >= test_start) & (pl.col("timestamp") < test_end) + ) + # Test users must also appear in train (next-item prediction setup). + train_users = train.select("uid").unique() + test = test_all.join(train_users, on="uid", how="inner") + return train, test + + def _build_sessions(self, pl, events, session_gap_units: int): + sorted_events = events.sort(["uid", "timestamp"]) + return ( + sorted_events + .with_columns( + ( + (pl.col("timestamp").diff().fill_null(0) > session_gap_units) + .cast(pl.UInt32) + .cum_sum() + ) + .over("uid") + .alias("session_id") + ) + .group_by(["uid", "session_id"]) + .agg( + pl.col("item_id").alias("item_ids"), + pl.col("timestamp").alias("timestamps"), + pl.col("event_type").alias("event_types"), + pl.col("is_organic").alias("is_organic"), + pl.col("played_ratio_pct").alias("played_ratio_pct"), + pl.col("track_length_seconds").alias("track_length_seconds"), + ) + .sort(["uid", "session_id"]) + ) + + def _build_session_index(self, pl, sessions): + return ( + sessions + .with_columns(pl.col("item_ids").list.len().alias("session_len")) + .group_by("uid") + .agg( + pl.col("session_id").alias("session_ids"), + pl.col("session_len").alias("session_lens"), + pl.col("session_len").cum_sum().alias("session_offsets"), + ) + .sort("uid") + ) + + def _compute_item_popularity(self, train_events) -> np.ndarray: + counts = ( + train_events + .group_by("item_id") + .len() + .sort("item_id") + ) + max_item = int(counts["item_id"].max()) + popularity = np.zeros(max_item + 1, dtype=np.int64) + popularity[counts["item_id"].to_numpy()] = counts["len"].to_numpy() + return popularity + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--dataset", + choices=SUPPORTED_DATASETS, + required=True, + help="dataset", + ) + parser.add_argument( + "--data-path", + default="data/", + help=( + "Root directory for raw + processed data. KuaiRand defaults to " + "the existing `data/` convention; Yambda defaults to `data/` too " + "but is commonly overridden to a shared filesystem location with " + "enough space for the 5b variant (~500 GB)." + ), + ) + args = parser.parse_args() + + data_path = args.data_path.rstrip("/") + "/" + + if args.dataset == "kuairand-1k": + DLRMKuaiRandProcessor( + download_url="https://zenodo.org/records/10439422/files/KuaiRand-1K.tar.gz", + data_path=data_path, + file_name="KuaiRand-1K.tar.gz", + prefix="KuaiRand-1K", + ).preprocess() + elif args.dataset == "kuairand-27k": + DLRMKuaiRandProcessor( + download_url="https://zenodo.org/records/10439422/files/KuaiRand-27K.tar.gz", + data_path=data_path, + file_name="KuaiRand-27K.tar.gz", + prefix="KuaiRand-27K", + ).preprocess() + elif args.dataset in YAMBDA_SIZES: + DLRMYambdaProcessor( + data_path=data_path, + size=YAMBDA_SIZES[args.dataset], + ).preprocess() + + +if __name__ == "__main__": + main() diff --git a/recommendation_v4/generative_recommenders/dlrm_v3/streaming_synthetic_data.py b/recommendation_v4/generative_recommenders/dlrm_v3/streaming_synthetic_data.py new file mode 100644 index 000000000..bb9e508af --- /dev/null +++ b/recommendation_v4/generative_recommenders/dlrm_v3/streaming_synthetic_data.py @@ -0,0 +1,664 @@ +# pyre-strict +""" +Streaming synthetic data generator for DLRMv3. + +This module generates synthetic streaming recommendation data for benchmarking +and testing purposes. It creates user-item interaction histories with timestamps, +ratings, and category-based item distributions. +""" + +import csv +import logging +import math +import multiprocessing +import os +import random +import shutil +import time +from typing import Dict, List, Tuple + +import numpy as np + +logger: logging.Logger = logging.getLogger(__name__) + + +class StreamingSyntheticDataGenerator: + """ + Generator for streaming synthetic recommendation data. + + Creates realistic user-item interaction data with temporal dynamics, + category preferences, and rating distributions for benchmarking + recommendation systems. + + Args: + num_categories: Number of item categories. + categories_per_user: Number of categories each user is interested in. + num_users: Total number of users to generate. + num_items: Total number of items in the catalog. + num_timestamps: Number of time periods in the streaming data. + avg_samples_per_item: Average number of interactions per item. + train_ratio: Fraction of timestamps used for training. + user_sampling_ratio: Probability of sampling a user at each timestamp. + num_eval_candidates: Number of candidates for evaluation. + num_inference_candidates: Number of candidates for inference. + debug: If True, use deterministic ratings for debugging. + rank: Process rank for distributed generation. + """ + + def __init__( + self, + num_categories: int, + categories_per_user: int, + num_users: int, + num_items: int, + num_timestamps: int, + avg_samples_per_item: int, + train_ratio: float, + user_sampling_ratio: float, + num_eval_candidates: int, + num_inference_candidates: int, + debug: bool = False, + rank: int = 0, + ) -> None: + self.num_categories = num_categories + self.categories_per_user = categories_per_user + self.num_users = num_users + self.num_items = num_items + self.num_timestamps = num_timestamps + self.avg_samples_per_item = avg_samples_per_item + self.avg_seq_len_per_timestamp = int( + num_items * avg_samples_per_item / num_users / num_timestamps + ) + self.items_per_category: int = num_items // num_categories + self.category_to_start_end_item_idx: Dict[int, Tuple[int, int]] = {} + for i in range(num_categories): + start_idx = i * self.items_per_category + end_idx = (i + 1) * self.items_per_category + self.category_to_start_end_item_idx[i] = (start_idx, end_idx) + self.alpha_range = (1, 500) + self.min_seq_len: int = num_eval_candidates + 1 + self.train_ratio = train_ratio + self.num_eval_candidates = num_eval_candidates + self.num_inference_candidates = num_inference_candidates + self.debug = debug + self.total_cnt = 0 + self.rank = rank + logger.warning(f"rank {self.rank}: start generating item rating") + np.random.seed(1001) + self.item_rating = np.random.choice( # pyre-ignore [4] + [5.0, 4.0, 3.0, 2.0, 1.0], size=num_items, p=[0.2, 0.25, 0.25, 0.2, 0.1] + ) + logger.warning(f"rank {self.rank}: finish generating item rating") + self.user_sampling_ratio = user_sampling_ratio + + def generate_one_timestamp( + self, + category_to_cnt: Dict[int, int], + categories: List[int], + t: int, + id: int, + output_folder: str, + uih_seq_len: int, + eval: bool, + inference: bool, + file_idx: int, + ts_buffers: Dict[int, List[int]], + ) -> Tuple[List[int], List[float], List[int], List[float], Dict[int, int]]: + """ + Generate interaction data for a single user at one timestamp. + + Args: + category_to_cnt: Running count of interactions per category. + categories: Categories this user is interested in. + t: Current timestamp index. + id: User ID. + output_folder: Output directory for files. + uih_seq_len: Length of user interaction history to generate. + eval: Whether this is for evaluation. + inference: Whether this is for inference. + file_idx: File index for output. + ts_buffers: Buffer for timestamp data. + + Returns: + Tuple of (uih_item_ids, uih_ratings, candidate_ids, candidate_ratings, + updated_category_counts). + """ + if t >= 0 and (not eval): + if t not in ts_buffers: + ts_buffers[t] = [] + ts_buffers[t].append(id) + seq_len: int = self.num_inference_candidates if inference else uih_seq_len + self.total_cnt += seq_len + alpha = random.randint(self.alpha_range[0], self.alpha_range[1]) + total_cnt = sum(category_to_cnt.values()) + p = np.array( + [ + (alpha / len(categories) + category_to_cnt[c]) / (alpha + total_cnt) + for c in categories + ] + ) + item_categories = np.random.choice(categories, size=seq_len, p=p) + unique, counts = np.unique(item_categories, return_counts=True) + for cat, cnt in zip(unique, counts): + category_to_cnt[cat] += int(cnt) + sample_end_idx = int( + self.items_per_category * max((t + 1), 1) / self.num_timestamps + ) + sample_inds = np.random.randint(0, sample_end_idx, size=seq_len) + offsets = np.array( + [self.category_to_start_end_item_idx[cat][0] for cat in item_categories] + ) + sample_inds = sample_inds + offsets + num_categories = len(categories) + quarter = num_categories // 4 + half = num_categories // 2 + three_quarter = num_categories // 4 * 3 + category_to_ratings = {} + cos1 = math.cos(t * math.pi / 4) + cos2 = math.cos((t + 2) * math.pi / 4) + cos3 = math.cos((t + 4) * math.pi / 4) + for i, cat in enumerate(categories): + if i < quarter: + if self.debug: + ratings = np.full(seq_len, 5.0) + else: + ratings = np.random.choice( + [4.5 + 0.5 * cos1, 4.0 + 0.5 * cos2], + size=seq_len, + p=[0.8, 0.2], + ) + elif i < half: + if self.debug: + ratings = np.full(seq_len, 4.0) + else: + ratings = np.random.choice( + [4.5 + 0.5 * cos1, 4.0 + 0.5 * cos2, 3.5 + 0.5 * cos3], + size=seq_len, + p=[0.1, 0.8, 0.1], + ) + elif i < three_quarter: + if self.debug: + ratings = np.full(seq_len, 3.0) + else: + ratings = np.random.choice( + [3.5 + 0.5 * cos1, 3.0 + 0.5 * cos2, 2.5 + 0.5 * cos3], + size=seq_len, + p=[0.1, 0.8, 0.1], + ) + else: + if self.debug: + ratings = np.full(seq_len, 2.0) + else: + ratings = np.random.choice( + [2.5 + 0.5 * cos1, 2.0 + 0.5 * cos2, 1.5 + 0.5 * cos3], + size=seq_len, + p=[0.1, 0.8, 0.1], + ) + category_to_ratings[cat] = ratings + sample_inds = sample_inds.tolist() + sample_ratings = [ + ( + category_to_ratings[item_categories[i]][i] + + self.item_rating[sample_inds[i]] + ) + / 2 + for i in range(seq_len) + ] + if not inference: + sub_indices = random.sample(range(seq_len), self.num_eval_candidates) + sample_candidate_inds = [sample_inds[i] for i in sub_indices] + sample_candidate_ratings = [sample_ratings[i] for i in sub_indices] + sample_uih_inds = sample_inds + sample_uih_ratings = sample_ratings + else: + sub_indices = random.sample(range(seq_len), uih_seq_len) + sample_uih_inds = [sample_inds[i] for i in sub_indices] + sample_uih_ratings = [sample_ratings[i] for i in sub_indices] + sample_candidate_inds = sample_inds + sample_candidate_ratings = sample_ratings + return ( + sample_uih_inds, + sample_uih_ratings, + sample_candidate_inds, + sample_candidate_ratings, + category_to_cnt, + ) + + def gen_rand_seq_len(self) -> int: + """ + Generate a random sequence length from a Gaussian distribution. + + Returns: + Sequence length, guaranteed to be at least min_seq_len. + """ + seq_len = round( + random.gauss( + self.avg_seq_len_per_timestamp, self.avg_seq_len_per_timestamp // 4 + ) + ) + seq_len = self.min_seq_len if seq_len < self.min_seq_len else seq_len + return seq_len + + def get_timestamp_sample(self, t: int) -> int: + """ + Determine if a user should be sampled at this timestamp. + + Args: + t: Timestamp index. Base timestamp (-1) is always sampled. + + Returns: + 1 if the user should be sampled, 0 otherwise. + """ + if t == -1: + sample = 1 + else: + sample = np.random.choice( + [1, 0], + size=1, + p=[self.user_sampling_ratio, 1 - self.user_sampling_ratio], + )[0] + return sample + + def generate_one_user( + self, + id: int, + output_folder: str, + file_idx: int, + ts_buffers: Dict[int, List[int]], + ) -> List[str]: + """ + Generate complete interaction history for one user. + + Creates training, evaluation, and inference data for a single user + across all timestamps. + + Args: + id: User ID. + output_folder: Output directory. + file_idx: File index for output. + ts_buffers: Buffer for timestamp metadata. + + Returns: + List of CSV row values for this user's data. + """ + categories = random.sample(range(self.num_categories), self.categories_per_user) + category_to_cnt = {c: 0 for c in categories} + out_list: List[str] = [] + # t = -1 as base UIH + ( + sample_inds, + sample_ratings, + sample_candidate_inds, + sample_candidate_ratings, + category_to_cnt, + ) = self.generate_one_timestamp( + category_to_cnt=category_to_cnt, + categories=categories, + t=-1, + id=id, + output_folder=output_folder, + uih_seq_len=self.gen_rand_seq_len(), + eval=False, + inference=False, + file_idx=file_idx, + ts_buffers=ts_buffers, + ) + out_list.append(",".join([str(ind) for ind in sample_candidate_inds])) + out_list.append(",".join([str(rat) for rat in sample_candidate_ratings])) + out_list.append(",".join([str(ind) for ind in sample_inds])) + out_list.append(",".join([str(rat) for rat in sample_ratings])) + # train + for t in range(int(self.num_timestamps * self.train_ratio)): + if self.get_timestamp_sample(t): + ( + sample_inds, + sample_ratings, + sample_candidate_inds, + sample_candidate_ratings, + category_to_cnt, + ) = self.generate_one_timestamp( + category_to_cnt=category_to_cnt, + categories=categories, + t=t, + id=id, + output_folder=output_folder, + uih_seq_len=self.gen_rand_seq_len(), + eval=False, + inference=False, + file_idx=file_idx, + ts_buffers=ts_buffers, + ) + out_list.append(",".join([str(ind) for ind in sample_candidate_inds])) + out_list.append( + ",".join([str(rat) for rat in sample_candidate_ratings]) + ) + out_list.append(",".join([str(ind) for ind in sample_inds])) + out_list.append(",".join([str(rat) for rat in sample_ratings])) + else: + out_list += ["", "", "", ""] + # eval + ( + sample_inds, + sample_ratings, + sample_candidate_inds, + sample_candidate_ratings, + category_to_cnt, + ) = self.generate_one_timestamp( + category_to_cnt=category_to_cnt, + categories=categories, + t=int(self.num_timestamps * self.train_ratio), + id=id, + output_folder=output_folder, + uih_seq_len=self.num_eval_candidates, + eval=True, + inference=False, + file_idx=file_idx, + ts_buffers=ts_buffers, + ) + out_list.append(",".join([str(ind) for ind in sample_candidate_inds])) + out_list.append(",".join([str(rat) for rat in sample_candidate_ratings])) + out_list.append(",".join([str(ind) for ind in sample_inds])) + out_list.append(",".join([str(rat) for rat in sample_ratings])) + # inference + for t in range( + int(self.num_timestamps * self.train_ratio), self.num_timestamps + ): + if self.get_timestamp_sample(t): + ( + sample_inds, + sample_ratings, + sample_candidate_inds, + sample_candidate_ratings, + category_to_cnt, + ) = self.generate_one_timestamp( + category_to_cnt=category_to_cnt, + categories=categories, + t=t, + id=id, + output_folder=output_folder, + uih_seq_len=self.gen_rand_seq_len(), + eval=False, + inference=True, + file_idx=file_idx, + ts_buffers=ts_buffers, + ) + out_list.append(",".join([str(ind) for ind in sample_candidate_inds])) + out_list.append( + ",".join([str(rat) for rat in sample_candidate_ratings]) + ) + out_list.append(",".join([str(ind) for ind in sample_inds])) + out_list.append(",".join([str(rat) for rat in sample_ratings])) + else: + out_list += ["", "", "", ""] + return out_list + + def write_dataset( + self, output_folder: str, num_files: int, file_idx: int, seed: int + ) -> None: + """ + Write dataset for a single file partition. + + Args: + output_folder: Output directory path. + num_files: Total number of files in the dataset. + file_idx: Index of this file partition. + seed: Random seed for reproducibility. + """ + t0 = time.time() + num_users_per_file = self.num_users // num_files + user_id: int = num_users_per_file * file_idx + random.seed(seed + file_idx) + np.random.seed(seed + file_idx) + # Buffer timestamp data in memory to avoid excessive file I/O + ts_buffers: Dict[int, List[int]] = {} + output_file = output_folder + f"{file_idx}.csv" + with open(output_file, "w") as file: + writer = csv.writer(file) + for i in range(num_users_per_file): + out_list = self.generate_one_user( + id=user_id, + output_folder=output_folder, + file_idx=file_idx, + ts_buffers=ts_buffers, + ) + user_id += 1 + writer.writerow(out_list) + if i % 10000 == 0: + logger.warning( + f"rank {self.rank}: Done with users {i} for file {file_idx + 1} / {num_files}, total_cnt = {self.total_cnt}, spends {time.time() - t0} seconds." + ) + # Write buffered timestamp data after all users are processed + for ts, user_ids in ts_buffers.items(): + ts_file = output_folder + f"ts_{file_idx}_{ts}.csv" + with open(ts_file, "w") as f: + writer = csv.writer(f) + for uid in user_ids: + writer.writerow([uid]) + logger.warning( + f"rank {self.rank}: Wrote {len(ts_buffers)} timestamp files for file {file_idx}" + ) + + +def worker( + rank: int, + world_size: int, + num_files: int, + num_users: int, + num_items: int, + num_categories: int, + categories_per_user: int, + num_timestamps: int, + avg_samples_per_item: int, + num_eval_candidates: int, + num_inference_candidates: int, + train_ratio: float, + user_sampling_ratio: float, + output_folder: str, +) -> None: + """ + Worker function for parallel data generation. + + Each worker generates a subset of the dataset files. + + Args: + rank: Worker rank. + world_size: Total number of workers. + num_files: Total files to generate. + num_users: Total users in dataset. + num_items: Total items in catalog. + num_categories: Number of item categories. + categories_per_user: Categories per user. + num_timestamps: Number of time periods. + avg_samples_per_item: Average interactions per item. + num_eval_candidates: Eval candidates count. + num_inference_candidates: Inference candidates count. + train_ratio: Training data fraction. + user_sampling_ratio: User sampling probability. + output_folder: Output directory. + """ + generator = StreamingSyntheticDataGenerator( + num_categories=num_categories, + categories_per_user=categories_per_user, + num_users=num_users, + num_items=num_items, + num_timestamps=num_timestamps, + avg_samples_per_item=avg_samples_per_item, + train_ratio=train_ratio, + user_sampling_ratio=user_sampling_ratio, + num_eval_candidates=num_eval_candidates, + num_inference_candidates=num_inference_candidates, + debug=False, + rank=rank, + ) + num_files_per_rank = num_files // world_size + file_indices = [i + rank * num_files_per_rank for i in range(num_files_per_rank)] + for file_idx in file_indices: + logger.warning(f"rank {rank}: start generating file {file_idx}") + generator.write_dataset( + output_folder=output_folder, + num_files=num_files, + file_idx=file_idx, + seed=1001, + ) + logger.warning(f"rank {rank}: finish generating file {file_idx}") + + +def write_offset(output_folder: str, num_files: int, num_users: int) -> None: + """ + Write file byte offsets for random access to user data. + + Creates an offset.csv file containing byte positions for each user + within their respective data files. + + Args: + output_folder: Directory containing data files. + num_files: Number of data files. + num_users: Total number of users. + """ + with open(output_folder + "offset.csv", "a") as output_file: + writer = csv.writer(output_file) + for i in range(num_files): + input_file = output_folder + f"{i}.csv" + offsets = [] + with open(input_file, "r") as f: + while True: + offset = f.tell() + line = f.readline() + if not line: + break + offsets.append(offset) + assert len(offsets) == num_users // num_files, ( + f"num_users {num_users // num_files} != {len(offsets)}" + ) + logger.warning(f"offsets for file {i} finished") + writer.writerow([",".join([str(offset) for offset in offsets])]) + + +def write_ts_metadata(output_folder: str, total_ts: int, num_files: int) -> None: + """ + Write timestamp metadata for streaming simulation. + + Creates files tracking which users are active at each timestamp + and cumulative counts for efficient streaming access. + + Args: + output_folder: Output directory path. + total_ts: Total number of timestamps. + num_files: Number of data files. + """ + with open(output_folder + "requests_per_ts.csv", "w") as file_requests: + with open(output_folder + "users_cumsum_per_ts.csv", "w") as file_cumsum: + requests_writer = csv.writer(file_requests) + cumsum_writer = csv.writer(file_cumsum) + for ts in range(total_ts): + requests = [] + num_users_per_file = [] + for file in range(num_files): + with open(f"{output_folder}ts_{file}_{ts}.csv", "r") as file: + reader = csv.reader(file) + size = 0 + for row in reader: + requests.append(int(row[0])) + size += 1 + num_users_per_file.append(size) + cumsum = np.cumsum(num_users_per_file).tolist() + assert cumsum[-1] == len(requests) + requests_writer.writerow([",".join([str(r) for r in requests])]) + cumsum_writer.writerow([",".join([str(s) for s in cumsum])]) + logger.warning(f"ts {ts} finished") + with open( + output_folder + "requests_per_ts_offset.csv", "w" + ) as file_requests_offset: + writer = csv.writer(file_requests_offset) + input_file = output_folder + "requests_per_ts.csv" + offsets = [] + with open(input_file, "r") as f: + while True: + offset = f.tell() + line = f.readline() + if not line: + break + offsets.append(offset) + assert len(offsets) == total_ts, f"total_ts {total_ts} != {len(offsets)}" + logger.warning("offsets for file requests_per_ts.csv finished") + writer.writerow([",".join([str(offset) for offset in offsets])]) + + +def copy_sub_dataset(src_folder: str) -> None: + """ + Copy a subset of dataset files for quick testing. + + Creates a sampled_data subdirectory with essential files. + + Args: + src_folder: Source folder containing full dataset. + """ + dst_folder = src_folder + "sampled_data/" + files_to_copy = [ + "0.csv", + "offset.csv", + "requests_per_ts.csv", + "requests_per_ts_offset.csv", + "users_cumsum_per_ts.csv", + ] + os.makedirs(dst_folder, exist_ok=True) + for filename in files_to_copy: + src_path = os.path.join(src_folder, filename) + dst_path = os.path.join(dst_folder, filename) + shutil.copy2(src_path, dst_path) + logger.warning("Files copied successfully.") + + +def main() -> None: + """ + Main entry point for synthetic data generation. + + Configures and launches parallel workers to generate a complete + streaming recommendation dataset. + """ + processes = [] + num_files = 100 + num_users = 5_000_000 + num_items = 1_000_000_000 + num_categories = 128 + categories_per_user = 4 + num_timestamps = 100 + avg_samples_per_item = 50 + num_eval_candidates = 32 + num_inference_candidates = 2048 + train_ratio = 0.9 + user_sampling_ratio = 0.7 + world_size = 5 + username = os.getlogin() + output_folder = f"/home/{username}/data/streaming-100b/" + for i in range(world_size): + p = multiprocessing.Process( + target=worker, + args=( + i, + world_size, + num_files, + num_users, + num_items, + num_categories, + categories_per_user, + num_timestamps, + avg_samples_per_item, + num_eval_candidates, + num_inference_candidates, + train_ratio, + user_sampling_ratio, + output_folder, + ), + ) + processes.append(p) + p.start() + for p in processes: + p.join() + write_offset(output_folder, num_files, num_users) + write_ts_metadata(output_folder, num_timestamps, num_files) + copy_sub_dataset(src_folder=output_folder) + + +if __name__ == "__main__": + main() diff --git a/recommendation_v4/generative_recommenders/dlrm_v3/tests/test_lifetime_auc_resume.py b/recommendation_v4/generative_recommenders/dlrm_v3/tests/test_lifetime_auc_resume.py new file mode 100644 index 000000000..3d5bbd1ee --- /dev/null +++ b/recommendation_v4/generative_recommenders/dlrm_v3/tests/test_lifetime_auc_resume.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict + +""" +Round-trip correctness test for ``LifetimeAUCMetricComputation`` checkpoint +serialization. + +Background: torchrec's ``AUCMetricComputation`` registers its +PREDICTIONS/LABELS/WEIGHTS buffers with ``persistent=False``, so the default +``state_dict()`` returns them empty and a separate ``_num_samples`` counter is +dropped too. Without the overrides on ``LifetimeAUCMetricComputation`` every +checkpoint resume would silently restart the lifetime AUC from an empty buffer. + +These tests assert: + 1. update -> compute == A; state_dict -> load_state_dict on a fresh metric -> + compute == A (buffers survive the round trip). + 2. ``_num_samples`` round-trips exactly (required so the next update() does + not take the init-sentinel branch and desync windowed eviction). + 3. The shared-blob path (buffers stripped) leaves a fresh metric empty, so the + per-rank artifact is the sole authority for the trailing buffer. + +Runs in <1s on CPU. Skipped automatically if torchrec is unavailable. +""" + +import unittest + +import torch + +try: + from generative_recommenders.dlrm_v3.utils import LifetimeAUCMetricComputation + + _HAVE_DEPS = True +except Exception: # pragma: no cover - import guard for envs without torchrec + _HAVE_DEPS = False + + +def _make_metric(n_tasks: int = 1, window: int = 10_000_000): + return LifetimeAUCMetricComputation( + my_rank=0, + batch_size=128, + n_tasks=n_tasks, + window_size=window, + ) + + +def _feed(metric, preds, labels, weights) -> None: + metric.update( + predictions=preds, + labels=labels, + weights=weights, + ) + + +@unittest.skipUnless(_HAVE_DEPS, "torchrec / generative_recommenders not importable") +class LifetimeAUCResumeTest(unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(0) + self.n_tasks = 1 + self.n = 4096 + self.preds = torch.rand(self.n_tasks, self.n) + self.labels = (torch.rand(self.n_tasks, self.n) > 0.5).float() + self.weights = torch.ones(self.n_tasks, self.n) + + def _compute_value(self, metric) -> float: + reports = metric._compute() + return float(reports[0].value.flatten()[0].item()) + + def test_state_dict_round_trip_preserves_auc(self) -> None: + m = _make_metric(self.n_tasks) + _feed(m, self.preds, self.labels, self.weights) + auc_a = self._compute_value(m) + n_a = m.lifetime_sample_count() + self.assertEqual(n_a, self.n) + + sd = m.state_dict() + + fresh = _make_metric(self.n_tasks) + fresh.load_state_dict(sd) + auc_b = self._compute_value(fresh) + + self.assertEqual(fresh.lifetime_sample_count(), self.n) + self.assertAlmostEqual(auc_a, auc_b, places=6) + + def test_num_samples_round_trips(self) -> None: + m = _make_metric(self.n_tasks) + _feed(m, self.preds, self.labels, self.weights) + sd = m.state_dict() + fresh = _make_metric(self.n_tasks) + fresh.load_state_dict(sd) + self.assertEqual(fresh._num_samples, m._num_samples) + + def test_continued_update_after_resume_matches_uninterrupted(self) -> None: + # Splitting a stream and resuming in the middle must equal feeding it all + # at once (this is what fails when _num_samples is not restored). + half = self.n // 2 + p1, p2 = self.preds[:, :half], self.preds[:, half:] + l1, l2 = self.labels[:, :half], self.labels[:, half:] + w1, w2 = self.weights[:, :half], self.weights[:, half:] + + ref = _make_metric(self.n_tasks) + _feed(ref, p1, l1, w1) + _feed(ref, p2, l2, w2) + auc_ref = self._compute_value(ref) + + part = _make_metric(self.n_tasks) + _feed(part, p1, l1, w1) + resumed = _make_metric(self.n_tasks) + resumed.load_state_dict(part.state_dict()) + _feed(resumed, p2, l2, w2) + auc_resumed = self._compute_value(resumed) + + self.assertAlmostEqual(auc_ref, auc_resumed, places=6) + + def test_blob_state_dict_strips_buffers(self) -> None: + from generative_recommenders.dlrm_v3.checkpoint import ( + _metric_blob_state_dict, + ) + + m = _make_metric(self.n_tasks) + _feed(m, self.preds, self.labels, self.weights) + blob = _metric_blob_state_dict(m) + prefix = LifetimeAUCMetricComputation._LIFETIME_KEY_PREFIX + self.assertFalse(any(k.startswith(prefix) for k in blob.keys())) + + # A fresh metric loaded from the stripped blob must NOT have history — + # the per-rank artifact is the only source of the trailing buffer. + fresh = _make_metric(self.n_tasks) + fresh.load_state_dict(blob) + self.assertEqual(fresh.lifetime_sample_count(), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/recommendation_v4/generative_recommenders/dlrm_v3/train/_env_bootstrap.py b/recommendation_v4/generative_recommenders/dlrm_v3/train/_env_bootstrap.py new file mode 100644 index 000000000..5470d4e39 --- /dev/null +++ b/recommendation_v4/generative_recommenders/dlrm_v3/train/_env_bootstrap.py @@ -0,0 +1,38 @@ +"""Gin-driven env-var bootstrap. + +Some env vars must be set *before* certain modules import (e.g. Triton's +`@triton.autotune` decorator reads `TRITON_FULL_AUTOTUNE` at module load +time, well before `gin.parse_config_file` runs in the default ordering). + +`apply_env_bootstrap()` is `@gin.configurable`, so the gin file becomes the +canonical source of truth. `train_ranker.py` parses gin with +`skip_unknown=True` early in `_main_func`, calls this function to push the +bindings into `os.environ`, then does the heavy imports. +""" + +import logging +import os +from typing import Optional + +import gin + +logger: logging.Logger = logging.getLogger(__name__) + + +@gin.configurable +def apply_env_bootstrap( + TRITON_FULL_AUTOTUNE: Optional[bool] = None, +) -> None: + # A pre-set environment variable wins over the gin binding. The pinned + # triton configs are MI350X-specific, so a different GPU arch (e.g. B200 + # sm_100) sets TRITON_FULL_AUTOTUNE=1 in the launcher environment to + # re-enable the full autotune search WITHOUT editing this (AMD-default) + # gin file. Cross-cluster launchers thus stay config-as-code via env. + if "TRITON_FULL_AUTOTUNE" in os.environ: + logger.info( + "env bootstrap: honoring pre-set TRITON_FULL_AUTOTUNE=%s (overrides gin binding)", + os.environ["TRITON_FULL_AUTOTUNE"], + ) + elif TRITON_FULL_AUTOTUNE is not None: + os.environ["TRITON_FULL_AUTOTUNE"] = "1" if TRITON_FULL_AUTOTUNE else "0" + logger.info("env bootstrap: TRITON_FULL_AUTOTUNE=%s", os.environ["TRITON_FULL_AUTOTUNE"]) diff --git a/recommendation_v4/generative_recommenders/dlrm_v3/train/gin/debug.gin b/recommendation_v4/generative_recommenders/dlrm_v3/train/gin/debug.gin new file mode 100644 index 000000000..9261dc222 --- /dev/null +++ b/recommendation_v4/generative_recommenders/dlrm_v3/train/gin/debug.gin @@ -0,0 +1,35 @@ +batch_size = 16 +dataset = "debug" + +# model parameters +make_model.dataset = %dataset + +# dense model optimizer +dense_optimizer_factory_and_class.learning_rate = 0.001 +dense_optimizer_factory_and_class.optimizer_name = "Adam" +dense_optimizer_factory_and_class.momentum = 0 +dense_optimizer_factory_and_class.weight_decay = 0 +dense_optimizer_factory_and_class.eps = 1e-8 +dense_optimizer_factory_and_class.betas = (0.95, 0.999) + +# sparse model optimizer +sparse_optimizer_factory_and_class.learning_rate = 0.001 +sparse_optimizer_factory_and_class.optimizer_name = "RowWiseAdagrad" +sparse_optimizer_factory_and_class.momentum = 0 +sparse_optimizer_factory_and_class.weight_decay = 0 +sparse_optimizer_factory_and_class.eps = 1e-8 +sparse_optimizer_factory_and_class.betas = (0.95, 0.999) + +# dataloader configs +make_train_test_dataloaders.batch_size = %batch_size +make_train_test_dataloaders.dataset_type = %dataset +make_train_test_dataloaders.train_split_percentage = 0.75 + +# train loop variables +train_loop.num_batches = 10 +train_loop.num_epochs = 1000 +train_loop.output_trace = True +train_loop.metric_log_frequency = 10 + +# logger variables +MetricsLogger.tensorboard_log_path = "/tmp/tensorboard_log_path.log" diff --git a/recommendation_v4/generative_recommenders/dlrm_v3/train/gin/yambda_5b.gin b/recommendation_v4/generative_recommenders/dlrm_v3/train/gin/yambda_5b.gin new file mode 100644 index 000000000..5da1d31c7 --- /dev/null +++ b/recommendation_v4/generative_recommenders/dlrm_v3/train/gin/yambda_5b.gin @@ -0,0 +1,669 @@ +# Per-rank batch size. Env-overridable so a denser per-sample shape can fit in +# HBM by lowering it without editing gin. Default 1024 preserves prior runs; +# streaming train+eval both read this macro (make_*streaming_dataloader.batch_size). +batch_size = @bs/env_int() +bs/env_int.key = "BATCH_SIZE" +bs/env_int.default = 1024 +# Dataloader parallelism. Env-overridable so a perf sweep can probe whether the +# shuffle steady-state cost is CPU-gather latency (hidden by more workers) vs +# GPU-side embedding work (not). Defaults preserve prior behavior (4 / 8). +num_workers = @nw/env_int() +nw/env_int.key = "NUM_WORKERS" +nw/env_int.default = 4 +prefetch_factor = @pf/env_int() +pf/env_int.key = "PREFETCH_FACTOR" +pf/env_int.default = 8 +dataset = "yambda-5b" + +# model parameters +make_model.dataset = %dataset +make_model.bf16_training = True +# HSTU attention/compute backend: "TRITON" (fused, flash-style — low HBM) or +# "PYTORCH" (unfused; materializes the dense [B,H,N,N] score tensor, ~32 GiB at +# N=2048/bs=1024). TRITON validated on MI350/ROCm. The HSTU_HAMMER_KERNEL env +# var, if set, overrides this binding for one-off runs. +make_model.hammer_kernel = "TRITON" + +# False = use pinned triton kernel configs (deterministic; whether that's +# the fast or slow equilibrium depends on which config was pinned for the +# current training shape + GPU). For a NEW training config (new shape, +# new GPU, new triton/torch version), set True and run with +# TRITON_PRINT_AUTOTUNING=1 to discover the fast configs, then update the +# pinned constants in ops/triton/_autotune_pinning.py call sites. +apply_env_bootstrap.TRITON_FULL_AUTOTUNE = False + +# ============================================================================= +# $SEED — global RNG seed for reproducible MODEL INITIALIZATION. +# +# WHAT IT CONTROLS (all weight init is a deterministic function of $SEED): +# 1. Dense parameters (HSTU transformer blocks, MLPs, action embeddings, the +# postprocessor). seed_everything() seeds python/numpy/torch/torch.cuda +# with $SEED right before make_model(). The SAME seed is set on every rank, +# so dense weights are initialized identically across ranks AND reproducibly +# run-to-run. +# 2. Sparse embedding tables (item_id, artist_id, album_id, uid, the cross +# tables). These are materialized on-device while DMP shards the model, NOT +# in make_model(), so seed_everything() alone does not pin them. Two +# mechanisms tie them to $SEED: +# (a) make_optimizer_and_shard() RE-SEEDS torch/torch.cuda from $SEED +# immediately before DistributedModelParallel(...), so the fused +# FBGEMM TBE init (which draws from the global RNG on-device) is +# reproducible. This is the path that actually applies here. +# (b) get_embedding_table_config() also attaches a per-table seeded +# init_fn (seed derived from sha256($SEED, table_name)) for the +# eager/non-meta code path. It no-ops on the meta device that DMP +# uses, so (a) is the effective guarantee for this setup. +# 3. Any other seeded RNG consumers (e.g. dropout's init-time draws). +# +# SCOPE ("Tier 1"): reproducible for a FIXED sharding plan — same GPU/world +# size AND same planner output. It is NOT invariant to changing the GPU count +# or sharding (the per-shard draw boundaries move). The init DISTRIBUTION is +# unchanged from stock (uniform +/-1/sqrt(num_embeddings), or a table's +# explicit weight_init_min/max), so $SEED affects determinism, not quality. +# +# VERIFY: INIT_CHECKSUM=1 (OFF by default) logs a per-table fingerprint + a +# one-line "[init-checksum] SEED=.. digest=.." right after DMP. Two builds with +# the same $SEED + plan print the same digest; different seeds differ. It is OFF +# by default because the fp64 per-shard reductions materialize a full fp64 copy +# of each local embedding shard (>150 GiB for the big tables), leaving almost no +# HBM headroom after sharding and OOMing the build on any node with residual +# memory. Enable only for an explicit reproducibility check (ideally a clean +# node / smaller batch). +# +# WHAT IT DOES *NOT* CONTROL (separate, independent knobs): +# - Streaming data order / shuffle permutation -> $STREAMING_SHUFFLE_SEED +# (get_dataset.streaming_shuffle_seed, below). +# - Train/eval holdout user split -> $SPLIT_SALT (below). +# So holding $STREAMING_SHUFFLE_SEED + $SPLIT_SALT fixed and varying ONLY $SEED +# isolates the effect of model initialization (a clean init-seed A/B / sweep). +# +# PARSE NOTE: seed_everything() runs right before make_model() in train_ranker +# (after the full gin parse), so this binding resolves in the second parse where +# env_int is registered. Override per-run via $SEED. +# Default 1 gives a fixed, reproducible seed each run; override $SEED to vary it +# (set $SEED = -1 to draw a fresh random seed per run). +seed_everything.seed = @seed/env_int() +seed/env_int.key = "SEED" +seed/env_int.default = 1 + +# $DECORRELATE_DROPOUT — re-seed torch/cuda with $SEED + rank after init so HSTU +# dropout masks differ per data-parallel rank. 1 = on, 0 = identical masks (default). +decorrelate_runtime_rng.enabled = @drr/env_int() +drr/env_int.key = "DECORRELATE_DROPOUT" +drr/env_int.default = 0 + +# dense model optimizer +# Learning rate is env-overridable via $DENSE_LR (default 1e-7). +dense_optimizer_factory_and_class.learning_rate = @dlr/env_float() +dlr/env_float.key = "DENSE_LR" +dlr/env_float.default = 0.0000001 +dense_optimizer_factory_and_class.optimizer_name = "Adam" +dense_optimizer_factory_and_class.momentum = 0 +dense_optimizer_factory_and_class.weight_decay = 0 +dense_optimizer_factory_and_class.eps = 1e-8 +dense_optimizer_factory_and_class.betas = (0.95, 0.999) + +# sparse model optimizer +# Learning rate is env-overridable via $SPARSE_LR (default 1e-7). +sparse_optimizer_factory_and_class.learning_rate = @slr/env_float() +slr/env_float.key = "SPARSE_LR" +slr/env_float.default = 0.0000001 +sparse_optimizer_factory_and_class.optimizer_name = "RowWiseAdagrad" +sparse_optimizer_factory_and_class.momentum = 0 +sparse_optimizer_factory_and_class.weight_decay = 0 +sparse_optimizer_factory_and_class.eps = 1e-8 +sparse_optimizer_factory_and_class.betas = (0.95, 0.999) + +# Gradient clipping for the STREAMING path (clips dense params; sparse tables +# use a fused optimizer and are unaffected). Env-overridable via $GRAD_CLIP_NORM. +# Default 1.0 = ON (max_norm=1.0); set 0.0 via $GRAD_CLIP_NORM to disable. +streaming_train_eval_loop.grad_clip_norm = @gcn/env_float() +gcn/env_float.key = "GRAD_CLIP_NORM" +gcn/env_float.default = 1.0 + +# Data root: resolved at runtime from $DLRM_DATA_PATH if set, else the literal +# below. Used by both make_train_test_dataloaders and get_dataset. +# Scoped (`data/env_path`) so this binding doesn't collide with the RUN_NAME +# env_path binding below — every distinct env_path() call site needs its own +# scope or the later `env_path.key=...` overrides earlier ones. +DATA_PATH = @data/env_path() +data/env_path.key = "DLRM_DATA_PATH" +data/env_path.default = "/apps/chcai/dlrm_data" + +# Shared train:eval split: fraction of USERS used for training; the remaining +# (1 - this) fraction are held out as a fixed eval set and NEVER trained. +# Bound to BOTH the static train-eval path (make_train_test_dataloaders, a +# positional split) and the streaming path (get_dataset, an explicit by-user +# hash split), so one value configures the holdout in either mode. +# 1.0 = no holdout (legacy streaming behavior). Override via $TRAIN_SPLIT_PERCENTAGE. +# Default 1.0: all users are trained AND evaluated (full-coverage eval), matching +# the alleval/qa2a production runs; set <1.0 (e.g. 0.90) for a clean held-out cohort. +TRAIN_SPLIT_PERCENTAGE = @tsp/env_float() +tsp/env_float.key = "TRAIN_SPLIT_PERCENTAGE" +tsp/env_float.default = 1.0 + +# dataloader configs +make_train_test_dataloaders.batch_size = %batch_size +make_train_test_dataloaders.eval_batch_size = 1024 +make_train_test_dataloaders.dataset_type = %dataset +make_train_test_dataloaders.train_split_percentage = %TRAIN_SPLIT_PERCENTAGE +make_train_test_dataloaders.new_path_prefix = %DATA_PATH +make_train_test_dataloaders.num_workers = %num_workers +make_train_test_dataloaders.prefetch_factor = %prefetch_factor +make_train_test_dataloaders.num_blocks = 1 + +# embedding planner: per-rank HBM ceiling the torchrec sharder targets. +# Override via $HBM_CAP_GB (e.g. lower to 150 to force more CW sharding). +make_optimizer_and_shard.hbm_cap_gb = @env_int() +env_int.key = "HBM_CAP_GB" +env_int.default = 260 + +# Embedding table placement. Global default applied to EVERY table: +# auto -> no constraint; planner decides from the HBM cap (default) +# hbm -> FUSED (resident in GPU HBM) +# uvm -> FUSED_UVM (host DDR via UVM, no HBM cache) +# uvm_caching -> FUSED_UVM_CACHING (host DDR + HBM cache) +# "force HBM" for all tables = set this to "hbm" (or export EMB_PLACEMENT=hbm). +make_optimizer_and_shard.embedding_placement = @emp/env_str() +emp/env_str.key = "EMB_PLACEMENT" +emp/env_str.default = "auto" + +# Per-table placement overrides (win over the global default above). Configure +# each table independently here as a gin dict; allowed values per table: +# "hbm" | "uvm" | "uvm_caching" | "auto" (unlisted tables use the global +# default above). merge=True means a per-run env override LAYERS ON TOP of this +# dict per key (tweak one table at launch, keep the rest), e.g.: +# EMB_PLACEMENT_OVERRIDES="uid=hbm" # only retargets uid; others stay as below +make_optimizer_and_shard.embedding_placement_overrides = @env_str_map() +env_str_map.key = "EMB_PLACEMENT_OVERRIDES" +env_str_map.merge = True +env_str_map.default = {} + +# Per-table SHARDING-TYPE overrides (shard layout, orthogonal to placement above). +# Absent/"auto" tables keep the planner's choice (ROW_WISE for the large yambda +# tables). Allowed per table: "row_wise"|"column_wise"|"table_wise"| +# "table_row_wise"|"table_column_wise"|"data_parallel"|"auto" (aliases rw/cw/tw/ +# twrw). DEFAULT OFF ({}) -> plan is byte-identical to the legacy all-ROW_WISE +# path. Opt in per run via env, e.g.: +# EMB_SHARDING_OVERRIDES="album_id=column_wise,artist_id=column_wise" +# WHY: ROW_WISE routes every lookup of row r to the single owner rank, so a few +# hot IDs concentrate the embedding all-to-all onto one rank and OOM it (the +# yambda-5b skew hang; album_id ~2.8x, artist_id ~1.3x per-rank load). COLUMN_WISE +# splits the table by embedding dim (every rank holds all rows, dim/world cols), +# so the a2a load is balanced by RANK regardless of which IDs are hot — removing +# the value-skew OOM — at identical per-rank table bytes. Convert only the skewed +# high-volume tables (album_id, artist_id); leave the balanced, highest-volume +# item_id and the tiny length-1 contextual/cross tables on ROW_WISE. +# WHY THIS GETS WORSE AT LARGER GLOBAL BATCH: the ROW_WISE a2a input buffer on the +# owner rank is sized by how many times its hot IDs appear across the WHOLE global +# batch, so that transient scales ~linearly with global batch size (here 32 ranks +# x 1024 = 32768, each carrying ~4096-token UIH sequences -> tens of millions of +# lookups/step, heavily re-hitting the same few popular albums/artists). Doubling +# the global batch ~doubles the hot-rank burst while every other rank stays idle, +# which is exactly what tipped GPU5/GPU3 from a saturated steady state (~208-238 +# GiB) over 288 GiB at window ~248. COLUMN_WISE makes each rank receive dim/world +# of EVERY lookup, so the per-rank a2a volume is ~global_batch/world (balanced) +# and grows with world size, not with which IDs are hot -> it scales cleanly as +# you push global batch / add ranks, instead of piling the growth onto one shard. +# Example (the yambda-5b 4-node run, global batch 32768, that hit the skew OOM): +# EMB_SHARDING_OVERRIDES="album_id=column_wise,artist_id=column_wise" +# validated by the CW smoke: reshard-loaded the ROW_WISE ckpt cleanly, window_auc +# stayed ~0.78-0.80, and the hot rank sat at ~58% (~120 GiB free) THROUGH the old +# OOM window (vs 16 GiB free right before the ROW_WISE crash). +# NOTE: changing the sharding plan changes the on-disk shard layout; a checkpoint +# written under a different plan must be resharded/validated on load. +make_optimizer_and_shard.embedding_sharding_overrides = @esh/env_str_map() +esh/env_str_map.key = "EMB_SHARDING_OVERRIDES" +esh/env_str_map.merge = True +esh/env_str_map.default = {} + +# Sparse embedding all-to-all wire precision. The embedding shuffle is the +# dominant, bandwidth-bound (esp. multi-node) collective; quantizing it via +# TorchRec QCommsConfig halves (bf16/fp16, both 2 bytes) the wire volume. +# Forward and backward are set independently (each: "fp32" | "bf16" | "fp16"). +# Both "fp32" = off (numerically identical to baseline trunk). +# Default is now fp16/fp16: an A/B run vs the fp32 baseline matched window AUC to +# within fixed-seed noise (mean Δ≈-1e-6, max|Δ|≈5e-5, r=1.0 over 0-54.5% data) +# while cutting the embedding-a2a wire volume in half (~6% end-to-end speedup). +# TorchRec golden_training suggests fwd=fp16 / bwd=bf16 (bf16's wider exponent +# hedges gradient overflow), but on yambda-5b the grads stay well in fp16 range +# (grad-clip=1.0, lr=1e-7), so fp16/fp16 is convergence-neutral here. +# Override via $SPARSE_A2A_FWD / $SPARSE_A2A_BWD (e.g. set both "fp32" to disable). +make_optimizer_and_shard.sparse_a2a_forward_precision = @saaf/env_str() +saaf/env_str.key = "SPARSE_A2A_FWD" +saaf/env_str.default = "fp16" +make_optimizer_and_shard.sparse_a2a_backward_precision = @saab/env_str() +saab/env_str.key = "SPARSE_A2A_BWD" +saab/env_str.default = "fp16" + +# ============================================================================= +# RUNTIME PATCHES / MONKEYPATCHES +# ----------------------------------------------------------------------------- +# Third-party kernels we override AT RUNTIME (no fork of the dependency). Each +# knob here is a kill switch for one such patch: ON by default (it fixes a real +# bug we hit), but flippable per-run so we can A/B the patch, reproduce the +# original failure, or fall back to stock if a future dependency version makes +# the patch unnecessary or incorrect. Patches are applied during model +# build/shard (make_optimizer_and_shard), before any training step. +# ============================================================================= +# +# --- qcomm_lowmem_clamp_cast: low-memory fbgemm fp16/bf16 quant codec --------- +# WHAT IT PATCHES +# fbgemm_gpu's embedding-a2a quantizer, fp32_to_fp16_with_clamp (and the bf16 +# variant), which the TorchRec qcomm codec calls to pack the sparse +# embedding all-to-all payload onto the wire. Stock implementation is: +# torch.clamp(tensor, HALF_MIN, HALF_MAX).half() +# torch.clamp() materializes a SECOND full-size fp32 tensor (same numel as the +# input) BEFORE the cast, so the transient peak is +# input(fp32) + clamp_temp(fp32) + output(fp16) ~= 2.5x the input. +# The patch reorders to an in-place, allocation-free-equivalent: +# tensor.half().clamp_(HALF_MIN, HALF_MAX) +# i.e. cast FIRST (only the fp16 output is allocated), then clamp IN PLACE — +# dropping the full-size fp32 clamp temp and cutting the transient peak by the +# size of the input tensor. +# +# WHY IT'S NEEDED (the bug it fixes) +# Embeddings are ROW-WISE sharded: every lookup of row r routes to the single +# rank that owns r. On a data-skewed batch a few extremely hot IDs send a +# disproportionate share of the global lookups to ONE owner rank, so that +# rank's a2a input tensor balloons. With the stock codec the extra fp32 clamp +# temp on that rank reached ~81.5 GiB, which OOM'd the rank mid-forward. The +# OOM'd rank then dropped out of the collective while its peers blocked forever +# in the embedding all-to-all -> ~30-min NCCL watchdog timeout -> the whole job +# SIGABRTs. This was a DETERMINISTIC hang (same window/step every run; e.g. +# yambda-5b 4-node fp16 a2a hit it at window 235 / global step 43621). +# See generative_recommenders/dlrm_v3/train/utils.py +# (_patch_fbgemm_lowmem_clamp_cast) for the full diagnosis + flight-recorder +# evidence. +# +# CORRECTNESS / PERF +# Numerically IDENTICAL (bit-for-bit) to stock: the single fp32->fp16 cast is +# the only rounding step in both orders; an fp32 value above HALF_MAX casts to +# +inf which clamp_ maps back to HALF_MAX, and NaNs pass through unchanged. +# In-place is safe because the codec encode() runs inside the qcomm autograd +# Function.forward with grad disabled (no graph to corrupt). No throughput +# regression (measured step time within run-to-run noise; it does strictly +# LESS memory traffic + allocation than stock). +# +# WHEN TO TURN OFF (set 0): only to reproduce the original OOM/hang for +# debugging, or to revalidate against stock after an fbgemm_gpu upgrade. +# Override via $QCOMM_LOWMEM_CODEC. Only takes effect when the embedding a2a is +# actually quantized (SPARSE_A2A_FWD/BWD not both fp32); with fp32 a2a there is +# no codec to patch and this knob is a no-op. +make_optimizer_and_shard.qcomm_lowmem_clamp_cast = @qlcc/env_int() +qlcc/env_int.key = "QCOMM_LOWMEM_CODEC" +qlcc/env_int.default = 1 + +get_dataset.name = %dataset +get_dataset.new_path_prefix = %DATA_PATH +# Total user-interaction-history (UIH) budget per sample, distributed evenly +# across 3 behaviour pools (listen+ / like / skip) at L//3 events each. +# Per-sample sequence the model sees = +# 3 × (L // 3) + 8 contextual + 1 candidate +# Choosing 4086 makes 3 × 1362 + 9 = 4095, the largest value that fits +# get_hstu_configs.max_seq_len = 4096 with no dataset-side truncation (the +# 4k analog of the previous 2039/2048 shape, where 3×679+9=2046 ≤ 2048). +# Larger L overflows the budget; the dataset truncates UIH events to fit. +# Note: like events are only 1.9% of the yambda corpus and max user lifetime +# is ~28k events, so the like pool fills to ~105 events per anchor on +# average (not 1362) — TRITON's jagged attention skips the unfilled slots, +# so the under-fill costs sequence budget but not GPU compute. +# Cache is keyed by L on disk under /hstu_cache_L/; +# switching L reuses an existing cache or builds a new one. Override via +# $HISTORY_LENGTH (default 4086 = the 4k-no-truncation shape; use 2039 with +# MAX_SEQ_LEN=2048 to reuse the previous 2k single-task cache). +get_dataset.history_length = @hl/env_int() +hl/env_int.key = "HISTORY_LENGTH" +hl/env_int.default = 4086 + +# UIH construction strategy — how a user's prior events become the sequence the +# model attends over. Override via $HISTORY_STRATEGY. Both strategies scan the +# last `scan_window` (20000) events before the anchor and consider ONLY the 3 +# behavior pools (listen+ / like / skip); dislike/unlike/undislike are excluded +# (no action-weight bit), so the model's action_weights=[1,2,4] are unchanged. +# The two differ only in HOW they pick events out of that scan window: +# +# "interleaved" (default): +# - Budget = equal per-pool quota of HISTORY_LENGTH//3 (=1362) events; take +# the last 1362 of EACH pool independently, then merge + re-sort by time. +# - HISTORY_LENGTH is thus a PER-POOL cap (nominal total = 3 * L//3). +# - Consequence: likes are only ~1.9% of the corpus, so the like pool +# under-fills (~105 of its 1362 slots) and that budget is NOT reallocated +# — the sequence comes up short. Measured: ~2663 events effective, ~4.0% +# of them likes (i.e. likes OVER-represented vs their natural rate). +# - Use when you want guaranteed like exposure per sample (class-balanced UIH). +# +# "last_n": +# - Budget = a single pool-agnostic cap: take the last HISTORY_LENGTH events +# of ANY pool type, already chronological (no per-pool split, no re-sort). +# - HISTORY_LENGTH is thus the LITERAL TOTAL UIH cap (not a per-pool L//3). +# - Consequence: the sequence fills to ~HISTORY_LENGTH (the only limit is how +# many pooled events exist in the scan window), so effective length is +# higher and the like share falls to its natural rate. Measured: ~4085 +# events effective (~1.5x interleaved), ~1.2% of them likes. +# - Trade-off: more recent listen+/skip context (longer effective sequence) +# at the cost of fewer likes; ~1.4x eval/step compute (more non-skipped +# jagged-attention positions). Keep HISTORY_LENGTH=4086 to fill the 4k +# model budget. +# +# Both strategies are strategy-INDEPENDENT on disk: the gather runs at +# sample-build time, so switching reuses the SAME hstu_cache_L/ +# and positions cache — no rebuild needed. +get_dataset.history_strategy = @hs/env_str() +hs/env_str.key = "HISTORY_STRATEGY" +hs/env_str.default = "interleaved" + +# Anchor-eligibility floor: a LISTEN event becomes a trainable/eval anchor once +# the user has >= MIN_HISTORY prior events. Decoupled from history_length (which +# is only the gather/truncation cap) — jagged attention handles short UIH, so we +# no longer need a full history_length of context to include a sample. The legacy +# behavior (require a full history_length, which dropped ~60% of users) is +# MIN_HISTORY=4086; the default below (0) includes ~all users AND their +# cold-start first event (zero prior context). Positions/anchor-ts caches are +# keyed by (L, MIN_HISTORY) so floors don't collide. Override via $MIN_HISTORY. +get_dataset.min_history = @mh/env_int() +mh/env_int.key = "MIN_HISTORY" +mh/env_int.default = 4086 + +# Model-side attention budget. Dataset truncates UIH to fit this value if +# `history_length + contextual + candidate` would overflow. Override via +# $MAX_SEQ_LEN (default 4096, the 4k-no-truncation shape paired with +# HISTORY_LENGTH=4086: 3*1362+9=4095 ≤ 4096). Set MAX_SEQ_LEN=2048 with +# HISTORY_LENGTH=2039 for the previous 2k production single-task shape. +get_hstu_configs.max_seq_len = @msl/env_int() +msl/env_int.key = "MAX_SEQ_LEN" +msl/env_int.default = 4096 + +# HSTU transformer depth (number of attention layers). Default 3. +# Override per-run via $HSTU_NUM_LAYERS. NOTE: changing depth changes the model +# shape, so a run with a new depth MUST use a FRESH CKPT_PATH (incompatible with +# checkpoints of a different depth). Resolved in the full gin parse +# (get_hstu_configs is not registered during the early skip_unknown parse), so +# the @env_int reference is skipped on the first pass — same safe path as the LR knobs. +get_hstu_configs.hstu_attn_num_layers = @nl/env_int() +nl/env_int.key = "HSTU_NUM_LAYERS" +nl/env_int.default = 3 + +# --- streaming (temporal-order) training ------------------------------------- +# Only consumed under `--mode streaming-train-eval`; the default train-eval +# path above is unaffected. Trains time window T then evals window T+1, +# advancing forward in wall-clock time (no future leakage). Window size is the +# temporal-ordering granularity knob (default 1 day). num_train_ts is clamped +# to the dataset's available window count at runtime; override via $NUM_TRAIN_TS. +get_dataset.streaming_window_seconds = 86400 +get_dataset.streaming_sort_within_window = False +# In-window shuffle to break user-major batching (consecutive sliding-window +# anchors otherwise come from the same few users -> few unique embedding reads). +# Diversity dial in [0,1] -- the AGREED, config-invariant benchmark knob. Maps to +# a within-segment shuffle with K = round(fraction * per-window train-anchor +# count): 0 = off (user-major, page-local mmap scans), 1 = full per-element +# shuffle (max diversity), intermediate = interpolation. Same fraction => same +# diversity regardless of world_size / #nodes / batch_size. +# +# Default 0.0 (no shuffle -> user-major, page-local mmap scans) so the standard +# run matches the production streaming order. Together with the fixed seed below +# the in-window order is fully DETERMINISTIC and identical across runs/resumes. +# Override per-run via $STREAMING_SHUFFLE_FRACTION (e.g. 1.0 for a full +# per-element shuffle / max diversity, 0.03 for the diversity/locality sweet spot). +streaming_shuffle_fraction = 0.0 +get_dataset.streaming_shuffle_fraction = @ssf/env_float() +ssf/env_float.key = "STREAMING_SHUFFLE_FRACTION" +ssf/env_float.default = %streaming_shuffle_fraction +# Fixed shuffle seed -> reproducible permutation. Exposed as a knob; override via +# $STREAMING_SHUFFLE_SEED only if you deliberately want a different draw. +get_dataset.streaming_shuffle_seed = @ssfseed/env_int() +ssfseed/env_int.key = "STREAMING_SHUFFLE_SEED" +ssfseed/env_int.default = 0 +# User-level train:eval holdout for the streaming path. With tsp<1.0, the top +# (1 - tsp) fraction of users (by a deterministic hash of uid+split_salt) are +# held out as a FIXED eval set and never trained -> no temporal/user leakage, +# stable comparable eval curve, bounded eval time. split_salt lets you draw a +# different holdout without changing the ratio. Override salt via $SPLIT_SALT. +get_dataset.train_split_percentage = %TRAIN_SPLIT_PERCENTAGE +get_dataset.split_salt = @ssalt/env_int() +ssalt/env_int.key = "SPLIT_SALT" +ssalt/env_int.default = 0 +make_streaming_dataloader.batch_size = %batch_size +make_streaming_dataloader.num_workers = %num_workers +make_streaming_dataloader.prefetch_factor = %prefetch_factor +make_persistent_streaming_dataloader.batch_size = %batch_size +make_persistent_streaming_dataloader.num_workers = %num_workers +make_persistent_streaming_dataloader.prefetch_factor = %prefetch_factor +streaming_train_eval_loop.num_train_ts = @nts/env_int() +nts/env_int.key = "NUM_TRAIN_TS" +# 299 daily windows -> with start_ts=0 the run sweeps ts=0..298, the full corpus +# (matches the long fp16 e2e run). Clamped to the dataset's available window +# count at runtime. Override via $NUM_TRAIN_TS. +nts/env_int.default = 299 +# Anchors need >= history_length prior events, so the first ~130 daily windows +# are near-empty warm-up. Default start_ts=0 trains the full corpus from the +# start (matches the canonical fp16 run); set $START_TS=150 to skip warm-up and +# begin at a dense window. Override via $START_TS. +streaming_train_eval_loop.start_ts = @sts/env_int() +sts/env_int.key = "START_TS" +sts/env_int.default = 0 +# Per-step metric logging cadence. Default 50 (one compute_and_log GPU->CPU +# sync per 50 batches). The streaming-resume test sets METRIC_LOG_FREQ=1 so +# every step emits a parseable "Step N metrics" line for trajectory comparison. +streaming_train_eval_loop.metric_log_frequency = @mlf/env_int() +mlf/env_int.key = "METRIC_LOG_FREQ" +mlf/env_int.default = 50 +# MLPerf train_loss event cadence (global train steps), INDEPENDENT of +# METRIC_LOG_FREQ above. 0 (default) = fall back to METRIC_LOG_FREQ, preserving +# the prior coupled behavior. Set $MLPERF_TRAIN_LOSS_LOG_FREQ>0 to log the MLPerf +# train_loss event at a different rate than the console/TB metrics. Disable the +# whole MLPerf stream with $MLPERF_LOGGING=0. +streaming_train_eval_loop.mlperf_train_loss_log_frequency = @mltlf/env_int() +mltlf/env_int.key = "MLPERF_TRAIN_LOSS_LOG_FREQ" +mltlf/env_int.default = 0 +# Diagnostic: log per-batch unique/total embedding-id counts on logged steps +# (rank 0). Quantifies the user-major batching redundancy and the realized +# diversity from get_dataset.streaming_shuffle_fraction. Off; set $DIAG_UNIQUE_EMB=1. +streaming_train_eval_loop.streaming_diag_unique_emb = @due/env_int() +due/env_int.key = "DIAG_UNIQUE_EMB" +due/env_int.default = 0 +# Chrome trace capture: reuses the shared Profiler.* bindings below (5-step +# window at step 52). The streaming step counter advances across train+eval +# batches, so step 52 lands in the first (train) window's compute. Off by +# default to avoid profiler overhead in production runs; set $OUTPUT_TRACE=1. +streaming_train_eval_loop.output_trace = @ot/env_int() +ot/env_int.key = "OUTPUT_TRACE" +ot/env_int.default = 0 +# Reuse one DataLoader (persistent workers) across windows instead of respawning +# per window. Skip eval to isolate window-reset cost. Override via env. +streaming_train_eval_loop.persistent_loader = @pl/env_int() +pl/env_int.key = "PERSISTENT_LOADER" +pl/env_int.default = 1 +# ---- Eval cadence: choose EXACTLY ONE of the two knobs below ---------------- +# The streaming loop can decide WHEN to run the full-holdout eval pass in one of +# two ways. They are MUTUALLY EXCLUSIVE — enabling the data-fraction cadence +# (EVAL_EVERY_DATA_PCT>0) requires EVAL_EVERY_N_WINDOWS=0; setting both >0 raises +# a ValueError at startup. The final end-of-run eval always runs in either mode. +# +# (1) PER-WINDOW cadence (EVAL_EVERY_N_WINDOWS). +# Full-holdout eval cadence (single knob; replaces the old EVAL_EACH_WINDOW +# on/off switch). 0 (default) = per-window cadence OFF -> defer to the +# data-fraction cadence below (EVAL_EVERY_DATA_PCT); if that is also 0, eval is +# disabled entirely (train-only, e.g. perf benchmarking or the resume test; the +# eval dataloader isn't even built). 1 = eval after every window. N>1 (e.g. 5 +# via $EVAL_EVERY_N_WINDOWS) = eval every Nth window (and always the final one) +# to amortize the cost of consuming the full next-day eval window. The cadence is +# anchored to the absolute ts grid so eval points stay stable across a mid-run +# resume. +# NOTE: each daily window has a DIFFERENT number of training samples, so a +# per-window cadence produces eval points that are UNEVENLY spaced in terms of +# how much data was trained between them. This is why the data-fraction cadence +# below is now the default; enable this per-window knob only if you specifically +# want eval anchored to the daily window grid. +streaming_train_eval_loop.eval_every_n_windows = @evn/env_int() +evn/env_int.key = "EVAL_EVERY_N_WINDOWS" +evn/env_int.default = 0 +# +# (2) DATA-FRACTION cadence (EVAL_EVERY_DATA_PCT). +# Run the full-holdout eval every time the run has trained this FRACTION of the +# run's TOTAL training data, so eval points are EVENLY spaced by data volume +# (compute), independent of how many samples each daily window happens to hold. +# This is the fix for the per-window cadence's uneven spacing noted above. +# value semantics (it is a fraction in (0, 1], NOT a percent number): +# 0.0 = OFF -> fall back to the per-window EVAL_EVERY_N_WINDOWS; +# if that is also 0, eval is disabled entirely (train-only). +# 0.005 (default)= eval every 0.5% of the data -> ~200 eval points total. +# 0.01 = eval every 1% of the data -> ~100 eval points total. +# 0.05 = eval every 5% of the data -> ~20 eval points total. +# 0.10 = eval every 10% of the data -> ~10 eval points total. +# How "fraction of data" becomes an eval trigger: at startup the fraction is +# converted ONCE into a global train-step interval +# eval_interval_steps = round(pct * total_train_anchors / (batch_size * world_size)) +# where total_train_anchors is counted over the ORIGINAL requested window range +# [start_ts, start_ts+num_train_ts). Eval then fires whenever the monotonic +# train global_step crosses a multiple of that interval (it can fire MID-WINDOW, +# i.e. partway through a daily window, and across window boundaries). Because the +# interval is computed over the original range and global_step is +# checkpoint-restored, the eval grid is identical on a cold start and on every +# resume. Each eval record's label carries "@step=" so the +# trajectory can be plotted against data volume. Override via $EVAL_EVERY_DATA_PCT. +streaming_train_eval_loop.eval_every_data_pct = @edp/env_float() +edp/env_float.key = "EVAL_EVERY_DATA_PCT" +edp/env_float.default = 0.005 +# Double-buffer windows: prepare the next window (index mask + first-batch +# prefetch) in a background thread during the current window's compute, hiding +# the per-window reset. Needs persistent_loader=1. Override via env. +streaming_train_eval_loop.double_buffer = @db/env_int() +db/env_int.key = "DOUBLE_BUFFER" +db/env_int.default = 1 +# Fixed eval-holdout window range (held-out users' anchors over these windows +# form the eval set evaluated at EVERY eval step). Default 299 = the window just +# past the ts=0..298 training sweep (matches the canonical fp16 run). Set +# EVAL_HOLDOUT_TS<0 to instead resolve at runtime to start_ts+num_train_ts (also +# stable across resume). EVAL_HOLDOUT_NUM_WINDOWS widens the eval span. +streaming_train_eval_loop.eval_holdout_ts = @eht/env_int() +eht/env_int.key = "EVAL_HOLDOUT_TS" +eht/env_int.default = 299 +streaming_train_eval_loop.eval_holdout_num_windows = @ehnw/env_int() +ehnw/env_int.key = "EVAL_HOLDOUT_NUM_WINDOWS" +ehnw/env_int.default = 1 +# num_train_batches / num_eval_batches unset => consume each full window. +# Set them (e.g. via gin) to cap per-window steps for short experiments. + +# Default (non-streaming) train-eval loop variables; used unless +# `--mode streaming-train-eval` selects the temporal-order path configured above. +train_eval_loop.num_epochs = 1 +train_eval_loop.metric_log_frequency = 50 +train_eval_loop.eval_frequency = 500 +train_eval_loop.num_eval_batches = 200 +train_eval_loop.checkpoint_frequency = 1000000000 # disable mid-training checkpoints (disk-full guard) +train_eval_loop.output_trace = True +# 3-stage TorchRec pipeline: overlaps the embedding all-to-all (the dominant +# exposed-comm bottleneck) with dense fwd/bwd. Set False to fall back to the +# sequential fwd/bwd loop. +train_eval_loop.use_pipeline = False + +# Run name → recommendation_v4/results// (override via $RUN_NAME env). +RUN_NAME = @run/env_path() +run/env_path.key = "RUN_NAME" +run/env_path.default = "default" +run_results_dir.run_name = %RUN_NAME + +# profiler: capture a 5-step window starting at step 52, every rank. +# `trim_warmup = True` (default) post-filters the chrome trace so only the 5 +# `active` steps appear; the 1 WARMUP step still runs (lets CUPTI/ROCprof +# settle) but its events are dropped from the output file. +Profiler.trace_steps = [52] +Profiler.warmup = 1 +Profiler.active = 5 +Profiler.trace_dir = @run_results_dir() + +# logger variables +# TensorBoard event dir. DISABLED BY DEFAULT (empty path): the shared-NFS +# tfevents writer is the only metrics sink whose I/O error is uncaught, and it +# repeatedly crashed trainers on transient /apps `Errno 121` (Remote I/O) hiccups. +# Nothing we consume reads TensorBoard — eval-window AUCs come from the durable +# `.metrics.jsonl` sink (try/except-guarded) and the text run log. An empty +# path makes MetricsLogger install a no-op writer (see _NoOpSummaryWriter). +# Re-enable for a run by setting $TENSORBOARD_LOG_PATH to a non-empty dir. +MetricsLogger.tensorboard_log_path = @tbp/env_path() +tbp/env_path.key = "TENSORBOARD_LOG_PATH" +tbp/env_path.default = "" +MetricsLogger.world_size = 8 +# MLPerf convergence target: run stops when the selected eval AUC reaches it. +# Default 1.0 is unreachable, so the run trains through all windows (no early +# stop) out of the box; set $AUC_THRESHOLD=0.80275 for the MLPerf target. +MetricsLogger.auc_threshold = @at/env_float() +at/env_float.key = "AUC_THRESHOLD" +at/env_float.default = 1.0 +# EVAL_ACCURACY + early-stop are driven by per-window AUC (window_auc) vs the +# threshold above. +# Lifetime-AUC backend, selectable independently for the train cumulative AUC and +# the eval cumulative ("lifetime_*") AUC. Both default to "binned": +# "binned" = BinnedCumulativeAUC: exact-cumulative AUC from an O(bins) score +# histogram (additive all-reduce, memory independent of #samples/#windows). +# "capped" = LifetimeAUCMetricComputation: AUC over a trailing buffer of +# `lifetime_auc_window` samples/rank (legacy; per-rank buffer all-gathered). +# Override per-run via $TRAIN_LIFETIME_AUC_MODE / $EVAL_LIFETIME_AUC_MODE. +MetricsLogger.train_lifetime_auc_mode = @tlam/env_str() +tlam/env_str.key = "TRAIN_LIFETIME_AUC_MODE" +tlam/env_str.default = "binned" +MetricsLogger.eval_lifetime_auc_mode = @elam/env_str() +elam/env_str.key = "EVAL_LIFETIME_AUC_MODE" +elam/env_str.default = "binned" +# Score-histogram resolution for the "binned" backend. Higher = finer AUC +# resolution at O(bins) memory. Override via $CUMULATIVE_AUC_BINS. +MetricsLogger.cumulative_auc_bins = @cab/env_int() +cab/env_int.key = "CUMULATIVE_AUC_BINS" +cab/env_int.default = 100000 +# Trailing-buffer size (samples/rank) for the "capped" backend. Override via +# $LIFETIME_AUC_WINDOW. Ignored when the backend is "binned". +MetricsLogger.lifetime_auc_window = @law/env_int() +law/env_int.key = "LIFETIME_AUC_WINDOW" +law/env_int.default = 10000000 +# Checkpointing disabled by default — a full DMP checkpoint is ~100s of GB and +# the streaming loop always saves on the final window. save_dmp_checkpoint +# no-ops on the empty path. Set $CKPT_PATH to a directory to re-enable; the +# load path auto-resolves to the highest-numbered numeric subdir inside it. +save_dmp_checkpoint.path = @ckpt/env_path() +ckpt/env_path.key = "CKPT_PATH" +ckpt/env_path.default = "" +load_dmp_checkpoint.path = @ckpt/env_path() +# Retention: keep only the most-recent N numeric subdirs after each successful +# save (atomic rename + prune-older). Override via $KEEP_LAST_N. +save_dmp_checkpoint.keep_last_n = @kln/env_int() +kln/env_int.key = "KEEP_LAST_N" +kln/env_int.default = 1 +# In-window checkpoint cadence for streaming-train-eval. 0 (default) = end-of- +# window only; on crash mid-window the partial progress is lost. Set N>0 (e.g. +# via $IN_WINDOW_CKPT_FREQ) for batch-granularity exact-once-on-resume — paired +# with the auto-latest load above, the resumed run skips the K already-trained +# batches of the partial window and continues with bit-equal trajectory. +streaming_train_eval_loop.in_window_checkpoint_frequency = @iwcf/env_int() +iwcf/env_int.key = "IN_WINDOW_CKPT_FREQ" +iwcf/env_int.default = 0 +# Global-step checkpoint cadence: save whenever the monotonic train global_step +# crosses a multiple of N (a true "every 1000 steps" trigger that spans windows +# and survives resume). 0 (default) = off. Override via $CKPT_STEP_FREQ. +streaming_train_eval_loop.checkpoint_step_frequency = @csf/env_int() +csf/env_int.key = "CKPT_STEP_FREQ" +csf/env_int.default = 0 +# Wall-clock checkpoint cadence in seconds: save when >= this many seconds have +# elapsed since the last save. Rank 0 owns the clock and broadcasts the decision +# so all ranks save together. Default 3600 = hourly saves (matches the canonical +# fp16 run); set 0.0 to disable. Override via $CKPT_TIME_INTERVAL_S. +streaming_train_eval_loop.checkpoint_time_interval_s = @ctis/env_float() +ctis/env_float.key = "CKPT_TIME_INTERVAL_S" +ctis/env_float.default = 3600.0 +# Cap each train_ts window's batch count (mostly for the resume test driver). +# Unset / 0 = use the full window. +streaming_train_eval_loop.num_train_batches = @ntb/env_int() +ntb/env_int.key = "NUM_TRAIN_BATCHES" +ntb/env_int.default = 0 +# Cap each eval (full-holdout) window's batch count. Unset / <=0 = consume the +# full eval window (the genuine full-holdout NE/AUC; this is what the long run +# uses). Set >0 via $NUM_EVAL_BATCHES to subsample eval for fast validation. +streaming_train_eval_loop.num_eval_batches = @neb/env_int() +neb/env_int.key = "NUM_EVAL_BATCHES" +neb/env_int.default = 0 +# Test-only failure injection: when >=0 and metric_logger.global_step['train'] +# reaches this, the process exits with code 42 right after the in-window save +# fires. Used by the streaming resume test harness (train/tests/) to verify resume. +streaming_train_eval_loop.die_at_step = @das/env_int() +das/env_int.key = "DIE_AT_STEP" +das/env_int.default = -1 diff --git a/recommendation_v4/generative_recommenders/dlrm_v3/train/mlperf_logging_utils.py b/recommendation_v4/generative_recommenders/dlrm_v3/train/mlperf_logging_utils.py new file mode 100644 index 000000000..8a716f87b --- /dev/null +++ b/recommendation_v4/generative_recommenders/dlrm_v3/train/mlperf_logging_utils.py @@ -0,0 +1,413 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-unsafe +"""MLPerf Training compliance logging for the DLRMv3 streaming-train-eval path. + +Rank-0-gated wrapper around ``mlperf_logging.mllog`` so the streaming loop emits +the MLPerf event stream without every call site re-checking rank or the dep. +""" + +import logging +import os +from typing import Any, Dict, Optional + +import gin +import torch + +logger: logging.Logger = logging.getLogger(__name__) + +try: + from mlperf_logging import mllog + from mlperf_logging.mllog import constants as mllog_constants + + _MLLOG_AVAILABLE = True +except Exception as e: # pragma: no cover - import-time guard + mllog = None # type: ignore[assignment] + mllog_constants = None # type: ignore[assignment] + _MLLOG_AVAILABLE = False + logger.warning( + "mlperf_logging not importable (%s); MLPerf logging disabled. " + "Install via `pip install git+https://github.com/mlcommons/logging.git`.", + e, + ) + + +def _rank() -> int: + if torch.distributed.is_available() and torch.distributed.is_initialized(): + return torch.distributed.get_rank() + return 0 + + +def _barrier() -> None: + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.barrier() + + +class MLPerfLogger: + """Rank-0-gated facade over ``mllog``. + + Event methods no-op on non-zero ranks and when mlperf_logging is absent. + ``sync=True`` barriers before emit so the timestamp reflects the slowest rank + (required for INIT_STOP/RUN_START/RUN_STOP). + """ + + def __init__( + self, + rank: Optional[int] = None, + log_path: Optional[str] = None, + default_stack_offset: int = 2, + benchmark_name: str = "hstu", + submitter_name: str = "AMD", + submission_platform: str = "MI355X", + fresh: bool = True, + ): + self.enabled: bool = _MLLOG_AVAILABLE + # Use the EXPLICIT caller rank: this is built before init_process_group, + # when dist.get_rank() would return 0 on every rank (all would log). + self.rank: int = rank if rank is not None else _rank() + self.benchmark_name: str = benchmark_name + self.submitter_name: str = submitter_name + self.submission_platform: str = submission_platform + self._logger = None + if not self.enabled: + return + # Only rank 0 emits, so only rank 0 needs the file handler. + if log_path and self.rank == 0: + log_dir = os.path.dirname(log_path) + if log_dir: # guard: os.makedirs("") raises for a bare filename + os.makedirs(log_dir, exist_ok=True) + # mllog's FileHandler APPENDS (mode "a"), which is what a resume needs + # so the single run's event stream accumulates across relaunches into + # one file. On a genuine cold start, truncate first so a re-used run + # dir / a previous crashed-cold-start's orphaned stream can't leave a + # second run_start in the file (the compliance checker requires + # EXACTLY_ONE). Resume (fresh=False) appends to continue the stream. + if fresh: + open(log_path, "w").close() + mllog.config(filename=log_path, default_stack_offset=default_stack_offset) + else: + mllog.config(default_stack_offset=default_stack_offset) + self._logger = mllog.get_mllogger() + + @property + def constants(self): # pyre-ignore[3] + return mllog_constants + + def event( + self, + key: str, + value: Any = None, + metadata: Optional[Dict[str, Any]] = None, + sync: bool = False, + ) -> None: + if sync: + _barrier() + if self.enabled and self.rank == 0: + self._logger.event(key=key, value=value, metadata=metadata or {}) + + def start( + self, + key: str, + value: Any = None, + metadata: Optional[Dict[str, Any]] = None, + sync: bool = False, + ) -> None: + if sync: + _barrier() + if self.enabled and self.rank == 0: + self._logger.start(key=key, value=value, metadata=metadata or {}) + + def end( + self, + key: str, + value: Any = None, + metadata: Optional[Dict[str, Any]] = None, + sync: bool = False, + ) -> None: + if sync: + _barrier() + if self.enabled and self.rank == 0: + self._logger.end(key=key, value=value, metadata=metadata or {}) + + def submission_info(self) -> None: + """Emit the five SUBMISSION_* events required for a valid submission.""" + if not (self.enabled and self.rank == 0): + return + c = mllog_constants + self.event(key=c.SUBMISSION_BENCHMARK, value=self.benchmark_name) + self.event(key=c.SUBMISSION_ORG, value=self.submitter_name) + self.event(key=c.SUBMISSION_DIVISION, value=c.CLOSED) + self.event(key=c.SUBMISSION_STATUS, value=c.ONPREM) + self.event(key=c.SUBMISSION_PLATFORM, value=self.submission_platform) + + def log_run_start( + self, + global_batch_size: int, + seed: int, + gradient_accumulation_steps: int = 1, + ) -> None: + """Emit submission info + core hyperparameters, then INIT_STOP + RUN_START. + + Optimizer names/LRs are read from gin (dense Adam + sparse RowWiseAdagrad), + resolving env-macro refs to concrete values. Call once on a genuine cold + start, after the model is built. INIT_STOP/RUN_START barrier so the + timestamp reflects the slowest rank, so ALL ranks must call this together + (non-rank-0 / disabled calls no-op the emit but still hit the barrier). + """ + c = self.constants + self.submission_info() + self.event(key=c.GLOBAL_BATCH_SIZE, value=int(global_batch_size)) + self.event( + key=c.GRADIENT_ACCUMULATION_STEPS, value=int(gradient_accumulation_steps) + ) + self.event(key=c.SEED, value=int(seed)) + self.event( + key=c.OPT_NAME, + value=_gin_param("dense_optimizer_factory_and_class.optimizer_name", "Adam"), + ) + self.event( + key=c.OPT_BASE_LR, + value=_gin_param("dense_optimizer_factory_and_class.learning_rate", None), + ) + self.event( + key="opt_sparse_name", + value=_gin_param( + "sparse_optimizer_factory_and_class.optimizer_name", "RowWiseAdagrad" + ), + ) + self.event( + key="opt_sparse_base_learning_rate", + value=_gin_param( + "sparse_optimizer_factory_and_class.learning_rate", None + ), + ) + self.end(key=c.INIT_STOP, sync=True) + self.start(key=c.RUN_START, sync=True) + + +def _gin_param(name: str, default: Any) -> Any: + """Read a gin-bound parameter, resolving env-macro refs to concrete values. + + Returns ``default`` if the parameter is unbound or a macro ref cannot be + resolved (so env-overridden LRs log as numbers, not unencodable objects). + """ + try: + value = gin.query_parameter(name) + except (ValueError, KeyError): + return default + if hasattr(value, "scoped_configurable_fn"): + try: + return value.scoped_configurable_fn() + except Exception: + return default + return value + + +class MLPerfRunTracker: + """Centralized MLPerf run-boundary state machine for the streaming loop. + + Owns the block/eval/run markers, the SAMPLES_COUNT/EPOCH_NUM progress + metadata, and the convergence decision (per-window AUC vs the configured + ``auc_threshold``). Every method no-ops when ``logger`` is None, so the + streaming loop can call them unconditionally. The convergence metric is + fixed to per-window AUC (higher-is-better). + """ + + # MetricsLogger.compute key short name for per-window AUC. + _EVAL_METRIC_SHORT = "window_auc" + + def __init__( + self, + logger: Optional[MLPerfLogger], + metric_logger: Any, + total_train_samples: int, + rank: int, + device: Any, + ): + self.logger = logger + self.metric_logger = metric_logger + self.total_train_samples = int(total_train_samples) + self.rank = int(rank) + self.device = device + self.run_stopped: bool = False + # Idempotency flag so the boundary helpers and the outer loop can both + # call start/stop without risking a double BLOCK_START/STOP. + self._block_open: bool = False + + @property + def enabled(self) -> bool: + return self.logger is not None + + def _progress(self) -> Dict[str, Any]: + c = self.logger.constants + samples = self.metric_logger.cumulative_train_samples + epoch = ( + samples / self.total_train_samples if self.total_train_samples > 0 else 0.0 + ) + return {c.SAMPLES_COUNT: samples, c.EPOCH_NUM: epoch} + + def log_dataset_sizes(self, eval_samples: Optional[int] = None) -> None: + if not self.enabled: + return + c = self.logger.constants + self.logger.event(key=c.TRAIN_SAMPLES, value=self.total_train_samples) + if eval_samples is not None: + self.logger.event(key=c.EVAL_SAMPLES, value=int(eval_samples)) + + def block_start(self) -> None: + if self.enabled and not self._block_open: + self.logger.start( + key=self.logger.constants.BLOCK_START, metadata=self._progress() + ) + self._block_open = True + + def block_stop(self) -> None: + if self.enabled and self._block_open: + self.logger.end( + key=self.logger.constants.BLOCK_STOP, metadata=self._progress() + ) + self._block_open = False + + def eval_start(self) -> None: + if self.enabled: + self.logger.start( + key=self.logger.constants.EVAL_START, metadata=self._progress() + ) + + def _target_metric(self, metrics: Dict[str, float]) -> Optional[float]: + # Key format `metric/{prefix}_{name}/{task}` (see MetricsLogger.compute); + # match the per-window AUC short name. + for key, val in metrics.items(): + short = key.split("/")[-2] if "/" in key else key + if short == self._EVAL_METRIC_SHORT: + return float(val) + return None + + def _meets_target(self, value: Optional[float]) -> bool: + thr = self.metric_logger.auc_threshold + if value is None or thr is None: + return False + return value >= thr + + def run_stop(self, status: object) -> None: + # Emit RUN_STOP exactly once, with an all-rank barrier so the timestamp + # reflects the slowest rank (MLPerf requirement). + if not self.enabled or self.run_stopped: + return + c = self.logger.constants + self.logger.end( + key=c.RUN_STOP, + metadata={c.STATUS: status, **self._progress()}, + sync=True, + ) + self.run_stopped = True + + def eval_stop(self, eval_metrics: Dict[str, float]) -> bool: + # Emit EVAL_ACCURACY + EVAL_STOP, early SUCCESS RUN_STOP on target. + # Rank 0 decides + broadcasts the stop bool so all ranks break in lockstep + # (a per-rank test could diverge and hang the next all-to-all). + if not self.enabled: + return False + c = self.logger.constants + eval_value = self._target_metric(eval_metrics) + if eval_value is not None: + self.logger.event( + key=c.EVAL_ACCURACY, value=eval_value, metadata=self._progress() + ) + self.logger.end(key=c.EVAL_STOP, metadata=self._progress()) + decision = torch.zeros(1, device=self.device) + if self.rank == 0 and not self.run_stopped and self._meets_target(eval_value): + decision[0] = 1.0 + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.broadcast(decision, src=0) + should_stop = bool(decision.item() > 0.5) + if should_stop: + # All ranks agree -> all reach the RUN_STOP barrier together. + self.run_stop(c.SUCCESS) + return should_stop + + def finalize(self, final_metrics: Dict[str, float]) -> None: + # End-of-run RUN_STOP when the target was never crossed: SUCCESS iff the + # final eval metric meets the target, else ABORTED. + if not self.enabled or self.run_stopped: + return + c = self.logger.constants + success = self._meets_target(self._target_metric(final_metrics)) + self.run_stop(c.SUCCESS if success else c.ABORTED) + + +def mlperf_checkpoint_present(ckpt_path: str) -> bool: + """True iff ``ckpt_path`` resolves to an existing checkpoint (i.e. a resume). + + A dependency-light mirror of ``checkpoint._resolve_latest_subdir`` so + ``train_ranker`` can decide cold-start vs resume BEFORE the heavy checkpoint + import + ``setup()``. This gates the one-time INIT_START/RUN_START markers: + emit them on a genuine cold start only, and never re-emit on a resume + relaunch (the MLPerf run spans the resume). Matches the loader's resolution: + empty path or a base dir with no numeric subdirs => cold start. + """ + if not ckpt_path: + return False + base = ckpt_path.rstrip("/") + # A leaf save (numeric basename) is a resume iff that dir actually exists. + if os.path.basename(base).isdigit(): + return os.path.isdir(base) + if not os.path.isdir(base): + return False + for name in os.listdir(base): + if name.isdigit() and os.path.isdir(os.path.join(base, name)): + return True + return False + + +@gin.configurable +def get_mlperf_logger( + rank: int = 0, + log_path: str = "", + benchmark_name: str = "hstu", + submitter_name: str = "AMD", + submission_platform: str = "MI355X", + fresh: bool = True, +) -> Optional[MLPerfLogger]: + """Build a configured :class:`MLPerfLogger`, or ``None`` if unavailable. + + Path defaults to ``$MLPERF_LOG_PATH``. Returns ``None`` (not a disabled + logger) so callers' ``is not None`` guards cleanly skip logging. + + Disable knob: set ``$MLPERF_LOGGING=0`` (or false/no/off) to turn the whole + MLPerf event stream off — returns ``None`` on EVERY rank, so the train loop's + ``is not None`` guards skip emission AND the cross-rank train-loss all-reduce + in lockstep. Default (unset / "1") = enabled, preserving prior behavior. + """ + if not _MLLOG_AVAILABLE: + return None + if os.environ.get("MLPERF_LOGGING", "1").strip().lower() in ( + "0", "false", "no", "off", + ): + logger.info("MLPerf logging disabled via $MLPERF_LOGGING=0") + return None + resolved_path = os.environ.get("MLPERF_LOG_PATH", log_path) + # SUBMISSION_PLATFORM defaults to "MI355X"; override per-submitter via env. + resolved_platform = os.environ.get( + "MLPERF_SUBMISSION_PLATFORM", submission_platform + ) + return MLPerfLogger( + rank=rank, + log_path=resolved_path, + benchmark_name=benchmark_name, + submitter_name=submitter_name, + submission_platform=resolved_platform, + fresh=fresh, + ) diff --git a/recommendation_v4/generative_recommenders/dlrm_v3/train/tests/checkpoint_cadence_test.py b/recommendation_v4/generative_recommenders/dlrm_v3/train/tests/checkpoint_cadence_test.py new file mode 100644 index 000000000..4a483c262 --- /dev/null +++ b/recommendation_v4/generative_recommenders/dlrm_v3/train/tests/checkpoint_cadence_test.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict +"""Unit tests for `select_in_window_checkpoint_reason` — the pure decision that +drives the streaming loop's three fine-grained checkpoint cadences: + + * `in_window_checkpoint_frequency` — per-window-local batch count + * `checkpoint_step_frequency` — monotonic global step ("every 1000 steps") + * `checkpoint_time_interval_s` — wall-clock ("hourly") + +These run without a GPU / distributed init: the loop broadcasts a single +`elapsed_since_last_save` from rank 0 and then calls this pure function, so +exercising the function directly fully covers the trigger semantics. +""" +import unittest + +from generative_recommenders.dlrm_v3.train.utils import ( + select_in_window_checkpoint_reason, +) + + +def _reason( + *, + batch: int = 1, + step: int = 1, + elapsed: float = 0.0, + in_window: int = 0, + step_freq: int = 0, + time_s: float = 0.0, +) -> str | None: + return select_in_window_checkpoint_reason( + train_batch_idx=batch, + global_step=step, + elapsed_since_last_save=elapsed, + in_window_checkpoint_frequency=in_window, + checkpoint_step_frequency=step_freq, + checkpoint_time_interval_s=time_s, + ) + + +class CheckpointCadenceTest(unittest.TestCase): + def test_all_disabled_never_fires(self) -> None: + for batch in (1, 100, 1000): + for step in (1, 1000, 5000): + self.assertIsNone(_reason(batch=batch, step=step, elapsed=1e9)) + + def test_step_based_every_1000(self) -> None: + # Fires exactly on multiples of the step frequency. + self.assertEqual(_reason(step=1000, step_freq=1000), "global_step") + self.assertEqual(_reason(step=2000, step_freq=1000), "global_step") + # Does not fire just off a boundary. + self.assertIsNone(_reason(step=999, step_freq=1000)) + self.assertIsNone(_reason(step=1001, step_freq=1000)) + + def test_step_zero_does_not_trigger(self) -> None: + # global_step==0 must not trivially satisfy `0 % N == 0`. + self.assertIsNone(_reason(step=0, step_freq=1000)) + + def test_time_based_interval(self) -> None: + # At/over the interval -> fires; under -> no save. + self.assertEqual( + _reason(step=3, elapsed=3600.0, time_s=3600.0), "time_interval" + ) + self.assertEqual( + _reason(step=3, elapsed=4000.0, time_s=3600.0), "time_interval" + ) + self.assertIsNone(_reason(step=3, elapsed=3599.9, time_s=3600.0)) + + def test_in_window_batch_cadence(self) -> None: + self.assertEqual(_reason(batch=5, in_window=5), "in_window_batch") + self.assertEqual(_reason(batch=10, in_window=5), "in_window_batch") + self.assertIsNone(_reason(batch=4, in_window=5)) + + def test_precedence_in_window_over_step_over_time(self) -> None: + # All three would fire this batch; precedence picks in_window first. + self.assertEqual( + _reason( + batch=5, + step=1000, + elapsed=9999.0, + in_window=5, + step_freq=1000, + time_s=3600.0, + ), + "in_window_batch", + ) + # in_window not due this batch -> step wins over time. + self.assertEqual( + _reason( + batch=4, + step=1000, + elapsed=9999.0, + in_window=5, + step_freq=1000, + time_s=3600.0, + ), + "global_step", + ) + # Neither batch nor step due -> time wins. + self.assertEqual( + _reason( + batch=4, + step=999, + elapsed=9999.0, + in_window=5, + step_freq=1000, + time_s=3600.0, + ), + "time_interval", + ) + + def test_step_and_time_combined_independent(self) -> None: + # Step frequency enabled, time disabled: only step boundaries fire. + self.assertEqual(_reason(step=1000, step_freq=1000, time_s=0.0), "global_step") + self.assertIsNone(_reason(step=1000, elapsed=1e9, step_freq=0, time_s=0.0)) + + +if __name__ == "__main__": + unittest.main() diff --git a/recommendation_v4/generative_recommenders/dlrm_v3/train/tests/streaming_resume_test.py b/recommendation_v4/generative_recommenders/dlrm_v3/train/tests/streaming_resume_test.py new file mode 100644 index 000000000..e3f78886c --- /dev/null +++ b/recommendation_v4/generative_recommenders/dlrm_v3/train/tests/streaming_resume_test.py @@ -0,0 +1,287 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end failure-injection test for streaming resume. + +Two scenarios, driven by the sibling `streaming_resume_test.sh` (see its header +for the full, platform-general launch wiring — NVIDIA B200 and AMD MI350/MI355). +This module is the shared log parser + a CLI the driver shells out to. + +SCENARIO `midwindow` — exact-once mid-window resume. Validates the four +single-window resume features end-to-end on the yambda-5b stack: + 1. Mid-window save (in_window_checkpoint_frequency) + 2. Within-window exact-once skip (StreamingWindowSampler.set_window skip) + 3. Auto-detect-latest checkpoint subdir + 4. keep_last_n retention (default 1) + +SCENARIO `multiwindow` — distributed-sync regression guard for the two fixes the +mid-window test cannot reach (it runs ONE window with per-window eval off): + A. total_train_anchors() computed once on rank 0 + broadcast (not world_size×). + B. window-boundary dist.barrier() before the first forward of each window. +Both only matter across >=2 windows with the data-fraction eval cadence +(EVAL_EVERY_DATA_PCT>0) active, and the deadlock they fix originally struck at a +window boundary mid-run — so the scenario trains multiple windows AND resumes +across a completed-window boundary. The signals are extracted by `summarize` +(see `summarize_run`) and asserted in the shell driver. + +Test flow (driven by the sibling `streaming_resume_test.sh`): + Phase 1 (baseline): Run streaming-train-eval for N=2 train_ts × K batches/window + with die_at_step=-1. Capture per-batch window_ne / window_auc into traj_baseline.json. + Phase 2 (interrupt): Same config but die_at_step=M (M mid-window-2). Expect + process to exit(42) after the in-window checkpoint at step M lands. + Phase 3 (resume): Re-launch with same CKPT_PATH (auto-latest picks the + in-window save). Continue to the same total step count. Capture + traj_resumed.json (which only contains the post-resume steps). + + Correctness is proven by the FUNCTIONAL INVARIANTS checked in the shell + driver (resumed at exactly batch_idx_in_window, per-rank RNG restored, atomic + save + keep_last_n), NOT by bit-equal trajectory matching. The training stack + is nondeterministic across runs (non-deterministic atomic scatter-add in the + embedding/attention backward on ROCm): two independent *cold* runs already + drift ~7e-4 in window_ne over 20 steps, and early-training chaos amplifies + it, so resume-vs-baseline can differ by a few percent even when resume is + perfect. The trajectory comparison here is therefore a LOOSE closeness bound + (default atol below) that only flags gross divergence — wrong data slice or + unrestored model state — while tolerating nondeterministic drift. + +This module also provides a CLI entry point used by the shell driver to (a) +parse a train.log into a step-keyed dict of metrics, and (b) compare two such +dicts and fail loudly on mismatch. +""" + +import argparse +import json +import re +import sys +from typing import Dict, List, Optional, Tuple + +# Per-step metrics from MetricsLogger.compute_and_log are emitted like: +# "train - Step 51 metrics: {'metric/lifetime_ne/listen_plus': tensor(1.0954, ...) +# 'metric/window_ne/listen_plus': tensor(0.9940, ...), +# 'metric/window_accuracy/listen_plus': tensor(0.6231, ...) ..." +_STEP_RE = re.compile(r"train - Step (\d+) metrics:") +_WNE_RE = re.compile(r"window_ne/listen_plus.*?tensor\(([0-9.]+)") +_WAUC_RE = re.compile(r"window_auc/listen_plus.*?tensor\(([0-9.]+)") +_WACC_RE = re.compile(r"window_accuracy/listen_plus.*?tensor\(([0-9.]+)") + + +# --- multi-window / data-pct-eval regression signals ------------------------- +# These cover the two distributed-sync fixes that the single-window mid-window +# test above does NOT exercise (it runs one window with per-window eval off): +# +# (A) total_train_anchors() rank-0 broadcast. The data-fraction eval cadence +# needs total_train_anchors — a multi-minute, single-threaded O(N) gather +# + uid-hash over the mmap'd anchor array. Run on EVERY rank it both wastes +# 8x CPU and desyncs the NCCL stream (a fast rank races into the first +# embedding all-to-all while slow ranks still hash) → deadlock. The fix +# computes it ONCE on rank 0 and broadcasts the scalar. yambda logs exactly +# one `total_train_anchors(start_ts=…)` line per call, so the regression +# guard is: that line appears EXACTLY ONCE per launch (was world_size×). +# +# (B) window-boundary barrier. Per-window data prep (`window_indices`, an O(N) +# mask over the ~18GB mmap) finishes at very different times across ranks; +# without a sync before the first forward the collective stream desyncs and +# the job hangs at the boundary. The fix adds a dist.barrier() at each +# window boundary. It is silent on the healthy path, so the trainer emits a +# `[window-barrier] … rendezvous complete` line (rank 0) per crossed window +# ONLY under WINDOW_BARRIER_DEBUG=1 — the guard counts those == #windows. +_TTA_RE = re.compile(r"total_train_anchors\(start_ts=(\d+),\s*num_ts=(\d+)\):") +_BARRIER_RE = re.compile(r"\[window-barrier\] train_ts=(\d+) rendezvous complete") +_DATA_PCT_SETUP_RE = re.compile( + r"\[data-pct-eval\] eval_every_data_pct=.*?eval_interval_steps=(\d+)" +) +_DATA_PCT_TRIGGER_RE = re.compile(r"\[data-pct-eval\] trigger eval train_ts=(\d+)") +_RESUME_COMPLETED_RE = re.compile(r"Resuming from completed train_ts=(\d+)") +_RESUME_MIDWINDOW_RE = re.compile( + r"Resuming mid-window at train_ts=(\d+) batch_idx_in_window=(\d+)" +) +# Test driver appends this sentinel after the trainer returns (clean OR crash); +# code 0 == the run finished all requested windows + final eval without hanging. +_PHASE_EXIT_RE = re.compile(r"PHASE_EXIT=(-?\d+)") + + +def summarize_run(log_path: str) -> Dict[str, object]: + """Extract the multi-window / data-pct-eval regression signals from a run log. + + Returns a JSON-able dict the shell driver asserts on. All counts are over the + WHOLE log (one launch's worth — the driver uses a fresh per-phase log).""" + tta_calls: List[Tuple[int, int]] = [] + barrier_windows: List[int] = [] + data_pct_eval_setup: bool = False + data_pct_eval_interval: Optional[int] = None + data_pct_eval_triggers: List[int] = [] + resume_completed_ts: Optional[int] = None + resume_midwindow: Optional[Tuple[int, int]] = None + phase_exit: Optional[int] = None + with open(log_path, "r", errors="replace") as f: + for line in f: + m = _TTA_RE.search(line) + if m: + tta_calls.append((int(m.group(1)), int(m.group(2)))) + m = _BARRIER_RE.search(line) + if m: + barrier_windows.append(int(m.group(1))) + m = _DATA_PCT_SETUP_RE.search(line) + if m: + data_pct_eval_setup = True + data_pct_eval_interval = int(m.group(1)) + m = _DATA_PCT_TRIGGER_RE.search(line) + if m: + data_pct_eval_triggers.append(int(m.group(1))) + m = _RESUME_COMPLETED_RE.search(line) + if m: + resume_completed_ts = int(m.group(1)) + m = _RESUME_MIDWINDOW_RE.search(line) + if m: + resume_midwindow = (int(m.group(1)), int(m.group(2))) + m = _PHASE_EXIT_RE.search(line) + if m: + phase_exit = int(m.group(1)) + return { + # (A) rank-0 broadcast: must be exactly 1 (was world_size× before the fix) + "total_train_anchors_calls": len(tta_calls), + "total_train_anchors_args": tta_calls, + # (B) barrier executed once per crossed window (rank 0, debug-gated) + "window_barrier_count": len(barrier_windows), + "windows_trained": sorted(set(barrier_windows)), + # data-fraction eval cadence active + actually fired + "data_pct_eval_setup": data_pct_eval_setup, + "data_pct_eval_interval_steps": data_pct_eval_interval, + "data_pct_eval_trigger_count": len(data_pct_eval_triggers), + # resume classification + "resume_completed_ts": resume_completed_ts, + "resume_midwindow": resume_midwindow, + # terminal status (None => still running / killed without sentinel) + "phase_exit": phase_exit, + } + + +def parse_trajectory(log_path: str) -> Dict[int, Dict[str, float]]: + """Extract a {step: {window_ne, window_auc, window_accuracy}} dict from a + train.log. The grep is loose on the metric line itself — we accept the + very long truncated form MetricsLogger prints.""" + out: Dict[int, Dict[str, float]] = {} + with open(log_path, "r", errors="replace") as f: + for line in f: + m = _STEP_RE.search(line) + if not m: + continue + step = int(m.group(1)) + wne = _WNE_RE.search(line) + wauc = _WAUC_RE.search(line) + wacc = _WACC_RE.search(line) + if not (wne and wauc and wacc): + continue + # Only keep ONE entry per step — log can have duplicate per-rank + # prints; first one wins (they're identical). + if step in out: + continue + out[step] = { + "window_ne": float(wne.group(1)), + "window_auc": float(wauc.group(1)), + "window_accuracy": float(wacc.group(1)), + } + return out + + +def compare_trajectories( + baseline: Dict[int, Dict[str, float]], + resumed: Dict[int, Dict[str, float]], + min_resume_step: int, + atol: float = 0.15, +) -> Tuple[bool, str]: + """Compare baseline vs resumed trajectories for steps >= min_resume_step. + + This is a LOOSE closeness bound, not a bit-equality check — see the module + docstring. `atol` defaults to a value that tolerates the nondeterministic + cross-run drift of this stack while still catching gross resume bugs. + Returns (ok, message). `ok=False` on any divergence outside `atol`.""" + steps = sorted(s for s in resumed if s >= min_resume_step) + if not steps: + return False, f"No resumed steps >= {min_resume_step}" + mismatches = [] + for s in steps: + if s not in baseline: + mismatches.append(f"step {s}: missing from baseline") + continue + b = baseline[s] + r = resumed[s] + for key in ("window_ne", "window_auc", "window_accuracy"): + if abs(b[key] - r[key]) > atol: + mismatches.append( + f"step {s} {key}: baseline={b[key]:.6f} " + f"resumed={r[key]:.6f} diff={b[key]-r[key]:+.6f}" + ) + if mismatches: + return False, ( + f"{len(mismatches)} mismatches across {len(steps)} resumed steps " + f"(atol={atol}):\n " + "\n ".join(mismatches[:10]) + ) + return True, ( + f"{len(steps)} resumed steps match baseline within atol={atol} " + f"(range: step {steps[0]}..{steps[-1]})" + ) + + +def main() -> int: + ap = argparse.ArgumentParser() + sub = ap.add_subparsers(dest="cmd", required=True) + + p_parse = sub.add_parser("parse", help="Parse a train.log → traj JSON") + p_parse.add_argument("log") + p_parse.add_argument("out") + + p_cmp = sub.add_parser("compare", help="Compare baseline vs resumed traj JSONs") + p_cmp.add_argument("baseline") + p_cmp.add_argument("resumed") + p_cmp.add_argument("--min-resume-step", type=int, required=True) + p_cmp.add_argument("--atol", type=float, default=0.15) + + p_sum = sub.add_parser( + "summarize", + help="Emit multi-window / data-pct-eval regression signals from a run log", + ) + p_sum.add_argument("log") + p_sum.add_argument("out", nargs="?", help="optional JSON output path") + + args = ap.parse_args() + if args.cmd == "parse": + traj = parse_trajectory(args.log) + with open(args.out, "w") as f: + json.dump(traj, f, indent=2) + print(f"Wrote {len(traj)} step entries to {args.out}", file=sys.stderr) + return 0 + if args.cmd == "compare": + with open(args.baseline) as f: + baseline = {int(k): v for k, v in json.load(f).items()} + with open(args.resumed) as f: + resumed = {int(k): v for k, v in json.load(f).items()} + ok, msg = compare_trajectories( + baseline, resumed, args.min_resume_step, atol=args.atol + ) + print(msg) + return 0 if ok else 1 + if args.cmd == "summarize": + summary = summarize_run(args.log) + out = json.dumps(summary, indent=2) + if args.out: + with open(args.out, "w") as f: + f.write(out) + print(out) + return 0 + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/recommendation_v4/generative_recommenders/dlrm_v3/train/tests/streaming_resume_test.sh b/recommendation_v4/generative_recommenders/dlrm_v3/train/tests/streaming_resume_test.sh new file mode 100755 index 000000000..e093fb31a --- /dev/null +++ b/recommendation_v4/generative_recommenders/dlrm_v3/train/tests/streaming_resume_test.sh @@ -0,0 +1,566 @@ +#!/bin/bash +# End-to-end failure-injection + resume test for streaming-train-eval. +# +# ============================================================================ +# WHY TWO TESTS (the intuition, in plain language) +# ============================================================================ +# Training runs over consecutive time WINDOWS (window 0, then 1, then 2, ...). +# All N GPUs must march from one window to the next IN LOCKSTEP: they constantly +# do "everybody-talk-at-once" group ops (NCCL collectives — sharing embeddings +# across GPUs), and every GPU must enter each group-op at the same time. If one +# GPU is late, the rest wait for it forever and the whole job FREEZES (deadlock). +# +# The two scenarios check DIFFERENT kinds of failure — not bigger vs smaller: +# +# midwindow = CORRECTNESS. "When I crash and resume, do I land on the exact +# right batch with the right RNG state and produce the same +# numbers?" (stays inside ONE window; never crosses a seam.) +# +# multiwindow = LIVENESS / NO-DEADLOCK. "Can all N GPUs hand off across a +# window SEAM together without one falling out of step and +# hanging the job?" (needs >=2 windows so a seam actually exists.) +# +# The dangerous spot is the SEAM between two windows: there, each GPU does solo +# prep work (load next window's data; count anchors for the eval cadence) and +# they DON'T all finish at the same speed. Two bugs lived exactly there, and BOTH +# are invisible to a single-window test: +# (A) every GPU separately ran a slow O(N) "count all the data" pass -> they +# finished at different times -> fast GPU barged into the next group-op +# while others were still counting -> freeze. +# FIX: only rank 0 counts, then broadcasts the number to everyone else. +# (B) no rendezvous at the seam -> uneven data-prep -> same desync -> freeze. +# FIX: a dist.barrier() at every window boundary (all GPUs wait, then cross +# together). WINDOW_BARRIER_DEBUG=1 makes rank 0 log one line per seam. +# +# TIMELINE — without the fixes (each GPU on its own clock at the seam): +# win0 train | solo prep (varies) | next group-op +# GPU0 ########|=====| >> waiting.......... +# GPU1 ########|========| >> waiting....... +# GPU2 ########|===========| >> waiting.... +# GPU3 ########|==============| >> never lines up -> HANG +# +# TIMELINE — with the fixes (rank 0 counts + a barrier gate at the seam): +# win0 train | [== BARRIER: all wait ==] | win1 train +# GPU0 ########| count | wait # |######## +# GPU1 ########| | wait # |######## +# GPU2 ########| | wait # |######## +# GPU3 ########| | wait # |######## +# rank0 shares count ^ all cross together ^ -> OK +# +# Why midwindow can NOT catch (A)/(B): it runs a SINGLE window with per-window +# eval off (NUM_TRAIN_TS=1, EVAL_EVERY_N_WINDOWS=0, split=1.0), so it never +# reaches a seam and never turns on the data-fraction-eval/anchor-count path. +# A broken barrier or broken broadcast passes midwindow silently. +# +# Why the multiwindow RESUME phase (P3 below) is the meanest case: restarting +# from a checkpoint loads the saved window and then IMMEDIATELY steps across a +# seam into the next window — landing right on the spot that used to freeze, AND +# re-running all that slow setup on the resume path. If (A)/(B) regressed, P3 +# hangs and the test fails by timing out. +# +# | midwindow | multiwindow +# --------------+----------------------+----------------------------- +# proves | resume to RIGHT spot | cross seam WITHOUT freezing +# windows | 1 (no seam) | >=2 (crosses >=1 seam) +# data-pct eval | off | on (exercises the anchor count) +# catches | wrong batch/RNG/ckpt | missing barrier/broadcast -> HANG +# failure mode | wrong NUMBERS | job FREEZES forever +# They are complementary: you need BOTH. +# ============================================================================ +# +# PLATFORM-GENERAL: runs on both NVIDIA B200 and AMD MI350/MI355 (ROCm/meta64). +# The only hardware-specific bits are picked by --platform (auto-detected from the +# running container if omitted): the container name, the dataset path, and the +# checkpoint root. Everything else — the worker entrypoint (scripts/launch_slurm.sh, +# which is the shared launcher both clusters' supervisors use), the env-driven gin +# knobs, and all assertions — is identical across platforms. +# +# Two scenarios (select with --scenario; default runs both): +# +# midwindow — exact-once MID-WINDOW resume (single window). +# P1 baseline: uninterrupted 1 train_ts × K batches. +# P2 interrupted: same + die_at_step=M → exits AFTER the in-window ckpt at M. +# P3 resume: relaunch w/ same CKPT_PATH → auto-latest picks the in-window +# save, skips the M already-trained batches, finishes. +# Gates: re-entered at batch_idx_in_window=M, per-rank RNG restored, first +# resumed step == M+1, atomic save + keep_last_n, trajectory within --atol. +# +# multiwindow — distributed-sync REGRESSION guard for the two fixes the +# mid-window test cannot reach (it runs ONE window with per-window eval off): +# (A) total_train_anchors() computed ONCE on rank 0 + broadcast (the +# data-fraction eval cadence needs it; running the multi-minute O(N) +# mmap gather + uid-hash on every rank desynced NCCL → boundary hang). +# (B) a dist.barrier() at every window boundary before the first forward +# (per-rank data-prep skew otherwise desyncs the collective stream). +# Both only bite across >=2 windows with EVAL_EVERY_DATA_PCT>0, and the +# deadlock struck at a boundary mid-run, so: +# P1 mw_baseline: cold run over MW_TS windows w/ data-pct eval. Asserts +# total_train_anchors logged EXACTLY ONCE (computed at setup + broadcast +# from rank 0), the barrier fired on EVERY window, the data-pct cadence +# was set up, and the run COMPLETED (no boundary hang). +# P2 mw_seed: 1 window → clean end-of-window (WINDOW_COMPLETE) ckpt. +# P3 mw_resume: relaunch over MW_TS windows w/ same CKPT_PATH → resumes +# past the completed window and CROSSES the boundary into the next +# windows. Asserts "Resuming from completed", barrier fired on each +# remaining window, anchors broadcast once, and the run COMPLETED — +# i.e. the exact boundary-crossing-on-resume case that used to hang. +# +# Driven entirely via env-driven gin knobs (yambda_5b.gin) through the SAME worker +# entrypoint both platforms' production supervisors use: `bash scripts/launch_slurm.sh` +# (worker phase, auto-detected inside the container). WINDOW_BARRIER_DEBUG=1 makes +# the otherwise-silent barrier emit one rank-0 line per crossed window. +# +# CHECKPOINT/DATASET PLACEMENT (the one real platform difference): +# * B200: virtiofs/NFS WEDGES under the trainer's concurrent mmap LOAD, so the +# checkpoint root AND the mmap'd dataset cache MUST be node-local (defaults +# /tmp/...). The dataset must already be staged node-local at --data-path +# (the e2e supervisor's stage_data_in does this); the test fails fast if not. +# * MI350/MI355 (meta64): NFS mmap is fine, so the checkpoint root + dataset +# read directly from shared NFS (defaults /apps/chcai/...), as the original +# test did. No staging needed. +# Logs always use read()/write() only, so they live on shared /apps/chcai and +# are grep-able from the head node on both platforms. +# +# Usage: +# bash generative_recommenders/dlrm_v3/train/tests/streaming_resume_test.sh \ +# --jobid [--platform b200|mi350] [--scenario all] +# [--container ] [--data-path ] [--ckpt-root ] [--start-ts 150] +# [--num-train-batches 200] [--die-at-step 100] # midwindow knobs +# [--mw-num-train-ts 3] [--mw-num-train-batches 20] # multiwindow knobs +# [--mw-eval-pct 0.34] [--phase-timeout S] [--mw-run-timeout S] [--keep] +# --platform is auto-detected from the running container when omitted. Any of +# --container/--data-path/--ckpt-root override the platform default. +# Per-phase wait budgets default per-platform (B200 node-local NVMe: 1800/3600s; +# MI350/MI355 shared-NFS full-model ckpts ~9 min each: 5400/5400s) and can be +# overridden with --phase-timeout (midwindow) / --mw-run-timeout (multiwindow). + +set -uo pipefail + +JOBID="" +# Repo root is derived from THIS script's location +# (/generative_recommenders/dlrm_v3/train/tests/streaming_resume_test.sh — +# four levels up), so the test is not pinned to any one user's home. Override with +# --repo if the repo is mounted at a different path inside the container. +_SELF_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +REPO=$(cd "$_SELF_DIR/../../../.." && pwd) +DATASET_SUBDIR=processed_5b/hstu_cache_L4086 +SCENARIO=all # midwindow | multiwindow | all +START_TS=150 +KEEP=0 +LOG_DIR=/apps/chcai/streaming_resume_test # shared NFS (read()/write() only) +# Platform + the three platform-specific paths. Empty sentinels here; filled by +# apply_platform_defaults() AFTER platform detection unless the user overrode +# them on the command line. (DATA_PATH uses a distinct sentinel because an +# explicit empty value is meaningful: "do not inject DLRM_DATA_PATH; let the gin +# default apply".) +PLATFORM="" # b200 | mi350 | mi355 ; auto if empty +CONTAINER="" # default: per-platform +DATA_PATH="__AUTO__" # default: per-platform +CKPT_ROOT="" # default: per-platform (node-local on B200) + +# --- midwindow knobs --- +NUM_TRAIN_BATCHES=200 +NUM_EVAL_BATCHES=5 # cap the per-phase FINAL eval (0 = full holdout, very slow) +DIE_AT_STEP=100 +IN_WINDOW_FREQ=50 +ATOL=0.15 # trajectory closeness bound (NOT bit-equality; see py module) +# Per-phase wait budget. Left empty here and filled per-platform below (a B200 +# ckpt save/load hits node-local NVMe and is fast; on meta64 each full-model DCP +# save/load lands on shared NFS and takes ~9 min, and the resume phase does a +# LOAD + several in-window saves + an end-of-window save, so it needs far longer). +# Override explicitly with --phase-timeout. +MW_TIMEOUT="" + +# --- multiwindow knobs --- +MW_TS=3 # windows to train (>=2 to cross a boundary) +MW_BATCHES=20 # train batches per window (small = fast) +MW_EVAL_BATCHES=5 # holdout eval batches per fired eval +MW_EVAL_PCT=0.34 # data-fraction eval cadence (>0 enables the anchors path) +MW_SPLIT=0.90 # train split (<1 => holdout exists => uid-hash anchor path) +MW_HOLDOUT_TS=200 # PINNED holdout window (must match across seed→resume) +# generous: init + planner + anchors gather can take min; on NFS add ckpt save/load. +# Empty => filled per-platform below. Override with --mw-run-timeout. +MW_RUN_TIMEOUT="" + +while [[ $# -gt 0 ]]; do + case $1 in + --jobid) JOBID="$2"; shift 2;; + --platform) PLATFORM="$2"; shift 2;; + --container) CONTAINER="$2"; shift 2;; + --repo) REPO="$2"; shift 2;; + --data-path) DATA_PATH="$2"; shift 2;; + --dataset-subdir) DATASET_SUBDIR="$2"; shift 2;; + --scenario) SCENARIO="$2"; shift 2;; + --start-ts) START_TS="$2"; shift 2;; + --ckpt-root) CKPT_ROOT="$2"; shift 2;; + --log-dir) LOG_DIR="$2"; shift 2;; + --num-train-batches) NUM_TRAIN_BATCHES="$2"; shift 2;; + --num-eval-batches) NUM_EVAL_BATCHES="$2"; shift 2;; + --die-at-step) DIE_AT_STEP="$2"; shift 2;; + --in-window-freq) IN_WINDOW_FREQ="$2"; shift 2;; + --atol) ATOL="$2"; shift 2;; + --phase-timeout) MW_TIMEOUT="$2"; shift 2;; + --mw-run-timeout) MW_RUN_TIMEOUT="$2"; shift 2;; + --mw-num-train-ts) MW_TS="$2"; shift 2;; + --mw-num-train-batches) MW_BATCHES="$2"; shift 2;; + --mw-num-eval-batches) MW_EVAL_BATCHES="$2"; shift 2;; + --mw-eval-pct) MW_EVAL_PCT="$2"; shift 2;; + --mw-split) MW_SPLIT="$2"; shift 2;; + --mw-holdout-ts) MW_HOLDOUT_TS="$2"; shift 2;; + --keep) KEEP=1; shift;; + *) echo "Unknown arg: $1"; exit 1;; + esac +done +[[ -z "$JOBID" ]] && { echo "Error: --jobid required"; exit 1; } +case "$SCENARIO" in midwindow|multiwindow|all) ;; *) echo "Error: --scenario must be midwindow|multiwindow|all"; exit 1;; esac +(( MW_TS < 2 )) && { echo "Error: --mw-num-train-ts must be >=2 to cross a boundary"; exit 1; } +[[ -n "$PLATFORM" ]] && case "$PLATFORM" in b200|mi350|mi355) ;; *) echo "Error: --platform must be b200|mi350|mi355"; exit 1;; esac + +# --- resolve platform + its three hardware-specific paths -------------------- +# Precedence: explicit --platform > inferred from explicit --container > probe +# the allocation's docker for a known training container > default b200. +if [[ -z "$PLATFORM" ]]; then + if [[ "$CONTAINER" == "yambda_b200" ]]; then PLATFORM=b200 + elif [[ "$CONTAINER" == "yambda_primus" ]]; then PLATFORM=mi350 + else + _names=$(srun --jobid="$JOBID" --overlap docker ps -a --format '{{.Names}}' 2>/dev/null) + if grep -qx yambda_b200 <<<"$_names"; then PLATFORM=b200 + elif grep -qx yambda_primus <<<"$_names"; then PLATFORM=mi350 + else + # No known training container yet (e.g. container not provisioned). + # Fall back to probing the allocation's GPU vendor on the host so we + # do NOT silently assume a platform. + _vendor=$(srun --jobid="$JOBID" --overlap bash -lc \ + 'if command -v rocm-smi >/dev/null 2>&1; then echo amd; \ + elif command -v nvidia-smi >/dev/null 2>&1; then echo nvidia; \ + else echo unknown; fi' 2>/dev/null | head -1) + case "$_vendor" in + amd) PLATFORM=mi350; echo "[$(date)] no known container — detected AMD GPU host (rocm-smi) → mi350";; + nvidia) PLATFORM=b200; echo "[$(date)] no known container — detected NVIDIA GPU host (nvidia-smi) → b200";; + *) echo "Error: could not auto-detect platform on job $JOBID (no yambda_b200/yambda_primus container and no rocm-smi/nvidia-smi). Pass --platform b200|mi350|mi355."; exit 1;; + esac + fi + fi + echo "[$(date)] auto-detected platform: $PLATFORM" +fi +case "$PLATFORM" in + b200) + : "${CONTAINER:=yambda_b200}" + # B200: mmap (ckpt LOAD + dataset cache) must NOT touch virtiofs/NFS. + [[ "$DATA_PATH" == "__AUTO__" ]] && DATA_PATH=/tmp/yambda_data + : "${CKPT_ROOT:=/tmp/yambda_resume_test/ckpts}" + # Node-local NVMe: full-model save/load is fast. + : "${MW_TIMEOUT:=1800}" + : "${MW_RUN_TIMEOUT:=3600}" + ;; + mi350|mi355) + : "${CONTAINER:=yambda_primus}" + # meta64: NFS mmap is fine — read dataset + write ckpt directly on NFS + # (matches the original MI350 test). /apps/chcai/dlrm_data is the gin default. + [[ "$DATA_PATH" == "__AUTO__" ]] && DATA_PATH=/apps/chcai/dlrm_data + : "${CKPT_ROOT:=/apps/chcai/ckpts_resume_test}" + # Shared NFS: each full-model DCP save/load is ~9 min. The midwindow resume + # phase chains a LOAD + multiple in-window saves + an end-of-window save + # (>2000s observed), so the B200 budgets are far too tight — abandoning a + # still-running trainer leaks GPU VRAM and OOMs the next phase. Be generous. + : "${MW_TIMEOUT:=5400}" + : "${MW_RUN_TIMEOUT:=5400}" + ;; +esac +echo "[$(date)] platform=$PLATFORM container=$CONTAINER data_path=${DATA_PATH:-} ckpt_root=$CKPT_ROOT phase_timeout=${MW_TIMEOUT}s mw_run_timeout=${MW_RUN_TIMEOUT}s" + +mkdir -p "$LOG_DIR" +PYHELPER="$REPO/generative_recommenders/dlrm_v3/train/tests/streaming_resume_test.py" + +# --- container helpers (inspect CKPT/dataset via docker exec — works whether the +# path is node-local on B200 or shared NFS on MI350) --- +sx() { srun --jobid="$JOBID" --overlap docker exec "$CONTAINER" bash -lc "$1" 2>/dev/null; } + +# Kill any lingering trainer procs from a prior phase AND block until they are +# really gone, so the freed GPU VRAM is reclaimed before the next phase shards +# its embedding tables (otherwise it OOMs on the leaked memory). +# * Bracketed patterns ([t]rain_ranker, …) are REQUIRED: a plain `pkill -f +# train_ranker` issued inside `bash -lc "...train_ranker..."` matches its OWN +# command line and SIGKILLs this very shell (docker exec returns 137), which +# silently aborted the rest of the old cleanup and leaked trainers/VRAM. +# * After signalling, poll until no trainer remains (bounded), then a short +# settle so the driver finishes reclaiming device memory. +cleanup_workers() { + sx ' + for pat in "[t]rain_ranker" "[g]enerative_recommenders" "[s]pawn_main" "[m]ultiprocessing"; do + pkill -9 -f "$pat" 2>/dev/null + done + for _ in $(seq 1 30); do + pgrep -f "[t]rain_ranker" >/dev/null 2>&1 || \ + pgrep -f "[g]enerative_recommenders" >/dev/null 2>&1 || break + sleep 2 + done + sleep 3; true' || true +} +clean_ckpt() { sx "rm -rf '$1'" || true; } + +# Precheck: the dataset cache must be readable at $DATA_PATH. On B200 it must be +# staged node-local (the supervisor's stage_data_in does this) since mmap from +# virtiofs/NFS wedges; on MI350 it reads directly from NFS. Skipped when DATA_PATH +# is empty (the trainer falls back to its gin default and we don't know the path). +precheck_data() { + [[ -z "$DATA_PATH" ]] && { echo "[$(date)] data path unset — trainer will use its gin default; skipping precheck"; return 0; } + local ok + ok=$(sx "[ -d '$DATA_PATH/$DATASET_SUBDIR' ] && echo yes || echo no") + if [[ "$ok" != "yes" ]]; then + echo "FAIL: dataset cache not found at $DATA_PATH/$DATASET_SUBDIR inside '$CONTAINER' (platform=$PLATFORM)." + if [[ "$PLATFORM" == "b200" ]]; then + echo " B200: stage it node-local first (the e2e supervisor does this via stage_data_in)," + echo " or pass --data-path to an already-staged local mirror. mmap from virtiofs/NFS wedges." + else + echo " MI350/MI355: pass --data-path to the NFS dataset root (gin default is /apps/chcai/dlrm_data)." + fi + exit 1 + fi +} + +# Wait (host-side grep on the shared-NFS log) for a target regex OR a crash +# sentinel. 0=target found, 1=crash first, 2=timeout. +wait_for_log() { + local log="$LOG_DIR/$1.log"; local target_re="$2"; local timeout_s="${3:-1800}" + local elapsed=0 + while (( elapsed < timeout_s )); do + grep -qE "$target_re" "$log" 2>/dev/null && return 0 + grep -qE "Traceback|RuntimeError|OutOfMemoryError|CUDA error" "$log" 2>/dev/null && return 1 + sleep 5; elapsed=$((elapsed + 5)) + done + return 2 +} + +# Launch one trainer phase (detached), appending a PHASE_EXIT sentinel after the +# trainer returns (clean OR crash) — exactly like the production supervisor. The +# common env (data path, mode, start_ts, barrier debug) is fixed; per-phase knobs +# are passed as additional "K=V" words. +run_phase() { + local name="$1"; shift + local log="$LOG_DIR/${name}.log" + # `$*` (joined into ONE word), NOT `$@`: embedded mid-string in the + # double-quoted `bash -lc "..."`, `$@` would expand to multiple args and + # bash -lc would stop after the first override (silent 0-byte log). + local env_overrides="$*" + # Inject DLRM_DATA_PATH only when a path is set; an empty DATA_PATH means + # "use the trainer's gin default" (the meta64 NFS root). + local data_env="" + [[ -n "$DATA_PATH" ]] && data_env="DLRM_DATA_PATH=$DATA_PATH" + : > "$log" + echo "[$(date)] === phase '$name' ===" + cleanup_workers + srun --jobid="$JOBID" --overlap docker exec -d "$CONTAINER" bash -lc " + cd $REPO && + $data_env \ + HSTU_HAMMER_KERNEL=TRITON \ + MODE=streaming-train-eval \ + START_TS=$START_TS \ + WINDOW_BARRIER_DEBUG=1 \ + RUN_NAME=resume_test_$name \ + LOG=$log \ + $env_overrides \ + bash scripts/launch_slurm.sh; + echo \"PHASE_EXIT=\$? \$(date '+%F %T')\" >> $log + " +} + +# Read a scalar field from a summarize-JSON. +jget() { python3 -c "import json,sys;print(json.load(open(sys.argv[1])).get(sys.argv[2]))" "$1" "$2"; } + +FAIL=0 +fail() { echo "FAIL: $*"; FAIL=1; } + +precheck_data + +# ============================================================================= +# SCENARIO: midwindow +# ============================================================================= +run_midwindow() { + echo "########## scenario: midwindow ##########" + local LAST_STEP=$NUM_TRAIN_BATCHES + if (( DIE_AT_STEP <= 0 || DIE_AT_STEP >= NUM_TRAIN_BATCHES )); then + echo "Warning: die_at_step=$DIE_AT_STEP not strictly inside window (0, $NUM_TRAIN_BATCHES)" >&2 + fi + if (( DIE_AT_STEP % IN_WINDOW_FREQ != 0 )); then + echo "Warning: die_at_step=$DIE_AT_STEP not a multiple of in_window_freq=$IN_WINDOW_FREQ; no save lands exactly at crash" >&2 + fi + + # P1 baseline + clean_ckpt "$CKPT_ROOT" + run_phase baseline \ + "NUM_TRAIN_TS=1" "EVAL_EVERY_N_WINDOWS=0" "METRIC_LOG_FREQ=1" \ + "NUM_TRAIN_BATCHES=$NUM_TRAIN_BATCHES" "NUM_EVAL_BATCHES=$NUM_EVAL_BATCHES" \ + "TRAIN_SPLIT_PERCENTAGE=1.0" "DIE_AT_STEP=-1" + wait_for_log baseline "PHASE_EXIT=0" "$MW_TIMEOUT"; local rc=$? + cleanup_workers + (( rc != 0 )) && { echo "FAIL: midwindow baseline didn't finish (rc=$rc)"; tail -20 "$LOG_DIR/baseline.log"; return 1; } + + # P2 interrupted + clean_ckpt "$CKPT_ROOT" + run_phase interrupt \ + "NUM_TRAIN_TS=1" "EVAL_EVERY_N_WINDOWS=0" "METRIC_LOG_FREQ=1" \ + "NUM_TRAIN_BATCHES=$NUM_TRAIN_BATCHES" "NUM_EVAL_BATCHES=$NUM_EVAL_BATCHES" \ + "TRAIN_SPLIT_PERCENTAGE=1.0" \ + "IN_WINDOW_CKPT_FREQ=$IN_WINDOW_FREQ" "KEEP_LAST_N=1" \ + "DIE_AT_STEP=$DIE_AT_STEP" "CKPT_PATH=$CKPT_ROOT" + wait_for_log interrupt "die_at_step=$DIE_AT_STEP hit" "$MW_TIMEOUT"; rc=$? + cleanup_workers + (( rc != 0 )) && { echo "FAIL: interrupt didn't hit die_at_step (rc=$rc)"; tail -20 "$LOG_DIR/interrupt.log"; return 1; } + echo "Saved checkpoints after interrupt: $(sx "ls '$CKPT_ROOT' 2>/dev/null | tr '\n' ' '")" + + # P3 resume + run_phase resume \ + "NUM_TRAIN_TS=1" "EVAL_EVERY_N_WINDOWS=0" "METRIC_LOG_FREQ=1" \ + "NUM_TRAIN_BATCHES=$NUM_TRAIN_BATCHES" "NUM_EVAL_BATCHES=$NUM_EVAL_BATCHES" \ + "TRAIN_SPLIT_PERCENTAGE=1.0" \ + "IN_WINDOW_CKPT_FREQ=$IN_WINDOW_FREQ" "KEEP_LAST_N=1" \ + "DIE_AT_STEP=-1" "CKPT_PATH=$CKPT_ROOT" + # PHASE_EXIT=0 only after the (blocking) end-of-window save renames cleanly, + # so this also confirms the final atomic save completed. + wait_for_log resume "PHASE_EXIT=0" "$MW_TIMEOUT"; rc=$? + cleanup_workers + (( rc != 0 )) && { echo "FAIL: resume didn't finish (rc=$rc)"; tail -20 "$LOG_DIR/resume.log"; return 1; } + + # HARD functional invariants (deterministic; the real correctness proof). + if ! grep -qE "Resuming mid-window at train_ts=[0-9]+ batch_idx_in_window=$DIE_AT_STEP\b" "$LOG_DIR/resume.log" 2>/dev/null; then + fail "resume did not re-enter mid-window at batch_idx_in_window=$DIE_AT_STEP" + grep -E "Resuming" "$LOG_DIR/resume.log" 2>/dev/null | head -2 + fi + local rng_restored + rng_restored=$(grep -c "RNG state restored from" "$LOG_DIR/resume.log" 2>/dev/null || echo 0) + echo "RNG state restored on $rng_restored ranks" + (( rng_restored < 1 )) && fail "no RNG state restored on resume" + local first_resumed + first_resumed=$(grep -oE 'train - Step [0-9]+ metrics: \{.metric' "$LOG_DIR/resume.log" 2>/dev/null \ + | grep -oE 'Step [0-9]+' | awk '{print $2}' | sort -n | head -1) + echo "First resumed train step: $first_resumed (expect $((DIE_AT_STEP + 1)))" + [[ "$first_resumed" != "$((DIE_AT_STEP + 1))" ]] && fail "resume did not continue at step $((DIE_AT_STEP + 1)) (got $first_resumed)" + + # On-disk: atomic save + retention. + local num_ckpt stale_ckpt + num_ckpt=$(sx "ls '$CKPT_ROOT' 2>/dev/null | grep -E '^[0-9]+$' | wc -l" | tr -d ' ') + stale_ckpt=$(sx "ls '$CKPT_ROOT' 2>/dev/null | grep -E '\\.(tmp|old|staging)$' | wc -l" | tr -d ' ') + echo "Final: $num_ckpt numeric ckpt subdirs, $stale_ckpt stale dirs (expect 1, 0)" + [[ "$num_ckpt" != "1" ]] && fail "keep_last_n=1 violated (got $num_ckpt)" + [[ "$stale_ckpt" != "0" ]] && fail "stale .tmp/.old/.staging dirs left behind ($stale_ckpt)" + + # Trajectory closeness (loose sanity bound, NOT bit-equality). + python3 "$PYHELPER" parse "$LOG_DIR/baseline.log" "$LOG_DIR/traj_baseline.json" + python3 "$PYHELPER" parse "$LOG_DIR/resume.log" "$LOG_DIR/traj_resumed.json" + if ! python3 "$PYHELPER" compare "$LOG_DIR/traj_baseline.json" "$LOG_DIR/traj_resumed.json" \ + --min-resume-step $((DIE_AT_STEP + 1)) --atol "$ATOL"; then + fail "trajectory diverged beyond $ATOL (likely wrong data slice / unrestored state)" + fi + (( FAIL == 0 )) && echo "=== midwindow: PASS ===" || echo "=== midwindow: FAIL ===" +} + +# ============================================================================= +# SCENARIO: multiwindow (regression guard for the broadcast + barrier fixes) +# ============================================================================= +# Common split contract — MUST be byte-identical between mw_seed and mw_resume, +# else the resume aborts on a split-contract mismatch (the holdout_ts default of +# start_ts+num_train_ts differs between a 1-window seed and an MW_TS resume, so +# it is PINNED here). +MW_SPLIT_ENV=( "TRAIN_SPLIT_PERCENTAGE=$MW_SPLIT" "SPLIT_SALT=0" + "EVAL_HOLDOUT_TS=$MW_HOLDOUT_TS" "EVAL_HOLDOUT_NUM_WINDOWS=1" ) + +run_multiwindow() { + echo "########## scenario: multiwindow ##########" + local sum + + # P1 mw_baseline — cold multi-window run with data-pct eval. + clean_ckpt "$CKPT_ROOT" + run_phase mw_baseline \ + "NUM_TRAIN_TS=$MW_TS" "NUM_TRAIN_BATCHES=$MW_BATCHES" "NUM_EVAL_BATCHES=$MW_EVAL_BATCHES" \ + "EVAL_EVERY_N_WINDOWS=0" "EVAL_EVERY_DATA_PCT=$MW_EVAL_PCT" "METRIC_LOG_FREQ=1" \ + "${MW_SPLIT_ENV[@]}" + wait_for_log mw_baseline "PHASE_EXIT=" "$MW_RUN_TIMEOUT"; local rc=$? + cleanup_workers + (( rc == 1 )) && { echo "FAIL: mw_baseline crashed"; tail -30 "$LOG_DIR/mw_baseline.log"; return 1; } + (( rc == 2 )) && { echo "FAIL: mw_baseline timed out (possible boundary deadlock)"; tail -30 "$LOG_DIR/mw_baseline.log"; return 1; } + + sum="$LOG_DIR/mw_baseline.summary.json" + python3 "$PYHELPER" summarize "$LOG_DIR/mw_baseline.log" "$sum" >/dev/null + echo "--- mw_baseline summary ---"; cat "$sum" + local exit_code anchors barriers dpct_setup dpct_trig + exit_code=$(jget "$sum" phase_exit) + anchors=$(jget "$sum" total_train_anchors_calls) + barriers=$(jget "$sum" window_barrier_count) + dpct_setup=$(jget "$sum" data_pct_eval_setup) + dpct_trig=$(jget "$sum" data_pct_eval_trigger_count) + # (barrier B) ran through ALL windows and exited 0 — no boundary deadlock. + [[ "$exit_code" != "0" ]] && fail "mw_baseline did not complete cleanly (phase_exit=$exit_code)" + [[ "$barriers" != "$MW_TS" ]] && fail "window barrier fired $barriers times, expected $MW_TS (one per window; need world_size>=2)" + # (broadcast A) total_train_anchors computed exactly once (rank 0), not Nx. + # It is computed at loop SETUP (before any training), so this exercises the + # broadcast regardless of whether an eval later fires. + [[ "$anchors" != "1" ]] && fail "total_train_anchors computed $anchors times, expected 1 (rank-0 broadcast regressed)" + # data-fraction eval cadence set up (the path that needs total_train_anchors). + [[ "$dpct_setup" != "True" ]] && fail "data-pct eval cadence not set up (total_train_anchors path not reached)" + # Trigger firing depends on (full-window) anchor count vs the few test steps, + # so it is informational — not required to exercise the broadcast fix. + echo "data-pct eval triggers fired: $dpct_trig (informational)" + + # P2 mw_seed — 1 window → clean WINDOW_COMPLETE checkpoint. + clean_ckpt "$CKPT_ROOT" + run_phase mw_seed \ + "NUM_TRAIN_TS=1" "NUM_TRAIN_BATCHES=$MW_BATCHES" "NUM_EVAL_BATCHES=$MW_EVAL_BATCHES" \ + "EVAL_EVERY_N_WINDOWS=0" "EVAL_EVERY_DATA_PCT=$MW_EVAL_PCT" "METRIC_LOG_FREQ=1" \ + "KEEP_LAST_N=1" "CKPT_PATH=$CKPT_ROOT" "${MW_SPLIT_ENV[@]}" + wait_for_log mw_seed "PHASE_EXIT=0" "$MW_RUN_TIMEOUT"; rc=$? + cleanup_workers + (( rc != 0 )) && { echo "FAIL: mw_seed didn't finish/checkpoint (rc=$rc)"; tail -30 "$LOG_DIR/mw_seed.log"; return 1; } + local seed_ckpt + seed_ckpt=$(sx "ls '$CKPT_ROOT' 2>/dev/null | grep -E '^[0-9]+$' | sort -n | tail -1" | tr -d ' ') + echo "mw_seed end-of-window checkpoint: ${seed_ckpt:-} (expect $START_TS)" + [[ "$seed_ckpt" != "$START_TS" ]] && { fail "mw_seed did not save end-of-window ckpt $START_TS (got '$seed_ckpt')"; return 1; } + + # P3 mw_resume — relaunch over MW_TS windows; resume past the completed + # window and CROSS the boundary into the remaining windows (the exact case + # that used to deadlock). The full split contract matches the seed. + run_phase mw_resume \ + "NUM_TRAIN_TS=$MW_TS" "NUM_TRAIN_BATCHES=$MW_BATCHES" "NUM_EVAL_BATCHES=$MW_EVAL_BATCHES" \ + "EVAL_EVERY_N_WINDOWS=0" "EVAL_EVERY_DATA_PCT=$MW_EVAL_PCT" "METRIC_LOG_FREQ=1" \ + "KEEP_LAST_N=1" "CKPT_PATH=$CKPT_ROOT" "${MW_SPLIT_ENV[@]}" + wait_for_log mw_resume "PHASE_EXIT=" "$MW_RUN_TIMEOUT"; rc=$? + cleanup_workers + (( rc == 1 )) && { echo "FAIL: mw_resume crashed"; tail -30 "$LOG_DIR/mw_resume.log"; return 1; } + (( rc == 2 )) && { echo "FAIL: mw_resume timed out (possible boundary deadlock on resume)"; tail -30 "$LOG_DIR/mw_resume.log"; return 1; } + + sum="$LOG_DIR/mw_resume.summary.json" + python3 "$PYHELPER" summarize "$LOG_DIR/mw_resume.log" "$sum" >/dev/null + echo "--- mw_resume summary ---"; cat "$sum" + local r_exit r_resume_ts r_anchors r_barriers r_dpct + r_exit=$(jget "$sum" phase_exit) + r_resume_ts=$(jget "$sum" resume_completed_ts) + r_anchors=$(jget "$sum" total_train_anchors_calls) + r_barriers=$(jget "$sum" window_barrier_count) + r_dpct=$(jget "$sum" data_pct_eval_setup) + # Resumed from the completed seed window (advanced past the boundary). + [[ "$r_resume_ts" != "$START_TS" ]] && fail "mw_resume did not resume from completed train_ts=$START_TS (got $r_resume_ts)" + # Crossed the boundary into the remaining MW_TS-1 windows and exited 0. + [[ "$r_exit" != "0" ]] && fail "mw_resume did not complete cleanly (phase_exit=$r_exit) — boundary deadlock on resume?" + [[ "$r_barriers" != "$((MW_TS - 1))" ]] && fail "mw_resume barrier fired $r_barriers times, expected $((MW_TS - 1)) (windows after the resumed one)" + # Broadcast still once on the resume path; data-pct cadence rebuilt. + [[ "$r_anchors" != "1" ]] && fail "mw_resume total_train_anchors computed $r_anchors times, expected 1" + [[ "$r_dpct" != "True" ]] && fail "mw_resume data-pct eval cadence not set up" + + (( FAIL == 0 )) && echo "=== multiwindow: PASS ===" || echo "=== multiwindow: FAIL ===" +} + +# ============================================================================= +[[ "$SCENARIO" == "midwindow" || "$SCENARIO" == "all" ]] && run_midwindow +[[ "$SCENARIO" == "multiwindow" || "$SCENARIO" == "all" ]] && run_multiwindow + +if [[ "$KEEP" != "1" ]]; then + rm -rf "$LOG_DIR" + clean_ckpt "$CKPT_ROOT" +fi + +if (( FAIL == 0 )); then + echo "=== PASS: all selected scenarios validated ===" + exit 0 +fi +echo "=== FAIL: one or more scenarios failed (see above) ===" +exit 1 diff --git a/recommendation_v4/generative_recommenders/dlrm_v3/train/tests/train_test.py b/recommendation_v4/generative_recommenders/dlrm_v3/train/tests/train_test.py new file mode 100644 index 000000000..dfd58f6e5 --- /dev/null +++ b/recommendation_v4/generative_recommenders/dlrm_v3/train/tests/train_test.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict +import unittest + +from generative_recommenders.common import gpu_unavailable +from generative_recommenders.dlrm_v3.train.train_ranker import main + + +class DLRMV3TrainTest(unittest.TestCase): + @unittest.skipIf(*gpu_unavailable) + def test_e2e(self) -> None: + main() + + +if __name__ == "__main__": + unittest.main() diff --git a/recommendation_v4/generative_recommenders/dlrm_v3/train/train_ranker.py b/recommendation_v4/generative_recommenders/dlrm_v3/train/train_ranker.py new file mode 100644 index 000000000..a57153f60 --- /dev/null +++ b/recommendation_v4/generative_recommenders/dlrm_v3/train/train_ranker.py @@ -0,0 +1,370 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict +import argparse +import logging + +logging.basicConfig(level=logging.INFO) +import os +import sys +import traceback + +import gin +import torch +from torch import multiprocessing as mp +from torchrec.test_utils import get_free_port + +# NOTE: heavy imports of generative_recommenders.dlrm_v3.* are deferred to +# inside _main_func so that gin-driven env-var bootstrap (see +# _env_bootstrap.apply_env_bootstrap) can run BEFORE the triton kernel +# modules evaluate their `@triton.autotune` decorators at module-load time. + +logger: logging.Logger = logging.getLogger(__name__) + + +SUPPORTED_CONFIGS = { + "debug": "debug.gin", + "kuairand-1k": "kuairand_1k.gin", + "movielens-1m": "movielens_1m.gin", + "movielens-20m": "movielens_20m.gin", + "movielens-13b": "movielens_13b.gin", + "movielens-18b": "movielens_18b.gin", + "streaming-400m": "streaming_400m.gin", + "streaming-200b": "streaming_200b.gin", + "streaming-100b": "streaming_100b.gin", + "yambda-5b": "yambda_5b.gin", +} + + +def _main_func( + local_rank: int, + world_size: int, + node_rank: int, + gpus_per_node: int, + master_addr: str, + master_port: int, + gin_file: str, + mode: str, +) -> None: + # `local_rank` is the index handed out by mp.start_processes (0..gpus_per_node-1) + # and indexes this node's GPUs. The GLOBAL rank is what every downstream + # consumer wants (data sharding via StreamingWindowSampler, checkpoint I/O, + # metrics), so derive it once and pass it through as `rank`. Only the CUDA + # device must be node-local. Single-node (node_rank=0) → rank == local_rank, + # exactly as before. + rank = node_rank * gpus_per_node + local_rank + device = torch.device(f"cuda:{local_rank}") + logger.info( + f"rank: {rank} (node_rank={node_rank} local_rank={local_rank}), " + f"world_size: {world_size}, device: {device}" + ) + # Phase 1: parse gin early with skip_unknown=True so env-bootstrap + # bindings take effect BEFORE any module-level @gin.configurable + # discovers itself. This is required because triton @triton.autotune + # decorators in generative_recommenders.ops.triton.* read env vars at + # module import time, and the heavy imports below pull those in. + from generative_recommenders.dlrm_v3.train._env_bootstrap import apply_env_bootstrap + from generative_recommenders.dlrm_v3.train.mlperf_logging_utils import ( + get_mlperf_logger, + mlperf_checkpoint_present, + ) + + gin.parse_config_file(gin_file, skip_unknown=True) + apply_env_bootstrap() + + # Cold-start vs resume, decided from the on-disk checkpoint BEFORE setup so + # the one-time INIT/RUN markers fire on a genuine cold start only and are NOT + # re-emitted on a resume relaunch — the MLPerf run (run_start..run_stop) spans + # the resume as a single coherent event stream in one appended log file. + mlperf_resume = mlperf_checkpoint_present(os.environ.get("CKPT_PATH", "")) + # Rank-0-gated MLPerf logger, only for the streaming-train-eval path. `fresh` + # truncates the log on cold start (one run_start per file) but appends on a + # resume so the pre-crash events are preserved and continued. + mlperf_logger = ( + get_mlperf_logger(rank=rank, fresh=not mlperf_resume) + if mode == "streaming-train-eval" + else None + ) + # INIT_START fires before setup on a cold start only (resume continues the + # already-open run, whose markers were emitted by the original process). + mlperf_cold_start = mlperf_logger is not None and not mlperf_resume + if mlperf_cold_start: + mlperf_logger.event(key=mlperf_logger.constants.CACHE_CLEAR, value=True) + mlperf_logger.start(key=mlperf_logger.constants.INIT_START) + + # Phase 2: heavy imports. Triton kernel modules evaluate their autotune + # decorators here, using the env vars set above. + from generative_recommenders.dlrm_v3.checkpoint import load_dmp_checkpoint + from generative_recommenders.dlrm_v3.train.utils import ( + cleanup, + decorrelate_runtime_rng, + eval_loop, + make_model, + make_optimizer_and_shard, + make_train_test_dataloaders, + seed_everything, + setup, + streaming_train_eval_loop, + train_eval_loop, + train_loop, + ) + from generative_recommenders.dlrm_v3.utils import ( + MetricsLogger, + get_gpu_peak_flops, + ) + + setup( + rank=rank, + world_size=world_size, + master_addr=master_addr, + master_port=master_port, + device=device, + ) + # Phase 3: re-parse to bind the @gin.configurables now that they are + # registered. The earlier skip_unknown pass already consumed the + # env-bootstrap binding, but bindings are idempotent so re-applying is + # fine, and this pass is the one that actually wires up make_model, + # make_train_test_dataloaders, etc. + gin.parse_config_file(gin_file) + + # Seed all RNGs (gin-configurable $SEED) BEFORE make_model() so weight init + # is reproducible run-to-run. Must follow the full parse above so the binding + # is wired, and precede make_model() below. + seed_everything(rank=rank) + + model, model_configs, embedding_table_configs = make_model() + model, optimizer = make_optimizer_and_shard( + model=model, + device=device, + world_size=world_size, + local_world_size=gpus_per_node, + embedding_table_configs=embedding_table_configs, + ) + # Decorrelate forward-time stochasticity (HSTU dropout) per data-parallel + # rank. MUST run after make_model() + make_optimizer_and_shard() so the + # replicated dense weights and sharded embeddings stay init-identical across + # ranks; this only offsets the global RNG by rank so dropout masks differ. + decorrelate_runtime_rng(rank=rank) + train_dataloader, test_dataloader = make_train_test_dataloaders( + hstu_config=model_configs, + embedding_table_configs=embedding_table_configs, + ) + # TFLOPS/MFU reporting: query the model's static dense estimate + + # current GPU's peak FLOPS. Both default to 0 if the model doesn't + # expose get_num_flops_per_sample, in which case MetricsLogger silently + # drops the tflops fields from the perf line. + inner_model = model.module if hasattr(model, "module") else model + num_flops_per_sample = ( + float(inner_model.get_num_flops_per_sample()) + if hasattr(inner_model, "get_num_flops_per_sample") + else 0.0 + ) + gpu_peak_flops = get_gpu_peak_flops( + "bf16" if getattr(model_configs, "bf16_training", True) else "fp32" + ) + # Streaming fixed-holdout eval uses the dual fresh/cumulative metric sets: + # window_* = fresh per-pass full-holdout, lifetime_* = cumulative across + # passes (AUC via O(bins) histogram). Other modes keep the legacy single set. + metrics = MetricsLogger( + multitask_configs=model_configs.multitask_configs, + batch_size=train_dataloader.batch_size, + window_size=2500, + device=device, + rank=rank, + # Pass the live world_size so metric normalization is correct at any + # node count; the gin's MetricsLogger.world_size default (=8) is only a + # single-node fallback and would mis-normalize a multi-node run. + world_size=world_size, + num_flops_per_sample=num_flops_per_sample, + gpu_peak_flops=gpu_peak_flops, + model=model, + eval_cumulative=(mode == "streaming-train-eval"), + # Lifetime-AUC backend + bins/window come from gin (see yambda_5b.gin: + # MetricsLogger.{train,eval}_lifetime_auc_mode / cumulative_auc_bins / + # lifetime_auc_window), env-overridable. eval_cumulative stays explicit + # because it is runtime-mode dependent, not a config knob. + ) + # Capture streaming resume hint (None for cold start / non-streaming + # checkpoints). For the streaming-train-eval mode, we forward this into + # streaming_train_eval_loop so it can advance past the last completed + # window OR re-enter the partial window and skip already-trained batches. + resume_train_ts, resume_batch_idx_in_window, resume_split_contract, resume_cold_start = ( + load_dmp_checkpoint( + model=model, + optimizer=optimizer, + metric_logger=metrics, + device=device, + rank=rank, + ) + ) + + # MLPerf run markers: open the run exactly once. On a cold start emit + # submission info + hyperparameters + INIT_STOP/RUN_START and mark the run as + # started (persisted in the checkpoint via metrics.mlperf_run_started). On a + # resume, load_dmp_checkpoint restored mlperf_run_started=True, so we skip the + # markers and just continue the stream. `metrics.mlperf_run_started` guards a + # double-emit even if cold/resume detection and the checkpoint ever disagree. + if mlperf_cold_start and not metrics.mlperf_run_started: + # Submission info + hyperparameters + INIT_STOP/RUN_START, all emitted by + # the logger (optimizer names/LRs read from gin internally). Seed is the + # value setup() resolved and exported to $SEED. + mlperf_logger.log_run_start( + global_batch_size=world_size * int(train_dataloader.batch_size), + seed=int(os.environ.get("SEED", "1")), + ) + metrics.mlperf_run_started = True + # Pass the logger to the loop whenever MLPerf logging is enabled, so block / + # eval / train_loss / run_stop events emit on BOTH a cold start and a resume. + mlperf_run_active = mlperf_logger is not None and metrics.mlperf_run_started + + # train loop + try: + if mode == "train": + train_loop( + rank=rank, + model=model, + dataloader=train_dataloader, + optimizer=optimizer, + metric_logger=metrics, + device=device, + ) + elif mode == "eval": + # reinit metrics logger for eval + metrics = MetricsLogger( + multitask_configs=model_configs.multitask_configs, + batch_size=train_dataloader.batch_size, + window_size=1000, + device=device, + rank=rank, + world_size=world_size, + ) + eval_loop( + rank=rank, + model=model, + dataloader=test_dataloader, + metric_logger=metrics, + device=device, + ) + elif mode == "train-eval": + train_eval_loop( + rank=rank, + model=model, + train_dataloader=train_dataloader, + eval_dataloader=test_dataloader, + optimizer=optimizer, + metric_logger=metrics, + device=device, + ) + elif mode == "streaming-train-eval": + streaming_train_eval_loop( + rank=rank, + model=model, + optimizer=optimizer, + metric_logger=metrics, + device=device, + hstu_config=model_configs, + embedding_table_configs=embedding_table_configs, + resume_train_ts=resume_train_ts, + resume_batch_idx_in_window=resume_batch_idx_in_window, + resume_split_contract=resume_split_contract, + resume_cold_start=resume_cold_start, + # Only pass the logger when run boundaries were emitted, so the + # loop never produces orphan block/eval events. + mlperf_logger=mlperf_logger if mlperf_run_active else None, + ) + except Exception as e: + logger.info(traceback.format_exc()) + raise Exception(e) + finally: + # Graceful distributed teardown on both success and failure: barrier so + # all ranks finish in lockstep, then destroy the process group (best- + # effort) to avoid noisy TCPStore/NCCL shutdown warnings at exit. + if torch.distributed.is_initialized(): + try: + torch.distributed.barrier() + except Exception: + logger.info("teardown barrier failed (non-fatal)") + try: + cleanup() + except Exception: + logger.info("teardown destroy_process_group failed (non-fatal)") + + +def get_args(): # pyre-ignore [3] + """Parse commandline.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--dataset", default="debug", choices=SUPPORTED_CONFIGS.keys(), help="dataset" + ) + parser.add_argument( + "--mode", + default="train", + choices=["train", "eval", "train-eval", "streaming-train-eval"], + help="mode", + ) + args, unknown_args = parser.parse_known_args() + logger.warning(f"unknown_args: {unknown_args}") + return args + + +def main() -> None: + args = get_args() + logger.info(args) + assert args.dataset in SUPPORTED_CONFIGS, f"Unsupported dataset: {args.dataset}" + assert args.mode in [ + "train", + "eval", + "train-eval", + "streaming-train-eval", + ], f"Unsupported mode: {args.mode}" + # Distributed topology (single-node defaults reproduce the legacy behavior): + # GPUS_PER_NODE local procs to spawn on THIS node (default: all visible GPUs) + # NNODES/NODE_RANK multi-node fan-out, set by the SLURM launcher + # WORLD_SIZE global rank count = NNODES * GPUS_PER_NODE + # MASTER_ADDR/PORT rank-0 rendezvous; the port MUST match across nodes, so + # honor it from the env when set and only fall back to a + # random free port for the standalone single-node path. + GPUS_PER_NODE = int(os.environ.get("GPUS_PER_NODE", 0)) or torch.cuda.device_count() + NNODES = int(os.environ.get("NNODES", 1)) + NODE_RANK = int(os.environ.get("NODE_RANK", 0)) + WORLD_SIZE = NNODES * GPUS_PER_NODE + MASTER_ADDR = os.environ.get("MASTER_ADDR", "localhost") + MASTER_PORT = str(os.environ.get("MASTER_PORT") or get_free_port()) + gin_path = f"{os.path.dirname(__file__)}/gin/{SUPPORTED_CONFIGS[args.dataset]}" + logger.info( + f"launching: nnodes={NNODES} node_rank={NODE_RANK} " + f"gpus_per_node={GPUS_PER_NODE} world_size={WORLD_SIZE} " + f"master={MASTER_ADDR}:{MASTER_PORT}" + ) + + mp.start_processes( + _main_func, + args=( + WORLD_SIZE, + NODE_RANK, + GPUS_PER_NODE, + MASTER_ADDR, + MASTER_PORT, + gin_path, + args.mode, + ), + nprocs=GPUS_PER_NODE, + join=True, + start_method="spawn", + ) + + +if __name__ == "__main__": + main() diff --git a/recommendation_v4/generative_recommenders/dlrm_v3/train/utils.py b/recommendation_v4/generative_recommenders/dlrm_v3/train/utils.py new file mode 100644 index 000000000..f226595a8 --- /dev/null +++ b/recommendation_v4/generative_recommenders/dlrm_v3/train/utils.py @@ -0,0 +1,2938 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict +import logging +import os +import threading +import time +from collections.abc import Iterator +from datetime import timedelta +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Type, + Union, +) + +import gin +import torch +import torchrec +from generative_recommenders.dlrm_v3.checkpoint import save_dmp_checkpoint, WINDOW_COMPLETE +from generative_recommenders.dlrm_v3.configs import ( + get_embedding_table_config, + get_hstu_configs, +) +from generative_recommenders.dlrm_v3.datasets.dataset import ( + collate_fn, + Dataset, + Samples, +) +from generative_recommenders.dlrm_v3.utils import get_dataset, MetricsLogger, Profiler +from generative_recommenders.common import HammerKernel +from generative_recommenders.modules.dlrm_hstu import DlrmHSTU, DlrmHSTUConfig +from torch import distributed as dist +from torch.distributed.optim import ( + _apply_optimizer_in_backward as apply_optimizer_in_backward, +) +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader, Dataset as TorchDataset +from torch.utils.data.distributed import _T_co, DistributedSampler +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.model_parallel import DistributedModelParallel +from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology +from torchrec.distributed.planner.types import ParameterConstraints +from torchrec.distributed.sharding_plan import get_default_sharders +from torchrec.distributed.types import ShardedTensor, ShardingEnv, ShardingType +from torchrec.modules.embedding_configs import EmbeddingConfig +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) +from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper +from torchrec.optim.optimizers import in_backward_optimizer_filter +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + +logger: logging.Logger = logging.getLogger(__name__) + +TORCHREC_TYPES: Set[Type[Union[EmbeddingBagCollection, EmbeddingCollection]]] = { + EmbeddingBagCollection, + EmbeddingCollection, +} + +# Embedding placement vocabulary -> torchrec compute kernel. Used by +# make_optimizer_and_shard to translate the gin/env placement strings +# ("hbm"/"uvm"/"uvm_caching") into ParameterConstraints. "auto" (or anything not +# in this map) means "no constraint": the planner decides from the HBM cap. +_PLACEMENT_TO_KERNEL: Dict[str, EmbeddingComputeKernel] = { + "hbm": EmbeddingComputeKernel.FUSED, + "uvm": EmbeddingComputeKernel.FUSED_UVM, + "uvm_caching": EmbeddingComputeKernel.FUSED_UVM_CACHING, +} + +# Per-table sharding-type vocabulary -> torchrec ShardingType. Used by +# make_optimizer_and_shard to pin a table's shard layout via ParameterConstraints +# (e.g. move a hot, data-skewed table off ROW_WISE to COLUMN_WISE so its +# embedding all-to-all load is balanced by rank instead of routed by row/value). +# "auto" (or anything not in this map) means "no constraint": the planner decides. +# Short aliases (rw/cw/tw/twrw) are accepted alongside the canonical names. +_SHARDING_TO_TYPE: Dict[str, ShardingType] = { + "row_wise": ShardingType.ROW_WISE, + "column_wise": ShardingType.COLUMN_WISE, + "table_wise": ShardingType.TABLE_WISE, + "table_row_wise": ShardingType.TABLE_ROW_WISE, + "table_column_wise": ShardingType.TABLE_COLUMN_WISE, + "data_parallel": ShardingType.DATA_PARALLEL, + "rw": ShardingType.ROW_WISE, + "cw": ShardingType.COLUMN_WISE, + "tw": ShardingType.TABLE_WISE, + "twrw": ShardingType.TABLE_ROW_WISE, +} + + +@gin.configurable +def seed_everything(seed: int = -1, rank: int = 0) -> None: + """Seed all RNGs (same value on every rank) for reproducible dense weight init. + + Call right before make_model(), after setup() (process group needed for the + broadcast) and the gin parse. seed < 0 ($SEED unset) draws a fresh random seed + per run (rank 0 broadcasts; exported to $SEED); seed >= 0 reproduces a run. + Data order/split are independent of this seed (StreamingWindowSampler/$SPLIT_SALT). + """ + import random + + import numpy as np + + pinned = seed >= 0 + if not pinned: + # rank 0 draws a random seed and broadcasts it so all ranks agree (an + # identical seed on every rank is REQUIRED, else dense weight init + # diverges across ranks and DDP/AllReduce trains garbage). + if dist.is_available() and dist.is_initialized(): + device = torch.device(f"cuda:{torch.cuda.current_device()}") + drawn = int.from_bytes(os.urandom(4), "little") if rank == 0 else 0 + _seed_t = torch.tensor([drawn], dtype=torch.int64, device=device) + dist.broadcast(_seed_t, src=0) + seed = int(_seed_t.item()) + else: + seed = int.from_bytes(os.urandom(4), "little") + # Export the resolved value so the run is reproducible after the fact. + os.environ["SEED"] = str(seed) + + logger.info( + f"[rank {rank}] seeding all RNGs with SEED={seed} " + f"({'pinned via $SEED' if pinned else 'random per-run; set $SEED to reproduce'})" + ) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +@gin.configurable +def decorrelate_runtime_rng(rank: int = 0, enabled: bool = True) -> None: + """Re-seed torch/cuda with $SEED + rank so HSTU dropout draws different masks + per data-parallel rank (seed_everything's identical seed would draw the same). + + MUST run after make_model() + make_optimizer_and_shard() so init stays + identical across ranks; it perturbs only forward-time stochasticity. + Reproducible (pure fn of $SEED + rank; RNG state checkpointed). enabled=False + keeps the legacy identical-mask behavior. + """ + if not enabled: + logger.info( + f"[rank {rank}] decorrelate_runtime_rng disabled; dropout masks " + f"identical across ranks" + ) + return + base = int(os.environ.get("SEED", "1")) + offset_seed = base + int(rank) + torch.manual_seed(offset_seed) + torch.cuda.manual_seed_all(offset_seed) + logger.info( + f"[rank {rank}] decorrelated runtime RNG: SEED={base} + rank={rank} " + f"=> {offset_seed} (per-rank dropout masks)" + ) + + +def setup( + rank: int, + world_size: int, + master_port: int, + device: torch.device, + master_addr: str = "localhost", +) -> dist.ProcessGroup: + # Default "localhost" keeps the single-node path unchanged; multi-node + # launches pass the rank-0 host so every node rendezvouses at the same addr. + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(master_port) + + BACKEND = dist.Backend.NCCL + # Process-group / NCCL watchdog timeout (seconds). Env-overridable so a + # diagnostic run can use a short, finite timeout that trips the NCCL flight + # recorder dump (TORCH_NCCL_TRACE_BUFFER_SIZE + TORCH_NCCL_DUMP_ON_TIMEOUT) + # on a collective desync instead of hanging for the full default. + TIMEOUT = int(os.environ.get("PG_TIMEOUT_S", "1800")) + + # set device BEFORE init_process_group so NCCL binds this rank to its + # own GPU; otherwise every rank's first CUDA context lands on GPU 0, + # leaving stale allocations and triggering OOMs on rank 0. + torch.cuda.set_device(device) + + # NOTE: RNG seeding for reproducible weight init lives in seed_everything(), + # which train_ranker calls right before make_model() (after setup() so the + # process group is initialized for the cross-rank seed broadcast, and after + # the full gin parse so the gin-configurable $SEED is bound). Seeding here + # would be too early to be gin-configurable. + + # initialize the process group + # + # The default PG timeout must match TIMEOUT (not the 600s NCCL default): + # checkpoint saves go through DCP collectives on *this* default PG, and the + # 560GB sparse-embedding write is both slow on shared NFS and badly skewed + # across ranks (shards range ~37GB..~95GB), so the fastest rank can sit in + # the post-write allgather/barrier well past 600s waiting for the slowest + # rank. The stock 600s watchdog then SIGABRTs an otherwise-healthy job. + if not dist.is_initialized(): + dist.init_process_group( + "nccl", + rank=rank, + world_size=world_size, + device_id=device, + timeout=timedelta(seconds=TIMEOUT), + ) + + pg = dist.new_group( + backend=BACKEND, + timeout=timedelta(seconds=TIMEOUT), + ) + + return pg + + +def cleanup() -> None: + dist.destroy_process_group() + + +def _window_boundary_barrier( + device: torch.device, world_size: int, train_ts: int +) -> None: + """Collective rendezvous at a streaming window boundary. + + The per-window data prep (``window_indices``: an O(N) mask over the ~18 GB + mmap'd ``anchor_ts`` array) can complete at very different times across + ranks. The embedding input-dist all-to-all that follows is a collective, so + if a fast rank reaches it while a slow rank is still in prep, the NCCL + stream desyncs and the job deadlocks (one rank a collective behind the + rest). Synchronizing here makes prep-time skew harmless: every rank waits + until all ranks have a ready window before any issues the first forward. + + Cost is one near-zero-payload barrier per window (299 total over a full + run). In the healthy case prep already overlapped the previous window's + compute via the prefetcher, so the barrier returns immediately; it only + blocks for the real prep skew it is there to absorb. + """ + if not (dist.is_available() and dist.is_initialized()) or world_size <= 1: + return + t0 = time.time() + if device.type == "cuda": + dist.barrier(device_ids=[device.index]) + else: + dist.barrier() + waited = time.time() - t0 + # Surface non-trivial skew (the thing this barrier exists to absorb) so a + # node with a slow rank is visible without trawling the flight recorder. + if waited > 5.0: + logger.warning( + "[window-barrier] train_ts=%d: waited %.1fs at boundary " + "rendezvous (per-rank data-prep skew)", + train_ts, + waited, + ) + # Test/debug observability: the healthy-path barrier is otherwise SILENT + # (the skew warning above only fires on >5s waits), so the resume e2e test + # has no signal that the boundary rendezvous actually executed. When + # WINDOW_BARRIER_DEBUG=1, rank 0 emits exactly one line per crossed window + # so the test can assert the barrier ran at EVERY boundary (regression guard + # for the desync deadlock the barrier fixes). Off by default — zero prod cost. + if os.environ.get("WINDOW_BARRIER_DEBUG") == "1" and dist.get_rank() == 0: + logger.info( + "[window-barrier] train_ts=%d rendezvous complete (waited %.3fs)", + train_ts, + waited, + ) + + +class HammerToTorchDataset(TorchDataset): + def __init__( + self, + dataset: Dataset, + ) -> None: + self.dataset: Dataset = dataset + + def __getitem__(self, idx: int) -> Tuple[KeyedJaggedTensor, KeyedJaggedTensor]: + self.dataset.load_query_samples([idx]) + sample = self.dataset.get_sample(idx) + self.dataset.unload_query_samples([idx]) + return sample + + def __getitems__( + self, indices: List[int] + ) -> List[Tuple[KeyedJaggedTensor, KeyedJaggedTensor]]: + self.dataset.load_query_samples(indices) + samples = [self.dataset.get_sample(i) for i in indices] + self.dataset.unload_query_samples(indices) + return samples + + +class _ChainedRanges: + """O(1) __len__ + O(log K) __getitem__ over a sequence of `range`s. + + Lets `torch.utils.data.Subset(dataset, _ChainedRanges([r1, r2, ...]))` + avoid materializing a Python list of all per-block indices (which at + multi-billion totals is ~28 B/int and dominates host RAM). + """ + + def __init__(self, ranges: List[range]) -> None: + self._ranges: List[range] = list(ranges) + offsets = [0] + for r in self._ranges: + offsets.append(offsets[-1] + len(r)) + self._offsets: List[int] = offsets + + def __len__(self) -> int: + return self._offsets[-1] + + def __getitem__(self, idx: int) -> int: + import bisect + if idx < 0: + idx += self._offsets[-1] + if idx < 0 or idx >= self._offsets[-1]: + raise IndexError(idx) + bucket = bisect.bisect_right(self._offsets, idx) - 1 + return self._ranges[bucket][idx - self._offsets[bucket]] + + +class ChunkDistributedSampler(DistributedSampler[_T_co]): + """ + Each rank reads a contiguous chunk (trunk) of the input data + """ + + def __init__( + self, + dataset: TorchDataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 1, + drop_last: bool = False, + ) -> None: + super().__init__( + dataset=dataset, + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + seed=seed, + drop_last=drop_last, + ) + + def __iter__(self) -> Iterator[_T_co]: + if self.shuffle: + g = torch.Generator() + g.manual_seed(self.seed + self.epoch * 1001 + int(self.rank)) + indices_t = torch.randperm(self.num_samples, generator=g) + else: + indices_t = torch.arange(self.num_samples) + assert self.drop_last is True, ( + "drop_last must be True for ChunkDistributedSampler" + ) + indices_t = indices_t + (self.num_samples * int(self.rank)) + assert indices_t.numel() == self.num_samples + # Iterate via the numpy view, NOT directly over the tensor: iter(Tensor) + # calls Tensor.unbind(0) which eagerly materializes one zero-dim Tensor + # object per element (~600 B each). For 40 M eval / 525 M train samples + # that's 24 GB / 315 GB of [heap] growth per rank, blowing host RAM + # before the first batch. numpy's iter yields one Python int at a time + # with O(1) extra memory. + indices_np = indices_t.numpy() + return (int(x) for x in indices_np) + + def set_epoch(self, epoch: int) -> None: + logger.warning(f"Setting epoch to {epoch}") + self.epoch = epoch + + +@gin.configurable +def make_model( + dataset: str, + bf16_training: bool = False, + hammer_kernel: Optional[str] = None, +) -> Tuple[torch.nn.Module, DlrmHSTUConfig, Dict[str, EmbeddingConfig]]: + hstu_config = get_hstu_configs(dataset) + table_config = get_embedding_table_config(dataset) + + # bf16 autocast is off by default: on the PYTORCH attn backend the + # pt_hstu_attention QK einsum backward overflows in bf16 at long + # sequences (NaN at step 1 when N>1k). Safe with TRITON; flip via + # `make_model.bf16_training = True` in the gin. + model = DlrmHSTU( + hstu_configs=hstu_config, + embedding_tables=table_config, + is_inference=False, + bf16_training=bf16_training, + ) + + # HSTU attention/compute kernel backend. Precedence: + # HSTU_HAMMER_KERNEL env var > make_model.hammer_kernel gin > model default. + # The env var stays as an ad-hoc override (e.g. forcing PYTORCH for a one-off + # debug run) without editing the gin. Note: the fused TRITON path avoids + # materializing the dense [B, H, N, N] attention-score tensor that the PYTORCH + # path allocates (~32 GiB at N=2048, bs=1024), so TRITON is both faster and + # far lighter on HBM. On older ROCm, TRITON could hit PassManager errors at + # some shapes (make_ttgir) — fall back to PYTORCH via the gin/env if so. + kernel_choice = ( + os.environ.get("HSTU_HAMMER_KERNEL", "").upper() + or (hammer_kernel.upper() if hammer_kernel else "") + ) + if kernel_choice: + model.set_hammer_kernel(HammerKernel[kernel_choice]) + logger.warning(f"HSTU hammer kernel set to: {kernel_choice}") + + return ( + model, + hstu_config, + table_config, + ) + + +@gin.configurable() +def dense_optimizer_factory_and_class( + optimizer_name: str, + betas: Tuple[float, float], + eps: float, + weight_decay: float, + momentum: float, + learning_rate: float, +) -> Tuple[ + Type[Optimizer], Dict[str, Any], Callable[[Iterable[torch.Tensor]], Optimizer] +]: + kwargs: Dict[str, Any] = {"lr": learning_rate} + if optimizer_name == "Adam": + optimizer_cls = torch.optim.Adam + kwargs.update({"betas": betas, "eps": eps, "weight_decay": weight_decay}) + elif optimizer_name == "SGD": + optimizer_cls = torch.optim.SGD + kwargs.update({"weight_decay": weight_decay, "momentum": momentum}) + elif optimizer_name == "AdamW": + optimizer_cls = torch.optim.AdamW + kwargs.update({"betas": betas, "eps": eps, "weight_decay": weight_decay}) + else: + raise Exception("Unsupported optimizer!") + + optimizer_factory = lambda params: optimizer_cls(params, **kwargs) + + logger.info( + f"[dense optimizer] {optimizer_name} learning_rate={learning_rate} " + f"(resolved from gin; override via $DENSE_LR)" + ) + return optimizer_cls, kwargs, optimizer_factory + + +@gin.configurable() +def sparse_optimizer_factory_and_class( + optimizer_name: str, + betas: Tuple[float, float], + eps: float, + weight_decay: float, + momentum: float, + learning_rate: float, +) -> Tuple[ + Type[Optimizer], Dict[str, Any], Callable[[Iterable[torch.Tensor]], Optimizer] +]: + kwargs: Dict[str, Any] = {"lr": learning_rate} + if optimizer_name == "Adam": + optimizer_cls = torch.optim.Adam + beta1, beta2 = betas + kwargs.update( + {"beta1": beta1, "beta2": beta2, "eps": eps, "weight_decay": weight_decay} + ) + elif optimizer_name == "SGD": + optimizer_cls = torchrec.optim.SGD + kwargs.update({"weight_decay": weight_decay, "momentum": momentum}) + elif optimizer_name == "RowWiseAdagrad": + optimizer_cls = torchrec.optim.RowWiseAdagrad + beta1, beta2 = betas + kwargs.update( + { + "eps": eps, + "beta1": beta1, + "beta2": beta2, + "weight_decay": weight_decay, + } + ) + else: + raise Exception("Unsupported optimizer!") + + optimizer_factory = lambda params: optimizer_cls(params, **kwargs) + + logger.info( + f"[sparse optimizer] {optimizer_name} learning_rate={learning_rate} " + f"(resolved from gin; override via $SPARSE_LR)" + ) + return optimizer_cls, kwargs, optimizer_factory + + +_FBGEMM_LOWMEM_PATCHED = False + + +def _patch_fbgemm_lowmem_clamp_cast(enabled: bool = True, rank0: bool = False) -> None: + """Replace fbgemm's quant clamp+cast with a memory-frugal equivalent. + + ``enabled`` is the gin/env-driven kill switch (see + ``make_optimizer_and_shard.qcomm_lowmem_clamp_cast`` / + ``$QCOMM_LOWMEM_CODEC``). Default ON; pass ``enabled=False`` to fall back to + stock fbgemm (e.g. to reproduce the pre-patch OOM, or if a future fbgemm + version changes the codec and the patch needs revalidation). + + fbgemm's ``fp32_to_fp16_with_clamp`` (and the bf16 variant) does + ``torch.clamp(tensor, MIN, MAX).half()``. ``torch.clamp(...)`` allocates a + SECOND full-size fp32 tensor (same numel as the input) *before* the cast, so + the transient peak is input(fp32) + clamp_temp(fp32) + output(fp16) ~= 2.5x + the input. On a skewed row-wise-sharded batch the hottest shard's embedding + tensor is huge (observed 81.5 GiB clamp temp), and that extra fp32 copy is + exactly the allocation that OOMs the rank — which then exits the train loop + while peers block forever in the a2a (a 30-min NCCL-watchdog hang). See + HANG_ROOTCAUSE.md / flight-recorder dump for the full diagnosis. + + Cast FIRST then clamp IN PLACE: ``tensor.half().clamp_(MIN, MAX)``. This + allocates only the fp16 output (no full-size fp32 temp), cutting the peak by + the size of the input tensor, while being numerically identical: an fp32 + value above HALF_MAX casts to +inf, which clamp_ maps back to HALF_MAX (and + NaNs pass through unchanged), matching clamp-then-cast bit for bit. Safe to + do in place because the codec ``encode()`` runs inside the qcomm autograd + ``Function.forward`` (grad disabled), so there is no graph to corrupt. + """ + global _FBGEMM_LOWMEM_PATCHED + if not enabled: + if rank0: + logger.warning( + "[qcomm-lowmem] DISABLED (qcomm_lowmem_clamp_cast=False / " + "QCOMM_LOWMEM_CODEC=0) — running stock fbgemm clamp+cast, which " + "allocates a full-size fp32 clamp temp and can OOM->hang the " + "hottest row-wise embedding shard on skewed batches." + ) + return + if _FBGEMM_LOWMEM_PATCHED: + return + try: + from fbgemm_gpu import quantize_comm, quantize_utils + + _HMIN = quantize_utils.TORCH_HALF_MIN + _HMAX = quantize_utils.TORCH_HALF_MAX + _BMIN = quantize_utils.TORCH_BFLOAT16_MIN + _BMAX = quantize_utils.TORCH_BFLOAT16_MAX + + def _lowmem_fp16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.half().clamp_(_HMIN, _HMAX) + + def _lowmem_bf16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.bfloat16().clamp_(_BMIN, _BMAX) + + # Patch BOTH the definition module and quantize_comm, which imported the + # names directly (``from .quantize_utils import fp32_to_fp16_with_clamp``) + # so its module-level reference must be overridden too. + for _mod in (quantize_utils, quantize_comm): + if hasattr(_mod, "fp32_to_fp16_with_clamp"): + _mod.fp32_to_fp16_with_clamp = _lowmem_fp16 + if hasattr(_mod, "fp32_to_bf16_with_clamp"): + _mod.fp32_to_bf16_with_clamp = _lowmem_bf16 + + _FBGEMM_LOWMEM_PATCHED = True + if rank0: + logger.info( + "[qcomm-lowmem] patched fbgemm fp32->fp16/bf16 clamp+cast to " + "cast-then-clamp_ (drops the full-size fp32 clamp temp; avoids " + "OOM on skewed row-wise embedding a2a)" + ) + except Exception as e: # noqa: BLE001 — patch is best-effort, never fatal + if rank0: + logger.warning( + "[qcomm-lowmem] could not patch fbgemm clamp+cast (%s: %s); " + "running with stock (higher-peak) quantizer", + type(e).__name__, + e, + ) + + +def _maybe_apply_qcomm_a2a( + sharders: List[Any], + device: torch.device, + forward_precision: str = "fp32", + backward_precision: str = "fp32", + lowmem_clamp_cast: bool = True, +) -> List[Any]: + """Optionally quantize the embedding all-to-all payload via TorchRec qcomm. + + The yambda-5b embedding shuffle is the dominant, bandwidth-bound (multi-node) + collective (~14.5 GB/rank fp32); a bf16/fp16 wire dtype halves it. Quant/ + dequant happen inside the comm op, transparent to the lookup consumer. Ported + from the DLRMv2 R2 lever, retargeted from ``EmbeddingBagCollectionSharder`` to + the sequence ``EmbeddingCollectionSharder`` this model uses. + + Forward and backward are configured independently because they have different + numerical needs (TorchRec golden_training/train_dlrm.py recommends + forward=fp16, backward=bf16): the forward carries bounded embedding + activations where fp16's extra mantissa helps, while gradients have a wider + range that can overflow fp16, so bf16 (fp32 exponent range) is safer there. + bf16 and fp16 are both 2 bytes, so the wire volume / perf is identical — the + choice is purely numerical. + + Args (set via gin on ``make_optimizer_and_shard``, env-overridable). Each is + one of ``fp32`` (that direction unquantized) | ``bf16`` | ``fp16``. If BOTH + are fp32 the sharders are returned untouched (identical to baseline trunk). + """ + _COMM = {"bf16": "BF16", "fp16": "FP16", "fp32": "FP32"} + fwd = (forward_precision or "fp32").strip().lower() + bwd = (backward_precision or "fp32").strip().lower() + rank0 = (not dist.is_initialized()) or dist.get_rank() == 0 + for name, p in (("forward", fwd), ("backward", bwd)): + if p not in _COMM: + # Misconfigured precision: fail loudly rather than silently running + # fp32. A typo in SPARSE_A2A_{FWD,BWD} must not pass as "no quant". + raise ValueError( + f"DLRMV4 qcomm a2a: unknown {name} precision {p!r} " + f"(want one of fp32|bf16|fp16)" + ) + if fwd == "fp32" and bwd == "fp32": + return sharders + # Before building the codec, swap fbgemm's clamp+cast for a memory-frugal + # equivalent — see _patch_fbgemm_lowmem_clamp_cast for why (avoids a full + # extra fp32 temp that OOMs the hottest row-wise shard on skewed batches). + # Gated by `lowmem_clamp_cast` (gin/env); ON by default. + _patch_fbgemm_lowmem_clamp_cast(enabled=lowmem_clamp_cast, rank0=rank0) + try: + from torchrec.distributed.embedding import EmbeddingCollectionSharder + from torchrec.distributed.fbgemm_qcomm_codec import ( + CommType, + get_qcomm_codecs_registry, + QCommsConfig, + ) + + qcfg = QCommsConfig( + forward_precision=getattr(CommType, _COMM[fwd]), + backward_precision=getattr(CommType, _COMM[bwd]), + ) + registry = get_qcomm_codecs_registry(qcfg, device=device) + except Exception as e: # noqa: BLE001 + # A configured quantized a2a that fails to build is a hard error. Silently + # downgrading to fp32 would change numerics/throughput with no signal, and + # a partial failure (one rank fp32, others fp16) would also desync the + # collectives. Raise on every rank so the whole job aborts consistently. + raise RuntimeError( + f"DLRMV4 qcomm a2a: failed to enable configured quantization " + f"(forward={fwd} backward={bwd}): {type(e).__name__}: {e}" + ) from e + + new_sharders = [] + replaced = False + for s in sharders: + if type(s).__name__ == "EmbeddingCollectionSharder" and not replaced: + new_sharders.append( + EmbeddingCollectionSharder(qcomm_codecs_registry=registry) + ) + replaced = True + else: + new_sharders.append(s) + if not replaced: + # Codec registry built fine, but there was no EmbeddingCollectionSharder to + # bind it to, so the quantized a2a would be silently inert. Treat this as a + # hard failure too — "configured but not applied" is the bug we want caught. + raise RuntimeError( + f"DLRMV4 qcomm a2a: quantization configured (forward={fwd} " + f"backward={bwd}) but no EmbeddingCollectionSharder was found to attach " + f"the qcomm codec registry to; refusing to run with quantization " + f"silently disabled" + ) + if rank0: + logger.info( + "DLRMV4 qcomm a2a ENABLED: forward=%s backward=%s " + "replaced_ec_sharder=%s", + fwd, + bwd, + replaced, + ) + return new_sharders + + +def _embedding_table_names( + model: torch.nn.Module, + embedding_table_configs: Optional[Dict[str, EmbeddingConfig]], +) -> List[str]: + """All embedding table names the planner will place. + + Prefers the authoritative `embedding_table_configs` (keys == table names == + planner parameter names). Falls back to walking the model's EBC/EC modules + when the configs are not passed in. + """ + if embedding_table_configs: + return list(embedding_table_configs.keys()) + names: List[str] = [] + for _, module in model.named_modules(): + if type(module) in TORCHREC_TYPES: + if isinstance(module, EmbeddingBagCollection): + names.extend(c.name for c in module.embedding_bag_configs()) + elif isinstance(module, EmbeddingCollection): + names.extend(c.name for c in module.embedding_configs()) + return names + + +def _build_placement_constraints( + model: torch.nn.Module, + embedding_placement: str, + embedding_placement_overrides: Dict[str, str], + embedding_table_configs: Optional[Dict[str, EmbeddingConfig]], + embedding_sharding_overrides: Optional[Dict[str, str]] = None, +) -> Dict[str, ParameterConstraints]: + """Translate gin/env placement + sharding strings into ParameterConstraints. + + Two orthogonal per-table knobs are merged into one constraint per table: + + * Placement (compute kernel / memory tier): + ``embedding_placement_overrides.get(name, embedding_placement)``. + ``auto``/empty -> no compute-kernel constraint (planner decides from HBM). + * Sharding type (shard layout): ``embedding_sharding_overrides.get(name)``. + ``auto``/empty (or absent) -> no sharding-type constraint (planner decides, + which is ROW_WISE for the large yambda tables). Use e.g. ``column_wise`` + to move a hot, data-skewed table off ROW_WISE so its embedding all-to-all + is balanced by rank instead of routed by (hot) row. + + A table is added to the returned dict only if at least one knob is set for it + (so with everything ``auto`` we return {} and the plan is byte-identical to + the legacy path). Unknown values raise ValueError. + """ + embedding_sharding_overrides = embedding_sharding_overrides or {} + valid_place = set(_PLACEMENT_TO_KERNEL) | {"auto", ""} + for where, val in [ + ("embedding_placement", embedding_placement), + *[ + (f"embedding_placement_overrides[{k}]", v) + for k, v in embedding_placement_overrides.items() + ], + ]: + if val not in valid_place: + raise ValueError( + f"Invalid embedding placement {val!r} for {where}; " + f"expected one of {sorted(valid_place - {''})}." + ) + valid_shard = set(_SHARDING_TO_TYPE) | {"auto", ""} + for k, v in embedding_sharding_overrides.items(): + if v not in valid_shard: + raise ValueError( + f"Invalid embedding sharding {v!r} for " + f"embedding_sharding_overrides[{k}]; " + f"expected one of {sorted(valid_shard - {''})}." + ) + + names = _embedding_table_names(model, embedding_table_configs) + unknown = ( + set(embedding_placement_overrides) | set(embedding_sharding_overrides) + ) - set(names) + if unknown: + logger.warning( + "[emb-placement] override(s) for unknown table(s) %s ignored; " + "known tables: %s", + sorted(unknown), + sorted(names), + ) + + constraints: Dict[str, ParameterConstraints] = {} + resolved_place: Dict[str, str] = {} + resolved_shard: Dict[str, str] = {} + for name in names: + placement = embedding_placement_overrides.get(name, embedding_placement) + sharding = embedding_sharding_overrides.get(name, "auto") + resolved_place[name] = placement or "auto" + resolved_shard[name] = sharding or "auto" + kernel = _PLACEMENT_TO_KERNEL.get(placement) + stype = _SHARDING_TO_TYPE.get(sharding) + kwargs: Dict[str, Any] = {} + if kernel is not None: + kwargs["compute_kernels"] = [kernel.value] + if stype is not None: + kwargs["sharding_types"] = [stype.value] + if kwargs: + constraints[name] = ParameterConstraints(**kwargs) + + rank = dist.get_rank() if dist.is_initialized() else 0 + if rank == 0: + logger.info( + "[emb-placement] placement(global=%r overrides=%s) sharding(overrides=%s) " + "-> resolved_placement=%s resolved_sharding=%s " + "(constrained=%d/%d tables; the rest are planner-auto)", + embedding_placement, + embedding_placement_overrides or {}, + embedding_sharding_overrides or {}, + resolved_place, + resolved_shard, + len(constraints), + len(names), + ) + return constraints + + +@gin.configurable +def make_optimizer_and_shard( + model: torch.nn.Module, + device: torch.device, + world_size: int, + local_world_size: Optional[int] = None, + hbm_cap_gb: int = 260, + sparse_a2a_forward_precision: str = "fp32", + sparse_a2a_backward_precision: str = "fp32", + qcomm_lowmem_clamp_cast: bool = True, + embedding_placement: str = "auto", + embedding_placement_overrides: Optional[Dict[str, str]] = None, + embedding_sharding_overrides: Optional[Dict[str, str]] = None, + embedding_table_configs: Optional[Dict[str, EmbeddingConfig]] = None, +) -> Tuple[DistributedModelParallel, torch.optim.Optimizer]: + dense_opt_cls, dense_opt_args, dense_opt_factory = ( + dense_optimizer_factory_and_class() + ) + + sparse_opt_cls, sparse_opt_args, sparse_opt_factory = ( + sparse_optimizer_factory_and_class() + ) + # Fuse sparse optimizer to backward step + for k, module in model.named_modules(): + if type(module) in TORCHREC_TYPES: + for _, param in module.named_parameters(prefix=k): + if param.requires_grad: + apply_optimizer_in_backward( + sparse_opt_cls, [param], sparse_opt_args + ) + sharders = get_default_sharders() + sharders = _maybe_apply_qcomm_a2a( + sharders, + device, + forward_precision=sparse_a2a_forward_precision, + backward_precision=sparse_a2a_backward_precision, + lowmem_clamp_cast=qcomm_lowmem_clamp_cast, + ) + # local_world_size = GPUs per node so the planner respects the intra-node + # (xGMI/NVLink) vs inter-node hierarchy when placing shards. Defaults to + # world_size for the single-node case (no behavior change). + logger.info( + "[hbm-cap] make_optimizer_and_shard: hbm_cap_gb=%s (planner Topology hbm_cap=%d bytes), " + "world_size=%s local_world_size=%s", + hbm_cap_gb, + hbm_cap_gb * 1024 * 1024 * 1024, + world_size, + local_world_size or world_size, + ) + # Resolve per-table embedding placement (gin/env-driven). Global default + # `embedding_placement` applies to every table; `embedding_placement_overrides` + # (table name -> placement) wins per table. Tables resolving to "auto" carry + # no constraint (planner decides from hbm_cap). When nothing is constrained we + # pass constraints=None so the plan is byte-identical to the legacy path. + constraints = _build_placement_constraints( + model=model, + embedding_placement=embedding_placement, + embedding_placement_overrides=embedding_placement_overrides or {}, + embedding_sharding_overrides=embedding_sharding_overrides or {}, + embedding_table_configs=embedding_table_configs, + ) + planner = EmbeddingShardingPlanner( + topology=Topology( + local_world_size=local_world_size or world_size, + world_size=world_size, + compute_device="cuda", + hbm_cap=hbm_cap_gb * 1024 * 1024 * 1024, + ddr_cap=0, + ), + constraints=constraints or None, + ) + pg = dist.GroupMember.WORLD + env = ShardingEnv.from_process_group(pg) # pyre-ignore [6] + pg = env.process_group + + plan = planner.collective_plan(model, sharders, pg) + + # Authoritative placement log: report the compute kernel the planner ACTUALLY + # assigned to each table (vs the [emb-placement] line above, which reports what + # was requested). "fused" = HBM; "fused_uvm"/"fused_uvm_caching" = UVM-backed. + # Rank 0 only, best-effort (never break the build over a logging shape change). + if (dist.get_rank() if dist.is_initialized() else 0) == 0: + try: + for module_path, param_plans in plan.plan.items(): + for param_name, ps in param_plans.items(): + logger.info( + "[emb-placement] plan: %s.%s -> compute_kernel=%s " + "sharding_type=%s", + module_path, + param_name, + getattr(ps, "compute_kernel", "?"), + getattr(ps, "sharding_type", "?"), + ) + except Exception as e: # logging only; must never fail the build + logger.warning("[emb-placement] could not dump plan kernels: %s", e) + + # Re-seed right before DMP materializes/inits the sharded embedding tables. + # The per-table seeded init_fn (configs.get_embedding_table_config) handles + # the eager path, but the fused FBGEMM TBE path inits weights on-device and + # may bypass init_fn, drawing from the global RNG instead. Re-seeding here + # (same value on every rank) makes embedding init reproducible run-to-run for + # a fixed sharding plan (Tier 1). Dense params are already initialized in + # make_model, so this does not perturb them. + _emb_seed = int(os.environ.get("SEED", "1")) + torch.manual_seed(_emb_seed) + torch.cuda.manual_seed_all(_emb_seed) + logger.info(f"[emb-init] re-seeded RNGs before DMP with SEED={_emb_seed}") + + # Shard model + model = DistributedModelParallel( + module=model, + device=device, + plan=plan, + sharders=sharders, + ) + + # --- startup init checksum (reproducibility probe) ------------------------- + # Right after DMP materializes real weights, log a deterministic fingerprint + # of every parameter so two builds with the same $SEED + sharding plan can be + # diffed for byte-level init reproducibility. For sharded embeddings we + # all-reduce the per-shard (count, sum, sumsq) so the fingerprint covers the + # WHOLE table independent of how rows split across ranks; replicated dense + # params use rank 0's local copy. + # OFF BY DEFAULT: the fp64 reductions below (.sum(dtype=float64) / + # vector_norm(dtype=float64)) materialize a full fp64 copy of each local + # embedding shard (~2x the fp32 shard, i.e. >150 GiB for the big tables), + # which leaves almost no HBM headroom after sharding and will OOM the build on + # any node with residual memory. Only enable for explicit reproducibility + # checks, ideally with a smaller batch / on a clean node. Enable with + # INIT_CHECKSUM=1. + if os.environ.get("INIT_CHECKSUM", "0") == "1": + import hashlib + + _rank = dist.get_rank() if dist.is_initialized() else 0 + _fps: List[str] = [] + for _name, _p in sorted(model.named_parameters(), key=lambda kv: kv[0]): + _sharded = isinstance(_p, ShardedTensor) + if _sharded: + _shards = _p.local_shards() + _loc = _shards[0].tensor if _shards else None + else: + _loc = _p + if _loc is None or _loc.numel() == 0: + _cnt, _sm, _sq = 0.0, 0.0, 0.0 + else: + _det = _loc.detach() + _cnt = float(_det.numel()) + _sm = _det.sum(dtype=torch.float64).item() + _nrm = torch.linalg.vector_norm( + _det, ord=2, dtype=torch.float64 + ).item() + _sq = _nrm * _nrm + if _sharded and dist.is_initialized(): + _stat = torch.tensor( + [_cnt, _sm, _sq], dtype=torch.float64, device=device + ) + dist.all_reduce(_stat, op=dist.ReduceOp.SUM) + _cnt, _sm, _sq = _stat.tolist() + _fps.append(f"{_name}|{int(_cnt)}|{_sm:.6f}|{_sq:.6f}") + if _rank == 0: + logger.info( + f"[init-checksum] {'sharded' if _sharded else 'dense'} " + f"{_name} n={int(_cnt)} sum={_sm:.6f} sumsq={_sq:.6f}" + ) + if _rank == 0: + _digest = hashlib.sha256("\n".join(_fps).encode()).hexdigest()[:16] + logger.info( + f"[init-checksum] SEED={os.environ.get('SEED', '?')} " + f"params={len(_fps)} digest={_digest}" + ) + + # Create keyed optimizer + all_optimizers = [] + all_params = {} + non_fused_sparse_params = {} + for k, v in in_backward_optimizer_filter(model.named_parameters()): + if v.requires_grad: + if isinstance(v, ShardedTensor): + non_fused_sparse_params[k] = v + else: + all_params[k] = v + + if non_fused_sparse_params: + all_optimizers.append( + ( + "sparse_non_fused", + KeyedOptimizerWrapper( + params=non_fused_sparse_params, optim_factory=sparse_opt_factory + ), + ) + ) + + if all_params: + all_optimizers.append( + ( + "dense", + KeyedOptimizerWrapper( + params=all_params, + optim_factory=dense_opt_factory, + ), + ) + ) + output_optimizer = CombinedOptimizer(all_optimizers) + output_optimizer.init_state(set(model.sparse_grad_parameter_names())) + return model, output_optimizer + + +@gin.configurable +def make_streaming_dataloader( + dataset: HammerToTorchDataset, + ts: Optional[int] = None, + batch_size: int = 0, + num_workers: int = 0, + prefetch_factor: int = 0, + train_only: bool = False, + indices: Optional["np.ndarray"] = None, +) -> DataLoader: + # `indices` (explicit anchor index array) is used by the eval path to + # iterate the FIXED user-holdout set, which spans a window range rather than + # a single ts. Otherwise restrict to window `ts`; train_only=True drops + # held-out eval users so the non-persistent TRAIN loader never trains on + # them (no-leakage guarantee). + if indices is not None: + dataset.dataset.set_active_indices(indices) # pyre-ignore [16] + else: + assert ts is not None, "make_streaming_dataloader needs ts or indices" + dataset.dataset.set_ts(ts, train_only=train_only) # pyre-ignore [16] + total_items = dataset.dataset.get_item_count() + subset = torch.utils.data.Subset(dataset, range(total_items)) + # shuffle=False keeps temporal order within the window: a non-shuffling + # DistributedSampler hands rank r the strided slice indices[r::num_replicas] + # (round-robin), so all ranks stay on the same time front and consume the + # window in index order. Fork ctx mirrors the train path (COW-share the + # mmap'd store instead of pickling it into every worker). + mp_ctx = "fork" if num_workers and num_workers > 0 else None + dataloader = DataLoader( + dataset=subset, + batch_size=batch_size, + shuffle=False, + collate_fn=collate_fn, + drop_last=True, + num_workers=num_workers, + prefetch_factor=prefetch_factor, + sampler=DistributedSampler(subset, shuffle=False, drop_last=True), + multiprocessing_context=mp_ctx, + ) + return dataloader + + +class StreamingWindowSampler(torch.utils.data.Sampler): + """Per-rank sampler whose index list is swapped each window. + + Yields this rank's round-robin slice of the active window's GLOBAL anchor + indices (into the dataset's ``_positions``). Because indices are global, a + single DataLoader with ``persistent_workers=True`` can be reused across all + windows: the main process re-iterates this sampler each window and ships the + new indices to the already-forked workers, which map any global index via + the shared mmap. No per-window worker respawn / dataset re-pickle. + + Round-robin striding (rank r gets ``indices[r::world_size]``) over the + time-sorted window keeps every rank on the same time front; the window is + truncated to a multiple of ``world_size`` so all ranks get equal counts + (required for DDP collective lockstep). + """ + + def __init__(self, rank: int, world_size: int) -> None: + self._rank: int = rank + self._world_size: int = world_size + self._indices: List[int] = [] + + def set_window(self, global_indices, skip_samples: int = 0) -> None: + """Install this window's per-rank index list, optionally fast-forwarding. + + ``skip_samples`` drops the first N per-rank samples from the list so the + next ``__iter__`` starts at sample N+1 in this rank's slice. Used on + resume to skip batches that were already trained: pass + ``skip_samples = batch_size * batches_completed`` and the dataloader + emits batches starting at exactly the next unseen batch. + + The skip is safe because the sample order is fully deterministic given + (global_indices, rank, world_size): we re-derive the same per-rank list + as the pre-crash run, just hand back a tail slice of it. + """ + n = (len(global_indices) // self._world_size) * self._world_size + per_rank = global_indices[:n][self._rank :: self._world_size].tolist() + if skip_samples < 0 or skip_samples > len(per_rank): + raise ValueError( + f"skip_samples={skip_samples} out of [0, {len(per_rank)}] " + f"for rank={self._rank} world_size={self._world_size}" + ) + self._indices = per_rank[skip_samples:] + + def __iter__(self): + return iter(self._indices) + + def __len__(self) -> int: + return len(self._indices) + + +@gin.configurable +def make_persistent_streaming_dataloader( + dataset: HammerToTorchDataset, + sampler: StreamingWindowSampler, + batch_size: int, + num_workers: int, + prefetch_factor: int, +) -> DataLoader: + """One reusable DataLoader for the whole streaming run. ``sampler`` is + mutated per window via ``set_window``; workers persist across windows.""" + use_workers = bool(num_workers and num_workers > 0) + return DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=False, + collate_fn=collate_fn, + drop_last=True, + num_workers=num_workers, + prefetch_factor=prefetch_factor if use_workers else None, + sampler=sampler, + persistent_workers=use_workers, + multiprocessing_context="fork" if use_workers else None, + ) + + +class _PrefetchingWindowLoader: + """Double-buffered window loader for the persistent streaming path. + + Holds ``n_buffers`` pre-forked persistent worker pools that ping-pong: while + the current window trains on one pool, the *next* window's index selection + (``window_indices``) and first-batch prefetch are prepared on another pool + in a background thread. By the time training advances, that window is warm, + so the per-window reset (mask + first-batch stall) is hidden behind GPU + compute (~0 dead time at the boundary). + + Worker pools are forked once on the main thread at the start of ``stream``; + afterwards only iterator resets happen (no forks), so background-thread + preparation cannot fork while other threads hold locks. + """ + + def __init__( + self, + dataset: "HammerToTorchDataset", + sampler_factory, + dl_factory, + n_buffers: int = 2, + ) -> None: + self._dataset = dataset + self._n = n_buffers + self._samplers = [sampler_factory() for _ in range(n_buffers)] + self._dls = [dl_factory(s) for s in self._samplers] + self._iters: List[Optional[object]] = [None] * n_buffers + + def _prepare(self, buf: int, ts: int, skip_samples: int = 0) -> None: + # train_window_indices() is the O(N) mask (+ uid-hash filter for the + # holdout); numpy releases the GIL for it, so it overlaps the main + # thread's GPU dispatch. iter() then kicks off this pool's background + # prefetch. This is a TRAIN-only loader, so held-out eval users are + # excluded here. `skip_samples` is non-zero only for the very first + # window after a mid-window resume; subsequent windows always start at 0. + self._samplers[buf].set_window( + self._dataset.dataset.train_window_indices(ts), skip_samples=skip_samples + ) + self._iters[buf] = iter(self._dls[buf]) + + def stream(self, ts_list: List[int], first_skip_samples: int = 0): + """Stream (ts, iterator) pairs. `first_skip_samples` is applied ONLY to + the first ts in ``ts_list`` (the mid-window-resumed window); every + subsequent window starts at sample 0 of its own per-rank list.""" + n = len(ts_list) + if n == 0: + return + threads: List[Optional[threading.Thread]] = [None] * self._n + # Prime the first n_buffers windows on the main thread (forks all pools). + for b in range(min(self._n, n)): + skip = first_skip_samples if b == 0 else 0 + self._prepare(b, ts_list[b], skip_samples=skip) + for i in range(n): + buf = i % self._n + if threads[buf] is not None: + threads[buf].join() + threads[buf] = None + yield ts_list[i], self._iters[buf] + # This pool is now free; prefetch the window n_buffers ahead. + # No skip on subsequent windows — only the first prepared window + # carries `first_skip_samples`. + j = i + self._n + if j < n: + th = threading.Thread( + target=self._prepare, args=(buf, ts_list[j]), daemon=True + ) + th.start() + threads[buf] = th + + +@gin.configurable +def make_train_test_dataloaders( + batch_size: int, + dataset_type: str, + hstu_config: DlrmHSTUConfig, + train_split_percentage: float, + embedding_table_configs: Dict[str, EmbeddingConfig], + new_path_prefix: str = "", + num_workers: int = 0, + num_blocks: int = 1, + prefetch_factor: Optional[int] = None, + eval_batch_size: Optional[int] = None, +) -> Tuple[DataLoader, DataLoader]: + dataset_class, kwargs = get_dataset( + name=dataset_type, new_path_prefix=new_path_prefix + ) + kwargs["embedding_config"] = embedding_table_configs + + # Create dataset + dataset = HammerToTorchDataset( + dataset=dataset_class(hstu_config=hstu_config, is_inference=False, **kwargs) + ) + total_items = dataset.dataset.get_item_count() + items_per_block = total_items // num_blocks + train_size_per_block = round(train_split_percentage * items_per_block) + # Avoid `extend(range(...))` which materializes a Python list of all sample + # indices — at 3.2B yambda samples × 28 bytes/int ≈ 90 GB/rank just for + # train_inds. Subset accepts any sequence with O(1) __len__ and __getitem__, + # so pass range objects (or a tiny chained view) directly. + if num_blocks == 1: + train_inds = range(0, train_size_per_block) + test_inds = range(train_size_per_block, items_per_block) + else: + train_inds = _ChainedRanges([ + range(i * items_per_block, i * items_per_block + train_size_per_block) + for i in range(num_blocks) + ]) + test_inds = _ChainedRanges([ + range(i * items_per_block + train_size_per_block, (i + 1) * items_per_block) + for i in range(num_blocks) + ]) + train_set = torch.utils.data.Subset(dataset, train_inds) + test_set = torch.utils.data.Subset(dataset, test_inds) + + # When the parent rank is started via mp.start_processes(start_method="spawn"), + # torch.multiprocessing's default Process context is also "spawn". DataLoader + # then pickles `self._dataset` to send to each worker — which for our mmap'd + # 211 GB yambda store materializes the entire dataset into the parent's anon + # memory (~230 GB/rank). Forcing "fork" lets workers inherit the parent's + # mmap'd pages via COW with zero extra anon. + mp_ctx = "fork" if num_workers and num_workers > 0 else None + train_dataloader = DataLoader( + dataset=train_set, + batch_size=batch_size, + shuffle=False, + collate_fn=collate_fn, + drop_last=True, + num_workers=num_workers, + prefetch_factor=prefetch_factor, + sampler=ChunkDistributedSampler(train_set, drop_last=True, shuffle=True), + multiprocessing_context=mp_ctx, + ) + test_dataloader = DataLoader( + dataset=test_set, + batch_size=eval_batch_size if eval_batch_size is not None else batch_size, + shuffle=False, + collate_fn=collate_fn, + drop_last=True, + num_workers=num_workers, + prefetch_factor=prefetch_factor, + sampler=ChunkDistributedSampler(test_set, drop_last=True, shuffle=True), + multiprocessing_context=mp_ctx, + ) + return train_dataloader, test_dataloader + + +# THROWAWAY DIAG state: per-embedding-table lookup stats accumulated across a +# fixed window of steps (DIAG_EMB_STEPS, default 100) so the reported numbers are +# averaged/aggregated rather than a single noisy batch. Rank-0 only. +_EMB_DIAG_ACC: Dict[str, Dict[str, Any]] = {} +_EMB_DIAG_NBATCH: int = 0 +# Cap (per-batch lookups) below which we also track a TRUE global-unique set +# across the whole window (cheap for contextual/cross tables, total == batch +# size). Sequential tables (item/artist/album) blow past this and only get the +# per-batch averages. +_EMB_DIAG_GLOBAL_CAP: int = 1 << 17 # 131072 + + +def _log_unique_embedding_diag( + sample, rank: int, step: int, max_steps: int = 100, log_every: int = 50 +) -> None: + """Diagnostic: aggregate per-embedding-table lookup stats over a step window. + + Quantifies the user-major batching concern — when consecutive sliding-window + anchors come from the same few users, a batch reads very few UNIQUE embedding + rows (low unique/total), so embedding lookups are highly redundant. Covers the + base id tables AND the cross-feature tables (user_x_* / *_x_hour hashed + combos); the user_x_* tables should be the most redundant under shuffle OFF. + + Accumulates over ``max_steps`` batches and emits an aggregate summary (mean + per-batch unique%/hot%/top10%, plus a true global-unique over the whole + window for the small contextual/cross tables). Rank-0 only, non-fatal. + """ + if rank != 0: + return + global _EMB_DIAG_NBATCH + try: + from generative_recommenders.dlrm_v3.configs import YAMBDA_5B_CROSS_SPECS + + cross_caps = {name: n for (name, _k, n, _s) in YAMBDA_5B_CROSS_SPECS} + + def _table_of(key: str): + # cross tables match by exact name; resolve BEFORE substring fallbacks + # so e.g. 'user_x_artist' isn't misread as the artist_id table. + if key in cross_caps: + return key, cross_caps[key] + if key == "uid" or key.endswith("_uid") or key.endswith(".uid"): + return "uid", 0 + if "artist" in key: + return "artist_id", 0 + if "album" in key: + return "album_id", 0 + if "item" in key and key.endswith("id"): + return "item_id", 0 + return None, 0 + + for tag, kjt in ( + ("uih", sample.uih_features_kjt), + ("cand", sample.candidates_features_kjt), + ): + for key in kjt.keys(): + table, cap = _table_of(key) + if table is None: + continue + vals = kjt[key].values() + total = int(vals.numel()) + if total == 0: + continue + u, counts = torch.unique(vals, return_counts=True) + uniq = int(u.numel()) + hot1 = int(counts.max().item()) + k = min(10, uniq) + topk = int(torch.topk(counts, k).values.sum().item()) + + slot = _EMB_DIAG_ACC.setdefault( + f"{tag}.{key}", + { + "table": table, + "cap": cap, + "n": 0, + "tot": 0, + "uniq": 0, + "upct": 0.0, + "upct_min": 100.0, + "upct_max": 0.0, + "hot1pct": 0.0, + "topkpct": 0.0, + "glob": None, # running global-unique id tensor (small tables) + }, + ) + upct = 100.0 * uniq / total + slot["n"] += 1 + slot["tot"] += total + slot["uniq"] += uniq + slot["upct"] += upct + slot["upct_min"] = min(slot["upct_min"], upct) + slot["upct_max"] = max(slot["upct_max"], upct) + slot["hot1pct"] += 100.0 * hot1 / total + slot["topkpct"] += 100.0 * topk / total + if total <= _EMB_DIAG_GLOBAL_CAP: + prev = slot["glob"] + merged = u if prev is None else torch.cat([prev, u]) + slot["glob"] = torch.unique(merged) + + _EMB_DIAG_NBATCH += 1 + n = _EMB_DIAG_NBATCH + if n % log_every == 0 or n >= max_steps: + lines = [f"emb-diag AGGREGATE over {n} batches (step<= {step}):"] + for name in sorted(_EMB_DIAG_ACC): + s = _EMB_DIAG_ACC[name] + c = max(1, s["n"]) + cap_s = f" cap={s['cap']/1e6:.0f}M" if s["cap"] else "" + glob_s = "" + if s["glob"] is not None: + g = int(s["glob"].numel()) + glob_s = ( + f" | global_uniq={g} over {s['tot']} seen " + f"({s['tot']/max(1,g):.1f}x reuse)" + ) + lines.append( + f" {name}[{s['table']}]{cap_s}: " + f"avg_tot={s['tot']/c:.0f} " + f"avg_uniq%={s['upct']/c:.1f} " + f"(min={s['upct_min']:.1f} max={s['upct_max']:.1f}) " + f"avg_hot1%={s['hot1pct']/c:.1f} " + f"avg_top10%={s['topkpct']/c:.1f}" + f"{glob_s}" + ) + logger.info("\n".join(lines)) + except Exception as e: # diagnostic must never break training + logger.warning(f"emb-diag failed: {e}") + + +@gin.configurable +def train_loop( + rank: int, + model: torch.nn.Module, + dataloader: torch.utils.data.DataLoader, + optimizer: Optimizer, + metric_logger: MetricsLogger, + device: torch.device, + num_epochs: int, + num_batches: Optional[int] = None, + output_trace: bool = False, + metric_log_frequency: int = 1, + checkpoint_frequency: int = 100, + start_batch_idx: int = 0, + streaming_diag_unique_emb: bool = False, + # lr_scheduler: to-do: Add a scheduler +) -> None: + model.train() + batch_idx: int = start_batch_idx + profiler = Profiler(rank) if output_trace else None + + for epoch in range(num_epochs): + dataloader.sampler.set_epoch(epoch) # pyre-ignore [16] + for sample in dataloader: + if streaming_diag_unique_emb and batch_idx < int( + os.environ.get("DIAG_EMB_STEPS", "100") + ): + _log_unique_embedding_diag( + sample, + rank, + batch_idx, + max_steps=int(os.environ.get("DIAG_EMB_STEPS", "100")), + log_every=metric_log_frequency, + ) + optimizer.zero_grad() + sample.to(device) + ( + _, + _, + aux_losses, + mt_target_preds, + mt_target_labels, + mt_target_weights, + ) = model.forward( + sample.uih_features_kjt, + sample.candidates_features_kjt, + ) + # pyre-ignore + sum(aux_losses.values()).backward() + optimizer.step() + metric_logger.update( + mode="train", + predictions=mt_target_preds, + labels=mt_target_labels, + weights=mt_target_weights, + num_candidates=sample.candidates_features_kjt.lengths().view( + len(sample.candidates_features_kjt.keys()), -1 + )[0], + ) + if batch_idx % metric_log_frequency != 0: + metric_logger.compute_and_log( + mode="train", + additional_logs={ + "losses": aux_losses, + }, + ) + if batch_idx % checkpoint_frequency == 0 and batch_idx > 0: + save_dmp_checkpoint( + model=model, + optimizer=optimizer, + metric_logger=metric_logger, + rank=rank, + batch_idx=batch_idx, + ) + batch_idx += 1 + if output_trace: + assert profiler is not None + profiler.step() + if num_batches is not None and batch_idx >= num_batches: + break + if num_batches is not None and batch_idx >= num_batches: + break + + +@gin.configurable +def eval_loop( + rank: int, + model: torch.nn.Module, + dataloader: torch.utils.data.DataLoader, + metric_logger: MetricsLogger, + device: torch.device, + metric_log_frequency: int = 1, + num_batches: Optional[int] = None, + output_trace: bool = False, + # lr_scheduler: to-do: Add a scheduler +) -> None: + model.eval() + # Exclude eval wall-time from the train step-time window (see _run_eval_window). + metric_logger.pause_perf("eval") + batch_idx: int = 0 + profiler = Profiler(rank) if output_trace else None + metric_logger.reset(mode="eval") + with torch.no_grad(): + for sample in dataloader: + sample.to(device) + ( + _, + _, + _, + mt_target_preds, + mt_target_labels, + mt_target_weights, + ) = model.forward( + sample.uih_features_kjt, + sample.candidates_features_kjt, + ) + metric_logger.update( + mode="eval", + predictions=mt_target_preds, + labels=mt_target_labels, + weights=mt_target_weights, + num_candidates=sample.candidates_features_kjt.lengths().view( + len(sample.candidates_features_kjt.keys()), -1 + )[0], + ) + if batch_idx % metric_log_frequency != 0: + metric_logger.compute_and_log(mode="eval") + batch_idx += 1 + if output_trace: + assert profiler is not None + profiler.step() + if num_batches is not None and batch_idx >= num_batches: + break + metric_logger.compute_and_log(mode="eval") + for k, v in metric_logger.compute(mode="eval").items(): + print(f"{k}: {v}") + metric_logger.resume_perf("eval") + + +class _PipelineModelWrapper(torch.nn.Module): + """Adapt ``DlrmHSTU.forward`` to the ``(loss, output)`` contract that + ``TrainPipelineSparseDist`` expects. + + The wrapped ``model`` is the same DMP instance handed to the pipeline as + ``model=``; the pipeline rewrites its sharded ``EmbeddingCollection`` in + place, so calling it here is what lets the embedding all-to-all overlap the + dense forward/backward compute. + """ + + def __init__(self, model: torch.nn.Module) -> None: + super().__init__() + self._model = model + + def forward( + self, batch: Samples + ) -> Tuple[torch.Tensor, Tuple[Any, ...]]: + # The model runs in `_pipeline_mode`: it takes the whole batch as its + # single arg and reads the pre-merged sparse KJT off it. This keeps the + # EmbeddingCollection input a plain getattr on the batch placeholder so + # TorchRec pipelines its input_dist (instead of skipping it for "input + # modifications"). + ( + _, + _, + aux_losses, + mt_target_preds, + mt_target_labels, + mt_target_weights, + ) = self._model(batch) + loss = sum(aux_losses.values()) + num_candidates = batch.candidates_features_kjt.lengths().view( + len(batch.candidates_features_kjt.keys()), -1 + )[0] + output = ( + aux_losses, + mt_target_preds, + mt_target_labels, + mt_target_weights, + num_candidates, + ) + return loss, output + + +def build_train_pipeline( + model: torch.nn.Module, + optimizer: Optimizer, + device: torch.device, + grad_clip_norm: float = 1.0, +) -> Any: + """Build a ``TrainPipelineSparseDist`` for the DMP-wrapped HSTU model. + + The 3-stage pipeline overlaps (1) H2D transfer of batch N+2, (2) the sparse + data-dist all-to-all of batch N+1's embedding lookup, and (3) dense fwd/bwd + of batch N, on separate CUDA streams. Requires the model to be wrapped with + ``DistributedModelParallel`` (see ``make_optimizer_and_shard``). + """ + # Lazy import: keeps module import working on torchrec builds that move or + # rename the pipeline, and matches the reference Primus-DLRM setup. + from torchrec.distributed.train_pipeline import TrainPipelineSparseDist + + # Switch the (DMP-wrapped) HSTU model into pipeline mode so both the fx trace + # and the live forward consume the batch as a single arg and read the + # pre-merged sparse KJT off it — required for the embedding input_dist to be + # pipelined. Eval call sites pass the batch the same way (see train_eval_loop). + underlying = model.module if hasattr(model, "module") else model + underlying._pipeline_mode = True + + # The pipeline calls backward()+optimizer.step() internally inside + # progress(), leaving no in-loop hook point for gradient clipping. Clip via + # a full-backward hook (fires after autograd populates dense grads, before + # the optimizer step) to preserve parity with the sequential path's + # clip_grad_norm_(model.parameters(), max_norm=1.0). + if grad_clip_norm and grad_clip_norm > 0: + + def _clip_grads(_m: torch.nn.Module, _gi: Any, _go: Any) -> None: + torch.nn.utils.clip_grad_norm_( + model.parameters(), max_norm=grad_clip_norm + ) + + model.register_full_backward_hook(_clip_grads) + + return TrainPipelineSparseDist( + model=model, + optimizer=optimizer, + device=device, + execute_all_batches=True, + custom_model_fwd=_PipelineModelWrapper(model), + ) + + +@gin.configurable +def train_eval_loop( + rank: int, + model: torch.nn.Module, + optimizer: Optimizer, + metric_logger: MetricsLogger, + device: torch.device, + num_epochs: int, + num_train_batches: Optional[int] = None, + num_eval_batches: Optional[int] = None, + train_dataloader: Optional[torch.utils.data.DataLoader] = None, + eval_dataloader: Optional[torch.utils.data.DataLoader] = None, + output_trace: bool = False, + metric_log_frequency: int = 1, + checkpoint_frequency: int = 100, + eval_frequency: int = 1, + start_train_batch_idx: int = 0, + start_eval_batch_idx: int = 0, + use_pipeline: bool = False, + # lr_scheduler: to-do: Add a scheduler +) -> None: + train_batch_idx: int = start_train_batch_idx + eval_batch_idx: int = start_eval_batch_idx + profiler = Profiler(rank) if output_trace else None + assert train_dataloader is not None and eval_dataloader is not None + + eval_data_iterator = iter(eval_dataloader) + train_data_iterator = iter(train_dataloader) + + # 3-stage TorchRec pipeline (overlaps embedding a2a with dense compute). + # When enabled, progress() owns H2D copy, sparse-dist, fwd/bwd and the + # optimizer step; grad clipping moves to a full-backward hook (see builder). + train_pipeline = ( + build_train_pipeline(model, optimizer, device) if use_pipeline else None + ) + + for epoch in range(num_epochs): + train_dataloader.sampler.set_epoch(epoch) # pyre-ignore [16] + while True: + model.train() + if train_pipeline is not None: + try: + ( + aux_losses, + mt_target_preds, + mt_target_labels, + mt_target_weights, + num_candidates, + ) = train_pipeline.progress(train_data_iterator) + except StopIteration: + train_data_iterator = iter(train_dataloader) + break + else: + try: + sample = next(train_data_iterator) + except StopIteration: + train_data_iterator = iter(train_dataloader) + break + optimizer.zero_grad() + sample.to(device) + ( + _, + _, + aux_losses, + mt_target_preds, + mt_target_labels, + mt_target_weights, + ) = model.forward( + sample.uih_features_kjt, + sample.candidates_features_kjt, + ) + # pyre-ignore + sum(aux_losses.values()).backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + num_candidates = sample.candidates_features_kjt.lengths().view( + len(sample.candidates_features_kjt.keys()), -1 + )[0] + metric_logger.update( + mode="train", + predictions=mt_target_preds, + labels=mt_target_labels, + weights=mt_target_weights, + num_candidates=num_candidates, + ) + if train_batch_idx % metric_log_frequency == 0: + metric_logger.compute_and_log( + mode="train", + additional_logs={ + "losses": aux_losses, + }, + ) + if train_batch_idx % checkpoint_frequency == 0 and train_batch_idx > 0: + save_dmp_checkpoint( + model=model, + optimizer=optimizer, + metric_logger=metric_logger, + rank=rank, + batch_idx=train_batch_idx, + ) + train_batch_idx += 1 + if output_trace: + assert profiler is not None + profiler.step() + if train_batch_idx % eval_frequency == 0: + model.eval() + eval_batch_idx: int = 0 + with torch.no_grad(): + while True: + try: + sample = next(eval_data_iterator) + except StopIteration: + eval_data_iterator = iter(eval_dataloader) + sample = next(eval_data_iterator) + sample.to(device) + ( + _, + _, + _, + mt_target_preds, + mt_target_labels, + mt_target_weights, + ) = ( + # In pipeline mode the model takes the batch as one + # arg (see _PipelineModelWrapper / DlrmHSTU.forward). + model.forward(sample) + if use_pipeline + else model.forward( + sample.uih_features_kjt, + sample.candidates_features_kjt, + ) + ) + metric_logger.update( + mode="eval", + predictions=mt_target_preds, + labels=mt_target_labels, + weights=mt_target_weights, + num_candidates=sample.candidates_features_kjt.lengths().view( + len(sample.candidates_features_kjt.keys()), -1 + )[0], + ) + eval_batch_idx += 1 + if output_trace: + assert profiler is not None + profiler.step() + if eval_batch_idx % metric_log_frequency == 0: + metric_logger.compute_and_log(mode="eval") + if ( + num_eval_batches is not None + and eval_batch_idx >= num_eval_batches + ): + break + for k, v in metric_logger.compute(mode="eval").items(): + print(f"{k}: {v}") + model.train() + # `num_train_batches` cap: None or 0 = run the whole window. >0 caps + # batches per window (mostly the streaming-resume test driver uses + # this to keep test windows short). + if num_train_batches and train_batch_idx >= num_train_batches: + break + + +def select_in_window_checkpoint_reason( + *, + train_batch_idx: int, + global_step: int, + elapsed_since_last_save: float, + in_window_checkpoint_frequency: int, + checkpoint_step_frequency: int, + checkpoint_time_interval_s: float, +) -> Optional[str]: + """Decide which (if any) in-window checkpoint cadence fires this batch. + + Pure / distributed-agnostic so it can be unit-tested without a real run. + The caller computes `elapsed_since_last_save` (broadcast from rank 0 in the + streaming loop) so all ranks pass the same value and reach the same verdict. + + Precedence (at most one save per batch): per-window-local batch count > + monotonic global step > wall-clock interval. Returns the trigger reason + string, or None when no cadence fires. A cadence is disabled when its + frequency/interval is 0 / 0.0. + + Counter conventions match the loop: `train_batch_idx` is already + post-incremented (>=1 on the first batch), and `global_step` is guarded + >0 so step 0 doesn't trivially satisfy `% N == 0`. + """ + if ( + in_window_checkpoint_frequency > 0 + and train_batch_idx % in_window_checkpoint_frequency == 0 + ): + return "in_window_batch" + if ( + checkpoint_step_frequency > 0 + and global_step > 0 + and global_step % checkpoint_step_frequency == 0 + ): + return "global_step" + if ( + checkpoint_time_interval_s > 0 + and elapsed_since_last_save >= checkpoint_time_interval_s + ): + return "time_interval" + return None + + +def _validate_split_contract( + saved: Optional[Dict[str, Any]], + live: Dict[str, Any], + rank: int, +) -> None: + """Guarantee the train:eval split (and the inputs the resume skip-offset + depends on) are unchanged across a crash/resume. + + `saved` is the contract recovered from the checkpoint (None on cold start or + legacy pre-holdout checkpoints). Any mismatch is fatal: continuing would + either desync the mid-window skip (duplicate/skip batches) or reassign users + so that previously held-out eval users get trained (leakage). Set + ALLOW_SPLIT_MISMATCH=1 to override (e.g. intentionally resuming a legacy + checkpoint into a holdout run, accepting the risk). + """ + allow = os.environ.get("ALLOW_SPLIT_MISMATCH", "0") == "1" + if saved is None: + # Legacy / cold-start checkpoint with no recorded contract. Only a + # problem if this run actually holds users out (tsp < 1.0): we cannot + # prove the earlier run used the same split. + if live.get("train_split_percentage", 1.0) < 1.0 and not allow: + raise RuntimeError( + "Resuming a checkpoint with NO saved split contract into a " + f"user-holdout run (train_split_percentage=" + f"{live['train_split_percentage']}). The earlier run's split " + "cannot be verified, so held-out eval users may have been " + "trained. Set ALLOW_SPLIT_MISMATCH=1 to override." + ) + return + mismatches = { + k: (saved.get(k), live.get(k)) + for k in live + if saved.get(k) != live.get(k) + } + if mismatches: + msg = ( + "Split/resume contract mismatch between checkpoint and current run: " + + ", ".join( + f"{k}: checkpoint={s!r} current={c!r}" for k, (s, c) in mismatches.items() + ) + + ". Resuming would desync the skip offset and/or leak held-out " + "users into training." + ) + if allow: + if rank == 0: + logger.warning("%s ALLOW_SPLIT_MISMATCH=1 set — continuing anyway.", msg) + else: + raise RuntimeError(msg + " Set ALLOW_SPLIT_MISMATCH=1 to override.") + elif rank == 0: + logger.info("Split/resume contract verified against checkpoint: %s", live) + + +@gin.configurable +def streaming_train_eval_loop( + rank: int, + model: torch.nn.Module, + optimizer: Optimizer, + metric_logger: MetricsLogger, + device: torch.device, + num_train_ts: int, + hstu_config: DlrmHSTUConfig, + embedding_table_configs: Dict[str, EmbeddingConfig], + num_train_batches: Optional[int] = None, + num_eval_batches: Optional[int] = None, + output_trace: bool = False, + metric_log_frequency: int = 1, + # MLPerf `train_loss` event cadence, in global train steps, INDEPENDENT of + # metric_log_frequency (the console/TB cadence). 0 = fall back to + # metric_log_frequency (preserves prior coupled behavior). Wired to + # $MLPERF_TRAIN_LOSS_LOG_FREQ via gin. + mlperf_train_loss_log_frequency: int = 0, + checkpoint_frequency: int = 100, + start_ts: int = 0, + persistent_loader: bool = False, + eval_every_n_windows: int = 1, + # Data-fraction eval cadence (mutually exclusive with eval_every_n_windows). + # >0 = eval every this FRACTION of the run's total training data, converted + # once to a global train-step interval. 0.0 = OFF (use the per-window + # cadence). Wired to $EVAL_EVERY_DATA_PCT via gin. + eval_every_data_pct: float = 0.0, + double_buffer: bool = False, + # --- fixed user-holdout eval set --- + # Window range the fixed eval set is drawn from. None -> default to + # original_end_ts (start_ts + num_train_ts), the window just past training. + eval_holdout_ts: Optional[int] = None, + eval_holdout_num_windows: int = 1, + # --- resume / mid-window-exact-once knobs --- + resume_train_ts: Optional[int] = None, + resume_batch_idx_in_window: int = WINDOW_COMPLETE, + # Split contract recovered from the checkpoint (None on cold start or + # legacy checkpoints). Validated below against the live split so a resumed + # run cannot silently train a different user-split (would leak). + resume_split_contract: Optional[Dict[str, Any]] = None, + # True iff no checkpoint was loaded (genuine fresh run). Distinguishes a + # cold start (safe to establish a new split) from a resume that merely lacks + # a contract (legacy/non-streaming checkpoint), which the guard must reject. + resume_cold_start: bool = False, + in_window_checkpoint_frequency: int = 0, + # --- global step / wall-clock checkpoint cadences --- + checkpoint_step_frequency: int = 0, + checkpoint_time_interval_s: float = 0.0, + # --- gradient clipping (streaming path, dense params). 0.0 = OFF, which + # preserves legacy streaming behavior. Wired to $GRAD_CLIP_NORM via gin. --- + grad_clip_norm: float = 0.0, + # --- diagnostic: log per-batch unique/total embedding-id counts --- + streaming_diag_unique_emb: bool = False, + # --- test-only failure injection knob --- + die_at_step: int = -1, + # MLPerf logger (rank-0-gated); None disables all MLPerf event emission. + mlperf_logger: Optional[Any] = None, +) -> None: + """Streaming train+eval loop with per-window (and optionally mid-window) + checkpoints. + + Resume semantics (set by train_ranker after `load_dmp_checkpoint` returns): + - resume_train_ts=None: cold start; honor `start_ts` as-is. + - resume_train_ts=N, resume_batch_idx_in_window=WINDOW_COMPLETE(-1): + previous run finished window N cleanly. Start at N+1 from sample 0. + - resume_train_ts=N, resume_batch_idx_in_window=K (K>=0): previous run + crashed mid-window after K completed batches. Re-enter window N and + skip the first K batches of THIS rank's per-rank sample list (deterministic + slice since `window_indices(N)` is a pure function of the anchor_ts cache). + + Checkpoint cadences (all independent; any combination may be enabled): + - `checkpoint_frequency`: window-granularity. End-of-window save every + Nth train_ts (and always on the final window). Uses WINDOW_COMPLETE. + - `in_window_checkpoint_frequency`: per-window-local batch count. Fires + every N batches *within* a window (counter resets each window). + - `checkpoint_step_frequency`: global-step granularity. Fires whenever + the monotonic `metric_logger.global_step['train']` hits a multiple of + N — i.e. a true "every 1000 steps" trigger that spans windows and + survives resume (global_step is restored from the checkpoint). + - `checkpoint_time_interval_s`: wall-clock granularity. Fires when at + least this many seconds have elapsed since the last save (e.g. 3600 + for hourly). Rank 0 owns the clock and broadcasts the decision so all + ranks save together (avoids the collective barrier in + `save_dmp_checkpoint` deadlocking on a split decision). + + All in-window triggers (`in_window_checkpoint_frequency`, + `checkpoint_step_frequency`, `checkpoint_time_interval_s`) route through + `_save_mid_window`, which stamps `batch_idx_in_window=K` so a crash leaves + a resumable partial-window checkpoint. End-of-window saves + (`checkpoint_frequency`) always use the WINDOW_COMPLETE sentinel. 0 / 0.0 + disables a given cadence (the default for all three fine-grained ones). + + `die_at_step` is a test-only hook: when `metric_logger.global_step['train']` + reaches this value, the process exits with code 42 right after the in-window + save fires. Used by the failure-injection test to crash at a deterministic + boundary and then resume. + """ + # Exactly one eval cadence may be active. eval_every_n_windows defaults to 1 + # (eval every window), so enabling the data-fraction cadence REQUIRES + # explicitly disabling the per-window one (EVAL_EVERY_N_WINDOWS=0). Fail fast + # on a contradictory config rather than silently picking one. + if (eval_every_data_pct and eval_every_data_pct > 0) and eval_every_n_windows > 0: + raise ValueError( + "Conflicting eval cadences: eval_every_data_pct=" + f"{eval_every_data_pct} (>0) AND eval_every_n_windows=" + f"{eval_every_n_windows} (>0). They are mutually exclusive. To use " + "the data-fraction cadence set EVAL_EVERY_N_WINDOWS=0; to use the " + "per-window cadence set EVAL_EVERY_DATA_PCT=0." + ) + # MLPerf train_loss cadence: independent of metric_log_frequency. 0 (the + # env-binding default) falls back to metric_log_frequency so unset behavior + # matches the prior coupled implementation. + mlperf_loss_every = ( + mlperf_train_loss_log_frequency + if mlperf_train_loss_log_frequency and mlperf_train_loss_log_frequency > 0 + else metric_log_frequency + ) + profiler = Profiler(rank) if output_trace else None + # Normalize the per-window caps: <=0 (the env-binding default) means "no cap + # = consume the full window". The eval-break check below is `is not None and + # eval_batch_idx >= num_eval_batches`, so a literal 0 would (wrongly) break + # after the first batch — map it to None instead for the full-holdout eval. + if num_eval_batches is not None and num_eval_batches <= 0: + num_eval_batches = None + if num_train_batches is not None and num_train_batches <= 0: + num_train_batches = None + dataset_class, kwargs = get_dataset() + kwargs["embedding_config"] = embedding_table_configs + dataset = HammerToTorchDataset( + dataset=dataset_class(hstu_config=hstu_config, is_inference=False, **kwargs) + ) + # Persistent path: build ONE DataLoader + a stateful sampler whose indices + # are swapped per window, so workers fork once and are reused across all + # windows (eliminates the per-window dataloader respawn + first-batch + # warmup). The non-persistent path recreates a DataLoader per window. + window_sampler: Optional[StreamingWindowSampler] = None + persistent_dl: Optional[DataLoader] = None + world_size = ( + torch.distributed.get_world_size() + if torch.distributed.is_initialized() + else 1 + ) + if persistent_loader: + window_sampler = StreamingWindowSampler(rank=rank, world_size=world_size) + persistent_dl = make_persistent_streaming_dataloader( + dataset=dataset, sampler=window_sampler + ) + + # The fixed user-holdout eval is yambda-specific (needs window_indices + + # the split API). Other streaming datasets (synthetic) keep the legacy + # per-window eval. Detect support once. + supports_holdout = hasattr(dataset.dataset, "eval_holdout_indices") + + # Fixed eval-holdout window range. Captured from the REQUESTED (start_ts, + # num_train_ts) BEFORE the resume block mutates them, so it is identical on + # cold start and on every resume (the supervisor relaunches with the same + # START_TS / NUM_TRAIN_TS). Defaults to the window just past training. + requested_end_ts = start_ts + num_train_ts + # Eval-cadence anchor: the ORIGINAL requested start_ts, captured BEFORE the + # resume block rebases start_ts. `_should_eval` keys the every-N-windows + # cadence off the absolute window ts relative to THIS anchor, so the eval + # grid (e.g. 150,160,170,...) is identical on cold start and on every resume. + # (Keying off the per-call loop index instead would re-anchor the grid to + # whatever window a mid-run resume happens to restart from.) + eval_anchor_ts = start_ts + # None (Python default) or <0 (the env-binding default) both mean "use the + # window just past training", which is stable across resume. + eval_holdout_ts_resolved = ( + eval_holdout_ts + if (eval_holdout_ts is not None and eval_holdout_ts >= 0) + else requested_end_ts + ) + + # Data-fraction eval cadence: convert eval_every_data_pct into a global + # train-step interval ONCE, over the ORIGINAL requested window range + # [eval_anchor_ts, requested_end_ts). Keying the later trigger off + # `global_step % eval_interval_steps` (global_step is monotonic and + # checkpoint-restored) makes the eval grid identical on cold start and on + # every resume, exactly like checkpoint_step_frequency. 0 => disabled. + eval_interval_steps = 0 + if eval_every_data_pct and eval_every_data_pct > 0: + # Per-rank batch size: the persistent loader carries it directly; the + # per-window path uses the same gin %batch_size (env BATCH_SIZE, + # default 1024 — matches make_streaming_dataloader.batch_size). + bs = ( + persistent_dl.batch_size + if persistent_dl is not None + else int(os.environ.get("BATCH_SIZE", "1024")) + ) + if hasattr(dataset.dataset, "total_train_anchors"): + # total_train_anchors does a full-range gather over the mmap'd uid + # array for ~billions of positions + a uid hash. It is slow + # (minutes, single-threaded) AND, run on every rank independently, + # a large per-rank skew source: a fast rank finishes and races into + # the first embedding all-to-all while slow ranks are still hashing, + # desyncing the NCCL collective stream and hanging the job. The + # result is a pure function of the (identical) dataset + split, so + # compute it ONCE on rank 0 and broadcast the scalar; ranks 1..N + # skip the gather entirely (no 8x mmap/CPU contention, no skew). + if world_size > 1 and torch.distributed.is_initialized(): + _tta = ( + dataset.dataset.total_train_anchors( # pyre-ignore[16] + eval_anchor_ts, requested_end_ts - eval_anchor_ts + ) + if rank == 0 + else 0 + ) + _tta_t = torch.tensor([_tta], dtype=torch.int64, device=device) + torch.distributed.broadcast(_tta_t, src=0) + total_train_anchors = int(_tta_t.item()) + else: + total_train_anchors = dataset.dataset.total_train_anchors( # pyre-ignore[16] + eval_anchor_ts, requested_end_ts - eval_anchor_ts + ) + total_train_steps = total_train_anchors // max(1, bs * world_size) + eval_interval_steps = max( + 1, round(eval_every_data_pct * total_train_steps) + ) + if rank == 0: + logger.info( + "[data-pct-eval] eval_every_data_pct=%.6g -> " + "eval_interval_steps=%d (total_train_anchors=%d bs=%d " + "world_size=%d total_train_steps=%d over windows [%d, %d))", + eval_every_data_pct, + eval_interval_steps, + total_train_anchors, + bs, + world_size, + total_train_steps, + eval_anchor_ts, + requested_end_ts, + ) + elif rank == 0: + logger.warning( + "[data-pct-eval] dataset %s has no total_train_anchors(); " + "data-fraction eval is DISABLED (no per-window eval either, " + "since EVAL_EVERY_N_WINDOWS must be 0 to reach here) — only the " + "final eval will run.", + type(dataset.dataset).__name__, + ) + + # The split is an immutable run contract: a silent change across resume + # would both desync the mid-window skip offset AND turn held-out eval users + # into trained users (leakage). Build the live contract and validate the + # one recovered from the checkpoint against it; abort on any mismatch unless + # ALLOW_SPLIT_MISMATCH=1 is set (e.g. deliberately resuming a legacy run). + live_split_contract: Optional[Dict[str, Any]] = None + if supports_holdout: + live_split_contract = { + "train_split_percentage": dataset.dataset._train_split_percentage, # pyre-ignore[16] + "split_salt": dataset.dataset._split_salt, # pyre-ignore[16] + "eval_holdout_ts": eval_holdout_ts_resolved, + "eval_holdout_num_windows": eval_holdout_num_windows, + "batch_size": persistent_dl.batch_size if persistent_dl is not None else None, + "world_size": world_size, + } + # Only validate on an actual resume. On a genuine cold start there is no + # prior split to verify and establishing this run's split is always safe; + # validating there would wrongly reject every fresh holdout run. A resume + # that lacks a contract (legacy/non-streaming checkpoint) is NOT a cold + # start and is still validated (and rejected) below. + if not resume_cold_start: + _validate_split_contract(resume_split_contract, live_split_contract, rank) + + # Apply resume hint: advance start_ts past the last completed window, or + # re-enter the partial window with a per-rank skip on its first iter. + # Shrink num_train_ts by the same amount so the resumed run finishes at + # the same final timestamp (start_ts + num_train_ts) as a fresh run would + # — i.e. resumed and uninterrupted produce identical total work. + first_skip_samples = 0 + if resume_train_ts is not None: + original_end_ts = start_ts + num_train_ts + if resume_batch_idx_in_window == WINDOW_COMPLETE: + new_start = resume_train_ts + 1 + if rank == 0: + logger.info( + "Resuming from completed train_ts=%d → start_ts=%d " + "(num_train_ts %d → %d)", + resume_train_ts, new_start, + num_train_ts, max(0, original_end_ts - new_start), + ) + start_ts = new_start + else: + if rank == 0: + logger.info( + "Resuming mid-window at train_ts=%d batch_idx_in_window=%d " + "(skipping batches already trained)", + resume_train_ts, + resume_batch_idx_in_window, + ) + start_ts = resume_train_ts + # `batch_size` is per-rank from the persistent dataloader (set via + # gin `make_persistent_streaming_dataloader.batch_size`). The + # skip-samples-per-rank below maps "K batches done" → "K * bs + # samples in this rank's index list", since each batch draws bs + # samples from this rank's deterministic round-robin slice. + assert persistent_dl is not None, ( + "Mid-window resume requires persistent_loader=True" + ) + first_skip_samples = resume_batch_idx_in_window * persistent_dl.batch_size + num_train_ts = max(0, original_end_ts - start_ts) + if num_train_ts == 0 and rank == 0: + logger.info( + "Resume target already reached (end_ts=%d, start_ts=%d) — " + "no further training windows; skipping straight to final eval.", + original_end_ts, start_ts, + ) + + if rank == 0: + logger.info( + "[grad-clip] streaming path gradient clipping %s (max_norm=%.4g via $GRAD_CLIP_NORM)", + "ENABLED" if (grad_clip_norm and grad_clip_norm > 0) else "OFF", + grad_clip_norm, + ) + + def _window_iter(ts: int, skip_samples: int = 0): + # TRAIN-only iterator: both branches exclude held-out eval users via + # train_window_indices / set_ts(train_only=True). (Eval uses the fixed + # holdout set, never this helper.) + if persistent_loader: + assert window_sampler is not None and persistent_dl is not None + window_sampler.set_window( + dataset.dataset.train_window_indices(ts), # pyre-ignore [16] + skip_samples=skip_samples, + ) + return iter(persistent_dl) + if skip_samples != 0: + raise NotImplementedError( + "skip_samples>0 requires persistent_loader=True" + ) + return iter( + make_streaming_dataloader(dataset=dataset, ts=ts, train_only=True) + ) + # Windows are [start_ts, start_ts + num_train_ts); each step trains window T + # then evals window T+1, so the last eval window is start_ts + num_train_ts, + # which must be < num_windows(). Anchors require >= history_length prior + # events, so the earliest windows are near-empty warm-up — use start_ts to + # begin at a dense window. Clamp instead of failing. + if hasattr(dataset.dataset, "num_windows"): + available = dataset.dataset.num_windows() # pyre-ignore [16] + max_count = max(0, available - 1 - start_ts) + if num_train_ts > max_count: + logger.warning( + f"start_ts={start_ts} + num_train_ts={num_train_ts} exceeds " + f"available windows ({available}); clamping num_train_ts to {max_count}." + ) + num_train_ts = max_count + # Wall-clock anchor for time-based checkpointing. Mutable single-element + # list so the nested train loop can reset it after each save. Starts at + # loop entry so the first time-trigger fires ~interval seconds in. + last_ckpt_time = [time.time()] + + def _broadcast_elapsed() -> float: + """Seconds since the last save, owned by rank 0 and broadcast to all + ranks. save_dmp_checkpoint runs a collective barrier, so every rank must + feed the same wall-clock value into the cadence decision — otherwise a + split verdict (rank 0 saves, rank 1 doesn't) would deadlock. Broadcasting + rank 0's elapsed keeps the (pure) decision identical everywhere.""" + elapsed = time.time() - last_ckpt_time[0] + if torch.distributed.is_initialized() and world_size > 1: + t = torch.tensor([elapsed], device=device, dtype=torch.float64) + torch.distributed.broadcast(t, src=0) + elapsed = float(t.item()) + return elapsed + + def _save_mid_window(train_ts: int, batch_idx_in_window: int) -> None: + """In-window checkpoint helper. Snapshots the same state as the + end-of-window save but stamps `batch_idx_in_window=K` instead of + WINDOW_COMPLETE so the resume path knows to skip K batches. + Uses train_ts as the numeric subdir name — every save into the same + train_ts overwrites the previous in-window snapshot (via atomic + replace), so disk stays bounded to keep_last_n train_ts dirs.""" + save_dmp_checkpoint( + model=model, + optimizer=optimizer, + metric_logger=metric_logger, + rank=rank, + batch_idx=train_ts, + train_ts=train_ts, + batch_idx_in_window=batch_idx_in_window, + device=device, + split_contract=live_split_contract, + ) + + def _run_train_window( + train_data_iterator, + train_ts: int, + start_batch_idx: int = 0, + label: Optional[str] = None, + do_eval: Optional[Callable[[int, int], None]] = None, + ) -> None: + # `start_batch_idx` is set when we're re-entering a window that was + # interrupted mid-way (in_window resume); the dataloader iterator was + # already advanced past those batches via the sampler skip, and we + # account for them in the local counter so in-window saves and the + # die_at_step hook fire at the right relative offsets. + train_batch_idx = start_batch_idx + first_wait: Optional[float] = None + while True: + model.train() + _t_next = time.perf_counter() if (label and rank == 0) else None + try: + sample = next(train_data_iterator) + except StopIteration: + break + if _t_next is not None and first_wait is None: + first_wait = time.perf_counter() - _t_next + if streaming_diag_unique_emb and train_batch_idx < int( + os.environ.get("DIAG_EMB_STEPS", "100") + ): + _log_unique_embedding_diag( + sample, + rank, + train_batch_idx, + max_steps=int(os.environ.get("DIAG_EMB_STEPS", "100")), + log_every=metric_log_frequency, + ) + optimizer.zero_grad() + sample.to(device) + ( + _, + _, + aux_losses, + mt_target_preds, + mt_target_labels, + mt_target_weights, + ) = model.forward( + sample.uih_features_kjt, + sample.candidates_features_kjt, + ) + # pyre-ignore + sum(aux_losses.values()).backward() + # Gradient clipping for the streaming path. Clips dense params (the + # sparse embedding tables use a fused optimizer and are unaffected, + # same as the non-streaming path's clip_grad_norm_). OFF by default + # (grad_clip_norm=0.0 via $GRAD_CLIP_NORM) so legacy streaming runs + # are byte-for-byte unchanged; set >0 to enable. + if grad_clip_norm and grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_( + model.parameters(), max_norm=grad_clip_norm + ) + optimizer.step() + metric_logger.update( + mode="train", + predictions=mt_target_preds, + labels=mt_target_labels, + weights=mt_target_weights, + num_candidates=sample.candidates_features_kjt.lengths().view( + len(sample.candidates_features_kjt.keys()), -1 + )[0], + ) + # MLPerf train_loss event on its own cadence (decoupled from the + # console/TB metric cadence below). Called every step; the cross-rank + # all-reduce only fires on the cadence, gated by the rank-identical + # global_step inside the method, so it stays in lockstep. + metric_logger.maybe_log_mlperf_train_loss(aux_losses, every=mlperf_loss_every) + if train_batch_idx % metric_log_frequency == 0: + metric_logger.compute_and_log( + mode="train", + additional_logs={ + "losses": aux_losses, + }, + ) + train_batch_idx += 1 + if output_trace: + assert profiler is not None + profiler.step() + # Fine-grained in-window checkpoint triggers. All stamp + # batch_idx_in_window so a crash here leaves a resumable partial + # checkpoint, and all fire AFTER the metric update so restored + # state reflects the just-completed batch. Triggers are mutually + # short-circuited (one save per batch max) but evaluated on the + # same deterministic counters across all ranks, so the collective + # inside save_dmp_checkpoint stays in lockstep. + gstep = metric_logger.global_step["train"] + # Wall-clock elapsed is broadcast from rank 0 so every rank feeds + # the same value into the (otherwise pure) cadence decision. + elapsed = ( + _broadcast_elapsed() if checkpoint_time_interval_s > 0 else 0.0 + ) + save_reason = select_in_window_checkpoint_reason( + train_batch_idx=train_batch_idx, + global_step=gstep, + elapsed_since_last_save=elapsed, + in_window_checkpoint_frequency=in_window_checkpoint_frequency, + checkpoint_step_frequency=checkpoint_step_frequency, + checkpoint_time_interval_s=checkpoint_time_interval_s, + ) + if save_reason is not None: + if rank == 0: + logger.info( + "checkpoint trigger=%s train_ts=%d batch=%d global_step=%d", + save_reason, + train_ts, + train_batch_idx, + gstep, + ) + _save_mid_window(train_ts, train_batch_idx) + # Reset the wall-clock anchor on ANY save so the next time + # trigger is measured from the most recent checkpoint. + last_ckpt_time[0] = time.time() + # Data-fraction eval cadence: run the full-holdout eval whenever the + # monotonic global step crosses a multiple of eval_interval_steps + # (i.e. every eval_every_data_pct of the training data). Keyed off + # global_step (checkpoint-restored) so the eval grid is identical + # across resume. Mid-window-safe: eval sets model.eval(), so restore + # train mode + dataset.is_eval afterward. do_eval is None unless the + # data-pct cadence is enabled. + if ( + do_eval is not None + and eval_interval_steps > 0 + and gstep > 0 + and gstep % eval_interval_steps == 0 + ): + if rank == 0: + logger.info( + "[data-pct-eval] trigger eval train_ts=%d global_step=%d " + "(interval=%d)", + train_ts, + gstep, + eval_interval_steps, + ) + do_eval(train_ts, gstep) + model.train() + dataset.dataset.is_eval = False # pyre-ignore [16] + # Data-fraction eval may hit the MLPerf target and emit RUN_STOP + # (via _do_eval_*). Stop the window immediately so we don't train + # past the convergence point; the outer window loop checks the + # same flag and breaks too. + if mlt.run_stopped: + break + # Test-only: deterministic crash for the failure-injection test. + # Triggered AFTER the save above, so on resume we re-enter at + # batch_idx_in_window=train_batch_idx and emit batches [K+1, end). + if ( + die_at_step >= 0 + and metric_logger.global_step["train"] >= die_at_step + ): + if rank == 0: + logger.warning( + "die_at_step=%d hit at train_ts=%d batch=%d global_step=%d " + "→ sys.exit(42)", + die_at_step, + train_ts, + train_batch_idx, + metric_logger.global_step["train"], + ) + # Distributed barrier so all ranks exit together rather than + # leaving a few ranks hanging on NCCL ops. + torch.distributed.barrier() + import sys + sys.exit(42) + # `num_train_batches` cap: None or 0 = run the whole window. >0 caps + # batches per window (mostly the streaming-resume test driver uses + # this to keep test windows short). + if num_train_batches and train_batch_idx >= num_train_batches: + break + if label and rank == 0 and first_wait is not None: + logger.info( + f"[boundary] {label} train first-batch data-wait={first_wait * 1000:.1f}ms" + ) + + def _run_eval_window( + eval_data_iterator, label: Optional[str] = None + ) -> Dict[str, float]: + # DO NOT add a checkpoint trigger anywhere inside this function. The eval + # data iterator's position is not serializable, so a checkpoint taken + # mid-eval could not be resumed deterministically. `_maybe_checkpoint` + # only fires after a completed eval window or mid-train-window, so any + # restored state always sits on a completed-eval boundary -- which is + # also why the eval reset below is safe across resume. + # + # Exclude this eval pass's wall-time from the train step-time window so + # step_ms stays canonical even when eval coincides with a train interval; + # the duration is reported separately (window_eval_time_ms + total_eval + # below). Resumed unconditionally at the end of this function. + metric_logger.pause_perf("eval") + model.eval() + # Reset eval metrics so each pass reports a clean number over the FIXED + # holdout set. Without this, lifetime/window eval metrics would keep + # accumulating across eval steps (the old behavior, made worse now that + # every step sees the identical set), making the eval-AUC trajectory + # uninterpretable. With the reset, each eval point == AUC over the whole + # fixed holdout at that train step -> directly comparable across steps. + metric_logger.reset(mode="eval") + eval_batch_idx = 0 + first_wait: Optional[float] = None + _t_enter = time.perf_counter() if (label and rank == 0) else None + with torch.no_grad(): + while True: + _t_next = time.perf_counter() if (label and rank == 0) else None + try: + sample = next(eval_data_iterator) + except StopIteration: + break + if _t_next is not None and first_wait is None: + first_wait = time.perf_counter() - _t_next + sample.to(device) + ( + _, + _, + _, + mt_target_preds, + mt_target_labels, + mt_target_weights, + ) = model.forward( + sample.uih_features_kjt, + sample.candidates_features_kjt, + ) + metric_logger.update( + mode="eval", + predictions=mt_target_preds, + labels=mt_target_labels, + weights=mt_target_weights, + num_candidates=sample.candidates_features_kjt.lengths().view( + len(sample.candidates_features_kjt.keys()), -1 + )[0], + ) + eval_batch_idx += 1 + if output_trace: + assert profiler is not None + profiler.step() + if eval_batch_idx % metric_log_frequency == 0: + metric_logger.compute_and_log(mode="eval") + if num_eval_batches is not None and eval_batch_idx >= num_eval_batches: + break + _eval_metrics = metric_logger.compute(mode="eval") + for k, v in _eval_metrics.items(): + print(f"{k}: {v}") + if label and rank == 0 and _t_enter is not None: + _eval_total = time.perf_counter() - _t_enter + _fw = (first_wait * 1000) if first_wait is not None else float("nan") + logger.info( + f"[boundary] {label} eval first-batch data-wait={_fw:.1f}ms " + f"total_eval={_eval_total * 1000:.1f}ms batches={eval_batch_idx}" + ) + # Dedicated per-eval metrics sink. One JSON line per eval boundary + # capturing the END-OF-PASS metric over the FIXED holdout -- the single + # correct value for that eval point (no interim/averaging ambiguity). + # Rank 0 only; append-only so it survives restarts and the trajectory + # accumulates across resumes. Written next to the main run log + # (".metrics.jsonl"), falling back to cwd if LOG is unset. + import json + import re as _re + + _log = os.environ.get("LOG") + if _log: + _base = _log[:-4] if _log.endswith(".log") else _log + _metrics_path = f"{_base}.metrics.jsonl" + else: + _metrics_path = "streaming_eval_metrics.jsonl" + _ts_m = _re.search(r"train_ts=(\d+)", label) + _rec = { + "label": label, + "train_ts": int(_ts_m.group(1)) if _ts_m else None, + "global_step": int(metric_logger.global_step.get("train", -1)), + "eval_batches": eval_batch_idx, + "total_eval_ms": round(_eval_total * 1000, 1), + "wall_time": time.time(), + } + for _k, _v in _eval_metrics.items(): + try: + _rec[_k] = float(_v) + except (TypeError, ValueError): + pass + try: + with open(_metrics_path, "a") as _f: + _f.write(json.dumps(_rec) + "\n") + except OSError as _e: + logger.warning("failed to write metrics sink %s: %s", _metrics_path, _e) + metric_logger.resume_perf("eval") + # Return metrics (on every rank) so the MLPerf eval hooks can consume them. + return _eval_metrics + + def _maybe_checkpoint(train_ts: int) -> None: + if ( + train_ts % checkpoint_frequency == 0 and train_ts > 0 + ) or train_ts == start_ts + num_train_ts - 1: + # End-of-window save: stamp WINDOW_COMPLETE so resume advances past + # this train_ts. `device` enables per-rank RNG snapshot for + # bit-equal resume of dropout-bearing modules. + save_dmp_checkpoint( + model=model, + optimizer=optimizer, + metric_logger=metric_logger, + rank=rank, + batch_idx=train_ts, + train_ts=train_ts, + batch_idx_in_window=WINDOW_COMPLETE, + device=device, + split_contract=live_split_contract, + ) + last_ckpt_time[0] = time.time() + + # Apply start_ts shift from resume (may have moved past the original start). + # num_train_ts is the requested *count*; preserve it so the loop runs for + # the same total number of windows post-resume as a fresh run would have. + train_ts_list = list(range(start_ts, start_ts + num_train_ts)) + n_train = len(train_ts_list) + + def _should_eval(i: int) -> bool: + """Whether to run the full-holdout eval after training window index `i`. + + Single cadence knob `eval_every_n_windows`: + * <=0 -> eval disabled entirely (train-only; e.g. perf benchmarking or + the resume test). The eval dataloader is not even built. + * 1 (default) -> eval after every window. + * K>1 -> eval when the ABSOLUTE window ts is on the grid anchored at + `eval_anchor_ts` (the original start_ts), i.e. ts in {anchor, + anchor+K, anchor+2K, ...}, and ALWAYS on the final window so the + trajectory ends with an eval point. Anchoring to the absolute ts + (not the per-call loop index `i`) keeps the eval grid (e.g. + 150,160,170,...) stable across a mid-run resume, which rebases + start_ts/`train_ts_list` to the resume window. + """ + if eval_every_n_windows <= 0: + return False + if eval_every_n_windows == 1: + return True + return (train_ts_list[i] - eval_anchor_ts) % eval_every_n_windows == 0 or i == n_train - 1 + + # Fixed eval set: held-out users' anchors over the resolved holdout window + # range, computed ONCE and reused at every eval step. Same anchors every + # step -> stable, comparable eval-AUC curve, and bounded eval time + # (~(1 - train_split_percentage) of a window). Cached inside the dataset so + # re-deriving it (e.g. on resume) returns the identical set. None for + # datasets without holdout support (synthetic) -> legacy per-window eval. + eval_global_indices: Optional["np.ndarray"] = None + if supports_holdout: + eval_global_indices = dataset.dataset.eval_holdout_indices( # pyre-ignore [16] + eval_holdout_ts_resolved, eval_holdout_num_windows + ) + if rank == 0: + logger.info( + "Fixed eval holdout: ts=[%d, %d) -> %d anchors (train_split_percentage=%s)", + eval_holdout_ts_resolved, + eval_holdout_ts_resolved + eval_holdout_num_windows, + len(eval_global_indices), + dataset.dataset._train_split_percentage, # pyre-ignore[16] + ) + + # --- MLPerf run tracking -------------------------------------------------- + # total_train_samples = epoch_num denominator (global trainable samples over + # the window range), computed once and logged as TRAIN_SAMPLES. + total_train_samples = 0 + if mlperf_logger is not None: + _idx_fn = getattr( + dataset.dataset, "train_window_indices", None + ) or getattr(dataset.dataset, "window_indices", None) + if _idx_fn is not None: + for _ts in train_ts_list: + total_train_samples += int(_idx_fn(_ts).size) + if rank == 0: + logger.info( + "MLPerf: total_train_samples=%d over %d windows", + total_train_samples, + n_train, + ) + # Wire the logger + LR getter so MetricsLogger.compute emits train_loss. + metric_logger.mlperf_logger = mlperf_logger + + def _current_lr() -> float: + return float(optimizer.param_groups[0]["lr"]) + + metric_logger.lr_getter = _current_lr + + # Centralized MLPerf run-boundary state machine: owns block/eval/run markers, + # SAMPLES_COUNT/EPOCH_NUM progress metadata, and the per-window-AUC vs + # auc_threshold convergence decision. Every method no-ops when mlperf_logger + # is None, so the loop below calls them unconditionally. + from generative_recommenders.dlrm_v3.train.mlperf_logging_utils import ( + MLPerfRunTracker, + ) + + mlt = MLPerfRunTracker( + logger=mlperf_logger, + metric_logger=metric_logger, + total_train_samples=total_train_samples, + rank=rank, + device=device, + ) + mlt.log_dataset_sizes( + eval_samples=eval_global_indices.size + if eval_global_indices is not None + else None + ) + + if persistent_loader and double_buffer: + # Double-buffered: next window prepared in the background during the + # current window's compute. Eval (if enabled) uses its own pre-forked + # pool, primed up front on the main thread so no fork races a bg thread. + prefetcher = _PrefetchingWindowLoader( + dataset=dataset, + sampler_factory=lambda: StreamingWindowSampler(rank, world_size), + dl_factory=lambda s: make_persistent_streaming_dataloader( + dataset=dataset, sampler=s + ), + ) + eval_sampler: Optional[StreamingWindowSampler] = None + eval_dl: Optional[DataLoader] = None + # Eval iterator is built one window ahead: the eval pool (idle while the + # current train window runs) prefetches the next eval window's first + # batches concurrently with train compute, so eval starts warm. yambda's + # sample content depends only on the sampler window, not is_eval, so + # prefetching during train is safe. + eval_iter: Optional[Iterator] = None + # Build/fork the eval pool when EITHER cadence needs it: the per-window + # cadence (eval_every_n_windows>0) or the data-fraction cadence + # (eval_interval_steps>0). Both are never simultaneously on (validated + # at entry), so this is "eval is enabled at all". + if (eval_every_n_windows > 0 or eval_interval_steps > 0) and len( + train_ts_list + ) > 0: + eval_sampler = StreamingWindowSampler(rank, world_size) + eval_dl = make_persistent_streaming_dataloader( + dataset=dataset, sampler=eval_sampler + ) + # CRITICAL: fork the eval worker pool HERE, on the main thread, + # BEFORE prefetcher.stream() below spins up its background prep + # thread. The pool is persistent_workers=True, so this first iter() + # is the ONLY fork; every later iter() merely resets and reuses these + # workers (no fork), so it can never deadlock against the background + # thread holding an allocator/GIL-released lock. (Deferring this + # first fork into the loop — as a sparse-eval cadence naively might — + # hangs the run.) _should_eval(0) is always True when eval is enabled + # (0 % K == 0). The eval set is the FIXED holdout (same every step), + # so we install it on the sampler ONCE here; later evals just call + # iter() again to replay the identical set (no set_window churn). + eval_sampler.set_window(eval_global_indices) + eval_iter = iter(eval_dl) + + # Data-fraction eval callback (double-buffer path). Fired mid-window by + # _run_train_window on the global-step cadence. Reuses the already-forked + # persistent eval pool: iter(eval_dl) here runs on the MAIN thread (a + # reset, not a fork — the only fork was the up-front iter() above), so it + # stays safe alongside the background window-prefetch thread. + def _do_eval_db(train_ts: int, gstep: int) -> None: + # Data-fraction eval boundary: this closes the current MLPerf block, + # runs the holdout eval with full EVAL_START/EVAL_STOP + EVAL_ACCURACY + # + convergence, then opens the next block. The block thus brackets + # exactly one eval_interval_steps of training (MLPerf block == work + # between two evals), instead of one timestamp window. + dataset.dataset.is_eval = True # pyre-ignore [16] + assert eval_dl is not None + mlt.block_stop() + mlt.eval_start() + eval_metrics = _run_eval_window( + iter(eval_dl), + label=f"eval_holdout@train_ts={train_ts}@step={gstep}", + ) + # Emits RUN_STOP (sets mlt.run_stopped) if the target is met; + # _run_train_window / the window loop break on that flag. + mlt.eval_stop(eval_metrics) + if not mlt.run_stopped: + mlt.block_start() + + _db_do_eval = _do_eval_db if eval_interval_steps > 0 else None + + # Block placement depends on the eval cadence. Per-window cadence + # (eval_every_n_windows>0): one block per timestamp window. Otherwise + # (data-fraction cadence, or no eval): a single block spans the whole + # run, split at each data-fraction eval boundary by _do_eval_db. Open + # the first block here for the latter; the boundary helper + the + # post-loop stop handle the rest. + _per_window_blocks = eval_every_n_windows > 0 + if not _per_window_blocks: + mlt.block_start() + + for i, (train_ts, train_data_iterator) in enumerate( + # Only the FIRST window after a mid-window resume needs the skip + # (handed via prefetcher.stream's first_skip_samples). The skip is + # zero on cold start (resume_train_ts is None → first_skip_samples=0) + # and on completed-window resume (mid-window slice is 0 too). + prefetcher.stream(train_ts_list, first_skip_samples=first_skip_samples) + ): + dataset.dataset.is_eval = False # pyre-ignore [16] + # First iteration after a mid-window resume carries + # resume_batch_idx_in_window so in-window saves and the die_at_step + # hook keep accurate counters; otherwise count from 0. + start_batch = ( + resume_batch_idx_in_window + if i == 0 and resume_batch_idx_in_window > 0 + else 0 + ) + # Rendezvous all ranks at the window boundary BEFORE the first + # forward of this window. The prefetcher has already handed back a + # ready iterator (this window's window_indices mmap scan is done), + # but that O(N) scan over the ~18 GB anchor_ts array can finish at + # very different times across ranks. Without this barrier a fast + # rank issues the first embedding all-to-all while a slow rank is + # still in prep, desyncing the NCCL collective stream and hanging + # the job (observed at a window boundary via the flight recorder: + # ranks split across consecutive collective seq ids). This only + # absorbs prep skew (one near-zero sync per window); it does not + # serialize the background prefetch of future windows. + _window_boundary_barrier(device, world_size, train_ts) + if _per_window_blocks: + mlt.block_start() + _run_train_window( + train_data_iterator, + train_ts=train_ts, + start_batch_idx=start_batch, + label=f"train_ts={train_ts}", + do_eval=_db_do_eval, + ) + if _per_window_blocks: + mlt.block_stop() + should_stop = False + if _should_eval(i): + dataset.dataset.is_eval = True # pyre-ignore [16] + assert eval_sampler is not None and eval_dl is not None + mlt.eval_start() + eval_metrics = _run_eval_window( + eval_iter, label=f"eval_holdout@train_ts={train_ts}" + ) + should_stop = mlt.eval_stop(eval_metrics) + # Re-arm the (already-forked) eval pool for the NEXT eval. The + # holdout set is fixed, so the sampler window is unchanged; we + # only need a fresh iter() to replay it. iter() reuses the + # persistent workers — no fork, safe alongside the bg thread. + next_eval_i = next( + (j for j in range(i + 1, n_train) if _should_eval(j)), None + ) + if next_eval_i is not None: + eval_iter = iter(eval_dl) + _maybe_checkpoint(train_ts) + # should_stop: per-window convergence. mlt.run_stopped: + # data-fraction convergence (RUN_STOP fired mid-window by _do_eval_db). + if should_stop or mlt.run_stopped: + # MLPerf target reached: RUN_STOP already emitted; stop training. + break + + # Close the run-spanning block for the data-fraction / no-eval case. + # Idempotent: a no-op if the last eval boundary already closed it (i.e. + # convergence stopped the run) or if per-window blocks were used. + if not _per_window_blocks: + mlt.block_stop() + else: + # Data-fraction eval callback (non-double-buffer path). Builds a fresh + # eval dataloader per call over the FIXED holdout set (or the legacy + # next-window eval when the dataset has no holdout support). + def _do_eval_nb(train_ts: int, gstep: int) -> None: + # Data-fraction eval boundary (non-double-buffer path). See _do_eval_db: + # close the current MLPerf block, run the eval with full markers + + # convergence, then open the next block so a block brackets one + # eval_interval_steps of training rather than a timestamp window. + dataset.dataset.is_eval = True # pyre-ignore [16] + mlt.block_stop() + mlt.eval_start() + if eval_global_indices is not None: + eval_metrics = _run_eval_window( + iter( + make_streaming_dataloader( + dataset=dataset, indices=eval_global_indices + ) + ), + label=f"eval_holdout@train_ts={train_ts}@step={gstep}", + ) + else: + eval_metrics = _run_eval_window( + iter(make_streaming_dataloader(dataset=dataset, ts=train_ts + 1)), + label=f"eval@train_ts={train_ts}@step={gstep}", + ) + mlt.eval_stop(eval_metrics) + if not mlt.run_stopped: + mlt.block_start() + + _nb_do_eval = _do_eval_nb if eval_interval_steps > 0 else None + + # See the double-buffer branch: per-window blocks for the per-window + # cadence, else a single run-spanning block split at data-fraction eval + # boundaries by _do_eval_nb. + _per_window_blocks = eval_every_n_windows > 0 + if not _per_window_blocks: + mlt.block_start() + + for i, train_ts in enumerate(train_ts_list): + dataset.dataset.is_eval = False # pyre-ignore [16] + skip = first_skip_samples if i == 0 else 0 + start_batch = ( + resume_batch_idx_in_window + if i == 0 and resume_batch_idx_in_window > 0 + else 0 + ) + # See the double-buffer path: rendezvous all ranks at the window + # boundary before the first forward so per-rank data-prep skew + # cannot desync the NCCL collective stream and hang the job. + _window_boundary_barrier(device, world_size, train_ts) + if _per_window_blocks: + mlt.block_start() + _run_train_window( + _window_iter(train_ts, skip_samples=skip), + train_ts=train_ts, + start_batch_idx=start_batch, + do_eval=_nb_do_eval, + ) + if _per_window_blocks: + mlt.block_stop() + should_stop = False + if _should_eval(i): + dataset.dataset.is_eval = True # pyre-ignore [16] + mlt.eval_start() + if eval_global_indices is not None: + eval_metrics = _run_eval_window( + iter( + make_streaming_dataloader( + dataset=dataset, indices=eval_global_indices + ) + ), + label=f"eval_holdout@train_ts={train_ts}", + ) + else: + # Legacy per-window eval (datasets without user holdout). + eval_metrics = _run_eval_window( + iter(make_streaming_dataloader(dataset=dataset, ts=train_ts + 1)) + ) + should_stop = mlt.eval_stop(eval_metrics) + _maybe_checkpoint(train_ts) + # should_stop: per-window convergence. mlt.run_stopped: + # data-fraction convergence (RUN_STOP fired mid-window by _do_eval_nb). + if should_stop or mlt.run_stopped: + # MLPerf target reached: RUN_STOP already emitted; stop training. + break + + # Close the run-spanning block for the data-fraction / no-eval case + # (idempotent; no-op under per-window blocks or after a convergence stop). + if not _per_window_blocks: + mlt.block_stop() + + # Final eval over the fixed user-holdout set (legacy final-window eval + # otherwise). Skipped if the MLPerf target already stopped the run mid-run. + if not mlt.run_stopped: + dataset.dataset.is_eval = True # pyre-ignore [16] + mlt.eval_start() + if eval_global_indices is not None: + final_metrics = _run_eval_window( + iter( + make_streaming_dataloader( + dataset=dataset, indices=eval_global_indices + ) + ), + label="eval_holdout@final", + ) + else: + final_metrics = _run_eval_window( + iter(make_streaming_dataloader(dataset=dataset, ts=num_train_ts)), + label="eval@final", + ) + mlt.eval_stop(final_metrics) + if rank == 0: + for k, v in final_metrics.items(): + print(f"{k}: {v}") + # End-of-run RUN_STOP: SUCCESS if final metric met target, else ABORTED. + mlt.finalize(final_metrics) diff --git a/recommendation_v4/generative_recommenders/dlrm_v3/utils.py b/recommendation_v4/generative_recommenders/dlrm_v3/utils.py new file mode 100644 index 000000000..1e5d79993 --- /dev/null +++ b/recommendation_v4/generative_recommenders/dlrm_v3/utils.py @@ -0,0 +1,1806 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-unsafe +""" +mlperf dlrm_v3 inference benchmarking tool. +""" + +import contextlib +import logging +import os +import time +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional + +import gin +import tensorboard # @manual=//tensorboard:lib # noqa: F401 - required implicit dep when using torch.utils.tensorboard +import torch +from generative_recommenders.dlrm_v3.datasets.dataset import DLRMv3RandomDataset +from generative_recommenders.dlrm_v3.datasets.kuairand import DLRMv3KuaiRandDataset +from generative_recommenders.dlrm_v3.datasets.movie_lens import DLRMv3MovieLensDataset +from generative_recommenders.dlrm_v3.datasets.synthetic_movie_lens import ( + DLRMv3SyntheticMovieLensDataset, +) +from generative_recommenders.dlrm_v3.datasets.synthetic_streaming import ( + DLRMv3SyntheticStreamingDataset, +) +from generative_recommenders.dlrm_v3.datasets.yambda import DLRMv3YambdaDataset +from generative_recommenders.modules.multitask_module import ( + MultitaskTaskType, + TaskConfig, +) +from torch.profiler import profile, profiler, ProfilerActivity # pyre-ignore [21] +from torch.utils.tensorboard import SummaryWriter +from torchrec.metrics.accuracy import AccuracyMetricComputation +from torchrec.metrics.auc import AUCMetricComputation, compute_auc +from torchrec.metrics.gauc import GAUCMetricComputation +from torchrec.metrics.mae import MAEMetricComputation +from torchrec.metrics.metrics_namespace import MetricName, MetricPrefix +from torchrec.metrics.mse import MSEMetricComputation +from torchrec.metrics.ne import NEMetricComputation +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetricComputation, +) + + +class LifetimeAUCMetricComputation(AUCMetricComputation): + """AUC over a 10M-sample (~5.5 eval-window) trailing buffer; emits with the + LIFETIME prefix. + + NOTE: despite the name, this is NOT an uncapped since-step-0 AUC. The parent + ``AUCMetricComputation`` evicts the prediction/label/weight buffers down to + ``window_size`` in ``update()``; we instantiate it with + ``window_size=10_000_000``, so "lifetime" is a ~10M-sample trailing window. + Raise ``window_size`` (accepting unbounded buffer growth) if true cumulative + AUC is ever required. + + Checkpoint correctness: torchrec registers the PREDICTIONS/LABELS/WEIGHTS + buffers with ``persistent=False`` (so the default ``state_dict()`` drops + them) and tracks a separate ``self._num_samples`` counter. Without the + overrides below, every checkpoint resume would silently restart this metric + from an empty buffer. We therefore serialize the buffers AND ``_num_samples`` + explicitly; restoring ``_num_samples`` is mandatory, since leaving it at 0 + makes the next ``update()`` take the init-sentinel branch and desync the + windowed eviction. These buffers are per-rank-local (cross-rank gather only + happens transiently at compute time), so the checkpoint layer MUST persist + and restore them per-rank — see ``checkpoint.py``. + """ + + # Prefix used for the explicitly-serialized non-persistent buffers so the + # keys can't collide with any persistent state the parent might register. + _LIFETIME_KEY_PREFIX: str = "_lifetime_" + + def _compute(self) -> List[MetricComputationReport]: + from typing import cast as _cast + from torchrec.metrics.auc import LABELS, PREDICTIONS, WEIGHTS + return [ + MetricComputationReport( + name=MetricName.AUC, + metric_prefix=MetricPrefix.LIFETIME, + value=compute_auc( + self._n_tasks, + _cast(List[torch.Tensor], getattr(self, PREDICTIONS)), + _cast(List[torch.Tensor], getattr(self, LABELS)), + _cast(List[torch.Tensor], getattr(self, WEIGHTS)), + self._apply_bin, + ), + ) + ] + + def lifetime_sample_count(self) -> int: + """Current number of buffered samples (greppable for sanity logs).""" + return int(getattr(self, "_num_samples", 0)) + + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + ) -> Dict[str, Any]: + from torchrec.metrics.auc import LABELS, PREDICTIONS, WEIGHTS + + destination = super().state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + # The parent registers these buffers persistent=False, so they are absent + # from `destination`. Concatenate each buffer list to one (n_tasks, N) + # tensor and serialize it alongside the sample counter. + for attr in (PREDICTIONS, LABELS, WEIGHTS): + buf = getattr(self, attr) + if isinstance(buf, (list, tuple)) and len(buf) > 0: + flat = torch.cat([t for t in buf], dim=-1) + elif isinstance(buf, torch.Tensor): + flat = buf + else: + flat = torch.empty(0) + destination[prefix + self._LIFETIME_KEY_PREFIX + attr] = ( + flat.detach().cpu().clone() + ) + destination[prefix + self._LIFETIME_KEY_PREFIX + "num_samples"] = ( + torch.tensor(int(getattr(self, "_num_samples", 0)), dtype=torch.long) + ) + return destination + + def load_state_dict( + self, + state_dict: Dict[str, Any], + strict: bool = True, + ) -> Any: + from torchrec.metrics.auc import LABELS, PREDICTIONS, WEIGHTS + + # Copy so we can strip our custom keys before delegating to the parent + # (whose strict load would otherwise reject them as unexpected). + remaining = dict(state_dict) + saved_bufs: Dict[str, torch.Tensor] = {} + for attr in (PREDICTIONS, LABELS, WEIGHTS): + key = self._LIFETIME_KEY_PREFIX + attr + if key in remaining: + saved_bufs[attr] = remaining.pop(key) + num_key = self._LIFETIME_KEY_PREFIX + "num_samples" + saved_num = remaining.pop(num_key, None) + + result = super().load_state_dict(remaining, strict=strict) + + if saved_bufs: + # Device of the live (init-sentinel) buffers; keep restored buffers + # co-located so subsequent update()/compute() stay on-device. + existing = getattr(self, PREDICTIONS) + dev = ( + existing[0].device + if isinstance(existing, (list, tuple)) and len(existing) > 0 + else torch.device("cpu") + ) + for attr, val in saved_bufs.items(): + setattr(self, attr, [val.to(dev)]) + if saved_num is not None: + self._num_samples = int(saved_num.item()) + return result + + +# Sentinel "window size" used for the FRESH eval metrics so torchrec's windowed +# eviction never fires within a single eval pass (the per-pass reset bounds the +# buffer to exactly one full holdout pass). 1<<60 is far above any realistic +# per-rank sample count and avoids sys.maxsize overflow inside torchrec math. +UNBOUNDED_WINDOW: int = 1 << 60 + + +class BinnedCumulativeAUC(RecMetricComputation): + """Cumulative AUC via a fixed-resolution score histogram (LIFETIME prefix). + + Global AUC is a rank statistic, so it has no fixed-size additive sufficient + statistic the way NE/Accuracy do - exact cumulative AUC otherwise needs every + (score, label) pair retained and sorted (the buffer-based ``AUCMetricComputation`` + / ``LifetimeAUCMetricComputation``). Instead we keep two weighted histograms of + positive/negative mass per score bin. This gives an AUC exact up to bin width + with O(num_bins) memory that does NOT grow with sample count, and - because + histograms are additive - cross-rank sync is a cheap all-reduce (dist_reduce_fx + "sum") rather than all-gathering millions of predictions. The state is truly + cumulative across all eval passes (never evicted, never reset on eval). + + Predictions MUST be probabilities in [0, 1] (the same tensor feeds NE, which + requires probabilities; the model applies sigmoid in multitask_module). Values + are clamped into [0, 1] defensively. + """ + + def __init__(self, *args, num_bins: int = 100_000, **kwargs) -> None: + # window_size is irrelevant here (no windowed state); pass through. + super().__init__(*args, **kwargs) + self._num_bins: int = int(num_bins) + self._add_state( + "pos_hist", + torch.zeros((self._n_tasks, self._num_bins), dtype=torch.float64), + add_window_state=False, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "neg_hist", + torch.zeros((self._n_tasks, self._num_bins), dtype=torch.float64), + add_window_state=False, + dist_reduce_fx="sum", + persistent=True, + ) + + def cumulative_sample_count(self) -> int: + """Total weighted samples in the histograms (greppable for sanity logs).""" + return int((self.pos_hist.sum() + self.neg_hist.sum()).item()) + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + if predictions is None or weights is None: + raise ValueError( + "BinnedCumulativeAUC.update requires predictions and weights" + ) + preds = predictions.float().clamp_(0.0, 1.0) # (n_tasks, n_examples) + labels = labels.float() + weights = weights.float() + # Bin index per example; the top edge (p==1.0) folds into the last bin. + idx = (preds * self._num_bins).long().clamp_(0, self._num_bins - 1) + pos_w = (weights * labels).to(self.pos_hist.dtype) + neg_w = (weights * (1.0 - labels)).to(self.neg_hist.dtype) + self.pos_hist.scatter_add_(1, idx, pos_w) + self.neg_hist.scatter_add_(1, idx, neg_w) + + def _compute(self) -> List[MetricComputationReport]: + # By compute() time torchmetrics has all-reduced (summed) the histograms + # across ranks, so these are the global per-bin masses. + pos = self.pos_hist # (n_tasks, num_bins) + neg = self.neg_hist + total_pos = pos.sum(dim=1) + total_neg = neg.sum(dim=1) + # Lower bin index == lower score. A positive in bin b outranks every + # negative in bins < b (exclusive prefix sum), and ties in bin b score + # 0.5. AUC = sum_b pos_b * (neg_below_b + 0.5*neg_b) / (P * N). + neg_below = torch.cumsum(neg, dim=1) - neg + numerator = (pos * (neg_below + 0.5 * neg)).sum(dim=1) + denom = total_pos * total_neg + auc = torch.where( + denom > 0, + numerator / denom, + torch.full_like(numerator, 0.5), + ).to(torch.float32) + return [ + MetricComputationReport( + name=MetricName.AUC, + metric_prefix=MetricPrefix.LIFETIME, + value=auc, + ) + ] + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("utils") + + +def _trim_warmup_from_trace(path: str, keep_n_active: int) -> None: + """Post-process a chrome trace to drop events from WARMUP-phase steps. + + torch.profiler captures events during BOTH the WARMUP and RECORD phases + of a schedule and writes them all to the exported trace. There is no + built-in flag to exclude WARMUP from the export. We approximate it by: + + 1) Finding all ``ProfilerStep#N`` spans in the file. + 2) Keeping only the last ``keep_n_active`` of them (sorted by start + timestamp) as the "active" range. + 3) Filtering ``traceEvents`` to events whose ``ts`` falls inside that + range. Metadata events (``ph='M'``) are always preserved. + + Mutates the file in place. + """ + import json as _json + with open(path) as f: + d = _json.load(f) + events = d.get("traceEvents", []) + + # ProfilerStep spans mark training-step boundaries; we filter by their + # time ranges rather than by name index because step numbering can offset + # between schedule_fn argument and the value printed in the trace. + # torch.profiler emits one ProfilerStep#N span per CPU thread that ran + # during that step, so dedupe by name first so "5 active steps" means + # 5 distinct step numbers, not 5 spans. + name_to_span: Dict[str, tuple] = {} + for e in events: + nm = e.get("name", "") + if "ProfilerStep" not in nm or e.get("ph") != "X" or "ts" not in e: + continue + ts = e["ts"] + end = ts + e.get("dur", 0) + prev = name_to_span.get(nm) + if prev is None: + name_to_span[nm] = (ts, end) + else: + name_to_span[nm] = (min(prev[0], ts), max(prev[1], end)) + if len(name_to_span) <= keep_n_active: + return + sorted_spans = sorted(name_to_span.values()) + active = sorted_spans[-keep_n_active:] + t_start = min(s for s, _ in active) + t_end = max(e for _, e in active) + + def _keep(e: dict) -> bool: + if e.get("ph") == "M": + return True + ts = e.get("ts") + if ts is None: + return True + return t_start <= ts < t_end + + kept = [e for e in events if _keep(e)] + d["traceEvents"] = kept + with open(path, "w") as f: + _json.dump(d, f) + logger.warning( + f"Trimmed WARMUP events from {path}: {len(events):,} -> {len(kept):,} " + f"(kept active range [{t_start:.0f}, {t_end:.0f}] us)" + ) + + +# GPU activity categories used to detect GPU stream rows and their busy time. +_GPU_KERNEL_CATS = frozenset({"kernel", "gpu_memcpy", "gpu_memset"}) + + +def _is_rocm() -> bool: + """True on ROCm/AMD builds (``torch.version.hip`` set), False on CUDA/B200. + + The ProfilerStep-layout normalization and the sub-us kernel de-overlap are + workarounds for how roctracer projects annotations/kernels onto HIP streams; + CUDA/CUPTI traces don't have those artifacts, so these passes must be skipped + on NVIDIA to avoid touching otherwise-correct traces. + """ + return getattr(torch.version, "hip", None) is not None + + +def _normalize_profilerstep_layout(path: str) -> None: + """Collapse fragmented GPU-side ``ProfilerStep#N`` spans into one span/step. + + ``torch.profiler`` emits ``ProfilerStep#N`` as a CPU ``user_annotation`` that + Kineto projects onto the GPU timeline as ``gpu_user_annotation`` spans. On + CUDA the blocking H2D copy shares the compute stream, so each step projects + onto a single GPU stream and renders as one full-width span. On ROCm a + blocking H2D copy lands on HIP's null stream (a different stream than the + non-null compute stream), so the step splits across two GPU rows and looks + truncated in Perfetto — a pure rendering artifact (every kernel is still + captured, and the underlying GPU is busy for the whole step). + + This rewrites each per-step GPU ``ProfilerStep`` annotation to a single span + on the rank's busiest (compute) GPU stream, covering the kernel extent inside + that step's CPU window. Works on a raw per-rank trace (GPU streams are tids + under one pid) by keying the busiest stream on ``(pid, tid)``. No-op when the + annotation already lives on a single GPU stream (the CUDA case), so it is + safe to run on every platform. Mutates the file in place. + """ + import json as _json + + with open(path) as f: + d = _json.load(f) + events = d.get("traceEvents", []) + + # Per (pid,tid) GPU busy time -> identify the busiest = compute stream. + stream_busy: Dict[tuple, int] = {} + for e in events: + if e.get("ph") == "X" and e.get("cat") in _GPU_KERNEL_CATS: + dur = e.get("dur", 0) + if dur > 0: + key = (e.get("pid"), e.get("tid")) + stream_busy[key] = stream_busy.get(key, 0) + dur + if not stream_busy: + return + busiest = max(stream_busy, key=lambda k: stream_busy[k]) + + # Existing GPU-side ProfilerStep spans and the streams they sit on. + gpu_ps_streams = set() + template = None + for e in events: + if e.get("cat") == "gpu_user_annotation" and str( + e.get("name", "") + ).startswith("ProfilerStep"): + gpu_ps_streams.add((e.get("pid"), e.get("tid"))) + if template is None: + template = e + # No fragmentation (single stream or none) -> leave the trace untouched. + if len(gpu_ps_streams) <= 1: + return + + # CPU ProfilerStep windows: step name -> [min ts, max end]. + cpu_win: Dict[str, list] = {} + for e in events: + if ( + e.get("cat") == "user_annotation" + and e.get("ph") == "X" + and str(e.get("name", "")).startswith("ProfilerStep") + ): + ts = e.get("ts", 0) + end = ts + e.get("dur", 0) + w = cpu_win.get(e["name"]) + if w is None: + cpu_win[e["name"]] = [ts, end] + else: + w[0] = min(w[0], ts) + w[1] = max(w[1], end) + + # GPU kernel extents (any stream) for clamping each step's span. + gpu_kernels = [ + (e.get("ts", 0), e.get("ts", 0) + e.get("dur", 0)) + for e in events + if e.get("ph") == "X" + and e.get("cat") in _GPU_KERNEL_CATS + and e.get("dur", 0) > 0 + ] + + new_spans = [] + for sname, (cs, ce) in cpu_win.items(): + ks = [(ts, end) for ts, end in gpu_kernels if end > cs and ts < ce] + if not ks: + continue + gmin = min(ts for ts, _ in ks) + gmax = max(end for _, end in ks) + span = dict(template) if template else {} + span.update( + { + "ph": "X", + "cat": "gpu_user_annotation", + "name": sname, + "pid": busiest[0], + "tid": busiest[1], + "ts": gmin, + "dur": gmax - gmin, + "args": {"normalized_profilerstep": True}, + } + ) + new_spans.append(span) + + if not new_spans: + return + + out = [ + e + for e in events + if not ( + e.get("cat") == "gpu_user_annotation" + and str(e.get("name", "")).startswith("ProfilerStep") + ) + ] + dropped = len(events) - len(out) + out.extend(new_spans) + d["traceEvents"] = out + with open(path, "w") as f: + _json.dump(d, f) + logger.warning( + f"Normalized GPU ProfilerStep layout in {path}: dropped {dropped} " + f"fragmented span(s) across {len(gpu_ps_streams)} stream(s), wrote " + f"{len(new_spans)} span(s) on busiest stream pid={busiest[0]} " + f"tid={busiest[1]}" + ) + + +def _deoverlap_gpu_slices(path: str, max_snap_us: float = 5.0) -> None: + """Remove sub-microsecond kernel overlaps that break Perfetto's renderer. + + Perfetto draws all ``ph=="X"`` slices on a single track (one ``(pid, tid)``) + as a strict nested stack ordered by start time: a slice that *opens* while a + previous slice on the same track is still open is treated as that slice's + child and is **clipped to the parent's end**. ROCm's roctracer reports + per-stream kernel timestamps at ns granularity, so two back-to-back kernels + on the same compute stream occasionally overlap by a fraction of a + microsecond (e.g. an 88 ns ``elementwise`` epilogue ending 0.075 us *after* + the next 21 ms ``_hstu_attn_bwd`` kernel begins). Perfetto then nests the + long kernel inside the tiny one and clips it to a sub-pixel sliver, so the + kernel "disappears" from the timeline even though it is fully present in the + JSON. + + This pulls each slice's end back to just *before* the next slice's start + whenever they overlap by less than ``max_snap_us`` (a measurement artifact, + not real concurrency — kernels on one stream are serialized), leaving genuine + nesting (a small kernel fully contained in a larger one) untouched. The + adjustment is sub-microsecond and does not change any reported duration + meaningfully. Mutates the file in place; best-effort. + + Critically, the slices are separated by a tiny ``_GAP_US`` (~1 ns) rather + than snapped to an *exactly equal* end==start timestamp. A coincident + end==start is just as fatal as an overlap in Perfetto: it nests the next + slice inside the previous one and clips it to zero width (this is the ~1 ns + gap that roctracer leaves between cleanly-rendered back-to-back kernels). So + we also fix exact-touch (``a_end == b.ts``) boundaries, not just overlaps. + """ + import json as _json + from collections import defaultdict + + # ~1 ns. Matches the natural inter-kernel gap roctracer leaves between + # back-to-back kernels that Perfetto already renders correctly. Must be + # strictly > 0 so end != start after the nudge. + _GAP_US = 0.001 + + with open(path) as f: + d = _json.load(f) + events = d.get("traceEvents", []) + + tracks: Dict[tuple, list] = defaultdict(list) + for e in events: + if ( + e.get("ph") == "X" + and e.get("cat") in _GPU_KERNEL_CATS + and e.get("dur", 0) > 0 + ): + tracks[(e.get("pid"), e.get("tid"))].append(e) + + snapped = 0 + max_clip = 0.0 + for sl in tracks.values(): + # Sort by start, then longest-first so a container precedes the slices + # it nests; consecutive pairs are then either disjoint, properly nested, + # or a tiny artifact overlap. + sl.sort(key=lambda e: (e["ts"], -e["dur"])) + for i in range(len(sl) - 1): + a = sl[i] + b = sl[i + 1] + a_end = a["ts"] + a["dur"] + b_end = b["ts"] + b["dur"] + # Touching (a_end == b.ts) or partial overlap (a ends inside b) both + # break rendering; true containment (a_end >= b_end) is valid nesting + # and is left alone. + if b["ts"] <= a_end < b_end: + desired_end = b["ts"] - _GAP_US + clip = a_end - desired_end + if a["ts"] < desired_end and 0 < clip < max_snap_us: + a["dur"] = desired_end - a["ts"] + snapped += 1 + if clip > max_clip: + max_clip = clip + + if snapped: + with open(path, "w") as f: + _json.dump(d, f) + logger.warning( + f"De-overlapped GPU slices in {path}: snapped {snapped} sub-us " + f"overlap(s) (max {max_clip:.3f}us) so Perfetto renders every kernel" + ) + + +def _deoverlap_gpu_annotations(path: str, max_snap_us: float = 5.0) -> None: + """Separate touching/overlapping *sibling* GPU annotations so Perfetto draws + each one full width (the B200-style stacked layout). + + Same root cause as :func:`_deoverlap_gpu_slices`, but at the annotation + boundary instead of the kernel boundary. The forward/backward phase + annotations Kineto projects onto the GPU stream (``## item_forward ##``, + ``## user_forward ##``, ``## multitask_module ##``, the ``## stu_* ##`` + pairs, ...) are emitted as a chain of siblings laid end-to-end: each is meant + to end exactly where the next begins. Perfetto stores timestamps as int64 ns, + and the absolute step timestamps are ~5.4e12 us where a float64's quantum is + already ~1 ns, so a sibling boundary that should be coincident instead lands + a few ns off. When the earlier sibling's end falls *at or after* the next + sibling's start, Perfetto nests the next sibling inside it and clips it to a + sub-pixel sliver — so e.g. the 100+ ms ``## user_forward ##`` span vanishes on + some ranks/steps and renders on others purely by rounding luck. + + Unlike kernels (all flat on one stream), annotations form a real nesting + hierarchy — ``## user_forward ##`` legitimately *contains* the ``## stu_* ##`` + spans and their kernels — so this cannot blindly snap consecutive slices. It + walks the per-track slice stack (sorted by start, longest-first) and only + snaps a slice ``a`` back when the next slice ``b`` is **not** contained in it + (``b`` extends beyond ``a``'s end), i.e. they are siblings rather than + parent/child. Real containment is left untouched, and a snap is skipped if it + would clip into ``a``'s own descendants (kernels or child annotations). + Mutates the file in place; best-effort. Run after :func:`_deoverlap_gpu_slices` + so kernel boundaries are already clean. + """ + import json as _json + from collections import defaultdict + + # ~2 ns. The annotation boundaries sit at ~5.4e12 us where a float64's + # quantum is ~0.98 ns, so a 1 ns nudge can round back onto the neighbour's + # timestamp (an exact touch, which Perfetto still nests+clips). 2 ns (~2 + # quanta) reliably separates them and is still far below any visible width. + _GAP_US = 0.002 + + with open(path) as f: + d = _json.load(f) + events = d.get("traceEvents", []) + + # Stack the full per-track hierarchy over BOTH kernels and annotations so a + # parent annotation knows the extent of its descendants (the snap guard), + # but only annotation slices are ever trimmed. + _ANN = "gpu_user_annotation" + tracks: Dict[tuple, list] = defaultdict(list) + for e in events: + if ( + e.get("ph") == "X" + and e.get("dur", 0) > 0 + and (e.get("cat") in _GPU_KERNEL_CATS or e.get("cat") == _ANN) + ): + tracks[(e.get("pid"), e.get("tid"))].append(e) + + snapped = 0 + max_clip = 0.0 + for sl in tracks.values(): + # Longest-first on ties so a container precedes the slices it nests. + sl.sort(key=lambda e: (e["ts"], -e["dur"])) + # Each frame: [event, max_descendant_end]. The stack holds the chain of + # currently-open ancestors for the slice being placed. + stack: list = [] + for b in sl: + b_ts = b["ts"] + b_end = b_ts + b["dur"] + while stack: + a = stack[-1][0] + a_end = a["ts"] + a["dur"] + if a_end < b_ts: + # a closed strictly before b begins -> disjoint sibling, pop. + frame = stack.pop() + eff = frame[0]["ts"] + frame[0]["dur"] + if stack: + stack[-1][1] = max(stack[-1][1], eff, frame[1]) + continue + if a_end < b_end: + # b starts at/inside a but extends past a's end => they are + # siblings (not parent/child), and a's tail nests+clips b in + # Perfetto. Snap a's end to just before b. This fires for both + # annotation tails (## item_forward ## overhanging + # ## user_forward ##) and kernel tails that straddle an + # annotation boundary (a layer-norm kernel ending a few ns + # past the start of the next phase span) -- both are sub-us + # roctracer/rounding artifacts, since kernels on one stream + # are serialized and phase spans are sequential. + desired_end = b_ts - _GAP_US + clip = a_end - desired_end + # Guard: only snap when a's deepest descendant ends at or + # before b's start. If a child (kernel or nested span) + # actually extends *past* b.ts, trimming a wouldn't fix b's + # clipping (the child would still nest b) and could drop a + # real child into b's territory, so leave it. A descendant + # ending exactly at the boundary is itself rounding noise and + # is clipped by <=1 ns, which is fine. + if ( + a["ts"] < desired_end + and stack[-1][1] <= b_ts + and 0 < clip < max_snap_us + ): + a["dur"] = desired_end - a["ts"] + snapped += 1 + if clip > max_clip: + max_clip = clip + frame = stack.pop() + eff = frame[0]["ts"] + frame[0]["dur"] + if stack: + stack[-1][1] = max(stack[-1][1], eff, frame[1]) + continue + # a_end >= b_end: a fully contains b -> b is a child, stop. + break + stack.append([b, b_ts]) + + if snapped: + with open(path, "w") as f: + _json.dump(d, f) + logger.warning( + f"De-overlapped GPU annotations in {path}: snapped {snapped} sub-us " + f"sibling overlap(s) (max {max_clip:.3f}us) so Perfetto renders every " + f"annotation full width" + ) + + +def _on_trace_ready_fn( + rank: Optional[int] = None, + trace_dir: str = "/tmp/dlrm_v3_traces", + keep_n_active: Optional[int] = None, + trace_steps: Optional[List[int]] = None, +) -> Callable[[torch.profiler.profile], None]: + """Create the on_trace_ready callback that exports a chrome trace to disk. + + Filename follows ``trace_step{step}_rank{rank}.json`` so multi-rank + captures don't collide and ``scripts/stitch_traces.py`` can merge them + by step number. + + The ``{step}`` label: + + * If ``trace_steps`` is provided (multi-window mode), the Nth callback + invocation labels its file with ``trace_steps[N]`` -- i.e. the + user-requested step that triggered the window. This is the most + intuitive labelling. + * Otherwise falls back to ``p.step_num`` (torch.profiler's internal + counter at trigger time, off by ~warmup+active from the schedule + arg). + + If ``keep_n_active`` is set, the exported file is post-processed to keep + only the last N ProfilerStep-spans worth of events (i.e. drop WARMUP). + """ + state = {"fire_count": 0} + + def handle_fn(p: torch.profiler.profile) -> None: + os.makedirs(trace_dir, exist_ok=True) + if trace_steps: + i = state["fire_count"] + step_label = ( + trace_steps[i] if i < len(trace_steps) else getattr(p, "step_num", 0) + ) + else: + step_label = getattr(p, "step_num", 0) + state["fire_count"] += 1 + rank_str = f"_rank{rank}" if rank is not None else "" + file_name = f"trace_step{step_label}{rank_str}.json" + path = os.path.join(trace_dir, file_name) + logger.warning( + p.key_averages(group_by_input_shape=True).table( + sort_by="self_cuda_time_total" + ) + ) + # Tracing is best-effort: a write/trim failure (permissions, disk full, + # malformed export) must never crash the training run. Degrade to a + # warning so the loop continues — especially important since streaming + # enables output_trace by default. + try: + p.export_chrome_trace(path) + logger.warning(f"Trace written to: {path}") + if keep_n_active is not None and keep_n_active > 0: + _trim_warmup_from_trace(path, keep_n_active) + # ROCm/AMD-only rendering fixes. CUDA/CUPTI (e.g. B200) traces don't + # exhibit the fragmented-ProfilerStep or sub-us kernel-overlap + # artifacts, so skip entirely on NVIDIA to avoid touching otherwise + # correct traces. Best-effort like trim above. + if _is_rocm(): + # Normalize the GPU-side ProfilerStep layout so ROCm traces + # render with one full-width step span per stream like CUDA. + _normalize_profilerstep_layout(path) + # Snap roctracer's sub-us kernel overlaps so Perfetto doesn't + # mis-nest and hide long kernels. + _deoverlap_gpu_slices(path) + # Same fix at the annotation-sibling boundary so phase spans + # (## user_forward ##, ## stu_* ##, ...) render full width. + _deoverlap_gpu_annotations(path) + except Exception as exc: + logger.warning(f"Trace export/trim failed for {path}: {exc!r} (skipping)") + + return handle_fn + + +def profiler_or_nullcontext( + enabled: bool, with_stack: bool, trace_dir: str = "/tmp/dlrm_v3_traces" +): + """One-shot profile context for ad-hoc captures (no scheduling).""" + return ( + profile( + # pyre-fixme[16]: Module `profiler` has no attribute `ProfilerActivity`. + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + on_trace_ready=_on_trace_ready_fn(trace_dir=trace_dir), + with_stack=with_stack, + ) + if enabled + else contextlib.nullcontext() + ) + + +def _multi_window_schedule( + trace_steps, + warmup: int, + active: int, +): + """Custom schedule that profiles around each step in ``trace_steps``. + + Step s gets: + [s - warmup, s) -> WARMUP + [s, s + active - 1) -> RECORD + s + active - 1 -> RECORD_AND_SAVE + """ + windows = [(s - warmup, s, s + active) for s in sorted(trace_steps)] + + def schedule_fn(step: int) -> torch.profiler.ProfilerAction: + for warmup_start, active_start, active_end in windows: + if warmup_start <= step < active_start: + return torch.profiler.ProfilerAction.WARMUP + if active_start <= step < active_end - 1: + return torch.profiler.ProfilerAction.RECORD + if step == active_end - 1: + return torch.profiler.ProfilerAction.RECORD_AND_SAVE + return torch.profiler.ProfilerAction.NONE + + return schedule_fn + + +@gin.configurable +class Profiler: + """Scheduled torch.profiler wrapper that writes Chrome traces to disk. + + Two modes (set via gin): + + * Single window (default): ``wait=10, warmup=20, active=50, repeat=1``. + Captures one contiguous window starting after ``wait`` steps. + * Multi-window: ``trace_steps=[500, 1000, 5000]`` (overrides wait+repeat). + Captures a separate window around each listed step. + + All knobs are gin-tunable, e.g. in a gin file:: + + Profiler.trace_dir = "/path/to/results/exp42/trace" + Profiler.trace_steps = [500, 1000, 5000] + Profiler.warmup = 5 + Profiler.active = 10 + """ + + def __init__( + self, + rank: int, + active: int = 50, + wait: int = 10, + warmup: int = 20, + repeat: int = 1, + trace_steps: Optional[List[int]] = None, + trace_dir: str = "/tmp/dlrm_v3_traces", + trim_warmup: bool = True, + record_shapes: bool = True, + profile_memory: bool = False, + with_stack: bool = False, + with_flops: bool = False, + with_modules: bool = False, + ) -> None: + self.rank = rank + self.trace_dir = trace_dir + if trace_steps: + sched = _multi_window_schedule(trace_steps, warmup, active) + else: + sched = torch.profiler.schedule( + wait=wait, warmup=warmup, active=active, repeat=repeat + ) + keep_n = active if trim_warmup else None + self._profiler: profiler.profile = torch.profiler.profile( + schedule=sched, + on_trace_ready=_on_trace_ready_fn( + self.rank, trace_dir, keep_n, trace_steps + ), + # pyre-fixme[16]: Module `profiler` has no attribute `ProfilerActivity`. + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=record_shapes, + profile_memory=profile_memory, + with_stack=with_stack, + with_flops=with_flops, + with_modules=with_modules, + ) + + def step(self) -> None: + """Advance the profiler to the next step.""" + self._profiler.step() + + +class _NoOpSummaryWriter: + """Drop-in stand-in for SummaryWriter used when TensorBoard is disabled + (empty ``tensorboard_log_path``). All scalar writes become no-ops so the + metrics path (compute + text-log + ``.metrics.jsonl`` sinks) runs unchanged + and never crashes on TB file I/O. The shared ``/apps`` tfevents writer was a + crash source under transient filer ``Errno 121`` (Remote I/O) errors, and + nothing we consume reads TensorBoard — eval-window AUCs come from the JSONL + sink and the text log.""" + + def add_scalar(self, *args, **kwargs) -> None: + pass + + def flush(self) -> None: + pass + + def close(self) -> None: + pass + + +@gin.configurable +class MetricsLogger: + """ + Logger for tracking and computing recommendation metrics. + + Supports both classification metrics (NE, Accuracy, GAUC) and regression + metrics (MSE, MAE) based on multitask configuration. + + Args: + multitask_configs: List of task configurations defining metric types. + batch_size: Batch size for metric computation. + window_size: Window size for running metric aggregation. + device: Device to place metric tensors on. + rank: Process rank for distributed training. + tensorboard_log_path: Optional path for TensorBoard logging. + """ + + def __init__( + self, + multitask_configs: List[TaskConfig], + batch_size: int, + window_size: int, + device: torch.device, + rank: int, + tensorboard_log_path: str = "", + world_size: int = 1, + auc_threshold: Optional[float] = None, + num_flops_per_sample: float = 0.0, + gpu_peak_flops: float = 0.0, + model: Optional[torch.nn.Module] = None, + eval_cumulative: bool = False, + cumulative_auc_bins: int = 100_000, + train_lifetime_auc_mode: str = "binned", + eval_lifetime_auc_mode: str = "binned", + lifetime_auc_window: int = 10_000_000, + ) -> None: + # tflops/mfu reporting state (optional — when both num_flops_per_sample + # and gpu_peak_flops are set, the train perf line gains tflops_algo/gpu, + # mfu, tflops_real/gpu, hfu, fill. The jagged ("real") numbers come + # from `model._last_jagged_flops_per_sample` stashed by DlrmHSTU.main_forward. + self._num_flops_per_sample: float = max(0.0, float(num_flops_per_sample)) + self._gpu_peak_flops: float = max(0.0, float(gpu_peak_flops)) + self._model_ref: Optional[torch.nn.Module] = model + if rank == 0 and self._num_flops_per_sample > 0 and self._gpu_peak_flops > 0: + logger.info( + f"FLOPS reporting enabled: {self._num_flops_per_sample / 1e9:.1f} " + f"GFLOP/sample (dense fwd+bwd), GPU peak {self._gpu_peak_flops / 1e12:.0f} TFLOPS" + ) + self.multitask_configs: List[TaskConfig] = multitask_configs + all_classification_tasks: List[str] = [ + task.task_name + for task in self.multitask_configs + if task.task_type != MultitaskTaskType.REGRESSION + ] + all_regression_tasks: List[str] = [ + task.task_name + for task in self.multitask_configs + if task.task_type == MultitaskTaskType.REGRESSION + ] + assert all_classification_tasks + all_regression_tasks == [ + task.task_name for task in multitask_configs + ] + self.task_names: List[str] = all_classification_tasks + all_regression_tasks + + # Eval metric semantics: + # eval_cumulative=False (default, legacy / static / non-streaming eval): + # a single eval set with the configured window_size, including a + # lifetime AUC. Unchanged behavior. + # eval_cumulative=True (streaming fixed-holdout eval): a FRESH eval set + # (window_size=UNBOUNDED, reset each pass -> per-pass full-holdout + # "window_*") PLUS a CUMULATIVE set ("eval_cum", never reset -> + # "lifetime_*"). NE/Accuracy/GAUC are cumulative for free via their + # persistent scalar sums; AUC cumulative uses the selected backend. + # + # Lifetime-AUC backend is configurable independently for train and eval: + # "binned" (default): BinnedCumulativeAUC - exact-cumulative AUC via an + # O(num_bins) score histogram (additive all-reduce, no unbounded + # buffer, memory independent of #samples/#windows). + # "capped": LifetimeAUCMetricComputation - AUC over a trailing buffer of + # `lifetime_auc_window` samples/rank (the legacy approach; per-rank + # buffer all-gathered at compute). + self._eval_cumulative: bool = eval_cumulative + self._cumulative_auc_bins: int = int(cumulative_auc_bins) + self._train_lifetime_auc_mode: str = str(train_lifetime_auc_mode) + self._eval_lifetime_auc_mode: str = str(eval_lifetime_auc_mode) + self._lifetime_auc_window: int = int(lifetime_auc_window) + n_cls = len(all_classification_tasks) + n_reg = len(all_regression_tasks) + + def _make_lifetime_auc(mode: str) -> RecMetricComputation: + if mode == "binned": + # window_size=0: no torchrec windowed state; histograms only. + return BinnedCumulativeAUC( + my_rank=rank, batch_size=batch_size, n_tasks=n_cls, + window_size=0, num_bins=self._cumulative_auc_bins, + ).to(device) + if mode == "capped": + return LifetimeAUCMetricComputation( + my_rank=rank, batch_size=batch_size, n_tasks=n_cls, + window_size=self._lifetime_auc_window, + ).to(device) + raise ValueError( + f"lifetime_auc_mode must be 'binned' or 'capped', got {mode!r}" + ) + + def _make_class(ws: int, lifetime_mode: Optional[str]) -> List[RecMetricComputation]: + mets: List[RecMetricComputation] = [ + NEMetricComputation(my_rank=rank, batch_size=batch_size, n_tasks=n_cls, window_size=ws).to(device), + AccuracyMetricComputation(my_rank=rank, batch_size=batch_size, n_tasks=n_cls, window_size=ws).to(device), + GAUCMetricComputation(my_rank=rank, batch_size=batch_size, n_tasks=n_cls, window_size=ws).to(device), + AUCMetricComputation(my_rank=rank, batch_size=batch_size, n_tasks=n_cls, window_size=ws).to(device), + ] + if lifetime_mode is not None: + mets.append(_make_lifetime_auc(lifetime_mode)) + return mets + + def _make_class_cumulative() -> List[RecMetricComputation]: + # NE/Accuracy/GAUC: cumulative via persistent lifetime sums (window + # value ignored at compute). AUC: selected lifetime backend. + return [ + NEMetricComputation(my_rank=rank, batch_size=batch_size, n_tasks=n_cls, window_size=window_size).to(device), + AccuracyMetricComputation(my_rank=rank, batch_size=batch_size, n_tasks=n_cls, window_size=window_size).to(device), + GAUCMetricComputation(my_rank=rank, batch_size=batch_size, n_tasks=n_cls, window_size=window_size).to(device), + _make_lifetime_auc(self._eval_lifetime_auc_mode), + ] + + def _make_reg(ws: int) -> List[RecMetricComputation]: + return [ + MSEMetricComputation(my_rank=rank, batch_size=batch_size, n_tasks=n_reg, window_size=ws).to(device), + MAEMetricComputation(my_rank=rank, batch_size=batch_size, n_tasks=n_reg, window_size=ws).to(device), + ] + + self.class_metrics: Dict[str, List[RecMetricComputation]] = {"train": [], "eval": []} + self.regression_metrics: Dict[str, List[RecMetricComputation]] = {"train": [], "eval": []} + if eval_cumulative: + self.class_metrics["eval_cum"] = [] + self.regression_metrics["eval_cum"] = [] + + if all_classification_tasks: + self.class_metrics["train"] = _make_class(window_size, lifetime_mode=self._train_lifetime_auc_mode) + if eval_cumulative: + self.class_metrics["eval"] = _make_class(UNBOUNDED_WINDOW, lifetime_mode=None) + self.class_metrics["eval_cum"] = _make_class_cumulative() + else: + self.class_metrics["eval"] = _make_class(window_size, lifetime_mode=self._eval_lifetime_auc_mode) + + if all_regression_tasks: + self.regression_metrics["train"] = _make_reg(window_size) + if eval_cumulative: + self.regression_metrics["eval"] = _make_reg(UNBOUNDED_WINDOW) + self.regression_metrics["eval_cum"] = _make_reg(window_size) + else: + self.regression_metrics["eval"] = _make_reg(window_size) + + self.global_step: Dict[str, int] = {"train": 0, "eval": 0} + # MLPerf `samples_count` progress unit: global trained samples, persisted + # alongside global_step so a resumed run continues the count. + self.cumulative_train_samples: int = 0 + # Whether the MLPerf run markers (RUN_START etc.) were already emitted for + # this logical run. Checkpointed so a resume relaunch knows the run is + # already open and does NOT re-emit INIT_START/RUN_START (the compliance + # checker requires EXACTLY_ONE); the resumed process continues the same + # event stream and emits the single RUN_STOP at convergence/end. + self.mlperf_run_started: bool = False + self._rank: int = int(rank) + # Optional MLPerf logger + LR accessor wired by the streaming loop (duck- + # typed to avoid a train-module import cycle); drives the train_loss event. + self.mlperf_logger: Optional[Any] = None + self.lr_getter: Optional[Callable[[], float]] = None + self.tb_logger: Optional[SummaryWriter] = None + if tensorboard_log_path != "": + self.tb_logger = SummaryWriter(log_dir=tensorboard_log_path, purge_step=0) + self.tb_logger.flush() + else: + # TB disabled: use a no-op writer so the existing call sites (and the + # `assert self.tb_logger is not None` in compute_and_log) keep working + # while no tfevents are written to the fragile shared filer. + self.tb_logger = _NoOpSummaryWriter() + + # Throughput / time-to-target tracking. Counters are train-only; eval + # samples are not relevant for headline samples/sec numbers. + self._world_size: int = max(1, int(world_size)) + self._auc_threshold: Optional[float] = auc_threshold + self._time_to_target_logged: bool = False + self._perf_t_start: float = time.perf_counter() + self._perf_t_window: float = self._perf_t_start + self._perf_steps_in_window: int = 0 + self._perf_total_samples: int = 0 + self._perf_samples_counter: torch.Tensor = torch.zeros( + 1, dtype=torch.long, device=device + ) + # Non-train wall-time to exclude from the train step-time window so + # `step_ms` reports the canonical per-step compute latency even when an + # interval coincides with eval or checkpointing. Categorized so the + # excluded time is also reportable (eval_ms / ckpt_ms) rather than just + # discarded. The trainer brackets eval/ckpt regions via + # pause_perf(cat)/resume_perf(cat); accumulators reset each train-perf log. + self._perf_excluded: Dict[str, float] = {"eval": 0.0, "ckpt": 0.0} + self._perf_pause: Dict[str, Optional[float]] = {} + + def pause_perf(self, category: str) -> None: + """Start excluding wall-time under `category` (e.g. "eval"/"ckpt") from + the train step-time window. Idempotent: a second pause without an + intervening resume is a no-op (keeps the earliest start).""" + if self._perf_pause.get(category) is None: + self._perf_pause[category] = time.perf_counter() + + def resume_perf(self, category: str) -> None: + """Stop excluding `category` and fold the elapsed interval into the + per-category accumulator. No-op if not currently paused.""" + t0 = self._perf_pause.get(category) + if t0 is not None: + self._perf_excluded[category] = ( + self._perf_excluded.get(category, 0.0) + + (time.perf_counter() - t0) + ) + self._perf_pause[category] = None + + @property + def auc_threshold(self) -> Optional[float]: + """Configured time-to-target AUC threshold (None if unset).""" + return self._auc_threshold + + @property + def all_metrics(self) -> Dict[str, List[RecMetricComputation]]: + """ + Get all metrics for train and eval modes. + + Returns: + Dictionary mapping mode ('train'/'eval') to list of metric computations. + """ + out = { + "train": self.class_metrics["train"] + self.regression_metrics["train"], + "eval": self.class_metrics["eval"] + self.regression_metrics["eval"], + } + if "eval_cum" in self.class_metrics or "eval_cum" in self.regression_metrics: + out["eval_cum"] = self.class_metrics.get( + "eval_cum", [] + ) + self.regression_metrics.get("eval_cum", []) + return out + + def update( + self, + predictions: torch.Tensor, + weights: torch.Tensor, + labels: torch.Tensor, + num_candidates: torch.Tensor, + mode: str = "train", + ) -> None: + """ + Update metrics with new batch of predictions and labels. + + Args: + predictions: Model prediction tensor. + weights: Sample weight tensor. + labels: Ground truth label tensor. + num_candidates: Number of candidates per sample (for GAUC). + mode: Either 'train' or 'eval'. + """ + # On eval, update BOTH the fresh set and the never-reset cumulative set + # (if enabled) from the same batch. + update_targets = list(self.all_metrics[mode]) + if mode == "eval" and "eval_cum" in self.all_metrics: + update_targets = update_targets + self.all_metrics["eval_cum"] + for metric in update_targets: + if isinstance(metric, GAUCMetricComputation): + metric.update( + predictions=predictions, + labels=labels, + weights=weights, + num_candidates=num_candidates, + ) + else: + metric.update( + predictions=predictions, + labels=labels, + weights=weights, + ) + self.global_step[mode] += 1 + if mode == "train": + # Accumulate on-device to avoid a per-step GPU->CPU sync; we read + # the counter only at compute_and_log boundaries. + self._perf_samples_counter += num_candidates.sum().to( + self._perf_samples_counter.dtype + ) + self._perf_steps_in_window += 1 + # MLPerf progress counter: per-rank sample count scaled by world size + # approximates global trained samples without an extra collective. + self.cumulative_train_samples += int(num_candidates.numel()) * self._world_size + + def compute(self, mode: str = "train") -> Dict[str, float]: + """ + Compute and return all metrics for the current window. + + Args: + mode: Either 'train' or 'eval'. + + Returns: + Dictionary mapping metric names to their computed values. + """ + all_computed_metrics = {} + + if mode == "eval" and "eval_cum" in self.all_metrics: + # Dual-set eval: `window_*` (fresh per-pass) from the reset-each-pass + # set; `lifetime_*` (cumulative across passes) from the never-reset + # set. Filter each set to the matching prefix, and drop GAUC's + # auxiliary `*_num_samples` reports. Key names are unchanged + # (`window_auc`, `lifetime_ne`, ...) so dashboards keep working. + def _emit( + metrics: List[RecMetricComputation], keep_prefix: str + ) -> None: + for metric in metrics: + for computed in metric.compute(): + pfx = str(computed.metric_prefix) + name = str(computed.name) + if pfx != keep_prefix or name.endswith("num_samples"): + continue + all_values = computed.value.cpu() + for i, task_name in enumerate(self.task_names): + if i >= len(all_values): + break + all_computed_metrics[f"metric/{pfx}{name}/{task_name}"] = ( + all_values[i] + ) + + _emit(self.all_metrics["eval"], "window_") + _emit(self.all_metrics["eval_cum"], "lifetime_") + else: + for metric in self.all_metrics[mode]: + computed_metrics = metric.compute() + for computed in computed_metrics: + all_values = computed.value.cpu() + for i, task_name in enumerate(self.task_names): + if i >= len(all_values): + break + key = f"metric/{str(computed.metric_prefix) + str(computed.name)}/{task_name}" + all_computed_metrics[key] = all_values[i] + + logger.info( + f"{mode} - Step {self.global_step[mode]} metrics: {all_computed_metrics}" + ) + return all_computed_metrics + + def _global_mean_loss(self, loss_terms: Dict[str, torch.Tensor]) -> float: + """Cross-rank mean of the summed per-task losses. + + The 1-element all-reduce MUST run on every rank in lockstep; callers gate + it on a rank-identical counter (global_step / a deterministic frequency) + so it cannot desync. + """ + loss_t = torch.stack( + [v.detach().float().sum() for v in loss_terms.values()] + ).sum() + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.all_reduce(loss_t, op=torch.distributed.ReduceOp.SUM) + loss_t = loss_t / self._world_size + return float(loss_t) + + def maybe_log_mlperf_train_loss( + self, aux_losses: Dict[str, torch.Tensor], every: int + ) -> None: + """Emit the MLPerf ``train_loss`` event on its OWN cadence. + + Decoupled from the console/TB metric cadence (``compute_and_log``): call + this every train step with the just-computed ``aux_losses`` and the + desired interval ``every`` (in global train steps). The cross-rank loss + all-reduce only fires on the cadence, gated by ``global_step["train"]`` + which is incremented identically on all ranks in ``update()`` — so the + collective stays in lockstep. No-op when MLPerf logging is disabled + (``mlperf_logger is None`` on every rank) or ``every <= 0``. + """ + if self.mlperf_logger is None or every <= 0 or not aux_losses: + return + if self.global_step["train"] % every != 0: + return + train_loss = self._global_mean_loss(aux_losses) + c = self.mlperf_logger.constants + md: Dict[str, Any] = {c.SAMPLES_COUNT: self.cumulative_train_samples} + if self.lr_getter is not None: + try: + md["lr"] = float(self.lr_getter()) + except Exception: + pass + self.mlperf_logger.event( + key=getattr(c, "TRAIN_LOSS", "train_loss"), + value=train_loss, + metadata=md, + ) + + def compute_and_log( + self, + mode: str = "train", + additional_logs: Optional[Dict[str, Dict[str, torch.Tensor]]] = None, + ) -> Dict[str, float]: + """ + Compute metrics and log to TensorBoard. + + Args: + mode: Either 'train' or 'eval'. + additional_logs: Optional additional data to log. + + Returns: + Dictionary mapping metric names to their computed values. + + Raises: + AssertionError: If TensorBoard logger is not configured. + """ + assert self.tb_logger is not None + all_computed_metrics = self.compute(mode) + for k, v in all_computed_metrics.items(): + self.tb_logger.add_scalar( # pyre-ignore [16] + f"{mode}_{k}", + v, + global_step=self.global_step[mode], + ) + + if additional_logs is not None: + for tag, data in additional_logs.items(): + for data_name, data_value in data.items(): + self.tb_logger.add_scalar( + f"{tag}/{mode}_{data_name}", + data_value.detach().clone().cpu(), + global_step=self.global_step[mode], + ) + + # Global (cross-rank mean) train loss to console/TB at the metric-logging + # cadence. The 1-element all-reduce runs on every rank in lockstep, so it + # cannot desync. The MLPerf `train_loss` EVENT is emitted separately via + # ``maybe_log_mlperf_train_loss`` so its cadence can be tuned independently + # of this console/TB cadence (see that method). + if mode == "train" and additional_logs is not None and "losses" in additional_logs: + loss_terms = additional_logs["losses"] + if loss_terms: + train_loss = self._global_mean_loss(loss_terms) + self.tb_logger.add_scalar( + "train_loss", train_loss, global_step=self.global_step["train"] + ) + if self._rank == 0: + logger.info( + f"train - Step {self.global_step['train']} " + f"train_loss={train_loss:.5f}" + ) + + # Throughput metrics (train only). One GPU->CPU sync per call. + if mode == "train" and self._perf_steps_in_window > 0: + now = time.perf_counter() + wall_dt = max(now - self._perf_t_window, 1e-6) + # Subtract bracketed eval/checkpoint wall-time so step_ms / sps / + # MFU reflect canonical train-step compute, not eval+ckpt stalls + # that happened to land in this window. The excluded time is also + # surfaced separately below (eval_ms / ckpt_ms) rather than discarded. + eval_s = self._perf_excluded.get("eval", 0.0) + ckpt_s = self._perf_excluded.get("ckpt", 0.0) + dt = max(wall_dt - eval_s - ckpt_s, 1e-6) + n_samples = int(self._perf_samples_counter.item()) + self._perf_total_samples += n_samples + local_sps = n_samples / dt + global_sps = local_sps * self._world_size + step_ms = dt * 1000.0 / self._perf_steps_in_window + wall_step_ms = wall_dt * 1000.0 / self._perf_steps_in_window + eval_ms = eval_s * 1000.0 + ckpt_ms = ckpt_s * 1000.0 + elapsed = now - self._perf_t_start + step = self.global_step["train"] + self.tb_logger.add_scalar( + "perf/train_samples_per_sec_local", local_sps, global_step=step + ) + self.tb_logger.add_scalar( + "perf/train_samples_per_sec_global", global_sps, global_step=step + ) + self.tb_logger.add_scalar( + "perf/train_step_time_ms", step_ms, global_step=step + ) + # Inclusive (old-semantics) per-step wall time and the eval/ckpt + # breakdown that was excluded from step_ms above. + self.tb_logger.add_scalar( + "perf/train_wall_step_time_ms", wall_step_ms, global_step=step + ) + self.tb_logger.add_scalar( + "perf/window_eval_time_ms", eval_ms, global_step=step + ) + self.tb_logger.add_scalar( + "perf/window_ckpt_time_ms", ckpt_ms, global_step=step + ) + self.tb_logger.add_scalar( + "perf/train_total_samples", self._perf_total_samples, global_step=step + ) + self.tb_logger.add_scalar( + "perf/train_elapsed_sec", elapsed, global_step=step + ) + # TFLOPS / MFU reporting (algo = dense yardstick, real = jagged). + # tflops_algo/gpu, mfu — uses max_seq_len^2 attention work (the + # MFU yardstick: the FLOPs the workload would do if every + # user's UIH filled the padded seq length). + # tflops_real/gpu, hfu — uses this batch's mean(s_i^2) (actual + # GPU work; hardware utilization). + # fill — real / algo as a percent; how much of + # the algo budget the model actually executed this batch. + # The jagged stash is read from the inner model; the model ref may + # be a DMP wrapper, so unwrap via .module if present. + tflops_str = "" + if self._num_flops_per_sample > 0 and self._gpu_peak_flops > 0: + local_flops = self._num_flops_per_sample * local_sps + tflops_algo = local_flops / 1e12 + mfu = 100.0 * local_flops / self._gpu_peak_flops + self.tb_logger.add_scalar("perf/train_tflops_algo_gpu", tflops_algo, global_step=step) + self.tb_logger.add_scalar("perf/train_mfu_pct", mfu, global_step=step) + tflops_str = f" tflops_algo/gpu={tflops_algo:.1f} mfu={mfu:.1f}%" + jagged_t = None + m = self._model_ref + if m is not None: + inner = m.module if hasattr(m, "module") else m + jagged_t = getattr(inner, "_last_jagged_flops_per_sample", None) + if jagged_t is not None: + jagged = float(jagged_t.item()) + if 0 < jagged < self._num_flops_per_sample: + tflops_real = jagged * local_sps / 1e12 + hfu = 100.0 * jagged * local_sps / self._gpu_peak_flops + fill = 100.0 * jagged / self._num_flops_per_sample + self.tb_logger.add_scalar("perf/train_tflops_real_gpu", tflops_real, global_step=step) + self.tb_logger.add_scalar("perf/train_hfu_pct", hfu, global_step=step) + self.tb_logger.add_scalar("perf/train_fill_pct", fill, global_step=step) + tflops_str += f" tflops_real/gpu={tflops_real:.1f} hfu={hfu:.1f}% fill={fill:.1f}%" + logger.info( + f"train - Step {step} perf: local_sps={local_sps:.1f} " + f"global_sps={global_sps:.1f} step_ms={step_ms:.2f} " + f"elapsed_sec={elapsed:.1f} total_samples={self._perf_total_samples} " + f"wall_step_ms={wall_step_ms:.2f} eval_ms={eval_ms:.1f} " + f"ckpt_ms={ckpt_ms:.1f}" + + tflops_str + ) + self._perf_t_window = now + self._perf_steps_in_window = 0 + self._perf_samples_counter.zero_() + # Reset the excluded-time accumulators for the next window. Any + # still-open pause (eval/ckpt straddling this log) is cleared so its + # remaining time is not double-counted; in practice perf logs fire + # only after a train step, never mid eval/ckpt. + self._perf_excluded = {"eval": 0.0, "ckpt": 0.0} + self._perf_pause = {} + + # Time-to-target: latch wall-clock once any task's AUC crosses threshold. + # Matches MLPerf DLRM-DCNv2 reporting style (default upstream target 0.80275). + if ( + self._auc_threshold is not None + and not self._time_to_target_logged + ): + for key, val in all_computed_metrics.items(): + metric_short = key.split("/")[-2] if "/" in key else key + if metric_short.endswith("auc") and not metric_short.endswith("gauc"): + if float(val) >= self._auc_threshold: + ttt = time.perf_counter() - self._perf_t_start + self.tb_logger.add_scalar( + f"perf/time_to_auc_{self._auc_threshold:.5f}_sec", + ttt, + global_step=self.global_step[mode], + ) + logger.info( + f"REACHED AUC>={self._auc_threshold} on {key}=" + f"{float(val):.6f} at elapsed_sec={ttt:.2f} " + f"step={self.global_step[mode]}" + ) + self._time_to_target_logged = True + break + + return all_computed_metrics + + def reset(self, mode: str = "train"): + """ + Reset all metrics for a given mode. + + Args: + mode: Either 'train' or 'eval'. + """ + for metric in self.all_metrics[mode]: + metric.reset() + + +# the datasets we support +SUPPORTED_DATASETS = [ + "debug", + "movielens-1m", + "movielens-20m", + "movielens-13b", + "movielens-18b", + "kuairand-1k", + "streaming-400m", + "streaming-200b", + "streaming-100b", + "sampled-streaming-100b", + "yambda-5b", +] + + +@gin.configurable +def env_path(key: str = "", default: str = "") -> str: + """Resolve a path from os.environ[key], falling back to `default`. + + Intended as a gin macro so paths can be overridden via env vars without + editing the gin file. Example gin usage: + + DATA_PATH = @env_path() + env_path.key = "DLRM_DATA_PATH" + env_path.default = "/some/default/path" + make_train_test_dataloaders.new_path_prefix = %DATA_PATH + """ + return os.environ.get(key, default) if key else default + + +@gin.configurable +def env_str(key: str = "", default: str = "") -> str: + """Resolve a string from os.environ[key], falling back to `default`. + + Companion to `env_int`/`env_float` for categorical/string overrides (e.g. a + metric backend selector). Example gin usage: + + MetricsLogger.train_lifetime_auc_mode = @tlam/env_str() + tlam/env_str.key = "TRAIN_LIFETIME_AUC_MODE" + tlam/env_str.default = "binned" + """ + raw = os.environ.get(key) if key else None + return raw if raw else default + + +@gin.configurable +def env_str_map( + key: str = "", + default: Optional[Dict[str, str]] = None, + merge: bool = False, +) -> Dict[str, str]: + """Parse os.environ[key] as 'name=value,name=value' into a dict. + + Falls back to `default` (gin) when the env var is unset/empty. Companion to + `env_str` for map-valued overrides (e.g. per-table embedding placement). + Example gin usage: + + make_optimizer_and_shard.embedding_placement_overrides = @env_str_map() + env_str_map.key = "EMB_PLACEMENT_OVERRIDES" + env_str_map.default = {} + + Example env override: EMB_PLACEMENT_OVERRIDES="uid=uvm_caching,item_id=hbm". + + `merge` controls how the parsed env entries combine with `default`: + * False (default): the env var REPLACES `default` wholesale (whole-dict + override) — set the env var and you fully define the map. + * True: the parsed env entries are LAYERED ON TOP of `default` (per-key + override) — gin `default` defines the base map and the env var tweaks + only the named keys, leaving the rest of `default` intact. + """ + base = dict(default or {}) + raw = os.environ.get(key) if key else None + if not raw: + return base + parsed: Dict[str, str] = {} + for pair in raw.split(","): + pair = pair.strip() + if not pair: + continue + k, _, v = pair.partition("=") + parsed[k.strip()] = v.strip() + return {**base, **parsed} if merge else parsed + + +@gin.configurable +def env_int(key: str = "", default: int = 0) -> int: + """Resolve an int from os.environ[key], falling back to `default`. + + Companion to `env_path` for numeric overrides. Example gin usage: + + make_optimizer_and_shard.hbm_cap_gb = @env_int() + env_int.key = "HBM_CAP_GB" + env_int.default = 260 + """ + raw = os.environ.get(key) if key else None + return int(raw) if raw else default + + +@gin.configurable +def env_float(key: str = "", default: float = 0.0) -> float: + """Resolve a float from os.environ[key], falling back to `default`. + + Companion to `env_int` for fractional/duration overrides (e.g. a + checkpoint time interval in seconds). Example gin usage: + + streaming_train_eval_loop.checkpoint_time_interval_s = @env_float() + env_float.key = "CKPT_TIME_INTERVAL_S" + env_float.default = 3600.0 + """ + raw = os.environ.get(key) if key else None + return float(raw) if raw else default + + +_GPU_PEAK_FLOPS_TABLE: Dict[str, Dict[str, float]] = { + # Per-GPU peak TFLOPS by dtype. Values from vendor datasheets / Primus-DLRM + # peak_table. Used as the denominator in MFU/HFU. Keyed by case-insensitive + # substring of torch.cuda.get_device_name(0). + "MI355X": {"bf16": 2300e12, "fp32": 575e12}, + "MI350X": {"bf16": 2300e12, "fp32": 575e12}, + "MI300X": {"bf16": 1300e12, "fp32": 653e12}, + "MI325X": {"bf16": 1300e12, "fp32": 653e12}, + "B200": {"bf16": 2250e12, "fp32": 1125e12}, + "H100": {"bf16": 990e12, "fp32": 67e12}, + "A100": {"bf16": 312e12, "fp32": 19.5e12}, +} + + +def get_gpu_peak_flops(dtype: str = "bf16") -> float: + """Peak FLOPS for the current GPU at the given dtype. + + Falls back to MI350X's number with a warning when the device name doesn't + match any table entry — better to over-report MFU than to silently skip. + """ + if not torch.cuda.is_available(): + return 0.0 + name = torch.cuda.get_device_name(0) + for gpu_key, peaks in _GPU_PEAK_FLOPS_TABLE.items(): + if gpu_key in name: + return peaks.get(dtype, peaks["bf16"]) + logger.warning( + f"Unknown GPU for peak FLOPS: {name}; defaulting to MI350X bf16 (2300 TF)" + ) + return _GPU_PEAK_FLOPS_TABLE["MI350X"]["bf16"] + + +@gin.configurable +def run_results_dir(run_name: str = "default", subdir: str = "results") -> str: + """Resolve ``//`` from this file's location. + + Used as a gin macro to give per-run output directories that persist on the + host (recommendation_v4 is bind-mounted into the training container). + + Example gin usage:: + + RUN_NAME = @env_path() + env_path.key = "RUN_NAME" + env_path.default = "default" + run_results_dir.run_name = %RUN_NAME + Profiler.trace_dir = @run_results_dir() + """ + # utils.py lives at /generative_recommenders/dlrm_v3/utils.py; + # parents[2] climbs to /. + repo_root = Path(__file__).resolve().parents[2] + return str(repo_root / subdir / run_name) + + +@gin.configurable +def get_dataset( + name: str, + new_path_prefix: str = "", + history_length: Optional[int] = None, + min_history: Optional[int] = None, + history_strategy: str = "interleaved", + streaming_window_seconds: int = 86400, + streaming_sort_within_window: bool = False, + streaming_shuffle_fraction: float = 0.0, + streaming_shuffle_seed: int = 0, + train_split_percentage: float = 1.0, + split_salt: int = 0, +): + """ + Get dataset class and configuration by name. + + Args: + name: Dataset identifier (must be in SUPPORTED_DATASETS). + new_path_prefix: Optional prefix to prepend to data paths. + + Returns: + Tuple of (dataset_class, kwargs_dict) for dataset instantiation. + + Raises: + AssertionError: If dataset name is not supported. + """ + assert name in SUPPORTED_DATASETS, f"dataset {name} not supported" + if name == "debug": + return DLRMv3RandomDataset, {} + if name == "movielens-1m": + return ( + DLRMv3MovieLensDataset, + { + "ratings_file": os.path.join( + new_path_prefix, "data/ml-1m/sasrec_format.csv" + ), + }, + ) + if name == "movielens-20m": + return ( + DLRMv3MovieLensDataset, + { + "ratings_file": os.path.join( + new_path_prefix, "data/ml-20m/sasrec_format.csv" + ), + }, + ) + if name == "movielens-13b": + return ( + DLRMv3SyntheticMovieLensDataset, + { + "ratings_file_prefix": os.path.join( + new_path_prefix, "data/ml-13b/16x16384" + ), + }, + ) + if name == "movielens-18b": + return ( + DLRMv3SyntheticMovieLensDataset, + { + "ratings_file_prefix": os.path.join( + new_path_prefix, "data/ml-18b/20x36864" + ), + }, + ) + if name == "kuairand-1k": + return ( + DLRMv3KuaiRandDataset, + { + "seq_logs_file": os.path.join( + new_path_prefix, "data/KuaiRand-1K/data/processed_seqs.csv" + ), + }, + ) + if name == "streaming-400m": + return ( + DLRMv3SyntheticStreamingDataset, + { + "ratings_file_prefix": os.path.join( + new_path_prefix, "data/streaming-400m/" + ), + "train_ts": 8, + "total_ts": 10, + "num_files": 3, + "num_users": 150_000, + "num_items": 1_500_000, + "num_categories": 128, + }, + ) + if name == "streaming-200b": + return ( + DLRMv3SyntheticStreamingDataset, + { + "ratings_file_prefix": os.path.join( + new_path_prefix, "data/streaming-200b/" + ), + "train_ts": 90, + "total_ts": 100, + "num_files": 100, + "num_users": 10_000_000, + "num_items": 1_000_000_000, + "num_categories": 128, + }, + ) + if name == "streaming-100b": + return ( + DLRMv3SyntheticStreamingDataset, + { + "ratings_file_prefix": os.path.join( + new_path_prefix, "data/streaming-100b/" + ), + "train_ts": 90, + "total_ts": 100, + "num_files": 100, + "num_users": 5_000_000, + "num_items": 1_000_000_000, + "num_categories": 128, + }, + ) + if name == "yambda-5b": + from generative_recommenders.dlrm_v3.configs import YAMBDA_5B_CROSS_SPECS + + return ( + DLRMv3YambdaDataset, + { + # Layout: /processed_5b/{train_sessions.parquet,...} + # and /shared_metadata/{artist,album}_item_mapping.parquet. + # The dataset auto-builds a MAP_SHARED-mmap'd cache of the + # flat columns + LISTEN-anchor positions under + # /hstu_cache_L/ on first use; + # all ranks on a node share the same physical pages. + "processed_dir": os.path.join(new_path_prefix, "processed_5b"), + "metadata_dir": os.path.join(new_path_prefix, "shared_metadata"), + # Per-pool truncation cap; total interleaved UIH ~ 3*L/3 = L. + # Override via `get_dataset.history_length = N` in gin. + "history_length": history_length if history_length is not None else 4096, + "scan_window": 20000, + # Anchor-eligibility floor: a LISTEN event qualifies once the + # user has >= min_history prior events. Decoupled from + # history_length (gather cap) since jagged attention handles + # short UIH. None = legacy (require a full history_length). + # Override via `get_dataset.min_history = N` / $MIN_HISTORY. + "min_history": min_history, + # UIH construction: "interleaved" (per-pool L//3 cap) or + # "last_n" (last history_length pooled events, no per-pool + # split). Strategy-independent on disk — both reuse the same + # hstu_cache_L/ and positions file (the gather + # runs at sample-construction time), so switching needs no + # rebuild. Override via $HISTORY_STRATEGY. + "history_strategy": history_strategy, + "cross_specs": YAMBDA_5B_CROSS_SPECS, + # Temporal-streaming knobs (only used under --mode + # streaming-train-eval; ignored by the default train-eval path). + "streaming_window_seconds": streaming_window_seconds, + "streaming_sort_within_window": streaming_sort_within_window, + # In-window shuffle diversity dial in [0,1]: K=round(frac*N) within- + # segment shuffle. 0=off/user-major, 1=full. Config-invariant and + # deterministic by (seed, ts). + "streaming_shuffle_fraction": streaming_shuffle_fraction, + "streaming_shuffle_seed": streaming_shuffle_seed, + # User-level train:eval holdout for the streaming path. 1.0 = + # no holdout (legacy). <1.0 holds out (1 - tsp) of users as a + # fixed eval set; those users are never trained. + "train_split_percentage": train_split_percentage, + "split_salt": split_salt, + }, + ) + if name == "sampled-streaming-100b": + return ( + DLRMv3SyntheticStreamingDataset, + { + "ratings_file_prefix": os.path.join( + new_path_prefix, "data/streaming-100b/sampled_data/" + ), + "train_ts": 90, + "total_ts": 100, + "num_files": 1, + "num_users": 50_000, + "num_items": 1_000_000_000, + "num_categories": 128, + }, + ) diff --git a/recommendation_v4/generative_recommenders/modules/action_encoder.py b/recommendation_v4/generative_recommenders/modules/action_encoder.py new file mode 100644 index 000000000..13b65557e --- /dev/null +++ b/recommendation_v4/generative_recommenders/modules/action_encoder.py @@ -0,0 +1,112 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from typing import Dict, List, Optional, Tuple + +import torch +from generative_recommenders.common import HammerModule +from generative_recommenders.ops.jagged_tensors import concat_2D_jagged + + +class ActionEncoder(HammerModule): + def __init__( + self, + action_embedding_dim: int, + action_feature_name: str, + action_weights: List[int], + watchtime_feature_name: str = "", + watchtime_to_action_thresholds_and_weights: Optional[ + List[Tuple[int, int]] + ] = None, + embedding_init_std: float = 0.1, + is_inference: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + self._watchtime_feature_name: str = watchtime_feature_name + self._action_feature_name: str = action_feature_name + self._watchtime_to_action_thresholds_and_weights: List[Tuple[int, int]] = ( + watchtime_to_action_thresholds_and_weights + if watchtime_to_action_thresholds_and_weights is not None + else [] + ) + self.register_buffer( + "_combined_action_weights", + torch.tensor( + action_weights + + [x[1] for x in self._watchtime_to_action_thresholds_and_weights] + ), + ) + self._num_action_types: int = len(action_weights) + len( + self._watchtime_to_action_thresholds_and_weights + ) + self._action_embedding_dim = action_embedding_dim + self._action_embedding_table: torch.nn.Parameter = torch.nn.Parameter( + torch.empty((self._num_action_types, action_embedding_dim)).normal_( + mean=0, std=embedding_init_std + ), + ) + self._target_action_embedding_table: torch.nn.Parameter = torch.nn.Parameter( + torch.empty((1, self._num_action_types * action_embedding_dim)).normal_( + mean=0, std=embedding_init_std + ), + ) + + @property + def output_embedding_dim(self) -> int: + return self._action_embedding_dim * self._num_action_types + + def forward( + self, + max_uih_len: int, + max_targets: int, + uih_offsets: torch.Tensor, + target_offsets: torch.Tensor, + seq_embeddings: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + ) -> torch.Tensor: + seq_actions = seq_payloads[self._action_feature_name] + if len(self._watchtime_to_action_thresholds_and_weights) > 0: + watchtimes = seq_payloads[self._watchtime_feature_name] + for threshold, weight in self._watchtime_to_action_thresholds_and_weights: + seq_actions = torch.bitwise_or( + seq_actions, (watchtimes >= threshold).to(torch.int64) * weight + ) + exploded_actions = ( + torch.bitwise_and( + seq_actions.unsqueeze(-1), self._combined_action_weights.unsqueeze(0) + ) + > 0 + ) + action_embeddings = ( + exploded_actions.unsqueeze(-1) * self._action_embedding_table.unsqueeze(0) + ).view(-1, self._num_action_types * self._action_embedding_dim) + total_targets: int = seq_embeddings.size(0) - action_embeddings.size(0) + action_embeddings = concat_2D_jagged( + max_seq_len=max_uih_len + max_targets, + values_left=action_embeddings, + values_right=self._target_action_embedding_table.tile( + total_targets, + 1, + ), + max_len_left=max_uih_len, + max_len_right=max_targets, + offsets_left=uih_offsets, + offsets_right=target_offsets, + kernel=self.hammer_kernel(), + ) + return action_embeddings diff --git a/recommendation_v4/generative_recommenders/modules/content_encoder.py b/recommendation_v4/generative_recommenders/modules/content_encoder.py new file mode 100644 index 000000000..acca82dbf --- /dev/null +++ b/recommendation_v4/generative_recommenders/modules/content_encoder.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from typing import Dict, List, Optional + +import torch +from generative_recommenders.common import HammerModule +from generative_recommenders.ops.jagged_tensors import concat_2D_jagged + + +class ContentEncoder(HammerModule): + def __init__( + self, + input_embedding_dim: int, + additional_content_features: Optional[Dict[str, int]] = None, + target_enrich_features: Optional[Dict[str, int]] = None, + is_inference: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + self._input_embedding_dim: int = input_embedding_dim + self._additional_content_features: Dict[str, int] = ( + additional_content_features + if additional_content_features is not None + else {} + ) + self._target_enrich_features: Dict[str, int] = ( + target_enrich_features if target_enrich_features is not None else {} + ) + self._target_enrich_dummy_embeddings: torch.nn.ParameterDict = ( + torch.nn.ParameterDict( + { + name: torch.nn.Parameter( + torch.empty((1, dim)).normal_(mean=0, std=0.1), + ) + for name, dim in self._target_enrich_features.items() + } + ) + ) + + @property + def output_embedding_dim(self) -> int: + return self._input_embedding_dim + sum( + list(self._additional_content_features.values()) + + list(self._target_enrich_features.values()) + ) + + def forward( + self, + max_uih_len: int, + max_targets: int, + uih_offsets: torch.Tensor, + target_offsets: torch.Tensor, + seq_embeddings: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + ) -> torch.Tensor: + content_embeddings_list: List[torch.Tensor] = [seq_embeddings] + if len(self._additional_content_features) > 0: + content_embeddings_list = content_embeddings_list + [ + (seq_payloads[x].to(seq_embeddings.dtype)) + for x in self._additional_content_features.keys() + ] + + if self._target_enrich_dummy_embeddings: + total_seq_len: int = seq_embeddings.size(0) + for name, param in self._target_enrich_dummy_embeddings.items(): + enrich_embeddings_target = seq_payloads[name].to(seq_embeddings.dtype) + total_targets: int = enrich_embeddings_target.size(0) + total_uih_len: int = total_seq_len - total_targets + enrich_embeddings_uih = param.tile(total_uih_len, 1).to( + seq_embeddings.dtype + ) + enrich_embeddings = concat_2D_jagged( + max_seq_len=max_uih_len + max_targets, + values_left=enrich_embeddings_uih, + values_right=enrich_embeddings_target, + max_len_left=max_uih_len, + max_len_right=max_targets, + offsets_left=uih_offsets, + offsets_right=target_offsets, + kernel=self.hammer_kernel(), + ) + content_embeddings_list.append(enrich_embeddings) + + if ( + len(self._target_enrich_features) == 0 + and len(self._additional_content_features) == 0 + ): + return seq_embeddings + else: + content_embeddings = torch.cat( + content_embeddings_list, + dim=1, + ) + return content_embeddings diff --git a/recommendation_v4/generative_recommenders/modules/contextual_interleave_preprocessor.py b/recommendation_v4/generative_recommenders/modules/contextual_interleave_preprocessor.py new file mode 100644 index 000000000..fff0d72f0 --- /dev/null +++ b/recommendation_v4/generative_recommenders/modules/contextual_interleave_preprocessor.py @@ -0,0 +1,357 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from math import sqrt +from typing import Callable, Dict, Optional, Tuple + +import torch +from generative_recommenders.common import fx_unwrap_optional_tensor +from generative_recommenders.modules.action_encoder import ActionEncoder +from generative_recommenders.modules.content_encoder import ContentEncoder +from generative_recommenders.modules.contextualize_mlps import ( + ContextualizedMLP, + ParameterizedContextualizedMLP, +) +from generative_recommenders.modules.preprocessors import ( + get_contextual_input_embeddings, + InputPreprocessor, +) +from generative_recommenders.ops.jagged_tensors import concat_2D_jagged + + +class ContextualInterleavePreprocessor(InputPreprocessor): + def __init__( + self, + input_embedding_dim: int, + output_embedding_dim: int, + contextual_feature_to_max_length: Dict[str, int], + contextual_feature_to_min_uih_length: Dict[str, int], + content_encoder: ContentEncoder, + content_contextualize_mlp_fn: Callable[ + [int, int, int, bool], ContextualizedMLP + ], + action_encoder: ActionEncoder, + action_contextualize_mlp_fn: Callable[[int, int, int, bool], ContextualizedMLP], + pmlp_contextual_dropout_ratio: float = 0.0, + enable_interleaving: bool = False, + is_inference: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + self._input_embedding_dim: int = input_embedding_dim + self._output_embedding_dim: int = output_embedding_dim + self._contextual_feature_to_max_length: Dict[str, int] = ( + contextual_feature_to_max_length + ) + self._max_contextual_seq_len: int = sum( + contextual_feature_to_max_length.values() + ) + self._contextual_feature_to_min_uih_length: Dict[str, int] = ( + contextual_feature_to_min_uih_length + ) + std = 1.0 * sqrt(2.0 / float(input_embedding_dim + output_embedding_dim)) + self._batched_contextual_linear_weights = torch.nn.Parameter( + torch.empty( + ( + self._max_contextual_seq_len, + input_embedding_dim, + output_embedding_dim, + ) + ).normal_(0.0, std) + ) + self._pmlp_contextual_dropout_ratio: float = pmlp_contextual_dropout_ratio + self._batched_contextual_linear_bias = torch.nn.Parameter( + torch.empty((self._max_contextual_seq_len, 1, output_embedding_dim)).fill_( + 0.0 + ) + ) + contextual_embedding_dim: int = ( + self._max_contextual_seq_len * input_embedding_dim + ) + self._content_encoder: ContentEncoder = content_encoder + self._content_embedding_mlp: ContextualizedMLP = content_contextualize_mlp_fn( + self._content_encoder.output_embedding_dim, + output_embedding_dim, + contextual_embedding_dim, + is_inference, + ) + self._action_encoder: ActionEncoder = action_encoder + self._action_embedding_mlp: ContextualizedMLP = action_contextualize_mlp_fn( + self._action_encoder.output_embedding_dim, + output_embedding_dim, + contextual_embedding_dim, + is_inference, + ) + self._enable_interleaving: bool = enable_interleaving + + def combine_embeddings( + self, + max_uih_len: int, + max_targets: int, + total_uih_len: int, + total_targets: int, + seq_lengths: torch.Tensor, + seq_timestamps: torch.Tensor, + content_embeddings: torch.Tensor, + action_embeddings: torch.Tensor, + contextual_embeddings: Optional[torch.Tensor], + num_targets: torch.Tensor, + ) -> Tuple[ + int, + int, + int, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + if self._enable_interleaving: + output_seq_timestamps = seq_timestamps.repeat_interleave(2) + output_seq_embeddings = torch.stack( + [content_embeddings, action_embeddings], dim=1 + ).reshape(-1, self._output_embedding_dim) + if self.interleave_targets(): + output_seq_lengths = seq_lengths * 2 + output_max_seq_len = (max_uih_len + max_targets) * 2 + output_num_targets = num_targets * 2 + output_total_uih_len = total_uih_len * 2 + output_total_targets = total_targets * 2 + else: + seq_lengths_by_2 = seq_lengths * 2 + output_seq_lengths = seq_lengths_by_2 - num_targets + output_max_seq_len = 2 * max_uih_len + max_targets + indices = torch.arange( + 2 * (max_uih_len + max_targets), device=seq_lengths.device + ).view(1, -1) + valid_mask = torch.logical_and( + indices < seq_lengths_by_2.view(-1, 1), + torch.logical_or( + indices < (output_seq_lengths - num_targets).view(-1, 1), + torch.remainder(indices, 2) == 0, + ), + ) + jagged_valid_mask = ( + torch.ops.fbgemm.dense_to_jagged( + valid_mask.int().unsqueeze(-1), + [ + torch.ops.fbgemm.asynchronous_complete_cumsum( + seq_lengths_by_2 + ) + ], + )[0] + .to(torch.bool) + .squeeze(1) + ) + output_seq_embeddings = output_seq_embeddings[jagged_valid_mask] + output_seq_timestamps = output_seq_timestamps[jagged_valid_mask] + output_num_targets = num_targets + output_total_uih_len = total_uih_len * 2 + output_total_targets = total_targets + else: + output_max_seq_len = max_uih_len + max_targets + output_seq_lengths = seq_lengths + output_num_targets = num_targets + output_seq_timestamps = seq_timestamps + output_seq_embeddings = content_embeddings + action_embeddings + output_total_uih_len = total_uih_len + output_total_targets = total_targets + + # concat contextual embeddings + output_seq_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + output_seq_lengths + ) + if self._max_contextual_seq_len > 0: + output_seq_embeddings = concat_2D_jagged( + max_seq_len=self._max_contextual_seq_len + output_max_seq_len, + values_left=fx_unwrap_optional_tensor(contextual_embeddings).reshape( + -1, self._output_embedding_dim + ), + values_right=output_seq_embeddings, + max_len_left=self._max_contextual_seq_len, + max_len_right=output_max_seq_len, + offsets_left=None, + offsets_right=output_seq_offsets, + kernel=self.hammer_kernel(), + ) + output_seq_timestamps = concat_2D_jagged( + max_seq_len=self._max_contextual_seq_len + output_max_seq_len, + values_left=torch.zeros( + (output_seq_lengths.size(0) * self._max_contextual_seq_len, 1), + dtype=output_seq_timestamps.dtype, + device=output_seq_timestamps.device, + ), + values_right=output_seq_timestamps.unsqueeze(-1), + max_len_left=self._max_contextual_seq_len, + max_len_right=output_max_seq_len, + offsets_left=None, + offsets_right=output_seq_offsets, + kernel=self.hammer_kernel(), + ).squeeze(-1) + output_max_seq_len = output_max_seq_len + self._max_contextual_seq_len + output_total_uih_len = ( + output_total_uih_len + + self._max_contextual_seq_len * output_seq_lengths.size(0) + ) + output_seq_lengths = output_seq_lengths + self._max_contextual_seq_len + output_seq_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + output_seq_lengths + ) + + return ( + output_max_seq_len, + output_total_uih_len, + output_total_targets, + output_seq_lengths, + output_seq_offsets, + output_seq_timestamps, + output_seq_embeddings, + output_num_targets, + ) + + def forward( # noqa C901 + self, + max_uih_len: int, + max_targets: int, + total_uih_len: int, + total_targets: int, + seq_lengths: torch.Tensor, + seq_timestamps: torch.Tensor, + seq_embeddings: torch.Tensor, + num_targets: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + ) -> Tuple[ + int, + int, + int, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Dict[str, torch.Tensor], + ]: + max_seq_len = max_uih_len + max_targets + with torch.autocast( + "cuda", + dtype=torch.bfloat16, + enabled=(not self.is_inference and self._training_dtype == torch.bfloat16), + ): + # get contextual_embeddings + contextual_embeddings: Optional[torch.Tensor] = None + pmlp_contextual_embeddings: Optional[torch.Tensor] = None + if self._max_contextual_seq_len > 0: + contextual_input_embeddings = get_contextual_input_embeddings( + seq_lengths=seq_lengths, + seq_payloads=seq_payloads, + contextual_feature_to_max_length=self._contextual_feature_to_max_length, + contextual_feature_to_min_uih_length=self._contextual_feature_to_min_uih_length, + dtype=seq_embeddings.dtype, + ) + if isinstance( + self._action_embedding_mlp, ParameterizedContextualizedMLP + ) or isinstance( + self._action_embedding_mlp, ParameterizedContextualizedMLP + ): + pmlp_contextual_embeddings = torch.nn.functional.dropout( + contextual_input_embeddings, + p=self._pmlp_contextual_dropout_ratio, + training=self.training, + ) + contextual_embeddings = torch.baddbmm( + self._batched_contextual_linear_bias.to( + contextual_input_embeddings.dtype + ), + contextual_input_embeddings.view( + -1, self._max_contextual_seq_len, self._input_embedding_dim + ).transpose(0, 1), + self._batched_contextual_linear_weights.to( + contextual_input_embeddings.dtype + ), + ).transpose(0, 1) + + # content embeddings + seq_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(seq_lengths) + target_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(num_targets) + uih_offsets = seq_offsets - target_offsets + content_embeddings = self._content_encoder( + max_uih_len=max_uih_len, + max_targets=max_targets, + uih_offsets=uih_offsets, + target_offsets=target_offsets, + seq_embeddings=seq_embeddings, + seq_payloads=seq_payloads, + ) + content_embeddings = self._content_embedding_mlp( + seq_embeddings=content_embeddings, + seq_offsets=seq_offsets, + max_seq_len=max_seq_len, + contextual_embeddings=pmlp_contextual_embeddings, + ) + + # action embeddings + action_embeddings = self._action_encoder( + max_uih_len=max_uih_len, + max_targets=max_targets, + uih_offsets=uih_offsets, + target_offsets=target_offsets, + seq_embeddings=seq_embeddings, + seq_payloads=seq_payloads, + ).to(seq_embeddings.dtype) + action_embeddings = self._action_embedding_mlp( + seq_embeddings=action_embeddings, + seq_offsets=seq_offsets, + max_seq_len=max_seq_len, + contextual_embeddings=pmlp_contextual_embeddings, + ) + + ( + output_max_seq_len, + output_total_uih_len, + output_total_targets, + output_seq_lengths, + output_seq_offsets, + output_seq_timestamps, + output_seq_embeddings, + output_num_targets, + ) = self.combine_embeddings( + max_uih_len=max_uih_len, + max_targets=max_targets, + total_uih_len=total_uih_len, + total_targets=total_targets, + seq_lengths=seq_lengths, + seq_timestamps=seq_timestamps, + content_embeddings=content_embeddings, + action_embeddings=action_embeddings, + contextual_embeddings=contextual_embeddings, + num_targets=num_targets, + ) + + return ( + output_max_seq_len, + output_total_uih_len, + output_total_targets, + output_seq_lengths, + output_seq_offsets, + output_seq_timestamps, + output_seq_embeddings, + output_num_targets, + seq_payloads, + ) + + def interleave_targets(self) -> bool: + return self.is_train and self._enable_interleaving diff --git a/recommendation_v4/generative_recommenders/modules/contextualize_mlps.py b/recommendation_v4/generative_recommenders/modules/contextualize_mlps.py new file mode 100644 index 000000000..dc49effeb --- /dev/null +++ b/recommendation_v4/generative_recommenders/modules/contextualize_mlps.py @@ -0,0 +1,141 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict +import abc +from typing import Optional + +import torch +from generative_recommenders.common import HammerModule, init_mlp_weights_optional_bias +from generative_recommenders.ops.jagged_tensors import jagged_dense_bmm_broadcast_add +from generative_recommenders.ops.layer_norm import LayerNorm, SwishLayerNorm +from libfb.py.pyre import none_throws + + +class ContextualizedMLP(HammerModule): + @abc.abstractmethod + def forward( + self, + max_seq_len: int, + seq_embeddings: torch.Tensor, + seq_offsets: torch.Tensor, + contextual_embeddings: Optional[torch.Tensor], + ) -> torch.Tensor: + """ + Args: + seq_embeddings: (L, D) + seq_offsets: (B + 1,) + max_seq_len: int + contextual_embeddings: (B, D') + """ + pass + + +class SimpleContextualizedMLP(ContextualizedMLP): + def __init__( + self, + sequential_input_dim: int, + sequential_output_dim: int, + hidden_dim: int, + is_inference: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + self._mlp: torch.nn.Module = torch.nn.Sequential( + torch.nn.Linear( + in_features=sequential_input_dim, + out_features=hidden_dim, + ), + SwishLayerNorm(hidden_dim, is_inference=is_inference), + torch.nn.Linear( + in_features=hidden_dim, + out_features=sequential_output_dim, + ), + LayerNorm(sequential_output_dim), + ).apply(init_mlp_weights_optional_bias) + + def forward( + self, + seq_embeddings: torch.Tensor, + seq_offsets: torch.Tensor, + max_seq_len: int, + contextual_embeddings: Optional[torch.Tensor], + ) -> torch.Tensor: + return self._mlp(seq_embeddings) + + +class ParameterizedContextualizedMLP(ContextualizedMLP): + def __init__( + self, + contextual_embedding_dim: int, + sequential_input_dim: int, + sequential_output_dim: int, + hidden_dim: int, + is_inference: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + + self._sequential_input_dim: int = sequential_input_dim + self._sequential_output_dim: int = sequential_output_dim + + self._dense_features_compress: torch.nn.Module = torch.nn.Linear( + in_features=contextual_embedding_dim, + out_features=hidden_dim, + ).apply(init_mlp_weights_optional_bias) + + self._attn_raw_weights: torch.nn.Module = torch.nn.Sequential( + torch.nn.Linear( + in_features=hidden_dim, + out_features=sequential_input_dim * sequential_output_dim, + ), + ).apply(init_mlp_weights_optional_bias) + + self._attn_weights_norm: torch.nn.Module = torch.nn.LayerNorm( + [sequential_input_dim, sequential_output_dim] + ) + + self._res_weights: torch.nn.Module = torch.nn.Sequential( + torch.nn.Linear( + in_features=hidden_dim, + out_features=hidden_dim, + ), + SwishLayerNorm(hidden_dim), + torch.nn.Linear( + in_features=hidden_dim, + out_features=sequential_output_dim, + ), + ).apply(init_mlp_weights_optional_bias) + + def forward( + self, + seq_embeddings: torch.Tensor, + seq_offsets: torch.Tensor, + max_seq_len: int, + contextual_embeddings: Optional[torch.Tensor], + ) -> torch.Tensor: + shared_input = self._dense_features_compress(none_throws(contextual_embeddings)) + attn_weights = self._attn_weights_norm( + self._attn_raw_weights(shared_input).reshape( + -1, self._sequential_input_dim, self._sequential_output_dim + ) + ) + return jagged_dense_bmm_broadcast_add( + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + jagged=seq_embeddings, + dense=attn_weights.to(seq_embeddings.dtype), + bias=self._res_weights(shared_input), + kernel=self.hammer_kernel(), + ) diff --git a/recommendation_v4/generative_recommenders/modules/dlrm_hstu.py b/recommendation_v4/generative_recommenders/modules/dlrm_hstu.py new file mode 100644 index 000000000..f11bb226a --- /dev/null +++ b/recommendation_v4/generative_recommenders/modules/dlrm_hstu.py @@ -0,0 +1,757 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + + +import logging +from dataclasses import dataclass, field +from typing import Dict, List, NamedTuple, Optional, Tuple + +import torch +from generative_recommenders.common import ( + fx_infer_max_len, + fx_mark_length_features, + HammerKernel, + HammerModule, + init_mlp_weights_optional_bias, + set_static_max_seq_lens, +) +from generative_recommenders.modules.hstu_transducer import HSTUTransducer +from generative_recommenders.modules.multitask_module import ( + DefaultMultitaskModule, + MultitaskTaskType, + TaskConfig, +) +from generative_recommenders.modules.positional_encoder import HSTUPositionalEncoder +from generative_recommenders.modules.postprocessors import ( + LayerNormPostprocessor, + TimestampLayerNormPostprocessor, +) +from generative_recommenders.modules.preprocessors import ContextualPreprocessor +from generative_recommenders.modules.stu import STU, STULayer, STULayerConfig, STUStack +from generative_recommenders.ops.jagged_tensors import concat_2D_jagged +from generative_recommenders.ops.layer_norm import LayerNorm, SwishLayerNorm +from torch.autograd.profiler import record_function +from torchrec import KeyedJaggedTensor +from torchrec.modules.embedding_configs import EmbeddingConfig +from torchrec.modules.embedding_modules import EmbeddingCollection + +logger: logging.Logger = logging.getLogger(__name__) + +def fx_total_targets(num_candidates: torch.Tensor) -> int: + """Sum a per-sample candidate-count tensor to a Python int. + + Wrapped with ``torch.fx.wrap`` so ``TrainPipelineSparseDist``'s symbolic + trace treats it as an opaque leaf instead of recursing into the data- + dependent ``int(Proxy.sum().item())`` (which raises during tracing). + """ + return int(num_candidates.sum().item()) + + +torch.fx.wrap("fx_infer_max_len") +torch.fx.wrap("fx_total_targets") +torch.fx.wrap("len") + + +class SequenceEmbedding(NamedTuple): + lengths: torch.Tensor + embedding: torch.Tensor + + +@dataclass +class DlrmHSTUConfig: + max_seq_len: int = 16384 + max_num_candidates: int = 10 + max_num_candidates_inference: int = 5 + hstu_num_heads: int = 1 + hstu_attn_linear_dim: int = 256 + hstu_attn_qk_dim: int = 128 + hstu_attn_num_layers: int = 12 + hstu_embedding_table_dim: int = 192 + hstu_preprocessor_hidden_dim: int = 256 + hstu_transducer_embedding_dim: int = 0 + hstu_group_norm: bool = False + hstu_input_dropout_ratio: float = 0.2 + hstu_linear_dropout_rate: float = 0.2 + contextual_feature_to_max_length: Dict[str, int] = field(default_factory=dict) + contextual_feature_to_min_uih_length: Dict[str, int] = field(default_factory=dict) + candidates_weight_feature_name: str = "" + candidates_watchtime_feature_name: str = "" + candidates_querytime_feature_name: str = "" + causal_multitask_weights: float = 0.2 + multitask_configs: List[TaskConfig] = field(default_factory=list) + user_embedding_feature_names: List[str] = field(default_factory=list) + item_embedding_feature_names: List[str] = field(default_factory=list) + uih_post_id_feature_name: str = "" + uih_action_time_feature_name: str = "" + uih_weight_feature_name: str = "" + hstu_uih_feature_names: List[str] = field(default_factory=list) + hstu_candidate_feature_names: List[str] = field(default_factory=list) + merge_uih_candidate_feature_mapping: List[Tuple[str, str]] = field( + default_factory=list + ) + action_weights: Optional[List[int]] = None + action_embedding_init_std: float = 0.1 + enable_postprocessor: bool = True + use_layer_norm_postprocessor: bool = False + + +def _get_supervision_labels_and_weights( + supervision_bitmasks: torch.Tensor, + watchtime_sequence: torch.Tensor, + task_configs: List[TaskConfig], +) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + supervision_labels: Dict[str, torch.Tensor] = {} + supervision_weights: Dict[str, torch.Tensor] = {} + for task in task_configs: + if task.task_type == MultitaskTaskType.REGRESSION: + supervision_labels[task.task_name] = watchtime_sequence.to(torch.float32) + elif task.task_type == MultitaskTaskType.BINARY_CLASSIFICATION: + supervision_labels[task.task_name] = ( + torch.bitwise_and(supervision_bitmasks, task.task_weight) > 0 + ).to(torch.float32) + else: + raise RuntimeError("Unsupported MultitaskTaskType") + return supervision_labels, supervision_weights + + +class DlrmHSTU(HammerModule): + def __init__( # noqa C901 + self, + hstu_configs: DlrmHSTUConfig, + embedding_tables: Dict[str, EmbeddingConfig], + is_inference: bool, + is_dense: bool = False, + bf16_training: bool = True, + ) -> None: + super().__init__(is_inference=is_inference) + logger.info(f"Initialize HSTU module with configs {hstu_configs}") + # When True, forward() takes the whole `Samples` batch as its single + # positional arg and reads the pre-merged sparse KJT off it. This keeps + # the EmbeddingCollection's input a plain getattr on the batch placeholder + # so TorchRec's TrainPipelineSparseDist can pipeline its input_dist. Set + # by build_train_pipeline(); leave False for eager / inference / eval. + self._pipeline_mode: bool = False + self._hstu_configs = hstu_configs + self._bf16_training: bool = bf16_training + # Last batch's jagged FLOPs/sample (0-d tensor on GPU). Populated by + # main_forward; MetricsLogger reads + .item()s on each compute_and_log + # to compute tflops_real/gpu and hfu (vs dense yardstick from + # get_num_flops_per_sample()). + self._last_jagged_flops_per_sample: Optional[torch.Tensor] = None + set_static_max_seq_lens([self._hstu_configs.max_seq_len]) + + if not is_dense: + self._embedding_collection: EmbeddingCollection = EmbeddingCollection( + tables=list(embedding_tables.values()), + need_indices=False, + device=torch.device("meta"), + ) + + # multitask configs must be sorted by task types + self._multitask_configs: List[TaskConfig] = hstu_configs.multitask_configs + self._multitask_module = DefaultMultitaskModule( + task_configs=self._multitask_configs, + embedding_dim=hstu_configs.hstu_transducer_embedding_dim, + prediction_fn=lambda in_dim, num_tasks: torch.nn.Sequential( + torch.nn.Linear(in_features=in_dim, out_features=512), + SwishLayerNorm(512), + torch.nn.Linear(in_features=512, out_features=num_tasks), + ).apply(init_mlp_weights_optional_bias), + causal_multitask_weights=hstu_configs.causal_multitask_weights, + is_inference=self._is_inference, + ) + self._additional_embedding_features: List[str] = [ + uih_feature_name + for ( + uih_feature_name, + candidate_feature_name, + ) in self._hstu_configs.merge_uih_candidate_feature_mapping + if ( + candidate_feature_name + in self._hstu_configs.item_embedding_feature_names + ) + and (uih_feature_name in self._hstu_configs.user_embedding_feature_names) + and (uih_feature_name is not self._hstu_configs.uih_post_id_feature_name) + ] + + # preprocessor setup + preprocessor = ContextualPreprocessor( + input_embedding_dim=hstu_configs.hstu_embedding_table_dim, + hidden_dim=hstu_configs.hstu_preprocessor_hidden_dim, + output_embedding_dim=hstu_configs.hstu_transducer_embedding_dim, + contextual_feature_to_max_length=hstu_configs.contextual_feature_to_max_length, + contextual_feature_to_min_uih_length=hstu_configs.contextual_feature_to_min_uih_length, + action_embedding_dim=8, + action_feature_name=self._hstu_configs.uih_weight_feature_name, + action_weights=self._hstu_configs.action_weights, + action_embedding_init_std=self._hstu_configs.action_embedding_init_std, + additional_embedding_features=self._additional_embedding_features, + is_inference=is_inference, + ) + + # positional encoder + positional_encoder = HSTUPositionalEncoder( + num_position_buckets=8192, + num_time_buckets=2048, + embedding_dim=hstu_configs.hstu_transducer_embedding_dim, + contextual_seq_len=sum( + dict(hstu_configs.contextual_feature_to_max_length).values() + ), + is_inference=self._is_inference, + ) + + if hstu_configs.enable_postprocessor: + if hstu_configs.use_layer_norm_postprocessor: + postprocessor = LayerNormPostprocessor( + embedding_dim=hstu_configs.hstu_transducer_embedding_dim, + eps=1e-5, + is_inference=self._is_inference, + ) + else: + postprocessor = TimestampLayerNormPostprocessor( + embedding_dim=hstu_configs.hstu_transducer_embedding_dim, + time_duration_features=[ + (60 * 60, 24), # hour of day + (24 * 60 * 60, 7), # day of week + # (24 * 60 * 60, 365), # time of year (approximate) + ], + eps=1e-5, + is_inference=self._is_inference, + ) + else: + postprocessor = None + + # construct HSTU + stu_module: STU = STUStack( + stu_list=[ + STULayer( + config=STULayerConfig( + embedding_dim=hstu_configs.hstu_transducer_embedding_dim, + num_heads=hstu_configs.hstu_num_heads, + hidden_dim=hstu_configs.hstu_attn_linear_dim, + attention_dim=hstu_configs.hstu_attn_qk_dim, + output_dropout_ratio=hstu_configs.hstu_linear_dropout_rate, + use_group_norm=hstu_configs.hstu_group_norm, + causal=True, + target_aware=True, + max_attn_len=None, + attn_alpha=None, + recompute_normed_x=True, + recompute_uvqk=True, + recompute_y=True, + sort_by_length=True, + contextual_seq_len=0, + ), + is_inference=is_inference, + ) + for _ in range(hstu_configs.hstu_attn_num_layers) + ], + is_inference=is_inference, + ) + self._hstu_transducer: HSTUTransducer = HSTUTransducer( + stu_module=stu_module, + input_preprocessor=preprocessor, + output_postprocessor=postprocessor, + input_dropout_ratio=hstu_configs.hstu_input_dropout_ratio, + positional_encoder=positional_encoder, + is_inference=self._is_inference, + return_full_embeddings=False, + listwise=False, + ) + + # item embeddings + self._item_embedding_mlp: torch.nn.Module = torch.nn.Sequential( + torch.nn.Linear( + in_features=hstu_configs.hstu_embedding_table_dim + * len(self._hstu_configs.item_embedding_feature_names), + out_features=512, + ), + SwishLayerNorm(512), + torch.nn.Linear( + in_features=512, + out_features=hstu_configs.hstu_transducer_embedding_dim, + ), + LayerNorm(hstu_configs.hstu_transducer_embedding_dim), + ).apply(init_mlp_weights_optional_bias) + + # -- FLOPs estimation ----------------------------------------------------- + # Convention matches TorchTitan / Primus-DLRM: matmul = 6 × M × N × K + # (×3 fwd+bwd, ×2 FMA), attention = 2 matmuls (Q·K^T + att·V). + # Embedding lookups excluded — they're memory-bound, not compute. + # + # HSTU vs OneTrans: HSTU collapses attention + FFN into a single UVQK + # projection plus SiLU(U) ⊙ y elementwise gating. There is NO separate + # FFN block (which dominates FLOPs in a standard transformer), so HSTU + # is intentionally compute-leaner per layer for the same N. + def _hstu_layer_flops( + self, n_tokens_linear: float, n_tokens_attn_sq: float + ) -> float: + """Per-layer FLOPs given linear-op token count and attention-token² + count. Dense estimate uses ``N`` and ``N²``; jagged estimate + substitutes ``mean(s_i)`` and ``mean(s_i²)``.""" + cfg = self._hstu_configs + D = cfg.hstu_embedding_table_dim + H = cfg.hstu_num_heads + hd = cfg.hstu_attn_linear_dim # V/U head dim + qd = cfg.hstu_attn_qk_dim # Q/K head dim + uvqk = 6 * n_tokens_linear * D * (2 * hd + 2 * qd) * H + attn = 6 * n_tokens_attn_sq * H * (qd + hd) # Q·K^T + att·V + out = 6 * n_tokens_linear * (3 * H * hd) * D + return uvqk + attn + out + + def get_num_flops_per_sample(self) -> float: + """Dense-equivalent fwd+bwd FLOPs per sample at ``max_seq_len``. + + Used as the MFU yardstick (peak utilization the workload could + theoretically reach if every sample's sequence were the full padded + length). The actual ``tflops_real``/``hfu`` reported per step uses + the jagged estimate stashed by ``main_forward``. + """ + cfg = self._hstu_configs + N = float(cfg.max_seq_len) + n_layers = cfg.hstu_attn_num_layers + flops = n_layers * self._hstu_layer_flops( + n_tokens_linear=N, n_tokens_attn_sq=N * N + ) + # Multitask head (Linear(D, n_tasks)) — negligible but cheap to add. + n_tasks = len(self._multitask_configs) + if n_tasks > 0: + flops += 6 * n_tasks * cfg.hstu_embedding_table_dim + return float(flops) + + def _compute_jagged_flops_per_sample( + self, + uih_seq_lengths: torch.Tensor, + num_candidates: torch.Tensor, + ) -> torch.Tensor: + """Jagged fwd+bwd FLOPs per sample for THIS batch's actual lengths. + + Per-sample merged sequence length s_i = uih_seq_lengths[i] + + num_candidates[i]. Returns a 0-d tensor on the batch's device; + caller should ``.item()`` it (one D→H sync per logging interval). + """ + s = (uih_seq_lengths + num_candidates).float() + mean_s = s.mean() + mean_s_sq = (s * s).mean() + cfg = self._hstu_configs + n_layers = cfg.hstu_attn_num_layers + flops = n_layers * ( + 6 * mean_s * cfg.hstu_embedding_table_dim + * (2 * cfg.hstu_attn_linear_dim + 2 * cfg.hstu_attn_qk_dim) + * cfg.hstu_num_heads + + 6 * mean_s_sq * cfg.hstu_num_heads + * (cfg.hstu_attn_qk_dim + cfg.hstu_attn_linear_dim) + + 6 * mean_s * (3 * cfg.hstu_num_heads * cfg.hstu_attn_linear_dim) + * cfg.hstu_embedding_table_dim + ) + n_tasks = len(self._multitask_configs) + if n_tasks > 0: + flops = flops + 6 * n_tasks * cfg.hstu_embedding_table_dim + return flops + + def _construct_payload( + self, + payload_features: Dict[str, torch.Tensor], + seq_embeddings: Dict[str, SequenceEmbedding], + ) -> Dict[str, torch.Tensor]: + if len(self._hstu_configs.contextual_feature_to_max_length) > 0: + contextual_offsets: List[torch.Tensor] = [] + for x in self._hstu_configs.contextual_feature_to_max_length.keys(): + contextual_offsets.append( + torch.ops.fbgemm.asynchronous_complete_cumsum( + seq_embeddings[x].lengths + ) + ) + else: + # Dummy, offsets are unused + contextual_offsets = torch.empty((0, 0)) + if torch.jit.is_scripting(): + # Explicit loops are TS-clean (avoid the dict-merge / dict-comp + # idioms below, which TorchScript cannot script). + out: Dict[str, torch.Tensor] = {} + for k, v in payload_features.items(): + out[k] = v + for x in self._hstu_configs.contextual_feature_to_max_length.keys(): + out[x] = seq_embeddings[x].embedding + i = 0 + for x in self._hstu_configs.contextual_feature_to_max_length.keys(): + # pyre-ignore[6] + out[x + "_offsets"] = contextual_offsets[i] + i += 1 + for x in self._additional_embedding_features: + out[x] = seq_embeddings[x].embedding + return out + return { + **payload_features, + **{ + x: seq_embeddings[x].embedding + for x in self._hstu_configs.contextual_feature_to_max_length.keys() + }, + **{ + x + "_offsets": contextual_offsets[i] + for i, x in enumerate( + list(self._hstu_configs.contextual_feature_to_max_length.keys()) + ) + }, + **{ + x: seq_embeddings[x].embedding + for x in self._additional_embedding_features + }, + } + + def _user_forward( + self, + max_uih_len: int, + max_candidates: int, + seq_embeddings: Dict[str, SequenceEmbedding], + payload_features: Dict[str, torch.Tensor], + num_candidates: torch.Tensor, + total_uih_len: Optional[int] = None, + total_targets: Optional[int] = None, + ) -> torch.Tensor: + source_lengths = seq_embeddings[ + self._hstu_configs.uih_post_id_feature_name + ].lengths + source_timestamps = concat_2D_jagged( + max_seq_len=max_uih_len + max_candidates, + max_len_left=max_uih_len, + offsets_left=payload_features["uih_offsets"], + values_left=payload_features[ + self._hstu_configs.uih_action_time_feature_name + ].unsqueeze(-1), + max_len_right=max_candidates, + offsets_right=payload_features["candidate_offsets"], + values_right=payload_features[ + self._hstu_configs.candidates_querytime_feature_name + ].unsqueeze(-1), + kernel=self.hammer_kernel(), + ).squeeze(-1) + if total_targets is None: + total_targets = fx_total_targets(num_candidates) + if total_uih_len is None: + total_uih_len = source_timestamps.numel() - total_targets + embedding = seq_embeddings[ + self._hstu_configs.uih_post_id_feature_name + ].embedding + dtype = embedding.dtype + if (not self.is_inference) and self._bf16_training: + embedding = embedding.to(torch.bfloat16) + if torch.jit.is_scripting(): + # TorchScript does not support ``with torch.autocast(...)``. + # In script-mode inference the dense path is already in bf16 + # (move_sparse_output_to_device upcasts on the C++ side), so + # autocast is a no-op for the path the predictor exercises. + candidates_user_embeddings, _ = self._hstu_transducer( + max_uih_len=max_uih_len, + max_targets=max_candidates, + total_uih_len=total_uih_len, + total_targets=total_targets, + seq_embeddings=embedding, + seq_lengths=source_lengths, + seq_timestamps=source_timestamps, + seq_payloads=self._construct_payload( + payload_features=payload_features, + seq_embeddings=seq_embeddings, + ), + num_targets=num_candidates, + ) + else: + with torch.autocast( + "cuda", + dtype=torch.bfloat16, + enabled=(not self.is_inference) and self._bf16_training, + ): + candidates_user_embeddings, _ = self._hstu_transducer( + max_uih_len=max_uih_len, + max_targets=max_candidates, + total_uih_len=total_uih_len, + total_targets=total_targets, + seq_embeddings=embedding, + seq_lengths=source_lengths, + seq_timestamps=source_timestamps, + seq_payloads=self._construct_payload( + payload_features=payload_features, + seq_embeddings=seq_embeddings, + ), + num_targets=num_candidates, + ) + candidates_user_embeddings = candidates_user_embeddings.to(dtype) + + return candidates_user_embeddings + + def _item_forward( + self, + seq_embeddings: Dict[str, SequenceEmbedding], + ) -> torch.Tensor: # [L, D] + all_embeddings = torch.cat( + [ + seq_embeddings[name].embedding + for name in self._hstu_configs.item_embedding_feature_names + ], + dim=-1, + ) + item_embeddings = self._item_embedding_mlp(all_embeddings) + return item_embeddings + + def preprocess( + self, + uih_features: KeyedJaggedTensor, + candidates_features: KeyedJaggedTensor, + merged_sparse_features: Optional[KeyedJaggedTensor] = None, + ) -> Tuple[ + Dict[str, SequenceEmbedding], + Dict[str, torch.Tensor], + int, + torch.Tensor, + int, + torch.Tensor, + ]: + # Embedding lookup for uih + candidates. When the caller (the pipeline + # path) supplies the pre-merged KJT from the batch, feed it straight to + # the EmbeddingCollection: that keeps the lookup's input a plain getattr + # off the batch so TorchRec's TrainPipelineSparseDist can hoist its + # input_dist into the prefetch stage. Building it here (cat + + # from_lengths_sync's .sync()) is an "input modification" that makes + # TorchRec skip pipelining the embedding collection. + if merged_sparse_features is None: + merged_sparse_features = KeyedJaggedTensor.from_lengths_sync( + keys=uih_features.keys() + candidates_features.keys(), + values=torch.cat( + [uih_features.values(), candidates_features.values()], + dim=0, + ), + lengths=torch.cat( + [uih_features.lengths(), candidates_features.lengths()], + dim=0, + ), + ) + seq_embeddings_dict = self._embedding_collection(merged_sparse_features) + num_candidates = fx_mark_length_features( + candidates_features.lengths().view(len(candidates_features.keys()), -1) + )[0] + max_num_candidates = fx_infer_max_len(num_candidates) + uih_seq_lengths = uih_features[ + self._hstu_configs.uih_post_id_feature_name + ].lengths() + max_uih_len = fx_infer_max_len(uih_seq_lengths) + + # prepare payload features + payload_features: Dict[str, torch.Tensor] = {} + for ( + uih_feature_name, + candidate_feature_name, + ) in self._hstu_configs.merge_uih_candidate_feature_mapping: + if ( + candidate_feature_name + not in self._hstu_configs.item_embedding_feature_names + and uih_feature_name + not in self._hstu_configs.user_embedding_feature_names + ): + values_left = uih_features[uih_feature_name].values() + if self._is_inference and ( + candidate_feature_name + == self._hstu_configs.candidates_weight_feature_name + or candidate_feature_name + == self._hstu_configs.candidates_watchtime_feature_name + ): + total_candidates = torch.sum(num_candidates).item() + values_right = torch.zeros( + total_candidates, # pyre-ignore + dtype=torch.int64, + device=values_left.device, + ) + else: + values_right = candidates_features[candidate_feature_name].values() + payload_features[uih_feature_name] = values_left + payload_features[candidate_feature_name] = values_right + payload_features["uih_offsets"] = torch.ops.fbgemm.asynchronous_complete_cumsum( + uih_seq_lengths + ) + payload_features["candidate_offsets"] = ( + torch.ops.fbgemm.asynchronous_complete_cumsum(num_candidates) + ) + + seq_embeddings = { + k: SequenceEmbedding( + lengths=seq_embeddings_dict[k].lengths(), + embedding=seq_embeddings_dict[k].values(), + ) + for k in self._hstu_configs.user_embedding_feature_names + + self._hstu_configs.item_embedding_feature_names + } + + return ( + seq_embeddings, + payload_features, + max_uih_len, + uih_seq_lengths, + max_num_candidates, + num_candidates, + ) + + def main_forward( + self, + seq_embeddings: Dict[str, SequenceEmbedding], + payload_features: Dict[str, torch.Tensor], + max_uih_len: int, + uih_seq_lengths: torch.Tensor, + max_num_candidates: int, + num_candidates: torch.Tensor, + total_uih_len: Optional[int] = None, + total_targets: Optional[int] = None, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + Dict[str, torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + ]: + # Stash this batch's jagged FLOPs/sample for MetricsLogger to read. + # No D->H sync: the .item() happens once per metric_log_frequency in + # the trainer, not on every step. Eval-mode batches also produce a + # stash but the trainer only consumes it on train batches. + if not torch.jit.is_scripting(): + self._last_jagged_flops_per_sample = ( + self._compute_jagged_flops_per_sample( + uih_seq_lengths=uih_seq_lengths, + num_candidates=num_candidates, + ) + ) + + # merge uih and candidates embeddings + for ( + uih_feature_name, + candidate_feature_name, + ) in self._hstu_configs.merge_uih_candidate_feature_mapping: + if uih_feature_name in seq_embeddings: + seq_embeddings[uih_feature_name] = SequenceEmbedding( + lengths=uih_seq_lengths + num_candidates, + embedding=concat_2D_jagged( + max_seq_len=max_uih_len + max_num_candidates, + max_len_left=max_uih_len, + offsets_left=torch.ops.fbgemm.asynchronous_complete_cumsum( + uih_seq_lengths + ), + values_left=seq_embeddings[uih_feature_name].embedding, + max_len_right=max_num_candidates, + offsets_right=torch.ops.fbgemm.asynchronous_complete_cumsum( + num_candidates + ), + values_right=seq_embeddings[candidate_feature_name].embedding, + kernel=self.hammer_kernel(), + ), + ) + + with record_function("## item_forward ##"): + candidates_item_embeddings = self._item_forward( + seq_embeddings, + ) + with record_function("## user_forward ##"): + candidates_user_embeddings = self._user_forward( + max_uih_len=max_uih_len, + max_candidates=max_num_candidates, + seq_embeddings=seq_embeddings, + payload_features=payload_features, + num_candidates=num_candidates, + total_uih_len=total_uih_len, + total_targets=total_targets, + ) + with record_function("## multitask_module ##"): + supervision_labels, supervision_weights = ( + _get_supervision_labels_and_weights( + supervision_bitmasks=payload_features[ + self._hstu_configs.candidates_weight_feature_name + ], + watchtime_sequence=payload_features[ + self._hstu_configs.candidates_watchtime_feature_name + ], + task_configs=self._multitask_configs, + ) + ) + mt_target_preds, mt_target_labels, mt_target_weights, mt_losses = ( + self._multitask_module( + encoded_user_embeddings=candidates_user_embeddings, + item_embeddings=candidates_item_embeddings, + supervision_labels=supervision_labels, + supervision_weights=supervision_weights, + ) + ) + + aux_losses: Dict[str, torch.Tensor] = {} + if not self._is_inference and self.training: + for i, task in enumerate(self._multitask_configs): + aux_losses[task.task_name] = mt_losses[i] + + return ( + candidates_user_embeddings, + candidates_item_embeddings, + aux_losses, + mt_target_preds, + mt_target_labels, + mt_target_weights, + ) + + def forward( + self, + uih_features: KeyedJaggedTensor, + candidates_features: Optional[KeyedJaggedTensor] = None, + merged_sparse_features: Optional[KeyedJaggedTensor] = None, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + Dict[str, torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + ]: + # Pipeline mode: TorchRec fx-traces this forward (via DMP.module) and the + # pipeline calls it with the single `Samples` batch. Unpacking the KJTs + # here — rather than in the wrapper — makes the EmbeddingCollection's + # input `batch.merged_sparse_features` a getattr off the batch placeholder, + # which is what lets TrainPipelineSparseDist hoist the embedding input_dist + # into the prefetch stage. Guarded from TorchScript (inference path). + if not torch.jit.is_scripting() and self._pipeline_mode: + batch = uih_features + uih_features = batch.uih_features_kjt + candidates_features = batch.candidates_features_kjt + merged_sparse_features = batch.merged_sparse_features + + with record_function("## preprocess ##"): + ( + seq_embeddings, + payload_features, + max_uih_len, + uih_seq_lengths, + max_num_candidates, + num_candidates, + ) = self.preprocess( + uih_features=uih_features, + candidates_features=candidates_features, + merged_sparse_features=merged_sparse_features, + ) + + with record_function("## main_forward ##"): + return self.main_forward( + seq_embeddings=seq_embeddings, + payload_features=payload_features, + max_uih_len=max_uih_len, + uih_seq_lengths=uih_seq_lengths, + max_num_candidates=max_num_candidates, + num_candidates=num_candidates, + ) diff --git a/recommendation_v4/generative_recommenders/modules/dynamic_stu.py b/recommendation_v4/generative_recommenders/modules/dynamic_stu.py new file mode 100644 index 000000000..e1fe8ad16 --- /dev/null +++ b/recommendation_v4/generative_recommenders/modules/dynamic_stu.py @@ -0,0 +1,304 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict +import abc +import contextlib +from typing import Any, Generator, Optional, Tuple + +import torch +from generative_recommenders.common import fx_infer_max_len +from generative_recommenders.modules.stu import STU +from generative_recommenders.ops.jagged_tensors import ( + hstu_concat_l2_embeddings, + hstu_split_l2_embeddings, +) + + +try: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") +except OSError: + pass + + +@contextlib.contextmanager +# pyre-ignore[3] +def _freeze_rng_state() -> Generator[Any, None, None]: + rng_state = torch.get_rng_state() + if torch.cuda.is_available(): + cuda_rng_state = torch.cuda.get_rng_state() + try: + yield + finally: + if torch.cuda.is_available(): + # pyre-ignore[61] + torch.cuda.set_rng_state(cuda_rng_state) + torch.set_rng_state(rng_state) + + +class DynamicSTU(STU): + def __init__(self, stu: STU, is_inference: bool) -> None: + super().__init__(is_inference) + self._stu = stu + + @abc.abstractmethod + def _preprocess( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + x_offsets: torch.Tensor, + max_seq_len: int, + num_targets: torch.Tensor, + max_kv_caching_len: int = 0, + kv_caching_lengths: Optional[torch.Tensor] = None, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + int, + torch.Tensor, + int, + Optional[torch.Tensor], + ]: + pass + + @abc.abstractmethod + def _postprocess( + self, + stu_output: torch.Tensor, + ) -> torch.Tensor: + pass + + def forward( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + x_offsets: torch.Tensor, + max_seq_len: int, + num_targets: torch.Tensor, + max_kv_caching_len: int = 0, + kv_caching_lengths: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + ( + x, + x_lengths, + x_offsets, + max_seq_len, + num_targets, + max_kv_caching_len, + kv_caching_lengths, + ) = self._preprocess( + x=x, + x_lengths=x_lengths, + x_offsets=x_offsets, + max_seq_len=max_seq_len, + num_targets=num_targets, + max_kv_caching_len=max_kv_caching_len, + kv_caching_lengths=kv_caching_lengths, + ) + + stu_output = self._stu( + x=x, + x_lengths=x_lengths, + x_offsets=x_offsets, + max_seq_len=max_seq_len, + num_targets=num_targets, + max_kv_caching_len=max_kv_caching_len, + kv_caching_lengths=kv_caching_lengths, + ) + + return self._postprocess( + stu_output=stu_output, + ) + + +class SDSTU(DynamicSTU): + def __init__( + self, + stu: STU, + is_inference: bool, + dropout_ratio: float = 0.5, + seed: int = 0, + ) -> None: + """ + Stochastic Depth STU + """ + super().__init__(stu=stu, is_inference=is_inference) + self._dropout_ratio: float = dropout_ratio + self._iter: int = 0 + self._seed: int = seed + self._skip_x: Optional[torch.Tensor] = None + + def _preprocess( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + x_offsets: torch.Tensor, + max_seq_len: int, + num_targets: torch.Tensor, + max_kv_caching_len: int = 0, + kv_caching_lengths: Optional[torch.Tensor] = None, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + int, + torch.Tensor, + int, + Optional[torch.Tensor], + ]: + if self.training: + with _freeze_rng_state(): + torch.manual_seed(self._iter + self._seed) + prob = torch.rand(1) + if prob.item() <= self._dropout_ratio: + new_x = torch.empty(size=(0, x.shape[1]), device=x.device) + self._skip_x = x + new_x_lengths = torch.zeros_like(x_lengths) + new_x_offsets = torch.zeros_like(x_offsets) + new_max_seq_len = 1 + else: + new_x = x + new_x_lengths = x_lengths + new_x_offsets = x_offsets + new_max_seq_len = max_seq_len + self._iter += 1 + else: + new_x = x + new_x_lengths = x_lengths + new_x_offsets = x_offsets + new_max_seq_len = max_seq_len + return ( + new_x, + new_x_lengths, + new_x_offsets, + new_max_seq_len, + num_targets, + max_kv_caching_len, + kv_caching_lengths, + ) + + def _postprocess( + self, + stu_output: torch.Tensor, + ) -> torch.Tensor: + if self.training and self._skip_x is not None: + ret = self._skip_x + self._skip_x = None + return ret + else: + return stu_output + + +@torch.fx.wrap +def _fx_unwrap_optional_tuple_tensor( + optional: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert optional is not None, "Expected optional to be non-None" + return optional + + +class L2STU(DynamicSTU): + def __init__( + self, + stu: STU, + max_l2_len: int, + is_inference: bool, + contextual_seq_len: int = 0, + ) -> None: + """ + Stochastic Depth STU + """ + super().__init__(stu=stu, is_inference=is_inference) + self._max_l2_len: int = max_l2_len + self._contextual_seq_len: int = contextual_seq_len + self._saved_tensors: Optional[ + Tuple[torch.Tensor, torch.Tensor, torch.Tensor] + ] = None + self._runtime_max_l2_len: int = 0 + self._runtime_prefix_len: int = 0 + + def _preprocess( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + x_offsets: torch.Tensor, + max_seq_len: int, + num_targets: torch.Tensor, + max_kv_caching_len: int = 0, + kv_caching_lengths: Optional[torch.Tensor] = None, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + int, + torch.Tensor, + int, + Optional[torch.Tensor], + ]: + prefix_lengths = ( + x_lengths - self._max_l2_len - num_targets - self._contextual_seq_len + ) + prefix_lengths = torch.clamp(prefix_lengths, min=0) + prefix_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(prefix_lengths) + l2_lengths = x_lengths - prefix_lengths + l2_offsets = x_offsets - prefix_offsets + self._runtime_max_l2_len: int = fx_infer_max_len(l2_lengths) + self._runtime_prefix_len: int = fx_infer_max_len(prefix_lengths) + prefix_x, l2_x = hstu_split_l2_embeddings( + max_seq_len=max_seq_len, + x=x, + prefix_offsets=prefix_offsets, + l2_offsets=l2_offsets, + contextual_seq_len=self._contextual_seq_len, + kernel=self.hammer_kernel(), + ) + self._saved_tensors = ( + prefix_offsets, + prefix_x, + l2_offsets, + ) + return ( + l2_x, + l2_lengths, + l2_offsets, + self._runtime_max_l2_len, + num_targets, + max_kv_caching_len, + kv_caching_lengths, + ) + + def _postprocess( + self, + stu_output: torch.Tensor, + ) -> torch.Tensor: + ( + prefix_offsets, + prefix_x, + l2_offsets, + ) = _fx_unwrap_optional_tuple_tensor(self._saved_tensors) + self._saved_tensors = None + return hstu_concat_l2_embeddings( + max_prefix_len=self._runtime_prefix_len, + prefix_x=prefix_x, + prefix_offsets=prefix_offsets, + max_l2_len=self._runtime_max_l2_len, + l2_x=stu_output, + l2_offsets=l2_offsets, + contextual_seq_len=self._contextual_seq_len, + kernel=self.hammer_kernel(), + ) diff --git a/recommendation_v4/generative_recommenders/modules/hstu_transducer.py b/recommendation_v4/generative_recommenders/modules/hstu_transducer.py new file mode 100644 index 000000000..ce91a67c9 --- /dev/null +++ b/recommendation_v4/generative_recommenders/modules/hstu_transducer.py @@ -0,0 +1,330 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import logging +from typing import Dict, Optional, Tuple + +import torch +from generative_recommenders.common import fx_unwrap_optional_tensor, HammerModule +from generative_recommenders.modules.positional_encoder import HSTUPositionalEncoder +from generative_recommenders.modules.postprocessors import ( + L2NormPostprocessor, + OutputPostprocessor, +) +from generative_recommenders.modules.preprocessors import InputPreprocessor +from generative_recommenders.modules.stu import STU +from generative_recommenders.ops.jagged_tensors import split_2D_jagged +from torch.profiler import record_function + +logger: logging.Logger = logging.getLogger(__name__) +torch.fx.wrap("len") + +try: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") +except OSError: + pass + + +@torch.fx.wrap +def default_seq_payload( + seq_payloads: Optional[Dict[str, torch.Tensor]], +) -> Dict[str, torch.Tensor]: + if seq_payloads is None: + return {} + else: + return torch.jit._unwrap_optional(seq_payloads) + + +class HSTUTransducer(HammerModule): + def __init__( + self, + stu_module: STU, + input_preprocessor: InputPreprocessor, + output_postprocessor: Optional[OutputPostprocessor] = None, + input_dropout_ratio: float = 0.0, + positional_encoder: Optional[HSTUPositionalEncoder] = None, + is_inference: bool = True, + return_full_embeddings: bool = False, + listwise: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + self._stu_module = stu_module + self._input_preprocessor: InputPreprocessor = input_preprocessor + self._output_postprocessor: OutputPostprocessor = ( + output_postprocessor + if output_postprocessor is not None + else L2NormPostprocessor(is_inference=is_inference) + ) + assert self._is_inference == self._input_preprocessor._is_inference, ( + f"input_preprocessor must have the same mode; self: {self._is_inference} vs input_preprocessor {self._input_preprocessor._is_inference}" + ) + self._positional_encoder: Optional[HSTUPositionalEncoder] = positional_encoder + self._input_dropout_ratio: float = input_dropout_ratio + self._return_full_embeddings: bool = return_full_embeddings + self._listwise_training: bool = listwise and self.is_train + + for name, m in self.named_modules(): + if "_stu_module" in name: + continue + elif isinstance(m, torch.nn.Linear): + torch.nn.init.xavier_normal_(m.weight) + elif isinstance(m, torch.nn.LayerNorm): + if m.weight.dim() >= 2: + torch.nn.init.xavier_normal_(m.weight) + if m.bias is not None and m.bias.dim() >= 2: + torch.nn.init.xavier_normal_(m.bias) + + def _preprocess( + self, + max_uih_len: int, + max_targets: int, + total_uih_len: int, + total_targets: int, + seq_lengths: torch.Tensor, + seq_timestamps: torch.Tensor, + seq_embeddings: torch.Tensor, + num_targets: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + ) -> Tuple[ + int, + int, + int, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Dict[str, torch.Tensor], + ]: + seq_payloads = default_seq_payload(seq_payloads) + + with record_function("hstu_input_preprocessor"): + ( + output_max_seq_len, + output_total_uih_len, + output_total_targets, + output_seq_lengths, + output_seq_offsets, + output_seq_timestamps, + output_seq_embeddings, + output_num_targets, + output_seq_payloads, + ) = self._input_preprocessor( + max_uih_len=max_uih_len, + max_targets=max_targets, + total_uih_len=total_uih_len, + total_targets=total_targets, + seq_lengths=seq_lengths, + seq_timestamps=seq_timestamps, + seq_embeddings=seq_embeddings, + num_targets=num_targets, + seq_payloads=seq_payloads, + ) + + with record_function("hstu_positional_encoder"): + if self._positional_encoder is not None: + output_seq_embeddings = self._positional_encoder( + max_seq_len=output_max_seq_len, + seq_lengths=output_seq_lengths, + seq_offsets=output_seq_offsets, + seq_timestamps=output_seq_timestamps, + seq_embeddings=output_seq_embeddings, + num_targets=( + None if self._listwise_training else output_num_targets + ), + ) + + output_seq_embeddings = torch.nn.functional.dropout( + output_seq_embeddings, + p=self._input_dropout_ratio, + training=self.training, + ) + + return ( + output_max_seq_len, + output_total_uih_len, + output_total_targets, + output_seq_lengths, + output_seq_offsets, + output_seq_timestamps, + output_seq_embeddings, + output_num_targets, + output_seq_payloads, + ) + + def _hstu_compute( + self, + max_seq_len: int, + seq_lengths: torch.Tensor, + seq_offsets: torch.Tensor, + seq_timestamps: torch.Tensor, + seq_embeddings: torch.Tensor, + num_targets: torch.Tensor, + ) -> torch.Tensor: + with record_function("hstu"): + seq_embeddings = self._stu_module( + max_seq_len=max_seq_len, + x=seq_embeddings, + x_lengths=seq_lengths, + x_offsets=seq_offsets, + num_targets=(None if self._listwise_training else num_targets), + ) + return seq_embeddings + + def _postprocess( + self, + max_seq_len: int, + max_uih_len: int, + max_targets: int, + total_uih_len: int, + total_targets: int, + seq_lengths: torch.Tensor, + seq_timestamps: torch.Tensor, + seq_embeddings: torch.Tensor, + num_targets: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: + with record_function("hstu_output_postprocessor"): + if self._return_full_embeddings: + seq_embeddings = self._output_postprocessor( + seq_embeddings=seq_embeddings, + seq_timestamps=seq_timestamps, + seq_payloads=seq_payloads, + ) + uih_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + seq_lengths - num_targets + ) + candidates_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + num_targets + ) + _, candidate_embeddings = split_2D_jagged( + values=seq_embeddings, + max_seq_len=max_seq_len, + total_len_left=total_uih_len, + total_len_right=total_targets, + max_len_left=max_uih_len, + max_len_right=max_targets, + offsets_left=uih_offsets, + offsets_right=candidates_offsets, + kernel=self.hammer_kernel(), + ) + interleave_targets: bool = self._input_preprocessor.interleave_targets() + if interleave_targets: + candidate_embeddings = candidate_embeddings.view( + -1, 2, candidate_embeddings.size(-1) + )[:, 0, :] + if not self._return_full_embeddings: + _, candidate_timestamps = split_2D_jagged( + values=seq_timestamps.unsqueeze(-1), + max_seq_len=max_seq_len, + total_len_left=total_uih_len, + total_len_right=total_targets, + max_len_left=max_uih_len, + max_len_right=max_targets, + offsets_left=uih_offsets, + offsets_right=candidates_offsets, + kernel=self.hammer_kernel(), + ) + candidate_timestamps = candidate_timestamps.squeeze(-1) + if interleave_targets: + candidate_timestamps = candidate_timestamps.view(-1, 2)[:, 0] + candidate_embeddings = self._output_postprocessor( + seq_embeddings=candidate_embeddings, + seq_timestamps=candidate_timestamps, + seq_payloads=seq_payloads, + ) + + return ( + seq_embeddings if self._return_full_embeddings else None, + candidate_embeddings, + ) + + def forward( + self, + max_uih_len: int, + max_targets: int, + total_uih_len: int, + total_targets: int, + seq_lengths: torch.Tensor, + seq_embeddings: torch.Tensor, + seq_timestamps: torch.Tensor, + num_targets: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + ) -> Tuple[ + torch.Tensor, + Optional[torch.Tensor], + ]: + orig_dtype = seq_embeddings.dtype + if not self._is_inference: + seq_embeddings = seq_embeddings.to(self._training_dtype) + + ( + max_seq_len, + total_uih_len, + total_targets, + seq_lengths, + seq_offsets, + seq_timestamps, + seq_embeddings, + num_targets, + seq_payloads, + ) = self._preprocess( + max_uih_len=max_uih_len, + max_targets=max_targets, + total_uih_len=total_uih_len, + total_targets=total_targets, + seq_lengths=seq_lengths, + seq_timestamps=seq_timestamps, + seq_embeddings=seq_embeddings, + num_targets=num_targets, + seq_payloads=seq_payloads, + ) + + encoded_embeddings = self._hstu_compute( + max_seq_len=max_seq_len, + seq_lengths=seq_lengths, + seq_offsets=seq_offsets, + seq_timestamps=seq_timestamps, + seq_embeddings=seq_embeddings, + num_targets=num_targets, + ) + + encoded_embeddings, encoded_candidate_embeddings = self._postprocess( + max_seq_len=max_seq_len, + max_uih_len=max_uih_len, + max_targets=max_targets, + total_uih_len=total_uih_len, + total_targets=total_targets, + seq_lengths=seq_lengths, + seq_embeddings=encoded_embeddings, + seq_timestamps=seq_timestamps, + num_targets=num_targets, + seq_payloads=seq_payloads, + ) + + if not self._is_inference: + encoded_candidate_embeddings = encoded_candidate_embeddings.to(orig_dtype) + if self._return_full_embeddings: + encoded_embeddings = fx_unwrap_optional_tensor(encoded_embeddings).to( + orig_dtype + ) + return ( + encoded_candidate_embeddings, + encoded_embeddings, + ) diff --git a/recommendation_v4/generative_recommenders/modules/multitask_module.py b/recommendation_v4/generative_recommenders/modules/multitask_module.py new file mode 100644 index 000000000..3cb11996f --- /dev/null +++ b/recommendation_v4/generative_recommenders/modules/multitask_module.py @@ -0,0 +1,288 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import abc +import logging +from dataclasses import dataclass +from enum import IntEnum +from typing import Callable, Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from generative_recommenders.common import HammerModule + +logger: logging.Logger = logging.getLogger(__name__) + + +class MultitaskTaskType(IntEnum): + BINARY_CLASSIFICATION = 0 + REGRESSION = 1 + + +@dataclass +class TaskConfig: + task_name: str + task_weight: int + task_type: MultitaskTaskType + + +class MultitaskModule(HammerModule): + @abc.abstractmethod + def forward( + self, + encoded_user_embeddings: torch.Tensor, + item_embeddings: torch.Tensor, + supervision_labels: Dict[str, torch.Tensor], + supervision_weights: Dict[str, torch.Tensor], + ) -> Tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + ]: + """ + Computes multi-task predictions. + + Args: + encoded_user_embeddings: (L, D) x float. + item_embeddings: (L, D) x float. + supervision_labels: Dict[T, L] x float or int + supervision_weights: Dict[T', L] x float or int, T' <= T + Returns: + (T, L) x float, predictions, labels, weights, losses + """ + pass + + +def _compute_pred_and_logits( + prediction_module: torch.nn.Module, + encoded_user_embeddings: torch.Tensor, + item_embeddings: torch.Tensor, + task_offsets: List[int], + has_multiple_task_types: bool, +) -> Tuple[torch.Tensor, torch.Tensor]: + mt_logits = prediction_module(encoded_user_embeddings * item_embeddings).transpose( + 0, 1 + ) + mt_preds_list: List[torch.Tensor] = [] + for task_type in MultitaskTaskType: + logits = mt_logits[ + task_offsets[task_type] : task_offsets[task_type + 1], + :, + ] + if task_offsets[task_type + 1] - task_offsets[task_type] > 0: + if task_type == MultitaskTaskType.REGRESSION: + mt_preds_list.append(logits) + else: + mt_preds_list.append(F.sigmoid(logits)) + if has_multiple_task_types: + mt_preds: torch.Tensor = torch.concat(mt_preds_list, dim=0) + else: + mt_preds: torch.Tensor = mt_preds_list[0] + + return mt_preds, mt_logits + + +def _compute_labels_and_weights( + supervision_labels: Dict[str, torch.Tensor], + supervision_weights: Dict[str, torch.Tensor], + task_configs: List[TaskConfig], + device: torch.device, + dtype: torch.dtype = torch.float32, +) -> Tuple[torch.Tensor, torch.Tensor]: + first_label: torch.Tensor = list(supervision_labels.values())[0] + default_supervision_weight = torch.ones_like( + first_label, + dtype=dtype, + device=device, + ) + mt_lables_list: List[torch.Tensor] = [] + mt_weights_list: List[torch.Tensor] = [] + for task in task_configs: + mt_lables_list.append(supervision_labels[task.task_name]) + mt_weights_list.append( + supervision_weights.get(task.task_name, default_supervision_weight) + ) + if len(task_configs) > 1: + mt_labels = torch.stack(mt_lables_list, dim=0) + mt_weights = torch.stack(mt_weights_list, dim=0) + else: + mt_labels = mt_lables_list[0].unsqueeze(0) + mt_weights = mt_weights_list[0].unsqueeze(0) + return mt_labels, mt_weights + + +def _compute_loss( + task_offsets: List[int], + causal_multitask_weights: float, + mt_logits: torch.Tensor, + mt_labels: torch.Tensor, + mt_weights: torch.Tensor, + has_multiple_task_types: bool, +) -> torch.Tensor: + mt_losses_list: List[torch.Tensor] = [] + for task_type in MultitaskTaskType: + if task_offsets[task_type + 1] - task_offsets[task_type] > 0: + logits = mt_logits[ + task_offsets[task_type] : task_offsets[task_type + 1], + :, + ] + labels = mt_labels[ + task_offsets[task_type] : task_offsets[task_type + 1], + :, + ] + weights = mt_weights[ + task_offsets[task_type] : task_offsets[task_type + 1], + :, + ] + if task_type == MultitaskTaskType.REGRESSION: + mt_losses_list.append( + F.mse_loss(logits, labels, reduction="none") * weights + ) + else: + mt_losses_list.append( + F.binary_cross_entropy_with_logits( + input=logits, target=labels, reduction="none" + ) + * weights + ) + + if has_multiple_task_types: + mt_losses = torch.concat(mt_losses_list, dim=0) + else: + mt_losses = mt_losses_list[0] + mt_losses = ( + mt_losses.sum(-1) / mt_weights.sum(-1).clamp(min=1.0) * causal_multitask_weights + ) + return mt_losses + + +class DefaultMultitaskModule(MultitaskModule): + def __init__( + self, + task_configs: List[TaskConfig], + embedding_dim: int, + prediction_fn: Callable[[int, int], torch.nn.Module], + causal_multitask_weights: float, + is_inference: bool, + ) -> None: + super().__init__(is_inference) + assert sorted(task_configs, key=lambda x: x.task_type) == task_configs, ( + "task_configs must be sorted by task_type." + ) + assert len(task_configs) > 0, "task_configs must be non-empty." + self._task_configs: List[TaskConfig] = task_configs + self._task_offsets: List[int] = [0] * (len(MultitaskTaskType) + 1) + for task in self._task_configs: + self._task_offsets[task.task_type + 1] += 1 + self._has_multiple_task_types: bool = self._task_offsets.count(0) < len( + MultitaskTaskType + ) + self._task_offsets[1:] = np.cumsum(self._task_offsets[1:]).tolist() + self._causal_multitask_weights: float = causal_multitask_weights + self._prediction_module: torch.nn.Module = prediction_fn( + embedding_dim, len(task_configs) + ) + + def forward( + self, + encoded_user_embeddings: torch.Tensor, + item_embeddings: torch.Tensor, + supervision_labels: Dict[str, torch.Tensor], + supervision_weights: Dict[str, torch.Tensor], + ) -> Tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + ]: + orig_dtype = encoded_user_embeddings.dtype + if not self._is_inference: + encoded_user_embeddings = encoded_user_embeddings.to(self._training_dtype) + item_embeddings = item_embeddings.to(self._training_dtype) + + if torch.jit.is_scripting(): + # Script-mode fast path: skip torch.autocast (unsupported in TS) + # and inline _compute_pred_and_logits to avoid its + # `torch.nn.Module` parameter annotation (TS only knows + # concrete module types). The dense module is already in bf16 + # at this point, so autocast is a no-op for the predictor path. + mt_logits = self._prediction_module( + encoded_user_embeddings * item_embeddings + ).transpose(0, 1) + mt_preds_list: List[torch.Tensor] = [] + # MultitaskTaskType is an IntEnum (BINARY_CLASSIFICATION=0, + # REGRESSION=1) but TorchScript treats it as an opaque Enum. + # Iterate by the integer task indices directly. + for task_type in range(len(self._task_offsets) - 1): + start = self._task_offsets[task_type] + end = self._task_offsets[task_type + 1] + logits = mt_logits[start:end, :] + if end - start > 0: + # 1 == MultitaskTaskType.REGRESSION + if task_type == 1: + mt_preds_list.append(logits) + else: + mt_preds_list.append(F.sigmoid(logits)) + if self._has_multiple_task_types: + mt_preds: torch.Tensor = torch.concat(mt_preds_list, dim=0) + else: + mt_preds: torch.Tensor = mt_preds_list[0] + return mt_preds, None, None, None + + with torch.autocast( + "cuda", + dtype=torch.bfloat16, + enabled=(not self.is_inference and self._training_dtype == torch.bfloat16), + ): + mt_preds, mt_logits = _compute_pred_and_logits( + prediction_module=self._prediction_module, + encoded_user_embeddings=encoded_user_embeddings, + item_embeddings=item_embeddings, + task_offsets=self._task_offsets, + has_multiple_task_types=self._has_multiple_task_types, + ) + + # losses are always computed in fp32 + mt_labels: Optional[torch.Tensor] = None + mt_weights: Optional[torch.Tensor] = None + mt_losses: Optional[torch.Tensor] = None + if not self._is_inference: + mt_labels, mt_weights = _compute_labels_and_weights( + supervision_labels=supervision_labels, + supervision_weights=supervision_weights, + task_configs=self._task_configs, + device=encoded_user_embeddings.device, + ) + mt_losses = _compute_loss( + task_offsets=self._task_offsets, + causal_multitask_weights=self._causal_multitask_weights, + mt_logits=mt_logits.to(mt_labels.dtype), + mt_labels=mt_labels, + mt_weights=mt_weights, + has_multiple_task_types=self._has_multiple_task_types, + ) + mt_preds = mt_preds.to(orig_dtype) + + return ( + mt_preds, + mt_labels, + mt_weights, + mt_losses, + ) diff --git a/recommendation_v4/generative_recommenders/modules/positional_encoder.py b/recommendation_v4/generative_recommenders/modules/positional_encoder.py new file mode 100644 index 000000000..99d904fd4 --- /dev/null +++ b/recommendation_v4/generative_recommenders/modules/positional_encoder.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from math import sqrt +from typing import Optional + +import torch +from generative_recommenders.common import HammerModule +from generative_recommenders.ops.position import add_timestamp_positional_embeddings + + +class HSTUPositionalEncoder(HammerModule): + def __init__( + self, + num_position_buckets: int, + num_time_buckets: int, + embedding_dim: int, + contextual_seq_len: int, + is_inference: bool = True, + ) -> None: + super().__init__(is_inference=is_inference) + self._embedding_dim: int = embedding_dim + self._contextual_seq_len: int = contextual_seq_len + self._position_embeddings_weight: torch.nn.Parameter = torch.nn.Parameter( + torch.empty(num_position_buckets, embedding_dim).uniform_( + -sqrt(1.0 / num_position_buckets), + sqrt(1.0 / num_position_buckets), + ), + ) + self._timestamp_embeddings_weight: torch.nn.Parameter = torch.nn.Parameter( + torch.empty(num_time_buckets + 1, embedding_dim).uniform_( + -sqrt(1.0 / num_time_buckets), + sqrt(1.0 / num_time_buckets), + ), + ) + + def forward( + self, + max_seq_len: int, + seq_lengths: torch.Tensor, + seq_offsets: torch.Tensor, + seq_timestamps: torch.Tensor, + seq_embeddings: torch.Tensor, + num_targets: Optional[torch.Tensor], + ) -> torch.Tensor: + seq_embeddings = add_timestamp_positional_embeddings( + alpha=self._embedding_dim**0.5, + max_seq_len=max_seq_len, + max_contextual_seq_len=self._contextual_seq_len, + position_embeddings_weight=self._position_embeddings_weight, + timestamp_embeddings_weight=self._timestamp_embeddings_weight, + seq_offsets=seq_offsets, + seq_lengths=seq_lengths, + seq_embeddings=seq_embeddings, + timestamps=seq_timestamps, + num_targets=num_targets, + interleave_targets=False, + kernel=self.hammer_kernel(), + ) + return seq_embeddings diff --git a/recommendation_v4/generative_recommenders/modules/postprocessors.py b/recommendation_v4/generative_recommenders/modules/postprocessors.py new file mode 100644 index 000000000..7958e3fa9 --- /dev/null +++ b/recommendation_v4/generative_recommenders/modules/postprocessors.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from abc import abstractmethod +from typing import Dict, List, Tuple + +import torch +from generative_recommenders.common import HammerModule, init_mlp_weights_optional_bias + + +@torch.fx.wrap +def _cast_dtype(t: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + if t.dtype != dtype: + return t.to(dtype) + return t + + +class OutputPostprocessor(HammerModule): + """An abstract class for post-processing user embeddings after HSTU layers.""" + + @abstractmethod + def forward( + self, + seq_embeddings: torch.Tensor, + seq_timestamps: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + ) -> torch.Tensor: + """ + Args: + seq_embeddings: (L, D) + seq_timestamps: (L, ) + seq_payloads: str-keyed tensors. Implementation specific. + + Returns: + postprocessed seq_embeddings, (L, D) + """ + pass + + +class L2NormPostprocessor(OutputPostprocessor): + """Postprocesses user embeddings with l2 norm.""" + + def __init__(self, is_inference: bool = False) -> None: + super().__init__(is_inference=is_inference) + + def forward( + self, + seq_embeddings: torch.Tensor, + seq_timestamps: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + ) -> torch.Tensor: + return seq_embeddings / torch.linalg.norm( + seq_embeddings, ord=2, dim=-1, keepdim=True + ).clamp(min=1e-6) + + +class LayerNormPostprocessor(OutputPostprocessor): + """Postprocesses user embeddings with layer norm.""" + + def __init__( + self, + embedding_dim: int, + eps: float = 1e-5, + is_inference: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + + self._layer_norm: torch.nn.Module = torch.nn.LayerNorm( + normalized_shape=[embedding_dim], eps=eps + ) + + def forward( + self, + seq_embeddings: torch.Tensor, + seq_timestamps: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + ) -> torch.Tensor: + # pyre-fixme[6]: For 1st argument expected `dtype` but got `Union[dtype, + # Tensor, Module]`. + return self._layer_norm(seq_embeddings.to(self._layer_norm.weight.dtype)) + + +@torch.fx.wrap +def _unsqueeze_if_needed(t: torch.Tensor, embedding: torch.Tensor) -> torch.Tensor: + if embedding.dim() == 3: + return t.unsqueeze(0) + return t + + +class TimestampLayerNormPostprocessor(OutputPostprocessor): + """Postprocesses user embeddings with timestamp-based MLP -> layer norm.""" + + def __init__( + self, + embedding_dim: int, + time_duration_features: List[Tuple[int, int]], + eps: float = 1e-5, + is_inference: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + + self._layer_norm: torch.nn.Module = torch.nn.LayerNorm( + normalized_shape=[embedding_dim], eps=eps + ) + self.register_buffer( + "_period_units", + torch.Tensor([f[0] for f in time_duration_features]).view(1, -1), + ) + self.register_buffer( + "_units_per_period", + torch.Tensor([f[1] for f in time_duration_features]).view(1, -1), + ) + self._time_feature_combiner: torch.nn.Module = torch.nn.Linear( + embedding_dim + 2 * len(time_duration_features), + embedding_dim, + ).apply(init_mlp_weights_optional_bias) + + def _concat_time_features( + self, + combined_embeddings: torch.Tensor, + timestamps: torch.Tensor, # [B] or [B, D] + ) -> torch.Tensor: + # concat time representation to combined embeddings + period_units = self._period_units + units_per_period = self._units_per_period + + timestamps = timestamps.unsqueeze(-1) + period_units = _unsqueeze_if_needed(period_units, combined_embeddings) + units_per_period = _unsqueeze_if_needed( + units_per_period, combined_embeddings + ).float() + # Compute time features in float32 to avoid bf16 precision loss through + # discontinuous floor/remainder ops, matching Inductor fusion behavior. + _units_elapsed_type: torch.dtype = combined_embeddings.dtype + _units_since_epoch = torch.div( + timestamps.float(), period_units.float(), rounding_mode="floor" + ) # [sum(N_i), num_time_features] or [B, N, num_time_features] + _units_elapsed = ( + (torch.remainder(_units_since_epoch, units_per_period) / units_per_period) + * 2 + * 3.14 + ) + _units_elapsed = torch.view_as_real( + torch.polar( + _cast_dtype(torch.ones_like(_units_elapsed), torch.float32), + _cast_dtype(_units_elapsed, torch.float32), + ) + ).flatten( + -2, -1 + ) # [sum(N_i), num_time_features * 2] or [B, N, num_time_features * 2] + _units_elapsed = _cast_dtype(_units_elapsed, _units_elapsed_type) + combined_embeddings = torch.cat([combined_embeddings, _units_elapsed], dim=-1) + return combined_embeddings + + def forward( + self, + seq_embeddings: torch.Tensor, + seq_timestamps: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + ) -> torch.Tensor: + user_embeddings = self._time_feature_combiner( + self._concat_time_features(seq_embeddings, timestamps=seq_timestamps).to( + self._time_feature_combiner.weight.dtype # pyre-fixme[6]: For 1st argument expected `dtype` but got `Union[dtype, + # Tensor, Module]`. + ) + ) + return self._layer_norm(user_embeddings) diff --git a/recommendation_v4/generative_recommenders/modules/preprocessors.py b/recommendation_v4/generative_recommenders/modules/preprocessors.py new file mode 100644 index 000000000..dc7806bb4 --- /dev/null +++ b/recommendation_v4/generative_recommenders/modules/preprocessors.py @@ -0,0 +1,334 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import abc +from math import sqrt +from typing import Dict, List, Optional, Tuple + +import torch +from generative_recommenders.common import ( + fx_unwrap_optional_tensor, + HammerModule, + init_mlp_weights_optional_bias, + jagged_to_padded_dense, +) +from generative_recommenders.modules.action_encoder import ActionEncoder +from generative_recommenders.ops.jagged_tensors import concat_2D_jagged +from generative_recommenders.ops.layer_norm import LayerNorm, SwishLayerNorm + + +class InputPreprocessor(HammerModule): + """An abstract class for pre-processing sequence embeddings before HSTU layers.""" + + @abc.abstractmethod + def forward( + self, + max_uih_len: int, + max_targets: int, + total_uih_len: int, + total_targets: int, + seq_lengths: torch.Tensor, + seq_timestamps: torch.Tensor, + seq_embeddings: torch.Tensor, + num_targets: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + ) -> Tuple[ + int, + int, + int, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Dict[str, torch.Tensor], + ]: + """ + Args: + max_uih_len: int + max_targets: int + total_uih_len: int + total_targets: int + seq_lengths: (B,) + seq_embeddings: (L, D) + seq_timestamps: (B, N) + num_targets: (B,) Optional. + seq_payloads: str-keyed tensors. Implementation specific. + + Returns: + (max_seq_len, total_uih_len, total_targets, lengths, offsets, timestamps, embeddings, num_targets, payloads) updated based on input preprocessor. + """ + pass + + def interleave_targets(self) -> bool: + return False + + +def get_contextual_input_embeddings( + seq_lengths: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + contextual_feature_to_max_length: Dict[str, int], + contextual_feature_to_min_uih_length: Dict[str, int], + dtype: torch.dtype, +) -> torch.Tensor: + padded_values: List[torch.Tensor] = [] + for key, max_len in contextual_feature_to_max_length.items(): + v = torch.flatten( + jagged_to_padded_dense( + values=seq_payloads[key].to(dtype), + offsets=[seq_payloads[key + "_offsets"]], + max_lengths=[max_len], + padding_value=0.0, + ), + 1, + 2, + ) + min_uih_length = contextual_feature_to_min_uih_length.get(key, 0) + if min_uih_length > 0: + v = v * (seq_lengths.view(-1, 1) >= min_uih_length) + padded_values.append(v) + return torch.cat(padded_values, dim=1) + + +class ContextualPreprocessor(InputPreprocessor): + def __init__( + self, + input_embedding_dim: int, + hidden_dim: int, + output_embedding_dim: int, + contextual_feature_to_max_length: Dict[str, int], + contextual_feature_to_min_uih_length: Dict[str, int], + action_embedding_dim: int = 8, + action_feature_name: str = "", + action_weights: Optional[List[int]] = None, + additional_embedding_features: List[str] = [], + action_embedding_init_std: float = 0.1, + is_inference: bool = True, + ) -> None: + super().__init__(is_inference=is_inference) + self._output_embedding_dim: int = output_embedding_dim + self._input_embedding_dim: int = input_embedding_dim + self._hidden_dim: int = hidden_dim + self._contextual_feature_to_max_length: Dict[str, int] = ( + contextual_feature_to_max_length + ) + self._max_contextual_seq_len: int = sum( + contextual_feature_to_max_length.values() + ) + self._contextual_feature_to_min_uih_length: Dict[str, int] = ( + contextual_feature_to_min_uih_length + ) + if self._max_contextual_seq_len > 0: + std = 1.0 * sqrt( + 2.0 / float(input_embedding_dim + self._output_embedding_dim) + ) + self._batched_contextual_linear_weights: torch.nn.Parameter = ( + torch.nn.Parameter( + torch.empty( + ( + self._max_contextual_seq_len, + input_embedding_dim, + self._output_embedding_dim, + ) + ).normal_(0.0, std) + ) + ) + self._batched_contextual_linear_bias: torch.nn.Parameter = ( + torch.nn.Parameter( + torch.empty( + (self._max_contextual_seq_len, self._output_embedding_dim) + ).fill_(0.0) + ) + ) + self._content_embedding_mlp: torch.nn.Module = torch.nn.Sequential( + torch.nn.Linear( + in_features=self._input_embedding_dim, + out_features=self._hidden_dim, + ), + SwishLayerNorm(self._hidden_dim), + torch.nn.Linear( + in_features=self._hidden_dim, + out_features=self._output_embedding_dim, + ), + LayerNorm(self._output_embedding_dim), + ).apply(init_mlp_weights_optional_bias) + self._additional_embedding_features: List[str] = additional_embedding_features + self._additional_embedding_mlp: torch.nn.Module = torch.nn.Sequential( + torch.nn.Linear( + in_features=self._input_embedding_dim + * len(additional_embedding_features), + out_features=self._hidden_dim, + ), + SwishLayerNorm(self._hidden_dim), + torch.nn.Linear( + in_features=self._hidden_dim, + out_features=self._output_embedding_dim, + ), + LayerNorm(self._output_embedding_dim), + ).apply(init_mlp_weights_optional_bias) + self._action_feature_name: str = action_feature_name + self._action_weights: Optional[List[int]] = action_weights + if self._action_weights is not None: + self._action_encoder: ActionEncoder = ActionEncoder( + action_feature_name=action_feature_name, + action_weights=self._action_weights, + action_embedding_dim=action_embedding_dim, + embedding_init_std=action_embedding_init_std, + is_inference=is_inference, + ) + self._action_embedding_mlp: torch.nn.Module = torch.nn.Sequential( + torch.nn.Linear( + in_features=self._action_encoder.output_embedding_dim, + out_features=self._hidden_dim, + ), + SwishLayerNorm(self._hidden_dim), + torch.nn.Linear( + in_features=self._hidden_dim, + out_features=self._output_embedding_dim, + ), + LayerNorm(self._output_embedding_dim), + ).apply(init_mlp_weights_optional_bias) + + def forward( # noqa C901 + self, + max_uih_len: int, + max_targets: int, + total_uih_len: int, + total_targets: int, + seq_lengths: torch.Tensor, + seq_timestamps: torch.Tensor, + seq_embeddings: torch.Tensor, + num_targets: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + ) -> Tuple[ + int, + int, + int, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Dict[str, torch.Tensor], + ]: + output_seq_embeddings = self._content_embedding_mlp(seq_embeddings) + if len(self._additional_embedding_features) > 0: + additional_embeddings = torch.cat( + [ + seq_payloads[feature] + for feature in self._additional_embedding_features + ], + dim=1, + ) + output_seq_embeddings = ( + output_seq_embeddings + + self._additional_embedding_mlp(additional_embeddings) + ) + max_seq_len = max_uih_len + max_targets + target_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(num_targets) + seq_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(seq_lengths) + uih_offsets = seq_offsets - target_offsets + if self._action_weights is not None: + action_embeddings = self._action_encoder( + max_uih_len=max_uih_len, + max_targets=max_targets, + uih_offsets=uih_offsets, + target_offsets=target_offsets, + seq_embeddings=seq_embeddings, + seq_payloads=seq_payloads, + ) + output_seq_embeddings = output_seq_embeddings + self._action_embedding_mlp( + action_embeddings + ) + + output_max_seq_len = max_seq_len + output_total_uih_len = total_uih_len + output_total_targets = total_targets + output_seq_lengths = seq_lengths + output_num_targets = num_targets + output_seq_timestamps = seq_timestamps + output_seq_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + output_seq_lengths + ) + # concat contextual embeddings + if self._max_contextual_seq_len > 0: + contextual_input_embeddings = get_contextual_input_embeddings( + seq_lengths=seq_lengths, + seq_payloads=seq_payloads, + contextual_feature_to_max_length=self._contextual_feature_to_max_length, + contextual_feature_to_min_uih_length=self._contextual_feature_to_min_uih_length, + dtype=seq_embeddings.dtype, + ) + contextual_embeddings = torch.baddbmm( + self._batched_contextual_linear_bias.view( + -1, 1, self._output_embedding_dim + ).to(contextual_input_embeddings.dtype), + contextual_input_embeddings.view( + -1, self._max_contextual_seq_len, self._input_embedding_dim + ).transpose(0, 1), + self._batched_contextual_linear_weights.to( + contextual_input_embeddings.dtype + ), + ).transpose(0, 1) + output_seq_embeddings = concat_2D_jagged( + max_seq_len=self._max_contextual_seq_len + output_max_seq_len, + values_left=fx_unwrap_optional_tensor(contextual_embeddings).reshape( + -1, self._output_embedding_dim + ), + values_right=output_seq_embeddings, + max_len_left=self._max_contextual_seq_len, + max_len_right=output_max_seq_len, + offsets_left=None, + offsets_right=output_seq_offsets, + kernel=self.hammer_kernel(), + ) + output_seq_timestamps = concat_2D_jagged( + max_seq_len=self._max_contextual_seq_len + output_max_seq_len, + values_left=torch.zeros( + (output_seq_lengths.size(0) * self._max_contextual_seq_len, 1), + dtype=output_seq_timestamps.dtype, + device=output_seq_timestamps.device, + ), + values_right=output_seq_timestamps.unsqueeze(-1), + max_len_left=self._max_contextual_seq_len, + max_len_right=output_max_seq_len, + offsets_left=None, + offsets_right=output_seq_offsets, + kernel=self.hammer_kernel(), + ).squeeze(-1) + output_max_seq_len = output_max_seq_len + self._max_contextual_seq_len + output_seq_lengths = output_seq_lengths + self._max_contextual_seq_len + output_seq_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + output_seq_lengths + ) + output_total_uih_len = ( + output_total_uih_len + + self._max_contextual_seq_len * output_seq_lengths.size(0) + ) + + return ( + output_max_seq_len, + output_total_uih_len, + output_total_targets, + output_seq_lengths, + output_seq_offsets, + output_seq_timestamps, + output_seq_embeddings, + output_num_targets, + seq_payloads, + ) diff --git a/recommendation_v4/generative_recommenders/modules/stu.py b/recommendation_v4/generative_recommenders/modules/stu.py new file mode 100644 index 000000000..45c6ea5f3 --- /dev/null +++ b/recommendation_v4/generative_recommenders/modules/stu.py @@ -0,0 +1,471 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict +import abc +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import torch +from generative_recommenders.common import fx_unwrap_optional_tensor, HammerModule +from generative_recommenders.ops.hstu_attention import delta_hstu_mha +from generative_recommenders.ops.hstu_compute import ( + hstu_compute_output, + hstu_compute_uqvk, + hstu_preprocess_and_attention, +) +from generative_recommenders.ops.jagged_tensors import concat_2D_jagged, split_2D_jagged +from torch.autograd.profiler import record_function + + +try: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") +except OSError: + pass + + +class STU(HammerModule, abc.ABC): + def cached_forward( + self, + delta_x: torch.Tensor, + num_targets: torch.Tensor, + max_kv_caching_len: int = 0, + kv_caching_lengths: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError + + @abc.abstractmethod + def forward( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + x_offsets: torch.Tensor, + max_seq_len: int, + num_targets: torch.Tensor, + max_kv_caching_len: int = 0, + kv_caching_lengths: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pass + + +@dataclass +class STULayerConfig: + embedding_dim: int + num_heads: int + hidden_dim: int + attention_dim: int + output_dropout_ratio: float = 0.3 + causal: bool = True + target_aware: bool = True + max_attn_len: Optional[int] = None + attn_alpha: Optional[float] = None + use_group_norm: bool = False + recompute_normed_x: bool = True + recompute_uvqk: bool = True + recompute_y: bool = True + sort_by_length: bool = True + contextual_seq_len: int = 0 + + +@torch.fx.wrap +def _update_kv_cache( + max_seq_len: int, + seq_offsets: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + max_kv_caching_len: int, + kv_caching_lengths: Optional[torch.Tensor], + orig_k_cache: Optional[torch.Tensor], + orig_v_cache: Optional[torch.Tensor], + orig_max_kv_caching_len: int, + orig_kv_caching_offsets: Optional[torch.Tensor], +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], int, Optional[torch.Tensor]]: + if kv_caching_lengths is not None: + kv_caching_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + kv_caching_lengths + ) + delta_offsets = seq_offsets - kv_caching_offsets + k_cache, _ = split_2D_jagged( + max_seq_len=max_seq_len, + values=fx_unwrap_optional_tensor(k).flatten(1, 2), + max_len_left=None, + max_len_right=None, + offsets_left=kv_caching_offsets, + offsets_right=delta_offsets, + ) + v_cache, _ = split_2D_jagged( + max_seq_len=max_seq_len, + values=fx_unwrap_optional_tensor(v).flatten(1, 2), + max_len_left=None, + max_len_right=None, + offsets_left=kv_caching_offsets, + offsets_right=delta_offsets, + ) + if max_kv_caching_len == 0: + max_kv_caching_len = int(kv_caching_lengths.max().item()) + return ( + k_cache, + v_cache, + max_kv_caching_len, + kv_caching_offsets, + ) + else: + return ( + orig_k_cache, + orig_v_cache, + orig_max_kv_caching_len, + orig_kv_caching_offsets, + ) + + +@torch.fx.wrap +def _construct_full_kv( + delta_k: torch.Tensor, + delta_v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + max_kv_caching_len: int, + kv_caching_offsets: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor]: + L, _ = delta_k.shape + B = kv_caching_offsets.shape[0] - 1 + delta_size = L // B + full_k = concat_2D_jagged( + max_seq_len=max_kv_caching_len + delta_size, + values_left=k_cache, + values_right=delta_k, + max_len_left=max_kv_caching_len, + max_len_right=delta_size, + offsets_left=kv_caching_offsets, + offsets_right=None, + ) + full_v = concat_2D_jagged( + max_seq_len=max_kv_caching_len + delta_size, + values_left=v_cache, + values_right=delta_v, + max_len_left=max_kv_caching_len, + max_len_right=delta_size, + offsets_left=kv_caching_offsets, + offsets_right=None, + ) + full_kv_caching_offsets = kv_caching_offsets + delta_size * torch.arange( + B + 1, device=delta_k.device + ) + return ( + full_k, + full_v, + max_kv_caching_len + delta_size, + full_kv_caching_offsets, + ) + + +class STULayer(STU): + max_kv_caching_len: int + k_cache: Optional[torch.Tensor] + v_cache: Optional[torch.Tensor] + kv_caching_offsets: Optional[torch.Tensor] + + def __init__( + self, + config: STULayerConfig, + is_inference: bool = False, + ) -> None: + super().__init__( + is_inference=is_inference, + ) + self.reset_kv_cache() + self._num_heads: int = config.num_heads + self._embedding_dim: int = config.embedding_dim + self._hidden_dim: int = config.hidden_dim + self._attention_dim: int = config.attention_dim + self._output_dropout_ratio: float = config.output_dropout_ratio + self._target_aware: bool = config.target_aware + self._causal: bool = config.causal + self._max_attn_len: int = config.max_attn_len or 0 + self._attn_alpha: float = config.attn_alpha or 1.0 / (self._attention_dim**0.5) + self._use_group_norm: bool = config.use_group_norm + self._recompute_normed_x: bool = config.recompute_normed_x + self._recompute_uvqk: bool = config.recompute_uvqk + self._recompute_y: bool = config.recompute_y + self._sort_by_length: bool = config.sort_by_length + self._contextual_seq_len: int = config.contextual_seq_len + + self._uvqk_weight: torch.nn.Parameter = torch.nn.Parameter( + torch.empty( + ( + self._embedding_dim, + (self._hidden_dim * 2 + self._attention_dim * 2) * self._num_heads, + ) + ), + ) + torch.nn.init.xavier_uniform_(self._uvqk_weight) + self._uvqk_beta: torch.nn.Parameter = torch.nn.Parameter( + torch.zeros( + (self._hidden_dim * 2 + self._attention_dim * 2) * self._num_heads, + ), + ) + self._input_norm_weight: torch.nn.Parameter = torch.nn.Parameter( + torch.ones((self._embedding_dim,)), + ) + self._input_norm_bias: torch.nn.Parameter = torch.nn.Parameter( + torch.zeros((self._embedding_dim,)), + ) + self._output_weight = torch.nn.Parameter( + torch.empty( + ( + self._hidden_dim * self._num_heads * 3, + self._embedding_dim, + ) + ), + ) + torch.nn.init.xavier_uniform_(self._output_weight) + output_norm_shape: int = ( + self._hidden_dim * self._num_heads + if not self._use_group_norm + else self._num_heads + ) + self._output_norm_weight: torch.nn.Parameter = torch.nn.Parameter( + torch.ones((output_norm_shape,)), + ) + self._output_norm_bias: torch.nn.Parameter = torch.nn.Parameter( + torch.zeros((output_norm_shape,)), + ) + + def reset_kv_cache(self) -> None: + self.k_cache = None + self.v_cache = None + self.kv_caching_offsets = None + self.max_kv_caching_len = 0 + + def update_kv_cache( + self, + max_seq_len: int, + seq_offsets: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + max_kv_caching_len: int, + kv_caching_lengths: Optional[torch.Tensor], + ) -> None: + self.k_cache, self.v_cache, self.max_kv_caching_len, self.kv_caching_offsets = ( + _update_kv_cache( + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + k=k, + v=v, + max_kv_caching_len=max_kv_caching_len, + kv_caching_lengths=kv_caching_lengths, + orig_k_cache=self.k_cache, + orig_v_cache=self.v_cache, + orig_max_kv_caching_len=self.max_kv_caching_len, + orig_kv_caching_offsets=self.kv_caching_offsets, + ) + ) + + def construct_full_kv( + self, + delta_k: torch.Tensor, + delta_v: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor]: + return _construct_full_kv( + delta_k=delta_k, + delta_v=delta_v, + k_cache=fx_unwrap_optional_tensor(self.k_cache), + v_cache=fx_unwrap_optional_tensor(self.v_cache), + max_kv_caching_len=self.max_kv_caching_len, + kv_caching_offsets=self.kv_caching_offsets, + ) + + def forward( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + x_offsets: torch.Tensor, + max_seq_len: int, + num_targets: torch.Tensor, + max_kv_caching_len: int = 0, + kv_caching_lengths: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + with record_function("## stu_preprocess_and_attention ##"): + u, attn_output, k, v = hstu_preprocess_and_attention( + x=x, + norm_weight=self._input_norm_weight.to(x.dtype), + norm_bias=self._input_norm_bias.to(x.dtype), + norm_eps=1e-6, + num_heads=self._num_heads, + attn_dim=self._attention_dim, + hidden_dim=self._hidden_dim, + uvqk_weight=self._uvqk_weight.to(x.dtype), + uvqk_bias=self._uvqk_beta.to(x.dtype), + max_seq_len=max_seq_len, + seq_offsets=x_offsets, + attn_alpha=self._attn_alpha, + causal=self._causal, + num_targets=num_targets if self._target_aware else None, + max_attn_len=self._max_attn_len, + contextual_seq_len=self._contextual_seq_len, + recompute_uvqk_in_backward=self._recompute_uvqk, + recompute_normed_x_in_backward=self._recompute_normed_x, + sort_by_length=self._sort_by_length, + prefill=kv_caching_lengths is not None, + kernel=self.hammer_kernel(), + ) + + self.update_kv_cache( + max_seq_len=max_seq_len, + seq_offsets=x_offsets, + k=k, + v=v, + max_kv_caching_len=max_kv_caching_len, + kv_caching_lengths=kv_caching_lengths, + ) + + with record_function("## stu_compute_output ##"): + return hstu_compute_output( + attn=attn_output, + u=u, + x=x, + norm_weight=self._output_norm_weight.to(x.dtype), + norm_bias=self._output_norm_bias.to(x.dtype), + norm_eps=1e-6, + dropout_ratio=self._output_dropout_ratio, + output_weight=self._output_weight.to(x.dtype), + group_norm=self._use_group_norm, + num_heads=self._num_heads, + linear_dim=self._hidden_dim, + concat_u=True, + concat_x=True, + mul_u_activation_type="none", + training=self.training, + kernel=self.hammer_kernel(), + recompute_y_in_backward=self._recompute_y, + ) + + def cached_forward( + self, + delta_x: torch.Tensor, + num_targets: torch.Tensor, + max_kv_caching_len: int = 0, + kv_caching_lengths: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + with record_function("## stu_compute_uqvk ##"): + delta_u, delta_q, delta_k, delta_v = hstu_compute_uqvk( + x=delta_x, + norm_weight=self._input_norm_weight.to(delta_x.dtype), + norm_bias=self._input_norm_bias.to(delta_x.dtype), + norm_eps=1e-6, + num_heads=self._num_heads, + attn_dim=self._attention_dim, + hidden_dim=self._hidden_dim, + uvqk_weight=self._uvqk_weight.to(delta_x.dtype), + uvqk_bias=self._uvqk_beta.to(delta_x.dtype), + kernel=self.hammer_kernel(), + ) + k, v, max_seq_len, seq_offsets = self.construct_full_kv( + delta_k=delta_k.flatten(1, 2), + delta_v=delta_v.flatten(1, 2), + ) + self.update_kv_cache( + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + k=k, + v=v, + max_kv_caching_len=max_kv_caching_len, + kv_caching_lengths=kv_caching_lengths, + ) + k = k.view(-1, self._num_heads, self._attention_dim) + v = v.view(-1, self._num_heads, self._hidden_dim) + with record_function("## delta_hstu_mha ##"): + delta_attn_output = delta_hstu_mha( + max_seq_len=max_seq_len, + alpha=self._attn_alpha, + delta_q=delta_q, + k=k, + v=v, + seq_offsets=seq_offsets, + num_targets=num_targets if self._target_aware else None, + max_attn_len=self._max_attn_len, + contextual_seq_len=self._contextual_seq_len, + kernel=self.hammer_kernel(), + ).view(-1, self._hidden_dim * self._num_heads) + with record_function("## stu_compute_output ##"): + return hstu_compute_output( + attn=delta_attn_output, + u=delta_u, + x=delta_x, + norm_weight=self._output_norm_weight.to(delta_x.dtype), + norm_bias=self._output_norm_bias.to(delta_x.dtype), + norm_eps=1e-6, + dropout_ratio=self._output_dropout_ratio, + output_weight=self._output_weight.to(delta_x.dtype), + group_norm=self._use_group_norm, + num_heads=self._num_heads, + linear_dim=self._hidden_dim, + concat_u=True, + concat_x=True, + mul_u_activation_type="none", + training=self.training, + kernel=self.hammer_kernel(), + recompute_y_in_backward=self._recompute_y, + ) + + +class STUStack(STU): + def __init__( + self, + stu_list: List[STU], + is_inference: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + self._stu_layers: torch.nn.ModuleList = torch.nn.ModuleList(modules=stu_list) + + def forward( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + x_offsets: torch.Tensor, + max_seq_len: int, + num_targets: torch.Tensor, + max_kv_caching_len: int = 0, + kv_caching_lengths: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + for layer in self._stu_layers: + x = layer( + x=x, + x_lengths=x_lengths, + x_offsets=x_offsets, + max_seq_len=max_seq_len, + num_targets=num_targets, + max_kv_caching_len=max_kv_caching_len, + kv_caching_lengths=kv_caching_lengths, + ) + return x + + def cached_forward( + self, + delta_x: torch.Tensor, + num_targets: torch.Tensor, + max_kv_caching_len: int = 0, + kv_caching_lengths: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + for layer in self._stu_layers: + delta_x = layer.cached_forward( # pyre-ignore [29] + delta_x=delta_x, + num_targets=num_targets, + max_kv_caching_len=max_kv_caching_len, + kv_caching_lengths=kv_caching_lengths, + ) + return delta_x diff --git a/recommendation_v4/generative_recommenders/modules/tests/action_encoder_test.py b/recommendation_v4/generative_recommenders/modules/tests/action_encoder_test.py new file mode 100644 index 000000000..184b314ea --- /dev/null +++ b/recommendation_v4/generative_recommenders/modules/tests/action_encoder_test.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import unittest + +import torch +from generative_recommenders.common import gpu_unavailable +from generative_recommenders.modules.action_encoder import ActionEncoder + + +class ActionEncoderTest(unittest.TestCase): + @unittest.skipIf(*gpu_unavailable) + def test_forward(self) -> None: + device = torch.device("cuda") + action_embedding_dim = 32 + action_weights = [1, 2, 4, 8, 16] + watchtime_to_action_thresholds_and_weights = [ + (30, 32), + (60, 64), + (100, 128), + ] + num_action_types = len(action_weights) + len( + watchtime_to_action_thresholds_and_weights + ) + combined_action_weights = action_weights + [ + x[1] for x in watchtime_to_action_thresholds_and_weights + ] + enabled_actions = [ + [0], + [0, 1], + [1, 3, 4], + [1, 2, 3, 4], + [1, 2], + [2], + ] + watchtimes = [40, 20, 110, 31, 26, 55] + for i, wt in enumerate(watchtimes): + for j, w in enumerate(watchtime_to_action_thresholds_and_weights): + if wt > w[0]: + enabled_actions[i].append(j + len(action_weights)) + actions = [ + sum([combined_action_weights[t] for t in x]) for x in enabled_actions + ] + + encoder = ActionEncoder( + watchtime_feature_name="watchtimes", + action_feature_name="actions", + action_weights=action_weights, + watchtime_to_action_thresholds_and_weights=watchtime_to_action_thresholds_and_weights, + action_embedding_dim=action_embedding_dim, + is_inference=False, + ).to(device) + + seq_lengths = [6, 3] + seq_offsets = [0, 6, 9] + num_targets = [2, 1] + uih_offsets = [0, 4, 6] + target_offsets = [0, 2, 3] + seq_embeddings = torch.rand(9, 128, device=device) + action_embeddings = encoder( + max_uih_len=4, + max_targets=2, + uih_offsets=torch.tensor(uih_offsets, device=device), + target_offsets=torch.tensor(target_offsets, device=device), + seq_embeddings=seq_embeddings, + seq_payloads={ + "watchtimes": torch.tensor(watchtimes, device=device), + "actions": torch.tensor(actions, device=device), + }, + ) + self.assertEqual( + action_embeddings.shape, (9, action_embedding_dim * num_action_types) + ) + for b in range(len(seq_lengths)): + b_start = seq_offsets[b] + b_end = seq_offsets[b + 1] + u_start = uih_offsets[b] + for j in range(b_start, b_end): + embedding = action_embeddings[j].view(num_action_types, -1) + for atype in range(num_action_types): + if b_end - j <= num_targets[b]: + torch.testing.assert_allclose( + embedding[atype], + encoder._target_action_embedding_table.view( + num_action_types, -1 + )[atype], + ) + else: + if atype in enabled_actions[j - b_start + u_start]: + torch.testing.assert_allclose( + embedding[atype], + encoder._action_embedding_table[atype], + ) + else: + torch.testing.assert_allclose( + embedding[atype], torch.zeros_like(embedding[atype]) + ) + action_embeddings.sum().backward() diff --git a/recommendation_v4/generative_recommenders/modules/tests/content_encoder_test.py b/recommendation_v4/generative_recommenders/modules/tests/content_encoder_test.py new file mode 100644 index 000000000..e67656388 --- /dev/null +++ b/recommendation_v4/generative_recommenders/modules/tests/content_encoder_test.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import unittest + +import torch +from generative_recommenders.common import gpu_unavailable +from generative_recommenders.modules.content_encoder import ContentEncoder + + +class ContentEncoderTest(unittest.TestCase): + @unittest.skipIf(*gpu_unavailable) + def test_forward(self) -> None: + device = torch.device("cuda") + input_embedding_dim = 32 + additional_embedding_dim = 64 + enrich_embedding_dim = 16 + encoder = ContentEncoder( + input_embedding_dim=input_embedding_dim, + additional_content_features={ + "a0": additional_embedding_dim, + "a1": additional_embedding_dim, + }, + target_enrich_features={ + "t0": enrich_embedding_dim, + "t1": enrich_embedding_dim, + }, + is_inference=False, + ).to(device) + seq_lengths = [6, 3] + num_targets = [2, 1] + uih_offsets = [0, 4, 6] + target_offsets = [0, 2, 3] + seq_embeddings = torch.rand( + sum(seq_lengths), input_embedding_dim, device=device + ).requires_grad_(True) + seq_payloads = { + "a0": torch.rand( + sum(seq_lengths), additional_embedding_dim, device=device + ).requires_grad_(True), + "a1": torch.rand( + sum(seq_lengths), additional_embedding_dim, device=device + ).requires_grad_(True), + "t0": torch.rand( + sum(num_targets), enrich_embedding_dim, device=device + ).requires_grad_(True), + "t1": torch.rand( + sum(num_targets), enrich_embedding_dim, device=device + ).requires_grad_(True), + } + content_embeddings = encoder( + max_uih_len=4, + max_targets=2, + uih_offsets=torch.tensor(uih_offsets, device=device), + target_offsets=torch.tensor(target_offsets, device=device), + seq_embeddings=seq_embeddings, + seq_payloads=seq_payloads, + ) + content_embeddings.sum().backward() diff --git a/recommendation_v4/generative_recommenders/modules/tests/contextual_interleave_preprocessor_test.py b/recommendation_v4/generative_recommenders/modules/tests/contextual_interleave_preprocessor_test.py new file mode 100644 index 000000000..c3202072c --- /dev/null +++ b/recommendation_v4/generative_recommenders/modules/tests/contextual_interleave_preprocessor_test.py @@ -0,0 +1,499 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import unittest + +import torch +from generative_recommenders.common import gpu_unavailable +from generative_recommenders.modules.action_encoder import ActionEncoder +from generative_recommenders.modules.content_encoder import ContentEncoder +from generative_recommenders.modules.contextual_interleave_preprocessor import ( + ContextualInterleavePreprocessor, +) +from generative_recommenders.modules.contextualize_mlps import ( + ParameterizedContextualizedMLP, + SimpleContextualizedMLP, +) +from hypothesis import given, settings, strategies as st, Verbosity + + +class ContextualInterleavePreprocessorTest(unittest.TestCase): + # pyre-ignore + @given( + enable_interleaving=st.sampled_from([True, False]), + enable_pmlp=st.sampled_from([True, False]), + is_train=st.sampled_from([True, False]), + dtype=st.sampled_from( + [torch.float32, torch.bfloat16] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + ) + @unittest.skipIf(*gpu_unavailable) + @settings(verbosity=Verbosity.verbose, max_examples=50, deadline=None) + def test_forward( + self, + enable_interleaving: bool, + enable_pmlp: bool, + is_train: bool, + dtype: torch.dtype, + ) -> None: + device = torch.device("cuda") + + input_embedding_dim = 64 + output_embedding_dim = 32 + action_embedding_dim = 16 + action_encoder_hidden_dim = 256 + content_encoder_hidden_dim = 128 + contextual_len = 3 + + content_encoder = ContentEncoder( + input_embedding_dim=input_embedding_dim, + additional_content_features={ + "a0": input_embedding_dim, + "a1": input_embedding_dim, + }, + target_enrich_features={ + "t0": input_embedding_dim, + "t1": input_embedding_dim, + }, + is_inference=False, + ).to(device) + action_embedding_dim = 32 + action_weights = [1, 2, 4, 8, 16] + watchtime_to_action_thresholds_and_weights = [ + (30, 32), + (60, 64), + (100, 128), + ] + action_encoder = ActionEncoder( + watchtime_feature_name="watchtimes", + action_feature_name="actions", + action_weights=action_weights, + watchtime_to_action_thresholds_and_weights=watchtime_to_action_thresholds_and_weights, + action_embedding_dim=action_embedding_dim, + is_inference=False, + ).to(device) + + preprocessor = ContextualInterleavePreprocessor( + input_embedding_dim=input_embedding_dim, + output_embedding_dim=output_embedding_dim, + contextual_feature_to_max_length={"c_0": 1, "c_1": 2}, + contextual_feature_to_min_uih_length={"c_1": 4}, + pmlp_contextual_dropout_ratio=0.2, + content_encoder=content_encoder, + content_contextualize_mlp_fn=lambda in_dim, + out_dim, + contextual_dim, + is_inference: ParameterizedContextualizedMLP( + contextual_embedding_dim=contextual_dim, + sequential_input_dim=in_dim, + sequential_output_dim=out_dim, + hidden_dim=content_encoder_hidden_dim, + is_inference=is_inference, + ) + if enable_pmlp + else SimpleContextualizedMLP( + sequential_input_dim=in_dim, + sequential_output_dim=out_dim, + hidden_dim=content_encoder_hidden_dim, + is_inference=is_inference, + ), + action_encoder=action_encoder, + action_contextualize_mlp_fn=lambda in_dim, + out_dim, + contextual_dim, + is_inference: ParameterizedContextualizedMLP( + contextual_embedding_dim=contextual_dim, + sequential_input_dim=in_dim, + sequential_output_dim=out_dim, + hidden_dim=action_encoder_hidden_dim, + is_inference=is_inference, + ) + if enable_pmlp + else SimpleContextualizedMLP( + sequential_input_dim=in_dim, + sequential_output_dim=out_dim, + hidden_dim=action_encoder_hidden_dim, + is_inference=is_inference, + ), + enable_interleaving=enable_interleaving, + is_inference=False, + ).to(device) + preprocessor.set_training_dtype(dtype) + if not is_train: + preprocessor.eval() + + # inputs + seq_lengths = [6, 3] + num_targets = [2, 1] + seq_embeddings = torch.rand( + (sum(seq_lengths), input_embedding_dim), + device=device, + dtype=dtype, + ) + seq_timestamps = torch.tensor( + [1, 2, 3, 4, 5, 6, 10, 20, 30], + device=device, + ) + watchtimes = [40, 20, 110, 31, 26, 55] + actions = [1, 3, 26, 30, 6, 4] + ( + output_max_seq_len, + output_total_uih_len, + output_total_targets, + output_seq_lengths, + output_seq_offsets, + output_seq_timestamps, + output_seq_embeddings, + output_num_targets, + _, + ) = preprocessor( + max_uih_len=4, + max_targets=2, + total_uih_len=sum(seq_lengths) - sum(num_targets), + total_targets=sum(num_targets), + seq_lengths=torch.tensor(seq_lengths, device=device), + seq_timestamps=seq_timestamps, + seq_embeddings=seq_embeddings, + seq_payloads={ + # contextual + "c_0": torch.rand((2, input_embedding_dim), device=device, dtype=dtype), + "c_0_offsets": torch.tensor([0, 1, 1], device=device), + "c_1": torch.rand((4, input_embedding_dim), device=device, dtype=dtype), + "c_1_offsets": torch.tensor([0, 2, 3], device=device), + # action + "watchtimes": torch.tensor(watchtimes, device=device), + "actions": torch.tensor(actions, device=device), + # content + "a0": torch.rand_like(seq_embeddings).requires_grad_(True), + "a1": torch.rand_like(seq_embeddings).requires_grad_(True), + "t0": torch.rand( + sum(num_targets), input_embedding_dim, device=device, dtype=dtype + ).requires_grad_(True), + "t1": torch.rand( + sum(num_targets), input_embedding_dim, device=device, dtype=dtype + ).requires_grad_(True), + }, + num_targets=torch.tensor(num_targets, device=device), + ) + if enable_interleaving: + if is_train: + expected_output_seq_lengths = [ + 2 * s + contextual_len for s in seq_lengths + ] + expected_max_seq_len = max(expected_output_seq_lengths) + expected_output_num_targets = [2 * s for s in num_targets] + expected_seq_embedding_size = ( + sum(expected_output_seq_lengths), + output_embedding_dim, + ) + expected_seq_timestamps_size = (sum(expected_output_seq_lengths),) + expected_output_seq_timestamps = [ + 0, + 0, + 0, + 1, + 1, + 2, + 2, + 3, + 3, + 4, + 4, + 5, + 5, + 6, + 6, + 0, + 0, + 0, + 10, + 10, + 20, + 20, + 30, + 30, + ] + else: + expected_output_seq_lengths = [ + 2 * s - n + contextual_len for s, n in zip(seq_lengths, num_targets) + ] + expected_max_seq_len = max(expected_output_seq_lengths) + expected_output_num_targets = num_targets + expected_seq_embedding_size = ( + sum(expected_output_seq_lengths), + output_embedding_dim, + ) + expected_seq_timestamps_size = (sum(expected_output_seq_lengths),) + expected_output_seq_timestamps = [ + 0, + 0, + 0, + 1, + 1, + 2, + 2, + 3, + 3, + 4, + 4, + 5, + 6, + 0, + 0, + 0, + 10, + 10, + 20, + 20, + 30, + ] + else: + expected_output_seq_lengths = [s + contextual_len for s in seq_lengths] + expected_max_seq_len = max(expected_output_seq_lengths) + expected_output_num_targets = num_targets + expected_seq_embedding_size = ( + sum(expected_output_seq_lengths), + output_embedding_dim, + ) + expected_seq_timestamps_size = (sum(expected_output_seq_lengths),) + expected_output_seq_timestamps = [ + 0, + 0, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 0, + 0, + 0, + 10, + 20, + 30, + ] + + self.assertEqual(output_max_seq_len, expected_max_seq_len) + self.assertEqual(output_seq_lengths.tolist(), expected_output_seq_lengths) + torch.testing.assert_allclose( + torch.ops.fbgemm.asynchronous_complete_cumsum(output_seq_lengths), + output_seq_offsets, + ) + self.assertEqual(output_num_targets.tolist(), expected_output_num_targets) + self.assertEqual( + output_seq_embeddings.size(), + expected_seq_embedding_size, + ) + self.assertEqual( + output_seq_timestamps.size(), + expected_seq_timestamps_size, + ) + self.assertEqual( + output_seq_timestamps.tolist(), + expected_output_seq_timestamps, + ) + + # test combine embeddings + batch_size = 10 + max_uih_len = 100 + max_targets = 20 + max_seq_len = max_uih_len + max_targets + seq_lengths = torch.randint(0, max_uih_len, (batch_size,), device=device) + total_uih_len = int(seq_lengths.sum().item()) + num_targets = torch.randint(1, max_targets, (batch_size,), device=device) + total_targets = int(num_targets.sum().item()) + seq_lengths = seq_lengths + num_targets + seq_offsets = torch.zeros( + (batch_size + 1,), dtype=torch.int64, device=torch.device("cuda") + ) + seq_offsets[1:] = torch.cumsum(seq_lengths, dim=0) + total_seq_len = int(torch.sum(seq_lengths).item()) + seq_timestamps = torch.randint(0, 1000000, (total_seq_len,), device=device) + content_embeddings = torch.rand( + (total_seq_len, output_embedding_dim), + device=device, + ).requires_grad_(True) + action_embeddings = torch.rand( + (total_seq_len, output_embedding_dim), + device=device, + ).requires_grad_(True) + contextual_embeddings = torch.rand( + (total_seq_len, 3 * output_embedding_dim), device=device + ).requires_grad_(True) + ( + output_max_seq_len, + output_total_uih_len, + output_total_targets, + output_seq_lengths, + output_seq_offsets, + output_seq_timestamps, + output_seq_embeddings, + output_num_targets, + ) = preprocessor.combine_embeddings( + max_uih_len=max_uih_len, + max_targets=max_targets, + total_uih_len=total_uih_len, + total_targets=total_targets, + seq_lengths=seq_lengths, + seq_timestamps=seq_timestamps, + content_embeddings=content_embeddings, + action_embeddings=action_embeddings, + contextual_embeddings=contextual_embeddings, + num_targets=num_targets, + ) + seq_embeddings = action_embeddings + content_embeddings + if enable_interleaving: + if is_train: + self.assertEqual( + output_max_seq_len, + max_seq_len * 2 + contextual_len, + ) + self.assertEqual( + output_total_uih_len, + total_uih_len * 2 + contextual_len * batch_size, + ) + self.assertEqual( + output_total_targets, + total_targets * 2, + ) + torch.testing.assert_allclose( + output_seq_lengths, seq_lengths * 2 + contextual_len + ) + torch.testing.assert_allclose(output_num_targets, num_targets * 2) + else: + self.assertEqual( + output_max_seq_len, + max_uih_len * 2 + max_targets + contextual_len, + ) + self.assertEqual( + output_total_uih_len, + total_uih_len * 2 + contextual_len * batch_size, + ) + self.assertEqual( + output_total_targets, + total_targets, + ) + torch.testing.assert_allclose( + output_seq_lengths, seq_lengths * 2 - num_targets + contextual_len + ) + torch.testing.assert_allclose(output_num_targets, num_targets) + else: + self.assertEqual( + output_max_seq_len, + max_seq_len + contextual_len, + ) + self.assertEqual( + output_total_uih_len, + total_uih_len + contextual_len * batch_size, + ) + self.assertEqual( + output_total_targets, + total_targets, + ) + torch.testing.assert_allclose( + output_seq_lengths, seq_lengths + contextual_len + ) + torch.testing.assert_allclose(output_num_targets, num_targets) + for b in range(batch_size): + input_start = int(seq_offsets[b].item()) + input_end = int(seq_offsets[b + 1].item()) + output_start = int(output_seq_offsets[b].item()) + output_end = int(output_seq_offsets[b + 1].item()) + input_targets = int(num_targets[b].item()) + output_targets = int(output_num_targets[b].item()) + torch.testing.assert_allclose( + output_seq_timestamps[output_start : output_start + contextual_len], + torch.zeros(3, device=device), + ) + torch.testing.assert_allclose( + output_seq_embeddings[ + output_start : output_start + contextual_len + ].view(-1), + contextual_embeddings[b], + ) + if enable_interleaving: + torch.testing.assert_allclose( + output_seq_timestamps[ + output_start + contextual_len : output_end - output_targets + ].view(-1, 2)[:, 0], + seq_timestamps[input_start : input_end - input_targets], + ) + torch.testing.assert_allclose( + output_seq_timestamps[ + output_start + contextual_len : output_end - output_targets + ].view(-1, 2)[:, 1], + seq_timestamps[input_start : input_end - input_targets], + ) + torch.testing.assert_allclose( + output_seq_embeddings[ + output_start + contextual_len : output_end - output_targets + ].view(-1, 2, output_embedding_dim)[:, 0, :], + content_embeddings[input_start : input_end - input_targets], + ) + torch.testing.assert_allclose( + output_seq_embeddings[ + output_start + contextual_len : output_end - output_targets + ].view(-1, 2, output_embedding_dim)[:, 1, :], + action_embeddings[input_start : input_end - input_targets], + ) + if is_train: + torch.testing.assert_allclose( + output_seq_timestamps[ + output_end - output_targets : output_end + ].view(-1, 2)[:, 0], + seq_timestamps[input_end - input_targets : input_end], + ) + torch.testing.assert_allclose( + output_seq_timestamps[ + output_end - output_targets : output_end + ].view(-1, 2)[:, 1], + seq_timestamps[input_end - input_targets : input_end], + ) + torch.testing.assert_allclose( + output_seq_embeddings[ + output_end - output_targets : output_end + ].view(-1, 2, output_embedding_dim)[:, 0, :], + content_embeddings[input_end - input_targets : input_end], + ) + torch.testing.assert_allclose( + output_seq_embeddings[ + output_end - output_targets : output_end + ].view(-1, 2, output_embedding_dim)[:, 1, :], + action_embeddings[input_end - input_targets : input_end], + ) + else: + torch.testing.assert_allclose( + output_seq_timestamps[output_end - output_targets : output_end], + seq_timestamps[input_end - input_targets : input_end], + ) + torch.testing.assert_allclose( + output_seq_embeddings[output_end - output_targets : output_end], + content_embeddings[input_end - input_targets : input_end], + ) + else: + torch.testing.assert_allclose( + output_seq_timestamps[output_start + contextual_len : output_end], + seq_timestamps[input_start:input_end], + ) + torch.testing.assert_allclose( + output_seq_embeddings[output_start + contextual_len : output_end], + seq_embeddings[input_start:input_end], + ) diff --git a/recommendation_v4/generative_recommenders/modules/tests/dynamic_stu_test.py b/recommendation_v4/generative_recommenders/modules/tests/dynamic_stu_test.py new file mode 100644 index 000000000..c1c598f1f --- /dev/null +++ b/recommendation_v4/generative_recommenders/modules/tests/dynamic_stu_test.py @@ -0,0 +1,279 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import copy +import unittest +from typing import List + +import torch +from generative_recommenders.common import gpu_unavailable, HammerKernel, set_dev_mode +from generative_recommenders.modules.dynamic_stu import L2STU, SDSTU +from generative_recommenders.modules.stu import STU, STULayer, STULayerConfig, STUStack +from hypothesis import given, settings, strategies as st, Verbosity + + +class DynamicStuTest(unittest.TestCase): + # pyre-ignore + @given( + causal=st.sampled_from([True]), + num_layers=st.sampled_from([2]), + num_heads=st.sampled_from([2]), + max_uih_len=st.sampled_from([300]), + batch_size=st.sampled_from([8]), + embedding_dim=st.sampled_from([16]), + attention_dim=st.sampled_from([32]), + linear_hidden_dim=st.sampled_from([64]), + has_multiple_targets=st.sampled_from([True, False]), + contextual_seq_len=st.sampled_from([0, 10]), + dtype=st.sampled_from( + [torch.float32, torch.bfloat16] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + ) + @unittest.skipIf(*gpu_unavailable) + @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) + def test_l2_stu( + self, + causal: bool, + num_layers: int, + num_heads: int, + max_uih_len: int, + batch_size: int, + embedding_dim: int, + attention_dim: int, + linear_hidden_dim: int, + has_multiple_targets: bool, + contextual_seq_len: int, + dtype: torch.dtype, + ) -> None: + set_dev_mode(True) + device = torch.device("cuda") + l3_stu_layers: List[STU] = [ + STULayer( + config=STULayerConfig( + embedding_dim=embedding_dim, + num_heads=num_heads, + hidden_dim=linear_hidden_dim, + attention_dim=attention_dim, + output_dropout_ratio=0.0, + causal=causal, + target_aware=has_multiple_targets, + contextual_seq_len=contextual_seq_len, + ), + is_inference=False, + ) + for _ in range(num_layers) + ] + l3_stu: List[STU] = [ + L2STU( + stu=STUStack( + stu_list=l3_stu_layers, + is_inference=False, + ), + max_l2_len=100, + contextual_seq_len=contextual_seq_len, + is_inference=False, + ) + ] + l2_stu_layers: List[STU] = [ + STULayer( + config=STULayerConfig( + embedding_dim=embedding_dim, + num_heads=num_heads, + hidden_dim=linear_hidden_dim, + attention_dim=attention_dim, + output_dropout_ratio=0.0, + causal=causal, + target_aware=has_multiple_targets, + contextual_seq_len=contextual_seq_len, + ), + is_inference=False, + ) + for _ in range(num_layers) + ] + l3_stu + l2_stu: List[STU] = [ + L2STU( + stu=STUStack( + stu_list=l2_stu_layers, + is_inference=False, + ), + max_l2_len=200, + contextual_seq_len=contextual_seq_len, + is_inference=False, + ) + ] + stu_layers: List[STU] = [ + STULayer( + config=STULayerConfig( + embedding_dim=embedding_dim, + num_heads=num_heads, + hidden_dim=linear_hidden_dim, + attention_dim=attention_dim, + output_dropout_ratio=0.0, + causal=causal, + target_aware=has_multiple_targets, + contextual_seq_len=contextual_seq_len, + ), + is_inference=False, + ) + for _ in range(num_layers) + ] + l2_stu + stu = STUStack( + stu_list=stu_layers, + is_inference=False, + ).to(device) + stu.recursive_setattr("_hammer_kernel", HammerKernel.TRITON) + + x_lengths = torch.randint(max_uih_len + 1, (batch_size,), device=device) + x_lengths = x_lengths + contextual_seq_len + max_seq_len = max_uih_len + contextual_seq_len + max_targets = 20 + num_targets = torch.randint(1, max_targets, size=(batch_size,), device=device) + if has_multiple_targets: + x_lengths = x_lengths + num_targets + max_seq_len = max_seq_len + max_targets + x_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(x_lengths) + total_seq_len = int(x_offsets[-1].cpu().item()) + x = torch.randn( + int(total_seq_len), + embedding_dim, + device=device, + dtype=dtype, + ).requires_grad_(True) + stu_output = stu( + x=x, + x_lengths=x_lengths, + x_offsets=x_offsets, + max_seq_len=max_seq_len, + num_targets=num_targets, + ) + dout = torch.randn_like(stu_output) + stu_output.backward(dout) + self.assertTrue(stu_output.shape == x.shape) + + # pyre-ignore + @given( + causal=st.sampled_from([True]), + num_layers=st.sampled_from([2]), + num_heads=st.sampled_from([2]), + max_uih_len=st.sampled_from([300]), + batch_size=st.sampled_from([8]), + embedding_dim=st.sampled_from([16]), + attention_dim=st.sampled_from([32]), + linear_hidden_dim=st.sampled_from([64]), + has_multiple_targets=st.sampled_from([True, False]), + contextual_seq_len=st.sampled_from([0, 10]), + dropout_ratio=st.sampled_from([0.0, 0.3, 1.0]), + dtype=st.sampled_from( + [torch.float32, torch.bfloat16] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + ) + @unittest.skipIf(*gpu_unavailable) + @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) + def test_sd_stu( + self, + causal: bool, + num_layers: int, + num_heads: int, + max_uih_len: int, + batch_size: int, + embedding_dim: int, + attention_dim: int, + linear_hidden_dim: int, + has_multiple_targets: bool, + contextual_seq_len: int, + dropout_ratio: float, + dtype: torch.dtype, + ) -> None: + set_dev_mode(True) + device = torch.device("cuda") + stu_layers: List[STU] = [ + STULayer( + config=STULayerConfig( + embedding_dim=embedding_dim, + num_heads=num_heads, + hidden_dim=linear_hidden_dim, + attention_dim=attention_dim, + output_dropout_ratio=0.0, + causal=causal, + target_aware=has_multiple_targets, + contextual_seq_len=contextual_seq_len, + ), + is_inference=False, + ) + for _ in range(num_layers) + ] + stu = STUStack( + stu_list=stu_layers, + is_inference=False, + ).to(device) + sd_stu = SDSTU( + stu=copy.deepcopy(stu), + dropout_ratio=dropout_ratio, + is_inference=False, + ).to(device) + stu.recursive_setattr("_hammer_kernel", HammerKernel.PYTORCH) + sd_stu.recursive_setattr("_hammer_kernel", HammerKernel.PYTORCH) + x_lengths = torch.randint(max_uih_len + 1, (batch_size,), device=device) + x_lengths = x_lengths + contextual_seq_len + max_seq_len = max_uih_len + contextual_seq_len + max_targets = 20 + num_targets = torch.randint(1, max_targets, size=(batch_size,), device=device) + if has_multiple_targets: + x_lengths = x_lengths + num_targets + max_seq_len = max_seq_len + max_targets + x_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(x_lengths) + total_seq_len = int(x_offsets[-1].cpu().item()) + x = torch.randn( + int(total_seq_len), + embedding_dim, + device=device, + dtype=dtype, + ).requires_grad_(True) + stu_output = stu( + x=x, + x_lengths=x_lengths, + x_offsets=x_offsets, + max_seq_len=max_seq_len, + num_targets=num_targets, + ) + dout = torch.randn_like(stu_output) + stu_output.backward(dout) + assert x.grad is not None + d_x, x.grad = x.grad.detach().clone(), None + x = x.detach().clone().requires_grad_(True) + sd_stu_output = sd_stu( + x=x, + x_lengths=x_lengths, + x_offsets=x_offsets, + max_seq_len=max_seq_len, + num_targets=num_targets, + ) + dout = dout.detach().clone() + sd_stu_output.backward(dout) + d_sd_x = x.grad.detach().clone() + + self.assertTrue(sd_stu_output.shape == x.shape) + if dropout_ratio == 0.0: + torch.testing.assert_close(stu_output, sd_stu_output) + torch.testing.assert_close(d_x, d_sd_x) + if dropout_ratio == 1.0: + torch.testing.assert_close(x, sd_stu_output) diff --git a/recommendation_v4/generative_recommenders/modules/tests/multitask_module_test.py b/recommendation_v4/generative_recommenders/modules/tests/multitask_module_test.py new file mode 100644 index 000000000..66f2db185 --- /dev/null +++ b/recommendation_v4/generative_recommenders/modules/tests/multitask_module_test.py @@ -0,0 +1,233 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import unittest +from typing import Dict, List, Tuple + +import torch +from generative_recommenders.common import gpu_unavailable, set_dev_mode +from generative_recommenders.modules.multitask_module import ( + DefaultMultitaskModule, + MultitaskTaskType, + TaskConfig, +) +from generative_recommenders.ops.layer_norm import SwishLayerNorm +from hypothesis import given, settings, strategies as st, Verbosity + + +_task_configs: List[List[TaskConfig]] = [ + [ + TaskConfig( + task_name="is_click", + task_weight=1, + task_type=MultitaskTaskType.BINARY_CLASSIFICATION, + ), + ], + [ + TaskConfig( + task_name="vvp", + task_weight=2, + task_type=MultitaskTaskType.REGRESSION, + ), + ], + [ + TaskConfig( + task_name="is_click", + task_weight=1, + task_type=MultitaskTaskType.BINARY_CLASSIFICATION, + ), + TaskConfig( + task_name="is_like", + task_weight=2, + task_type=MultitaskTaskType.BINARY_CLASSIFICATION, + ), + TaskConfig( + task_name="is_follow", + task_weight=4, + task_type=MultitaskTaskType.BINARY_CLASSIFICATION, + ), + ], + [ + TaskConfig( + task_name="rating", + task_weight=1, + task_type=MultitaskTaskType.REGRESSION, + ), + TaskConfig( + task_name="vvp", + task_weight=2, + task_type=MultitaskTaskType.REGRESSION, + ), + ], + [ + TaskConfig( + task_name="type_1", + task_weight=2, + task_type=MultitaskTaskType.REGRESSION, + ), + ], + [ + TaskConfig( + task_name="is_click", + task_weight=1, + task_type=MultitaskTaskType.BINARY_CLASSIFICATION, + ), + TaskConfig( + task_name="is_like", + task_weight=2, + task_type=MultitaskTaskType.BINARY_CLASSIFICATION, + ), + TaskConfig( + task_name="is_follow", + task_weight=4, + task_type=MultitaskTaskType.BINARY_CLASSIFICATION, + ), + TaskConfig( + task_name="rating", + task_weight=1, + task_type=MultitaskTaskType.REGRESSION, + ), + TaskConfig( + task_name="vvp", + task_weight=2, + task_type=MultitaskTaskType.REGRESSION, + ), + ], + [ + TaskConfig( + task_name="is_click", + task_weight=1, + task_type=MultitaskTaskType.BINARY_CLASSIFICATION, + ), + TaskConfig( + task_name="is_like", + task_weight=2, + task_type=MultitaskTaskType.BINARY_CLASSIFICATION, + ), + TaskConfig( + task_name="is_follow", + task_weight=4, + task_type=MultitaskTaskType.BINARY_CLASSIFICATION, + ), + TaskConfig( + task_name="rating", + task_weight=1, + task_type=MultitaskTaskType.REGRESSION, + ), + TaskConfig( + task_name="vvp", + task_weight=2, + task_type=MultitaskTaskType.REGRESSION, + ), + ], +] + + +def _get_random_supervision_labels_and_weights( + num_examples: int, + task_configs: List[TaskConfig], + device: torch.device, +) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + supervision_labels: Dict[str, torch.Tensor] = {} + supervision_weights: Dict[str, torch.Tensor] = {} + for task in task_configs: + if task.task_type == MultitaskTaskType.REGRESSION: + supervision_labels[task.task_name] = torch.randn( + num_examples, device=device + ).to(torch.float32) + else: + supervision_labels[task.task_name] = torch.randint( + 0, + 11, + (num_examples,), + device=device, + ).to(torch.float32) + + return supervision_labels, supervision_weights + + +class MultiTaskModuleTest(unittest.TestCase): + # pyre-ignore + @given( + task_config_idx=st.sampled_from(range(len(_task_configs))), + training=st.booleans(), + is_inference=st.booleans(), + dtype=st.sampled_from( + [torch.float32, torch.bfloat16] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + ) + @unittest.skipIf(*gpu_unavailable) + @settings(verbosity=Verbosity.verbose, max_examples=50, deadline=None) + def test_default_multitask_module( + self, + task_config_idx: int, + training: bool, + is_inference: bool, + dtype: torch.dtype, + ) -> None: + set_dev_mode(True) + device = torch.device("cuda") + + L = 200 + embedding_dim = 64 + causal_multitask_weights = 0.3 + + task_configs: List[TaskConfig] = _task_configs[task_config_idx] + task_configs.sort(key=lambda x: x.task_type) + multitask_module = DefaultMultitaskModule( + task_configs=task_configs, + embedding_dim=embedding_dim, + prediction_fn=lambda in_dim, num_tasks: torch.nn.Sequential( + torch.nn.Linear(in_features=in_dim, out_features=512), + SwishLayerNorm(512), + torch.nn.Linear(in_features=512, out_features=num_tasks), + ), + causal_multitask_weights=causal_multitask_weights, + is_inference=is_inference, + ).to(device) + multitask_module.set_training_dtype(dtype) + supervision_labels, supervision_weights = ( + _get_random_supervision_labels_and_weights( + num_examples=L, + task_configs=task_configs, + device=device, + ) + ) + encoded_user_embeddings = torch.rand(L, embedding_dim, device=device) + item_embeddings = torch.rand(L, embedding_dim, device=device) + + ( + mt_preds, + mt_labels, + mt_weights, + mt_losses, + ) = multitask_module( + encoded_user_embeddings=encoded_user_embeddings, + item_embeddings=item_embeddings, + supervision_labels=supervision_labels, + supervision_weights=supervision_weights, + ) + + self.assertEqual(mt_preds.size(), (len(task_configs), L)) + if not is_inference: + self.assertEqual(mt_labels.size(), (len(task_configs), L)) + self.assertEqual(mt_weights.size(), (len(task_configs), L)) + if training: + self.assertEqual(mt_losses.size(), (len(task_configs),)) diff --git a/recommendation_v4/generative_recommenders/modules/tests/stu_test.py b/recommendation_v4/generative_recommenders/modules/tests/stu_test.py new file mode 100644 index 000000000..f6440e55e --- /dev/null +++ b/recommendation_v4/generative_recommenders/modules/tests/stu_test.py @@ -0,0 +1,453 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import copy +import unittest +from typing import List + +import torch +from generative_recommenders.common import gpu_unavailable, HammerKernel, set_dev_mode +from generative_recommenders.modules.stu import STU, STULayer, STULayerConfig, STUStack +from generative_recommenders.ops.jagged_tensors import split_2D_jagged +from hypothesis import given, settings, strategies as st, Verbosity + + +def _inplace_swap( + batch_size: int, + x: torch.Tensor, + swap_from: torch.Tensor, + swap_to: torch.Tensor, +) -> torch.Tensor: + for i in range(batch_size): + tmp = x[i, swap_from[i], :].detach().clone() + x[i, swap_from[i], :] = x[i, swap_to[i], :] + x[i, swap_to[i], :] = tmp + return x + + +class StuTest(unittest.TestCase): + # pyre-ignore + @given( + causal=st.sampled_from([True]), + num_layers=st.sampled_from([2]), + num_heads=st.sampled_from([1, 2]), + max_uih_len=st.sampled_from([20, 64]), + batch_size=st.sampled_from([8]), + embedding_dim=st.sampled_from([16]), + attention_dim=st.sampled_from([32]), + linear_hidden_dim=st.sampled_from([64]), + has_multiple_targets=st.sampled_from([True, False]), + contextual_seq_len=st.sampled_from([0, 10]), + use_group_norm=st.sampled_from([True, False]), + recompute_uvqk_in_backward=st.sampled_from([True, False]), + recompute_normed_x_in_backward=st.sampled_from([True, False]), + recompute_y_in_backward=st.sampled_from([True, False]), + empty_inputs=st.sampled_from([False]), + dtype=st.sampled_from( + [torch.float32] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + ) + @unittest.skipIf(*gpu_unavailable) + @settings(verbosity=Verbosity.verbose, max_examples=100, deadline=None) + def test_triton( + self, + causal: bool, + num_layers: int, + num_heads: int, + max_uih_len: int, + batch_size: int, + embedding_dim: int, + attention_dim: int, + linear_hidden_dim: int, + has_multiple_targets: bool, + contextual_seq_len: int, + use_group_norm: bool, + recompute_uvqk_in_backward: bool, + recompute_normed_x_in_backward: bool, + recompute_y_in_backward: bool, + empty_inputs: bool, # test the case where all the seqlen in the batch are 0 + dtype: torch.dtype, + ) -> None: + set_dev_mode(True) + device = torch.device("cuda") + + stu_layers: List[STU] = [ + STULayer( + config=STULayerConfig( + embedding_dim=embedding_dim, + num_heads=num_heads, + hidden_dim=linear_hidden_dim, + attention_dim=attention_dim, + output_dropout_ratio=0.0, + causal=causal, + target_aware=has_multiple_targets, + max_attn_len=None, + attn_alpha=None, + use_group_norm=use_group_norm, + recompute_normed_x=recompute_normed_x_in_backward, + recompute_uvqk=recompute_uvqk_in_backward, + recompute_y=recompute_y_in_backward, + sort_by_length=True, + contextual_seq_len=contextual_seq_len, + ), + is_inference=False, + ) + for _ in range(num_layers) + ] + stu = STUStack( + stu_list=stu_layers, + is_inference=False, + ).to(device) + stu.recursive_setattr("_hammer_kernel", HammerKernel.PYTORCH) + stu_triton = copy.deepcopy(stu) + stu_triton.recursive_setattr("_hammer_kernel", HammerKernel.TRITON) + + if empty_inputs: + x_lengths = torch.zeros(batch_size, dtype=torch.int32, device=device) + num_targets = torch.zeros(batch_size, dtype=torch.int32, device=device) + contextual_seq_len = 0 + max_seq_len = 16 + else: + x_lengths = torch.randint(max_uih_len + 1, (batch_size,), device=device) + x_lengths = x_lengths + contextual_seq_len + max_seq_len = max_uih_len + contextual_seq_len + max_targets = 20 + num_targets = torch.randint( + 1, max_targets, size=(batch_size,), device=device + ) + if has_multiple_targets: + x_lengths = x_lengths + num_targets + max_seq_len = max_seq_len + max_targets + x_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(x_lengths) + total_seq_len = int(x_offsets[-1].cpu().item()) + x = torch.randn( + int(total_seq_len), + embedding_dim, + device=device, + dtype=dtype, + ).requires_grad_(True) + x_triton = x.clone().detach().requires_grad_() + stu_output = stu( + x=x, + x_lengths=x_lengths, + x_offsets=x_offsets, + max_seq_len=max_seq_len, + num_targets=num_targets, + ) + stu_triton_output = stu_triton( + x=x_triton, + x_lengths=x_lengths, + x_offsets=x_offsets, + max_seq_len=max_seq_len, + num_targets=num_targets, + ) + atol = 5e-3 if dtype == torch.bfloat16 else None + rtol = 1e-2 if dtype == torch.bfloat16 else None + torch.testing.assert_close(stu_triton_output, stu_output, atol=atol, rtol=rtol) + dout = torch.randn_like(stu_output) + stu_output.backward(dout) + dout = dout.detach().clone() + stu_triton_output.backward(dout) + torch.testing.assert_close(x.grad, x_triton.grad, atol=atol, rtol=rtol) + + # pyre-ignore + @given( + dtype=st.sampled_from( + [torch.float32] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + ) + @unittest.skipIf(*gpu_unavailable) + @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + def test_target_invariance( + self, + dtype: torch.dtype, + ) -> None: + set_dev_mode(True) + device = torch.device("cuda") + num_layers = 2 + num_heads = 2 + max_seq_len = 32 + batch_size = 8 + embedding_dim = 16 + attention_dim = 32 + linear_hidden_dim = 32 + causal = True + use_group_norm = False + recompute_normed_x_in_backward = False + recompute_uvqk_in_backward = False + recompute_y_in_backward = False + max_attn_len = None + stu_layers: List[STU] = [ + STULayer( + config=STULayerConfig( + embedding_dim=embedding_dim, + num_heads=num_heads, + hidden_dim=linear_hidden_dim, + attention_dim=attention_dim, + output_dropout_ratio=0.0, + causal=causal, + target_aware=True, + max_attn_len=max_attn_len, + attn_alpha=None, + use_group_norm=use_group_norm, + recompute_normed_x=recompute_normed_x_in_backward, + recompute_uvqk=recompute_uvqk_in_backward, + recompute_y=recompute_y_in_backward, + sort_by_length=True, + contextual_seq_len=0, + ), + is_inference=False, + ) + for _ in range(num_layers) + ] + stu = STUStack( + stu_list=stu_layers, + is_inference=False, + ).to(device) + + x_lengths = torch.randint( + low=2, high=max_seq_len + 1, size=(batch_size,), device=device + ) + num_targets = torch.randint(low=2, high=10, size=(batch_size,), device=device) + x_lengths = x_lengths + num_targets + x_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(x_lengths) + total_seq_len = int(x_offsets[-1].cpu()) + + swap_from = torch.remainder( + torch.randint(20, (batch_size,), device=device), num_targets + ) + swap_to = torch.remainder( + torch.randint(20, (batch_size,), device=device), num_targets + ) + swap_from = x_lengths - 1 - swap_from + swap_to = x_lengths - 1 - swap_to + max_seq_len = int(x_lengths.max().item()) + + # forward() + x = torch.randn( + int(total_seq_len), + embedding_dim, + device=device, + dtype=dtype, + ).requires_grad_(True) + stu_output = stu( + x=x, + x_lengths=x_lengths, + x_offsets=x_offsets, + max_seq_len=max_seq_len, + num_targets=num_targets, + ) + stu_output_dense = torch.ops.fbgemm.jagged_to_padded_dense( + values=stu_output, + offsets=[x_offsets], + max_lengths=[max_seq_len], + padding_value=0.0, + ) + + # swapped forward(). + dense_x = torch.ops.fbgemm.jagged_to_padded_dense( + x.detach(), + [x_offsets], + [max_seq_len], + ) + swapped_dense_x = _inplace_swap(batch_size, dense_x, swap_from, swap_to) + swapped_x = torch.ops.fbgemm.dense_to_jagged( + swapped_dense_x, + [x_offsets], + )[0].requires_grad_(True) + swapped_stu_output = stu( + x=swapped_x, + x_lengths=x_lengths, + x_offsets=x_offsets, + max_seq_len=max_seq_len, + num_targets=num_targets, + ) + swapped_stu_output_dense = torch.ops.fbgemm.jagged_to_padded_dense( + values=swapped_stu_output, + offsets=[x_offsets], + max_lengths=[max_seq_len], + padding_value=0.0, + ) + + # backward + dout = torch.randn_like(stu_output_dense) + stu_output_dense.backward(dout) + dout = dout.detach().clone() + swapped_stu_output_dense.backward( + _inplace_swap(batch_size, dout, swap_from, swap_to) + ) + + swapped_swapped_stu_output_dense = _inplace_swap( + batch_size, swapped_stu_output_dense, swap_from, swap_to + ) + torch.testing.assert_close(stu_output_dense, swapped_swapped_stu_output_dense) + + # backward + torch.testing.assert_close( + torch.ops.fbgemm.jagged_to_padded_dense( + swapped_x.grad, + [x_offsets], + [max_seq_len], + ), + _inplace_swap( + batch_size, + torch.ops.fbgemm.jagged_to_padded_dense( + x.grad, + [x_offsets], + [max_seq_len], + ), + swap_from, + swap_to, + ), + ) + + # pyre-ignore[56] + @given( + num_layers=st.sampled_from([1, 2, 4]), + num_heads=st.sampled_from([1, 4]), + max_uih_len=st.sampled_from([20, 128]), + batch_size=st.sampled_from([4, 8]), + embedding_dim=st.sampled_from([32]), + attention_dim=st.sampled_from([16]), + linear_hidden_dim=st.sampled_from([64]), + contextual_seq_len=st.sampled_from([0, 10]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) + @unittest.skipIf(*gpu_unavailable) + @torch.inference_mode() + def test_cached_forward( + self, + num_layers: int, + num_heads: int, + max_uih_len: int, + batch_size: int, + embedding_dim: int, + attention_dim: int, + linear_hidden_dim: int, + contextual_seq_len: int, + ) -> None: + set_dev_mode(True) + device = torch.device("cuda") + + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + + use_group_norm = False + recompute_normed_x_in_backward = False + recompute_uvqk_in_backward = False + recompute_y_in_backward = False + max_attn_len = None + stu_layers: List[STU] = [ + STULayer( + config=STULayerConfig( + embedding_dim=embedding_dim, + num_heads=num_heads, + hidden_dim=linear_hidden_dim, + attention_dim=attention_dim, + output_dropout_ratio=0.0, + causal=True, + target_aware=True, + max_attn_len=max_attn_len, + attn_alpha=None, + use_group_norm=use_group_norm, + recompute_normed_x=recompute_normed_x_in_backward, + recompute_uvqk=recompute_uvqk_in_backward, + recompute_y=recompute_y_in_backward, + sort_by_length=True, + contextual_seq_len=contextual_seq_len, + ), + is_inference=True, + ) + for _ in range(num_layers) + ] + stu = STUStack( + stu_list=stu_layers, + is_inference=True, + ).to(device) + stu.recursive_setattr("_hammer_kernel", HammerKernel.TRITON) + stu.eval() + + x_lengths = torch.randint( + max_uih_len, max_uih_len + 1, (batch_size,), device=device + ) + x_lengths = x_lengths + contextual_seq_len + max_seq_len = max_uih_len + contextual_seq_len + delta_size = 20 + max_targets = delta_size * 2 + num_targets = torch.randint( + delta_size, max_targets + 1, size=(batch_size,), device=device + ) + x_lengths = x_lengths + num_targets + contextual_seq_len + max_seq_len = max_seq_len + max_targets + contextual_seq_len + x_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(x_lengths) + total_seq_len = int(x_offsets[-1].cpu().item()) + x = torch.randn( + int(total_seq_len), + embedding_dim, + device=device, + ).requires_grad_(True) + + # default forward(). + ref_y = stu( + x=x, + x_lengths=x_lengths, + x_offsets=x_offsets, + max_seq_len=max_seq_len, + num_targets=num_targets, + ) + prime_lengths = x_lengths - delta_size + prime_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(prime_lengths) + _, ref_delta_y = split_2D_jagged( + max_seq_len=max_seq_len, + values=ref_y, + max_len_left=None, + max_len_right=delta_size, + offsets_left=prime_offsets, + offsets_right=None, + kernel=HammerKernel.TRITON, + ) + + # cached forward(). + prime_x, delta_x = split_2D_jagged( + max_seq_len=max_seq_len, + values=x, + max_len_left=None, + max_len_right=delta_size, + offsets_left=prime_offsets, + offsets_right=None, + kernel=HammerKernel.TRITON, + ) + _ = stu( + x=prime_x, + x_lengths=prime_lengths, + x_offsets=prime_offsets, + max_seq_len=max_seq_len, + num_targets=num_targets - delta_size, + max_kv_caching_len=max_seq_len - delta_size, + kv_caching_lengths=x_lengths - delta_size, + ) + delta_y = stu.cached_forward( + delta_x=delta_x, + num_targets=num_targets, + ) + + torch.testing.assert_close(ref_delta_y, delta_y) diff --git a/recommendation_v4/generative_recommenders/ops/benchmarks/addmm_bench.py b/recommendation_v4/generative_recommenders/ops/benchmarks/addmm_bench.py new file mode 100644 index 000000000..b1be3a803 --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/benchmarks/addmm_bench.py @@ -0,0 +1,174 @@ +# pyre-unsafe +import time +from typing import List, Optional, Tuple + +import click +import pandas as pd +import torch + +# @manual=//triton:triton +import triton +from generative_recommenders.common import HammerKernel +from generative_recommenders.ops.triton.triton_addmm import ( + triton_addmm_fwd, + triton_addmm_fwd_tma_persistent, + triton_addmm_fwd_tma_ws_persistent_tlx, + triton_addmm_fwd_tma_ws_tlx, +) +from generative_recommenders.ops.utils import is_sm100 + +try: + # @manual=//triton:triton + import triton.language.extra.tlx as tlx # type: ignore + + HAS_TLX = True +except ImportError: + tlx = None + HAS_TLX = False + + +def get_kernel(provider: str) -> HammerKernel: + if provider == "triton": + return HammerKernel.TRITON + elif provider == "pytorch": + return HammerKernel.PYTORCH + else: + raise ValueError(f"Unknown provider {provider}") + + +def get_dtype(dtype: str) -> torch.dtype: + if dtype == "bfloat16": + return torch.bfloat16 + elif dtype == "float32": + return torch.float32 + elif dtype == "float16": + return torch.float16 + else: + raise ValueError(f"Not supported dtype {dtype}") + + +@click.command() +@click.option("--m", type=int, default=0) +@click.option("--k", type=int, default=4096) +@click.option("--n", type=int, default=4096) +@click.option("--dtype", type=str, default="bfloat16") +@click.option("--return-result", type=bool, default=False) +@click.option("--broadcast-y", type=bool, is_flag=True, default=False) +def main( + m: int, + k: int, + n: int, + dtype: str, + return_result: bool, + broadcast_y: bool, +) -> Optional[Tuple[List[triton.testing.Benchmark], List[pd.DataFrame]]]: + if m == 0: + batch_sizes = [64, 128, 256, 512] + else: + batch_sizes = [m] + line_vals = [ + "pytorch", + "triton", + "triton_tma_persistent", + "triton_tma_persistent_ws", + ] + line_names = [ + "PyTorch", + "Triton", + "Triton TMA Persistent", + "Triton TMA Persistent WS", + ] + styles = [ + ("red", "-"), + ("green", "-"), + ("orange", "-"), + ("purple", "-"), + ] + if is_sm100() and HAS_TLX: # tmem is only supported on Blackwell + line_vals.append("triton_tma_ws_tlx") + line_names.append("Triton TMA WS TLX") + styles.append(("cyan", "-")) + line_vals.append("triton_tma_persistent_ws_tlx") + line_names.append("Triton TMA Persistent WS TLX") + styles.append(("magenta", "-")) + configs: List[triton.testing.Benchmark] = [ + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=batch_sizes, + line_arg="provider", + line_vals=line_vals, + line_names=line_names, + styles=styles, + ylabel="ms", + plot_name=f"addmm-K-{k}-N-{n}-mode-{mode}-dtype-{dtype}-broadcast_y-{broadcast_y}", + args={ + "K": k, + "N": n, + "dtype": dtype, + "broadcast_y": broadcast_y, + }, + ) + for mode in ["fwd"] + ] + + @triton.testing.perf_report(configs) + def bench_addmm( + batch_size: int, + K: int, + N: int, + dtype: str, + provider: str, + broadcast_y: bool, + ) -> float: + warmup = 20 + rep = 2000 + x = torch.randn( + (batch_size, K), dtype=get_dtype(dtype), device=torch.device("cuda") + ).requires_grad_(True) + weight = torch.randn( + (N, K), dtype=get_dtype(dtype), device=torch.device("cuda") + ).requires_grad_(True) + if broadcast_y: + y = torch.randn( + (N), dtype=get_dtype(dtype), device=torch.device("cuda") + ).requires_grad_(True) + else: + y = torch.randn( + (batch_size, N), dtype=get_dtype(dtype), device=torch.device("cuda") + ).requires_grad_(True) + + # Make sure tensors are contiguous for TMA kernels + weight_t_contiguous = weight.T.contiguous() + + if provider == "pytorch": + fn = lambda: torch.addmm(y, x, weight.T) # noqa E731 + elif provider == "triton_tma_persistent": + fn = lambda: triton_addmm_fwd_tma_persistent( + x, weight_t_contiguous, y, warp_specialize=False + ) # noqa E731 + elif provider == "triton_tma_persistent_ws": + fn = lambda: triton_addmm_fwd_tma_persistent( + x, weight_t_contiguous, y, warp_specialize=True + ) # noqa E731 + elif provider == "triton_tma_persistent_ws_tlx": + fn = lambda: triton_addmm_fwd_tma_ws_persistent_tlx( + x, weight_t_contiguous, y + ) # noqa E731 + elif provider == "triton_tma_ws_tlx": + fn = lambda: triton_addmm_fwd_tma_ws_tlx(x, weight_t_contiguous, y) # noqa E731 + elif provider == "triton": + fn = lambda: triton_addmm_fwd(x, weight_t_contiguous, y) # noqa E731 + else: + raise ValueError(f"Unknown provider: {provider}") + time.sleep(2) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + + df = bench_addmm.run(print_data=True, return_df=return_result) + + if return_result: + return configs, df + + +if __name__ == "__main__": + main() diff --git a/recommendation_v4/generative_recommenders/ops/benchmarks/jagged_dense_bmm_bench.py b/recommendation_v4/generative_recommenders/ops/benchmarks/jagged_dense_bmm_bench.py new file mode 100644 index 000000000..dcfb9819e --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/benchmarks/jagged_dense_bmm_bench.py @@ -0,0 +1,199 @@ +# pyre-strict +import math +from typing import List, Optional, Tuple + +import click +import pandas as pd +import torch + +# @manual=//triton:triton +import triton +from generative_recommenders.common import HammerKernel +from generative_recommenders.ops.triton.triton_jagged import triton_jagged_dense_bmm + +# buck2 run @mode/{opt,inplace} //generative_recommenders/ops/benchmarks:jagged_dense_bmm_bench -- --fwd-only + + +def jagged_dense_bmm( + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> torch.Tensor: + """ + Wrapper function for jagged_dense_bmm with kernel selection. + Computing out = jagged x dense + jagged has shape (sum_B(M_i), K), dense has shape (B, K, N) + out has shape (sum_B(M_i), N) + """ + if kernel == HammerKernel.TRITON: + return triton_jagged_dense_bmm( + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + jagged=jagged, + dense=dense, + ) + elif kernel == HammerKernel.PYTORCH: + # PyTorch implementation - manual implementation using standard operations + B, K, N = dense.shape + outputs = [] + for i in range(B): + start_idx = seq_offsets[i] + end_idx = seq_offsets[i + 1] + if start_idx < end_idx: + jagged_seq = jagged[start_idx:end_idx] # (seq_len, K) + dense_seq = dense[i] # (K, N) + output_seq = torch.mm(jagged_seq, dense_seq) # (seq_len, N) + outputs.append(output_seq) + return ( + torch.cat(outputs, dim=0) + if outputs + else torch.empty(0, N, device=jagged.device, dtype=jagged.dtype) + ) + else: + raise ValueError(f"Unsupported kernel: {kernel}") + + +@click.command() +@click.option( + "--batch-size", + type=int, + default=512, +) +@click.option( + "--max-seq-len", + type=int, + default=8192, + show_default=True, +) +@click.option( + "-d", + type=int, + default=64, + show_default=True, +) +@click.option( + "-k", + type=int, + default=64, + show_default=True, +) +@click.option("--dtype", type=str, default="bf16") +@click.option("--fwd-only", is_flag=True) +@click.option("--return-result", type=bool, default=False) +def main( + batch_size: int, + max_seq_len: int, + d: int, + k: int, + dtype: str, + fwd_only: bool, + return_result: bool, +) -> Optional[Tuple[List[triton.testing.Benchmark], List[pd.DataFrame]]]: + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + max_seq_len_log2 = int(round(math.log2(max_seq_len))) + if dtype == "fp32": + pt_dtype = torch.float32 + elif dtype == "fp16": + pt_dtype = torch.float16 + elif dtype == "bf16": + pt_dtype = torch.bfloat16 + else: + raise ValueError(f"Unsupported data type: {dtype}.") + + configs: List[triton.testing.Benchmark] = [ + triton.testing.Benchmark( + x_names=["seq_len"], + x_vals=[2**i for i in range(5, max_seq_len_log2 + 1)], + line_arg="provider", + line_vals=["triton", "pytorch"], + line_names=["Triton", "Pytorch"], + styles=[("red", "-"), ("blue", "-")], + ylabel="ms", + plot_name=f"jagged_dense_bmm-b{batch_size}-D{d}-K{k}-{dtype}", + args={ + "batch_size": batch_size, + "D": d, + "K": k, + "dtype": pt_dtype, + "mode": mode, + }, + ) + for mode in (["fwd"] if fwd_only else ["fwd", "bwd"]) + ] + + @triton.testing.perf_report(configs) + def bench_jagged_dense_bmm( + batch_size: int, + seq_len: int, + D: int, + K: int, + mode: str, + provider: str, + dtype: torch.dtype, + ) -> float: + assert mode in ["fwd", "bwd"] + warmup = 25 + rep = 100 + + max_seq_len = seq_len + lengths = torch.randint( + max_seq_len + 1, size=(batch_size,), device=torch.device("cuda") + ) + seq_offsets = torch.zeros( + (batch_size + 1,), dtype=torch.int64, device=torch.device("cuda") + ) + seq_offsets[1:] = torch.cumsum(lengths, dim=0) + jagged_size = int(seq_offsets[-1].item()) + jagged = ( + torch.empty((jagged_size, D), dtype=dtype, device=torch.device("cuda")) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + dense = ( + torch.empty((batch_size, D, K), dtype=dtype, device=torch.device("cuda")) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + + if provider == "triton": + fn = lambda: jagged_dense_bmm( # noqa E731 + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + jagged=jagged, + dense=dense, + kernel=HammerKernel.TRITON, + ) + if mode == "bwd": + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) # noqa E731 + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + elif provider == "pytorch": + fn = lambda: jagged_dense_bmm( # noqa E731 + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + jagged=jagged, + dense=dense, + kernel=HammerKernel.PYTORCH, + ) + if mode == "bwd": + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) # noqa E731 + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + else: + raise ValueError(f"unsupported provider: {provider}") + + df = bench_jagged_dense_bmm.run(print_data=True, return_df=return_result) + + if return_result: + return configs, df + + +if __name__ == "__main__": + main() diff --git a/recommendation_v4/generative_recommenders/ops/benchmarks/jagged_dense_bmm_broadcast_add_bench.py b/recommendation_v4/generative_recommenders/ops/benchmarks/jagged_dense_bmm_broadcast_add_bench.py new file mode 100644 index 000000000..193704bf9 --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/benchmarks/jagged_dense_bmm_broadcast_add_bench.py @@ -0,0 +1,270 @@ +# pyre-strict +import math +import pickle +from typing import List + +import click +import torch + +# @manual=//triton:triton +import triton +from generative_recommenders.common import HammerKernel +from generative_recommenders.ops.jagged_tensors import jagged_dense_bmm_broadcast_add +from generative_recommenders.ops.triton.triton_jagged import ( + jagged_dense_bmm_broadcast_add_kernel, + triton_jagged_dense_bmm, + triton_jagged_dense_broadcast_add, +) + +# buck2 run @mode/{opt,inplace} //generative_recommenders/ops/benchmarks:jagged_dense_bmm_broadcast_add_bench -- --fwd-only + +# To dump the jagged_dense_bmm_broadcast_add_kernel cache +# buck2 run @mode/opt //generative_recommenders/ops/benchmarks:jagged_dense_bmm_broadcast_add_bench -- --fwd-only --dump-cache-dir=/home/${USER}/fbsource/fbcode/generative_recommenders/ops/triton/jagged_dense_bmm_broadcast_add_kernel_cache.pkl + + +def get_kernel(provider: str) -> HammerKernel: + if provider == "triton": + return HammerKernel.TRITON + elif provider == "pytorch": + return HammerKernel.PYTORCH + else: + raise ValueError(f"Unknown provider {provider}") + + +def jagged_dense_broadcast_add( + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> torch.Tensor: + """ + Wrapper function for jagged_dense_broadcast_add with kernel selection. + Computing out = jagged + dense (broadcasted) + jagged has shape (sum_B(M_i), N), dense has shape (B, N) + out has shape (sum_B(M_i), N) + """ + if kernel == HammerKernel.TRITON: + return triton_jagged_dense_broadcast_add( + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + jagged=jagged, + dense=dense, + ) + elif kernel == HammerKernel.PYTORCH: + # PyTorch implementation - manual implementation using standard operations + B, N = dense.shape + outputs = [] + for i in range(B): + start_idx = seq_offsets[i] + end_idx = seq_offsets[i + 1] + if start_idx < end_idx: + jagged_seq = jagged[start_idx:end_idx] # (seq_len, N) + dense_seq = dense[i] # (N,) + output_seq = jagged_seq + dense_seq # (seq_len, N) + outputs.append(output_seq) + return ( + torch.cat(outputs, dim=0) + if outputs + else torch.empty(0, N, device=jagged.device, dtype=jagged.dtype) + ) + else: + raise ValueError(f"Unsupported kernel: {kernel}") + + +def jagged_dense_bmm( + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> torch.Tensor: + """ + Wrapper function for jagged_dense_bmm with kernel selection. + Computing out = jagged x dense + jagged has shape (sum_B(M_i), K), dense has shape (B, K, N) + out has shape (sum_B(M_i), N) + """ + if kernel == HammerKernel.TRITON: + return triton_jagged_dense_bmm( + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + jagged=jagged, + dense=dense, + ) + elif kernel == HammerKernel.PYTORCH: + # PyTorch implementation - manual implementation using standard operations + B, K, N = dense.shape + outputs = [] + for i in range(B): + start_idx = seq_offsets[i] + end_idx = seq_offsets[i + 1] + if start_idx < end_idx: + jagged_seq = jagged[start_idx:end_idx] # (seq_len, K) + dense_seq = dense[i] # (K, N) + output_seq = torch.mm(jagged_seq, dense_seq) # (seq_len, N) + outputs.append(output_seq) + return ( + torch.cat(outputs, dim=0) + if outputs + else torch.empty(0, N, device=jagged.device, dtype=jagged.dtype) + ) + else: + raise ValueError(f"Unsupported kernel: {kernel}") + + +@click.command() +@click.option( + "--batch-size", + type=int, + default=384, +) +@click.option( + "--max-seq-len", + type=int, + default=4096, + show_default=True, +) +@click.option( + "-d", + type=int, + default=512, + show_default=True, +) +@click.option( + "-k", + type=int, + default=512, + show_default=True, +) +@click.option("--dtype", type=str, default="bf16") +@click.option("--fwd-only", is_flag=True) +@click.option("--dump-cache-dir", type=str, default="") +def main( + batch_size: int, + max_seq_len: int, + d: int, + k: int, + dtype: str, + fwd_only: bool, + dump_cache_dir: str, +) -> None: + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + max_seq_len_log2 = int(round(math.log2(max_seq_len))) + if dtype == "fp32": + pt_dtype = torch.float32 + elif dtype == "fp16": + pt_dtype = torch.float16 + elif dtype == "bf16": + pt_dtype = torch.bfloat16 + else: + raise ValueError(f"Unsupported data type: {dtype}.") + + configs: List[triton.testing.Benchmark] = [ + triton.testing.Benchmark( + x_names=["seq_len"], + x_vals=[2**i for i in range(8, max_seq_len_log2 + 1)], + line_arg="provider", + line_vals=["triton", "pytorch", "triton_nonfused"], + line_names=["Triton", "Pytorch", "Triton_Nonfused"], + styles=[("red", "-"), ("blue", "-"), ("green", "-")], + ylabel="ms", + plot_name=f"jagged_dense_bmm_broadcast_add-{mode}-b{batch_size}-D{d}-K{k}-{dtype}", + args={ + "batch_size": batch_size, + "D": d, + "K": k, + "dtype": pt_dtype, + "mode": mode, + }, + ) + for mode in (["fwd"] if fwd_only else ["fwd", "bwd"]) + ] + + @triton.testing.perf_report(configs) + def bench_jagged_dense_bmm_broadcast_add( + batch_size: int, + seq_len: int, + D: int, + K: int, + mode: str, + provider: str, + dtype: torch.dtype, + ) -> float: + assert mode in ["fwd", "bwd"] + warmup = 25 + rep = 100 + + max_seq_len = seq_len + lengths = torch.randint( + max_seq_len + 1, size=(batch_size,), device=torch.device("cuda") + ) + seq_offsets = torch.zeros( + (batch_size + 1,), dtype=torch.int64, device=torch.device("cuda") + ) + seq_offsets[1:] = torch.cumsum(lengths, dim=0) + jagged_size = int(seq_offsets[-1].item()) + jagged = ( + torch.empty((jagged_size, D), dtype=dtype, device=torch.device("cuda")) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + dense = ( + torch.empty((batch_size, D, K), dtype=dtype, device=torch.device("cuda")) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + bias = ( + torch.empty((batch_size, K), dtype=dtype, device=torch.device("cuda")) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + + if provider in ["triton", "pytorch"]: + fn = lambda: jagged_dense_bmm_broadcast_add( # noqa E731 + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + jagged=jagged, + dense=dense, + bias=bias, + kernel=get_kernel(provider), + ) + if mode == "bwd": + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) # noqa E731 + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + elif provider == "triton_nonfused": + fn = lambda: jagged_dense_broadcast_add( # noqa E731 + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + jagged=jagged_dense_bmm( + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + jagged=jagged, + dense=dense, + kernel=HammerKernel.TRITON, + ), + dense=bias, + kernel=HammerKernel.TRITON, + ) + if mode == "bwd": + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) # noqa E731 + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + else: + raise ValueError(f"unsupported provider: {provider}") + + bench_jagged_dense_bmm_broadcast_add.run(print_data=True) + if dump_cache_dir: + with open(dump_cache_dir, "wb") as data: + # @lint-ignore PYTHONPICKLEISBAD + pickle.dump(jagged_dense_bmm_broadcast_add_kernel.cache, data) + + +if __name__ == "__main__": + main() diff --git a/recommendation_v4/generative_recommenders/ops/benchmarks/jagged_dense_broadcast_add_bench.py b/recommendation_v4/generative_recommenders/ops/benchmarks/jagged_dense_broadcast_add_bench.py new file mode 100644 index 000000000..049258e3d --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/benchmarks/jagged_dense_broadcast_add_bench.py @@ -0,0 +1,205 @@ +# pyre-strict +import math +import pickle +from typing import List, Optional, Tuple + +import click +import pandas as pd +import torch + +# @manual=//triton:triton +import triton +from generative_recommenders.common import HammerKernel +from generative_recommenders.ops.triton.triton_jagged import ( + jagged_dense_broadcast_add_kernel, + triton_jagged_dense_broadcast_add, +) + +# buck2 run @mode/{opt,inplace} //generative_recommenders/ops/benchmarks:jagged_dense_broadcast_add_bench + + +# To dump the jagged_dense_broadcast_add_kernel cache, run: +# buck2 run @mode/{opt,inplace} //generative_recommenders/ops/benchmarks:jagged_dense_broadcast_add_bench -- --dump-ragged-tuner-cache-dir=/home/${USER}/fbsource/fbcode/generative_recommenders/ops/triton/jagged_dense_broadcast_add_kernel_cache.pkl + + +def jagged_dense_broadcast_add( + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> torch.Tensor: + """ + Wrapper function for jagged_dense_broadcast_add with kernel selection. + Computing out = jagged + dense (broadcasted) + jagged has shape (sum_B(M_i), N), dense has shape (B, N) + out has shape (sum_B(M_i), N) + """ + if kernel == HammerKernel.TRITON: + return triton_jagged_dense_broadcast_add( + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + jagged=jagged, + dense=dense, + ) + elif kernel == HammerKernel.PYTORCH: + # PyTorch implementation - manual implementation using standard operations + B, N = dense.shape + outputs = [] + for i in range(B): + start_idx = seq_offsets[i] + end_idx = seq_offsets[i + 1] + if start_idx < end_idx: + jagged_seq = jagged[start_idx:end_idx] # (seq_len, N) + dense_seq = dense[i] # (N,) + output_seq = jagged_seq + dense_seq # (seq_len, N) + outputs.append(output_seq) + return ( + torch.cat(outputs, dim=0) + if outputs + else torch.empty(0, N, device=jagged.device, dtype=jagged.dtype) + ) + else: + raise ValueError(f"Unsupported kernel: {kernel}") + + +@click.command() +@click.option( + "--batch-size", + type=int, + default=512, +) +@click.option( + "--max-seq-len", + type=int, + default=8192, + show_default=True, +) +@click.option( + "-d", + type=int, + default=64, + show_default=True, +) +@click.option("--dtype", type=str, default="fp32") +@click.option("--fwd-only", is_flag=True) +@click.option("--dump-ragged-tuner-cache-dir", type=str, default="") +@click.option("--return-result", type=bool, default=False) +def main( + batch_size: int, + max_seq_len: int, + d: int, + dtype: str, + fwd_only: bool, + dump_ragged_tuner_cache_dir: str, + return_result: bool, +) -> Optional[Tuple[List[triton.testing.Benchmark], List[pd.DataFrame]]]: + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + max_seq_len_log2 = int(round(math.log2(max_seq_len))) + if dtype == "fp32": + pt_dtype = torch.float32 + elif dtype == "fp16": + pt_dtype = torch.float16 + elif dtype == "bf16": + pt_dtype = torch.bfloat16 + else: + raise ValueError(f"Unsupported data type: {dtype}.") + + configs: List[triton.testing.Benchmark] = [ + triton.testing.Benchmark( + x_names=["seq_len"], + x_vals=[2**i for i in range(5, max_seq_len_log2 + 1)], + line_arg="provider", + line_vals=["triton", "pytorch"], + line_names=["Triton", "Pytorch"], + styles=[("red", "-"), ("blue", "-")], + ylabel="ms", + plot_name=f"jagged_dense_broadcast_add-b{batch_size}-D{d}-{dtype}-{mode}", + args={ + "batch_size": batch_size, + "D": d, + "dtype": pt_dtype, + "mode": mode, + }, + ) + for mode in (["fwd"] if fwd_only else ["fwd", "fwd+bwd"]) + ] + + @triton.testing.perf_report(configs) + def bench_jagged_dense_broadcast_add( + batch_size: int, + seq_len: int, + D: int, + mode: str, + provider: str, + dtype: torch.dtype, + ) -> float: + assert mode in ["fwd", "bwd", "fwd+bwd"] + warmup = 25 + rep = 100 + + max_seq_len = seq_len + lengths = torch.randint( + max_seq_len + 1, size=(batch_size,), device=torch.device("cuda") + ) + seq_offsets = torch.zeros( + (batch_size + 1,), dtype=torch.int64, device=torch.device("cuda") + ) + seq_offsets[1:] = torch.cumsum(lengths, dim=0) + jagged_size = int(seq_offsets[-1].item()) + jagged = ( + torch.empty((jagged_size, D), dtype=dtype, device=torch.device("cuda")) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + dense = ( + torch.empty((batch_size, D), dtype=dtype, device=torch.device("cuda")) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + + if provider == "triton": + fn = lambda: jagged_dense_broadcast_add( # noqa E731 + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + jagged=jagged, + dense=dense, + kernel=HammerKernel.TRITON, + ) + if mode == "bwd": + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) # noqa E731 + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + elif provider == "pytorch": + fn = lambda: jagged_dense_broadcast_add( # noqa E731 + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + jagged=jagged, + dense=dense, + kernel=HammerKernel.PYTORCH, + ) + if mode == "bwd": + o = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) # noqa E731 + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + else: + raise ValueError(f"unsupported provider: {provider}") + + df = bench_jagged_dense_broadcast_add.run(print_data=True, return_df=return_result) + + if dump_ragged_tuner_cache_dir: + with open(dump_ragged_tuner_cache_dir, "wb") as data: + # @lint-ignore PYTHONPICKLEISBAD + pickle.dump(jagged_dense_broadcast_add_kernel.cache, data) + + if return_result: + return configs, df + + +if __name__ == "__main__": + main() diff --git a/recommendation_v4/generative_recommenders/ops/hstu_attention.py b/recommendation_v4/generative_recommenders/ops/hstu_attention.py new file mode 100644 index 000000000..137482227 --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/hstu_attention.py @@ -0,0 +1,353 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from typing import Optional + +import torch +from generative_recommenders.common import HammerKernel, switch_to_contiguous_if_needed +from generative_recommenders.ops.pytorch.pt_hstu_attention import ( + pytorch_cached_hstu_mha, + pytorch_hstu_mha, +) +from generative_recommenders.ops.triton.triton_hstu_attention import ( + triton_cached_hstu_mha, + triton_hstu_mha, +) + +try: + # @manual=//generative_recommenders/ops/triton_aot:triton_ragged_hstu_attention + from generative_recommenders.ops.triton_aot.triton_ragged_hstu_attention import ( # pyre-ignore[21] + aot_triton_kernel_wrapper_cached_hstu_mha, + aot_triton_kernel_wrapper_ragged_hstu_mha, + ) +except ImportError: + + def aot_triton_kernel_wrapper_cached_hstu_mha( + *args: object, + **kwargs: object, + ) -> torch.Tensor: + raise ImportError( + "AOT-T is required for the TRITON_INFERENCE cached_hstu_mha kernel." + ) + + def aot_triton_kernel_wrapper_ragged_hstu_mha( + *args: object, + **kwargs: object, + ) -> torch.Tensor: + raise ImportError( + "AOT-T is required for the TRITON_INFERENCE ragged_hstu_mha kernel." + ) + + +try: + from hammer.ops.triton.cc.hstu_attention.triton_cc_hstu_attention import ( + triton_cc_hstu_mha, + ) + from hammer.v2.ops.triton.template.tlx_bw_hstu_attention import ( + tlx_bw_hstu_mha_wrapper, + ) +except ImportError: + tlx_bw_hstu_mha_wrapper = None + from generative_recommenders.ops.triton.triton_hstu_attention import ( + triton_hstu_mha as triton_cc_hstu_mha, + ) +from torch.fx._symbolic_trace import is_fx_tracing + +torch.fx.wrap("triton_hstu_mha") +torch.fx.wrap("triton_cached_hstu_mha") + + +@torch.fx.wrap +def hstu_mha_cuda( + max_seq_len: int, + alpha: float, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + num_targets: Optional[torch.Tensor] = None, + max_attn_len: int = 0, + contextual_seq_len: int = 0, +) -> torch.Tensor: + """TorchScript-friendly inference forwarder onto ``torch.ops.hstu.hstu_mha``. + + Bypasses the ``HammerKernel`` enum dispatch in :func:`hstu_mha` so the + scripted graph has a single concrete C++ op to call. Mirrors the + inference-only path of + :func:`generative_recommenders.ops.cpp.cuda_hstu_attention.cuda_hstu_mha_inference_wrapper` + with the subset of arguments :class:`STULayer` actually uses. + """ + return torch.ops.hstu.hstu_mha( + max_seq_len, + alpha, + q, + k, + v, + seq_offsets, + True, # causal + num_targets, + None, # attn_scale + max_attn_len, + 0, # min_full_attn_seq_len + contextual_seq_len, + None, # q_descale + None, # k_descale + None, # v_descale + False, # sort_by_length + False, # deterministic + 0, # sm_margin + 0, # max_q_len + None, # seq_offsets_q + 0, # num_softmax_heads + False, # training + None, # max_seq_len_tensor + None, # contextual_seq_len_tensor + None, # max_attn_len_tensor + None, # min_full_attn_seq_len_tensor + 1, # num_groups + ) + + +def hstu_mha( + max_seq_len: int, + alpha: float, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + causal: bool = True, + dropout_pr: float = 0.0, + training: bool = True, + num_targets: Optional[torch.Tensor] = None, + attn_scale: Optional[torch.Tensor] = None, + max_attn_len: int = 0, + contextual_seq_len: int = 0, + min_full_attn_seq_len: int = 0, + sort_by_length: bool = False, + kernel: HammerKernel = HammerKernel.PYTORCH, + enable_tma: bool = False, +) -> torch.Tensor: + _, H, _ = q.shape + if not is_fx_tracing(): + torch._assert(max_seq_len > 0, "max_seq_len must be larger than 0") + torch._assert(q.dim() == 3, "q must be 3-D") + torch._assert(k.shape == q.shape, "k must be the same shape as q") + torch._assert(v.dim() == 3, "v must be 3-D") + torch._assert(v.shape[0] == q.shape[0], "wrong v shape[0]") + torch._assert(v.shape[1] == H, "wrong v shape[1]") + torch._assert(causal, "only support causal attention") + + if kernel in [ + HammerKernel.TRITON, + HammerKernel.TLX, + HammerKernel.TRITON_CC, + HammerKernel.TRITON_INFERENCE, + ]: + if not is_fx_tracing() and kernel == HammerKernel.TRITON: + torch._assert(q.is_cuda, "q must be CUDA tensor") + torch._assert(k.is_cuda, "k must be CUDA tensor") + torch._assert(v.is_cuda, "v must be CUDA tensor") + torch._assert(seq_offsets.is_cuda, "seq_offsets must be CUDA tensor") + torch._assert(dropout_pr < 1e-6, "dropout for triton path not implemented") + torch._assert( + min_full_attn_seq_len == 0, "min_full_attn_seq_len not implemented" + ) + assert attn_scale is None, "attn_scale not implemented" + q = switch_to_contiguous_if_needed(q) + k = switch_to_contiguous_if_needed(k) + v = switch_to_contiguous_if_needed(v) + seq_offsets = seq_offsets.contiguous() + + if kernel == HammerKernel.TRITON: + return triton_hstu_mha( + N=max_seq_len, + alpha=alpha, + q=q, + k=k, + v=v, + seq_offsets=seq_offsets, + num_targets=num_targets, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + sort_by_length=sort_by_length, + enable_tma=enable_tma, + ) + elif kernel == HammerKernel.TLX: + if tlx_bw_hstu_mha_wrapper is None: + raise ImportError( + "hammer.v2 is required for the TLX kernel. " + "Falling back to TRITON or PYTORCH kernel instead." + ) + return tlx_bw_hstu_mha_wrapper( + max_seq_len=max_seq_len, + alpha=alpha, + q=q, + k=k, + v=v, + seq_offsets=seq_offsets, + attn_scale=torch.tensor(1.0 / max_seq_len, device=q.device), + num_targets=num_targets, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + sort_by_length=sort_by_length, + ) + elif kernel == HammerKernel.TRITON_CC: + return triton_cc_hstu_mha( + N=max_seq_len, + alpha=alpha, + q=q, + k=k, + v=v, + seq_offsets=seq_offsets, + num_targets=num_targets, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + ) + elif kernel == HammerKernel.TRITON_INFERENCE: + return aot_triton_kernel_wrapper_ragged_hstu_mha( + N=max_seq_len, + alpha=alpha, + q=q, + k=k, + v=v, + seq_offsets=seq_offsets, + invalid_attn_mask_type="causal", + num_targets=num_targets, + attn_scale=attn_scale, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + full_attn_size=min_full_attn_seq_len, + num_softmax_heads=0, + ) + else: + return pytorch_hstu_mha( + max_seq_len=max_seq_len, + alpha=alpha, + q=q, + k=k, + v=v, + seq_offsets=seq_offsets, + causal=True, + dropout_pr=dropout_pr, + training=training, + num_targets=num_targets, + attn_scale=attn_scale, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + min_full_attn_seq_len=min_full_attn_seq_len, + ) + + +def delta_hstu_mha( + max_seq_len: int, + alpha: float, + delta_q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + num_targets: Optional[torch.Tensor] = None, + max_attn_len: int = 0, + contextual_seq_len: int = 0, + kernel: HammerKernel = HammerKernel.PYTORCH, + enable_tma: bool = False, +) -> torch.Tensor: + L, H, D = delta_q.shape + B = seq_offsets.size(0) - 1 + DeltaSize = L // B + if not is_fx_tracing(): + torch._assert(max_seq_len > 0, "max_seq_len must be larger than 0") + torch._assert(delta_q.dim() == 3, "delta_q must be 3-D") + torch._assert(L % B == 0, "delta_q must be padded") + torch._assert(k.dim() == 3, "k must be 3-D") + torch._assert(k.shape[1] == H, "wrong k shape[1]") + torch._assert(k.shape[2] == D, "wrong k shape[2]") + torch._assert(v.dim() == 3, "v must be 3-D") + torch._assert(v.shape[1] == H, "wrong v shape[1]") + if kernel in [ + HammerKernel.TRITON, + HammerKernel.TRITON_CC, + HammerKernel.TRITON_INFERENCE, + ]: + if not is_fx_tracing() and kernel == HammerKernel.TRITON: + torch._assert(delta_q.is_cuda, "q must be CUDA tensor") + torch._assert(seq_offsets.is_cuda, "seq_offsets must be CUDA tensor") + if num_targets is not None: + torch._assert(num_targets.is_cuda, "num_targets must be CUDA tensor") + seq_offsets = seq_offsets.contiguous() + delta_q = switch_to_contiguous_if_needed(delta_q) + k = switch_to_contiguous_if_needed(k) + v = switch_to_contiguous_if_needed(v) + + if kernel == HammerKernel.TRITON: + return triton_cached_hstu_mha( + N=max_seq_len, + alpha=alpha, + delta_q=delta_q, + k=k, + v=v, + seq_offsets=seq_offsets, + num_targets=num_targets, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + enable_tma=enable_tma, + ) + elif kernel == HammerKernel.TRITON_CC: + return triton_cc_hstu_mha( + N=max_seq_len, + alpha=alpha, + q=delta_q, + k=k, + v=v, + seq_offsets=seq_offsets, + num_targets=num_targets, + is_delta_q=True, + delta_size=DeltaSize, + ) + elif kernel == HammerKernel.TRITON_INFERENCE: + delta_x_offsets = torch.arange( + 0, + L + 1, + DeltaSize, + device=delta_q.device, + dtype=seq_offsets.dtype, + ) + return aot_triton_kernel_wrapper_cached_hstu_mha( + N=max_seq_len, + alpha=alpha, + delta_q=delta_q, + k=k, + v=v, + delta_x_offsets=delta_x_offsets, + seq_offsets=seq_offsets, + num_targets=num_targets, + attn_scale=None, + max_attn_len=max_attn_len, + full_attn_size=0, + ) + else: + return pytorch_cached_hstu_mha( + max_seq_len=max_seq_len, + alpha=alpha, + delta_q=delta_q, + k=k, + v=v, + seq_offsets=seq_offsets, + num_targets=num_targets, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + ) diff --git a/recommendation_v4/generative_recommenders/ops/hstu_compute.py b/recommendation_v4/generative_recommenders/ops/hstu_compute.py new file mode 100644 index 000000000..7728c1454 --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/hstu_compute.py @@ -0,0 +1,390 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from generative_recommenders.ops.layer_norm import layer_norm +from generative_recommenders.ops.mm import addmm +from generative_recommenders.ops.pytorch.pt_hstu_linear import ( + pytorch_hstu_compute_output, +) + +try: + from hammer.ops.triton.cc.addmm.triton_cc_addmm import triton_cc_addmm + from hammer.ops.triton.cc.group_norm_mul_dropout.triton_cc_group_norm_mul_dropout import ( + triton_cc_group_norm_mul_dropout_wrapper, + ) + from hammer.ops.triton.cc.layer_norm_mul_dropout.triton_cc_layer_norm_mul_dropout import ( + triton_cc_layer_norm_mul_dropout_wrapper, + ) +except ImportError: + triton_cc_addmm = None + triton_cc_group_norm_mul_dropout_wrapper = None + triton_cc_layer_norm_mul_dropout_wrapper = None +from generative_recommenders.common import HammerKernel +from generative_recommenders.ops.hstu_attention import hstu_mha, hstu_mha_cuda +from generative_recommenders.ops.triton.triton_hstu_linear import ( + triton_hstu_compute_output, +) +from generative_recommenders.ops.triton.triton_hstu_preprocess_and_attention import ( + triton_hstu_preprocess_and_attention, +) +from torch.fx._symbolic_trace import is_fx_tracing + +try: + # @manual=//generative_recommenders/ops/triton_aot:triton_group_norm_mul_dropout + from generative_recommenders.ops.triton_aot.triton_group_norm_mul_dropout import ( # pyre-ignore[21] + aot_triton_kernel_wrapper_group_norm_mul_dropout, + ) + + # @manual=//generative_recommenders/ops/triton_aot:triton_layer_norm_mul_dropout + from generative_recommenders.ops.triton_aot.triton_layer_norm_mul_dropout import ( # pyre-ignore[21] + aot_triton_kernel_wrapper_layer_norm_mul_dropout, + ) +except ImportError: + + def aot_triton_kernel_wrapper_group_norm_mul_dropout( + *args: object, + **kwargs: object, + ) -> torch.Tensor: + raise ImportError( + "AOT-T is required for the TRITON_INFERENCE group_norm_mul_dropout kernel." + ) + + def aot_triton_kernel_wrapper_layer_norm_mul_dropout( + *args: object, + **kwargs: object, + ) -> torch.Tensor: + raise ImportError( + "AOT-T is required for the TRITON_INFERENCE layer_norm_mul_dropout kernel." + ) + + +torch.fx.wrap("triton_hstu_compute_output") + + +def hstu_compute_uqvk( + x: torch.Tensor, + norm_weight: torch.Tensor, + norm_bias: torch.Tensor, + norm_eps: float, + num_heads: int, + attn_dim: int, + hidden_dim: int, + uvqk_weight: torch.Tensor, + uvqk_bias: torch.Tensor, + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if torch.jit.is_scripting(): + # Script-mode fast path: pure PyTorch, no HammerKernel dispatch. + normed_x = F.layer_norm( + x, + normalized_shape=(x.shape[-1],), + weight=norm_weight, + bias=norm_bias, + eps=norm_eps, + ) + uvqk = torch.addmm(uvqk_bias, normed_x, uvqk_weight) + else: + normed_x = layer_norm( + x, + weight=norm_weight, + bias=norm_bias, + eps=norm_eps, + kernel=kernel, + ) + # NOTE: for AMD training, we go with torch.addmm instead of the triton + # version before Triton on AMD achieves on-par perf with NV GPU. + if torch.version.hip and kernel == HammerKernel.TRITON: + uvqk = torch.addmm(uvqk_bias, normed_x, uvqk_weight) + else: + uvqk = addmm(uvqk_bias, normed_x, uvqk_weight, kernel) + u, v, q, k = torch.split( + uvqk, + [ + hidden_dim * num_heads, + hidden_dim * num_heads, + attn_dim * num_heads, + attn_dim * num_heads, + ], + dim=1, + ) + u = F.silu(u) + q = q.view(-1, num_heads, attn_dim) + k = k.view(-1, num_heads, attn_dim) + v = v.view(-1, num_heads, hidden_dim) + return u, q, k, v + + +def hstu_compute_output( + attn: torch.Tensor, + u: torch.Tensor, + x: torch.Tensor, + norm_weight: torch.Tensor, + norm_bias: torch.Tensor, + norm_eps: float, + output_weight: torch.Tensor, + num_heads: int, + linear_dim: int, + dropout_ratio: float, + training: bool, + concat_u: bool, + concat_x: bool, + mul_u_activation_type: str, + group_norm: bool, + recompute_y_in_backward: bool, + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> torch.Tensor: + if torch.jit.is_scripting(): + return pytorch_hstu_compute_output( + attn=attn, + u=u, + x=x, + norm_weight=norm_weight, + norm_bias=norm_bias, + output_weight=output_weight, + eps=norm_eps, + dropout_ratio=dropout_ratio, + training=training, + concat_u=concat_u, + concat_x=concat_x, + mul_u_activation_type=mul_u_activation_type, + group_norm=group_norm, + num_heads=num_heads, + linear_dim=linear_dim, + ) + if kernel == HammerKernel.TRITON: + return triton_hstu_compute_output( + attn=attn, + u=u, + x=x, + norm_weight=norm_weight, + norm_bias=norm_bias, + output_weight=output_weight, + eps=norm_eps, + dropout_ratio=dropout_ratio, + training=training, + concat_u=concat_u, + concat_x=concat_x, + mul_u_activation_type=mul_u_activation_type, + group_norm=group_norm, + num_heads=num_heads, + linear_dim=linear_dim, + seed=None, + recompute_y_in_backward=recompute_y_in_backward, + ) + elif kernel == HammerKernel.TRITON_INFERENCE: + if group_norm: + y = aot_triton_kernel_wrapper_group_norm_mul_dropout( + x=attn, + u=u, + weight=norm_weight, + bias=norm_bias, + eps=norm_eps, + silu_u=mul_u_activation_type == "silu", + concat_ux=concat_u and concat_x, + num_heads=num_heads, + linear_dim=linear_dim, + ) + else: + y = aot_triton_kernel_wrapper_layer_norm_mul_dropout( + x=attn, + u=u, + weight=norm_weight, + bias=norm_bias, + eps=norm_eps, + dropout_ratio=dropout_ratio, + training=training, + silu_u=mul_u_activation_type == "silu", + concat_ux=concat_u and concat_x, + mul_u_activation_type=mul_u_activation_type, + ) + return addmm(x, y, output_weight, kernel) + elif kernel == HammerKernel.TRITON_CC: + if triton_cc_group_norm_mul_dropout_wrapper is None or triton_cc_addmm is None: + raise ImportError( + "hammer is required for the TRITON_CC kernel in hstu_compute_output." + ) + if group_norm: + y = triton_cc_group_norm_mul_dropout_wrapper( + x=attn, + u=u, + weight=norm_weight, + bias=norm_bias, + eps=norm_eps, + dropout_ratio=dropout_ratio, + training=training, + concat_ux=concat_u and concat_x, + num_heads=num_heads, + linear_dim=linear_dim, + ) + else: + y = triton_cc_layer_norm_mul_dropout_wrapper( + x=attn, + u=u, + weight=norm_weight, + bias=norm_bias, + eps=norm_eps, + dropout_ratio=dropout_ratio, + training=training, + concat_u=concat_u, + concat_x=concat_x, + mul_u_activation_type=mul_u_activation_type, + ) + return triton_cc_addmm(x, y, output_weight) + else: + return pytorch_hstu_compute_output( + attn=attn, + u=u, + x=x, + norm_weight=norm_weight, + norm_bias=norm_bias, + output_weight=output_weight, + eps=norm_eps, + dropout_ratio=dropout_ratio, + training=training, + concat_u=concat_u, + concat_x=concat_x, + mul_u_activation_type=mul_u_activation_type, + group_norm=group_norm, + num_heads=num_heads, + linear_dim=linear_dim, + ) + + +def hstu_preprocess_and_attention( + x: torch.Tensor, + norm_weight: torch.Tensor, + norm_bias: torch.Tensor, + norm_eps: float, + num_heads: int, + attn_dim: int, + hidden_dim: int, + uvqk_weight: torch.Tensor, + uvqk_bias: torch.Tensor, + max_seq_len: int, + seq_offsets: torch.Tensor, + attn_alpha: float, + causal: bool, + num_targets: Optional[torch.Tensor], + max_attn_len: int, + contextual_seq_len: int, + recompute_uvqk_in_backward: bool, + recompute_normed_x_in_backward: bool, + sort_by_length: bool, + prefill: bool = False, + kernel: HammerKernel = HammerKernel.PYTORCH, + enable_tma: Optional[bool] = None, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + if not is_fx_tracing(): + torch._assert(max_seq_len > 0, "max_seq_len must be larger than 0") + torch._assert(x.dim() == 2, "x must be 2-D") + torch._assert( + x.shape[1] == uvqk_weight.shape[0], + "x.shape[1] must equal uvqk_weight.shape[0]", + ) + torch._assert( + uvqk_weight.shape[1] == 2 * num_heads * (hidden_dim + attn_dim), + "uvqk_weight.shape[1] must equal 2 * num_heads * (hidden_dim + attn_dim)", + ) + torch._assert(causal is True, "only causal attention is supported.") + if torch.jit.is_scripting(): + # Script-mode: compute uvqk via PyTorch fallback then call the + # libtorch-callable CUDA HSTU MHA op directly. Avoids both the + # HammerKernel enum dispatch and the Triton-only fused path. + u, q, k, v = hstu_compute_uqvk( + x=x, + norm_weight=norm_weight, + norm_bias=norm_bias, + norm_eps=norm_eps, + num_heads=num_heads, + attn_dim=attn_dim, + hidden_dim=hidden_dim, + uvqk_weight=uvqk_weight, + uvqk_bias=uvqk_bias, + kernel=HammerKernel.PYTORCH, + ) + attn_output = hstu_mha_cuda( + max_seq_len=max_seq_len, + alpha=attn_alpha, + q=q, + k=k, + v=v, + seq_offsets=seq_offsets, + num_targets=num_targets, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + ).view(-1, hidden_dim * num_heads) + return u, attn_output, k, v + if kernel == HammerKernel.TRITON and prefill is False: + u, attn_output = triton_hstu_preprocess_and_attention( + x=x, + norm_weight=norm_weight, + norm_bias=norm_bias, + norm_eps=norm_eps, + num_heads=num_heads, + attn_dim=attn_dim, + hidden_dim=hidden_dim, + uvqk_weight=uvqk_weight, + uvqk_bias=uvqk_bias, + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + attn_alpha=attn_alpha, + num_targets=num_targets, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + recompute_uvqk_in_backward=recompute_uvqk_in_backward, + recompute_normed_x_in_backward=recompute_normed_x_in_backward, + sort_by_length=sort_by_length, + enable_tma=enable_tma, + ) + attn_output = attn_output.view(-1, hidden_dim * num_heads) + k = None + v = None + else: + u, q, k, v = hstu_compute_uqvk( + x=x, + norm_weight=norm_weight, + norm_bias=norm_bias, + norm_eps=norm_eps, + num_heads=num_heads, + attn_dim=attn_dim, + hidden_dim=hidden_dim, + uvqk_weight=uvqk_weight, + uvqk_bias=uvqk_bias, + kernel=kernel, + ) + attn_output = hstu_mha( + max_seq_len=max_seq_len, + alpha=attn_alpha, + q=q, + k=k, + v=v, + seq_offsets=seq_offsets, + causal=causal, + dropout_pr=0.0, + training=False, + num_targets=num_targets, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + sort_by_length=sort_by_length, + kernel=kernel, + ).view(-1, hidden_dim * num_heads) + return u, attn_output, k, v diff --git a/recommendation_v4/generative_recommenders/ops/jagged_tensors.py b/recommendation_v4/generative_recommenders/ops/jagged_tensors.py new file mode 100644 index 000000000..73e3c4a73 --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/jagged_tensors.py @@ -0,0 +1,451 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from typing import Optional, Tuple + +import torch +from generative_recommenders.common import HammerKernel +from generative_recommenders.ops.pytorch.pt_jagged import pytorch_jagged_dense_bmm_add +from generative_recommenders.ops.pytorch.pt_jagged_tensors import ( + pytorch_concat_2D_jagged, + pytorch_hstu_concat_l2_embeddings, + pytorch_hstu_split_l2_embeddings, + pytorch_split_2D_jagged, +) +from generative_recommenders.ops.triton.triton_jagged import triton_jagged_dense_bmm_add +from generative_recommenders.ops.triton.triton_jagged_tensors import ( + triton_concat_2D_jagged, + triton_concat_2D_jagged_multirow, + triton_split_2D_jagged, + triton_split_2D_jagged_multirow, +) +from torch.fx._symbolic_trace import is_fx_tracing + +try: + # @manual=//generative_recommenders/ops/triton_aot:triton_concat_2d_jagged + from generative_recommenders.ops.triton_aot.triton_concat_2d_jagged import ( # pyre-ignore[21] + aot_triton_kernel_wrapper_concat_2D_jagged, + ) + + # @manual=//generative_recommenders/ops/triton_aot:triton_split_2d_jagged + from generative_recommenders.ops.triton_aot.triton_split_2d_jagged import ( # pyre-ignore[21] + aot_triton_kernel_wrapper_split_2D_jagged, + ) +except ImportError: + + def aot_triton_kernel_wrapper_concat_2D_jagged( + *args: object, + **kwargs: object, + ) -> torch.Tensor: + raise ImportError( + "AOT-T is required for the TRITON_INFERENCE concat_2D_jagged kernel." + ) + + def aot_triton_kernel_wrapper_split_2D_jagged( + *args: object, + **kwargs: object, + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise ImportError( + "AOT-T is required for the TRITON_INFERENCE split_2D_jagged kernel." + ) + + +torch.fx.wrap("triton_jagged_dense_bmm_add") + +try: + from hammer.ops.triton.cc.jagged_dense_bmm.triton_cc_jagged_dense_bmm import ( + triton_cc_jagged_dense_bmm, + ) +except ImportError: + triton_cc_jagged_dense_bmm = None + + +torch.fx.wrap("triton_concat_2D_jagged") +torch.fx.wrap("triton_split_2D_jagged") +torch.fx.wrap("triton_concat_2D_jagged_multirow") +torch.fx.wrap("triton_split_2D_jagged_multirow") + + +def concat_2D_jagged( + max_seq_len: int, + values_left: torch.Tensor, + values_right: torch.Tensor, + max_len_left: Optional[int] = None, + max_len_right: Optional[int] = None, + offsets_left: Optional[torch.Tensor] = None, + offsets_right: Optional[torch.Tensor] = None, + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> torch.Tensor: + if torch.jit.is_scripting(): + return pytorch_concat_2D_jagged( + values_left=values_left, + values_right=values_right, + max_len_left=max_len_left, + max_len_right=max_len_right, + offsets_left=offsets_left, + offsets_right=offsets_right, + ) + if not is_fx_tracing(): + torch._assert(values_left.dim() == 2, "values_left must be 2D") + torch._assert(values_right.dim() == 2, "values_right must be 2D") + torch._assert( + values_right.shape[1] == values_left.shape[1], + f"values_left shape[1] must be equal to values_right shape[1] {values_left.shape[1]} vs {values_right.shape[1]}", + ) + if kernel == HammerKernel.TRITON: + return triton_concat_2D_jagged( + max_seq_len=max_seq_len, + values_left=values_left, + values_right=values_right, + max_len_left=max_len_left, + max_len_right=max_len_right, + offsets_left=offsets_left, + offsets_right=offsets_right, + ) + elif kernel == HammerKernel.TRITON_INFERENCE: + aott_values_left = values_left + aott_values_right = values_right + if offsets_left is None: + assert max_len_left is not None + aott_values_left = values_left.reshape( + -1, + max_len_left, + values_left.shape[-1], + ) + if offsets_right is None: + assert max_len_right is not None + aott_values_right = values_right.reshape( + -1, + max_len_right, + values_right.shape[-1], + ) + return aot_triton_kernel_wrapper_concat_2D_jagged( + max_seq_len=max_seq_len, + values_a=aott_values_left, + values_b=aott_values_right, + offsets_a=offsets_left, + offsets_b=offsets_right, + ) + else: + return pytorch_concat_2D_jagged( + values_left=values_left, + values_right=values_right, + max_len_left=max_len_left, + max_len_right=max_len_right, + offsets_left=offsets_left, + offsets_right=offsets_right, + ) + + +def split_2D_jagged( + max_seq_len: int, + values: torch.Tensor, + total_len_left: Optional[int] = None, + total_len_right: Optional[int] = None, + max_len_left: Optional[int] = None, + max_len_right: Optional[int] = None, + offsets_left: Optional[torch.Tensor] = None, + offsets_right: Optional[torch.Tensor] = None, + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> Tuple[torch.Tensor, torch.Tensor]: + if torch.jit.is_scripting(): + return pytorch_split_2D_jagged( + max_seq_len=max_seq_len, + values=values, + max_len_left=max_len_left, + max_len_right=max_len_right, + offsets_left=offsets_left, + offsets_right=offsets_right, + ) + if not is_fx_tracing(): + torch._assert(values.dim() == 2, "values must be 2D") + torch._assert( + offsets_left is not None or offsets_right is not None, + "offsets_left and offsets_right cannot be None at the same time", + ) + if offsets_left is None: + torch._assert( + max_len_left is not None, + "max_len_left must be provided when offsets_left is None", + ) + if offsets_right is None: + torch._assert( + max_len_right is not None, + "max_len_right must be provided when offsets_right is None", + ) + if offsets_left is not None and offsets_right is not None: + torch._assert( + offsets_left.shape[0] == offsets_right.shape[0], + "offsets_left shape[0] must be equal to offsets_right shape[0]", + ) + if kernel == HammerKernel.TRITON: + return triton_split_2D_jagged( + max_seq_len=max_seq_len, + values=values, + total_len_left=total_len_left, + total_len_right=total_len_right, + max_len_left=max_len_left, + max_len_right=max_len_right, + offsets_left=offsets_left, + offsets_right=offsets_right, + ) + elif kernel == HammerKernel.TRITON_INFERENCE: + dense_size = 0 + if offsets_left is None and max_len_left is not None: + dense_size = max_len_left + elif offsets_right is None and max_len_right is not None: + dense_size = max_len_right + split_left, split_right = aot_triton_kernel_wrapper_split_2D_jagged( + values=values, + max_seq_len=max_seq_len, + offsets_a=offsets_left, + offsets_b=offsets_right, + dense_size=dense_size, + ) + if offsets_left is None: + split_left = split_left.reshape(-1, split_left.shape[-1]) + if offsets_right is None: + split_right = split_right.reshape(-1, split_right.shape[-1]) + return split_left, split_right + else: + return pytorch_split_2D_jagged( + max_seq_len=max_seq_len, + values=values, + max_len_left=max_len_left, + max_len_right=max_len_right, + offsets_left=offsets_left, + offsets_right=offsets_right, + ) + + +def hstu_split_l2_embeddings( + max_seq_len: int, + x: torch.Tensor, + prefix_offsets: torch.Tensor, + l2_offsets: torch.Tensor, + contextual_seq_len: int, + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> Tuple[torch.Tensor, torch.Tensor]: + if kernel == HammerKernel.TRITON: + return triton_split_2D_jagged( + max_seq_len=max_seq_len, + values=x, + total_len_right=None, + total_len_left=None, + max_len_left=None, + max_len_right=None, + offsets_left=prefix_offsets, + offsets_right=l2_offsets, + n_prefix_to_right=contextual_seq_len, + ) + else: + return pytorch_hstu_split_l2_embeddings( + max_seq_len=max_seq_len, + x=x, + prefix_offsets=prefix_offsets, + l2_offsets=l2_offsets, + contextual_seq_len=contextual_seq_len, + ) + + +def hstu_concat_l2_embeddings( + max_prefix_len: int, + prefix_x: torch.Tensor, + prefix_offsets: torch.Tensor, + max_l2_len: int, + l2_x: torch.Tensor, + l2_offsets: torch.Tensor, + contextual_seq_len: int, + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> torch.Tensor: + if kernel == HammerKernel.TRITON: + return triton_concat_2D_jagged( + max_seq_len=max_prefix_len + max_l2_len, + values_left=prefix_x, + values_right=l2_x, + max_len_left=max_prefix_len, + max_len_right=max_l2_len, + offsets_left=prefix_offsets, + offsets_right=l2_offsets, + n_prefix_from_right=contextual_seq_len, + ) + else: + return pytorch_hstu_concat_l2_embeddings( + contextual_seq_len=contextual_seq_len, + max_prefix_len=max_prefix_len, + prefix_x=prefix_x, + prefix_offsets=prefix_offsets, + max_l2_len=max_l2_len, + l2_x=l2_x, + l2_offsets=l2_offsets, + ) + + +def jagged_dense_bmm_broadcast_add( + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, + bias: torch.Tensor, + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> torch.Tensor: + """ + Computing out = jagged x dense + bias + jagged has shape (sum_B(M_i), K), dense has shape (B, K, N), and bias has shape (B, N) + out has shape (sum_B(M_i), N) + """ + if not is_fx_tracing(): + _, K = jagged.shape + B, _, N = dense.shape + torch._assert(dense.shape[1] == K, "wrong dense shape[1]") + torch._assert(seq_offsets.shape[0] == B + 1, "wrong seq_offsets shape[0]") + torch._assert(bias.shape[0] == B, "wrong bias shape[0]") + torch._assert(bias.shape[1] == N, "wrong bias shape[1]") + if kernel == HammerKernel.TRITON: + return triton_jagged_dense_bmm_add( + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + jagged=jagged, + dense=dense, + bias=bias, + elementwise=False, + ) + elif kernel == HammerKernel.TRITON_CC: + if triton_cc_jagged_dense_bmm is None: + raise ImportError( + "hammer is required for the TRITON_CC kernel in jagged_dense_bmm_broadcast_add." + ) + return triton_cc_jagged_dense_bmm( + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + jagged=jagged, + dense=dense, + bias=bias, + ) + else: + return pytorch_jagged_dense_bmm_add( + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + jagged=jagged, + dense=dense, + bias=bias, + ) + + +def concat_2D_jagged_multirow( + max_seq_len: int, + values_left: torch.Tensor, + values_right: torch.Tensor, + offsets_left: Optional[torch.Tensor], + offsets_right: Optional[torch.Tensor], + max_len_left: int, + max_len_right: int, + kernel: HammerKernel = HammerKernel.TRITON, +) -> torch.Tensor: + if not is_fx_tracing(): + torch._assert(values_left.dim() == 2, "values_left must be 2D") + torch._assert(values_right.dim() == 2, "values_right must be 2D") + torch._assert( + values_right.shape[1] == values_left.shape[1], + f"values_left shape[1] must be equal to values_right shape[1] {values_left.shape[1]} vs {values_right.shape[1]}", + ) + if offsets_left is not None and offsets_right is not None: + torch._assert( + offsets_left.shape[0] == offsets_right.shape[0], + "offsets_left and offsets_right must have the same batch dimension", + ) + + if kernel == HammerKernel.TRITON: + return triton_concat_2D_jagged_multirow( + max_seq_len=max_seq_len, + values_a=values_left, + values_b=values_right, + offsets_a=offsets_left, + offsets_b=offsets_right, + max_len_a=max_len_left, + max_len_b=max_len_right, + ) + else: + return concat_2D_jagged( + max_seq_len=max_seq_len, + values_left=values_left, + values_right=values_right, + max_len_left=max_len_left, + max_len_right=max_len_right, + offsets_left=offsets_left, + offsets_right=offsets_right, + kernel=kernel, + ) + + +def split_2D_jagged_multirow( + max_seq_len: int, + values: torch.Tensor, + total_len_left: Optional[int] = None, + total_len_right: Optional[int] = None, + max_len_left: Optional[int] = None, + max_len_right: Optional[int] = None, + offsets_left: Optional[torch.Tensor] = None, + offsets_right: Optional[torch.Tensor] = None, + kernel: HammerKernel = HammerKernel.TRITON, +) -> Tuple[torch.Tensor, torch.Tensor]: + if not is_fx_tracing(): + torch._assert(values.dim() == 2, "values must be 2D") + torch._assert( + offsets_left is not None or offsets_right is not None, + "offsets_left and offsets_right cannot be None at the same time", + ) + if offsets_left is None: + torch._assert( + max_len_left is not None, + "max_len_left must be provided when offsets_left is None", + ) + if offsets_right is None: + torch._assert( + max_len_right is not None, + "max_len_right must be provided when offsets_right is None", + ) + if offsets_left is not None and offsets_right is not None: + torch._assert( + offsets_left.shape[0] == offsets_right.shape[0], + "offsets_left and offsets_right must have the same batch dimension", + ) + + if kernel == HammerKernel.TRITON: + return triton_split_2D_jagged_multirow( + max_seq_len=max_seq_len, + values=values, + total_len_left=total_len_left, + total_len_right=total_len_right, + max_len_left=max_len_left, + max_len_right=max_len_right, + offsets_left=offsets_left, + offsets_right=offsets_right, + ) + else: + return split_2D_jagged( + max_seq_len=max_seq_len, + values=values, + total_len_left=total_len_left, + total_len_right=total_len_right, + max_len_left=max_len_left, + max_len_right=max_len_right, + offsets_left=offsets_left, + offsets_right=offsets_right, + kernel=kernel, + ) diff --git a/recommendation_v4/generative_recommenders/ops/layer_norm.py b/recommendation_v4/generative_recommenders/ops/layer_norm.py new file mode 100644 index 000000000..74ed377d6 --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/layer_norm.py @@ -0,0 +1,330 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + + +from typing import List + +import torch +from generative_recommenders.ops.pytorch.pt_layer_norm import ( + pytorch_layer_norm, + pytorch_rms_norm, + pytorch_swish_layer_norm, +) +from generative_recommenders.ops.triton.triton_layer_norm import triton_rms_norm + +try: + from hammer.ops.triton.cc.rms_norm.triton_cc_rms_norm import triton_cc_rms_norm + from hammer.ops.triton.cc.swish_layer_norm.triton_cc_swish_layer_norm import ( + triton_cc_swish_layer_norm, + ) +except ImportError: + triton_cc_swish_layer_norm = None + triton_cc_rms_norm = None +from generative_recommenders.common import HammerKernel, HammerModule +from generative_recommenders.ops.triton.triton_layer_norm import ( + triton_layer_norm, + triton_swish_layer_norm, +) +from torch.fx._symbolic_trace import is_fx_tracing + +try: + # @manual=//generative_recommenders/ops/triton_aot:triton_layer_norm + from generative_recommenders.ops.triton_aot.triton_layer_norm import ( # pyre-ignore[21] + aot_triton_kernel_wrapper_swish_layer_norm, + ) + + # @manual=//generative_recommenders/ops/triton_aot:triton_rms_norm + from generative_recommenders.ops.triton_aot.triton_rms_norm import ( # pyre-ignore[21] + aot_triton_kernel_wrapper_rms_norm, + ) +except ImportError: + + def aot_triton_kernel_wrapper_swish_layer_norm( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + is_swish: bool, + ) -> torch.Tensor: + raise ImportError( + "AOT-T is required for the TRITON_INFERENCE swish_layer_norm kernel." + ) + + def aot_triton_kernel_wrapper_rms_norm( + x: torch.Tensor, + weight: torch.Tensor, + eps: float, + silu: bool, + ) -> torch.Tensor: + raise ImportError("AOT-T is required for the TRITON_INFERENCE rms_norm kernel.") + + +torch.fx.wrap("triton_layer_norm") +torch.fx.wrap("triton_swish_layer_norm") +torch.fx.wrap("triton_rms_norm") + + +def layer_norm( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-5, + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> torch.Tensor: + if torch.jit.is_scripting(): + # Script-mode fast path: bypass the HammerKernel ladder (which would + # drag in is_fx_tracing()'s closed-over global bool). + return torch.nn.functional.layer_norm( + x, + normalized_shape=(x.shape[-1],), + weight=weight, + bias=bias, + eps=eps, + ) + if kernel == HammerKernel.TRITON: + if not is_fx_tracing(): + torch._assert(not x.is_cpu, "x must be device tensor") + torch._assert(not weight.is_cpu, "weight must be device tensor") + torch._assert(not bias.is_cpu, "bias must be device tensor") + return triton_layer_norm(x, weight, bias, eps) + elif kernel == HammerKernel.TRITON_INFERENCE: + return aot_triton_kernel_wrapper_swish_layer_norm( + x, + weight, + bias, + eps, + is_swish=False, + ) + elif kernel == HammerKernel.TRITON_CC: + if triton_cc_swish_layer_norm is None: + raise ImportError( + "hammer is required for the TRITON_CC kernel in layer_norm." + ) + return triton_cc_swish_layer_norm( + x, + weight, + bias, + eps, + is_swish=False, + ) + else: + return pytorch_layer_norm( + x, + [ + x.shape[-1], + ], + weight, + bias, + eps, + ) + + +def rms_norm( + x: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-5, + kernel: HammerKernel = HammerKernel.PYTORCH, + silu: bool = False, +) -> torch.Tensor: + if torch.jit.is_scripting(): + # Script-mode fast path: bypass the HammerKernel ladder. + x_f = x.float() + norm = torch.rsqrt(x_f.pow(2).mean(-1, keepdim=True) + eps) + out = (x_f * norm * weight.float()).to(x.dtype) + if silu: + out = torch.nn.functional.silu(out) + return out + if kernel == HammerKernel.TRITON: + if not is_fx_tracing(): + torch._assert(not x.is_cpu, "x must be device tensor") + torch._assert(not weight.is_cpu, "weight must be device tensor") + return triton_rms_norm(x, weight, eps, silu) + elif kernel == HammerKernel.TRITON_INFERENCE: + return aot_triton_kernel_wrapper_rms_norm(x, weight, eps, silu) + elif kernel == HammerKernel.TRITON_CC: + if triton_cc_rms_norm is None: + raise ImportError( + "hammer is required for the TRITON_CC kernel in rms_norm." + ) + return triton_cc_rms_norm( + x, + weight, + eps, + silu=silu, + ) + else: + return pytorch_rms_norm( + x, + [ + x.shape[-1], + ], + weight, + eps, + silu, + ) + + +def swish_layer_norm( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-5, + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> torch.Tensor: + if torch.jit.is_scripting(): + # Script-mode fast path: bypass the HammerKernel ladder (which + # otherwise drags in is_fx_tracing(), Triton/Triton_CC closures, + # etc.) and call pure PyTorch directly. + return pytorch_swish_layer_norm( + x, + [x.shape[-1]], + weight, + bias, + eps, + ) + if kernel == HammerKernel.TRITON: + if not is_fx_tracing(): + torch._assert(not x.is_cpu, "x must be device tensor") + torch._assert(not weight.is_cpu, "weight must be device tensor") + torch._assert(not bias.is_cpu, "bias must be device tensor") + return triton_swish_layer_norm(x, [x.shape[-1]], weight, bias, eps) + elif kernel == HammerKernel.TRITON_INFERENCE: + return aot_triton_kernel_wrapper_swish_layer_norm( + x, + weight, + bias, + eps, + is_swish=True, + ) + elif kernel == HammerKernel.TRITON_CC: + if triton_cc_swish_layer_norm is None: + raise ImportError( + "hammer is required for the TRITON_CC kernel in swish_layer_norm." + ) + return triton_cc_swish_layer_norm( + x, + weight, + bias, + eps, + is_swish=True, + ) + else: + return pytorch_swish_layer_norm( + x, + [ + x.shape[-1], + ], + weight, + bias, + eps, + ) + + +class LayerNorm(HammerModule): + def __init__( + self, + dim: int, + eps: float = 1e-5, + is_inference: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + self._normalized_shape: List[int] = [dim] + self._eps = eps + self.weight = torch.nn.Parameter( + torch.ones(self._normalized_shape), + ) + self.bias = torch.nn.Parameter( + torch.zeros(self._normalized_shape), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return layer_norm( + x=x, + weight=self.weight, + bias=self.bias, + eps=self._eps, + kernel=self.hammer_kernel(), + ) + + +class RMSNorm(HammerModule): + def __init__( + self, + dim: int, + eps: float = 1e-5, + is_inference: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + self._eps = eps + self.weight = torch.nn.Parameter(torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return rms_norm( + x, + self.weight, + self._eps, + silu=False, + kernel=self.hammer_kernel(), + ) + + +class RMSNormSilu(HammerModule): + def __init__( + self, + dim: int, + eps: float = 1e-5, + is_inference: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + self._eps = eps + self.weight = torch.nn.Parameter(torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return rms_norm( + x, + self.weight, + self._eps, + silu=True, + kernel=self.hammer_kernel(), + ) + + +class SwishLayerNorm(HammerModule): + def __init__( + self, + dim: int, + eps: float = 1e-5, + is_inference: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + self._normalized_shape: List[int] = [dim] + self.weight = torch.nn.Parameter(torch.ones(self._normalized_shape)) + self.bias = torch.nn.Parameter(torch.zeros(self._normalized_shape)) + self._eps = eps + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + return swish_layer_norm( + x=x, + weight=self.weight, + bias=self.bias, + eps=self._eps, + kernel=self.hammer_kernel(), + ) diff --git a/recommendation_v4/generative_recommenders/ops/mm.py b/recommendation_v4/generative_recommenders/ops/mm.py new file mode 100644 index 000000000..31a5c5d36 --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/mm.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import torch + +try: + from hammer.ops.triton.cc.addmm.triton_cc_addmm import triton_cc_addmm +except ImportError: + triton_cc_addmm = None +from generative_recommenders.common import HammerKernel +from generative_recommenders.ops.triton.triton_addmm import triton_addmm + +try: + # @manual=//generative_recommenders/ops/triton_aot:triton_addmm + from generative_recommenders.ops.triton_aot.triton_addmm import ( # pyre-ignore[21] + aot_triton_kernel_wrapper_addmm, + ) +except ImportError: + + def aot_triton_kernel_wrapper_addmm( + input: torch.Tensor, + mat1: torch.Tensor, + mat2: torch.Tensor, + ) -> torch.Tensor: + raise ImportError("AOT-T is required for the TRITON_INFERENCE addmm kernel.") + + +def addmm( + input: torch.Tensor, + mat1: torch.Tensor, + mat2: torch.Tensor, + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> torch.Tensor: + if torch.jit.is_scripting(): + return torch.addmm(input, mat1, mat2) + if kernel == HammerKernel.TRITON: + return triton_addmm(input, mat1, mat2) + elif kernel == HammerKernel.TRITON_INFERENCE: + return aot_triton_kernel_wrapper_addmm(input, mat1, mat2) + elif kernel == HammerKernel.TRITON_CC: + if triton_cc_addmm is None: + raise ImportError("hammer is required for the TRITON_CC kernel in addmm.") + return triton_cc_addmm(input, mat1, mat2) + else: + return torch.addmm(input, mat1, mat2) diff --git a/recommendation_v4/generative_recommenders/ops/position.py b/recommendation_v4/generative_recommenders/ops/position.py new file mode 100644 index 000000000..e090827e3 --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/position.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from typing import Optional + +import torch +from generative_recommenders.ops.pytorch.pt_position import ( + pytorch_add_timestamp_positional_embeddings, +) + +try: + from hammer.ops.triton.cc.add_timestamp_position_embeddings.triton_cc_add_timestamp_position_embeddings import ( + triton_cc_add_timestamp_position_embeddings, + ) +except ImportError: + triton_cc_add_timestamp_position_embeddings = None +from generative_recommenders.common import HammerKernel +from generative_recommenders.ops.triton.triton_position import ( + triton_add_timestamp_positional_embeddings, +) + +try: + # @manual=//generative_recommenders/ops/triton_aot:triton_position + from generative_recommenders.ops.triton_aot.triton_position import ( # pyre-ignore[21] + aot_triton_kernel_wrapper_position, + ) +except ImportError: + + def aot_triton_kernel_wrapper_position( + *args: object, + **kwargs: object, + ) -> torch.Tensor: + raise ImportError("AOT-T is required for the TRITON_INFERENCE position kernel.") + + +torch.fx.wrap("triton_add_timestamp_positional_embeddings") + + +def add_timestamp_positional_embeddings( + alpha: float, + max_seq_len: int, + max_contextual_seq_len: int, + position_embeddings_weight: torch.Tensor, + timestamp_embeddings_weight: torch.Tensor, + seq_offsets: torch.Tensor, + seq_lengths: torch.Tensor, + seq_embeddings: torch.Tensor, + timestamps: torch.Tensor, + num_targets: Optional[torch.Tensor], + interleave_targets: bool, + time_bucket_fn: str = "sqrt", + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> torch.Tensor: + if torch.jit.is_scripting(): + # Script-mode fast path: bypass the HammerKernel ladder. + seq_embeddings = seq_embeddings * alpha + return pytorch_add_timestamp_positional_embeddings( + seq_embeddings=seq_embeddings, + seq_offsets=seq_offsets, + pos_embeddings=position_embeddings_weight, + ts_embeddings=timestamp_embeddings_weight, + timestamps=timestamps, + max_seq_len=max_seq_len, + max_contextual_seq_len=max_contextual_seq_len, + seq_lengths=seq_lengths, + num_targets=num_targets, + interleave_targets=interleave_targets, + time_bucket_fn=time_bucket_fn, + ) + assert time_bucket_fn in ["sqrt", "log"] + seq_embeddings = seq_embeddings * alpha + if kernel == HammerKernel.TRITON: + return triton_add_timestamp_positional_embeddings( + seq_embeddings=seq_embeddings, + seq_offsets=seq_offsets, + pos_embeddings=position_embeddings_weight, + ts_embeddings=timestamp_embeddings_weight, + timestamps=timestamps, + max_seq_len=max_seq_len, + max_contextual_seq_len=max_contextual_seq_len, + seq_lengths=seq_lengths, + num_targets=num_targets, + interleave_targets=interleave_targets, + time_bucket_fn=time_bucket_fn, + ) + elif kernel == HammerKernel.TRITON_INFERENCE: + return aot_triton_kernel_wrapper_position( + alpha=1.0, + max_seq_len=max_seq_len, + max_contextual_seq_len=max_contextual_seq_len, + position_embeddings_weight=position_embeddings_weight.to(torch.float32), + timestamp_embeddings_weight=timestamp_embeddings_weight.to(torch.float32), + seq_offsets=seq_offsets, + seq_lengths=seq_lengths, + seq_embeddings=seq_embeddings, + timestamps=timestamps, + num_targets=num_targets, + interleave_targets=interleave_targets, + time_bucket_fn=time_bucket_fn, + ) + elif kernel == HammerKernel.TRITON_CC: + if triton_cc_add_timestamp_position_embeddings is None: + raise ImportError( + "hammer is required for the TRITON_CC kernel in add_timestamp_positional_embeddings." + ) + return triton_cc_add_timestamp_position_embeddings( + seq_embeddings=seq_embeddings, + seq_offsets=seq_offsets, + pos_embeddings=position_embeddings_weight, + ts_embeddings=timestamp_embeddings_weight, + timestamps=timestamps, + max_seq_len=max_seq_len, + max_contextual_seq_len=max_contextual_seq_len, + seq_lengths=seq_lengths, + num_targets=num_targets, + interleave_targets=interleave_targets, + time_bucket_fn=time_bucket_fn, + ) + else: + return pytorch_add_timestamp_positional_embeddings( + seq_embeddings=seq_embeddings, + seq_offsets=seq_offsets, + pos_embeddings=position_embeddings_weight, + ts_embeddings=timestamp_embeddings_weight, + timestamps=timestamps, + max_seq_len=max_seq_len, + max_contextual_seq_len=max_contextual_seq_len, + seq_lengths=seq_lengths, + num_targets=num_targets, + interleave_targets=interleave_targets, + time_bucket_fn=time_bucket_fn, + ) diff --git a/recommendation_v4/generative_recommenders/ops/pytorch/pt_hstu_attention.py b/recommendation_v4/generative_recommenders/ops/pytorch/pt_hstu_attention.py new file mode 100644 index 000000000..32575c4db --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/pytorch/pt_hstu_attention.py @@ -0,0 +1,251 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F + + +try: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") +except OSError: + pass + + +@torch.fx.wrap +def _get_valid_attn_mask( + device: torch.device, + causal: bool, + N: int, + seq_lengths: torch.Tensor, + num_targets: Optional[torch.Tensor] = None, + max_attn_len: int = 0, + contextual_seq_len: int = 0, + min_full_attn_seq_len: int = 0, +) -> torch.Tensor: + ids = torch.arange(0, N, device=device).view(1, N) + max_ids = seq_lengths.view(-1, 1, 1) + if contextual_seq_len > 0: + ids = ids - contextual_seq_len + 1 + ids = torch.clamp(ids, min=0) + max_ids = max_ids - contextual_seq_len + 1 + if num_targets is not None: + max_ids = max_ids - num_targets.view(-1, 1, 1) + ids = torch.clamp( + ids, + max=max_ids, + ) + row_ids = ids.view(-1, N, 1).expand(-1, N, N) + col_ids = ids.view(-1, 1, N).expand(-1, N, N) + else: + row_ids = ids.view(N, 1).expand(N, N) + col_ids = row_ids.t() + row_ids = row_ids.view(1, N, N) + col_ids = col_ids.view(1, N, N) + row_col_dist = row_ids - col_ids + valid_attn_mask = torch.eye(N, device=device, dtype=torch.bool).view(1, N, N) + if not causal: + row_col_dist = torch.where(row_col_dist > 0, row_col_dist, -row_col_dist) + valid_attn_mask = torch.logical_or(valid_attn_mask, row_col_dist > 0) + if max_attn_len > 0: + if min_full_attn_seq_len > 0: + valid_attn_mask = torch.logical_and( + valid_attn_mask, + torch.logical_or( + row_col_dist <= max_attn_len, + row_ids >= max_ids - min_full_attn_seq_len, + ), + ) + else: + valid_attn_mask = torch.logical_and( + valid_attn_mask, row_col_dist <= max_attn_len + ) + if contextual_seq_len > 0: + valid_attn_mask = torch.logical_or( + valid_attn_mask, torch.logical_and(row_ids == 0, col_ids < max_ids) + ) + return valid_attn_mask + + +def _pad_qkv( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + N: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + L, H, D = q.shape + V = v.shape[2] + padded_q = ( + torch.ops.fbgemm.jagged_to_padded_dense( + values=q.reshape(L, H * D), + offsets=[seq_offsets], + max_lengths=[N], + padding_value=0.0, + ) + .view(-1, N, H, D) + .transpose(1, 2) + ) # [B, H, N, A] + padded_k = ( + torch.ops.fbgemm.jagged_to_padded_dense( + values=k.reshape(L, H * D), + offsets=[seq_offsets], + max_lengths=[N], + padding_value=0.0, + ) + .view(-1, N, H, D) + .transpose(1, 2) + ) # [B, H, N, A] + padded_v = ( + torch.ops.fbgemm.jagged_to_padded_dense( + values=v.reshape(L, H * V), + offsets=[seq_offsets], + max_lengths=[N], + padding_value=0.0, + ) + .view(-1, N, H, V) + .transpose(1, 2) + ) # [B, H, N, D] + return padded_q, padded_k, padded_v + + +@torch.fx.wrap +def pytorch_hstu_mha( + max_seq_len: int, + alpha: float, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + causal: bool = True, + dropout_pr: float = 0.0, + training: bool = True, + num_targets: Optional[torch.Tensor] = None, + attn_scale: Optional[torch.Tensor] = None, + max_attn_len: int = 0, + contextual_seq_len: int = 0, + min_full_attn_seq_len: int = 0, +) -> torch.Tensor: + L, H, _ = q.shape + V = v.shape[2] + q, k, v = _pad_qkv( + q, k, v, seq_offsets, max_seq_len + ) # [B, H, N, D) and [B, H, N, V] + qk_attn = torch.einsum("bhxa,bhya->bhxy", q, k) * alpha + if attn_scale is not None: + if attn_scale.ndim > 0: + attn_scale = ( + torch.ops.fbgemm.jagged_to_padded_dense( + values=attn_scale.unsqueeze(-1), + offsets=[seq_offsets], + max_lengths=[max_seq_len], + padding_value=0.0, + ) + .unsqueeze(1) + .to(qk_attn.dtype) + ) + + qk_attn = F.silu(qk_attn) * attn_scale + else: + qk_attn = F.silu(qk_attn) / max_seq_len + valid_attn_mask = _get_valid_attn_mask( + device=q.device, + causal=causal, + N=max_seq_len, + seq_lengths=seq_offsets[1:] - seq_offsets[:-1], + num_targets=num_targets, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + min_full_attn_seq_len=min_full_attn_seq_len, + ) + # raise NotImplementedError(valid_attn_mask[0, :, :].to(torch.int32)) + qk_attn = qk_attn * valid_attn_mask.unsqueeze(1) + if dropout_pr > 0.0: + qk_attn = F.dropout(qk_attn, p=dropout_pr, training=training) + attn_dense = torch.einsum("bhxd,bhdv->bhxv", qk_attn, v) # [B, H, N, V] + return torch.ops.fbgemm.dense_to_jagged( + attn_dense.transpose(1, 2).flatten(2, 3), # [B, N, H, V]->[B, N, H * V] + [seq_offsets], + L, + )[0].view(L, H, V) + + +@torch.fx.wrap +def pytorch_cached_hstu_mha( + max_seq_len: int, + alpha: float, + delta_q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + num_targets: Optional[torch.Tensor] = None, + max_attn_len: int = 0, + contextual_seq_len: int = 0, +) -> torch.Tensor: + L, H, D = delta_q.shape + _, _, V = v.shape + B = seq_offsets.size(0) - 1 + delta_size = L // B + delta_q = delta_q.view(B, -1, H, D).transpose(1, 2) + full_k = ( + torch.ops.fbgemm.jagged_to_padded_dense( + values=k.reshape(-1, H * D), + offsets=[seq_offsets], + max_lengths=[max_seq_len], + padding_value=0.0, + ) + .view(B, -1, H, D) + .transpose(1, 2) + ) + full_v = ( + torch.ops.fbgemm.jagged_to_padded_dense( + values=v.reshape(-1, H * V), + offsets=[seq_offsets], + max_lengths=[max_seq_len], + padding_value=0.0, + ) + .view(B, -1, H, V) + .transpose(1, 2) + ) + qk_attn = torch.einsum("bhxa,bhya->bhxy", delta_q, full_k) * alpha + qk_attn = F.silu(qk_attn) / max_seq_len + full_valid_attn_mask = _get_valid_attn_mask( + device=delta_q.device, + causal=True, + N=max_seq_len, + seq_lengths=seq_offsets[1:] - seq_offsets[:-1], + num_targets=num_targets, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + ) + seq_lengths = seq_offsets[1:] - seq_offsets[:-1] + mask = torch.arange(max_seq_len, device=delta_q.device).view(1, -1) + mask = torch.logical_and( + mask >= (seq_lengths - delta_size).view(-1, 1), + mask < seq_lengths.view(-1, 1), + ) + valid_attn_mask = ( + full_valid_attn_mask.expand(B, -1, -1) + .flatten(0, 1)[mask.view(-1), :] + .view(-1, delta_size, max_seq_len) + ) + qk_attn = qk_attn * valid_attn_mask.unsqueeze(1) + attn_output = torch.einsum("bhxd,bhdv->bhxv", qk_attn, full_v) + return attn_output.transpose(1, 2).reshape(-1, H, V) diff --git a/recommendation_v4/generative_recommenders/ops/pytorch/pt_hstu_linear.py b/recommendation_v4/generative_recommenders/ops/pytorch/pt_hstu_linear.py new file mode 100644 index 000000000..6ea94a565 --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/pytorch/pt_hstu_linear.py @@ -0,0 +1,130 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import torch +import torch.nn.functional as F + + +def pytorch_norm_mul_dropout( + x: torch.Tensor, + u: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + dropout_ratio: float, + training: bool, + silu_u: bool = False, + concat_u: bool = False, + concat_x: bool = False, + mul_u_activation_type: str = "none", + group_norm: bool = False, + num_heads: int = 1, + linear_dim: int = -1, +) -> torch.Tensor: + dtype = x.dtype + x = x.to(torch.float32) + u = u.to(torch.float32) + if group_norm: + if silu_u: + u = F.silu(u) + u = u.to(torch.float32) + y = u * F.group_norm( + x.view(-1, num_heads, linear_dim), + num_groups=num_heads, + weight=weight.to(torch.float32), + bias=bias.to(torch.float32), + eps=eps, + ).view(-1, num_heads * linear_dim) + if concat_u and concat_x: + y = torch.cat([u, x, y], dim=1) + else: + mul_u = u + if mul_u_activation_type == "sigmoid": + mul_u = torch.sigmoid(u) + elif mul_u_activation_type == "silu": + mul_u = F.silu(u) + y = mul_u * F.layer_norm( + x, + normalized_shape=(x.shape[-1],), + weight=weight.to(torch.float32), + bias=bias.to(torch.float32), + eps=eps, + ) + if concat_u: + if silu_u: + u = F.silu(u) + if concat_x: + y = torch.cat([u, x, y], dim=1) + else: + y = torch.cat([u, y], dim=1) + elif concat_x: + y = torch.cat([x, y], dim=1) + y = F.dropout( + y, + p=dropout_ratio, + training=training, + ) + return y.to(dtype) + + +def pytorch_hstu_compute_output( + attn: torch.Tensor, + u: torch.Tensor, + x: torch.Tensor, + norm_weight: torch.Tensor, + norm_bias: torch.Tensor, + output_weight: torch.Tensor, + eps: float, + dropout_ratio: float, + training: bool, + silu_u: bool = False, + concat_u: bool = False, + concat_x: bool = False, + mul_u_activation_type: str = "none", + group_norm: bool = False, + num_heads: int = 1, + linear_dim: int = -1, +) -> torch.Tensor: + dtype = x.dtype + y = pytorch_norm_mul_dropout( + x=attn, + u=u, + weight=norm_weight, + bias=norm_bias, + eps=eps, + dropout_ratio=dropout_ratio, + training=training, + silu_u=silu_u, + concat_u=concat_u, + concat_x=concat_x, + mul_u_activation_type=mul_u_activation_type, + group_norm=group_norm, + num_heads=num_heads, + linear_dim=linear_dim, + ) + return torch.addmm(x, y, output_weight.to(x.dtype)).to(dtype) + + +def pytorch_swiglu( + x: torch.Tensor, + w_gate: torch.Tensor, + w_up: torch.Tensor, +) -> torch.Tensor: + gate = F.silu(F.linear(x, w_gate)) + up = F.linear(x, w_up) + return gate * up diff --git a/recommendation_v4/generative_recommenders/ops/pytorch/pt_jagged.py b/recommendation_v4/generative_recommenders/ops/pytorch/pt_jagged.py new file mode 100644 index 000000000..82d82f402 --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/pytorch/pt_jagged.py @@ -0,0 +1,258 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from typing import Tuple + +import torch + + +try: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") +except OSError: + pass + + +def pytorch_jagged_dense_bmm( + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, +) -> torch.Tensor: + dtype = jagged.dtype + jagged = jagged.to(torch.float32) + dense = dense.to(torch.float32) + padded_jagged = torch.ops.fbgemm.jagged_to_padded_dense( + values=jagged, + offsets=[seq_offsets], + max_lengths=[max_seq_len], + padding_value=0.0, + ) + bmm_out = torch.bmm(padded_jagged, dense) + jagged_bmm_out = torch.ops.fbgemm.dense_to_jagged( + bmm_out, [seq_offsets], total_L=jagged.shape[0] + )[0] + jagged_bmm_out = jagged_bmm_out.to(dtype) + return jagged_bmm_out + + +def pytorch_jagged_dense_broadcast_add( + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, +) -> torch.Tensor: + dtype = jagged.dtype + jagged = jagged.to(torch.float32) + dense = dense.to(torch.float32) + padded_jagged = torch.ops.fbgemm.jagged_to_padded_dense( + values=jagged, + offsets=[seq_offsets], + max_lengths=[max_seq_len], + padding_value=0.0, + ) + out = padded_jagged + dense.unsqueeze(1) + jagged_out = torch.ops.fbgemm.dense_to_jagged( + out, [seq_offsets], total_L=jagged.shape[0] + )[0] + jagged_out = jagged_out.to(dtype) + return jagged_out + + +def pytorch_jagged_dense_bmm_add( + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, + bias: torch.Tensor, + elementwise: bool = False, +) -> torch.Tensor: + dtype = jagged.dtype + jagged = jagged.to(torch.float32) + dense = dense.to(torch.float32) + padded_jagged = torch.ops.fbgemm.jagged_to_padded_dense( + values=jagged, + offsets=[seq_offsets], + max_lengths=[max_seq_len], + padding_value=0.0, + ) + bmm_out = torch.bmm(padded_jagged, dense) + + if elementwise: + jagged_out = ( + torch.ops.fbgemm.dense_to_jagged( + bmm_out, [seq_offsets], total_L=jagged.shape[0] + )[0] + + bias + ) + else: + jagged_out = torch.ops.fbgemm.dense_to_jagged( + bmm_out + bias.unsqueeze(1), [seq_offsets], total_L=jagged.shape[0] + )[0] + + jagged_out = jagged_out.to(dtype) + return jagged_out + + +@torch.fx.wrap +def _arange(len: int, device: torch.device) -> torch.Tensor: + return torch.arange(len, device=device) + + +def pytorch_concat_2D_dense_jagged( + jagged_max_seq_len: int, + jagged_offsets: torch.Tensor, + jagged_values: torch.Tensor, + dense_values: torch.Tensor, +) -> torch.Tensor: + B, dense_size, D = dense_values.size() + jagged_dense = torch.ops.fbgemm.jagged_to_padded_dense( + values=jagged_values, + offsets=[jagged_offsets], + max_lengths=[jagged_max_seq_len], + padding_value=0.0, + ) + concatted_dense = torch.cat([dense_values, jagged_dense], dim=1) + concatted_offsets = ( + dense_size * _arange(B + 1, device=jagged_offsets.device) + jagged_offsets + ) + return torch.ops.fbgemm.dense_to_jagged( + concatted_dense, + [concatted_offsets], + total_L=jagged_values.shape[0] + dense_size * B, + )[0] + + +def pytorch_concat_2D_jagged_jagged( + max_seq_len_left: int, + offsets_left: torch.Tensor, + values_left: torch.Tensor, + max_seq_len_right: int, + offsets_right: torch.Tensor, + values_right: torch.Tensor, + is_replace: bool = False, + n_prefix_from_right: int = 0, +) -> torch.Tensor: + # is_replace with n_prefix_from_right != 0 is not supported yet (neither in triton) + if is_replace: + return pytorch_replace_last_n_with_jagged( + max_seq_len_left, + offsets_left, + values_left, + offsets_right, + values_right, + ) + + lengths_a = offsets_left[1:] - offsets_left[:-1] + lengths_b = offsets_right[1:] - offsets_right[:-1] + + # Compute output offsets via cumsum (no dynamic shapes). + output_lengths = lengths_a + lengths_b + output_offsets = torch.nn.functional.pad( + torch.cumsum(output_lengths, dim=0), (1, 0) + ) + + total_len = values_left.shape[0] + values_right.shape[0] + positions = torch.arange(total_len, device=values_left.device) + batch_idx = torch.searchsorted(output_offsets[1:], positions, right=True) + local_pos = positions - output_offsets[batch_idx] + + per_batch_lengths_a = lengths_a[batch_idx] + + # Classify each output position into prefix / left / suffix. + is_prefix = local_pos < n_prefix_from_right + is_left = (local_pos >= n_prefix_from_right) & ( + local_pos < n_prefix_from_right + per_batch_lengths_a + ) + + # Pad with a sentinel zero row so index_select works on empty tensors + values_left_safe = torch.nn.functional.pad(values_left, (0, 0, 0, 1)) + values_right_safe = torch.nn.functional.pad(values_right, (0, 0, 0, 1)) + + left_idx = (offsets_left[batch_idx] + (local_pos - n_prefix_from_right)).clamp( + min=0, max=values_left.shape[0] + ) + right_prefix_idx = offsets_right[batch_idx] + local_pos + right_suffix_idx = offsets_right[batch_idx] + (local_pos - per_batch_lengths_a) + right_idx = torch.where(is_prefix, right_prefix_idx, right_suffix_idx).clamp( + min=0, max=values_right.shape[0] + ) + + left_values = values_left_safe.index_select(0, left_idx) + right_values = values_right_safe.index_select(0, right_idx) + + return torch.where(is_left.unsqueeze(-1), left_values, right_values) + + +def pytorch_jagged_remove_first_or_last_1D( + values: torch.Tensor, + lengths: torch.Tensor, + offsets: torch.Tensor, + max_seq_len: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + values = values.view(-1, 1) + shrunk_lengths = lengths - 1 + k_lengths = torch.stack([shrunk_lengths, torch.ones_like(lengths)], dim=1).view(-1) + q_lengths = torch.stack([torch.ones_like(lengths), shrunk_lengths], dim=1).view(-1) + all_indices = torch.arange( + start=0, end=q_lengths.numel(), device=values.device + ).reshape(-1, 2) + q_indices, k_indices = all_indices[:, 1], all_indices[:, 0] + values_no_first, _ = torch.ops.fbgemm.jagged_index_select( + values, q_lengths, q_indices + ) + values_no_last, _ = torch.ops.fbgemm.jagged_index_select( + values, k_lengths, k_indices + ) + return values_no_first.squeeze(), values_no_last.squeeze() + + +@torch.fx.wrap +def fx_apply_mask( + tensor: torch.Tensor, mask: torch.Tensor, fill_value: torch.Tensor +) -> torch.Tensor: + tensor[mask] = fill_value + return tensor + + +def pytorch_replace_last_n_with_jagged( + max_seq_len_left: int, + offsets_left: torch.Tensor, + values_left: torch.Tensor, + offsets_right: torch.Tensor, + values_right: torch.Tensor, +) -> torch.Tensor: + lengths_a = offsets_left[1:] - offsets_left[:-1] + lengths_b = offsets_right[1:] - offsets_right[:-1] + + total_len = values_left.shape[0] + positions = torch.arange(total_len, device=values_left.device) + batch_idx = torch.searchsorted(offsets_left[1:], positions, right=True) + local_pos = positions - offsets_left[batch_idx] + + # Positions >= (lengths_a - lengths_b) within each batch are in the replace zone. + threshold = lengths_a[batch_idx] - lengths_b[batch_idx] + in_replace_zone = local_pos >= threshold + + # Pad with a sentinel zero row so index_select works on empty tensors + values_right_safe = torch.nn.functional.pad(values_right, (0, 0, 0, 1)) + right_idx = (offsets_right[batch_idx] + (local_pos - threshold)).clamp( + min=0, max=values_right.shape[0] + ) + right_values = values_right_safe.index_select(0, right_idx) + return torch.where(in_replace_zone.unsqueeze(-1), right_values, values_left) diff --git a/recommendation_v4/generative_recommenders/ops/pytorch/pt_jagged_tensors.py b/recommendation_v4/generative_recommenders/ops/pytorch/pt_jagged_tensors.py new file mode 100644 index 000000000..27817f7fb --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/pytorch/pt_jagged_tensors.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from typing import Optional, Tuple + +import torch +from generative_recommenders.common import fx_arange + +try: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") +except OSError: + pass + + +def _concat_2D_jagged_jagged( + values_left: torch.Tensor, + values_right: torch.Tensor, + max_len_left: int, + max_len_right: int, + offsets_left: torch.Tensor, + offsets_right: torch.Tensor, +) -> torch.Tensor: + max_seq_len = max_len_left + max_len_right + lengths_left = offsets_left[1:] - offsets_left[:-1] + lengths_right = offsets_right[1:] - offsets_right[:-1] + padded_left = torch.ops.fbgemm.jagged_to_padded_dense( + values=values_left, + offsets=[offsets_left], + max_lengths=[max_len_left], + padding_value=0.0, + ) + padded_right = torch.ops.fbgemm.jagged_to_padded_dense( + values=values_right, + offsets=[offsets_right], + max_lengths=[max_len_right], + padding_value=0.0, + ) + concatted_dense = torch.cat([padded_left, padded_right], dim=1) + mask = fx_arange(max_seq_len, device=offsets_left.device).view(1, -1) + mask = torch.logical_or( + mask < lengths_left.view(-1, 1), + torch.logical_and( + mask >= max_len_left, + mask < max_len_left + lengths_right.view(-1, 1), + ), + ) + return concatted_dense.flatten(0, 1)[mask.view(-1), :] + + +@torch.fx.wrap +def pytorch_concat_2D_jagged( + values_left: torch.Tensor, + values_right: torch.Tensor, + max_len_left: Optional[int], + max_len_right: Optional[int], + offsets_left: Optional[torch.Tensor], + offsets_right: Optional[torch.Tensor], +) -> torch.Tensor: + if offsets_left is None: + assert max_len_left is not None + B = values_left.shape[0] // max_len_left + offsets_left_non_optional = max_len_left * torch.arange( + B + 1, device=values_left.device + ) + else: + offsets_left_non_optional = offsets_left + if offsets_right is None: + assert max_len_right is not None + B = values_right.shape[0] // max_len_right + offsets_right_non_optional = max_len_right * torch.arange( + B + 1, device=values_left.device + ) + else: + offsets_right_non_optional = offsets_right + max_len_left = ( + int( + (offsets_left_non_optional[1:] - offsets_left_non_optional[:-1]) + .max() + .item() + ) + if max_len_left is None + else max_len_left + ) + max_len_right = ( + int( + (offsets_right_non_optional[1:] - offsets_right_non_optional[:-1]) + .max() + .item() + ) + if max_len_right is None + else max_len_right + ) + return _concat_2D_jagged_jagged( + values_left=values_left, + values_right=values_right, + max_len_left=max_len_left, + max_len_right=max_len_right, + offsets_left=offsets_left_non_optional, + offsets_right=offsets_right_non_optional, + ) + + +def _split_2D_jagged_jagged( + max_seq_len: int, + values: torch.Tensor, + offsets_left: torch.Tensor, + offsets_right: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + offsets = offsets_left + offsets_right + padded_values = torch.ops.fbgemm.jagged_to_padded_dense( + values=values, + offsets=[offsets], + max_lengths=[max_seq_len], + padding_value=0.0, + ).flatten(0, 1) + lengths_left = offsets_left[1:] - offsets_left[:-1] + lengths_right = offsets_right[1:] - offsets_right[:-1] + mask = fx_arange(max_seq_len, device=values.device).view(1, -1) + mask_left = mask < lengths_left.view(-1, 1) + mask_right = torch.logical_and( + mask >= lengths_left.view(-1, 1), + mask < (lengths_left + lengths_right).view(-1, 1), + ) + return padded_values[mask_left.view(-1), :], padded_values[mask_right.view(-1), :] + + +@torch.fx.wrap +def pytorch_split_2D_jagged( + max_seq_len: int, + values: torch.Tensor, + max_len_left: Optional[int], + max_len_right: Optional[int], + offsets_left: Optional[torch.Tensor], + offsets_right: Optional[torch.Tensor], +) -> Tuple[torch.Tensor, torch.Tensor]: + if offsets_left is None: + assert max_len_left is not None + assert offsets_right is not None + offsets_left_non_optional = max_len_left * torch.arange( + offsets_right.shape[0], device=values.device + ) + else: + offsets_left_non_optional = offsets_left + if offsets_right is None: + assert max_len_right is not None + assert offsets_left is not None + offsets_right_non_optional = max_len_right * torch.arange( + offsets_left.shape[0], device=values.device + ) + else: + offsets_right_non_optional = offsets_right + return _split_2D_jagged_jagged( + max_seq_len=max_seq_len, + values=values, + offsets_left=offsets_left_non_optional, + offsets_right=offsets_right_non_optional, + ) + + +def pytorch_hstu_split_l2_embeddings( + max_seq_len: int, + x: torch.Tensor, + prefix_offsets: torch.Tensor, + l2_offsets: torch.Tensor, + contextual_seq_len: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + x_offsets = prefix_offsets + l2_offsets + x_lengths = x_offsets[1:] - x_offsets[:-1] + padded_x = torch.ops.fbgemm.jagged_to_padded_dense( + values=x, + offsets=[x_offsets], + max_lengths=[max_seq_len], + padding_value=0.0, + ).flatten(0, 1) + prefix_lengths = prefix_offsets[1:] - prefix_offsets[:-1] + mask = fx_arange(max_seq_len, device=x_offsets.device).view(1, -1) + mask_prefix = torch.logical_and( + mask >= contextual_seq_len, + mask < prefix_lengths.view(-1, 1) + contextual_seq_len, + ) + mask_l2 = torch.logical_or( + mask < contextual_seq_len, + torch.logical_and( + mask >= prefix_lengths.view(-1, 1) + contextual_seq_len, + mask < x_lengths.view(-1, 1), + ), + ) + return padded_x[mask_prefix.view(-1), :], padded_x[mask_l2.view(-1), :] + + +def pytorch_hstu_concat_l2_embeddings( + max_prefix_len: int, + prefix_x: torch.Tensor, + prefix_offsets: torch.Tensor, + max_l2_len: int, + l2_x: torch.Tensor, + l2_offsets: torch.Tensor, + contextual_seq_len: int, +) -> torch.Tensor: + padded_prefix_x = torch.ops.fbgemm.jagged_to_padded_dense( + values=prefix_x, + offsets=[prefix_offsets], + max_lengths=[max_prefix_len], + padding_value=0.0, + ) + padded_l2_x = torch.ops.fbgemm.jagged_to_padded_dense( + values=l2_x, + offsets=[l2_offsets], + max_lengths=[max_l2_len], + padding_value=0.0, + ) + padded_x = torch.cat( + [ + padded_l2_x[:, 0:contextual_seq_len, :], + padded_prefix_x, + padded_l2_x[:, contextual_seq_len:, :], + ], + dim=1, + ) + mask = fx_arange(max_prefix_len + max_l2_len, device=prefix_x.device).view(1, -1) + prefix_lengths = prefix_offsets[1:] - prefix_offsets[:-1] + l2_lengths = l2_offsets[1:] - l2_offsets[:-1] + mask = torch.logical_or( + mask < prefix_lengths.view(-1, 1) + contextual_seq_len, + torch.logical_and( + mask >= max_prefix_len + contextual_seq_len, + mask < max_prefix_len + l2_lengths.view(-1, 1), + ), + ) + return padded_x.flatten(0, 1)[mask.view(-1), :] diff --git a/recommendation_v4/generative_recommenders/ops/pytorch/pt_layer_norm.py b/recommendation_v4/generative_recommenders/ops/pytorch/pt_layer_norm.py new file mode 100644 index 000000000..0666212ce --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/pytorch/pt_layer_norm.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# pyre-strict + + +from typing import List + +import torch + + +def pytorch_layer_norm( + x: torch.Tensor, + normalized_shape: List[int], + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, +) -> torch.Tensor: + dtype = x.dtype + return torch.nn.functional.layer_norm( + x.to(torch.float32), + normalized_shape, + weight.to(torch.float32), + bias.to(torch.float32), + eps, + ).to(dtype) + + +def pytorch_rms_norm( + x: torch.Tensor, + normalized_shape: List[int], + weight: torch.Tensor, + eps: float, + silu: bool = False, +) -> torch.Tensor: + dtype = x.dtype + x_float = x.to(torch.float32) + normalized = torch.nn.functional.rms_norm( + x_float, + normalized_shape, + weight.to(torch.float32), + eps, + ) + if silu: + normalized = torch.nn.functional.silu(normalized) + return normalized.to(dtype) + + +def pytorch_swish_layer_norm( + x: torch.Tensor, + normalized_shape: List[int], + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, +) -> torch.Tensor: + dtype = x.dtype + x = x.to(torch.float32) + return ( + x + * torch.sigmoid( + torch.nn.functional.layer_norm( + x, + normalized_shape, + weight.to(torch.float32), + bias.to(torch.float32), + eps, + ) + ) + ).to(dtype) diff --git a/recommendation_v4/generative_recommenders/ops/pytorch/pt_position.py b/recommendation_v4/generative_recommenders/ops/pytorch/pt_position.py new file mode 100644 index 000000000..fced57e4f --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/pytorch/pt_position.py @@ -0,0 +1,142 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from typing import Optional + +import torch +from generative_recommenders.common import ( + fx_unwrap_optional_tensor, + jagged_to_padded_dense, +) + +try: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") +except OSError: + pass + + +@torch.fx.wrap +def torch_arange(end: int, device: torch.device) -> torch.Tensor: + return torch.arange(end, device=device) + + +@torch.fx.wrap +def _get_col_indices( + max_seq_len: int, + max_contextual_seq_len: int, + max_pos_ind: int, + seq_lengths: torch.Tensor, + num_targets: Optional[torch.Tensor], + interleave_targets: bool, +) -> torch.Tensor: + B = seq_lengths.size(0) + col_indices = torch.arange(max_seq_len, device=seq_lengths.device).expand( + B, max_seq_len + ) + if num_targets is not None: + if interleave_targets: + high_inds = seq_lengths - fx_unwrap_optional_tensor(num_targets) * 2 + else: + high_inds = seq_lengths - fx_unwrap_optional_tensor(num_targets) + col_indices = torch.clamp(col_indices, max=high_inds.view(-1, 1)) + col_indices = high_inds.view(-1, 1) - col_indices + else: + col_indices = seq_lengths.view(-1, 1) - col_indices + col_indices = col_indices + max_contextual_seq_len + col_indices = torch.clamp(col_indices, max=max_pos_ind - 1) + if max_contextual_seq_len > 0: + col_indices[:, :max_contextual_seq_len] = torch.arange( + 0, + max_contextual_seq_len, + device=col_indices.device, + dtype=col_indices.dtype, + ).view(1, -1) + return col_indices + + +def pytorch_add_timestamp_positional_embeddings( + seq_embeddings: torch.Tensor, + seq_offsets: torch.Tensor, + pos_embeddings: torch.Tensor, + ts_embeddings: torch.Tensor, + timestamps: torch.Tensor, + max_seq_len: int, + max_contextual_seq_len: int, + seq_lengths: torch.Tensor, + num_targets: Optional[torch.Tensor], + interleave_targets: bool, + time_bucket_fn: str, +) -> torch.Tensor: + max_pos_ind = int(pos_embeddings.size(0)) + # position encoding + pos_inds = _get_col_indices( + max_seq_len=max_seq_len, + max_contextual_seq_len=max_contextual_seq_len, + max_pos_ind=max_pos_ind, + seq_lengths=seq_lengths, + num_targets=num_targets, + interleave_targets=interleave_targets, + ) + B, _ = pos_inds.shape + # timestamp encoding + num_time_buckets = ts_embeddings.size(1) - 1 + time_bucket_increments = 60.0 + time_bucket_divisor = 1.0 + time_delta = 0 + timestamps = jagged_to_padded_dense( + values=timestamps.unsqueeze(-1), + offsets=[seq_offsets], + max_lengths=[max_seq_len], + padding_value=0.0, + ).squeeze(-1) + query_time = torch.gather( + timestamps, dim=1, index=(seq_lengths - 1).unsqueeze(1).clamp(min=0) + ) + ts = query_time - timestamps + ts = ts + time_delta + ts = ts.clamp(min=1e-6) / time_bucket_increments + if time_bucket_fn == "log": + ts = torch.log(ts) + else: + ts = torch.sqrt(ts) + ts = (ts / time_bucket_divisor).clamp(min=0).int() + ts = torch.clamp( + ts, + min=0, + max=num_time_buckets, + ) + position_embeddings = torch.index_select( + pos_embeddings, 0, pos_inds.reshape(-1) + ).view(B, max_seq_len, -1) + time_embeddings = torch.index_select(ts_embeddings, 0, ts.reshape(-1)).view( + B, max_seq_len, -1 + ) + padded_emb = torch.ops.fbgemm.jagged_to_padded_dense( + values=seq_embeddings, + offsets=[seq_offsets], + max_lengths=[max_seq_len], + padding_value=0.0, + ) + summed = padded_emb + (time_embeddings + position_embeddings).to( + seq_embeddings.dtype + ) + result, _ = torch.ops.fbgemm.dense_to_jagged( + summed, [seq_offsets], seq_embeddings.shape[0] + ) + return result diff --git a/recommendation_v4/generative_recommenders/ops/tests/fake_signature_test.py b/recommendation_v4/generative_recommenders/ops/tests/fake_signature_test.py new file mode 100644 index 000000000..fb9047454 --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/tests/fake_signature_test.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +""" +Tests to ensure fake and real implementations of triton functions +have the same function signatures. This is critical for PT2 compile compatibility. +""" + +import inspect +import unittest +from typing import Any, Callable, List + + +def get_custom_op_params(func: Callable[..., object]) -> List[str]: + """ + Get parameter names from a function, handling custom_op decorated functions. + + For maybe_register_custom_op decorated functions, inspect.signature may return + *args, **kwargs instead of the actual parameters. In this case, we need to + access the underlying schema to get the real parameter names. + """ + sig = inspect.signature(func) + params = list(sig.parameters.keys()) + + if params == ["args", "kwargs"]: + func_any: Any = func + if hasattr(func_any, "_opoverload"): + schema = func_any._opoverload._schema + return [arg.name for arg in schema.arguments] + + return params + + +class FakeSignatureTest(unittest.TestCase): + """Test to ensure fake and real implementations have the same function signatures.""" + + def test_triton_addmm_fwd_and_fake_have_same_signature(self) -> None: + """Verify triton_addmm_fwd and triton_addmm_fwd_fake have the same arguments.""" + from generative_recommenders.ops.triton.triton_addmm import ( + triton_addmm_fwd, + triton_addmm_fwd_fake, + ) + + real_params = get_custom_op_params(triton_addmm_fwd) + fake_params = get_custom_op_params(triton_addmm_fwd_fake) + + self.assertEqual( + real_params, + fake_params, + f"triton_addmm_fwd and triton_addmm_fwd_fake have different arguments.\n" + f"Real: {real_params}\n" + f"Fake: {fake_params}", + ) + + def test_maybe_triton_addmm_fwd_and_fake_have_same_signature(self) -> None: + """Verify maybe_triton_addmm_fwd and maybe_triton_addmm_fwd_fake have the same arguments.""" + from generative_recommenders.ops.triton.triton_addmm import ( + maybe_triton_addmm_fwd, + maybe_triton_addmm_fwd_fake, + ) + + real_params = get_custom_op_params(maybe_triton_addmm_fwd) + fake_params = get_custom_op_params(maybe_triton_addmm_fwd_fake) + + self.assertEqual( + real_params, + fake_params, + f"maybe_triton_addmm_fwd and maybe_triton_addmm_fwd_fake have different arguments.\n" + f"Real: {real_params}\n" + f"Fake: {fake_params}", + ) + + def test_triton_hstu_attention_fwd_and_fake_have_same_signature(self) -> None: + """Verify triton_hstu_attention_fwd and _triton_hstu_attention_fwd_fake have the same arguments.""" + from generative_recommenders.ops.triton.triton_hstu_attention import ( + _triton_hstu_attention_fwd_fake, + triton_hstu_attention_fwd, + ) + + real_params = get_custom_op_params(triton_hstu_attention_fwd) + fake_params = get_custom_op_params(_triton_hstu_attention_fwd_fake) + + self.assertEqual( + real_params, + fake_params, + f"triton_hstu_attention_fwd and _triton_hstu_attention_fwd_fake have different arguments.\n" + f"Real: {real_params}\n" + f"Fake: {fake_params}", + ) + + def test_triton_hstu_attention_bwd_and_fake_have_same_signature(self) -> None: + """Verify triton_hstu_attention_bwd and _triton_hstu_attention_bwd_fake have the same arguments.""" + from generative_recommenders.ops.triton.triton_hstu_attention import ( + _triton_hstu_attention_bwd_fake, + triton_hstu_attention_bwd, + ) + + real_params = get_custom_op_params(triton_hstu_attention_bwd) + fake_params = get_custom_op_params(_triton_hstu_attention_bwd_fake) + + self.assertEqual( + real_params, + fake_params, + f"triton_hstu_attention_bwd and _triton_hstu_attention_bwd_fake have different arguments.\n" + f"Real: {real_params}\n" + f"Fake: {fake_params}", + ) + + def test_triton_layer_norm_mul_dropout_fwd_impl_and_fake_have_same_signature( + self, + ) -> None: + """Verify _triton_layer_norm_mul_dropout_fwd_impl and its fake have the same arguments.""" + from generative_recommenders.ops.triton.triton_hstu_linear import ( + _triton_layer_norm_mul_dropout_fwd_impl, + _triton_layer_norm_mul_dropout_fwd_impl_fake, + ) + + real_params = get_custom_op_params(_triton_layer_norm_mul_dropout_fwd_impl) + fake_params = get_custom_op_params(_triton_layer_norm_mul_dropout_fwd_impl_fake) + + self.assertEqual( + real_params, + fake_params, + f"_triton_layer_norm_mul_dropout_fwd_impl and _triton_layer_norm_mul_dropout_fwd_impl_fake have different arguments.\n" + f"Real: {real_params}\n" + f"Fake: {fake_params}", + ) + + def test_triton_layer_norm_mul_dropout_bwd_impl_and_fake_have_same_signature( + self, + ) -> None: + """Verify _triton_layer_norm_mul_dropout_bwd_impl and its fake have the same arguments.""" + from generative_recommenders.ops.triton.triton_hstu_linear import ( + _triton_layer_norm_mul_dropout_bwd_impl, + _triton_layer_norm_mul_dropout_bwd_impl_fake, + ) + + real_params = get_custom_op_params(_triton_layer_norm_mul_dropout_bwd_impl) + fake_params = get_custom_op_params(_triton_layer_norm_mul_dropout_bwd_impl_fake) + + self.assertEqual( + real_params, + fake_params, + f"_triton_layer_norm_mul_dropout_bwd_impl and _triton_layer_norm_mul_dropout_bwd_impl_fake have different arguments.\n" + f"Real: {real_params}\n" + f"Fake: {fake_params}", + ) diff --git a/recommendation_v4/generative_recommenders/ops/tests/hstu_attention_test.py b/recommendation_v4/generative_recommenders/ops/tests/hstu_attention_test.py new file mode 100644 index 000000000..ef70b8cc2 --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/tests/hstu_attention_test.py @@ -0,0 +1,485 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import random +import unittest +from typing import Optional + +import torch +from generative_recommenders.common import ( + generate_sparse_seq_len, + gpu_unavailable, + HammerKernel, + set_dev_mode, +) +from generative_recommenders.ops.jagged_tensors import split_2D_jagged +from hypothesis import given, settings, strategies as st, Verbosity + + +def test_attn( + batch_size: int, + heads: int, + max_uih_len: int, + max_targets: int, + attn_dim: int, + hidden_dim: int, + causal: bool, + has_multiple_targets: bool, + has_max_attn_len: bool, + dtype: torch.dtype, + test_backward: bool, + ref_kernel: HammerKernel, + real_kernel: HammerKernel, + skip_comparisons: bool = False, + sparsity: float = -1.0, + contextual_seq_len: int = 0, + atol: Optional[float] = None, + rtol: Optional[float] = None, + enable_tma: bool = False, +) -> None: + set_dev_mode(True) + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + from generative_recommenders.ops.hstu_attention import hstu_mha + + alpha = 1.0 / (attn_dim**0.5) + if sparsity > 0.0: + lengths = generate_sparse_seq_len( + size=batch_size, + max_seq_len=max_uih_len, + sparsity=sparsity, + device=torch.device("cuda"), + ) + else: + lengths = torch.randint( + max_uih_len + 1, size=(batch_size,), device=torch.device("cuda") + ) + num_targets = torch.randint( + 1, max_targets + 1, size=(batch_size,), device=torch.device("cuda") + ) + lengths = lengths + num_targets + contextual_seq_len + max_seq_len = max_uih_len + max_targets + contextual_seq_len + if has_max_attn_len: + max_attn_len = random.randint(1, max_uih_len // 5) + else: + max_attn_len = 0 + seq_offsets = torch.zeros( + (batch_size + 1,), dtype=torch.int64, device=torch.device("cuda") + ) + seq_offsets[1:] = torch.cumsum(lengths, dim=0) + + L = int(seq_offsets[-1].item()) + q = ( + torch.empty((L, heads, attn_dim), dtype=dtype, device=torch.device("cuda")) + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + k = ( + torch.empty((L, heads, attn_dim), dtype=dtype, device=torch.device("cuda")) + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + v = ( + torch.empty((L, heads, hidden_dim), dtype=dtype, device=torch.device("cuda")) + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + + # ref implementation + ref_out = hstu_mha( + max_seq_len=max_seq_len, + alpha=alpha, + q=q, + k=k, + v=v, + seq_offsets=seq_offsets, + causal=causal, + num_targets=num_targets if has_multiple_targets else None, + dropout_pr=0.0, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + kernel=ref_kernel, + enable_tma=enable_tma, + ) + dout = torch.randn_like(ref_out) + ref_out.backward(dout) + + if skip_comparisons: + return + + # pyre-ignore + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + + # triton implementation + q = q.detach().clone().requires_grad_() + k = k.detach().clone().requires_grad_() + v = v.detach().clone().requires_grad_() + dout = dout.detach().clone() + real_out = hstu_mha( + max_seq_len=max_seq_len, + alpha=alpha, + q=q, + k=k, + v=v, + seq_offsets=seq_offsets, + causal=causal, + num_targets=num_targets if has_multiple_targets else None, + dropout_pr=0.0, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + kernel=real_kernel, + enable_tma=enable_tma, + ) + + torch.testing.assert_close( + ref_out, + real_out, + atol=atol, + rtol=rtol, + ) + if test_backward: + real_out.backward(dout) + real_dq, real_dk, real_dv = q.grad.clone(), k.grad.clone(), v.grad.clone() + torch.testing.assert_close(ref_dv, real_dv, atol=atol, rtol=rtol) + torch.testing.assert_close(ref_dk, real_dk, atol=atol, rtol=rtol) + torch.testing.assert_close(ref_dq, real_dq, atol=atol, rtol=rtol) + + +def test_delta_attn( + batch_size: int, + heads: int, + max_uih_len: int, + max_targets: int, + delta_size: int, + attn_dim: int, + hidden_dim: int, + has_multiple_targets: bool, + has_max_attn_len: bool, + dtype: torch.dtype, + ref_kernel: HammerKernel, + real_kernel: HammerKernel, + contextual_seq_len: int = 0, + atol: Optional[float] = None, + rtol: Optional[float] = None, + enable_tma: bool = False, +) -> None: + set_dev_mode(True) + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + from generative_recommenders.ops.hstu_attention import delta_hstu_mha + + alpha = 1.0 / (attn_dim**0.5) + lengths = torch.randint( + max_uih_len + 1, size=(batch_size,), device=torch.device("cuda") + ) + num_targets = torch.randint( + 1, delta_size + 1, size=(batch_size,), device=torch.device("cuda") + ) + lengths = lengths + delta_size + contextual_seq_len + max_seq_len = max_uih_len + delta_size + contextual_seq_len + if has_max_attn_len: + max_attn_len = random.randint(1, max_uih_len // 5) + else: + max_attn_len = 0 + seq_offsets = torch.zeros( + (batch_size + 1,), dtype=torch.int64, device=torch.device("cuda") + ) + seq_offsets[1:] = torch.cumsum(lengths, dim=0) + + L = int(seq_offsets[-1].item()) + delta_q = torch.empty( + (batch_size * delta_size, heads, attn_dim), + dtype=dtype, + device=torch.device("cuda"), + ).uniform_(-0.1, 0.1) + k = torch.empty( + (L, heads, attn_dim), dtype=dtype, device=torch.device("cuda") + ).uniform_(-0.1, 0.1) + v = torch.empty( + (L, heads, hidden_dim), dtype=dtype, device=torch.device("cuda") + ).uniform_(-0.1, 0.1) + + # ref implementation + ref_out = delta_hstu_mha( + max_seq_len=max_seq_len, + alpha=alpha, + delta_q=delta_q, + k=k, + v=v, + seq_offsets=seq_offsets, + num_targets=num_targets if has_multiple_targets else None, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + kernel=ref_kernel, + enable_tma=enable_tma, + ) + + # real implementation + real_out = delta_hstu_mha( + max_seq_len=max_seq_len, + alpha=alpha, + delta_q=delta_q, + k=k, + v=v, + seq_offsets=seq_offsets, + num_targets=num_targets if has_multiple_targets else None, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + kernel=real_kernel, + enable_tma=enable_tma, + ) + torch.testing.assert_close( + ref_out, + real_out, + atol=atol, + rtol=rtol, + ) + + +class HSTUAttentionTest(unittest.TestCase): + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore + @given( + batch_size=st.integers(4, 8), + heads=st.integers(1, 4), + max_uih_len=st.sampled_from([20, 100, 128, 256]), + max_targets=st.sampled_from([20, 512]), + attn_dim=st.sampled_from([16, 32, 64, 128]), + hidden_dim=st.sampled_from([16, 32, 64, 128]), + causal=st.sampled_from([True]), + has_multiple_targets=st.sampled_from([True, False]), + dtype=st.sampled_from( + [torch.bfloat16, torch.float32] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + has_max_attn_len=st.sampled_from([True, False]), + contextual_seq_len=st.sampled_from([0, 10]), + ) + @settings( + verbosity=Verbosity.verbose, + max_examples=200, + deadline=None, + ) + # pyre-ignore[2] + def test_attn_triton(self, *args, **kwargs) -> None: + test_attn( + *args, + **kwargs, + test_backward=True, + ref_kernel=HammerKernel.PYTORCH, + real_kernel=HammerKernel.TRITON, + ) + + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore + @given( + batch_size=st.just(64), + heads=st.just(4), + max_uih_len=st.sampled_from([32768]), + max_targets=st.sampled_from([32]), + attn_dim=st.just(128), + hidden_dim=st.just(128), + causal=st.sampled_from([True]), + has_multiple_targets=st.sampled_from([True, False]), + dtype=st.sampled_from([torch.bfloat16]), + has_max_attn_len=st.sampled_from([True, False]), + ) + @settings( + verbosity=Verbosity.verbose, + max_examples=5, + deadline=None, + ) + # pyre-ignore[2] + def test_attn_triton_long_seqs(self, *args, **kwargs) -> None: + test_attn( + *args, + **kwargs, + test_backward=True, + ref_kernel=HammerKernel.TRITON, + real_kernel=HammerKernel.TRITON, + skip_comparisons=True, + sparsity=1.0, + ) + + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore + @given( + batch_size=st.integers(4, 8), + heads=st.integers(1, 4), + max_uih_len=st.sampled_from([100, 128, 256]), + max_targets=st.sampled_from([20, 512]), + delta_size=st.sampled_from([20, 512]), + attn_dim=st.sampled_from([16, 32, 64, 128]), + hidden_dim=st.sampled_from([16, 32, 64, 128]), + has_multiple_targets=st.sampled_from([True, False]), + dtype=st.sampled_from( + [torch.bfloat16, torch.float32] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + has_max_attn_len=st.sampled_from([False, True]), + contextual_seq_len=st.sampled_from([0, 10]), + ) + @settings( + verbosity=Verbosity.verbose, + max_examples=200, + deadline=None, + ) + # pyre-ignore[2] + def test_delta_attn_triton(self, *args, **kwargs) -> None: + test_delta_attn( + *args, + **kwargs, + ref_kernel=HammerKernel.PYTORCH, + real_kernel=HammerKernel.TRITON, + ) + + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore + @given( + batch_size=st.integers(4, 8), + heads=st.integers(1, 4), + max_uih_len=st.sampled_from([20, 100, 128]), + max_targets=st.sampled_from([20, 512]), + delta_size=st.sampled_from([20, 512]), + attn_dim=st.sampled_from([16, 32, 64]), + hidden_dim=st.sampled_from([16, 32, 64]), + has_multiple_targets=st.sampled_from([True, False]), + dtype=st.sampled_from( + [torch.bfloat16, torch.float32] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + has_max_attn_len=st.sampled_from([False, True]), + contextual_seq_len=st.sampled_from([0, 10]), + ) + @settings( + verbosity=Verbosity.verbose, + max_examples=200, + deadline=None, + ) + def test_cache( + self, + batch_size: int, + heads: int, + max_uih_len: int, + max_targets: int, + delta_size: int, + attn_dim: int, + hidden_dim: int, + has_multiple_targets: bool, + dtype: torch.dtype, + has_max_attn_len: bool, + contextual_seq_len: int, + ) -> None: + set_dev_mode(True) + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + from generative_recommenders.ops.hstu_attention import delta_hstu_mha, hstu_mha + + alpha = 1.0 / (attn_dim**0.5) + lengths = torch.randint( + max_uih_len + 1, size=(batch_size,), device=torch.device("cuda") + ) + num_targets = torch.randint( + 1, delta_size + 1, size=(batch_size,), device=torch.device("cuda") + ) + lengths = lengths + delta_size + contextual_seq_len + max_seq_len = max_uih_len + delta_size + contextual_seq_len + if has_max_attn_len: + max_attn_len = random.randint(1, max_uih_len // 5) + else: + max_attn_len = 0 + seq_offsets = torch.zeros( + (batch_size + 1,), dtype=torch.int64, device=torch.device("cuda") + ) + seq_offsets[1:] = torch.cumsum(lengths, dim=0) + + L = int(seq_offsets[-1].item()) + q = torch.empty( + (L, heads, attn_dim), + dtype=dtype, + device=torch.device("cuda"), + ).uniform_(-0.1, 0.1) + _, delta_q = split_2D_jagged( + max_seq_len=max_seq_len, + values=q.view(-1, heads * attn_dim), + max_len_left=None, + max_len_right=delta_size, + offsets_left=torch.ops.fbgemm.asynchronous_complete_cumsum( + lengths - delta_size + ), + offsets_right=None, + kernel=HammerKernel.TRITON, + ) + delta_q = delta_q.view(-1, heads, attn_dim) + k = torch.empty( + (L, heads, attn_dim), dtype=dtype, device=torch.device("cuda") + ).uniform_(-0.1, 0.1) + v = torch.empty( + (L, heads, hidden_dim), dtype=dtype, device=torch.device("cuda") + ).uniform_(-0.1, 0.1) + prime_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + lengths - delta_size + ) + + # ref implementation + ref_out = hstu_mha( + max_seq_len=max_seq_len, + alpha=alpha, + q=q, + k=k, + v=v, + seq_offsets=seq_offsets, + causal=True, + num_targets=num_targets if has_multiple_targets else None, + dropout_pr=0.0, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + kernel=HammerKernel.TRITON, + ) + _, delta_out = split_2D_jagged( + max_seq_len=max_seq_len, + values=ref_out.view(-1, heads * hidden_dim), + max_len_left=None, + max_len_right=delta_size, + offsets_left=prime_offsets, + offsets_right=None, + kernel=HammerKernel.TRITON, + ) + delta_out = delta_out.view(-1, heads, hidden_dim) + + # real implementation + real_delta_out = delta_hstu_mha( + max_seq_len=max_seq_len, + alpha=alpha, + delta_q=delta_q, + k=k, + v=v, + seq_offsets=seq_offsets, + num_targets=num_targets if has_multiple_targets else None, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + ) + torch.testing.assert_close( + delta_out, + real_delta_out, + ) diff --git a/recommendation_v4/generative_recommenders/ops/tests/hstu_attention_tma_test.py b/recommendation_v4/generative_recommenders/ops/tests/hstu_attention_tma_test.py new file mode 100644 index 000000000..8bb264af6 --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/tests/hstu_attention_tma_test.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import random +import unittest + +import torch +from generative_recommenders.common import ( + HammerKernel, + nv_gpu_unavailable, + set_dev_mode, +) +from generative_recommenders.ops.jagged_tensors import split_2D_jagged +from generative_recommenders.ops.tests.hstu_attention_test import ( + test_attn, + test_delta_attn, +) +from hypothesis import given, settings, strategies as st, Verbosity + + +class HSTUAttentionTmaTest(unittest.TestCase): + @unittest.skipIf(*nv_gpu_unavailable) + # pyre-ignore + @given( + batch_size=st.integers(4, 8), + heads=st.integers(1, 4), + max_uih_len=st.sampled_from([20, 100, 128, 256]), + max_targets=st.sampled_from([20, 512]), + attn_dim=st.sampled_from([16, 32, 64, 128]), + hidden_dim=st.sampled_from([16, 32, 64, 128]), + causal=st.sampled_from([True]), + has_multiple_targets=st.sampled_from([True, False]), + dtype=st.sampled_from( + [torch.bfloat16, torch.float32] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + has_max_attn_len=st.sampled_from([True, False]), + contextual_seq_len=st.sampled_from([0, 10]), + ) + @settings( + verbosity=Verbosity.verbose, + max_examples=200, + deadline=None, + ) + # pyre-ignore[2] + def test_attn_triton_tma(self, *args, **kwargs) -> None: + test_attn( + *args, + **kwargs, + test_backward=True, + ref_kernel=HammerKernel.PYTORCH, + real_kernel=HammerKernel.TRITON, + enable_tma=True, + ) + + @unittest.skipIf(*nv_gpu_unavailable) + # pyre-ignore + @given( + batch_size=st.just(64), + heads=st.just(4), + max_uih_len=st.sampled_from([32768]), + max_targets=st.sampled_from([32]), + attn_dim=st.just(128), + hidden_dim=st.just(128), + causal=st.sampled_from([True]), + has_multiple_targets=st.sampled_from([True, False]), + dtype=st.sampled_from([torch.bfloat16]), + has_max_attn_len=st.sampled_from([True, False]), + ) + @settings( + verbosity=Verbosity.verbose, + max_examples=5, + deadline=None, + ) + # pyre-ignore[2] + def test_attn_triton_long_seqs_tma(self, *args, **kwargs) -> None: + test_attn( + *args, + **kwargs, + test_backward=True, + ref_kernel=HammerKernel.TRITON, + real_kernel=HammerKernel.TRITON, + skip_comparisons=True, + sparsity=1.0, + enable_tma=True, + ) + + @unittest.skipIf(*nv_gpu_unavailable) + # pyre-ignore + @given( + batch_size=st.integers(4, 8), + heads=st.integers(1, 4), + max_uih_len=st.sampled_from([100, 128, 256]), + max_targets=st.sampled_from([20, 512]), + delta_size=st.sampled_from([20, 512]), + attn_dim=st.sampled_from([16, 32, 64, 128]), + hidden_dim=st.sampled_from([16, 32, 64, 128]), + has_multiple_targets=st.sampled_from([True, False]), + dtype=st.sampled_from( + [torch.bfloat16, torch.float32] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + has_max_attn_len=st.sampled_from([False, True]), + contextual_seq_len=st.sampled_from([0, 10]), + ) + @settings( + verbosity=Verbosity.verbose, + max_examples=200, + deadline=None, + ) + # pyre-ignore[2] + def test_delta_attn_triton_tma(self, *args, **kwargs) -> None: + test_delta_attn( + *args, + **kwargs, + ref_kernel=HammerKernel.PYTORCH, + real_kernel=HammerKernel.TRITON, + enable_tma=True, + ) + + @unittest.skipIf(*nv_gpu_unavailable) + # pyre-ignore + @given( + batch_size=st.integers(4, 8), + heads=st.integers(1, 4), + max_uih_len=st.sampled_from([20, 100, 128]), + max_targets=st.sampled_from([20, 512]), + delta_size=st.sampled_from([20, 512]), + attn_dim=st.sampled_from([16, 32, 64]), + hidden_dim=st.sampled_from([16, 32, 64]), + has_multiple_targets=st.sampled_from([True, False]), + dtype=st.sampled_from( + [torch.bfloat16, torch.float32] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + has_max_attn_len=st.sampled_from([False, True]), + contextual_seq_len=st.sampled_from([0, 10]), + ) + @settings( + verbosity=Verbosity.verbose, + max_examples=200, + deadline=None, + ) + def test_cache_tma( + self, + batch_size: int, + heads: int, + max_uih_len: int, + max_targets: int, + delta_size: int, + attn_dim: int, + hidden_dim: int, + has_multiple_targets: bool, + dtype: torch.dtype, + has_max_attn_len: bool, + contextual_seq_len: int, + ) -> None: + set_dev_mode(True) + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + from generative_recommenders.ops.hstu_attention import delta_hstu_mha, hstu_mha + + alpha = 1.0 / (attn_dim**0.5) + lengths = torch.randint( + max_uih_len + 1, size=(batch_size,), device=torch.device("cuda") + ) + num_targets = torch.randint( + 1, delta_size + 1, size=(batch_size,), device=torch.device("cuda") + ) + lengths = lengths + delta_size + contextual_seq_len + max_seq_len = max_uih_len + delta_size + contextual_seq_len + if has_max_attn_len: + max_attn_len = random.randint(1, max_uih_len // 5) + else: + max_attn_len = 0 + seq_offsets = torch.zeros( + (batch_size + 1,), dtype=torch.int64, device=torch.device("cuda") + ) + seq_offsets[1:] = torch.cumsum(lengths, dim=0) + + L = int(seq_offsets[-1].item()) + q = torch.empty( + (L, heads, attn_dim), + dtype=dtype, + device=torch.device("cuda"), + ).uniform_(-0.1, 0.1) + _, delta_q = split_2D_jagged( + max_seq_len=max_seq_len, + values=q.view(-1, heads * attn_dim), + max_len_left=None, + max_len_right=delta_size, + offsets_left=torch.ops.fbgemm.asynchronous_complete_cumsum( + lengths - delta_size + ), + offsets_right=None, + kernel=HammerKernel.TRITON, + ) + delta_q = delta_q.view(-1, heads, attn_dim) + k = torch.empty( + (L, heads, attn_dim), dtype=dtype, device=torch.device("cuda") + ).uniform_(-0.1, 0.1) + v = torch.empty( + (L, heads, hidden_dim), dtype=dtype, device=torch.device("cuda") + ).uniform_(-0.1, 0.1) + prime_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + lengths - delta_size + ) + + # ref implementation + ref_out = hstu_mha( + max_seq_len=max_seq_len, + alpha=alpha, + q=q, + k=k, + v=v, + seq_offsets=seq_offsets, + causal=True, + num_targets=num_targets if has_multiple_targets else None, + dropout_pr=0.0, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + kernel=HammerKernel.TRITON, + enable_tma=True, + ) + _, delta_out = split_2D_jagged( + max_seq_len=max_seq_len, + values=ref_out.view(-1, heads * hidden_dim), + max_len_left=None, + max_len_right=delta_size, + offsets_left=prime_offsets, + offsets_right=None, + kernel=HammerKernel.TRITON, + ) + delta_out = delta_out.view(-1, heads, hidden_dim) + + # real implementation + real_delta_out = delta_hstu_mha( + max_seq_len=max_seq_len, + alpha=alpha, + delta_q=delta_q, + k=k, + v=v, + seq_offsets=seq_offsets, + num_targets=num_targets if has_multiple_targets else None, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + enable_tma=True, + ) + torch.testing.assert_close( + delta_out, + real_delta_out, + ) diff --git a/recommendation_v4/generative_recommenders/ops/tests/hstu_compute_test.py b/recommendation_v4/generative_recommenders/ops/tests/hstu_compute_test.py new file mode 100644 index 000000000..57f217895 --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/tests/hstu_compute_test.py @@ -0,0 +1,503 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import random +import unittest +from typing import Optional + +import torch +from generative_recommenders.common import ( + generate_sparse_seq_len, + gpu_unavailable, + HammerKernel, + set_dev_mode, +) +from hypothesis import given, settings, strategies as st, Verbosity + + +class HSTUComputeTest(unittest.TestCase): + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore[56] + @given( + N=st.integers(min_value=1000, max_value=1000), + D=st.integers(min_value=128, max_value=128), + L=st.integers(min_value=512, max_value=512), + concat_u=st.booleans(), + concat_x=st.booleans(), + mul_u_activation_type=st.sampled_from(["silu", "sigmoid", "none"]), + group_norm=st.booleans(), + num_heads=st.sampled_from([4]), + training=st.just(False), + recompute_y_in_backward=st.sampled_from([True, False]), + dtype=st.sampled_from( + [torch.float32, torch.bfloat16] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + ) + @settings( + deadline=None, + verbosity=Verbosity.verbose, + max_examples=20, + ) + # pyre-ignore[2] + def test_compute_output(self, *args, **kwargs) -> None: + self._test_compute_output( + *args, + **kwargs, + test_backward=True, + ref_kernel=HammerKernel.PYTORCH, + opt_kernel=HammerKernel.TRITON, + ) + + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore[56] + @given( + N=st.just(1500000), + D=st.just(512), + L=st.just(512), + concat_u=st.sampled_from([True]), + concat_x=st.sampled_from([True]), + mul_u_activation_type=st.sampled_from(["none"]), + group_norm=st.sampled_from([False]), + num_heads=st.sampled_from([4]), + training=st.just(False), + recompute_y_in_backward=st.sampled_from([False]), + dtype=st.just(torch.bfloat16), + ) + @settings( + deadline=None, + verbosity=Verbosity.verbose, + max_examples=1, + ) + # pyre-ignore[2] + def test_long_sequences_compute_output(self, *args, **kwargs) -> None: + self._test_compute_output( + *args, + **kwargs, + test_backward=False, + ref_kernel=HammerKernel.TRITON, + opt_kernel=HammerKernel.TRITON, + skip_comparisons=True, + ) + + def _test_compute_output( + self, + N: int, + D: int, + L: int, + concat_u: bool, + concat_x: bool, + mul_u_activation_type: str, + group_norm: bool, + num_heads: int, + training: bool, + recompute_y_in_backward: bool, + dtype: torch.dtype, + test_backward: bool, + ref_kernel: HammerKernel, + opt_kernel: HammerKernel, + skip_comparisons: bool = False, + atol: Optional[float] = None, + rtol: Optional[float] = None, + ) -> None: + from generative_recommenders.ops.hstu_compute import hstu_compute_output + + torch.manual_seed(0) + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + dropout_ratio = 0.3 if training else 0.0 + attn = ( + torch.empty((N, L), dtype=dtype, device=torch.device("cuda")) + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + u = ( + torch.empty((N, L), dtype=dtype, device=torch.device("cuda")) + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + norm_weight = ( + torch.empty( + (L if not group_norm else num_heads,), + dtype=dtype, + device=torch.device("cuda"), + ) + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + norm_bias = ( + torch.empty( + (L if not group_norm else num_heads,), + dtype=dtype, + device=torch.device("cuda"), + ) + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + norm_eps = 1e-6 + # When group_norm=True, only concat_ux = concat_u and concat_x is supported + if group_norm: + L_mult = 3 if (concat_u and concat_x) else 1 + else: + L_mult = 1 + if concat_u: + L_mult += 1 + if concat_x: + L_mult += 1 + output_weight = ( + torch.empty((L * L_mult, D), dtype=dtype, device=torch.device("cuda")) + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + x = ( + torch.empty((N, D), dtype=dtype, device=torch.device("cuda")) + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + + # ref + ref_out = hstu_compute_output( + attn=attn, + u=u, + x=x, + norm_weight=norm_weight, + norm_bias=norm_bias, + norm_eps=norm_eps, + dropout_ratio=dropout_ratio, + output_weight=output_weight, + concat_u=concat_u, + concat_x=concat_x, + mul_u_activation_type=mul_u_activation_type, + group_norm=group_norm, + num_heads=num_heads, + linear_dim=L // num_heads, + training=training, + recompute_y_in_backward=recompute_y_in_backward, + kernel=ref_kernel, + ) + dout = torch.randn_like(ref_out) * 0.1 + ref_out.backward(dout) + if skip_comparisons: + return + # pyre-ignore[16] + ref_dattn, attn.grad = attn.grad.detach().clone(), None + ref_du, u.grad = u.grad.detach().clone(), None + ref_d_norm_w, norm_weight.grad = norm_weight.grad.detach().clone(), None + ref_d_norm_b, norm_bias.grad = norm_bias.grad.detach().clone(), None + ref_dx, x.grad = x.grad.detach().clone(), None + ref_d_output_w, output_weight.grad = output_weight.grad.detach().clone(), None + + # opt + attn = attn.detach().clone().requires_grad_() + u = u.detach().clone().requires_grad_() + norm_weight = norm_weight.detach().clone().requires_grad_() + norm_bias = norm_bias.detach().clone().requires_grad_() + output_weight = output_weight.detach().clone().requires_grad_() + x = x.detach().clone().requires_grad_() + opt_out = hstu_compute_output( + attn=attn, + u=u, + x=x, + norm_weight=norm_weight, + norm_bias=norm_bias, + norm_eps=norm_eps, + dropout_ratio=dropout_ratio, + output_weight=output_weight, + concat_u=concat_u, + concat_x=concat_x, + mul_u_activation_type=mul_u_activation_type, + group_norm=group_norm, + num_heads=num_heads, + linear_dim=L // num_heads, + training=training, + recompute_y_in_backward=recompute_y_in_backward, + kernel=opt_kernel, + ) + torch.testing.assert_close( + ref_out, + opt_out, + atol=atol, + rtol=rtol, + ) + + if test_backward: + dout = dout.detach().clone() + opt_out.backward(dout) + opt_dattn, attn.grad = attn.grad.detach().clone(), None + opt_du, u.grad = u.grad.detach().clone(), None + opt_d_norm_w, norm_weight.grad = norm_weight.grad.detach().clone(), None + opt_d_norm_b, norm_bias.grad = norm_bias.grad.detach().clone(), None + opt_dx, x.grad = x.grad.detach().clone(), None + opt_d_output_w, output_weight.grad = ( + output_weight.grad.detach().clone(), + None, + ) + torch.testing.assert_close(ref_du, opt_du) + torch.testing.assert_close(ref_dattn, opt_dattn) + torch.testing.assert_close(ref_d_norm_w, opt_d_norm_w) + torch.testing.assert_close(ref_d_norm_b, opt_d_norm_b) + torch.testing.assert_close(ref_dx, opt_dx) + torch.testing.assert_close(ref_d_output_w, opt_d_output_w) + + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore + @given( + batch_size=st.integers(4, 8), + heads=st.integers(1, 4), + max_uih_len=st.sampled_from([100, 128, 256, 1300]), + max_targets=st.sampled_from([20, 512]), + embedding_dim=st.sampled_from([16, 32, 64]), + attn_dim=st.sampled_from([16, 32, 64, 128]), + hidden_dim=st.sampled_from([16, 32, 64, 128]), + causal=st.sampled_from([True]), + has_multiple_targets=st.sampled_from([True, False]), + dtype=st.sampled_from( + [torch.float32] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + contextual_seq_len=st.sampled_from([0]), + has_max_attn_len=st.sampled_from([False, True]), + sort_by_length=st.sampled_from([True, False]), + recompute_uvqk_in_backward=st.sampled_from([True, False]), + recompute_normed_x_in_backward=st.sampled_from([True, False]), + ) + @settings( + verbosity=Verbosity.verbose, + max_examples=150, + deadline=None, + ) + # pyre-ignore[2] + def test_preprocess_and_attention(self, *args, **kwargs) -> None: + self._test_hstu_preprocess_and_attention( + *args, + **kwargs, + test_backward=True, + ref_kernel=HammerKernel.PYTORCH, + real_kernel=HammerKernel.TRITON, + ) + + def _test_hstu_preprocess_and_attention( + self, + batch_size: int, + heads: int, + max_uih_len: int, + max_targets: int, + embedding_dim: int, + attn_dim: int, + hidden_dim: int, + causal: bool, + has_multiple_targets: bool, + has_max_attn_len: bool, + dtype: torch.dtype, + ref_kernel: HammerKernel, + real_kernel: HammerKernel, + test_backward: bool, + contextual_seq_len: int, + sort_by_length: bool, + recompute_uvqk_in_backward: bool, + recompute_normed_x_in_backward: bool, + sparsity: float = -1.0, + atol: Optional[float] = None, + rtol: Optional[float] = None, + ) -> None: + set_dev_mode(True) + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + + from generative_recommenders.ops.hstu_compute import ( + hstu_preprocess_and_attention, + ) + + alpha = 1.0 / (attn_dim**0.5) + if sparsity > 0.0: + lengths = generate_sparse_seq_len( + size=batch_size, + max_seq_len=max_uih_len, + sparsity=sparsity, + device=torch.device("cuda"), + ) + else: + lengths = torch.randint( + max_uih_len + 1, size=(batch_size,), device=torch.device("cuda") + ) + + num_targets = torch.randint( + max_targets + 1, size=(batch_size,), device=torch.device("cuda") + ) + lengths = lengths + num_targets + max_seq_len = max_uih_len + max_targets + if has_max_attn_len: + max_attn_len = random.randint(1, max_uih_len // 5) + else: + max_attn_len = 0 + seq_offsets = torch.zeros( + (batch_size + 1,), dtype=torch.int64, device=torch.device("cuda") + ) + seq_offsets[1:] = torch.cumsum(lengths, dim=0) + + L = int(seq_offsets[-1].item()) + + x = ( + torch.empty((L, embedding_dim), dtype=dtype, device=torch.device("cuda")) + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + norm_weight = ( + torch.empty((embedding_dim,), dtype=dtype, device=torch.device("cuda")) + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + norm_bias = ( + torch.empty( + (embedding_dim,), + dtype=dtype, + device=torch.device("cuda"), + ) + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + norm_eps = 1e-6 + uvqk_weight = ( + torch.empty( + ( + embedding_dim, + (hidden_dim * 2 + attn_dim * 2) * heads, + ), + dtype=dtype, + device=torch.device("cuda"), + ) + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + uvqk_bias = ( + torch.empty( + (hidden_dim * 2 + attn_dim * 2) * heads, + dtype=dtype, + device=torch.device("cuda"), + ) + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + + # ref implementation + ref_u, ref_attn_output, _, _ = hstu_preprocess_and_attention( + x=x, + norm_weight=norm_weight, + norm_bias=norm_bias, + norm_eps=norm_eps, + num_heads=heads, + attn_dim=attn_dim, + hidden_dim=hidden_dim, + uvqk_weight=uvqk_weight, + uvqk_bias=uvqk_bias, + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + attn_alpha=alpha, + causal=causal, + num_targets=num_targets, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + recompute_uvqk_in_backward=recompute_uvqk_in_backward, + recompute_normed_x_in_backward=recompute_normed_x_in_backward, + sort_by_length=sort_by_length, + kernel=ref_kernel, + ) + ref_out = ref_u + ref_attn_output + dout = torch.randn_like(ref_out) * 0.01 + ref_out.backward(dout) + + # pyre-ignore + ref_dx, x.grad = x.grad.clone(), None + ref_d_norm_weight, norm_weight.grad = norm_weight.grad.clone(), None + ref_d_norm_bias, norm_bias.grad = norm_bias.grad.clone(), None + ref_d_uvqk_weight, uvqk_weight.grad = uvqk_weight.grad.clone(), None + ref_d_uvqk_bias, uvqk_bias.grad = uvqk_bias.grad.clone(), None + + # real implementation + x = x.detach().clone().requires_grad_() + norm_weight = norm_weight.detach().clone().requires_grad_() + norm_bias = norm_bias.detach().clone().requires_grad_() + uvqk_weight = uvqk_weight.detach().clone().requires_grad_() + uvqk_bias = uvqk_bias.detach().clone().requires_grad_() + real_u, real_attn_output, _, _ = hstu_preprocess_and_attention( + x=x, + norm_weight=norm_weight, + norm_bias=norm_bias, + norm_eps=norm_eps, + num_heads=heads, + attn_dim=attn_dim, + hidden_dim=hidden_dim, + uvqk_weight=uvqk_weight, + uvqk_bias=uvqk_bias, + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + attn_alpha=alpha, + causal=causal, + num_targets=num_targets, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + recompute_uvqk_in_backward=recompute_uvqk_in_backward, + recompute_normed_x_in_backward=recompute_normed_x_in_backward, + sort_by_length=sort_by_length, + kernel=real_kernel, + ) + real_out = real_u + real_attn_output + torch.testing.assert_close( + ref_u, + real_u, + atol=atol, + rtol=rtol, + ) + torch.testing.assert_close( + ref_attn_output, + real_attn_output, + atol=atol, + rtol=rtol, + ) + if test_backward: + # real implementation + dout = dout.detach().clone() + real_out.backward(dout) + ( + real_dx, + real_d_norm_weight, + real_d_norm_bias, + real_d_uvqk_weight, + real_d_uvqk_bias, + ) = ( + x.grad.clone(), + norm_weight.grad.clone(), + norm_bias.grad.clone(), + uvqk_weight.grad.clone(), + uvqk_bias.grad.clone(), + ) + torch.testing.assert_close(ref_dx, real_dx, atol=atol, rtol=rtol) + torch.testing.assert_close( + ref_d_norm_weight, real_d_norm_weight, atol=atol, rtol=rtol + ) + torch.testing.assert_close( + ref_d_norm_bias, real_d_norm_bias, atol=atol, rtol=rtol + ) + torch.testing.assert_close( + ref_d_uvqk_weight, real_d_uvqk_weight, atol=atol, rtol=rtol + ) + torch.testing.assert_close( + ref_d_uvqk_bias, real_d_uvqk_bias, atol=atol, rtol=rtol + ) diff --git a/recommendation_v4/generative_recommenders/ops/tests/jagged_tensors_test.py b/recommendation_v4/generative_recommenders/ops/tests/jagged_tensors_test.py new file mode 100644 index 000000000..e03d68d0b --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/tests/jagged_tensors_test.py @@ -0,0 +1,963 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import unittest +from typing import Optional + +import torch +from generative_recommenders.common import ( + generate_sparse_seq_len, + gpu_unavailable, + HammerKernel, + set_dev_mode, +) +from generative_recommenders.ops.jagged_tensors import ( + concat_2D_jagged, + concat_2D_jagged_multirow, + split_2D_jagged, + split_2D_jagged_multirow, +) +from hypothesis import given, settings, strategies as st, Verbosity + + +class JaggedTensorsTest(unittest.TestCase): + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore + @given( + batch_size=st.integers(2, 8), + max_len_a=st.integers(20, 100), + max_len_b=st.integers(20, 100), + D=st.integers(10, 30), + is_dense_a=st.sampled_from([True, False]), + is_dense_b=st.sampled_from([True, False]), + dtype=st.sampled_from( + [torch.bfloat16, torch.float32] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + ) + @settings( + verbosity=Verbosity.verbose, + max_examples=20, + deadline=None, + ) + # pyre-ignore[2] + def test_split_2D_jagged_triton(self, *args, **kwargs) -> None: + self._test_split_2D_jagged( + *args, + **kwargs, + test_backward=True, + ref_kernel=HammerKernel.PYTORCH, + real_kernel=HammerKernel.TRITON, + ) + + def _test_split_2D_jagged( + self, + batch_size: int, + max_len_a: int, + max_len_b: int, + D: int, + is_dense_a: bool, + is_dense_b: bool, + test_backward: bool, + ref_kernel: HammerKernel, + real_kernel: HammerKernel, + dtype: torch.dtype = torch.float32, + skip_comparisons: bool = False, + ) -> None: + set_dev_mode(True) + from generative_recommenders.ops.jagged_tensors import split_2D_jagged + + max_seq_len = max_len_a + max_len_b + if not is_dense_a: + lengths_a = torch.randint( + 1, max_len_a + 1, size=(batch_size,), device=torch.device("cuda") + ) + offsets_a = torch.zeros( + (batch_size + 1,), dtype=torch.int64, device=torch.device("cuda") + ) + offsets_a[1:] = torch.cumsum(lengths_a, dim=0) + total_len_a = int(offsets_a[-1].item()) + else: + offsets_a = None + total_len_a = batch_size * max_len_a + is_dense_b = False + if not is_dense_b: + lengths_b = torch.randint( + 1, max_len_b + 1, size=(batch_size,), device=torch.device("cuda") + ) + offsets_b = torch.zeros( + (batch_size + 1,), dtype=torch.int64, device=torch.device("cuda") + ) + offsets_b[1:] = torch.cumsum(lengths_b, dim=0) + total_len_b = int(offsets_b[-1].item()) + else: + offsets_b = None + total_len_b = batch_size * max_len_b + values = ( + torch.empty( + (total_len_a + total_len_b, D), + dtype=dtype, + device=torch.device("cuda"), + ) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + + ref_values_a, ref_values_b = split_2D_jagged( + max_seq_len=max_len_a + max_len_b, + values=values, + max_len_left=max_len_a if is_dense_a else None, + max_len_right=max_len_b if is_dense_b else None, + offsets_left=offsets_a, + offsets_right=offsets_b, + kernel=ref_kernel, + ) + d_values_a = torch.randn_like(ref_values_a) + d_values_b = torch.randn_like(ref_values_b) + ref_values_a.backward(d_values_a, retain_graph=True) + ref_values_b.backward(d_values_b) + if skip_comparisons: + return + + assert values.grad is not None + ref_d_values, values.grad = values.grad.clone(), None + + values = values.detach().clone().requires_grad_() + real_values_a, real_values_b = split_2D_jagged( + max_seq_len=max_seq_len, + values=values, + max_len_left=max_len_a if is_dense_a else None, + max_len_right=max_len_b if is_dense_b else None, + offsets_left=offsets_a, + offsets_right=offsets_b, + kernel=real_kernel, + ) + torch.testing.assert_close(ref_values_a, real_values_a) + torch.testing.assert_close(ref_values_b, real_values_b) + + if test_backward: + d_values_a = d_values_a.detach().clone() + d_values_b = d_values_b.detach().clone() + real_values_a.backward(d_values_a, retain_graph=True) + real_values_b.backward(d_values_b) + real_d_values = values.grad.clone() + torch.testing.assert_close(ref_d_values, real_d_values) + + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore + @given( + batch_size=st.integers(2, 8), + max_len_a=st.integers(20, 100), + max_len_b=st.integers(20, 100), + D=st.integers(10, 30), + is_dense_a=st.sampled_from([True, False]), + is_dense_b=st.sampled_from([True, False]), + dtype=st.sampled_from( + [torch.bfloat16, torch.float32] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + ) + @settings( + verbosity=Verbosity.verbose, + max_examples=20, + deadline=None, + ) + # pyre-ignore[2] + def test_concat_2D_jagged_triton(self, *args, **kwargs) -> None: + self._test_concat_2D_jagged( + *args, + **kwargs, + test_backward=True, + ref_kernel=HammerKernel.PYTORCH, + real_kernel=HammerKernel.TRITON, + ) + + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore + @given( + batch_size=st.sampled_from([130]), + max_len_a=st.sampled_from([32768]), + max_len_b=st.sampled_from([10]), + D=st.sampled_from([512]), + is_dense_a=st.sampled_from([True, False]), + is_dense_b=st.sampled_from([True, False]), + dtype=st.sampled_from( + [torch.bfloat16, torch.float32] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + ) + @settings( + verbosity=Verbosity.verbose, + max_examples=20, + deadline=None, + ) + # pyre-ignore[2] + def test_concat_2D_jagged_large_tensor(self, *args, **kwargs) -> None: + self._test_concat_2D_jagged( + *args, + **kwargs, + test_backward=True, + skip_comparisons=True, + ref_kernel=HammerKernel.TRITON, + real_kernel=HammerKernel.TRITON, + ) + + def _test_concat_2D_jagged( + self, + batch_size: int, + max_len_a: int, + max_len_b: int, + D: int, + is_dense_a: bool, + is_dense_b: bool, + test_backward: bool, + ref_kernel: HammerKernel, + real_kernel: HammerKernel, + dtype: torch.dtype = torch.float32, + skip_comparisons: bool = False, + ) -> None: + set_dev_mode(True) + from generative_recommenders.ops.jagged_tensors import concat_2D_jagged + + if not is_dense_a: + lengths_a = torch.randint( + 1, max_len_a + 1, size=(batch_size,), device=torch.device("cuda") + ) + offsets_a = torch.zeros( + (batch_size + 1,), dtype=torch.int64, device=torch.device("cuda") + ) + offsets_a[1:] = torch.cumsum(lengths_a, dim=0) + total_len_a = int(offsets_a[-1].item()) + else: + offsets_a = None + total_len_a = batch_size * max_len_a + values_a = ( + torch.empty( + (total_len_a, D), + dtype=dtype, + device=torch.device("cuda"), + ) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + + if not is_dense_b: + lengths_b = torch.randint( + 1, max_len_b + 1, size=(batch_size,), device=torch.device("cuda") + ) + offsets_b = torch.zeros( + (batch_size + 1,), dtype=torch.int64, device=torch.device("cuda") + ) + offsets_b[1:] = torch.cumsum(lengths_b, dim=0) + total_len_b = int(offsets_b[-1].item()) + else: + offsets_b = None + total_len_b = batch_size * max_len_b + values_b = ( + torch.empty( + (total_len_b, D), + dtype=dtype, + device=torch.device("cuda"), + ) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + + ref_values = concat_2D_jagged( + max_seq_len=max_len_a + max_len_b, + values_left=values_a, + values_right=values_b, + max_len_left=max_len_a, + max_len_right=max_len_b, + offsets_left=offsets_a, + offsets_right=offsets_b, + kernel=ref_kernel, + ) + dout = torch.randn_like(ref_values) + ref_values.backward(dout) + if skip_comparisons: + return + + assert values_a.grad is not None + ref_d_a, values_a.grad = values_a.grad.clone(), None + assert values_b.grad is not None + ref_d_b, values_b.grad = values_b.grad.clone(), None + + values_a = values_a.detach().clone().requires_grad_() + values_b = values_b.detach().clone().requires_grad_() + dout = dout.detach().clone() + real_values = concat_2D_jagged( + max_seq_len=max_len_a + max_len_b, + values_left=values_a, + values_right=values_b, + max_len_left=max_len_a, + max_len_right=max_len_b, + offsets_left=offsets_a, + offsets_right=offsets_b, + kernel=real_kernel, + ) + torch.testing.assert_close(ref_values, real_values) + + if test_backward: + real_values.backward(dout) + real_d_a = values_a.grad.clone() + real_d_b = values_b.grad.clone() + torch.testing.assert_close(ref_d_a, real_d_a) + torch.testing.assert_close(ref_d_b, real_d_b) + + # pyre-ignore + @given( + batch_size=st.integers(2, 8), + max_uih_len=st.integers(20, 100), + max_l2_len=st.integers(10, 30), + contextual_seq_len=st.sampled_from([0, 10]), + max_targets=st.sampled_from([10, 20]), + D=st.integers(10, 30), + dtype=st.sampled_from( + [torch.bfloat16, torch.float32] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + ) + @settings( + verbosity=Verbosity.verbose, + max_examples=20, + deadline=None, + ) + def test_hstu_split_l2_embeddings( + self, + batch_size: int, + max_uih_len: int, + max_l2_len: int, + contextual_seq_len: int, + max_targets: int, + D: int, + dtype: torch.dtype, + ) -> None: + set_dev_mode(True) + from generative_recommenders.ops.jagged_tensors import hstu_split_l2_embeddings + + max_seq_len = max_uih_len + max_targets + contextual_seq_len + num_targets = torch.randint( + 1, max_targets + 1, size=(batch_size,), device=torch.device("cuda") + ) + x_lengths = torch.randint( + 0, + max_uih_len + 1, + size=(batch_size,), + device=torch.device("cuda"), + ) + x_lengths = num_targets + x_lengths + contextual_seq_len + x_offsets = torch.zeros( + (batch_size + 1,), dtype=torch.int64, device=torch.device("cuda") + ) + x_offsets[1:] = torch.cumsum(x_lengths, dim=0) + total_len = int(x_offsets[-1].item()) + x = ( + torch.empty( + (total_len, D), + dtype=dtype, + device=torch.device("cuda"), + ) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + prefix_lengths = x_lengths - max_l2_len - num_targets - contextual_seq_len + prefix_lengths = torch.clamp(prefix_lengths, min=0) + prefix_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(prefix_lengths) + l2_offsets = x_offsets - prefix_offsets + ref_prefix_x, ref_l2_x = hstu_split_l2_embeddings( + max_seq_len=max_seq_len, + x=x, + prefix_offsets=prefix_offsets, + l2_offsets=l2_offsets, + contextual_seq_len=contextual_seq_len, + kernel=HammerKernel.PYTORCH, + ) + d_prefix_x = torch.randn_like(ref_prefix_x) + d_l2_x = torch.randn_like(ref_l2_x) + ref_prefix_x.backward(d_prefix_x, retain_graph=True) + ref_l2_x.backward(d_l2_x) + assert x.grad is not None + ref_d_x, x.grad = x.grad.clone(), None + x = x.detach().clone().requires_grad_() + real_prefix_x, real_l2_x = hstu_split_l2_embeddings( + max_seq_len=max_seq_len, + x=x, + prefix_offsets=prefix_offsets, + l2_offsets=l2_offsets, + contextual_seq_len=contextual_seq_len, + kernel=HammerKernel.TRITON, + ) + print(ref_prefix_x.shape, real_prefix_x.shape) + torch.testing.assert_close(ref_prefix_x, real_prefix_x) + torch.testing.assert_close(ref_l2_x, real_l2_x) + d_prefix_x = d_prefix_x.detach().clone() + d_l2_x = d_l2_x.detach().clone() + real_prefix_x.backward(d_prefix_x, retain_graph=True) + real_l2_x.backward(d_l2_x) + real_d_x = x.grad.clone() + torch.testing.assert_close(ref_d_x, real_d_x) + + # pyre-ignore + @given( + batch_size=st.integers(1, 1), + max_prefix_len=st.integers(10, 10), + max_l2_len=st.integers(5, 5), + contextual_seq_len=st.sampled_from([3]), + max_targets=st.sampled_from([2]), + D=st.integers(10, 10), + dtype=st.sampled_from( + [torch.bfloat16, torch.float32] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + ) + @settings( + verbosity=Verbosity.verbose, + max_examples=20, + deadline=None, + ) + def test_hstu_concat_l2_embeddings( + self, + batch_size: int, + max_prefix_len: int, + max_l2_len: int, + contextual_seq_len: int, + max_targets: int, + D: int, + dtype: torch.dtype, + ) -> None: + set_dev_mode(True) + from generative_recommenders.ops.jagged_tensors import hstu_concat_l2_embeddings + + num_targets = torch.randint( + 1, max_targets + 1, size=(batch_size,), device=torch.device("cuda") + ) + l2_lengths = torch.randint( + 0, + max_l2_len + 1, + size=(batch_size,), + device=torch.device("cuda"), + ) + l2_lengths = num_targets + l2_lengths + contextual_seq_len + max_l2_len = max_l2_len + contextual_seq_len + max_targets + l2_offsets = torch.zeros( + (batch_size + 1,), dtype=torch.int64, device=torch.device("cuda") + ) + l2_offsets[1:] = torch.cumsum(l2_lengths, dim=0) + total_l2_len = int(l2_offsets[-1].item()) + l2_x = ( + torch.empty( + (total_l2_len, D), + dtype=dtype, + device=torch.device("cuda"), + ) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + prefix_lengths = torch.randint( + 0, + max_prefix_len + 1, + size=(batch_size,), + device=torch.device("cuda"), + ) + prefix_lengths = torch.randint( + 0, + max_prefix_len + 1, + size=(batch_size,), + device=torch.device("cuda"), + ) + prefix_offsets = torch.zeros( + (batch_size + 1,), dtype=torch.int64, device=torch.device("cuda") + ) + prefix_offsets[1:] = torch.cumsum(prefix_lengths, dim=0) + total_prefix_len = int(prefix_offsets[-1].item()) + prefix_x = ( + torch.empty( + (total_prefix_len, D), + dtype=dtype, + device=torch.device("cuda"), + ) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + ref_x = hstu_concat_l2_embeddings( + max_prefix_len=max_prefix_len, + prefix_x=prefix_x, + prefix_offsets=prefix_offsets, + max_l2_len=max_l2_len, + l2_x=l2_x, + l2_offsets=l2_offsets, + contextual_seq_len=contextual_seq_len, + kernel=HammerKernel.PYTORCH, + ) + dout = torch.randn_like(ref_x) + ref_x.backward(dout) + + assert prefix_x.grad is not None + ref_d_prefix_x, prefix_x.grad = prefix_x.grad.clone(), None + assert l2_x.grad is not None + ref_d_l2_x, l2_x.grad = l2_x.grad.clone(), None + + prefix_x = prefix_x.detach().clone().requires_grad_() + l2_x = l2_x.detach().clone().requires_grad_() + real_x = hstu_concat_l2_embeddings( + max_prefix_len=max_prefix_len, + prefix_x=prefix_x, + prefix_offsets=prefix_offsets, + max_l2_len=max_l2_len, + l2_x=l2_x, + l2_offsets=l2_offsets, + contextual_seq_len=contextual_seq_len, + kernel=HammerKernel.TRITON, + ) + torch.testing.assert_close(ref_x, real_x) + dout = dout.detach().clone() + real_x.backward(dout) + real_d_prefix_x = prefix_x.grad.clone() + real_d_l2_x = l2_x.grad.clone() + torch.testing.assert_close(ref_d_prefix_x, real_d_prefix_x) + torch.testing.assert_close(ref_d_l2_x, real_d_l2_x) + + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore + @given( + batch_size=st.integers(4, 8), + max_seq_len=st.integers(50, 500), + D=st.integers(20, 200), + K=st.integers(30, 200), + dtype=st.sampled_from( + [torch.float32, torch.bfloat16, torch.float16] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + contiguous=st.booleans(), + ) + @settings( + verbosity=Verbosity.verbose, + max_examples=20, + deadline=None, + ) + # pyre-ignore[2] + def test_jagged_dense_bmm_broadcast_add_triton(self, *args, **kwargs) -> None: + self._test_jagged_dense_bmm_broadcast_add( + *args, + **kwargs, + test_backward=True, + atol=None, + rtol=None, + ref_kernel=HammerKernel.PYTORCH, + real_kernel=HammerKernel.TRITON, + ) + + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore + @given( + batch_size=st.sampled_from([130]), + max_seq_len=st.sampled_from([32768]), + D=st.sampled_from([512]), + K=st.sampled_from([512]), + dtype=st.sampled_from([torch.float32, torch.bfloat16]), + contiguous=st.booleans(), + ) + @settings( + verbosity=Verbosity.verbose, + max_examples=1, + deadline=None, + ) + def test_jagged_dense_bmm_broadcast_add_triton_large_tensor( + self, + # pyre-fixme[2]: Parameter must be annotated. + *args, + **kwargs, # pyre-ignore[2] + ) -> None: + self._test_jagged_dense_bmm_broadcast_add( + *args, + **kwargs, + test_backward=True, + atol=None, + rtol=None, + ref_kernel=HammerKernel.TRITON, + real_kernel=HammerKernel.TRITON, + ) + + def _test_jagged_dense_bmm_broadcast_add( + self, + batch_size: int, + max_seq_len: int, + D: int, + K: int, + dtype: torch.dtype, + ref_kernel: HammerKernel, + real_kernel: HammerKernel, + test_backward: bool, + contiguous: bool = True, + atol: Optional[float] = None, + rtol: Optional[float] = None, + sparsity: float = -1, + ) -> None: + set_dev_mode(True) + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + from generative_recommenders.ops.jagged_tensors import ( + jagged_dense_bmm_broadcast_add, + ) + + if sparsity > 0.0: + lengths = generate_sparse_seq_len( + size=batch_size, + max_seq_len=max_seq_len, + sparsity=sparsity, + device=torch.device("cuda"), + ).to(torch.int64) + else: + lengths = torch.randint( + max_seq_len + 1, size=(batch_size,), device=torch.device("cuda") + ) + # Test the edge case with an empty row + seq_offsets = torch.zeros( + (batch_size + 1,), dtype=torch.int64, device=torch.device("cuda") + ) + seq_offsets[1:] = torch.cumsum(lengths, dim=0) + jagged_size = int(seq_offsets[-1].item()) + jagged = ( + torch.empty((jagged_size, D), dtype=dtype, device=torch.device("cuda")) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + dense = ( + torch.empty((batch_size, D, K), dtype=dtype, device=torch.device("cuda")) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + bias = ( + torch.empty((batch_size, K), dtype=dtype, device=torch.device("cuda")) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + + if not contiguous: + dense = ( + dense.transpose(1, 2) + .contiguous() + .transpose(1, 2) + .detach() + .clone() + .requires_grad_() + ) + + ref_out = jagged_dense_bmm_broadcast_add( + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + jagged=jagged, + dense=dense, + bias=bias, + kernel=ref_kernel, + ).to(jagged.dtype) + if test_backward: + dout = torch.randn_like(ref_out) * 0.01 + ref_out.backward(dout) + # pyre-ignore + ref_d_jagged, jagged.grad = jagged.grad.clone(), None + ref_d_dense, dense.grad = dense.grad.clone(), None + ref_d_bias, bias.grad = bias.grad.clone(), None + + jagged = jagged.detach().clone().requires_grad_() + dense = dense.detach().clone().requires_grad_() + bias = bias.detach().clone().requires_grad_() + real_out = jagged_dense_bmm_broadcast_add( + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + jagged=jagged, + dense=dense, + bias=bias, + kernel=real_kernel, + ) + torch.testing.assert_close( + ref_out, + real_out, + atol=atol, + rtol=rtol, + ) + if test_backward: + real_out.backward(dout) # pyre-ignore + real_d_jagged = jagged.grad.clone() + real_d_dense = dense.grad.clone() + real_d_bias = bias.grad.clone() + torch.testing.assert_close( + ref_d_jagged, # pyre-ignore + real_d_jagged, + atol=atol, + rtol=rtol, + ) + torch.testing.assert_close( + ref_d_dense, # pyre-ignore + real_d_dense, + atol=atol, + rtol=rtol, + ) + torch.testing.assert_close( + ref_d_bias, # pyre-ignore + real_d_bias, + atol=atol, + rtol=rtol, + ) + + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore + @given( + batch_size=st.integers(2, 8), + max_len_a=st.integers(20, 100), + max_len_b=st.integers(20, 100), + D=st.integers(10, 30), + is_dense_a=st.sampled_from([True, False]), + is_dense_b=st.sampled_from([True, False]), + dtype=st.sampled_from( + [torch.bfloat16, torch.float32] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + ) + @settings( + verbosity=Verbosity.verbose, + max_examples=20, + deadline=None, + ) + # pyre-ignore[2] + def test_concat_2D_jagged_multirow_triton(self, *args, **kwargs) -> None: + self._test_concat_2D_jagged_multirow( + *args, + **kwargs, + test_backward=True, + ref_kernel=HammerKernel.PYTORCH, + real_kernel=HammerKernel.TRITON, + ) + + def _test_concat_2D_jagged_multirow( + self, + batch_size: int, + max_len_a: int, + max_len_b: int, + D: int, + is_dense_a: bool, + is_dense_b: bool, + test_backward: bool, + ref_kernel: HammerKernel, + real_kernel: HammerKernel, + dtype: torch.dtype = torch.float32, + ) -> None: + set_dev_mode(True) + + if not is_dense_a: + lengths_a = torch.randint( + 1, max_len_a + 1, size=(batch_size,), device=torch.device("cuda") + ) + offsets_a = torch.zeros( + (batch_size + 1,), dtype=torch.int64, device=torch.device("cuda") + ) + offsets_a[1:] = torch.cumsum(lengths_a, dim=0) + total_len_a = int(offsets_a[-1].item()) + else: + offsets_a = None + total_len_a = batch_size * max_len_a + values_a = ( + torch.empty( + (total_len_a, D), + dtype=dtype, + device=torch.device("cuda"), + ) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + + if not is_dense_b: + lengths_b = torch.randint( + 1, max_len_b + 1, size=(batch_size,), device=torch.device("cuda") + ) + offsets_b = torch.zeros( + (batch_size + 1,), dtype=torch.int64, device=torch.device("cuda") + ) + offsets_b[1:] = torch.cumsum(lengths_b, dim=0) + total_len_b = int(offsets_b[-1].item()) + else: + offsets_b = None + total_len_b = batch_size * max_len_b + values_b = ( + torch.empty( + (total_len_b, D), + dtype=dtype, + device=torch.device("cuda"), + ) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + + ref_values = concat_2D_jagged( + max_seq_len=max_len_a + max_len_b, + values_left=values_a, + values_right=values_b, + max_len_left=max_len_a, + max_len_right=max_len_b, + offsets_left=offsets_a, + offsets_right=offsets_b, + kernel=ref_kernel, + ) + dout = torch.randn_like(ref_values) * 0.1 + ref_values.backward(dout) + assert values_a.grad is not None + ref_d_a, values_a.grad = values_a.grad.clone(), None + assert values_b.grad is not None + ref_d_b, values_b.grad = values_b.grad.clone(), None + + values_a = values_a.detach().clone().requires_grad_() + values_b = values_b.detach().clone().requires_grad_() + dout = dout.detach().clone() + + real_values = concat_2D_jagged_multirow( + max_seq_len=max_len_a + max_len_b, + values_left=values_a, + values_right=values_b, + offsets_left=offsets_a, + offsets_right=offsets_b, + max_len_left=max_len_a, + max_len_right=max_len_b, + kernel=real_kernel, + ) + torch.testing.assert_close(ref_values, real_values) + if test_backward: + real_values.backward(dout) + real_d_a = values_a.grad.clone() + real_d_b = values_b.grad.clone() + torch.testing.assert_close(ref_d_a, real_d_a) + torch.testing.assert_close(ref_d_b, real_d_b) + + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore + @given( + batch_size=st.integers(2, 8), + max_len_a=st.integers(20, 100), + max_len_b=st.integers(20, 100), + D=st.integers(10, 30), + dtype=st.sampled_from( + [torch.bfloat16, torch.float32] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + ) + @settings( + verbosity=Verbosity.verbose, + max_examples=20, + deadline=None, + ) + # pyre-ignore[2] + def test_split_2D_jagged_multirow_triton(self, *args, **kwargs) -> None: + self._test_split_2D_jagged_multirow( + *args, + **kwargs, + test_backward=True, + ref_kernel=HammerKernel.PYTORCH, + real_kernel=HammerKernel.TRITON, + ) + + def _test_split_2D_jagged_multirow( + self, + batch_size: int, + max_len_a: int, + max_len_b: int, + D: int, + test_backward: bool, + ref_kernel: HammerKernel, + real_kernel: HammerKernel, + dtype: torch.dtype = torch.float32, + ) -> None: + set_dev_mode(True) + + lengths_a = torch.randint( + 1, max_len_a + 1, size=(batch_size,), device=torch.device("cuda") + ) + offsets_a = torch.zeros( + (batch_size + 1,), dtype=torch.int64, device=torch.device("cuda") + ) + offsets_a[1:] = torch.cumsum(lengths_a, dim=0) + + lengths_b = torch.randint( + 1, max_len_b + 1, size=(batch_size,), device=torch.device("cuda") + ) + offsets_b = torch.zeros( + (batch_size + 1,), dtype=torch.int64, device=torch.device("cuda") + ) + offsets_b[1:] = torch.cumsum(lengths_b, dim=0) + + total_len = int(offsets_a[-1].item()) + int(offsets_b[-1].item()) + values = ( + torch.empty( + (total_len, D), + dtype=dtype, + device=torch.device("cuda"), + ) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + + ref_values_a, ref_values_b = split_2D_jagged( + max_seq_len=max_len_a + max_len_b, + values=values, + total_len_left=int(offsets_a[-1].item()), + total_len_right=int(offsets_b[-1].item()), + max_len_left=max_len_a, + max_len_right=max_len_b, + offsets_left=offsets_a, + offsets_right=offsets_b, + kernel=ref_kernel, + ) + d_values_a = torch.randn_like(ref_values_a) * 0.1 + d_values_b = torch.randn_like(ref_values_b) * 0.1 + ref_values_a.backward(d_values_a, retain_graph=True) + ref_values_b.backward(d_values_b) + assert values.grad is not None + ref_d_values, values.grad = values.grad.clone(), None + + values = values.detach().clone().requires_grad_() + d_values_a = d_values_a.detach().clone() + d_values_b = d_values_b.detach().clone() + + max_len_a_actual = int((offsets_a[1:] - offsets_a[:-1]).max().item()) + max_len_b_actual = int((offsets_b[1:] - offsets_b[:-1]).max().item()) + + real_values_a, real_values_b = split_2D_jagged_multirow( + max_seq_len=max_len_a + max_len_b, + values=values, + total_len_left=int(offsets_a[-1].item()), + total_len_right=int(offsets_b[-1].item()), + max_len_left=max_len_a_actual, + max_len_right=max_len_b_actual, + offsets_left=offsets_a, + offsets_right=offsets_b, + kernel=real_kernel, + ) + torch.testing.assert_close(ref_values_a, real_values_a) + torch.testing.assert_close(ref_values_b, real_values_b) + if test_backward: + real_values_a.backward(d_values_a, retain_graph=True) + real_values_b.backward(d_values_b) + real_d_values = values.grad.clone() + torch.testing.assert_close(ref_d_values, real_d_values) diff --git a/recommendation_v4/generative_recommenders/ops/tests/layer_norm_test.py b/recommendation_v4/generative_recommenders/ops/tests/layer_norm_test.py new file mode 100644 index 000000000..62540967a --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/tests/layer_norm_test.py @@ -0,0 +1,231 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import copy +import unittest + +import torch +from generative_recommenders.common import gpu_unavailable, HammerKernel, set_dev_mode +from generative_recommenders.ops.layer_norm import ( + layer_norm, + LayerNorm, + swish_layer_norm, + SwishLayerNorm, +) +from hypothesis import given, settings, strategies as st, Verbosity + + +class LayerNormTest(unittest.TestCase): + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore[56] + @given( + N=st.sampled_from([4200000]), + D=st.sampled_from([512]), + is_swish=st.sampled_from([False]), + dtype=st.sampled_from( + [torch.bfloat16] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + ) + @settings( + deadline=None, + verbosity=Verbosity.verbose, + max_examples=1, + ) + # pyre-ignore[2] + def test_large_tensors(self, *args, **kwargs) -> None: + self._test_layernorm( + *args, + **kwargs, + ref_kernel=HammerKernel.TRITON, + real_kernel=HammerKernel.TRITON, + skip_comparisons=True, + ) + + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore[56] + @given( + N=st.integers(min_value=0, max_value=10000), + D=st.integers(min_value=32, max_value=512), + is_swish=st.sampled_from([True, False]), + dtype=st.sampled_from( + [torch.bfloat16, torch.float32] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + ) + @settings( + deadline=None, + verbosity=Verbosity.verbose, + max_examples=20, + ) + # pyre-ignore[2] + def test_ln(self, *args, **kwargs) -> None: + self._test_layernorm( + *args, + **kwargs, + ref_kernel=HammerKernel.PYTORCH, + real_kernel=HammerKernel.TRITON, + ) + + def _test_layernorm( + self, + N: int, + D: int, + is_swish: bool, + dtype: torch.dtype, + ref_kernel: HammerKernel, + real_kernel: HammerKernel, + skip_comparisons: bool = False, + ) -> None: + N = N // 4 * 4 + # enable auto-tuning to verify correctness of multi-row kernel + set_dev_mode(False) + x = ( + torch.empty((N, D), dtype=dtype, device=torch.device("cuda")) + .normal_(0.0, 1.0) + .requires_grad_() + ) + weight = ( + torch.empty((D,), device=torch.device("cuda")) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + bias = ( + torch.empty((D,), device=torch.device("cuda")) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + if is_swish: + layer_norm_func = swish_layer_norm + else: + layer_norm_func = layer_norm + # ref + ref_out = layer_norm_func(x, weight, bias, eps=1e-6, kernel=ref_kernel) + dout = torch.randn_like(ref_out) * 0.05 + ref_out.backward(dout) + if skip_comparisons: + return + # pyre-ignore[16] + ref_dx, x.grad = x.grad.detach().clone(), None + ref_dw, weight.grad = weight.grad.detach().clone(), None + ref_db, bias.grad = bias.grad.detach().clone(), None + # opt + x = x.detach().clone().requires_grad_() + weight = weight.detach().clone().requires_grad_() + bias = bias.detach().clone().requires_grad_() + opt_out = layer_norm_func(x, weight, bias, eps=1e-6, kernel=real_kernel) + dout = dout.detach().clone() + opt_out.backward(dout) + opt_dx, x.grad = x.grad.detach().clone(), None + opt_dw, weight.grad = weight.grad.detach().clone(), None + opt_db, bias.grad = bias.grad.detach().clone(), None + torch.testing.assert_close(ref_out, opt_out) + torch.testing.assert_close(ref_dx, opt_dx) + torch.testing.assert_close(ref_dw, opt_dw) + torch.testing.assert_close(ref_db, opt_db) + + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore[56] + @given( + N=st.integers(min_value=32, max_value=10000), + D=st.integers(min_value=32, max_value=512), + is_swish=st.sampled_from([True, False]), + dtype=st.sampled_from( + [torch.bfloat16, torch.float32] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + ) + @settings( + deadline=None, + verbosity=Verbosity.verbose, + max_examples=20, + ) + # pyre-ignore[2] + def test_modules(self, *args, **kwargs) -> None: + self._test_layer_norm_module( + *args, + **kwargs, + ref_kernel=HammerKernel.PYTORCH, + real_kernel=HammerKernel.TRITON, + ) + + def _test_layer_norm_module( + self, + N: int, + D: int, + is_swish: bool, + dtype: torch.dtype, + ref_kernel: HammerKernel, + real_kernel: HammerKernel, + skip_comparisons: bool = False, + ) -> None: + set_dev_mode(True) + x = ( + torch.empty((N, D), dtype=dtype, device=torch.device("cuda")) + .normal_(0.0, 1.0) + .requires_grad_() + ) + # ref + if is_swish: + ref_layer = SwishLayerNorm( + dim=D, + eps=1e-6, + ).to(device="cuda") + ref_layer._hammer_kernel = ref_kernel + else: + ref_layer = LayerNorm( + dim=D, + eps=1e-6, + ).to(device="cuda") + ref_layer._hammer_kernel = ref_kernel + opt_layer = copy.deepcopy(ref_layer) + opt_layer._hammer_kernel = real_kernel + + ref_out = ref_layer(x) + dout = torch.randn_like(ref_out) * 0.05 + ref_out.backward(dout) + if skip_comparisons: + return + # pyre-ignore[16] + ref_dx, x.grad = x.grad.detach().clone(), None + ref_dw = ref_layer.weight.grad.detach().clone() + ref_db = ref_layer.bias.grad.detach().clone() + # opt + x = x.detach().clone().requires_grad_() + opt_out = opt_layer(x) + dout = dout.detach().clone() + opt_out.backward(dout) + opt_dx, x.grad = x.grad.detach().clone(), None + opt_dw = opt_layer.weight.grad.detach().clone() + opt_db = opt_layer.bias.grad.detach().clone() + torch.testing.assert_close(ref_out, opt_out) + torch.testing.assert_close( + ref_dx, + opt_dx, + ) + torch.testing.assert_close( + ref_dw, + opt_dw, + ) + torch.testing.assert_close( + ref_db, + opt_db, + ) diff --git a/recommendation_v4/generative_recommenders/ops/tests/mm_test.py b/recommendation_v4/generative_recommenders/ops/tests/mm_test.py new file mode 100644 index 000000000..0695275e3 --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/tests/mm_test.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import unittest +from typing import Optional + +import torch +from generative_recommenders.common import gpu_unavailable, HammerKernel +from generative_recommenders.ops.mm import addmm +from hypothesis import given, settings, strategies as st, Verbosity + + +class MMlTest(unittest.TestCase): + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore[56] + @given( + M=st.integers(min_value=100, max_value=300), + N=st.integers(min_value=100, max_value=300), + K=st.sampled_from([128, 256]), + broadcast=st.booleans(), + dtype=st.sampled_from( + [torch.float32, torch.bfloat16, torch.float16] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + ) + @settings( + verbosity=Verbosity.verbose, + max_examples=20, + deadline=None, + ) + def test_addmm( + self, + M: int, + N: int, + K: int, + broadcast: bool, + dtype: torch.dtype, + ) -> None: + self._test_addmm( + M=M, + N=N, + K=K, + broadcast=broadcast, + dtype=dtype, + kernel_type=HammerKernel.TRITON, + ) + + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore[56] + @given( + M=st.integers(min_value=100, max_value=300), + N=st.sampled_from([16, 48, 128, 144, 256]), + K=st.sampled_from([16, 48, 128, 144, 256]), + broadcast=st.booleans(), + dtype=st.sampled_from( + [torch.float32, torch.bfloat16, torch.float16] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + ) + @settings( + verbosity=Verbosity.verbose, + max_examples=20, + deadline=None, + ) + def test_addmm_tma( + self, + M: int, + N: int, + K: int, + broadcast: bool, + dtype: torch.dtype, + ) -> None: + self._test_addmm( + M=M, + N=N, + K=K, + broadcast=broadcast, + dtype=dtype, + kernel_type=HammerKernel.TRITON, + ) + + def _test_addmm( + self, + M: int, + N: int, + K: int, + broadcast: bool, + dtype: torch.dtype, + kernel_type: HammerKernel, + atol: Optional[float] = None, + rtol: Optional[float] = None, + ) -> None: + # to enable more deterministic results. + torch.manual_seed(0) + + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + + x: torch.Tensor = torch.rand((M, K), dtype=dtype, device="cuda").requires_grad_( + True + ) + w: torch.Tensor = torch.rand((K, N), dtype=dtype, device="cuda").requires_grad_( + True + ) + + if broadcast: + y: torch.Tensor = torch.rand( + (N), dtype=dtype, device="cuda" + ).requires_grad_(True) + else: + y: torch.Tensor = torch.rand( + (M, N), dtype=dtype, device="cuda" + ).requires_grad_(True) + + ref_z = addmm(y, x, w, kernel=HammerKernel.PYTORCH) + dz = torch.randn_like(ref_z) * 0.1 + ref_z.backward(dz) + # pyre-ignore[16] + ref_dx, x.grad = x.grad.detach().clone(), None + ref_dw, w.grad = w.grad.detach().clone(), None + ref_dy, y.grad = y.grad.detach().clone(), None + + x = x.detach().clone().requires_grad_(True) + w = w.detach().clone().requires_grad_(True) + y = y.detach().clone().requires_grad_(True) + real_z = addmm(y, x, w, kernel=kernel_type) + + torch.testing.assert_close(ref_z, real_z, atol=atol, rtol=rtol) + + # triton cc doesn't support backward + if kernel_type != HammerKernel.TRITON_CC: + real_z.backward(dz) + real_dx, x.grad = x.grad.detach().clone(), None + real_dw, w.grad = w.grad.detach().clone(), None + real_dy, y.grad = y.grad.detach().clone(), None + + torch.testing.assert_close(ref_dx, real_dx, atol=atol, rtol=rtol) + torch.testing.assert_close(ref_dw, real_dw, atol=atol, rtol=rtol) + torch.testing.assert_close(ref_dy, real_dy, atol=atol, rtol=rtol) diff --git a/recommendation_v4/generative_recommenders/ops/tests/position_test.py b/recommendation_v4/generative_recommenders/ops/tests/position_test.py new file mode 100644 index 000000000..ab9e1b415 --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/tests/position_test.py @@ -0,0 +1,234 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import unittest + +import torch +from generative_recommenders.common import ( + generate_sparse_seq_len, + gpu_unavailable, + HammerKernel, + set_dev_mode, +) +from hypothesis import given, settings, strategies as st, Verbosity + + +class PositionEmbeddingsTest(unittest.TestCase): + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore + @given( + alpha=st.sampled_from([0.5]), + max_uih_len=st.integers(50, 500), + max_contextual_seq_len=st.sampled_from([10]), + interleave_targets=st.sampled_from([True, False]), + batch_size=st.integers(16, 32), + D=st.integers(20, 200), + max_targets=st.sampled_from([10, 20]), + time_bucket_fn=st.sampled_from(["log"]), + dtype=st.sampled_from([torch.float32, torch.bfloat16, torch.float16]), + ) + @settings( + verbosity=Verbosity.verbose, + max_examples=20, + deadline=None, + ) + # pyre-ignore[2] + def test_add_timestamp_positional_embeddings_triton(self, *args, **kwargs) -> None: + self._test_add_timestamp_positional_embeddings( + *args, + **kwargs, + test_backward=True, + ref_kernel=HammerKernel.PYTORCH, + real_kernel=HammerKernel.TRITON, + ) + + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore + @given( + alpha=st.sampled_from([0.5]), + max_uih_len=st.sampled_from([32768]), + max_contextual_seq_len=st.sampled_from([10]), + interleave_targets=st.sampled_from([False]), + batch_size=st.sampled_from([130]), + D=st.sampled_from([512]), + max_targets=st.sampled_from([10]), + time_bucket_fn=st.sampled_from(["log"]), + dtype=st.sampled_from([torch.bfloat16]), + ) + @settings( + verbosity=Verbosity.verbose, + max_examples=1, + deadline=None, + ) + def test_add_timestamp_positional_embeddings_triton_large_tensor( + self, + # pyre-fixme[2]: Parameter must be annotated. + *args, + # pyre-ignore[2] + **kwargs, + ) -> None: + self._test_add_timestamp_positional_embeddings( + *args, + **kwargs, + test_backward=False, + ref_kernel=HammerKernel.TRITON, + real_kernel=HammerKernel.TRITON, + sparsity=1.0, + ) + + def _test_add_timestamp_positional_embeddings( + self, + alpha: float, + max_uih_len: int, + max_contextual_seq_len: int, + interleave_targets: bool, + batch_size: int, + D: int, + max_targets: int, + time_bucket_fn: str, + dtype: torch.dtype, + ref_kernel: HammerKernel, + real_kernel: HammerKernel, + test_backward: bool, + sparsity: float = -1, + ) -> None: + set_dev_mode(True) + from generative_recommenders.ops.position import ( + add_timestamp_positional_embeddings, + ) + + num_targets = torch.randint( + max_targets + 1, size=(batch_size,), device=torch.device("cuda") + ) + if sparsity > 0.0: + lengths = generate_sparse_seq_len( + size=batch_size, + max_seq_len=max_uih_len, + sparsity=sparsity, + device=torch.device("cuda"), + ).to(torch.int64) + else: + lengths = torch.randint( + max_uih_len + 1, size=(batch_size,), device=torch.device("cuda") + ) + seq_offsets = torch.zeros( + (batch_size + 1,), dtype=torch.int64, device=torch.device("cuda") + ) + seq_offsets[1:] = torch.cumsum(lengths, dim=0) + max_seq_len = max_uih_len + max_targets + + position_embeddings_weight = ( + torch.empty( + (max_seq_len, D), dtype=torch.float32, device=torch.device("cuda") + ) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + num_time_buckets = 1000 + timestamp_embeddings_weight = ( + torch.empty( + (num_time_buckets, D), dtype=torch.float32, device=torch.device("cuda") + ) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + seq_embeddings = ( + torch.empty( + (int(seq_offsets[-1].item()), D), + dtype=dtype, + device=torch.device("cuda"), + ) + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + timestamp_deltas: torch.Tensor = torch.randint( + 86400, + size=(batch_size, max_seq_len), + device="cuda", + ) + timestamps = timestamp_deltas.cumsum(dim=1) + mask = torch.arange(max_seq_len, device=timestamps.device) < lengths.unsqueeze( + 1 + ) + timestamps = timestamps[mask.view(batch_size, -1)] + + ref_out = add_timestamp_positional_embeddings( + alpha=alpha, + max_seq_len=max_seq_len, + max_contextual_seq_len=max_contextual_seq_len, + position_embeddings_weight=position_embeddings_weight, + timestamp_embeddings_weight=timestamp_embeddings_weight, + seq_offsets=seq_offsets, + seq_lengths=lengths, + seq_embeddings=seq_embeddings, + timestamps=timestamps, + num_targets=num_targets, + interleave_targets=interleave_targets, + time_bucket_fn=time_bucket_fn, + kernel=ref_kernel, + ) + dout = torch.randn_like(ref_out) * 0.01 + ref_out.backward(dout) + # pyre-ignore + ref_d_seq_embeddings, seq_embeddings.grad = seq_embeddings.grad.clone(), None + ref_d_position_embeddings_weight, position_embeddings_weight.grad = ( + position_embeddings_weight.grad.clone(), + None, + ) + ref_d_timestamp_embeddings_weight, timestamp_embeddings_weight.grad = ( + timestamp_embeddings_weight.grad.clone(), + None, + ) + + real_out = add_timestamp_positional_embeddings( + alpha=alpha, + max_seq_len=max_seq_len, + max_contextual_seq_len=max_contextual_seq_len, + position_embeddings_weight=position_embeddings_weight, + timestamp_embeddings_weight=timestamp_embeddings_weight, + seq_offsets=seq_offsets, + seq_lengths=lengths, + seq_embeddings=seq_embeddings, + timestamps=timestamps, + num_targets=num_targets, + interleave_targets=interleave_targets, + time_bucket_fn=time_bucket_fn, + kernel=real_kernel, + ) + + torch.testing.assert_close(ref_out, real_out) + if test_backward: + real_out.backward(dout) + real_d_seq_embeddings = seq_embeddings.grad.clone() + real_d_position_embeddings_weight = position_embeddings_weight.grad.clone() + real_d_timestamp_embeddings_weight = ( + timestamp_embeddings_weight.grad.clone() + ) + torch.testing.assert_close(ref_d_seq_embeddings, real_d_seq_embeddings) + torch.testing.assert_close( + ref_d_position_embeddings_weight, + real_d_position_embeddings_weight, + atol=5e-2 if dtype != torch.float32 else None, + rtol=2e-2 if dtype != torch.float32 else None, + ) + torch.testing.assert_close( + ref_d_timestamp_embeddings_weight, + real_d_timestamp_embeddings_weight, + atol=5e-2 if dtype != torch.float32 else None, + rtol=2e-2 if dtype != torch.float32 else None, + ) diff --git a/recommendation_v4/generative_recommenders/ops/tests/rms_norm_test.py b/recommendation_v4/generative_recommenders/ops/tests/rms_norm_test.py new file mode 100644 index 000000000..4e5c1a871 --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/tests/rms_norm_test.py @@ -0,0 +1,229 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import copy +import unittest + +import torch +from generative_recommenders.common import gpu_unavailable, HammerKernel, set_dev_mode +from generative_recommenders.ops.layer_norm import rms_norm, RMSNorm +from hammer.ops.triton.cc.utils import set_triton_cc_version +from hypothesis import given, settings, strategies as st, Verbosity + + +class LayerNormTest(unittest.TestCase): + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore[56] + @given( + N=st.sampled_from([2000000]), + D=st.sampled_from([512]), + dtype=st.sampled_from( + [torch.bfloat16] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + silu=st.booleans(), + ) + @settings( + deadline=None, + verbosity=Verbosity.verbose, + max_examples=1, + ) + # pyre-ignore[2] + def test_large_tensors(self, *args, **kwargs) -> None: + self._test_rms_norm( + *args, + **kwargs, + ref_kernel=HammerKernel.TRITON, + real_kernel=HammerKernel.TRITON, + skip_comparisons=True, + ) + + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore[56] + @given( + N=st.integers(min_value=0, max_value=10000), + D=st.integers(min_value=32, max_value=512), + dtype=st.sampled_from( + [torch.bfloat16, torch.float32] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + silu=st.booleans(), + ) + @settings( + deadline=None, + verbosity=Verbosity.verbose, + max_examples=50, + ) + # pyre-ignore[2] + def test_rms_norm(self, *args, **kwargs) -> None: + self._test_rms_norm( + *args, + **kwargs, + ref_kernel=HammerKernel.PYTORCH, + real_kernel=HammerKernel.TRITON, + ) + + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore[56] + @given( + N=st.integers(min_value=4, max_value=10000), + D=st.sampled_from([256, 512]), + dtype=st.sampled_from([torch.bfloat16, torch.float16]), + triton_cc_version=st.sampled_from(["", "repkg"]), + silu=st.just(False), + ) + @settings( + deadline=None, + verbosity=Verbosity.verbose, + max_examples=10, + ) + # pyre-ignore[2] + def test_rms_norm_triton_cc(self, triton_cc_version: str, *args, **kwargs) -> None: + set_triton_cc_version(triton_cc_version) + self._test_rms_norm( + *args, + **kwargs, + ref_kernel=HammerKernel.PYTORCH, + real_kernel=HammerKernel.TRITON_CC, + test_backward=False, + ) + + def _test_rms_norm( + self, + N: int, + D: int, + dtype: torch.dtype, + silu: bool, + ref_kernel: HammerKernel, + real_kernel: HammerKernel, + skip_comparisons: bool = False, + test_backward: bool = True, + ) -> None: + N = N // 4 * 4 + # enable auto-tuning to verify correctness of multi-row kernel + set_dev_mode(False) + x = ( + torch.empty((N, D), dtype=dtype, device=torch.device("cuda")) + .normal_(0.0, 1.0) + .requires_grad_() + ) + weight = ( + torch.empty((D,), device=torch.device("cuda")) + .uniform_(-1.0, 1.0) + .requires_grad_() + ) + # ref + ref_out = rms_norm(x, weight, eps=1e-6, silu=silu, kernel=ref_kernel) + opt_x = x.detach().clone().requires_grad_() + opt_weight = weight.detach().clone().requires_grad_() + opt_out = rms_norm(opt_x, opt_weight, eps=1e-6, silu=silu, kernel=real_kernel) + torch.testing.assert_close(ref_out, opt_out) + + if not test_backward: + return + + dout = torch.randn_like(ref_out) * 0.05 + ref_out.backward(dout) + if skip_comparisons: + return + # pyre-ignore[16] + ref_dx, x.grad = x.grad.detach().clone(), None + ref_dw, weight.grad = weight.grad.detach().clone(), None + # opt + dout = dout.detach().clone() + opt_out.backward(dout) + opt_dx, x.grad = opt_x.grad.detach().clone(), None + opt_dw, weight.grad = opt_weight.grad.detach().clone(), None + torch.testing.assert_close(ref_dx, opt_dx) + torch.testing.assert_close(ref_dw, opt_dw) + + @unittest.skipIf(*gpu_unavailable) + # pyre-ignore[56] + @given( + N=st.integers(min_value=32, max_value=10000), + D=st.integers(min_value=32, max_value=512), + dtype=st.sampled_from( + [torch.bfloat16, torch.float32] + if torch.cuda.get_device_capability(torch.device("cuda"))[0] >= 8 + else [torch.float32] + ), + ) + @settings( + deadline=None, + verbosity=Verbosity.verbose, + max_examples=50, + ) + # pyre-ignore[2] + def test_modules(self, *args, **kwargs) -> None: + self._test_rms_norm_module( + *args, + **kwargs, + ref_kernel=HammerKernel.PYTORCH, + real_kernel=HammerKernel.TRITON, + ) + + def _test_rms_norm_module( + self, + N: int, + D: int, + dtype: torch.dtype, + ref_kernel: HammerKernel, + real_kernel: HammerKernel, + skip_comparisons: bool = False, + ) -> None: + set_dev_mode(True) + x = ( + torch.empty((N, D), dtype=dtype, device=torch.device("cuda")) + .normal_(0.0, 1.0) + .requires_grad_() + ) + # ref + ref_layer = RMSNorm( + dim=D, + eps=1e-6, + ).to(device="cuda") + ref_layer._hammer_kernel = ref_kernel + opt_layer = copy.deepcopy(ref_layer) + opt_layer._hammer_kernel = real_kernel + + ref_out = ref_layer(x) + dout = torch.randn_like(ref_out) * 0.05 + ref_out.backward(dout) + if skip_comparisons: + return + # pyre-ignore[16] + ref_dx, x.grad = x.grad.detach().clone(), None + ref_dw = ref_layer.weight.grad.detach().clone() + # opt + x = x.detach().clone().requires_grad_() + opt_out = opt_layer(x) + dout = dout.detach().clone() + opt_out.backward(dout) + opt_dx, x.grad = x.grad.detach().clone(), None + opt_dw = opt_layer.weight.grad.detach().clone() + torch.testing.assert_close(ref_out.to(dtype), opt_out) + torch.testing.assert_close( + ref_dx, + opt_dx, + ) + torch.testing.assert_close( + ref_dw, + opt_dw, + ) diff --git a/recommendation_v4/generative_recommenders/ops/triton/_autotune_pinning.py b/recommendation_v4/generative_recommenders/ops/triton/_autotune_pinning.py new file mode 100644 index 000000000..5aa24eb85 --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/triton/_autotune_pinning.py @@ -0,0 +1,27 @@ +"""Triton autotune pinning helper. + +A handful of Triton kernels in this directory have two stable autotune +equilibria on MI350X gfx950 at our yambda bs=32 L=2039 shape: a fast one +(~52 ms/step) and a slow one (~71 ms/step). The autotuner's measurement +noise puts the choice on a coin flip per cold start. We pin the winning +config for these kernels so every cold start lands at the fast equilibrium +deterministically. + +Set `TRITON_FULL_AUTOTUNE=1` to bypass the pin and re-enable the full +autotune search (useful when validating a new shape, GPU, or Triton version +before re-capturing winners). +""" + +import os +from typing import Callable, List + +import triton + + +def pinned_or_full( + pinned: List[triton.Config], + full_configs_fn: Callable[[], List[triton.Config]], +) -> List[triton.Config]: + if os.environ.get("TRITON_FULL_AUTOTUNE", "0") == "1": + return full_configs_fn() + return pinned diff --git a/recommendation_v4/generative_recommenders/ops/triton/triton_addmm.py b/recommendation_v4/generative_recommenders/ops/triton/triton_addmm.py new file mode 100644 index 000000000..487aae189 --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/triton/triton_addmm.py @@ -0,0 +1,1708 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-unsafe + +#!/usr/bin/env python3 + + +import math +from typing import List, Optional, Tuple + +import torch + +# @manual=//triton:triton +import triton + +# @manual=//triton:triton +import triton.language as tl +from generative_recommenders.ops.utils import is_sm100_plus, maybe_register_custom_op + +try: + # @manual=//triton:triton + from triton.language.extra.subtile_ops import _split_n_2D +except ImportError: + _split_n_2D = None + +try: + # @manual=//triton:triton + import triton.language.extra.tlx as tlx # type: ignore + + HAS_TLX = True +except ImportError: + tlx = None + HAS_TLX = False + +from generative_recommenders.common import triton_autotune, triton_cc + +try: + # @manual=//triton:triton + from triton.tools.tensor_descriptor import TensorDescriptor + + TMA_AVAILABLE = True +except ImportError: + TMA_AVAILABLE = False + pass + + +ENABLE_FULL_TURNING_SPACE = False + + +def _use_meta_ws() -> bool: + """Check if Meta's warp specialization is available, enabled, and on SM100+.""" + return ( + is_sm100_plus() + and hasattr(triton, "knobs") + and hasattr(triton.knobs, "nvidia") + # `use_meta_ws` is absent in some Triton builds (e.g. nvcr.io/nvidia/pytorch:26.01-py3); + # use getattr so import doesn't crash on AttributeError before any step runs. + and getattr(triton.knobs.nvidia, "use_meta_ws", False) + ) + + +def _check_tma_alignment( + x: torch.Tensor, w: torch.Tensor, y: torch.Tensor, min_alignment: int = 16 +) -> bool: + """Check if tensors meet TMA alignment requirements. + + TMA (Tensor Memory Accelerator) on H100 requires: + 1. Base addresses to be 64-byte aligned + 2. Dimensions to be multiples of 64 for optimal performance + 3. Contiguous inner dimensions (stride=1) + + Args: + x: Input tensor [M, K] + w: Weight tensor [K, N] + y: Bias tensor [N] or [M, N] + min_alignment: Minimum alignment requirement (default: 64) + + Returns: + True if all tensors meet TMA alignment requirements + """ + _, K = x.shape + KB, N = w.shape + assert K == KB, f"incompatible dimensions {K}, {KB}" + + is_y_1d = y.dim() == 1 + NY = y.shape[0] if is_y_1d else y.shape[1] + assert N == NY, f"incompatible dimensions {N}, {NY}" + + return (K % min_alignment == 0) and (N % min_alignment == 0) + + +def _prune_persistent_autows_configs(configs, named_args, **kwargs): # noqa + if not _use_meta_ws(): + return configs + BROADCAST_Y = kwargs.get("BROADCAST_Y", False) + pruned = [] + for c in configs: + BLOCK_M = c.kwargs.get("BLOCK_M", 0) + BLOCK_N = c.kwargs.get("BLOCK_N", 0) + EPILOGUE_SUBTILE = c.kwargs.get("EPILOGUE_SUBTILE", 1) + DP = c.kwargs.get("DATA_PARTITION_FACTOR", 1) + # DATA_PARTITION_FACTOR=2 is only supported with BLOCK_M=256 + if DP == 2 and BLOCK_M != 256: + continue + if (BLOCK_N // EPILOGUE_SUBTILE) < 32: + continue + if BROADCAST_Y and (BLOCK_N // EPILOGUE_SUBTILE) < 64: + continue + pruned.append(c) + return pruned + + +def _prune_configs_for_tlx_persistent_addmm(configs, named_args, **kwargs): # noqa + M = named_args.get("M", 0) + N = named_args.get("N", 0) + BROADCAST_Y = kwargs.get("BROADCAST_Y", False) + + pruned = [] + for c in configs: + BLOCK_M = c.kwargs.get("BLOCK_M", 0) + BLOCK_N = c.kwargs.get("BLOCK_N", 0) + EPILOGUE_SUBTILE = c.kwargs.get("EPILOGUE_SUBTILE", 1) + NUM_MMA_GROUPS = c.kwargs.get("NUM_MMA_GROUPS", 1) + BLOCK_M_SPLIT = BLOCK_M // NUM_MMA_GROUPS + NUM_SMEM_BUFFERS = c.kwargs.get("NUM_SMEM_BUFFERS", 1) + + # Hardware constraint: Always make MMA tile 128. + if BLOCK_M_SPLIT != 128: + continue + + # BLOCK_N >= 64 required for PAIR_CTA + if BLOCK_N < 64: + continue + + # Subslice loads cannot be smaller than 32 + if (BLOCK_N // EPILOGUE_SUBTILE) < 32: + continue + + # TMA loads must be at least 128 bytes. With BROADCAST_Y + # this may not be met. + if BROADCAST_Y and (BLOCK_N // EPILOGUE_SUBTILE) < 64: + continue + + # Prune the support SMEM_BUFFER configurations. + if BROADCAST_Y: + if NUM_MMA_GROUPS == 1 and NUM_SMEM_BUFFERS != 5: + continue + elif NUM_MMA_GROUPS == 2 and NUM_SMEM_BUFFERS != 4: + continue + else: + if NUM_MMA_GROUPS == 1 and NUM_SMEM_BUFFERS != 4: + continue + elif NUM_MMA_GROUPS == 2 and NUM_SMEM_BUFFERS != 3: + continue + + # PAIR_CTA requires even number of M tiles and even total tiles + num_tiles_m = math.ceil(M / BLOCK_M) if BLOCK_M > 0 else 0 + num_tiles_n = math.ceil(N / BLOCK_N) if BLOCK_N > 0 else 0 + total_tiles = num_tiles_m * num_tiles_n + + # PAIR_CTA incompatible with MMA M=64 + pair_cta_compatible = ( + (num_tiles_m % 2 == 0) + and (total_tiles % 2 == 0) + and BLOCK_M == 128 + and NUM_MMA_GROUPS == 1 + ) + + c.kwargs["PAIR_CTA"] = pair_cta_compatible + # Set ctas_per_cga for CUDA-native cluster launch semantics (TLX way) + c.ctas_per_cga = (2, 1, 1) if pair_cta_compatible else None + + pruned.append(c) + return pruned + + +def get_mm_configs(pre_hook=None) -> List[triton.Config]: + if torch.version.hip: + if ENABLE_FULL_TURNING_SPACE: + block_m_range = [32, 64, 128, 256] + block_n_range = [32, 64, 128, 256] + block_k_range = [32, 64] + group_m_range = [4, 8] + matrix_instr_nonkdim_range = [16] + waves_per_eu_range = [0] + kpack_range = [1, 2] + num_warps_range = [4, 8] + num_stage_range = [2] if triton.__version__ >= "3.2.0" else [0] + else: + block_m_range = [256] + block_n_range = [256] + block_k_range = [32] + group_m_range = [8] + matrix_instr_nonkdim_range = [16] + waves_per_eu_range = [0] + kpack_range = [2] + num_warps_range = [8] + num_stage_range = [2] if triton.__version__ >= "3.2.0" else [0] + + return [ + triton.Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "GROUP_M": group_m, + "matrix_instr_nonkdim": matrix_instr_nonkdim, + "waves_per_eu": waves_per_eu, + "kpack": kpack, + }, + num_stages=num_stages, + num_warps=num_warps, + pre_hook=pre_hook, + ) + for block_m in block_m_range + for block_n in block_n_range + for block_k in block_k_range + for group_m in group_m_range + for matrix_instr_nonkdim in matrix_instr_nonkdim_range + for waves_per_eu in waves_per_eu_range + for kpack in kpack_range + for num_stages in num_stage_range + for num_warps in num_warps_range + ] + else: + block_m_range = [32, 64, 128, 256] + block_n_range = [32, 64, 128, 256] + block_k_range = [32, 64] + group_m_range = [4, 8] + # WARP_SPECIALIZE only works with num_warps >=4 + num_warps_range = [4, 8] if is_sm100_plus() else [2, 4, 8] + num_stage_range = [2, 3, 4, 5] + if ENABLE_FULL_TURNING_SPACE: + return [ + triton.Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "GROUP_M": group_m, + }, + num_stages=num_stages, + num_warps=num_warps, + pre_hook=pre_hook, + ) + for block_m in block_m_range + for block_n in block_n_range + for block_k in block_k_range + for group_m in group_m_range + for num_stages in num_stage_range + for num_warps in num_warps_range + ] + else: + configs = [ + triton.Config( + { + "BLOCK_M": 32, + "BLOCK_N": 64, + "BLOCK_K": 32, + "GROUP_M": 8, + }, + num_stages=5, + num_warps=2, + pre_hook=pre_hook, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 256, + "BLOCK_K": 64, + "GROUP_M": 8, + }, + num_stages=3, + num_warps=8, + pre_hook=pre_hook, + ), + triton.Config( + { + "BLOCK_M": 64, + "BLOCK_N": 256, + "BLOCK_K": 32, + "GROUP_M": 8, + }, + num_stages=4, + num_warps=4, + pre_hook=pre_hook, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 128, + "BLOCK_K": 32, + "GROUP_M": 8, + }, + num_stages=4, + num_warps=4, + pre_hook=pre_hook, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "BLOCK_K": 32, + "GROUP_M": 8, + }, + num_stages=4, + num_warps=4, + pre_hook=pre_hook, + ), + triton.Config( + { + "BLOCK_M": 64, + "BLOCK_N": 128, + "BLOCK_K": 32, + "GROUP_M": 8, + }, + num_stages=4, + num_warps=4, + pre_hook=pre_hook, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 32, + "BLOCK_K": 32, + "GROUP_M": 8, + }, + num_stages=4, + num_warps=4, + pre_hook=pre_hook, + ), + triton.Config( + { + "BLOCK_M": 64, + "BLOCK_N": 32, + "BLOCK_K": 32, + "GROUP_M": 8, + }, + num_stages=5, + num_warps=2, + pre_hook=pre_hook, + ), + ] + if is_sm100_plus(): + configs += [ + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 256, + "BLOCK_K": 64, + "GROUP_M": 8, + }, + num_stages=3, + num_warps=4, + pre_hook=pre_hook, + ), + ] + return [c for c in configs if c.num_warps >= 4] + + return configs + + +def _get_addmm_tma_ws_persistent_configs(pre_hook=None) -> List[triton.Config]: + """Get configs for _addmm_fwd_tma_ws_persistent (sm100+ TLX kernel). + + This kernel has unique requirements (warp specialization, PAIR_CTA, + EPILOGUE_SUBTILE) that don't apply to the other addmm kernels. + """ + if ENABLE_FULL_TURNING_SPACE: + block_m_range = [64, 128, 256] + block_n_range = [64, 128, 256] + block_k_range = [64, 128, 256] + group_m_range = [8] + num_warps_range = [4] + num_stage_range = [1] + epilogue_subtile_range = [1, 2, 4] + num_mma_groups_range = [1, 2] + return [ + triton.Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "GROUP_M": group_m, + "EPILOGUE_SUBTILE": epilogue_subtile, + "NUM_MMA_GROUPS": num_mma_groups, + "NUM_TMEM_BUFFERS": 1 if num_mma_groups == 2 else 2, + "NUM_SMEM_BUFFERS": num_smem_buffers, + }, + num_stages=num_stages, + num_warps=num_warps, + pre_hook=pre_hook, + ) + for block_m in block_m_range + for block_n in block_n_range + for block_k in block_k_range + for group_m in group_m_range + for num_stages in num_stage_range + for num_warps in num_warps_range + for epilogue_subtile in epilogue_subtile_range + for num_mma_groups in num_mma_groups_range + for num_smem_buffers in [3, 4, 5] + ] + else: + configs = [] + for block_m, block_n, block_k in [ + (128, 256, 64), + (128, 128, 64), + (64, 128, 64), + (64, 256, 64), + (128, 64, 128), + ]: + # Note: num_smem_buffers is pruned to 1 in + # the pruning function. + for num_smem_buffers in [3, 4, 5]: + configs.append( + triton.Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "GROUP_M": 8, + "EPILOGUE_SUBTILE": 1, + "NUM_MMA_GROUPS": 1, + "NUM_TMEM_BUFFERS": 2, + "NUM_SMEM_BUFFERS": num_smem_buffers, + }, + num_stages=1, + num_warps=4, + pre_hook=pre_hook, + ), + ) + return configs + + +def get_triton_persistent_configs(pre_hook=None) -> List[triton.Config]: + if not _use_meta_ws(): + configs = get_mm_configs(pre_hook=pre_hook) + for c in configs: + c.kwargs["DATA_PARTITION_FACTOR"] = 1 + c.kwargs["EPILOGUE_SUBTILE"] = 1 + return configs + # TODO: Prune configs to best configs. + return [ + triton.Config( # pyre-ignore[28] + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "GROUP_M": 8, + "EPILOGUE_SUBTILE": subtile, + "DATA_PARTITION_FACTOR": DP, + }, + num_stages=num_stages, + num_warps=4, + pre_hook=pre_hook, + early_tma_store_lowering=1, + maxRegAutoWS=255, + ) + for block_m in [64, 128, 256] + for block_n in [64, 128, 256] + for block_k in [64, 128, 256] + for num_stages in [2, 3, 4] + for subtile in [1, 2, 4, 8] + for DP in [1, 2] + ] + + +@triton_cc( + annotations={ + "M": "i32", + "N": ("i32", 16), + "K": ("i32", 16), + "stride_xm": ("i32", 16), + "stride_xk": ("i32", 1), + "stride_wk": ("i32", 16), + "stride_wn": ("i32", 1), + "stride_ym": ("i32", 16), + "stride_yn": ("i32", 1), + "stride_zm": ("i32", 16), + "stride_zn": ("i32", 1), + }, +) +@triton_autotune( + configs=get_mm_configs(), + key=["N", "K"], +) +@triton.jit +def _addmm_fwd( + x_ptr, + w_ptr, + y_ptr, + z_ptr, + M, + N, + K, + stride_xm, + stride_xk, + stride_wk, + stride_wn, + stride_ym, + stride_yn, + stride_zm, + stride_zn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BROADCAST_Y: tl.constexpr, +): + pid_0, pid_1 = tl.program_id(axis=0), tl.program_id(axis=1) + pid = pid_0 * tl.num_programs(axis=1) + pid_1 + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, BLOCK_K) + offs_n = tl.arange(0, BLOCK_N) + mask_m = (pid_m * BLOCK_M + offs_m)[:, None] < M + mask_n = (pid_n * BLOCK_N + offs_n)[None, :] < N + x_ptr += pid_m.to(tl.int64) * BLOCK_M * stride_xm + x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk) + w_ptr += pid_n.to(tl.int64) * BLOCK_N * stride_wn + w_ptrs = w_ptr + (offs_k[:, None] * stride_wk + offs_n[None, :] * stride_wn) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + mask_k = offs_k[None, :] < K - k * BLOCK_K + x = tl.load(x_ptrs, mask=mask_k & mask_m, other=0.0) + mask_k = offs_k[:, None] < K - k * BLOCK_K + w = tl.load(w_ptrs, mask=mask_k & mask_n, other=0.0) + accumulator += tl.dot(x, w, allow_tf32=ALLOW_TF32) + x_ptrs += BLOCK_K * stride_xk + w_ptrs += BLOCK_K * stride_wk + + z_mask = mask_m & mask_n + if BROADCAST_Y: + # y is a vector, broadcast to add to z + y_ptr += pid_n.to(tl.int64) * BLOCK_N * stride_yn + y_ptrs = y_ptr + stride_yn * offs_n[None, :] + y = tl.load(y_ptrs, mask=mask_n) + else: + y_ptr += pid_m.to(tl.int64) * BLOCK_M * stride_ym + y_ptr += pid_n.to(tl.int64) * BLOCK_N * stride_yn + y_ptrs = y_ptr + stride_ym * offs_m[:, None] + stride_yn * offs_n[None, :] + y = tl.load(y_ptrs, mask=z_mask) + z = (accumulator + y.to(tl.float32)).to(z_ptr.dtype.element_ty) + z_ptr += pid_m.to(tl.int64) * BLOCK_M * stride_zm + z_ptr += pid_n.to(tl.int64) * BLOCK_N * stride_zn + z_ptrs = z_ptr + stride_zm * offs_m[:, None] + stride_zn * offs_n[None, :] + tl.store(z_ptrs, z, mask=z_mask) + + +def _addmm_tma_set_block_size_hook(nargs): + BLOCK_M = nargs["BLOCK_M"] + BLOCK_N = nargs["BLOCK_N"] + BLOCK_K = nargs["BLOCK_K"] + NUM_MMA_GROUPS = nargs.get("NUM_MMA_GROUPS", 1) + BLOCK_M_SPLIT = BLOCK_M // NUM_MMA_GROUPS + PAIR_CTA = nargs.get("PAIR_CTA", False) + nargs["x_desc"].block_shape = [BLOCK_M_SPLIT, BLOCK_K] + # In PAIR_CTA mode, each CTA loads BLOCK_N // 2 of W + if PAIR_CTA: + nargs["w_desc"].block_shape = [BLOCK_K, BLOCK_N // 2] + else: + nargs["w_desc"].block_shape = [BLOCK_K, BLOCK_N] + EPILOGUE_SUBTILE = nargs.get("EPILOGUE_SUBTILE", 1) + nargs["z_desc"].block_shape = [BLOCK_M_SPLIT, BLOCK_N // EPILOGUE_SUBTILE] + if nargs["BROADCAST_Y"]: + nargs["y_desc"].block_shape = [1, BLOCK_N // EPILOGUE_SUBTILE] + else: + nargs["y_desc"].block_shape = [BLOCK_M_SPLIT, BLOCK_N // EPILOGUE_SUBTILE] + + +@triton.jit +def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS): + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + return pid_m, pid_n + + +@triton.jit +def _addmm_persistent_tile_body( + x_desc, + w_desc, + y_desc, + z_desc, + tile_id, + num_pid_in_group, + num_pid_m, + k_tiles, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BROADCAST_Y: tl.constexpr, + NUM_SMS: tl.constexpr, + EPILOGUE_SUBTILE: tl.constexpr, + INNER_WARP_SPECIALIZE: tl.constexpr, +): + pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_M, NUM_SMS) + offs_xm = pid_m * BLOCK_M + offs_wn = pid_n * BLOCK_N + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in tl.range(0, k_tiles, warp_specialize=INNER_WARP_SPECIALIZE): + offs_k = k * BLOCK_K + x = x_desc.load([offs_xm, offs_k]) + w = w_desc.load([offs_k, offs_wn]) + accumulator = tl.dot(x, w, accumulator, allow_tf32=ALLOW_TF32) + + # Epilogue subtiling breaks the store into multiple pieces to reduce + # shared memory consumption and allow higher stage counts. + tl.static_assert( + EPILOGUE_SUBTILE <= 8, + "EPILOGUE_SUBTILE > 8 is not supported", + ) + acc_subtiles = _split_n_2D(accumulator, EPILOGUE_SUBTILE) # pyre-ignore[16] + slice_size: tl.constexpr = BLOCK_N // EPILOGUE_SUBTILE + for i in tl.static_range(EPILOGUE_SUBTILE): + if BROADCAST_Y: + y_i = y_desc.load([0, offs_wn + i * slice_size]) + else: + y_i = y_desc.load([offs_xm, offs_wn + i * slice_size]) + z_i = (acc_subtiles[i] + y_i.to(tl.float32)).to(z_desc.dtype) + z_desc.store([offs_xm, offs_wn + i * slice_size], z_i) + + +@triton_autotune( + configs=get_triton_persistent_configs(pre_hook=_addmm_tma_set_block_size_hook), + key=["M", "N", "K", "WARP_SPECIALIZE"], + prune_configs_by={"early_config_prune": _prune_persistent_autows_configs}, +) +@triton.jit +def _addmm_fwd_tma_persistent( + x_desc, + w_desc, + y_desc, + z_desc, + M, + N, + K, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BROADCAST_Y: tl.constexpr, + WARP_SPECIALIZE: tl.constexpr, + NUM_SMS: tl.constexpr, + EPILOGUE_SUBTILE: tl.constexpr, + DATA_PARTITION_FACTOR: tl.constexpr, + USE_META_WS: tl.constexpr, +): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = num_pid_m * num_pid_n + + num_pid_in_group = GROUP_M * num_pid_n + + if USE_META_WS: + # Some arguments are only available in FBexperimental. + # pyre-ignore[28]: smem_alloc_algo is FBexperimental + for tile_id in tl.range( + start_pid, + num_tiles, + NUM_SMS, + flatten=False, + warp_specialize=WARP_SPECIALIZE, + data_partition_factor=DATA_PARTITION_FACTOR, + smem_alloc_algo=1, + ): + _addmm_persistent_tile_body( + x_desc, + w_desc, + y_desc, + z_desc, + tile_id, + num_pid_in_group, + num_pid_m, + k_tiles, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + GROUP_M=GROUP_M, + ALLOW_TF32=ALLOW_TF32, + BROADCAST_Y=BROADCAST_Y, + NUM_SMS=NUM_SMS, + EPILOGUE_SUBTILE=EPILOGUE_SUBTILE, + INNER_WARP_SPECIALIZE=tl.constexpr(False), + ) + else: + # Pure OAI Triton version. + for tile_id in tl.range( + start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=WARP_SPECIALIZE + ): + _addmm_persistent_tile_body( + x_desc, + w_desc, + y_desc, + z_desc, + tile_id, + num_pid_in_group, + num_pid_m, + k_tiles, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + GROUP_M=GROUP_M, + ALLOW_TF32=ALLOW_TF32, + BROADCAST_Y=BROADCAST_Y, + NUM_SMS=NUM_SMS, + EPILOGUE_SUBTILE=EPILOGUE_SUBTILE, + INNER_WARP_SPECIALIZE=WARP_SPECIALIZE, + ) + + +@triton_autotune( + configs=get_mm_configs(pre_hook=_addmm_tma_set_block_size_hook), + key=["N", "K"], +) +@triton.jit +def _addmm_fwd_tma_ws( + x_desc, + w_desc, + y_desc, + z_desc, + M, + N, + K, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BROADCAST_Y: tl.constexpr, + NUM_SMEM_BUFFERS: tl.constexpr, +): + x_buffers = tlx.local_alloc((BLOCK_M, BLOCK_K), x_desc.dtype, NUM_SMEM_BUFFERS) + w_buffers = tlx.local_alloc((BLOCK_K, BLOCK_N), w_desc.dtype, NUM_SMEM_BUFFERS) + acc_tmem_buffer = tlx.local_alloc( + (BLOCK_M, BLOCK_N), tl.float32, tl.constexpr(1), tlx.storage_kind.tmem + ) + + if BROADCAST_Y: + y_buffer = tlx.local_alloc((1, BLOCK_N), y_desc.dtype, tl.constexpr(1)) + else: + y_buffer = tlx.local_alloc((BLOCK_M, BLOCK_N), y_desc.dtype, tl.constexpr(1)) + z_buffer = tlx.local_alloc((BLOCK_M, BLOCK_N), z_desc.dtype, tl.constexpr(1)) + + smem_full_bars = tlx.alloc_barriers(num_barriers=NUM_SMEM_BUFFERS, arrive_count=1) + smem_empty_bars = tlx.alloc_barriers(num_barriers=NUM_SMEM_BUFFERS, arrive_count=1) + y_load_barrier = tlx.alloc_barriers(num_barriers=1, arrive_count=1) + + with tlx.async_tasks(): + # Producer task: TMA loads + with tlx.async_task("default"): + pid_0, pid_1 = tl.program_id(axis=0), tl.program_id(axis=1) + pid = pid_0 * tl.num_programs(axis=1) + pid_1 + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_xm = pid_m * BLOCK_M + offs_wn = pid_n * BLOCK_N + k_tiles = tl.cdiv(K, BLOCK_K) + + load_phase = 0 + for k in range(0, k_tiles): + buf = k % int(NUM_SMEM_BUFFERS) + + # Wait for buffer to be free + if k >= NUM_SMEM_BUFFERS: + tlx.barrier_wait(smem_empty_bars[buf], load_phase ^ 1) + + offs_k = k * BLOCK_K + tlx.barrier_expect_bytes( + smem_full_bars[buf], + 2 * (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N), + ) + tlx.async_descriptor_load( + x_desc, x_buffers[buf], [offs_xm, offs_k], smem_full_bars[buf] + ) + tlx.async_descriptor_load( + w_desc, w_buffers[buf], [offs_k, offs_wn], smem_full_bars[buf] + ) + + load_phase = load_phase ^ (buf == NUM_SMEM_BUFFERS - 1) + + # Consumer task: async_dot MMA + with tlx.async_task(num_warps=4, num_regs=232): + pid_0, pid_1 = tl.program_id(axis=0), tl.program_id(axis=1) + pid = pid_0 * tl.num_programs(axis=1) + pid_1 + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_xm = pid_m * BLOCK_M + offs_wn = pid_n * BLOCK_N + k_tiles = tl.cdiv(K, BLOCK_K) + + # Start async load of y early + y_buf_view = tlx.local_view(y_buffer, 0) + y_load_bar = tlx.local_view(y_load_barrier, 0) + if BROADCAST_Y: + tlx.barrier_expect_bytes(y_load_bar, 1 * BLOCK_N * 2) + tlx.async_descriptor_load(y_desc, y_buf_view, [0, offs_wn], y_load_bar) + else: + tlx.barrier_expect_bytes(y_load_bar, BLOCK_M * BLOCK_N * 2) + tlx.async_descriptor_load( + y_desc, y_buf_view, [offs_xm, offs_wn], y_load_bar + ) + + dot_phase = 0 + for k in range(0, k_tiles): + buf = k % int(NUM_SMEM_BUFFERS) + tlx.barrier_wait(smem_full_bars[buf], dot_phase) + + tlx.async_dot( + x_buffers[buf], + w_buffers[buf], + acc_tmem_buffer[0], + use_acc=k > 0, + mBarriers=[smem_empty_bars[buf]], + out_dtype=tl.float32, + ) + + dot_phase = dot_phase ^ (buf == NUM_SMEM_BUFFERS - 1) + + last_buf = (k_tiles - 1) % NUM_SMEM_BUFFERS + last_dot_phase = dot_phase ^ (last_buf == NUM_SMEM_BUFFERS - 1) + tlx.barrier_wait(smem_empty_bars[last_buf], last_dot_phase) + + tmem_result = tlx.local_load(acc_tmem_buffer[0]) + + tlx.barrier_wait(y_load_bar, 0) + y = tlx.local_load(y_buf_view) + + z = (tmem_result + y.to(tl.float32)).to(z_desc.dtype) + z_buf_view = tlx.local_view(z_buffer, 0) + tlx.local_store(z_buf_view, z) + tlx.async_descriptor_store(z_desc, z_buf_view, [offs_xm, offs_wn]) + tlx.async_descriptor_store_wait(0) + + +@triton_autotune( + configs=_get_addmm_tma_ws_persistent_configs( + pre_hook=_addmm_tma_set_block_size_hook + ), + key=["M", "N", "K"], + prune_configs_by={"early_config_prune": _prune_configs_for_tlx_persistent_addmm}, +) +@triton.jit +def _addmm_fwd_tma_ws_persistent( + x_desc, + w_desc, + y_desc, + z_desc, + M, + N, + K, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BROADCAST_Y: tl.constexpr, + NUM_SMEM_BUFFERS: tl.constexpr, + NUM_TMEM_BUFFERS: tl.constexpr, + NUM_SMS: tl.constexpr, + EPILOGUE_SUBTILE: tl.constexpr, + PAIR_CTA: tl.constexpr, + NUM_MMA_GROUPS: tl.constexpr, +): + BLOCK_M_SPLIT: tl.constexpr = BLOCK_M // NUM_MMA_GROUPS + + # Allocate buffers once for all tiles + x_buffers = tlx.local_alloc( + (BLOCK_M_SPLIT, BLOCK_K), x_desc.dtype, NUM_SMEM_BUFFERS * NUM_MMA_GROUPS + ) + # In pair CTA mode, each CTA only needs to load half of W + if PAIR_CTA: + w_buffers = tlx.local_alloc( + (BLOCK_K, BLOCK_N // 2), w_desc.dtype, NUM_SMEM_BUFFERS + ) + else: + w_buffers = tlx.local_alloc((BLOCK_K, BLOCK_N), w_desc.dtype, NUM_SMEM_BUFFERS) + tmem_buffers = tlx.local_alloc( + (BLOCK_M_SPLIT, BLOCK_N), + tl.float32, + NUM_TMEM_BUFFERS * NUM_MMA_GROUPS, + tlx.storage_kind.tmem, + ) + slice_size: tl.constexpr = BLOCK_N // EPILOGUE_SUBTILE + + Y_Z_SHARED: tl.constexpr = NUM_MMA_GROUPS == 2 and not BROADCAST_Y + if Y_Z_SHARED: + NUM_Z_BUFFERS: tl.constexpr = EPILOGUE_SUBTILE * NUM_MMA_GROUPS + else: + NUM_Z_BUFFERS: tl.constexpr = NUM_MMA_GROUPS + + if Y_Z_SHARED: + bias_storage_alias = tlx.storage_alias_spec(storage=tlx.storage_kind.smem) + y_buffers = tlx.local_alloc( + (BLOCK_M_SPLIT, slice_size), + y_desc.dtype, + NUM_Z_BUFFERS, + reuse=bias_storage_alias, + ) + z_buffers = tlx.local_alloc( + (BLOCK_M_SPLIT, slice_size), + z_desc.dtype, + NUM_Z_BUFFERS, + reuse=bias_storage_alias, + ) + # Define y and z to share a single buffer + bias_storage_alias.set_buffer_overlap( + tlx.reuse_group( + y_buffers, + z_buffers, + group_type=tlx.reuse_group_type.shared, + ) + ) + else: + if BROADCAST_Y: + y_buffers = tlx.local_alloc( + (1, slice_size), y_desc.dtype, EPILOGUE_SUBTILE * NUM_MMA_GROUPS + ) + else: + y_buffers = tlx.local_alloc( + (BLOCK_M_SPLIT, slice_size), + y_desc.dtype, + EPILOGUE_SUBTILE * NUM_MMA_GROUPS, + ) + z_buffers = tlx.local_alloc( + (BLOCK_M_SPLIT, slice_size), z_desc.dtype, NUM_Z_BUFFERS + ) + + cluster_cta_rank = tlx.cluster_cta_rank() + pred_cta0 = cluster_cta_rank == 0 + if PAIR_CTA: + cta_bars = tlx.alloc_barriers( + num_barriers=NUM_SMEM_BUFFERS * NUM_MMA_GROUPS, arrive_count=2 + ) + + # Barriers for producer <-> MMA (separate X and W barriers) + x_smem_full_bars = tlx.alloc_barriers( + num_barriers=NUM_SMEM_BUFFERS * NUM_MMA_GROUPS, arrive_count=1 + ) + x_smem_empty_bars = tlx.alloc_barriers( + num_barriers=NUM_SMEM_BUFFERS * NUM_MMA_GROUPS, arrive_count=1 + ) + w_smem_full_bars = tlx.alloc_barriers(num_barriers=NUM_SMEM_BUFFERS, arrive_count=1) + # Barriers for MMA <-> Epilogue + tmem_full_bars = tlx.alloc_barriers( + num_barriers=NUM_TMEM_BUFFERS * NUM_MMA_GROUPS, arrive_count=1 + ) + tmem_empty_bars = tlx.alloc_barriers( + num_barriers=NUM_TMEM_BUFFERS * NUM_MMA_GROUPS, arrive_count=1 + ) + # Barriers for producer <-> Epilogue + # y_load_bar: producer signals when y data is ready + # y_empty_bar: epilogue signals when done using y buffer + y_load_bars = tlx.alloc_barriers( + num_barriers=EPILOGUE_SUBTILE * NUM_MMA_GROUPS, arrive_count=1 + ) + y_empty_bars = tlx.alloc_barriers( + num_barriers=EPILOGUE_SUBTILE * NUM_MMA_GROUPS, arrive_count=1 + ) + z_load_bars = tlx.alloc_barriers(num_barriers=NUM_Z_BUFFERS, arrive_count=1) + z_empty_bars = tlx.alloc_barriers(num_barriers=NUM_Z_BUFFERS, arrive_count=1) + + with tlx.async_tasks(): + # Epilogue consumer: waits for Y from producer, adds bias, stores to SMEM. + with tlx.async_task("default"): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + if PAIR_CTA: + # Round up to even for proper CTA pairing + num_pid_m = (num_pid_m + 1) // 2 * 2 + num_pid_n = tl.cdiv(N, BLOCK_N) + num_tiles = num_pid_m * num_pid_n + + tmem_read_phase = 0 + cur_tmem_buf = 0 + y_load_phase = 0 + z_load_phase = 0 + + z_idx = 0 + for _ in range(start_pid, num_tiles, NUM_SMS): + for slice_id in tl.static_range(EPILOGUE_SUBTILE): + for group_id in tl.static_range(NUM_MMA_GROUPS): + buf_idx = group_id * NUM_TMEM_BUFFERS + cur_tmem_buf + acc_tmem = tmem_buffers[buf_idx] + if slice_id == 0: + # Wait for MMA to finish computing this group + tlx.barrier_wait(tmem_full_bars[buf_idx], tmem_read_phase) + + # Load result from TMEM and add bias + acc_subslice = tlx.subslice( + acc_tmem, slice_id * slice_size, slice_size + ) + result = tlx.local_load(acc_subslice) + if slice_id == EPILOGUE_SUBTILE - 1: + # Signal MMA that this TMEM buffer is now free + tlx.barrier_arrive(tmem_empty_bars[buf_idx], 1) + + y_idx = slice_id * NUM_MMA_GROUPS + group_id + y_buf_view = tlx.local_view(y_buffers, y_idx) + y_full = tlx.local_view(y_load_bars, y_idx) + tlx.barrier_wait(y_full, y_load_phase) + y = tlx.local_load(y_buf_view) + # If Y and Z are not shared signal we can load the next bias. + if not Y_Z_SHARED: + y_empty = tlx.local_view(y_empty_bars, y_idx) + tlx.barrier_arrive(y_empty, 1) + z = (result + y.to(tl.float32)).to(z_desc.dtype) + z_buf_view = tlx.local_view(z_buffers, z_idx) + # If Y and Z are not shared wait for Z to be empty. + # If there are shared this already guaranteed. + if not Y_Z_SHARED: + z_empty = tlx.local_view(z_empty_bars, z_idx) + tlx.barrier_wait(z_empty, z_load_phase ^ 1) + tlx.local_store(z_buf_view, z) + z_full = tlx.local_view(z_load_bars, z_idx) + tlx.barrier_arrive(z_full, 1) + z_load_phase = z_load_phase ^ (z_idx == (NUM_Z_BUFFERS - 1)) + # pyre-ignore[58] + z_idx = (z_idx + 1) % NUM_Z_BUFFERS + + tmem_read_phase = tmem_read_phase ^ ( + cur_tmem_buf == int(NUM_TMEM_BUFFERS) - 1 + ) + y_load_phase = y_load_phase ^ 1 + + cur_tmem_buf = (cur_tmem_buf + 1) % int(NUM_TMEM_BUFFERS) + + # MMA consumer: performs matrix multiplication + with tlx.async_task(num_warps=1, num_regs=24): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + if PAIR_CTA: + # Round up to even for proper CTA pairing + num_pid_m = (num_pid_m + 1) // 2 * 2 + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_M * num_pid_n + num_tiles = num_pid_m * num_pid_n + k_tiles = tl.cdiv(K, BLOCK_K) + + dot_phase = 0 + tmem_write_phase = 1 + cur_tmem_buf = 0 + processed_k_iters = 0 + + for tile_id in range(start_pid, num_tiles, NUM_SMS): + pid_m, pid_n = _compute_pid( + tile_id, num_pid_in_group, num_pid_m, GROUP_M, NUM_SMS + ) + + # First K iteration (peeled): use_acc=False + buf = processed_k_iters % int(NUM_SMEM_BUFFERS) + tlx.barrier_wait(w_smem_full_bars[buf], dot_phase) + + for group_id in tl.static_range(NUM_MMA_GROUPS): + a_buf = group_id * NUM_SMEM_BUFFERS + buf + acc_buf = group_id * NUM_TMEM_BUFFERS + cur_tmem_buf + + tlx.barrier_wait(x_smem_full_bars[a_buf], dot_phase) + + # Wait for epilogue to finish with this TMEM buffer + tlx.barrier_wait(tmem_empty_bars[acc_buf], tmem_write_phase) + + if PAIR_CTA: + # pyre-ignore[61] + tlx.barrier_arrive(cta_bars[a_buf], 1, remote_cta_rank=0) + # pyre-ignore[61] + tlx.barrier_wait( + # pyre-ignore[61] + cta_bars[a_buf], + phase=dot_phase, + pred=pred_cta0, + ) + + tlx.async_dot( + x_buffers[a_buf], + w_buffers[buf], + tmem_buffers[acc_buf], + use_acc=False, + mBarriers=[x_smem_empty_bars[a_buf]], + two_ctas=PAIR_CTA, + out_dtype=tl.float32, + ) + + dot_phase = dot_phase ^ (buf == int(NUM_SMEM_BUFFERS) - 1) + + # Remaining K iterations: use_acc=True + for k in range(1, k_tiles): + buf = (processed_k_iters + k) % int(NUM_SMEM_BUFFERS) + tlx.barrier_wait(w_smem_full_bars[buf], dot_phase) + + for group_id in tl.static_range(NUM_MMA_GROUPS): + a_buf = group_id * NUM_SMEM_BUFFERS + buf + acc_buf = group_id * NUM_TMEM_BUFFERS + cur_tmem_buf + + tlx.barrier_wait(x_smem_full_bars[a_buf], dot_phase) + + if PAIR_CTA: + # pyre-ignore[61] + tlx.barrier_arrive(cta_bars[a_buf], 1, remote_cta_rank=0) + # pyre-ignore[61] + tlx.barrier_wait( + # pyre-ignore[61] + cta_bars[a_buf], + phase=dot_phase, + # pyre-ignore[61] + pred=pred_cta0, + ) + + tlx.async_dot( + x_buffers[a_buf], + w_buffers[buf], + tmem_buffers[acc_buf], + use_acc=True, + mBarriers=[x_smem_empty_bars[a_buf]], + two_ctas=PAIR_CTA, + out_dtype=tl.float32, + ) + + dot_phase = dot_phase ^ (buf == int(NUM_SMEM_BUFFERS) - 1) + + # Wait for last MMA to complete and signal epilogue + last_buf = (processed_k_iters + k_tiles - 1) % int(NUM_SMEM_BUFFERS) + last_dot_phase = dot_phase ^ (last_buf == int(NUM_SMEM_BUFFERS) - 1) + for group_id in tl.static_range(NUM_MMA_GROUPS): + a_buf = group_id * NUM_SMEM_BUFFERS + last_buf + tlx.barrier_wait(x_smem_empty_bars[a_buf], last_dot_phase) + acc_buf = group_id * NUM_TMEM_BUFFERS + cur_tmem_buf + # Signal epilogue that result is ready + tlx.barrier_arrive(tmem_full_bars[acc_buf], 1) + + tmem_write_phase = tmem_write_phase ^ ( + cur_tmem_buf == int(NUM_TMEM_BUFFERS) - 1 + ) + cur_tmem_buf = (cur_tmem_buf + 1) % int(NUM_TMEM_BUFFERS) + processed_k_iters += k_tiles + + # Producer: TMA loads for X, W, and Y + with tlx.async_task(num_warps=1, num_regs=24): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + if PAIR_CTA: + # Round up to even for proper CTA pairing + num_pid_m = (num_pid_m + 1) // 2 * 2 + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_M * num_pid_n + num_tiles = num_pid_m * num_pid_n + k_tiles = tl.cdiv(K, BLOCK_K) + + load_phase = 0 + y_load_phase = 0 + processed_k_iters = 0 + + for tile_id in range(start_pid, num_tiles, NUM_SMS): + pid_m, pid_n = _compute_pid( + tile_id, num_pid_in_group, num_pid_m, GROUP_M, NUM_SMS + ) + offs_xm = pid_m * BLOCK_M + # Full tile offset for y loading (both CTAs use same y) + offs_wn_full = pid_n * BLOCK_N + # Split W into two parts so each CTA has different offset + if PAIR_CTA: + # pyre-ignore[61] + offs_wn = pid_n * BLOCK_N + cluster_cta_rank * (BLOCK_N // 2) + else: + offs_wn = pid_n * BLOCK_N + + for k in range(0, k_tiles): + buf = (processed_k_iters + k) % int(NUM_SMEM_BUFFERS) + offs_k = k * BLOCK_K + + # Load X for group 0 + a_buf = buf # 0 * NUM_SMEM_BUFFERS + buf + tlx.barrier_wait(x_smem_empty_bars[a_buf], load_phase ^ 1) + tlx.barrier_expect_bytes( + x_smem_full_bars[a_buf], + 2 * BLOCK_M_SPLIT * BLOCK_K, + ) + tlx.async_descriptor_load( + x_desc, + x_buffers[a_buf], + [offs_xm, offs_k], + x_smem_full_bars[a_buf], + ) + + # Load W (wait for last group's x_empty to know W is free) + last_a_buf = (NUM_MMA_GROUPS - 1) * NUM_SMEM_BUFFERS + buf + tlx.barrier_wait(x_smem_empty_bars[last_a_buf], load_phase ^ 1) + if PAIR_CTA: + tlx.barrier_expect_bytes( + w_smem_full_bars[buf], + 2 * BLOCK_K * (BLOCK_N // 2), + ) + else: + tlx.barrier_expect_bytes( + w_smem_full_bars[buf], + 2 * BLOCK_K * BLOCK_N, + ) + tlx.async_descriptor_load( + w_desc, + w_buffers[buf], + [offs_k, offs_wn], + w_smem_full_bars[buf], + ) + + # Load X for remaining groups + for group_id in tl.static_range(1, NUM_MMA_GROUPS): + a_buf = group_id * NUM_SMEM_BUFFERS + buf + tlx.barrier_wait(x_smem_empty_bars[a_buf], load_phase ^ 1) + offs_xm2 = offs_xm + group_id * BLOCK_M_SPLIT + tlx.barrier_expect_bytes( + x_smem_full_bars[a_buf], + 2 * BLOCK_M_SPLIT * BLOCK_K, + ) + tlx.async_descriptor_load( + x_desc, + x_buffers[a_buf], + [offs_xm2, offs_k], + x_smem_full_bars[a_buf], + ) + + load_phase = load_phase ^ (buf == int(NUM_SMEM_BUFFERS) - 1) + + for slice_id in tl.static_range(EPILOGUE_SUBTILE): + for group_id in tl.static_range(NUM_MMA_GROUPS): + y_idx = slice_id * NUM_MMA_GROUPS + group_id + y_buf_view = tlx.local_view(y_buffers, y_idx) + y_bar = tlx.local_view(y_load_bars, y_idx) + # If Y and Z are shared we need to wait for Z to be empty. + if Y_Z_SHARED: + y_empty = tlx.local_view(z_empty_bars, y_idx) + else: + y_empty = tlx.local_view(y_empty_bars, y_idx) + tlx.barrier_wait(y_empty, y_load_phase ^ 1) + if BROADCAST_Y: + tlx.barrier_expect_bytes(y_bar, 1 * slice_size * 2) + tlx.async_descriptor_load( + y_desc, + y_buf_view, + [0, offs_wn_full + slice_id * slice_size], + y_bar, + ) + else: + tlx.barrier_expect_bytes( + y_bar, BLOCK_M_SPLIT * slice_size * 2 + ) + tlx.async_descriptor_load( + y_desc, + y_buf_view, + [ + offs_xm + group_id * BLOCK_M_SPLIT, + offs_wn_full + slice_id * slice_size, + ], + y_bar, + ) + + y_load_phase = y_load_phase ^ 1 + + processed_k_iters += k_tiles + + # TMA Store consumer. Added to simplify the barrier + # logic. + with tlx.async_task(num_warps=1, num_regs=24): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + if PAIR_CTA: + # Round up to even for proper CTA pairing + num_pid_m = (num_pid_m + 1) // 2 * 2 + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_M * num_pid_n + num_tiles = num_pid_m * num_pid_n + z_load_phase = 0 + + # Unroll the first iteration. + # This guraranteed safe from our grid size. + pid_m, pid_n = _compute_pid( + start_pid, num_pid_in_group, num_pid_m, GROUP_M, NUM_SMS + ) + offs_xm = pid_m * BLOCK_M + offs_wn = pid_n * BLOCK_N + z_idx = 0 + for slice_id in tl.static_range(EPILOGUE_SUBTILE): + for group_id in tl.static_range(NUM_MMA_GROUPS): + # Determine the base "index" to decide if we need to wait on TMA. + z_idx_unrolled = slice_id * NUM_MMA_GROUPS + group_id + if z_idx_unrolled >= NUM_Z_BUFFERS: + tlx.async_descriptor_store_wait(NUM_Z_BUFFERS - 1) + z_empty = tlx.local_view(z_empty_bars, z_idx) + tlx.barrier_arrive(z_empty, 1) + + z_full = tlx.local_view(z_load_bars, z_idx) + tlx.barrier_wait(z_full, z_load_phase) + z_buf_view = tlx.local_view(z_buffers, z_idx) + tlx.fence_async_shared() + tlx.async_descriptor_store( + z_desc, + z_buf_view, + [ + offs_xm + group_id * BLOCK_M_SPLIT, + offs_wn + slice_id * slice_size, + ], + ) + + z_load_phase = z_load_phase ^ (z_idx == (NUM_Z_BUFFERS - 1)) + # pyre-ignore[58] + z_idx = (z_idx + 1) % NUM_Z_BUFFERS + + for tile_id in range(start_pid + NUM_SMS, num_tiles, NUM_SMS): + pid_m, pid_n = _compute_pid( + tile_id, num_pid_in_group, num_pid_m, GROUP_M, NUM_SMS + ) + offs_xm = pid_m * BLOCK_M + offs_wn = pid_n * BLOCK_N + for slice_id in tl.static_range(EPILOGUE_SUBTILE): + for group_id in tl.static_range(NUM_MMA_GROUPS): + # Wait on prior store to finish. + tlx.async_descriptor_store_wait(NUM_Z_BUFFERS - 1) + z_empty = tlx.local_view(z_empty_bars, z_idx) + tlx.barrier_arrive(z_empty, 1) + # Wait for the next load to be ready + z_full = tlx.local_view(z_load_bars, z_idx) + tlx.barrier_wait(z_full, z_load_phase) + z_buf_view = tlx.local_view(z_buffers, z_idx) + tlx.async_descriptor_store( + z_desc, + z_buf_view, + [ + offs_xm + group_id * BLOCK_M_SPLIT, + offs_wn + slice_id * slice_size, + ], + ) + z_load_phase = z_load_phase ^ (z_idx == (NUM_Z_BUFFERS - 1)) + # pyre-ignore[58] + z_idx = (z_idx + 1) % NUM_Z_BUFFERS + + # Wait for the last store. + tlx.async_descriptor_store_wait(0) + + +@torch.fx.wrap +def triton_addmm_fwd_tma_persistent( + x: torch.Tensor, + w: torch.Tensor, + y: torch.Tensor, + warp_specialize: bool | None = None, +) -> torch.Tensor: + _meta_ws = _use_meta_ws() + if warp_specialize is None: + warp_specialize = _meta_ws + + M, K = x.shape + _, N = w.shape + + is_y_1d = y.dim() == 1 + + # Allocate output + z = torch.empty((M, N), device=x.device, dtype=x.dtype) + if M == 0 or N == 0: + return z + + # A dummy block value that will be overwritten when we have the real block size + dummy_block = [1, 1] + # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional + # argument, expected `List[int]` but got `Size` + x_desc = TensorDescriptor(x, x.shape, x.stride(), dummy_block) + # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional + # argument, expected `List[int]` but got `Size` + w_desc = TensorDescriptor(w, w.shape, w.stride(), dummy_block) + y = y.reshape(1, -1) if is_y_1d else y + # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional + # argument, expected `List[int]` but got `Size` + y_desc = TensorDescriptor(y, y.shape, y.stride(), dummy_block) + # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional + # argument, expected `List[int]` but got `Size` + z_desc = TensorDescriptor(z, z.shape, z.stride(), dummy_block) + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + def grid(meta): + BLOCK_M = meta["BLOCK_M"] + BLOCK_N = meta["BLOCK_N"] + return ( + min( + NUM_SMS, + triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), + ), + ) + + _addmm_fwd_tma_persistent[grid]( + x_desc, + w_desc, + y_desc, + z_desc, + M, + N, + K, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + BROADCAST_Y=is_y_1d, + WARP_SPECIALIZE=warp_specialize, + NUM_SMS=NUM_SMS, + USE_META_WS=_meta_ws, + ) + return z + + +@torch.fx.wrap +def triton_addmm_fwd_tma_ws_tlx( + x: torch.Tensor, + w: torch.Tensor, + y: torch.Tensor, +) -> torch.Tensor: + M, K = x.shape + _, N = w.shape + + is_y_1d = y.dim() == 1 + + # Allocate output + z = torch.empty((M, N), device=x.device, dtype=x.dtype) + if M == 0 or N == 0: + return z + + # A dummy block value that will be overwritten when we have the real block size + dummy_block = [1, 1] + # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional + # argument, expected `List[int]` but got `Size` + x_desc = TensorDescriptor(x, x.shape, x.stride(), dummy_block) + # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional + # argument, expected `List[int]` but got `Size` + w_desc = TensorDescriptor(w, w.shape, w.stride(), dummy_block) + y = y.reshape(1, -1) if is_y_1d else y + # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional + # argument, expected `List[int]` but got `Size` + y_desc = TensorDescriptor(y, y.shape, y.stride(), dummy_block) + # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional + # argument, expected `List[int]` but got `Size` + z_desc = TensorDescriptor(z, z.shape, z.stride(), dummy_block) + + def grid(meta): + BLOCK_M = meta["BLOCK_M"] + BLOCK_N = meta["BLOCK_N"] + return ( + triton.cdiv(M, BLOCK_M), + triton.cdiv(N, BLOCK_N), + ) + + _addmm_fwd_tma_ws[grid]( + x_desc, + w_desc, + y_desc, + z_desc, + M, + N, + K, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + BROADCAST_Y=is_y_1d, + NUM_SMEM_BUFFERS=2, # Double buffering + ) + return z + + +@torch.fx.wrap +def triton_addmm_fwd_tma_ws_persistent_tlx( + x: torch.Tensor, + w: torch.Tensor, + y: torch.Tensor, +) -> torch.Tensor: + M, K = x.shape + _, N = w.shape + + is_y_1d = y.dim() == 1 + + # Allocate output + z = torch.empty((M, N), device=x.device, dtype=x.dtype) + if M == 0 or N == 0: + return z + + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + # A dummy block value that will be overwritten by the hook + dummy_block = [1, 1] + # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional + # argument, expected `List[int]` but got `Size` + x_desc = TensorDescriptor(x, x.shape, x.stride(), dummy_block) + # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional + # argument, expected `List[int]` but got `Size` + w_desc = TensorDescriptor(w, w.shape, w.stride(), dummy_block) + y = y.reshape(1, -1) if is_y_1d else y + # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional + # argument, expected `List[int]` but got `Size` + y_desc = TensorDescriptor(y, y.shape, y.stride(), dummy_block) + # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional + # argument, expected `List[int]` but got `Size` + z_desc = TensorDescriptor(z, z.shape, z.stride(), dummy_block) + + def grid(meta): + BLOCK_M = meta["BLOCK_M"] + BLOCK_N = meta["BLOCK_N"] + num_pid_m = triton.cdiv(M, BLOCK_M) + num_pid_n = triton.cdiv(N, BLOCK_N) + # Round up num_pid_m to even for PAIR_CTA cluster compatibility + num_pid_m = (num_pid_m + 1) // 2 * 2 + total_tiles = num_pid_m * num_pid_n + grid_size = min(NUM_SMS, total_tiles) + # Ensure grid is even for cluster compatibility + if grid_size % 2 == 1: + grid_size = min(grid_size + 1, NUM_SMS) + # If rounding up exceeds NUM_SMS and NUM_SMS is odd, round down instead + if grid_size % 2 == 1: + grid_size = grid_size - 1 + return (grid_size,) + + _addmm_fwd_tma_ws_persistent[grid]( + x_desc, + w_desc, + y_desc, + z_desc, + M, + N, + K, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + BROADCAST_Y=is_y_1d, + NUM_SMS=NUM_SMS, + ) + return z + + +@maybe_register_custom_op("generative_recommenders::triton_addmm_fwd", mutates_args=()) +def triton_addmm_fwd( + x: torch.Tensor, + w: torch.Tensor, + y: torch.Tensor, +) -> torch.Tensor: + M, K = x.shape + KB, N = w.shape + assert K == KB, f"incompatible dimensions {K}, {KB}" + + is_y_1d = y.dim() == 1 + NY = y.shape[0] if is_y_1d else y.shape[1] + assert N == NY, f"incompatible dimensions {N}, {NY}" + + # Allocate output + z = torch.empty((M, N), device=x.device, dtype=x.dtype) + if M == 0 or N == 0: + return z + + grid = lambda meta: ( # noqa E731 + triton.cdiv(M, meta["BLOCK_M"]), + triton.cdiv(N, meta["BLOCK_N"]), + ) + + _addmm_fwd[grid]( + x, + w, + y, + z, + M, + N, + K, + x.stride(0), + x.stride(1), + w.stride(0), + w.stride(1), + y.stride(0) if not is_y_1d else 0, + y.stride(1) if not is_y_1d else y.stride(0), + z.stride(0), + z.stride(1), + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + BROADCAST_Y=is_y_1d, + ) + return z + + +@triton_addmm_fwd.register_fake +def triton_addmm_fwd_fake( + x: torch.Tensor, + w: torch.Tensor, + y: torch.Tensor, +) -> torch.Tensor: + """Fake implementation for FakeTensor tracing.""" + M, _ = x.shape + _, N = w.shape + return torch.empty((M, N), device=x.device, dtype=x.dtype) + + +def triton_addmm_bwd( + x: torch.Tensor, + w: torch.Tensor, + dz: torch.Tensor, + is_y_1d: bool, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if is_y_1d: + dy = torch.sum(dz, dim=0) + else: + dy = dz + dw = torch.mm(x.t(), dz) + dx = torch.mm(dz, w.t()) + + return dx, dw, dy + + +@maybe_register_custom_op( + "generative_recommenders::maybe_triton_addmm_fwd", mutates_args=() +) +def maybe_triton_addmm_fwd( + x: torch.Tensor, + w: torch.Tensor, + y: Optional[torch.Tensor], +) -> torch.Tensor: + # triton addmm is slower than torch (cublas) on AMD/Blackwell. + # Default to pytorch addmm on AMD/Blackwell for now. + if y is None: + return torch.mm(x, w) + if is_sm100_plus() or torch.version.hip is not None: + return torch.addmm(y, x, w) + else: + return triton_addmm_fwd(x=x, w=w, y=y) + + +@maybe_triton_addmm_fwd.register_fake +def maybe_triton_addmm_fwd_fake( + x: torch.Tensor, + w: torch.Tensor, + y: torch.Tensor, +) -> torch.Tensor: + """Fake implementation for FakeTensor tracing.""" + M, _ = x.shape + _, N = w.shape + return torch.empty((M, N), device=x.device, dtype=x.dtype) + + +class _AddMmFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + x: torch.Tensor, + w: torch.Tensor, + y: torch.Tensor, + ) -> torch.Tensor: + ctx.save_for_backward(x, w) + ctx.is_y_1d = y.dim() == 1 + if is_sm100_plus() and TMA_AVAILABLE and _check_tma_alignment(x, w, y): + if x.dtype == torch.float32 or HAS_TLX == False: + return triton_addmm_fwd_tma_persistent(x, w, y, warp_specialize=True) + else: + return triton_addmm_fwd_tma_ws_persistent_tlx( + x, w, y + ) # tlx.async_dot doesn't support fp32 inputs because of WGMMA requirements + else: + return triton_addmm_fwd(x, w, y) + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, dz: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + (x, w) = ctx.saved_tensors + return triton_addmm_bwd(x, w, dz, ctx.is_y_1d) + + +def triton_addmm( + input: torch.Tensor, + mat1: torch.Tensor, + mat2: torch.Tensor, +) -> torch.Tensor: + return _AddMmFunction.apply(mat1, mat2, input) diff --git a/recommendation_v4/generative_recommenders/ops/triton/triton_attention_utils.py b/recommendation_v4/generative_recommenders/ops/triton/triton_attention_utils.py new file mode 100644 index 000000000..61fd614f3 --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/triton/triton_attention_utils.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-unsafe + +#!/usr/bin/env python3 + + +# @manual=//triton:triton +import triton + +# @manual=//triton:triton +import triton.language as tl + + +@triton.jit +def acc_dq( + dq_ptrs_trans, + start_m, + stride_dqm, + k, + dqk_trans, + alpha, + mask_m, + MAX_SEQ_LEN, + LOCK, + BLOCK_M: tl.constexpr, + ATOMIC_ADD: tl.constexpr, + ALLOW_TF32: tl.constexpr, +): + if ATOMIC_ADD: + lock_id = start_m // BLOCK_M + stride_lock = tl.cdiv(MAX_SEQ_LEN, BLOCK_M) + lock = LOCK + tl.program_id(0) * stride_lock + lock_id + tl.debug_barrier() # add a barrier to force sync + while tl.atomic_cas(lock, 0, 1) == 1: + pass + dq_trans = tl.load( + dq_ptrs_trans + start_m * stride_dqm, + mask=mask_m[None, :], + other=0.0, + eviction_policy="evict_last", + ) + dq_trans += tl.dot(tl.trans(k), dqk_trans, allow_tf32=ALLOW_TF32) * alpha + dq_trans = dq_trans.to(k.dtype) + tl.store( + dq_ptrs_trans + start_m * stride_dqm, + dq_trans, + mask=mask_m[None, :], + eviction_policy="evict_last", + ) + if ATOMIC_ADD: + tl.atomic_xchg(lock, 0) # pyre-ignore [61] diff --git a/recommendation_v4/generative_recommenders/ops/triton/triton_hstu_attention.py b/recommendation_v4/generative_recommenders/ops/triton/triton_hstu_attention.py new file mode 100644 index 000000000..768ef0013 --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/triton/triton_hstu_attention.py @@ -0,0 +1,3242 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#!/usr/bin/env python3 + +# pyre-unsafe + +from typing import List, Optional, Tuple + +import torch + +# @manual=//triton:triton +import triton + +# @manual=//triton:triton +import triton.language as tl +from generative_recommenders.ops.utils import ( + copy_if_different_ptr, + maybe_register_custom_op, +) + +try: + # @manual=//triton:triton + import triton.language.extra.tlx as tlx # type: ignore + + HAS_TLX = True +except ImportError: + # suppress type checking errors + tlx = None + + HAS_TLX = False + +from generative_recommenders.common import ( + autotune_max_seq_len, + prev_power_of_2, + switch_to_contiguous_if_needed, + triton_autotune, +) +from generative_recommenders.ops.triton._autotune_pinning import pinned_or_full +from triton.language.extra.libdevice import ( # @manual=//triton:triton + fast_dividef, + fast_expf, +) + +try: + # @manual=//triton:triton + from triton.tools.tensor_descriptor import TensorDescriptor + + tensor_descriptor_tma = True +except ImportError: + tensor_descriptor_tma = False + +try: + from generative_recommenders.ops.triton.fb.triton_attention_utils import acc_dq +except ImportError: + from generative_recommenders.ops.triton.triton_attention_utils import acc_dq + + +def _host_descriptor_pre_hook(nargs): + if not tensor_descriptor_tma: + return + + if not isinstance(nargs["Q"], TensorDescriptor): + return + BLOCK_M = nargs["BLOCK_M"] + BLOCK_N = nargs["BLOCK_N"] + BLOCK_D_Q = nargs["BLOCK_D_Q"] + BLOCK_D_V = nargs["BLOCK_D_V"] + if "USE_TLX" in nargs and nargs["USE_TLX"]: + BLOCK_M = BLOCK_M // nargs["NUM_MMA_GROUPS"] + nargs["Q"].block_shape = [BLOCK_M, BLOCK_D_Q] + nargs["V"].block_shape = [BLOCK_N, BLOCK_D_V] + nargs["K"].block_shape = [BLOCK_N, BLOCK_D_Q] + + +# pyre-ignore[2] +def _early_config_prune( + configs: List[triton.Config], + named_args, + **kwargs, +) -> List[triton.Config]: + """Filter autotune configs that are incompatible with the current call. + + The TLX (warp-specialized) variant of ``_hstu_attn_fwd`` calls + ``tlx.async_descriptor_load(Q, ...)`` which requires Q/K/V to be real TMA + tensor descriptors (``tl.tensor_descriptor_base``). They are only + constructed by the host wrapper when ``ENABLE_TMA=True`` AND the host + ``TensorDescriptor`` API is importable. If the kernel is invoked without + those preconditions, raw tensors flow into the TLX path and the + ``isinstance(desc, tl.tensor_descriptor_base)`` assert in + ``triton/language/extra/tlx/mem_ops.py`` fires at compile time. + + We make autotuning robust to that mismatch by dropping any config with + ``USE_TLX=True`` whenever ENABLE_TMA is not set or TMA host descriptors + are unavailable. This is purely defensive: if the caller threads + ``enable_tma=True`` (see ``_should_enable_tma`` below) the TLX configs + remain eligible. + """ + enable_tma = kwargs.get("ENABLE_TMA", None) + if enable_tma is None: + enable_tma = named_args.get("ENABLE_TMA", False) + if enable_tma and tensor_descriptor_tma: + return configs + pruned = [c for c in configs if not c.kwargs.get("USE_TLX", False)] + # Safety: never return an empty config list. + return pruned if pruned else configs + + +def _should_enable_tma() -> bool: + """Return True iff the TMA / TLX fast path can be safely enabled. + + Conditions: + * The host ``triton.tools.tensor_descriptor.TensorDescriptor`` API is + importable (``tensor_descriptor_tma``). + * CUDA is available and the device is Hopper (compute capability 9), + which is the only architecture for which TLX configs are emitted in + ``_get_fw_configs``. + """ + if not tensor_descriptor_tma: + return False + if not torch.cuda.is_available(): + return False + # NVIDIA-only gate: TMA (Tensor Memory Accelerator) is Hopper-specific + # hardware. On ROCm/HIP, `torch.cuda.get_device_capability()` mirrors the + # gfx name into a major.minor tuple — gfx950 (MI350X) returns (9, 5), which + # would otherwise pass the `device_capability == 9` check below and trick + # the kernel into taking the TMA path. The TMA path uses + # `triton.tools.tensor_descriptor.TensorDescriptor` and `TensorDescriptor.load` + # which lower to PTX `cp.async.bulk.tensor.*`; on AMD this either fails to + # compile or produces a kernel with mismatched reduction-dim shapes for + # `tl.dot(silu, v)` in `_hstu_attn_fwd_one_block` (see WARNING in + # `_hstu_attn_fwd` for the cascade). Bail out early on HIP so the + # non-TMA path is selected and AMD gets a working kernel. + if torch.version.hip: + return False + try: + device_capability = torch.cuda.get_device_capability()[0] + except (RuntimeError, AssertionError): + return False + return device_capability == 9 + + +def _get_fw_configs() -> List[triton.Config]: # noqa: C901 + configs = [] + if torch.version.hip: + for BLOCK_M in [32, 64, 128]: + for BLOCK_N in [32, 64]: + for num_stages in [1, 2]: + for num_warps in [4, 8]: + for matrix_instr_nonkdim in [16, 32]: + configs.append( + triton.Config( + { + "BLOCK_M": BLOCK_M, + "BLOCK_N": BLOCK_N, + "matrix_instr_nonkdim": matrix_instr_nonkdim, + "waves_per_eu": 0, + "kpack": 2, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + else: + configs = [ + triton.Config( + {"BLOCK_M": 16, "BLOCK_N": 32}, + num_stages=2, + num_warps=2, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32}, + num_stages=2, + num_warps=2, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32}, + num_stages=4, + num_warps=2, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32}, + num_stages=2, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32}, + num_stages=4, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64}, + num_stages=2, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64}, + num_stages=4, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64}, + num_stages=4, + num_warps=8, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 128}, + num_stages=2, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 128}, + num_stages=2, + num_warps=8, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32}, + num_stages=4, + num_warps=2, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32}, + num_stages=2, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32}, + num_stages=4, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32}, + num_stages=2, + num_warps=8, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64}, + num_stages=2, + num_warps=2, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64}, + num_stages=2, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64}, + num_stages=4, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64}, + num_stages=4, + num_warps=8, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32}, + num_stages=2, + num_warps=2, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32}, + num_stages=4, + num_warps=2, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32}, + num_stages=2, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32}, + num_stages=4, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32}, + num_stages=2, + num_warps=8, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32}, + num_stages=4, + num_warps=8, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64}, + num_stages=2, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64}, + num_stages=2, + num_warps=8, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64}, + num_stages=4, + num_warps=8, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128}, + num_stages=4, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128}, + num_stages=2, + num_warps=8, + pre_hook=_host_descriptor_pre_hook, + ), + ] + + # The `_hstu_attn_fwd` kernel signature unconditionally declares the four + # constexprs `USE_TLX`, `NUM_BUFFERS`, `NUM_MMA_WARPS_PER_GROUP`, + # `NUM_MMA_GROUPS` (introduced for the Hopper TLX warp-specialized variant). + # Triton requires every constexpr be bound at autotune time; missing any one + # of them triggers `TypeError: dynamic_func() missing N required positional + # arguments` during kernel dispatch. This loop populates the non-TLX defaults + # so the kernel call site doesn't have to know about TLX at all. + # + # IMPORTANT: this loop must apply to BOTH the HIP branch and the CUDA branch + # above. It used to live inside the CUDA `else:` block which meant HIP + # configs reached `_hstu_attn_fwd[grid](...)` without these defaults and + # crashed at dispatch. Keep this hoisted (outside the if/else) when + # editing — see commit message for the symptom. + for config in configs: + if not config.kwargs.get("USE_TLX", False): + config.kwargs["USE_TLX"] = False + config.kwargs["NUM_BUFFERS"] = 1 + config.kwargs["NUM_MMA_WARPS_PER_GROUP"] = 1 + config.kwargs["NUM_MMA_GROUPS"] = 1 + + # TLX (Triton Language Extension) warp-specialized configs are Hopper-only. + # Guard with `not torch.version.hip` so AMD never sees them — the TLX code + # path inside `_hstu_attn_fwd` calls `tlx.async_descriptor_load(...)` which + # requires real TMA tensor descriptors and only compiles on CUDA. + if not torch.version.hip: + if HAS_TLX: + try: + device_capability = torch.cuda.get_device_capability()[0] + except (RuntimeError, AssertionError): + # No CUDA device available + device_capability = None + + if device_capability == 9: + # H100 configs + configs.append( + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "USE_TLX": True, + "NUM_BUFFERS": 2, + "NUM_MMA_WARPS_PER_GROUP": 4, + "NUM_MMA_GROUPS": 2, + }, + num_stages=0, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + ) + + return configs + + +@triton.jit +def _hstu_attn_fwd_one_block( # noqa: C901 + start_n, + seq_len, + offs_m, + offs_n, + q, + K, + V, + K_block_ptr, + V_block_ptr, + offset_kh, + offset_vh, + seq_start, + n_targets, + alpha, + MAX_SEQ_LEN, + contextual_seq_len, + max_attn_len, + HAS_MULTIPLE_TARGETS: tl.constexpr, + HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, + HAS_MAX_ATTN_LEN: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_D_V: tl.constexpr, + BLOCK_N: tl.constexpr, + ENABLE_TMA: tl.constexpr, +): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = None + qk = None + if ENABLE_TMA: + k = K.load( + [(seq_start + start_n).to(tl.int32), offset_kh.to(tl.int32)], + ) + # tma can only be loaded in one order, use trans afterwards + qk = tl.dot(q, tl.trans(k), allow_tf32=ALLOW_TF32) * alpha + else: + k = tl.load(K_block_ptr, boundary_check=(1,), padding_option="zero") + qk = tl.dot(q, k, allow_tf32=ALLOW_TF32) * alpha + invalid_mask = offs_m[:, None] == offs_n[None, :] + max_ids = seq_len + if HAS_CONTEXTUAL_SEQ_LEN: + offs_m = offs_m - contextual_seq_len + 1 + offs_m = tl.where( + offs_m > 0, + offs_m, + 0, + ) + offs_n = offs_n - contextual_seq_len + 1 + offs_n = tl.where( + offs_n > 0, + offs_n, + 0, + ) + max_ids = max_ids - contextual_seq_len + 1 + if HAS_MULTIPLE_TARGETS: + max_ids = max_ids - n_targets + offs_m = tl.where( + offs_m < max_ids, + offs_m, + max_ids, + ) + offs_n = tl.where( + offs_n < max_ids, + offs_n, + max_ids, + ) + offs_m_minus_n = offs_m[:, None] - offs_n[None, :] + invalid_mask = invalid_mask or (offs_m_minus_n > 0) + if HAS_MAX_ATTN_LEN: + invalid_mask = invalid_mask and offs_m_minus_n <= max_attn_len + if HAS_CONTEXTUAL_SEQ_LEN: + invalid_mask = invalid_mask or ( + offs_m[:, None] == 0 and offs_n[None, :] < max_ids + ) + scale = tl.where(invalid_mask, (1.0 / MAX_SEQ_LEN), 0.0) + silu = fast_dividef(qk, 1.0 + fast_expf(-qk)) * scale + v = None + if ENABLE_TMA: + v = V.load( + [(seq_start + start_n).to(tl.int32), offset_vh.to(tl.int32)], + ) + else: + v = tl.load(V_block_ptr, boundary_check=(0,), padding_option="zero") + silu = silu.to(v.dtype) + return tl.dot(silu, v, allow_tf32=ALLOW_TF32) + + +@triton.jit +def _hstu_attn_fwd_compute( # noqa C901 + Q, + K, + V, + H, + DimQ, + DimV, + workspace_ptr, + seq_offsets, + num_targets, + Out, + stride_qm, + stride_qh, + stride_kn, + stride_kh, + stride_vn, + stride_vh, + stride_om, + stride_oh, + alpha, + MAX_SEQ_LEN, + DeltaSize, + contextual_seq_len, + max_attn_len, + off_z, + off_h, + pid, + HAS_MULTIPLE_TARGETS: tl.constexpr, + IS_DELTA_Q: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_D_V: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, + HAS_MAX_ATTN_LEN: tl.constexpr, + ENABLE_TMA: tl.constexpr, + TMA_DESC_SIZE: tl.constexpr, +): + seq_start = tl.load(seq_offsets + off_z).to(tl.int64) + off_h = off_h.to(tl.int64) + off_z = off_z.to(tl.int64) + seq_end = tl.load(seq_offsets + off_z + 1) + seq_len = (seq_end - seq_start).to(tl.int32) + + if IS_DELTA_Q: + start_m_delta = pid * BLOCK_M + start_m = (start_m_delta + seq_len - DeltaSize).to(tl.int32) + else: + start_m_delta = 0 + start_m = pid * BLOCK_M + if start_m < seq_len: + if HAS_MULTIPLE_TARGETS: + n_targets = tl.load(num_targets + off_z).to(tl.int32) + else: + n_targets = None + + # initialize offsets + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + Q_block_ptr = None + K_block_ptr = None + V_block_ptr = None + if not ENABLE_TMA: + if IS_DELTA_Q: + Q_block_ptr = tl.make_block_ptr( + base=Q + off_h * stride_qh + off_z * DeltaSize * stride_qm, + shape=(DeltaSize, BLOCK_D_Q), + strides=(stride_qm, 1), + offsets=(start_m_delta, 0), + block_shape=(BLOCK_M, BLOCK_D_Q), + order=(1, 0), + ) + else: + Q_block_ptr = tl.make_block_ptr( + base=Q + off_h * stride_qh + seq_start * stride_qm, + shape=(seq_len, BLOCK_D_Q), + strides=(stride_qm, 1), + offsets=(start_m, 0), + block_shape=(BLOCK_M, BLOCK_D_Q), + order=(1, 0), + ) + q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option="zero") + + K_block_ptr = tl.make_block_ptr( + base=K + off_h * stride_kh + seq_start * stride_kn, + shape=(BLOCK_D_Q, seq_len), + strides=(1, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_D_Q, BLOCK_N), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=V + off_h * stride_vh + seq_start * stride_vn, + shape=(seq_len, BLOCK_D_V), + strides=(stride_vn, 1), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_D_V), + order=(1, 0), + ) + else: + if IS_DELTA_Q: + q = Q.load( + [ + (off_z * DeltaSize + start_m_delta).to(tl.int32), + (off_h * stride_qh).to(tl.int32), + ] + ) + else: + q = Q.load( + [ + (seq_start + start_m).to(tl.int32), + (off_h * stride_qh).to(tl.int32), + ] + ) + + acc = tl.zeros([BLOCK_M, BLOCK_D_V], dtype=tl.float32) + if HAS_MULTIPLE_TARGETS: + uih_end = seq_len - n_targets + else: + uih_end = seq_len + if HAS_CONTEXTUAL_SEQ_LEN is True and start_m < contextual_seq_len: + # uih_end must be larger than start_m + low = 0 + high = seq_len + else: + low = 0 + high = start_m + BLOCK_M + if HAS_MAX_ATTN_LEN: + if start_m > uih_end: + low = uih_end - max_attn_len + else: + low = start_m - max_attn_len + if HAS_CONTEXTUAL_SEQ_LEN: + low = low if low > contextual_seq_len else 0 + else: + low = low if low > 0 else 0 + if HAS_MULTIPLE_TARGETS: + uih_end = (uih_end + BLOCK_N - 1) // BLOCK_N * BLOCK_N + if uih_end < start_m: + high = seq_len - n_targets + + if low > 0: + if not ENABLE_TMA: + K_block_ptr = tl.advance(K_block_ptr, (0, low)) + V_block_ptr = tl.advance(V_block_ptr, (low, 0)) + end_n = low + for start_n in range(low, high, BLOCK_N): + acc += _hstu_attn_fwd_one_block( + start_n=start_n, + seq_len=seq_len, + offs_m=offs_m, + offs_n=offs_n + start_n, + q=q, + K=K, + V=V, + K_block_ptr=K_block_ptr, + V_block_ptr=V_block_ptr, + offset_kh=off_h * stride_kh, + offset_vh=off_h * stride_vh, + seq_start=seq_start, + n_targets=n_targets if HAS_MULTIPLE_TARGETS else None, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, + HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, + ALLOW_TF32=ALLOW_TF32, + BLOCK_D_Q=BLOCK_D_Q, + BLOCK_D_V=BLOCK_D_V, + BLOCK_N=BLOCK_N, + ENABLE_TMA=ENABLE_TMA, + ) + if not ENABLE_TMA: + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + end_n += BLOCK_N + + if HAS_MULTIPLE_TARGETS: + # pyre-ignore[61] + if uih_end < start_m: + low_delta = start_m + high_delta = start_m + BLOCK_M + offset = (low_delta - end_n).to(tl.int32) + if not ENABLE_TMA: + K_block_ptr = tl.advance(K_block_ptr, (0, offset)) + V_block_ptr = tl.advance(V_block_ptr, (offset, 0)) + for start_delta in tl.range( + low_delta, high_delta, BLOCK_N, num_stages=0 + ): + acc += _hstu_attn_fwd_one_block( + start_n=start_delta, + seq_len=seq_len, + offs_m=offs_m, + offs_n=offs_n + start_delta, + q=q, + K=K, + V=V, + K_block_ptr=K_block_ptr, + V_block_ptr=V_block_ptr, + offset_kh=off_h * stride_kh, + offset_vh=off_h * stride_vh, + seq_start=seq_start, + n_targets=n_targets if HAS_MULTIPLE_TARGETS else None, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, + HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, + ALLOW_TF32=ALLOW_TF32, + BLOCK_D_Q=BLOCK_D_Q, + BLOCK_D_V=BLOCK_D_V, + BLOCK_N=BLOCK_N, + ENABLE_TMA=ENABLE_TMA, + ) + if not ENABLE_TMA: + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + # Don't use TMA in Jagged case since we don't want to overwrite + # the output of another sequence + if IS_DELTA_Q: + start_m_delta = pid * BLOCK_M + offs_m_delta = start_m_delta + tl.arange(0, BLOCK_M) + offs_v_d = tl.arange(0, BLOCK_D_V) + off_o = Out + off_z * DeltaSize * stride_om + off_h * stride_oh + out_ptrs = off_o + offs_m_delta[:, None] * stride_om + offs_v_d[None, :] + tl.store(out_ptrs, acc, mask=(offs_m_delta < DeltaSize)[:, None]) + else: + # rematerialize offsets to save registers + start_m = pid * BLOCK_M + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_v_d = tl.arange(0, BLOCK_D_V) + off_o = Out + seq_start * stride_om + off_h * stride_oh + out_ptrs = off_o + offs_m[:, None] * stride_om + offs_v_d[None, :] + tl.store(out_ptrs, acc, mask=(offs_m < seq_len)[:, None]) + + +@triton.jit +def _hstu_attn_fwd_compute_main_loop_tlx( # noqa C901 + low, + high, + seq_len, + offs_m, + offs_n, + acc, + q_tiles, + k_tiles, + v_tiles, + q_fulls, + k_fulls, + v_fulls, + k_empties, + v_empties, + v_dtype, + n_targets, + alpha, + end_n, + loop_trip_cnt, + max_attn_len, + HAS_MULTIPLE_TARGETS: tl.constexpr, + HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, + HAS_MAX_ATTN_LEN: tl.constexpr, + cid: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_BUFFERS: tl.constexpr, + MAX_SEQ_LEN: tl.constexpr, + WAIT_FOR_Q: tl.constexpr, +): + if WAIT_FOR_Q: + # wait for the Q buffer to be populated by the producer + q_full = tlx.local_view(q_fulls, cid) + tlx.barrier_wait(q_full, 0) + + q_tile = tlx.local_view(q_tiles, cid) + + for start in tl.range(low + BLOCK_N, high, BLOCK_N, num_stages=0): + buf_id = loop_trip_cnt % NUM_BUFFERS + # buffers in a row share the same phase + kv_phase = (loop_trip_cnt // NUM_BUFFERS) % 2 + + start_n = tl.multiple_of(start, BLOCK_N) + offs_n_start = offs_n + offs_n = offs_n_start + start_n + + # wait for the K buffer to be populated by the producer + k_full = tlx.local_view(k_fulls, buf_id) + tlx.barrier_wait(k_full, kv_phase) + k_tile = tlx.local_view(k_tiles, buf_id) + + # tma can only be loaded in one order, use trans afterwards + k_tile = tlx.local_trans(k_tile) + # second + qk = tlx.async_dot(q_tile, k_tile) + # wait for the MMA using to complete + qk = tlx.async_dot_wait(0, qk) + # release the K buffer + k_empty = tlx.local_view(k_empties, buf_id) + tlx.barrier_arrive(k_empty, 1) + + qk = qk * alpha + + invalid_mask = offs_m[:, None] == offs_n[None, :] + max_ids = seq_len + if HAS_MULTIPLE_TARGETS: + max_ids = max_ids - n_targets + offs_m = tl.where( + offs_m < max_ids, + offs_m, + max_ids, + ) + offs_n = tl.where( + offs_n < max_ids, + offs_n, + max_ids, + ) + offs_m_minus_n = offs_m[:, None] - offs_n[None, :] + invalid_mask = invalid_mask or (offs_m_minus_n > 0) + if HAS_MAX_ATTN_LEN: + invalid_mask = invalid_mask and offs_m_minus_n <= max_attn_len + if HAS_CONTEXTUAL_SEQ_LEN: + invalid_mask = invalid_mask or ( + offs_m[:, None] == 0 and offs_n[None, :] < max_ids + ) + scale = tl.where(invalid_mask, (1.0 / MAX_SEQ_LEN), 0.0) + silu = fast_dividef(qk, 1.0 + fast_expf(-qk)) * scale + silu = silu.to(v_dtype) + + # wait for the V buffer to be populated by the producer + v_full = tlx.local_view(v_fulls, buf_id) + tlx.barrier_wait(v_full, kv_phase) + v_tile = tlx.local_view(v_tiles, buf_id) + acc = tlx.async_dot(silu, v_tile, acc) + # wait for the MMA using to complete + acc = tlx.async_dot_wait(0, acc) + # release the V buffer + v_empty = tlx.local_view(v_empties, buf_id) + tlx.barrier_arrive(v_empty, 1) + + end_n += BLOCK_N + + # increment loop trip counts + loop_trip_cnt += 1 + + return acc, end_n, loop_trip_cnt + + +@triton.jit +def _hstu_attn_fwd_compute_main_loop_tlx_pipelined( # noqa C901 + low, + high, + seq_len, + offs_m, + offs_n, + acc, + q_tiles, + k_tiles, + v_tiles, + q_fulls, + k_fulls, + v_fulls, + k_empties, + v_empties, + v_dtype, + n_targets, + alpha, + end_n, + loop_trip_cnt, + max_attn_len, + HAS_MULTIPLE_TARGETS: tl.constexpr, + HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, + HAS_MAX_ATTN_LEN: tl.constexpr, + cid: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_BUFFERS: tl.constexpr, + MAX_SEQ_LEN: tl.constexpr, + WAIT_FOR_Q: tl.constexpr, +): + if WAIT_FOR_Q: + # wait for the Q buffer to be populated by the producer + q_full = tlx.local_view(q_fulls, cid) + tlx.barrier_wait(q_full, 0) + q_tile = tlx.local_view(q_tiles, cid) + + # wait for the K buffer to be populated by the producer + k_buf_id = loop_trip_cnt % NUM_BUFFERS + # buffers in a row share the same phase + k_phase = (loop_trip_cnt // NUM_BUFFERS) % 2 + + k_full = tlx.local_view(k_fulls, k_buf_id) + tlx.barrier_wait(k_full, k_phase) + k_tile = tlx.local_view(k_tiles, k_buf_id) + + # tma can only be loaded in one order, use trans afterwards + k_tile = tlx.local_trans(k_tile) + + # Pingpong + if cid == 0: + # Consumer 0 waits for Consumer 1 to reach synchronization point at barrier 9. + tlx.named_barrier_wait(9, 256) + else: + # Consumer 1 signals its arrival at barrier 9. + tlx.named_barrier_arrive(9, 256) + # Then waits at barrier 10 until Consumer 0 finishes issuing its async_dot. + tlx.named_barrier_wait(10, 256) + + qk = tlx.async_dot(q_tile, k_tile) + + if cid == 0: + # After issuing async_dot, Consumer 0 signals barrier 10 to unblock Consumer 1. + tlx.named_barrier_arrive(10, 256) + + # wait for the MMA using to complete + qk = tlx.async_dot_wait(0, qk) + # release the K buffer + k_empty = tlx.local_view(k_empties, k_buf_id) + tlx.barrier_arrive(k_empty, 1) + + qk = qk * alpha + + start_n = tl.multiple_of(low, BLOCK_N) + offs_n_start = offs_n + offs_n = offs_n_start + start_n + + invalid_mask = offs_m[:, None] == offs_n[None, :] + max_ids = seq_len + if HAS_MULTIPLE_TARGETS: + max_ids = max_ids - n_targets + offs_m = tl.where( + offs_m < max_ids, + offs_m, + max_ids, + ) + offs_n = tl.where( + offs_n < max_ids, + offs_n, + max_ids, + ) + offs_m_minus_n = offs_m[:, None] - offs_n[None, :] + invalid_mask = invalid_mask or (offs_m_minus_n > 0) + if HAS_MAX_ATTN_LEN: + invalid_mask = invalid_mask and offs_m_minus_n <= max_attn_len + if HAS_CONTEXTUAL_SEQ_LEN: + invalid_mask = invalid_mask or ( + offs_m[:, None] == 0 and offs_n[None, :] < max_ids + ) + scale = tl.where(invalid_mask, (1.0 / MAX_SEQ_LEN), 0.0) + silu = fast_dividef(qk, 1.0 + fast_expf(-qk)) * scale + silu = silu.to(v_dtype) + + loop_trip_cnt += 1 + + for start in tl.range(low + BLOCK_N, high, BLOCK_N, num_stages=0): + start_n = tl.multiple_of(start, BLOCK_N) + offs_n = offs_n_start + start_n + + k_buf_id = loop_trip_cnt % NUM_BUFFERS + # buffers in a row share the same phase + k_phase = k_phase ^ (k_buf_id == 0) + + # wait for the K buffer to be populated by the producer + k_full = tlx.local_view(k_fulls, k_buf_id) + tlx.barrier_wait(k_full, k_phase) + k_tile = tlx.local_view(k_tiles, k_buf_id) + + # tma can only be loaded in one order, use trans afterwards + k_tile = tlx.local_trans(k_tile) + + qk = tlx.async_dot(q_tile, k_tile) + # wait for the MMA using to complete + prev_silu = silu + + v_buf_id = (loop_trip_cnt - 1) % NUM_BUFFERS + # v_phase = v_phase ^ (v_buf_id == 0) + v_phase = ((loop_trip_cnt - 1) // NUM_BUFFERS) % 2 + v_full = tlx.local_view(v_fulls, v_buf_id) + tlx.barrier_wait(v_full, v_phase) + v_tile = tlx.local_view(v_tiles, v_buf_id) + acc = tlx.async_dot(prev_silu, v_tile, acc) + qk = tlx.async_dot_wait(1, qk) + + # release the K buffer + k_empty = tlx.local_view(k_empties, k_buf_id) + tlx.barrier_arrive(k_empty, 1) + + qk = qk * alpha + invalid_mask = offs_m[:, None] == offs_n[None, :] + max_ids = seq_len + if HAS_MULTIPLE_TARGETS: + max_ids = max_ids - n_targets + offs_m = tl.where( + offs_m < max_ids, + offs_m, + max_ids, + ) + offs_n = tl.where( + offs_n < max_ids, + offs_n, + max_ids, + ) + offs_m_minus_n = offs_m[:, None] - offs_n[None, :] + invalid_mask = invalid_mask or (offs_m_minus_n > 0) + if HAS_MAX_ATTN_LEN: + invalid_mask = invalid_mask and offs_m_minus_n <= max_attn_len + if HAS_CONTEXTUAL_SEQ_LEN: + invalid_mask = invalid_mask or ( + offs_m[:, None] == 0 and offs_n[None, :] < max_ids + ) + scale = tl.where(invalid_mask, (1.0 / MAX_SEQ_LEN), 0.0) + silu = fast_dividef(qk, 1.0 + fast_expf(-qk)) * scale + silu = silu.to(v_dtype) + + acc = tlx.async_dot_wait(0, acc) + # release the V buffer + v_empty = tlx.local_view(v_empties, v_buf_id) + tlx.barrier_arrive(v_empty, 1) + + end_n += BLOCK_N + + # increment loop trip counts + loop_trip_cnt += 1 + # v_buf_id = loop_trip_cnt % NUM_BUFFERS + # v_phase = (loop_trip_cnt // NUM_BUFFERS) % 2 + + # wait for the V buffer to be populated by the producer + v_buf_id = (loop_trip_cnt - 1) % NUM_BUFFERS + v_phase = ((loop_trip_cnt - 1) // NUM_BUFFERS) % 2 + v_full = tlx.local_view(v_fulls, v_buf_id) + # tlx.barrier_wait(v_full, v_buf_id) + v_tile = tlx.local_view(v_tiles, v_buf_id) + tlx.barrier_wait(v_full, v_phase) + acc = tlx.async_dot(silu, v_tile, acc) + acc = tlx.async_dot_wait(0, acc) + # release the V buffer + v_empty = tlx.local_view(v_empties, v_buf_id) + tlx.barrier_arrive(v_empty, 1) + + return acc, end_n, loop_trip_cnt + + +@triton.jit +def _hstu_attn_fwd_load_K_or_V( + K, + k_tiles, + k_empties, + k_fulls, + buf_id, + k_phase, + start_n, + seq_start, + offset_kh, + BLOCK_D_Q: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # wait for the K buffer to be released by the consumer + k_empty = tlx.local_view(k_empties, buf_id) + tlx.barrier_wait(k_empty, k_phase) + # load K + k_full = tlx.local_view(k_fulls, buf_id) + k_tile = tlx.local_view(k_tiles, buf_id) + tlx.barrier_expect_bytes(k_full, 2 * BLOCK_N * BLOCK_D_Q) # float16 + tlx.async_descriptor_load( + K, + k_tile, + [(seq_start + start_n).to(tl.int32), offset_kh.to(tl.int32)], + k_full, + ) + + +@triton.jit +def _hstu_attn_fwd_load_Q( + Q, + q_tiles, + q_fulls, + cid, + off_z, + off_h, + stride_qh, + start_m, + seq_start, + DeltaSize, + IS_DELTA_Q: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_M: tl.constexpr, +): + q_full = tlx.local_view(q_fulls, cid) + tlx.barrier_expect_bytes(q_full, 2 * BLOCK_M * BLOCK_D_Q) # float16 + q_tile = tlx.local_view(q_tiles, cid) + seq_offset = start_m + cid * BLOCK_M + if IS_DELTA_Q: + tlx.async_descriptor_load( + Q, + q_tile, + [ + (off_z * DeltaSize + start_m).to(tl.int32), + (off_h * stride_qh).to(tl.int32), + ], + q_full, + ) + else: + tlx.async_descriptor_load( + Q, + q_tile, + [ + (seq_start + seq_offset).to(tl.int32), + (off_h * stride_qh).to(tl.int32), + ], + q_full, + ) + + +@triton.jit +def _hstu_attn_fwd_caculate_range( + seq_len, + start_m, + n_targets, + contextual_seq_len, + max_attn_len, + HAS_MULTIPLE_TARGETS: tl.constexpr, + HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, + HAS_MAX_ATTN_LEN: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + if HAS_MULTIPLE_TARGETS: + uih_end = seq_len - n_targets + else: + uih_end = seq_len + + if HAS_CONTEXTUAL_SEQ_LEN is True and start_m < contextual_seq_len: + # uih_end must be larger than start_m + low = 0 + high = seq_len + else: + low = 0 + high = start_m + BLOCK_M + if HAS_MAX_ATTN_LEN: + if start_m > uih_end: + low = uih_end - max_attn_len + else: + low = start_m - max_attn_len + if HAS_CONTEXTUAL_SEQ_LEN: + low = low if low > contextual_seq_len else 0 + else: + low = low if low > 0 else 0 + if HAS_MULTIPLE_TARGETS: + uih_end = (uih_end + BLOCK_N - 1) // BLOCK_N * BLOCK_N + if uih_end < start_m: + high = seq_len - n_targets + + return low, high, uih_end + + +@triton.jit +def _hstu_attn_fwd_load_Q_K_V( + Q, + K, + V, + q_tiles, + k_tiles, + v_tiles, + q_fulls, + k_fulls, + v_fulls, + k_empties, + v_empties, + stride_qh, + stride_kh, + stride_vh, + contextual_seq_len, + max_attn_len, + DeltaSize, + off_z, + off_h, + start_m, + seq_start, + seq_len, + n_targets, + HAS_MULTIPLE_TARGETS: tl.constexpr, + IS_DELTA_Q: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_D_V: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_BUFFERS: tl.constexpr, + NUM_MMA_GROUPS: tl.constexpr, + HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, + HAS_MAX_ATTN_LEN: tl.constexpr, +): + # load q: it will stay in SRAM throughout + BLOCK_M_SPLIT: tl.constexpr = BLOCK_M // NUM_MMA_GROUPS + + _hstu_attn_fwd_load_Q( + Q=Q, + q_tiles=q_tiles, + q_fulls=q_fulls, + cid=0, + off_z=off_z, + off_h=off_h, + stride_qh=stride_qh, + start_m=start_m, + seq_start=seq_start, + DeltaSize=DeltaSize, + IS_DELTA_Q=IS_DELTA_Q, + BLOCK_D_Q=BLOCK_D_Q, + BLOCK_M=BLOCK_M_SPLIT, + ) + + off_h = off_h.to(tl.int64) + off_z = off_z.to(tl.int64) + offset_kh = off_h * stride_kh + offset_vh = off_h * stride_vh + + low, high, uih_end = _hstu_attn_fwd_caculate_range( + seq_len, + start_m, + n_targets, + contextual_seq_len, + max_attn_len, + HAS_MULTIPLE_TARGETS, + HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN, + BLOCK_M, + BLOCK_N, + ) + + kv_phase = 0 + loop_trip_cnt = 0 + + # pyre-ignore[58] + buf_id = loop_trip_cnt % NUM_BUFFERS + # buffers in a row share the same phase + kv_phase = kv_phase ^ (buf_id == 0) + + start_n = tl.multiple_of(low, BLOCK_N) + + _hstu_attn_fwd_load_K_or_V( + K, + k_tiles, + k_empties, + k_fulls, + buf_id, + kv_phase, + start_n, + seq_start, + offset_kh, + BLOCK_D_Q, + BLOCK_N, + ) + + for cid in tl.range(1, NUM_MMA_GROUPS, loop_unroll_factor=NUM_MMA_GROUPS - 1): + _hstu_attn_fwd_load_Q( + Q, + q_tiles, + q_fulls, + cid, + off_z, + off_h, + stride_qh, + start_m, + seq_start, + DeltaSize, + IS_DELTA_Q, + BLOCK_D_Q, + BLOCK_M_SPLIT, + ) + + _hstu_attn_fwd_load_K_or_V( + V, + v_tiles, + v_empties, + v_fulls, + buf_id, + kv_phase, + start_n, + seq_start, + offset_vh, + BLOCK_D_V, + BLOCK_N, + ) + + loop_trip_cnt += 1 + + for start in range(low + BLOCK_N, high, BLOCK_N): + # pyre-ignore[58] + buf_id = loop_trip_cnt % NUM_BUFFERS + # buffers in a row share the same phase + kv_phase = kv_phase ^ (buf_id == 0) + + start_n = tl.multiple_of(start, BLOCK_N) + + _hstu_attn_fwd_load_K_or_V( + K, + k_tiles, + k_empties, + k_fulls, + buf_id, + kv_phase, + start_n, + seq_start, + offset_kh, + BLOCK_D_Q, + BLOCK_N, + ) + + _hstu_attn_fwd_load_K_or_V( + V, + v_tiles, + v_empties, + v_fulls, + buf_id, + kv_phase, + start_n, + seq_start, + offset_vh, + BLOCK_D_V, + BLOCK_N, + ) + + # increment loop trip counts + loop_trip_cnt += 1 + + # pyre-ignore[61] + if uih_end < start_m: + low_delta = start_m + high_delta = start_m + BLOCK_M + for start_delta in tl.range(low_delta, high_delta, BLOCK_N, num_stages=0): + # pyre-ignore[58] + buf_id = loop_trip_cnt % NUM_BUFFERS + # buffers in a row share the same phase + kv_phase = kv_phase ^ (buf_id == 0) + + start_n = tl.multiple_of(start_delta, BLOCK_N) + + _hstu_attn_fwd_load_K_or_V( + K, + k_tiles, + k_empties, + k_fulls, + buf_id, + kv_phase, + start_n, + seq_start, + offset_kh, + BLOCK_D_Q, + BLOCK_N, + ) + + _hstu_attn_fwd_load_K_or_V( + V, + v_tiles, + v_empties, + v_fulls, + buf_id, + kv_phase, + start_n, + seq_start, + offset_vh, + BLOCK_D_V, + BLOCK_N, + ) + + # increment loop trip counts + loop_trip_cnt += 1 + + +@triton.jit +def _hstu_attn_fwd_compute_tlx( # noqa C901 + Q, + K, + V, + H, + DimQ, + DimV, + seq_offsets, + num_targets, + Out, + stride_qh, + stride_kh, + stride_vh, + stride_om, + stride_oh, + alpha, + MAX_SEQ_LEN, + DeltaSize, + contextual_seq_len, + max_attn_len, + off_z, + off_h, + pid, + HAS_MULTIPLE_TARGETS: tl.constexpr, + IS_DELTA_Q: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_D_V: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_BUFFERS: tl.constexpr, # + NUM_MMA_WARPS_PER_GROUP: tl.constexpr, # + NUM_MMA_GROUPS: tl.constexpr, # + HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, + HAS_MAX_ATTN_LEN: tl.constexpr, +): + seq_start = tl.load(seq_offsets + off_z).to(tl.int64) + seq_end = tl.load(seq_offsets + off_z + 1) + seq_len = (seq_end - seq_start).to(tl.int32) + + if IS_DELTA_Q: + start_m = pid * BLOCK_M + start_m = (start_m + seq_len - DeltaSize).to(tl.int32) + else: + start_m = pid * BLOCK_M + + if start_m >= seq_len: + return + + if HAS_MULTIPLE_TARGETS: + n_targets = tl.load(num_targets + off_z).to(tl.int32) + else: + n_targets = None + + BLOCK_M_SPLIT: tl.constexpr = BLOCK_M // NUM_MMA_GROUPS + # allocate buffers + q_tiles = tlx.local_alloc( + (BLOCK_M_SPLIT, BLOCK_D_Q), tlx.dtype_of(Q), NUM_MMA_GROUPS + ) + k_tiles = tlx.local_alloc((BLOCK_N, BLOCK_D_Q), tlx.dtype_of(K), NUM_BUFFERS) + v_tiles = tlx.local_alloc((BLOCK_N, BLOCK_D_V), tlx.dtype_of(V), NUM_BUFFERS) + + # allocate barriers + q_fulls = tlx.alloc_barriers(num_barriers=NUM_MMA_GROUPS, arrive_count=1) + k_empties = tlx.alloc_barriers( + num_barriers=NUM_BUFFERS, arrive_count=NUM_MMA_GROUPS + ) + k_fulls = tlx.alloc_barriers(num_barriers=NUM_BUFFERS, arrive_count=1) + v_empties = tlx.alloc_barriers( + num_barriers=NUM_BUFFERS, arrive_count=NUM_MMA_GROUPS + ) + v_fulls = tlx.alloc_barriers(num_barriers=NUM_BUFFERS, arrive_count=1) + + with tlx.async_tasks(): + # producer group + with tlx.async_task("default"): + _hstu_attn_fwd_load_Q_K_V( + Q=Q, + K=K, + V=V, + q_tiles=q_tiles, + k_tiles=k_tiles, + v_tiles=v_tiles, + q_fulls=q_fulls, + k_fulls=k_fulls, + v_fulls=v_fulls, + k_empties=k_empties, + v_empties=v_empties, + stride_qh=stride_qh, + stride_kh=stride_kh, + stride_vh=stride_vh, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + DeltaSize=DeltaSize, + off_z=off_z, + off_h=off_h, + start_m=start_m, + seq_start=seq_start, + seq_len=seq_len, + n_targets=n_targets, + HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, + IS_DELTA_Q=IS_DELTA_Q, + BLOCK_D_Q=BLOCK_D_Q, + BLOCK_D_V=BLOCK_D_V, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + NUM_BUFFERS=NUM_BUFFERS, + NUM_MMA_GROUPS=NUM_MMA_GROUPS, + HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, + ) + + # consumer groups + with tlx.async_task( + num_warps=NUM_MMA_WARPS_PER_GROUP, registers=232, replicate=NUM_MMA_GROUPS + ): + cid = tlx.async_task_replica_id() + acc = tl.zeros([BLOCK_M_SPLIT, BLOCK_D_V], dtype=tl.float32) + # initialize offsets + offs_m = start_m + tl.arange(0, BLOCK_M_SPLIT) + cid * BLOCK_M_SPLIT + offs_n = tl.arange(0, BLOCK_N) + + low, high, uih_end = _hstu_attn_fwd_caculate_range( + seq_len, + start_m, + n_targets, + contextual_seq_len, + max_attn_len, + HAS_MULTIPLE_TARGETS, + HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN, + BLOCK_M, + BLOCK_N, + ) + + end_n = low + loop_trip_cnt = 0 + + acc, end_n, loop_trip_cnt = _hstu_attn_fwd_compute_main_loop_tlx_pipelined( + low=low, + high=high, + seq_len=seq_len, + offs_m=offs_m, + offs_n=offs_n, + acc=acc, + q_tiles=q_tiles, + k_tiles=k_tiles, + v_tiles=v_tiles, + q_fulls=q_fulls, + k_fulls=k_fulls, + v_fulls=v_fulls, + k_empties=k_empties, + v_empties=v_empties, + v_dtype=tlx.dtype_of(V), + n_targets=n_targets, + alpha=alpha, + end_n=end_n, + loop_trip_cnt=loop_trip_cnt, + max_attn_len=max_attn_len, + HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, + HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, + cid=cid, + BLOCK_N=BLOCK_N, + NUM_BUFFERS=NUM_BUFFERS, + MAX_SEQ_LEN=MAX_SEQ_LEN, + WAIT_FOR_Q=1, + ) + + # pyre-ignore[61] + if uih_end < start_m: + low_delta = start_m + high_delta = start_m + BLOCK_M + acc, end_n, loop_trip_cnt = _hstu_attn_fwd_compute_main_loop_tlx( + low=low_delta, + high=high_delta, + seq_len=seq_len, + offs_m=offs_m, + offs_n=offs_n, + acc=acc, + q_tiles=q_tiles, + k_tiles=k_tiles, + v_tiles=v_tiles, + q_fulls=q_fulls, + k_fulls=k_fulls, + v_fulls=v_fulls, + k_empties=k_empties, + v_empties=v_empties, + v_dtype=tlx.dtype_of(V), + n_targets=n_targets, + alpha=alpha, + end_n=end_n, + loop_trip_cnt=loop_trip_cnt, + max_attn_len=max_attn_len, + HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, + HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, + cid=cid, + BLOCK_N=BLOCK_N, + NUM_BUFFERS=NUM_BUFFERS, + MAX_SEQ_LEN=MAX_SEQ_LEN, + WAIT_FOR_Q=0, + ) + + # Don't use TMA in Jagged case since we don't want to overwrite + # the output of another sequence + if IS_DELTA_Q: + start_m_delta = pid * BLOCK_M + cid * BLOCK_M_SPLIT + offs_m_delta = start_m_delta + tl.arange(0, BLOCK_M_SPLIT) + offs_v_d = tl.arange(0, BLOCK_D_V) + off_o = Out + off_z * DeltaSize * stride_om + off_h * stride_oh + out_ptrs = off_o + offs_m_delta[:, None] * stride_om + offs_v_d[None, :] + tl.store(out_ptrs, acc, mask=(offs_m_delta < DeltaSize)[:, None]) + else: + # rematerialize offsets to save registers + start_m = pid * BLOCK_M + cid * BLOCK_M_SPLIT + offs_m = start_m + tl.arange(0, BLOCK_M_SPLIT) + offs_v_d = tl.arange(0, BLOCK_D_V) + off_o = Out + seq_start * stride_om + off_h * stride_oh + out_ptrs = off_o + offs_m[:, None] * stride_om + offs_v_d[None, :] + tl.store(out_ptrs, acc, mask=(offs_m < seq_len)[:, None]) + + +def _get_fw_pinned_configs() -> List[triton.Config]: + # Pinned forward-attention configs for MI350X gfx950. The full search is a + # coin flip between matrix_instr_nonkdim 16 (fast) and 32 (slow); pinning the + # known winners makes cold starts deterministic. See _autotune_pinning. + # + # This is a LIST, one entry per training shape we've tuned (currently just + # bs=1024). With >1 entry the autotuner still runs a tiny benchmark over only + # these candidates and caches the winner per `key` (AUTOTUNE_Z / H / + # AUTOTUNE_MAX_SEQ_LEN / DimQ / DimV / ...), so each batch size automatically + # picks its own config — same pattern as the layer-norm pins. + # + # TO ADD A NEW BATCH SIZE / SHAPE: + # 1. Run once with TRITON_FULL_AUTOTUNE=1 TRITON_PRINT_AUTOTUNING=1. + # 2. Grep the log for "best config selected:" under "_hstu_attn_fwd". + # 3. Append that config below (copy BLOCK_M/BLOCK_N/matrix_instr_nonkdim/ + # waves_per_eu/kpack/num_stages/num_warps verbatim). + # The four USE_TLX/NUM_* defaults below are required by the kernel signature + # (see the USE_TLX-default loop in _get_fw_configs); the pinned path bypasses + # that loop, so every pinned entry must set them explicitly. + if torch.version.hip: + return [ + # --- yambda bs=1024, L=2048 winner (from capture log) --- + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 32, + "matrix_instr_nonkdim": 16, + "waves_per_eu": 0, + "kpack": 2, + "USE_TLX": False, + "NUM_BUFFERS": 1, + "NUM_MMA_WARPS_PER_GROUP": 1, + "NUM_MMA_GROUPS": 1, + }, + num_stages=2, + num_warps=8, + ), + # --- add more (bs, L) winners here; see "TO ADD A NEW BATCH SIZE" --- + ] + return _get_fw_configs() + + +@triton_autotune( + configs=pinned_or_full(_get_fw_pinned_configs(), _get_fw_configs), + key=[ + "AUTOTUNE_Z", + "H", + "AUTOTUNE_MAX_SEQ_LEN", + "DimQ", + "DimV", + "DeltaSize", + "IS_DELTA_Q", + ], + prune_configs_by={"early_config_prune": _early_config_prune}, +) +@triton.jit +def _hstu_attn_fwd( # noqa C901 + Q, + K, + V, + workspace_ptr, + sort_by_length_indices, + seq_offsets, + num_targets, + Out, + stride_qm, + stride_qh, + stride_kn, + stride_kh, + stride_vn, + stride_vh, + stride_om, + stride_oh, + alpha, + Z, + AUTOTUNE_Z, + H, + MAX_SEQ_LEN, + AUTOTUNE_MAX_SEQ_LEN, # Quantized MAX_SEQ_LEN used as an autotuning key + DimQ, + DimV, + DeltaSize, + contextual_seq_len, + max_attn_len, + HAS_MULTIPLE_TARGETS: tl.constexpr, + IS_DELTA_Q: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_D_V: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + USE_TLX: tl.constexpr, + NUM_BUFFERS: tl.constexpr, # + NUM_MMA_WARPS_PER_GROUP: tl.constexpr, # + NUM_MMA_GROUPS: tl.constexpr, # + HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, + HAS_MAX_ATTN_LEN: tl.constexpr, + HAS_SORT_BY_LENGTH_INDICES: tl.constexpr, + ENABLE_TMA: tl.constexpr, + TMA_DESC_SIZE: tl.constexpr, +): + off_hz = tl.program_id(1) + off_z = off_hz // H + if HAS_SORT_BY_LENGTH_INDICES: + off_z = tl.load(sort_by_length_indices + off_z) + off_h = off_hz % H + pid = tl.program_id(0) + if USE_TLX: + _hstu_attn_fwd_compute_tlx( + Q=Q, + K=K, + V=V, + H=H, + DimQ=DimQ, + DimV=DimV, + seq_offsets=seq_offsets, + num_targets=num_targets, + Out=Out, + stride_qh=stride_qh, + stride_kh=stride_kh, + stride_vh=stride_vh, + stride_om=stride_om, + stride_oh=stride_oh, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + DeltaSize=DeltaSize, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + off_z=off_z, + off_h=off_h, + pid=pid, + HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, + IS_DELTA_Q=IS_DELTA_Q, + ALLOW_TF32=ALLOW_TF32, + BLOCK_D_Q=BLOCK_D_Q, + BLOCK_D_V=BLOCK_D_V, + HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + NUM_BUFFERS=NUM_BUFFERS, + NUM_MMA_WARPS_PER_GROUP=NUM_MMA_WARPS_PER_GROUP, + NUM_MMA_GROUPS=NUM_MMA_GROUPS, + ) + else: + _hstu_attn_fwd_compute( + Q=Q, + K=K, + V=V, + H=H, + DimQ=DimQ, + DimV=DimV, + workspace_ptr=workspace_ptr, + seq_offsets=seq_offsets, + num_targets=num_targets, + Out=Out, + stride_qm=stride_qm, + stride_qh=stride_qh, + stride_kn=stride_kn, + stride_kh=stride_kh, + stride_vn=stride_vn, + stride_vh=stride_vh, + stride_om=stride_om, + stride_oh=stride_oh, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + DeltaSize=DeltaSize, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + off_z=off_z, + off_h=off_h, + pid=pid, + HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, + IS_DELTA_Q=IS_DELTA_Q, + ALLOW_TF32=ALLOW_TF32, + BLOCK_D_Q=BLOCK_D_Q, + BLOCK_D_V=BLOCK_D_V, + HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ENABLE_TMA=ENABLE_TMA, + TMA_DESC_SIZE=TMA_DESC_SIZE, + ) + + +@triton_autotune( + configs=pinned_or_full(_get_fw_pinned_configs(), _get_fw_configs), + key=[ + "AUTOTUNE_Z", + "H", + "AUTOTUNE_MAX_SEQ_LEN", + "DimQ", + "DimV", + "DeltaSize", + "IS_DELTA_Q", + ], + prune_configs_by={"early_config_prune": _early_config_prune}, +) +@triton.jit +def _hstu_attn_fwd_persistent( # noqa C901 + Q, + K, + V, + workspace_ptr, + sort_by_length_indices, + seq_offsets, + num_targets, + Out, + stride_qm, + stride_qh, + stride_kn, + stride_kh, + stride_vn, + stride_vh, + stride_om, + stride_oh, + alpha, + Z, + AUTOTUNE_Z, + H, + MAX_SEQ_LEN, + AUTOTUNE_MAX_SEQ_LEN, # Quantized MAX_SEQ_LEN used as an autotuning key + DimQ, + DimV, + DeltaSize, + contextual_seq_len, + max_attn_len, + HAS_MULTIPLE_TARGETS: tl.constexpr, + IS_DELTA_Q: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_D_V: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + USE_TLX: tl.constexpr, + NUM_BUFFERS: tl.constexpr, # + NUM_MMA_WARPS_PER_GROUP: tl.constexpr, # + NUM_MMA_GROUPS: tl.constexpr, # + HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, + HAS_MAX_ATTN_LEN: tl.constexpr, + HAS_SORT_BY_LENGTH_INDICES: tl.constexpr, + ENABLE_TMA: tl.constexpr, + TMA_DESC_SIZE: tl.constexpr, +): + n_tile_num = tl.cdiv(MAX_SEQ_LEN, BLOCK_M) + prog_id = tl.program_id(0) + num_progs = tl.num_programs(0) + + total_tiles = n_tile_num * Z * H + + tiles_per_sm = total_tiles // num_progs + if prog_id < total_tiles % num_progs: + tiles_per_sm += 1 + + tile_idx = prog_id + for _ in range(0, tiles_per_sm): + pid = (total_tiles - tile_idx - 1) // (Z * H) + off_hz = (total_tiles - tile_idx - 1) % (Z * H) + off_z = off_hz // H + off_h = off_hz % H + _hstu_attn_fwd_compute( + Q=Q, + K=K, + V=V, + H=H, + DimQ=DimQ, + DimV=DimV, + workspace_ptr=workspace_ptr, + seq_offsets=seq_offsets, + num_targets=num_targets, + Out=Out, + stride_qm=stride_qm, + stride_qh=stride_qh, + stride_kn=stride_kn, + stride_kh=stride_kh, + stride_vn=stride_vn, + stride_vh=stride_vh, + stride_om=stride_om, + stride_oh=stride_oh, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + DeltaSize=DeltaSize, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + off_z=off_z, + off_h=off_h, + pid=pid, + HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, + IS_DELTA_Q=IS_DELTA_Q, + ALLOW_TF32=ALLOW_TF32, + BLOCK_D_Q=BLOCK_D_Q, + BLOCK_D_V=BLOCK_D_V, + HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ENABLE_TMA=ENABLE_TMA, + TMA_DESC_SIZE=TMA_DESC_SIZE, + ) + tile_idx += num_progs + + +@triton.jit +def _hstu_attn_bwd_one_block( # noqa C901 + start_m, + offs_n, + offs_m, + q_ptrs_trans, + dq_ptrs_trans, + do_ptrs, + device_desc_q, + device_desc_do, + dk, + dv, + k, + v, + pos_offs_n, + seq_len, + max_ids, + contextual_seq_len, + max_attn_len, + LOCK, + off_h, + stride_qh, + stride_doh, + stride_qm, + stride_dom, + stride_dqm, + alpha, + MAX_SEQ_LEN, + HAS_MULTIPLE_TARGETS: tl.constexpr, + HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, + HAS_MAX_ATTN_LEN: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_M: tl.constexpr, + ATOMIC_ADD: tl.constexpr, + ENABLE_TMA: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_D_V: tl.constexpr, +): + pos_offs_m = offs_m + start_m + mask_m = pos_offs_m < seq_len + invalid_mask_trans = pos_offs_m[None, :] == offs_n[:, None] + # recompute qk and silu + if HAS_CONTEXTUAL_SEQ_LEN: + pos_offs_m = pos_offs_m - contextual_seq_len + 1 + pos_offs_m = tl.where( + pos_offs_m > 0, + pos_offs_m, + 0, + ) + if HAS_MULTIPLE_TARGETS: + pos_offs_m = tl.where( + pos_offs_m < max_ids, + pos_offs_m, + max_ids, + ) + if ENABLE_TMA: + q = device_desc_q.load( + [start_m, (off_h * stride_qh).to(tl.int32)], + ) + q_trans = tl.trans(q) + else: + q_trans = tl.load( + q_ptrs_trans + start_m * stride_qm, + mask=mask_m[None, :], + other=0.0, + ) + qk_trans = tl.dot(k, q_trans, allow_tf32=ALLOW_TF32) * alpha + sig_trans = fast_dividef(1.0, 1.0 + tl.exp(-qk_trans)) + silu_trans = qk_trans * sig_trans * (1.0 / MAX_SEQ_LEN) + pos_offs_m_minus_n = pos_offs_m[None, :] - pos_offs_n[:, None] + invalid_mask_trans = invalid_mask_trans or (pos_offs_m_minus_n > 0) + if HAS_MAX_ATTN_LEN: + invalid_mask_trans = invalid_mask_trans and pos_offs_m_minus_n <= max_attn_len + if HAS_CONTEXTUAL_SEQ_LEN: + invalid_mask_trans = invalid_mask_trans or ( + pos_offs_m[None, :] == 0 and pos_offs_n[:, None] < max_ids + ) + silu_trans = tl.where(invalid_mask_trans, silu_trans, 0) + silu_trans = silu_trans.to(k.dtype) + # compute dv + if ENABLE_TMA: + do = device_desc_do.load( + [start_m, (off_h * stride_doh).to(tl.int32)], + ) + else: + do = tl.load( + do_ptrs + start_m * stride_dom, + mask=mask_m[:, None], + other=0.0, + ) + dv += tl.dot(silu_trans, do, allow_tf32=ALLOW_TF32) + + # compute dk and dq + dqk_trans = tl.dot(v, tl.trans(do), allow_tf32=ALLOW_TF32) + dqk_trans = ( + dqk_trans * sig_trans * (1 + qk_trans * (1 - sig_trans)) * (1.0 / MAX_SEQ_LEN) + ) + dqk_trans = tl.where(invalid_mask_trans, dqk_trans, 0) + dqk_trans = dqk_trans.to(k.dtype) + + # Note: the factor `alpha` is delayed until the end of the function to reduce the cost + dk += tl.dot(dqk_trans, tl.trans(q_trans), allow_tf32=ALLOW_TF32) + acc_dq( + dq_ptrs_trans=dq_ptrs_trans, + start_m=start_m, + stride_dqm=stride_dqm, + k=k, + dqk_trans=dqk_trans, + alpha=alpha, + mask_m=mask_m, + MAX_SEQ_LEN=MAX_SEQ_LEN, + LOCK=LOCK, + BLOCK_M=BLOCK_M, + ATOMIC_ADD=ATOMIC_ADD, + ALLOW_TF32=ALLOW_TF32, + ) + return dk, dv + + +@triton.jit +def _hstu_attn_bwd_one_col_block( # noqa C901 + start_n, + seq_len, + n_targets, + contextual_seq_len, + max_attn_len, + Q, + K, + V, + DOut, + DQ, + DK, + DV, + device_desc_q, + device_desc_k, + device_desc_v, + device_desc_do, + device_desc_dk, + device_desc_dv, + LOCK, + off_h, + stride_qh, + stride_kh, + stride_vh, + stride_doh, + stride_dkh, + stride_dvh, + stride_qm, + stride_kn, + stride_vn, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + alpha, + MAX_SEQ_LEN, + HAS_MULTIPLE_TARGETS: tl.constexpr, + HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, + HAS_MAX_ATTN_LEN: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_D_V: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + UNROLL: tl.constexpr, + ATOMIC_ADD: tl.constexpr, + ENABLE_TMA: tl.constexpr, +): + if HAS_MULTIPLE_TARGETS: + low = start_n + if HAS_MAX_ATTN_LEN: + high = start_n + max_attn_len + BLOCK_N + high = high if high + n_targets < seq_len else seq_len + else: + high = seq_len + else: + low = start_n + if HAS_MAX_ATTN_LEN: + high = start_n + max_attn_len + BLOCK_N + high = high if high < seq_len else seq_len + else: + high = seq_len + if HAS_CONTEXTUAL_SEQ_LEN: + contextual_block_end = tl.cdiv(contextual_seq_len, BLOCK_M) * BLOCK_M + if low < contextual_block_end: + low = contextual_block_end + + offs_m = tl.arange(0, BLOCK_M) + offs_qk_d = tl.arange(0, BLOCK_D_Q) + offs_v_d = tl.arange(0, BLOCK_D_V) + offs_n = start_n + tl.arange(0, BLOCK_N) + + dq_ptrs_trans = DQ + (offs_m[None, :] * stride_dqm + offs_qk_d[:, None]) + dv = tl.zeros([BLOCK_N, BLOCK_D_V], dtype=tl.float32) + dk = tl.zeros([BLOCK_N, BLOCK_D_Q], dtype=tl.float32) + if ENABLE_TMA: + q_ptrs_trans = None + do_ptrs = None + k = device_desc_k.load( + [start_n, (off_h * stride_kh).to(tl.int32)], + ) + v = device_desc_v.load( + [start_n, (off_h * stride_vh).to(tl.int32)], + ) + else: + mask_n = offs_n < seq_len + q_ptrs_trans = Q + (offs_m[None, :] * stride_qm + offs_qk_d[:, None]) + do_ptrs = DOut + (offs_m[:, None] * stride_dom + offs_v_d[None, :]) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_qk_d[None, :]) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_v_d[None, :]) + k = tl.load(k_ptrs, mask=mask_n[:, None], other=0.0) + v = tl.load(v_ptrs, mask=mask_n[:, None], other=0.0) + max_ids = seq_len + if HAS_CONTEXTUAL_SEQ_LEN: + pos_offs_n = offs_n - contextual_seq_len + 1 + pos_offs_n = tl.where( + pos_offs_n > 0, + pos_offs_n, + 0, + ) + max_ids = max_ids - contextual_seq_len + 1 + else: + pos_offs_n = offs_n + if HAS_MULTIPLE_TARGETS: + max_ids = max_ids - n_targets + pos_offs_n = tl.where( + pos_offs_n < max_ids, + pos_offs_n, + max_ids, + ) + # loop over rows + if HAS_CONTEXTUAL_SEQ_LEN: + for start_m in range(0, contextual_seq_len, BLOCK_M): + start_m = tl.multiple_of(start_m, BLOCK_M) + dk, dv = _hstu_attn_bwd_one_block( + start_m=start_m, + offs_n=offs_n, + offs_m=offs_m, + q_ptrs_trans=q_ptrs_trans, + dq_ptrs_trans=dq_ptrs_trans, + do_ptrs=do_ptrs, + device_desc_q=device_desc_q, + device_desc_do=device_desc_do, + dk=dk, + dv=dv, + k=k, + v=v, + pos_offs_n=pos_offs_n, + seq_len=seq_len, + max_ids=max_ids, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + LOCK=LOCK, + off_h=off_h, + stride_qh=stride_qh, + stride_doh=stride_doh, + stride_qm=stride_qm, + stride_dom=stride_dom, + stride_dqm=stride_dqm, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, + HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, + ALLOW_TF32=ALLOW_TF32, + BLOCK_M=BLOCK_M, + ATOMIC_ADD=ATOMIC_ADD, + ENABLE_TMA=ENABLE_TMA, + BLOCK_D_Q=BLOCK_D_Q, + BLOCK_D_V=BLOCK_D_V, + ) + for start_m in tl.range(low, high, BLOCK_M, loop_unroll_factor=UNROLL): + start_m = tl.multiple_of(start_m, BLOCK_M) + dk, dv = _hstu_attn_bwd_one_block( + start_m=start_m, + offs_n=offs_n, + offs_m=offs_m, + q_ptrs_trans=q_ptrs_trans, + dq_ptrs_trans=dq_ptrs_trans, + do_ptrs=do_ptrs, + device_desc_q=device_desc_q, + device_desc_do=device_desc_do, + dk=dk, + dv=dv, + k=k, + v=v, + pos_offs_n=pos_offs_n, + seq_len=seq_len, + max_ids=max_ids, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + LOCK=LOCK, + off_h=off_h, + stride_qh=stride_qh, + stride_doh=stride_doh, + stride_qm=stride_qm, + stride_dom=stride_dom, + stride_dqm=stride_dqm, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, + HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, + ALLOW_TF32=ALLOW_TF32, + BLOCK_M=BLOCK_M, + ATOMIC_ADD=ATOMIC_ADD, + ENABLE_TMA=ENABLE_TMA, + BLOCK_D_Q=BLOCK_D_Q, + BLOCK_D_V=BLOCK_D_V, + ) + # write-back + dk = dk * alpha + if ENABLE_TMA: + device_desc_dv.store( + [start_n, (off_h * stride_dvh).to(tl.int32)], + dv.to(k.dtype), + ) + device_desc_dk.store( + [start_n, (off_h * stride_dkh).to(tl.int32)], + dk.to(k.dtype), + ) + else: + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_v_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_qk_d[None, :]) + tl.store(dv_ptrs, dv.to(k.dtype), mask=mask_n[:, None]) # pyre-ignore[61] + tl.store(dk_ptrs, dk.to(k.dtype), mask=mask_n[:, None]) # pyre-ignore[61] + + +def _bwd_pre_hook(nargs): + nargs["DQ"].zero_() + if nargs["SEQUENCE_PARALLEL"] is True: + nargs["LOCK"].zero_() + + +def _get_bw_configs() -> List[triton.Config]: + if torch.version.hip: + configs = [] + for BLOCK_M in [32, 64]: + for BLOCK_N in [32, 64, 128]: + for num_stages in [1, 2]: + for num_warps in [4, 8]: + for matrix_instr_nonkdim in [16, 32]: + for waves_per_eu in [0, 2, 4]: + for sp in [True, False]: + configs.append( + triton.Config( + { + "BLOCK_M": BLOCK_M, + "BLOCK_N": BLOCK_N, + "matrix_instr_nonkdim": matrix_instr_nonkdim, + "waves_per_eu": waves_per_eu, + "SEQUENCE_PARALLEL": sp, + "UNROLL": 1, + }, + num_stages=num_stages, + num_warps=num_warps, + pre_hook=_bwd_pre_hook, + ) + ) + return configs + + configs = [ + triton.Config( + {"BLOCK_M": 16, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=2, + num_warps=2, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 16, "BLOCK_N": 16, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=2, + num_warps=2, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 16, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=2, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 16, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=1, + num_warps=8, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=1, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=2, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=2, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=2, + num_warps=8, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=1, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=2, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=1, + num_warps=8, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=2, + num_warps=8, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=2, + num_warps=8, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=3, + num_warps=8, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False, "UNROLL": 4}, + num_stages=2, + num_warps=8, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 16, "BLOCK_N": 32, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, + num_stages=2, + num_warps=2, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, + num_stages=1, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, + num_stages=2, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, + num_stages=1, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, + num_stages=2, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + ] + if torch.cuda.is_available() and torch.version.cuda < "12.8": + configs += [ + triton.Config( + {"BLOCK_M": 16, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=1, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=1, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=1, + num_warps=8, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, + num_stages=1, + num_warps=8, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, + num_stages=3, + num_warps=8, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, + num_stages=1, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, + num_stages=2, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + { + "BLOCK_M": 32, + "BLOCK_N": 128, + "SEQUENCE_PARALLEL": False, + "UNROLL": 2, + }, + num_stages=2, + num_warps=8, + pre_hook=_bwd_pre_hook, + ), + ] + else: + print("WARNING: temporarily disabled some autotune configs for CUDA 12.8+") + return configs + + +def _get_bw_pinned_configs() -> List[triton.Config]: + # Pinned backward-attention configs for MI350X gfx950. Pins the fast + # matrix_instr_nonkdim=16 winner(s) to avoid the 16-vs-32 autotune lottery. + # + # LIST, one entry per tuned shape (currently just bs=1024). With >1 entry the + # autotuner benchmarks only these candidates and caches the winner per `key` + # (AUTOTUNE_Z / H / AUTOTUNE_MAX_SEQ_LEN / DimQ / DimV), so each batch size + # picks its own config automatically — same pattern as the layer-norm pins. + # + # TO ADD A NEW BATCH SIZE / SHAPE: + # 1. Run once with TRITON_FULL_AUTOTUNE=1 TRITON_PRINT_AUTOTUNING=1. + # 2. Grep the log for "best config selected:" under "_hstu_attn_bwd". + # 3. Append that config below (verbatim BLOCK_M/BLOCK_N/matrix_instr_nonkdim/ + # waves_per_eu/SEQUENCE_PARALLEL/UNROLL/num_stages/num_warps). + # Keep pre_hook=_bwd_pre_hook on every entry (the bwd configs require it). + if torch.version.hip: + return [ + # --- yambda bs=1024, L=2048 winner (from capture log) --- + triton.Config( + { + "BLOCK_M": 32, + "BLOCK_N": 128, + "matrix_instr_nonkdim": 16, + "waves_per_eu": 0, + "SEQUENCE_PARALLEL": False, + "UNROLL": 1, + }, + num_stages=1, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + # --- add more (bs, L) winners here; see "TO ADD A NEW BATCH SIZE" --- + ] + return _get_bw_configs() + + +@triton_autotune( + configs=pinned_or_full(_get_bw_pinned_configs(), _get_bw_configs), + key=[ + "AUTOTUNE_Z", + "H", + "AUTOTUNE_MAX_SEQ_LEN", + "DimQ", + "DimV", + ], +) +@triton.jit +def _hstu_attn_bwd( # noqa C901 + Q, + K, + V, + tma_workspace_ptr, + sort_by_length_indices, + seq_offsets, + num_targets, + DOut, + DQ, + DK, + DV, + LOCK, + stride_qm, + stride_qh, + stride_kn, + stride_kh, + stride_vn, + stride_vh, + stride_dom, + stride_doh, + stride_dqm, + stride_dqh, + stride_dkn, + stride_dkh, + stride_dvn, + stride_dvh, + alpha, + contextual_seq_len, + max_attn_len, + Z, + AUTOTUNE_Z, + H, + MAX_SEQ_LEN, + AUTOTUNE_MAX_SEQ_LEN, # Quantized MAX_SEQ_LEN used as an autotuning key + DimQ, + DimV, + HAS_MULTIPLE_TARGETS: tl.constexpr, + HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, + HAS_MAX_ATTN_LEN: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_D_V: tl.constexpr, + SEQUENCE_PARALLEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + UNROLL: tl.constexpr, + HAS_SORT_BY_LENGTH_INDICES: tl.constexpr, + ENABLE_TMA: tl.constexpr, + TMA_DESC_SIZE: tl.constexpr, + ENABLE_BUFFER_OPS_ASSUMES: tl.constexpr, +): + off_hz = tl.program_id(0) + off_z = off_hz // H + if HAS_SORT_BY_LENGTH_INDICES: + off_z = tl.load(sort_by_length_indices + off_z) + off_h = off_hz % H + off_h = off_h.to(tl.int64) + seq_start = tl.load(seq_offsets + off_z).to(tl.int64) + seq_end = tl.load(seq_offsets + off_z + 1) + seq_len = (seq_end - seq_start).to(tl.int32) + if HAS_MULTIPLE_TARGETS: + n_targets = tl.load(num_targets + off_z).to(tl.int32) + else: + n_targets = None + if ENABLE_BUFFER_OPS_ASSUMES: + tl.assume(off_hz >= 0) + tl.assume(off_z >= 0) + tl.assume(off_h >= 0) + tl.assume(seq_start >= 0) + tl.assume(stride_qm >= 0) + tl.assume(stride_qh >= 0) + tl.assume(stride_kn >= 0) + tl.assume(stride_kh >= 0) + tl.assume(stride_vn >= 0) + tl.assume(stride_vh >= 0) + tl.assume(stride_dom >= 0) + tl.assume(stride_doh >= 0) + tl.assume(stride_dqm >= 0) + tl.assume(stride_dqh >= 0) + tl.assume(stride_dkn >= 0) + tl.assume(stride_dkh >= 0) + tl.assume(stride_dvn >= 0) + tl.assume(stride_dvh >= 0) + + # offset pointers for batch/head + Q = Q + seq_start * stride_qm + K = K + seq_start * stride_kn + V = V + seq_start * stride_vn + DOut = DOut + seq_start * stride_dom + DQ = DQ + seq_start * stride_dqm + off_h * stride_dqh + DK = DK + seq_start * stride_dkn + DV = DV + seq_start * stride_dvn + device_desc_q = None + device_desc_k = None + device_desc_v = None + device_desc_do = None + device_desc_dk = None + device_desc_dv = None + if ENABLE_TMA: + device_desc_q = tl.make_tensor_descriptor( + Q, + shape=[seq_len, H * DimQ], + strides=[H * DimQ, 1], + block_shape=[BLOCK_M, BLOCK_D_Q], + ) + device_desc_do = tl.make_tensor_descriptor( + DOut, + shape=[seq_len, H * DimV], + strides=[H * DimV, 1], + block_shape=[BLOCK_M, BLOCK_D_V], + ) + device_desc_k = tl.make_tensor_descriptor( + K, + shape=[seq_len, H * DimQ], + strides=[H * DimQ, 1], + block_shape=[BLOCK_N, BLOCK_D_Q], + ) + device_desc_dk = tl.make_tensor_descriptor( + DK, + shape=[seq_len, H * DimQ], + strides=[H * DimQ, 1], + block_shape=[BLOCK_N, BLOCK_D_Q], + ) + device_desc_v = tl.make_tensor_descriptor( + V, + shape=[seq_len, H * DimV], + strides=[H * DimV, 1], + block_shape=[BLOCK_N, BLOCK_D_V], + ) + device_desc_dv = tl.make_tensor_descriptor( + DV, + shape=[seq_len, H * DimV], + strides=[H * DimV, 1], + block_shape=[BLOCK_N, BLOCK_D_V], + ) + else: + Q += off_h * stride_qh + K += off_h * stride_kh + V += off_h * stride_vh + DOut += off_h * stride_doh + DK += off_h * stride_dkh + DV += off_h * stride_dvh + if SEQUENCE_PARALLEL: + start_n = tl.program_id(1) * BLOCK_N + if start_n >= seq_len: + return + _hstu_attn_bwd_one_col_block( + start_n=start_n, + seq_len=seq_len, + n_targets=n_targets, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + Q=Q, + K=K, + V=V, + DOut=DOut, + DQ=DQ, + DK=DK, + DV=DV, + device_desc_q=device_desc_q, + device_desc_k=device_desc_k, + device_desc_v=device_desc_v, + device_desc_do=device_desc_do, + device_desc_dk=device_desc_dk, + device_desc_dv=device_desc_dv, + LOCK=LOCK, + off_h=off_h, + stride_qh=stride_qh, + stride_kh=stride_kh, + stride_vh=stride_vh, + stride_doh=stride_doh, + stride_dkh=stride_dkh, + stride_dvh=stride_dvh, + stride_qm=stride_qm, + stride_kn=stride_kn, + stride_vn=stride_vn, + stride_dom=stride_dom, + stride_dqm=stride_dqm, + stride_dkn=stride_dkn, + stride_dvn=stride_dvn, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, + HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, + ALLOW_TF32=ALLOW_TF32, + BLOCK_D_Q=BLOCK_D_Q, + BLOCK_D_V=BLOCK_D_V, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + UNROLL=UNROLL, + ATOMIC_ADD=True, + ENABLE_TMA=ENABLE_TMA, + ) + else: + for start_n in range(0, seq_len, BLOCK_N): + _hstu_attn_bwd_one_col_block( + start_n=start_n, + seq_len=seq_len, + n_targets=n_targets, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + Q=Q, + K=K, + V=V, + DOut=DOut, + DQ=DQ, + DK=DK, + DV=DV, + device_desc_q=device_desc_q, + device_desc_k=device_desc_k, + device_desc_v=device_desc_v, + device_desc_do=device_desc_do, + device_desc_dk=device_desc_dk, + device_desc_dv=device_desc_dv, + LOCK=LOCK, + off_h=off_h, + stride_qh=stride_qh, + stride_kh=stride_kh, + stride_vh=stride_vh, + stride_doh=stride_doh, + stride_dkh=stride_dkh, + stride_dvh=stride_dvh, + stride_qm=stride_qm, + stride_kn=stride_kn, + stride_vn=stride_vn, + stride_dom=stride_dom, + stride_dqm=stride_dqm, + stride_dkn=stride_dkn, + stride_dvn=stride_dvn, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, + HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, + ALLOW_TF32=ALLOW_TF32, + BLOCK_D_Q=BLOCK_D_Q, + BLOCK_D_V=BLOCK_D_V, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + UNROLL=UNROLL, + ATOMIC_ADD=False, + ENABLE_TMA=ENABLE_TMA, + ) + + +@maybe_register_custom_op( + "generative_recommenders::triton_hstu_attention_fwd", mutates_args=() +) +def triton_hstu_attention_fwd( + N: int, + alpha: float, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + num_targets: Optional[torch.Tensor], + max_attn_len: int, + contextual_seq_len: int, + sort_by_length_indices: Optional[torch.Tensor], + enable_tma: bool, + num_softmax_heads: int, +) -> torch.Tensor: + Z = seq_offsets.numel() - 1 + AUTOTUNE_Z = prev_power_of_2(Z) + L, H, DimQ = q.shape + _, _, DimV = v.shape + out = torch.empty_like(v) + has_multiple_targets = num_targets is not None + has_contextual_seq_len = contextual_seq_len > 0 + has_max_attn_len = max_attn_len > 0 + has_sort_by_length_indices = sort_by_length_indices is not None + if L == 0: + return out + + TMA_DESC_SIZE = 128 + workspace = None + desc_q = q + desc_k = k + desc_v = v + + if enable_tma and tensor_descriptor_tma: + dummy_block = [1, 1] + desc_q = TensorDescriptor( + q, + shape=[L, H * DimQ], + strides=[H * DimQ, 1], + block_shape=dummy_block, + ) + desc_v = TensorDescriptor( + v, + shape=[L, H * DimV], + strides=[H * DimV, 1], + block_shape=dummy_block, + ) + desc_k = TensorDescriptor( + k, + shape=[L, H * DimQ], + strides=[H * DimQ, 1], + block_shape=dummy_block, + ) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + assert align == TMA_DESC_SIZE + return torch.empty(size, dtype=torch.int8, device="cuda") + + # pyre-ignore [6] + triton.set_allocator(alloc_fn) + grid = lambda meta: ( # noqa E731 + triton.cdiv(N, meta["BLOCK_M"]), + Z * H, + ) + + _hstu_attn_fwd[grid]( + Q=desc_q, + K=desc_k, + V=desc_v, + workspace_ptr=workspace, + sort_by_length_indices=sort_by_length_indices, + seq_offsets=seq_offsets, + num_targets=num_targets, + Out=out, + stride_qm=q.stride(0), + stride_qh=q.stride(1), + stride_kn=k.stride(0), + stride_kh=k.stride(1), + stride_vn=v.stride(0), + stride_vh=v.stride(1), + stride_om=out.stride(0), + stride_oh=out.stride(1), + alpha=alpha, + Z=Z, + AUTOTUNE_Z=AUTOTUNE_Z, + H=H, + MAX_SEQ_LEN=N, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(N), + DimQ=DimQ, + DimV=DimV, + DeltaSize=0, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + HAS_MULTIPLE_TARGETS=has_multiple_targets, + IS_DELTA_Q=False, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + BLOCK_D_Q=DimQ, + BLOCK_D_V=DimV, + HAS_CONTEXTUAL_SEQ_LEN=has_contextual_seq_len, + HAS_MAX_ATTN_LEN=has_max_attn_len, + HAS_SORT_BY_LENGTH_INDICES=has_sort_by_length_indices, + ENABLE_TMA=enable_tma, + TMA_DESC_SIZE=TMA_DESC_SIZE, + ) + return out + + +@maybe_register_custom_op( + "generative_recommenders::triton_hstu_attention_bwd", + mutates_args=("dq", "dk", "dv"), +) +def triton_hstu_attention_bwd( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor, + seq_offsets: torch.Tensor, + num_targets: Optional[torch.Tensor], + N: int, + alpha: float, + max_attn_len: int, + contextual_seq_len: int, + sort_by_length_indices: Optional[torch.Tensor], + enable_tma: bool, + num_softmax_heads: int, +) -> None: + orig_dq, orig_dk, orig_dv = dq, dk, dv + dout = switch_to_contiguous_if_needed(dout) + dq = switch_to_contiguous_if_needed(dq) + dk = switch_to_contiguous_if_needed(dk) + dv = switch_to_contiguous_if_needed(dv) + if dout.shape[0] == 0: + orig_dq.zero_() + orig_dk.zero_() + orig_dv.zero_() + return + Z = seq_offsets.numel() - 1 + _, H, DimQ = q.shape + _, _, DimV = v.shape + grid = lambda meta: ( # noqa E731 + Z * H, + (triton.cdiv(N, meta["BLOCK_N"]) if meta["SEQUENCE_PARALLEL"] else 1), + ) + # The minimum size of BLOCK_M used in `_get_bw_configs`. + # TODO (linjianma): avoid hardcoding the value. + MIN_BLOCK_M = 16 + lock = torch.empty( + (Z * H, triton.cdiv(N, MIN_BLOCK_M)), + dtype=torch.int32, + device=q.device, + ) + AUTOTUNE_Z = prev_power_of_2(Z) + TMA_DESC_SIZE = 128 + tma_workspace = None + + def alloc_fn(size: int, align: int, stream: Optional[int]): + assert align == TMA_DESC_SIZE + return torch.empty(size, dtype=torch.int8, device="cuda") + + # pyre-ignore [6] + triton.set_allocator(alloc_fn) + + # Enable BufferOps on AMD + ENABLE_BUFFER_OPS_ASSUMES = torch.version.hip is not None + _hstu_attn_bwd[grid]( + Q=q, + K=k, + V=v, + tma_workspace_ptr=tma_workspace, + sort_by_length_indices=sort_by_length_indices, + seq_offsets=seq_offsets, + num_targets=num_targets, + DOut=dout, + DQ=dq, + DK=dk, + DV=dv, + LOCK=lock, + stride_qm=q.stride(0), + stride_qh=q.stride(1), + stride_kn=k.stride(0), + stride_kh=k.stride(1), + stride_vn=v.stride(0), + stride_vh=v.stride(1), + stride_dom=dout.stride(0), + stride_doh=dout.stride(1), + stride_dqm=dq.stride(0), + stride_dqh=dq.stride(1), + stride_dkn=dk.stride(0), + stride_dkh=dk.stride(1), + stride_dvn=dv.stride(0), + stride_dvh=dv.stride(1), + alpha=alpha, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + Z=Z, + AUTOTUNE_Z=AUTOTUNE_Z, + H=H, + MAX_SEQ_LEN=N, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(N), + DimQ=DimQ, + DimV=DimV, + HAS_MULTIPLE_TARGETS=num_targets is not None, + HAS_CONTEXTUAL_SEQ_LEN=contextual_seq_len > 0, + HAS_MAX_ATTN_LEN=max_attn_len > 0, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + BLOCK_D_Q=DimQ, + BLOCK_D_V=DimV, + HAS_SORT_BY_LENGTH_INDICES=sort_by_length_indices is not None, + ENABLE_TMA=enable_tma, + TMA_DESC_SIZE=TMA_DESC_SIZE, + ENABLE_BUFFER_OPS_ASSUMES=ENABLE_BUFFER_OPS_ASSUMES, + ) + + copy_if_different_ptr(orig_dq, dq) + copy_if_different_ptr(orig_dk, dk) + copy_if_different_ptr(orig_dv, dv) + + +@triton_hstu_attention_fwd.register_fake +def _triton_hstu_attention_fwd_fake( + N: int, + alpha: float, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + num_targets: Optional[torch.Tensor], + max_attn_len: int, + contextual_seq_len: int, + sort_by_length_indices: Optional[torch.Tensor], + enable_tma: bool, + num_softmax_heads: int, +) -> torch.Tensor: + L, H, _ = q.shape + _, _, DimV = v.shape + out = torch.empty((L, H, DimV), dtype=v.dtype, device=v.device) + return out + + +@triton_hstu_attention_bwd.register_fake +def _triton_hstu_attention_bwd_fake( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor, + seq_offsets: torch.Tensor, + num_targets: Optional[torch.Tensor], + N: int, + alpha: float, + max_attn_len: int, + contextual_seq_len: int, + sort_by_length_indices: Optional[torch.Tensor], + enable_tma: bool, + num_softmax_heads: int, +) -> None: + return None + + +class _AttentionFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + N: int, + alpha: float, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + num_targets: Optional[torch.Tensor], + max_attn_len: int, + contextual_seq_len: int, + sort_by_length: bool, + enable_tma: bool, + ) -> torch.Tensor: + sort_by_length_indices = None + if sort_by_length: + seq_lengths = seq_offsets[1:] - seq_offsets[:-1] + _, sort_by_length_indices = torch.sort( + seq_lengths, descending=True, stable=False + ) + saved_tensors = [q, k, v, seq_offsets] + if num_targets is not None: + saved_tensors.append(num_targets) + if sort_by_length_indices is not None: + saved_tensors.append(sort_by_length_indices) + ctx.save_for_backward(*saved_tensors) + ctx.alpha = alpha + ctx.has_multiple_targets = num_targets is not None + ctx.max_attn_len = max_attn_len + ctx.N = N + ctx.contextual_seq_len = contextual_seq_len + ctx.sort_by_length = sort_by_length + ctx.enable_tma = enable_tma + return triton_hstu_attention_fwd( + N=N, + alpha=alpha, + q=q, + k=k, + v=v, + seq_offsets=seq_offsets, + num_targets=num_targets, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + sort_by_length_indices=sort_by_length_indices, + enable_tma=enable_tma, + num_softmax_heads=0, + ) + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, dout: torch.Tensor + ) -> Tuple[ + None, + None, + torch.Tensor, + torch.Tensor, + torch.Tensor, + None, + None, + None, + None, + None, + None, + ]: + with torch.inference_mode(): + q, k, v, seq_offsets = ctx.saved_tensors[:4] + idx = 4 + if ctx.has_multiple_targets: + num_targets = ctx.saved_tensors[idx] + idx += 1 + else: + num_targets = None + if ctx.sort_by_length: + sort_by_length_indices = ctx.saved_tensors[idx] + else: + sort_by_length_indices = None + + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + triton_hstu_attention_bwd( + dout=dout, + q=q, + k=k, + v=v, + dq=dq, + dk=dk, + dv=dv, + seq_offsets=seq_offsets, + num_targets=num_targets, + N=ctx.N, + alpha=ctx.alpha, + max_attn_len=ctx.max_attn_len, + contextual_seq_len=ctx.contextual_seq_len, + sort_by_length_indices=sort_by_length_indices, + enable_tma=ctx.enable_tma, + num_softmax_heads=0, + ) + return ( + None, + None, + dq, + dk, + dv, + None, + None, + None, + None, + None, + None, + ) + + +@torch.jit.unused +@torch.fx.wrap +def triton_hstu_mha( + N: int, + alpha: float, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + num_targets: Optional[torch.Tensor] = None, + max_attn_len: int = 0, + contextual_seq_len: int = 0, + sort_by_length: bool = False, + enable_tma: bool = False, +) -> torch.Tensor: + return _AttentionFunction.apply( + N, + alpha, + q, + k, + v, + seq_offsets, + num_targets, + max_attn_len, + contextual_seq_len, + sort_by_length, + enable_tma, + ) + + +@torch.jit.unused +@torch.fx.wrap +def triton_cached_hstu_mha( + N: int, + alpha: float, + delta_q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + num_targets: Optional[torch.Tensor] = None, + max_attn_len: int = 0, + contextual_seq_len: int = 0, + enable_tma: bool = False, +) -> torch.Tensor: + Z = seq_offsets.size(0) - 1 + AUTOTUNE_Z = prev_power_of_2(Z) + DELTA_L, H, DimQ = delta_q.shape + DeltaSize = DELTA_L // Z + L, _, DimV = v.shape + out = torch.empty((DELTA_L, H, DimV), dtype=delta_q.dtype, device=delta_q.device) + + TMA_DESC_SIZE = 128 + desc_q = delta_q + desc_k = k + desc_v = v + + if enable_tma and tensor_descriptor_tma: + dummy_block = [1, 1] + desc_q = TensorDescriptor( + delta_q, + shape=[DELTA_L, H * DimQ], + strides=[H * DimQ, 1], + block_shape=dummy_block, + ) + desc_v = TensorDescriptor( + v, + shape=[L, H * DimV], + strides=[H * DimV, 1], + block_shape=dummy_block, + ) + desc_k = TensorDescriptor( + k, + shape=[L, H * DimQ], + strides=[H * DimQ, 1], + block_shape=dummy_block, + ) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + assert align == TMA_DESC_SIZE + return torch.empty(size, dtype=torch.int8, device="cuda") + + # pyre-ignore [6] + triton.set_allocator(alloc_fn) + grid = lambda meta: ( # noqa E731 + triton.cdiv(DeltaSize, meta["BLOCK_M"]), + Z * H, + ) + + has_contextual_seq_len = contextual_seq_len > 0 + has_max_attn_len = max_attn_len > 0 + _hstu_attn_fwd[grid]( + Q=desc_q, + K=desc_k, + V=desc_v, + workspace_ptr=None, + sort_by_length_indices=None, + seq_offsets=seq_offsets, + num_targets=num_targets, + Out=out, + stride_qm=delta_q.stride(0), + stride_qh=delta_q.stride(1), + stride_kn=k.stride(0), + stride_kh=k.stride(1), + stride_vn=v.stride(0), + stride_vh=v.stride(1), + stride_om=out.stride(0), + stride_oh=out.stride(1), + alpha=alpha, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + Z=Z, + AUTOTUNE_Z=AUTOTUNE_Z, + H=H, + MAX_SEQ_LEN=N, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(N), + DimQ=DimQ, + DimV=DimV, + DeltaSize=DeltaSize, + HAS_MULTIPLE_TARGETS=num_targets is not None, + IS_DELTA_Q=True, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + BLOCK_D_Q=DimQ, + BLOCK_D_V=DimV, + HAS_CONTEXTUAL_SEQ_LEN=has_contextual_seq_len, + HAS_MAX_ATTN_LEN=has_max_attn_len, + HAS_SORT_BY_LENGTH_INDICES=False, + ENABLE_TMA=enable_tma, + TMA_DESC_SIZE=TMA_DESC_SIZE, + ) + return out diff --git a/recommendation_v4/generative_recommenders/ops/triton/triton_hstu_linear.py b/recommendation_v4/generative_recommenders/ops/triton/triton_hstu_linear.py new file mode 100644 index 000000000..516a15664 --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/triton/triton_hstu_linear.py @@ -0,0 +1,3047 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-unsafe + +#!/usr/bin/env python3 + + +from typing import List, Optional, Tuple + +import torch + +# @manual=//triton:triton +import triton + +# @manual=//triton:triton +import triton.language as tl +from generative_recommenders.common import ( + switch_to_contiguous_if_needed, + triton_autotune, +) +from generative_recommenders.ops.triton.triton_addmm import maybe_triton_addmm_fwd +from generative_recommenders.ops.utils import maybe_register_custom_op + + +def _get_layer_norm_mul_dropout_fwd_multirow_configs() -> List[triton.Config]: + """Generate autotune configs for multi-row LayerNorm multiplication with dropout kernels.""" + configs = [] + for BLOCK_N in [1, 2, 4, 8, 16]: + for num_warps in [1, 2, 4]: + configs.append( + triton.Config( + {"BLOCK_N": BLOCK_N}, + num_warps=num_warps, + ) + ) + return configs + + +from generative_recommenders.ops.utils import use_separated_rng_ln_mul_dropout + +# @manual=//triton:triton +from triton.language.extra import libdevice + +try: + # @manual=//triton:triton + from triton.language.extra.libdevice import fast_dividef +except ImportError: + try: + # @manual=//triton:triton + from triton.language.extra.cuda.libdevice import fast_dividef + except ImportError: + # pyre-ignore: Undefined import [21] + # @manual=//triton:triton + from triton.language.math import fast_dividef + + +COMPUTE_OUTPUT_LN_FAST_DROPOUT = False + + +def set_compute_output_ln_fast_dropout(value: bool) -> None: + global COMPUTE_OUTPUT_LN_FAST_DROPOUT + COMPUTE_OUTPUT_LN_FAST_DROPOUT = value + + +FUSE_OUTPUT_LN_RNG_BLACKWELL = False + + +# Only impact B200 training when CONCAT_UX is False +def set_fuse_output_ln_rng_blackwell(value: bool) -> None: + global FUSE_OUTPUT_LN_RNG_BLACKWELL + FUSE_OUTPUT_LN_RNG_BLACKWELL = value + + +@triton.jit +def rand3x(seed, offsets, n_rounds: tl.constexpr = 10): # pyre-ignore [9] + i1, i2, i3, _ = tl.randint4x(seed, offsets, n_rounds) + u1 = tl.uint_to_uniform_float(i1) + u2 = tl.uint_to_uniform_float(i2) + u3 = tl.uint_to_uniform_float(i3) + return u1, u2, u3 + + +@triton.jit +def _generate_random_mask( + MASK_BUFFER, + N, + dropout_ratio, + seed, + D: tl.constexpr, + STRIDE: tl.constexpr, + BLOCK_D: tl.constexpr, + NUM_MASKS: tl.constexpr, +): + """Generate bit-packed dropout masks for (N, D) tensors. Outputs int8. + + Processes 4 rows per program using rand4x. Mask j occupies bit j. + Extraction: y = val & 1, x = val & 2, u = val & 4. + """ + pid = tl.program_id(0) + cols = tl.arange(0, BLOCK_D) + col_mask = cols < D + start_row = pid.to(tl.int64) * 4 + + base_ptr = MASK_BUFFER + start_row * STRIDE + cols + row0_mask = (start_row < N) & col_mask + row1_mask = ((start_row + 1) < N) & col_mask + row2_mask = ((start_row + 2) < N) & col_mask + row3_mask = ((start_row + 3) < N) & col_mask + + # Each pid uses NUM_MASKS consecutive BLOCK_D chunks for Philox offsets + rand_offset = pid * (NUM_MASKS * BLOCK_D) + cols + + packed0 = tl.zeros([BLOCK_D], dtype=tl.int8) + packed1 = tl.zeros([BLOCK_D], dtype=tl.int8) + packed2 = tl.zeros([BLOCK_D], dtype=tl.int8) + packed3 = tl.zeros([BLOCK_D], dtype=tl.int8) + + for j in tl.static_range(NUM_MASKS): + r0, r1, r2, r3 = tl.rand4x(seed, rand_offset) + packed0 |= (r0 > dropout_ratio).to(tl.int8) << j + packed1 |= (r1 > dropout_ratio).to(tl.int8) << j + packed2 |= (r2 > dropout_ratio).to(tl.int8) << j + packed3 |= (r3 > dropout_ratio).to(tl.int8) << j + rand_offset += BLOCK_D + + tl.store(base_ptr, packed0, mask=row0_mask) + tl.store(base_ptr + STRIDE, packed1, mask=row1_mask) + tl.store(base_ptr + 2 * STRIDE, packed2, mask=row2_mask) + tl.store(base_ptr + 3 * STRIDE, packed3, mask=row3_mask) + + +@triton_autotune( + configs=_get_layer_norm_mul_dropout_fwd_multirow_configs(), + key=["BLOCK_D"], +) +@triton.jit +def _ln_mul_dropout_fwd_rng( + X, + U, + Y, + W, + B, + Mean, + Rstd, + RANDOM_MASK, + N, + D, + eps, + dropout_ratio, + stride_x, + stride_u, + stride_y, + stride_mask, + SILU_U: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, + TRAINING: tl.constexpr, + CONCAT_U: tl.constexpr, + CONCAT_X: tl.constexpr, + MUL_U_ACTIVATION_TYPE: tl.constexpr, +): + block_id = tl.program_id(0) + start_row = block_id * BLOCK_N + + # Create block pointers for X, U, and Y + X_block_ptr = tl.make_block_ptr( + base=X, + shape=(N, D), + strides=(stride_x, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + U_block_ptr = tl.make_block_ptr( + base=U, + shape=(N, D), + strides=(stride_u, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + # Load data blocks + x_block = tl.load(X_block_ptr, boundary_check=(0, 1), padding_option="zero").to( + tl.float32 + ) + u_block = tl.load(U_block_ptr, boundary_check=(0, 1), padding_option="zero").to( + tl.float32 + ) + + cols = tl.arange(0, BLOCK_D) + col_mask = cols < D + rows = start_row + tl.arange(0, BLOCK_N) + row_mask = rows < N + # Pre-compute 2D mask for reuse in dropout and masked operations + mask_2d = row_mask[:, None] & col_mask[None, :] + + # Pre-compute inv_D to replace divisions with multiplications (optimization) + inv_D = 1.0 / D + + mean = tl.sum(x_block, axis=1) * inv_D + tl.store(Mean + rows, mean, mask=row_mask) + mean = tl.expand_dims(mean, 1) + + x_mean = x_block - mean + x_mean = tl.where(mask_2d, x_mean, 0.0) + _var = x_mean * x_mean + var = tl.sum(_var, axis=1) * inv_D + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + rows, rstd, mask=row_mask) + rstd = tl.expand_dims(rstd, 1) + + y = x_mean * rstd + w = tl.load(W + cols, mask=col_mask).to(tl.float32) + b = tl.load(B + cols, mask=col_mask).to(tl.float32) + y = y * w[None, :] + b[None, :] + + # Pre-compute sigmoid once to avoid redundant computation + sigmoid_u_block = tl.sigmoid(u_block) + silu_u_block = u_block * sigmoid_u_block + + if MUL_U_ACTIVATION_TYPE == "silu": + y = y * silu_u_block + elif MUL_U_ACTIVATION_TYPE == "sigmoid": + y = y * sigmoid_u_block + else: + y = y * u_block + + if CONCAT_U and SILU_U: + # pyre-fixme[16] + u_block = silu_u_block + + if TRAINING: + # Reuse rows (as int64 for pointer arithmetic) and pre-computed mask_2d + row_offsets_i64 = rows.to(tl.int64) + # Pre-compute loop-invariant values + dropout_scale = 1.0 / (1.0 - dropout_ratio) + offsets = row_offsets_i64[:, None] * stride_mask + cols[None, :] + + if CONCAT_U or CONCAT_X: + # All 2+ mask cases use compressed int8 format - load once + compressed = tl.load(RANDOM_MASK + offsets, mask=mask_2d, other=0).to( + tl.int32 + ) + # Bit 0 is always y_mask + y_keep = (compressed & 1) != 0 + + if CONCAT_U and CONCAT_X: + # 3-mask: (u_mask << 2) | (x_mask << 1) | y_mask + x_keep = (compressed & 2) != 0 + u_keep = (compressed & 4) != 0 + u_block = tl.where(u_keep, u_block * dropout_scale, 0.0) + x_block = tl.where(x_keep, x_block * dropout_scale, 0.0) + elif CONCAT_U: + # 2-mask: (u_mask << 1) | y_mask + u_keep = (compressed & 2) != 0 + u_block = tl.where(u_keep, u_block * dropout_scale, 0.0) + else: # CONCAT_X + # 2-mask: (x_mask << 1) | y_mask + x_keep = (compressed & 2) != 0 + x_block = tl.where(x_keep, x_block * dropout_scale, 0.0) + + y = tl.where(y_keep, y * dropout_scale, 0.0) + else: + # 1-mask: y_mask at bit 0 + y_keep = tl.load(RANDOM_MASK + offsets, mask=mask_2d, other=True) + y = tl.where(y_keep, y * dropout_scale, 0.0) + + if CONCAT_U and CONCAT_X: + Y_block_ptr_u = tl.make_block_ptr( + base=Y, + shape=(N, 3 * D), + strides=(stride_y, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + Y_block_ptr_x = tl.make_block_ptr( + base=Y, + shape=(N, 3 * D), + strides=(stride_y, 1), + offsets=(start_row, D), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + Y_block_ptr_y = tl.make_block_ptr( + base=Y, + shape=(N, 3 * D), + strides=(stride_y, 1), + offsets=(start_row, 2 * D), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + tl.store(Y_block_ptr_u, u_block.to(Y.dtype.element_ty), boundary_check=(0, 1)) + tl.store(Y_block_ptr_x, x_block.to(Y.dtype.element_ty), boundary_check=(0, 1)) + tl.store(Y_block_ptr_y, y.to(Y.dtype.element_ty), boundary_check=(0, 1)) + elif CONCAT_U: + Y_block_ptr_u = tl.make_block_ptr( + base=Y, + shape=(N, 2 * D), + strides=(stride_y, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + Y_block_ptr_y = tl.make_block_ptr( + base=Y, + shape=(N, 2 * D), + strides=(stride_y, 1), + offsets=(start_row, D), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + tl.store(Y_block_ptr_u, u_block.to(Y.dtype.element_ty), boundary_check=(0, 1)) + tl.store(Y_block_ptr_y, y.to(Y.dtype.element_ty), boundary_check=(0, 1)) + elif CONCAT_X: + Y_block_ptr_x = tl.make_block_ptr( + base=Y, + shape=(N, 2 * D), + strides=(stride_y, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + Y_block_ptr_y = tl.make_block_ptr( + base=Y, + shape=(N, 2 * D), + strides=(stride_y, 1), + offsets=(start_row, D), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + tl.store(Y_block_ptr_x, x_block.to(Y.dtype.element_ty), boundary_check=(0, 1)) + tl.store(Y_block_ptr_y, y.to(Y.dtype.element_ty), boundary_check=(0, 1)) + else: + Y_block_ptr = tl.make_block_ptr( + base=Y, + shape=(N, D), + strides=(stride_y, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + tl.store(Y_block_ptr, y.to(Y.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def _ln_mul_dropout_fwd( + X, + U, + Y, + W, + B, + Mean, + Rstd, + D, + eps, + seed, + dropout_ratio, + stride_x, + stride_u, + stride_y, + SILU_U: tl.constexpr, + BLOCK_D: tl.constexpr, + TRAINING: tl.constexpr, + CONCAT_U: tl.constexpr, + CONCAT_X: tl.constexpr, + MUL_U_ACTIVATION_TYPE: tl.constexpr, + FAST_DROPOUT: tl.constexpr, +): + row = tl.program_id(0) + X += row.to(tl.int64) * stride_x + U += row.to(tl.int64) * stride_u + Y += row.to(tl.int64) * stride_y + cols = tl.arange(0, BLOCK_D) + + # Compute mean + mean = 0.0 + x = tl.load(X + cols, mask=cols < D, other=0.0).to(tl.float32) + mean = tl.sum(x, axis=0) / D + + # Compute variance + _var = tl.zeros([BLOCK_D], dtype=tl.float32) + x_mean = tl.where(cols < D, x - mean, 0.0) + _var += x_mean * x_mean + var = tl.sum(_var, axis=0) / D + rstd = 1 / tl.sqrt(var + eps) + tl.store(Mean + row, mean) + tl.store(Rstd + row, rstd) + + # Normalize and apply linear transformation + mask = cols < D + y = x_mean * rstd + w = tl.load(W + cols, mask=mask).to(tl.float32) + b = tl.load(B + cols, mask=mask).to(tl.float32) + y = y * w + b + u = tl.load(U + cols, mask=cols < D, other=0.0).to(tl.float32) + sigmoid_u = tl.sigmoid(u) + silu_u = u * sigmoid_u + + if MUL_U_ACTIVATION_TYPE == "silu": + y = y * silu_u + elif MUL_U_ACTIVATION_TYPE == "sigmoid": + y = y * sigmoid_u + else: + y = y * u + + if CONCAT_U and SILU_U: + u = silu_u + + if TRAINING: + random_offsets = 3 * row * BLOCK_D + cols + if CONCAT_U and CONCAT_X: + # apply dropout on u + if FAST_DROPOUT: + random_u, random_x, random_y = rand3x(seed, random_offsets) + else: + random_u = tl.rand(seed, random_offsets) + u_keep = random_u > dropout_ratio + u = tl.where(u_keep, u / (1.0 - dropout_ratio), 0.0) + # apply dropout on x + if not FAST_DROPOUT: + random_x = tl.rand(seed, random_offsets + D) + x_keep = random_x > dropout_ratio # pyre-ignore [61] + x = tl.where(x_keep, x / (1.0 - dropout_ratio), 0.0) + # apply dropout on y + if not FAST_DROPOUT: + random_y = tl.rand(seed, random_offsets + 2 * D) + y_keep = random_y > dropout_ratio # pyre-ignore [61] + y = tl.where(y_keep, y / (1.0 - dropout_ratio), 0.0) + elif CONCAT_U: + # apply dropout on u + if FAST_DROPOUT: + random_u, random_y, _ = rand3x(seed, random_offsets) + else: + random_u = tl.rand(seed, random_offsets) + u_keep = random_u > dropout_ratio + u = tl.where(u_keep, u / (1.0 - dropout_ratio), 0.0) + # apply dropout on y + if not FAST_DROPOUT: + random_y = tl.rand(seed, random_offsets + D) + y_keep = random_y > dropout_ratio # pyre-ignore [61] + y = tl.where(y_keep, y / (1.0 - dropout_ratio), 0.0) + elif CONCAT_X: + # apply dropout on x + if FAST_DROPOUT: + random_x, random_y, _ = rand3x(seed, random_offsets) + else: + random_x = tl.rand(seed, random_offsets) + x_keep = random_x > dropout_ratio + x = tl.where(x_keep, x / (1.0 - dropout_ratio), 0.0) + # apply dropout on y + if not FAST_DROPOUT: + random_y = tl.rand(seed, random_offsets + D) + y_keep = random_y > dropout_ratio # pyre-ignore [61] + y = tl.where(y_keep, y / (1.0 - dropout_ratio), 0.0) + else: + random = tl.rand(seed, random_offsets) + y_keep = random > dropout_ratio + # write-back + y = tl.where(y_keep, y / (1.0 - dropout_ratio), 0.0) + + # Write output + if CONCAT_U and CONCAT_X: + tl.store(Y + cols, u.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + D + cols, x.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + 2 * D + cols, y.to(Y.dtype.element_ty), mask=mask) + elif CONCAT_U: + tl.store(Y + cols, u.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + D + cols, y.to(Y.dtype.element_ty), mask=mask) + elif CONCAT_X: + tl.store(Y + cols, x.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + D + cols, y.to(Y.dtype.element_ty), mask=mask) + else: + tl.store(Y + cols, y.to(Y.dtype.element_ty), mask=mask) + + +@triton.jit +def _ln_mul_dropout_bwd_dx_du_rng( + DX, + DU, + DY, + DW, + DB, + X, + U, + Y, + W, + B, + Mean, + Rstd, + RANDOM_MASK, + stride_dx, + stride_du, + stride_dy, + stride_x, + stride_u, + stride_y, + stride_mask, + D, + eps, + dropout_ratio, + N, + SILU_U: tl.constexpr, + BLOCK_D: tl.constexpr, + TRAINING: tl.constexpr, + CONCAT_U: tl.constexpr, + CONCAT_X: tl.constexpr, + MUL_U_ACTIVATION_TYPE: tl.constexpr, + COMPUTE_Y: tl.constexpr, +): + pid = tl.program_id(0) + tile_num = tl.num_programs(0) + rows_per_tile = N // tile_num + if pid < N % tile_num: + rows_per_tile += 1 + + if rows_per_tile == 0: + return + + cols = tl.arange(0, BLOCK_D) + mask = cols < D + + row = pid + # Pre-compute row and pid as int64 once for initial pointer setup + row_i64 = row.to(tl.int64) + pid_i64 = pid.to(tl.int64) + X += row_i64 * stride_x + U += row_i64 * stride_u + if COMPUTE_Y: + Y += row_i64 * stride_y + DY += row_i64 * stride_dy + DX += row_i64 * stride_dx + DU += row_i64 * stride_du + DW = DW + pid_i64 * D + cols + DB = DB + pid_i64 * D + cols + + # Pre-compute mask pointer offset (all cases use stride_mask for (N, D) shape) + RANDOM_MASK += row_i64 * stride_mask + + partial_dw = tl.zeros((BLOCK_D,), dtype=tl.float32) + partial_db = tl.zeros((BLOCK_D,), dtype=tl.float32) + w = tl.load(W + cols, mask=mask).to(tl.float32) + b = tl.load(B + cols, mask=mask).to(tl.float32) + + dropout_scale = 0.0 + if TRAINING: + dropout_scale = 1.0 / (1.0 - dropout_ratio) + + # Pre-compute inv_D to replace divisions with multiplications (optimization) + inv_D = 1.0 / D + + # Pre-compute tile_num as int64 to avoid repeated conversion in the loop + tile_num_i64 = tile_num.to(tl.int64) + for _ in range(0, rows_per_tile): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + if CONCAT_U and CONCAT_X: + du = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + dx = tl.load(DY + D + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + 2 * D + cols, mask=mask, other=0).to(tl.float32) + elif CONCAT_U: + du = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + dx = tl.zeros([BLOCK_D], dtype=tl.float32) + dy = tl.load(DY + D + cols, mask=mask, other=0).to(tl.float32) + elif CONCAT_X: + du = tl.zeros([BLOCK_D], dtype=tl.float32) + dx = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + D + cols, mask=mask, other=0).to(tl.float32) + else: + du = tl.zeros([BLOCK_D], dtype=tl.float32) + dx = tl.zeros([BLOCK_D], dtype=tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if TRAINING: + if CONCAT_U or CONCAT_X: + # All 2+ mask cases use compressed int8 format - load once + compressed = tl.load(RANDOM_MASK + cols, mask=mask, other=0).to( + tl.int32 + ) + dy_keep = (compressed & 1) != 0 # Bit 0 always y_mask + + if CONCAT_U and CONCAT_X: + # Format: (u_mask << 2) | (x_mask << 1) | y_mask + dx_keep = (compressed & 2) != 0 + du_keep = (compressed & 4) != 0 + du = tl.where(du_keep, du * dropout_scale, 0.0) + dx = tl.where(dx_keep, dx * dropout_scale, 0.0) + elif CONCAT_U: + # Format: (u_mask << 1) | y_mask + du_keep = (compressed & 2) != 0 + du = tl.where(du_keep, du * dropout_scale, 0.0) + else: # CONCAT_X + # Format: (x_mask << 1) | y_mask + dx_keep = (compressed & 2) != 0 + dx = tl.where(dx_keep, dx * dropout_scale, 0.0) + dy = tl.where(dy_keep, dy * dropout_scale, 0.0) + else: + # 1-mask: y_mask at bit 0 + dy_keep = tl.load(RANDOM_MASK + cols, mask=mask, other=True) + dy = tl.where(dy_keep, dy * dropout_scale, 0.0) + + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + + # Compute dx + xhat = (x - mean) * rstd + u = tl.load(U + cols, mask=mask, other=0).to(tl.float32) + ln = xhat * w + b + du_y = dy * ln + mul_u = u + sig_u = tl.sigmoid(u) + + # Pre-compute commonly used expressions to avoid redundant computation + silu_u = u * sig_u # silu(u) - used multiple times + dsig_u = sig_u * (1.0 - sig_u) # sigmoid derivative - used multiple times + dsilu_u = sig_u + silu_u * ( + 1.0 - sig_u + ) # silu derivative - used multiple times + + if MUL_U_ACTIVATION_TYPE == "silu": + mul_u = silu_u + du_y = dy * ln * dsilu_u + dy = dy * silu_u + elif MUL_U_ACTIVATION_TYPE == "sigmoid": + mul_u = sig_u + du_y = dy * ln * dsig_u + dy = dy * sig_u + else: + dy = dy * u + + du_u = du + if CONCAT_U and SILU_U: + du_u *= dsilu_u + u = silu_u + + du = du_y + du_u + + tl.store(DU + cols, du.to(DU.dtype.element_ty), mask=mask) + + wdy = w * dy + if COMPUTE_Y: + y = ln * mul_u + if TRAINING: + if CONCAT_U: + u = tl.where( + du_keep, # pyre-ignore [61] + u * dropout_scale, + 0.0, + ) + if CONCAT_X: + x = tl.where( + dx_keep, # pyre-ignore [61] + x * dropout_scale, + 0.0, + ) + y = tl.where( + dy_keep, # pyre-ignore [61] + y * dropout_scale, + 0.0, + ) + if CONCAT_U and CONCAT_X: + tl.store(Y + cols, u.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + D + cols, x.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + 2 * D + cols, y.to(Y.dtype.element_ty), mask=mask) + elif CONCAT_U: + tl.store(Y + cols, u.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + D + cols, y.to(Y.dtype.element_ty), mask=mask) + elif CONCAT_X: + tl.store(Y + cols, x.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + D + cols, y.to(Y.dtype.element_ty), mask=mask) + else: + tl.store(Y + cols, y.to(Y.dtype.element_ty), mask=mask) + Y += tile_num_i64 * stride_y + + # Note: xhat and wdy are already 0 outside valid range due to masked loads, + # so no additional tl.where masking is needed before reduction + c1 = tl.sum(xhat * wdy, axis=0) * inv_D + c2 = tl.sum(wdy, axis=0) * inv_D + dx += (wdy - (xhat * c1 + c2)) * rstd + # Write dx + tl.store(DX + cols, dx, mask=mask) + + # Accumulate partial sums for dw/db + partial_dw += dy * xhat + partial_db += dy + X += tile_num_i64 * stride_x + U += tile_num_i64 * stride_u + DY += tile_num_i64 * stride_dy + DX += tile_num_i64 * stride_dx + DU += tile_num_i64 * stride_du + # Increment mask pointer + RANDOM_MASK += tile_num_i64 * stride_mask + row += tile_num + tl.store(DW, partial_dw, mask=mask) + tl.store(DB, partial_db, mask=mask) + + +@triton.jit +def _ln_mul_dropout_bwd_dx_du( + DX, + DU, + DY, + DW, + DB, + X, + U, + Y, + W, + B, + Mean, + Rstd, + stride_dx, + stride_du, + stride_dy, + stride_x, + stride_u, + stride_y, + D, + eps, + seed, + dropout_ratio, + N, + SILU_U: tl.constexpr, + BLOCK_D: tl.constexpr, + TRAINING: tl.constexpr, + CONCAT_U: tl.constexpr, + CONCAT_X: tl.constexpr, + MUL_U_ACTIVATION_TYPE: tl.constexpr, + COMPUTE_Y: tl.constexpr, + FAST_DROPOUT: tl.constexpr, +): + pid = tl.program_id(0) + tile_num = tl.num_programs(0) + rows_per_tile = N // tile_num + if pid < N % tile_num: + rows_per_tile += 1 + + if rows_per_tile == 0: + return + + cols = tl.arange(0, BLOCK_D) + mask = cols < D + + row = pid + X += row.to(tl.int64) * stride_x + U += row.to(tl.int64) * stride_u + if COMPUTE_Y: + Y += row.to(tl.int64) * stride_y + DY += row.to(tl.int64) * stride_dy + DX += row.to(tl.int64) * stride_dx + DU += row.to(tl.int64) * stride_du + DW = DW + pid * D + cols + DB = DB + pid * D + cols + + partial_dw = tl.zeros((BLOCK_D,), dtype=tl.float32) + partial_db = tl.zeros((BLOCK_D,), dtype=tl.float32) + w = tl.load(W + cols, mask=mask).to(tl.float32) + b = tl.load(B + cols, mask=mask).to(tl.float32) + for _idx in range(0, rows_per_tile): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + if CONCAT_U and CONCAT_X: + du = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + dx = tl.load(DY + D + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + 2 * D + cols, mask=mask, other=0).to(tl.float32) + elif CONCAT_U: + du = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + dx = tl.zeros([BLOCK_D], dtype=tl.float32) + dy = tl.load(DY + D + cols, mask=mask, other=0).to(tl.float32) + elif CONCAT_X: + du = tl.zeros([BLOCK_D], dtype=tl.float32) + dx = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + D + cols, mask=mask, other=0).to(tl.float32) + else: + du = tl.zeros([BLOCK_D], dtype=tl.float32) + dx = tl.zeros([BLOCK_D], dtype=tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if TRAINING: + random_offsets = 3 * row * BLOCK_D + cols + if CONCAT_U and CONCAT_X: + # apply dropout on du + if FAST_DROPOUT: + random_du, random_dx, random_dy = rand3x(seed, random_offsets) + else: + random_du = tl.rand(seed, random_offsets) + du_keep = random_du > dropout_ratio + du = tl.where(du_keep, du / (1.0 - dropout_ratio), 0.0) + # apply dropout on dx + if not FAST_DROPOUT: + random_dx = tl.rand(seed, random_offsets + D) + dx_keep = random_dx > dropout_ratio # pyre-ignore [61] + dx = tl.where(dx_keep, dx / (1.0 - dropout_ratio), 0.0) + # apply dropout on dy + if not FAST_DROPOUT: + random_dy = tl.rand(seed, random_offsets + 2 * D) + dy_keep = random_dy > dropout_ratio # pyre-ignore [61] + dy = tl.where(dy_keep, dy / (1.0 - dropout_ratio), 0.0) + elif CONCAT_U: + # apply dropout on du + if FAST_DROPOUT: + random_du, _, random_dy = rand3x(seed, random_offsets) + else: + random_du = tl.rand(seed, random_offsets) + du_keep = random_du > dropout_ratio + du = tl.where(du_keep, du / (1.0 - dropout_ratio), 0.0) + # apply dropout on dy + if not FAST_DROPOUT: + random_dy = tl.rand(seed, random_offsets + D) + dy_keep = random_dy > dropout_ratio # pyre-ignore [61] + dy = tl.where(dy_keep, dy / (1.0 - dropout_ratio), 0.0) + elif CONCAT_X: + # apply dropout on dx + if FAST_DROPOUT: + _, random_dx, random_dy = rand3x(seed, random_offsets) + else: + random_dx = tl.rand(seed, random_offsets) + dx_keep = random_dx > dropout_ratio # pyre-ignore [61] + dx = tl.where(dx_keep, dx / (1.0 - dropout_ratio), 0.0) + # apply dropout on dy + if not FAST_DROPOUT: + random_dy = tl.rand(seed, random_offsets + D) + dy_keep = random_dy > dropout_ratio # pyre-ignore [61] + dy = tl.where(dy_keep, dy / (1.0 - dropout_ratio), 0.0) + else: + random = tl.rand(seed, random_offsets) + dy_keep = random > dropout_ratio + # write-back + dy = tl.where(dy_keep, dy / (1.0 - dropout_ratio), 0.0) + + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + + # Compute dx + xhat = (x - mean) * rstd + u = tl.load(U + cols, mask=mask, other=0).to(tl.float32) + ln = xhat * w + b + du_y = dy * ln + mul_u = u + sig_u = tl.sigmoid(u) + + if MUL_U_ACTIVATION_TYPE == "silu": + mul_u = u * sig_u + du_y = dy * ln * (sig_u + u * sig_u * (1.0 - sig_u)) + dy = dy * u * sig_u + elif MUL_U_ACTIVATION_TYPE == "sigmoid": + mul_u = sig_u + du_y = dy * ln * (sig_u * (1.0 - sig_u)) + dy = dy * sig_u + else: + dy = dy * u + + du_u = du + if CONCAT_U: + if SILU_U: + du_u *= sig_u + u * sig_u * (1.0 - sig_u) + u = u * sig_u + + du = du_y + du_u + + tl.store(DU + cols, du.to(DU.dtype.element_ty), mask=mask) + wdy = w * dy + if COMPUTE_Y: + y = ln * mul_u + if TRAINING: + if CONCAT_U: + u = tl.where( + du_keep, # pyre-ignore [61] + u / (1.0 - dropout_ratio), + 0.0, + ) + if CONCAT_X: + x = tl.where( + dx_keep, # pyre-ignore [61] + x / (1.0 - dropout_ratio), + 0.0, + ) + y = tl.where( + dy_keep, # pyre-ignore [61] + y / (1.0 - dropout_ratio), + 0.0, + ) + if CONCAT_U and CONCAT_X: + tl.store(Y + cols, u.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + D + cols, x.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + 2 * D + cols, y.to(Y.dtype.element_ty), mask=mask) + elif CONCAT_U: + tl.store(Y + cols, u.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + D + cols, y.to(Y.dtype.element_ty), mask=mask) + elif CONCAT_X: + tl.store(Y + cols, x.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + D + cols, y.to(Y.dtype.element_ty), mask=mask) + else: + tl.store(Y + cols, y.to(Y.dtype.element_ty), mask=mask) + Y += tile_num.to(tl.int64) * stride_y + + xhat = tl.where(mask, xhat, 0.0) + wdy = tl.where(mask, wdy, 0.0) + c1 = tl.sum(xhat * wdy, axis=0) / D + c2 = tl.sum(wdy, axis=0) / D + dx += (wdy - (xhat * c1 + c2)) * rstd + # Write dx + tl.store(DX + cols, dx, mask=mask) + + # Accumulate partial sums for dw/db + partial_dw += dy * xhat + partial_db += dy + X += tile_num.to(tl.int64) * stride_x + U += tile_num.to(tl.int64) * stride_u + DY += tile_num.to(tl.int64) * stride_dy + DX += tile_num.to(tl.int64) * stride_dx + DU += tile_num.to(tl.int64) * stride_du + row += tile_num + tl.store(DW, partial_dw, mask=mask) + tl.store(DB, partial_db, mask=mask) + + +def _get_bwd_dwdb_configs() -> List[triton.Config]: + configs = [] + for BLOCK_N in [32, 64, 128, 256]: + for num_warps in [8, 16] + ([] if torch.ops.hip else [32]): + configs.append( + triton.Config( + {"BLOCK_N": BLOCK_N}, + num_warps=num_warps, + ) + ) + return configs + + +@triton_autotune( + configs=_get_bwd_dwdb_configs(), + key=["D"], +) +@triton.jit +def _ln_mul_dropout_bwd_dwdb( + DW, + DB, + FINAL_DW, + FINAL_DB, + N, + D, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid = tl.program_id(0).to(tl.int64) + cols = pid * BLOCK_D + tl.arange(0, BLOCK_D) + dw = tl.zeros((BLOCK_N, BLOCK_D), dtype=tl.float32) + db = tl.zeros((BLOCK_N, BLOCK_D), dtype=tl.float32) + + for i in range(0, N, BLOCK_N): + rows = i + tl.arange(0, BLOCK_N) + # pyre-fixme[16]: `int` has no attribute `__getitem__`. + mask = (rows[:, None] < N) & (cols[None, :] < D) + offs = rows[:, None] * D + cols[None, :] + dw += tl.load(DW + offs, mask=mask, other=0.0) + db += tl.load(DB + offs, mask=mask, other=0.0) + + sum_dw = tl.sum(dw, axis=0) + sum_db = tl.sum(db, axis=0) + tl.store(FINAL_DW + cols, sum_dw.to(FINAL_DW.dtype.element_ty), mask=cols < D) + tl.store(FINAL_DB + cols, sum_db.to(FINAL_DB.dtype.element_ty), mask=cols < D) + + +def _create_dropout_mask( + N: int, + D: int, + BLOCK_D: int, + concat_u: bool, + concat_x: bool, + dropout_ratio: float, + seed: int, + device: torch.device, +) -> torch.Tensor: + """Create dropout mask tensor for layer norm mul dropout. + + Args: + N: Number of rows + D: Feature dimension + BLOCK_D: Block size for D dimension + concat_u: Whether to concatenate u + concat_x: Whether to concatenate x + dropout_ratio: Dropout ratio + seed: Random seed + device: Device to create tensor on + + Returns: + random_mask: (N, D) int8 tensor. Mask j at bit j. + + Bit layout: y = val & 1, x = val & 2, u = val & 4. + """ + num_masks = 1 + int(concat_u) + int(concat_x) + # Torch uses 1 byte for bool internally, same as int8, so always use int8. + random_mask = torch.empty([N, D], dtype=torch.int8, device=device) + _generate_random_mask[(triton.cdiv(N, 4),)]( + random_mask, + N, + dropout_ratio, + seed, + D, # pyre-ignore[6] + random_mask.stride(0), # pyre-ignore[6] + BLOCK_D, # pyre-fixme[6]: Triton constexpr param + num_masks, # pyre-ignore[6]: NUM_MASKS constexpr + ) + return random_mask + + +@maybe_register_custom_op( + "generative_recommenders::_triton_layer_norm_mul_dropout_fwd_impl", mutates_args=() +) +def _triton_layer_norm_mul_dropout_fwd_impl( + x: torch.Tensor, + u: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + dropout_ratio: float, + training: bool, + silu_u: bool, + concat_u: bool, + concat_x: bool, + mul_u_activation_type: str, + seed: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Internal implementation that returns only tensors for custom_op compatibility. + + Returns (y, mean, rstd, random_mask) where random_mask is empty when not used. + """ + N, D = x.shape + + if concat_u and concat_x: + y = torch.empty((N, 3 * D), dtype=x.dtype, device=x.device) + elif concat_u: + y = torch.empty((N, 2 * D), dtype=x.dtype, device=x.device) + elif concat_x: + y = torch.empty((N, 2 * D), dtype=x.dtype, device=x.device) + else: + y = torch.empty_like(x) + mean = torch.empty((N,), dtype=torch.float32, device=x.device) + rstd = torch.empty((N,), dtype=torch.float32, device=x.device) + if N == 0: + return y, mean, rstd, torch.empty(0, dtype=x.dtype, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_D: int = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BLOCK_D: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + + num_warps: int = min(max(BLOCK_D // 256, 1), 8) + random_mask: torch.Tensor = torch.empty(0, dtype=x.dtype, device=x.device) + # Separating RNG from the ln_mul_dropout kernel lets us batch multiple rows per + # program (autotuned _ln_mul_dropout_fwd_rng) and reuse the precomputed mask in the + # backward, instead of launching one program per row with fused RNG. This is a large + # win on Blackwell (sm_100) and AMD MI350 (gfx950); other GPUs keep the fused path. + # Extended to support concat_u + concat_x for mask reuse optimization + if ( + not FUSE_OUTPUT_LN_RNG_BLACKWELL + and use_separated_rng_ln_mul_dropout() + and training + ): + random_mask = _create_dropout_mask( + N=N, + D=D, + BLOCK_D=BLOCK_D, + concat_u=concat_u, + concat_x=concat_x, + dropout_ratio=dropout_ratio, + seed=seed, + device=x.device, + ) + + def grid(META): + return (triton.cdiv(N, META["BLOCK_N"]),) + + # pyre-ignore[28] + _ln_mul_dropout_fwd_rng[grid]( + x, + u, + y, + weight, + bias, + mean, + rstd, + random_mask, + N, + D, + eps, + dropout_ratio, + x.stride(0), + u.stride(0), + y.stride(0), + random_mask.stride(0), + SILU_U=silu_u, + BLOCK_D=BLOCK_D, + TRAINING=training, + CONCAT_U=concat_u, + CONCAT_X=concat_x, + MUL_U_ACTIVATION_TYPE=mul_u_activation_type, + ) + + else: + # Default path: fused RNG generation + # Mask cannot be saved with fused RNG - it's generated inline in the kernel + # pyre-ignore[28] + _ln_mul_dropout_fwd[(N,)]( + x, + u, + y, + weight, + bias, + mean, + rstd, + D, + eps, + seed, + dropout_ratio, + x.stride(0), + u.stride(0), + y.stride(0), + SILU_U=silu_u, + BLOCK_D=BLOCK_D, + TRAINING=training, + CONCAT_U=concat_u, + CONCAT_X=concat_x, + MUL_U_ACTIVATION_TYPE=mul_u_activation_type, + FAST_DROPOUT=COMPUTE_OUTPUT_LN_FAST_DROPOUT, + num_warps=num_warps, + ) + return y, mean, rstd, random_mask + + +@_triton_layer_norm_mul_dropout_fwd_impl.register_fake +def _triton_layer_norm_mul_dropout_fwd_impl_fake( + x: torch.Tensor, + u: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + dropout_ratio: float, + training: bool, + silu_u: bool, + concat_u: bool, + concat_x: bool, + mul_u_activation_type: str, + seed: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Fake implementation for FakeTensor tracing.""" + N, D = x.shape + if concat_u and concat_x: + y = torch.empty((N, 3 * D), dtype=x.dtype, device=x.device) + elif concat_u: + y = torch.empty((N, 2 * D), dtype=x.dtype, device=x.device) + elif concat_x: + y = torch.empty((N, 2 * D), dtype=x.dtype, device=x.device) + else: + y = torch.empty_like(x) + mean = torch.empty((N,), dtype=torch.float32, device=x.device) + rstd = torch.empty((N,), dtype=torch.float32, device=x.device) + random_mask = torch.empty(0, dtype=x.dtype, device=x.device) + return y, mean, rstd, random_mask + + +def triton_layer_norm_mul_dropout_fwd( + x: torch.Tensor, + u: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + dropout_ratio: float, + training: bool, + silu_u: bool = False, + concat_u: bool = False, + concat_x: bool = False, + mul_u_activation_type: str = "none", + seed: Optional[int] = None, +) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, Optional[torch.Tensor] +]: # y, mean, rstd, BLOCK_D, num_warps, seed, random_mask + """Forward pass for layer norm + mul + dropout. + + Args: + x: Input tensor of shape (N, D) + u: Second input tensor of shape (N, D) + weight: Layer norm weight of shape (D,) + bias: Layer norm bias of shape (D,) + eps: Layer norm epsilon + dropout_ratio: Dropout probability + training: Whether in training mode + silu_u: Whether to apply SiLU to u before concatenation + concat_u: Whether to concatenate u to output + concat_x: Whether to concatenate x to output + mul_u_activation_type: Activation type for u multiplication + seed: Random seed for dropout + + Returns: + Tuple of (y, mean, rstd, BLOCK_D, num_warps, seed, random_mask) + - random_mask is None when using fused RNG path (non-SM100+) + - random_mask is always returned when using separate RNG path (SM100+) + for reuse in backward pass (avoids redundant mask generation) + """ + assert x.dim() == 2 + x = switch_to_contiguous_if_needed(x) + N, D = x.shape + assert weight.dim() == 1 + assert bias.dim() == 1 + assert weight.numel() == D + assert bias.numel() == D + + if N == 0: + D = x.shape[1] + if concat_u and concat_x: + y = torch.empty((0, 3 * D), dtype=x.dtype, device=x.device) + elif concat_u or concat_x: + y = torch.empty((0, 2 * D), dtype=x.dtype, device=x.device) + else: + y = torch.empty_like(x) + return ( + y, + torch.empty((N,), dtype=torch.float32, device=x.device), + torch.empty((N,), dtype=torch.float32, device=x.device), + 0, + 0, + 0, + None, + ) + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_D: int = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BLOCK_D: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + + if seed is None and training: + # pyre-ignore[9]: torch.randint with dtype=int64 always returns int + seed = torch.randint(low=0, high=2**62, size=(1,), dtype=torch.int64).item() + num_warps: int = min(max(BLOCK_D // 256, 1), 8) + + # Call internal implementation + y, mean, rstd, random_mask_tensor = _triton_layer_norm_mul_dropout_fwd_impl( + x, + u, + weight, + bias, + eps, + dropout_ratio, + training, + silu_u, + concat_u, + concat_x, + mul_u_activation_type, + seed if seed is not None else 0, + ) + + # Convert empty tensor back to None + random_mask: Optional[torch.Tensor] = ( + random_mask_tensor if random_mask_tensor.numel() > 0 else None + ) + + return y, mean, rstd, BLOCK_D, num_warps, seed, random_mask # pyre-ignore[7] + + +@maybe_register_custom_op( + "generative_recommenders::_triton_layer_norm_mul_dropout_bwd_impl", mutates_args=() +) +def _triton_layer_norm_mul_dropout_bwd_impl( + dy: torch.Tensor, + x: torch.Tensor, + u: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + mean: torch.Tensor, + rstd: torch.Tensor, + BLOCK_D: int, + num_warps: int, + eps: float, + training: bool, + dropout_ratio: float, + seed: int, + silu_u: bool, + concat_u: bool, + concat_x: bool, + mul_u_activation_type: str, + compute_y: bool, + random_mask: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Internal implementation that returns only tensors for custom_op compatibility. + + When compute_y is False, y is returned as an empty tensor. + random_mask with numel() == 0 means no mask (fused RNG path). + """ + N, D = x.shape + if compute_y: + if concat_u and concat_x: + y = torch.empty((N, 3 * D), dtype=x.dtype, device=x.device) + elif concat_u: + y = torch.empty((N, 2 * D), dtype=x.dtype, device=x.device) + elif concat_x: + y = torch.empty((N, 2 * D), dtype=x.dtype, device=x.device) + else: + y = torch.empty_like(x) + else: + y = torch.empty(0, dtype=x.dtype, device=x.device) + + if N == 0: + return ( + torch.zeros_like(x), + torch.zeros_like(u), + torch.zeros((D,), dtype=weight.dtype, device=x.device), + torch.zeros((D,), dtype=weight.dtype, device=x.device), + y, + ) + dx = torch.empty_like(x) + du = torch.empty_like(u) + sms = torch.cuda.get_device_properties(x.device).multi_processor_count + tile_num = max(1, min(sms * 64, N // 4)) + _dweight = torch.empty((tile_num, D), dtype=torch.float32, device=x.device) + _dbias = torch.empty((tile_num, D), dtype=torch.float32, device=x.device) + dweight = torch.empty((D,), dtype=weight.dtype, device=x.device) + dbias = torch.empty((D,), dtype=weight.dtype, device=x.device) + + # Use separated RNG when random_mask is provided (from forward pass on SM100+ path) + has_random_mask = random_mask.numel() > 0 + if has_random_mask: + # pyre-ignore[28] + _ln_mul_dropout_bwd_dx_du_rng[(tile_num,)]( + dx, + du, + dy, + _dweight, + _dbias, + x, + u, + y if compute_y else None, + weight, + bias, + mean, + rstd, + random_mask, + dx.stride(0), + du.stride(0), + dy.stride(0), + x.stride(0), + u.stride(0), + y.stride(0) if compute_y else 0, + random_mask.stride(0), + D, + eps, + dropout_ratio, + N=N, + SILU_U=silu_u, + BLOCK_D=BLOCK_D, + TRAINING=training, + CONCAT_U=concat_u, + CONCAT_X=concat_x, + MUL_U_ACTIVATION_TYPE=mul_u_activation_type, + COMPUTE_Y=compute_y, + num_warps=num_warps, + ) + + else: + # pyre-ignore[28] + _ln_mul_dropout_bwd_dx_du[(tile_num,)]( + dx, + du, + dy, + _dweight, + _dbias, + x, + u, + y if compute_y else None, + weight, + bias, + mean, + rstd, + dx.stride(0), + du.stride(0), + dy.stride(0), + x.stride(0), + u.stride(0), + y.stride(0) if compute_y else 0, + D, + eps, + seed, + dropout_ratio, + N=N, + SILU_U=silu_u, + BLOCK_D=BLOCK_D, + TRAINING=training, + CONCAT_U=concat_u, + CONCAT_X=concat_x, + MUL_U_ACTIVATION_TYPE=mul_u_activation_type, + COMPUTE_Y=compute_y, + FAST_DROPOUT=COMPUTE_OUTPUT_LN_FAST_DROPOUT, + num_warps=num_warps, + ) + + def grid(META): + return (triton.cdiv(D, META["BLOCK_D"]),) + + blocks = triton.next_power_of_2(sms * 4) + BLOCK_D_bwd = triton.next_power_of_2(triton.cdiv(D, blocks)) + BLOCK_D_bwd = min(max(BLOCK_D_bwd, 4), 128) + _ln_mul_dropout_bwd_dwdb[grid]( + _dweight, + _dbias, + dweight, + dbias, + tile_num, + D, + BLOCK_D=BLOCK_D_bwd, + ) + return dx, du, dweight, dbias, y + + +@_triton_layer_norm_mul_dropout_bwd_impl.register_fake +def _triton_layer_norm_mul_dropout_bwd_impl_fake( + dy: torch.Tensor, + x: torch.Tensor, + u: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + mean: torch.Tensor, + rstd: torch.Tensor, + BLOCK_D: int, + num_warps: int, + eps: float, + training: bool, + dropout_ratio: float, + seed: int, + silu_u: bool, + concat_u: bool, + concat_x: bool, + mul_u_activation_type: str, + compute_y: bool, + random_mask: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Fake implementation for FakeTensor tracing.""" + N, D = x.shape + dx = torch.empty_like(x) + du = torch.empty_like(u) + dweight = torch.empty((D,), dtype=weight.dtype, device=x.device) + dbias = torch.empty((D,), dtype=weight.dtype, device=x.device) + if compute_y: + if concat_u and concat_x: + y = torch.empty((N, 3 * D), dtype=x.dtype, device=x.device) + elif concat_u: + y = torch.empty((N, 2 * D), dtype=x.dtype, device=x.device) + elif concat_x: + y = torch.empty((N, 2 * D), dtype=x.dtype, device=x.device) + else: + y = torch.empty_like(x) + else: + y = torch.empty(0, dtype=x.dtype, device=x.device) + return dx, du, dweight, dbias, y + + +def triton_layer_norm_mul_dropout_bwd( + dy: torch.Tensor, + x: torch.Tensor, + u: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + mean: torch.Tensor, + rstd: torch.Tensor, + BLOCK_D: int, + num_warps: int, + eps: float, + training: bool, + dropout_ratio: float, + seed: Optional[int] = None, + silu_u: bool = False, + concat_u: bool = False, + concat_x: bool = False, + mul_u_activation_type: str = "none", + compute_y: bool = False, + random_mask: Optional[torch.Tensor] = None, +) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor] +]: + N, D = x.shape + + # Use empty tensor as sentinel for no random_mask + random_mask_tensor = ( + random_mask + if random_mask is not None + else torch.empty(0, dtype=x.dtype, device=x.device) + ) + + dx, du, dweight, dbias, y_tensor = _triton_layer_norm_mul_dropout_bwd_impl( + dy, + x, + u, + weight, + bias, + mean, + rstd, + BLOCK_D, + num_warps, + eps, + training, + dropout_ratio, + seed if seed is not None else 0, + silu_u, + concat_u, + concat_x, + mul_u_activation_type, + compute_y, + random_mask_tensor, + ) + + # Convert empty tensor back to None + y: Optional[torch.Tensor] = y_tensor if compute_y else None + return dx, du, dweight, dbias, y + + +class LayerNormMulDropoutFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + x: torch.Tensor, + u: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + dropout_ratio: float, + training: bool, + silu_u: bool = False, + concat_ux: bool = False, + seed: Optional[int] = None, + ) -> torch.Tensor: + if dropout_ratio == 0.0: + # skip dropout computation if dropout ratio is 0 + training = False + # skipping supporting concat_u and concat_x separately here because seems like this code path is only used in v1 of hstu_linear + concat_u, concat_x = concat_ux, concat_ux + + # Call forward function which generates and returns random_mask + # On SM100+ path, random_mask is always returned for backward reuse + # On fused RNG path, random_mask is None (mask generated inline in kernel) + y, mean, rstd, BLOCK_D, num_warps, returned_seed, random_mask = ( + triton_layer_norm_mul_dropout_fwd( + x=x, + u=u, + weight=weight, + bias=bias, + eps=eps, + dropout_ratio=dropout_ratio, + training=training, + silu_u=silu_u, + concat_u=concat_u, + concat_x=concat_x, + seed=seed, + ) + ) + + # Save tensors for backward pass + # When random_mask is generated (SM100+ path), always save it for reuse + # in backward pass. This avoids redundant _generate_random_mask execution. + if random_mask is not None: + ctx.save_for_backward(x, u, weight, bias, mean, rstd, random_mask) + ctx.has_random_mask = True + else: + ctx.save_for_backward(x, u, weight, bias, mean, rstd) + ctx.has_random_mask = False + + ctx.BLOCK_D = BLOCK_D + ctx.num_warps = num_warps + ctx.eps = eps + ctx.seed = returned_seed + ctx.training = training + ctx.concat_ux = concat_ux + ctx.silu_u = silu_u + ctx.dropout_ratio = dropout_ratio + return y + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, dy: torch.Tensor + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + None, + None, + None, + None, + None, + None, + ]: + # Extract saved tensors including optional random mask + if ctx.has_random_mask: + x, u, weight, bias, mean, rstd, random_mask = ctx.saved_tensors + else: + x, u, weight, bias, mean, rstd = ctx.saved_tensors + random_mask = None + + dx, du, dweight, dbias, _ = triton_layer_norm_mul_dropout_bwd( + dy=dy, + x=x, + u=u, + weight=weight, + bias=bias, + mean=mean, + rstd=rstd, + BLOCK_D=ctx.BLOCK_D, + num_warps=ctx.num_warps, + eps=ctx.eps, + training=ctx.training, + dropout_ratio=ctx.dropout_ratio, + seed=ctx.seed, + silu_u=ctx.silu_u, + concat_u=ctx.concat_ux, + concat_x=ctx.concat_ux, + compute_y=False, + random_mask=random_mask, # Pass saved mask to backward + ) + return dx, du, dweight, dbias, None, None, None, None, None, None + + +@triton.jit +def _group_norm_mul_dropout_fwd( + X, + U, + Y, + W, + B, + Mean, + Rstd, + D, + Heads, + eps, + seed, + dropout_ratio, + stride_x, + stride_u, + stride_y, + SILU_U: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_H: tl.constexpr, + TRAINING: tl.constexpr, + CONCAT_UX: tl.constexpr, +): + row = tl.program_id(0) + X += row.to(tl.int64) * stride_x + U += row.to(tl.int64) * stride_u + Y += row.to(tl.int64) * stride_y + cols = tl.arange(0, BLOCK_D) + heads = tl.arange(0, BLOCK_H) + offsets = heads[:, None] * D + cols[None, :] + mask_h = heads < Heads + mask_c = cols < D + mask = mask_c[None, :] & mask_h[:, None] + + # Compute mean + mean = 0.0 + x = tl.load(X + offsets, mask=mask, other=0.0).to(tl.float32) + mean = tl.sum(x, axis=1) / D + mean = tl.ravel(mean) + + # Compute variance + _var = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32) + x_mean = tl.where(mask, x - mean[:, None], 0.0) + _var += x_mean * x_mean + var = tl.sum(_var, axis=1) / D + var = tl.ravel(var) + rstd = 1 / tl.sqrt(var + eps) + tl.store(Mean + row * Heads + heads, mean, mask=mask_h) + tl.store(Rstd + row * Heads + heads, rstd, mask=mask_h) + + # Normalize and apply linear transformation + y = x_mean * rstd[:, None] # pyre-ignore [16] + w = tl.load(W + heads, mask=mask_h).to(tl.float32) + b = tl.load(B + heads, mask=mask_h).to(tl.float32) + y = y * w[:, None] + b[:, None] + u = tl.load(U + offsets, mask=mask, other=0.0).to(tl.float32) + if SILU_U: + u = u * tl.sigmoid(u) + y = y * u + + if TRAINING: + if CONCAT_UX: + random_offsets = row * 3 * D * Heads + offsets + # apply dropout on u + random_u = tl.rand(seed, random_offsets) + u_keep = random_u > dropout_ratio + u = tl.where(u_keep, u / (1.0 - dropout_ratio), 0.0) + # apply dropout on x + random_x = tl.rand(seed, random_offsets + Heads * D) + x_keep = random_x > dropout_ratio + x = tl.where(x_keep, x / (1.0 - dropout_ratio), 0.0) + # apply dropout on y + random_y = tl.rand(seed, random_offsets + 2 * Heads * D) + y_keep = random_y > dropout_ratio + y = tl.where(y_keep, y / (1.0 - dropout_ratio), 0.0) + else: + random_offsets = row * D * Heads + offsets + random = tl.rand(seed, random_offsets) + y_keep = random > dropout_ratio + # write-back + y = tl.where(y_keep, y / (1.0 - dropout_ratio), 0.0) + + # Write output + if CONCAT_UX: + tl.store(Y + offsets, u.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + Heads * D + offsets, x.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + 2 * Heads * D + offsets, y.to(Y.dtype.element_ty), mask=mask) + else: + tl.store(Y + offsets, y.to(Y.dtype.element_ty), mask=mask) + + +@triton.jit +def _group_norm_mul_dropout_bwd_dx_du( + DX, + DU, + DY, + DW, + DB, + X, + U, + Y, + W, + B, + Mean, + Rstd, + stride_dx, + stride_du, + stride_dy, + stride_x, + stride_u, + stride_y, + D, + Heads, + eps, + seed, + dropout_ratio, + SILU_U: tl.constexpr, + GROUP_N: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_H: tl.constexpr, + TRAINING: tl.constexpr, + CONCAT_UX: tl.constexpr, + COMPUTE_Y: tl.constexpr, +): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_D) + off_heads = tl.arange(0, BLOCK_H) + mask_c = cols < D + mask_h = off_heads < Heads + mask = mask_c[None, :] & mask_h[:, None] + X += row.to(tl.int64) * stride_x + U += row.to(tl.int64) * stride_u + DY += row.to(tl.int64) * stride_dy + DX += row.to(tl.int64) * stride_dx + DU += row.to(tl.int64) * stride_du + offsets = off_heads[:, None] * D + cols[None, :] + + # Load data to SRAM + x = tl.load(X + offsets, mask=mask, other=0).to(tl.float32) + if CONCAT_UX: + du = tl.load(DY + offsets, mask=mask, other=0).to(tl.float32) + dx = tl.load(DY + Heads * D + offsets, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + 2 * Heads * D + offsets, mask=mask, other=0).to(tl.float32) + else: + du = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32) + dx = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32) + dy = tl.load(DY + offsets, mask=mask, other=0).to(tl.float32) + if TRAINING: + if CONCAT_UX: + random_offsets = row * 3 * D * Heads + offsets + # apply dropout on du + random_du = tl.rand(seed, random_offsets) + du_keep = random_du > dropout_ratio + du = tl.where(du_keep, du / (1.0 - dropout_ratio), 0.0) + # apply dropout on dx + random_dx = tl.rand(seed, random_offsets + Heads * D) + dx_keep = random_dx > dropout_ratio + dx = tl.where(dx_keep, dx / (1.0 - dropout_ratio), 0.0) + # apply dropout on dy + random_dy = tl.rand(seed, random_offsets + 2 * Heads * D) + dy_keep = random_dy > dropout_ratio + dy = tl.where(dy_keep, dy / (1.0 - dropout_ratio), 0.0) + else: + random_offsets = row * D * Heads + offsets + random = tl.rand(seed, random_offsets) + dy_keep = random > dropout_ratio + # write-back + dy = tl.where(dy_keep, dy / (1.0 - dropout_ratio), 0.0) + + mean = tl.load(Mean + row * Heads + off_heads) + rstd = tl.load(Rstd + row * Heads + off_heads) + + # Compute dx + xhat = (x - mean[:, None]) * rstd[:, None] + w = tl.load(W + off_heads, mask=mask_h).to(tl.float32) + b = tl.load(B + off_heads, mask=mask_h).to(tl.float32) + u = tl.load(U + offsets, mask=mask, other=0).to(tl.float32) + ln = xhat * w[:, None] + b[:, None] + du += dy * ln + if SILU_U: + sig_u = tl.sigmoid(u) + silu_u = u * sig_u + du = du * sig_u * (1 + u - silu_u) + u = silu_u + tl.store(DU + offsets, du.to(DU.dtype.element_ty), mask=mask) + dy = dy * u + wdy = w[:, None] * dy + if COMPUTE_Y: + Y += row.to(tl.int64) * stride_y + y = ln * u + if TRAINING: + if CONCAT_UX: + u = tl.where( + du_keep, # pyre-ignore [61] + u / (1.0 - dropout_ratio), + 0.0, + ) + x = tl.where( + dx_keep, # pyre-ignore [61] + x / (1.0 - dropout_ratio), + 0.0, + ) + y = tl.where( + dy_keep, # pyre-ignore [61] + y / (1.0 - dropout_ratio), + 0.0, + ) + else: + y = tl.where( + dy_keep, # pyre-ignore [61] + y / (1.0 - dropout_ratio), + 0.0, + ) + if CONCAT_UX: + tl.store(Y + offsets, u.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + Heads * D + offsets, x.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + 2 * Heads * D + offsets, y.to(Y.dtype.element_ty), mask=mask) + else: + tl.store(Y + offsets, y.to(Y.dtype.element_ty), mask=mask) + + xhat = tl.where(mask, xhat, 0.0) + wdy = tl.where(mask, wdy, 0.0) + c1 = tl.sum(xhat * wdy, axis=1) / D + c2 = tl.sum(wdy, axis=1) / D + dx += (wdy - (xhat * c1[:, None] + c2[:, None])) * rstd[:, None] + # Write dx + tl.store(DX + offsets, dx, mask=mask) + + # Offset locks and weights/biases gradient pointer for parallel reduction + lock_id = row % GROUP_N + DW = DW + lock_id * Heads + off_heads + DB = DB + lock_id * Heads + off_heads + # Accumulate partial sums for dw/db + partial_dw = tl.sum(dy * xhat, axis=1) + partial_dw = tl.ravel(partial_dw) + partial_db = tl.sum(dy, axis=1) + partial_db = tl.ravel(partial_db) + tl.atomic_add( + DW, + partial_dw, + mask=mask_h, + sem="relaxed", + ) + tl.atomic_add( + DB, + partial_db, + mask=mask_h, + sem="relaxed", + ) + + +def triton_group_norm_mul_dropout_fwd( + x: torch.Tensor, + u: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + dropout_ratio: float, + training: bool, + silu_u: bool = False, + concat_ux: bool = False, + num_heads: int = 1, + linear_dim: int = -1, + seed: Optional[int] = None, +) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, int +]: # y, mean, rstd, BLOCK_D, BLOCK_H, num_warps, seed + assert x.dim() == 2 + assert x.shape == u.shape + assert x.shape[1] == num_heads * linear_dim + x = switch_to_contiguous_if_needed(x) + u = switch_to_contiguous_if_needed(u) + N, _ = x.shape + assert weight.dim() == 1 + assert bias.dim() == 1 + assert weight.numel() == num_heads + assert bias.numel() == num_heads + + if concat_ux: + y = torch.empty((N, 3 * num_heads * linear_dim), dtype=x.dtype, device=x.device) + else: + y = torch.empty((N, num_heads * linear_dim), dtype=x.dtype, device=x.device) + mean = torch.empty((N * num_heads,), dtype=torch.float32, device=x.device) + rstd = torch.empty((N * num_heads,), dtype=torch.float32, device=x.device) + if N == 0: + return y, mean, rstd, 0, 0, 0, 0 + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_D: int = triton.next_power_of_2(linear_dim) + BLOCK_H: int = triton.next_power_of_2(num_heads) + if BLOCK_D * BLOCK_H > MAX_FUSED_SIZE: + raise RuntimeError( + "This group norm doesn't support num_heads * linear_dim >= 64KB." + ) + + if seed is None: + seed = torch.randint(low=0, high=2**62, size=(1,), dtype=torch.int64).item() + num_warps: int = min(max(BLOCK_D * BLOCK_H // 256, 1), 8) + # pyre-ignore[28] + _group_norm_mul_dropout_fwd[(N,)]( + x, + u, + y, + weight, + bias, + mean, + rstd, + linear_dim, + num_heads, + eps, + seed, + dropout_ratio, + x.stride(0), + u.stride(0), + y.stride(0), + SILU_U=silu_u, + BLOCK_D=BLOCK_D, + BLOCK_H=BLOCK_H, + TRAINING=training, + CONCAT_UX=concat_ux, + num_warps=num_warps, + ) + return y, mean, rstd, BLOCK_D, BLOCK_H, num_warps, seed # pyre-ignore [7] + + +def triton_group_norm_mul_dropout_bwd( + dy: torch.Tensor, + x: torch.Tensor, + u: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + mean: torch.Tensor, + rstd: torch.Tensor, + BLOCK_D: int, + BLOCK_H: int, + num_warps: int, + eps: float, + training: bool, + dropout_ratio: float, + seed: Optional[int] = None, + silu_u: bool = False, + concat_ux: bool = False, + num_heads: int = 1, + linear_dim: int = -1, + compute_y: bool = False, +) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor] +]: + y = None + N, dim = x.shape + if compute_y: + if concat_ux: + y = torch.empty( + (N, 3 * num_heads * linear_dim), dtype=x.dtype, device=x.device + ) + else: + y = torch.empty((N, num_heads * linear_dim), dtype=x.dtype, device=x.device) + if N == 0: + return ( + torch.zeros_like(x), + torch.zeros_like(u), + torch.zeros_like(weight), + torch.zeros_like(bias), + y, + ) + dx = torch.empty_like(x) + du = torch.empty_like(u) + if dim <= 1024: + GROUP_N = 256 * 8 + elif dim <= 4096: + GROUP_N = 128 * 8 + elif dim <= 8192: + GROUP_N = 96 * 8 + else: + GROUP_N = 64 * 8 + GROUP_N = N if GROUP_N > N else GROUP_N + _dweight = torch.zeros((GROUP_N, num_heads), dtype=torch.float32, device=x.device) + _dbias = torch.zeros((GROUP_N, num_heads), dtype=torch.float32, device=x.device) + dweight = torch.empty((num_heads,), dtype=weight.dtype, device=x.device) + dbias = torch.empty((num_heads,), dtype=weight.dtype, device=x.device) + # pyre-ignore[28] + _group_norm_mul_dropout_bwd_dx_du[(N,)]( + dx, + du, + dy, + _dweight, + _dbias, + x, + u, + y, + weight, + bias, + mean, + rstd, + dx.stride(0), + du.stride(0), + dy.stride(0), + x.stride(0), + u.stride(0), + y.stride(0) if compute_y else 0, # pyre-ignore [16] + linear_dim, + num_heads, + eps, + seed, + dropout_ratio, + SILU_U=silu_u, + GROUP_N=GROUP_N, + BLOCK_D=BLOCK_D, + BLOCK_H=BLOCK_H, + TRAINING=training, + CONCAT_UX=concat_ux, + COMPUTE_Y=compute_y, + num_warps=num_warps, + ) + _group_norm_bwd_dwdb[(num_heads,)]( + _dweight, + _dbias, + dweight, + dbias, + GROUP_N, + ) + return dx, du, dweight, dbias, y + + +def _get_bwd_dwdb_configs() -> List[triton.Config]: + configs = [] + for BLOCK_N in [32, 64, 128, 256]: + for num_warps in [8, 16] + ([] if torch.ops.hip else [32]): + configs.append( + triton.Config( + {"BLOCK_N": BLOCK_N}, + num_warps=num_warps, + ) + ) + return configs + + +@triton_autotune( + configs=_get_bwd_dwdb_configs(), + key=[], +) +@triton.jit +def _group_norm_bwd_dwdb( + DW, + DB, + FINAL_DW, + FINAL_DB, + N, + BLOCK_N: tl.constexpr, +): + col = tl.program_id(0) + num_heads = tl.num_programs(0) + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + + for i in range(0, N, BLOCK_N): + rows = i + tl.arange(0, BLOCK_N) + mask = rows < N + offs = rows * num_heads + col + dw += tl.load(DW + offs, mask=mask, other=0.0) + db += tl.load(DB + offs, mask=mask, other=0.0) + + sum_dw = tl.sum(dw, axis=0) + sum_db = tl.sum(db, axis=0) + tl.store(FINAL_DW + col, sum_dw.to(FINAL_DW.dtype.element_ty)) + tl.store(FINAL_DB + col, sum_db.to(FINAL_DB.dtype.element_ty)) + + +class GroupNormMulDropoutFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + x: torch.Tensor, + u: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + dropout_ratio: float, + training: bool, + silu_u: bool = False, + concat_ux: bool = False, + num_heads: int = 1, + linear_dim: int = -1, + seed: Optional[int] = None, + ) -> torch.Tensor: + y, mean, rstd, BLOCK_D, BLOCK_H, num_warps, seed = ( + triton_group_norm_mul_dropout_fwd( + x=x, + u=u, + weight=weight, + bias=bias, + eps=eps, + dropout_ratio=dropout_ratio, + training=training, + silu_u=silu_u, + concat_ux=concat_ux, + num_heads=num_heads, + linear_dim=linear_dim, + seed=seed, + ) + ) + ctx.save_for_backward(x, u, weight, bias, mean, rstd) + ctx.BLOCK_D = BLOCK_D + ctx.BLOCK_H = BLOCK_H + ctx.num_warps = num_warps + ctx.eps = eps + ctx.seed = seed + ctx.training = training + ctx.silu_u = silu_u + ctx.concat_ux = concat_ux + ctx.dropout_ratio = dropout_ratio + ctx.num_heads = num_heads + ctx.linear_dim = linear_dim + return y + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, dy: torch.Tensor + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + None, + None, + None, + None, + None, + None, + None, + None, + ]: + x, u, weight, bias, mean, rstd = ctx.saved_tensors + dx, du, dweight, dbias, _ = triton_group_norm_mul_dropout_bwd( + dy=dy, + x=x, + u=u, + weight=weight, + bias=bias, + mean=mean, + rstd=rstd, + BLOCK_D=ctx.BLOCK_D, + BLOCK_H=ctx.BLOCK_H, + num_warps=ctx.num_warps, + eps=ctx.eps, + training=ctx.training, + dropout_ratio=ctx.dropout_ratio, + seed=ctx.seed, + silu_u=ctx.silu_u, + concat_ux=ctx.concat_ux, + num_heads=ctx.num_heads, + linear_dim=ctx.linear_dim, + compute_y=False, + ) + return ( + dx, + du, + dweight, + dbias, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +class HSTUComputeOutputFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + attn: torch.Tensor, + u: torch.Tensor, + x: torch.Tensor, + norm_weight: torch.Tensor, + norm_bias: torch.Tensor, + output_weight: torch.Tensor, + eps: float, + dropout_ratio: float, + training: bool, + silu_u: bool = False, + concat_u: bool = False, + concat_x: bool = False, + mul_u_activation_type: str = "none", + group_norm: bool = False, + num_heads: int = 1, + linear_dim: int = -1, + seed: Optional[int] = None, + recompute_y_in_backward: bool = False, + ) -> torch.Tensor: + if dropout_ratio == 0.0: + training = False + + if group_norm: + y, mean, rstd, BLOCK_D, BLOCK_H, num_warps, seed = ( + triton_group_norm_mul_dropout_fwd( + x=attn, + u=u, + weight=norm_weight, + bias=norm_bias, + eps=eps, + dropout_ratio=dropout_ratio, + training=training, + silu_u=silu_u, + concat_ux=concat_u and concat_x, + num_heads=num_heads, + linear_dim=linear_dim, + seed=seed, + ) + ) + ctx.BLOCK_H = BLOCK_H + random_mask = None + else: + y, mean, rstd, BLOCK_D, num_warps, seed, random_mask = ( + triton_layer_norm_mul_dropout_fwd( + x=attn, + u=u, + weight=norm_weight, + bias=norm_bias, + eps=eps, + dropout_ratio=dropout_ratio, + training=training, + silu_u=silu_u, + concat_u=concat_u, + concat_x=concat_x, + seed=seed, + ) + ) + + out = maybe_triton_addmm_fwd(x=y, w=output_weight, y=x) + + saved_tensors = [attn, u, norm_weight, norm_bias, mean, rstd, output_weight] + if not recompute_y_in_backward: + saved_tensors.append(y) + # Save random_mask for reuse in backward pass (avoids regenerating mask) + # When random_mask is available (SM100+ path), always save it. + if random_mask is not None: + saved_tensors.append(random_mask) + ctx.has_random_mask = True + else: + ctx.has_random_mask = False + ctx.save_for_backward(*saved_tensors) + ctx.BLOCK_D = BLOCK_D + ctx.num_warps = num_warps + ctx.eps = eps + ctx.seed = seed + ctx.training = training + ctx.concat_u = concat_u + ctx.concat_x = concat_x + ctx.dropout_ratio = dropout_ratio + ctx.num_heads = num_heads + ctx.linear_dim = linear_dim + ctx.group_norm = group_norm + ctx.recompute_y_in_backward = recompute_y_in_backward + ctx.silu_u = silu_u + ctx.mul_u_activation_type = mul_u_activation_type + return out + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, dout: torch.Tensor + ) -> Tuple[ + torch.Tensor, # dattn + torch.Tensor, # du + torch.Tensor, # dx + torch.Tensor, # d_norm_weight + torch.Tensor, # d_norm_bias + torch.Tensor, # d_output_weight + None, # eps + None, # dropout_ratio + None, # training + None, # silu_u + None, # concat_u + None, # concat_x + None, # mul_u_activation_type + None, # group_norm + None, # num_heads + None, # linear_dim + None, # seed + None, # recompute_y_in_backward + ]: + attn, u, norm_weight, norm_bias, mean, rstd, output_weight = ctx.saved_tensors[ + :7 + ] + # Extract optional saved tensors based on flags + next_idx = 7 + if not ctx.recompute_y_in_backward: + saved_y = ctx.saved_tensors[next_idx] + next_idx += 1 + else: + saved_y = None + if ctx.has_random_mask: + random_mask = ctx.saved_tensors[next_idx] + else: + random_mask = None + dy = torch.mm(dout, output_weight.t()) + + if ctx.group_norm: + dattn, du, d_norm_weight, d_norm_bias, y = ( + triton_group_norm_mul_dropout_bwd( + dy=dy, + x=attn, + u=u, + weight=norm_weight, + bias=norm_bias, + mean=mean, + rstd=rstd, + BLOCK_D=ctx.BLOCK_D, + BLOCK_H=ctx.BLOCK_H, + num_warps=ctx.num_warps, + eps=ctx.eps, + training=ctx.training, + dropout_ratio=ctx.dropout_ratio, + seed=ctx.seed, + silu_u=ctx.silu_u, + concat_ux=ctx.concat_u and ctx.concat_x, + num_heads=ctx.num_heads, + linear_dim=ctx.linear_dim, + compute_y=ctx.recompute_y_in_backward, + ) + ) + else: + dattn, du, d_norm_weight, d_norm_bias, y = ( + triton_layer_norm_mul_dropout_bwd( + dy=dy, + x=attn, + u=u, + weight=norm_weight, + bias=norm_bias, + mean=mean, + rstd=rstd, + BLOCK_D=ctx.BLOCK_D, + num_warps=ctx.num_warps, + eps=ctx.eps, + training=ctx.training, + dropout_ratio=ctx.dropout_ratio, + seed=ctx.seed, + silu_u=ctx.silu_u, + concat_u=ctx.concat_u, + concat_x=ctx.concat_x, + mul_u_activation_type=ctx.mul_u_activation_type, + compute_y=ctx.recompute_y_in_backward, + random_mask=random_mask, + ) + ) + if not ctx.recompute_y_in_backward: + y = saved_y + d_output_weight = torch.mm(y.t(), dout) + return ( + dattn, + du, + dout, + d_norm_weight, + d_norm_bias, + d_output_weight, + None, # eps + None, # dropout_ratio + None, # training + None, # silu_u + None, # concat_u + None, # concat_x + None, # mul_u_activation_type + None, # group_norm + None, # num_heads + None, # linear_dim + None, # seed + None, # recompute_y_in_backward + ) + + +@triton.jit +def _helion_ln_mul_dropout_fwd( + x, + weight, + bias, + u, + y, + mean, + rstd, + eps, + seed, + dropout_ratio, + D: tl.constexpr, + stride_x: tl.constexpr, + stride_u: tl.constexpr, + stride_y: tl.constexpr, + BLOCK_D: tl.constexpr, + CONCAT_UX: tl.constexpr, + SILU_U: tl.constexpr, + TRAINING: tl.constexpr, +): + row = tl.program_id(0) + x += row.to(tl.int64) * stride_x + u += row.to(tl.int64) * stride_u + y += row.to(tl.int64) * stride_y + cols = tl.arange(0, BLOCK_D) + mask = cols < D + + # Load input + x_val = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) + + # Precompute inverse of D for faster computation + inv_D = 1.0 / D + + # Compute mean + mean_val = tl.sum(x_val, axis=0) * inv_D + + # Center the data + x_mean = tl.where(mask, x_val - mean_val, 0.0) + + # Compute variance + var = tl.sum(x_mean * x_mean, axis=0) * inv_D + + # Compute reciprocal standard deviation + # pyre-fixme[16] + rstd_val = libdevice.rsqrt(var + eps) + + # Normalize + y_norm = x_mean * rstd_val + + # Apply weight and bias + w = tl.load(weight + cols, mask=mask, other=0.0).to(tl.float32) + b = tl.load(bias + cols, mask=mask, other=0.0).to(tl.float32) + y_ln = y_norm * w + b + + # Load u and optionally apply SiLU activation + u_val = tl.load(u + cols, mask=mask, other=0.0).to(tl.float32) + if SILU_U: + u_processed = u_val * tl.sigmoid(u_val) + else: + u_processed = u_val + + y_out = y_ln * u_processed + + if TRAINING: + # Compute dropout scale + # pyre-fixme[16] + dropout_scale = fast_dividef(1.0, 1.0 - dropout_ratio) + + if CONCAT_UX: + # Generate dropout masks + random_offsets = 3 * row * BLOCK_D + cols + random_u, random_x, random_y = rand3x(seed, random_offsets) + + u_keep = random_u > dropout_ratio + x_keep = random_x > dropout_ratio + y_keep = random_y > dropout_ratio + + # Apply dropout to u, x, y + u_output = tl.where(u_keep, u_processed * dropout_scale, 0.0) + x_output = tl.where(x_keep, x_val * dropout_scale, 0.0) + y_output = tl.where(y_keep, y_out * dropout_scale, 0.0) + else: + # Generate dropout mask for y + random_offsets = row * BLOCK_D + cols + random_y = tl.rand(seed, random_offsets) + y_keep = random_y > dropout_ratio + + # Apply dropout to y + y_output = tl.where(y_keep, y_out * dropout_scale, 0.0) + else: + if CONCAT_UX: + u_output = u_processed + x_output = x_val + y_output = y_out + + # Store outputs + if CONCAT_UX: + tl.store(y + cols, u_output.to(y.dtype.element_ty), mask=mask) + tl.store(y + D + cols, x_output.to(y.dtype.element_ty), mask=mask) + tl.store(y + 2 * D + cols, y_output.to(y.dtype.element_ty), mask=mask) + else: + tl.store(y + cols, y_output.to(y.dtype.element_ty), mask=mask) + + # Store mean and rstd + tl.store(mean + row, mean_val) + tl.store(rstd + row, rstd_val) + + +def helion_layer_norm_mul_dropout_fwd( + x: torch.Tensor, + u: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + dropout_ratio: float, + training: bool, + silu_u: bool = False, + concat_ux: bool = False, + seed: Optional[int] = None, +) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, int, int, int +]: # y, mean, rstd, BLOCK_D, num_warps, seed + N, D = x.shape + + if seed is None: + seed = torch.randint(low=0, high=2**62, size=(1,), dtype=torch.int64).item() + + if concat_ux: + y = torch.empty([N, 3 * D], dtype=x.dtype, device=x.device) + else: + y = torch.empty([N, D], dtype=x.dtype, device=x.device) + mean = torch.empty([N], dtype=torch.float32, device=x.device) + rstd = torch.empty([N], dtype=torch.float32, device=x.device) + + BLOCK_D = triton.next_power_of_2(D) + # pyre-ignore[28] + _helion_ln_mul_dropout_fwd[(N,)]( + x, + weight, + bias, + u, + y, + mean, + rstd, + eps, + seed, + dropout_ratio, + D, + x.stride(0), + u.stride(0), + y.stride(0), + BLOCK_D, + CONCAT_UX=concat_ux, + SILU_U=silu_u, + TRAINING=training, + num_warps=1, + ) + + return y, mean, rstd, BLOCK_D, 1, seed # pyre-ignore [7] + + +@triton.jit +def _helion_ln_mul_dropout_bwd_dx_du( + DX, + DU, + DY, + DW, + DB, + X, + U, + Y, + W, + B, + Mean, + Rstd, + stride_dx, + stride_du, + stride_dy, + stride_x, + stride_u, + stride_y, + D: tl.constexpr, + eps, + seed, + dropout_ratio, + N, + SILU_U: tl.constexpr, + BLOCK_D: tl.constexpr, + TRAINING: tl.constexpr, + CONCAT_UX: tl.constexpr, + COMPUTE_Y: tl.constexpr, +): + pid = tl.program_id(0) + tile_num = tl.num_programs(0) + rows_per_tile = N // tile_num + if pid < N % tile_num: + rows_per_tile += 1 + + if rows_per_tile == 0: + return + + cols = tl.arange(0, BLOCK_D) + mask = cols < D + + # precompute inverse of D + inv_D: tl.constexpr = 1.0 / D + + row = pid + X += row.to(tl.int64) * stride_x + U += row.to(tl.int64) * stride_u + if COMPUTE_Y: + Y += row.to(tl.int64) * stride_y + DY += row.to(tl.int64) * stride_dy + DX += row.to(tl.int64) * stride_dx + DU += row.to(tl.int64) * stride_du + DW = DW + pid * D + cols + DB = DB + pid * D + cols + + partial_dw = tl.zeros((BLOCK_D,), dtype=tl.float32) + partial_db = tl.zeros((BLOCK_D,), dtype=tl.float32) + w = tl.load(W + cols, mask=mask).to(tl.float32) + b = tl.load(B + cols, mask=mask).to(tl.float32) + + for _idx in range(0, rows_per_tile): + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + if CONCAT_UX: + du = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + dx = tl.load(DY + D + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + 2 * D + cols, mask=mask, other=0).to(tl.float32) + else: + du = tl.zeros([BLOCK_D], dtype=tl.float32) + dx = tl.zeros([BLOCK_D], dtype=tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + + if TRAINING: + # pyre-fixme[16] + dropout_scale = fast_dividef(1.0, 1.0 - dropout_ratio) + if CONCAT_UX: + random_offsets = 3 * row * BLOCK_D + cols + # apply dropout on du + random_du, random_dx, random_dy = rand3x(seed, random_offsets) + du_keep = random_du > dropout_ratio + du = tl.where(du_keep, du * dropout_scale, 0.0) + # apply dropout on dx + dx_keep = random_dx > dropout_ratio + dx = tl.where(dx_keep, dx * dropout_scale, 0.0) + # apply dropout on dy + dy_keep = random_dy > dropout_ratio + dy = tl.where(dy_keep, dy * dropout_scale, 0.0) + else: + random_offsets = row * BLOCK_D + cols + random = tl.rand(seed, random_offsets) + dy_keep = random > dropout_ratio + dy = tl.where(dy_keep, dy * dropout_scale, 0.0) + + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + + # Compute dx + xhat = (x - mean) * rstd + u = tl.load(U + cols, mask=mask, other=0).to(tl.float32) + ln = xhat * w + b + du += dy * ln + + if SILU_U: + sig_u = tl.sigmoid(u) + silu_u = u * sig_u + du = du * sig_u * (1 + u - silu_u) + u = silu_u + + tl.store(DU + cols, du.to(DU.dtype.element_ty), mask=mask) + dy = dy * u + wdy = w * dy + + if COMPUTE_Y: + y = ln * u + if TRAINING: + # pyre-fixme[16] + dropout_scale_y = fast_dividef(1.0, 1.0 - dropout_ratio) + if CONCAT_UX: + u = tl.where(du_keep, u * dropout_scale_y, 0.0) # pyre-ignore [61] + x = tl.where(dx_keep, x * dropout_scale_y, 0.0) # pyre-ignore [61] + y = tl.where(dy_keep, y * dropout_scale_y, 0.0) # pyre-ignore [61] + else: + y = tl.where(dy_keep, y * dropout_scale_y, 0.0) # pyre-ignore [61] + if CONCAT_UX: + tl.store(Y + cols, u.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + D + cols, x.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + 2 * D + cols, y.to(Y.dtype.element_ty), mask=mask) + else: + tl.store(Y + cols, y.to(Y.dtype.element_ty), mask=mask) + Y += tile_num.to(tl.int64) * stride_y + + xhat = tl.where(mask, xhat, 0.0) + wdy = tl.where(mask, wdy, 0.0) + # multiply by inv_D + c1 = tl.sum(xhat * wdy, axis=0) * inv_D + c2 = tl.sum(wdy, axis=0) * inv_D + dx += (wdy - (xhat * c1 + c2)) * rstd + + # Write dx + tl.store(DX + cols, dx, mask=mask) + + # Accumulate partial sums for dw/db + partial_dw += dy * xhat + partial_db += dy + + X += tile_num.to(tl.int64) * stride_x + U += tile_num.to(tl.int64) * stride_u + DY += tile_num.to(tl.int64) * stride_dy + DX += tile_num.to(tl.int64) * stride_dx + DU += tile_num.to(tl.int64) * stride_du + row += tile_num + + tl.store(DW, partial_dw, mask=mask) + tl.store(DB, partial_db, mask=mask) + + +@triton_autotune( + configs=_get_bwd_dwdb_configs(), + key=["D"], +) +@triton.jit +def _helion_ln_mul_dropout_bwd_dwdb( + DW, + DB, + FINAL_DW, + FINAL_DB, + N, + D, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid = tl.program_id(0) + cols = pid * BLOCK_D + tl.arange(0, BLOCK_D) + dw = tl.zeros((BLOCK_N, BLOCK_D), dtype=tl.float32) + db = tl.zeros((BLOCK_N, BLOCK_D), dtype=tl.float32) + + for i in range(0, N, BLOCK_N): + rows = i + tl.arange(0, BLOCK_N) + # pyre-fixme[16]: `int` has no attribute `__getitem__`. + off_mask = (rows[:, None] < N) & (cols[None, :] < D) + offs = rows[:, None] * D + cols[None, :] + dw += tl.load(DW + offs, mask=off_mask, other=0.0) + db += tl.load(DB + offs, mask=off_mask, other=0.0) + + sum_dw = tl.sum(dw, axis=0) + sum_db = tl.sum(db, axis=0) + tl.store(FINAL_DW + cols, sum_dw.to(FINAL_DW.dtype.element_ty), mask=cols < D) + tl.store(FINAL_DB + cols, sum_db.to(FINAL_DB.dtype.element_ty), mask=cols < D) + + +def helion_layer_norm_mul_dropout_bwd( + dy: torch.Tensor, + x: torch.Tensor, + u: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + mean: torch.Tensor, + rstd: torch.Tensor, + BLOCK_D: int, + num_warps: int, + eps: float, + training: bool, + dropout_ratio: float, + seed: Optional[int] = None, + silu_u: bool = False, + concat_ux: bool = False, + compute_y: bool = False, +) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor] +]: + y = None + N, D = x.shape + if compute_y: + if concat_ux: + y = torch.empty((N, 3 * D), dtype=x.dtype, device=x.device) + else: + y = torch.empty_like(x) + if N == 0: + return ( + torch.zeros_like(x), + torch.zeros_like(u), + torch.zeros((D,), dtype=weight.dtype, device=x.device), + torch.zeros((D,), dtype=weight.dtype, device=x.device), + y, + ) + dx = torch.empty_like(x) + du = torch.empty_like(u) + sms = torch.cuda.get_device_properties(x.device).multi_processor_count + tile_num = max(1, min(sms * 64, N // 4)) + _dweight = torch.empty((tile_num, D), dtype=torch.float32, device=x.device) + _dbias = torch.empty((tile_num, D), dtype=torch.float32, device=x.device) + dweight = torch.empty((D,), dtype=weight.dtype, device=x.device) + dbias = torch.empty((D,), dtype=weight.dtype, device=x.device) + + # pyre-ignore[28] + _helion_ln_mul_dropout_bwd_dx_du[(tile_num,)]( + dx, + du, + dy, + _dweight, + _dbias, + x, + u, + y, + weight, + bias, + mean, + rstd, + dx.stride(0), + du.stride(0), + dy.stride(0), + x.stride(0), + u.stride(0), + y.stride(0) if compute_y else 0, # pyre-ignore [16] + D, + eps, + seed, + dropout_ratio, + N=N, + SILU_U=silu_u, + BLOCK_D=BLOCK_D, + TRAINING=training, + CONCAT_UX=concat_ux, + COMPUTE_Y=compute_y, + num_warps=num_warps, + ) + + blocks = triton.next_power_of_2(sms * 4) + BLOCK_D_DWDB = triton.next_power_of_2(triton.cdiv(D, blocks)) + BLOCK_D_DWDB = min(max(BLOCK_D_DWDB, 4), 128) + _helion_ln_mul_dropout_bwd_dwdb[(triton.cdiv(D, BLOCK_D_DWDB),)]( + _dweight, + _dbias, + dweight, + dbias, + tile_num, + D, + BLOCK_D=BLOCK_D_DWDB, + ) + return dx, du, dweight, dbias, y + + +class HelionLayerNormMulDropoutFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + x: torch.Tensor, + u: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + dropout_ratio: float, + training: bool, + silu_u: bool = False, + concat_ux: bool = False, + seed: Optional[int] = None, + ) -> torch.Tensor: + if dropout_ratio == 0.0: + # skip dropout computation if dropout ratio is 0 + training = False + y, mean, rstd, BLOCK_D, num_warps, seed = helion_layer_norm_mul_dropout_fwd( + x=x, + u=u, + weight=weight, + bias=bias, + eps=eps, + dropout_ratio=dropout_ratio, + training=training, + silu_u=silu_u, + concat_ux=concat_ux, + seed=seed, + ) + ctx.save_for_backward(x, u, weight, bias, mean, rstd) + ctx.BLOCK_D = BLOCK_D + ctx.num_warps = num_warps + ctx.eps = eps + ctx.seed = seed + ctx.training = training + ctx.silu_u = silu_u + ctx.concat_ux = concat_ux + ctx.dropout_ratio = dropout_ratio + return y + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, dy: torch.Tensor + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + None, + None, + None, + None, + None, + None, + ]: + x, u, weight, bias, mean, rstd = ctx.saved_tensors + dx, du, dweight, dbias, _ = helion_layer_norm_mul_dropout_bwd( + dy=dy, + x=x, + u=u, + weight=weight, + bias=bias, + mean=mean, + rstd=rstd, + BLOCK_D=ctx.BLOCK_D, + num_warps=ctx.num_warps, + eps=ctx.eps, + training=ctx.training, + dropout_ratio=ctx.dropout_ratio, + seed=ctx.seed, + silu_u=ctx.silu_u, + concat_ux=ctx.concat_ux, + compute_y=False, + ) + return dx, du, dweight, dbias, None, None, None, None, None, None + + +@torch.fx.wrap +def helion_norm_mul_dropout( + x: torch.Tensor, + u: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + dropout_ratio: float, + training: bool, + silu_u: bool = False, + concat_ux: bool = False, + seed: Optional[int] = None, +) -> torch.Tensor: + return HelionLayerNormMulDropoutFunction.apply( + x, u, weight, bias, eps, dropout_ratio, training, silu_u, concat_ux, seed + ) + + +@torch.fx.wrap +def triton_norm_mul_dropout( + x: torch.Tensor, + u: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + dropout_ratio: float, + training: bool, + silu_u: bool = False, + concat_u: bool = False, + concat_x: bool = False, + group_norm: bool = False, + num_heads: int = 1, + linear_dim: int = -1, + seed: Optional[int] = None, +) -> torch.Tensor: + if group_norm: + return GroupNormMulDropoutFunction.apply( + x, + u, + weight, + bias, + eps, + dropout_ratio, + training, + silu_u, + concat_u and concat_x, + num_heads, + linear_dim, + seed, + ) + else: + return LayerNormMulDropoutFunction.apply( + x, + u, + weight, + bias, + eps, + dropout_ratio, + training, + silu_u, + concat_u and concat_x, + seed, + ) + + +@torch.jit.unused +@torch.fx.wrap +def triton_hstu_compute_output( + attn: torch.Tensor, + u: torch.Tensor, + x: torch.Tensor, + norm_weight: torch.Tensor, + norm_bias: torch.Tensor, + output_weight: torch.Tensor, + eps: float, + dropout_ratio: float, + training: bool, + silu_u: bool = False, + concat_u: bool = False, + concat_x: bool = False, + mul_u_activation_type: str = "none", + group_norm: bool = False, + num_heads: int = 1, + linear_dim: int = -1, + seed: Optional[int] = None, + recompute_y_in_backward: bool = False, +) -> torch.Tensor: + return HSTUComputeOutputFunction.apply( + attn, + u, + x, + norm_weight, + norm_bias, + output_weight, + eps, + dropout_ratio, + training, + silu_u, + concat_u, + concat_x, + mul_u_activation_type, + group_norm, + num_heads, + linear_dim, + seed, + recompute_y_in_backward, + ) diff --git a/recommendation_v4/generative_recommenders/ops/triton/triton_hstu_preprocess_and_attention.py b/recommendation_v4/generative_recommenders/ops/triton/triton_hstu_preprocess_and_attention.py new file mode 100644 index 000000000..bda97ff96 --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/triton/triton_hstu_preprocess_and_attention.py @@ -0,0 +1,342 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from typing import Optional, Tuple + +import torch +from generative_recommenders.ops.triton.triton_addmm import ( + maybe_triton_addmm_fwd, + triton_addmm_bwd, + triton_addmm_fwd, +) +from generative_recommenders.ops.triton.triton_hstu_attention import ( + _should_enable_tma, + triton_hstu_attention_bwd, + triton_hstu_attention_fwd, +) +from generative_recommenders.ops.triton.triton_layer_norm import ( + compute_BLOCK_D, + triton_weighted_layer_norm_bwd, + triton_weighted_layer_norm_fwd, +) +from torch.nn import functional as F + + +class _HSTUPreprocessAndAttentionFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore [14] + def forward( + ctx, # pyre-ignore [2] + x: torch.Tensor, + norm_weight: torch.Tensor, + norm_bias: torch.Tensor, + norm_eps: float, + num_heads: int, + attn_dim: int, + hidden_dim: int, + uvqk_weight: torch.Tensor, + uvqk_bias: torch.Tensor, + max_seq_len: int, + seq_offsets: torch.Tensor, + attn_alpha: float, + num_targets: Optional[torch.Tensor], + max_attn_len: int, + contextual_seq_len: int, + recompute_uvqk_in_backward: bool, + recompute_normed_x_in_backward: bool, + sort_by_length: bool, + enable_tma: bool, + num_softmax_heads: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert num_softmax_heads == 0, "Softmax attention is not supported" + normed_x, x_mean, x_rstd = triton_weighted_layer_norm_fwd( + x=x, + weight=norm_weight, + bias=norm_bias, + eps=norm_eps, + ) + BLOCK_D = compute_BLOCK_D(x) + uvqk = maybe_triton_addmm_fwd( + x=normed_x, w=uvqk_weight, y=uvqk_bias + ).contiguous() + u, v, q, k = uvqk.split( + [ + hidden_dim * num_heads, + hidden_dim * num_heads, + attn_dim * num_heads, + attn_dim * num_heads, + ], + dim=1, + ) + q = q.view(-1, num_heads, attn_dim) + k = k.view(-1, num_heads, attn_dim) + v = v.view(-1, num_heads, hidden_dim) + silu_u = F.silu(u) + sort_by_length_indices = None + if sort_by_length: + seq_lengths = seq_offsets[1:] - seq_offsets[:-1] + _, sort_by_length_indices = torch.sort( + seq_lengths, descending=True, stable=False + ) + out = triton_hstu_attention_fwd( + N=max_seq_len, + alpha=attn_alpha, + q=q, + k=k, + v=v, + seq_offsets=seq_offsets, + num_targets=num_targets, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + sort_by_length_indices=sort_by_length_indices, + enable_tma=enable_tma, + num_softmax_heads=num_softmax_heads, + ) + # update ctx + saved_tensors = [ + x, + norm_weight, + norm_bias, + x_mean, + x_rstd, + uvqk_weight, + seq_offsets, + ] + if num_targets is not None: + saved_tensors.append(num_targets) + if not recompute_normed_x_in_backward: + saved_tensors.append(normed_x) + if recompute_uvqk_in_backward: + saved_tensors.append(uvqk_bias) + else: + saved_tensors.append(uvqk) + if sort_by_length: + saved_tensors.append(sort_by_length_indices) + ctx.save_for_backward(*saved_tensors) + ctx.attn_alpha = attn_alpha + ctx.has_multiple_targets = num_targets is not None + ctx.max_seq_len = max_seq_len + ctx.max_attn_len = max_attn_len + ctx.recompute_normed_x_in_backward = recompute_normed_x_in_backward + ctx.recompute_uvqk_in_backward = recompute_uvqk_in_backward + ctx.hidden_dim = hidden_dim + ctx.attn_dim = attn_dim + ctx.num_heads = num_heads + ctx.uvqk_bias_1d = uvqk_bias.dim() == 1 + ctx.norm_eps = norm_eps + ctx.norm_BLOCK_D = BLOCK_D + ctx.contextual_seq_len = contextual_seq_len + ctx.sort_by_length = sort_by_length + ctx.enable_tma = enable_tma + ctx.num_softmax_heads = num_softmax_heads + return silu_u, out + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, # pyre-ignore[2] + dsilu_u: torch.Tensor, + dout: torch.Tensor, + ) -> Tuple[ + torch.Tensor, # d_x + torch.Tensor, # d_norm_weight + torch.Tensor, # d_norm_bias + None, + None, + None, + None, + torch.Tensor, # d_uvqk_weight + torch.Tensor, # d_uvqk_bias + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ]: + x, norm_weight, norm_bias, x_mean, x_rstd, uvqk_weight, seq_offsets = ( + ctx.saved_tensors[:7] + ) + idx = 7 + if ctx.has_multiple_targets: + num_targets = ctx.saved_tensors[idx] + idx += 1 + else: + num_targets = None + if ctx.recompute_normed_x_in_backward: + normed_x, _, _ = triton_weighted_layer_norm_fwd( + x=x, + weight=norm_weight, + bias=norm_bias, + eps=ctx.norm_eps, + mean=x_mean, + rstd=x_rstd, + ) + else: + normed_x = ctx.saved_tensors[idx] + idx += 1 + if ctx.recompute_uvqk_in_backward: + uvqk_bias = ctx.saved_tensors[idx] + uvqk = maybe_triton_addmm_fwd(x=normed_x, w=uvqk_weight, y=uvqk_bias) + idx += 1 + else: + uvqk = ctx.saved_tensors[idx] + idx += 1 + if ctx.sort_by_length: + sort_by_length_indices = ctx.saved_tensors[idx] + else: + sort_by_length_indices = None + + duvqk = torch.empty_like(uvqk) + du, dv, dq, dk = duvqk.split( + [ + ctx.hidden_dim * ctx.num_heads, + ctx.hidden_dim * ctx.num_heads, + ctx.attn_dim * ctx.num_heads, + ctx.attn_dim * ctx.num_heads, + ], + dim=1, + ) + u, v, q, k = uvqk.split( + [ + ctx.hidden_dim * ctx.num_heads, + ctx.hidden_dim * ctx.num_heads, + ctx.attn_dim * ctx.num_heads, + ctx.attn_dim * ctx.num_heads, + ], + dim=1, + ) + q = q.view(-1, ctx.num_heads, ctx.attn_dim) + k = k.view(-1, ctx.num_heads, ctx.attn_dim) + v = v.view(-1, ctx.num_heads, ctx.hidden_dim) + dq = dq.view(-1, ctx.num_heads, ctx.attn_dim) + dk = dk.view(-1, ctx.num_heads, ctx.attn_dim) + dv = dv.view(-1, ctx.num_heads, ctx.hidden_dim) + # Note: the operation below updates duvqk in place + triton_hstu_attention_bwd( + dout=dout, + q=q, + k=k, + v=v, + dq=dq, + dk=dk, + dv=dv, + seq_offsets=seq_offsets, + num_targets=num_targets, + N=ctx.max_seq_len, + max_attn_len=ctx.max_attn_len, + alpha=ctx.attn_alpha, + contextual_seq_len=ctx.contextual_seq_len, + sort_by_length_indices=sort_by_length_indices, + enable_tma=ctx.enable_tma, + num_softmax_heads=ctx.num_softmax_heads, + ) + torch.ops.aten.silu_backward(dsilu_u, u, grad_input=du) + d_normed_x, d_uvqk_weight, d_uvqk_bias = triton_addmm_bwd( + x=normed_x, + w=uvqk_weight, + dz=duvqk, + is_y_1d=ctx.uvqk_bias_1d, + ) + d_x, d_norm_weight, d_norm_bias = triton_weighted_layer_norm_bwd( + dy=d_normed_x, + x=x, + weight=norm_weight, + bias=norm_bias, + mean=x_mean, + rstd=x_rstd, + learnable=True, + eps=ctx.norm_eps, + BLOCK_D=ctx.norm_BLOCK_D, + ) + # pyre-ignore[7] + return ( + d_x, + d_norm_weight, + d_norm_bias, + None, + None, + None, + None, + d_uvqk_weight, + d_uvqk_bias, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def triton_hstu_preprocess_and_attention( + x: torch.Tensor, + norm_weight: torch.Tensor, + norm_bias: torch.Tensor, + norm_eps: float, + num_heads: int, + attn_dim: int, + hidden_dim: int, + uvqk_weight: torch.Tensor, + uvqk_bias: torch.Tensor, + max_seq_len: int, + seq_offsets: torch.Tensor, + attn_alpha: float, + num_targets: Optional[torch.Tensor], + max_attn_len: int = 0, + contextual_seq_len: int = 0, + recompute_uvqk_in_backward: bool = False, + recompute_normed_x_in_backward: bool = False, + sort_by_length: bool = False, + enable_tma: Optional[bool] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + # When the caller does not specify enable_tma, auto-detect whether the + # TMA / TLX fast path is safe on this device. Resolving here (vs inside + # the autograd Function.forward) keeps a concrete bool flowing through + # ctx.save_for_backward / ctx attributes. + if enable_tma is None: + enable_tma = _should_enable_tma() + return _HSTUPreprocessAndAttentionFunction.apply( + x, + norm_weight, + norm_bias, + norm_eps, + num_heads, + attn_dim, + hidden_dim, + uvqk_weight, + uvqk_bias, + max_seq_len, + seq_offsets, + attn_alpha, + num_targets, + max_attn_len, + contextual_seq_len, + recompute_uvqk_in_backward, + recompute_normed_x_in_backward, + sort_by_length, + enable_tma, + ) diff --git a/recommendation_v4/generative_recommenders/ops/triton/triton_jagged.py b/recommendation_v4/generative_recommenders/ops/triton/triton_jagged.py new file mode 100644 index 000000000..3f5609d75 --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/triton/triton_jagged.py @@ -0,0 +1,2537 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-unsafe + +#!/usr/bin/env python3 + + +import os +from typing import List, Optional, Tuple + +import torch + +# @manual=//triton:triton +import triton + +# @manual=//triton:triton +import triton.language as tl +from generative_recommenders.common import ( + autotune_max_seq_len, + fine_grained_autotune_max_seq_len, + switch_to_contiguous_if_needed, + triton_autotune, +) +from generative_recommenders.ops.triton._autotune_pinning import pinned_or_full +from generative_recommenders.ops.utils import is_sm100_plus, is_sm90 +from torch._inductor.runtime import triton_helpers + +try: + torch.ops.load_library( + "//generative_recommenders/fb/ultra/ops/hopper/jagged_dense_bmm_add:jagged_dense_bmm_add" + ) +except OSError: + pass + +CUDA_JAGGED_DENSE_BMM_FWD = False +CUDA_JAGGED_DENSE_BMM_BWD = False + +SPLIT_2D_JAGGED_KERNEL = None +GLN_MUL_DROPOUT_KERNEL = None +CONCAT_2D_JAGGED_KERNEL = None + + +def set_cuda_jagged_dense_bmm_fwd(value: bool) -> None: + global CUDA_JAGGED_DENSE_BMM_FWD + CUDA_JAGGED_DENSE_BMM_FWD = value + + +def get_cuda_jagged_dense_bmm_fwd() -> bool: + # currently only supports H100 + return CUDA_JAGGED_DENSE_BMM_FWD and is_sm90() + + +def set_cuda_jagged_dense_bmm_bwd(value: bool) -> None: + global CUDA_JAGGED_DENSE_BMM_BWD + CUDA_JAGGED_DENSE_BMM_BWD = value + + +def get_cuda_jagged_dense_bmm_bwd() -> bool: + # currently only supports H100 + return CUDA_JAGGED_DENSE_BMM_BWD and is_sm90() + + +def set_split_2d_jagged_kernel(value: Optional[str]) -> None: + global SPLIT_2D_JAGGED_KERNEL + SPLIT_2D_JAGGED_KERNEL = value + + +def get_split_2d_jagged_kernel() -> Optional[str]: + # only override during training + if torch.is_grad_enabled(): + return SPLIT_2D_JAGGED_KERNEL + return None + + +def set_concat_2d_jagged_kernel(value: Optional[str]) -> None: + global CONCAT_2D_JAGGED_KERNEL + CONCAT_2D_JAGGED_KERNEL = value + + +def get_concat_2d_jagged_kernel() -> Optional[str]: + # only override during training + if torch.is_grad_enabled(): + return CONCAT_2D_JAGGED_KERNEL + return None + + +def _should_use_multirow() -> bool: + """Check if multirow kernel should be used based on current hardware. + + Can be overridden via the JAGGED_USE_MULTIROW_MI350 environment variable: + JAGGED_USE_MULTIROW_MI350=1 -> force multirow on + JAGGED_USE_MULTIROW_MI350=0 -> force multirow off + unset -> auto-detect based on hardware (SM100+ or MI350) + """ + env = os.environ.get("JAGGED_USE_MULTIROW_MI350") + if env is not None: + return env == "1" + return is_sm100_plus() + + +def set_gln_mul_dropout_kernel(value: Optional[str]) -> None: + global GLN_MUL_DROPOUT_KERNEL + GLN_MUL_DROPOUT_KERNEL = value + + +def get_gln_mul_dropout_kernel() -> Optional[str]: + # only override during training + return GLN_MUL_DROPOUT_KERNEL + + +def _triton_concat_2D_jagged_internal( + values_a: torch.Tensor, + values_b: torch.Tensor, + values_out: torch.Tensor, + max_seq_len: int, + B: int, + offsets_a: Optional[torch.Tensor], + offsets_b: Optional[torch.Tensor], + D: int, + dense_size: int, + stride_dense_batch: int, + n_prefix: int, + is_dense_a: bool, + is_dense_b: bool, + is_replace: bool, + BLOCK_D: int, +) -> None: + use_multirow = _should_use_multirow() + if n_prefix != 0: + if use_multirow: + + def grid(meta): + return (triton.cdiv(max_seq_len, meta["BLOCK_N"]), B) + + concat_2D_jagged_jagged_w_prefix_multirow[grid]( + OffsetsA=offsets_a, + ValuesA=values_a, + OffsetsB=offsets_b, + ValuesB=values_b, + Out=values_out, + D=D, + stride_ad=values_a.stride(-2), + stride_bd=values_b.stride(-2), + stride_od=values_out.stride(0), + n_prefix_from_B=n_prefix, + BLOCK_D=BLOCK_D, + ) + else: + concat_2D_jagged_jagged_w_prefix[(max_seq_len, B)]( + OffsetsA=offsets_a, + ValuesA=values_a, + OffsetsB=offsets_b, + ValuesB=values_b, + Out=values_out, + D=D, + stride_ad=values_a.stride(-2), + stride_bd=values_b.stride(-2), + stride_od=values_out.stride(0), + n_prefix_from_B=n_prefix, + BLOCK_D=BLOCK_D, # pyre-ignore[6] + ) + else: + if use_multirow: + + def grid(meta): + return (triton.cdiv(max_seq_len, meta["BLOCK_N"]), B) + + concat_2D_jagged_multirow[grid]( + OffsetsA=offsets_a, + ValuesA=values_a, + OffsetsB=offsets_b, + ValuesB=values_b, + DenseSize=dense_size, + Out=values_out, + D=D, + stride_ad=values_a.stride(-2), + stride_bd=values_b.stride(-2), + stride_dense_batch=stride_dense_batch, + stride_od=values_out.stride(0), + IS_DENSE_A=is_dense_a, # pyre-ignore[6] + IS_DENSE_B=is_dense_b, # pyre-ignore[6] + BLOCK_D=BLOCK_D, # pyre-ignore[6] + IS_REPLACE=is_replace, # pyre-ignore[6] + ) + else: + concat_2D_jagged[(max_seq_len, B)]( + OffsetsA=offsets_a, + ValuesA=values_a, + OffsetsB=offsets_b, + ValuesB=values_b, + DenseSize=dense_size, + Out=values_out, + D=D, + stride_ad=values_a.stride(-2), + stride_bd=values_b.stride(-2), + stride_dense_batch=stride_dense_batch, + stride_od=values_out.stride(0), + IS_DENSE_A=is_dense_a, # pyre-ignore[6] + IS_DENSE_B=is_dense_b, # pyre-ignore[6] + BLOCK_D=BLOCK_D, # pyre-ignore[6] + IS_REPLACE=is_replace, # pyre-ignore[6] + ) + + +def _get_split_concat_2d_jagged_multirow_configs() -> List[triton.Config]: + configs = [] + for BLOCK_N in [1, 2, 4, 8]: + for num_warps in [1, 2, 4]: + configs.append( + triton.Config( + {"BLOCK_N": BLOCK_N}, + num_warps=num_warps, + ) + ) + return configs + + +def _get_split_concat_2d_jagged_multirow_configs_wrapper() -> List[triton.Config]: + # Use extended config space only when JAGGED_USE_MULTIROW_MI350 is explicitly set, + # otherwise fall back to the existing configs to avoid breaking autotune. + if os.environ.get("JAGGED_USE_MULTIROW_MI350") is not None: + configs = [] + # Extended config space for MI350 tuning + # - BLOCK_N: number of rows processed per block + # - num_warps: number of warps (AMD wavefront = 64 threads) + # - num_stages: software pipeline depth for memory latency hiding + # NOTE: num_stages=0 is invalid for AMD GPUs, start from 1 + # - waves_per_eu: AMD-specific, controls occupancy (waves per execution unit) + for BLOCK_N in [1, 2, 4, 8, 16, 32]: + for num_warps in [1, 2, 4, 8, 16, 32]: + for num_stages in [1, 2, 3, 4]: + for waves_per_eu in [0, 1, 2, 4]: + configs.append( + triton.Config( + {"BLOCK_N": BLOCK_N, "waves_per_eu": waves_per_eu}, + num_warps=num_warps, + num_stages=num_stages, + ) + ) + return configs + return _get_split_concat_2d_jagged_multirow_configs() + + +def _get_bmm_configs() -> List[triton.Config]: + configs = [] + for BLOCK_M in [64, 128]: + for BLOCK_N in [64, 128, 256]: + for BLOCK_K in [32, 64]: + for num_stages in [3, 5]: + for num_warps in [4, 8]: + configs.append( + triton.Config( + { + "BLOCK_M": BLOCK_M, + "BLOCK_N": BLOCK_N, + "BLOCK_K": BLOCK_K, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + return configs + + +@triton_autotune( + configs=_get_bmm_configs(), + key=["AUTOTUNE_MAX_SEQ_LEN", "N", "K", "ELEMENTWISE", "HAS_BIAS"], +) +@triton.jit +def jagged_dense_bmm_broadcast_add_kernel( + seq_offsets, + Jagged, + Dense, + Bias, + Out, + AUTOTUNE_MAX_SEQ_LEN, + N, + K, + stride_jm, + stride_db, + stride_dk, + stride_dn, + stride_bias_b, + stride_om, + HAS_BIAS: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ELEMENTWISE: tl.constexpr, +): + """ + Computing bmm Out = Jagged x Dense + Bias + M is the jagged dimension + Jagged has shape (sum_B(M_i), K), Dense has shape (B, K, N), Bias has shape (B, N), and Out has shape (sum_B(M_i), N) + """ + + off_n = tl.program_id(0) + off_m = tl.program_id(1).to(tl.int64) + off_b = tl.program_id(2) + + seq_start = tl.load(seq_offsets + off_b).to(tl.int64) + seq_end = tl.load(seq_offsets + off_b + 1) + seq_len = seq_end - seq_start + start_m = off_m * BLOCK_M + start_n = off_n * BLOCK_N + if start_m >= seq_len: + return + + Jagged += (seq_start + start_m) * stride_jm + Dense += off_b.to(tl.int64) * stride_db + Out += seq_start * stride_om + + offs_m = tl.arange(0, BLOCK_M) + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + jg_ptrs = Jagged + offs_m[:, None] * stride_jm + offs_k[None, :] + dn_ptrs = Dense + offs_k[:, None] * stride_dk + offs_n[None, :] * stride_dn + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, K, BLOCK_K): + jg = tl.load( + jg_ptrs, + # pyre-fixme[16]: `int` has no attribute `__getitem__`. + mask=(offs_m[:, None] < (seq_len - start_m)) & ((k + offs_k)[None, :] < K), + other=0.0, + ) + dn = tl.load( + dn_ptrs, + mask=((k + offs_k)[:, None] < K) & (offs_n[None, :] < N), + other=0.0, + ) + accumulator += tl.dot(jg, dn, allow_tf32=ALLOW_TF32) + jg_ptrs += BLOCK_K + dn_ptrs += BLOCK_K * stride_dk + + if HAS_BIAS: + if ELEMENTWISE: + Bias += (seq_start + start_m) * stride_bias_b + bias_ptrs = Bias + offs_m[:, None] * stride_bias_b + offs_n[None, :] + bias = tl.load( + bias_ptrs, + mask=(offs_m[:, None] < (seq_len - start_m)) & (offs_n[None, :] < N), + other=0.0, + ) + accumulator += bias.to(tl.float32) + else: + bias_ptrs = Bias + off_b.to(tl.int64) * stride_bias_b + offs_n + bias = tl.load(bias_ptrs, mask=offs_n < N) + accumulator += bias[None, :].to(tl.float32) + + out = accumulator.to(Out.dtype.element_ty) + + offs_m = tl.arange(0, BLOCK_M) + offs_n = start_n + tl.arange(0, BLOCK_N) + Out += start_m * stride_om + out_ptrs = Out + offs_m[:, None] * stride_om + offs_n[None, :] + tl.store( + out_ptrs, + out, + mask=(offs_m[:, None] < (seq_len - start_m)) & (offs_n[None, :] < N), + ) + + +def _get_bmm_reduce_sum_configs() -> List[triton.Config]: + configs = [] + for BLOCK_M in [64, 128]: + for BLOCK_N in [64, 128]: + for BLOCK_K in [64, 128]: + for num_stages in [3, 4]: + for num_warps in [4, 8]: + configs.append( + triton.Config( + { + "BLOCK_M": BLOCK_M, + "BLOCK_N": BLOCK_N, + "BLOCK_K": BLOCK_K, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + return configs + + +@triton_autotune( + configs=_get_bmm_reduce_sum_configs(), + key=["M", "N", "AUTOTUNE_MAX_SEQ_LEN"], +) +@triton.jit +def _jagged_jagged_bmm_reduce_sum( + seq_offsets, + JaggedA, + JaggedB, + Out, + ReduceOut, + M, + N, + AUTOTUNE_MAX_SEQ_LEN, + stride_ak, + stride_bk, + stride_ob, + stride_om, + stride_on, + stride_orb, + stride_orn, + REDUCE_JAGGEDB: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + """ + Computing bmm Out = Jagged x Jagged + K is the jagged dimension + JaggedA has shape (sum_B(K_i), M), JaggedB has shape (sum_B(K_i), N), and Out has shape (B, M, N) + """ + + off_m = tl.program_id(0).to(tl.int64) + off_n = tl.program_id(1) + off_b = tl.program_id(2) + + seq_start = tl.load(seq_offsets + off_b).to(tl.int64) + seq_end = tl.load(seq_offsets + off_b + 1) + seq_len = seq_end - seq_start + + start_m = off_m * BLOCK_M + start_n = off_n * BLOCK_N + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + Out += off_b.to(tl.int64) * stride_ob + offs_m = tl.arange(0, BLOCK_M) + offs_n = start_n + tl.arange(0, BLOCK_N) + Out += start_m * stride_om + out_ptrs = Out + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + if REDUCE_JAGGEDB: + out_reduce_ptrs = ( + ReduceOut + off_b.to(tl.int64) * stride_orb + offs_n * stride_orn + ) + acc_reduce = tl.zeros((BLOCK_N,), dtype=tl.float32) + if seq_len == 0: + out = accumulator.to(Out.dtype.element_ty) + tl.store( + out_ptrs, + out, + mask=(offs_m[:, None] < (M - start_m)) & (offs_n[None, :] < N), + ) + if REDUCE_JAGGEDB: + if off_m == 0: + tl.store( + out_reduce_ptrs, # pyre-ignore [61] + acc_reduce.to(ReduceOut.dtype.element_ty), + mask=(offs_n < N), + ) + return + + JaggedA += seq_start * stride_ak + JaggedB += seq_start * stride_bk + offs_k = tl.arange(0, BLOCK_K) + jg_a_ptrs = JaggedA + offs_k[None, :] * stride_ak + (start_m + offs_m)[:, None] + jg_b_ptrs = JaggedB + offs_k[:, None] * stride_bk + offs_n[None, :] + + for k in range(0, seq_len, BLOCK_K): + jg_a = tl.load( + jg_a_ptrs, + # pyre-fixme[16]: `int` has no attribute `__getitem__`. + mask=(offs_m[:, None] < (M - start_m)) & ((k + offs_k)[None, :] < seq_len), + other=0.0, + ) + jg_b = tl.load( + jg_b_ptrs, + mask=(offs_n[None, :] < N) & ((k + offs_k)[:, None] < seq_len), + other=0.0, + ) + + accumulator += tl.dot(jg_a, jg_b, allow_tf32=ALLOW_TF32) + if REDUCE_JAGGEDB: + if off_m == 0: + acc_reduce += tl.sum(jg_b.to(tl.float32), axis=0) + + jg_a_ptrs += BLOCK_K * stride_ak + jg_b_ptrs += BLOCK_K * stride_bk + + out = accumulator.to(Out.dtype.element_ty) + tl.store( + out_ptrs, + out, + mask=(offs_m[:, None] < (M - start_m)) & (offs_n[None, :] < N), + ) + if REDUCE_JAGGEDB: + if off_m == 0: + tl.store( + out_reduce_ptrs, # pyre-ignore [61] + acc_reduce.to(ReduceOut.dtype.element_ty), + mask=(offs_n < N), + ) + + +class _JaggedDenseBmmFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, + ): + jagged = switch_to_contiguous_if_needed(jagged) + L, D = jagged.shape + B, _, K = dense.shape + bmm_out = torch.empty((L, K), dtype=jagged.dtype, device=jagged.device) + + grid = lambda meta: ( # noqa E731 + triton.cdiv(K, meta["BLOCK_N"]), + triton.cdiv(max_seq_len, meta["BLOCK_M"]), + B, + ) + + jagged_dense_bmm_broadcast_add_kernel[grid]( + seq_offsets=seq_offsets, + Jagged=jagged, + Dense=dense, + Bias=0, + Out=bmm_out, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(max_seq_len), + N=K, + K=D, + stride_jm=jagged.stride(0), + stride_db=dense.stride(0), + stride_dk=dense.stride(1), + stride_dn=dense.stride(2), + stride_bias_b=0, + stride_om=bmm_out.stride(0), + HAS_BIAS=False, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + ELEMENTWISE=False, + ) + + ctx.save_for_backward(seq_offsets, jagged, dense) + ctx.B = B + ctx.max_seq_len = max_seq_len + ctx.K = K + ctx.D = D + return bmm_out + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, d_bmm_out: torch.Tensor + ) -> Tuple[None, None, torch.Tensor, torch.Tensor]: + seq_offsets, jagged, dense = ctx.saved_tensors + d_jagged = torch.empty_like(jagged) + d_dense = torch.empty_like(dense) + + grid = lambda meta: ( # noqa E731 + triton.cdiv(ctx.D, meta["BLOCK_N"]), + triton.cdiv(ctx.max_seq_len, meta["BLOCK_M"]), + ctx.B, + ) + jagged_dense_bmm_broadcast_add_kernel[grid]( + seq_offsets=seq_offsets, + Jagged=d_bmm_out, + Dense=dense, + Bias=None, + Out=d_jagged, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(ctx.max_seq_len), + N=ctx.D, + K=ctx.K, + stride_jm=d_bmm_out.stride(0), + stride_db=dense.stride(0), + stride_dk=dense.stride(2), + stride_dn=dense.stride(1), + stride_bias_b=0, + stride_om=d_jagged.stride(0), + HAS_BIAS=False, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + ELEMENTWISE=False, + ) + + grid = lambda meta: ( # noqa E731 + triton.cdiv(ctx.D, meta["BLOCK_M"]), + triton.cdiv(ctx.K, meta["BLOCK_N"]), + ctx.B, + ) + _jagged_jagged_bmm_reduce_sum[grid]( + seq_offsets=seq_offsets, + JaggedA=jagged, + JaggedB=d_bmm_out, + Out=d_dense, + ReduceOut=None, + M=ctx.D, + N=ctx.K, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(ctx.max_seq_len), + stride_ak=jagged.stride(0), + stride_bk=d_bmm_out.stride(0), + stride_ob=d_dense.stride(0), + stride_om=d_dense.stride(1), + stride_on=d_dense.stride(2), + stride_orb=0, + stride_orn=0, + REDUCE_JAGGEDB=False, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + ) + + return None, None, d_jagged, d_dense + + +def _get_jagged_dense_broadcast_add_configs() -> List[triton.Config]: + configs = [] + for BLOCK_N in [16, 32, 64]: + for num_stages in [1, 2]: + for num_warps in [2, 4, 8]: + configs.append( + triton.Config( + { + "BLOCK_N": BLOCK_N, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + return configs + + +@triton_autotune( + configs=_get_jagged_dense_broadcast_add_configs(), + key=["AUTOTUNE_MAX_SEQ_LEN"], +) +@triton.jit +def jagged_dense_broadcast_add_kernel( + seq_offsets, + Jagged, + Dense, + Out, + AUTOTUNE_MAX_SEQ_LEN, + D, + stride_jn, + stride_db, + stride_on, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + """ + Computing Out = Jagged + Dense + JaggedA has shape (sum_B(N_i), D), Dense has shape (B, D), and Out has shape (sum_B(N_i), D) + """ + + off_b = tl.program_id(0) + off_n = tl.program_id(1) + seq_start = tl.load(seq_offsets + off_b) + seq_end = tl.load(seq_offsets + off_b + 1) + seq_len = seq_end - seq_start + start_n = off_n * BLOCK_N + if start_n >= seq_len: + return + Jagged += seq_start * stride_jn + Dense += off_b * stride_db + Out += seq_start * stride_on + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + jagged_ptrs = Jagged + offs_n[:, None] * stride_jn + offs_d[None, :] + dense_ptrs = Dense + offs_d + out_ptrs = Out + offs_n[:, None] * stride_jn + offs_d[None, :] + for d in range(0, D, BLOCK_D): + jg = tl.load( + jagged_ptrs, + # pyre-fixme[16]: `int` has no attribute `__getitem__`. + mask=(offs_n[:, None] < seq_len) & ((d + offs_d)[None, :] < D), + ) + dn = tl.load(dense_ptrs, mask=d + offs_d < D) + out = jg + dn[None, :] + tl.store( + out_ptrs, + out, + mask=(offs_n[:, None] < seq_len) & ((d + offs_d)[None, :] < D), + ) + dense_ptrs += BLOCK_D + jagged_ptrs += BLOCK_D + out_ptrs += BLOCK_D + + +@triton.jit +def jagged_reduce_sum( + seq_offsets, + Jagged, + Out, + D, + stride_jn, + stride_ob, + BLOCK_D: tl.constexpr, +): + """ + Computing Out = Jagged + Dense + JaggedA has shape (sum_B(N_i), D), Dense has shape (B, D), and Out has shape (sum_B(N_i), D) + """ + off_b = tl.program_id(0) + off_d = tl.program_id(1) * BLOCK_D + seq_start = tl.load(seq_offsets + off_b) + seq_end = tl.load(seq_offsets + off_b + 1) + seq_len = seq_end - seq_start + Jagged += seq_start * stride_jn + Out += off_b * stride_ob + offs_d = off_d + tl.arange(0, BLOCK_D) + jagged_ptrs = Jagged + offs_d + out_ptrs = Out + offs_d + accumulator = tl.zeros((BLOCK_D,), dtype=tl.float32) + for _ in range(0, seq_len): + jg = tl.load( + jagged_ptrs, + mask=offs_d < D, + ) + accumulator += jg + jagged_ptrs += stride_jn + out = accumulator.to(Out.dtype.element_ty) + tl.store( + out_ptrs, + out, + mask=offs_d < D, + ) + + +class _JaggedDenseBroadcastAddFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, + ): + jagged = switch_to_contiguous_if_needed(jagged) + dense = switch_to_contiguous_if_needed(dense) + L, D = jagged.shape + B, _ = dense.shape + out = torch.empty_like(jagged) + + grid = lambda meta: ( # noqa E731 + B, + triton.cdiv(max_seq_len, meta["BLOCK_N"]), + ) + BLOCK_D = triton.next_power_of_2(D) if D < 64 else 64 + jagged_dense_broadcast_add_kernel[grid]( + seq_offsets=seq_offsets, + Jagged=jagged, + Dense=dense, + Out=out, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(max_seq_len), + D=D, + stride_jn=jagged.stride(0), + stride_db=dense.stride(0), + stride_on=out.stride(0), + BLOCK_D=BLOCK_D, + ) + + ctx.save_for_backward(seq_offsets) + ctx.max_seq_len = max_seq_len + ctx.B = B + ctx.D = D + return out + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, d_out: torch.Tensor + ) -> Tuple[None, None, torch.Tensor, torch.Tensor]: + seq_offsets = ctx.saved_tensors[0] + d_dense = torch.empty((ctx.B, ctx.D), device=d_out.device, dtype=d_out.dtype) + BLOCK_D = triton.next_power_of_2(ctx.D) if ctx.D < 64 else 64 + jagged_reduce_sum[(ctx.B, triton.cdiv(ctx.D, BLOCK_D))]( + seq_offsets=seq_offsets, + Jagged=d_out, + Out=d_dense, + D=ctx.D, + stride_jn=d_out.stride(0), + stride_ob=d_dense.stride(0), + BLOCK_D=BLOCK_D, + ) + return None, None, d_out, d_dense + + +def triton_jagged_dense_bmm_add_fwd( + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, + bias: torch.Tensor, + elementwise: bool = False, +) -> Tuple[torch.Tensor, int, int, int]: + jagged = switch_to_contiguous_if_needed(jagged) + bias = switch_to_contiguous_if_needed(bias) + L, K = jagged.shape + B, _, N = dense.shape + out = torch.empty((L, N), dtype=jagged.dtype, device=jagged.device) + + grid = lambda meta: ( # noqa E731 + triton.cdiv(N, meta["BLOCK_N"]), + triton.cdiv(max_seq_len, meta["BLOCK_M"]), + B, + ) + + jagged_dense_bmm_broadcast_add_kernel[grid]( + seq_offsets=seq_offsets, + Jagged=jagged, + Dense=dense, + Bias=bias, + Out=out, + AUTOTUNE_MAX_SEQ_LEN=fine_grained_autotune_max_seq_len(max_seq_len), + N=N, + K=K, + stride_jm=jagged.stride(0), + stride_db=dense.stride(0), + stride_dk=dense.stride(1), + stride_dn=dense.stride(2), + stride_bias_b=bias.stride(0), + stride_om=out.stride(0), + HAS_BIAS=True, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + ELEMENTWISE=elementwise, + ) + + return out, B, K, N + + +def triton_jagged_dense_bmm_add_bwd_jagged( + max_seq_len: int, + seq_offsets: torch.Tensor, + d_jagged: torch.Tensor, + dense: torch.Tensor, + d_out: torch.Tensor, + K: int, + B: int, + N: int, +) -> torch.Tensor: + grid = lambda meta: ( # noqa E731 + triton.cdiv(K, meta["BLOCK_N"]), + triton.cdiv(max_seq_len, meta["BLOCK_M"]), + B, + ) + jagged_dense_bmm_broadcast_add_kernel[grid]( + seq_offsets=seq_offsets, + Jagged=d_out, + Dense=dense, + Bias=None, + Out=d_jagged, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(max_seq_len), + N=K, + K=N, + stride_jm=d_out.stride(0), + stride_db=dense.stride(0), + stride_dk=dense.stride(2), + stride_dn=dense.stride(1), + stride_bias_b=0, + stride_om=d_jagged.stride(0), + HAS_BIAS=False, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + ELEMENTWISE=False, + ) + + return d_jagged + + +def triton_jagged_dense_bmm_add_bwd_dense_bias( + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + d_dense: torch.Tensor, + B: int, + K: int, + N: int, + d_out: torch.Tensor, + elementwise: bool, +) -> Tuple[torch.Tensor, torch.Tensor]: + d_bias = torch.empty((B, N), device=d_out.device, dtype=d_out.dtype) + + grid = lambda meta: ( # noqa E731 + triton.cdiv(K, meta["BLOCK_M"]), + triton.cdiv(N, meta["BLOCK_N"]), + B, + ) + + if elementwise: + d_bias = d_out + reduce_out = None + stride_orb = 0 + stride_orn = 0 + reduce_jaggedb = False + else: + reduce_out = d_bias + stride_orb = d_bias.stride(0) + stride_orn = d_bias.stride(1) + reduce_jaggedb = True + + _jagged_jagged_bmm_reduce_sum[grid]( + seq_offsets=seq_offsets, + JaggedA=jagged, + JaggedB=d_out, + Out=d_dense, + ReduceOut=reduce_out, + M=K, + N=N, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(max_seq_len), + stride_ak=jagged.stride(0), + stride_bk=d_out.stride(0), + stride_ob=d_dense.stride(0), + stride_om=d_dense.stride(1), + stride_on=d_dense.stride(2), + stride_orb=stride_orb, + stride_orn=stride_orn, + REDUCE_JAGGEDB=reduce_jaggedb, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + ) + + return d_dense, d_bias + + +class _JaggedDenseBmmAddFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, + bias: torch.Tensor, + elementwise: bool = False, + ): + if get_cuda_jagged_dense_bmm_fwd(): + jagged = switch_to_contiguous_if_needed(jagged) + bias = switch_to_contiguous_if_needed(bias) + # Ensure bias has same dtype as jagged (required by CUDA kernel) + bias = bias.to(jagged.dtype) + # Ensure seq_offsets is int64 (required by CUDA kernel) + seq_offsets = seq_offsets.to(torch.int64) + _, K = jagged.shape + B, _, N = dense.shape + out = torch.ops.jagged_dense_bmm_broadcast_add.jagged_dense_bmm_broadcast_add_fwd( + max_seq_len, seq_offsets, jagged, dense, bias, elementwise + ) + else: + out, B, K, N = triton_jagged_dense_bmm_add_fwd( + max_seq_len, seq_offsets, jagged, dense, bias, elementwise + ) + + ctx.save_for_backward(seq_offsets, jagged, dense) + ctx.B = B + ctx.max_seq_len = max_seq_len + ctx.K = K + ctx.N = N + ctx.elementwise = elementwise + return out + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, d_out: torch.Tensor + ) -> Tuple[None, None, torch.Tensor, torch.Tensor, torch.Tensor, None]: + seq_offsets, jagged, dense = ctx.saved_tensors + if get_cuda_jagged_dense_bmm_bwd(): + d_jagged, d_dense, d_bias = ( + torch.ops.jagged_dense_bmm_broadcast_add.jagged_dense_bmm_broadcast_add_bwd( + ctx.max_seq_len, + d_out, + seq_offsets.to(torch.int64), + jagged, + dense, + ctx.elementwise, + ) + ) + else: + d_jagged = triton_jagged_dense_bmm_add_bwd_jagged( + ctx.max_seq_len, + seq_offsets, + torch.empty_like(jagged), + dense, + d_out, + ctx.K, + ctx.B, + ctx.N, + ) + d_dense, d_bias = triton_jagged_dense_bmm_add_bwd_dense_bias( + ctx.max_seq_len, + seq_offsets, + jagged, + torch.empty_like(dense), + ctx.B, + ctx.K, + ctx.N, + d_out, + ctx.elementwise, + ) + + return None, None, d_jagged, d_dense, d_bias, None + + +@triton.jit +def concat_2D_jagged_w_prefix( + OffsetsA, + ValuesA, + OffsetsB, + ValuesB, + DenseSize, + Out, + D, + stride_ad, + stride_bd, + stride_dense_batch, + stride_od, + n_prefix_from_B, # nonzero is not supported when IS_REPLACE=True + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + IS_REPLACE: tl.constexpr, +): + off_z = tl.program_id(1) + off_n = tl.program_id(0) + if IS_DENSE_A: + seq_start_a = off_z * DenseSize + seq_len_a = DenseSize + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_len_b = seq_end_b - seq_start_b + elif IS_DENSE_B: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + seq_start_b = off_z * DenseSize + seq_len_b = DenseSize + else: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_len_b = seq_end_b - seq_start_b + + if IS_REPLACE: + seq_len = seq_len_a + else: + seq_len = seq_len_a + seq_len_b + if off_n >= seq_len: + return + + offs_d = tl.arange(0, BLOCK_D) + if IS_REPLACE: + out_seq_start = seq_start_a + off_n + out_seq_b_start = seq_len_a - seq_len_b + else: + out_seq_start = seq_start_a + seq_start_b + off_n + out_seq_b_start = seq_len_a + n_prefix_from_B + + out_ptrs = Out + out_seq_start.to(tl.int64) * stride_od + offs_d + if off_n < out_seq_b_start and off_n >= n_prefix_from_B: + off_a = off_n - n_prefix_from_B + if IS_DENSE_A: + in_ptrs = ( + ValuesA + + off_a.to(tl.int64) * stride_ad + + off_z.to(tl.int64) * stride_dense_batch + + offs_d + ) + else: + in_ptrs = ValuesA + (off_a + seq_start_a).to(tl.int64) * stride_ad + offs_d + else: + off_b = off_n - out_seq_b_start + n_prefix_from_B + if off_n < n_prefix_from_B: + off_b += out_seq_b_start - n_prefix_from_B + if IS_DENSE_B: + in_ptrs = ( + ValuesB + + off_b.to(tl.int64) * stride_bd + + off_z.to(tl.int64) * stride_dense_batch + + offs_d + ) + else: + in_ptrs = ValuesB + (off_b + seq_start_b).to(tl.int64) * stride_bd + offs_d + v = tl.load(in_ptrs, mask=offs_d < D) + tl.store(out_ptrs, v, mask=offs_d < D) + + +@triton.jit +def concat_2D_jagged( + OffsetsA, + ValuesA, + OffsetsB, + ValuesB, + DenseSize, + Out, + D, + stride_ad, + stride_bd, + stride_dense_batch, + stride_od, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + IS_REPLACE: tl.constexpr, +): + concat_2D_jagged_w_prefix( + OffsetsA, + ValuesA, + OffsetsB, + ValuesB, + DenseSize, + Out, + D, + stride_ad, + stride_bd, + stride_dense_batch, + stride_od, + 0, + IS_DENSE_A, + IS_DENSE_B, + BLOCK_D, + IS_REPLACE, + ) + + +@triton.jit +def concat_2D_jagged_jagged_w_prefix( + OffsetsA, + ValuesA, + OffsetsB, + ValuesB, + Out, + D, + stride_ad, + stride_bd, + stride_od, + n_prefix_from_B, + BLOCK_D: tl.constexpr, +): + concat_2D_jagged_w_prefix( + OffsetsA, + ValuesA, + OffsetsB, + ValuesB, + 0, + Out, + D, + stride_ad, + stride_bd, + 0, + stride_od, + n_prefix_from_B, + IS_DENSE_A=False, + IS_DENSE_B=False, + BLOCK_D=BLOCK_D, + IS_REPLACE=False, + ) + + +@triton.jit +def split_2D_jagged_w_prefix( + JaggedIn, + DenseSize, + OffsetsA, + OffsetsB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + n_prefix_to_B, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + IS_REPLACE: tl.constexpr, +): + off_z = tl.program_id(1) + off_n = tl.program_id(0) + if IS_DENSE_A: + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_start_a = off_z * DenseSize + seq_len_a = DenseSize + seq_len_b = seq_end_b - seq_start_b + elif IS_DENSE_B: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + seq_start_b = off_z * DenseSize + seq_len_b = DenseSize + else: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_len_b = seq_end_b - seq_start_b + if IS_REPLACE: + seq_len = seq_len_a + else: + seq_len = seq_len_a + seq_len_b + if off_n >= seq_len: + return + + if IS_REPLACE: + seq_start = seq_start_a + out_seq_b_start = seq_len_a - seq_len_b + else: + seq_start = seq_start_a + seq_start_b + out_seq_b_start = seq_len_a + n_prefix_to_B + + offs_d = tl.arange(0, BLOCK_D) + in_ptrs = JaggedIn + (seq_start + off_n).to(tl.int64) * stride_id + offs_d + if off_n < out_seq_b_start and off_n >= n_prefix_to_B: + off_a = off_n - n_prefix_to_B + out_ptrs = OutA + (off_a + seq_start_a).to(tl.int64) * stride_ad + offs_d + else: + off_b = off_n - out_seq_b_start + n_prefix_to_B + if off_n < n_prefix_to_B: + off_b += out_seq_b_start - n_prefix_to_B + out_ptrs = OutB + (off_b + seq_start_b).to(tl.int64) * stride_bd + offs_d + v = tl.load(in_ptrs, mask=offs_d < D) + tl.store(out_ptrs, v, mask=offs_d < D) + + +@triton.jit +def split_2D_jagged( + JaggedIn, + DenseSize, + OffsetsA, + OffsetsB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + IS_REPLACE: tl.constexpr, +): + split_2D_jagged_w_prefix( + JaggedIn, + DenseSize, + OffsetsA, + OffsetsB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + 0, + IS_DENSE_A, + IS_DENSE_B, + BLOCK_D, + IS_REPLACE, + ) + + +@triton.jit +def split_2D_jagged_jagged_w_prefix( + JaggedIn, + OffsetsA, + OffsetsB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + n_prefix_to_B, + BLOCK_D: tl.constexpr, +): + split_2D_jagged_w_prefix( + JaggedIn, + 0, + OffsetsA, + OffsetsB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + n_prefix_to_B, + IS_DENSE_A=False, + IS_DENSE_B=False, + BLOCK_D=BLOCK_D, + IS_REPLACE=False, + ) + + +def _triton_split_2D_jagged_internal( + jagged_in: torch.Tensor, + max_seq_len: int, + B: int, + offsets_a: Optional[torch.Tensor], + offsets_b: Optional[torch.Tensor], + out_a: torch.Tensor, + out_b: torch.Tensor, + D: int, + dense_size: int, + n_prefix: int, + is_dense_a: bool, + is_dense_b: bool, + is_replace: bool, + BLOCK_D: int, +) -> None: + use_multirow = _should_use_multirow() + if n_prefix != 0: + if use_multirow: + + def grid(meta): + return (triton.cdiv(max_seq_len, meta["BLOCK_N"]), B) + + split_2D_jagged_jagged_w_prefix_multirow[grid]( + JaggedIn=jagged_in, + OffsetsA=offsets_a, + OffsetsB=offsets_b, + OutA=out_a, + OutB=out_b, + D=D, + stride_id=jagged_in.stride(0), + stride_ad=out_a.stride(0), + stride_bd=out_b.stride(0), + n_prefix_to_B=n_prefix, + BLOCK_D=BLOCK_D, + ) + else: + split_2D_jagged_jagged_w_prefix[(max_seq_len, B)]( + JaggedIn=jagged_in, + OffsetsA=offsets_a, + OffsetsB=offsets_b, + OutA=out_a, + OutB=out_b, + D=D, + stride_id=jagged_in.stride(0), + stride_ad=out_a.stride(0), + stride_bd=out_b.stride(0), + n_prefix_to_B=n_prefix, + BLOCK_D=BLOCK_D, # pyre-ignore[6] + ) + else: + if use_multirow: + + def grid(meta): + return (triton.cdiv(max_seq_len, meta["BLOCK_N"]), B) + + split_2D_jagged_multirow[grid]( + JaggedIn=jagged_in, + DenseSize=dense_size, + OffsetsA=offsets_a, + OffsetsB=offsets_b, + OutA=out_a, + OutB=out_b, + D=D, + stride_id=jagged_in.stride(0), + stride_ad=out_a.stride(0), + stride_bd=out_b.stride(0), + IS_DENSE_A=is_dense_a, # pyre-ignore[6] + IS_DENSE_B=is_dense_b, # pyre-ignore[6] + BLOCK_D=BLOCK_D, # pyre-ignore[6] + IS_REPLACE=is_replace, # pyre-ignore[6] + ) + else: + split_2D_jagged[(max_seq_len, B)]( + JaggedIn=jagged_in, + DenseSize=dense_size, + OffsetsA=offsets_a, + OffsetsB=offsets_b, + OutA=out_a, + OutB=out_b, + D=D, + stride_id=jagged_in.stride(0), + stride_ad=out_a.stride(0), + stride_bd=out_b.stride(0), + IS_DENSE_A=is_dense_a, # pyre-ignore[6] + IS_DENSE_B=is_dense_b, # pyre-ignore[6] + BLOCK_D=BLOCK_D, # pyre-ignore[6] + IS_REPLACE=is_replace, # pyre-ignore[6] + ) + + +class _Concat2DJaggedFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + max_seq_len: int, + values_a: torch.Tensor, + values_b: torch.Tensor, + offsets_a: Optional[torch.Tensor] = None, + offsets_b: Optional[torch.Tensor] = None, + is_replace: bool = False, + n_prefix_from_right: int = 0, + ): + values_a = switch_to_contiguous_if_needed(values_a) + values_b = switch_to_contiguous_if_needed(values_b) + is_dense_a = offsets_a is None + is_dense_b = offsets_b is None + dense_size: int = 0 + if is_dense_a: + assert offsets_b is not None + B, dense_size, D = values_a.shape + seq_len_a = dense_size * B + seq_len_b, _ = values_b.shape + device = values_b.device + dtype = values_b.dtype + stride_dense_batch = values_a.stride(0) + elif is_dense_b: + assert offsets_a is not None + B, dense_size, D = values_b.shape + seq_len_a, _ = values_a.shape + seq_len_b = dense_size * B + device = values_a.device + dtype = values_a.dtype + stride_dense_batch = values_b.stride(0) + else: + assert offsets_a is not None and offsets_b is not None + B = offsets_a.shape[0] - 1 + seq_len_a, D = values_a.shape + seq_len_b, _ = values_b.shape + device = values_a.device + dtype = values_a.dtype + stride_dense_batch = 0 + + BLOCK_D = triton.next_power_of_2(D) + if is_replace: + values_out = torch.empty_like(values_a) + else: + values_out = torch.empty( + (seq_len_a + seq_len_b, D), device=device, dtype=dtype + ) + _triton_concat_2D_jagged_internal( + values_a=values_a, + values_b=values_b, + values_out=values_out, + max_seq_len=max_seq_len, + B=B, + offsets_a=offsets_a, + offsets_b=offsets_b, + D=D, + dense_size=dense_size, + stride_dense_batch=stride_dense_batch, + n_prefix=n_prefix_from_right, + is_dense_a=is_dense_a, + is_dense_b=is_dense_b, + is_replace=is_replace, + BLOCK_D=BLOCK_D, + ) + ctx.save_for_backward(offsets_a, offsets_b) + ctx.max_seq_len = max_seq_len + ctx.seq_len_a = seq_len_a + ctx.seq_len_b = seq_len_b + ctx.is_dense_a = is_dense_a + ctx.is_dense_b = is_dense_b + ctx.dense_size = dense_size + ctx.is_replace = is_replace + ctx.n_prefix_from_right = n_prefix_from_right + return values_out + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, d_out: torch.Tensor + ) -> Tuple[None, torch.Tensor, torch.Tensor, None, None, None, None]: + offsets_a, offsets_b = ctx.saved_tensors + is_dense_a, is_dense_b, is_replace = ( + ctx.is_dense_a, + ctx.is_dense_b, + ctx.is_replace, + ) + dense_size = ctx.dense_size + if is_dense_a: + B = offsets_b.shape[0] - 1 + else: + B = offsets_a.shape[0] - 1 + _, D = d_out.shape + BLOCK_D = triton.next_power_of_2(D) + values_a = torch.zeros( + (ctx.seq_len_a, D), device=d_out.device, dtype=d_out.dtype + ) + values_b = torch.empty( + (ctx.seq_len_b, D), device=d_out.device, dtype=d_out.dtype + ) + _triton_split_2D_jagged_internal( + jagged_in=d_out, + max_seq_len=ctx.max_seq_len, + B=B, + offsets_a=offsets_a, + offsets_b=offsets_b, + out_a=values_a, + out_b=values_b, + D=D, + dense_size=dense_size, + n_prefix=ctx.n_prefix_from_right, + is_dense_a=is_dense_a, + is_dense_b=is_dense_b, + is_replace=is_replace, + BLOCK_D=BLOCK_D, + ) + + if is_dense_a: + values_a = values_a.reshape((B, dense_size, D)) + elif is_dense_b: + values_b = values_b.reshape((B, dense_size, D)) + return None, values_a, values_b, None, None, None, None + + +class _HelionConcat2DJaggedFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + max_seq_len: int, + values_a: torch.Tensor, + values_b: torch.Tensor, + offsets_a: Optional[torch.Tensor] = None, + offsets_b: Optional[torch.Tensor] = None, + ): + values_a = switch_to_contiguous_if_needed(values_a) + values_b = switch_to_contiguous_if_needed(values_b) + + assert offsets_a is not None and offsets_b is not None + B = offsets_a.shape[0] - 1 + seq_len_a, D = values_a.shape + seq_len_b, _ = values_b.shape + device = values_a.device + dtype = values_a.dtype + + BLOCK_D = triton.next_power_of_2(D) + values_out = torch.empty((seq_len_a + seq_len_b, D), device=device, dtype=dtype) + _triton_concat_2D_jagged_internal( + values_a=values_a, + values_b=values_b, + values_out=values_out, + max_seq_len=max_seq_len, + B=B, + offsets_a=offsets_a, + offsets_b=offsets_b, + D=D, + dense_size=0, + stride_dense_batch=0, + n_prefix=0, + is_dense_a=False, + is_dense_b=False, + is_replace=False, + BLOCK_D=BLOCK_D, + ) + ctx.save_for_backward(offsets_a, offsets_b) + ctx.max_seq_len = max_seq_len + ctx.seq_len_a = seq_len_a + ctx.seq_len_b = seq_len_b + return values_out + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, d_out: torch.Tensor + ) -> Tuple[None, torch.Tensor, torch.Tensor, None, None, None, None]: + offsets_a, offsets_b = ctx.saved_tensors + d_out = switch_to_contiguous_if_needed(d_out) + values_a, values_b = _helion_split_2D_jagged_impl( + values=d_out, + max_seq_len=ctx.max_seq_len, + offsets_a=offsets_a, + offsets_b=offsets_b, + dense_size=0, + total_len_a=ctx.seq_len_a, + total_len_b=ctx.seq_len_b, + ) + + return None, values_a, values_b, None, None, None, None + + +class _Split2DJaggedFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + values: torch.Tensor, + max_seq_len: int, + offsets_a: Optional[torch.Tensor] = None, + offsets_b: Optional[torch.Tensor] = None, + dense_size: int = 0, + n_prefix_to_right: int = 0, + seq_len_a: Optional[int] = None, + seq_len_b: Optional[int] = None, + total_seq_len: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + values = switch_to_contiguous_if_needed(values) + is_dense_a: bool = offsets_a is None + is_dense_b: bool = offsets_b is None + if is_dense_a: + L, _ = values.shape + assert offsets_b is not None + B = offsets_b.shape[0] - 1 + seq_len_a = dense_size * B + seq_len_b = L - seq_len_a + offsets_a = offsets_b.new_empty(0) + elif is_dense_b: + L, _ = values.shape + assert offsets_a is not None + B = offsets_a.shape[0] - 1 + seq_len_b = dense_size * B + seq_len_a = L - seq_len_b + offsets_b = offsets_a.new_empty(0) + else: + assert offsets_a is not None and offsets_b is not None + B = offsets_a.shape[0] - 1 + + # Select the last offset item using torch.index_select instead of + # "int(offsets_a[-1].item())" so that it won't cause "Cannot cast + # FakeTensor to python number" error for AOTI. + if torch.compiler.is_compiling(): + offsets_b_last_idx = torch.tensor(offsets_b.size(0) - 1).to( + offsets_b.device, non_blocking=True + ) + if seq_len_b is None: + seq_len_b = offsets_b.index_select(dim=0, index=offsets_b_last_idx) + if seq_len_a is None and total_seq_len is None: + offsets_a_last_idx = torch.tensor(offsets_a.size(0) - 1).to( + offsets_a.device, non_blocking=True + ) + seq_len_a = offsets_a.index_select(dim=0, index=offsets_a_last_idx) + else: + if seq_len_b is None: + seq_len_b = int(offsets_b[-1].item()) + if seq_len_a is None and total_seq_len is None: + seq_len_a = int(offsets_a[-1].item()) + _, D = values.shape + BLOCK_D = triton.next_power_of_2(D) + # pyre-ignore[6] Incompatible parameter type + values_b = torch.empty((seq_len_b, D), device=values.device, dtype=values.dtype) + if seq_len_a is None: + # Derive seq_len_a from total_seq_len and values_b.size(0). + # values_b.size(0) is a SymInt (from the torch.empty above), + # so this is SymInt arithmetic — no new unbacked SymInt. + assert total_seq_len is not None + seq_len_a = total_seq_len - values_b.size(0) + # pyre-ignore[6] Incompatible parameter type + values_a = torch.empty((seq_len_a, D), device=values.device, dtype=values.dtype) + _triton_split_2D_jagged_internal( + jagged_in=values, + max_seq_len=max_seq_len, + B=B, + offsets_a=offsets_a, + offsets_b=offsets_b, + out_a=values_a, + out_b=values_b, + D=D, + dense_size=dense_size, + n_prefix=n_prefix_to_right, + is_dense_a=is_dense_a, + is_dense_b=is_dense_b, + is_replace=False, + BLOCK_D=BLOCK_D, + ) + if is_dense_a: + values_a = values_a.reshape(B, dense_size, D) + if is_dense_b: + values_b = values_b.reshape(B, dense_size, D) + ctx.save_for_backward(offsets_a, offsets_b) + ctx.max_seq_len = max_seq_len + ctx.seq_len_a = seq_len_a + ctx.seq_len_b = seq_len_b + ctx.is_dense_a = is_dense_a + ctx.is_dense_b = is_dense_b + ctx.dense_size = dense_size + ctx.B = B + ctx.D = D + ctx.n_prefix_to_right = n_prefix_to_right + return values_a, values_b + + @staticmethod + def backward( + ctx, *d_values + ) -> Tuple[torch.Tensor, None, None, None, None, None, None, None, None]: + offsets_a, offsets_b = ctx.saved_tensors + is_dense_a, is_dense_b = ctx.is_dense_a, ctx.is_dense_b + values_a, values_b = d_values + if is_dense_a: + stride_dense_batch = values_a.stride(0) + elif is_dense_b: + stride_dense_batch = values_b.stride(0) + else: + stride_dense_batch = 0 + + BLOCK_D = triton.next_power_of_2(ctx.D) + dvalues = torch.empty( + (ctx.seq_len_a + ctx.seq_len_b, ctx.D), + device=values_a.device, + dtype=values_b.dtype, + ) + _triton_concat_2D_jagged_internal( + values_a=values_a, + values_b=values_b, + values_out=dvalues, + max_seq_len=ctx.max_seq_len, + B=ctx.B, + offsets_a=offsets_a, + offsets_b=offsets_b, + D=ctx.D, + dense_size=ctx.dense_size, + stride_dense_batch=stride_dense_batch, + n_prefix=ctx.n_prefix_to_right, + is_dense_a=is_dense_a, + is_dense_b=is_dense_b, + is_replace=False, + BLOCK_D=BLOCK_D, + ) + + return dvalues, None, None, None, None, None, None, None, None + + +@torch.jit.unused +@torch.fx.wrap +def triton_jagged_dense_bmm_add( + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, + bias: torch.Tensor, + elementwise: bool = False, +) -> torch.Tensor: + """ + Computing bmm Out = Jagged x Dense + Bias + M is the jagged dimension + Jagged has shape (sum_B(M_i), K), Dense has shape (B, K, N), Bias has shape (B, N) or (sum_B(M_i), N) depending on Elementwise, and Out has shape (sum_B(M_i), N) + """ + return _JaggedDenseBmmAddFunction.apply( + max_seq_len, seq_offsets, jagged, dense, bias, elementwise + ) + + +@torch.fx.wrap +def triton_concat_2D_jagged( + max_seq_len: int, + values_a: torch.Tensor, + values_b: torch.Tensor, + offsets_a: Optional[torch.Tensor] = None, + offsets_b: Optional[torch.Tensor] = None, + is_replace: bool = False, + n_prefix_from_right: int = 0, +) -> torch.Tensor: + return _Concat2DJaggedFunction.apply( + max_seq_len, + values_a, + values_b, + offsets_a, + offsets_b, + is_replace, + n_prefix_from_right, + ) + + +@torch.fx.wrap +def triton_concat_2D_jagged_jagged( + max_seq_len_left: int, + offsets_left: torch.Tensor, + values_left: torch.Tensor, + max_seq_len_right: int, + offsets_right: torch.Tensor, + values_right: torch.Tensor, + is_replace: bool, + n_prefix_from_right: int, +) -> torch.Tensor: + return triton_concat_2D_jagged( + max_seq_len=max_seq_len_left + max_seq_len_right, + values_a=values_left, + values_b=values_right, + offsets_a=offsets_left, + offsets_b=offsets_right, + is_replace=is_replace, + n_prefix_from_right=n_prefix_from_right, + ) + + +@torch.fx.wrap +def helion_concat_2D_jagged( + max_seq_len: int, + values_a: torch.Tensor, + values_b: torch.Tensor, + offsets_a: Optional[torch.Tensor] = None, + offsets_b: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return _HelionConcat2DJaggedFunction.apply( + max_seq_len, + values_a, + values_b, + offsets_a, + offsets_b, + ) + + +@torch.fx.wrap +def triton_concat_2D_dense_jagged( + jagged_max_seq_len: int, + jagged_offsets: torch.Tensor, + jagged_values: torch.Tensor, + dense_values: torch.Tensor, +) -> torch.Tensor: + B, dense_size, D = dense_values.size() + max_seq_len = jagged_max_seq_len + dense_size + return triton_concat_2D_jagged( + max_seq_len=max_seq_len, + values_a=dense_values, + values_b=jagged_values, + offsets_a=None, + offsets_b=jagged_offsets, + ) + + +def triton_jagged_dense_bmm( + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, +) -> torch.Tensor: + return _JaggedDenseBmmFunction.apply(max_seq_len, seq_offsets, jagged, dense) + + +@torch.jit.unused +def triton_split_2D_jagged( + values: torch.Tensor, + max_seq_len: int, + offsets_a: Optional[torch.Tensor] = None, + offsets_b: Optional[torch.Tensor] = None, + dense_size: int = 0, + n_prefix_to_right: int = 0, + seq_len_a: Optional[int] = None, + seq_len_b: Optional[int] = None, + total_seq_len: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _Split2DJaggedFunction.apply( + values, + max_seq_len, + offsets_a, + offsets_b, + dense_size, + n_prefix_to_right, + seq_len_a, + seq_len_b, + total_seq_len, + ) + + +@torch.jit.unused +def helion_split_2D_jagged( + values: torch.Tensor, + max_seq_len: int, + offsets_a: Optional[torch.Tensor] = None, + offsets_b: Optional[torch.Tensor] = None, + dense_size: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _HelionSplit2DJaggedFunction.apply( + values, + max_seq_len, + offsets_a, + offsets_b, + dense_size, + ) + + +@triton.jit +def concat_2D_jagged_w_prefix_multirow( + OffsetsA, + ValuesA, + OffsetsB, + ValuesB, + DenseSize, + Out, + D, + stride_ad, + stride_bd, + stride_dense_batch, + stride_od, + n_prefix_from_B, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_REPLACE: tl.constexpr, +): + off_z = tl.program_id(1) + off_block_n = tl.program_id(0) + + if IS_DENSE_A: + seq_start_a = off_z * DenseSize + seq_len_a = DenseSize + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_len_b = seq_end_b - seq_start_b + elif IS_DENSE_B: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + seq_start_b = off_z * DenseSize + seq_len_b = DenseSize + else: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_len_b = seq_end_b - seq_start_b + + if IS_REPLACE: + seq_len = seq_len_a + out_seq_start = seq_start_a + out_seq_b_start = seq_len_a - seq_len_b + else: + seq_len = seq_len_a + seq_len_b + out_seq_start = seq_start_a + seq_start_b + out_seq_b_start = seq_len_a + n_prefix_from_B + + start_n = off_block_n * BLOCK_N + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + if start_n >= seq_len: + return + valid_mask = offs_n < seq_len + + out_ptrs = ( + Out + + (out_seq_start + offs_n[:, None]).to(tl.int64) * stride_od + + offs_d[None, :] + ) + + to_a_mask = (offs_n < out_seq_b_start) & (offs_n >= n_prefix_from_B) & valid_mask + to_b_mask = ~to_a_mask & valid_mask + + off_a = offs_n - n_prefix_from_B + if IS_DENSE_A: + in_a_ptrs = ( + ValuesA + + off_a[:, None].to(tl.int64) * stride_ad + + off_z.to(tl.int64) * stride_dense_batch + + offs_d[None, :] + ) + else: + in_a_ptrs = ( + ValuesA + + (off_a[:, None] + seq_start_a).to(tl.int64) * stride_ad + + offs_d[None, :] + ) + + v_a = tl.load(in_a_ptrs, mask=to_a_mask[:, None] & (offs_d[None, :] < D), other=0.0) + tl.store(out_ptrs, v_a, mask=to_a_mask[:, None] & (offs_d[None, :] < D)) + + prefix_mask = offs_n < n_prefix_from_B + + off_b = tl.where(prefix_mask, offs_n, offs_n - out_seq_b_start + n_prefix_from_B) + if IS_DENSE_B: + in_b_ptrs = ( + ValuesB + + off_b[:, None].to(tl.int64) * stride_bd + + off_z.to(tl.int64) * stride_dense_batch + + offs_d[None, :] + ) + else: + in_b_ptrs = ( + ValuesB + + (off_b[:, None] + seq_start_b).to(tl.int64) * stride_bd + + offs_d[None, :] + ) + + v_b = tl.load(in_b_ptrs, mask=to_b_mask[:, None] & (offs_d[None, :] < D), other=0.0) + tl.store(out_ptrs, v_b, mask=to_b_mask[:, None] & (offs_d[None, :] < D)) + + +@triton_autotune( + configs=_get_split_concat_2d_jagged_multirow_configs_wrapper(), + key=["BLOCK_D"], +) +@triton.jit +def concat_2D_jagged_multirow( + OffsetsA, + ValuesA, + OffsetsB, + ValuesB, + DenseSize, + Out, + D, + stride_ad, + stride_bd, + stride_dense_batch, + stride_od, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_REPLACE: tl.constexpr, +): + concat_2D_jagged_w_prefix_multirow( + OffsetsA, + ValuesA, + OffsetsB, + ValuesB, + DenseSize, + Out, + D, + stride_ad, + stride_bd, + stride_dense_batch, + stride_od, + 0, + IS_DENSE_A, + IS_DENSE_B, + BLOCK_D, + BLOCK_N, + IS_REPLACE, + ) + + +@triton_autotune( + configs=_get_split_concat_2d_jagged_multirow_configs(), + key=["BLOCK_D"], +) +@triton.jit +def concat_2D_jagged_jagged_w_prefix_multirow( + OffsetsA, + ValuesA, + OffsetsB, + ValuesB, + Out, + D, + stride_ad, + stride_bd, + stride_od, + n_prefix_from_B, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, +): + concat_2D_jagged_w_prefix_multirow( + OffsetsA, + ValuesA, + OffsetsB, + ValuesB, + 0, + Out, + D, + stride_ad, + stride_bd, + 0, + stride_od, + n_prefix_from_B, + IS_DENSE_A=False, + IS_DENSE_B=False, + BLOCK_D=BLOCK_D, + BLOCK_N=BLOCK_N, + IS_REPLACE=False, + ) + + +@triton.jit +def split_2D_jagged_w_prefix_multirow( + JaggedIn, + DenseSize, + OffsetsA, + OffsetsB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + n_prefix_to_B, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_REPLACE: tl.constexpr, +): + off_z = tl.program_id(1) + off_block_n = tl.program_id(0) + + if IS_DENSE_A: + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_start_a = off_z * DenseSize + seq_len_a = DenseSize + seq_len_b = seq_end_b - seq_start_b + elif IS_DENSE_B: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + seq_start_b = off_z * DenseSize + seq_len_b = DenseSize + else: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_len_b = seq_end_b - seq_start_b + + if IS_REPLACE: + seq_len = seq_len_a + else: + seq_len = seq_len_a + seq_len_b + + if IS_REPLACE: + seq_start = seq_start_a + out_seq_b_start = seq_len_a - seq_len_b + else: + seq_start = seq_start_a + seq_start_b + out_seq_b_start = seq_len_a + n_prefix_to_B + + start_n = off_block_n * BLOCK_N + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + if start_n >= seq_len: + return + valid_mask = offs_n < seq_len + + in_ptrs = ( + JaggedIn + + (seq_start + offs_n[:, None]).to(tl.int64) * stride_id + + offs_d[None, :] + ) + + v = tl.load(in_ptrs, mask=valid_mask[:, None] & (offs_d[None, :] < D), other=0.0) + + to_a_mask = (offs_n < out_seq_b_start) & (offs_n >= n_prefix_to_B) & valid_mask + to_b_mask = ~to_a_mask & valid_mask + + off_a = offs_n - n_prefix_to_B + out_a_ptrs = ( + OutA + (off_a[:, None] + seq_start_a).to(tl.int64) * stride_ad + offs_d[None, :] + ) + tl.store(out_a_ptrs, v, mask=to_a_mask[:, None] & (offs_d[None, :] < D)) + + prefix_mask = offs_n < n_prefix_to_B + + off_b = tl.where(prefix_mask, offs_n, offs_n - out_seq_b_start + n_prefix_to_B) + out_b_ptrs = ( + OutB + (off_b[:, None] + seq_start_b).to(tl.int64) * stride_bd + offs_d[None, :] + ) + tl.store(out_b_ptrs, v, mask=to_b_mask[:, None] & (offs_d[None, :] < D)) + + +@triton_autotune( + configs=pinned_or_full( + [triton.Config({"BLOCK_N": 8}, num_warps=1)], + _get_split_concat_2d_jagged_multirow_configs_wrapper, + ), + key=["BLOCK_D"], +) +@triton.jit +def split_2D_jagged_multirow( + JaggedIn, + DenseSize, + OffsetsA, + OffsetsB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_REPLACE: tl.constexpr, +): + split_2D_jagged_w_prefix_multirow( + JaggedIn, + DenseSize, + OffsetsA, + OffsetsB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + 0, + IS_DENSE_A, + IS_DENSE_B, + BLOCK_D, + BLOCK_N, + IS_REPLACE, + ) + + +@triton_autotune( + configs=_get_split_concat_2d_jagged_multirow_configs(), + key=["BLOCK_D"], +) +@triton.jit +def split_2D_jagged_jagged_w_prefix_multirow( + JaggedIn, + OffsetsA, + OffsetsB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + n_prefix_to_B, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, +): + split_2D_jagged_w_prefix_multirow( + JaggedIn, + 0, + OffsetsA, + OffsetsB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + n_prefix_to_B, + IS_DENSE_A=False, + IS_DENSE_B=False, + BLOCK_D=BLOCK_D, + BLOCK_N=BLOCK_N, + IS_REPLACE=False, + ) + + +def triton_jagged_dense_broadcast_add( + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, +) -> torch.Tensor: + return _JaggedDenseBroadcastAddFunction.apply( + max_seq_len, seq_offsets, jagged, dense + ) + + +@triton.jit +def _helion_split_2d_jagged_kernel( + offsets_a, + offsets_b, + values_flat, + out_a_flat, + out_b_flat, + max_seq_len, + D: tl.constexpr, + _BLOCK_SIZE_0: tl.constexpr, + _BLOCK_SIZE_1: tl.constexpr, +) -> None: + # Get program ID and decompose to batch and sequence block coordinates + program_id = tl.program_id(0) + flat_program_id = program_id + batch_id = triton_helpers.div_floor_integer( + flat_program_id, + triton_helpers.div_floor_integer( + -1 + _BLOCK_SIZE_0 + max_seq_len, _BLOCK_SIZE_0 + ), + ) + seq_block_id = triton_helpers.remainder_integer( # noqa: F841 + flat_program_id, + triton_helpers.div_floor_integer( + -1 + _BLOCK_SIZE_0 + max_seq_len, _BLOCK_SIZE_0 + ), + ) + # Load output boundaries for part A + out_a_start = tl.load(offsets_a + batch_id * 1, None, eviction_policy="evict_last") + batch_id_plus_1 = 1 + triton_helpers.div_floor_integer( + flat_program_id, + triton_helpers.div_floor_integer( + -1 + _BLOCK_SIZE_0 + max_seq_len, _BLOCK_SIZE_0 + ), + ) + out_a_end = tl.load( + offsets_a + batch_id_plus_1 * 1, None, eviction_policy="evict_last" + ) + len_a = out_a_end - out_a_start + # Load output boundaries for part B + out_b_start = tl.load(offsets_b + batch_id * 1, None) + out_b_end = tl.load( + offsets_b + batch_id_plus_1 * 1, None, eviction_policy="evict_last" + ) + len_b = out_b_end - out_b_start + # Compute input start and total length for this batch + input_start = out_a_start + out_b_start + total_len = len_a + len_b + # Calculate sequence offset for this block + seq_offset = _BLOCK_SIZE_0 * triton_helpers.remainder_integer( + flat_program_id, + triton_helpers.div_floor_integer( + -1 + _BLOCK_SIZE_0 + max_seq_len, _BLOCK_SIZE_0 + ), + ) + has_work = total_len > seq_offset + if has_work: + # Generate row indices for this sequence block + seq_range = tl.arange(0, _BLOCK_SIZE_0) + seq_offset_i32 = tl.cast(seq_offset, tl.int32) + row_indices = seq_range + seq_offset_i32 + + # Create masks for valid rows and parts A/B + total_len_i32 = tl.cast(total_len[None], tl.int32) + len_a_i32 = tl.cast(len_a[None], tl.int32) + valid_mask = row_indices < total_len_i32 + is_part_a = row_indices < len_a_i32 + is_part_b = (row_indices >= len_a_i32) & valid_mask + + # Extract scalar values once + input_start_i32 = tl.cast(input_start[None, None], tl.int32) + out_a_start_i32 = tl.cast(out_a_start[None, None], tl.int32) + out_b_start_i32 = tl.cast(out_b_start[None, None], tl.int32) + + # Process features in smaller tiles + for feature_offset in tl.range( + 0, + D, + _BLOCK_SIZE_1, + loop_unroll_factor=1, + num_stages=4, + disallow_acc_multi_buffer=True, + flatten=True, + ): + feature_indices = feature_offset + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + + # Compute D constant and feature mask once per feature iteration + D_const = tl.full([], tl.cast(D, tl.int32), tl.int32) + D_i32 = tl.cast(D, tl.int32) + feature_mask = feature_indices < D_i32 + + # Compute indices for part A + row_subscript = row_indices[:, None] + input_row_a = input_start_i32 + row_subscript + input_idx_a = ( + tl.cast(input_row_a * D_const, tl.int32) + feature_indices[None, :] + ) + + out_a_row = out_a_start_i32 + row_subscript + out_a_idx = ( + tl.cast(out_a_row * D_const, tl.int32) + feature_indices[None, :] + ) + + mask_a = is_part_a[:, None] & valid_mask[:, None] & feature_mask[None, :] + + # Load and store part A data + slice_a = tl.load( + values_flat + input_idx_a * 1, + mask_a, + other=0, + eviction_policy="evict_first", + ) + tl.store(out_a_flat + out_a_idx * 1, slice_a, mask_a) + + # Compute indices for part B + input_idx_b = ( + tl.cast((input_start_i32 + row_subscript) * D_const, tl.int32) + + feature_indices[None, :] + ) + + row_minus_len_a = row_subscript - len_a_i32 + out_b_row = out_b_start_i32 + row_minus_len_a + out_b_idx = ( + tl.cast(out_b_row * D_const, tl.int32) + feature_indices[None, :] + ) + + mask_b = is_part_b[:, None] & feature_mask[None, :] + + # Load and store part B data + slice_b = tl.load( + values_flat + input_idx_b * 1, + mask_b, + other=0, + eviction_policy="evict_first", + ) + tl.store(out_b_flat + out_b_idx * 1, slice_b, mask_b) + + +class _HelionSplit2DJaggedFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + values: torch.Tensor, + max_seq_len: int, + offsets_a: torch.Tensor, + offsets_b: torch.Tensor, + dense_size: int = 0, # noqa: F841 + ) -> Tuple[torch.Tensor, torch.Tensor]: + values = switch_to_contiguous_if_needed(values) + B = offsets_a.shape[0] - 1 + D = values.size(1) + + # TODO: maybe check if torch.compiler.is_compiling() and use index_select instead + seq_len_a = int(offsets_a[-1].item()) + seq_len_b = int(offsets_b[-1].item()) + + values_a, values_b = _helion_split_2D_jagged_impl( + values=values, + max_seq_len=max_seq_len, + offsets_a=offsets_a, + offsets_b=offsets_b, + dense_size=dense_size, + total_len_a=seq_len_a, + total_len_b=seq_len_b, + ) + + ctx.save_for_backward(offsets_a, offsets_b) + ctx.max_seq_len = max_seq_len + ctx.seq_len_a = seq_len_a + ctx.seq_len_b = seq_len_b + ctx.dense_size = dense_size + ctx.B = B + ctx.D = D + return values_a, values_b + + @staticmethod + def backward(ctx, *d_values) -> Tuple[torch.Tensor, None, None, None, None]: + offsets_a, offsets_b = ctx.saved_tensors + values_a, values_b = d_values + BLOCK_D = triton.next_power_of_2(ctx.D) + + dvalues = torch.empty( + (ctx.seq_len_a + ctx.seq_len_b, ctx.D), + device=values_a.device, + dtype=values_a.dtype, + ) + _triton_concat_2D_jagged_internal( + values_a=values_a, + values_b=values_b, + values_out=dvalues, + max_seq_len=ctx.max_seq_len, + B=ctx.B, + offsets_a=offsets_a, + offsets_b=offsets_b, + D=ctx.D, + dense_size=0, + stride_dense_batch=0, + n_prefix=0, + is_dense_a=False, + is_dense_b=False, + is_replace=False, + BLOCK_D=BLOCK_D, + ) + return dvalues, None, None, None, None + + +def _helion_split_2D_jagged_impl( + values: torch.Tensor, + max_seq_len: int, + offsets_a: torch.Tensor, + offsets_b: torch.Tensor, + dense_size: int = 0, # noqa: F841 + total_len_a: Optional[int] = None, + total_len_b: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + D = values.size(1) + + # Select dtype-specific optimal parameters + if values.dtype == torch.float32: + # FP32-optimized parameters + block_size_0 = 64 + block_size_1 = 64 + num_warps = 4 + num_stages = 4 + else: + # BF16/FP16-optimized parameters + block_size_0 = 128 + block_size_1 = triton.next_power_of_2(D) + num_warps = 32 + num_stages = 7 + + return _helion_split_2d_jagged( + values, + max_seq_len, + offsets_a, + offsets_b, + dense_size, + block_size_0=block_size_0, + block_size_1=block_size_1, + num_warps=num_warps, + num_stages=num_stages, + total_len_a=total_len_a, + total_len_b=total_len_b, + ) + + +def _helion_split_2d_jagged( + values: torch.Tensor, + max_seq_len: int, + offsets_a: torch.Tensor, + offsets_b: torch.Tensor, + dense_size: int, # noqa: F841 + block_size_0: int = 64, + block_size_1: int = 64, + num_warps: int = 4, + num_stages: int = 4, + total_len_a: Optional[int] = None, + total_len_b: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + values = values.contiguous() + num_batches = offsets_a.size(0) - 1 + D = values.size(1) + num_seq_blocks = (max_seq_len + block_size_0 - 1) // block_size_0 + if total_len_a is None: + total_len_a = int(offsets_a[-1].item()) + if total_len_b is None: + total_len_b = int(offsets_b[-1].item()) + out_a = torch.empty([total_len_a, D], dtype=values.dtype, device=values.device) + out_b = torch.empty([total_len_b, D], dtype=values.dtype, device=values.device) + values_flat = values.view(-1) + out_a_flat = out_a.view(-1) + out_b_flat = out_b.view(-1) + total_programs = num_batches * num_seq_blocks + + # pyre-ignore[28] + _helion_split_2d_jagged_kernel[(total_programs,)]( + offsets_a, + offsets_b, + values_flat, + out_a_flat, + out_b_flat, + max_seq_len, + D, + block_size_0, + block_size_1, + num_warps=num_warps, + num_stages=num_stages, + ) + return (out_a, out_b) diff --git a/recommendation_v4/generative_recommenders/ops/triton/triton_jagged_tensors.py b/recommendation_v4/generative_recommenders/ops/triton/triton_jagged_tensors.py new file mode 100644 index 000000000..2c3728c0a --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/triton/triton_jagged_tensors.py @@ -0,0 +1,1088 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-unsafe + +#!/usr/bin/env python3 + + +from typing import Optional, Tuple + +import torch + +# @manual=//triton:triton +import triton + +# @manual=//triton:triton +import triton.language as tl +from generative_recommenders.common import ( + switch_to_contiguous_if_needed, + triton_autotune, +) +from generative_recommenders.ops.utils import is_sm100_plus + + +def _triton_concat_2D_jagged_internal( + values_a: torch.Tensor, + values_b: torch.Tensor, + values_out: torch.Tensor, + max_seq_len: int, + B: int, + offsets_a: Optional[torch.Tensor], + offsets_b: Optional[torch.Tensor], + max_len_a: Optional[int], + max_len_b: Optional[int], + D: int, + n_prefix_from_B: int, + is_dense_a: bool, + is_dense_b: bool, + BLOCK_D: int, +) -> None: + if is_sm100_plus() or (torch.cuda.is_available() and torch.version.hip): + # Route AMD/ROCm through the multirow kernel. + # + # The basic `_concat_2D_jagged` kernel below issues one program per + # output row (grid = `(max_seq_len, B)`). On ROCm Triton this fails to + # lower in the `TritonAMDGPUCanonicalizePointers` pass with + # `RuntimeError: PassManager::run failed` at `make_ttgir`. The + # multirow variant tiles rows with a tunable `BLOCK_N` (grid = + # `(cdiv(max_seq_len, BLOCK_N), B)`) and compiles cleanly on ROCm. + # The original `is_sm100_plus()` gate was conservative — only Blackwell + # was opted in. Adding HIP keeps NVIDIA H100/A100 on the basic kernel + # they were validated against and unblocks AMD without behavior change + # on existing NVIDIA paths. + + def grid(meta): + return (triton.cdiv(max_seq_len, meta["BLOCK_N"]), B) + + concat_2D_jagged_multirow[grid]( + ValuesA=values_a, + ValuesB=values_b, + OffsetsA=offsets_a, + OffsetsB=offsets_b, + MaxLenA=max_len_a, + MaxLenB=max_len_b, + Out=values_out, + D=D, + stride_ad=values_a.stride(-2), + stride_bd=values_b.stride(-2), + stride_od=values_out.stride(-2), + n_prefix_from_B=n_prefix_from_B, + IS_DENSE_A=is_dense_a, # pyre-ignore[6] + IS_DENSE_B=is_dense_b, # pyre-ignore[6] + BLOCK_D=BLOCK_D, # pyre-ignore[6] + ) + else: + _concat_2D_jagged[(max_seq_len, B)]( + ValuesA=values_a, + ValuesB=values_b, + OffsetsA=offsets_a, + OffsetsB=offsets_b, + MaxLenA=max_len_a, + MaxLenB=max_len_b, + Out=values_out, + D=D, + stride_ad=values_a.stride(-2), + stride_bd=values_b.stride(-2), + stride_od=values_out.stride(-2), + n_prefix_from_B=n_prefix_from_B, + IS_DENSE_A=is_dense_a, # pyre-ignore[6] + IS_DENSE_B=is_dense_b, # pyre-ignore[6] + BLOCK_D=BLOCK_D, # pyre-ignore[6] + ) + + +def _triton_split_2D_jagged_internal( + jagged_in: torch.Tensor, + max_seq_len: int, + B: int, + offsets_a: Optional[torch.Tensor], + offsets_b: Optional[torch.Tensor], + max_len_a: Optional[int], + max_len_b: Optional[int], + out_a: torch.Tensor, + out_b: torch.Tensor, + D: int, + n_prefix_to_B: int, + is_dense_a: bool, + is_dense_b: bool, + BLOCK_D: int, +) -> None: + if is_sm100_plus() or (torch.cuda.is_available() and torch.version.hip): + # Route AMD/ROCm through the multirow kernel for the same reason as + # `_triton_concat_2D_jagged_internal` above: basic `_split_2D_jagged` + # hits `PassManager::run failed` in `TritonAMDGPUCanonicalizePointers` + # on ROCm; multirow lowers cleanly. + + def grid(meta): + return (triton.cdiv(max_seq_len, meta["BLOCK_N"]), B) + + split_2D_jagged_multirow[grid]( + JaggedIn=jagged_in, + OffsetsA=offsets_a, + OffsetsB=offsets_b, + MaxLenA=max_len_a, + MaxLenB=max_len_b, + OutA=out_a, + OutB=out_b, + D=D, + stride_id=jagged_in.stride(0), + stride_ad=out_a.stride(0), + stride_bd=out_b.stride(0), + n_prefix_to_B=n_prefix_to_B, + IS_DENSE_A=is_dense_a, # pyre-ignore[6] + IS_DENSE_B=is_dense_b, # pyre-ignore[6] + BLOCK_D=BLOCK_D, # pyre-ignore[6] + ) + else: + _split_2D_jagged[(max_seq_len, B)]( + JaggedIn=jagged_in, + OffsetsA=offsets_a, + OffsetsB=offsets_b, + MaxLenA=max_len_a, + MaxLenB=max_len_b, + OutA=out_a, + OutB=out_b, + D=D, + stride_id=jagged_in.stride(0), + stride_ad=out_a.stride(0), + stride_bd=out_b.stride(0), + n_prefix_to_B=n_prefix_to_B, + IS_DENSE_A=is_dense_a, # pyre-ignore[6] + IS_DENSE_B=is_dense_b, # pyre-ignore[6] + BLOCK_D=BLOCK_D, # pyre-ignore[6] + ) + + +def _get_concat_split_2d_jagged_multirow_configs(): + configs = [] + for BLOCK_N in [1, 2, 4, 8]: + for num_warps in [1, 2, 4]: + configs.append( + triton.Config( + {"BLOCK_N": BLOCK_N}, + num_warps=num_warps, + ) + ) + return configs + + +@triton.jit +def _concat_2D_jagged_multirow( + ValuesA, + ValuesB, + OffsetsA, + OffsetsB, + MaxLenA, + MaxLenB, + Out, + D, + stride_ad, + stride_bd, + stride_od, + n_prefix_from_B, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, +): + off_z = tl.program_id(1) + block_n = tl.program_id(0) + + if IS_DENSE_A: + seq_start_a = off_z * MaxLenA + seq_len_a = MaxLenA + else: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + if IS_DENSE_B: + seq_start_b = off_z * MaxLenB + seq_len_b = MaxLenB + else: + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_len_b = seq_end_b - seq_start_b + seq_len = seq_len_a + seq_len_b + + start_n = block_n * BLOCK_N + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + + valid_mask = offs_n < seq_len + + out_seq_start = seq_start_a + seq_start_b + offs_n + out_ptrs = Out + out_seq_start[:, None].to(tl.int64) * stride_od + offs_d[None, :] + + from_prefix_b_mask = (offs_n < n_prefix_from_B) & valid_mask + from_a_mask = ( + (offs_n >= n_prefix_from_B) + & (offs_n < seq_len_a + n_prefix_from_B) + & valid_mask + ) + from_suffix_b_mask = (offs_n >= seq_len_a + n_prefix_from_B) & valid_mask + + in_b1_ptrs = ( + ValuesB + + (offs_n[:, None] + seq_start_b).to(tl.int64) * stride_bd + + offs_d[None, :] + ) + v_b1 = tl.load( + in_b1_ptrs, mask=from_prefix_b_mask[:, None] & (offs_d[None, :] < D), other=0.0 + ) + tl.store(out_ptrs, v_b1, mask=from_prefix_b_mask[:, None] & (offs_d[None, :] < D)) + + off_a = offs_n - n_prefix_from_B + in_a_ptrs = ( + ValuesA + + (off_a[:, None] + seq_start_a).to(tl.int64) * stride_ad + + offs_d[None, :] + ) + v_a = tl.load( + in_a_ptrs, mask=from_a_mask[:, None] & (offs_d[None, :] < D), other=0.0 + ) + tl.store(out_ptrs, v_a, mask=from_a_mask[:, None] & (offs_d[None, :] < D)) + + off_b = offs_n - seq_len_a + in_b2_ptrs = ( + ValuesB + + (off_b[:, None] + seq_start_b).to(tl.int64) * stride_bd + + offs_d[None, :] + ) + v_b2 = tl.load( + in_b2_ptrs, mask=from_suffix_b_mask[:, None] & (offs_d[None, :] < D), other=0.0 + ) + tl.store(out_ptrs, v_b2, mask=from_suffix_b_mask[:, None] & (offs_d[None, :] < D)) + + +@triton_autotune( + configs=_get_concat_split_2d_jagged_multirow_configs(), + key=["BLOCK_D"], +) +@triton.jit +def concat_2D_jagged_multirow( + ValuesA, + ValuesB, + OffsetsA, + OffsetsB, + MaxLenA, + MaxLenB, + Out, + D, + stride_ad, + stride_bd, + stride_od, + n_prefix_from_B, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, +): + _concat_2D_jagged_multirow( + ValuesA, + ValuesB, + OffsetsA, + OffsetsB, + MaxLenA, + MaxLenB, + Out, + D, + stride_ad, + stride_bd, + stride_od, + n_prefix_from_B, + IS_DENSE_A, + IS_DENSE_B, + BLOCK_D, + BLOCK_N, + ) + + +@triton.jit +def _split_2D_jagged_multirow( + JaggedIn, + OffsetsA, + OffsetsB, + MaxLenA, + MaxLenB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + n_prefix_to_B, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, +): + off_z = tl.program_id(1) + block_n = tl.program_id(0) + + if IS_DENSE_A: + seq_start_a = off_z * MaxLenA + seq_len_a = MaxLenA + else: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + if IS_DENSE_B: + seq_start_b = off_z * MaxLenB + seq_len_b = MaxLenB + else: + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_len_b = seq_end_b - seq_start_b + seq_len = seq_len_a + seq_len_b + seq_start = seq_start_a + seq_start_b + + start_n = block_n * BLOCK_N + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + + valid_mask = offs_n < seq_len + + in_ptrs = ( + JaggedIn + + (seq_start + offs_n[:, None]).to(tl.int64) * stride_id + + offs_d[None, :] + ) + + v = tl.load(in_ptrs, mask=valid_mask[:, None] & (offs_d[None, :] < D), other=0.0) + + to_prefix_b_mask = (offs_n < n_prefix_to_B) & valid_mask + to_a_mask = ( + (offs_n >= n_prefix_to_B) & (offs_n < seq_len_a + n_prefix_to_B) & valid_mask + ) + to_suffix_b_mask = (offs_n >= seq_len_a + n_prefix_to_B) & valid_mask + + out_b1_ptrs = ( + OutB + + (offs_n[:, None] + seq_start_b).to(tl.int64) * stride_bd + + offs_d[None, :] + ) + tl.store(out_b1_ptrs, v, mask=to_prefix_b_mask[:, None] & (offs_d[None, :] < D)) + + off_a = offs_n - n_prefix_to_B + out_a_ptrs = ( + OutA + (off_a[:, None] + seq_start_a).to(tl.int64) * stride_ad + offs_d[None, :] + ) + tl.store(out_a_ptrs, v, mask=to_a_mask[:, None] & (offs_d[None, :] < D)) + + off_b = offs_n - seq_len_a + out_b2_ptrs = ( + OutB + (off_b[:, None] + seq_start_b).to(tl.int64) * stride_bd + offs_d[None, :] + ) + tl.store(out_b2_ptrs, v, mask=to_suffix_b_mask[:, None] & (offs_d[None, :] < D)) + + +@triton_autotune( + configs=_get_concat_split_2d_jagged_multirow_configs(), + key=["BLOCK_D"], +) +@triton.jit +def split_2D_jagged_multirow( + JaggedIn, + OffsetsA, + OffsetsB, + MaxLenA, + MaxLenB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + n_prefix_to_B, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, +): + _split_2D_jagged_multirow( + JaggedIn, + OffsetsA, + OffsetsB, + MaxLenA, + MaxLenB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + n_prefix_to_B, + IS_DENSE_A, + IS_DENSE_B, + BLOCK_D, + BLOCK_N, + ) + + +@triton.jit +def _concat_2D_jagged( + ValuesA, + ValuesB, + OffsetsA, + OffsetsB, + MaxLenA, + MaxLenB, + Out, + D, + stride_ad, + stride_bd, + stride_od, + n_prefix_from_B, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, +): + off_z = tl.program_id(1) + off_n = tl.program_id(0) + if IS_DENSE_A: + seq_start_a = off_z * MaxLenA + seq_len_a = MaxLenA + else: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + if IS_DENSE_B: + seq_start_b = off_z * MaxLenB + seq_len_b = MaxLenB + else: + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_len_b = seq_end_b - seq_start_b + seq_len = seq_len_a + seq_len_b + if off_n >= seq_len: + return + offs_d = tl.arange(0, BLOCK_D) + out_seq_start = seq_start_a + seq_start_b + off_n + out_ptrs = Out + out_seq_start.to(tl.int64) * stride_od + offs_d + if off_n < n_prefix_from_B: + in_ptrs = ValuesB + (off_n + seq_start_b).to(tl.int64) * stride_bd + offs_d + elif off_n < seq_len_a + n_prefix_from_B: + in_ptrs = ( + ValuesA + + (off_n - n_prefix_from_B + seq_start_a).to(tl.int64) * stride_ad + + offs_d + ) + else: + in_ptrs = ( + ValuesB + + (off_n - seq_len_a + seq_start_b).to(tl.int64) * stride_bd + + offs_d + ) + v = tl.load(in_ptrs, mask=offs_d < D) + tl.store(out_ptrs, v, mask=offs_d < D) + + +@triton.jit +def _split_2D_jagged( + JaggedIn, + OffsetsA, + OffsetsB, + MaxLenA, + MaxLenB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + n_prefix_to_B, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, +): + off_z = tl.program_id(1) + off_n = tl.program_id(0) + if IS_DENSE_A: + seq_start_a = off_z * MaxLenA + seq_len_a = MaxLenA + else: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + if IS_DENSE_B: + seq_start_b = off_z * MaxLenB + seq_len_b = MaxLenB + else: + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_len_b = seq_end_b - seq_start_b + seq_len = seq_len_a + seq_len_b + if off_n >= seq_len: + return + seq_start = seq_start_a + seq_start_b + offs_d = tl.arange(0, BLOCK_D) + in_ptrs = JaggedIn + (seq_start + off_n).to(tl.int64) * stride_id + offs_d + if off_n < n_prefix_to_B: + out_ptrs = OutB + (off_n + seq_start_b).to(tl.int64) * stride_bd + offs_d + elif off_n < seq_len_a + n_prefix_to_B: + out_ptrs = ( + OutA + + (off_n - n_prefix_to_B + seq_start_a).to(tl.int64) * stride_ad + + offs_d + ) + else: + out_ptrs = ( + OutB + (off_n - seq_len_a + seq_start_b).to(tl.int64) * stride_bd + offs_d + ) + v = tl.load(in_ptrs, mask=offs_d < D) + tl.store(out_ptrs, v, mask=offs_d < D) + + +class _Concat2DJaggedFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + max_seq_len: int, + values_a: torch.Tensor, + values_b: torch.Tensor, + max_len_a: Optional[int], + max_len_b: Optional[int], + offsets_a: Optional[torch.Tensor], + offsets_b: Optional[torch.Tensor], + n_prefix_from_B: int, + ): + values_a = switch_to_contiguous_if_needed(values_a) + values_b = switch_to_contiguous_if_needed(values_b) + is_dense_a = offsets_a is None + is_dense_b = offsets_b is None + total_len_a, D = values_a.shape + total_len_b, _ = values_b.shape + if is_dense_a: + assert max_len_a is not None + B = total_len_a // max_len_a + else: + assert offsets_a is not None + B = offsets_a.shape[0] - 1 + if is_dense_b: + assert max_len_b is not None + B = total_len_b // max_len_b + else: + assert offsets_b is not None + B = offsets_b.shape[0] - 1 + total_seq_len = total_len_a + total_len_b + BLOCK_D = triton.next_power_of_2(D) + values_out = torch.empty( + (total_seq_len, D), device=values_a.device, dtype=values_a.dtype + ) + _triton_concat_2D_jagged_internal( + values_a=values_a, + values_b=values_b, + values_out=values_out, + max_seq_len=max_seq_len, + B=B, + offsets_a=offsets_a, + offsets_b=offsets_b, + max_len_a=max_len_a, + max_len_b=max_len_b, + D=D, + n_prefix_from_B=n_prefix_from_B, + is_dense_a=is_dense_a, + is_dense_b=is_dense_b, + BLOCK_D=BLOCK_D, + ) + ctx.save_for_backward(offsets_a, offsets_b) + ctx.max_seq_len = max_seq_len + ctx.total_len_a = total_len_a + ctx.total_len_b = total_len_b + ctx.is_dense_a = is_dense_a + ctx.is_dense_b = is_dense_b + ctx.max_len_a = max_len_a + ctx.max_len_b = max_len_b + ctx.B = B + ctx.n_prefix_from_B = n_prefix_from_B + return values_out + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, d_out: torch.Tensor + ) -> Tuple[None, torch.Tensor, torch.Tensor, None, None, None, None, None]: + offsets_a, offsets_b = ctx.saved_tensors + _, D = d_out.shape + BLOCK_D = triton.next_power_of_2(D) + d_values_a = torch.zeros( + (ctx.total_len_a, D), device=d_out.device, dtype=d_out.dtype + ) + d_values_b = torch.empty( + (ctx.total_len_b, D), device=d_out.device, dtype=d_out.dtype + ) + # Go through `_triton_split_2D_jagged_internal` (not raw + # `_split_2D_jagged[grid]`) so this backward pass benefits from the same + # AMD-routing-through-multirow workaround as the forward. Calling the + # raw kernel directly would hit `PassManager::run failed` on ROCm at + # `TritonAMDGPUCanonicalizePointers`. If you refactor this, do not + # collapse it back to `_split_2D_jagged[(ctx.max_seq_len, ctx.B)](...)`. + _triton_split_2D_jagged_internal( + jagged_in=d_out, + max_seq_len=ctx.max_seq_len, + B=ctx.B, + offsets_a=offsets_a, + offsets_b=offsets_b, + max_len_a=ctx.max_len_a, + max_len_b=ctx.max_len_b, + out_a=d_values_a, + out_b=d_values_b, + D=D, + n_prefix_to_B=ctx.n_prefix_from_B, + is_dense_a=ctx.is_dense_a, + is_dense_b=ctx.is_dense_b, + BLOCK_D=BLOCK_D, + ) + return None, d_values_a, d_values_b, None, None, None, None, None + + +class _Split2DJaggedFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + max_seq_len: int, + values: torch.Tensor, + total_len_left: Optional[int], + total_len_right: Optional[int], + max_len_a: Optional[int], + max_len_b: Optional[int], + offsets_a: Optional[torch.Tensor], + offsets_b: Optional[torch.Tensor], + n_prefix_to_B: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + values = switch_to_contiguous_if_needed(values) + is_dense_a: bool = offsets_a is None + is_dense_b: bool = offsets_b is None + total_seq_len, D = values.shape + if is_dense_a: + assert is_dense_b is False + assert offsets_b is not None + assert max_len_a is not None + B = offsets_b.shape[0] - 1 + total_len_a = max_len_a * B + total_len_b = total_seq_len - total_len_a + elif is_dense_b: + assert is_dense_a is False + assert offsets_a is not None + assert max_len_b is not None + B = offsets_a.shape[0] - 1 + total_len_b = max_len_b * B + total_len_a = total_seq_len - total_len_b + else: + assert offsets_a is not None and offsets_b is not None + B = offsets_a.shape[0] - 1 + if total_len_left is not None and total_len_right is not None: + assert total_len_left + total_len_right == total_seq_len + total_len_a = total_len_left + total_len_b = total_len_right + else: + total_len_a = int(offsets_a[-1].item()) + total_len_b = values.size(0) - total_len_a + _, D = values.shape + BLOCK_D = triton.next_power_of_2(D) + values_a = torch.empty( + (total_len_a, D), device=values.device, dtype=values.dtype + ) + values_b = torch.empty( + (total_len_b, D), device=values.device, dtype=values.dtype + ) + _triton_split_2D_jagged_internal( + jagged_in=values, + max_seq_len=max_seq_len, + B=B, + offsets_a=offsets_a, + offsets_b=offsets_b, + max_len_a=max_len_a, + max_len_b=max_len_b, + out_a=values_a, + out_b=values_b, + D=D, + n_prefix_to_B=n_prefix_to_B, + is_dense_a=is_dense_a, + is_dense_b=is_dense_b, + BLOCK_D=BLOCK_D, + ) + ctx.save_for_backward(offsets_a, offsets_b) + ctx.max_seq_len = max_seq_len + ctx.total_seq_len = total_seq_len + ctx.max_len_a = max_len_a + ctx.max_len_b = max_len_b + ctx.is_dense_a = is_dense_a + ctx.is_dense_b = is_dense_b + ctx.B = B + ctx.D = D + ctx.n_prefix_to_B = n_prefix_to_B + return values_a, values_b + + @staticmethod + def backward( + ctx, *d_values + ) -> Tuple[None, torch.Tensor, None, None, None, None, None, None, None]: + offsets_a, offsets_b = ctx.saved_tensors + d_values_a, d_values_b = d_values + BLOCK_D = triton.next_power_of_2(ctx.D) + d_jagged_in = torch.empty( + (ctx.total_seq_len, ctx.D), + device=d_values_a.device, + dtype=d_values_a.dtype, + ) + _triton_concat_2D_jagged_internal( + values_a=d_values_a, + values_b=d_values_b, + values_out=d_jagged_in, + max_seq_len=ctx.max_seq_len, + B=ctx.B, + offsets_a=offsets_a, + offsets_b=offsets_b, + max_len_a=ctx.max_len_a, + max_len_b=ctx.max_len_b, + D=ctx.D, + n_prefix_from_B=ctx.n_prefix_to_B, + is_dense_a=ctx.is_dense_a, + is_dense_b=ctx.is_dense_b, + BLOCK_D=BLOCK_D, + ) + + return None, d_jagged_in, None, None, None, None, None, None, None + + +@torch.jit.unused +@torch.fx.wrap +def triton_concat_2D_jagged( + max_seq_len: int, + values_left: torch.Tensor, + values_right: torch.Tensor, + max_len_left: Optional[int], + max_len_right: Optional[int], + offsets_left: Optional[torch.Tensor], + offsets_right: Optional[torch.Tensor], + n_prefix_from_right: int = 0, +) -> torch.Tensor: + return _Concat2DJaggedFunction.apply( + max_seq_len, + values_left, + values_right, + max_len_left, + max_len_right, + offsets_left, + offsets_right, + n_prefix_from_right, + ) + + +@torch.jit.unused +@torch.fx.wrap +def triton_split_2D_jagged( + max_seq_len: int, + values: torch.Tensor, + total_len_left: Optional[int], + total_len_right: Optional[int], + max_len_left: Optional[int], + max_len_right: Optional[int], + offsets_left: Optional[torch.Tensor], + offsets_right: Optional[torch.Tensor], + n_prefix_to_right: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _Split2DJaggedFunction.apply( + max_seq_len, + values, + total_len_left, + total_len_right, + max_len_left, + max_len_right, + offsets_left, + offsets_right, + n_prefix_to_right, + ) + + +class _Concat2DJaggedMultirowFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + max_seq_len: int, + values_left: torch.Tensor, + values_right: torch.Tensor, + max_len_left: Optional[int], + max_len_right: Optional[int], + offsets_left: Optional[torch.Tensor], + offsets_right: Optional[torch.Tensor], + n_prefix_from_right: int, + ) -> torch.Tensor: + values_left = switch_to_contiguous_if_needed(values_left) + values_right = switch_to_contiguous_if_needed(values_right) + is_dense_left = offsets_left is None + is_dense_right = offsets_right is None + total_len_left, D = values_left.shape + total_len_right, _ = values_right.shape + if is_dense_left: + assert max_len_left is not None + B = total_len_left // max_len_left + else: + assert offsets_left is not None + B = offsets_left.shape[0] - 1 + if is_dense_right: + assert max_len_right is not None + B = total_len_right // max_len_right + else: + assert offsets_right is not None + B = offsets_right.shape[0] - 1 + total_seq_len = total_len_left + total_len_right + BLOCK_D = triton.next_power_of_2(D) + values_out = torch.empty( + (total_seq_len, D), device=values_left.device, dtype=values_left.dtype + ) + + def grid(meta): + return (triton.cdiv(max_seq_len, meta["BLOCK_N"]), B) + + concat_2D_jagged_multirow[grid]( + ValuesA=values_left, + ValuesB=values_right, + OffsetsA=offsets_left, + OffsetsB=offsets_right, + MaxLenA=max_len_left, + MaxLenB=max_len_right, + Out=values_out, + D=D, + stride_ad=values_left.stride(-2), + stride_bd=values_right.stride(-2), + stride_od=values_out.stride(-2), + n_prefix_from_B=n_prefix_from_right, + IS_DENSE_A=is_dense_left, + IS_DENSE_B=is_dense_right, + BLOCK_D=BLOCK_D, + ) + ctx.save_for_backward(offsets_left, offsets_right) + ctx.max_seq_len = max_seq_len + ctx.total_len_left = total_len_left + ctx.total_len_right = total_len_right + ctx.is_dense_left = is_dense_left + ctx.is_dense_right = is_dense_right + ctx.max_len_left = max_len_left + ctx.max_len_right = max_len_right + ctx.B = B + ctx.n_prefix_from_right = n_prefix_from_right + return values_out + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, d_out: torch.Tensor + ) -> Tuple[None, torch.Tensor, torch.Tensor, None, None, None, None, None]: + offsets_left, offsets_right = ctx.saved_tensors + _, D = d_out.shape + BLOCK_D = triton.next_power_of_2(D) + d_values_left = torch.zeros( + (ctx.total_len_left, D), device=d_out.device, dtype=d_out.dtype + ) + d_values_right = torch.empty( + (ctx.total_len_right, D), device=d_out.device, dtype=d_out.dtype + ) + + def grid(meta): + return (triton.cdiv(ctx.max_seq_len, meta["BLOCK_N"]), ctx.B) + + split_2D_jagged_multirow[grid]( + JaggedIn=d_out, + OffsetsA=offsets_left, + OffsetsB=offsets_right, + MaxLenA=ctx.max_len_left, + MaxLenB=ctx.max_len_right, + OutA=d_values_left, + OutB=d_values_right, + D=D, + stride_id=d_out.stride(-2), + stride_ad=d_values_left.stride(-2), + stride_bd=d_values_right.stride(-2), + n_prefix_to_B=ctx.n_prefix_from_right, + IS_DENSE_A=ctx.is_dense_left, + IS_DENSE_B=ctx.is_dense_right, + BLOCK_D=BLOCK_D, + ) + return None, d_values_left, d_values_right, None, None, None, None, None + + +class _Split2DJaggedMultirowFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + max_seq_len: int, + values: torch.Tensor, + total_len_left: Optional[int], + total_len_right: Optional[int], + max_len_left: Optional[int], + max_len_right: Optional[int], + offsets_left: Optional[torch.Tensor], + offsets_right: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + values = switch_to_contiguous_if_needed(values) + is_dense_left: bool = offsets_left is None + is_dense_right: bool = offsets_right is None + total_seq_len, D = values.shape + + if is_dense_left: + assert is_dense_right is False + assert offsets_right is not None + assert max_len_left is not None + B = offsets_right.shape[0] - 1 + total_len_a = max_len_left * B + total_len_b = total_seq_len - total_len_a + elif is_dense_right: + assert is_dense_left is False + assert offsets_left is not None + assert max_len_right is not None + B = offsets_left.shape[0] - 1 + total_len_b = max_len_right * B + total_len_a = total_seq_len - total_len_b + else: + assert offsets_left is not None and offsets_right is not None + B = offsets_left.shape[0] - 1 + if total_len_left is not None and total_len_right is not None: + assert total_len_left + total_len_right == total_seq_len + total_len_a = total_len_left + total_len_b = total_len_right + else: + total_len_a = int(offsets_left[-1].item()) + total_len_b = values.size(0) - total_len_a + + BLOCK_D = triton.next_power_of_2(D) + values_a = torch.empty( + (total_len_a, D), device=values.device, dtype=values.dtype + ) + values_b = torch.empty( + (total_len_b, D), device=values.device, dtype=values.dtype + ) + + def grid(meta): + return (triton.cdiv(max_seq_len, meta["BLOCK_N"]), B) + + split_2D_jagged_multirow[grid]( + JaggedIn=values, + OffsetsA=offsets_left, + OffsetsB=offsets_right, + MaxLenA=max_len_left, + MaxLenB=max_len_right, + OutA=values_a, + OutB=values_b, + D=D, + stride_id=values.stride(-2), + stride_ad=values_a.stride(-2), + stride_bd=values_b.stride(-2), + n_prefix_to_B=0, + IS_DENSE_A=is_dense_left, + IS_DENSE_B=is_dense_right, + BLOCK_D=BLOCK_D, + ) + + ctx.save_for_backward(offsets_left, offsets_right) + ctx.max_seq_len = max_seq_len + ctx.total_seq_len = total_seq_len + ctx.max_len_left = max_len_left + ctx.max_len_right = max_len_right + ctx.is_dense_left = is_dense_left + ctx.is_dense_right = is_dense_right + ctx.B = B + ctx.D = D + + return values_a, values_b + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, *d_values + ) -> Tuple[None, torch.Tensor, None, None, None, None, None, None]: + offsets_left, offsets_right = ctx.saved_tensors + d_values_a, d_values_b = d_values + BLOCK_D = triton.next_power_of_2(ctx.D) + d_jagged_in = torch.empty( + (ctx.total_seq_len, ctx.D), + device=d_values_a.device, + dtype=d_values_a.dtype, + ) + + def grid(meta): + return (triton.cdiv(ctx.max_seq_len, meta["BLOCK_N"]), ctx.B) + + concat_2D_jagged_multirow[grid]( + ValuesA=d_values_a, + ValuesB=d_values_b, + OffsetsA=offsets_left, + OffsetsB=offsets_right, + MaxLenA=ctx.max_len_left, + MaxLenB=ctx.max_len_right, + Out=d_jagged_in, + D=ctx.D, + stride_ad=d_values_a.stride(-2), + stride_bd=d_values_b.stride(-2), + stride_od=d_jagged_in.stride(-2), + n_prefix_from_B=0, + IS_DENSE_A=ctx.is_dense_left, + IS_DENSE_B=ctx.is_dense_right, + BLOCK_D=BLOCK_D, + ) + + return None, d_jagged_in, None, None, None, None, None, None + + +@torch.jit.unused +@torch.fx.wrap +def triton_concat_2D_jagged_multirow( + max_seq_len: int, + values_a: torch.Tensor, + values_b: torch.Tensor, + offsets_a: Optional[torch.Tensor], + offsets_b: Optional[torch.Tensor], + max_len_a: int, + max_len_b: int, +) -> torch.Tensor: + return _Concat2DJaggedMultirowFunction.apply( + max_seq_len, + values_a, + values_b, + max_len_a, + max_len_b, + offsets_a, + offsets_b, + 0, + ) + + +@torch.jit.unused +@torch.fx.wrap +def triton_split_2D_jagged_multirow( + max_seq_len: int, + values: torch.Tensor, + total_len_left: Optional[int] = None, + total_len_right: Optional[int] = None, + max_len_left: Optional[int] = None, + max_len_right: Optional[int] = None, + offsets_left: Optional[torch.Tensor] = None, + offsets_right: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _Split2DJaggedMultirowFunction.apply( + max_seq_len, + values, + total_len_left, + total_len_right, + max_len_left, + max_len_right, + offsets_left, + offsets_right, + ) diff --git a/recommendation_v4/generative_recommenders/ops/triton/triton_layer_norm.py b/recommendation_v4/generative_recommenders/ops/triton/triton_layer_norm.py new file mode 100644 index 000000000..cc513433b --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/triton/triton_layer_norm.py @@ -0,0 +1,1337 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-unsafe + +#!/usr/bin/env python3 + + +from typing import List, Optional, Tuple + +import torch + +# @manual=//triton:triton +import triton + +# @manual=//triton:triton +import triton.language as tl +from generative_recommenders.common import ( + switch_to_contiguous_if_needed, + triton_autotune, +) +from generative_recommenders.ops.triton._autotune_pinning import pinned_or_full +from generative_recommenders.ops.utils import ( + is_sm100_plus, + is_sm90, + maybe_register_custom_op, +) + +try: + # @manual=//triton:triton + from triton.language.extra.libdevice import fast_dividef, rsqrt as libdevice_rsqrt +except ImportError: + try: + # @manual=//triton:triton + from triton.language.extra.cuda.libdevice import ( + fast_dividef, + rsqrt as libdevice_rsqrt, + ) + except ImportError: + # pyre-ignore: Undefined import [21] + # @manual=//triton:triton + from triton.language.math import fast_dividef, rsqrt as libdevice_rsqrt + + +def _get_layer_norm_fwd_configs() -> List[triton.Config]: + """Generate autotune configs for multi-row LayerNorm kernels.""" + configs = [] + block_ns = [4, 8, 16] if is_sm100_plus() else [1, 2, 4, 8] + for BLOCK_N in block_ns: + for num_warps in [1, 2, 4, 8]: + configs.append( + triton.Config( + {"BLOCK_N": BLOCK_N}, + num_warps=num_warps, + ) + ) + return configs + + +def _bwd_pre_hook(nargs): + nargs["DW"].zero_() + if "DB" in nargs: + nargs["DB"].zero_() + + +def _get_norm_bwd_configs() -> List[triton.Config]: + """Generate autotune configs for multi-row LayerNorm kernels.""" + configs = [] + if is_sm100_plus(): + block_ns = [8, 16] + num_shards_list = [8, 16] + num_warps_list = [2, 4] + elif is_sm90(): + block_ns = [2, 4] + num_shards_list = [8] + num_warps_list = [2, 4] + else: + block_ns = [1, 2] + num_shards_list = [8] + num_warps_list = [2, 4] + for BLOCK_N in block_ns: + for num_warps in num_warps_list: + for num_shards in num_shards_list: + configs.append( + triton.Config( + {"BLOCK_N": BLOCK_N, "SHARDS_PER_SM": num_shards}, + num_warps=num_warps, + ) + ) + return configs + + +@triton_autotune( + configs=_get_layer_norm_fwd_configs(), + key=["BLOCK_D"], +) +@triton.jit +def _layer_norm_fwd( + X, + Y, + Mean, + Rstd, + N, + D, + eps, + stride_x, + stride_y, + TRAINING: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, + COMPUTE_MEAN_AND_RSTD: tl.constexpr, +): + block_id = tl.program_id(0) + start_row = block_id * BLOCK_N + + X_block_ptr = tl.make_block_ptr( + base=X, + shape=(N, D), + strides=(stride_x, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + Y_block_ptr = tl.make_block_ptr( + base=Y, + shape=(N, D), + strides=(stride_y, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + x_block = tl.load(X_block_ptr, boundary_check=(0, 1), padding_option="zero").to( + tl.float32 + ) + + cols = tl.arange(0, BLOCK_D) + col_mask = cols < D + rows = start_row + tl.arange(0, BLOCK_N) + row_mask = rows < N + + if COMPUTE_MEAN_AND_RSTD: + mean = tl.sum(x_block, axis=1) / D + if TRAINING: + tl.store(Mean + rows, mean, row_mask) + mean = tl.expand_dims(mean, 1) + else: + mean = tl.load(Mean + rows, row_mask, other=0.0) + mean = tl.expand_dims(mean, 1) + + x_mean = x_block - mean + x_mean = tl.where(row_mask[:, None] & col_mask[None, :], x_mean, 0.0) + + if COMPUTE_MEAN_AND_RSTD: + _var = x_mean * x_mean + var = tl.sum(_var, axis=1) / D + rstd = 1 / tl.sqrt(var + eps) + if TRAINING: + tl.store(Rstd + rows, rstd, row_mask) + else: + rstd = tl.load(Rstd + rows, row_mask, other=0.0) + + rstd = tl.expand_dims(rstd, 1) + y = x_mean * rstd + + tl.store(Y_block_ptr, y.to(Y.dtype.element_ty), boundary_check=(0, 1)) + + +@triton_autotune( + configs=_get_layer_norm_fwd_configs(), + key=["BLOCK_D"], +) +@triton.jit +def _weighted_layer_norm_fwd( + X, + Y, + W, + B, + Mean, + Rstd, + N, + D, + eps, + stride_x, + stride_y, + IS_SWISH: tl.constexpr, + TRAINING: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, + COMPUTE_MEAN_AND_RSTD: tl.constexpr, +): + # Get the block ID and calculate starting row + block_id = tl.program_id(0) + start_row = block_id * BLOCK_N + + # Load weight and bias once (shared across all rows in this block) + cols = tl.arange(0, BLOCK_D) + col_mask = cols < D + w = tl.load(W + cols, mask=col_mask, other=0.0).to(tl.float32) + b = tl.load(B + cols, mask=col_mask, other=0.0).to(tl.float32) + + # Create block pointers for X and Y + X_block_ptr = tl.make_block_ptr( + base=X, + shape=(N, D), + strides=(stride_x, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + Y_block_ptr = tl.make_block_ptr( + base=Y, + shape=(N, D), + strides=(stride_y, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + x_block = tl.load(X_block_ptr, boundary_check=(0, 1), padding_option="zero").to( + tl.float32 + ) + + rows = start_row + tl.arange(0, BLOCK_N) + row_mask = rows < N + + if COMPUTE_MEAN_AND_RSTD: + mean = tl.sum(x_block, axis=1) / D + if TRAINING: + tl.store(Mean + rows, mean, row_mask) + mean = tl.expand_dims(mean, 1) + else: + mean = tl.load(Mean + rows, row_mask, other=0.0) + mean = tl.expand_dims(mean, 1) + + x_mean = x_block - mean + x_mean = tl.where(row_mask[:, None] & col_mask[None, :], x_mean, 0.0) + + if COMPUTE_MEAN_AND_RSTD: + _var = x_mean * x_mean + var = tl.sum(_var, axis=1) / D + rstd = libdevice_rsqrt(var + eps) + if TRAINING: + tl.store(Rstd + rows, rstd, row_mask) + else: + rstd = tl.load(Rstd + rows, row_mask, other=0.0) + + rstd = tl.expand_dims(rstd, 1) + y = x_mean * rstd + y = y * w[None, :] + b[None, :] + + if IS_SWISH: + y = tl.sigmoid(y) * x_block + + tl.store(Y_block_ptr, y.to(Y.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def _layer_norm_bwd_dx( + DX, + DY, + X, + Mean, + Rstd, + stride_dx, + stride_dy, + stride_x, + D, + eps, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_D) + mask = cols < D + X += row.to(tl.int64) * stride_x + DY += row.to(tl.int64) * stride_dy + DX += row.to(tl.int64) * stride_dx + + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + + # Compute dx + xhat = (x - mean) * rstd + xhat = tl.where(mask, xhat, 0.0) + dy = tl.where(mask, dy, 0.0) + c1 = tl.sum(xhat * dy, axis=0) / D + c2 = tl.sum(dy, axis=0) / D + dx = (dy - (xhat * c1 + c2)) * rstd + # Write dx + tl.store(DX + cols, dx, mask=mask) + + +@triton_autotune( + configs=pinned_or_full( + [ + triton.Config({"BLOCK_N": 1}, num_warps=1), # bs=32 winner + triton.Config({"BLOCK_N": 8}, num_warps=1), # bs=1024 winner + ], + _get_layer_norm_fwd_configs, + ), + key=["BLOCK_D"], +) +@triton.jit +def _weighted_layer_norm_bwd_dx( + DX, + DY, + DW, + DB, + X, + W, + B, + Mean, + Rstd, + stride_dx, + stride_dy, + stride_x, + D, + eps, + IS_SWISH: tl.constexpr, + N, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid = tl.program_id(0) + tile_num = tl.num_programs(0) + num_blocks = tl.cdiv(N, BLOCK_N) + blocks_per_tile = num_blocks // tile_num + if pid < num_blocks % tile_num: + blocks_per_tile += 1 + + cols = tl.arange(0, BLOCK_D) + col_mask = cols < D + w = tl.load(W + cols, mask=col_mask, other=0.0).to(tl.float32) + + acc_dw = tl.zeros([BLOCK_D], dtype=tl.float32) + acc_db = tl.zeros([BLOCK_D], dtype=tl.float32) + + start_block = pid + + for idx in range(blocks_per_tile): + current_block = start_block + idx * tile_num + start_row = current_block * BLOCK_N + + X_block_ptr = tl.make_block_ptr( + base=X, + shape=(N, D), + strides=(stride_x, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + DX_block_ptr = tl.make_block_ptr( + base=DX, + shape=(N, D), + strides=(stride_dx, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + DY_block_ptr = tl.make_block_ptr( + base=DY, + shape=(N, D), + strides=(stride_dy, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + # Load data blocks + x_block = tl.load(X_block_ptr, boundary_check=(0, 1), padding_option="zero").to( + tl.float32 + ) + dy_block = tl.load( + DY_block_ptr, boundary_check=(0, 1), padding_option="zero" + ).to(tl.float32) + + # Load mean and rstd for all rows in this block + rows = start_row + tl.arange(0, BLOCK_N) + row_mask = rows < N + mean = tl.load(Mean + rows, row_mask, other=0.0) + rstd = tl.load(Rstd + rows, row_mask, other=0.0) + + # Expand dimensions for broadcasting + mean = tl.expand_dims(mean, 1) + rstd = tl.expand_dims(rstd, 1) + + xhat = (x_block - mean) * rstd + + xhat = tl.where(row_mask[:, None] & col_mask[None, :], xhat, 0.0) + wdy = w[None, :] * dy_block + wdy = tl.where(row_mask[:, None] & col_mask[None, :], wdy, 0.0) + + # Compute dx + if IS_SWISH: + b = tl.load(B + cols, mask=col_mask, other=0.0).to(tl.float32) + sigmoid_layer_norm = tl.sigmoid(xhat * w[None, :] + b[None, :]) + sigmoid_layer_norm = tl.where( + row_mask[:, None] & col_mask[None, :], sigmoid_layer_norm, 0.0 + ) + + sigmoid_deriv = sigmoid_layer_norm * (1 - sigmoid_layer_norm) + x_ = wdy * x_block * sigmoid_deriv + x_ = tl.where(row_mask[:, None] & col_mask[None, :], x_, 0.0) + + c1 = tl.sum(xhat * x_, axis=1) / D + c2 = tl.sum(x_, axis=1) / D + c1 = tl.expand_dims(c1, 1) + c2 = tl.expand_dims(c2, 1) + dx = (x_ - (xhat * c1 + c2)) * rstd + + dx = dy_block * sigmoid_layer_norm + dx + # Write dx + tl.store(DX_block_ptr, dx.to(DX.dtype.element_ty), boundary_check=(0, 1)) + partial_dw = tl.sum(dy_block * x_block * xhat * sigmoid_deriv, axis=0) + partial_db = tl.sum(dy_block * x_block * sigmoid_deriv, axis=0) + else: + c1 = tl.sum(xhat * wdy, axis=1) / D + c2 = tl.sum(wdy, axis=1) / D + c1 = tl.expand_dims(c1, 1) + c2 = tl.expand_dims(c2, 1) + dx = (wdy - (xhat * c1 + c2)) * rstd + # Write dx + tl.store(DX_block_ptr, dx.to(DX.dtype.element_ty), boundary_check=(0, 1)) + partial_dw = tl.sum(dy_block * xhat, axis=0) + partial_db = tl.sum(dy_block, axis=0) + + # Accumulate partial sums in shared memory + acc_dw += partial_dw + acc_db += partial_db + + # Store accumulated sums back to global memory + dw_ptrs = DW + pid.to(tl.int64) * D + cols + db_ptrs = DB + pid.to(tl.int64) * D + cols + tl.store(dw_ptrs, acc_dw, mask=col_mask) + tl.store(db_ptrs, acc_db, mask=col_mask) + + +def _get_bwd_dwdb_configs() -> List[triton.Config]: + configs = [] + BLOCK_N_CHOICES = [32, 64, 128, 256] + if is_sm100_plus(): + BLOCK_N_CHOICES = [128, 256, 512, 1024] + for BLOCK_N in BLOCK_N_CHOICES: + for num_warps in [8, 16] + ([] if torch.ops.hip else [32]): + configs.append( + triton.Config( + {"BLOCK_N": BLOCK_N}, + num_warps=num_warps, + ) + ) + return configs + + +@triton_autotune( + configs=pinned_or_full( + [triton.Config({"BLOCK_N": 128}, num_warps=8)], + _get_bwd_dwdb_configs, + ), + key=["D"], +) +@triton.jit +def _layer_norm_bwd_dwdb( + DW, + DB, + FINAL_DW, + FINAL_DB, + N, + D, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid = tl.program_id(0) + cols = pid * BLOCK_D + tl.arange(0, BLOCK_D) + dw = tl.zeros((BLOCK_N, BLOCK_D), dtype=tl.float32) + db = tl.zeros((BLOCK_N, BLOCK_D), dtype=tl.float32) + + for i in range(0, N, BLOCK_N): + rows = i + tl.arange(0, BLOCK_N) + # pyre-fixme[16]: `int` has no attribute `__getitem__`. + mask = (rows[:, None] < N) & (cols[None, :] < D) + offs = rows[:, None] * D + cols[None, :] + dw += tl.load(DW + offs, mask=mask, other=0.0) + db += tl.load(DB + offs, mask=mask, other=0.0) + + sum_dw = tl.sum(dw, axis=0) + sum_db = tl.sum(db, axis=0) + tl.store(FINAL_DW + cols, sum_dw.to(FINAL_DW.dtype.element_ty), mask=cols < D) + tl.store(FINAL_DB + cols, sum_db.to(FINAL_DB.dtype.element_ty), mask=cols < D) + + +def compute_BLOCK_D(x: torch.Tensor) -> int: + """Compute the BLOCK_D parameter for layer norm kernels.""" + D = x.shape[-1] + MAX_FUSED_SIZE = 65536 // x.element_size() + return min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + + +@maybe_register_custom_op( + "generative_recommenders::triton_weighted_layer_norm_fwd", mutates_args=() +) +def triton_weighted_layer_norm_fwd( + x: torch.Tensor, + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + eps: float, + mean: Optional[torch.Tensor] = None, + rstd: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert x.dim() == 2, f"x.dim() == {x.dim()}, expected 2" + x = switch_to_contiguous_if_needed(x) + N, D = x.shape + learnable = weight is not None + if learnable: + assert bias is not None and weight is not None + assert weight.dim() == 1 + assert bias.dim() == 1 + assert weight.numel() == D + assert bias.numel() == D + + y = torch.empty_like(x) + compute_mean_and_rstd = mean is None or rstd is None + # Always allocate new tensors to avoid aliasing inputs with outputs + out_mean = torch.empty((N,), dtype=torch.float32, device=x.device) + out_rstd = torch.empty((N,), dtype=torch.float32, device=x.device) + if not compute_mean_and_rstd: + assert mean is not None and rstd is not None + out_mean.copy_(mean) + out_rstd.copy_(rstd) + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_D: int = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BLOCK_D: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + + if N == 0: + return y, out_mean, out_rstd + + # pyre-ignore[28] + grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]),) # noqa E731 + if learnable: + _weighted_layer_norm_fwd[grid]( + x, + y, + weight, + bias, + out_mean, + out_rstd, + N, + D, + eps, + x.stride(0), + y.stride(0), + IS_SWISH=False, + TRAINING=True, + BLOCK_D=BLOCK_D, + COMPUTE_MEAN_AND_RSTD=compute_mean_and_rstd, + ) + else: + _layer_norm_fwd[grid]( + x, + y, + out_mean, + out_rstd, + N, + D, + eps, + x.stride(0), + y.stride(0), + TRAINING=True, + BLOCK_D=BLOCK_D, + COMPUTE_MEAN_AND_RSTD=compute_mean_and_rstd, + ) + + return y, out_mean, out_rstd + + +@triton_weighted_layer_norm_fwd.register_fake +def _( + x: torch.Tensor, + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + eps: float, + mean: Optional[torch.Tensor] = None, + rstd: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + N = x.shape[0] + y = torch.empty_like(x) + # Always allocate new tensors to avoid aliasing inputs with outputs + out_mean = torch.empty((N,), dtype=torch.float32, device=x.device) + out_rstd = torch.empty((N,), dtype=torch.float32, device=x.device) + return y, out_mean, out_rstd + + +@maybe_register_custom_op( + "generative_recommenders::triton_weighted_layer_norm_bwd", mutates_args=() +) +def _triton_weighted_layer_norm_bwd_impl( + dy: torch.Tensor, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + mean: torch.Tensor, + rstd: torch.Tensor, + learnable: bool, + eps: float, + BLOCK_D: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + num_warps: int = min(max(BLOCK_D // 256, 1), 8) + if learnable: + N, D = x.shape + dx = torch.empty_like(x) + sms = torch.cuda.get_device_properties(x.device).multi_processor_count + tile_num = max(1, min(sms * 8, N // 4)) + _dweight = torch.empty((tile_num, D), dtype=torch.float32, device=x.device) + _dbias = torch.empty((tile_num, D), dtype=torch.float32, device=x.device) + dweight = torch.empty((D,), dtype=weight.dtype, device=x.device) + dbias = torch.empty((D,), dtype=weight.dtype, device=x.device) + if N == 0: + dweight.zero_() + dbias.zero_() + return dx, dweight, dbias + # pyre-ignore[28] + _weighted_layer_norm_bwd_dx[(tile_num,)]( + dx, + dy, + _dweight, + _dbias, + x, + weight, + bias, + mean, + rstd, + dx.stride(0), + dy.stride(0), + x.stride(0), + D, + eps, + IS_SWISH=False, + N=N, + BLOCK_D=BLOCK_D, + ) + + def grid(META): + return (triton.cdiv(D, META["BLOCK_D"]),) + + blocks = triton.next_power_of_2(sms * 4) + BLOCK_D = triton.next_power_of_2(triton.cdiv(D, blocks)) + BLOCK_D = min(max(BLOCK_D, 4), 128) + _layer_norm_bwd_dwdb[grid]( + _dweight, + _dbias, + dweight, + dbias, + tile_num, + D, + BLOCK_D=BLOCK_D, + ) + + return dx, dweight, dbias + else: + N, D = x.shape + dx = torch.empty_like(x) + # Return empty tensors as sentinels for None + dweight = torch.empty(0, dtype=x.dtype, device=x.device) + dbias = torch.empty(0, dtype=x.dtype, device=x.device) + if N == 0: + return dx, dweight, dbias + # pyre-ignore[28] + _layer_norm_bwd_dx[(N,)]( + dx, + dy, + x, + mean, + rstd, + dx.stride(0), + dy.stride(0), + x.stride(0), + D, + eps, + BLOCK_D=BLOCK_D, + num_warps=num_warps, + ) + return dx, dweight, dbias + + +@_triton_weighted_layer_norm_bwd_impl.register_fake +def _( + dy: torch.Tensor, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + mean: torch.Tensor, + rstd: torch.Tensor, + learnable: bool, + eps: float, + BLOCK_D: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + dx = torch.empty_like(x) + if learnable: + D = x.shape[-1] + dweight = torch.empty((D,), dtype=weight.dtype, device=x.device) + dbias = torch.empty((D,), dtype=weight.dtype, device=x.device) + else: + dweight = torch.empty(0, dtype=x.dtype, device=x.device) + dbias = torch.empty(0, dtype=x.dtype, device=x.device) + return dx, dweight, dbias + + +def triton_weighted_layer_norm_bwd( + dy: torch.Tensor, + x: torch.Tensor, + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + mean: torch.Tensor, + rstd: torch.Tensor, + learnable: bool, + eps: float, + BLOCK_D: int, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + # Use sentinel tensors for custom_op compatibility (can't return Optional[Tensor]) + _weight = ( + weight if weight is not None else torch.empty(0, dtype=x.dtype, device=x.device) + ) + _bias = bias if bias is not None else torch.empty(0, dtype=x.dtype, device=x.device) + dx, dweight, dbias = _triton_weighted_layer_norm_bwd_impl( + dy=dy, + x=x, + weight=_weight, + bias=_bias, + mean=mean, + rstd=rstd, + learnable=learnable, + eps=eps, + BLOCK_D=BLOCK_D, + ) + if not learnable: + return dx, None, None + return dx, dweight, dbias + + +class LayerNormFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + x: torch.Tensor, + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + eps: float, + ) -> torch.Tensor: + y, mean, rstd = triton_weighted_layer_norm_fwd( + x=x, + weight=weight, + bias=bias, + eps=eps, + ) + BLOCK_D = compute_BLOCK_D(x) + learnable = weight is not None + if learnable: + ctx.save_for_backward(x, weight, bias, mean, rstd) + else: + ctx.save_for_backward(x, mean, rstd) + ctx.BLOCK_D = BLOCK_D + ctx.eps = eps + ctx.learnable = learnable + return y + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, dy: torch.Tensor + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], None]: + if ctx.learnable: + x, weight, bias, mean, rstd = ctx.saved_tensors + else: + x, mean, rstd = ctx.saved_tensors + weight, bias = None, None + dx, dweight, dbias = triton_weighted_layer_norm_bwd( + dy=dy, + x=x, + weight=weight, + bias=bias, + mean=mean, + rstd=rstd, + learnable=ctx.learnable, + eps=ctx.eps, + BLOCK_D=ctx.BLOCK_D, + ) + return dx, dweight, dbias, None + + +def _get_rms_norm_fwd_configs() -> List[triton.Config]: + """Generate autotune configs for multi-row RMSNorm kernels.""" + configs = [] + for BLOCK_N in [1, 4, 16]: + for num_warps in [2, 4]: + configs.append( + triton.Config( + {"BLOCK_N": BLOCK_N}, + num_warps=num_warps, + ) + ) + return configs + + +@triton.autotune( + configs=_get_rms_norm_fwd_configs(), + key=["BLOCK_D", "SILU"], +) +@triton.jit +def _weighted_rms_norm_fwd( + X, + Y, + W, + Rstd, + N, + D: tl.constexpr, + eps, + stride_x, + stride_y, + SILU: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, +): + block_id = tl.program_id(0) + start_row = block_id * BLOCK_N + + # Load weight once (shared across all rows in this block) + cols = tl.arange(0, BLOCK_D) + col_mask = cols < D + w = tl.load(W + cols, mask=col_mask, other=0.0).to(tl.float32) + + # Create block pointers for X and Y + X_block_ptr = tl.make_block_ptr( + base=X, + shape=(N, D), + strides=(stride_x, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + Y_block_ptr = tl.make_block_ptr( + base=Y, + shape=(N, D), + strides=(stride_y, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + x_block = tl.load(X_block_ptr, boundary_check=(0, 1), padding_option="zero").to( + tl.float32 + ) + + rows = start_row + tl.arange(0, BLOCK_N) + row_mask = rows < N + + # Compute variance (RMS norm uses x directly, not x - mean) + x_masked = tl.where(row_mask[:, None] & col_mask[None, :], x_block, 0.0) + _var = x_masked * x_masked + var = tl.sum(_var, axis=1) / D + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + rows, rstd, row_mask) + + # Normalize and apply linear transformation + rstd = tl.expand_dims(rstd, 1) + y = x_block * rstd + y = y * w[None, :] + + if SILU: + # pyre-ignore[16]: Module `triton.language.math` has no attribute `fast_dividef` + y = fast_dividef(y, 1.0 + tl.exp(-y)) + + tl.store(Y_block_ptr, y.to(Y.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def _weighted_rms_norm_bwd_dx( + DX, + DY, + DW, + X, + W, + Rstd, + Lock, + stride_dx, + stride_dy, + stride_x, + D: tl.constexpr, + eps, + GROUP_N, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_D) + mask = cols < D + X += row.to(tl.int64) * stride_x + DY += row.to(tl.int64) * stride_dy + DX += row.to(tl.int64) * stride_dx + + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + rstd = tl.load(Rstd + row) + + # Compute dx + xhat = x * rstd + w = tl.load(W + cols, mask=mask).to(tl.float32) + wdy = w * dy + + xhat = tl.where(mask, xhat, 0.0) + wdy = tl.where(mask, wdy, 0.0) + c1 = tl.sum(xhat * wdy, axis=0) / D + dx = (wdy - (xhat * c1)) * rstd + # Write dx + tl.store(DX + cols, dx, mask=mask) + + # Offset locks and weights/biases gradient pointer for parallel reduction + lock_id = row % GROUP_N + Lock += lock_id + Count = Lock + GROUP_N + DW = DW + lock_id * D + cols + # Accumulate partial sums for dw/db + partial_dw = dy * xhat + while tl.atomic_cas(Lock, 0, 1) == 1: + pass + count = tl.load(Count) + # First store doesn't accumulate + if count == 0: + tl.atomic_xchg(Count, 1) + else: + partial_dw += tl.load(DW, mask=mask) + tl.store(DW, partial_dw, mask=mask) + # Release the lock + tl.atomic_xchg(Lock, 0) + + +@triton_autotune( + configs=_get_norm_bwd_configs(), + key=["BLOCK_D", "SILU"], + reset_to_zero=["DW"], +) +@triton.jit +def _weighted_rms_norm_bwd( + DX, + DY, + DW, + X, + W, + Rstd, + stride_dx, + stride_dy, + stride_x, + D, + eps, + N, + SILU: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, + SHARDS_PER_SM: tl.constexpr, +): + pid = tl.program_id(0) + tile_num = tl.num_programs(0) + num_blocks = tl.cdiv(N, BLOCK_N) + blocks_per_tile = num_blocks // tile_num + if pid < num_blocks % tile_num: + blocks_per_tile += 1 + + cols = tl.arange(0, BLOCK_D) + col_mask = cols < D + w = tl.load(W + cols, mask=col_mask, other=0.0).to(tl.float32) + + start_block = pid + + acc_dw = tl.zeros([BLOCK_D], dtype=tl.float32) + + for idx in range(blocks_per_tile): + current_block = start_block + idx * tile_num + start_row = current_block * BLOCK_N + + X_block_ptr = tl.make_block_ptr( + base=X, + shape=(N, D), + strides=(stride_x, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + DX_block_ptr = tl.make_block_ptr( + base=DX, + shape=(N, D), + strides=(stride_dx, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + DY_block_ptr = tl.make_block_ptr( + base=DY, + shape=(N, D), + strides=(stride_dy, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + # Load data blocks + x_block = tl.load(X_block_ptr, boundary_check=(0, 1), padding_option="zero").to( + tl.float32 + ) + dy_block = tl.load( + DY_block_ptr, boundary_check=(0, 1), padding_option="zero" + ).to(tl.float32) + + # Load rstd for all rows in this block + rows = start_row + tl.arange(0, BLOCK_N) + row_mask = rows < N + rstd = tl.load(Rstd + rows, row_mask, other=0.0) + + # Expand dimensions for broadcasting + rstd = tl.expand_dims(rstd, 1) + + # Compute dx + xhat = x_block * rstd + + # Apply SILU backward if enabled + if SILU: + y_before_silu = xhat * w[None, :] + # pyre-fixme[16] + sig_y = fast_dividef(1.0, 1.0 + tl.exp(-y_before_silu)) + # SILU derivative: sigmoid(y) + y * sigmoid(y) * (1 - sigmoid(y)) + dy_block = dy_block * (sig_y + y_before_silu * sig_y * (1.0 - sig_y)) + + wdy = w[None, :] * dy_block + + c1 = tl.sum(xhat * wdy, axis=1) / D + c1 = tl.expand_dims(c1, 1) + dx = (wdy - (xhat * c1)) * rstd + + # Write dx + tl.store(DX_block_ptr, dx.to(DX.dtype.element_ty), boundary_check=(0, 1)) + + # Accumulate partial sums for dw + # Compute dw for all rows, then sum locally before atomic operation + partial_dw_block = dy_block * xhat + # Local reduction: sum across all rows in this block + partial_dw = tl.sum(partial_dw_block, axis=0) + acc_dw += partial_dw + + DW_ptr = DW + cols + tl.atomic_add(DW_ptr, acc_dw, col_mask) + + +@triton_autotune( + configs=_get_bwd_dwdb_configs(), + key=["D"], +) +@triton.jit +def _rms_norm_bwd_dwdb( + DW, + FINAL_DW, + N, + D, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid = tl.program_id(0) + cols = pid * BLOCK_D + tl.arange(0, BLOCK_D) + dw = tl.zeros((BLOCK_N, BLOCK_D), dtype=tl.float32) + + for i in range(0, N, BLOCK_N): + rows = i + tl.arange(0, BLOCK_N) + # pyre-fixme[16]: `int` has no attribute `__getitem__`. + mask = (rows[:, None] < N) & (cols[None, :] < D) + offs = rows[:, None] * D + cols[None, :] + dw += tl.load(DW + offs, mask=mask, other=0.0) + + sum_dw = tl.sum(dw, axis=0) + tl.store(FINAL_DW + cols, sum_dw.to(FINAL_DW.dtype.element_ty), mask=cols < D) + + +class RMSNormFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + x: torch.Tensor, + weight: torch.Tensor, + eps: float, + silu: bool, + ) -> torch.Tensor: + assert x.dim() == 2 + x = switch_to_contiguous_if_needed(x) + N, D = x.shape + assert weight.dim() == 1 + assert weight.numel() == D + + y = torch.empty_like(x) + rstd = torch.empty((N,), dtype=torch.float32, device=x.device) + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_D = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BLOCK_D: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + + ctx.save_for_backward(x, weight, rstd) + ctx.silu = silu + if N == 0: + return y + + # pyre-ignore[28] + grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]),) # noqa E731 + _weighted_rms_norm_fwd[grid]( + x, + y, + weight, + rstd, + N, + D, + eps, + x.stride(0), + y.stride(0), + SILU=silu, + BLOCK_D=BLOCK_D, + ) + + ctx.BLOCK_D = BLOCK_D + ctx.eps = eps + return y + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, dy: torch.Tensor + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], None, None]: + x, weight, rstd = ctx.saved_tensors + N, D = x.shape + dx = torch.empty_like(x) + dweight = torch.zeros((D,), dtype=weight.dtype, device=x.device) + if N == 0: + dweight.zero_() + return dx, dweight, None, None + + sms = torch.cuda.get_device_properties(x.device).multi_processor_count + + # pyre-ignore[28] + grid = lambda meta: ( # noqa E731 + max(1, min(sms * meta["SHARDS_PER_SM"], N // 4)), + ) + _weighted_rms_norm_bwd[grid]( + dx, + dy, + dweight, + x, + weight, + rstd, + dx.stride(0), + dy.stride(0), + x.stride(0), + D, + ctx.eps, + N=N, + SILU=ctx.silu, + BLOCK_D=ctx.BLOCK_D, + ) + + return dx, dweight, None, None + + +class SwishLayerNormFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + ) -> torch.Tensor: + assert x.dim() == 2, f"x.dim() == {x.dim()}, expected 2" + x = switch_to_contiguous_if_needed(x) + N, D = x.shape + + assert bias is not None and weight is not None + assert weight.dim() == 1 + assert bias.dim() == 1 + assert weight.numel() == D + assert bias.numel() == D + + y = torch.empty_like(x) + mean = torch.empty((N,), dtype=torch.float32, device=x.device) + rstd = torch.empty((N,), dtype=torch.float32, device=x.device) + + BLOCK_D = triton.next_power_of_2(D) + num_warps = min(max(BLOCK_D // 256, 1), 8) + + ctx.save_for_backward(x, weight, bias, mean, rstd) + ctx.BLOCK_D = BLOCK_D + ctx.num_warps = num_warps + ctx.eps = eps + if N == 0: + return y + + # pyre-ignore[28] + grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]),) # noqa E731 + _weighted_layer_norm_fwd[grid]( + x, + y, + weight, + bias, + mean, + rstd, + N, + D, + eps, + x.stride(0), + y.stride(0), + IS_SWISH=True, + TRAINING=True, + BLOCK_D=BLOCK_D, + COMPUTE_MEAN_AND_RSTD=True, + ) + + return y + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, dy: torch.Tensor + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], None]: + x, weight, bias, mean, rstd = ctx.saved_tensors + N, D = x.shape + dx = torch.empty_like(x) + sms = torch.cuda.get_device_properties(x.device).multi_processor_count + tile_num = max(1, min(sms * 8, N // 4)) + _dweight = torch.empty((tile_num, D), dtype=torch.float32, device=x.device) + _dbias = torch.empty((tile_num, D), dtype=torch.float32, device=x.device) + dweight = torch.empty((D,), dtype=weight.dtype, device=x.device) + dbias = torch.empty((D,), dtype=weight.dtype, device=x.device) + if N == 0: + dweight.zero_() + dbias.zero_() + return dx, dweight, dbias, None + # pyre-ignore[28] + _weighted_layer_norm_bwd_dx[(tile_num,)]( + dx, + dy, + _dweight, + _dbias, + x, + weight, + bias, + mean, + rstd, + dx.stride(0), + dy.stride(0), + x.stride(0), + D, + ctx.eps, + IS_SWISH=True, + N=N, + BLOCK_D=ctx.BLOCK_D, + ) + + def grid(META): + return (triton.cdiv(D, META["BLOCK_D"]),) + + blocks = triton.next_power_of_2(sms * 4) + BLOCK_D = triton.next_power_of_2(triton.cdiv(D, blocks)) + BLOCK_D = min(max(BLOCK_D, 4), 128) + _layer_norm_bwd_dwdb[grid]( + _dweight, + _dbias, + dweight, + dbias, + tile_num, + D, + BLOCK_D=BLOCK_D, + ) + + return dx, dweight, dbias, None + + +@torch.jit.unused +@torch.fx.wrap +def triton_layer_norm( + x: torch.Tensor, + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + eps: float, +) -> torch.Tensor: + return LayerNormFunction.apply(x, weight, bias, eps) + + +@torch.jit.unused +@torch.fx.wrap +def triton_rms_norm( + x: torch.Tensor, + weight: Optional[torch.Tensor], + eps: float, + silu: bool = False, +) -> torch.Tensor: + return RMSNormFunction.apply(x, weight, eps, silu) + + +@torch.jit.unused +@torch.fx.wrap +def triton_swish_layer_norm( + x: torch.Tensor, + normalized_shape: List[int], + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + eps: float, +) -> torch.Tensor: + return SwishLayerNormFunction.apply(x, weight, bias, eps) diff --git a/recommendation_v4/generative_recommenders/ops/triton/triton_position.py b/recommendation_v4/generative_recommenders/ops/triton/triton_position.py new file mode 100644 index 000000000..72c43f9ac --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/triton/triton_position.py @@ -0,0 +1,438 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-unsafe + +#!/usr/bin/env python3 + + +from typing import List, Optional, Tuple + +import torch + +# @manual=//triton:triton +import triton + +# @manual=//triton:triton +import triton.language as tl + +try: + torch.ops.load_library("//hammer/ops/cuda:cuda_ops") +except OSError: + pass + +from generative_recommenders.common import ( + autotune_max_seq_len, + prev_power_of_2, + switch_to_contiguous_if_needed, + triton_autotune, +) + + +def _autotune_configs() -> List[triton.Config]: + configs = [] + for BLOCK_N in [16, 32, 64]: + for num_stages in [1, 2]: + for num_warps in [2, 4, 8]: + configs.append( + triton.Config( + { + "BLOCK_N": BLOCK_N, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + return configs + + +@triton_autotune( + configs=_autotune_configs(), + key=["AUTOTUNE_MAX_SEQ_LEN"], +) +@triton.jit +def _add_timestamp_position_embeddings_kernel( + SeqEmb, + Offsets, + Lengths, + PosEmb, + TsEmb, + Out, + TS, + PosInds, + TsInds, + NumTargets, + AUTOTUNE_MAX_SEQ_LEN, + D, + num_time_buckets, + time_bucket_increments, + time_bucket_scale, + time_delta, + max_contextual_seq_len, + max_pos_ind, + stride_sn, + stride_pn, + stride_tn, + stride_on, + TRAINING: tl.constexpr, + HAS_MULTIPLE_TARGETS: tl.constexpr, + INTERLEAVE_TARGETS: tl.constexpr, + TIME_BUCKET_FN: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + SeqEmb has shape (sum_B(N_i), D), + PosEmb has shape (N_p, D), + TsEmb has shape (N_t, D), + Out has shape (sum_B(N_i), D) + """ + + off_b = tl.program_id(0) + off_n = tl.program_id(1) + seq_start = tl.load(Offsets + off_b) + seq_end = tl.load(Offsets + off_b + 1) + seq_len = seq_end - seq_start + start_n = off_n * BLOCK_N + if start_n >= seq_len: + return + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + seq_emb_offsets = offs_n[:, None] * stride_sn + offs_d[None, :] + SeqEmb += seq_start.to(tl.int64) * stride_sn + mask_n = offs_n < seq_len + # position encoding + seq_len = tl.load(Lengths + off_b) + if HAS_MULTIPLE_TARGETS: + num_targets = tl.load(NumTargets + off_b) + if INTERLEAVE_TARGETS: + high_ind = seq_len - num_targets * 2 + else: + high_ind = seq_len - num_targets + else: + high_ind = seq_len + pos_inds = tl.where(offs_n < high_ind, offs_n, high_ind) + pos_inds = high_ind - pos_inds + max_contextual_seq_len + pos_inds = tl.where(pos_inds < max_pos_ind - 1, pos_inds, max_pos_ind - 1) + pos_inds = tl.where(offs_n < max_contextual_seq_len, offs_n, pos_inds) + if TRAINING: + tl.store(PosInds + seq_start + offs_n, pos_inds, mask=mask_n) + pos_emb_offsets = pos_inds[:, None] * stride_pn + offs_d[None, :] + # timestamp encoding + ts = tl.load(TS + seq_start + offs_n, mask=mask_n) + query_time = tl.load(TS + seq_end - 1) + ts = query_time - ts + time_delta + ts = tl.where(ts > 1e-6, ts, 1e-6) / time_bucket_increments + if TIME_BUCKET_FN == "log": + ts = tl.log(ts) + else: + ts = tl.sqrt(ts) + ts = ts * time_bucket_scale + ts = ts.to(tl.int32) + ts = tl.where(ts > 0, ts, 0) + ts = tl.where(ts < num_time_buckets, ts, num_time_buckets) + if TRAINING: + tl.store(TsInds + seq_start + offs_n, ts, mask=mask_n) + ts_emb_offsets = ts[:, None] * stride_tn + offs_d[None, :] + Out += seq_start.to(tl.int64) * stride_on + out_offsets = Out + offs_n[:, None] * stride_on + offs_d[None, :] + for _d in range(0, D, BLOCK_D): + mask = (offs_n[:, None] < seq_len) and offs_d[None, :] < D + seq_emb = tl.load(SeqEmb + seq_emb_offsets, mask=mask) + pos_emb = tl.load(PosEmb + pos_emb_offsets, mask=mask) + ts_emb = tl.load(TsEmb + ts_emb_offsets, mask=mask) + tl.store(out_offsets, seq_emb + (pos_emb + ts_emb).to(seq_emb.dtype), mask=mask) + seq_emb_offsets += BLOCK_D + pos_emb_offsets += BLOCK_D + ts_emb_offsets += BLOCK_D + out_offsets += BLOCK_D + offs_d += BLOCK_D + + +def bwd_pre_hook(nargs): + nargs["Out"].zero_() + + +def _add_embeddings_bwd_configs() -> List[triton.Config]: + configs = [] + for BLOCK in [32, 64, 128]: + for num_stages in [2, 3, 4]: + for num_warps in [2, 4, 8]: + configs.append( + triton.Config( + { + "BLOCK": BLOCK, + }, + num_stages=num_stages, + num_warps=num_warps, + pre_hook=bwd_pre_hook, + ) + ) + return configs + + +@triton_autotune( + configs=_add_embeddings_bwd_configs(), + key=["AUTOTUNE_MAX_SEQ_LEN", "AUTOTUNE_B", "D"], +) +@triton.jit +def _add_embeddings_bwd_kernel( + In, + KeyInds, + ValueInds, + Out, + AUTOTUNE_MAX_SEQ_LEN, + AUTOTUNE_B, + D, + jagged_size, + stride_in, + stride_on, + BLOCK_D: tl.constexpr, + BLOCK: tl.constexpr, +): + off_block = tl.program_id(0) + offs_d = tl.arange(0, BLOCK_D) + mask_d = offs_d < D + key_ind = -1 + key_ind = key_ind.to(KeyInds.dtype.element_ty) # pyre-ignore[16] + accumulator = tl.zeros((BLOCK_D,), dtype=In.dtype.element_ty) + for off_i in range(0, BLOCK): + off = off_block * BLOCK + off_i + if off < jagged_size: + value_ind = tl.load(ValueInds + off) + in_offset = In + value_ind.to(tl.int64) * stride_in + jagged_in = tl.load(in_offset + offs_d, mask=mask_d) + key_ind_new = tl.load(KeyInds + off) + if key_ind == key_ind_new: + accumulator += jagged_in + else: + if key_ind >= 0: + out_offset = Out + key_ind.to(tl.int64) * stride_on + tl.atomic_add( + out_offset + offs_d, + accumulator.to(Out.dtype.element_ty), + mask=mask_d, + sem="relaxed", + ) + key_ind = key_ind_new + accumulator = jagged_in + if key_ind >= 0: + out_offset = Out + key_ind.to(tl.int64) * stride_on + tl.atomic_add( + out_offset + offs_d, + accumulator.to(Out.dtype.element_ty), + mask=mask_d, + sem="relaxed", + ) + + +class _AddTimestampPositionEmbeddingsFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + seq_embeddings: torch.Tensor, + seq_offsets: torch.Tensor, + pos_embeddings: torch.Tensor, + ts_embeddings: torch.Tensor, + timestamps: torch.Tensor, + max_seq_len: int, + max_contextual_seq_len: int, + seq_lengths: torch.Tensor, + num_targets: Optional[torch.Tensor], + interleave_targets: bool, + time_bucket_fn: str, + ): + seq_embeddings = switch_to_contiguous_if_needed(seq_embeddings) + pos_embeddings = switch_to_contiguous_if_needed(pos_embeddings) + ts_embeddings = switch_to_contiguous_if_needed(ts_embeddings) + + max_pos_ind = pos_embeddings.shape[0] + B = seq_lengths.shape[0] + N, D = seq_embeddings.shape + assert len(pos_embeddings.shape) == 2 + assert len(ts_embeddings.shape) == 2 + assert pos_embeddings.shape[1] == D, ( + "shape[1] of pos_embeddings much match seq_embeddings" + ) + assert ts_embeddings.shape[1] == D, ( + "shape[1] of ts_embeddings much match seq_embeddings" + ) + out = torch.empty_like(seq_embeddings) + + timestamps = switch_to_contiguous_if_needed(timestamps) + ts_inds = torch.empty_like(seq_embeddings[:, 0], dtype=torch.int32) + pos_inds = torch.empty_like(seq_embeddings[:, 0], dtype=torch.int32) + ts_emb_size = ts_embeddings.shape[0] + + grid = lambda meta: ( # noqa E731 + B, + triton.cdiv(max_seq_len, meta["BLOCK_N"]), + ) + BLOCK_D = triton.next_power_of_2(D) if D < 64 else 64 + _add_timestamp_position_embeddings_kernel[grid]( + SeqEmb=seq_embeddings, + Offsets=seq_offsets, + Lengths=seq_lengths, + PosEmb=pos_embeddings, + TsEmb=ts_embeddings, + Out=out, + TS=timestamps, + PosInds=pos_inds, + TsInds=ts_inds, + NumTargets=num_targets, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(max_seq_len), + D=D, + num_time_buckets=ts_emb_size - 1, + time_bucket_increments=60.0, + time_bucket_scale=1.0, + time_delta=0, + max_contextual_seq_len=max_contextual_seq_len, + max_pos_ind=max_pos_ind, + stride_sn=seq_embeddings.stride(0), + stride_pn=pos_embeddings.stride(0), + stride_tn=ts_embeddings.stride(0), + stride_on=out.stride(0), + TRAINING=True, + HAS_MULTIPLE_TARGETS=num_targets is not None, + INTERLEAVE_TARGETS=interleave_targets, + TIME_BUCKET_FN=time_bucket_fn, + BLOCK_D=BLOCK_D, + ) + try: + values = torch.arange(0, N, dtype=torch.int32, device=timestamps.device) + sorted_ts_key_inds, sorted_ts_value_inds = torch.ops.hammer.sort_kv_pairs( + ts_inds, values + ) + sorted_pos_key_inds, sorted_pos_value_inds = torch.ops.hammer.sort_kv_pairs( + pos_inds, values + ) + except Exception: + sorted_ts_key_inds, sorted_ts_value_inds = torch.sort(ts_inds) + sorted_pos_key_inds, sorted_pos_value_inds = torch.sort(pos_inds) + ctx.save_for_backward( + sorted_pos_key_inds, + sorted_pos_value_inds, + sorted_ts_key_inds, + sorted_ts_value_inds, + ) + ctx.B = B + ctx.D = D + ctx.max_seq_len = max_seq_len + ctx.pos_emb_size = pos_embeddings.shape[0] + ctx.ts_emb_size = ts_emb_size + ctx.pos_dtype = pos_embeddings.dtype + ctx.ts_dtype = ts_embeddings.dtype + return out + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, d_out: torch.Tensor + ) -> Tuple[ + torch.Tensor, + None, + torch.Tensor, + torch.Tensor, + None, + None, + None, + None, + None, + None, + None, + ]: + ( + sorted_pos_key_inds, + sorted_pos_value_inds, + sorted_ts_key_inds, + sorted_ts_value_inds, + ) = ctx.saved_tensors + d_pos_embeddings = torch.empty( + (ctx.pos_emb_size, ctx.D), device=d_out.device, dtype=torch.float32 + ) + d_ts_embeddings = torch.empty( + (ctx.ts_emb_size, ctx.D), device=d_out.device, dtype=torch.float32 + ) + grid = lambda meta: (triton.cdiv(d_out.shape[0], meta["BLOCK"]),) # noqa E731 + AUTOTUNE_B = prev_power_of_2(ctx.B) + _add_embeddings_bwd_kernel[grid]( + In=d_out, + KeyInds=sorted_pos_key_inds, + ValueInds=sorted_pos_value_inds, + Out=d_pos_embeddings, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(ctx.max_seq_len), + AUTOTUNE_B=AUTOTUNE_B, + D=ctx.D, + jagged_size=d_out.shape[0], + stride_in=d_out.stride(0), + stride_on=d_pos_embeddings.stride(0), + BLOCK_D=triton.next_power_of_2(ctx.D), + ) + _add_embeddings_bwd_kernel[grid]( + In=d_out, + KeyInds=sorted_ts_key_inds, + ValueInds=sorted_ts_value_inds, + Out=d_ts_embeddings, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(ctx.max_seq_len), + AUTOTUNE_B=AUTOTUNE_B, + D=ctx.D, + jagged_size=d_out.shape[0], + stride_in=d_out.stride(0), + stride_on=d_ts_embeddings.stride(0), + BLOCK_D=triton.next_power_of_2(ctx.D), + ) + return ( + d_out, + None, + d_pos_embeddings.to(ctx.pos_dtype), + d_ts_embeddings.to(ctx.ts_dtype), + None, + None, + None, + None, + None, + None, + None, + ) + + +@torch.jit.unused +@torch.fx.wrap +def triton_add_timestamp_positional_embeddings( + seq_embeddings: torch.Tensor, + seq_offsets: torch.Tensor, + pos_embeddings: torch.Tensor, + ts_embeddings: torch.Tensor, + timestamps: torch.Tensor, + max_seq_len: int, + max_contextual_seq_len: int, + seq_lengths: torch.Tensor, + num_targets: Optional[torch.Tensor], + interleave_targets: bool, + time_bucket_fn: str, +) -> torch.Tensor: + return _AddTimestampPositionEmbeddingsFunction.apply( + seq_embeddings, + seq_offsets, + pos_embeddings, + ts_embeddings, + timestamps, + max_seq_len, + max_contextual_seq_len, + seq_lengths, + num_targets, + interleave_targets, + time_bucket_fn, + ) diff --git a/recommendation_v4/generative_recommenders/ops/triton/triton_swiglu.py b/recommendation_v4/generative_recommenders/ops/triton/triton_swiglu.py new file mode 100644 index 000000000..5f30da53b --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/triton/triton_swiglu.py @@ -0,0 +1,753 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-unsafe + +from typing import List + +import torch + +# @manual=//triton:triton +import triton + +# @manual=//triton:triton +import triton.language as tl +from generative_recommenders.common import triton_autotune +from generative_recommenders.ops.utils import is_sm100_plus + +TMA_AVAILABLE = False +try: + # @manual=//triton:triton + from triton.tools.tensor_descriptor import TensorDescriptor + + TMA_AVAILABLE = True +except ImportError: + pass + +HAS_TLX = False +try: + # @manual=//triton:triton + import triton.language.extra.tlx as tlx # type: ignore + + HAS_TLX = True +except ImportError: + pass + + +def is_blackwell_triton_swiglu_supported() -> bool: + return is_sm100_plus() and TMA_AVAILABLE and HAS_TLX + + +def _swiglu_tma_set_block_size_hook(nargs) -> None: + BLOCK_M = nargs["BLOCK_M"] + BLOCK_N = nargs["BLOCK_N"] + BLOCK_K = nargs["BLOCK_K"] + EPILOGUE_SUBTILE = nargs.get("EPILOGUE_SUBTILE", 1) + + nargs["x_desc"].block_shape = [BLOCK_M, BLOCK_K] + nargs["w_gate_desc"].block_shape = [BLOCK_N, BLOCK_K] + nargs["w_up_desc"].block_shape = [BLOCK_N, BLOCK_K] + nargs["out_desc"].block_shape = [BLOCK_M, BLOCK_N // EPILOGUE_SUBTILE] + + +def get_swiglu_configs(pre_hook) -> List[triton.Config]: + return [ + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 128, + "BLOCK_K": 64, + "GROUP_M": 8, + }, + num_stages=1, + num_warps=4, + pre_hook=pre_hook, + ), + triton.Config( + { + "BLOCK_M": 64, + "BLOCK_N": 128, + "BLOCK_K": 64, + "GROUP_M": 8, + }, + num_stages=1, + num_warps=4, + pre_hook=pre_hook, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "BLOCK_K": 64, + "GROUP_M": 8, + }, + num_stages=1, + num_warps=4, + pre_hook=pre_hook, + ), + ] + + +@triton.jit +def _compute_pid_swiglu( + tile_id, + num_pid_in_group, + num_pid_m, + GROUP_M: tl.constexpr, + NUM_SMS: tl.constexpr, +): + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + return pid_m, pid_n + + +@triton_autotune( + configs=get_swiglu_configs(pre_hook=_swiglu_tma_set_block_size_hook), + key=["M_BLOCK", "N", "K"], +) +@triton.jit +def _swiglu_fwd_tma_ws_persistent( + x_desc, + w_gate_desc, + w_up_desc, + out_desc, + M, + N, + K, + M_BLOCK, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + NUM_SMEM_BUFFERS: tl.constexpr, + NUM_TMEM_BUFFERS: tl.constexpr, + NUM_SMS: tl.constexpr, + EPILOGUE_SUBTILE: tl.constexpr, +): + # Allocate SMEM buffers + x_buffers = tlx.local_alloc((BLOCK_M, BLOCK_K), x_desc.dtype, NUM_SMEM_BUFFERS) + + # Allocate SMEM buffers for W_gate and W_up + w_gate_buffers = tlx.local_alloc( + (BLOCK_N, BLOCK_K), w_gate_desc.dtype, NUM_SMEM_BUFFERS + ) + w_up_buffers = tlx.local_alloc( + (BLOCK_N, BLOCK_K), w_up_desc.dtype, NUM_SMEM_BUFFERS + ) + + # Allocate TMEM for accumulators + tmem_gate_buffers = tlx.local_alloc( + (BLOCK_M, BLOCK_N), tl.float32, NUM_TMEM_BUFFERS, tlx.storage_kind.tmem + ) + tmem_up_buffers = tlx.local_alloc( + (BLOCK_M, BLOCK_N), tl.float32, NUM_TMEM_BUFFERS, tlx.storage_kind.tmem + ) + + # Barriers for Producer <-> MMA synchronization + smem_full_bars_x_gate = tlx.alloc_barriers( + num_barriers=NUM_SMEM_BUFFERS, + arrive_count=1, # pyre-ignore[6] + ) + smem_full_bars_up = tlx.alloc_barriers( + num_barriers=NUM_SMEM_BUFFERS, + arrive_count=1, # pyre-ignore[6] + ) + # Empty barriers: arrive_count=2 because both GEMM1 and GEMM2 signal completion + smem_empty_bars = tlx.alloc_barriers( + num_barriers=NUM_SMEM_BUFFERS, + arrive_count=2, # pyre-ignore[6] + ) + + # Barriers for MMA <-> Epilogue synchronization + # pyre-ignore[6] + tmem_full_bars = tlx.alloc_barriers(num_barriers=NUM_TMEM_BUFFERS, arrive_count=1) + # pyre-ignore[6] + tmem_empty_bars = tlx.alloc_barriers(num_barriers=NUM_TMEM_BUFFERS, arrive_count=1) + + with tlx.async_tasks(): + # Epilogue Consumer: Reads from TMEM, applies SwiGLU, and stores to output + with tlx.async_task("default"): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_M * num_pid_n + num_tiles = num_pid_m * num_pid_n + + # Initialize buffer tracking + processed_tiles = 0 + + for tile_id in range(start_pid, num_tiles, NUM_SMS): + pid_m, pid_n = _compute_pid_swiglu( + tile_id, num_pid_in_group, num_pid_m, GROUP_M, NUM_SMS + ) + offs_m = pid_m * BLOCK_M + offs_n = pid_n * BLOCK_N + + cur_tmem_buf = processed_tiles % int(NUM_TMEM_BUFFERS) + tmem_read_phase = (processed_tiles // int(NUM_TMEM_BUFFERS)) & 1 + + # Wait for MMA to finish writing to TMEM + # pyre-ignore[16] + tlx.barrier_wait(tmem_full_bars[cur_tmem_buf], tmem_read_phase) + + # Load gate and up results from TMEM + # pyre-ignore[16] + gate_tmem = tmem_gate_buffers[cur_tmem_buf] + up_tmem = tmem_up_buffers[cur_tmem_buf] + + if EPILOGUE_SUBTILE > 1: + # Process tile in subtiles + slice_size: tl.constexpr = BLOCK_N // EPILOGUE_SUBTILE + for slice_id in tl.static_range(EPILOGUE_SUBTILE): + gate_subslice = tlx.local_slice( + gate_tmem, + [0, slice_id * slice_size], + # pyre-ignore[6] + [BLOCK_M, slice_size], + ) + up_subslice = tlx.local_slice( + up_tmem, + [0, slice_id * slice_size], + # pyre-ignore[6] + [BLOCK_M, slice_size], + ) + + gate = tlx.local_load(gate_subslice).to(out_desc.dtype) + up = tlx.local_load(up_subslice).to(out_desc.dtype) + + gate_fp32 = gate.to(tl.float32) + silu_gate = (gate_fp32 * tl.sigmoid(gate_fp32)).to( + out_desc.dtype + ) + result = silu_gate * up + + out_desc.store([offs_m, offs_n + slice_id * slice_size], result) + else: + # Process full tile + gate = tlx.local_load(gate_tmem).to(out_desc.dtype) + up = tlx.local_load(up_tmem).to(out_desc.dtype) + + gate_fp32 = gate.to(tl.float32) + silu_gate = (gate_fp32 * tl.sigmoid(gate_fp32)).to(out_desc.dtype) + result = silu_gate * up + + out_desc.store([offs_m, offs_n], result) + + # Signal MMA that TMEM buffer is free + # pyre-ignore[6] + tlx.barrier_arrive(tmem_empty_bars[cur_tmem_buf], 1) + + processed_tiles += 1 + + # MMA Consumer: Computes both GEMMs: gate = X @ W_gate, up = X @ W_up + with tlx.async_task(num_warps=4, num_regs=232): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_M * num_pid_n + num_tiles = num_pid_m * num_pid_n + k_tiles = tl.cdiv(K, BLOCK_K) + + processed_k_iters = 0 + processed_tiles = 0 + + for tile_id in range(start_pid, num_tiles, NUM_SMS): + pid_m, pid_n = _compute_pid_swiglu( + tile_id, num_pid_in_group, num_pid_m, GROUP_M, NUM_SMS + ) + + cur_tmem_buf = processed_tiles % int(NUM_TMEM_BUFFERS) + tmem_write_phase = (processed_tiles // int(NUM_TMEM_BUFFERS)) & 1 + + # Wait for epilogue to finish + tlx.barrier_wait(tmem_empty_bars[cur_tmem_buf], tmem_write_phase ^ 1) + + # Perform K-dimension reduction for both GEMMs + for k in range(0, k_tiles): + buf = (processed_k_iters + k) % int(NUM_SMEM_BUFFERS) + + total_iters = processed_k_iters + k + dot_phase = (total_iters // int(NUM_SMEM_BUFFERS)) & 1 + + # Wait for x and w_gate to be loaded, then start GEMM1 + tlx.barrier_wait(smem_full_bars_x_gate[buf], dot_phase) + + # Transpose weight buffer for MMA + w_gate_trans = tlx.local_trans(w_gate_buffers[buf]) + + # GEMM 1: gate = X @ W_gate.T + tlx.async_dot( + x_buffers[buf], + w_gate_trans, + tmem_gate_buffers[cur_tmem_buf], + # pyre-ignore[6] + use_acc=(k > 0), + mBarriers=[smem_empty_bars[buf]], + out_dtype=tl.float32, + ) + + # Wait for w_up to be loaded before starting GEMM2 + tlx.barrier_wait(smem_full_bars_up[buf], dot_phase) + + w_up_trans = tlx.local_trans(w_up_buffers[buf]) + + # GEMM 2: up = X @ W_up.T + tlx.async_dot( + x_buffers[buf], + w_up_trans, + tmem_up_buffers[cur_tmem_buf], + # pyre-ignore[6] + use_acc=(k > 0), + mBarriers=[smem_empty_bars[buf]], + out_dtype=tl.float32, + ) + + # Wait for last MMA to complete + last_buf = (processed_k_iters + k_tiles - 1) % int(NUM_SMEM_BUFFERS) + last_total_iters = processed_k_iters + k_tiles - 1 + last_dot_phase = (last_total_iters // int(NUM_SMEM_BUFFERS)) & 1 + tlx.barrier_wait(smem_empty_bars[last_buf], last_dot_phase) + + # Signal epilogue that results are ready + # pyre-ignore[6] + tlx.barrier_arrive(tmem_full_bars[cur_tmem_buf], 1) + + processed_tiles += 1 + processed_k_iters += k_tiles + + # Producer: TMA loads for X, W_gate, W_up + with tlx.async_task(num_warps=1, num_regs=24): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_M * num_pid_n + num_tiles = num_pid_m * num_pid_n + k_tiles = tl.cdiv(K, BLOCK_K) + + # Initialize phase tracking + processed_k_iters = 0 + + for tile_id in range(start_pid, num_tiles, NUM_SMS): + pid_m, pid_n = _compute_pid_swiglu( + tile_id, num_pid_in_group, num_pid_m, GROUP_M, NUM_SMS + ) + offs_m = pid_m * BLOCK_M + offs_n = pid_n * BLOCK_N + + for k in range(0, k_tiles): + buf = (processed_k_iters + k) % int(NUM_SMEM_BUFFERS) + + total_iters = processed_k_iters + k + load_phase = (total_iters // int(NUM_SMEM_BUFFERS)) & 1 + + # Wait for buffer to be free + tlx.barrier_wait(smem_empty_bars[buf], load_phase ^ 1) + + offs_k = k * BLOCK_K + + # Set expected bytes for x+w_gate barrier + tlx.barrier_expect_bytes( + smem_full_bars_x_gate[buf], + # pyre-ignore[6] + 2 * (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N), + ) + + # Set expected bytes for w_up barrier + tlx.barrier_expect_bytes( + smem_full_bars_up[buf], + # pyre-ignore[6] + 2 * (BLOCK_K * BLOCK_N), + ) + + # Load x and w_gate first, signal smem_full_bars_x_gate + tlx.async_descriptor_load( + x_desc, + x_buffers[buf], + [offs_m, offs_k], + smem_full_bars_x_gate[buf], + ) + + # Weights are in [N, K] layout, load with [offs_n, offs_k] + tlx.async_descriptor_load( + w_gate_desc, + w_gate_buffers[buf], + [offs_n, offs_k], + smem_full_bars_x_gate[buf], + ) + + # Load w_up separately, signal smem_full_bars_up + tlx.async_descriptor_load( + w_up_desc, + w_up_buffers[buf], + [offs_n, offs_k], + smem_full_bars_up[buf], + ) + + processed_k_iters += k_tiles + + +@torch.fx.wrap +def triton_swiglu_fwd_tma_ws_persistent_tlx( + x: torch.Tensor, + w_gate: torch.Tensor, + w_up: torch.Tensor, +) -> torch.Tensor: + M, K = x.shape + N, K_gate = w_gate.shape + N_up, K_up = w_up.shape + + # Only bf16/fp16 supported by the kernel + supported_dtypes = (torch.bfloat16, torch.float16) + assert x.dtype in supported_dtypes, ( + f"x.dtype must be bfloat16 or float16, got {x.dtype}" + ) + assert w_gate.dtype in supported_dtypes, ( + f"w_gate.dtype must be bfloat16 or float16, got {w_gate.dtype}" + ) + assert w_up.dtype in supported_dtypes, ( + f"w_up.dtype must be bfloat16 or float16, got {w_up.dtype}" + ) + + assert K == K_gate, f"Incompatible dimensions: x.K={K}, w_gate.K={K_gate}" + assert K == K_up, f"Incompatible dimensions: x.K={K}, w_up.K={K_up}" + assert N == N_up, f"Incompatible dimensions: w_gate.N={N}, w_up.N={N_up}" + + # Allocate output + out = torch.empty((M, N), device=x.device, dtype=x.dtype) + if M == 0 or N == 0: + return out + + M_BLOCK = triton.next_power_of_2(M) + + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + # A dummy block value that will be overwritten by the hook + dummy_block = [1, 1] + + # pyre-ignore[6] + x_desc = TensorDescriptor(x, x.shape, x.stride(), dummy_block) + # pyre-ignore[6] + w_gate_desc = TensorDescriptor(w_gate, w_gate.shape, w_gate.stride(), dummy_block) + # pyre-ignore[6] + w_up_desc = TensorDescriptor(w_up, w_up.shape, w_up.stride(), dummy_block) + # pyre-ignore[6] + out_desc = TensorDescriptor(out, out.shape, out.stride(), dummy_block) + + def grid(meta): + BLOCK_M = meta["BLOCK_M"] + BLOCK_N = meta["BLOCK_N"] + return ( + min( + NUM_SMS, + triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), + ), + ) + + _swiglu_fwd_tma_ws_persistent[grid]( + x_desc, + w_gate_desc, + w_up_desc, + out_desc, + M, + N, + K, + M_BLOCK, + NUM_SMS=NUM_SMS, + NUM_SMEM_BUFFERS=4, + NUM_TMEM_BUFFERS=2, + EPILOGUE_SUBTILE=2, + ) + return out + + +# ============================================================================= +# Standard fused SwiGLU kernel for A100/H100 (non-TLX path). +# +# Fuses silu(x @ W_gate^T) * (x @ W_up^T) into a single kernel launch. +# Uses standard Triton pointer arithmetic (no TMA), works on SM80+. +# +# Key optimization: x is loaded from HBM ONCE and reused for both GEMMs. +# Activation (silu * up) is computed in float32 registers, no HBM round-trip. +# +# Weight layout: expects [N, K] (nn.Linear native format). +# The wrapper transposes to [K, N] for the GEMM internally. +# ============================================================================= + + +def _get_swiglu_fwd_configs() -> List[triton.Config]: + """ + Autotune configs for the standard (non-TLX) fused SwiGLU kernel. + + Two float32 accumulators (gate + up) double register pressure vs single + GEMM, so smaller block sizes are included. + """ + configs = [ + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8}, + num_stages=3, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8}, + num_stages=3, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_M": 8}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_M": 8}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_M": 8}, + num_stages=3, + num_warps=8, + ), + ] + if torch.version.hip: + hip_num_stages = 2 if triton.__version__ >= "3.2.0" else 0 + configs.extend( + [ + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_M": 8}, + num_stages=hip_num_stages, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_M": 8}, + num_stages=hip_num_stages, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8}, + num_stages=hip_num_stages, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8}, + num_stages=hip_num_stages, + num_warps=4, + ), + ] + ) + return configs + + +@triton_autotune( + configs=_get_swiglu_fwd_configs(), + key=["M_BLOCK", "N", "K"], +) +@triton.jit +def _swiglu_fwd_kernel( + # Pointers to input/output tensors + x_ptr, # [M, K] input activation + w_gate_ptr, # [K, N] gate weight (already transposed from [N, K]) + w_up_ptr, # [K, N] up weight (already transposed from [N, K]) + out_ptr, # [M, N] output = silu(x @ w_gate) * (x @ w_up) + # Matrix dimensions + M, # rows in x (batch_size * seq_len) + N, # output dimension (hidden_dim) + K, # input/reduction dimension (input_dim) + M_BLOCK, # next_power_of_2(M) for stable autotuning + # Strides + stride_xm, + stride_xk, + stride_wgk, + stride_wgn, + stride_wuk, + stride_wun, + stride_om, + stride_on, + # Compile-time constants + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + ALLOW_TF32: tl.constexpr, +): + """ + Fused SwiGLU forward: out = silu(x @ W_gate) * (x @ W_up). + + Each thread block computes one [BLOCK_M, BLOCK_N] output tile. + Two accumulators share the same x tile loads (the fusion benefit). + """ + # -- Step 1: Compute tile coordinates with grouped ordering (L2 reuse) -- + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_M + + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_in_group = pid % num_pid_in_group + pid_m = first_pid_m + (pid_in_group % group_size_m) + pid_n = pid_in_group // group_size_m + + # -- Step 2: Set up pointers for x, w_gate, w_up tiles -- + offs_m = tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + mask_m = (pid_m * BLOCK_M + offs_m)[:, None] < M + mask_n = (pid_n * BLOCK_N + offs_n)[None, :] < N + # [BLOCK_M, BLOCK_K] + x_ptrs = ( + x_ptr + + (pid_m.to(tl.int64) * BLOCK_M + offs_m)[:, None] * stride_xm + + offs_k[None, :] * stride_xk + ) + # [BLOCK_K, BLOCK_N] + wg_ptrs = ( + w_gate_ptr + + offs_k[:, None] * stride_wgk + + (pid_n.to(tl.int64) * BLOCK_N + offs_n)[None, :] * stride_wgn + ) + + # [BLOCK_K, BLOCK_N] + wu_ptrs = ( + w_up_ptr + + offs_k[:, None] * stride_wuk + + (pid_n.to(tl.int64) * BLOCK_N + offs_n)[None, :] * stride_wun + ) + + # -- Step 3: K-loop - two GEMMs sharing the same x tile -- + acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_K)): + mask_k = offs_k[None, :] < K - k * BLOCK_K + x = tl.load(x_ptrs, mask=mask_m & mask_k, other=0.0) + mask_k = offs_k[:, None] < K - k * BLOCK_K + wg = tl.load(wg_ptrs, mask=mask_k & mask_n, other=0.0) + wu = tl.load(wu_ptrs, mask=mask_k & mask_n, other=0.0) + + acc_gate += tl.dot(x, wg, allow_tf32=ALLOW_TF32) + acc_up += tl.dot(x, wu, allow_tf32=ALLOW_TF32) + + x_ptrs += BLOCK_K * stride_xk + wg_ptrs += BLOCK_K * stride_wgk + wu_ptrs += BLOCK_K * stride_wuk + + # -- Step 4: Apply SwiGLU activation in registers (no HBM round-trip) -- + gate_activated = acc_gate * tl.sigmoid(acc_gate) # silu + result = (gate_activated * acc_up).to(out_ptr.dtype.element_ty) + + # -- Step 5: Store result -- + offs_m = pid_m * BLOCK_M + offs_m + offs_n = pid_n * BLOCK_N + offs_n + out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + tl.store(out_ptrs, result, mask=mask_m & mask_n) + + +def triton_swiglu_fwd( + x: torch.Tensor, + w_gate: torch.Tensor, + w_up: torch.Tensor, +) -> torch.Tensor: + """ + Forward pass of fused SwiGLU (non-TLX path). Works on A100/H100/MI300X. + + Computes: silu(x @ w_gate^T) * (x @ w_up^T) + + Args: + x: [M, K] input tensor + w_gate: [N, K] gate weight (nn.Linear format) + w_up: [N, K] up weight (nn.Linear format) + + Returns: + [M, N] output tensor + """ + M, K = x.shape + N, K_gate = w_gate.shape + N_up, K_up = w_up.shape + assert K == K_gate, f"x.K={K} != w_gate.K={K_gate}" + assert K == K_up, f"x.K={K} != w_up.K={K_up}" + assert N == N_up, f"w_gate.N={N} != w_up.N={N_up}" + + out = torch.empty((M, N), device=x.device, dtype=x.dtype) + if M == 0 or N == 0: + return out + + M_BLOCK = triton.next_power_of_2(M) + + # Transpose weights from [N, K] to [K, N] for the GEMM kernel + w_gate_t = w_gate.t().contiguous() + w_up_t = w_up.t().contiguous() + + grid = lambda meta: ( # noqa E731 + triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]), + ) + + _swiglu_fwd_kernel[grid]( + x, + w_gate_t, + w_up_t, + out, + M, + N, + K, + M_BLOCK, + x.stride(0), + x.stride(1), + w_gate_t.stride(0), + w_gate_t.stride(1), + w_up_t.stride(0), + w_up_t.stride(1), + out.stride(0), + out.stride(1), + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + ) + return out + + +def triton_swiglu( + x: torch.Tensor, + w_gate: torch.Tensor, + w_up: torch.Tensor, +) -> torch.Tensor: + if is_sm100_plus() and TMA_AVAILABLE and HAS_TLX: + # Blackwell: use the fast TLX persistent kernel with TMA + _, K = x.shape + N, _ = w_gate.shape + assert K % 16 == 0 and N % 16 == 0, ( + f"K ({K}) and N ({N}) must be divisible by 16 for TMA alignment" + ) + return triton_swiglu_fwd_tma_ws_persistent_tlx(x, w_gate, w_up) + else: + # A100/H100: use the standard fused kernel + return triton_swiglu_fwd(x, w_gate, w_up) diff --git a/recommendation_v4/generative_recommenders/ops/utils.py b/recommendation_v4/generative_recommenders/ops/utils.py new file mode 100644 index 000000000..16edd99a9 --- /dev/null +++ b/recommendation_v4/generative_recommenders/ops/utils.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-ignore-all-errors + +import functools +import os + +import torch + + +class _PlainFuncWrapper: + """Thin wrapper around a plain function that provides no-op register_fake + and register_kernel methods, mirroring the CustomOpDef API so that + downstream @func.register_fake / func.register_kernel("cpu") calls + don't break when the function is not wrapped as a custom op.""" + + def __init__(self, func): + self._func = func + functools.update_wrapper(self, func) + + def __call__(self, *args, **kwargs): + return self._func(*args, **kwargs) + + def register_fake(self, fake_func): + return fake_func + + def register_kernel(self, device): + def inner(func): + return func + + return inner + + +def maybe_register_custom_op(op_name, mutates_args): + """ + Conditionally registers a function as a torch custom op. + + When AOTI_LOWER is set in the environment, the function is returned + unwrapped so that torch.export / Dynamo can trace through the plain + Python implementation instead of treating the custom op as opaque. + """ + + def decorator(func): + if os.environ.get("AOTI_LOWER"): + return _PlainFuncWrapper(func) + return torch.library.custom_op(op_name, func, mutates_args=mutates_args) + + return decorator + + +def is_sm100_plus() -> bool: + """ + Check if this is a Blackwell Datacenter GPU. + These are between 100 and 103 for B200-GB300. + """ + if not torch.cuda.is_available(): + return False + props = torch.cuda.get_device_properties(0) + return props.major == 10 and (props.minor >= 0 and props.minor <= 3) + + +def is_sm90() -> bool: + if not torch.cuda.is_available(): + return False + props = torch.cuda.get_device_properties(0) + return props.major == 9 and props.minor == 0 + + +def is_sm90_plus() -> bool: + return is_sm100_plus() or is_sm90() + + +@functools.lru_cache(maxsize=None) +def is_amd_mi350() -> bool: + """Detect an AMD Instinct MI350-class GPU (gfx950) running under ROCm. + + MI350 benefits from the same multi-row, separated-RNG layer-norm-mul-dropout + path as Blackwell datacenter parts (sm_100), so it is gated together with + is_sm100_plus() at the kernel dispatch sites. + """ + if not torch.cuda.is_available(): + return False + if getattr(torch.version, "hip", None) is None: + return False + try: + arch = torch.cuda.get_device_properties(0).gcnArchName or "" + except (AssertionError, RuntimeError, AttributeError): + return False + return "gfx950" in arch + + +def use_separated_rng_ln_mul_dropout() -> bool: + """Hardware that should use the autotuned, multi-row ``_ln_mul_dropout_fwd_rng`` + kernel with a precomputed dropout mask instead of the legacy single-row, + fused-RNG ``_ln_mul_dropout_fwd`` kernel. + + Blackwell datacenter GPUs (sm_100-103) and AMD MI350 (gfx950) both prefer the + separated-RNG path: it batches rows per program and lets the backward reuse the + same mask, which is a large win over launching one program per row. + """ + return is_sm100_plus() or is_amd_mi350() + + +def copy_if_different_ptr(dst: torch.Tensor, src: torch.Tensor) -> None: + if torch.compiler.is_compiling(): + # .data_ptr() will break PT2 + dst.copy_(src) + else: + if dst.data_ptr() != src.data_ptr(): + dst.copy_(src) diff --git a/recommendation_v4/generative_recommenders/tests/test_common.py b/recommendation_v4/generative_recommenders/tests/test_common.py new file mode 100644 index 000000000..be3823d67 --- /dev/null +++ b/recommendation_v4/generative_recommenders/tests/test_common.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import unittest + +import torch +from generative_recommenders.common import switch_to_contiguous_if_needed + + +class SwitchToContiguousIfNeededTest(unittest.TestCase): + def test_torchscript_does_not_compile_fx_tracing_helper(self) -> None: + class ContiguousModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return switch_to_contiguous_if_needed(x) + + scripted = torch.jit.script(ContiguousModule()) + x = torch.arange(12).reshape(3, 4).transpose(0, 1) + + out = scripted(x) + + self.assertTrue(torch.equal(out, x)) + self.assertTrue(out.is_contiguous()) + + +if __name__ == "__main__": + unittest.main() diff --git a/recommendation_v4/md5sums_yambda_5b_processed.txt b/recommendation_v4/md5sums_yambda_5b_processed.txt new file mode 100644 index 000000000..82998cca3 --- /dev/null +++ b/recommendation_v4/md5sums_yambda_5b_processed.txt @@ -0,0 +1,22 @@ +# MD5 checksums for the preprocessed yambda-5b dataset (processed_5b/). +# +# Format: standard `md5sum` output -> " ". +# Paths are relative to ${DLRM_DATA_PATH}/processed_5b/. +# +# These hashes are PLACEHOLDERS (TODO). They must be generated from a canonical +# preprocessing run before this benchmark is submitted, e.g.: +# +# cd "${DLRM_DATA_PATH}/processed_5b" +# md5sum train_sessions.parquet test_events.parquet session_index.parquet \ +# item_popularity.npy split_meta.json \ +# > /md5sums_yambda_5b_processed.txt +# +# Until then `verify_dataset.sh` falls back to an existence/layout check and +# warns that checksums are not yet pinned. +# +# TODO(rcp/data): replace the lines below with real md5 hashes. +TODO_GENERATE_HASH train_sessions.parquet +TODO_GENERATE_HASH test_events.parquet +TODO_GENERATE_HASH session_index.parquet +TODO_GENERATE_HASH item_popularity.npy +TODO_GENERATE_HASH split_meta.json diff --git a/recommendation_v4/rcp/README.md b/recommendation_v4/rcp/README.md new file mode 100644 index 000000000..02a977862 --- /dev/null +++ b/recommendation_v4/rcp/README.md @@ -0,0 +1,28 @@ +# Reference Convergence Points (RCP) + +**Status: placeholder — RCPs not yet generated. Intentionally left blank.** + +This directory will hold the Reference Convergence Points for the +recommendation_v4 (HSTU / yambda-5b) benchmark once convergence runs are +complete. + +Per the MLPerf Training +[CONTRIBUTING guidance](https://github.com/mlcommons/training_policies/blob/master/CONTRIBUTING.md) +("Some things to note while generating reference convergence points"): + +- Use FP32 or BF16 precision and record the exact precision used in the RCP JSON. +- Generate RCPs for at least **3 reasonable batch sizes**. +- Run RCPs with an eval frequency **higher** than the chosen benchmark eval + frequency (more data points for picking the target accuracy). +- Run at least **2N seeds**, where N = number of submission runs. + +The convergence target for this benchmark is **eval AUC >= 0.80275** (see +[../README.MD](../README.MD) §9). The RCP JSON files and convergence-curve plots +(samples-to-converge vs. batch size / seed) will be committed here. + +## TODO + +- [ ] Run >= 2N-seed convergence sweeps at >= 3 batch sizes. +- [ ] Record precision (FP32/BF16) per the rules. +- [ ] Add `rcp_.json` files in the mlperf_logging RCP format. +- [ ] Add convergence-curve plots and the chosen target-accuracy justification. diff --git a/recommendation_v4/requirements.txt b/recommendation_v4/requirements.txt new file mode 100644 index 000000000..852aa149b --- /dev/null +++ b/recommendation_v4/requirements.txt @@ -0,0 +1,43 @@ +# Frozen dependency versions for the recommendation_v4 (HSTU / yambda-5b) MLPerf +# reference. The CANONICAL, fully-reproducible environment is the Dockerfile +# (built on rocm/primus:v26.3); see docs/training_recipe.md for the per-platform +# (MI350X / B200) install commands and rationale. The pins below mirror that +# stack. torch / torchvision / torchaudio / fbgemm_gpu / torchrec are +# accelerator-specific and MUST be installed with --no-deps from the matching +# index (see Dockerfile) so pip does not clobber the +rocm wheels. + +# --- accelerator stack (install via Dockerfile; --no-deps, matching index) --- +# torch==2.12.0+rocm7.2 # --index-url https://download.pytorch.org/whl/rocm7.2 +# torchvision==0.27.0+rocm7.2 +# torchaudio==2.11.0+rocm7.2 +# fbgemm_gpu # built from FBGEMM commit 10b775730212923f65f7b78f79b6a01d80cf3c29 for gfx950 +torch==2.12.0 +fbgemm_gpu==1.7.0 +torchrec @ git+https://github.com/pytorch/torchrec.git@v2026.06.01.00 + +# --- data / config / logging ------------------------------------------------- +polars-u64-idx==1.33.1 +gin-config==0.5.0 +absl-py==2.1.0 +pandas==2.2.3 +pyarrow==17.0.0 +numpy==1.26.4 +xxhash==3.5.0 +datasets==3.2.0 +huggingface_hub==0.27.0 + +# --- metrics / training utils ------------------------------------------------ +torchmetrics==1.0.3 +tensordict==0.6.2 +tensorboard==2.19.0 +pyre-extensions==0.0.32 +iopath==0.1.10 +typing-inspect==0.9.0 +psutil==6.1.0 +tqdm==4.67.1 +pyyaml==6.0.2 +pybind11==2.13.6 +lightning-utilities==0.11.9 + +# --- MLPerf compliance logging (pinned to the Training 6.0 tag) -------------- +mlperf-logging @ git+https://github.com/mlcommons/logging.git@6.0.0-rc6 diff --git a/recommendation_v4/run_and_time.sh b/recommendation_v4/run_and_time.sh new file mode 100755 index 000000000..a0207e71c --- /dev/null +++ b/recommendation_v4/run_and_time.sh @@ -0,0 +1,65 @@ +#!/usr/bin/env bash +# MLPerf Training reference script: run the benchmark and report wall-clock time. +# +# Runs the full-reference HSTU / yambda-5b streaming train+eval sweep to the +# MLPerf quality target (eval AUC >= 0.80275) and prints the elapsed time of the +# timed region. This is the canonical single-host (8-GPU) entry point; for +# multi-node SLURM launches use scripts/launch_slurm.sh (which calls into the +# same trainer). +# +# Usage: +# DLRM_DATA_PATH=/path/to/dlrm_data ./run_and_time.sh +# +# Env (run shape / cadence -- defaults are the FULL reference sweep): +# DLRM_DATA_PATH data root (required). +# SEED RNG seed (default 1). +# START_TS / NUM_TRAIN_TS window range (default 0 / 299 = full sweep). +# EVAL_EVERY_DATA_PCT eval cadence as a fraction of train data (default 0.005). +# AUC_THRESHOLD convergence target (default 0.80275). +# GPUS_PER_NODE GPUs on this host (default 8). +# RUN_NAME results dir name under results/ (default reference_run). +set -euo pipefail + +: "${DLRM_DATA_PATH:?Set DLRM_DATA_PATH to the data root}" + +REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "${REPO_ROOT}" + +# ---- Reference run shape (full sweep to the quality target) ----------------- +export SEED="${SEED:-1}" +export START_TS="${START_TS:-0}" +export NUM_TRAIN_TS="${NUM_TRAIN_TS:-299}" +export NUM_TRAIN_BATCHES="${NUM_TRAIN_BATCHES:-0}" +export NUM_EVAL_BATCHES="${NUM_EVAL_BATCHES:-0}" +export EVAL_EVERY_N_WINDOWS="${EVAL_EVERY_N_WINDOWS:-0}" +export EVAL_EVERY_DATA_PCT="${EVAL_EVERY_DATA_PCT:-0.005}" +export AUC_THRESHOLD="${AUC_THRESHOLD:-0.80275}" +export RUN_NAME="${RUN_NAME:-reference_run}" + +# ---- Single-host distributed topology (override for multi-node) ------------- +export GPUS_PER_NODE="${GPUS_PER_NODE:-8}" +export NNODES="${NNODES:-1}" +export NODE_RANK="${NODE_RANK:-0}" +export WORLD_SIZE="${WORLD_SIZE:-$((NNODES * GPUS_PER_NODE))}" +export MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}" +export MASTER_PORT="${MASTER_PORT:-29500}" + +# ---- MLPerf compliance logging ---------------------------------------------- +export MLPERF_LOGGING="${MLPERF_LOGGING:-1}" +export MLPERF_LOG_PATH="${MLPERF_LOG_PATH:-${REPO_ROOT}/results/${RUN_NAME}/mlperf/yambda_5b_mlperf.log}" +export MLPERF_SUBMISSION_PLATFORM="${MLPERF_SUBMISSION_PLATFORM:-MI355X}" +mkdir -p "$(dirname "${MLPERF_LOG_PATH}")" + +# ---- Timed region ----------------------------------------------------------- +# Pull the start timestamp into a clear region per the MLPerf run_and_time.sh idiom. +start=$(date +%s) +echo "STARTING TIMING RUN AT $(date -u '+%Y-%m-%d %r')" + +python -m generative_recommenders.dlrm_v3.train.train_ranker \ + --dataset yambda-5b \ + --mode streaming-train-eval + +end=$(date +%s) +result=$(( end - start )) +echo "ENDING TIMING RUN AT $(date -u '+%Y-%m-%d %r')" +echo "RESULT,recommendation_v4_hstu_yambda_5b,${SEED},${result},$(whoami),$(date -u '+%Y-%m-%d %r')" diff --git a/recommendation_v4/scripts/build_ne_auc_trajectory.py b/recommendation_v4/scripts/build_ne_auc_trajectory.py new file mode 100644 index 000000000..edc57cfbf --- /dev/null +++ b/recommendation_v4/scripts/build_ne_auc_trajectory.py @@ -0,0 +1,287 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Build a combined train+eval NE/AUC trajectory from a streaming-train-eval log. + +The streaming loop (generative_recommenders/dlrm_v3/train/utils.py) emits, via +MetricsLogger.compute(), one line per logged step of the form: + + INFO:utils:train - Step 201 metrics: {'metric/lifetime_ne/listen_plus': + tensor(1.0182, dtype=torch.float64), 'metric/window_ne/listen_plus': + tensor(0.9846, ...), ..., 'metric/window_auc/listen_plus': tensor(0.5912), + 'metric/lifetime_auc/listen_plus': tensor(0.5480)} + +and the analogous `eval - Step N metrics:` lines during each (full-holdout) eval +window, plus throughput lines: + + INFO:utils:train - Step 201 perf: local_sps=97.0 global_sps=776.2 + step_ms=10553.89 elapsed_sec=680.6 total_samples=205824 + +This script parses all three, for a chosen task (default listen_plus), and writes: + * /trajectory.json — {"train": {step: {...}}, "eval": {...}, "perf": [...]} + * /trajectory.csv — long-form rows (mode, step, metric, value) + * /trajectory_ne_auc.png — NE and AUC vs train step, train + eval overlaid + (skipped gracefully if matplotlib is absent) + +It is dependency-light (stdlib + optional matplotlib) so it runs anywhere the +log is readable, including the head node. + +Usage: + python3 scripts/build_ne_auc_trajectory.py LOG [--out DIR] [--task listen_plus] +""" + +import argparse +import csv +import json +import os +import re +import sys +from typing import Dict, List, Optional, Tuple + +# `train - Step 201 metrics: {...}` / `eval - Step 17 metrics: {...}` +_STEP_RE = re.compile(r"(train|eval) - Step (\d+) metrics: \{(.*)\}") +# `metric//': tensor(` — value may be int/float/sci. +_METRIC_RE = re.compile( + r"metric/([A-Za-z0-9_]+)/([A-Za-z0-9_+]+)'?\s*:\s*tensor\(\s*([-0-9.eE+]+)" +) +# `train - Step 201 perf: local_sps=97.0 global_sps=776.2 step_ms=10553.89 ` +# `elapsed_sec=680.6 total_samples=205824` +_PERF_RE = re.compile( + r"train - Step (\d+) perf: local_sps=([-0-9.eE+]+) global_sps=([-0-9.eE+]+) " + r"step_ms=([-0-9.eE+]+) elapsed_sec=([-0-9.eE+]+) total_samples=(\d+)" +) +# `[boundary] eval_ts=181 eval first-batch ...` — marks the start of a full-holdout +# eval block; the eval runs at whatever the latest train global step was, so we use +# it to anchor each eval's metrics onto the shared train-global-step x-axis. +_EVAL_BOUNDARY_RE = re.compile(r"\[boundary\] eval_ts=(\d+) eval first-batch") + +# Metrics we surface in the trajectory (others are still captured if present). +_KEEP = ("window_ne", "lifetime_ne", "window_auc", "lifetime_auc", + "window_accuracy", "lifetime_accuracy", "window_gauc", "lifetime_gauc") + + +def _parse_metrics(body: str, task: str) -> Dict[str, float]: + row: Dict[str, float] = {} + for name, tname, val in _METRIC_RE.findall(body): + if tname != task: + continue + try: + row[name] = float(val) + except ValueError: + continue + return row + + +def parse_log( + log_path: str, task: str +) -> Tuple[Dict[str, Dict[int, Dict[str, float]]], List[Dict[str, float]]]: + """Return ({'train': {step: {metric: val}}, 'eval': {...}}, perf_rows). + + Train is keyed by train global step (last write wins — duplicate per-rank + prints are identical). Eval uses a per-rank-resetting internal step counter + that restarts every eval window, so we instead anchor each eval window onto + the *train global step at which it ran* (the loop trains window T then evals + window T+1, so the eval's anchor is the last train step before it). Each eval + window collapses to a single point carrying its final, most-aggregated + full-holdout metrics, plus `eval_window` (the eval_ts) for reference. + """ + out: Dict[str, Dict[int, Dict[str, float]]] = {"train": {}, "eval": {}} + perf: List[Dict[str, float]] = [] + + last_train_step = 0 + cur_anchor: Optional[int] = None # train global step this eval block runs at + cur_ts: Optional[int] = None # eval window id (eval_ts) + cur_row: Optional[Dict[str, float]] = None # final row of the current block + cur_internal: Optional[int] = None # last eval internal step (reset detection) + + def flush_eval() -> None: + nonlocal cur_anchor, cur_ts, cur_row, cur_internal + if cur_row: + anchor = cur_anchor if cur_anchor is not None else last_train_step + row = dict(cur_row) + if cur_ts is not None: + row["eval_window"] = float(cur_ts) + key = anchor + while key in out["eval"]: # keep distinct evals from colliding + key += 1 + out["eval"][key] = row + cur_anchor = cur_ts = cur_row = cur_internal = None + + with open(log_path, "r", errors="replace") as f: + for line in f: + pm = _PERF_RE.search(line) + if pm: + perf.append({ + "step": int(pm.group(1)), + "local_sps": float(pm.group(2)), + "global_sps": float(pm.group(3)), + "step_ms": float(pm.group(4)), + "elapsed_sec": float(pm.group(5)), + "total_samples": int(pm.group(6)), + }) + continue + bm = _EVAL_BOUNDARY_RE.search(line) + if bm: + # The boundary line (a different logger) can interleave before OR + # after this eval's metric lines, so don't use it to delimit the + # block — just tag the current block with its eval_ts. Block + # boundaries come from eval-step resets / training resuming. + if cur_anchor is None: + cur_anchor = last_train_step + cur_ts = int(bm.group(1)) + continue + m = _STEP_RE.search(line) + if not m: + continue + mode, step_s, body = m.group(1), m.group(2), m.group(3) + step = int(step_s) + row = _parse_metrics(body, task) + if mode == "train": + last_train_step = step + if cur_anchor is not None or cur_row is not None: + flush_eval() # an eval block ends when training resumes + if row: + out["train"][step] = row # last write wins + else: # eval — accumulate into the current block (last = most aggregated) + # Fallback for logs without a boundary marker: a drop in the eval + # internal step counter signals a fresh eval window. + if (cur_internal is not None and step < cur_internal + and cur_anchor is None): + flush_eval() + if cur_anchor is None: + cur_anchor = last_train_step + cur_internal = step + if row: + cur_row = row + flush_eval() + return out, perf + + +def write_outputs( + traj: Dict[str, Dict[int, Dict[str, float]]], + perf: List[Dict[str, float]], + out_dir: str, + task: str, +) -> None: + os.makedirs(out_dir, exist_ok=True) + + json_path = os.path.join(out_dir, "trajectory.json") + with open(json_path, "w") as f: + json.dump( + { + "task": task, + "train": {str(k): v for k, v in sorted(traj["train"].items())}, + "eval": {str(k): v for k, v in sorted(traj["eval"].items())}, + "perf": perf, + }, + f, + indent=2, + ) + + csv_path = os.path.join(out_dir, "trajectory.csv") + with open(csv_path, "w", newline="") as f: + w = csv.writer(f) + w.writerow(["mode", "step", "metric", "value"]) + for mode in ("train", "eval"): + for step in sorted(traj[mode]): + for metric, val in traj[mode][step].items(): + w.writerow([mode, step, metric, val]) + + n_train = len(traj["train"]) + n_eval = len(traj["eval"]) + print(f"Parsed {n_train} train points, {n_eval} eval points, " + f"{len(perf)} perf points (task={task}).", file=sys.stderr) + print(f"Wrote {json_path}", file=sys.stderr) + print(f"Wrote {csv_path}", file=sys.stderr) + + _maybe_plot(traj, out_dir, task) + + +def _maybe_plot( + traj: Dict[str, Dict[int, Dict[str, float]]], out_dir: str, task: str +) -> None: + try: + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + except Exception as e: # noqa: BLE001 + print(f"matplotlib unavailable ({e}); skipping plot.", file=sys.stderr) + return + + def series(mode: str, metric: str) -> Tuple[List[int], List[float]]: + steps = sorted(s for s in traj[mode] if metric in traj[mode][s]) + return steps, [traj[mode][s][metric] for s in steps] + + fig, (ax_ne, ax_auc) = plt.subplots(2, 1, figsize=(11, 9), sharex=True) + + for metric, style in (("window_ne", "-"), ("lifetime_ne", "--")): + xs, ys = series("train", metric) + if xs: + ax_ne.plot(xs, ys, style, label=f"train/{metric}", alpha=0.85) + for metric, marker in (("window_ne", "o"), ("lifetime_ne", "s")): + xs, ys = series("eval", metric) + if xs: + ax_ne.plot(xs, ys, marker=marker, ms=5, ls="-", lw=1.0, alpha=0.9, + label=f"eval/{metric}") + ax_ne.set_ylabel("NE (normalized entropy)") + ax_ne.set_title(f"yambda-5b streaming train+eval trajectory — task={task}") + ax_ne.grid(True, alpha=0.3) + ax_ne.legend(fontsize=8, ncol=2) + + for metric, style in (("window_auc", "-"), ("lifetime_auc", "--")): + xs, ys = series("train", metric) + if xs: + ax_auc.plot(xs, ys, style, label=f"train/{metric}", alpha=0.85) + for metric, marker in (("window_auc", "o"), ("lifetime_auc", "s")): + xs, ys = series("eval", metric) + if xs: + ax_auc.plot(xs, ys, marker=marker, ms=5, ls="-", lw=1.0, alpha=0.9, + label=f"eval/{metric}") + ax_auc.set_ylabel("AUC") + ax_auc.set_xlabel("train global step (eval points anchored to the step they ran at)") + ax_auc.grid(True, alpha=0.3) + ax_auc.legend(fontsize=8, ncol=2) + + png_path = os.path.join(out_dir, "trajectory_ne_auc.png") + fig.tight_layout() + fig.savefig(png_path, dpi=120) + print(f"Wrote {png_path}", file=sys.stderr) + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("log", help="Path to the streaming train.log") + ap.add_argument("--out", default=None, + help="Output dir (default: /_trajectory)") + ap.add_argument("--task", default="listen_plus", + help="Task name to extract (default: listen_plus)") + args = ap.parse_args() + + if not os.path.exists(args.log): + print(f"Log not found: {args.log}", file=sys.stderr) + return 2 + out_dir = args.out + if out_dir is None: + stem = os.path.splitext(os.path.basename(args.log))[0] + out_dir = os.path.join(os.path.dirname(os.path.abspath(args.log)), + f"{stem}_trajectory") + + traj, perf = parse_log(args.log, args.task) + write_outputs(traj, perf, out_dir, args.task) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/recommendation_v4/scripts/launch_local.sh b/recommendation_v4/scripts/launch_local.sh new file mode 100755 index 000000000..45f4bb6f2 --- /dev/null +++ b/recommendation_v4/scripts/launch_local.sh @@ -0,0 +1,157 @@ +#!/bin/bash +# ============================================================================= +# launch_local.sh — single-host, NON-SLURM launcher for the yambda-5b trainer. +# +# This is the SLURM-free analog of scripts/launch_slurm.sh's `worker` phase: +# it sets the single-node distributed topology + sane env and invokes the SAME +# entry point (`train_ranker.py --dataset yambda-5b`) reading the SAME +# train/gin/yambda_5b.gin config. No scheduler, no docker, no RDMA overlay — +# everything runs directly on this host against an already-prepared dataset. +# +# Use it to: +# * Smoke-test the launch path on a single GPU box (SMOKE=1, the default — +# a few train/eval batches of one streaming window), or +# * Run the full gin-default workload (SMOKE=0 — consumes whole windows). +# +# PREREQUISITES +# 1) Data prepared (run once, CPU-only — no GPU needed): +# python generative_recommenders/dlrm_v3/preprocess_public_data.py \ +# --dataset yambda-5b --data-path "$DLRM_DATA_PATH" +# producing $DLRM_DATA_PATH/processed_5b/{train_sessions.parquet,...} +# and $DLRM_DATA_PATH/shared_metadata/{artist,album}_item_mapping.parquet +# 2) The train_recipe GPU stack importable by $PYTHON (see docs/training_recipe.md): +# torch (rocm or cuda build), fbgemm_gpu, torchrec, polars-u64-idx, +# gin-config, xxhash, pandas, tensorboard, ... +# This box must have visible GPUs (the trainer shards embeddings onto HBM). +# +# USAGE +# # smoke (default): one window, 20 train + 10 eval batches +# DLRM_DATA_PATH=/home/chcai/dlrm_data bash scripts/launch_local.sh +# +# # full gin-default run (whole windows; long) +# SMOKE=0 DLRM_DATA_PATH=/home/chcai/dlrm_data bash scripts/launch_local.sh +# +# # restrict to 2 GPUs, custom log, plain (non-streaming) train-eval +# GPUS_PER_NODE=2 MODE=train-eval LOG=/tmp/y.log bash scripts/launch_local.sh +# +# Every knob below is env-overridable; defaults reproduce launch_slurm.sh's +# single-node smoke path so a local run matches the known-good cluster path. +# ============================================================================= +set -uo pipefail + +REPO_ROOT=$(cd "$(dirname "$0")/.." && pwd) +cd "$REPO_ROOT" + +# ---- interpreter ------------------------------------------------------------ +# Default to the venv created for data prep if present, else system python3. +# Override with PYTHON=/path/to/python (e.g. the in-container recipe python). +DEFAULT_PY=/home/chcai/dlrmv4_venv/bin/python +PYTHON=${PYTHON:-$([ -x "$DEFAULT_PY" ] && echo "$DEFAULT_PY" || echo python3)} + +# ---- dataset / data path ---------------------------------------------------- +DATASET=${DATASET:-yambda-5b} +MODE=${MODE:-streaming-train-eval} +# Mirrors the yambda_5b.gin default ("/apps/chcai/dlrm_data"); point at wherever +# preprocess_public_data.py wrote processed_5b/ + shared_metadata/. +export DLRM_DATA_PATH=${DLRM_DATA_PATH:-/home/chcai/dlrm_data} + +LOG=${LOG:-$REPO_ROOT/yambda_local.$(date +%Y%m%d_%H%M%S).log} + +# ---- single-node distributed topology -------------------------------------- +# train_ranker reads these from the env (see train_ranker.main): it spawns +# GPUS_PER_NODE ranks via torch.multiprocessing on THIS host. localhost +# rendezvous; empty MASTER_PORT => train_ranker picks a free port. +export NNODES=${NNODES:-1} +export NODE_RANK=${NODE_RANK:-0} +export MASTER_ADDR=${MASTER_ADDR:-localhost} +export MASTER_PORT=${MASTER_PORT:-} +# GPUS_PER_NODE: 0/unset => train_ranker auto-detects torch.cuda.device_count(). +export GPUS_PER_NODE=${GPUS_PER_NODE:-0} + +# ---- runtime env (matches launch_slurm.sh worker defaults) ------------------ +export HSTU_HAMMER_KERNEL=${HSTU_HAMMER_KERNEL:-TRITON} +export PYTORCH_CUDA_ALLOC_CONF=${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True} +export PYTHONPATH="$REPO_ROOT:${PYTHONPATH:-}" +# Single-node RCCL bootstrap: all ranks rendezvous over localhost, so pin the +# loopback NIC. Left to auto-detect, RCCL can grab a non-routable per-GPU RoCE +# NIC and hang/"No route to host" at init (same failure launch_slurm.sh pins +# fenic0 to avoid). Override NCCL_SOCKET_IFNAME for a routable multi-host setup. +export NCCL_SOCKET_IFNAME=${NCCL_SOCKET_IFNAME:-lo} +export NCCL_DEBUG=${NCCL_DEBUG:-WARN} + +# ---- smoke caps ------------------------------------------------------------- +# SMOKE=1 (default): apply small per-window batch caps so a launch finishes in +# minutes (validates the path end-to-end). SMOKE=0: leave the gin defaults +# untouched (consume full windows — the real workload). +SMOKE=${SMOKE:-1} +if [ "$SMOKE" = "1" ]; then + export START_TS=${START_TS:-150} + export NUM_TRAIN_TS=${NUM_TRAIN_TS:-1} + export NUM_TRAIN_BATCHES=${NUM_TRAIN_BATCHES:-20} + export NUM_EVAL_BATCHES=${NUM_EVAL_BATCHES:-10} + # Default eval cadence: per-window OFF (0), data-fraction every 0.5% of data + # (0.005). Mutually exclusive (both >0 raises a ValueError at startup). + export EVAL_EVERY_N_WINDOWS=${EVAL_EVERY_N_WINDOWS:-0} + export EVAL_EVERY_DATA_PCT=${EVAL_EVERY_DATA_PCT:-0.005} + export METRIC_LOG_FREQ=${METRIC_LOG_FREQ:-5} + # Smaller per-sample shape keeps the smoke run light; drop these to use the + # gin defaults (4086/4096). Reuse an existing hstu_cache_L/ if present. + export BATCH_SIZE=${BATCH_SIZE:-32} +fi + +mkdir -p "$(dirname "$LOG")" +{ + echo "[$(date)] launch_local: dataset=$DATASET mode=$MODE smoke=$SMOKE" + echo "[$(date)] PYTHON=$PYTHON" + echo "[$(date)] DLRM_DATA_PATH=$DLRM_DATA_PATH" + echo "[$(date)] topology: nnodes=$NNODES node_rank=$NODE_RANK gpus_per_node(req)=$GPUS_PER_NODE master=$MASTER_ADDR:${MASTER_PORT:-}" +} | tee -a "$LOG" + +# ---- preflight: data present? ---------------------------------------------- +SUFFIX=${DATASET#yambda-} +PROCESSED="$DLRM_DATA_PATH/processed_${SUFFIX}/train_sessions.parquet" +META="$DLRM_DATA_PATH/shared_metadata/artist_item_mapping.parquet" +if [ "$DATASET" = "yambda-5b" ] && { [ ! -f "$PROCESSED" ] || [ ! -f "$META" ]; }; then + echo "[$(date)] ERROR: prepared data not found." | tee -a "$LOG" + echo " expected: $PROCESSED" | tee -a "$LOG" + echo " and: $META" | tee -a "$LOG" + echo " run preprocessing first:" | tee -a "$LOG" + echo " $PYTHON generative_recommenders/dlrm_v3/preprocess_public_data.py --dataset $DATASET --data-path $DLRM_DATA_PATH" | tee -a "$LOG" + exit 1 +fi + +# ---- preflight: GPU stack importable + GPUs visible? ------------------------ +echo "[$(date)] preflight: checking torch / fbgemm_gpu / torchrec + GPU count" | tee -a "$LOG" +"$PYTHON" - <<'PY' 2>&1 | tee -a "$LOG" +import sys +missing = [] +for m in ("torch", "fbgemm_gpu", "torchrec", "polars", "gin", "xxhash"): + try: + __import__(m) + except Exception as e: + missing.append(f"{m} ({e.__class__.__name__})") +if missing: + print("PREFLIGHT FAIL: missing/broken imports: " + ", ".join(missing)) + print("Install the train_recipe GPU stack (see docs/training_recipe.md).") + sys.exit(3) +import torch +n = torch.cuda.device_count() +print(f"imports OK, torch {torch.__version__}, cuda/hip available={torch.cuda.is_available()}, {n} GPU(s)") +if n == 0: + print("PREFLIGHT FAIL: no GPUs visible — the HSTU trainer shards embeddings " + "onto GPU HBM and cannot run CPU-only. Launch on a GPU host.") + sys.exit(4) +PY +pf=${PIPESTATUS[0]} +if [ "$pf" -ne 0 ]; then + echo "[$(date)] preflight failed (rc=$pf) — not launching trainer." | tee -a "$LOG" + exit "$pf" +fi + +# ---- launch ----------------------------------------------------------------- +echo "[$(date)] launching train_ranker ($DATASET, mode=$MODE)" | tee -a "$LOG" +"$PYTHON" -m generative_recommenders.dlrm_v3.train.train_ranker \ + --dataset "$DATASET" --mode "$MODE" 2>&1 | tee -a "$LOG" +rc=${PIPESTATUS[0]} +echo "[$(date)] launch_local finished rc=$rc" | tee -a "$LOG" +exit "$rc" diff --git a/recommendation_v4/scripts/launch_slurm.sh b/recommendation_v4/scripts/launch_slurm.sh new file mode 100755 index 000000000..cb171593c --- /dev/null +++ b/recommendation_v4/scripts/launch_slurm.sh @@ -0,0 +1,662 @@ +#!/bin/bash +#SBATCH --job-name=yambda_slurm +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --exclusive +#SBATCH --partition=meta64 # [CLUSTER-SPECIFIC] partition name +#SBATCH --output=/apps/chcai/yambda_slurm.%j.out +# ============================================================================= +# launch_slurm.sh — single entry point for the yambda-5b trainer on N>=1 nodes. +# +# Consolidates what used to be separate scripts so multi-node enablement is +# ONE committable script (plus the train_ranker.py / utils.py python changes): +# * orchestrate phase (host SLURM glue) — formerly sbatch_smoke_multinode.sh +# * provision phase (container + RDMA) — formerly _provision_yambda_primus.sh +# * worker phase (in-container train) — now inlined below +# +# PHASES (auto-detected from context; force with LAUNCH_SLURM_PHASE=): +# orchestrate Runs on the SLURM batch host (no /.dockerenv). Resolves the +# rendezvous (MASTER_ADDR/PORT), ensures the container on every +# node (provision phase), then `docker exec`s the worker phase on +# every node, one task per node. +# provision Runs on a compute-node host. Ensures the `yambda_primus` +# container is up (loads the pre-baked image if present — no +# internet/pip — else builds from the base image) and stages the +# host RDMA userspace overlay on shared NFS. +# worker Runs INSIDE the container. Sets the distributed topology + +# NCCL/RDMA env and spawns this node's GPU ranks via train_ranker. +# N==1 transparently uses the legacy single-node path (localhost, +# node_rank 0), byte-for-byte as before, so the streaming-e2e +# supervisor's direct `bash scripts/launch_slurm.sh` is unchanged. +# +# USAGE +# Multi-node (N>=1): sbatch --nodes=2 scripts/launch_slurm.sh +# Single-node direct: bash scripts/launch_slurm.sh (already inside container; +# what run_streaming_e2e.sh invokes per relaunch) +# Perf pair: +# LOG=/apps/chcai/perf_1node.log NUM_TRAIN_BATCHES=200 NUM_EVAL_BATCHES=0 \ +# EVAL_EVERY_N_WINDOWS=0 METRIC_LOG_FREQ=20 \ +# sbatch --nodes=1 --job-name=y1 scripts/launch_slurm.sh +# LOG=/apps/chcai/perf_2node.log NUM_TRAIN_BATCHES=200 NUM_EVAL_BATCHES=0 \ +# EVAL_EVERY_N_WINDOWS=0 METRIC_LOG_FREQ=20 \ +# sbatch --nodes=2 --job-name=y2 scripts/launch_slurm.sh +# # then: bash scripts/compare_node_perf.sh /apps/chcai/perf_1node.log /apps/chcai/perf_2node.log +# +# ONE-TIME IMAGE BAKE (so fresh nodes skip the multi-GB torch download + pip): +# BAKE_IMAGE=1 LAUNCH_SLURM_PHASE=provision bash scripts/launch_slurm.sh +# (commits the deps-installed container to $BAKED_IMAGE and `docker save`s it to +# $BAKED_TAR on NFS; subsequent provisions `docker load` it offline.) +# +# ----------------------------------------------------------------------------- +# PORTABILITY — what to change for a DIFFERENT cluster / network / hardware. +# Every such knob is also tagged inline with "[CLUSTER-SPECIFIC]" (grep for it). +# All are env-overridable, so you can adapt without editing this file. +# +# A) SLURM / scheduler +# - #SBATCH --partition=meta64 : partition name. CHANGE per cluster. +# - #SBATCH --time / --exclusive : policy; adjust to taste. +# +# B) Filesystems (must be shared/NFS across ALL nodes — this script re-invokes +# itself and reads the overlay + data from these paths cluster-wide) +# - REPO_MOUNT (repo + this script, e.g. /home/) is bind-mounted rw; +# DATA_MOUNT (e.g. /apps/chcai) holds the read-only dataset + overlay + +# baked tar + pip tarball; SCRATCH (e.g. /home//yambda_runs) is the +# writable log/output root. Override any via env — nothing is user-hardwired. +# +# C) Container image / GPU software stack (tied to the GPU arch + ROCm version) +# - IMAGE=rocm/primus:v26.3 : base image. ROCm/AMD-specific. +# - docker run --device=/dev/kfd --device=/dev/dri --group-add video : AMD ROCm +# device passthrough. For NVIDIA this is --gpus all / nvidia runtime instead. +# - --ulimit memlock=-1 : REQUIRED for RDMA QP registration (do not drop). +# - TORCH_IDX (rocm7.2), torch/vision/audio ==*+rocm7.2, FBGEMM_WHL (a gfx950 +# wheel), torchrec pin : the whole deps set is arch/ROCm-version-specific. +# +# D) Network fabric — THE trickiest part; defaults are PROVEN on meta64 cv350 +# (Broadcom bnxt_re RoCEv2). On a different fabric these almost certainly change +# (see the worker-phase block for the full rationale): +# - NCCL_SOCKET_IFNAME=fenic0 : the ONE routable host NIC for TCP bootstrap. +# Find yours with `ip -br addr`; the per-GPU RDMA NICs are usually NOT +# routable for plain TCP, so auto-detect hangs init — you MUST pin this. +# - NCCL_IB_HCA=bnxt_re0..7 : the RDMA HCA device names. List with `ibv_devices`. +# Different NIC vendor (e.g. mlx5_*, ionic_*) => different names AND a +# different userspace provider, which changes the RDMA overlay below. +# - NCCL_IB_GID_INDEX=3 : RoCEv2 IPv4 GID index. Check `show_gids`; v1/v2 and +# IPv4/IPv6 live at different indices per port. +# - NCCL_IB_TC=104 : RoCE lossless (PFC) traffic class. Fabric/switch-specific. +# - RDMA overlay (provision phase): only needed when the CONTAINER's rdma-core +# is older than the HOST kernel driver's uapi (our bnxt_re v34-vs-v59 case). +# Different NIC/host => different /usr/lib64 provider .so to stage, or the +# overlay may be unnecessary entirely (set RDMA_OVERLAY= to disable). If RDMA +# can't be made to work, NCCL_NET_TRANSPORT=socket falls back to TCP. +# +# E) Not cluster-specific (auto-derived): GPUS_PER_NODE (torch.cuda.device_count), +# NNODES/NODE_RANK/MASTER_ADDR (from SLURM), WORLD_SIZE. +# ============================================================================= +set -uo pipefail + +# Absolute path to THIS script so the orchestrate phase can re-invoke it on every +# node (home is shared NFS, so the same path resolves cluster-wide). +SELF=$(cd "$(dirname "$0")" && pwd)/$(basename "$0") +REPO_ROOT=$(cd "$(dirname "$0")/.." && pwd) + +# ---- phase detection -------------------------------------------------------- +PHASE="${LAUNCH_SLURM_PHASE:-}" +if [ -z "$PHASE" ]; then + if [ -f /.dockerenv ]; then PHASE=worker; else PHASE=orchestrate; fi +fi + +# ---- shared config (env-overridable) ---------------------------------------- +CONTAINER=${CONTAINER:-yambda_primus} +REPO=${REPO:-$REPO_ROOT} # repo path inside the container +IMAGE=${IMAGE:-rocm/primus:v26.3} # [CLUSTER-SPECIFIC] ROCm/arch base image +BAKED_IMAGE=${BAKED_IMAGE:-yambda_primus_baked:latest} +BAKED_TAR=${BAKED_TAR:-/apps/chcai/yambda_primus_baked.tar} # [CLUSTER-SPECIFIC] shared-NFS path +USE_BAKED=${USE_BAKED:-1} +OVERLAY=${RDMA_OVERLAY:-/apps/chcai/rdma_host_el9_new} # [CLUSTER-SPECIFIC] shared-NFS RDMA overlay + +REPO_MOUNT=${REPO_MOUNT:-$HOME} # NFS home holding the repo (must contain $REPO); override if your repo lives elsewhere +DATA_MOUNT=${DATA_MOUNT:-/apps/chcai} # shared dataset + RDMA overlay + pip/fbgemm assets (read-only) +SCRATCH=${SCRATCH:-$HOME/yambda_runs} # writable output root (logs / tb / traces) + +# ============================================================================= +# PHASE: orchestrate (SLURM batch host) +# ============================================================================= +orchestrate() { + # When run as the SLURM batch script, $0 is the node-local staged copy + # (/var/spool/slurmd/job/slurm_script), so $SELF / $REPO_ROOT are WRONG + # here (they don't exist on other nodes). Resolve the REAL shared-NFS script + # path + repo root from SLURM so we can re-invoke this script on every node and + # `cd` to the right repo inside the container. + SCRIPT_PATH=$(scontrol show job "${SLURM_JOB_ID:-0}" 2>/dev/null | grep -oP 'Command=\K\S+') + [ -f "${SCRIPT_PATH:-}" ] || SCRIPT_PATH="${SLURM_SUBMIT_DIR:-$REPO_ROOT}/scripts/launch_slurm.sh" + [ -f "$SCRIPT_PATH" ] || SCRIPT_PATH="$SELF" + REPO=$(cd "$(dirname "$SCRIPT_PATH")/.." && pwd) + + mkdir -p "$SCRATCH" 2>/dev/null || true + LOG=${LOG:-$SCRATCH/yambda_slurm.${SLURM_JOB_ID:-manual}.log} + + # Smoke defaults — override via env for a perf run (see header USAGE). + MODE=${MODE:-streaming-train-eval} + START_TS=${START_TS:-150} + NUM_TRAIN_TS=${NUM_TRAIN_TS:-1} + NUM_TRAIN_BATCHES=${NUM_TRAIN_BATCHES:-20} + NUM_EVAL_BATCHES=${NUM_EVAL_BATCHES:-10} + EVAL_EACH_WINDOW=${EVAL_EACH_WINDOW:-1} + # Default eval cadence: per-window OFF (0), data-fraction every 0.5% of data + # (0.005). The two are mutually exclusive (both >0 raises a ValueError). + EVAL_EVERY_N_WINDOWS=${EVAL_EVERY_N_WINDOWS:-0} + EVAL_EVERY_DATA_PCT=${EVAL_EVERY_DATA_PCT:-0.005} + METRIC_LOG_FREQ=${METRIC_LOG_FREQ:-5} + FORCE_PROVISION=${FORCE_PROVISION:-0} + + # Truncate the metrics log on a FRESH run; APPEND on a supervised relaunch + # (APPEND_LOG=1) so the full-run NE/AUC history survives crash/node-failover + # resubmits instead of being wiped on every attempt (mirrors the single-node + # supervisor's init-once/append model). + if [ "${APPEND_LOG:-0}" = "1" ]; then + echo "[$(date)] === resume: appending to existing $LOG (APPEND_LOG=1) ===" >> "$LOG" + else + : > "$LOG" + fi + echo "[$(date)] launch_slurm/orchestrate: job=${SLURM_JOB_ID:-?} nodes=${SLURM_JOB_NODELIST:-?} nnodes=${SLURM_NNODES:-1}" | tee -a "$LOG" + echo "[$(date)] resolved SCRIPT_PATH=$SCRIPT_PATH REPO=$REPO" | tee -a "$LOG" + echo "[$(date)] config: MODE=$MODE START_TS=$START_TS NUM_TRAIN_TS=$NUM_TRAIN_TS NUM_TRAIN_BATCHES=$NUM_TRAIN_BATCHES NUM_EVAL_BATCHES=$NUM_EVAL_BATCHES METRIC_LOG_FREQ=$METRIC_LOG_FREQ" | tee -a "$LOG" + + # Rendezvous resolved on the HOST (the container image has no SLURM client). + MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" 2>/dev/null | head -1) + MASTER_ADDR=${MASTER_ADDR:-localhost} + MASTER_PORT=$(( 20000 + ${SLURM_JOB_ID:-0} % 20000 )) + echo "[$(date)] rendezvous: MASTER_ADDR=$MASTER_ADDR MASTER_PORT=$MASTER_PORT" | tee -a "$LOG" + + # Optional NCCL/RCCL fabric overrides — forwarded into the container only when + # set at submit time (docker exec does NOT inherit the srun task env). The + # worker phase applies its own validated multi-node bnxt_re defaults when these + # are unset. Common: NCCL_NET_TRANSPORT=socket (TCP fallback), NCCL_DEBUG=INFO. + NCCL_ENV_ARGS="" + for v in NCCL_NET_TRANSPORT NCCL_DEBUG NCCL_SOCKET_IFNAME NCCL_IB_HCA NCCL_IB_GID_INDEX \ + NCCL_IB_TC NCCL_IB_TIMEOUT NCCL_IGNORE_CPU_AFFINITY RCCL_MSCCL_ENABLE NCCL_NET_GDR_LEVEL \ + NCCL_IB_PCI_RELAXED_ORDERING NCCL_IB_USE_INLINE NCCL_IB_QPS_PER_CONNECTION \ + NCCL_IB_ECE_ENABLE NCCL_DMABUF_ENABLE NCCL_GDRCOPY_ENABLE NCCL_GDR_FLUSH_DISABLE \ + NCCL_PXN_DISABLE NCCL_CHECKS_DISABLE NCCL_CROSS_NIC RDMA_OVERLAY; do + eval "val=\${$v:-}" + if [ -n "$val" ]; then NCCL_ENV_ARGS="$NCCL_ENV_ARGS -e $v=$val"; fi + done + + # TRICKY — variable expansion inside the `srun ... bash -c "..."` blocks below: + # the string is double-quoted, so PLAIN $VAR expands NOW on the batch host (e.g. + # $MASTER_ADDR, $CONTAINER, $SCRIPT_PATH — values computed above), while + # BACKSLASH-escaped \$VAR is passed through literally and expands LATER on each + # compute node inside the srun task (e.g. \$SLURM_NODEID, \$(hostname)) where the + # per-node SLURM_* env actually lives. Mixing these up sends every rank the + # wrong node id or breaks the docker exec — keep the \$ on per-node values. + + # --- step 1: ensure the container is up on every node ---------------------- + echo "[$(date)] ensuring container '$CONTAINER' on all nodes (force=$FORCE_PROVISION)" | tee -a "$LOG" + srun --ntasks-per-node=1 bash -c " + # Reap stale/foreign GPU containers from prior jobs BEFORE (re)provisioning. + # The node is allocated --exclusive, so any GPU container other than + # '$CONTAINER' is an orphan left by a previous job (its container outlives the + # SLURM allocation). We remove every such container that has GPU access + # (/dev/kfd or /dev/dri) — running OR stopped, whether or not it currently + # pins VRAM ('docker ps -aq' includes stopped ones) — since idle orphans can + # still hold device handles or wake up; leaked HBM from these has caused both + # OOMs and RCCL collective hangs. We deliberately SKIP non-GPU containers + # (e.g. 'k8s-node-services-*' and other cluster system services) so we don't + # disrupt node infrastructure. docker teardown lets the driver reclaim HBM. + for _c in \$(docker ps -aq 2>/dev/null); do + _nm=\$(docker inspect -f '{{.Name}}' \"\$_c\" 2>/dev/null | sed 's#^/##') + [ \"\$_nm\" = \"$CONTAINER\" ] && continue + _dev=\$(docker inspect -f '{{range .HostConfig.Devices}}{{.PathOnHost}} {{end}}' \"\$_c\" 2>/dev/null) + case \"\$_dev\" in + *kfd*|*dri*) + echo \"[\$(hostname)] reaping stale GPU container \$_nm (\$_c)\" + docker rm -f \"\$_c\" >/dev/null 2>&1 || true ;; + *) + echo \"[\$(hostname)] keeping non-GPU/system container \$_nm (\$_c)\" ;; + esac + done + # Reuse a STOPPED '$CONTAINER' (its installed deps persist in the container + # fs) instead of destructively re-provisioning from the base image + pip. + # Harmless no-op on a fresh node (no such container) -> falls through to + # provision below. Repo code is bind-mounted, so live edits are still picked up. + docker start $CONTAINER >/dev/null 2>&1 || true + if [ \"$FORCE_PROVISION\" = \"1\" ] || ! docker exec $CONTAINER true >/dev/null 2>&1; then + echo \"[\$(hostname)] (re)provisioning container\" + LAUNCH_SLURM_PHASE=provision CONTAINER=$CONTAINER IMAGE=$IMAGE \ + BAKED_IMAGE=$BAKED_IMAGE BAKED_TAR=$BAKED_TAR USE_BAKED=$USE_BAKED \ + BAKE_IMAGE=${BAKE_IMAGE:-0} RDMA_OVERLAY=$OVERLAY REPO=$REPO \ + REPO_MOUNT=$REPO_MOUNT DATA_MOUNT=$DATA_MOUNT SCRATCH=$SCRATCH bash $SCRIPT_PATH + else + # Container persists across jobs; the reap above only removes FOREIGN GPU + # containers, so our own '$CONTAINER' can still pin HBM via stray trainer + # ranks left by a prior OOM/crash (this caused repeated 'CUDA out of memory' + # on relaunch onto the same node). Restart it to kill every exec'd proc and + # let the driver reclaim HBM — cheap (keeps the installed deps in the + # container fs; NFS RDMA overlay also persists), no full re-provision. + echo \"[\$(hostname)] container already up — restarting to free any leaked HBM before launch\" + docker restart $CONTAINER >/dev/null 2>&1 || true + # Readiness gate: a bare 'docker exec true' can pass while the runtime is + # still settling, so the SUBSEQUENT (heavier) worker exec races the restart + # and dies with 'container is not running' / OCI 'setns' errors (observed on + # c07-08 and e08-08 -> the peer never joins rendezvous -> master 600s + # TCPStore timeout). Require State.Running=true AND a successful probe, then + # a short settle, before considering the container ready. + for _w in \$(seq 1 30); do + [ \"\$(docker inspect -f '{{.State.Running}}' $CONTAINER 2>/dev/null)\" = \"true\" ] \ + && docker exec $CONTAINER true >/dev/null 2>&1 && break + sleep 2 + done + sleep 2 + echo \"[\$(hostname)] container restarted (HBM reclaimed; running=\$(docker inspect -f '{{.State.Running}}' $CONTAINER 2>/dev/null))\" + fi + " 2>&1 | tee -a "$LOG" + + # --- step 2: launch the worker (trainer) inside the container on every node - + echo "[$(date)] launching trainer (worker phase) on all nodes" | tee -a "$LOG" + srun --ntasks-per-node=1 bash -c " + # Pre-flight readiness gate (per node): step 1 ran in a SEPARATE srun, so the + # container can still be settling here. Wait for State.Running=true + a probe + # before the worker exec so we don't race a just-restarted container. + for _w in \$(seq 1 30); do + [ \"\$(docker inspect -f '{{.State.Running}}' $CONTAINER 2>/dev/null)\" = \"true\" ] \ + && docker exec $CONTAINER true >/dev/null 2>&1 && break + [ \$_w -eq 1 ] && echo \"[\$(hostname)] worker pre-flight: waiting for container to be ready...\" + sleep 2 + done + # Retry wrapper: docker exec startup failures (rc 125 daemon 'container is not + # running', 126/127 OCI/setns 'exec failed') mean the container wasn't ready, + # NOT that the trainer ran and failed. Restart + re-gate + retry a few times. + # Any OTHER rc (the trainer actually started and exited) is propagated so the + # supervisor's resume-from-checkpoint logic owns real failures. + _wattempt=0 + while : ; do + _wattempt=\$((_wattempt+1)) + docker exec \ + -e LAUNCH_SLURM_PHASE=worker \ + -e SCRATCH=$SCRATCH \ + -e SLURM_NNODES=\$SLURM_NNODES \ + -e SLURM_NODEID=\$SLURM_NODEID \ + -e SLURM_PROCID=\$SLURM_PROCID \ + -e SLURM_JOB_NODELIST=\"\$SLURM_JOB_NODELIST\" \ + -e SLURM_JOB_ID=\$SLURM_JOB_ID \ + -e MASTER_ADDR=$MASTER_ADDR \ + -e MASTER_PORT=$MASTER_PORT \ + -e HSTU_HAMMER_KERNEL=${HSTU_HAMMER_KERNEL:-TRITON} \ + -e MODE=$MODE \ + -e START_TS=$START_TS \ + -e NUM_TRAIN_TS=$NUM_TRAIN_TS \ + -e EVAL_EVERY_N_WINDOWS=$EVAL_EVERY_N_WINDOWS \ + ${EVAL_EVERY_DATA_PCT:+-e EVAL_EVERY_DATA_PCT=$EVAL_EVERY_DATA_PCT} \ + -e NUM_TRAIN_BATCHES=$NUM_TRAIN_BATCHES \ + -e NUM_EVAL_BATCHES=$NUM_EVAL_BATCHES \ + ${DIE_AT_STEP:+-e DIE_AT_STEP=$DIE_AT_STEP} \ + -e METRIC_LOG_FREQ=$METRIC_LOG_FREQ \ + ${MLPERF_LOGGING:+-e MLPERF_LOGGING=$MLPERF_LOGGING} \ + ${MLPERF_TRAIN_LOSS_LOG_FREQ:+-e MLPERF_TRAIN_LOSS_LOG_FREQ=$MLPERF_TRAIN_LOSS_LOG_FREQ} \ + ${STREAMING_SHUFFLE_FRACTION:+-e STREAMING_SHUFFLE_FRACTION=$STREAMING_SHUFFLE_FRACTION} \ + ${STREAMING_SHUFFLE_SEED:+-e STREAMING_SHUFFLE_SEED=$STREAMING_SHUFFLE_SEED} \ + ${NUM_WORKERS:+-e NUM_WORKERS=$NUM_WORKERS} \ + ${PREFETCH_FACTOR:+-e PREFETCH_FACTOR=$PREFETCH_FACTOR} \ + ${DIAG_UNIQUE_EMB:+-e DIAG_UNIQUE_EMB=$DIAG_UNIQUE_EMB} \ + ${DIAG_EMB_STEPS:+-e DIAG_EMB_STEPS=$DIAG_EMB_STEPS} \ + ${OUTPUT_TRACE:+-e OUTPUT_TRACE=$OUTPUT_TRACE} \ + ${MIN_HISTORY:+-e MIN_HISTORY=$MIN_HISTORY} \ + ${SEED:+-e SEED=$SEED} \ + ${DENSE_LR:+-e DENSE_LR=$DENSE_LR} \ + ${SPARSE_LR:+-e SPARSE_LR=$SPARSE_LR} \ + ${GRAD_CLIP_NORM:+-e GRAD_CLIP_NORM=$GRAD_CLIP_NORM} \ + ${HSTU_NUM_LAYERS:+-e HSTU_NUM_LAYERS=$HSTU_NUM_LAYERS} \ + ${MAX_SEQ_LEN:+-e MAX_SEQ_LEN=$MAX_SEQ_LEN} \ + ${HISTORY_LENGTH:+-e HISTORY_LENGTH=$HISTORY_LENGTH} \ + ${BATCH_SIZE:+-e BATCH_SIZE=$BATCH_SIZE} \ + ${CKPT_TIME_INTERVAL_S:+-e CKPT_TIME_INTERVAL_S=$CKPT_TIME_INTERVAL_S} \ + ${KEEP_LAST_N:+-e KEEP_LAST_N=$KEEP_LAST_N} \ + ${IN_WINDOW_CKPT_FREQ:+-e IN_WINDOW_CKPT_FREQ=$IN_WINDOW_CKPT_FREQ} \ + ${CKPT_STEP_FREQ:+-e CKPT_STEP_FREQ=$CKPT_STEP_FREQ} \ + -e TRAIN_SPLIT_PERCENTAGE=${TRAIN_SPLIT_PERCENTAGE:-1.0} \ + -e AUC_THRESHOLD=${AUC_THRESHOLD:-1.0} \ + ${MLPERF_SUBMISSION_PLATFORM:+-e MLPERF_SUBMISSION_PLATFORM=$MLPERF_SUBMISSION_PLATFORM} \ + -e SPLIT_SALT=${SPLIT_SALT:-0} \ + -e EVAL_HOLDOUT_TS=${EVAL_HOLDOUT_TS:--1} \ + -e EVAL_HOLDOUT_NUM_WINDOWS=${EVAL_HOLDOUT_NUM_WINDOWS:-1} \ + ${WORKER_CMD:+-e WORKER_CMD=\"$WORKER_CMD\"} \ + ${RUN_NAME:+-e RUN_NAME=$RUN_NAME} \ + ${TENSORBOARD_LOG_PATH:+-e TENSORBOARD_LOG_PATH=$TENSORBOARD_LOG_PATH} \ + ${MLPERF_LOG_PATH:+-e MLPERF_LOG_PATH=$MLPERF_LOG_PATH} \ + ${CKPT_PATH:+-e CKPT_PATH=$CKPT_PATH} \ + ${SPARSE_A2A_FWD:+-e SPARSE_A2A_FWD=$SPARSE_A2A_FWD} \ + ${SPARSE_A2A_BWD:+-e SPARSE_A2A_BWD=$SPARSE_A2A_BWD} \ + ${QCOMM_LOWMEM_CODEC:+-e QCOMM_LOWMEM_CODEC=$QCOMM_LOWMEM_CODEC} \ + ${EMB_SHARDING_OVERRIDES:+-e EMB_SHARDING_OVERRIDES=$EMB_SHARDING_OVERRIDES} \ + ${EMB_PLACEMENT_OVERRIDES:+-e EMB_PLACEMENT_OVERRIDES=$EMB_PLACEMENT_OVERRIDES} \ + ${EMB_PLACEMENT:+-e EMB_PLACEMENT=$EMB_PLACEMENT} \ + ${PG_TIMEOUT_S:+-e PG_TIMEOUT_S=$PG_TIMEOUT_S} \ + ${TORCH_NCCL_TRACE_BUFFER_SIZE:+-e TORCH_NCCL_TRACE_BUFFER_SIZE=$TORCH_NCCL_TRACE_BUFFER_SIZE} \ + ${TORCH_NCCL_DUMP_ON_TIMEOUT:+-e TORCH_NCCL_DUMP_ON_TIMEOUT=$TORCH_NCCL_DUMP_ON_TIMEOUT} \ + ${TORCH_NCCL_TRACE_CPP_STACK:+-e TORCH_NCCL_TRACE_CPP_STACK=$TORCH_NCCL_TRACE_CPP_STACK} \ + ${TORCH_NCCL_DEBUG_INFO_TEMP_FILE:+-e TORCH_NCCL_DEBUG_INFO_TEMP_FILE=$TORCH_NCCL_DEBUG_INFO_TEMP_FILE} \ + -e LOG=$LOG \ + $NCCL_ENV_ARGS \ + $CONTAINER bash -lc 'cd $REPO && LAUNCH_SLURM_PHASE=worker bash scripts/launch_slurm.sh' + _wrc=\$? + if { [ \$_wrc -eq 125 ] || [ \$_wrc -eq 126 ] || [ \$_wrc -eq 127 ]; } && [ \$_wattempt -lt 5 ]; then + echo \"[\$(hostname)] worker exec failed to START (rc=\$_wrc, attempt \$_wattempt/5) — container not ready; restarting + retrying\" + docker restart $CONTAINER >/dev/null 2>&1 || true + for _w in \$(seq 1 30); do + [ \"\$(docker inspect -f '{{.State.Running}}' $CONTAINER 2>/dev/null)\" = \"true\" ] \ + && docker exec $CONTAINER true >/dev/null 2>&1 && break + sleep 2 + done + sleep 3 + continue + fi + exit \$_wrc + done + " 2>&1 | tee -a "$LOG" + rc=${PIPESTATUS[0]} + echo "[$(date)] launch_slurm/orchestrate finished rc=$rc" | tee -a "$LOG" + exit $rc +} + +# ============================================================================= +# PHASE: provision (compute-node host) +# ============================================================================= +provision() { + export PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:${PATH:-}" + DOCKER=$(command -v docker 2>/dev/null || true); DOCKER=${DOCKER:-/usr/bin/docker} + FBGEMM_WHL=${FBGEMM_WHL:-/apps/chcai/FBGEMM/fbgemm_gpu/dist/fbgemm_gpu_nightly_rocm-2026.6.2-cp312-cp312-linux_x86_64.whl} # [CLUSTER-SPECIFIC] gfx950/ROCm wheel + TORCH_IDX=${TORCH_IDX:-https://download.pytorch.org/whl/rocm7.2} # [CLUSTER-SPECIFIC] ROCm version index + echo "[provision] host=$(hostname) container=$CONTAINER docker=$DOCKER" + + # Resolve which image to run + whether deps must be installed. Prefer a pre-baked + # image (deps already installed) to skip the multi-GB torch download + pip / + # torchrec-from-git build on every fresh node: + # 1) baked image in this node's docker -> use it, skip deps + # 2) baked image tar on NFS -> docker load (local, no internet) + # 3) neither -> base image + pip (slow path, which + # can then be baked via BAKE_IMAGE=1) + NEED_DEPS=1 + RUN_IMAGE="$IMAGE" + if [ "$USE_BAKED" = "1" ]; then + if "$DOCKER" image inspect "$BAKED_IMAGE" >/dev/null 2>&1; then + echo "[provision] using baked image $BAKED_IMAGE (deps preinstalled, no download)" + RUN_IMAGE="$BAKED_IMAGE"; NEED_DEPS=0 + elif [ -f "$BAKED_TAR" ]; then + echo "[provision] loading baked image from $BAKED_TAR (local, no internet)..." + if "$DOCKER" load -i "$BAKED_TAR" >/dev/null 2>&1 && "$DOCKER" image inspect "$BAKED_IMAGE" >/dev/null 2>&1; then + RUN_IMAGE="$BAKED_IMAGE"; NEED_DEPS=0; echo "[provision] baked image loaded" + else + echo "[provision] WARNING: docker load failed; falling back to base-image + pip" + fi + fi + fi + if ! "$DOCKER" image inspect "$RUN_IMAGE" >/dev/null 2>&1; then + echo "[provision] pulling $RUN_IMAGE (this can take a while)..."; "$DOCKER" pull "$RUN_IMAGE" + fi + + echo "[provision] (re)starting container $CONTAINER from $RUN_IMAGE" + "$DOCKER" rm -f "$CONTAINER" >/dev/null 2>&1 || true + "$DOCKER" run -d --name "$CONTAINER" \ + --network=host --ipc=host --shm-size=64g \ + --device=/dev/kfd --device=/dev/dri --group-add video \ + `# [CLUSTER-SPECIFIC] AMD ROCm device passthrough; NVIDIA uses --gpus all / nvidia runtime` \ + --cap-add=SYS_PTRACE --cap-add=CAP_SYS_ADMIN --cap-add=IPC_LOCK \ + --ulimit memlock=-1:-1 --ulimit stack=67108864:67108864 \ + `# memlock=-1 is REQUIRED for RDMA QP memory registration — do not drop` \ + --security-opt seccomp=unconfined --privileged \ + -v "$REPO_MOUNT:$REPO_MOUNT" \ + -v "$DATA_MOUNT:$DATA_MOUNT" \ + `# shared-NFS bind mounts: repo home (REPO_MOUNT, rw) + dataset/build assets (DATA_MOUNT)` \ + -w "$REPO" \ + "$RUN_IMAGE" sleep infinity + + # --- RDMA userspace overlay for in-container RCCL (bnxt_re) ----------------- + # The image (rocm/primus, rdma-core 50/libbnxt_re-rdmav34) ships an OLDER RDMA + # userspace than the host kernel bnxt_re driver. The stock v34 provider faults + # RCCL's deep-queue create_qp (max_send_wr=256) against the newer kernel uapi + # -> "ibv_create_qp ... Bad address". Fix: stage the host's matched rdma-core + # (libibverbs v61 + libbnxt_re-rdmav59 + libnl) on NFS so the worker phase makes + # RCCL load it via LD_PRELOAD + LD_LIBRARY_PATH. The UNVERSIONED libibverbs.so + # symlink is essential (import torch pulls the unversioned soname; without it + # the lookup falls through to the container v34 lib and the fix regresses). + if [ "${FORCE_OVERLAY:-0}" != "1" ] && ls "$OVERLAY/lib/libibverbs/"libbnxt_re-rdmav*.so >/dev/null 2>&1 && [ -L "$OVERLAY/lib/libibverbs.so" ]; then + echo "[provision] host RDMA overlay already staged at $OVERLAY (shared NFS) — skipping" + else + echo "[provision] staging host RDMA userspace overlay -> $OVERLAY" + rm -rf "${OVERLAY}.tmp" 2>/dev/null + mkdir -p "${OVERLAY}.tmp/lib/libibverbs" "${OVERLAY}.tmp/libibverbs.d" + cp -L /usr/lib64/libibverbs.so.1 /usr/lib64/libnl-3.so.200 /usr/lib64/libnl-route-3.so.200 "${OVERLAY}.tmp/lib/" 2>/dev/null || true + ln -sf libibverbs.so.1 "${OVERLAY}.tmp/lib/libibverbs.so" + cp -L /usr/lib64/libibverbs/*.so "${OVERLAY}.tmp/lib/libibverbs/" 2>/dev/null || true + cp /etc/libibverbs.d/*.driver "${OVERLAY}.tmp/libibverbs.d/" 2>/dev/null || true + if ls "${OVERLAY}.tmp/lib/libibverbs/"libbnxt_re-rdmav*.so >/dev/null 2>&1; then + rm -rf "$OVERLAY" 2>/dev/null + mv "${OVERLAY}.tmp" "$OVERLAY" 2>/dev/null || { mkdir -p "$OVERLAY"; cp -a "${OVERLAY}.tmp/." "$OVERLAY/"; } + echo "[provision] host RDMA overlay staged: $(ls "$OVERLAY/lib/libibverbs" | wc -l) providers + libibverbs.so symlink" + else + echo "[provision] WARNING: host bnxt_re provider not found at /usr/lib64/libibverbs — multi-node RDMA will fail 'Bad address'; use NCCL_NET_TRANSPORT=socket" + fi + fi + + if [ "$NEED_DEPS" = "0" ]; then + echo "[provision] baked image — deps preinstalled; verifying imports only" + "$DOCKER" exec "$CONTAINER" bash -lc ' +python -c "import torch, fbgemm_gpu, torchrec, polars, xxhash, gin; print(\"imports OK,\", torch.__version__, torch.version.hip, torch.cuda.device_count(), \"gpus\")" +' || echo "[provision] WARNING: baked-image import smoke failed" + else + echo "[provision] installing recipe deps (base image, slow path)" + # Install misc deps FIRST, then pin the rocm torch stack + fbgemm + torchrec + # LAST with --no-deps so nothing pulls a CUDA torch over the rocm build. + "$DOCKER" exec "$CONTAINER" bash -lc ' +set -e +echo "=== native torch ==="; python -c "import torch;print(torch.__version__)" || true +echo "=== misc python deps ===" +pip install --no-cache-dir polars-u64-idx pyarrow pyyaml tqdm psutil numba xxhash gin-config \ + absl-py pandas tensorboard torchmetrics tensordict pyre-extensions iopath typing-inspect 2>&1 | tail -3 || true +echo "=== rocm torch stack (force, no-deps, LAST) ===" +pip install --force-reinstall --no-deps --index-url '"$TORCH_IDX"' \ + torch==2.12.0+rocm7.2 torchvision==0.27.0+rocm7.2 torchaudio==2.11.0+rocm7.2 +echo "=== fbgemm (local gfx950 wheel) ===" +pip install --force-reinstall --no-deps '"$FBGEMM_WHL"' +echo "=== torchrec v2026.06.01.00 (force, no-deps) ===" +pip install --force-reinstall --no-deps "git+https://github.com/pytorch/torchrec.git@v2026.06.01.00" +echo "=== import smoke ===" +python -c "import torch, fbgemm_gpu, torchrec, polars, xxhash, gin; print(\"imports OK,\", torch.__version__, torch.version.hip, torch.cuda.device_count(), \"gpus\")" +' + fi + + # --- one-time bake: snapshot the deps-installed container into a reusable image + # and save it to NFS so future nodes skip the download/pip path entirely. + if [ "${BAKE_IMAGE:-0}" = "1" ]; then + echo "[provision] baking: docker commit $CONTAINER -> $BAKED_IMAGE" + if "$DOCKER" commit "$CONTAINER" "$BAKED_IMAGE" >/dev/null; then + echo "[provision] saving $BAKED_IMAGE -> $BAKED_TAR (one-time, tens of GB)" + if "$DOCKER" save "$BAKED_IMAGE" -o "${BAKED_TAR}.tmp.$$" && mv -f "${BAKED_TAR}.tmp.$$" "$BAKED_TAR"; then + echo "[provision] bake done: $(ls -lh "$BAKED_TAR" 2>/dev/null | awk '{print $5}')" + else + echo "[provision] WARNING: docker save failed"; rm -f "${BAKED_TAR}.tmp.$$" 2>/dev/null + fi + else + echo "[provision] WARNING: docker commit failed" + fi + fi + echo "[provision] DONE" +} + +# ============================================================================= +# PHASE: worker (inside the container) +# ============================================================================= +worker() { + cd "$REPO_ROOT" + mkdir -p "$SCRATCH" 2>/dev/null || true + LOG=${LOG:-$SCRATCH/yambda_5b_8gpu.log} + # Append (not truncate): under the streaming-e2e supervisor a run may relaunch + # many times into the SAME $LOG; the supervisor initializes it once at run start. + # MLPerf compliance log (rank 0 writes it). Per-job filename so each standalone + # sbatch gets a clean log; the e2e supervisor pins MLPERF_LOG_PATH itself. + export MLPERF_LOG_PATH=${MLPERF_LOG_PATH:-$SCRATCH/mlperf/yambda_5b_mlperf.${SLURM_JOB_ID:-manual}.log} + echo "[$(date)] REPO_ROOT=$REPO_ROOT" | tee -a "$LOG" + + # polars-u64-idx (NOT stock polars) — yambda parquet's flat-explode overruns + # 32-bit row index. Reserved node has no outbound DNS, so install from a + # pre-staged tarball under /apps/chcai/. Override PIP_LOCAL_TGZ for other hosts. + PIP_LOCAL_TGZ=${PIP_LOCAL_TGZ:-/apps/chcai/pip_local_yambda.tgz} # [CLUSTER-SPECIFIC] shared-NFS path + PIP_LOCAL_DIR=${PIP_LOCAL_DIR:-/tmp/pip_local} + if [ ! -f "$PIP_LOCAL_DIR/lib/python3.12/site-packages/polars/__init__.py" ]; then + rm -rf "$PIP_LOCAL_DIR" + mkdir -p "$PIP_LOCAL_DIR" && tar xzf "$PIP_LOCAL_TGZ" -C "$(dirname "$PIP_LOCAL_DIR")" 2>&1 | tail -3 | tee -a "$LOG" + fi + + export PYTHONPATH="$PIP_LOCAL_DIR/lib/python3.12/site-packages:$REPO_ROOT:${PYTHONPATH:-}" + export HOME=${HOME:-/tmp} + echo "[$(date)] PYTHONPATH=$PYTHONPATH" | tee -a "$LOG" + python -c "import torch, fbgemm_gpu, torchrec, polars, xxhash, gin; print('imports OK,', torch.__version__, torch.cuda.device_count(),'gpus')" 2>&1 | tee -a "$LOG" + + export HIP_VISIBLE_DEVICES=${HIP_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7} + export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7} + + # --- distributed topology --------------------------------------------------- + GPUS_PER_NODE=$(python -c "import torch; print(torch.cuda.device_count())") + # Multi-node when launched one-task-per-node under SLURM (SLURM_NNODES>1); + # otherwise fall through to legacy single-node defaults (localhost, node_rank 0). + if [ "${SLURM_NNODES:-1}" -gt 1 ] && [ -n "${SLURM_JOB_NODELIST:-}" ]; then + NNODES=${SLURM_NNODES} + NODE_RANK=${SLURM_NODEID:-${SLURM_PROCID:-0}} + # PREFER a MASTER_ADDR/PORT forwarded from the orchestrate phase (resolved on + # the host, which has scontrol); the container image carries no SLURM client. + if [ -z "${MASTER_ADDR:-}" ]; then + MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" 2>/dev/null | head -1) + MASTER_ADDR=${MASTER_ADDR:-localhost} + fi + MASTER_PORT=${MASTER_PORT:-$(( 20000 + ${SLURM_JOB_ID:-0} % 20000 ))} + else + NNODES=${NNODES:-1} + NODE_RANK=${NODE_RANK:-0} + # Single-node: all ranks live on THIS host, so rendezvous over loopback and + # do NOT use the SLURM hostname. On some nodes the hostname resolves to a + # non-routable per-GPU RoCE /31 (benic 192.168.x) address; using it makes the + # NCCL bootstrap fail with "No route to host". localhost is node-independent. + MASTER_ADDR=localhost + MASTER_PORT=${MASTER_PORT:-} # empty => train_ranker picks a free port + fi + export NNODES NODE_RANK GPUS_PER_NODE MASTER_ADDR MASTER_PORT + export WORLD_SIZE=$(( NNODES * GPUS_PER_NODE )) + echo "[$(date)] topology: nnodes=$NNODES node_rank=$NODE_RANK gpus_per_node=$GPUS_PER_NODE world_size=$WORLD_SIZE master=$MASTER_ADDR:${MASTER_PORT:-}" | tee -a "$LOG" + + # NCCL bootstrap NIC — pin for BOTH single- and multi-node. The container is + # --network=host so RCCL sees ALL host interfaces; if left to auto-detect, NCCL + # can pick a non-routable per-GPU RoCE /31 (benic* 192.168.x) link and fail + # bootstrap with "No route to host" (this is node-dependent: it happened to + # work on some nodes and not others, causing repetitive single-node init + # failures). Pinning the routable host NIC fixes it everywhere. + # [CLUSTER-SPECIFIC] routable host NIC for TCP bootstrap (find via `ip -br addr`). + export NCCL_SOCKET_IFNAME=${NCCL_SOCKET_IFNAME:-fenic0} + + # Multi-node additionally needs the RDMA data-plane (bnxt_re HCAs) configured; + # single-node uses intra-node P2P (XGMI/PCIe) so only the bootstrap NIC matters. + if [ "$NNODES" -gt 1 ]; then + NCCL_NET_TRANSPORT=${NCCL_NET_TRANSPORT:-ib} + if [ "$NCCL_NET_TRANSPORT" = "socket" ]; then + export NCCL_IB_DISABLE=1 + echo "[$(date)] NCCL: IB disabled — allreduce over TCP (fenic0). Functional, not RDMA-fast." | tee -a "$LOG" + else + # bnxt_re userspace provider ABI overlay (REQUIRED for RCCL). The stock v34 + # provider faults RCCL's create_qp (256 WRs) against the host kernel uapi + # ("Bad address"); the host v61/v59 set staged by the provision phase works. + # The libibverbs.so (UNVERSIONED) symlink + LD_PRELOAD are both required so + # the torch process maps ONLY the host lib (see provision phase comment). + if [ -e "$OVERLAY/lib/libibverbs.so.1" ]; then + [ -e "$OVERLAY/lib/libibverbs.so" ] || ln -sf libibverbs.so.1 "$OVERLAY/lib/libibverbs.so" 2>/dev/null || true + export LD_LIBRARY_PATH="$OVERLAY/lib:$OVERLAY/lib/libibverbs:${LD_LIBRARY_PATH:-}" + export LD_PRELOAD="$OVERLAY/lib/libibverbs.so.1${LD_PRELOAD:+:$LD_PRELOAD}" + echo "[$(date)] NCCL: bnxt_re provider overlay -> $OVERLAY (host rdma-core v61/v59; symlink+LD_PRELOAD so RCCL binds the host lib for QP creation)" | tee -a "$LOG" + else + echo "[$(date)] WARNING: RDMA overlay $OVERLAY missing — RCCL QP creation will fail 'Bad address' on stock v34 provider; set RDMA_OVERLAY or use NCCL_NET_TRANSPORT=socket" | tee -a "$LOG" + fi + # MINIMAL bnxt_re set PROVEN on these meta64 cv350 nodes (cmcknigh RCCL + # benchmarks + confirmed e2e here). NCCL_IB_TC=104 (RoCE lossless PFC class) + # is required; do NOT add the ionic-AINIC QPS/ECE/DMABUF block. + # [CLUSTER-SPECIFIC] RDMA HCA names (`ibv_devices`); other vendors => mlx5_*/ionic_* + export NCCL_IB_HCA=${NCCL_IB_HCA:-bnxt_re0,bnxt_re1,bnxt_re2,bnxt_re3,bnxt_re4,bnxt_re5,bnxt_re6,bnxt_re7} + export NCCL_IB_GID_INDEX=${NCCL_IB_GID_INDEX:-3} # [CLUSTER-SPECIFIC] RoCEv2 IPv4 GID idx (`show_gids`) + export NCCL_IB_TC=${NCCL_IB_TC:-104} # [CLUSTER-SPECIFIC] RoCE lossless/PFC traffic class + export NCCL_IB_TIMEOUT=${NCCL_IB_TIMEOUT:-14} + export NCCL_IGNORE_CPU_AFFINITY=${NCCL_IGNORE_CPU_AFFINITY:-1} + export RCCL_MSCCL_ENABLE=${RCCL_MSCCL_ENABLE:-0} + # GPU-Direct RDMA: ENABLED by default. The brcmrdma host kernel ships the + # inbox peer-memory client (`ib_register_peer_memory_client` in + # /proc/kallsyms), so RCCL does true GPU<->NIC DMA over bnxt_re instead of + # bouncing through host memory. Measured ~+22% throughput at 2 nodes + # (65.7%->79.8% weak-scaling efficiency) vs the old host-staged path. + # GDR_LEVEL=5 (most permissive) is required so GDR is used even when the GPU + # and NIC cross the CPU root complex. NCCL_DMABUF_ENABLE=1 is a harmless + # no-op here (kernel lacks CONFIG_DMABUF_MOVE_NOTIFY/CONFIG_PCI_P2PDMA, so + # peermem carries it). Enabling is non-fatal: if peermem is ever absent RCCL + # just logs "GDR 0" and falls back to host staging. Override with + # NCCL_NET_GDR_LEVEL=0 to force the legacy host-staged path. + export NCCL_NET_GDR_LEVEL=${NCCL_NET_GDR_LEVEL:-5} + export NCCL_DMABUF_ENABLE=${NCCL_DMABUF_ENABLE:-1} + echo "[$(date)] NCCL: RDMA over bnxt_re (GID idx ${NCCL_IB_GID_INDEX}, TC ${NCCL_IB_TC}, GDR_LEVEL=${NCCL_NET_GDR_LEVEL}, DMABUF=${NCCL_DMABUF_ENABLE}; meta64 bnxt_re config, validated)" | tee -a "$LOG" + fi + fi + export NCCL_DEBUG=${NCCL_DEBUG:-WARN} + export HSTU_HAMMER_KERNEL=${HSTU_HAMMER_KERNEL:-} + export PYTORCH_CUDA_ALLOC_CONF=${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True} + + # --- GPU clock sanity guard ------------------------------------------------- + # A leftover perf_determinism cap (half clock) silently slows every kernel ~1.9x. + # Log the perf level + a live sclk sample and try to restore boost (non-fatal). + if command -v rocm-smi >/dev/null 2>&1; then + echo "[$(date)] GPU perf-level check:" | tee -a "$LOG" + rocm-smi --showperflevel 2>/dev/null | grep -iE "GPU\[[0-9]+\]" | tee -a "$LOG" || true + if rocm-smi --showperflevel 2>/dev/null | grep -iqE "Performance Level: *(perf_determinism|manual|low)"; then + echo "[$(date)] WARNING: GPUs not in 'auto' perf level — attempting --setperflevel auto" | tee -a "$LOG" + rocm-smi --setperflevel auto 2>/dev/null | grep -iE "set to auto" | tee -a "$LOG" \ + || echo "[$(date)] WARNING: could not set perf level (no permission?). Run 'rocm-smi --setperflevel auto' on the HOST before benchmarking — clocks may be capped." | tee -a "$LOG" + fi + echo "[$(date)] sclk sample (GPU0):$(rocm-smi -d 0 --showclocks 2>/dev/null | grep -i 'sclk clock level' | sed -E 's/.*sclk clock level//')" | tee -a "$LOG" || true + fi + + # --- stray-trainer / leaked-VRAM guard ------------------------------------- + # The trainer runs via `docker exec` into a long-lived container, so its procs + # live in the container PID namespace, NOT the SLURM job cgroup. If a prior job + # OOM'd/crashed, a rank can leak and keep holding ~half of every GPU's VRAM, + # which persists across jobs (container survives) and guarantees the next + # attempt OOMs. Before launching, reap any pre-existing trainer procs (there + # should be none at this point) and wait for VRAM to drain. [g]-guard avoids + # self-match. Non-fatal. + if pgrep -f '[g]enerative_recommenders' >/dev/null 2>&1; then + echo "[$(date)] WARNING: leaked trainer procs found pre-launch — killing." | tee -a "$LOG" + pkill -9 -f '[g]enerative_recommenders' 2>/dev/null || true + for _i in $(seq 1 15); do + pgrep -f '[g]enerative_recommenders' >/dev/null 2>&1 || break + sleep 2 + done + sleep 5 # let the driver release VRAM after process exit + if command -v rocm-smi >/dev/null 2>&1; then + echo "[$(date)] post-cleanup GPU0 used GiB:$(rocm-smi --showmeminfo vram 2>/dev/null | awk -F: '/Used/{printf " %.0f", $3/1073741824; exit}')" | tee -a "$LOG" + fi + fi + + # WORKER_CMD override: run an arbitrary in-container command (e.g. an a2a/RCCL + # micro-benchmark) instead of the trainer, REUSING all the NCCL/RDMA/topology + # setup above so it exercises the exact transport the trainer uses. The + # supervisor never sets WORKER_CMD, so the training path is unchanged. + if [ -n "${WORKER_CMD:-}" ]; then + echo "[$(date)] WORKER_CMD override (WORLD_SIZE=$WORLD_SIZE): $WORKER_CMD" | tee -a "$LOG" + bash -lc "cd $REPO_ROOT && $WORKER_CMD" 2>&1 | tee -a "$LOG" + return + fi + + echo "[$(date)] launching train_ranker with WORLD_SIZE=$WORLD_SIZE" | tee -a "$LOG" + python -m generative_recommenders.dlrm_v3.train.train_ranker \ + --dataset yambda-5b --mode "${MODE:-streaming-train-eval}" 2>&1 | tee -a "$LOG" +} + +# ---- dispatch --------------------------------------------------------------- +case "$PHASE" in + orchestrate) orchestrate ;; + provision) provision ;; + worker) worker ;; + *) echo "launch_slurm.sh: unknown LAUNCH_SLURM_PHASE='$PHASE'" >&2; exit 2 ;; +esac diff --git a/recommendation_v4/scripts/run_streaming_e2e.sh b/recommendation_v4/scripts/run_streaming_e2e.sh new file mode 100755 index 000000000..c913bac7b --- /dev/null +++ b/recommendation_v4/scripts/run_streaming_e2e.sh @@ -0,0 +1,295 @@ +#!/bin/bash +# ============================================================================= +# run_streaming_e2e.sh — self-healing supervisor for a yambda-5b streaming +# train+eval run (sbatch-job level). Works for 1..N nodes. +# ============================================================================= +# +# WHAT IT SUPERVISES +# The run is an `sbatch [--nodes=N] scripts/launch_slurm.sh` BATCH job. That +# batch script is fully self-contained: it runs orchestrate -> provision +# (container + RDMA) -> worker (in-container trainer) on EVERY node, so it +# handles single-node (--nodes=1) and multi-node (world_size=8N) identically. +# This supervisor wraps THAT job: it monitors it and, on crash / node-failure +# / hang, RESUBMITS it (which resumes from the latest checkpoint via load +# auto-latest), bounded by --max-relaunch. There is no docker-exec lifecycle +# or in-place node failover here — node replacement is SLURM's job on resubmit. +# +# RESUME MODEL (why a resubmit "just works") +# The trainer checkpoints to $CKPT_PATH and on startup load_dmp_checkpoint +# auto-resolves to the highest-numbered subdir, restoring model+optimizer+RNG +# and skipping already-trained batches of a partial window. So resubmitting the +# SAME submit-script (same CKPT_PATH/LOG) continues from where it died. +# Resubmits set APPEND_LOG=1 so the metrics log is preserved across attempts. +# +# WHAT IT DETECTS (poll every --poll-s) +# * job left the queue -> read sacct State/ExitCode: +# COMPLETED+0 => run finished (success, exit 0) +# CANCELLED => user intent (stop, exit 0 — NOT our place to resubmit) +# FAILED/NODE_FAIL/TIMEOUT/OUT_OF_MEMORY/BOOT_FAIL/PREEMPTED => relaunch +# * hang watchdog: job RUNNING but LOG frozen >= --stall-s AND no trainer +# process alive on ANY node (cross-node pgrep) => scancel + relaunch. +# * disk guard before each (re)submit: require --min-free-gib on the ckpt vol. +# +# WHERE IT RUNS +# On the SLURM head node (NFS-mounted /home/chcai code + /apps/chcai +# ckpts/logs are visible here for squeue/sacct/df and the cross-node pgrep). +# +# USAGE +# # Submit a fresh job from the launch script, then supervise it: +# nohup bash scripts/run_streaming_e2e.sh \ +# --submit-script /apps/chcai/yambda_5b_e2e//launch_1node.sh \ +# --log /apps/chcai/yambda_5b_e2e//.log \ +# --ckpt-path /apps/chcai/yambda_5b_e2e//ckpts \ +# --run-name \ +# > /apps/chcai/yambda_5b_e2e//.supervisor.console.log 2>&1 & +# +# # Adopt an already-submitted job instead of submitting a new one: +# nohup bash scripts/run_streaming_e2e.sh --jobid 13235 \ +# --submit-script .../launch_2node.sh --log .../run.log \ +# --ckpt-path .../ckpts --run-name > .../console.log 2>&1 & +# +# The node count, partition, and reservation all live in the --submit-script's +# sbatch line (launch_1node.sh / launch_2node.sh / ...), not here. +# +# EXIT CODES +# 0 run completed (COMPLETED+0) or user-cancelled +# 1 exhausted --max-relaunch without completion (or submit failed) +# 3 disk guard tripped +# ============================================================================= +set -uo pipefail + +JOBID="" # adopt this job; empty => submit fresh +SUBMIT_SCRIPT="" +LOG="" +CKPT_PATH="" +RUN_NAME="yambda_5b_e2e" +CONTAINER=yambda_primus +MAX_RELAUNCH=50 +MIN_FREE_GIB=1200 +STALL_S=2400 # 40 min: comfortably exceeds a full-holdout eval + # window + a blocking ckpt save; only trips when + # the log is frozen AND no trainer proc is alive. +POLL_S=30 + +while [[ $# -gt 0 ]]; do + case $1 in + --jobid) JOBID="$2"; shift 2;; + --submit-script) SUBMIT_SCRIPT="$2"; shift 2;; + --log) LOG="$2"; shift 2;; + --ckpt-path) CKPT_PATH="$2"; shift 2;; + --run-name) RUN_NAME="$2"; shift 2;; + --container) CONTAINER="$2"; shift 2;; + --max-relaunch) MAX_RELAUNCH="$2"; shift 2;; + --min-free-gib) MIN_FREE_GIB="$2"; shift 2;; + --stall-s) STALL_S="$2"; shift 2;; + --poll-s) POLL_S="$2"; shift 2;; + *) echo "Unknown arg: $1"; exit 1;; + esac +done + +[[ -n "$SUBMIT_SCRIPT" && -f "$SUBMIT_SCRIPT" ]] || { echo "FATAL: --submit-script required and must exist"; exit 1; } +[[ -n "$LOG" ]] || { echo "FATAL: --log required"; exit 1; } + +SUP_LOG="${LOG%.log}.supervisor.log" +sup() { echo "[$(date '+%F %T')] [supervisor] $*" | tee -a "$SUP_LOG"; } + +# Is the job in the queue right now (single read)? +job_in_queue() { [[ -n "$(squeue -h -j "$1" -o '%T' 2>/dev/null | head -1)" ]]; } +job_state() { squeue -h -j "$1" -o '%T' 2>/dev/null | head -1; } + +# Is the job still active? squeue/the SLURM control plane can transiently return +# empty during an NFS/controller blip even though the job is alive (this once +# killed all supervisors at once: empty squeue -> sacct said RUNNING -> a bogus +# "relaunch"). So a SINGLE empty read is not trusted: re-check a few times before +# believing the job is really gone. +job_active() { + job_in_queue "$1" && return 0 + local k + for k in 1 2 3; do + sleep 10 + job_in_queue "$1" && return 0 + done + return 1 +} + +# Terminal State + ExitCode from accounting once the job has left the queue. +job_final() { sacct -j "$1" -X -n -o State,ExitCode 2>/dev/null | head -1 | tr -s ' '; } + +# sacct/SLURM states that mean the job is STILL ALIVE (not terminal). If we see +# one of these after the monitor loop exits, squeue lied (transient) — resume +# monitoring instead of relaunching (which could spawn a DUPLICATE job). +is_active_state() { + case "$1" in + RUNNING|PENDING|CONFIGURING|COMPLETING|REQUEUED|RESIZING|SUSPENDED|REQUEUE_HOLD|REQUEUE_FED|SIGNALING|STAGE_OUT) return 0;; + *) return 1;; + esac +} + +# Any trainer process alive on ANY node of the allocation? (cross-node pgrep via +# overlap srun into each node's container). [g]enerative self-match guard avoids +# pgrep matching its own command line. +trainer_alive() { + local jid="$1" n + n=$(timeout 70 srun --jobid="$jid" --overlap --ntasks-per-node=1 bash -c \ + "docker exec $CONTAINER bash -lc 'set -f; pgrep -f [g]enerative_recommenders | wc -l' 2>/dev/null" 2>/dev/null \ + | awk '{s+=$1} END{print s+0}') + [[ "${n:-0}" -gt 0 ]] +} + +# Free GiB on the ckpt volume (NFS is mounted on this head node, so df locally). +disk_free_gib() { + df -BG --output=avail "$CKPT_PATH" 2>/dev/null | tail -1 | tr -dc '0-9' +} + +disk_guard() { + [[ -z "$CKPT_PATH" ]] && return 0 + local free; free=$(disk_free_gib); free=${free:-0} + sup "disk guard: ${free} GiB free on $CKPT_PATH (min ${MIN_FREE_GIB})" + if (( free < MIN_FREE_GIB )); then + sup "FATAL: insufficient free space (${free} < ${MIN_FREE_GIB} GiB). Aborting." + return 1 + fi + return 0 +} + +# Resubmit the run; resumes from latest checkpoint. APPEND_LOG=1 preserves the +# metrics log. Echoes the new jobid. +resubmit() { + local out newjid + out=$(APPEND_LOG=1 bash "$SUBMIT_SCRIPT" 2>&1) + newjid=$(echo "$out" | grep -oE 'Submitted batch job [0-9]+' | grep -oE '[0-9]+' | head -1) + echo "$out" | sed 's/^/ /' >> "$SUP_LOG" + echo "$newjid" +} + +# Submit with retries+backoff. A transient NFS / control-plane error (e.g. +# "sbatch: error: ... I/O error writing script/environment to file") must NOT +# kill the supervisor — it leaves runs unsupervised / unlaunched. Echoes a jobid +# on success, or empty after all retries. +submit_retry() { + local cand sub_try + for sub_try in $(seq 1 12); do + cand=$(resubmit) + if [[ "$cand" =~ ^[0-9]+$ ]]; then echo "$cand"; return 0; fi + sup "submit attempt $sub_try/12 failed (transient sbatch/NFS error) — backing off." + sleep $(( sub_try < 5 ? 30 : 120 )) + done + return 1 +} + +sup "=== streaming e2e supervisor start ===" +sup "run=$RUN_NAME submit=$SUBMIT_SCRIPT log=$LOG ckpt=$CKPT_PATH" +sup "max_relaunch=$MAX_RELAUNCH min_free_gib=$MIN_FREE_GIB stall_s=$STALL_S poll_s=$POLL_S" + +attempt=0 +if [[ -z "$JOBID" ]]; then + if ! disk_guard; then exit 3; fi + sup "no --jobid given; submitting a fresh job" + JOBID=$(submit_retry) + [[ "$JOBID" =~ ^[0-9]+$ ]] || { sup "FATAL: could not submit after 12 retries — aborting."; exit 1; } +fi +attempt=1 +sup "supervising jobid=$JOBID (attempt $attempt/$MAX_RELAUNCH)" + +while (( attempt <= MAX_RELAUNCH )); do + # --- wait for the job to be schedulable / running --- + wait_pend=0 + while job_active "$JOBID" && [[ "$(job_state "$JOBID")" != "RUNNING" ]]; do + (( wait_pend % 10 == 0 )) && sup "job $JOBID state=$(job_state "$JOBID") — waiting to run…" + sleep "$POLL_S"; wait_pend=$((wait_pend + 1)) + done + [[ "$(job_state "$JOBID")" == "RUNNING" ]] && sup "job $JOBID RUNNING on $(squeue -h -j "$JOBID" -o '%N' 2>/dev/null | head -1)" + + # --- monitor loop --- + last_size=0; stall_accum=0; hb=0; self_cancelled=0 + while job_active "$JOBID"; do + st=$(job_state "$JOBID") + if [[ "$st" == "RUNNING" ]]; then + cur_size=$(stat -c %s "$LOG" 2>/dev/null || echo 0) + if [[ "$cur_size" == "$last_size" ]]; then + # frozen log: only count as a stall if no trainer proc is alive + # (a long eval / blocking save keeps the process up -> not a stall) + hb=$((hb + 1)) + if (( hb % 4 == 0 )); then + if trainer_alive "$JOBID"; then + stall_accum=0 + else + stall_accum=$((stall_accum + POLL_S * 4)) + sup "log frozen + no trainer alive (${stall_accum}s/${STALL_S}s)" + if (( stall_accum >= STALL_S )); then + sup "STALL: hung run — scancel $JOBID and relaunch." + self_cancelled=1 + scancel "$JOBID" 2>/dev/null || true + sleep 20 + break + fi + fi + fi + else + stall_accum=0; last_size=$cur_size + fi + fi + sleep "$POLL_S" + done + + # --- job has left the queue (or we scancel'd it): decide --- + sleep 5 + final=$(job_final "$JOBID") + state=$(echo "$final" | awk '{print $1}') + code=$(echo "$final" | awk '{print $2}') + sup "job $JOBID ended: state='${state:-?}' exit='${code:-?}'" + + # The monitor loop only exits when squeue has been empty across several + # confirming reads. If accounting STILL reports an active state, the job is + # actually alive (squeue/control-plane blip) — resume monitoring rather than + # relaunching, which would create a duplicate job. + if is_active_state "$state"; then + sup "sacct reports still-active state '$state' — transient squeue blip; resuming monitoring (NOT relaunching)." + sleep "$POLL_S" + continue + fi + + case "$state" in + COMPLETED) + if [[ "$code" == "0:0" ]]; then + sup "RUN COMPLETED CLEANLY on attempt $attempt." + sup "=== supervisor done (success) ===" + exit 0 + fi + sup "COMPLETED but nonzero exit ($code) — relaunching." + ;; + CANCELLED*) + if (( self_cancelled )); then + sup "job CANCELLED by our own stall recovery — relaunching from latest checkpoint." + else + sup "job CANCELLED (user/admin intent) — NOT resubmitting. Stopping supervisor." + sup "=== supervisor done (cancelled) ===" + exit 0 + fi + ;; + FAILED|NODE_FAIL|TIMEOUT|OUT_OF_MEMORY|BOOT_FAIL|PREEMPTED|"") + sup "failure state '${state:-unknown}' — will relaunch from latest checkpoint." + ;; + *) + sup "unrecognized terminal state '${state}' — relaunching to be safe." + ;; + esac + + if (( attempt >= MAX_RELAUNCH )); then break; fi + if ! disk_guard; then exit 3; fi + sleep $(( attempt < 5 ? 20 : 60 )) # small backoff + # Resubmit with retries. A transient NFS / control-plane error (e.g. + # "sbatch: error: Batch job submission failed: I/O error writing + # script/environment to file") must NOT kill the supervisor — that once + # left a live run permanently unsupervised. Retry with backoff first. + JOBID=$(submit_retry) + if ! [[ "$JOBID" =~ ^[0-9]+$ ]]; then + sup "FATAL: resubmit failed after 12 retries — aborting."; exit 1 + fi + attempt=$((attempt + 1)) + sup "relaunched as jobid=$JOBID (attempt $attempt/$MAX_RELAUNCH)" +done + +sup "FATAL: exhausted MAX_RELAUNCH=$MAX_RELAUNCH without completion." +sup "=== supervisor done (failure) ===" +exit 1 diff --git a/recommendation_v4/scripts/run_streaming_e2e_local.sh b/recommendation_v4/scripts/run_streaming_e2e_local.sh new file mode 100755 index 000000000..85a817c8c --- /dev/null +++ b/recommendation_v4/scripts/run_streaming_e2e_local.sh @@ -0,0 +1,200 @@ +#!/bin/bash +# ============================================================================= +# run_streaming_e2e_local.sh — self-healing supervisor for a SINGLE-HOST +# (NON-SLURM) yambda-5b streaming train+eval run. Local analog of +# scripts/run_streaming_e2e.sh (the SLURM/sbatch supervisor). +# +# WHAT IT SUPERVISES +# The "job" is one foreground run of --submit-script (default +# scripts/launch_e2e_local.sh), which `docker exec`s the trainer in the +# container. The supervisor runs that submit-script in the BACKGROUND so: +# * its host PID == liveness (kill -0 / hang watchdog), and +# * `wait $PID` == the trainer's EXIT CODE (success vs. failure). +# On crash / nonzero-exit / hang it RELAUNCHES the same submit-script (same +# $CKPT_PATH/$LOG → resumes from the latest checkpoint), bounded by +# --max-relaunch. This is the SLURM supervisor's sacct/squeue/scancel control +# plane re-expressed with a local process + `docker exec` lifecycle. +# +# WHAT IT DETECTS (poll every --poll-s) +# * submit-script process exits -> read its exit code: +# 0 => run finished cleanly (success) +# != 0 => crash/OOM/die_at_step(42)/etc. => relaunch from latest ckpt +# * hang watchdog: process alive but $LOG frozen >= --stall-s AND no trainer +# process alive in the container (pgrep via docker exec) => kill + pkill in +# container + relaunch. A long eval / blocking ckpt save keeps the trainer +# process up, so it is NOT counted as a stall. +# * disk guard before each (re)launch: require --min-free-gib on the ckpt vol. +# +# USAGE +# nohup bash scripts/run_streaming_e2e_local.sh \ +# --submit-script scripts/launch_e2e_local.sh \ +# --log /home/chcai/yambda_5b_e2e//.log \ +# --ckpt-path /home/chcai/yambda_5b_e2e//ckpts \ +# --run-name \ +# > /home/chcai/yambda_5b_e2e//.supervisor.console.log 2>&1 & +# +# Per-run hyperparameters live in the --submit-script's env defaults (or are +# exported before invoking this supervisor), not here. +# +# EXIT CODES +# 0 run completed cleanly +# 1 exhausted --max-relaunch without completion (or launch failed) +# 3 disk guard tripped +# ============================================================================= +set -uo pipefail + +SUBMIT_SCRIPT="scripts/launch_e2e_local.sh" +LOG="" +CKPT_PATH="" +RUN_NAME="yambda_5b_e2e_local" +CONTAINER=${CONTAINER:-yambda_local} +DOCKER=${DOCKER:-sudo docker} +MAX_RELAUNCH=50 +MIN_FREE_GIB=700 # one full DMP ckpt (~600 GB) + headroom for the atomic + # .tmp written beside the retained one during a save. +STALL_S=2400 # 40 min frozen-log + no-trainer-proc => hung. +POLL_S=30 + +while [[ $# -gt 0 ]]; do + case $1 in + --submit-script) SUBMIT_SCRIPT="$2"; shift 2;; + --log) LOG="$2"; shift 2;; + --ckpt-path) CKPT_PATH="$2"; shift 2;; + --run-name) RUN_NAME="$2"; shift 2;; + --container) CONTAINER="$2"; shift 2;; + --docker) DOCKER="$2"; shift 2;; + --max-relaunch) MAX_RELAUNCH="$2"; shift 2;; + --min-free-gib) MIN_FREE_GIB="$2"; shift 2;; + --stall-s) STALL_S="$2"; shift 2;; + --poll-s) POLL_S="$2"; shift 2;; + *) echo "Unknown arg: $1"; exit 1;; + esac +done + +[[ -n "$SUBMIT_SCRIPT" && -f "$SUBMIT_SCRIPT" ]] || { echo "FATAL: --submit-script required and must exist ($SUBMIT_SCRIPT)"; exit 1; } +[[ -n "$LOG" ]] || { echo "FATAL: --log required"; exit 1; } + +SUP_LOG="${LOG%.log}.supervisor.log" +# Create the log + ckpt dirs up front so the disk guard's df has a real path to +# stat (df on a nonexistent dir returns 0 avail -> false "disk full" abort). +mkdir -p "$(dirname "$SUP_LOG")" +[[ -n "$CKPT_PATH" ]] && mkdir -p "$CKPT_PATH" +sup() { echo "[$(date '+%F %T')] [supervisor] $*" | tee -a "$SUP_LOG"; } + +# Any trainer process alive in the container? [g]enerative self-match guard +# avoids pgrep matching its own command line. +trainer_alive() { + local n + n=$($DOCKER exec "$CONTAINER" bash -lc 'set -f; pgrep -f "[g]enerative_recommenders" | wc -l' 2>/dev/null | tr -dc '0-9') + [[ "${n:-0}" -gt 0 ]] +} + +# Hard-kill any trainer processes left in the container AND wait for GPU HBM to +# actually drain before returning. A rank stuck in a HIP/RCCL collective sits in +# uninterruptible D-state and keeps its multi-hundred-GB embedding shard resident +# for many seconds after SIGKILL; relaunching before that frees makes the next +# attempt OOM on dirty GPUs (an OOM-crash -> dirty-GPU -> OOM cascade). So kill, +# then poll rocm-smi until every GPU is <5 GB (or give up after ~120s). +cleanup_container() { + $DOCKER exec "$CONTAINER" bash -lc \ + 'pkill -9 -f generative_recommenders 2>/dev/null; pkill -9 -f spawn_main 2>/dev/null; pkill -9 -f resource_tracker 2>/dev/null; true' \ + 2>/dev/null || true + local k busy + for k in $(seq 1 24); do # up to ~120s + busy=$($DOCKER exec "$CONTAINER" bash -lc \ + "rocm-smi --showmeminfo vram 2>/dev/null | awk '/Used/{if (\$NF+0 > 5e9) c++} END{print c+0}'" \ + 2>/dev/null | tr -dc '0-9') + busy=${busy:-0} + [[ "$busy" == "0" ]] && return 0 + sup "waiting for GPU HBM to drain ($busy GPU(s) still >5GB)…" + $DOCKER exec "$CONTAINER" bash -lc 'pkill -9 -f spawn_main 2>/dev/null; true' 2>/dev/null || true + sleep 5 + done + sup "WARNING: GPUs still show residual HBM after 120s — launching anyway." + return 0 +} + +disk_free_gib() { df -BG --output=avail "$CKPT_PATH" 2>/dev/null | tail -1 | tr -dc '0-9'; } + +disk_guard() { + [[ -z "$CKPT_PATH" ]] && return 0 + local free; free=$(disk_free_gib); free=${free:-0} + sup "disk guard: ${free} GiB free on $CKPT_PATH (min ${MIN_FREE_GIB})" + if (( free < MIN_FREE_GIB )); then + sup "FATAL: insufficient free space (${free} < ${MIN_FREE_GIB} GiB). Aborting." + return 1 + fi + return 0 +} + +# Run the submit-script in the FOREGROUND (its exit status == the trainer's). +# APPEND_LOG=1 preserves the metrics log across relaunches. This is invoked as +# `launch & PID=$!` from the main loop so the backgrounded copy is a DIRECT child +# of this shell — otherwise `wait $PID` can't reap it and always returns 127, +# making every clean completion look like a failure (infinite relaunch loop). +launch() { + APPEND_LOG=1 CONTAINER="$CONTAINER" DOCKER="$DOCKER" \ + RUN_NAME="$RUN_NAME" LOG="$LOG" CKPT_PATH="$CKPT_PATH" \ + bash "$SUBMIT_SCRIPT" >>"$SUP_LOG" 2>&1 +} + +sup "=== streaming e2e LOCAL supervisor start ===" +sup "run=$RUN_NAME submit=$SUBMIT_SCRIPT log=$LOG ckpt=$CKPT_PATH container=$CONTAINER" +sup "max_relaunch=$MAX_RELAUNCH min_free_gib=$MIN_FREE_GIB stall_s=$STALL_S poll_s=$POLL_S" + +attempt=1 +while (( attempt <= MAX_RELAUNCH )); do + if ! disk_guard; then exit 3; fi + sup "launching attempt $attempt/$MAX_RELAUNCH" + cleanup_container # ensure no stragglers from a prior attempt + launch & PID=$! # direct child => wait $PID reaps the real rc + sup "submit-script running as host pid=$PID" + + # --- monitor loop --- + last_size=0; stall_accum=0; hb=0; hung=0 + while kill -0 "$PID" 2>/dev/null; do + cur_size=$(stat -c %s "$LOG" 2>/dev/null || echo 0) + if [[ "$cur_size" == "$last_size" ]]; then + hb=$((hb + 1)) + # Re-check liveness only every 4 polls (cheap docker exec amortized). + if (( hb % 4 == 0 )); then + if trainer_alive; then + stall_accum=0 + else + stall_accum=$((stall_accum + POLL_S * 4)) + sup "log frozen + no trainer alive (${stall_accum}s/${STALL_S}s)" + if (( stall_accum >= STALL_S )); then + sup "STALL: hung run — killing pid=$PID + container trainer procs, will relaunch." + hung=1 + kill -9 "$PID" 2>/dev/null || true + cleanup_container + break + fi + fi + fi + else + stall_accum=0; last_size=$cur_size + fi + sleep "$POLL_S" + done + + # --- the submit-script has exited (or we killed it): decide --- + wait "$PID" 2>/dev/null; rc=$? + if (( hung )); then + sup "attempt $attempt ended via STALL recovery (rc=$rc) — relaunching from latest checkpoint." + elif (( rc == 0 )); then + sup "RUN COMPLETED CLEANLY on attempt $attempt." + sup "=== supervisor done (success) ===" + exit 0 + else + sup "attempt $attempt exited rc=$rc (crash/OOM/die_at_step) — relaunching from latest checkpoint." + fi + + if (( attempt >= MAX_RELAUNCH )); then break; fi + sleep $(( attempt < 5 ? 20 : 60 )) # small backoff + attempt=$((attempt + 1)) +done + +sup "FATAL: exhausted MAX_RELAUNCH=$MAX_RELAUNCH without completion." +sup "=== supervisor done (failure) ===" +exit 1 diff --git a/recommendation_v4/scripts/stitch_traces.py b/recommendation_v4/scripts/stitch_traces.py new file mode 100644 index 000000000..54d7963d6 --- /dev/null +++ b/recommendation_v4/scripts/stitch_traces.py @@ -0,0 +1,329 @@ +#!/usr/bin/env python3 +"""Stitch per-rank Chrome traces from a dlrm_v3 run into one merged file. + +When ``Profiler`` runs on multiple ranks, each rank writes its own file: + + /trace_step{step}_rank{rank}.json + +Each per-rank trace uses overlapping ``pid`` namespaces (CPU pid = OS pid; +GPU streams pid = 0..N), so concatenating the raw event lists would collapse +multiple ranks onto the same Perfetto track. This script: + +* Identifies each pid as ``CPU`` / ``GPU`` / ``Spans`` (and other torch.profiler + string-pid tracks) using the per-rank ``process_labels`` metadata events. +* Always drops the ``Spans`` track (low-signal in this codebase, large in + visual clutter). +* Optionally filters to just ``cpu`` or ``gpu`` events via ``--include``. +* Sorts the surviving tracks into contiguous Perfetto sections: + **all CPU tracks (rank 0..N) first, then all GPU tracks (rank 0..N, stream + 0..K)**. +* Remaps every event's ``pid`` and flow ``id`` so cross-rank events never + collide on the same track or flow arrow. + +Because torch.profiler emits ``baseTimeNanoseconds`` from the same node clock, +timestamps line up directly across ranks — no time-shift needed for single-node +runs (multi-node would need clock-skew correction, not implemented here). + +Examples +-------- +Stitch step 52, default (CPU + GPU, drop Spans), gzip output:: + + python scripts/stitch_traces.py --step 52 --gzip + +GPU-only view (skip CPU thread tree entirely — useful for kernel-level analysis):: + + python scripts/stitch_traces.py --step 52 --include gpu --gzip + +CPU-only view (host-side ops, profiler annotations, comm scheduling):: + + python scripts/stitch_traces.py --step 52 --include cpu --gzip +""" +from __future__ import annotations + +import argparse +import gzip +import json +import re +import sys +from collections import defaultdict +from pathlib import Path + +# trace_step52_rank3.json or trace_3_rank0.json (legacy filename) +_RANK_RE = re.compile(r"trace_(?:step)?(\d+)_rank(\d+)\.json$") +_KEY_RE = re.compile(r"trace_(.+?)_rank\d+\.json$") + +# Per-rank pid offset. Picked large enough that no real OS pid collides +# (Linux pids fit in 22 bits; 1e6 per rank gives ~10 ranks of headroom). +_PID_STRIDE = 1_000_000 + +# Per-rank flow-id offset. torch.profiler flow ids are int32/int64 — pack rank +# into the high bits so cross-rank flows can never link by accident. +_FLOW_ID_STRIDE = 1 << 40 + +# Sort-index sections in Perfetto. Lower = appears higher in the timeline UI. +# Each section reserves a wide range so within-section ordering (rank, stream) +# fits comfortably without overlapping the next section. +_SORT_BASE = { + "cpu": 0, + "gpu": 1_000_000, + "other": 10_000_000, # Traces / "" misc string-pid tracks +} + +# `Spans` carries no useful content in our workloads (one X event per trace) +# and clutters the timeline — always dropped. +_ALWAYS_DROP_PIDS_STR = {"Spans"} + + +def _classify_pid(pid_to_label: dict, pid_to_name: dict) -> dict: + """Map original pid -> ('cpu'|'gpu'|'spans'|'other', stream_idx_or_0). + + Classification order, first match wins: + 1. pid (as a string) is in the always-drop set -> 'spans' + 2. process_name is in the always-drop set -> 'spans' + 3. process_labels == 'CPU' -> 'cpu' + 4. process_labels starts with 'GPU ' -> 'gpu', stream id + 5. anything else (including unlabeled pids) -> 'other' + """ + all_pids = set(pid_to_label) | set(pid_to_name) + out: dict = {} + for pid in all_pids: + label = pid_to_label.get(pid, "") + name = pid_to_name.get(pid, "") + if isinstance(pid, str) and pid in _ALWAYS_DROP_PIDS_STR: + out[pid] = ("spans", 0) + continue + if name in _ALWAYS_DROP_PIDS_STR: + out[pid] = ("spans", 0) + continue + if label == "CPU": + out[pid] = ("cpu", 0) + elif label.startswith("GPU"): + try: + stream_idx = int(label.split()[1]) + except (IndexError, ValueError): + stream_idx = 0 + out[pid] = ("gpu", stream_idx) + else: + out[pid] = ("other", 0) + return out + + +def _scan_pid_metadata(events: list[dict]) -> tuple[dict, dict]: + """First pass: collect per-pid label and name from ``ph='M'`` events.""" + label: dict = {} + name: dict = {} + for e in events: + if e.get("ph") != "M": + continue + pid = e.get("pid") + if pid is None: + continue + if e.get("name") == "process_labels": + label[pid] = e.get("args", {}).get("labels", "") + elif e.get("name") == "process_name": + name[pid] = e.get("args", {}).get("name", "") + return label, name + + +def _new_sort_index(kind: str, rank: int, stream_idx: int) -> int: + """Compute Perfetto sort_index so tracks group as: CPU(rank0..N), GPU(rank0..N, stream0..K), other.""" + base = _SORT_BASE.get(kind, _SORT_BASE["other"]) + return base + rank * 100 + stream_idx + + +def _new_pid(orig_pid, rank: int) -> object: + """Remap a single pid into a per-rank namespace, preserving int vs str.""" + if isinstance(orig_pid, int): + return orig_pid + rank * _PID_STRIDE + if isinstance(orig_pid, str): + try: + return int(orig_pid) + rank * _PID_STRIDE + except ValueError: + return f"rank{rank}_{orig_pid}" if orig_pid else f"rank{rank}_misc" + return orig_pid + + +def _process_one_rank( + events: list[dict], + rank: int, + include: set[str], +) -> list[dict]: + """Filter + remap one rank's events. ``include`` is a subset of {'cpu','gpu','other'}.""" + label, name = _scan_pid_metadata(events) + classify = _classify_pid(label, name) + + out: list[dict] = [] + for e in events: + pid = e.get("pid") + if pid is None: + out.append(e) + continue + # Always-drop check on the raw pid value first - Spans events in our + # workloads have NO process_name/process_labels metadata, so the + # classifier table doesn't list them. Catch them here directly. + if isinstance(pid, str) and pid in _ALWAYS_DROP_PIDS_STR: + continue + kind, stream_idx = classify.get(pid, ("other", 0)) + if kind == "spans": # always dropped + continue + if kind not in include: # filtered by --include + continue + + # Remap pid + flow id (per-rank namespace). + e["pid"] = _new_pid(pid, rank) + if "id" in e and e.get("ph") in ("s", "t", "f"): + try: + e["id"] = int(e["id"]) + rank * _FLOW_ID_STRIDE + except (TypeError, ValueError): + pass + + # Rewrite metadata: section-aware sort_index + rank-prefixed name. + if e.get("ph") == "M": + args = e.setdefault("args", {}) + if e.get("name") == "process_sort_index": + args["sort_index"] = _new_sort_index(kind, rank, stream_idx) + elif e.get("name") == "process_name": + orig = args.get("name", "python") + args["name"] = f"[Rank {rank}] {orig}" + + out.append(e) + + return out + + +def _group_by_step(trace_dir: Path) -> dict[str, dict[int, Path]]: + """Map step-key (e.g. ``"step52"`` or ``"3"``) -> {rank: path}.""" + groups: dict[str, dict[int, Path]] = defaultdict(dict) + for p in sorted(trace_dir.glob("trace_*_rank*.json")): + m = _RANK_RE.search(p.name) + if not m: + continue + prefix_match = _KEY_RE.match(p.name) + key = prefix_match.group(1) if prefix_match else m.group(1) + groups[key][int(m.group(2))] = p + return dict(groups) + + +def stitch_one(rank_to_path: dict[int, Path], out_path: Path, *, + include: set[str], gzip_out: bool, verbose: bool) -> None: + """Merge one (step, rank->path) group into a single trace file.""" + merged_events: list[dict] = [] + base: dict | None = None + + for rank in sorted(rank_to_path): + path = rank_to_path[rank] + if verbose: + sz_mb = path.stat().st_size / (1 << 20) + print(f" rank {rank}: {path.name} ({sz_mb:.1f} MB)", file=sys.stderr) + with path.open() as f: + trace = json.load(f) + if base is None: + base = {k: v for k, v in trace.items() if k != "traceEvents"} + base["distributedInfo"] = { + **trace.get("distributedInfo", {}), + "stitched_ranks": sorted(rank_to_path), + "stitched_files": [p.name for p in rank_to_path.values()], + "stitched_include": sorted(include), + } + merged_events.extend( + _process_one_rank(trace.get("traceEvents", []), rank, include) + ) + + assert base is not None, "no input traces provided" + base["traceEvents"] = merged_events + + out_path.parent.mkdir(parents=True, exist_ok=True) + if gzip_out: + with gzip.open(out_path, "wt") as f: + json.dump(base, f) + else: + with out_path.open("w") as f: + json.dump(base, f) + if verbose: + sz_mb = out_path.stat().st_size / (1 << 20) + print( + f" -> {out_path} ({len(merged_events):,} events, {sz_mb:.1f} MB)", + file=sys.stderr, + ) + + +def main() -> int: + ap = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + ap.add_argument("trace_dir", type=Path, + help="Directory containing trace_*_rank*.json files.") + ap.add_argument("--step", type=str, default=None, + help="Stitch only the given step key (e.g. '52' or 'step52'). " + "Default: stitch every step group found.") + ap.add_argument("--out", type=Path, default=None, + help="Output path. Only valid when --step selects exactly " + "one group. Default: /trace_.json[.gz] " + "(or trace__cpu/_gpu when --include filters).") + ap.add_argument("--include", choices=("cpu", "gpu", "both"), default="both", + help="Which sections to keep: cpu-only tracks, gpu-only " + "tracks, or both (default). 'Spans' is always dropped.") + ap.add_argument("--gzip", action="store_true", + help="Write gzip-compressed JSON (Perfetto auto-detects).") + ap.add_argument("-q", "--quiet", action="store_true") + args = ap.parse_args() + + if not args.trace_dir.is_dir(): + print(f"error: {args.trace_dir} is not a directory", file=sys.stderr) + return 2 + + if args.include == "both": + # 'other' covers torch.profiler string-pid tracks (Traces / misc) that + # carry low-volume but legitimate annotations. Dropped under cpu/gpu + # so each filtered view is clean. + include = {"cpu", "gpu", "other"} + else: + include = {args.include} + + groups = _group_by_step(args.trace_dir) + if not groups: + print(f"error: no trace_*_rank*.json files under {args.trace_dir}", + file=sys.stderr) + return 2 + + if args.step is not None: + wanted = args.step if args.step.startswith("step") else f"step{args.step}" + if wanted not in groups and args.step in groups: + wanted = args.step + if wanted not in groups: + print( + f"error: step {args.step!r} not found. " + f"Available: {sorted(groups)}", + file=sys.stderr, + ) + return 2 + groups = {wanted: groups[wanted]} + + if args.out is not None and len(groups) != 1: + print("error: --out requires --step to select exactly one group", + file=sys.stderr) + return 2 + + for key, rank_map in sorted(groups.items()): + if not args.quiet: + print( + f"stitching {key} ({len(rank_map)} ranks, include={args.include}):", + file=sys.stderr, + ) + if args.out is not None: + out = args.out + else: + ext = ".json.gz" if args.gzip else ".json" + # Default mode ("both") gets the bare filename; explicit cpu/gpu + # filters tag the output so they can coexist in one directory. + suffix = "" if args.include == "both" else f"_{args.include}" + out = args.trace_dir / f"trace_{key}{suffix}{ext}" + stitch_one(rank_map, out, include=include, + gzip_out=args.gzip, verbose=not args.quiet) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/recommendation_v4/setup.py b/recommendation_v4/setup.py new file mode 100644 index 000000000..bdab528f4 --- /dev/null +++ b/recommendation_v4/setup.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-unsafe + +from setuptools import find_packages, setup + +setup( + name="generative_recommenders", + version="0.1.0", + description="Library for generative recommendation algorithms.", + packages=find_packages(exclude=["configs"]), + python_requires=">=3.10", + install_requires=[ + "torch>=2.6.0", + "fbgemm_gpu>=1.1.0", + "torchrec>=1.1.0", + "gin_config>=0.5.0", + "pandas>=2.2.0", + "tensorboard>=2.19.0", + "pybind11", + "click", + "pandas", + "matplotlib", + ], + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + url="https://github.com/meta-recsys/generative-recommenders", + license="Apache-2.0", +) diff --git a/recommendation_v4/verify_dataset.sh b/recommendation_v4/verify_dataset.sh new file mode 100755 index 000000000..839ccb91f --- /dev/null +++ b/recommendation_v4/verify_dataset.sh @@ -0,0 +1,64 @@ +#!/usr/bin/env bash +# MLPerf Training reference script: verify the preprocessed dataset. +# +# Checks the integrity of the preprocessed dataset under +# ${DLRM_DATA_PATH}/${PROCESSED_SUBDIR} +# against md5sums_yambda_5b_processed.txt (standard `md5sum -c` format). +# +# If the checksum file still contains placeholder hashes (TODO_GENERATE_HASH), +# the script falls back to an existence/layout check and warns that the +# canonical checksums have not been pinned yet. +# +# Usage: +# DLRM_DATA_PATH=/path/to/dlrm_data ./verify_dataset.sh +# +# Env: +# DLRM_DATA_PATH data root (required). +# PROCESSED_SUBDIR processed subdir under the data root (default: processed_5b). +set -euo pipefail + +: "${DLRM_DATA_PATH:?Set DLRM_DATA_PATH to the data root}" +PROCESSED_SUBDIR="${PROCESSED_SUBDIR:-processed_5b}" + +REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CHECKSUM_FILE="${REPO_ROOT}/md5sums_yambda_5b_processed.txt" +PROCESSED_DIR="${DLRM_DATA_PATH}/${PROCESSED_SUBDIR}" + +echo "[verify_dataset] processed dir: ${PROCESSED_DIR}" + +if [[ ! -d "${PROCESSED_DIR}" ]]; then + echo "[verify_dataset] ERROR: ${PROCESSED_DIR} does not exist. Run ./download_dataset.sh first." >&2 + exit 1 +fi + +EXPECTED_FILES=( + train_sessions.parquet + test_events.parquet + session_index.parquet + item_popularity.npy + split_meta.json +) + +# Detect whether the checksum file has real (32 hex char) hashes or placeholders. +if grep -qiE '^[0-9a-f]{32}[[:space:]]' "${CHECKSUM_FILE}"; then + echo "[verify_dataset] checking md5 checksums from ${CHECKSUM_FILE}" + (cd "${PROCESSED_DIR}" && md5sum -c "${CHECKSUM_FILE}") + echo "[verify_dataset] OK: all checksums match." +else + echo "[verify_dataset] WARNING: ${CHECKSUM_FILE} contains placeholder hashes;" >&2 + echo "[verify_dataset] falling back to existence/layout check only." >&2 + missing=0 + for f in "${EXPECTED_FILES[@]}"; do + if [[ -s "${PROCESSED_DIR}/${f}" ]]; then + echo " OK ${f}" + else + echo " MISS ${f}" >&2 + missing=1 + fi + done + if [[ "${missing}" -ne 0 ]]; then + echo "[verify_dataset] ERROR: one or more expected files are missing/empty." >&2 + exit 1 + fi + echo "[verify_dataset] layout OK (checksums NOT yet pinned -- see TODO in ${CHECKSUM_FILE})." +fi