diff --git a/paddleapex/api_tracer/Tracer.py b/paddleapex/api_tracer/Tracer.py index 9efe197..4570726 100644 --- a/paddleapex/api_tracer/Tracer.py +++ b/paddleapex/api_tracer/Tracer.py @@ -16,6 +16,7 @@ from paddleapex.api_tracer.Dump import dump_util from paddleapex.api_tracer.wrap_op.hijack_tool import hijack_api from paddleapex.api_tracer.config import cfg +from paddleapex.apex.utils import print_info_log class Tracer: @@ -32,3 +33,19 @@ def start(self): def stop(self): if cfg.dump_state: dump_util.dump() + + def start_in_training(self, cur_step, acc): + self.acc = acc + self.global_step = cur_step // acc + self.inner_step = cur_step % acc + if self.inner_step == 0: + dump_signal = cfg.new_step_in_training(self.global_step) + if dump_signal: + print_info_log(f"Starting tracing step:{self.global_step}") + + def stop_in_training(self): + if self.inner_step == self.acc - 1: + dump_signal = cfg.reset_step_in_training(self.global_step) + if dump_signal: + print_info_log(f"Stopping tracing step:{self.global_step}") + dump_util.dump() \ No newline at end of file diff --git a/paddleapex/api_tracer/config.py b/paddleapex/api_tracer/config.py index 6c40c73..8f0e5ef 100644 --- a/paddleapex/api_tracer/config.py +++ b/paddleapex/api_tracer/config.py @@ -53,6 +53,20 @@ def new_step(self): else: self.Op_count = {} self.dump_state = False + + def new_step_in_training(self, global_step): + if global_step in self.target_step: + self.Op_count = {} + self.dump_state = True + return True + return False + + def reset_step_in_training(self, global_step): + if global_step in self.target_step: + self.global_step = global_step + self.dump_state = False + return True + return False cfg = Config()