[PYTHON] Autoencoder with Chainer (Notes on how to use + trainer)


Last time I tried to challenge the image classification of CIFAR-10 using Chainer's new function, trainer, but due to machine power, it works. I couldn't confirm it and it ended. So, this time, I will confirm how to use trainer by creating Autoencoder using MNIST.

Regarding Autoencoder, I referred to this article.

-[Deep learning] Try Autoencoder with Chainer and visualize the result. -Try making a Deep Autoencoder with Chainer


Create a network that takes 1000 MNIST handwritten characters as input and passes through one hidden layer to obtain an output that is equal to the input. The entire code is listed here [https://github.com/trtd56/Autoencoder).

Network part

The number of hidden layer units is limited to 64. Also, when called with hidden = True, the hidden layer can be output.

class Autoencoder(chainer.Chain):
    def __init__(self):
        super(Autoencoder, self).__init__(
                encoder = L.Linear(784, 64),
                decoder = L.Linear(64, 784))

    def __call__(self, x, hidden=False):
        h = F.relu(self.encoder(x))
        if hidden:
            return h
            return F.relu(self.decoder(h))

Data creation part

Read MNIST data and create teacher data and test data. I don't need a label for the teacher data, and the output is the same as the input, so I'm tinkering with the shape of the data a bit.

# Read MNIST data
train, test = chainer.datasets.get_mnist()

# Teacher data
train = train[0:1000]
train = [i[0] for i in train]
train = tuple_dataset.TupleDataset(train, train)
train_iter = chainer.iterators.SerialIterator(train, 100)

# Test data
test = test[0:25]


model = L.Classifier(Autoencoder(), lossfun=F.mean_squared_error)
model.compute_accuracy = False
optimizer = chainer.optimizers.Adam()

Two points to note here

  1. Definition of loss function When defining a model with L.Classifier, the loss function seems to be softmax_cross_entropy by default, but this time I want to use mean_squared_error, so I have to define it with lossfun.

  2. Do not calculate accuracy This time, we don't use labels for teacher data, so we don't need to calculate accuracy. So you need to set compute_accuracy to False.

Learning part

I don't think there is any particular need for explanation. Since trainer became available, I've been able to write this part easily, which has helped me ^^

updater = training.StandardUpdater(train_iter, optimizer, device=-1)
trainer = training.Trainer(updater, (N_EPOCH, 'epoch'), out="result")
trainer.extend(extensions.PrintReport( ['epoch', 'main/loss']))


Check the result

Create a function and plot the result with matplotlib. The original label is printed in red at the top of the image. Since the coordinates are not adjusted properly, there are some parts that are covered ...

By the way, if you enter the test data as it is into this function, the image of the original data will be output.

def plot_mnist_data(samples):
    for index, (data, label) in enumerate(samples):
        plt.subplot(5, 5, index + 1)
        plt.imshow(data.reshape(28, 28), cmap=cm.gray_r, interpolation='nearest')
        n = int(label)
        plt.title(n, color='red')

pred_list = []
for (data, label) in test:
    pred_data = model.predictor(np.array([data]).astype(np.float32)).data
    pred_list.append((pred_data, label))


Let's see how it changes as we increase epoch.

Original image


16 images including all 0-9. Let's look at these 16 types of changes.

epoch = 1


It's like a sandstorm on TV and I don't know what it is at this point.

epoch = 5


I've finally seen something like a number, but I still don't know it.

epoch = 10


The shapes of 0, 1, 3 etc. are gradually becoming visible. The 6 in the second row is still crushed and I'm not sure.

epoch = 20


I can almost see the numbers.

epoch = 100


I tried to advance to 100 at once. The shape of 6 in the second row, which was almost crushed, is now visible. If you add more epoch, you can see it clearly, but this time it is up to here.

in conclusion

It was fun to watch the network recognize numbers as numbers. ~~ trainer is convenient, but be careful because various parts such as the loss function are automatically determined. ~~ (Fixed on 2016.08.10) It was the Classifer spec, not the trainer, that the loss function defaulted to soft_max_cross_entropy. The loss function is specified when defining the updater used in trainer, but usually the one set in optimizer seems to be linked.

Recommended Posts

Autoencoder with Chainer (Notes on how to use + trainer)
Notes on how to use pywinauto
Notes on how to use featuretools
Notes on how to use doctest
Python: How to use async with
How to use virtualenv with PowerShell
How to use homebrew on Debian
How to use FTP with Python
Notes on how to write requirements.txt
[Hyperledger Iroha] Notes on how to use the Python SDK
Notes on how to use marshmallow in the schema library
How to use mecab, neologd-ipadic on colab
How to use OpenVPN with Ubuntu 18.04.3 LTS
How to use Cmder with PyCharm (Windows)
How to use Google Assistant on Windows 10
How to use Ass / Alembic with HtoA
How to use python put in pyenv on macOS with PyCall
How to use Japanese with NLTK plot
How to use jupyter notebook with ABCI
How to use CUT command (with sample)
How to use SQLAlchemy / Connect with aiomysql
How to use JDBC driver with Redash
How to use GCP trace with open Telemetry
Strategy on how to monetize with Python Java
How to install OpenGM on OSX with macports
How to use tkinter with python in pyenv
How to use xml.etree.ElementTree
How to use Python-shell
How to use tf.data
How to use virtualenv
How to use Seaboan
How to use image-match
How to use shogun
How to use Pandas 2
How to use Virtualenv
How to use numpy.vectorize
How to use partial
How to use Bio.Phylo
How to use SymPy
How to use WikiExtractor.py
How to use IPython
How to use virtualenv
How to use Matplotlib
How to use iptables
How to use numpy
How to use TokyoTechFes2015
How to use venv
How to use dictionary {}
How to use Pyenv
How to use list []
How to use python-kabusapi
How to use OptParse
How to use return
How to use dotenv
How to use pyenv-virtualenv
How to use Go.mod
How to use imutils
How to use import
How to get the key on Amazon S3 with Boto 3, implementation example, notes
How to use xgboost: Multi-class classification with iris data
How to use C216 Audio Controller on Arch Linux