Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions paddleapex/api_tracer/api_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,23 @@ def update_APIInfo(self, op_name, rank):
self.op_name = op_name
self.rank = rank

def update_real_data(self, inputs, kwargs):

def get_extra_param(self):
init_param = cfg.cls_target_obj[self.op_name.split('*')[0]]['extra_param']
return init_param

def update_real_data(self, inputs, kwargs, param={}, api_type="op"):
self.is_half_precision = False
args_info_list = self.analyze_element(inputs)
kwargs_info_dict = self.analyze_element(kwargs)
param_info_dict = self.analyze_element(param)
# print(param_info_dict)
self.api_info_struct = {
self.op_name: {"args": args_info_list, "kwargs": kwargs_info_dict, "dout_list": ["Failed"]}
self.op_name: {"api_type": api_type, "args": args_info_list, "kwargs": kwargs_info_dict}
}
self.api_info_struct[self.op_name].update(param_info_dict | {"dout_list": ["Failed"]})
# print(self.api_info_struct)

dump_util.update_api_dict(self.api_info_struct, self.rank, self.is_half_precision)

def record_dout(self, grad_value):
Expand Down
7 changes: 7 additions & 0 deletions paddleapex/api_tracer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ def __init__(self) -> None:
configs = yaml.load(f, Loader=yaml.FullLoader)
self.dump_mode = configs["dump_mode"]
self.op_target_pth = configs["op_target_path"]
self.cls_target_pth = configs["cls_target_path"]
if self.op_target_pth == "None":
self.op_target_pth = os.path.join(current_dir, "configs/op_target.yaml")
if self.cls_target_pth == "None":
self.cls_target_pth = os.path.join(current_dir, "configs/cls_target.yaml")
self.dump_root_path = configs["dump_root_path"]
self.target_step = configs["target_step"]
self.remote_path = configs["remote_path"]
Expand All @@ -42,8 +45,11 @@ def __init__(self) -> None:
time.sleep(1)
self.global_step = -1
self.dump_state = False
# When dump_class_state is true, dump_func needs to be temporarily disabled to avoid duplication.
self.disable_dump_func_state = False
self.Op_count = {}
self.prefix_op_name_ = None
self.cls_target_obj = None

def new_step(self):
self.global_step += 1
Expand All @@ -55,4 +61,5 @@ def new_step(self):
self.dump_state = False



cfg = Config()
9 changes: 9 additions & 0 deletions paddleapex/api_tracer/configs/cls_target.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
ignored_cls:

target_cls:
paddle_xpu.layers.nn.Linear:
extra_param:
- self.weight
- self.bias
paddle_cls:
paddle.nn.Linear
1 change: 1 addition & 0 deletions paddleapex/api_tracer/configs/tool_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# Target APIs.
op_target_path: "None"
cls_target_path: "None"
# If op_target_path is set to None, tool will set the default config in repository configs.
dump_root_path: "./dump_info"

Expand Down
123 changes: 92 additions & 31 deletions paddleapex/api_tracer/wrap_op/OPTemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,40 +30,101 @@ def __init__(self, op_name):
cfg.prefix_op_name_ = self.op_name_ + "*"

