On an HPC cluster, each term in a mean square loss can be calculated using embarrassingly parallel logic.
Unfortunately, the native way of doing this with jax (using jax.vmap and jax.pmap) is not compatible with input we must parallelize over: the Molecule object. This is because its data is stored in "ragged" structure. I.e., the dimensions of the grid for one molecule are very often different from the grid for another and the dimensions of the 1-RDM for one molecule is different for another: jnp.array([rdm1_1, rdm1_2]) will not work.
This means that for loss parallelism, we need to think differently. Sharding may be the way forward, but this requires more thought. A good reference is here: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
I don't think we will get around to solving this problem before our release deadline, but if we want to do something with HPC, getting this right is non-negotiable.
On an HPC cluster, each term in a mean square loss can be calculated using embarrassingly parallel logic.
Unfortunately, the native way of doing this with
jax(usingjax.vmapandjax.pmap) is not compatible with input we must parallelize over: theMoleculeobject. This is because its data is stored in "ragged" structure. I.e., the dimensions of the grid for one molecule are very often different from the grid for another and the dimensions of the 1-RDM for one molecule is different for another:jnp.array([rdm1_1, rdm1_2])will not work.This means that for loss parallelism, we need to think differently. Sharding may be the way forward, but this requires more thought. A good reference is here: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
I don't think we will get around to solving this problem before our release deadline, but if we want to do something with HPC, getting this right is non-negotiable.