Warning
This library is in an experimental phase and may be buggy. Expect breaking changes between versions.
Rixa (Runtime Initialization by pmiX Adoption) is a high-performance library that provides a unified and efficient way to bootstrap distributed PyTorch jobs. It leverages PMIx 5.0 to seamlessly launch PyTorch workloads on large-scale HPC clusters, eliminating the need to manually specify the master IP address and port.
Standard PyTorch bootstrapping (TCP/File-store) is built for cloud portability but often
struggles with scale and reliability on bare-metal HPC clusters.
rixa bypasses these overheads by using PMIx as a native high-performance key-value store, providing
-
Zero-config Launching: No master IP/Port orchestration required.
-
HPC Native: Leverages existing Slurm, Flux, or OpenMPI environments.
-
ABI Compatibility: Built on PMIx 5.0+ (ABI stable), ensuring portability across different MPI/PMIx versions.
Note
For detailed installation instructions, see the wiki.
In order to install one needs to specify the version of the library. Currently two versions are supported, pytorch and nvshmem.
They can be easily installed with
# Install for PyTorch support
pip install "rixa[pytorch]"
# Install for NVSHMEM support
pip install "rixa[nvshmem]"Note
For more use cases, see the wiki.
One can use rixa to start PyTorch distributed job with one simple line
import rixa
rixa.pytorch.init_process_group(pytorch_argument1, pytorch_argument2, keyword2=pytorch_parameter)Jobs can be launched with any PMIx 5.0-compatible plugin, starting with prrte, some MPI implementations (OpenMPI 5.0),
native job launcher plugins to SLURM or Flux. Example:
prterun -n 16 python3 -c "import rixa; rixa.pytorch.init_process_group(); import torch; print(torch.distributed.get_rank())"Remember to manually finalize the backend!
torch.distributed.destroy_process_group()Nvshmem usage is very similar to the pytorch usage, one needs to use a thin wrapper around the native nvshmem init.
Example:
import rixa
from cuda.core import Device
store = rixa.PMIxStore(30) #manualy specify the PMIx backend with manual timeout, can become handy to set the device
dev = Device(0) #first device or just set based on the use case
rixa.nvshmem.init(dev, store) #device, storeRemember to manually finalize the backend!
nvshmem.finalize()- Support for NVSHMEM (pytorch, cupy)
- Support for JAX