def forward(self, *args, **kwargs):
if self.op_name_ not in cfg.Op_count:
cfg.Op_count[self.op_name_] = 1
cfg.prefix_op_name_ += "0"
else:
cfg.Op_count[self.op_name_] += 1
cfg.prefix_op_name_ += str(cfg.Op_count[self.op_name_] - 1)
if cfg.dump_state:
api_recorder = API(cfg.dump_mode)
rank = dist.get_rank()
api_recorder.update_APIInfo(cfg.prefix_op_name_, rank)
api_recorder.update_real_data(args, kwargs)
output = getattr(HookOp, "wrap_" + str(self.op_name_))(*args, **kwargs)
try:
if isinstance(output, paddle.Tensor):
if not output.stop_gradient:
output.register_hook(api_recorder.record_dout)
api_recorder.output_num = 1
else:
api_recorder.record_dout(None)
if isinstance(output, (list, tuple)):
need_record = False
for item in output:
if isinstance(item, paddle.Tensor) and not item.stop_gradient:
api_recorder.output_num += 1
need_record = True
item.register_hook(api_recorder.record_dout)
if not need_record:
api_recorder.record_dout(None)
except Exception as e:
print(self.op_name_, " register hook failed. Due to :", e)
api_recorder.record_dout(None)
if not cfg.disable_dump_func_state:
if self.op_name_ not in cfg.Op_count:
cfg.Op_count[self.op_name_] = 1
cfg.prefix_op_name_ += "0"
else:
cfg.Op_count[self.op_name_] += 1
cfg.prefix_op_name_ += str(cfg.Op_count[self.op_name_] - 1)
if cfg.dump_state:
api_recorder = API(cfg.dump_mode)
rank = dist.get_rank()
api_recorder.update_APIInfo(cfg.prefix_op_name_, rank)
api_recorder.update_real_data(args, kwargs)
# print(self.op_name_)
output = getattr(HookOp, "wrap_" + str(self.op_name_))(*args, **kwargs)
try:
if isinstance(output, paddle.Tensor):
if not output.stop_gradient:
output.register_hook(api_recorder.record_dout)
api_recorder.output_num = 1
else:
api_recorder.record_dout(None)
if isinstance(output, (list, tuple)):
need_record = False
for item in output:
if isinstance(item, paddle.Tensor) and not item.stop_gradient:
api_recorder.output_num += 1
need_record = True
item.register_hook(api_recorder.record_dout)
if not need_record:
api_recorder.record_dout(None)
except Exception as e:
print(self.op_name_, " register hook failed. Due to :", e)
api_recorder.record_dout(None)
else:
output = getattr(HookOp, "wrap_" + str(self.op_name_))(*args, **kwargs)
else:
output = getattr(HookOp, "wrap_" + str(self.op_name_))(*args, **kwargs)
return output

def __call__(self, *inputs, **kwargs):
return self.forward(*inputs, **kwargs)


def temp_init(self, *inputs, **kwargs):
# print("============init==================")
self.cls_all_name_ = self.__class__.__name__
# print("self.__class__.__name__ = ", self.cls_all_name_)
cfg.prefix_op_name_ = self.cls_all_name_ + "*"
cfg.disable_dump_func_state = True
super(self.__class__, self).__init__(*inputs, **kwargs)

def temp_forward(self, *inputs, **kwargs):
# print("============forward==================")
if self.cls_all_name_ not in cfg.Op_count:
cfg.Op_count[self.cls_all_name_] = 1
cfg.prefix_op_name_ += "0"
else:
cfg.Op_count[self.cls_all_name_] += 1
cfg.prefix_op_name_ += str(cfg.Op_count[self.cls_all_name_] - 1)
if cfg.dump_state:
api_recorder = API(cfg.dump_mode)
rank = dist.get_rank()
api_recorder.update_APIInfo(cfg.prefix_op_name_, rank)

extra_param_str = api_recorder.get_extra_param()
extra_param = {param: getattr(self, param.split('.')[1]) for param in extra_param_str}

# print("extra_param: \n", extra_param)
api_recorder.update_real_data(inputs, kwargs, extra_param, "class")
# Call the parent class function
# print(self.cls_all_name_ + '.forword')
output = super(self.__class__, self).forward(*inputs, **kwargs)
try:
if isinstance(output, paddle.Tensor):
if not output.stop_gradient:
output.register_hook(api_recorder.record_dout)
api_recorder.output_num = 1
else:
api_recorder.record_dout(None)
if isinstance(output, (list, tuple)):
need_record = False
for item in output:
if isinstance(item, paddle.Tensor) and not item.stop_gradient:
api_recorder.output_num += 1
need_record = True
item.register_hook(api_recorder.record_dout)
if not need_record:
api_recorder.record_dout(None)
except Exception as e:
print(self.cls_all_name_+ '.forword', " register hook failed. Due to :", e)
api_recorder.record_dout(None)
else:
output = super(self.__class__, self).forward(*inputs, **kwargs)

