利用高斯核卷积对MINIST数据集进行去噪

更新时间:2020-07-24 09:50:30点击次数:197次
import torch

import torchvision

from torch.autograd import Variable

from torchvision import datasets, transforms

from torch.utils.data import DataLoader

import cv2

from torch import nn

import numpy as np

import torch.nn.functional as F

import advertorch.defenses as defenses 

from numpy import *

seed = 2014

 

torch.manual_seed(seed)

np.random.seed(seed)  # Numpy module.

random.seed(seed)  # Python random module.

torch.manual_seed(seed)

 

train_dataset =   datasets.FashionMNIST('./fashionmnist_data/', train=True, download=True,

                       transform=transforms.Compose([

                           transforms.ToTensor(),

                       ]))

 

train_loader = DataLoader(dataset = train_dataset, batch_size = 500, shuffle = True)

 

test_loader = torch.utils.data.DataLoader(

        datasets.FashionMNIST('./fashionmnist_data/', train=False, transform=transforms.Compose([

        transforms.ToTensor(),

        ])),batch_size=1, shuffle=True)

 

epoch = 12


 

class Linear_cliassifer(torch.nn.Module):

    def __init__(self) :

        super(Linear_cliassifer, self).__init__()


 

        self.Gs = defenses.GaussianSmoothing2D(3, 1, 3)

        self.Line1 = torch.nn.Linear(28 * 28, 10)

 

    def forward(self, x):


 

        x = self.Gs(x)

        x = self.Line1(x.view(-1, 28 * 28))

 

        return x


 

net = Linear_cliassifer()

cost = torch.nn.CrossEntropyLoss()



 

optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

 

for k in range(epoch):

    sum_loss = 0.0

    train_correct = 0

    for i, data in enumerate(train_loader, 0):

        inputs, labels = data

        optimizer.zero_grad()

        outputs = net(inputs)

 

        loss = cost(outputs, labels)

        loss.backward()

        optimizer.step()

 

        print(loss)

        _, id = torch.max(outputs.data, 1) 

        sum_loss += loss.data

        train_correct += torch.sum(id == labels.data)

        #print('[%d,%d] loss:%.03f' % (k + 1, k, sum_loss / len(train_loader)))

    print('        correct:%.03f%%' % (100 * train_correct / len(train_dataset)))

    torch.save(net.state_dict(), 'model/fasion_BL.pt')

本站文章版权归原作者及原出处所有 。内容为作者个人观点, 并不代表本站赞同其观点和对其真实性负责,本站只提供参考并不构成任何投资及应用建议。本站是一个个人学习交流的平台,网站上部分文章为转载,并不用于任何商业目的,我们已经尽可能的对作者和来源进行了通告,但是能力有限或疏忽,造成漏登,请及时联系我们,我们将根据著作权人的要求,立即更正或者删除有关内容。本站拥有对此声明的最终解释权。

  • 项目经理 点击这里给我发消息
  • 项目经理 点击这里给我发消息
  • 项目经理 点击这里给我发消息
  • 项目经理 点击这里给我发消息