Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions src/deepquantum/photonic/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class PhotonLoss(Channel):
nmode: The number of modes that the quantum operation acts on. Default: 1
wires: The indices of the modes that the quantum operation acts on. Default: ``None``
cutoff: The Fock space truncation. Default: ``None``
convention: The convention of the input parameter, including ``'theta'``, ``'t'`` and ``'db'``.
Default: ``'theta'``
requires_grad: Whether the parameter is ``nn.Parameter`` or ``buffer``.
Default: ``False`` (which means ``buffer``)
"""
Expand All @@ -35,12 +37,15 @@ def __init__(
nmode: int = 1,
wires: int | list[int] | None = None,
cutoff: int | None = None,
convention: str = 'theta',
requires_grad: bool = False,
) -> None:
super().__init__(name='PhotonLoss', nmode=nmode, wires=wires, cutoff=cutoff)
assert convention in ('theta', 't', 'db'), 'Invalid convention'
self.convention = convention
self.requires_grad = requires_grad
self.gate = BeamSplitterSingle(
inputs=inputs,
inputs=self._inputs_to_theta(inputs),
nmode=self.nmode + 1,
wires=self.wires + [self.nmode],
cutoff=cutoff,
Expand All @@ -60,6 +65,11 @@ def t(self):
"""Transmittance."""
return torch.cos(self.theta / 2) ** 2

@property
def db(self):
"""Photon loss in dB."""
return -10 * torch.log10(self.t)

def update_matrix_state(self) -> torch.Tensor:
"""Update the local Kraus matrices acting on Fock state density matrices."""
return self.get_matrix_state(self.theta)
Expand All @@ -74,7 +84,23 @@ def get_matrix_state(self, theta: Any) -> torch.Tensor:

def init_para(self, inputs: Any = None) -> None:
"""Initialize the parameters."""
self.gate.init_para(inputs)
self.gate.init_para(self._inputs_to_theta(inputs))

def _inputs_to_theta(self, inputs: Any = None) -> torch.Tensor | None:
"""Convert inputs under the chosen convention to the internal loss theta."""
while isinstance(inputs, list):
inputs = inputs[0]
if inputs is None:
return None
if not isinstance(inputs, torch.Tensor):
inputs = torch.tensor(inputs, dtype=torch.float)
if self.convention == 'theta':
return inputs
elif self.convention == 't':
return torch.arccos(inputs**0.5) * 2
else:
t = 10 ** (-inputs / 10)
return torch.arccos(t**0.5) * 2

def update_transform_xy(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Update the local transformation matrices X and Y acting on Gaussian states.
Expand Down
37 changes: 28 additions & 9 deletions src/deepquantum/photonic/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2818,10 +2818,20 @@ def loss_t(self, wires: int, inputs: Any = None, encode: bool = False) -> None:
requires_grad = not encode
if inputs is not None:
requires_grad = False
if not isinstance(inputs, torch.Tensor):
inputs = torch.tensor(inputs, dtype=torch.float)
theta = torch.arccos(inputs**0.5) * 2
loss = PhotonLoss(inputs=theta, nmode=self.nmode, wires=wires, cutoff=self.cutoff, requires_grad=requires_grad)
if isinstance(inputs, torch.Tensor) and inputs.requires_grad and inputs.is_leaf:
warnings.warn(
'loss_t() does not optimize leaf transmittance inputs directly; use loss() or encode=True.',
UserWarning,
stacklevel=2,
)
loss = PhotonLoss(
inputs=inputs,
nmode=self.nmode,
wires=wires,
cutoff=self.cutoff,
convention='t',
requires_grad=requires_grad,
)
self.add(loss, encode=encode)

def loss_db(self, wires: int, inputs: Any = None, encode: bool = False) -> None:
Expand All @@ -2836,11 +2846,20 @@ def loss_db(self, wires: int, inputs: Any = None, encode: bool = False) -> None:
requires_grad = not encode
if inputs is not None:
requires_grad = False
if not isinstance(inputs, torch.Tensor):
inputs = torch.tensor(inputs, dtype=torch.float)
t = 10 ** (-inputs / 10)
theta = torch.arccos(t**0.5) * 2
loss = PhotonLoss(inputs=theta, nmode=self.nmode, wires=wires, cutoff=self.cutoff, requires_grad=requires_grad)
if isinstance(inputs, torch.Tensor) and inputs.requires_grad and inputs.is_leaf:
warnings.warn(
'loss_db() does not optimize leaf dB inputs directly; use loss() or encode=True.',
UserWarning,
stacklevel=2,
)
loss = PhotonLoss(
inputs=inputs,
nmode=self.nmode,
wires=wires,
cutoff=self.cutoff,
convention='db',
requires_grad=requires_grad,
)
self.add(loss, encode=encode)

def barrier(self, wires: int | list[int] | None = None) -> None:
Expand Down