1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
| all_images = torchvision.datasets.CIFAR10(train=True, root="../data", download=True)
train_augs = torchvision.transforms.Compose([ torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.ToTensor()])
test_augs = torchvision.transforms.Compose([ torchvision.transforms.ToTensor()])
def load_cifar10(is_train, augs, batch_size): dataset = torchvision.datasets.CIFAR10(root="../data", train=is_train, transform=augs, download=True) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=is_train, num_workers=d2l.get_dataloader_workers()) return dataloader
def train_batch_ch13(net, X, y, loss, trainer, devices): """用多GPU进行小批量训练""" if isinstance(X, list): X = [x.to(devices[0]) for x in X] else: X = X.to(devices[0]) y = y.to(devices[0]) net.train() trainer.zero_grad() pred = net(X) l = loss(pred, y) l.sum().backward() trainer.step() train_loss_sum = l.sum() train_acc_sum = d2l.accuracy(pred, y) return train_loss_sum, train_acc_sum
def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices=d2l.try_all_gpus()): """用多GPU进行模型训练""" timer, num_batches = d2l.Timer(), len(train_iter) animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1], legend=['train loss', 'train acc', 'test acc']) net = nn.DataParallel(net, device_ids=devices).to(devices[0]) for epoch in range(num_epochs): metric = d2l.Accumulator(4) for i, (features, labels) in enumerate(train_iter): timer.start() l, acc = train_batch_ch13( net, features, labels, loss, trainer, devices) metric.add(l, acc, labels.shape[0], labels.numel()) timer.stop() if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1: animator.add(epoch + (i + 1) / num_batches, (metric[0] / metric[2], metric[1] / metric[3], None)) test_acc = d2l.evaluate_accuracy_gpu(net, test_iter) animator.add(epoch + 1, (None, None, test_acc)) print(f'loss {metric[0] / metric[2]:.3f}, train acc ' f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}') print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on ' f'{str(devices)}')
batch_size, devices, net = 256, d2l.try_all_gpus(), d2l.resnet18(10, 3)
def init_weights(m): if type(m) in [nn.Linear, nn.Conv2d]: nn.init.xavier_uniform_(m.weight)
net.apply(init_weights)
def train_with_data_aug(train_augs, test_augs, net, lr=0.001): train_iter = load_cifar10(True, train_augs, batch_size) test_iter = load_cifar10(False, test_augs, batch_size) loss = nn.CrossEntropyLoss(reduction="none") trainer = torch.optim.Adam(net.parameters(), lr=lr) train_ch13(net, train_iter, test_iter, loss, trainer, 10, devices)
train_with_data_aug(train_augs, test_augs, net)
|