From 07a5f46e73c874ce27baab689ca73f85a588653c Mon Sep 17 00:00:00 2001 From: ZibinGuo <1009134431@qq.com> Date: Wed, 9 Oct 2024 06:50:13 +0000 Subject: [PATCH] [XPU]Fix the bug in the three-way comparison. --- paddleapex/apex/utils/data_generate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddleapex/apex/utils/data_generate.py b/paddleapex/apex/utils/data_generate.py index 25c09ce..2463edb 100644 --- a/paddleapex/apex/utils/data_generate.py +++ b/paddleapex/apex/utils/data_generate.py @@ -169,7 +169,7 @@ def gen_common_tensor(low_info, high_info, shape, data_dtype): if math.isnan(high) or math.isnan(low) or math.isinf(high) or math.isinf(low): tensor = generate_random_tensor(shape, 0, 1) tensor = paddle.to_tensor( - tensor, dtype=eval(REAL_TYPE_PADDLE.get(data_dtype)) + tensor, dtype=eval(REAL_TYPE_PADDLE.get("FP32")) ) return tensor else: @@ -177,7 +177,7 @@ def gen_common_tensor(low_info, high_info, shape, data_dtype): shape = [1] tensor = generate_random_tensor(shape, low, high).astype(numpy.float32) tensor = paddle.to_tensor( - tensor, dtype=eval(REAL_TYPE_PADDLE.get(data_dtype)) + tensor, dtype=eval(REAL_TYPE_PADDLE.get("FP32")) ) return tensor elif ( @@ -292,7 +292,7 @@ def rand_like(data, seed=1234): if isinstance(data, paddle.Tensor): if data.dtype.name in ["BF16", "FP16"]: random_normals = numpy.random.randn(*data.shape) - x = paddle.to_tensor(random_normals, dtype=data.dtype) + x = paddle.to_tensor(random_normals, dtype='float32') return x elif data.dtype.name in ["FP32", "FP64"]: random_normals = numpy.random.randn(*data.shape)