HOW TO USE keras.layers.flatten()

Kevin McLean
3 min readAug 29, 2021

Keras is definitely one of the best free machine learning libraries. It acts as a high-level python API for TensorFlow.

With Keras you can create deep neural networks much easier. You can import trained models or just create one faster and then train it by yourself.

One of the widely used functions in Keras is keras.layers.flatten()

The problem with input tensors like image datasets

When working with input tensors like image datasets, we need to find a way to properly feed them into our input layer.

After all, your input data shape needs to match your input layer shape.

Let’s take a look at Fashion MNIST

Fashion MNIST has 70,000 images in 10 different fashion categories. Each image has 28* 28 pixel resolution.

Let me just print out the 1st image of this dataset in python.

image tensor of the fashion mnist dataset printed in python

WoW, Look at that! Now we have an issue feeding this multi-dimensional array or tensor into our input layer. Are we going to create 28 * 28 layers? Does it even make sense?

keras.layers.flatten()

This is where Keras flatten comes to save us. This function converts the multi-dimensional arrays into flattened one-dimensional arrays or single-dimensional arrays.

It takes all the elements in the original tensor (multi-dimensional array) and puts them into a single-dimensional array. not that this does not include the batch dimension.

let’s understand keras flatten using fashion MNIST example.

Each image in the fashion mnist dataset is a multi-dimensional array of 28 arrays each including 28 elements in it. Then we have 784 elements in each tensor or each image.

What keras flatten does is getting all these 784 elements and put them in a single array. Simple! We can do this and model our first layer at the same time by writing the following single line of code.

keras.layers.flatten(input_shape=(28,28))

Importing TensorFlow, Keras, Fashion MNIST, creating a DNN, training, and evaluation

It’s one thing to understand the theory behind a concept than actually implementing it in practice.

To use keras.layers.flatten() and actually create a DNN you can read the full tutorial at https://neuralnetlab.com/keras-flatten-dnn-example

This tutorial has everything you need to know about keras flatten. Starting from importing TensorFlow, building the DNN, training with fashion MNIST to the final accuracy evaluation of the model.

Be sure to check out the main blog at https://neuralnetlab.com to learn more about machine learning and AI with Python with easy to understand tutorials.

--

--