[Python] Introduction to CNN with Pytorch MNIST

Introduction

I was looking for various sites to implement MNIST learning using CNN with Pytorch, but I could not find a site that explained CNN + MNIST, so it was a good idea for the first person. I thought it would be difficult, so I decided to write an article. In this article, I will focus on how to handle Pytorch and MNIST, and how to implement CNN.

↓ ↓ See this famous article for a theoretical explanation of CNN. https://qiita.com/icoxfog417/items/5fd55fad152231d706c2

This time, I will write it on the assumption that you are looking at the above article or have knowledge of CNN.

Target person

・ Those who understand python ・ I understand the theory, but how do you implement it? Those who say ・ Those who have learned the outline in a university class but do not know how to actually implement it

environment

python 3.8.8 Pytorch 1.17.0

What to import

import.py


import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimizer

Explanation of prerequisite knowledge

What is MNIST

A research dataset for handwriting recognition. Abbreviation for Mixed National Institute of Standards and Technology database. The data is a black-and-white grayscale image data of 28 × 28 pixels and 256 gradations obtained by handwriting a number from 0 to 9. As the correct answer label, one of 10 kinds of labels from 0 to 9 is given. 60,000 such images are available for training and 10,000 for testing. Since the same data set is used, it is easy to compare the superiority and inferiority of the algorithms. The data set is compact and the training time is short. Therefore, it has become established as a standard data set for image recognition. (From Kotobank)

Simply put, it is a ** collection of data prepared for research and learning **.

What is torchvision

There are many ways to get MNIST, but you can easily download it by using the library torchvision of Pytorch. torchvision offers a variety of other datasets and useful modules, so remember how to use them easily. At first, it's okay to recognize as much as ** a library that provides data **. At the time of import, from torchvision.datasets import MNIST is set, so it can be described as follows. Select the destination path for the dataset with root. Setting train to True selects the training data, and conversely, setting it to False selects the test data. If you set download to True, the data will be downloaded, and if you set it to False, it will not be downloaded. If you are studying or studying with machine learning, you will use the dataset many times, so it is recommended to download it.

torchvision.py


data = MNIST(root="./MNIST", train=True, download=True, transform=transforms.ToTensor())

What is DataLoader?

Next, I will explain about DataLoader. DataLoader is a function (mini-batch learning method) that easily divides the data in torch.utils.data. The division in this case is not the division into training data and test data used for cross-validation. The division in this case is called ** mini-batch learning **. To put it simply, it is a method of learning training data in several parts instead of learning it all at once **. There are several reasons and grounds for dividing the data, but this time it is difficult to calculate at once if there are a lot of data, so it is okay to recognize about **. Let's see how to use DataLoader below. First, select which data to split with data as the first argument. batch_size is 64 this time, which means ** split the selected data into 64 pieces **. If shuffle is set to True, data will be divided in random order, and if it is set to False, it will be divided in order from the top of data.

dataloader.py


data_loader = DataLoader(data, batch_size=64, shuffle=True)

Data preparation

I downloaded MNIST for training data and test data by the method introduced earlier, and set batch_size to ** 16 ** respectively.

dataload.py


train_data = MNIST(root="./MNIST", train=True, download=True, transform=transforms.ToTensor())
train_data_loader = DataLoader(train_data, batch_size=16, shuffle=True)
test_data = MNIST(root="./MNIST", train=False, download=True, transform=transforms.ToTensor())
test_data_loader = DataLoader(test_data, batch_size=16, shuffle=False)

Building a neural network

Finally, we will start building the network, which is the real thrill. First, the parameters of the convolution part are calculated by a simple calculation before building the entire network. Since the input data is MNIST, a ** 28x28 ** image will be input. If this is folded with a filter size of 5, (image size)-(filter size) + (slide width) = (image size after convolution), and if the filter size is 5 and the slide width is 1, 28 --5 + 1 = 24, and ** 28x28 ** is convolved into ** 24x24 **. If you pass the ** 2x2 ** pool layer here, the overall image size will be halved, so it will be ** 12x12 **. Further convolution gives a 12 -3 + 1 = 10 and an image size of ** 10x10 **. Finally, only once more through the ** 2x2 ** pool layer, the final image size will be ** 5x5 **. The rest is completed by multiplying the total number of pixels of the image (5x5 = 25) and the final number of outputs to the fully connected layer. Please refer to the simple diagram of the series of steps. (Only the layer is shown without writing the activation function etc.) CNN_ex1 (1).png

__Init__

Conv2d now takes the argumentConv2d (number of input channels, number of output channels, filter size). It is okay to think of the input / output channels here as processing similar to that of a normal intermediate layer. (See the official documentation for the exact definition) In other words, in the case of self.conv1, it means "16 outputs for 1 input". The size of the filter, as mentioned in the premise article, takes a small area called a filter on the image and compresses (= convolves) this as one feature quantity. In the case of self.conv1 this time, ** 5 ** is specified as the argument of the filter part, so the size of the filter (small area) is a 5x5 square.