cfg.disable_dump_func_state = False
# print(output)
return output

19 changes: 19 additions & 0 deletions paddleapex/api_tracer/wrap_op/get_target_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,26 @@ def check_api_stack(self):

def get_target_ops(self):
self.api_to_catch = set(self.target_op) - set(self.ignored_op)

# if profile mode is on, we will not catch max and min
if cfg.profile_mode:
self.api_to_catch -= set(["paddle.max", "paddle.min"])
self.check_api_stack()
return self.api_to_catch

class GetTargetCls(GetTargetOP):
def __init__(self, yaml_path):
with open(yaml_path, "r") as f:
Ops = yaml.safe_load(f)
self.target_op = Ops.get("target_cls")
cfg.cls_target_obj = self.target_op
# print(self.target_op)
self.ignored_op = Ops.get("ignored_cls")
f.close()
if self.ignored_op is None:
self.ignored_op = []
self.api_to_catch = set(self.target_op.keys()) - set(self.ignored_op)

def get_target_ops(self):
self.check_api_stack()
return self.api_to_catch
62 changes: 57 additions & 5 deletions paddleapex/api_tracer/wrap_op/hijack_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# limitations under the License.


import inspect
from .. import config
from ...utils import try_import
from .get_target_op import GetTargetOP
from .OPTemplate import OPTemplate, HookOp
from .get_target_op import GetTargetOP, GetTargetCls
from .OPTemplate import OPTemplate, HookOp, temp_init, temp_forward

cfg = config.cfg

Expand All @@ -27,22 +28,73 @@ def op_template(*args, **kwargs):

return op_template

# def wrapped_cls(cls_name, cls):
# def op_template(*args, **kwargs):
# # 可以从HookCls中获取cls对象
# wrap_class = type(f"{cls_name}", (cls,), {'__init__': temp_init, 'forward': temp_forward})
# return wrap_class(cls_name, *args, **kwargs)

# return op_template

def add_class(cls, parent_package, class_name, whole_name):
# print(whole_name)
module = eval(parent_package)

original_class = getattr(module, class_name, None)
if original_class:
print(f"Original class before replacement: {class_name} (id: {id(original_class)})")
else:
print(f"No original class named {class_name} found in {parent_package}")

wrap_class = type(f"{whole_name}", (cls,), {'__init__': temp_init, 'forward': temp_forward})
setattr(module, class_name, wrap_class)

new_class = getattr(module, class_name, None)
if new_class:
print(f"New class after replacement: {class_name} (id: {id(new_class)})")

if original_class != new_class:
print(f"Class {class_name} successfully replaced.")
else:
print(f"Class {class_name} replacement failed.")



def hijack_api():
op = GetTargetOP(cfg.op_target_pth)
cls = GetTargetCls(cfg.cls_target_pth)
target_op = op.get_target_ops()
target_cls = cls.get_target_ops()
for op_name in target_op:
parent_package, method_name = op_name.rsplit(".", maxsplit=1)
try:
pack = parent_package.split(".")[0]
package_name, module = try_import(pack)
globals()[package_name] = module
setattr(
HookOp, "wrap_" + op_name, getattr(eval(parent_package), method_name)
)
target_object = getattr(eval(parent_package), method_name)
if inspect.isclass(target_object):
raise ValueError(f"{op_name} is a class")
else:
setattr(
HookOp, "wrap_" + op_name, target_object
)
except Exception as err:
print(op_name, str(err))

for cls_name in target_cls:
parent_package, method_name = cls_name.rsplit(".", maxsplit=1)
try:
pack = parent_package.split(".")[0]
package_name, module = try_import(pack)
globals()[package_name] = module
target_object = getattr(eval(parent_package), method_name)
if inspect.isclass(target_object):
add_class(target_object, parent_package, method_name, cls_name)
else:
raise ValueError(f"{cls} not a class")
except Exception as err:
print(cls_name, str(err))

for attr_name in dir(HookOp):
if attr_name.startswith("wrap_"):
parent_package, method_name = attr_name[5:].rsplit(".", maxsplit=1)
Expand Down