From 706e9230dba087d4c6cc0c7db1d4f1de8b9491ae Mon Sep 17 00:00:00 2001 From: Linwei-Ying Date: Fri, 31 May 2024 14:16:32 +0800 Subject: [PATCH 1/8] update run_ut.py comment3 --- Acc/run_ut/run_ut.py | 28 ++++++---------------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/Acc/run_ut/run_ut.py b/Acc/run_ut/run_ut.py index 2d38c9e..73900a1 100644 --- a/Acc/run_ut/run_ut.py +++ b/Acc/run_ut/run_ut.py @@ -1,5 +1,4 @@ import argparse -import json import os import sys import time @@ -11,9 +10,9 @@ import paddle import paddle.nn.functional as F from utils import Const, print_warn_log, api_info_preprocess, get_json_contents, print_info_log, create_directory, print_error_log, check_path_before_create, seed_all -from data_generate import gen_api_params, gen_args +from data_generate import gen_api_params from run_ut_utils import hf_32_standard_api, Backward_Message -from file_check_util import FileOpen, FileCheckConst, FileChecker, check_link, change_mode, check_file_suffix +from file_check_util import FileOpen, FileCheckConst, FileChecker, check_link, check_file_suffix # from compare.compare import Comparator seed_all() @@ -23,7 +22,7 @@ current_time = time.strftime("%Y%m%d%H%M%S") RESULT_FILE_NAME = f"accuracy_checking_result_" + current_time + ".csv" DETAILS_FILE_NAME = f"accuracy_checking_details_" + current_time + ".csv" -RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path', +RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'result_csv_path', 'details_csv_path', 'save_error_data', 'is_continue_run_ut', 'real_data_path']) tqdm_params = { @@ -227,12 +226,6 @@ def run_ut_command_save(args): forward_file = os.path.realpath(args.forward_input_file) check_file_suffix(forward_file, FileCheckConst.JSON_SUFFIX) forward_content = get_json_contents(forward_file) - backward_content = {} - if args.backward_input_file: - check_link(args.backward_input_file) - backward_file = os.path.realpath(args.backward_input_file) - check_file_suffix(backward_file, FileCheckConst.JSON_SUFFIX) - backward_content = get_json_contents(backward_file) result_csv_path = os.path.join(out_path, RESULT_FILE_NAME) details_csv_path = os.path.join(out_path, DETAILS_FILE_NAME) if args.result_csv_path: @@ -243,7 +236,7 @@ def run_ut_command_save(args): time_info = result_csv_path.split('.')[0].split('_')[-1] global UT_ERROR_DATA_DIR UT_ERROR_DATA_DIR = 'ut_error_data' + time_info - run_ut_config = RunUTConfig(forward_content, backward_content, result_csv_path, details_csv_path, save_error_data, + run_ut_config = RunUTConfig(forward_content, result_csv_path, details_csv_path, save_error_data, args.result_csv_path, args.real_data_path) run_ut_save(run_ut_config) @@ -391,10 +384,6 @@ def _run_ut_parser(parser): help=" The api param tool forward result file: generate from api param tool, " "a json file.", required=True) - parser.add_argument("-backward", "--backward", dest="backward_input_file", default="", type=str, - help=" The api param tool backward result file: generate from api param tool, " - "a json file.", - required=False) parser.add_argument("-o", "--dump_path", dest="out_path", default="", type=str, help=" The ut task result out path.", required=False) @@ -493,12 +482,6 @@ def run_ut_command(args): forward_file = os.path.realpath(args.forward_input_file) check_file_suffix(forward_file, FileCheckConst.JSON_SUFFIX) forward_content = get_json_contents(forward_file) - backward_content = {} - if args.backward_input_file: - check_link(args.backward_input_file) - backward_file = os.path.realpath(args.backward_input_file) - check_file_suffix(backward_file, FileCheckConst.JSON_SUFFIX) - backward_content = get_json_contents(backward_file) result_csv_path = os.path.join(out_path, RESULT_FILE_NAME) details_csv_path = os.path.join(out_path, DETAILS_FILE_NAME) if args.result_csv_path: @@ -509,7 +492,7 @@ def run_ut_command(args): time_info = result_csv_path.split('.')[0].split('_')[-1] global UT_ERROR_DATA_DIR UT_ERROR_DATA_DIR = 'ut_error_data' + time_info - run_ut_config = RunUTConfig(forward_content, backward_content, result_csv_path, details_csv_path, save_error_data, + run_ut_config = RunUTConfig(forward_content, result_csv_path, details_csv_path, save_error_data, args.result_csv_path, args.real_data_path) run_ut(run_ut_config) @@ -525,6 +508,7 @@ def __init__(self, bench_grad, device_grad, device_output, bench_output, in_fwd_ self.backward_message = backward_message self.rank = rank + if __name__ == "__main__": _run_ut() print_info_log("UT task completed") From 60b0097decc035d8f2784fea1a4bc4f7d5e2e070 Mon Sep 17 00:00:00 2001 From: Linwei-Ying Date: Fri, 31 May 2024 14:24:14 +0800 Subject: [PATCH 2/8] update run_ut.py comment --- Acc/run_ut/run_ut.py | 1 - 1 file changed, 1 deletion(-) diff --git a/Acc/run_ut/run_ut.py b/Acc/run_ut/run_ut.py index 73900a1..8c1697d 100644 --- a/Acc/run_ut/run_ut.py +++ b/Acc/run_ut/run_ut.py @@ -70,7 +70,6 @@ def recursive_arg_to_cpu(arg_in, to_detach, raise_dtype=None): if need_backward and not arg_in.stop_gradient: arg_in = deal_detach(raise_bench_data_dtype(api_name, arg_in.clone(), raise_dtype), to_detach) arg_in.stop_gradient = False - return arg_in else: return deal_detach(raise_bench_data_dtype(api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach) From bde6ac2e4f78d2618f51bee7edcf22200b663190 Mon Sep 17 00:00:00 2001 From: Linwei-Ying Date: Fri, 31 May 2024 14:48:27 +0800 Subject: [PATCH 3/8] update run_ut.py comment --- Acc/run_ut/run_ut.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/Acc/run_ut/run_ut.py b/Acc/run_ut/run_ut.py index 8c1697d..bb0ef14 100644 --- a/Acc/run_ut/run_ut.py +++ b/Acc/run_ut/run_ut.py @@ -329,7 +329,6 @@ def need_to_backward(grad_index, out): def run_backward(args, grad_index, out): - if grad_index is not None: out[grad_index].backward() else: @@ -339,7 +338,6 @@ def run_backward(args, grad_index, out): if isinstance(arg, paddle.Tensor): args_grad.append(arg.grad) grad_out = args_grad - return grad_out @@ -424,10 +422,8 @@ def preprocess_forward_content(forward_content): processed_content = {} base_keys_variants = {} arg_cache = {} - for key, value in forward_content.items(): base_key = key.rsplit(Const.DELIMITER, 1)[0] - if key not in arg_cache: new_args = value['args'] new_kwargs = value['kwargs'] @@ -436,9 +432,7 @@ def preprocess_forward_content(forward_content): for arg in new_args if isinstance(arg, dict) ] arg_cache[key] = (filtered_new_args, new_kwargs) - filtered_new_args, new_kwargs = arg_cache[key] - if base_key not in base_keys_variants: processed_content[key] = value base_keys_variants[base_key] = {key} @@ -449,11 +443,9 @@ def preprocess_forward_content(forward_content): if existing_args == filtered_new_args and existing_kwargs == new_kwargs: is_duplicate = True break - if not is_duplicate: processed_content[key] = value base_keys_variants[base_key].add(key) - return processed_content From bfca2bb49a9c4d3f90c7e6cbbf4599d300564a99 Mon Sep 17 00:00:00 2001 From: Linwei-Ying Date: Fri, 31 May 2024 16:07:21 +0800 Subject: [PATCH 4/8] update save --- Acc/run_ut/run_ut.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/Acc/run_ut/run_ut.py b/Acc/run_ut/run_ut.py index bb0ef14..d73b588 100644 --- a/Acc/run_ut/run_ut.py +++ b/Acc/run_ut/run_ut.py @@ -245,7 +245,8 @@ def run_ut_save(config): for i, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)): try: print(api_full_name) - run_paddle_api_save(api_full_name, config.real_data_path, api_info_dict) + dump_path = os.path.dirname(config.result_csv_path) + run_paddle_api_save(api_full_name, config.real_data_path, api_info_dict, dump_path) print("*"*200) except Exception as err: [_, api_name, _] = api_full_name.split("*") @@ -258,7 +259,7 @@ def run_ut_save(config): gc.collect() -def run_paddle_api_save(api_full_name, real_data_path, api_info_dict): +def run_paddle_api_save(api_full_name, real_data_path, api_info_dict, dump_path): in_fwd_data_list = [] backward_message = '' [api_type, api_name, _] = api_full_name.split('*') @@ -285,8 +286,13 @@ def run_paddle_api_save(api_full_name, real_data_path, api_info_dict): output_folder = "npu_output" else: output_folder = "gpu_output" - current_dir = os.path.dirname(os.path.abspath(__file__)) - output_dir = os.path.abspath(os.path.join(current_dir, "..", output_folder)) + if dump_path == os.path.dirname(os.path.abspath(__file__)): + output_dir = os.path.abspath(os.path.join(dump_path, "..", output_folder)) + output_dir_back = os.path.abspath(os.path.join(dump_path, "..", output_folder + "_backward")) + else: + current_dir = dump_path + output_dir = os.path.abspath(os.path.join(current_dir, output_folder)) + output_dir_back = os.path.abspath(os.path.join(current_dir, output_folder + "_backward")) os.makedirs(output_dir, exist_ok=True) output_path = output_dir + '/' + f'{api_full_name}' paddle.save(device_out, output_path) @@ -306,9 +312,9 @@ def run_paddle_api_save(api_full_name, real_data_path, api_info_dict): else: backward_message += Backward_Message.MULTIPLE_BACKWARD_MESSAGE - output_dir = os.path.abspath(os.path.join(current_dir, "..", output_folder + "_backward")) - os.makedirs(output_dir, exist_ok=True) - output_path = output_dir + '/' + f'{api_full_name}' + # output_dir = os.path.abspath(os.path.join(current_dir, "..", output_folder + "_backward")) + os.makedirs(output_dir_back, exist_ok=True) + output_path = output_dir_back + '/' + f'{api_full_name}' paddle.save(device_grad_out, output_path) return From 3cdd90e7e6170f7f0edb6037ec24ad048dc4c442 Mon Sep 17 00:00:00 2001 From: Linwei-Ying Date: Fri, 31 May 2024 17:10:08 +0800 Subject: [PATCH 5/8] clp real_data_path delete --- Acc/run_ut/data_generate.py | 26 ++++++++++++-------------- Acc/run_ut/run_ut.py | 26 +++++++++++--------------- 2 files changed, 23 insertions(+), 29 deletions(-) diff --git a/Acc/run_ut/data_generate.py b/Acc/run_ut/data_generate.py index 4402c83..6a34636 100644 --- a/Acc/run_ut/data_generate.py +++ b/Acc/run_ut/data_generate.py @@ -24,7 +24,7 @@ def hook(grad): tensor.register_hook(hook) -def gen_data(info, api_name, need_grad, convert_type, real_data_path=None): +def gen_data(info, api_name, need_grad, convert_type): """ Function Description: Based on arg basic information, generate arg data @@ -221,7 +221,7 @@ def gen_bool_tensor(low, high, shape): return data -def gen_args(args_info, api_name, need_grad=True, convert_type=None, real_data_path=None): +def gen_args(args_info, api_name, need_grad=True, convert_type=None): """ Function Description: Based on API basic information, generate input parameters: args, for API forward running @@ -230,15 +230,14 @@ def gen_args(args_info, api_name, need_grad=True, convert_type=None, real_data_p api_name: API name need_grad: set Tensor grad for backward convert_type: convert ori_type to dist_type flag. - real_data_path: the root directory for storing real data. """ check_object_type(args_info, list) args_result = [] for arg in args_info: if isinstance(arg, (list, tuple)): - data = gen_args(arg, api_name, need_grad, convert_type, real_data_path) + data = gen_args(arg, api_name, need_grad, convert_type) elif isinstance(arg, dict): - data = gen_data(arg, api_name, need_grad, convert_type, real_data_path) + data = gen_data(arg, api_name, need_grad, convert_type) else: print_warn_log(f'Warning: {arg} is not supported') raise NotImplementedError() @@ -246,14 +245,13 @@ def gen_args(args_info, api_name, need_grad=True, convert_type=None, real_data_p return args_result -def gen_kwargs(api_info, convert_type=None, real_data_path=None): +def gen_kwargs(api_info, convert_type=None): """ Function Description: Based on API basic information, generate input parameters: kwargs, for API forward running Parameter: api_info: API basic information. Dict convert_type: convert ori_type to dist_type flag. - real_data_path: the root directory for storing real data. """ check_object_type(api_info, dict) kwargs_params = api_info.get("kwargs") @@ -261,9 +259,9 @@ def gen_kwargs(api_info, convert_type=None, real_data_path=None): if value is None: continue if isinstance(value, (list, tuple)): - kwargs_params[key] = gen_list_kwargs(value, convert_type, real_data_path) + kwargs_params[key] = gen_list_kwargs(value, convert_type) elif value.get('type') in TENSOR_DATA_LIST_PADDLE or value.get('type').startswith("numpy"): - kwargs_params[key] = gen_data(value, True, convert_type, real_data_path) + kwargs_params[key] = gen_data(value, True, convert_type) elif value.get('type') in PADDLE_TYPE: gen_paddle_kwargs(kwargs_params, key, value) else: @@ -276,7 +274,7 @@ def gen_paddle_kwargs(kwargs_params, key, value): kwargs_params[key] = eval(value.get('value')) -def gen_list_kwargs(kwargs_item_value, convert_type, real_data_path=None): +def gen_list_kwargs(kwargs_item_value, convert_type): """ Function Description: When kwargs value is list, generate the list of kwargs result @@ -287,14 +285,14 @@ def gen_list_kwargs(kwargs_item_value, convert_type, real_data_path=None): kwargs_item_result = [] for item in kwargs_item_value: if item.get('type') in TENSOR_DATA_LIST_PADDLE: - item_value = gen_data(item, False, convert_type, real_data_path) + item_value = gen_data(item, False, convert_type) else: item_value = item.get('value') kwargs_item_result.append(item_value) return kwargs_item_result -def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_data_path=None): +def gen_api_params(api_info, api_name, need_grad=True, convert_type=None): """ Function Description: Based on API basic information, generate input parameters: args, kwargs, for API forward running @@ -308,9 +306,9 @@ def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_d if convert_type and convert_type not in Const.CONVERT: error_info = f"convert_type params not support {convert_type}." raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info) - kwargs_params = gen_kwargs(api_info, convert_type, real_data_path) + kwargs_params = gen_kwargs(api_info, convert_type) if api_info.get("args"): - args_params = gen_args(api_info.get("args"), api_name, need_grad, convert_type, real_data_path) + args_params = gen_args(api_info.get("args"), api_name, need_grad, convert_type) else: print_warn_log(f'Warning: No args in {api_info} ') args_params = [] diff --git a/Acc/run_ut/run_ut.py b/Acc/run_ut/run_ut.py index d73b588..65810dd 100644 --- a/Acc/run_ut/run_ut.py +++ b/Acc/run_ut/run_ut.py @@ -23,7 +23,7 @@ RESULT_FILE_NAME = f"accuracy_checking_result_" + current_time + ".csv" DETAILS_FILE_NAME = f"accuracy_checking_details_" + current_time + ".csv" RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'result_csv_path', 'details_csv_path', - 'save_error_data', 'is_continue_run_ut', 'real_data_path']) + 'save_error_data', 'is_continue_run_ut']) tqdm_params = { 'smoothing': 0, # 平滑进度条的预计剩余时间,取值范围0到1 @@ -141,7 +141,7 @@ def run_ut(config): for i, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)): try: print(api_full_name) - data_info = run_paddle_api(api_full_name, config.real_data_path, api_info_dict) + data_info = run_paddle_api(api_full_name, api_info_dict) # is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, # data_info.bench_output, # data_info.device_output, @@ -158,11 +158,11 @@ def run_ut(config): gc.collect() -def run_paddle_api(api_full_name, real_data_path, api_info_dict): +def run_paddle_api(api_full_name, api_info_dict): in_fwd_data_list = [] backward_message = '' [api_type, api_name, _] = api_full_name.split('*') - args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path) + args, kwargs, need_grad = get_api_info(api_info_dict, api_name) in_fwd_data_list.append(args) in_fwd_data_list.append(kwargs) need_backward = True @@ -236,7 +236,7 @@ def run_ut_command_save(args): global UT_ERROR_DATA_DIR UT_ERROR_DATA_DIR = 'ut_error_data' + time_info run_ut_config = RunUTConfig(forward_content, result_csv_path, details_csv_path, save_error_data, - args.result_csv_path, args.real_data_path) + args.result_csv_path) run_ut_save(run_ut_config) @@ -246,7 +246,7 @@ def run_ut_save(config): try: print(api_full_name) dump_path = os.path.dirname(config.result_csv_path) - run_paddle_api_save(api_full_name, config.real_data_path, api_info_dict, dump_path) + run_paddle_api_save(api_full_name, api_info_dict, dump_path) print("*"*200) except Exception as err: [_, api_name, _] = api_full_name.split("*") @@ -259,11 +259,11 @@ def run_ut_save(config): gc.collect() -def run_paddle_api_save(api_full_name, real_data_path, api_info_dict, dump_path): +def run_paddle_api_save(api_full_name, api_info_dict, dump_path): in_fwd_data_list = [] backward_message = '' [api_type, api_name, _] = api_full_name.split('*') - args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path) + args, kwargs, need_grad = get_api_info(api_info_dict, api_name) in_fwd_data_list.append(args) in_fwd_data_list.append(kwargs) need_backward = True @@ -319,12 +319,12 @@ def run_paddle_api_save(api_full_name, real_data_path, api_info_dict, dump_path) return -def get_api_info(api_info_dict, api_name, real_data_path): +def get_api_info(api_info_dict, api_name): convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict) need_grad = True if api_info_dict.get("kwargs") and "out" in api_info_dict.get("kwargs"): need_grad = False - args, kwargs = gen_api_params(api_info_dict, api_name, need_grad, convert_type, real_data_path) + args, kwargs = gen_api_params(api_info_dict, api_name, need_grad, convert_type) return args, kwargs, need_grad @@ -418,10 +418,6 @@ def __call__(self, parser, namespace, values, option_string=None): help=" The path of accuracy_checking_result_{timestamp}.csv, " "when run ut is interrupted, enter the file path to continue run ut.", required=False) - parser.add_argument("-real_data_path", dest="real_data_path", nargs="?", const="", default="", type=str, - help=" In real data mode, the root directory for storing real data " - "must be configured.", - required=False) def preprocess_forward_content(forward_content): @@ -490,7 +486,7 @@ def run_ut_command(args): global UT_ERROR_DATA_DIR UT_ERROR_DATA_DIR = 'ut_error_data' + time_info run_ut_config = RunUTConfig(forward_content, result_csv_path, details_csv_path, save_error_data, - args.result_csv_path, args.real_data_path) + args.result_csv_path) run_ut(run_ut_config) From 012d012e64f95b999fcdcce70be086eac0a34373 Mon Sep 17 00:00:00 2001 From: zhangjian <1032674385@qq.com> Date: Fri, 31 May 2024 17:29:04 +0800 Subject: [PATCH 6/8] add compare in run_ut.py --- Acc/run_ut/run_ut.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/Acc/run_ut/run_ut.py b/Acc/run_ut/run_ut.py index 65810dd..331be3a 100644 --- a/Acc/run_ut/run_ut.py +++ b/Acc/run_ut/run_ut.py @@ -13,7 +13,7 @@ from data_generate import gen_api_params from run_ut_utils import hf_32_standard_api, Backward_Message from file_check_util import FileOpen, FileCheckConst, FileChecker, check_link, check_file_suffix -# from compare.compare import Comparator +from compare.compare import Comparator seed_all() not_raise_dtype_set = {'type_as'} @@ -133,20 +133,19 @@ def run_ut(config): print_info_log("start UT test") print_info_log(f"UT task result will be saved in {config.result_csv_path}") print_info_log(f"UT task details will be saved in {config.details_csv_path}") - # compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut) - # with FileOpen(config.result_csv_path, 'r') as file: - # csv_reader = csv.reader(file) - # next(csv_reader) - # api_name_set = {row[0] for row in csv_reader} + compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut) + with FileOpen(config.result_csv_path, 'r') as file: + csv_reader = csv.reader(file) + next(csv_reader) for i, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)): try: print(api_full_name) data_info = run_paddle_api(api_full_name, api_info_dict) - # is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, - # data_info.bench_output, - # data_info.device_output, - # data_info.bench_grad, - # data_info.device_grad) + is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, + data_info.bench_output, + data_info.device_output, + data_info.bench_grad, + data_info.device_grad) except Exception as err: [_, api_name, _] = api_full_name.split("*") if "expected scalar type Long" in str(err): From 88ea3028ef0b34dced73e10217077d61f1087c68 Mon Sep 17 00:00:00 2001 From: Linwei-Ying Date: Mon, 3 Jun 2024 16:15:40 +0800 Subject: [PATCH 7/8] run_ut_utils merged --- Acc/run_ut/data_generate.py | 4 ++-- Acc/run_ut/run_ut.py | 4 ++-- Acc/run_ut/utils.py | 9 +++++++++ 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/Acc/run_ut/data_generate.py b/Acc/run_ut/data_generate.py index 6a34636..47ece64 100644 --- a/Acc/run_ut/data_generate.py +++ b/Acc/run_ut/data_generate.py @@ -2,8 +2,8 @@ import os import numpy import math -from utils import check_object_type, Const, CompareException, print_error_log, print_warn_log, check_file_or_directory_path, get_full_data_path -from run_ut_utils import hf_32_standard_api +from utils import check_object_type, Const, CompareException, print_error_log, print_warn_log, \ + check_file_or_directory_path, hf_32_standard_api TENSOR_DATA_LIST_PADDLE = ["paddle.Tensor", "paddle.create_parameter"] diff --git a/Acc/run_ut/run_ut.py b/Acc/run_ut/run_ut.py index 331be3a..2fb3649 100644 --- a/Acc/run_ut/run_ut.py +++ b/Acc/run_ut/run_ut.py @@ -9,9 +9,9 @@ from tqdm import tqdm import paddle import paddle.nn.functional as F -from utils import Const, print_warn_log, api_info_preprocess, get_json_contents, print_info_log, create_directory, print_error_log, check_path_before_create, seed_all +from utils import Const, print_warn_log, api_info_preprocess, get_json_contents, print_info_log, create_directory, \ + print_error_log, check_path_before_create, seed_all, hf_32_standard_api, Backward_Message from data_generate import gen_api_params -from run_ut_utils import hf_32_standard_api, Backward_Message from file_check_util import FileOpen, FileCheckConst, FileChecker, check_link, check_file_suffix from compare.compare import Comparator diff --git a/Acc/run_ut/utils.py b/Acc/run_ut/utils.py index e235fd1..c5a3826 100644 --- a/Acc/run_ut/utils.py +++ b/Acc/run_ut/utils.py @@ -11,6 +11,15 @@ from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +hf_32_standard_api = ["conv1d", "conv2d"] + + +class Backward_Message: + MULTIPLE_BACKWARD_MESSAGE = "Multiple backward is not supported." + UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, skip backward." + NO_BACKWARD_RESULT_MESSAGE = "function backward result is None, skip backward." + + class Const: """ Class for const From d7dc1f0f0d0ff987ce17c6c43b715246d276cc6e Mon Sep 17 00:00:00 2001 From: Linwei-Ying Date: Mon, 3 Jun 2024 16:23:21 +0800 Subject: [PATCH 8/8] run_ut_utils.py deleted --- Acc/run_ut/run_ut_utils.py | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 Acc/run_ut/run_ut_utils.py diff --git a/Acc/run_ut/run_ut_utils.py b/Acc/run_ut/run_ut_utils.py deleted file mode 100644 index d78642f..0000000 --- a/Acc/run_ut/run_ut_utils.py +++ /dev/null @@ -1,7 +0,0 @@ -hf_32_standard_api = ["conv1d", "conv2d"] - - -class Backward_Message: - MULTIPLE_BACKWARD_MESSAGE = "Multiple backward is not supported." - UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, skip backward." - NO_BACKWARD_RESULT_MESSAGE = "function backward result is None, skip backward." \ No newline at end of file