博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
pytorch训练cifar10数据集查看各个种类图片的准确率
阅读量:2136 次
发布时间:2019-04-30

本文共 2363 字,大约阅读时间需要 7 分钟。

以vgg19为例

import osimport torchfrom torch.utils.data import DataLoaderfrom torchvision import datasetsfrom torchvision import transformsfrom torch import nn,optim#from lenet5 import Lenet5from vgg import VGG19from vgg import VGG34 def main():    batchsz = 32     cifar_train = datasets.CIFAR10('dataset/', train=True, transform=transforms.Compose([        transforms.Resize((32, 32)),        transforms.ToTensor(),        transforms.Normalize(mean=[0.485,0.456,0.406],        					 std=[0.229,0.224,0.225])    ]), download=True)    cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)    cifar_test = datasets.CIFAR10('dataset/', train=False, transform=transforms.Compose([        transforms.Resize((32, 32)),        transforms.ToTensor(),        transforms.Normalize(mean=[0.485,0.456,0.406],        					 std=[0.229,0.224,0.225])    ]), download=True)    cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)     x, label = iter(cifar_train).next()    print('x:', x.shape, 'label:', label.shape)     device = torch.device('cuda')    model = VGG19().to(device)     criton = nn.CrossEntropyLoss().to(device)  #包含了softmax    optimizer = optim.Adam(model.parameters(),lr=1e-3)    print(model)        if os.path.exists('model.pkl'):        model.load_state_dict(torch.load('model.pkl'))        print('model loaded from model.pkl')         classes = ('plane', 'car', 'bird', 'cat',        'deer', 'dog', 'frog', 'horse', 'ship', 'truck')    N_CLASSES = 10    class_correct = list(0. for i in range(N_CLASSES))    class_total = list(0. for i in range(N_CLASSES))    model.eval()    #test    total_correct = 0    total_num = 0    for x, label in cifar_test:        x,label = x.to(device) ,label.to(device)        logits = model(x)        pred = logits.argmax(dim=1)        total_correct += torch.eq(pred,label).float().sum().item()        total_num += x.size(0)  #即batch_size              c = (pred == label).squeeze()                for i in range(len(label)):            _label = label[i]            class_correct[_label] += c[i].item()            class_total[_label] += 1            acc = total_correct / total_num    print('acc: ',acc)    for i in range(N_CLASSES):        print('Accuracy of %5s : %2d %%' % (            classes[i], 100 * class_correct[i] / class_total[i]))  if __name__ == '__main__':    main()

 

而训练了vgg34则会是这样

 

 

参考:

转载地址:http://slygf.baihongyu.com/

你可能感兴趣的文章
Leetcode C++《热题 Hot 100-43》94.二叉树的中序遍历
查看>>
Leetcode C++ 《第175场周赛-1 》5332.检查整数及其两倍数是否存在
查看>>
Leetcode C++ 《第175场周赛-2 》5333.制造字母异位词的最小步骤数
查看>>
Leetcode C++ 《第175场周赛-3》1348. 推文计数
查看>>
Leetcode C++《热题 Hot 100-44》102.二叉树的层次遍历
查看>>
Leetcode C++《热题 Hot 100-45》338.比特位计数
查看>>
读书摘要系列之《kubernetes权威指南·第四版》第一章:kubernetes入门
查看>>
Leetcode C++《热题 Hot 100-46》739.每日温度
查看>>
Leetcode C++《热题 Hot 100-47》236.二叉树的最近公共祖先
查看>>
Leetcode C++《热题 Hot 100-48》406.根据身高重建队列
查看>>
《kubernetes权威指南·第四版》第二章:kubernetes安装配置指南
查看>>
Leetcode C++《热题 Hot 100-49》399.除法求值
查看>>
Leetcode C++《热题 Hot 100-51》152. 乘积最大子序列
查看>>
[Kick Start 2020] Round A 1.Allocation
查看>>
Leetcode C++ 《第181场周赛-1》 5364. 按既定顺序创建目标数组
查看>>
Leetcode C++ 《第181场周赛-2》 1390. 四因数
查看>>
阿里云《云原生》公开课笔记 第一章 云原生启蒙
查看>>
阿里云《云原生》公开课笔记 第二章 容器基本概念
查看>>
阿里云《云原生》公开课笔记 第三章 kubernetes核心概念
查看>>
阿里云《云原生》公开课笔记 第四章 理解Pod和容器设计模式
查看>>