Knowledge Distillation with Keras*

By Ujjwal Upadhyay, Published: 08/09/2018, Last Updated: 08/09/2018

Graphics knowledge and accuracy

The problem that we are facing right now is that we have built sophisticated models that can perform complex tasks, but the question is, how do we deploy such bulky models on our mobile devices for instant usage? Obviously, we can deploy our model to the cloud and can call it whenever we need its service, but this would require a reliable Internet connection and hence, it becomes a constraint in production. So, what we need is a model that can run on our mobile devices.

Train in datacenter with original data

So what’s the problem?

We can train a small network that can run on a limited computational resource of our mobile device. But there is a problem in this approach. Small models can’t extract many complex features that can be handy in generating predictions unless you devise some elegant algorithm to do so. Though an ensemble of small models gives good results, unfortunately, making predictions using a whole ensemble of models is cumbersome and may be too computationally expensive to allow deployment to a large number of users. In this case, we resort to either of these two techniques:

  • Knowledge distillation
  • Model compression

In this article, we look at knowledge distillation. We cover model compression in an upcoming article.

Knowledge distillation is a simple way to improve the performance of deep learning models on mobile devices. In this process, we train a large and complex network or an ensemble model that can extract important features from the given data and can therefore produce better predictions. Then, we train a small network with the help of the cumbersome model. This small network will be able to produce comparable results, and in some cases, it can even be made capable of replicating the results of the cumbersome network.

Convolution pooling softmax other

For example, since GoogLeNet is a very cumbersome (deep and complex) network; its deepness gives the ability to extract complex features and its complexity gives it the power to remain accurate. But the model is heavy enough that one needs a large amount of memory and a powerful graphics processing unit (GPU) to perform large and complex calculations. that’s why we need to transfer the knowledge learned by this model to a much smaller model that can easily be used in mobile.

About Cumbersome Models

Cumbersome models learn to discriminate between a large number of classes. The normal training objective is to maximize the average log probability of the correct answer, and it assigns probability to all the classes, with some classes given small probabilities with respect to others. The relative probabilities of incorrect answers tell us a lot about how this complex model tends to generalize. An image of a car, for example, may only have a very small chance of being mistaken for a truck, but that mistake is still many times more probable than mistaking it for a cat.

Note: That objective function should be chosen such that it generalizes well to new data. Keep in mind while selecting an appropriate objective function that it shouldn’t be selected in such a way that it optimizes well on training data.

These operations will be quite heavy for a mobile during performance, so to deal with this situation, we have to transfer the knowledge of the cumbersome model to a small model that can be easily exported to mobile devices. To achieve this, we can consider the cumbersome model as the teacher network and our new small model as the student network.

Teacher and Student

You can distill the large and complex network in another, much smaller network, and the smaller network does a reasonable job of approximating the original function learned by a deep network.

Rule knowledge distillation

However, there is a catch. The distilled model (student), is trained to mimic the output of the larger network (teacher), instead of training it on the raw data directly. This has something to do with how the deeper network learns hierarchical abstractions of the features.

So How is This Transfer of Knowledge Done?

The transferring of the generalization ability of the cumbersome model to a small model can be done by the use of class probabilities produced by the cumbersome model as soft targets for training the small model. For this transfer stage, we use the same training set or a separate transfer set as used for training cumbersome models. When the cumbersome model is a large ensemble of simpler models, we can use an arithmetic or geometric mean of their individual predictive distributions as the soft targets. When the soft targets have high entropy, they provide much more information per training case than hard targets and much less variance in the gradient between training cases, so the small model can often be trained on much less data than the original cumbersome model while using a much higher learning rate.

The soft outputs

Much of the information about the learned function resides in the ratios of very small probabilities in the soft targets. This is valuable information that defines a rich similarity structure over the data (that is, it says which 2s look like 3s and which look like 7s, or which Golden Retriever looks like a Labrador) but it has very little influence on the cross-entropy cost function during the transfer stage because the probabilities are so close to zero.


For distilling the learned knowledge, we use logits (the inputs to the final softmax function). Logits can be used for learning the small model, and this can be done by minimizing the squared difference between the logits produced by the cumbersome model and the logits produced by the small model.

Softmax with temperature
Softmax with Temperature

For high temperatures (T -> inf), all actions have nearly the same probability, and at the lower temperature (T -> 0), the more expected rewards affect the probability. For a low temperature, the probability of the action with the highest expected reward tends to 1.

In distillation, we raise the temperature of the final softmax until the cumbersome model produces a suitably soft set of targets. We then use the same high temperature when training the small model to match these soft targets.

Objective Function

The first objective function is the cross entropy with the soft targets and this cross entropy is computed using the same high temperature in the softmax of the distilled model as was used for generating the soft targets from the cumbersome model.

The second objective function is the cross entropy with the correct labels, and this is computed using exactly the same logits in softmax of the distilled model but at a temperature of 1.

I B M model of the T-shaped professional

Training Ensembles of Specialists

Training an ensemble of models is a very simple way to take advantage of parallel computation. There is an objection that an ensemble requires too much computation at test time. But this can be easily dealt with in the technique we are learning. And so distillation can be used to deal with this allegation.

