在Pytorch中,训练和验证过程一般都要用到DataLoader
类来装载数据集,它会返回一个可迭代对象用于训练和验证。DataLoader
需要的对象是Data
类型,定义方法如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| class MyDataset(Dataset): ''' x: Features. y: Targets, if none, do prediction. ''' def __init__(self, x, y=None): if y is None: self.y = y else: self.y = torch.FloatTensor(y) self.x = torch.FloatTensor(x)
def __getitem__(self, idx): if self.y is None: return self.x[idx] else: return self.x[idx], self.y[idx]
def __len__(self): return len(self.x)
|
其中__getitem__
和__len__
是必须要实现的,对于可以以顺序表存储的数据都可以使用以上方式存储。而对于难以随机读取或者开销过大的数据(例如从数据库、远程服务器中读取数据),可以用IterableDataset
。