Skip to content
This repository was archived by the owner on Apr 29, 2021. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions src/myrtlespeech/run/callbacks/csv_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,20 @@


class CSVLogger(Callback):
r"""Logs at the end of an epoch to the file at ``path`` in CSV format.
r"""Logs at the end of an epoch in CSV format.

This callback saves the CSV to ``log_dir`` with file name
``log_<datetime.isoformat>.csv`` when
:py:meth:`.CallbackHandler.on_train_begin` is called.

The first entry in each row of the CSV file denotes the date and time the
entry was logged in ISO 8601 format (:py:meth:`datetime.isoformat`). The
second entry in each row of the CSV file will be a string denoting the
current stage. This will either be ``train`` or ``eval``.

Args:
path: Path of a file to write values to. The file will be truncated
when :py:meth:`.CallbackHandler.on_train_begin` is called.
log_dir: A pathlike object representing the directory to which results
will be logged.

keys: Keys in :py:data:`.CallbackHandler.state_dict`, passed as
``**kwargs`` to the ``CSVLogger``, whose values will be logged. If
Expand All @@ -36,13 +40,14 @@ class CSVLogger(Callback):
Example:
>>> # imports
>>> import tempfile
>>> import glob
>>> from myrtlespeech.run.callbacks.callback import CallbackHandler
>>>
>>> # create file to write to
>>> temp = tempfile.NamedTemporaryFile()
>>> temp = tempfile.TemporaryDirectory()
>>>
>>> # initialize CSVLogger and CallbackHandler
>>> csv_logger = CSVLogger(path=temp.name, keys=["epoch"])
>>> csv_logger = CSVLogger(log_dir=temp.name, keys=["epoch"])
>>> cb_handler = CallbackHandler(callbacks=[csv_logger])
>>>
>>> # simulate training and eval example
Expand All @@ -56,7 +61,8 @@ class CSVLogger(Callback):
>>> # note the first entry in each CSV row is the date
>>> # this will change per each doctest invocation so it is ignored
>>> # using ellipsis (...)
>>> csv_contents = open(temp.name, 'r').read()
>>> csv_name = glob.glob(f'{temp.name}/*.csv')[0]
>>> csv_contents = open(csv_name, 'r').read()
>>> print("example_output:\n", csv_contents) # doctest:+ELLIPSIS
example_output:
...,stage,epoch
Expand All @@ -67,12 +73,14 @@ class CSVLogger(Callback):

def __init__(
self,
path: Union[str, pathlib.Path],
log_dir: Union[str, pathlib.Path],
keys: Optional[Container] = None,
exclude: Optional[Container] = None,
):
super().__init__()
self.path = pathlib.Path(path)
self.path = (
pathlib.Path(log_dir) / f"log_{datetime.now().isoformat()}.csv"
)
self.keys = keys
self.exclude = set() if exclude is None else exclude

Expand Down
4 changes: 1 addition & 3 deletions src/myrtlespeech/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,7 @@ def run() -> None:
if args.stop_epoch_after is not None:
callbacks.append(StopEpochAfter(epoch_batches=args.stop_epoch_after))

callbacks.extend(
[CSVLogger(log_dir.joinpath("log.csv")), Saver(log_dir, seq_to_seq)]
)
callbacks.extend([CSVLogger(log_dir), Saver(log_dir, seq_to_seq)])

# train and evaluate
fit(
Expand Down