How to use GAN to color Pokemon

In the previous Demo, we used the conditional GAN to generate a handwritten digital image. So what else can we do with neural networks besides generating digital images?

In this case, we use neural network to color the wireframe of Pokemon.

Step 1: import and use the library

from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf

import numpy as np
import pandas as pd

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

In the model training process of Pokemon coloring, a large amount of memory is needed. In order to ensure the smooth operation of our model on 2070, we limit the use of video memory to 90% to avoid errors caused by insufficient video memory.

config = tf.compat.v1.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.9
session = tf.compat.v1.Session(config=config)

Define the constants to use.

PATH = 'dataset/'
LAMBDA = 100

Step 2: define the function to be used

The main function of the image data loading function is to use the io interface of tensor flow to read in the image and put it into the object of the sensor for subsequent use

def load(image_file):
    image =
    image = tf.image.decode_jpeg(image)

    w = tf.shape(image)[1]

    w = w // 2
    input_image = image[:, :w, :]
    real_image = image[:, w:, :]

    input_image = tf.cast(input_image, tf.float32)
    real_image = tf.cast(real_image, tf.float32)

    return input_image, real_image

Functions of transforming a sensor object into a numpy object

In the process of training, I will visualize some training results and pictures of the intermediate state. Tensorflow's sensor object can't be directly used in matplot, so we need a function to convert the sensor into a numpy object.

def tensor_to_array(tensor1):
    return tensor1.numpy()

Step 3: Data Visualization

Let's see what our training data looks like first.
Each data picture is divided into two parts. The left part is the wireframe, which we use as the input data, the right part is the color map, which we use as the training target picture.
Let's use the load function defined above to load a picture

input, real = load(PATH+'train/114.jpg')


Step 4: data enhancement

Because we don't have enough training data, we use data enhancement to increase our samples. So that small sample data can achieve better results.

We adopt the following data enhancement programs:

  1. Picture zoom: zoom the picture of the input data to the size of the picture we specified
  2. Random clipping
  3. data normalization
  4. Turn left and right
def resize(input_image, real_image, height, width):
    input_image = tf.image.resize(input_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    real_image = tf.image.resize(real_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    return input_image, real_image
def random_crop(input_image, real_image):
    stacked_image = tf.stack([input_image, real_image], axis=0)
    cropped_image = tf.image.random_crop(stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

    return cropped_image[0], cropped_image[1]
def random_crop(input_image, real_image):
    stacked_image = tf.stack([input_image, real_image], axis=0)
    cropped_image = tf.image.random_crop(stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

    return cropped_image[0], cropped_image[1]

We make the above enhancement scheme into a function, in which left and right flipping is carried out randomly

def random_jitter(input_image, real_image):
    input_image, real_image = resize(input_image, real_image, 286, 286)
    input_image, real_image = random_crop(input_image, real_image)

    if tf.random.uniform(()) > 0.5:
        input_image = tf.image.flip_left_right(input_image)
        real_image = tf.image.flip_left_right(real_image)

    return input_image, real_image

Data enhancement effect

plt.figure(figsize=(6, 6))
for i in range(4):
    input_image, real_image = random_jitter(input, real)
    plt.subplot(2, 2, i+1)

Step 5: preparation of training data

Define the loading function of training data and test data

def load_image_train(image_file):
    input_image, real_image = load(image_file)
    input_image, real_image = random_jitter(input_image, real_image)
    input_image, real_image = normalize(input_image, real_image)

    return input_image, real_image
def load_image_test(image_file):
    input_image, real_image = load(image_file)
    input_image, real_image = resize(input_image, real_image, IMG_HEIGHT, IMG_WIDTH)
    input_image, real_image = normalize(input_image, real_image)

    return input_image, real_image

Use tensorflow's DataSet to load training and test data, and define our training data and test data set objects

train_dataset ='train/*.jpg')
train_dataset =,
train_dataset = train_dataset.cache().shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(1)
test_dataset ='test/*.jpg')
test_dataset =
test_dataset = test_dataset.batch(1)

Step 6: define the model

For Pokemon coloring, we use the GAN model to train. Compared with the previous conditional GAN to generate handwritten digital images, the complexity of the GAN model is higher this time.
Let's first look at the overall structure of the generating network and the discriminating network

Generation network

The basic framework of U-Net is used to generate the network. We use convolution layer > BN layer > leakyrelu for every Block in the coding phase. For each Block in the decoding phase, we use deconvolution > BN layer > Dropout or ReLU. We use Dropout for the first three blocks, and ReLU for the next three blocks. The Block output of each coding layer is also connected with the Block of corresponding decoding layer. For details, refer to skip connection of U-Net

Define code Block

def downsample(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2D(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False))

    if apply_batchnorm:


    return result

down_model = downsample(3, 4)

Define decoding Block

def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2DTranspose(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False))

    if apply_dropout:


    return result

up_model = upsample(3, 4)

Define build network model

def Generator():
    down_stack = [
        downsample(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64)
        downsample(128, 4), # (bs, 64, 64, 128)
        downsample(256, 4), # (bs, 32, 32, 256)
        downsample(512, 4), # (bs, 16, 16, 512)
        downsample(512, 4), # (bs, 8, 8, 512)
        downsample(512, 4), # (bs, 4, 4, 512)
        downsample(512, 4), # (bs, 2, 2, 512)
        downsample(512, 4), # (bs, 1, 1, 512)

    up_stack = [
        upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
        upsample(512, 4), # (bs, 16, 16, 1024)
        upsample(256, 4), # (bs, 32, 32, 512)
        upsample(128, 4), # (bs, 64, 64, 256)
        upsample(64, 4), # (bs, 128, 128, 128)

    initializer = tf.random_normal_initializer(0., 0.02)
    last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                         activation='tanh') # (bs, 256, 256, 3)

    concat = tf.keras.layers.Concatenate()

    inputs = tf.keras.layers.Input(shape=[None,None,3])
    x = inputs

    skips = []
    for down in down_stack:
        x = down(x)

    skips = reversed(skips[:-1])

    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = concat([x, skip])

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

generator = Generator()

Discriminant network

We use PatchGAN, also known as Markov discriminator. Many of the traditional CNN based classification models introduce a full connection layer at the end, and then output the result of discrimination. However, PatchGAN is different. It is completely composed of convolutions, and the final output is a square matrix with latitude of N. Then calculate the mean value of the matrix for true or false output. Intuitively, each output of the output matrix is a receptive field of the model to the original image. This receptive field corresponds to a place in the original image, which is also called Patch. Therefore, the GAN of this structure is called PatchGAN.

Each Block in PatchGAN is composed of a volume accumulation layer - > BN layer - > leaky relu.

In our model, the last level of our output latitude is (Batch Size, 30, 30, 1), where 1 represents the channel of the picture.

Each 30x30 output corresponds to a 70x70 area of the original image. For detailed structure, please refer to this article paper.

def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)

    inp = tf.keras.layers.Input(shape=[None, None, 3], name='input_image')
    tar = tf.keras.layers.Input(shape=[None, None, 3], name='target_image')

    # (batch size, 256, 256, channels*2)
    x = tf.keras.layers.concatenate([inp, tar])

    # (batch size, 128, 128, 64)
    down1 = downsample(64, 4, False)(x)
    # (batch size, 64, 64, 128)
    down2 = downsample(128, 4)(down1)
    # (batch size, 32, 32, 256)
    down3 = downsample(256, 4)(down2)

    # (batch size, 34, 34, 256)
    zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)
    # (batch size, 31, 31, 512)
    conv = tf.keras.layers.Conv2D(512, 4, strides=1, kernel_initializer=initializer, use_bias=False)(zero_pad1) 

    batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

    leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

    # (batch size, 33, 33, 512)
    zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)

    # (batch size, 30, 30, 1)
    last = tf.keras.layers.Conv2D(1, 4, strides=1, kernel_initializer=initializer)(zero_pad2)

    return tf.keras.Model(inputs=[inp, tar], outputs=last)

