diff --git a/audinterface/core/process.py b/audinterface/core/process.py index c7e9977..55364bc 100644 --- a/audinterface/core/process.py +++ b/audinterface/core/process.py @@ -1,3 +1,4 @@ +from collections.abc import Iterable import errno import inspect import itertools @@ -17,7 +18,7 @@ from audinterface.core.typing import Timestamps -def identity(signal, sampling_rate) -> np.ndarray: +def identity(signal, sampling_rate=None) -> np.ndarray: r"""Default processing function. This function is used, @@ -104,6 +105,13 @@ class Process: multiprocessing multiprocessing: use multiprocessing instead of multithreading verbose: show debug messages + read_func: function to read in signals/data. When specified, + it needs to be able to read signals signal data as well + as text data. + Per default, :func:`audinterface.utils.read_audio` + will be used for signal file(s), and + :func:`audinterface.utils.read_text` for files + with ``.json`` or ``text``extensions. Raises: ValueError: if ``resample = True``, but ``sampling_rate = None`` @@ -171,6 +179,7 @@ def __init__( num_workers: typing.Optional[int] = 1, multiprocessing: bool = False, verbose: bool = False, + read_func: typing.Callable[..., typing.Any] = None, ): if channels is not None: channels = audeer.to_list(channels) @@ -236,6 +245,14 @@ def __init__( self.win_dur = win_dur r"""Window duration.""" + # set read_audio and read_text methods + if read_func is None: + setattr(self.__class__, "read_audio", staticmethod(utils.read_audio)) + setattr(self.__class__, "read_text", staticmethod(utils.read_text)) + else: + setattr(self.__class__, "read_audio", staticmethod(read_func)) + setattr(self.__class__, "read_text", staticmethod(read_func)) + def _process_file( self, file: str, @@ -274,7 +291,7 @@ def _process_file( # Text files if ext in ["json", "txt"]: - data = utils.read_text(file, root=root) + data = self.read_text(file, root=root) y, file = self._process_data( data, idx=idx, @@ -288,7 +305,7 @@ def _process_file( # Audio/video files else: - signal, sampling_rate = utils.read_audio( + signal, sampling_rate = self.read_audio( file, start=start, end=end, @@ -489,22 +506,7 @@ def process_files( ) self.verbose = verbose - y = list(itertools.chain.from_iterable([x[0] for x in xs])) - files = list(itertools.chain.from_iterable([x[1] for x in xs])) - starts = list(itertools.chain.from_iterable([x[2] for x in xs])) - ends = list(itertools.chain.from_iterable([x[3] for x in xs])) - - if ( - len(audeer.unique(starts)) == 1 - and audeer.unique(starts)[0] is None - and len(audeer.unique(ends)) == 1 - and audeer.unique(ends)[0] is None - ): - index = audformat.filewise_index(files) - else: - index = audformat.segmented_index(files, starts, ends) - y = pd.Series(y, index) - + y = self._postprocess_xs(xs) return y def process_folder( @@ -601,10 +603,69 @@ def _process_index_wo_segment( task_description=f"Process {len(index)} segments", ) - y = list(itertools.chain.from_iterable([x[0] for x in xs])) + y = self._postprocess_xs(xs) + return y + + @staticmethod + def _postprocess_xs(xs): + """Postprocesses a list of tuples containing processed data, + files, starts, and ends, and returns a pandas Series. + + This is mainly factored into a separate method as it + is used in multiple places: + + - :meth:`process._process_index_wo_segment` + - :meth:`process._postprocess_xs` + + I find it hard to come up with less inelegance + + Parameters: + xs (list): A list of tuples containing processed data, + files, starts, and ends. + index (pd.Index): The index of the resulting pandas Series. + + Returns: + pd.Series: A pandas Series containing the postprocessed data. + """ + ys = [x[0] for x in xs] + # TODO: put into single list comprehension for all these three diagnostics + all_dict = all(map(lambda x: isinstance(x, dict), [x[0] for x in xs])) + all_iterable = all(map(lambda x: isinstance(x, Iterable), [x[0] for x in xs])) + all_text = all(map(lambda x: isinstance(x, str), [x[0] for x in xs])) + + if all_dict: + # prevent pd.Series from converting0 to list of values + keys = list(itertools.chain.from_iterable([x.keys() for x in ys])) + values = list(itertools.chain.from_iterable([x.values() for x in ys])) + y = [{x: y} for (x, y) in zip(keys, values)] + else: + # if all text, need to pack into a list in order to avoid flattening + # and the resulting dimension problems + if all_iterable and all_text: + y = list(itertools.chain.from_iterable([[x[0]] for x in xs])) + else: + y = list(itertools.chain.from_iterable([x[0] for x in xs])) + files = list(itertools.chain.from_iterable([x[1] for x in xs])) - starts = list(itertools.chain.from_iterable([x[2] for x in xs])) - ends = list(itertools.chain.from_iterable([x[3] for x in xs])) + + # avoid 'NoneType' object is not iterable error in itertools.chain + # for starts: this happens when all entries are None + try: + starts = list(itertools.chain.from_iterable([x[2] for x in xs])) + except TypeError: + pass + starts_non_iterable = [x for x in filter(None, [x[2] for x in xs])] == [] + assert starts_non_iterable, "unknown problem" + starts = [x[2] for x in xs] + + # same as for starts + try: + ends = list(itertools.chain.from_iterable([x[3] for x in xs])) + except TypeError: + pass + ends_non_iterable = [x for x in filter(None, [x[3] for x in xs])] == [] + assert ends_non_iterable, "unknown problem" + ends = [x[3] for x in xs] if ( len(audeer.unique(starts)) == 1 @@ -1147,6 +1208,8 @@ def _call_data( process_func_args = process_func_args or self.process_func_args special_args = self._special_args(idx, root, file, process_func_args) y = self.process_func(data, **special_args, **process_func_args) + # ensure non-scalar answer + y = [y] if len(y) == 1 else y return y def _special_args( diff --git a/audinterface/utils/__init__.py b/audinterface/utils/__init__.py index 0679a8b..0d9ed08 100644 --- a/audinterface/utils/__init__.py +++ b/audinterface/utils/__init__.py @@ -1,4 +1,5 @@ from audinterface.core.utils import read_audio +from audinterface.core.utils import read_text from audinterface.core.utils import signal_index from audinterface.core.utils import sliding_window from audinterface.core.utils import to_timedelta diff --git a/tests/test_process_text.py b/tests/test_process_text.py index 4238ace..79d61e9 100644 --- a/tests/test_process_text.py +++ b/tests/test_process_text.py @@ -16,6 +16,10 @@ def identity(data): return data +def data_identity(data): + return data + + def length(data): return len(data) @@ -73,6 +77,7 @@ def test_process_file( # test absolute path y = process.process_file(path) + expected_series = pd.Series( [expected_data], index=audformat.filewise_index(path), @@ -131,19 +136,28 @@ def test_process_files( paths.append(path) # test absolute paths + index = audformat.filewise_index(paths) + if num_files == 0: + index = pd.RangeIndex(0, 0, 1) + y = process.process_files(paths) expected_y = pd.Series( expected_output, - index=audformat.filewise_index(paths), + index=index, ) pd.testing.assert_series_equal(y, expected_y) # test relative paths + index = audformat.filewise_index(files) + if num_files == 0: + index = pd.RangeIndex(0, 0, 1) + y = process.process_files(files, root=root) expected_y = pd.Series( expected_output, - index=audformat.filewise_index(files), + index=index, ) + pd.testing.assert_series_equal(y, expected_y) @@ -175,7 +189,7 @@ def test_process_folder( files = [os.path.join(root, f"file{n}.{file_format}") for n in range(num_files)] for file in files: write_text_file(file, data) - y = process.process_folder(root) + y = process.process_folder(root, filetype=file_format) pd.testing.assert_series_equal( y, process.process_files(files), @@ -191,21 +205,63 @@ def test_process_folder( pd.testing.assert_series_equal(y, pd.Series(dtype=object)) +def _get_idx_type(preserve_index, segment_is_None, idx): + """Get expected index type. + + preserve_index: if ``True`` + and :attr:`audinterface.Process.segment` is ``None`` + the returned index + will be of same type + as the original one. + Otherwise it will be a segmented index + if any audio/video files are processed, + or a filewise index otherwise + """ + if preserve_index and segment_is_None: + idx_type = "segmented" if audformat.is_segmented_index(idx) else "filewise" + return idx_type + + extensions = [os.path.splitext(x)[-1] for x in idx.get_level_values(0).tolist()] + # we only use wav in fixtures so this is ok + any_media = any(["wav" in x for x in extensions]) + + if any_media: + idx_type = "segmented" + else: + idx_type = "filewise" + + return idx_type + + +def _series_generator(y, index_type: str): + for idx, value in y.items(): + if index_type == "filewise": + file = idx + yield file, value + elif index_type == "segmented": + (file, _, _) = idx + yield file, value + else: + raise ValueError("index type invalid") + + @pytest.mark.parametrize("num_workers", [1, 2, None]) @pytest.mark.parametrize("file_format", ["json", "txt"]) @pytest.mark.parametrize("multiprocessing", [False, True]) @pytest.mark.parametrize("preserve_index", [False, True]) +@pytest.mark.parametrize("process_func", [data_identity, None, identity]) def test_process_index( tmpdir, num_workers, file_format, multiprocessing, preserve_index, + process_func, ): cache_root = os.path.join(tmpdir, "cache") process = audinterface.Process( - process_func=None, + process_func=process_func, num_workers=num_workers, multiprocessing=multiprocessing, verbose=False, @@ -233,17 +289,26 @@ def test_process_index( starts=[0, 0, 1, 2], ends=[None, 1, 2, 3], ) + y = process.process_index( index, preserve_index=preserve_index, ) + if preserve_index: pd.testing.assert_index_equal(y.index, index) - for (path, _, _), value in y.items(): + + expected_idx_type = _get_idx_type(preserve_index, process.segment is None, index) + + for path, value in _series_generator(y, expected_idx_type): assert audinterface.utils.read_text(path) == data assert value == data - # Segmented index with relative paths + # for (path, _, _), value in y.items(): + # assert audinterface.utils.read_text(path) == data + # assert value == data + + # # Segmented index with relative paths index = audformat.segmented_index( [file] * 4, starts=[0, 0, 1, 2], @@ -256,25 +321,37 @@ def test_process_index( ) if preserve_index: pd.testing.assert_index_equal(y.index, index) - for (file, _, _), value in y.items(): + + for file, value in _series_generator(y, expected_idx_type): assert audinterface.utils.read_text(file, root=root) == data assert value == data + # for (file, _, _), value in y.items(): + # assert audinterface.utils.read_text(file, root=root) == data + # assert value == data + # Filewise index with absolute paths index = audformat.filewise_index(path) y = process.process_index( index, preserve_index=preserve_index, ) + if preserve_index: pd.testing.assert_index_equal(y.index, index) - for path, value in y.items(): + expected_idx_type = _get_idx_type( + preserve_index, process.segment is None, index + ) + for path, value in _series_generator(y, expected_idx_type): assert audinterface.utils.read_text(path) == data assert value == data else: + expected_idx_type = _get_idx_type( + preserve_index, process.segment is None, index + ) expected_index = audformat.filewise_index(files=list(index)) pd.testing.assert_index_equal(y.index, expected_index) - for (path, _, _), value in y.items(): + for path, value in _series_generator(y, "filewise"): assert audinterface.utils.read_text(path) == data assert value == data @@ -287,13 +364,19 @@ def test_process_index( ) if preserve_index: pd.testing.assert_index_equal(y.index, index) - for file, value in y.items(): + for file, value in _series_generator(y, "filewise"): assert audinterface.utils.read_text(file, root=root) == data assert value == data + # for file, value in y.items(): + # assert audinterface.utils.read_text(file, root=root) == data + # assert value == data else: - for (file, _, _), value in y.items(): + for file, value in _series_generator(y, "filewise"): assert audinterface.utils.read_text(file, root=root) == data assert value == data + # for (file, _, _), value in y.items(): + # assert audinterface.utils.read_text(file, root=root) == data + # assert value == data # Cache result y = process.process_index( @@ -302,10 +385,10 @@ def test_process_index( root=root, cache_root=cache_root, ) - os.remove(path) + os.remove(path) # Fails because second file does not exist - with pytest.raises(RuntimeError): + with pytest.raises(FileNotFoundError): process.process_index( index, preserve_index=preserve_index, @@ -346,7 +429,7 @@ def test_process_data( process_func_args=process_func_args, verbose=False, ) - x = process.process_signal(data, file=file) + x = process.process_data(data, file=file) if file is None: y = pd.Series([expected_signal]) @@ -493,4 +576,4 @@ def test_read_data(tmpdir, data): file = audeer.path(tmpdir, "media.txt") with open(file, "w") as fp: fp.write(data) - assert audinterface.utils.read_data(file) == data + assert audinterface.utils.read_text(file) == data