-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdata.py
More file actions
69 lines (49 loc) · 2.12 KB
/
Copy pathdata.py
File metadata and controls
69 lines (49 loc) · 2.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import pandas as pd
import json
from torch.utils.data import Dataset
from dataclasses import dataclass
from typing import List, Any
@dataclass
class CSVPromptBatch:
idx: List[int]
story: List[str]
question: List[str]
full_user_prompt: List[str]
def __post_init__(self):
if len(self.idx) != 1:
raise NotImplementedError("Haven't figured out batching for SFT yet. Need to a) mask loss for padding, and b) reshaped sampled prescribed early exits correctly")
class CSVPromptDataset(Dataset):
system_prompt: str
task_context: str
prefiller: str
def __init__(self, tsv_path: str, json_path: str = None):
self.df = pd.read_csv(tsv_path, header=0)
self.columns = self.df.columns.tolist()
assert self.columns == ['story', 'question'], "Need this dataset structure right now!"
# Load and validate JSON config if provided
if json_path:
with open(json_path, 'r') as f:
config = json.load(f)
self.system_prompt = config['system_prompt']
self.task_context = config['task_context']
self.prefiller = config['prefiller']
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
return idx, self.df.iloc[idx].to_dict()
def generate_prompt(self, batch_item: dict) -> str:
return f"{self.task_context}\n\n{batch_item['story']}\n\n{batch_item['question']}"
def collate_fn(self, batch) -> CSVPromptBatch:
# Transpose list of dicts into dict of lists
batch_dict = {col: [item[1][col] for item in batch] for col in self.columns}
batch_dict['idx'] = [item[0] for item in batch]
batch_dict['full_user_prompt'] = [self.generate_prompt(item[1]) for item in batch]
return CSVPromptBatch(**batch_dict)
# # CSV file: data.csv
# # JSON file: config.json
# dataset = CSVDataset("data.csv", json_path="config.json")
# print(dataset.system_prompt)
# print(dataset.prefiller)
# dataloader = DataLoader(dataset, batch_size=16, collate_fn=dataset.collate_fn)
# for batch in dataloader:
# print(batch) # batch.column_name is a list