diff --git a/projects/.gitignore b/projects/.gitignore new file mode 100644 index 0000000..f10c0ff --- /dev/null +++ b/projects/.gitignore @@ -0,0 +1,52 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] + +# C extensions +*.so + +# Distribution / packaging +bin/ +build/ +develop-eggs/ +dist/ +eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +*.egg-info/ +.installed.cfg +*.egg + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +.tox/ +.coverage +.cache +nosetests.xml +coverage.xml + +# Translations +*.mo + +# Mr Developer +.mr.developer.cfg +.project +.pydevproject + +# Rope +.ropeproject + +# Django stuff: +*.log +*.pot + +# Sphinx documentation +docs/_build/ + +.venv \ No newline at end of file diff --git a/projects/vad/files/SI499.wav b/projects/vad/files/SI499.wav new file mode 100644 index 0000000..6538a81 Binary files /dev/null and b/projects/vad/files/SI499.wav differ diff --git a/projects/vad/files/checkpoint_19 b/projects/vad/files/checkpoint_19 new file mode 100644 index 0000000..81ebb0b Binary files /dev/null and b/projects/vad/files/checkpoint_19 differ diff --git a/projects/vad/inference.py b/projects/vad/inference.py new file mode 100644 index 0000000..9cbd8d6 --- /dev/null +++ b/projects/vad/inference.py @@ -0,0 +1,73 @@ +import torch +import argparse +import numpy as np + +from model.models import DNN +from utils.audio_processing import signal_to_melspec +from utils.utils import load_wav, get_config +from utils.dataloader import process_neighbor + + +def load_checkpoint(model, checkpoint_path): + checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') + model.load_state_dict(checkpoint_dict['state_dict_model']) + return model + +def get_sample_all_positions(config, mel_file): + neighbors = process_neighbor(config["p"], config["f"]) + + if isinstance(mel_file, str): + mel = np.load(mel_file) + else: + mel = mel_file + + mel = np.pad(mel, ((0, 0), (config["p"], config["f"]))) + C, L = mel.shape + + all_positions = np.arange(config["p"], L-config["f"]) + mel_neighbors = [[mel[:, position+n] for n in neighbors] for position in all_positions] + mel_neighbors = np.asarray(mel_neighbors).reshape(-1, C*len(neighbors)) + mel = torch.from_numpy(mel_neighbors) + + return mel.to(config["device"]) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-c', '--yaml-dir', type=str, default='./config/base.yaml', + help="YAML file for config") + parser.add_argument('-cp', '--check-point', type=str, default='./files/checkpoint_19', + help="Checkpoint file for model") + parser.add_argument('-i', '--input-audio', type=str, default='./files/SI499.wav', + help="Audio file for input") + args = parser.parse_args() + + model_config = get_config(args.yaml_dir, "model") + preprocess_config = get_config(args.yaml_dir, "preprocess") + load_checkpoint_file = args.check_point + input_audio_file = args.input_audio + + model = DNN(model_config['in_features'], model_config['hidden_features_list'], model_config['dropout']).to(model_config['device']) + + if load_checkpoint_file is not None: + model = load_checkpoint(model, load_checkpoint_file) + + + signal_wav = load_wav(input_audio_file, target_sr=preprocess_config['sr_model']) + mel = signal_to_melspec(signal_wav, + sr=preprocess_config['sr_model'], + n_fft=preprocess_config['n_fft'], + hop_length=preprocess_config['hop_length'], + win_length=preprocess_config['win_length'], + window=preprocess_config['fn_window'], + n_mel_channels=preprocess_config['n_mel_channels'], + mel_fmin=preprocess_config['mel_fmin'], + mel_fmax=preprocess_config['mel_fmax']) + + signal_wav = signal_wav[:mel.shape[-1]*preprocess_config['hop_length']] + mel = mel[:signal_wav.shape[-1]//preprocess_config['hop_length']] + + mels = get_sample_all_positions(model_config, mel) + scores, preds = model.infer(mels) + TP = TN = FP = FN = 0 + for pred in zip(preds.reshape(-1)): + print(pred) \ No newline at end of file