discriminator = Discriminator()

Step 7: define loss function and optimizer


loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)


def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
    generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

    total_disc_loss = real_loss + generated_loss

    return total_disc_loss
def generator_loss(disc_generated_output, gen_output, target):
    gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

    total_gen_loss = gan_loss + (LAMBDA * l1_loss)

    return total_gen_loss
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

Step 8: define CheckPoint function

Because our training time is long, we will save the middle training state for subsequent loading and training

checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,

If we save the results of previous training, we load the saved data. Then we apply the last saved model to output our test data.

def generate_images(model, test_input, tar):
    prediction = model(test_input, training=True)

    display_list = [test_input[0], tar[0], prediction[0]]
    title = ['Input', 'Target', 'Predicted']

    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.imshow(tensor_to_array(display_list[i]) * 0.5 + 0.5)
ckpt_manager = tf.train.CheckpointManager(checkpoint, "./", max_to_keep=2)

if ckpt_manager.latest_checkpoint:

for inp, tar in test_dataset.take(20):
    generate_images(generator, inp, tar)

Step 9: Training

In the training, we output the first picture to see the changes of each epoch to our prediction results. Let everyone enjoy it
We save every 20 epoch s

def train_step(input_image, target):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image, training=True)

        disc_real_output = discriminator([input_image, target], training=True)
        disc_generated_output = discriminator([input_image, gen_output], training=True)

        gen_loss = generator_loss(disc_generated_output, gen_output, target)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

    generator_gradients = gen_tape.gradient(gen_loss,
    discriminator_gradients = disc_tape.gradient(disc_loss,

def fit(train_ds, epochs, test_ds):
    for epoch in range(epochs):
        start = time.time()

        for input_image, target in train_ds:
            train_step(input_image, target)

        for example_input, example_target in test_ds.take(1):
            generate_images(generator, example_input, example_target)

        if (epoch + 1) % 20 == 0:
            ckpt_save_path =
            print ('Save the first{}individual epoch reach{}\n'.format(epoch+1, ckpt_save_path))

        print ('Training session{}individual epoch The time used is{:.2f}second\n'.format(epoch + 1, time.time()-start))
fit(train_dataset, EPOCHS, test_dataset)

The training time for the eighth epoch was 51.33 seconds.

Step 10: color the test data to see our effect

for input, target in test_dataset.take(20):
    generate_images(generator, input, target)

Chi Chi Yun Now it's on the shelf of the "Pokemon coloring" image; interested little friends can try to use it in the "Jupyter tutorial Demo" image on the cloud official website of rectangular pool.

Tags: Python network Session Lambda IPython

Posted on Thu, 12 Mar 2020 20:37:02 -0700 by tkmk