-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodule.py
More file actions
104 lines (75 loc) · 2.99 KB
/
Copy pathmodule.py
File metadata and controls
104 lines (75 loc) · 2.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import collections
import mysql.connector
class BlockSkipGramModel(nn.Module):
def __init__(self, num_embedding=1162, embedding_dim=128):
super().__init__()
self.embed = nn.Embedding(num_embeddings=num_embedding, embedding_dim=embedding_dim)
self.linear = nn.Linear(in_features=embedding_dim, out_features=num_embedding)
self.activation = nn.Softmax(dim=1)
with torch.no_grad():
self.embed.weight.uniform_(-1, 1)
def forward(self, x):
x = self.embed(x)
x = self.linear(x)
x = self.activation(x)
return x
class BlockSkipGramModelDataSet(Dataset):
def __init__(self, user, password) -> None:
super().__init__()
_, cursor = self.connect_to_db(user, password)
cursor.execute("SELECT count(*) from blocks")
data_nums = cursor.fetchone()[0]
self.data_nums = data_nums
self.user = user
self.password = password
self._init = False
@staticmethod
def connect_to_db(user, password):
connection = mysql.connector.connect(
host="127.0.0.1",
port=3188,
user=user,
password=password,
database="defaults",
auth_plugin="mysql_native_password"
)
cursor = connection.cursor()
cursor.execute('set global max_connections=1000')
cursor.execute('set global max_allowed_packet=1048576000')
return connection, cursor
def _init_connection(self):
if self._init == True:
return
connection, cursor = self.connect_to_db(self.user, self.password)
self.connection = connection
self.cursor = cursor
self._init = True
def __getitem__(self, index):
self._init_connection()
return (self.cursor, str(index))
# cursor = self.cursor
# # cursor.execute('SELECT center, target FROM blocks LIMIT 1 OFFSET {}'.format(index))
# cursor.execute('SELECT center, target FROM blocks where id = {}'.format(index))
# data, label = cursor.fetchone()
# return torch.tensor(data, dtype=torch.long), torch.tensor(label, dtype=torch.long)
def __len__(self):
return self.data_nums
@staticmethod
def collate_fn(batch_indices):
batch_size = len(batch_indices)
cursor = batch_indices[0][0]
index_list = [index for _, index in batch_indices]
cursor.execute('SELECT center, target FROM blocks where id in ({})'.format(','.join(index_list)))
data, label = torch.tensor(cursor.fetchall()).split(split_size=1, dim=1)
return data.squeeze(1), label.squeeze(1)
def tokenizer():
vocab = collections.OrderedDict()
with open('vocab.txt', mode='r') as f:
tokens = f.readlines()
for index, token in enumerate(tokens):
token = token.rstrip()
vocab[token] = index
return vocab