The Linear is the same as a normal NN (neural network), only the input / output unit is specified.

Forward

In forward, it is the same as a normal NN (neural network), and describes the flow from the input layer to the output layer of the network. The max_pool2d that appears here pools the size of the value entered (like compression). This time, 2x2 is set for both pooling layers, so the input is halved. Since the input is compressed and returned, by setting F.max_pool2d (F.relu (self.conv1 (x)), (2,2)), a further compressed x will be created after being folded. .. There is an activation function relu in between, but I won't explain this. If you don't know, it's okay if you think about the processing required between layers. Furthermore, since this time it is a multi-class classification (three or more types of correct labels), the log_softmax function is used as the activation function at the end.

network.py


import torch.nn as nn
import torch.nn.functional as F
import torch

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 128, 5)
        self.conv2 = nn.Conv2d(128, 256, 3)
        
        self.fc1 = nn.Linear(256 * 5 * 5, 150)
        self.fc2 = nn.Linear(150, 120)
        self.fc3 = nn.Linear(120, 10)
        
    def forward(self,x):
        x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))
        x = F.max_pool2d(F.relu(self.conv2(x)),(2,2))
        x = x.view(-1,256 * 5 * 5)
        x = self.fc1(x)
        x = self.fc2(x)
        x = F.log_softmax(self.fc3(x),dim=1)
        return x
    
net = Net()
print(net)

-Execution result- When executed, the outline of the created network will be output as shown below. Make sure that the items you specify are correct. Some items are set by default, so if you don't know a word, you can ignore it. kernel_size represents the size of the filter.

Net(
  (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=3200, out_features=200, bias=True)
  (fc2): Linear(in_features=200, out_features=120, bias=True)
  (fc3): Linear(in_features=120, out_features=10, bias=True)
)

Definition of optimization and loss functions

This time we will use cross entropy for the loss function and SGD for the optimization function.

optimizer.py


import torch.optim as optimizer

criterion = nn.CrossEntropyLoss()
optimizer = optimizer.SGD(net.parameters(),lr=0.01)

Execution of learning

This time, the number of epochs is set to 5, but you can set it to the optimum value if you like. When performing GPU calculation using CUDA etc., the learning time can be significantly shortened by increasing the number of batches and the number of epochs. (Example: batch number-> 256, epoch number-> 150, etc.) This is the same as a normal NN (neural network), so I will omit the explanation.

learn.py


