diff --git a/src/deepquantum/photonic/circuit.py b/src/deepquantum/photonic/circuit.py index f148bf0d..53af7540 100644 --- a/src/deepquantum/photonic/circuit.py +++ b/src/deepquantum/photonic/circuit.py @@ -1168,7 +1168,8 @@ def _get_prob_gaussian_base( sub_mat = sub_gamma else: sub_mat[torch.arange(len(sub_gamma)), torch.arange(len(sub_gamma))] = sub_gamma - haf = abs(hafnian(sub_mat, loop=loop)) ** 2 if purity else hafnian(sub_mat, loop=loop) + temp_haf = hafnian(sub_mat, loop=loop) + haf = temp_haf.real.square() + temp_haf.imag.square() if purity else temp_haf prob = p_vac * haf / product_factorial(final_state).to(haf.device, haf.dtype) elif detector == 'threshold': final_state_double = torch.cat([final_state, final_state]) diff --git a/src/deepquantum/photonic/hafnian_.py b/src/deepquantum/photonic/hafnian_.py index 19b560d6..e38f5581 100644 --- a/src/deepquantum/photonic/hafnian_.py +++ b/src/deepquantum/photonic/hafnian_.py @@ -49,7 +49,9 @@ def get_submat_haf(a: torch.Tensor, z: torch.Tensor) -> torch.Tensor: return submat -def poly_lambda(submat: torch.Tensor, int_partition: list, power: int, loop: bool = False) -> torch.Tensor: +def poly_lambda( + submat: torch.Tensor, int_partition: list, power: int, loop: bool = False, threshold: float = 1e-30 +) -> torch.Tensor: """Get the coefficient of the polynomial. See https://arxiv.org/abs/1805.12498 Eq.(3.26) (noting that Eq.(3.26) contains a typo) and @@ -117,7 +119,9 @@ def hafnian(matrix: torch.Tensor, loop: bool = False) -> torch.Tensor: z_sets = torch.tensor(powerset[i], device=matrix.device) num_z = len(z_sets[0]) submats = torch.vmap(get_submat_haf, in_dims=(None, 0))(matrix, z_sets) + submats = submats.to(torch.cdouble) coeff = torch.vmap(poly_lambda, in_dims=(0, None, None, None))(submats, int_partition, power, loop) + coeff = coeff.to(matrix.dtype) coeff_sum = (-1) ** (power - num_z) * coeff.sum() haf += coeff_sum return haf