From 0afe6ec0b76db4faa8b230684abc349c9d2e600d Mon Sep 17 00:00:00 2001 From: cangtianhuang <1903374751@qq.com> Date: Thu, 18 Jun 2026 11:00:09 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Update=20CopsFull=5FRule?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tester/paddle_to_torch/rules.py | 46 +++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/tester/paddle_to_torch/rules.py b/tester/paddle_to_torch/rules.py index c8ff49ce..c090296c 100644 --- a/tester/paddle_to_torch/rules.py +++ b/tester/paddle_to_torch/rules.py @@ -7138,10 +7138,52 @@ class CopsFull_Rule(BaseRule): def apply(self, paddle_api: str) -> ConvertResult: core = """ +shape = locals().get('shape') +fill_value = locals().get("value", 0.0) +dtype = locals().get('dtype') x = locals().get("x") -value = locals().get("value", 0.0) + +# handle shape +def convert_to_list(shape): + if isinstance(shape, torch.Tensor): + return shape.tolist() + elif isinstance(shape, (list, tuple)): + shape_list = [] + for item in shape: + if isinstance(item, torch.Tensor): + if item.shape == torch.Size([]): + shape_list.append(item.item()) + else: + shape_list.extend(item.tolist()) + else: + shape_list.append(item) + return shape_list + elif isinstance(shape, int): + return [shape] + else: + return shape + +# handle fill_value +def convert_to_scalar(fill_value): + if isinstance(fill_value, torch.Tensor): + return fill_value.item() + # example: "-inf", "3.5" + elif isinstance(fill_value, str): + return float(fill_value) + else: + return fill_value + +shape = convert_to_list(shape) +fill_value = convert_to_scalar(fill_value) + +if dtype is None and not isinstance(fill_value, bool): + if isinstance(fill_value, complex): + dtype = torch.complex128 + else: + dtype = torch.float32 +tmp = torch.full(size=shape, fill_value=fill_value, dtype=dtype) with torch.no_grad(): - x.fill_(float(value)) + x.set_(tmp) result = x """ code = Code(core=core.splitlines())