Skip to content

TypeError: Attempted boolean conversion of traced array with shape bool[] in select_HF_omegas (JAX TracerBoolConversionError) #99

Description

@STOKES-DOT

Bug Description

An error occurs when running the training script, related to a boolean conversion of a JAX-traced array in the method select_HF_omegas within the GradDFT-main/grad_dft/molecule.py file. The error traceback is as follows:

Traceback (most recent call last):
  File "/home/yjiao/DeepRSH/module/train.py", line 657, in <module>
    main()
  File "/home/yjiao/DeepRSH/module/train.py", line 488, in main
    state, _, epoch_metrics = train_epoch(state, kernel, dataset)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yjiao/DeepRSH/module/train.py", line 418, in train_epoch
    params, opt_state, cost_val, metrics = kernel(params, opt_state, system[1], system[1].energy)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yjiao/GradDFT-main/grad_dft/train.py", line 352, in kernel
    (cost_value, predictedenergy), grads = loss(params, atoms, ground_truth_energy)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yjiao/DeepRSH/module/train.py", line 394, in loss
    predicted_energy, fock = compute_energy(params, molecule)
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yjiao/GradDFT-main/grad_dft/train.py", line 147, in predict
    Exc, fock_xc = xc_energy_and_grads(params, atoms.rdm1, atoms, *args)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yjiao/GradDFT-main/grad_dft/train.py", line 115, in xc_energy_and_grads
    densities = functional.compute_densities(atoms, *args, **functional_kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yjiao/GradDFT-main/grad_dft/functional.py", line 176, in compute_densities
    nograd_densities = stop_gradient(self.nograd_densities(atoms, *args, **kwargs))
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yjiao/DeepRSH/module/train.py", line 376, in <lambda>
    nograd_densities=lambda molecule, *_, **__: molecule.HF_energy_density(omegas),
                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yjiao/GradDFT-main/grad_dft/molecule.py", line 203, in HF_energy_density
    chi = self.select_HF_omegas(omegas)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yjiao/GradDFT-main/grad_dft/molecule.py", line 177, in select_HF_omegas
    if o not in self.omegas:
       ^^^^^^^^^^^^^^^^^^^^
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function kernel at /home/yjiao/GradDFT-main/grad_dft/train.py:328 for jit. This concrete value was not available in Python because it depends on the value of the argument atoms.omegas.

Steps to Reproduce

  1. Run the training script as normal.
  2. Observe the traceback above after the call to select_HF_omegas.

Expected behavior

The script should run without JAX boolean conversion errors when operating on traced arrays or values inside a jitted context.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions