diff --git a/week06/spotter_home_task_abdrakhimov.ipynb b/week06/spotter_home_task_abdrakhimov.ipynb new file mode 100755 index 0000000..bba1c18 --- /dev/null +++ b/week06/spotter_home_task_abdrakhimov.ipynb @@ -0,0 +1,860 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Копия блокнота \"Untitled9.ipynb\"", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "DLqUa76aAHIg", + "outputId": "e311b2b4-d36a-4518-c7e1-650e8b04ccf6" + }, + "source": [ + "!apt install sox" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Reading package lists... Done\n", + "Building dependency tree \n", + "Reading state information... Done\n", + "The following additional packages will be installed:\n", + " libmagic-mgc libmagic1 libopencore-amrnb0 libopencore-amrwb0 libsox-fmt-alsa\n", + " libsox-fmt-base libsox3\n", + "Suggested packages:\n", + " file libsox-fmt-all\n", + "The following NEW packages will be installed:\n", + " libmagic-mgc libmagic1 libopencore-amrnb0 libopencore-amrwb0 libsox-fmt-alsa\n", + " libsox-fmt-base libsox3 sox\n", + "0 upgraded, 8 newly installed, 0 to remove and 37 not upgraded.\n", + "Need to get 760 kB of archives.\n", + "After this operation, 6,717 kB of additional disk space will be used.\n", + "Get:1 http://archive.ubuntu.com/ubuntu bionic/universe amd64 libopencore-amrnb0 amd64 0.1.3-2.1 [92.0 kB]\n", + "Get:2 http://archive.ubuntu.com/ubuntu bionic/universe amd64 libopencore-amrwb0 amd64 0.1.3-2.1 [45.8 kB]\n", + "Get:3 http://archive.ubuntu.com/ubuntu bionic-updates/main amd64 libmagic-mgc amd64 1:5.32-2ubuntu0.4 [184 kB]\n", + "Get:4 http://archive.ubuntu.com/ubuntu bionic-updates/main amd64 libmagic1 amd64 1:5.32-2ubuntu0.4 [68.6 kB]\n", + "Get:5 http://archive.ubuntu.com/ubuntu bionic-updates/universe amd64 libsox3 amd64 14.4.2-3ubuntu0.18.04.1 [226 kB]\n", + "Get:6 http://archive.ubuntu.com/ubuntu bionic-updates/universe amd64 libsox-fmt-alsa amd64 14.4.2-3ubuntu0.18.04.1 [10.6 kB]\n", + "Get:7 http://archive.ubuntu.com/ubuntu bionic-updates/universe amd64 libsox-fmt-base amd64 14.4.2-3ubuntu0.18.04.1 [32.1 kB]\n", + "Get:8 http://archive.ubuntu.com/ubuntu bionic-updates/universe amd64 sox amd64 14.4.2-3ubuntu0.18.04.1 [101 kB]\n", + "Fetched 760 kB in 2s (321 kB/s)\n", + "Selecting previously unselected package libopencore-amrnb0:amd64.\n", + "(Reading database ... 155047 files and directories currently installed.)\n", + "Preparing to unpack .../0-libopencore-amrnb0_0.1.3-2.1_amd64.deb ...\n", + "Unpacking libopencore-amrnb0:amd64 (0.1.3-2.1) ...\n", + "Selecting previously unselected package libopencore-amrwb0:amd64.\n", + "Preparing to unpack .../1-libopencore-amrwb0_0.1.3-2.1_amd64.deb ...\n", + "Unpacking libopencore-amrwb0:amd64 (0.1.3-2.1) ...\n", + "Selecting previously unselected package libmagic-mgc.\n", + "Preparing to unpack .../2-libmagic-mgc_1%3a5.32-2ubuntu0.4_amd64.deb ...\n", + "Unpacking libmagic-mgc (1:5.32-2ubuntu0.4) ...\n", + "Selecting previously unselected package libmagic1:amd64.\n", + "Preparing to unpack .../3-libmagic1_1%3a5.32-2ubuntu0.4_amd64.deb ...\n", + "Unpacking libmagic1:amd64 (1:5.32-2ubuntu0.4) ...\n", + "Selecting previously unselected package libsox3:amd64.\n", + "Preparing to unpack .../4-libsox3_14.4.2-3ubuntu0.18.04.1_amd64.deb ...\n", + "Unpacking libsox3:amd64 (14.4.2-3ubuntu0.18.04.1) ...\n", + "Selecting previously unselected package libsox-fmt-alsa:amd64.\n", + "Preparing to unpack .../5-libsox-fmt-alsa_14.4.2-3ubuntu0.18.04.1_amd64.deb ...\n", + "Unpacking libsox-fmt-alsa:amd64 (14.4.2-3ubuntu0.18.04.1) ...\n", + "Selecting previously unselected package libsox-fmt-base:amd64.\n", + "Preparing to unpack .../6-libsox-fmt-base_14.4.2-3ubuntu0.18.04.1_amd64.deb ...\n", + "Unpacking libsox-fmt-base:amd64 (14.4.2-3ubuntu0.18.04.1) ...\n", + "Selecting previously unselected package sox.\n", + "Preparing to unpack .../7-sox_14.4.2-3ubuntu0.18.04.1_amd64.deb ...\n", + "Unpacking sox (14.4.2-3ubuntu0.18.04.1) ...\n", + "Setting up libmagic-mgc (1:5.32-2ubuntu0.4) ...\n", + "Setting up libmagic1:amd64 (1:5.32-2ubuntu0.4) ...\n", + "Setting up libopencore-amrnb0:amd64 (0.1.3-2.1) ...\n", + "Setting up libopencore-amrwb0:amd64 (0.1.3-2.1) ...\n", + "Setting up libsox3:amd64 (14.4.2-3ubuntu0.18.04.1) ...\n", + "Setting up libsox-fmt-base:amd64 (14.4.2-3ubuntu0.18.04.1) ...\n", + "Setting up libsox-fmt-alsa:amd64 (14.4.2-3ubuntu0.18.04.1) ...\n", + "Setting up sox (14.4.2-3ubuntu0.18.04.1) ...\n", + "Processing triggers for libc-bin (2.27-3ubuntu1.3) ...\n", + "/sbin/ldconfig.real: /usr/local/lib/python3.7/dist-packages/ideep4py/lib/libmkldnn.so.0 is not a symbolic link\n", + "\n", + "Processing triggers for man-db (2.8.3-2ubuntu0.1) ...\n", + "Processing triggers for mime-support (3.60ubuntu1) ...\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eX_UvnL9FOB2" + }, + "source": [ + "### Baseline commands recognition (2-5 points)\n", + "\n", + "We're now going to train a classifier to recognize voice. More specifically, we'll use the [Speech Commands Dataset] that contains around 30 different words with a few thousand voice records each." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "FvHkw2rfY9k7", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "c761e2b2-ce65-46ba-f702-af37cb996294" + }, + "source": [ + "import os\n", + "from IPython.display import display, Audio\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "import numpy as np\n", + "import librosa\n", + "import torch\n", + "from torch.utils.data import TensorDataset, DataLoader\n", + "\n", + "datadir = \"speech_commands\"\n", + "\n", + "!wget http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz -O speech_commands_v0.01.tar.gz\n", + "# alternative url: https://www.dropbox.com/s/j95n278g48bcbta/speech_commands_v0.01.tar.gz?dl=1\n", + "!mkdir {datadir} && tar -C {datadir} -xvzf speech_commands_v0.01.tar.gz 1> log\n", + "\n", + "samples_by_target = {\n", + " cls: [os.path.join(datadir, cls, name) for name in os.listdir(\"./speech_commands/{}\".format(cls))]\n", + " for cls in os.listdir(datadir)\n", + " if os.path.isdir(os.path.join(datadir, cls))\n", + "}\n", + "print('Classes:', ', '.join(sorted(samples_by_target.keys())[1:]))" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "--2021-10-26 09:50:12-- http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz\n", + "Resolving download.tensorflow.org (download.tensorflow.org)... 142.250.157.128, 2404:6800:4008:c07::80\n", + "Connecting to download.tensorflow.org (download.tensorflow.org)|142.250.157.128|:80... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 1489096277 (1.4G) [application/gzip]\n", + "Saving to: ‘speech_commands_v0.01.tar.gz’\n", + "\n", + "speech_commands_v0. 100%[===================>] 1.39G 66.9MB/s in 25s \n", + "\n", + "2021-10-26 09:50:38 (55.8 MB/s) - ‘speech_commands_v0.01.tar.gz’ saved [1489096277/1489096277]\n", + "\n", + "Classes: bed, bird, cat, dog, down, eight, five, four, go, happy, house, left, marvin, nine, no, off, on, one, right, seven, sheila, six, stop, three, tree, two, up, wow, yes, zero\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "ME4cVShQ916w", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "e60311ea-a518-4c35-a93c-4d689837d251" + }, + "source": [ + "!sox --info speech_commands/bed/00176480_nohash_0.wav" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "Input File : 'speech_commands/bed/00176480_nohash_0.wav'\n", + "Channels : 1\n", + "Sample Rate : 16000\n", + "Precision : 16-bit\n", + "Duration : 00:00:01.00 = 16000 samples ~ 75 CDDA sectors\n", + "File Size : 32.0k\n", + "Bit Rate : 256k\n", + "Sample Encoding: 16-bit Signed Integer PCM\n", + "\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cvF5l-PCyd8z", + "outputId": "aec0e205-2386-4343-f6d1-b23221b58d3c" + }, + "source": [ + "from sklearn.model_selection import train_test_split\n", + "from itertools import chain\n", + "from tqdm import tqdm\n", + "import joblib as jl\n", + "\n", + "classes = (\"left\", \"right\", \"up\", \"down\", \"stop\")\n", + "\n", + "def preprocess_sample(filepath, max_length=150):\n", + " amplitudes, sr = librosa.core.load(filepath)\n", + " spectrogram = librosa.feature.melspectrogram(amplitudes, sr=sr)[:, :max_length]\n", + " spectrogram = np.pad(spectrogram, [[0, 0], [0, max(0, max_length - spectrogram.shape[1])]], mode='constant')\n", + " target = classes.index(filepath.split(os.sep)[-2])\n", + " return np.float32(spectrogram), np.int64(target)\n", + "\n", + "all_files = chain(*(samples_by_target[cls] for cls in classes))\n", + "spectrograms_and_targets = jl.Parallel(n_jobs=-1)(tqdm(list(map(jl.delayed(preprocess_sample), all_files))))\n", + "X, y = map(np.stack, zip(*spectrograms_and_targets))\n", + "X = X.transpose([0, 2, 1]) # to [batch, time, channels]\n", + "\n", + "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 11834/11834 [07:47<00:00, 25.32it/s]\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "n_G6hip4A8Uz" + }, + "source": [ + "X_train = X_train[:, None, :, :]\n", + "X_test = X_test[:, None, :, :]" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "7Ol6sywTG_Y9" + }, + "source": [ + "device = 'cuda'\n", + "batch_size = 16\n", + "\n", + "tensor_x = torch.Tensor(X_train)\n", + "tensor_y = torch.LongTensor(y_train)\n", + "\n", + "train_dataset = TensorDataset(tensor_x, tensor_y)\n", + "\n", + "tensor_x = torch.Tensor(X_test) # transform to torch tensor\n", + "tensor_y = torch.LongTensor(y_test)\n", + "\n", + "test_dataset = TensorDataset(tensor_x, tensor_y)\n", + "\n", + "\n", + "trainloader = DataLoader(train_dataset, batch_size=batch_size,\n", + " shuffle=True, num_workers=2)\n", + "testloader = DataLoader(test_dataset, batch_size=batch_size,\n", + " shuffle=False, num_workers=2)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "tI9Iyis_5gIU", + "outputId": "d772d55c-c9e0-4e13-afb5-70daa86d4d0b" + }, + "source": [ + "tensor_x[0].size()" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "torch.Size([1, 150, 128])" + ] + }, + "metadata": {}, + "execution_count": 8 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "-qr8t6wCF8vT" + }, + "source": [ + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "\n", + "class Net(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " # TODO: define your layers here\n", + " self.conv1 = nn.Conv2d(1, 64, 3)\n", + " self.bn1 = nn.BatchNorm2d(64)\n", + " self.dropout1 = nn.Dropout(0.2)\n", + " self.conv2 = nn.Conv2d(64, 128, 5)\n", + " self.bn2 = nn.BatchNorm2d(128)\n", + " self.pool1 = nn.MaxPool2d(2)\n", + "\n", + " self.conv3 = nn.Conv2d(128, 256, 3)\n", + " self.bn3 = nn.BatchNorm2d(256)\n", + " self.pool2 = nn.MaxPool2d(4)\n", + "\n", + " self.conv4 = nn.Conv2d(256, 256, 5)\n", + " self.bn4 = nn.BatchNorm2d(256)\n", + " self.pool3 = nn.MaxPool2d(6)\n", + "\n", + "\n", + " self.flatten = nn.Flatten()\n", + " self.dense1 = nn.Linear(512, 256)\n", + " self.dense2 = nn.Linear(256, 5)\n", + "\n", + "\n", + " def forward(self, x):\n", + " # TODO: apply your layers here\n", + " x = self.conv1(x)\n", + " x = self.bn1(x)\n", + " x = F.relu(x)\n", + " x = self.dropout1(x)\n", + "\n", + " x = self.conv2(x)\n", + " x = self.bn2(x)\n", + " x = F.relu(x)\n", + " x = self.pool1(x)\n", + "\n", + " x = self.conv3(x)\n", + " x = self.bn3(x)\n", + " x = F.relu(x)\n", + " x = self.pool2(x)\n", + "\n", + " x = self.conv4(x)\n", + " x = self.bn4(x)\n", + " x = F.relu(x)\n", + " x = self.pool3(x)\n", + "\n", + " x = self.flatten(x)\n", + " x = self.dense1(x)\n", + " x = F.relu(x)\n", + " x = self.dense2(x)\n", + " x = F.softmax(x)\n", + " return x\n", + "\n", + "\n", + "net = Net().to(device)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "UU5L1G05M85R" + }, + "source": [ + "class VGG11(nn.Module):\n", + " def __init__(self, in_channels, num_classes=1000):\n", + " super(VGG11, self).__init__()\n", + " self.in_channels = in_channels\n", + " self.num_classes = num_classes\n", + " # convolutional layers \n", + " self.conv_layers = nn.Sequential(\n", + " nn.Conv2d(self.in_channels, 64, kernel_size=3, padding=1),\n", + " nn.ReLU(),\n", + " nn.MaxPool2d(kernel_size=2, stride=2),\n", + " nn.Conv2d(64, 128, kernel_size=3, padding=1),\n", + " nn.ReLU(),\n", + " nn.MaxPool2d(kernel_size=2, stride=2),\n", + " nn.Conv2d(128, 256, kernel_size=3, padding=1),\n", + " nn.ReLU(),\n", + " nn.Conv2d(256, 256, kernel_size=3, padding=1),\n", + " nn.ReLU(),\n", + " nn.MaxPool2d(kernel_size=2, stride=2),\n", + " nn.Conv2d(256, 512, kernel_size=3, padding=1),\n", + " nn.ReLU(),\n", + " nn.Conv2d(512, 512, kernel_size=3, padding=1),\n", + " nn.ReLU(),\n", + " nn.MaxPool2d(kernel_size=2, stride=2),\n", + " nn.Conv2d(512, 512, kernel_size=3, padding=1),\n", + " nn.ReLU(),\n", + " nn.Conv2d(512, 512, kernel_size=3, padding=1),\n", + " nn.ReLU(),\n", + " nn.MaxPool2d(kernel_size=2, stride=2)\n", + " )\n", + " # fully connected linear layers\n", + " self.linear_layers = nn.Sequential(\n", + " nn.Linear(in_features=8192, out_features=4096),\n", + " nn.ReLU(),\n", + " nn.Dropout2d(0.5),\n", + " nn.Linear(in_features=4096, out_features=4096),\n", + " nn.ReLU(),\n", + " nn.Dropout2d(0.5),\n", + " nn.Linear(in_features=4096, out_features=self.num_classes)\n", + " )\n", + " def forward(self, x):\n", + " x = self.conv_layers(x)\n", + " \n", + " # flatten to prepare for the fully connected layers\n", + " x = x.view(x.size(0), -1)\n", + " # print(x.size())\n", + " x = self.linear_layers(x)\n", + " return x" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "6p3Yq8vSM9Bw" + }, + "source": [ + "vgg11 = VGG11(in_channels=1, num_classes=5).to(device)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "dUuwnaV6NSjM" + }, + "source": [ + "import torch.optim as optim\n", + "\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = optim.Adam(vgg11.parameters(), lr=0.001)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "vHDI51TXNLN7", + "outputId": "2bd51460-cf8b-4872-e2b3-b413ef5ca511" + }, + "source": [ + "from tqdm import tqdm\n", + "for epoch in tqdm(range(500)): # loop over the dataset multiple times\n", + "\n", + " running_loss = 0.0\n", + " for i, data in enumerate(trainloader, 0):\n", + " # get the inputs; data is a list of [inputs, labels]\n", + " inputs, labels = data\n", + " inputs = inputs.to(device)\n", + " labels = labels.to(device)\n", + " # zero the parameter gradients\n", + " optimizer.zero_grad()\n", + "\n", + " # forward + backward + optimize\n", + " outputs = vgg11(inputs)\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # print statistics\n", + " running_loss += loss.item()\n", + " if i % 2000 == 1999: # print every 2000 mini-batches\n", + " print('[%d, %5d] loss: %.3f' %\n", + " (epoch + 1, i + 1, running_loss / 2000))\n", + " running_loss = 0.0\n", + "\n", + "print('Finished Training')" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + " 37%|███▋ | 185/500 [4:33:59<7:42:59, 88.19s/it]" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "FbDU9xhQXaBP", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 244 + }, + "outputId": "1876114c-9ffb-4b56-e256-a4be9bfe9603" + }, + "source": [ + "correct_pred = {classname: 0 for classname in classes}\n", + "total_pred = {classname: 0 for classname in classes}\n", + "\n", + "with torch.no_grad():\n", + " for i, data in enumerate(testloader, 0):\n", + " images, labels = data \n", + "\n", + " images = images.to(device)\n", + " labels = labels.to(device)\n", + "\n", + " outputs = vgg11(images) \n", + " _, predictions = torch.max(outputs, 1)\n", + " # collect the correct predictions for each class\n", + " for label, prediction in zip(labels, predictions):\n", + " if label == prediction:\n", + " correct_pred[classes[label]] += 1\n", + " total_pred[classes[label]] += 1\n", + "\n", + " \n", + "# print accuracy for each class\n", + "for classname, correct_count in correct_pred.items():\n", + " accuracy = 100 * float(correct_count) / total_pred[classname]\n", + " print(\"Accuracy for class {:5s} is: {:.1f} %\".format(classname, \n", + " accuracy))" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "error", + "ename": "NameError", + "evalue": "ignored", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mcorrect_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mclassname\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mclassname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mclasses\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mtotal_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mclassname\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mclassname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mclasses\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtestloader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'classes' is not defined" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "3jpYUcwWXaDz" + }, + "source": [ + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "0i8cJvJsXaGZ" + }, + "source": [ + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "qdJhetQpXaI4" + }, + "source": [ + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "Qr2pqtnuNLnn" + }, + "source": [ + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "U7NIhlaqNLqA" + }, + "source": [ + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "MCZ7MkvsF9gs" + }, + "source": [ + "import torch.optim as optim\n", + "\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "xR1uxQ-GGGLr", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 431 + }, + "outputId": "223a2835-9677-4473-f4a0-2767ce0272ea" + }, + "source": [ + "from tqdm import tqdm\n", + "for epoch in tqdm(range(300)): # loop over the dataset multiple times\n", + "\n", + " running_loss = 0.0\n", + " for i, data in enumerate(trainloader, 0):\n", + " # get the inputs; data is a list of [inputs, labels]\n", + " inputs, labels = data\n", + " inputs = inputs.to(device)\n", + " labels = labels.to(device)\n", + " # zero the parameter gradients\n", + " optimizer.zero_grad()\n", + "\n", + " # forward + backward + optimize\n", + " outputs = net(inputs)\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # print statistics\n", + " running_loss += loss.item()\n", + " if i % 2000 == 1999: # print every 2000 mini-batches\n", + " print('[%d, %5d] loss: %.3f' %\n", + " (epoch + 1, i + 1, running_loss / 2000))\n", + " running_loss = 0.0\n", + "\n", + "print('Finished Training')" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + " 0%| | 0/300 [00:00\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 17\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 253\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 254\u001b[0m inputs=inputs)\n\u001b[0;32m--> 255\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 256\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 257\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 147\u001b[0m Variable._execution_engine.run_backward(\n\u001b[1;32m 148\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 149\u001b[0;31m allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag\n\u001b[0m\u001b[1;32m 150\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 151\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "vfYEpn9zIft6", + "outputId": "c3a3f951-d1a4-4c08-a0ab-d928d4d6e085" + }, + "source": [ + "net" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "Net(\n", + " (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))\n", + " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (dropout1): Dropout(p=0.2, inplace=False)\n", + " (conv2): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1))\n", + " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (conv3): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))\n", + " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (pool2): MaxPool2d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)\n", + " (conv4): Conv2d(256, 256, kernel_size=(5, 5), stride=(1, 1))\n", + " (bn4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (pool3): MaxPool2d(kernel_size=6, stride=6, padding=0, dilation=1, ceil_mode=False)\n", + " (flatten): Flatten(start_dim=1, end_dim=-1)\n", + " (dense1): Linear(in_features=512, out_features=256, bias=True)\n", + " (dense2): Linear(in_features=256, out_features=5, bias=True)\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 21 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "_I3VxkteI67a" + }, + "source": [ + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "taaoVcBhIKiq", + "outputId": "bbdc837e-c9a4-4287-cae4-66419d49f4dc" + }, + "source": [ + "correct_pred = {classname: 0 for classname in classes}\n", + "total_pred = {classname: 0 for classname in classes}\n", + "\n", + "with torch.no_grad():\n", + " for i, data in enumerate(testloader, 0):\n", + " images, labels = data \n", + "\n", + " images = images.to(device)\n", + " labels = labels.to(device)\n", + "\n", + " outputs = net(images) \n", + " _, predictions = torch.max(outputs, 1)\n", + " # collect the correct predictions for each class\n", + " for label, prediction in zip(labels, predictions):\n", + " if label == prediction:\n", + " correct_pred[classes[label]] += 1\n", + " total_pred[classes[label]] += 1\n", + "\n", + " \n", + "# print accuracy for each class\n", + "for classname, correct_count in correct_pred.items():\n", + " accuracy = 100 * float(correct_count) / total_pred[classname]\n", + " print(\"Accuracy for class {:5s} is: {:.1f} %\".format(classname, \n", + " accuracy))" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:56: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Accuracy for class left is: 40.0 %\n", + "Accuracy for class right is: 2.1 %\n", + "Accuracy for class up is: 30.8 %\n", + "Accuracy for class down is: 45.6 %\n", + "Accuracy for class stop is: 66.6 %\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ojEuXjx5DlDW" + }, + "source": [ + "Train a model: finally, lets' build and train a classifier neural network. You can use any library you like. If in doubt, consult the model & training tips below." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hwgnOrZy1E8p" + }, + "source": [ + "__Training tips:__ here's what you can try:\n", + "* __Layers:__ 1d or 2d convolutions, perhaps with some batch normalization in between;\n", + "* __Architecture:__ VGG-like, residual, highway, densely-connected, MatchboxNet, Dilated convs - you name it :)\n", + "* __Batch size matters:__ smaller batches usually train slower but better. Try to find the one that suits you best.\n", + "* __Data augmentation:__ add background noise, faster/slower, change pitch;\n", + "* __Average checkpoints:__ you can make model more stable with [this simple technique (arxiv)](https://arxiv.org/abs/1803.05407)\n", + "* __For full scale stage:__ make sure you're not losing too much data due to max_length in the pre-processing stage!\n", + "\n", + "These are just recommendations. As long as your model works, you're not required to follow them." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Fvf8UCsPDvj2" + }, + "source": [ + "### Full scale commands recognition (3+ points)\n", + "\n", + "Your final task is to train a full-scale voice command spotter and apply it to a video:\n", + "1. Build the dataset with all 30+ classes (directions, digits, names, etc.)\n", + " * __Optional:__ include a special \"noise\" class that contains random unrelated sounds\n", + " * You can download youtube videos with [`youtube-dl`](https://ytdl-org.github.io/youtube-dl/index.html) library.\n", + "2. Train a model on this full dataset. Kudos for tuning its accuracy :)\n", + "3. Apply it to a audio/video of your choice to spot the occurences of each keyword\n", + " * Here's one [video about primes](https://www.youtube.com/watch?v=EK32jo7i5LQ) that you can try. It should be full of numbers :)\n", + " * There are multiple ways you can analyze the performance of your network, e.g. plot probabilities predicted for every time-step. Chances are you'll discover something useful about how to improve your model :)\n", + "\n", + "\n", + "Please briefly describe what you did in a short informal report." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "16Ux38uFD2g-" + }, + "source": [ + "" + ], + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git "a/week10/hometask/\320\232\320\276\320\277\320\270\321\217_\320\261\320\273\320\276\320\272\320\275\320\276\321\202\320\260__\320\237\321\201\320\265\320\262\320\264\320\276\321\200\320\260\320\267\320\274\320\265\321\202\320\272\320\260_\320\241\320\265\320\274\320\270\320\275\320\260\321\200_\320\270_\320\224\320\227__.ipynb" "b/week10/hometask/\320\232\320\276\320\277\320\270\321\217_\320\261\320\273\320\276\320\272\320\275\320\276\321\202\320\260__\320\237\321\201\320\265\320\262\320\264\320\276\321\200\320\260\320\267\320\274\320\265\321\202\320\272\320\260_\320\241\320\265\320\274\320\270\320\275\320\260\321\200_\320\270_\320\224\320\227__.ipynb" new file mode 100644 index 0000000..b5aeeb2 --- /dev/null +++ "b/week10/hometask/\320\232\320\276\320\277\320\270\321\217_\320\261\320\273\320\276\320\272\320\275\320\276\321\202\320\260__\320\237\321\201\320\265\320\262\320\264\320\276\321\200\320\260\320\267\320\274\320\265\321\202\320\272\320\260_\320\241\320\265\320\274\320\270\320\275\320\260\321\200_\320\270_\320\224\320\227__.ipynb" @@ -0,0 +1,2347 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "accelerator": "GPU", + "colab": { + "name": "Копия блокнота \"Псевдоразметка. Семинар и ДЗ.\"", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "5d37c505ee9445b9be75c0cd53fa065c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "HBoxView", + "_dom_classes": [], + "_model_name": "HBoxModel", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.5.0", + "box_style": "", + "layout": "IPY_MODEL_727d85fbd65d47ae8ea735c6088d922f", + "_model_module": "@jupyter-widgets/controls", + "children": [ + "IPY_MODEL_68867cee67014c349b1fcba806d00392", + "IPY_MODEL_28a98bfea00e4f28bfb88ba12f3579f3", + "IPY_MODEL_eb5ebd75e88a4bd79eedfa4b9e1dfcab" + ] + } + }, + "727d85fbd65d47ae8ea735c6088d922f": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "68867cee67014c349b1fcba806d00392": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_7294333aa47f4893a095b7f535ef79eb", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": "", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_52128d07e5d94268a5ce4adbae32d085" + } + }, + "28a98bfea00e4f28bfb88ba12f3579f3": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "ProgressView", + "style": "IPY_MODEL_83ae185b32c141449d620fad690f3153", + "_dom_classes": [], + "description": "", + "_model_name": "FloatProgressModel", + "bar_style": "success", + "max": 9912422, + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": 9912422, + "_view_count": null, + "_view_module_version": "1.5.0", + "orientation": "horizontal", + "min": 0, + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_35aa4c00872b4cde83cf30c188b0aaf8" + } + }, + "eb5ebd75e88a4bd79eedfa4b9e1dfcab": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_90a835f4ea274fe19a3ed1cef72072d1", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": " 9913344/? [00:00<00:00, 42507098.24it/s]", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_61fbd0b3aba3479d89676184971c5b72" + } + }, + "7294333aa47f4893a095b7f535ef79eb": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "52128d07e5d94268a5ce4adbae32d085": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "83ae185b32c141449d620fad690f3153": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "StyleView", + "_model_name": "ProgressStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "bar_color": null, + "_model_module": "@jupyter-widgets/controls" + } + }, + "35aa4c00872b4cde83cf30c188b0aaf8": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "90a835f4ea274fe19a3ed1cef72072d1": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "61fbd0b3aba3479d89676184971c5b72": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "5755c92608c641a8804c2141bc2f1277": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "HBoxView", + "_dom_classes": [], + "_model_name": "HBoxModel", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.5.0", + "box_style": "", + "layout": "IPY_MODEL_d14ddb5d55c0440084f4af37604ebe91", + "_model_module": "@jupyter-widgets/controls", + "children": [ + "IPY_MODEL_3035584bcd524c3884a45f05c5ca50c1", + "IPY_MODEL_4bf8d9eed12c4cf0ad3ddcf05c6f8b40", + "IPY_MODEL_835ee1d357fc4faca2848170687d2385" + ] + } + }, + "d14ddb5d55c0440084f4af37604ebe91": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "3035584bcd524c3884a45f05c5ca50c1": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_b06fd0778378484f953beff77dfb667b", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": "", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_53515cba82b348ca8e1b3ba408945b1c" + } + }, + "4bf8d9eed12c4cf0ad3ddcf05c6f8b40": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "ProgressView", + "style": "IPY_MODEL_29ff014885c94d89b70abd94efac8420", + "_dom_classes": [], + "description": "", + "_model_name": "FloatProgressModel", + "bar_style": "success", + "max": 28881, + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": 28881, + "_view_count": null, + "_view_module_version": "1.5.0", + "orientation": "horizontal", + "min": 0, + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_526fad5b957f4d5bbfdb2e1fd385a075" + } + }, + "835ee1d357fc4faca2848170687d2385": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_e92b7c31ddcc4d348aec2000e00cb9f9", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": " 29696/? [00:00<00:00, 713948.64it/s]", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_adb2e39ae6f74a40b5c67641a330181c" + } + }, + "b06fd0778378484f953beff77dfb667b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "53515cba82b348ca8e1b3ba408945b1c": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "29ff014885c94d89b70abd94efac8420": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "StyleView", + "_model_name": "ProgressStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "bar_color": null, + "_model_module": "@jupyter-widgets/controls" + } + }, + "526fad5b957f4d5bbfdb2e1fd385a075": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "e92b7c31ddcc4d348aec2000e00cb9f9": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "adb2e39ae6f74a40b5c67641a330181c": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "708e3b49100449408db3ee4c061a6b3e": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "HBoxView", + "_dom_classes": [], + "_model_name": "HBoxModel", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.5.0", + "box_style": "", + "layout": "IPY_MODEL_d9e5d5abdb834de09d471842a64e50d8", + "_model_module": "@jupyter-widgets/controls", + "children": [ + "IPY_MODEL_d657ac78dccf459abfd42607cfc7fa81", + "IPY_MODEL_2241e58cb330409797507a7682d0bee5", + "IPY_MODEL_e5283a4c4eba45639c87850bf553f875" + ] + } + }, + "d9e5d5abdb834de09d471842a64e50d8": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "d657ac78dccf459abfd42607cfc7fa81": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_53fdcfade3d24928a7da41e9917ea368", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": "", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_eac31c264856434aa4c5588bb2b5020f" + } + }, + "2241e58cb330409797507a7682d0bee5": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "ProgressView", + "style": "IPY_MODEL_b444011e318a404cb3b8ece8c2a7f157", + "_dom_classes": [], + "description": "", + "_model_name": "FloatProgressModel", + "bar_style": "success", + "max": 1648877, + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": 1648877, + "_view_count": null, + "_view_module_version": "1.5.0", + "orientation": "horizontal", + "min": 0, + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_65e66ea8bfb6474796f65621d85ab031" + } + }, + "e5283a4c4eba45639c87850bf553f875": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_54f6179d5b784ea3b82a5e910b66eca4", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": " 1649664/? [00:00<00:00, 17955604.81it/s]", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_1d99a271cd544ebdade737eff6d04dd9" + } + }, + "53fdcfade3d24928a7da41e9917ea368": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "eac31c264856434aa4c5588bb2b5020f": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "b444011e318a404cb3b8ece8c2a7f157": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "StyleView", + "_model_name": "ProgressStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "bar_color": null, + "_model_module": "@jupyter-widgets/controls" + } + }, + "65e66ea8bfb6474796f65621d85ab031": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "54f6179d5b784ea3b82a5e910b66eca4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "1d99a271cd544ebdade737eff6d04dd9": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "743b13d278b7460192361e986b1bacf2": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "HBoxView", + "_dom_classes": [], + "_model_name": "HBoxModel", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.5.0", + "box_style": "", + "layout": "IPY_MODEL_4286828ed6864f8b835f6691824b5281", + "_model_module": "@jupyter-widgets/controls", + "children": [ + "IPY_MODEL_6d39e9ea791b4854992d0fe40595f2f3", + "IPY_MODEL_7a2e413cfd6e4349913e43c6649c57a2", + "IPY_MODEL_edb8a71bc807448282f4ecce56278f2f" + ] + } + }, + "4286828ed6864f8b835f6691824b5281": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "6d39e9ea791b4854992d0fe40595f2f3": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_d337d48436074daf90d9e716d2d765a4", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": "", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_d7264da7f1de4fe4b2f8a9f9f649742c" + } + }, + "7a2e413cfd6e4349913e43c6649c57a2": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "ProgressView", + "style": "IPY_MODEL_b2832d02a5c14b6db210371a62cabdf3", + "_dom_classes": [], + "description": "", + "_model_name": "FloatProgressModel", + "bar_style": "success", + "max": 4542, + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": 4542, + "_view_count": null, + "_view_module_version": "1.5.0", + "orientation": "horizontal", + "min": 0, + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_0557be959c5a434cbb43b8e9fec73f80" + } + }, + "edb8a71bc807448282f4ecce56278f2f": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_9ecc8de1054e4d6e9ff37839b255835a", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": " 5120/? [00:00<00:00, 138688.71it/s]", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_bfda249d54734c84b7076d49329beb1b" + } + }, + "d337d48436074daf90d9e716d2d765a4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "d7264da7f1de4fe4b2f8a9f9f649742c": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "b2832d02a5c14b6db210371a62cabdf3": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "StyleView", + "_model_name": "ProgressStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "bar_color": null, + "_model_module": "@jupyter-widgets/controls" + } + }, + "0557be959c5a434cbb43b8e9fec73f80": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "9ecc8de1054e4d6e9ff37839b255835a": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "bfda249d54734c84b7076d49329beb1b": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + } + } + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "N1-8AS8I5x-H" + }, + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VHDmcbfN5rgg" + }, + "source": [ + "**Использование псевдоразметки. Семинар.**" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "sxnb33uad1d5" + }, + "source": [ + "import argparse\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "from torchvision import datasets, transforms\n", + "from torch.autograd import Variable\n", + "import random\n", + "import numpy as np" + ], + "execution_count": 1, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "A1YpI_9a5rgo" + }, + "source": [ + "torch.manual_seed(123)\n", + "torch.cuda.manual_seed(123)\n", + "np.random.seed(123)\n", + "random.seed(123)\n", + "torch.backends.cudnn.deterministic = True" + ], + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "auP4TzJO5rgo" + }, + "source": [ + "Начнем с загрузки датасета. Речевые данные (и модели, обучаемые на них) очень тяжелые, поэтому мы обойдемся чем-нибудь попроще." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "wKvD1Q8gdWkL", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 435, + "referenced_widgets": [ + "5d37c505ee9445b9be75c0cd53fa065c", + "727d85fbd65d47ae8ea735c6088d922f", + "68867cee67014c349b1fcba806d00392", + "28a98bfea00e4f28bfb88ba12f3579f3", + "eb5ebd75e88a4bd79eedfa4b9e1dfcab", + "7294333aa47f4893a095b7f535ef79eb", + "52128d07e5d94268a5ce4adbae32d085", + "83ae185b32c141449d620fad690f3153", + "35aa4c00872b4cde83cf30c188b0aaf8", + "90a835f4ea274fe19a3ed1cef72072d1", + "61fbd0b3aba3479d89676184971c5b72", + "5755c92608c641a8804c2141bc2f1277", + "d14ddb5d55c0440084f4af37604ebe91", + "3035584bcd524c3884a45f05c5ca50c1", + "4bf8d9eed12c4cf0ad3ddcf05c6f8b40", + "835ee1d357fc4faca2848170687d2385", + "b06fd0778378484f953beff77dfb667b", + "53515cba82b348ca8e1b3ba408945b1c", + "29ff014885c94d89b70abd94efac8420", + "526fad5b957f4d5bbfdb2e1fd385a075", + "e92b7c31ddcc4d348aec2000e00cb9f9", + "adb2e39ae6f74a40b5c67641a330181c", + "708e3b49100449408db3ee4c061a6b3e", + "d9e5d5abdb834de09d471842a64e50d8", + "d657ac78dccf459abfd42607cfc7fa81", + "2241e58cb330409797507a7682d0bee5", + "e5283a4c4eba45639c87850bf553f875", + "53fdcfade3d24928a7da41e9917ea368", + "eac31c264856434aa4c5588bb2b5020f", + "b444011e318a404cb3b8ece8c2a7f157", + "65e66ea8bfb6474796f65621d85ab031", + "54f6179d5b784ea3b82a5e910b66eca4", + "1d99a271cd544ebdade737eff6d04dd9", + "743b13d278b7460192361e986b1bacf2", + "4286828ed6864f8b835f6691824b5281", + "6d39e9ea791b4854992d0fe40595f2f3", + "7a2e413cfd6e4349913e43c6649c57a2", + "edb8a71bc807448282f4ecce56278f2f", + "d337d48436074daf90d9e716d2d765a4", + "d7264da7f1de4fe4b2f8a9f9f649742c", + "b2832d02a5c14b6db210371a62cabdf3", + "0557be959c5a434cbb43b8e9fec73f80", + "9ecc8de1054e4d6e9ff37839b255835a", + "bfda249d54734c84b7076d49329beb1b" + ] + }, + "outputId": "1c26c4a1-84f6-4657-a0ea-03d5eeb539b8" + }, + "source": [ + "train_dataset = \\\n", + " datasets.MNIST('./data', train=True, download=True,\n", + " transform=transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.1307,), (0.3081,))\n", + " ]))\n", + "test_dataset = \\\n", + " datasets.MNIST('./data', train=False, transform=transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.1307,), (0.3081,))\n", + " ]))" + ], + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", + "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz\n" + ] + }, + { + "output_type": "display_data", + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5d37c505ee9445b9be75c0cd53fa065c", + "version_minor": 0, + "version_major": 2 + }, + "text/plain": [ + " 0%| | 0/9912422 [00:00 7:\n", + " return True\n", + " else:\n", + " return False" + ], + "execution_count": 13, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "scjxRQfR5rgr", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "1a6eb17f-627d-4056-b746-168d443a9e77" + }, + "source": [ + "sampling_iteration = 0\n", + "while True:\n", + " labeled_train_dataset, unlabeled_train_dataset = torch.utils.data.random_split(train_dataset, [100, 59900])\n", + " if check_dataset(labeled_train_dataset):\n", + " break\n", + " sampling_iteration += 1\n", + "print(f'Split the dataset after {sampling_iteration} resamplings')" + ], + "execution_count": 18, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Split the dataset after 0 resamplings\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "9XDTdYXunhve" + }, + "source": [ + "test_loader = torch.utils.data.DataLoader(\n", + " test_dataset, batch_size=64, shuffle=False)\n", + "labeled_train_loader = torch.utils.data.DataLoader(\n", + " labeled_train_dataset, batch_size=64, shuffle=True)\n", + "unlabeled_train_loader = torch.utils.data.DataLoader(\n", + " unlabeled_train_dataset, batch_size=64, shuffle=False)" + ], + "execution_count": 19, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Xyb9Uk505rgs" + }, + "source": [ + "Теперь, когда мы получили данные, определим архитектуру сети. Возьмем простую сверточную сетку с droupout'ом." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "BXHh9QBMi1md" + }, + "source": [ + "class Net(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv1 = nn.Conv2d(1, 20, kernel_size=5)\n", + " self.conv2 = nn.Conv2d(20, 40, kernel_size=5)\n", + " self.dropout = nn.Dropout2d(p=0.5)\n", + " self.fc1 = nn.Linear(640, 150)\n", + " self.fc2 = nn.Linear(150, 10)\n", + " self.log_softmax = nn.LogSoftmax(dim=1)\n", + "\n", + " def forward(self, x):\n", + " x = x.view(-1, 1, 28, 28)\n", + " x = F.relu(F.max_pool2d(self.conv1(x), 2))\n", + " x = F.relu(F.max_pool2d(self.dropout(self.conv2(x)), 2))\n", + " x = x.view(-1, 640)\n", + " x = F.relu(self.fc1(x))\n", + " x = self.dropout(x)\n", + " x = F.relu(self.fc2(x))\n", + " x = self.log_softmax(x)\n", + " return x" + ], + "execution_count": 20, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HaG1bxhI5rgt" + }, + "source": [ + "Опишем вспомогательные функции." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "UwbKihvcjsIl" + }, + "source": [ + "def train(epoch_idx, model, optimizer, train_loader, loss_func=F.nll_loss):\n", + " model.train()\n", + " for batch_idx, (x, target) in enumerate(train_loader):\n", + " x, target = x.cuda(), target.cuda()\n", + " optimizer.zero_grad()\n", + " output = model(x)\n", + " loss = loss_func(output, target)\n", + " loss.backward()\n", + " optimizer.step()" + ], + "execution_count": 21, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "esbjg0S0lF-a" + }, + "source": [ + "def test(epoch_idx, model, test_loader):\n", + " model.eval()\n", + " test_loss = 0.0\n", + " correct = 0\n", + " with torch.no_grad():\n", + " for x, target in test_loader:\n", + " x, target = x.cuda(), target.cuda()\n", + " output = model(x)\n", + " test_loss += F.nll_loss(output, target, size_average=False).item()\n", + " pred = output.data.max(1, keepdim=True)[1]\n", + " correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()\n", + "\n", + " test_loss /= len(test_loader.dataset)\n", + " print('Epoch {}: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(\n", + " epoch_idx, test_loss, correct, len(test_loader.dataset),\n", + " 100. * correct / len(test_loader.dataset)))" + ], + "execution_count": 22, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "WzVXnGtKc_pf" + }, + "source": [ + "def predict(model, loader):\n", + " model.eval()\n", + " result = []\n", + " with torch.no_grad():\n", + " for x, _ in loader:\n", + " result.append(model(x.cuda()))\n", + " return torch.cat(result)" + ], + "execution_count": 23, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ga7bO-Sx5rgu" + }, + "source": [ + "Создадим модель и обучим ее на нашем размеченном датасете." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "X831vbLVuAjY" + }, + "source": [ + "model = Net().cuda()\n", + "optimizer = torch.optim.SGD(model.parameters(), lr=0.1)" + ], + "execution_count": 24, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "h_jXa_zUuGbi", + "scrolled": true, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "b94b947f-8a9d-41c7-ba6a-1eac9a450a8c" + }, + "source": [ + "for i in range(400):\n", + " train(i, model, optimizer, labeled_train_loader)\n", + " if i % 10 == 0:\n", + " test(i, model, test_loader)" + ], + "execution_count": 25, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.7/dist-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n", + " warnings.warn(warning.format(ret))\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch 0: Average loss: 2.2950, Accuracy: 1253/10000 (13%)\n", + "Epoch 10: Average loss: 1.8219, Accuracy: 4334/10000 (43%)\n", + "Epoch 20: Average loss: 0.9842, Accuracy: 7412/10000 (74%)\n", + "Epoch 30: Average loss: 0.7605, Accuracy: 7487/10000 (75%)\n", + "Epoch 40: Average loss: 0.5439, Accuracy: 8425/10000 (84%)\n", + "Epoch 50: Average loss: 0.5683, Accuracy: 8399/10000 (84%)\n", + "Epoch 60: Average loss: 0.4262, Accuracy: 8818/10000 (88%)\n", + "Epoch 70: Average loss: 0.4740, Accuracy: 8735/10000 (87%)\n", + "Epoch 80: Average loss: 0.4628, Accuracy: 8794/10000 (88%)\n", + "Epoch 90: Average loss: 0.4893, Accuracy: 8789/10000 (88%)\n", + "Epoch 100: Average loss: 0.4700, Accuracy: 8871/10000 (89%)\n", + "Epoch 110: Average loss: 0.4666, Accuracy: 8882/10000 (89%)\n", + "Epoch 120: Average loss: 0.4790, Accuracy: 8878/10000 (89%)\n", + "Epoch 130: Average loss: 0.5363, Accuracy: 8788/10000 (88%)\n", + "Epoch 140: Average loss: 0.5816, Accuracy: 8787/10000 (88%)\n", + "Epoch 150: Average loss: 0.5435, Accuracy: 8812/10000 (88%)\n", + "Epoch 160: Average loss: 0.5552, Accuracy: 8779/10000 (88%)\n", + "Epoch 170: Average loss: 0.5199, Accuracy: 8889/10000 (89%)\n", + "Epoch 180: Average loss: 0.5467, Accuracy: 8853/10000 (89%)\n", + "Epoch 190: Average loss: 0.5821, Accuracy: 8840/10000 (88%)\n", + "Epoch 200: Average loss: 0.5401, Accuracy: 8898/10000 (89%)\n", + "Epoch 210: Average loss: 0.5006, Accuracy: 8984/10000 (90%)\n", + "Epoch 220: Average loss: 0.5559, Accuracy: 8884/10000 (89%)\n", + "Epoch 230: Average loss: 0.5846, Accuracy: 8852/10000 (89%)\n", + "Epoch 240: Average loss: 0.5949, Accuracy: 8868/10000 (89%)\n", + "Epoch 250: Average loss: 0.5607, Accuracy: 8915/10000 (89%)\n", + "Epoch 260: Average loss: 0.5854, Accuracy: 8893/10000 (89%)\n", + "Epoch 270: Average loss: 0.5632, Accuracy: 8917/10000 (89%)\n", + "Epoch 280: Average loss: 0.5880, Accuracy: 8907/10000 (89%)\n", + "Epoch 290: Average loss: 0.6756, Accuracy: 8801/10000 (88%)\n", + "Epoch 300: Average loss: 0.6404, Accuracy: 8873/10000 (89%)\n", + "Epoch 310: Average loss: 0.6383, Accuracy: 8851/10000 (89%)\n", + "Epoch 320: Average loss: 0.6970, Accuracy: 8766/10000 (88%)\n", + "Epoch 330: Average loss: 0.6428, Accuracy: 8870/10000 (89%)\n", + "Epoch 340: Average loss: 0.6107, Accuracy: 8865/10000 (89%)\n", + "Epoch 350: Average loss: 0.6697, Accuracy: 8766/10000 (88%)\n", + "Epoch 360: Average loss: 0.6730, Accuracy: 8789/10000 (88%)\n", + "Epoch 370: Average loss: 0.5835, Accuracy: 8907/10000 (89%)\n", + "Epoch 380: Average loss: 0.6172, Accuracy: 8868/10000 (89%)\n", + "Epoch 390: Average loss: 0.5891, Accuracy: 8910/10000 (89%)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RRFYDeIg5rgv" + }, + "source": [ + "Теперь попробуем побить этот результат с помощью псевдолейблов. Напишем функцию, которая принимает модель и возращает DataLoader с хард-лейблами, и запустим обучение." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "f1FwrOLe5rgw" + }, + "source": [ + "def get_pseudo_loader(model):\n", + " dataset = list(unlabeled_train_dataset)\n", + " soft_labels = predict(model, dataset)\n", + " hard_labels = torch.argmax(soft_labels, 1)\n", + " for idx, i in enumerate(dataset):\n", + " # print(i[1], hard_labels[idx])\n", + " dataset[idx] = (i[0], hard_labels[idx])\n", + " return torch.utils.data.DataLoader(\n", + " dataset, batch_size=64, shuffle=True)" + ], + "execution_count": 59, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "lbFpYgqC5rgw" + }, + "source": [ + "model_hard = Net().cuda()\n", + "model_hard.load_state_dict(model.state_dict())\n", + "optimizer_hard = torch.optim.SGD(model_hard.parameters(), lr=0.1)" + ], + "execution_count": 60, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "i4WjNPrM2qZi", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "7b7f3826-4d58-44d1-a005-9d46be95ad17" + }, + "source": [ + "hard_labeled_loader = get_pseudo_loader(model)\n", + "for i in range(10):\n", + " train(i, model_hard, optimizer_hard, hard_labeled_loader)\n", + " train(i, model_hard, optimizer_hard, labeled_train_loader)\n", + " test(i, model_hard, test_loader)" + ], + "execution_count": 61, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.7/dist-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n", + " warnings.warn(warning.format(ret))\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch 0: Average loss: 0.3560, Accuracy: 9014/10000 (90%)\n", + "Epoch 1: Average loss: 0.4231, Accuracy: 8994/10000 (90%)\n", + "Epoch 2: Average loss: 0.3778, Accuracy: 9105/10000 (91%)\n", + "Epoch 3: Average loss: 0.4300, Accuracy: 9056/10000 (91%)\n", + "Epoch 4: Average loss: 0.4270, Accuracy: 9007/10000 (90%)\n", + "Epoch 5: Average loss: 0.4094, Accuracy: 9032/10000 (90%)\n", + "Epoch 6: Average loss: 0.3992, Accuracy: 9053/10000 (91%)\n", + "Epoch 7: Average loss: 0.4301, Accuracy: 9017/10000 (90%)\n", + "Epoch 8: Average loss: 0.4489, Accuracy: 9050/10000 (90%)\n", + "Epoch 9: Average loss: 0.4518, Accuracy: 9077/10000 (91%)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7gcvAXOY5rgx" + }, + "source": [ + "**Итеративная псевдоразметка.**\n", + "\n", + "Мы уже видим небольшое улучшение, но можно пойти дальше." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "oZ00dzMK5rgx" + }, + "source": [ + "model_hard_iter = Net().cuda()\n", + "model_hard_iter.load_state_dict(model.state_dict())\n", + "optimizer_hard_iter = torch.optim.SGD(model_hard_iter.parameters(), lr=0.1)" + ], + "execution_count": 62, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "REl_v-Sf5rgx", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "72387cd3-d19b-4550-a441-17fba81dda4f" + }, + "source": [ + "for i in range(20):\n", + " hard_labeled_loader = get_pseudo_loader(model_hard_iter)\n", + " train(i, model_hard_iter, optimizer_hard_iter, hard_labeled_loader)\n", + " train(i, model_hard_iter, optimizer_hard_iter, labeled_train_loader)\n", + " test(i, model_hard_iter, test_loader)" + ], + "execution_count": 63, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.7/dist-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n", + " warnings.warn(warning.format(ret))\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch 0: Average loss: 0.3581, Accuracy: 9062/10000 (91%)\n", + "Epoch 1: Average loss: 0.3643, Accuracy: 9105/10000 (91%)\n", + "Epoch 2: Average loss: 0.3638, Accuracy: 9127/10000 (91%)\n", + "Epoch 3: Average loss: 0.3718, Accuracy: 9217/10000 (92%)\n", + "Epoch 4: Average loss: 0.3325, Accuracy: 9247/10000 (92%)\n", + "Epoch 5: Average loss: 0.3360, Accuracy: 9296/10000 (93%)\n", + "Epoch 6: Average loss: 0.3014, Accuracy: 9322/10000 (93%)\n", + "Epoch 7: Average loss: 0.2760, Accuracy: 9367/10000 (94%)\n", + "Epoch 8: Average loss: 0.2848, Accuracy: 9373/10000 (94%)\n", + "Epoch 9: Average loss: 0.3037, Accuracy: 9376/10000 (94%)\n", + "Epoch 10: Average loss: 0.3252, Accuracy: 9364/10000 (94%)\n", + "Epoch 11: Average loss: 0.3257, Accuracy: 9374/10000 (94%)\n", + "Epoch 12: Average loss: 0.2874, Accuracy: 9427/10000 (94%)\n", + "Epoch 13: Average loss: 0.2856, Accuracy: 9434/10000 (94%)\n", + "Epoch 14: Average loss: 0.2901, Accuracy: 9423/10000 (94%)\n", + "Epoch 15: Average loss: 0.2860, Accuracy: 9437/10000 (94%)\n", + "Epoch 16: Average loss: 0.2794, Accuracy: 9466/10000 (95%)\n", + "Epoch 17: Average loss: 0.2736, Accuracy: 9478/10000 (95%)\n", + "Epoch 18: Average loss: 0.2820, Accuracy: 9475/10000 (95%)\n", + "Epoch 19: Average loss: 0.2833, Accuracy: 9454/10000 (95%)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ri3o1-Mz5rgy" + }, + "source": [ + "**Домашнее задание.**\n", + "\n", + "Модифицировать функцию `get_pseudo_loader`, чтобы она могла возвращать софт-лейблы (2 балла).\n", + "\n", + "Правильно запустить обучение - в качестве лосса используем KL-дивергенцию. Получить accuracy 90% или выше. (+5 баллов).\n", + "\n", + "Интуитивно кажется, что модель не должна ничему учиться, т.к. ее выход будет полностью совпадать с софт-лейблами. Напишите (текстом), почему тем не менее удается сильно выиграть относительно бейзлайна. (+3 балла)." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "BXezhZ325rgy" + }, + "source": [ + "model_soft_iter = Net().cuda()\n", + "model_soft_iter.load_state_dict(model.state_dict())\n", + "optimizer_soft_iter = torch.optim.SGD(model_soft_iter.parameters(), lr=0.1)\n", + "KL_loss = torch.nn.KLDivLoss(log_target=True, reduction='batchmean')" + ], + "execution_count": 109, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "UA5DmEKV5rgz" + }, + "source": [ + "def get_pseudo_loader(model, soft=False):\n", + " dataset = list(unlabeled_train_dataset)\n", + " soft_labels = predict(model, dataset)\n", + " if soft == True:\n", + " for idx, i in enumerate(dataset):\n", + " # print(i[1], hard_labels[idx])\n", + " dataset[idx] = (i[0], soft_labels[idx])\n", + " else:\n", + " hard_labels = torch.argmax(soft_labels, 1)\n", + " for idx, i in enumerate(dataset):\n", + " # print(i[1], hard_labels[idx])\n", + " dataset[idx] = (i[0], hard_labels[idx])\n", + " return torch.utils.data.DataLoader(\n", + " dataset, batch_size=64, shuffle=True)" + ], + "execution_count": 102, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def test(epoch_idx, model, test_loader):\n", + " model.eval()\n", + " test_loss = 0.0\n", + " correct = 0\n", + " with torch.no_grad():\n", + " for x, target in test_loader:\n", + " x, target = x.cuda(), target.cuda()\n", + " output = model(x)\n", + " test_loss += F.nll_loss(output, target, size_average=False).item()\n", + " pred = output.data.max(1, keepdim=True)[1]\n", + " # print(pred)\n", + " correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()\n", + "\n", + " test_loss /= len(test_loader.dataset)\n", + " print('Epoch {}: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(\n", + " epoch_idx, test_loss, correct, len(test_loader.dataset),\n", + " 100. * correct / len(test_loader.dataset)))" + ], + "metadata": { + "id": "-eWO33Lfta00" + }, + "execution_count": 103, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def train(epoch_idx, model, optimizer, train_loader, loss_func=F.nll_loss):\n", + " model.train()\n", + " for batch_idx, (x, target) in enumerate(train_loader):\n", + " x, target = x.cuda(), target.cuda()\n", + " optimizer.zero_grad()\n", + " output = model(x)\n", + " loss = loss_func(output, target)\n", + " loss.backward()\n", + " optimizer.step()" + ], + "metadata": { + "id": "A_oqSzkiucNz" + }, + "execution_count": 105, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "first = torch.Tensor([-3.2228e+01, -2.4380e+01, -2.5240e+01, -2.0968e+01, -3.2228e+01,-3.2228e+01, -3.2228e+01, 0.0000e+00, -3.1462e+01, -3.2108e+01])\n", + "second = torch.Tensor([-3.8978e+01, -3.6266e+01, -3.6124e+01, -3.1969e+01, -3.8322e+01, -3.8978e+01, -3.8978e+01, 0.0000e+00, -3.7471e+01, -3.3333e+01])" + ], + "metadata": { + "id": "_watUdFcy3W2" + }, + "execution_count": 96, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "KL_loss(first, second)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "OHbV-3u9zGt_", + "outputId": "d354e52f-c182-4f25-a2b2-e0c056d0cb08" + }, + "execution_count": 110, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor(-1.5284e-14)" + ] + }, + "metadata": {}, + "execution_count": 110 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "KD-uX-Cl5rgz", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "2088f056-b519-4dc9-a93a-3b9205c1698c" + }, + "source": [ + "for i in range(20):\n", + " soft_labeled_loader = get_pseudo_loader(model_soft_iter, soft=True)\n", + " train(i, model_soft_iter, optimizer_soft_iter, soft_labeled_loader, loss_func=torch.nn.KLDivLoss(log_target=True, reduction='batchmean'))\n", + " train(i, model_soft_iter, optimizer_soft_iter, labeled_train_loader)\n", + " test(i, model_soft_iter, test_loader)\n" + ], + "execution_count": 111, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.7/dist-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n", + " warnings.warn(warning.format(ret))\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch 0: Average loss: 0.3404, Accuracy: 9095/10000 (91%)\n", + "Epoch 1: Average loss: 0.2865, Accuracy: 9202/10000 (92%)\n", + "Epoch 2: Average loss: 0.2748, Accuracy: 9226/10000 (92%)\n", + "Epoch 3: Average loss: 0.2846, Accuracy: 9194/10000 (92%)\n", + "Epoch 4: Average loss: 0.2725, Accuracy: 9221/10000 (92%)\n", + "Epoch 5: Average loss: 0.2653, Accuracy: 9237/10000 (92%)\n", + "Epoch 6: Average loss: 0.2595, Accuracy: 9275/10000 (93%)\n", + "Epoch 7: Average loss: 0.2522, Accuracy: 9290/10000 (93%)\n", + "Epoch 8: Average loss: 0.2471, Accuracy: 9284/10000 (93%)\n", + "Epoch 9: Average loss: 0.2415, Accuracy: 9310/10000 (93%)\n", + "Epoch 10: Average loss: 0.2400, Accuracy: 9316/10000 (93%)\n", + "Epoch 11: Average loss: 0.2422, Accuracy: 9301/10000 (93%)\n", + "Epoch 12: Average loss: 0.2397, Accuracy: 9313/10000 (93%)\n", + "Epoch 13: Average loss: 0.2305, Accuracy: 9335/10000 (93%)\n", + "Epoch 14: Average loss: 0.2330, Accuracy: 9312/10000 (93%)\n", + "Epoch 15: Average loss: 0.2357, Accuracy: 9328/10000 (93%)\n", + "Epoch 16: Average loss: 0.2351, Accuracy: 9319/10000 (93%)\n", + "Epoch 17: Average loss: 0.2361, Accuracy: 9353/10000 (94%)\n", + "Epoch 18: Average loss: 0.2298, Accuracy: 9358/10000 (94%)\n", + "Epoch 19: Average loss: 0.2206, Accuracy: 9392/10000 (94%)\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "" + ], + "metadata": { + "id": "KPte1ZGloPI4" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "" + ], + "metadata": { + "id": "6Sawg2j72n9m" + } + } + ] +} \ No newline at end of file