Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 14 additions & 16 deletions Acc/run_ut/data_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import os
import numpy
import math
from utils import check_object_type, Const, CompareException, print_error_log, print_warn_log, check_file_or_directory_path, get_full_data_path
from run_ut_utils import hf_32_standard_api
from utils import check_object_type, Const, CompareException, print_error_log, print_warn_log, \
check_file_or_directory_path, hf_32_standard_api


TENSOR_DATA_LIST_PADDLE = ["paddle.Tensor", "paddle.create_parameter"]
Expand All @@ -24,7 +24,7 @@ def hook(grad):
tensor.register_hook(hook)


def gen_data(info, api_name, need_grad, convert_type, real_data_path=None):
def gen_data(info, api_name, need_grad, convert_type):
"""
Function Description:
Based on arg basic information, generate arg data
Expand Down Expand Up @@ -221,7 +221,7 @@ def gen_bool_tensor(low, high, shape):
return data


def gen_args(args_info, api_name, need_grad=True, convert_type=None, real_data_path=None):
def gen_args(args_info, api_name, need_grad=True, convert_type=None):
"""
Function Description:
Based on API basic information, generate input parameters: args, for API forward running
Expand All @@ -230,40 +230,38 @@ def gen_args(args_info, api_name, need_grad=True, convert_type=None, real_data_p
api_name: API name
need_grad: set Tensor grad for backward
convert_type: convert ori_type to dist_type flag.
real_data_path: the root directory for storing real data.
"""
check_object_type(args_info, list)
args_result = []
for arg in args_info:
if isinstance(arg, (list, tuple)):
data = gen_args(arg, api_name, need_grad, convert_type, real_data_path)
data = gen_args(arg, api_name, need_grad, convert_type)
elif isinstance(arg, dict):
data = gen_data(arg, api_name, need_grad, convert_type, real_data_path)
data = gen_data(arg, api_name, need_grad, convert_type)
else:
print_warn_log(f'Warning: {arg} is not supported')
raise NotImplementedError()
args_result.append(data)
return args_result


def gen_kwargs(api_info, convert_type=None, real_data_path=None):
def gen_kwargs(api_info, convert_type=None):
"""
Function Description:
Based on API basic information, generate input parameters: kwargs, for API forward running
Parameter:
api_info: API basic information. Dict
convert_type: convert ori_type to dist_type flag.
real_data_path: the root directory for storing real data.
"""
check_object_type(api_info, dict)
kwargs_params = api_info.get("kwargs")
for key, value in kwargs_params.items():
if value is None:
continue
if isinstance(value, (list, tuple)):
kwargs_params[key] = gen_list_kwargs(value, convert_type, real_data_path)
kwargs_params[key] = gen_list_kwargs(value, convert_type)
elif value.get('type') in TENSOR_DATA_LIST_PADDLE or value.get('type').startswith("numpy"):
kwargs_params[key] = gen_data(value, True, convert_type, real_data_path)
kwargs_params[key] = gen_data(value, True, convert_type)
elif value.get('type') in PADDLE_TYPE:
gen_paddle_kwargs(kwargs_params, key, value)
else:
Expand All @@ -276,7 +274,7 @@ def gen_paddle_kwargs(kwargs_params, key, value):
kwargs_params[key] = eval(value.get('value'))


def gen_list_kwargs(kwargs_item_value, convert_type, real_data_path=None):
def gen_list_kwargs(kwargs_item_value, convert_type):
"""
Function Description:
When kwargs value is list, generate the list of kwargs result
Expand All @@ -287,14 +285,14 @@ def gen_list_kwargs(kwargs_item_value, convert_type, real_data_path=None):
kwargs_item_result = []
for item in kwargs_item_value:
if item.get('type') in TENSOR_DATA_LIST_PADDLE:
item_value = gen_data(item, False, convert_type, real_data_path)
item_value = gen_data(item, False, convert_type)
else:
item_value = item.get('value')
kwargs_item_result.append(item_value)
return kwargs_item_result


def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_data_path=None):
def gen_api_params(api_info, api_name, need_grad=True, convert_type=None):
"""
Function Description:
Based on API basic information, generate input parameters: args, kwargs, for API forward running
Expand All @@ -308,9 +306,9 @@ def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_d
if convert_type and convert_type not in Const.CONVERT:
error_info = f"convert_type params not support {convert_type}."
raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
kwargs_params = gen_kwargs(api_info, convert_type, real_data_path)
kwargs_params = gen_kwargs(api_info, convert_type)
if api_info.get("args"):
args_params = gen_args(api_info.get("args"), api_name, need_grad, convert_type, real_data_path)
args_params = gen_args(api_info.get("args"), api_name, need_grad, convert_type)
else:
print_warn_log(f'Warning: No args in {api_info} ')
args_params = []
Expand Down
104 changes: 40 additions & 64 deletions Acc/run_ut/run_ut.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import json
import os
import sys
import time
Expand All @@ -10,11 +9,11 @@
from tqdm import tqdm
import paddle
import paddle.nn.functional as F
from utils import Const, print_warn_log, api_info_preprocess, get_json_contents, print_info_log, create_directory, print_error_log, check_path_before_create, seed_all
from data_generate import gen_api_params, gen_args
from run_ut_utils import hf_32_standard_api, Backward_Message
from file_check_util import FileOpen, FileCheckConst, FileChecker, check_link, change_mode, check_file_suffix
# from compare.compare import Comparator
from utils import Const, print_warn_log, api_info_preprocess, get_json_contents, print_info_log, create_directory, \
print_error_log, check_path_before_create, seed_all, hf_32_standard_api, Backward_Message
from data_generate import gen_api_params
from file_check_util import FileOpen, FileCheckConst, FileChecker, check_link, check_file_suffix
from compare.compare import Comparator

seed_all()
not_raise_dtype_set = {'type_as'}
Expand All @@ -23,8 +22,8 @@
current_time = time.strftime("%Y%m%d%H%M%S")
RESULT_FILE_NAME = f"accuracy_checking_result_" + current_time + ".csv"
DETAILS_FILE_NAME = f"accuracy_checking_details_" + current_time + ".csv"
RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
'save_error_data', 'is_continue_run_ut', 'real_data_path'])
RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'result_csv_path', 'details_csv_path',
'save_error_data', 'is_continue_run_ut'])

tqdm_params = {
'smoothing': 0, # 平滑进度条的预计剩余时间,取值范围0到1
Expand Down Expand Up @@ -71,7 +70,6 @@ def recursive_arg_to_cpu(arg_in, to_detach, raise_dtype=None):
if need_backward and not arg_in.stop_gradient:
arg_in = deal_detach(raise_bench_data_dtype(api_name, arg_in.clone(), raise_dtype), to_detach)
arg_in.stop_gradient = False

return arg_in
else:
return deal_detach(raise_bench_data_dtype(api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach)
Expand Down Expand Up @@ -135,20 +133,19 @@ def run_ut(config):
print_info_log("start UT test")
print_info_log(f"UT task result will be saved in {config.result_csv_path}")
print_info_log(f"UT task details will be saved in {config.details_csv_path}")
# compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut)
# with FileOpen(config.result_csv_path, 'r') as file:
# csv_reader = csv.reader(file)
# next(csv_reader)
# api_name_set = {row[0] for row in csv_reader}
compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut)
with FileOpen(config.result_csv_path, 'r') as file:
csv_reader = csv.reader(file)
next(csv_reader)
for i, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)):
try:
print(api_full_name)
data_info = run_paddle_api(api_full_name, config.real_data_path, api_info_dict)
# is_fwd_success, is_bwd_success = compare.compare_output(api_full_name,
# data_info.bench_output,
# data_info.device_output,
# data_info.bench_grad,
# data_info.device_grad)
data_info = run_paddle_api(api_full_name, api_info_dict)
is_fwd_success, is_bwd_success = compare.compare_output(api_full_name,
data_info.bench_output,
data_info.device_output,
data_info.bench_grad,
data_info.device_grad)
except Exception as err:
[_, api_name, _] = api_full_name.split("*")
if "expected scalar type Long" in str(err):
Expand All @@ -160,11 +157,11 @@ def run_ut(config):
gc.collect()


def run_paddle_api(api_full_name, real_data_path, api_info_dict):
def run_paddle_api(api_full_name, api_info_dict):
in_fwd_data_list = []
backward_message = ''
[api_type, api_name, _] = api_full_name.split('*')
args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path)
args, kwargs, need_grad = get_api_info(api_info_dict, api_name)
in_fwd_data_list.append(args)
in_fwd_data_list.append(kwargs)
need_backward = True
Expand Down Expand Up @@ -227,12 +224,6 @@ def run_ut_command_save(args):
forward_file = os.path.realpath(args.forward_input_file)
check_file_suffix(forward_file, FileCheckConst.JSON_SUFFIX)
forward_content = get_json_contents(forward_file)
backward_content = {}
if args.backward_input_file:
check_link(args.backward_input_file)
backward_file = os.path.realpath(args.backward_input_file)
check_file_suffix(backward_file, FileCheckConst.JSON_SUFFIX)
backward_content = get_json_contents(backward_file)
result_csv_path = os.path.join(out_path, RESULT_FILE_NAME)
details_csv_path = os.path.join(out_path, DETAILS_FILE_NAME)
if args.result_csv_path:
Expand All @@ -243,8 +234,8 @@ def run_ut_command_save(args):
time_info = result_csv_path.split('.')[0].split('_')[-1]
global UT_ERROR_DATA_DIR
UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
run_ut_config = RunUTConfig(forward_content, backward_content, result_csv_path, details_csv_path, save_error_data,
args.result_csv_path, args.real_data_path)
run_ut_config = RunUTConfig(forward_content, result_csv_path, details_csv_path, save_error_data,
args.result_csv_path)
run_ut_save(run_ut_config)


Expand All @@ -253,7 +244,8 @@ def run_ut_save(config):
for i, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)):
try:
print(api_full_name)
run_paddle_api_save(api_full_name, config.real_data_path, api_info_dict)
dump_path = os.path.dirname(config.result_csv_path)
run_paddle_api_save(api_full_name, api_info_dict, dump_path)
print("*"*200)
except Exception as err:
[_, api_name, _] = api_full_name.split("*")
Expand All @@ -266,11 +258,11 @@ def run_ut_save(config):
gc.collect()


def run_paddle_api_save(api_full_name, real_data_path, api_info_dict):
def run_paddle_api_save(api_full_name, api_info_dict, dump_path):
in_fwd_data_list = []
backward_message = ''
[api_type, api_name, _] = api_full_name.split('*')
args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path)
args, kwargs, need_grad = get_api_info(api_info_dict, api_name)
in_fwd_data_list.append(args)
in_fwd_data_list.append(kwargs)
need_backward = True
Expand All @@ -293,8 +285,13 @@ def run_paddle_api_save(api_full_name, real_data_path, api_info_dict):
output_folder = "npu_output"
else:
output_folder = "gpu_output"
current_dir = os.path.dirname(os.path.abspath(__file__))
output_dir = os.path.abspath(os.path.join(current_dir, "..", output_folder))
if dump_path == os.path.dirname(os.path.abspath(__file__)):
output_dir = os.path.abspath(os.path.join(dump_path, "..", output_folder))
output_dir_back = os.path.abspath(os.path.join(dump_path, "..", output_folder + "_backward"))
else:
current_dir = dump_path
output_dir = os.path.abspath(os.path.join(current_dir, output_folder))
output_dir_back = os.path.abspath(os.path.join(current_dir, output_folder + "_backward"))
os.makedirs(output_dir, exist_ok=True)
output_path = output_dir + '/' + f'{api_full_name}'
paddle.save(device_out, output_path)
Expand All @@ -314,19 +311,19 @@ def run_paddle_api_save(api_full_name, real_data_path, api_info_dict):
else:
backward_message += Backward_Message.MULTIPLE_BACKWARD_MESSAGE

output_dir = os.path.abspath(os.path.join(current_dir, "..", output_folder + "_backward"))
os.makedirs(output_dir, exist_ok=True)
output_path = output_dir + '/' + f'{api_full_name}'
# output_dir = os.path.abspath(os.path.join(current_dir, "..", output_folder + "_backward"))
os.makedirs(output_dir_back, exist_ok=True)
output_path = output_dir_back + '/' + f'{api_full_name}'
paddle.save(device_grad_out, output_path)
return


def get_api_info(api_info_dict, api_name, real_data_path):
def get_api_info(api_info_dict, api_name):
convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict)
need_grad = True
if api_info_dict.get("kwargs") and "out" in api_info_dict.get("kwargs"):
need_grad = False
args, kwargs = gen_api_params(api_info_dict, api_name, need_grad, convert_type, real_data_path)
args, kwargs = gen_api_params(api_info_dict, api_name, need_grad, convert_type)
return args, kwargs, need_grad


Expand All @@ -337,7 +334,6 @@ def need_to_backward(grad_index, out):


def run_backward(args, grad_index, out):

if grad_index is not None:
out[grad_index].backward()
else:
Expand All @@ -347,7 +343,6 @@ def run_backward(args, grad_index, out):
if isinstance(arg, paddle.Tensor):
args_grad.append(arg.grad)
grad_out = args_grad

return grad_out


Expand Down Expand Up @@ -391,10 +386,6 @@ def _run_ut_parser(parser):
help="<Optional> The api param tool forward result file: generate from api param tool, "
"a json file.",
required=True)
parser.add_argument("-backward", "--backward", dest="backward_input_file", default="", type=str,
help="<Optional> The api param tool backward result file: generate from api param tool, "
"a json file.",
required=False)
parser.add_argument("-o", "--dump_path", dest="out_path", default="", type=str,
help="<optional> The ut task result out path.",
required=False)
Expand Down Expand Up @@ -426,20 +417,14 @@ def __call__(self, parser, namespace, values, option_string=None):
help="<optional> The path of accuracy_checking_result_{timestamp}.csv, "
"when run ut is interrupted, enter the file path to continue run ut.",
required=False)
parser.add_argument("-real_data_path", dest="real_data_path", nargs="?", const="", default="", type=str,
help="<optional> In real data mode, the root directory for storing real data "
"must be configured.",
required=False)


def preprocess_forward_content(forward_content):
processed_content = {}
base_keys_variants = {}
arg_cache = {}

for key, value in forward_content.items():
base_key = key.rsplit(Const.DELIMITER, 1)[0]

if key not in arg_cache:
new_args = value['args']
new_kwargs = value['kwargs']
Expand All @@ -448,9 +433,7 @@ def preprocess_forward_content(forward_content):
for arg in new_args if isinstance(arg, dict)
]
arg_cache[key] = (filtered_new_args, new_kwargs)

filtered_new_args, new_kwargs = arg_cache[key]

if base_key not in base_keys_variants:
processed_content[key] = value
base_keys_variants[base_key] = {key}
Expand All @@ -461,11 +444,9 @@ def preprocess_forward_content(forward_content):
if existing_args == filtered_new_args and existing_kwargs == new_kwargs:
is_duplicate = True
break

if not is_duplicate:
processed_content[key] = value
base_keys_variants[base_key].add(key)

return processed_content


Expand Down Expand Up @@ -493,12 +474,6 @@ def run_ut_command(args):
forward_file = os.path.realpath(args.forward_input_file)
check_file_suffix(forward_file, FileCheckConst.JSON_SUFFIX)
forward_content = get_json_contents(forward_file)
backward_content = {}
if args.backward_input_file:
check_link(args.backward_input_file)
backward_file = os.path.realpath(args.backward_input_file)
check_file_suffix(backward_file, FileCheckConst.JSON_SUFFIX)
backward_content = get_json_contents(backward_file)
result_csv_path = os.path.join(out_path, RESULT_FILE_NAME)
details_csv_path = os.path.join(out_path, DETAILS_FILE_NAME)
if args.result_csv_path:
Expand All @@ -509,8 +484,8 @@ def run_ut_command(args):
time_info = result_csv_path.split('.')[0].split('_')[-1]
global UT_ERROR_DATA_DIR
UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
run_ut_config = RunUTConfig(forward_content, backward_content, result_csv_path, details_csv_path, save_error_data,
args.result_csv_path, args.real_data_path)
run_ut_config = RunUTConfig(forward_content, result_csv_path, details_csv_path, save_error_data,
args.result_csv_path)
run_ut(run_ut_config)


Expand All @@ -525,6 +500,7 @@ def __init__(self, bench_grad, device_grad, device_output, bench_output, in_fwd_
self.backward_message = backward_message
self.rank = rank


if __name__ == "__main__":
_run_ut()
print_info_log("UT task completed")
7 changes: 0 additions & 7 deletions Acc/run_ut/run_ut_utils.py

This file was deleted.

Loading