0%

PyTorch数据集构建

在Pytorch中,训练和验证过程一般都要用到DataLoader类来装载数据集,它会返回一个可迭代对象用于训练和验证。DataLoader需要的对象是Data类型,定义方法如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class MyDataset(Dataset):
'''
x: Features.
y: Targets, if none, do prediction.
'''
def __init__(self, x, y=None):
if y is None:
self.y = y
else:
self.y = torch.FloatTensor(y)
self.x = torch.FloatTensor(x)

def __getitem__(self, idx):
if self.y is None:
return self.x[idx]
else:
return self.x[idx], self.y[idx]

def __len__(self):
return len(self.x)

其中__getitem____len__是必须要实现的,对于可以以顺序表存储的数据都可以使用以上方式存储。而对于难以随机读取或者开销过大的数据(例如从数据库、远程服务器中读取数据),可以用IterableDataset