epochs = 3
for epoch in range(epochs):
    running_loss = 0.0
    for i,data in enumerate(train_data_loader):
        train_data, teacher_labels = data
        optimizer.zero_grad()
        outputs = net(train_data)
        loss = criterion(outputs,teacher_labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        
        if i % 2000 == 1999:
            print("Learning progress: [{0},{1}]Learning loss: {2}".format(epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0
print("End of learning")

-Execution result-

Learning progress: [1,2000]Learning loss: 0.595951408088862
Learning progress: [1,4000]Learning loss: 0.1649219517123347
Learning progress: [1,6000]Learning loss: 0.11794542767963322
Learning progress: [1,8000]Learning loss: 0.10531693481299771
Learning progress: [1,10000]Learning loss: 0.09142379220648672
Learning progress: [1,12000]Learning loss: 0.07094955143502511
Learning progress: [1,14000]Learning loss: 0.07791419489397276
Learning progress: [2,2000]Learning loss: 0.053966304615655415
Learning progress: [2,4000]Learning loss: 0.06185422517460961
Learning progress: [2,6000]Learning loss: 0.0553153860775642
Learning progress: [2,8000]Learning loss: 0.048434725125255054
Learning progress: [2,10000]Learning loss: 0.050712744873740806
Learning progress: [2,12000]Learning loss: 0.04531467049338937
Learning progress: [2,14000]Learning loss: 0.04466106877903383
Learning progress: [3,2000]Learning loss: 0.032806222373120246
Learning progress: [3,4000]Learning loss: 0.03608645232143033
Learning progress: [3,6000]Learning loss: 0.03568348353291483
Learning progress: [3,8000]Learning loss: 0.04021910582008564
Learning progress: [3,10000]Learning loss: 0.038381989276232556
Learning progress: [3,12000]Learning loss: 0.038247817724080646
Learning progress: [3,14000]Learning loss: 0.03719676028518558
End of learning

Cross-validation

At first, we divided it into train_data and test_data, so we will perform cross-validation. count represents the total of the prediction results obtained using the generated trained model, and total represents the total value of the correct labels for test_data. The correct answer rate is calculated by dividing these.

cross_varidation.py


count = 0
total = 0

for data in test_data_loader:
    test_data, teacher_labels = data[0], data[1]
    results = net(test_data)
    _, predicted = torch.max(results.data, 1)
    count += (predicted == teacher_labels).sum()
    total += teacher_labels.size(0)
    
print("Correct answer rate: {0}/{1} -> {2}".format(count, total, (int(count) / int(total)) * 100))

-Execution result- Since we are using CNN, the accuracy is as high as 99%.

Correct answer rate: 9906/10000 -> 99.06

Judgment of test data

I think it's hard to get a real feeling because you can't see the prediction result just by cross-validation. So let's visualize the MNIST forecast data. Since test_data_loader is a special object dedicated to Pytorch, it is first iterated and then divided into explanatory variables and objective variables. (Of course, it's okay if you turn it with a for statement) After that, get the label of the forecast data in the same way as for cross-validation. Since MNIST is ** 28x28 ** handwritten digit data, convert the test data to a numpy array with reshape. You can use reshape even if you don't have numpy installed because the function numpy () in pytorch converts it to a numpy array.

result.py


test_itr = iter(test_data_loader)
test_data, labels = test_itr.next()
results = net(test_data)
    
_, predicted = torch.max(results.data, 1)

plt.imshow(test_data[0].numpy().reshape(28,28), cmap="inferno", interpolation="bicubic")
print("label: {0}".format(predicted[0]))

Execution result

label: 7

JupyterLab および他 6 ページ - 個人 - Microsoft​ Edge 2021_01_02 22_30_29.png

in conclusion

Thank you for your hard work. CNN is a basic category of neural networks, and if you understand this, it will be easy to apply, so it is recommended that you understand it deeply. This time, I learned with the prepared data called MNIST, but it may be interesting to train with the data prepared by myself. In that case, there is work before learning, such as adjusting the size of the data in advance and doing some pre-processing, but if you use your own data, you will deepen your understanding considerably, so if you are interested, try it. recommend to. If there are any mistakes in the article, it would be helpful if you could tell me. We will correct it each time.

Recommended Posts

[Python] Introduction to CNN with Pytorch MNIST
Introduction to Python Image Inflating Image inflating with ImageDataGenerator
[Introduction to Python] Let's use foreach with Python
[Introduction to Pytorch] I played with sinGAN ♬
Introduction to Lightning pytorch
Introduction to Python language
Introduction to OpenCV (python)-(2)
[Python] Easy introduction to machine learning with python (SVM)
Introduction to Artificial Intelligence with Python 1 "Genetic Algorithm-Theory-"
Markov Chain Chatbot with Python + Janome (1) Introduction to Janome
Introduction to Artificial Intelligence with Python 2 "Genetic Algorithm-Practice-"
Introduction to Tornado (1): Python web framework started with Tornado
Introduction to formation flight with Tello edu (Python)
Introduction to Python with Atom (on the way)
Introduction to Generalized Linear Models (GLM) with Python
[Introduction to Udemy Python3 + Application] 9. First, print with print
Connect to BigQuery with Python
Introduction to PyTorch (1) Automatic differentiation
Procedure to load MNIST with python and output to png
Introduction to Python Django (2) Win
[Introduction to Python] How to iterate with the range function?
Connect to Wikipedia with Python
Post to slack with Python 3
[Chapter 5] Introduction to Python with 100 knocks of language processing
An introduction to Python distributed parallel processing with Ray
Introduction to Mathematics Starting with Python Study Memo Vol.1
Introduction to RDB with sqlalchemy Ⅰ
Introduction to serial communication [Python]
[Chapter 3] Introduction to Python with 100 knocks of language processing
[Introduction to Pytorch] I tried categorizing Cifar10 with VGG16 ♬
Switch python to 2.7 with alternatives
Write to csv with Python
[Chapter 2] Introduction to Python with 100 knocks of language processing
[Introduction to Python] <list> [edit: 2020/02/22]
Introduction to Python (Python version APG4b)
An introduction to Python Programming
Introduction to Python For, While
[Chapter 4] Introduction to Python with 100 knocks of language processing
Introduction to her made with Python ~ Tinder automation project ~ Episode 6
20200329_Introduction to Data Analysis with Python Second Edition Personal Summary
Introduction to her made with Python ~ Tinder automation project ~ Episode 5
I tried to classify MNIST by GNN (with PyTorch geometric)
Introduction to Python for VBA users-Calling Python from Excel with xlwings-
[Raspi4; Introduction to Sound] Stable recording of sound input with python ♪
[Introduction to Python] How to get data with the listdir function
[Introduction to Udemy Python3 + Application] 51. Be careful with default arguments
[Introduction to Udemy Python 3 + Application] 58. Lambda
[Introduction to Udemy Python 3 + Application] 31. Comments
Python: How to use async with
Link to get started with python
Practice! !! Introduction to Python (Type Hints)
[Introduction to Python3 Day 1] Programming and Python
[Python] Write to csv file with Python
Create folders from '01' to '12' with python
Nice to meet you with python
[Introduction to Python] <numpy ndarray> [edit: 2020/02/22]
[Introduction to Udemy Python 3 + Application] 57. Decorator
Try to operate Facebook with Python
Introduction to Python Hands On Part 1
[Introduction to Python3 Day 13] Chapter 7 Strings (7.1-7.1.1.1)
Output to csv file with Python