|
507 | 507 | " loss.backward()\n",
|
508 | 508 | " optimizer.step()\n",
|
509 | 509 | " # 记录误差\n",
|
510 | | - " train_loss += loss.data[0]\n", |
| 510 | + " train_loss += loss.item()\n", |
511 | 511 | " # 计算分类的准确率\n",
|
512 | 512 | " _, pred = out.max(1)\n",
|
513 | | - " num_correct = (pred == label).sum().data[0]\n", |
| 513 | + " num_correct = (pred == label).sum().item()\n", |
514 | 514 | " acc = num_correct / im.shape[0]\n",
|
515 | 515 | " train_acc += acc\n",
|
516 | 516 | " \n",
|
|
526 | 526 | " out = net(im)\n",
|
527 | 527 | " loss = criterion(out, label)\n",
|
528 | 528 | " # 记录误差\n",
|
529 | | - " eval_loss += loss.data[0]\n", |
| 529 | + " eval_loss += loss.item()\n", |
530 | 530 | " # 记录准确率\n",
|
531 | 531 | " _, pred = out.max(1)\n",
|
532 | | - " num_correct = (pred == label).sum().data[0]\n", |
| 532 | + " num_correct = (pred == label).sum().item()\n", |
533 | 533 | " acc = num_correct / im.shape[0]\n",
|
534 | 534 | " eval_acc += acc\n",
|
535 | 535 | " \n",
|
|
0 commit comments