Distilling the Knowledge in a Neural Network (1503.02531v1)
Abstract: A very simple way to improve the performance of almost any machine learning algorithm is to train many different models on the same data and then to average their predictions. Unfortunately, making predictions using a whole ensemble of models is cumbersome and may be too computationally expensive to allow deployment to a large number of users, especially if the individual models are large neural nets. Caruana and his collaborators have shown that it is possible to compress the knowledge in an ensemble into a single model which is much easier to deploy and we develop this approach further using a different compression technique. We achieve some surprising results on MNIST and we show that we can significantly improve the acoustic model of a heavily used commercial system by distilling the knowledge in an ensemble of models into a single model. We also introduce a new type of ensemble composed of one or more full models and many specialist models which learn to distinguish fine-grained classes that the full models confuse. Unlike a mixture of experts, these specialist models can be trained rapidly and in parallel.
Sponsor
Paper Prompts
Sign up for free to create and run prompts on this paper using GPT-5.
Top Community Prompts
Explain it Like I'm 14
What is this paper about?
This paper introduces a simple, smart way to take what a big, powerful “teacher” neural network (or a committee of many networks) has learned and pass that knowledge to a smaller, faster “student” network. The goal is to keep most of the teacher’s smarts while making the student quick and cheap enough to use in real-world apps like voice search.
What questions are the researchers asking?
- Can we “distill” (compress and transfer) the knowledge from a big model or an ensemble of models into a smaller model that is easier to use?
- How should we train the small model so it learns the teacher’s way of generalizing, not just the right answers?
- Does this work on real tasks, like recognizing handwritten digits and speech?
- For huge image datasets, can we also train “specialist” models for confusing categories and get benefits quickly?
- Do “soft targets” (teacher’s graded hints) prevent overfitting and let us learn well even with much less data?
How does the method work? (Plain-language explanation)
Think of a big model as a very experienced teacher and a small model as a student.
- Normally, we train the student using “hard targets” (just the correct answer: one label, 100% confidence).
- In distillation, we use the teacher’s “soft targets” instead. A soft target is the teacher’s full set of probabilities for all classes. It says not only what the teacher thinks is right, but also how wrong answers differ. For example, the teacher might say an image is 99.9% likely to be a “2”, 0.0001% a “3”, and 0.000001% a “7”—those tiny differences tell the student which mistakes are more plausible.
To make these soft targets more informative, the teacher uses a softmax “temperature” dial:
- Softmax is the function that turns raw scores (called “logits”) into probabilities.
- Temperature T > 1 makes the probabilities “softer” (more spread out), so small differences between wrong classes become more visible.
- The student is trained to match these softened probabilities (using the same T during training), and may also get some weight on matching the true hard labels at normal temperature (T = 1).
A helpful analogy:
- Hard targets are like an answer key with only right/wrong.
- Soft targets are like the teacher’s graded hints: “This looks a little like a 3, and a tiny bit like a 7.” Those hints help the student learn the teacher’s way of thinking.
A technical note, simplified:
- “Logits” are the model’s raw scores before probabilities.
- Matching logits directly (as done in earlier work) is actually a special case of distillation when the temperature is high. The authors show that in this high-temperature limit, training to match softened probabilities is mathematically similar to matching logits.
What did they do and find?
Here are the main experiments and why they matter.
- Handwritten digits (MNIST):
- A big, regularized teacher network made only 67 mistakes.
- A small, unregularized network made 146 mistakes when trained the usual way.
- The small network trained with distillation (soft targets, temperature ~20) made just 74 mistakes—almost as good as the big one.
- Even if the transfer set had zero examples of the digit “3”, the distilled model could still recognize most 3’s correctly after adjusting a bias. This shows soft targets transfer a sense of similarity across classes: the model learns what 3’s look like by how other digits relate to 3’s in the teacher’s “hinted” probabilities.
- Speech recognition (a commercial-scale acoustic model):
- Baseline single model: 58.9% frame accuracy, 10.9% word error rate (WER).
- Ensemble of 10 models: 61.1% frame accuracy, 10.7% WER.
- Distilled single model (same size as baseline): 60.8% frame accuracy, 10.7% WER.
- Distillation captured over 80% of the ensemble’s accuracy gain while being as easy to deploy as a single model. That’s a big win in practice.
- Huge image dataset (JFT, 100M images, 15,000 labels):
- They trained fast “specialist” models focusing on confusable groups (e.g., different types of bridges or specific car models).
- Starting from the generalist model’s weights and using soft targets helped prevent overfitting.
- With 61 specialists, top-1 test accuracy improved from 25.0% to 26.1% (about 4.4% relative improvement). The more specialists covered a class, the larger the gains—promising because specialists are easy to train in parallel.
- Soft targets as a regularizer (preventing overfitting):
- Using only 3% of the speech training data:
- Hard-target training overfit badly (test accuracy 44.5%).
- Soft-target training reached 57.0% test accuracy, close to the full-data baseline (58.9%).
- Takeaway: soft targets carry rich information that helps the student generalize, even with little data.
Why is this important?
- Makes deployment easier: You get near-ensemble performance with just one small model, saving memory, time, and energy.
- Better learning with fewer labels or data: Soft targets convey “how to generalize,” not just “what’s right,” which reduces overfitting.
- Scales to big problems: Specialists can boost accuracy on massive datasets without the cost of training full ensembles end-to-end.
- Practical and flexible: You can distill from a big single model, an ensemble, or a mixture of generalist plus specialists. Training is easy to parallelize.
Simple implications and future impact
- Apps like voice assistants, photo recognition, and other AI services can run faster and cheaper while staying accurate.
- Teams can train powerful, complex models offline to extract patterns, then distill these patterns into compact models for phones, browsers, or embedded devices.
- In education terms: the “teacher-student” setup becomes standard—train a strong teacher, then teach a small student with hints, not just answers.
- For research: soft targets are a powerful regularizer and a bridge to using unlabeled or limited data. Distillation also opens the door to combining generalists and specialists smoothly.
Overall, the paper shows that “knowledge distillation” is a practical, effective way to keep the brain of a big AI model while giving the body of a small, efficient one.
Collections
Sign up for free to add this paper to one or more collections.