From 504782b6a991ca5c393ce17c10e80330741d5e71 Mon Sep 17 00:00:00 2001 From: sansiro77 <30891481+sansiro77@users.noreply.github.com> Date: Wed, 17 Jun 2026 16:57:34 +0800 Subject: [PATCH] Fix PhotonLoss input conventions --- src/deepquantum/photonic/channel.py | 30 +++++++++++++++++++++-- src/deepquantum/photonic/circuit.py | 37 ++++++++++++++++++++++------- 2 files changed, 56 insertions(+), 11 deletions(-) diff --git a/src/deepquantum/photonic/channel.py b/src/deepquantum/photonic/channel.py index 5187b548..9ab3e9c3 100644 --- a/src/deepquantum/photonic/channel.py +++ b/src/deepquantum/photonic/channel.py @@ -25,6 +25,8 @@ class PhotonLoss(Channel): nmode: The number of modes that the quantum operation acts on. Default: 1 wires: The indices of the modes that the quantum operation acts on. Default: ``None`` cutoff: The Fock space truncation. Default: ``None`` + convention: The convention of the input parameter, including ``'theta'``, ``'t'`` and ``'db'``. + Default: ``'theta'`` requires_grad: Whether the parameter is ``nn.Parameter`` or ``buffer``. Default: ``False`` (which means ``buffer``) """ @@ -35,12 +37,15 @@ def __init__( nmode: int = 1, wires: int | list[int] | None = None, cutoff: int | None = None, + convention: str = 'theta', requires_grad: bool = False, ) -> None: super().__init__(name='PhotonLoss', nmode=nmode, wires=wires, cutoff=cutoff) + assert convention in ('theta', 't', 'db'), 'Invalid convention' + self.convention = convention self.requires_grad = requires_grad self.gate = BeamSplitterSingle( - inputs=inputs, + inputs=self._inputs_to_theta(inputs), nmode=self.nmode + 1, wires=self.wires + [self.nmode], cutoff=cutoff, @@ -60,6 +65,11 @@ def t(self): """Transmittance.""" return torch.cos(self.theta / 2) ** 2 + @property + def db(self): + """Photon loss in dB.""" + return -10 * torch.log10(self.t) + def update_matrix_state(self) -> torch.Tensor: """Update the local Kraus matrices acting on Fock state density matrices.""" return self.get_matrix_state(self.theta) @@ -74,7 +84,23 @@ def get_matrix_state(self, theta: Any) -> torch.Tensor: def init_para(self, inputs: Any = None) -> None: """Initialize the parameters.""" - self.gate.init_para(inputs) + self.gate.init_para(self._inputs_to_theta(inputs)) + + def _inputs_to_theta(self, inputs: Any = None) -> torch.Tensor | None: + """Convert inputs under the chosen convention to the internal loss theta.""" + while isinstance(inputs, list): + inputs = inputs[0] + if inputs is None: + return None + if not isinstance(inputs, torch.Tensor): + inputs = torch.tensor(inputs, dtype=torch.float) + if self.convention == 'theta': + return inputs + elif self.convention == 't': + return torch.arccos(inputs**0.5) * 2 + else: + t = 10 ** (-inputs / 10) + return torch.arccos(t**0.5) * 2 def update_transform_xy(self) -> tuple[torch.Tensor, torch.Tensor]: """Update the local transformation matrices X and Y acting on Gaussian states. diff --git a/src/deepquantum/photonic/circuit.py b/src/deepquantum/photonic/circuit.py index f148bf0d..62bedff8 100644 --- a/src/deepquantum/photonic/circuit.py +++ b/src/deepquantum/photonic/circuit.py @@ -2818,10 +2818,20 @@ def loss_t(self, wires: int, inputs: Any = None, encode: bool = False) -> None: requires_grad = not encode if inputs is not None: requires_grad = False - if not isinstance(inputs, torch.Tensor): - inputs = torch.tensor(inputs, dtype=torch.float) - theta = torch.arccos(inputs**0.5) * 2 - loss = PhotonLoss(inputs=theta, nmode=self.nmode, wires=wires, cutoff=self.cutoff, requires_grad=requires_grad) + if isinstance(inputs, torch.Tensor) and inputs.requires_grad and inputs.is_leaf: + warnings.warn( + 'loss_t() does not optimize leaf transmittance inputs directly; use loss() or encode=True.', + UserWarning, + stacklevel=2, + ) + loss = PhotonLoss( + inputs=inputs, + nmode=self.nmode, + wires=wires, + cutoff=self.cutoff, + convention='t', + requires_grad=requires_grad, + ) self.add(loss, encode=encode) def loss_db(self, wires: int, inputs: Any = None, encode: bool = False) -> None: @@ -2836,11 +2846,20 @@ def loss_db(self, wires: int, inputs: Any = None, encode: bool = False) -> None: requires_grad = not encode if inputs is not None: requires_grad = False - if not isinstance(inputs, torch.Tensor): - inputs = torch.tensor(inputs, dtype=torch.float) - t = 10 ** (-inputs / 10) - theta = torch.arccos(t**0.5) * 2 - loss = PhotonLoss(inputs=theta, nmode=self.nmode, wires=wires, cutoff=self.cutoff, requires_grad=requires_grad) + if isinstance(inputs, torch.Tensor) and inputs.requires_grad and inputs.is_leaf: + warnings.warn( + 'loss_db() does not optimize leaf dB inputs directly; use loss() or encode=True.', + UserWarning, + stacklevel=2, + ) + loss = PhotonLoss( + inputs=inputs, + nmode=self.nmode, + wires=wires, + cutoff=self.cutoff, + convention='db', + requires_grad=requires_grad, + ) self.add(loss, encode=encode) def barrier(self, wires: int | list[int] | None = None) -> None: