Creating custom Keras callbacks in python

Upasana | December 07, 2019 | 6 min read | 633 views


In this tutorial I am going to discuss how to create Custom callbacks i.e. logging batch results to stdout, stream batch results to CSV file, terminate training on NaN loss.

I was working with deep learning models using keras in python. Since there was not much variance coming in results per epoch, I wanted to see the results per batch size. That is when I came across Nbatchlogging here which goes something like:

NBatchLogger:
class NBatchLogger(Callback):
    """
    A Logger that log average performance per `display` steps.
    """
    def __init__(self, display):
        self.step = 0
        self.display = display
        self.metric_cache = {}
def on_batch_end(self, batch, logs={}):
        self.step += 1
        for k in self.params['metrics']:
            if k in logs:
                self.metric_cache[k] = self.metric_cache.get(k, 0) + logs[k]
        if self.step % self.display == 0:
            metrics_log = ''
            for (k, v) in self.metric_cache.items():
                val = v / self.display
                if abs(val) > 1e-3:
                    metrics_log += ' - %s: %.4f' % (k, val)
                else:
                    metrics_log += ' - %s: %.4e' % (k, val)
            print('step: {}/{} ... {}'.format(self.step,
                                          self.params['steps'],
                                          metrics_log))
            self.metric_cache.clear()

You just need to copy this and paste in callbacks.py file in keras package in your system and then you can use it in fit like:

Using NBatch Logger while fitting model
nbatch_logging = NBatchLogger(display=1)
model.fit(X_train, y_train, validation_split = val_split,verbose=0,epochs=num_epochs, batch_size=batch_size,callbacks=[nbatch_logging])

You just need to make sure that verbose is set to 0 such that logs per epoch and per batch doesn’t overlaps. Now I could see the logs per batch. And then I became greedy. I wanted more.

Keras provide abstract class named Callback that we can extend to create custom callback implementation. Here is the Class Diagram for the same.

class_Diagram
Class Diagram for Keras Callback

I wanted to save these logs as well, wanted to set an early callback as well if possible based on results per batches and then use it to make graphs as well. In keras callbacks file, there are six important functions to pay attention to as per one want to make a custom callback. Those are:

Methods in Base Class of Callbacks
def on_epoch_begin(self, epoch, logs=None):
        pass
def on_epoch_end(self, epoch, logs=None):
        pass
def on_batch_begin(self, batch, logs=None):
        pass
def on_batch_end(self, batch, logs=None):
        pass
def on_train_begin(self, logs=None):
        pass
def on_train_end(self, logs=None):
        pass

So we need to focus on just these callbacks and see what would we need in our case. For example, i will try to explain already made Keras callback

TerminateOnNaN: This function terminates the training on finding NaN in neurons
class TerminateOnNaN(Callback):
    """Callback that terminates training when a NaN loss is encountered.
    """
    def __init__(self):
        super(TerminateOnNaN, self).__init__()
    def on_batch_end(self, batch, logs=None):
        logs = logs or {}
        loss = logs.get('loss')
        if loss is not None:
            if np.isnan(loss) or np.isinf(loss):
                print('Batch %d: Invalid loss, terminating training' % (batch))
                self.model.stop_training = True

This callback makes sure that when Nan comes in results, model should stop training. And if we think logically, it should check batches as well not only results per epoch as it would be just time waste so that is what it works on as you can see. It checks results when batch ends and then let the model proceed as per the verification. It access the logs and then loss from it. Uses numpy to check if it is Nan or not and then stops training if Nan otherwise not.

To make a Batch Early Stopping callback class, i read the Early Stopping callback class and worked on it only. And it worked in just second try.

