diff --git a/tester/paddle_to_torch/rules.py b/tester/paddle_to_torch/rules.py index c8ff49ce..c090296c 100644 --- a/tester/paddle_to_torch/rules.py +++ b/tester/paddle_to_torch/rules.py @@ -7138,10 +7138,52 @@ class CopsFull_Rule(BaseRule): def apply(self, paddle_api: str) -> ConvertResult: core = """ +shape = locals().get('shape') +fill_value = locals().get("value", 0.0) +dtype = locals().get('dtype') x = locals().get("x") -value = locals().get("value", 0.0) + +# handle shape +def convert_to_list(shape): + if isinstance(shape, torch.Tensor): + return shape.tolist() + elif isinstance(shape, (list, tuple)): + shape_list = [] + for item in shape: + if isinstance(item, torch.Tensor): + if item.shape == torch.Size([]): + shape_list.append(item.item()) + else: + shape_list.extend(item.tolist()) + else: + shape_list.append(item) + return shape_list + elif isinstance(shape, int): + return [shape] + else: + return shape + +# handle fill_value +def convert_to_scalar(fill_value): + if isinstance(fill_value, torch.Tensor): + return fill_value.item() + # example: "-inf", "3.5" + elif isinstance(fill_value, str): + return float(fill_value) + else: + return fill_value + +shape = convert_to_list(shape) +fill_value = convert_to_scalar(fill_value) + +if dtype is None and not isinstance(fill_value, bool): + if isinstance(fill_value, complex): + dtype = torch.complex128 + else: + dtype = torch.float32 +tmp = torch.full(size=shape, fill_value=fill_value, dtype=dtype) with torch.no_grad(): - x.fill_(float(value)) + x.set_(tmp) result = x """ code = Code(core=core.splitlines())