Imbalanced classes in classification problem in deep learning with keras

Upasana | August 31, 2019 | 1 min read | 116 views


While working on various problems in real world, we usually face problem with imbalanced classes in the data we have collected. A while ago, i also faced a similar problem while working on mood detection model. Problems that we face while working with imbalanced classes in data is that trained model usually gives biased results.

Here we will see how we can overcome this problem when we are building classification model with deep learning in keras.

There is a parameter named as class_weight in model.fit which can be used to balance the weights.

Here we will be using class_weight from sklearn(scikit-learn)

Importing Libraries
import numpy as np
from sklearn.utils import class_weight
from sklearn.preprocessing import LabelEncoder

label is the pandas Series extracted from the data we have by choosing only label column

le = LabelEncoder()
labels = data[label_column]

Create a dict with weights of classes and use it in model.fit

class_weights = class_weight.compute_class_weight('balanced',
                                                  np.unique(labels),
                                                  labels)
class_weights_dict = dict(zip(le.transform(list(le.classes_)),
                          class_weights))

model.fit(x_train, y_train, validation_split, class_weight=class_weights_dict)

Full code will go something like below:

import numpy as np
from sklearn.utils import class_weight
from sklearn.preprocessing import LabelEncoder

le = LabelEncoder()

labels = data[label_column]

class_weights = class_weight.compute_class_weight('balanced',
                                                  np.unique(labels),
                                                  labels)
class_weights_dict = dict(zip(le.transform(list(le.classes_)),
                          class_weights))

model.fit(x_train, y_train, validation_split, class_weight=class_weights_dict)

Hope this works for your problem. Thanks for reading.


Top articles in this category:
  1. SVM after LSTM deep learning model for text classification
  2. Deploying Keras Model in Production using Flask
  3. Top 100 interview questions on Data Science & Machine Learning
  4. Deploying Keras Model in Production with TensorFlow 2.0
  5. Creating custom Keras callbacks in python
  6. Configure Logging in gunicorn based application in docker container
  7. Part 2: Deploy Flask API in production using WSGI gunicorn with nginx reverse proxy

Recommended books for interview preparation:

Find more on this topic: