Transfer learning is a machine learning technique where knowledge gained from a pre-trained model on one task is reused as a starting point for training a model on a related, but potentially different, task.
In traditional machine learning, a common assumption is that both the training and test data are drawn from the same underlying distribution, meaning they belong to the same domain — implying that the input feature space and data distribution characteristics are identical for both sets. However, in reality, this assumption is often violated in scenarios such as:
Data drift: Over time, data distribution can change, leading to difference between training and test data.
Domain shift: Test data may come from a different domain altogether, such as using a model trained on images of cats to identify dogs.
Data quality issues: Training data might contain errors or biases not present in test data.
Consider using transfer learning for the following scenarios:
Limited data: Acquiring extensive labeled data is often challenging and costly. Transfer learning enables us to use pre-trained models, reducing the dependency on large datasets.
Enhanced performance: Starting with a pre-trained model, which has already learned from substantial data, allows for faster and more accurate results on new tasks — ideal for applications needing high accuracy and efficiency.
Time and cost efficiency: Transfer learning shortens training time and conserves resources by utilizing existing models, eliminating the need for training from scratch.
Adaptability: Models trained on one task can be fine-tuned for related tasks, making transfer learning versatile for various applications, from image recognition to natural language processing.
In a CNN, transfer learning is achieved by 'freezing' the early convolutional layers of the pre-trained model, which have learned general features, and then fine-tuning the later layers, which are more task specific. To describe the process in greater detail:
Freezing early layers: Initial convolutional layers of a CNN, which typically learn basic features like edges and patterns, are often frozen (weights stay un-updated during fine-tuning). This is because these features are generally useful across different tasks.
Fine-tuning later layers: Later layers of CNN, particularly fully connected layers, are more specific to the original training task. These layers are unfrozen and fine-tuned using the new, smaller dataset, allowing the model to adapt its knowledge to the new task.
Retraining output layers: Final layers of CNN that make predictions are fine-tuned to adapt to the new dataset's classes or labels.
This approach significantly reduces training time and data requirements, especially when the new task has limited labeled data. By building on the pre-trained model's knowledge, the model can achieve good performance even with a smaller dataset.
Transfer learning can be broken down into three core stages, each with a 'source' side (blue) and a 'target' side (red) — linked by the 'knowledge' that is carried over (green):
Source domain: Data you do have in abundance (e.g. 10 million English movie reviews). Defines input distribution P_S(X).
Source task: What you train on that data (e.g. binary sentiment: positive vs. negative). Defines label space Y_S and predictive function f_S.
Source system/model: Actual model you train (e.g. large Transformer fine-tuned for English sentiment). Learns parameters θ_S that capture both general language features and task-specific decision boundaries.
Knowledge (green block) is the distilled essence of θ_S, which will be transferred to the target side. It can contain:
Feature representations (e.g. pretrained embeddings, hidden-layer activations)
Weights (e.g. first few layers of a CNN or Transformer)
Hyperparameters or architectural motifs (e.g. residual blocks, attention heads)
On the target side:
Target domain: New data distribution you care about but have less of (e.g. 1 million Chinese clothing reviews). Defines P_T(X), which differs from P_S(X).
Target task: Label space or prediction goal you actually need (e.g. 0-10 sentiment rating). Can be the same as source (sentiment) but with different granularity, or even a different task entirely.
Target system/model: You initialize (or 'warm start') your model from the transferred knowledge. Feature extractor weights are copied (and possibly frozen or given a low learning rate), while task-specific head (e.g. a 5-class rating classifier) is randomly initialized. Then you fine-tune on the smaller target dataset, adapting both transferred and new parameters to (𝑋_𝑇, 𝑌_𝑇).
By utilizing this source-target architecture:
Data efficiency: You leverage source data scale to learn robust features, then only need modest target labels to specialize.
Domain adaptation: By keeping core representations and adjusting only the final layers (or fine-tuning gently), you bridge the gap between P_S(X) and P_T(X).
Task flexibility: You can transfer to tasks that are similar (binary to multi-class sentiment) or even somewhat different, so long as the underlying features remain useful.
Three types of transfer learning methods include:
Feature-based transfer learning: For applications where source and target domain lack overlaps at the instance level. Learns a pair of mapping functions {φs (·), φt (·)} to map data respectively from source and target domain to a common feature space, where difference between domains can be reduced, facilitating effective knowledge transfer. Operated in abstract feature space instead of raw input space.
Maximum mean discrepancy (MMD): Nonparametric (not involving any assumptions as to form or parameters of a frequency distribution) criterion to measure distance between distributions. Takes the form of
Model-based transfer learning: For applications where source and target task share some common knowledge in model level. Transferred knowledge is encoded into model parameters or model architectures. A well-trained model in the source domain has learned a lot of structure information, so if tasks are related, they may share similar parameters/structures.
Instance-based transfer learning: Where knowledge transferred corresponds to source instances.
Idea: training a model by source data, then fine-tune the model by target data
Use the trained model from source data as the initial model, and continue train the model with the target data
Challenge: only limited target data, so be careful about overfitting
Conservation learning aims to retain as much as possible of the knowledge or features learned in the source model when transferring it to a new task. This can be achieved through various techniques, including freezing or partially freezing layers of the pre-trained model, fine-tuning specific layers, or using techniques like transfer learning with domain adaptation.
This approach encourages the model to learn more robust and transferable features, reducing its tendency to memorize training data specific details that may not generalize well to new, unseen data. Additionally, by incorporating info from related tasks, conservation learning acts as a form of regularization, discouraging the model from memorizing noise or irrelevant details in the training data.
Layer transfer is a form of transfer learning in which 1 or more entire layers (or blocks of layers) from a pretrained neural network are copied ('transferred') into a new model for a different but related task.
Domain-adversarial training (DAT) is a technique for unsupervised domain adaptation (tackles domain shift) that encourages a model to learn domain-invariant features by pitting two objectives against each other.
DAT works in the following order:
Text
For example, DAT forces a dog-detecting CNN’s internal features for source images (e.g., web-sourced dog photos) and target images (e.g., smartphone-captured dog photos) to look indistinguishable to a domain discriminator. That means said dog detector can better generalize from one style of dog photo to another.
By having the feature extractor simultaneously, the model can produce representations that are useful for the source task. Then, by making domains indistinguishable to the discriminator, users end up with features that generalize from source to target, despite distribution shifts.
In cases where training data and test data involve different tasks, the goal zero-shot learning (ZSL) is to help models generalize to new, unseen categories based on learned relationships and features.
For example, someone who has seen a horse, but does not know zebras, could likely recognize one knowing that a zebra looks like a horse with black-and-white stripes. ZSL assumes a semantic relationship between the seen and unseen classes.
ZSL borrows zero examples (hence the name) and semantic descriptors to classify unseen classes via semantic matching.
Few-shot learning (FSL), unlike ZSL, learns from borrowing 1-K (e.g. 1-5) examples per novel class. It may use semantic info, but primary signal comes from the few examples — unlike how ZSL needs semantic descriptors (e.g., attributes, word embeddings, textual definitions) that link seen-unseen classes.
For example, FSL can learn to recognize a new breed of rabbit from just 3 labeled photos. Therefore, consider using FSL over ZSL when you can obtain a handful of labeled examples per new class and/or want a model that quickly adapts to new categories with minimal annotation effort.
Transfer learning with unlabeled data means first using self-supervised objectives (e.g., masking) to teach a model about language (or images, audio, etc.) in a label-free way, then transferring that distilled knowledge via fine-tuning to any downstream task where labels are available.
For example, in language modeling with word embeddings, unlabeled data transfer learning excels in areas (compared to labeled data transfer learning) such as:
Abundant data: Developers can scrape billions of words from the web without costly annotation.
Broad coverage: Captures syntax, semantics, world knowledge (e.g. analogies in word2vec such as, "king – man + woman ≈ queen").
Domain agnostic: Trained on news, fiction, code, tweets—all in one. The embeddings generalize across many downstream tasks (QA, translation, sentiment).
Rich, hierarchical features: Lower layers learn word-level co-occurrences (embeddings); middle layers learn phrase patterns; higher layers learn discourse-level signals.
Data efficiency for downstream: A 1 million-example classification task only needs fine-tuning; you are not teaching the model "what a verb is" from scratch.