Friday, November 28, 2025

Knowledge Distillation

Knowledge distillation from a trained large Teacher model to a smaller Student model is a very popular technique in the ML. Distillation helps to train a Student model which despite being much smaller and compressed shows performance comparable to the larger Teacher model.

The other advantage of Distillation is that the Student model requires a much smaller set of labelled training data (<10%) since it's essentially trying to match the output of the Teacher during training. The Distillation loss is a function of the difference between the prediction of the Student (y_pred) & the Teacher models (teacher_pred) for every training input (x). Kullback-Leibler divergence (KLDivergence) loss between student_pred (y_pred) & teacher_pred is a common pick for the Distillation loss.

For a working example of Distillation refer to TextClassificationDistillation.py which is distilled from a Keras Text Classification model in Torch. The original Text Classification Teacher model had several Convolution layers which have been replaced by a Dense layer. Also the Input Embedding layer's ouput dimension has been reduced from 128 to 32. 

The original Text Classification model (Teacher) had ~8.9 Lakh parameters and was trained with 25K data samples. The distilled Student model has only ~1.6 Lakh parameters (~18%) and was trained using 2.5K samples (~10%). In terms of the size of the saved models the Teacher model is 10.2MB vs 0.6 MB of the student. There was only a marginal 4% drop in accuracy seen with the Student model on the held-out test data.

Keras Text Classification - Teacher Keras Text Classification - Student
 Fig 1: Text Classification - Teacher Model

Fig 2: Keras Text Classification - Student Model

No comments:

Post a Comment