An error occurs when running the training script, related to a boolean conversion of a JAX-traced array in the method select_HF_omegas within the GradDFT-main/grad_dft/molecule.py file. The error traceback is as follows:
Traceback (most recent call last):
File "/home/yjiao/DeepRSH/module/train.py", line 657, in <module>
main()
File "/home/yjiao/DeepRSH/module/train.py", line 488, in main
state, _, epoch_metrics = train_epoch(state, kernel, dataset)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yjiao/DeepRSH/module/train.py", line 418, in train_epoch
params, opt_state, cost_val, metrics = kernel(params, opt_state, system[1], system[1].energy)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yjiao/GradDFT-main/grad_dft/train.py", line 352, in kernel
(cost_value, predictedenergy), grads = loss(params, atoms, ground_truth_energy)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yjiao/DeepRSH/module/train.py", line 394, in loss
predicted_energy, fock = compute_energy(params, molecule)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yjiao/GradDFT-main/grad_dft/train.py", line 147, in predict
Exc, fock_xc = xc_energy_and_grads(params, atoms.rdm1, atoms, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yjiao/GradDFT-main/grad_dft/train.py", line 115, in xc_energy_and_grads
densities = functional.compute_densities(atoms, *args, **functional_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yjiao/GradDFT-main/grad_dft/functional.py", line 176, in compute_densities
nograd_densities = stop_gradient(self.nograd_densities(atoms, *args, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yjiao/DeepRSH/module/train.py", line 376, in <lambda>
nograd_densities=lambda molecule, *_, **__: molecule.HF_energy_density(omegas),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yjiao/GradDFT-main/grad_dft/molecule.py", line 203, in HF_energy_density
chi = self.select_HF_omegas(omegas)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yjiao/GradDFT-main/grad_dft/molecule.py", line 177, in select_HF_omegas
if o not in self.omegas:
^^^^^^^^^^^^^^^^^^^^
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function kernel at /home/yjiao/GradDFT-main/grad_dft/train.py:328 for jit. This concrete value was not available in Python because it depends on the value of the argument atoms.omegas.
The script should run without JAX boolean conversion errors when operating on traced arrays or values inside a jitted context.
Bug Description
An error occurs when running the training script, related to a boolean conversion of a JAX-traced array in the method
select_HF_omegaswithin theGradDFT-main/grad_dft/molecule.pyfile. The error traceback is as follows:Steps to Reproduce
select_HF_omegas.Expected behavior
The script should run without JAX boolean conversion errors when operating on traced arrays or values inside a jitted context.