BatchEarlyStopping
class BatchEarlyStopping(Callback):
    def __init__(self, monitor='loss',
                 min_delta=0, patience=0, verbose=0, mode='auto'):
        super(BatchEarlyStopping, self).__init__()
        self.monitor = monitor
        self.patience = patience
        self.verbose = verbose
        self.min_delta = min_delta
        self.wait = 0
        self.stopped_batch = 0
        if mode not in ['auto', 'min', 'max']:
            warnings.warn('BatchEarlyStopping mode %s is unknown, '
                          'fallback to auto mode.' % mode,
                          RuntimeWarning)
            mode = 'auto'
        if mode == 'min':
            self.monitor_op = np.less
        elif mode == 'max':
            self.monitor_op = np.greater
        else:
            if 'acc' in self.monitor:
                self.monitor_op = np.greater
            else:
                self.monitor_op = np.less
        if self.monitor_op == np.greater:
            self.min_delta *= 1
        else:
            self.min_delta *= -1
    def on_train_begin(self, logs=None):
        self.wait = 0
        self.stopped_batch = 0
        self.best = np.Inf if self.monitor_op == np.less else -np.Inf
    def on_batch_end(self, batch, logs=None):
        current = logs.get(self.monitor)
        if current is None:
            warnings.warn(
                'Batch Early stopping conditioned on metric `%s` '
                'which is not available. Available metrics are: %s' %
                (self.monitor, ','.join(list(logs.keys()))), RuntimeWarning
            )
            return
        if self.monitor_op(current - self.min_delta, self.best):
            self.best = current
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_batch = batch
                self.model.stop_training = True
    def on_train_end(self, logs=None):
        if self.stopped_batch > 0 and self.verbose > 0:
            print('Batch %05d: early stopping' % (self.stopped_batch + 1))

Replacing few things worked like charm for me. We can use this by defining it like following before calling fit:

batch_early_callback = BatchEarlyStopping(patience=500,monitor='loss')

Note: Using Batch Early callback is a bit tricky as well as it depends on the batch size and size of the training samples as well.

Next is saving the batch logs in a file

NBatchCSVLogger : This will save NBatch logs in CSV file for seeing the improvement on tensorboard
class NBatchCSVLogger(Callback):
    """Callback that streams every batch results to a csv file.
    """
    def __init__(self, filename, separator=',', append=False):
        self.sep = separator
        self.filename = filename
        self.append = append
        self.writer = None
        self.keys = None
        self.append_header = True
        self.file_flags = 'b' if six.PY2 and os.name == 'nt' else ''
        super(NBatchCSVLogger, self).__init__()
    def on_train_begin(self, logs=None):
        if self.append:
            if os.path.exists(self.filename):
                with open(self.filename, 'r' + self.file_flags) as f:
                    self.append_header = not bool(len(f.readline()))
            self.csv_file = open(self.filename, 'a' + self.file_flags)
        else:
            self.csv_file = open(self.filename, 'w' + self.file_flags)
    def on_batch_end(self, batch, logs=None):
        logs = logs or {}
       def handle_value(k):
            is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
            if isinstance(k, six.string_types):
                return k
            elif isinstance(k, Iterable) and not is_zero_dim_ndarray:
                return '"[%s]"' % (', '.join(map(str, k)))
            else:
                return k
       if self.keys is None:
            self.keys = sorted(logs.keys())
        if self.model.stop_training:
            logs = dict([(k, logs[k]) if k in logs else (k, 'NA') for k in self.keys])
        if not self.writer:
            class CustomDialect(csv.excel):
                delimiter = self.sep
            self.writer = csv.DictWriter(self.csv_file,
                                         fieldnames=['batch'] + self.keys, dialect=CustomDialect)
            if self.append_header:
                self.writer.writeheader()
        row_dict = OrderedDict({'batch': batch})
        row_dict.update((key, handle_value(logs[key])) for key in self.keys)
        self.writer.writerow(row_dict)
        self.csv_file.flush()
    def on_train_end(self, logs=None):
        self.csv_file.close()
        self.writer = None

This callback saves logs for batches in a file and can come is handy to diagnose the variance in results.

This can be defined as following before calling fit

Calling NBatchCSVLogger while fitting model
batch_logg_saving = NBatchCSVLogger("batch_logs.csv", separator=',', append=False)

This is how i defined these custom callbacks and few others as well. I hope this will help others well while defining custom callbacks.


So, we just need to see what we actually needs and which function would help us in getting that end result. This can help us because what we get is end results when epoch ends but sometimes our objective is something else. Lets say we want to save best weights when the change in validation accuracy and validation loss becomes constant. For that, as well we can put model training in try and put an exception such that model stops training when loss is almost constant for batches. Thanks for reading this article. I hope you found it useful.


Top articles in this category:
  1. Deploying Keras Model in Production using Flask
  2. Imbalanced classes in classification problem in deep learning with keras
  3. Python coding challenges for interviews
  4. Flask Interview Questions
  5. Deploying Keras Model in Production with TensorFlow 2.0
  6. Top 100 interview questions on Data Science & Machine Learning
  7. Find extra long factorials in python

Recommended books for interview preparation:

Find more on this topic:
Buy interview books

Java & Microservices interview refresher for experienced developers.