Large model ensemble small model

Specialist Models

Specialist models and one generalist model make our one cumbersome model. The generalist model is trained on all training data; the specialist models focused on a different confusable subset of the classes can reduce the total amount of computation required to learn an ensemble. The main problem with specialists is that they overfit very easily. But this overfitting may be prevented by using soft targets.

Reduce Overfitting in Specialist Models

To reduce overfitting and share the work of learning lower-level feature detectors, each specialist model is initialized with the weights of the generalist model. These weights are then slightly modified by training the specialist, with half its examples coming from its special subset, and half sampled at random from the remainder of the training set. After training, we can correct for the biased training set by incrementing the logit of the dustbin class by the log of the proportion by which the specialist class is oversampled.

Training sample x

Assign Classes to Specialists

We apply a clustering algorithm to the covariance matrix of the predictions of our generalist model, so that a set of classes sm that are often predicted together will be used as targets for one of our specialist models, m. So we apply K-means clustering to the columns of covariance matrix to get our required clusters or classes.

Initial C Best C
Assign a score to an ordered covariance matrix. High correlations within a cluster improve the score. High correlations between clusters decease the score.

Covariance/Correlation clustering provides a method for clustering a set of objects into the optimum number of clusters without specifying that number in advance.

Performing Inference

  1. For each test case, we find the ‘n’ most probable classes, according to the generalist model. Call this set of classes k.
  2. For all the specialist models, m, whose special subset of confusable classes,sm , have something in common with k and call this the active set of specialists ak(note that this set may be empty). We then find the full probability distribution q over all the classes that minimizes:

KL denotes the KL divergence, and pm,pg denote the probability distribution of a specialist model or the generalist full model. The distribution pm is over all the specialist classes of m + dustbin class, so when computing its KL divergence from the full q distribution, we sum all of the probabilities that the full q distribution assigns to all the classes in m’s dustbin.

Code for Loss F

def knowledge_distillation_loss(y_true, y_pred, lambda_const):
    temperature = 5.0    
    # split in 
    # onehot hard true targets
    # logits from xception
    y_true, logits = y_true[:, :256], y_true[:, 256:]
    # convert logits to soft targets
    y_soft = K.softmax(logits/temperature)
    # split in 
    # usual output probabilities
    # probabilities made softer with temperature
    y_pred, y_pred_soft = y_pred[:, :256], y_pred[:, 256:]    
    return lambda_const*logloss(y_true, y_pred) + logloss(y_soft, y_pred_soft)

Inference Graphs (for MobileNet* Model)


These results are obtained after training our student model using pretrained teacher models.

Ensemble Model as Teacher

If the teacher is an ensemble model then we can take the following steps in order to improve our learning curve and accuracy of the distilled model:

  1. To improve the performance of the student network, use temperature to balance different teachers.
  2. If student networks are provided with multiple streams of information via the various teacher distributions, the student will observe various views of the data and will be able to generalize better, while at the same time it can capture different information provided by teacher networks. To facilitate this strategy we can try creating multiple copies of data with corresponding soft output targets from various teachers.

Soft Targets as Regularizers

Soft targets, or labels predicted from a model, contain more information than binary hard labels because they encode similarity measures between the classes.

Incorrect labels tagged by the model describe co-label similarities, and these similarities should be evident in future stages of learning, even if the effect is diminished. For example, imagine training a deep neural net on a classification dataset of various dog breeds. In the initial few stages of learning the model will not accurately distinguish between similar dog-breeds such as a Belgian Shepherd versus a German Shepherd. This same effect, although not so exaggerated, should appear in later stages of training. If, given an image of a German Shepherd, the model predicts the class German Shepherd with a high accuracy, the next-highest predicted dog should still be a Belgian Shepherd, or a similar looking dog. Over-fitting starts to occur when the majority of these co-label effects begin to disappear. By forcing the model to contain these effects in the later stages of training, we reduced the amount of over-fitting.

Using soft targets as Regularizers is not considered very effective.


In this paper we presented how to learn from bulky models and create lighter models by using knowledge distillation to transfer knowledge from Xception to MobileNet 0.25 and SqueezeNet v1.1.

The results are shown below:


Xception 82.3 94.7 0.705
MobileNet-0.25 64.6 84.9 1.455
MobileNet-0.25 with KD 66.2 86.7 1.464
SqueezeNet v1.1 67.2 86.5 1.555
SqueezeNet v1.1 with KD 68.9 87.4 1.297

The paper Distilling the Knowledge in a Neural Network was written by Geoffrey Hinton, Oriol Vinyals, and Jeff Dean

Intel Development Tools Used

This project made use of Jupyter* Notebook on the Intel® AI DevCloud (using Intel® Xeon® Scalable processors) to write the code. We trained the student model on Intel® AI DevCloud and used Keras* framework for its development using Intel® Optimization for TensorFlow* as its backend. We also used information from the Intel® AI Developer Program forum and blogs on helped a lot. The code can be found in this GitHub* Repository. The code is self-explanatory, but if you find some bug or you don't understand part of it, please create an issue.