diff --git a/paddleapex/api_tracer/api_info.py b/paddleapex/api_tracer/api_info.py index c677845..cbc509e 100644 --- a/paddleapex/api_tracer/api_info.py +++ b/paddleapex/api_tracer/api_info.py @@ -107,13 +107,23 @@ def update_APIInfo(self, op_name, rank): self.op_name = op_name self.rank = rank - def update_real_data(self, inputs, kwargs): + + def get_extra_param(self): + init_param = cfg.cls_target_obj[self.op_name.split('*')[0]]['extra_param'] + return init_param + + def update_real_data(self, inputs, kwargs, param={}, api_type="op"): self.is_half_precision = False args_info_list = self.analyze_element(inputs) kwargs_info_dict = self.analyze_element(kwargs) + param_info_dict = self.analyze_element(param) + # print(param_info_dict) self.api_info_struct = { - self.op_name: {"args": args_info_list, "kwargs": kwargs_info_dict, "dout_list": ["Failed"]} + self.op_name: {"api_type": api_type, "args": args_info_list, "kwargs": kwargs_info_dict} } + self.api_info_struct[self.op_name].update(param_info_dict | {"dout_list": ["Failed"]}) + # print(self.api_info_struct) + dump_util.update_api_dict(self.api_info_struct, self.rank, self.is_half_precision) def record_dout(self, grad_value): diff --git a/paddleapex/api_tracer/config.py b/paddleapex/api_tracer/config.py index 6c40c73..6fd297c 100644 --- a/paddleapex/api_tracer/config.py +++ b/paddleapex/api_tracer/config.py @@ -26,8 +26,11 @@ def __init__(self) -> None: configs = yaml.load(f, Loader=yaml.FullLoader) self.dump_mode = configs["dump_mode"] self.op_target_pth = configs["op_target_path"] + self.cls_target_pth = configs["cls_target_path"] if self.op_target_pth == "None": self.op_target_pth = os.path.join(current_dir, "configs/op_target.yaml") + if self.cls_target_pth == "None": + self.cls_target_pth = os.path.join(current_dir, "configs/cls_target.yaml") self.dump_root_path = configs["dump_root_path"] self.target_step = configs["target_step"] self.remote_path = configs["remote_path"] @@ -42,8 +45,11 @@ def __init__(self) -> None: time.sleep(1) self.global_step = -1 self.dump_state = False + # When dump_class_state is true, dump_func needs to be temporarily disabled to avoid duplication. + self.disable_dump_func_state = False self.Op_count = {} self.prefix_op_name_ = None + self.cls_target_obj = None def new_step(self): self.global_step += 1 @@ -55,4 +61,5 @@ def new_step(self): self.dump_state = False + cfg = Config() diff --git a/paddleapex/api_tracer/configs/cls_target.yaml b/paddleapex/api_tracer/configs/cls_target.yaml new file mode 100644 index 0000000..9e74fd8 --- /dev/null +++ b/paddleapex/api_tracer/configs/cls_target.yaml @@ -0,0 +1,9 @@ +ignored_cls: + +target_cls: + paddle_xpu.layers.nn.Linear: + extra_param: + - self.weight + - self.bias + paddle_cls: + paddle.nn.Linear \ No newline at end of file diff --git a/paddleapex/api_tracer/configs/tool_config.yaml b/paddleapex/api_tracer/configs/tool_config.yaml index 30375b1..43c3b7b 100644 --- a/paddleapex/api_tracer/configs/tool_config.yaml +++ b/paddleapex/api_tracer/configs/tool_config.yaml @@ -2,6 +2,7 @@ # Target APIs. op_target_path: "None" +cls_target_path: "None" # If op_target_path is set to None, tool will set the default config in repository configs. dump_root_path: "./dump_info" diff --git a/paddleapex/api_tracer/wrap_op/OPTemplate.py b/paddleapex/api_tracer/wrap_op/OPTemplate.py index f9e8833..7126d0b 100644 --- a/paddleapex/api_tracer/wrap_op/OPTemplate.py +++ b/paddleapex/api_tracer/wrap_op/OPTemplate.py @@ -30,40 +30,101 @@ def __init__(self, op_name): cfg.prefix_op_name_ = self.op_name_ + "*" def forward(self, *args, **kwargs): - if self.op_name_ not in cfg.Op_count: - cfg.Op_count[self.op_name_] = 1 - cfg.prefix_op_name_ += "0" - else: - cfg.Op_count[self.op_name_] += 1 - cfg.prefix_op_name_ += str(cfg.Op_count[self.op_name_] - 1) - if cfg.dump_state: - api_recorder = API(cfg.dump_mode) - rank = dist.get_rank() - api_recorder.update_APIInfo(cfg.prefix_op_name_, rank) - api_recorder.update_real_data(args, kwargs) - output = getattr(HookOp, "wrap_" + str(self.op_name_))(*args, **kwargs) - try: - if isinstance(output, paddle.Tensor): - if not output.stop_gradient: - output.register_hook(api_recorder.record_dout) - api_recorder.output_num = 1 - else: - api_recorder.record_dout(None) - if isinstance(output, (list, tuple)): - need_record = False - for item in output: - if isinstance(item, paddle.Tensor) and not item.stop_gradient: - api_recorder.output_num += 1 - need_record = True - item.register_hook(api_recorder.record_dout) - if not need_record: - api_recorder.record_dout(None) - except Exception as e: - print(self.op_name_, " register hook failed. Due to :", e) - api_recorder.record_dout(None) + if not cfg.disable_dump_func_state: + if self.op_name_ not in cfg.Op_count: + cfg.Op_count[self.op_name_] = 1 + cfg.prefix_op_name_ += "0" + else: + cfg.Op_count[self.op_name_] += 1 + cfg.prefix_op_name_ += str(cfg.Op_count[self.op_name_] - 1) + if cfg.dump_state: + api_recorder = API(cfg.dump_mode) + rank = dist.get_rank() + api_recorder.update_APIInfo(cfg.prefix_op_name_, rank) + api_recorder.update_real_data(args, kwargs) + # print(self.op_name_) + output = getattr(HookOp, "wrap_" + str(self.op_name_))(*args, **kwargs) + try: + if isinstance(output, paddle.Tensor): + if not output.stop_gradient: + output.register_hook(api_recorder.record_dout) + api_recorder.output_num = 1 + else: + api_recorder.record_dout(None) + if isinstance(output, (list, tuple)): + need_record = False + for item in output: + if isinstance(item, paddle.Tensor) and not item.stop_gradient: + api_recorder.output_num += 1 + need_record = True + item.register_hook(api_recorder.record_dout) + if not need_record: + api_recorder.record_dout(None) + except Exception as e: + print(self.op_name_, " register hook failed. Due to :", e) + api_recorder.record_dout(None) + else: + output = getattr(HookOp, "wrap_" + str(self.op_name_))(*args, **kwargs) else: output = getattr(HookOp, "wrap_" + str(self.op_name_))(*args, **kwargs) return output def __call__(self, *inputs, **kwargs): return self.forward(*inputs, **kwargs) + + +def temp_init(self, *inputs, **kwargs): + # print("============init==================") + self.cls_all_name_ = self.__class__.__name__ + # print("self.__class__.__name__ = ", self.cls_all_name_) + cfg.prefix_op_name_ = self.cls_all_name_ + "*" + cfg.disable_dump_func_state = True + super(self.__class__, self).__init__(*inputs, **kwargs) + +def temp_forward(self, *inputs, **kwargs): + # print("============forward==================") + if self.cls_all_name_ not in cfg.Op_count: + cfg.Op_count[self.cls_all_name_] = 1 + cfg.prefix_op_name_ += "0" + else: + cfg.Op_count[self.cls_all_name_] += 1 + cfg.prefix_op_name_ += str(cfg.Op_count[self.cls_all_name_] - 1) + if cfg.dump_state: + api_recorder = API(cfg.dump_mode) + rank = dist.get_rank() + api_recorder.update_APIInfo(cfg.prefix_op_name_, rank) + + extra_param_str = api_recorder.get_extra_param() + extra_param = {param: getattr(self, param.split('.')[1]) for param in extra_param_str} + + # print("extra_param: \n", extra_param) + api_recorder.update_real_data(inputs, kwargs, extra_param, "class") + # Call the parent class function + # print(self.cls_all_name_ + '.forword') + output = super(self.__class__, self).forward(*inputs, **kwargs) + try: + if isinstance(output, paddle.Tensor): + if not output.stop_gradient: + output.register_hook(api_recorder.record_dout) + api_recorder.output_num = 1 + else: + api_recorder.record_dout(None) + if isinstance(output, (list, tuple)): + need_record = False + for item in output: + if isinstance(item, paddle.Tensor) and not item.stop_gradient: + api_recorder.output_num += 1 + need_record = True + item.register_hook(api_recorder.record_dout) + if not need_record: + api_recorder.record_dout(None) + except Exception as e: + print(self.cls_all_name_+ '.forword', " register hook failed. Due to :", e) + api_recorder.record_dout(None) + else: + output = super(self.__class__, self).forward(*inputs, **kwargs) + + cfg.disable_dump_func_state = False + # print(output) + return output + diff --git a/paddleapex/api_tracer/wrap_op/get_target_op.py b/paddleapex/api_tracer/wrap_op/get_target_op.py index 9b5c2f0..9a25845 100644 --- a/paddleapex/api_tracer/wrap_op/get_target_op.py +++ b/paddleapex/api_tracer/wrap_op/get_target_op.py @@ -44,7 +44,26 @@ def check_api_stack(self): def get_target_ops(self): self.api_to_catch = set(self.target_op) - set(self.ignored_op) + + # if profile mode is on, we will not catch max and min if cfg.profile_mode: self.api_to_catch -= set(["paddle.max", "paddle.min"]) self.check_api_stack() return self.api_to_catch + +class GetTargetCls(GetTargetOP): + def __init__(self, yaml_path): + with open(yaml_path, "r") as f: + Ops = yaml.safe_load(f) + self.target_op = Ops.get("target_cls") + cfg.cls_target_obj = self.target_op + # print(self.target_op) + self.ignored_op = Ops.get("ignored_cls") + f.close() + if self.ignored_op is None: + self.ignored_op = [] + self.api_to_catch = set(self.target_op.keys()) - set(self.ignored_op) + + def get_target_ops(self): + self.check_api_stack() + return self.api_to_catch diff --git a/paddleapex/api_tracer/wrap_op/hijack_tool.py b/paddleapex/api_tracer/wrap_op/hijack_tool.py index 8dad0f2..867425f 100644 --- a/paddleapex/api_tracer/wrap_op/hijack_tool.py +++ b/paddleapex/api_tracer/wrap_op/hijack_tool.py @@ -13,10 +13,11 @@ # limitations under the License. +import inspect from .. import config from ...utils import try_import -from .get_target_op import GetTargetOP -from .OPTemplate import OPTemplate, HookOp +from .get_target_op import GetTargetOP, GetTargetCls +from .OPTemplate import OPTemplate, HookOp, temp_init, temp_forward cfg = config.cfg @@ -27,22 +28,73 @@ def op_template(*args, **kwargs): return op_template +# def wrapped_cls(cls_name, cls): +# def op_template(*args, **kwargs): +# # 可以从HookCls中获取cls对象 +# wrap_class = type(f"{cls_name}", (cls,), {'__init__': temp_init, 'forward': temp_forward}) +# return wrap_class(cls_name, *args, **kwargs) + +# return op_template + +def add_class(cls, parent_package, class_name, whole_name): + # print(whole_name) + module = eval(parent_package) + + original_class = getattr(module, class_name, None) + if original_class: + print(f"Original class before replacement: {class_name} (id: {id(original_class)})") + else: + print(f"No original class named {class_name} found in {parent_package}") + + wrap_class = type(f"{whole_name}", (cls,), {'__init__': temp_init, 'forward': temp_forward}) + setattr(module, class_name, wrap_class) + + new_class = getattr(module, class_name, None) + if new_class: + print(f"New class after replacement: {class_name} (id: {id(new_class)})") + + if original_class != new_class: + print(f"Class {class_name} successfully replaced.") + else: + print(f"Class {class_name} replacement failed.") + + def hijack_api(): op = GetTargetOP(cfg.op_target_pth) + cls = GetTargetCls(cfg.cls_target_pth) target_op = op.get_target_ops() + target_cls = cls.get_target_ops() for op_name in target_op: parent_package, method_name = op_name.rsplit(".", maxsplit=1) try: pack = parent_package.split(".")[0] package_name, module = try_import(pack) globals()[package_name] = module - setattr( - HookOp, "wrap_" + op_name, getattr(eval(parent_package), method_name) - ) + target_object = getattr(eval(parent_package), method_name) + if inspect.isclass(target_object): + raise ValueError(f"{op_name} is a class") + else: + setattr( + HookOp, "wrap_" + op_name, target_object + ) except Exception as err: print(op_name, str(err)) + for cls_name in target_cls: + parent_package, method_name = cls_name.rsplit(".", maxsplit=1) + try: + pack = parent_package.split(".")[0] + package_name, module = try_import(pack) + globals()[package_name] = module + target_object = getattr(eval(parent_package), method_name) + if inspect.isclass(target_object): + add_class(target_object, parent_package, method_name, cls_name) + else: + raise ValueError(f"{cls} not a class") + except Exception as err: + print(cls_name, str(err)) + for attr_name in dir(HookOp): if attr_name.startswith("wrap_"): parent_package, method_name = attr_name[5:].rsplit(".", maxsplit=1)