From e79853f829108086f7e62048fd085f3f97ede576 Mon Sep 17 00:00:00 2001 From: julianmack Date: Fri, 31 Jan 2020 11:22:28 +0000 Subject: [PATCH] Updated csv logger so that it no longer overwrites old csv logs --- src/myrtlespeech/run/callbacks/csv_logger.py | 24 +++++++++++++------- src/myrtlespeech/run/run.py | 4 +--- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/src/myrtlespeech/run/callbacks/csv_logger.py b/src/myrtlespeech/run/callbacks/csv_logger.py index d72e7769..7d6222f2 100644 --- a/src/myrtlespeech/run/callbacks/csv_logger.py +++ b/src/myrtlespeech/run/callbacks/csv_logger.py @@ -9,7 +9,11 @@ class CSVLogger(Callback): - r"""Logs at the end of an epoch to the file at ``path`` in CSV format. + r"""Logs at the end of an epoch in CSV format. + + This callback saves the CSV to ``log_dir`` with file name + ``log_.csv`` when + :py:meth:`.CallbackHandler.on_train_begin` is called. The first entry in each row of the CSV file denotes the date and time the entry was logged in ISO 8601 format (:py:meth:`datetime.isoformat`). The @@ -17,8 +21,8 @@ class CSVLogger(Callback): current stage. This will either be ``train`` or ``eval``. Args: - path: Path of a file to write values to. The file will be truncated - when :py:meth:`.CallbackHandler.on_train_begin` is called. + log_dir: A pathlike object representing the directory to which results + will be logged. keys: Keys in :py:data:`.CallbackHandler.state_dict`, passed as ``**kwargs`` to the ``CSVLogger``, whose values will be logged. If @@ -36,13 +40,14 @@ class CSVLogger(Callback): Example: >>> # imports >>> import tempfile + >>> import glob >>> from myrtlespeech.run.callbacks.callback import CallbackHandler >>> >>> # create file to write to - >>> temp = tempfile.NamedTemporaryFile() + >>> temp = tempfile.TemporaryDirectory() >>> >>> # initialize CSVLogger and CallbackHandler - >>> csv_logger = CSVLogger(path=temp.name, keys=["epoch"]) + >>> csv_logger = CSVLogger(log_dir=temp.name, keys=["epoch"]) >>> cb_handler = CallbackHandler(callbacks=[csv_logger]) >>> >>> # simulate training and eval example @@ -56,7 +61,8 @@ class CSVLogger(Callback): >>> # note the first entry in each CSV row is the date >>> # this will change per each doctest invocation so it is ignored >>> # using ellipsis (...) - >>> csv_contents = open(temp.name, 'r').read() + >>> csv_name = glob.glob(f'{temp.name}/*.csv')[0] + >>> csv_contents = open(csv_name, 'r').read() >>> print("example_output:\n", csv_contents) # doctest:+ELLIPSIS example_output: ...,stage,epoch @@ -67,12 +73,14 @@ class CSVLogger(Callback): def __init__( self, - path: Union[str, pathlib.Path], + log_dir: Union[str, pathlib.Path], keys: Optional[Container] = None, exclude: Optional[Container] = None, ): super().__init__() - self.path = pathlib.Path(path) + self.path = ( + pathlib.Path(log_dir) / f"log_{datetime.now().isoformat()}.csv" + ) self.keys = keys self.exclude = set() if exclude is None else exclude diff --git a/src/myrtlespeech/run/run.py b/src/myrtlespeech/run/run.py index ec0b77ad..2fa6fdf9 100644 --- a/src/myrtlespeech/run/run.py +++ b/src/myrtlespeech/run/run.py @@ -229,9 +229,7 @@ def run() -> None: if args.stop_epoch_after is not None: callbacks.append(StopEpochAfter(epoch_batches=args.stop_epoch_after)) - callbacks.extend( - [CSVLogger(log_dir.joinpath("log.csv")), Saver(log_dir, seq_to_seq)] - ) + callbacks.extend([CSVLogger(log_dir), Saver(log_dir, seq_to_seq)]) # train and evaluate fit(