Skip to content

Demonstrate parallel execution of a loss function #50

Description

@jackbaker1001

On an HPC cluster, each term in a mean square loss can be calculated using embarrassingly parallel logic.

Unfortunately, the native way of doing this with jax (using jax.vmap and jax.pmap) is not compatible with input we must parallelize over: the Molecule object. This is because its data is stored in "ragged" structure. I.e., the dimensions of the grid for one molecule are very often different from the grid for another and the dimensions of the 1-RDM for one molecule is different for another: jnp.array([rdm1_1, rdm1_2]) will not work.

This means that for loss parallelism, we need to think differently. Sharding may be the way forward, but this requires more thought. A good reference is here: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html

I don't think we will get around to solving this problem before our release deadline, but if we want to do something with HPC, getting this right is non-negotiable.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions