- GPU设置
# 确定GPU是否可用以及GPU数量
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_gpu = torch.cuda.device_count()
net = NET() # 实例化网络
- 尝试单个batch过拟合,确认网络在工作
first_batch = next(iter(train_loader))
for batch_idx, (data) in enumerate([first_batch] * 50):
# train