- The paper introduces Adversarial Feature Alignment (AFA), a novel method using adversarial and MMD-based feature alignment to prevent catastrophic forgetting in incremental lifelong learning.
- AFA employs a two-stream network architecture and aligns features at different levels using a GAN-like adversarial game for low-level visual features and Maximum Mean Discrepancy (MMD) for high-level semantic features.
- Experiments demonstrate that AFA significantly mitigates catastrophic forgetting on old tasks while achieving high accuracy on new tasks, often outperforming state-of-the-art methods like LwF, EWC, SI, and MAS.
The paper introduces a novel activation regularization method, termed Adversarial Feature Alignment (AFA), to mitigate catastrophic forgetting in incremental multi-task image classification scenarios. The core idea is to use intermediate activations of a pre-trained model, which encapsulates knowledge from previous tasks, as soft targets to guide the training process when adapting to new data.
AFA's framework comprises a two-stream model representing the old and new networks. Beyond the cross-entropy loss for the new task and the distillation loss between classification probabilities, the method leverages both low-level visual features and high-level semantic features as soft targets. This is intended to provide comprehensive supervised information about the old tasks via multilevel feature alignment.
The method aligns convolutional visual features by introducing a trainable discriminator network to play a GAN-like minimax game with the feature extractors of the old and new models. The discriminator aims to distinguish latent representations encoded by an activation-based mapping function from the convolutional feature maps of the old and new networks. The mapping function, Fatt, takes a 3D tensor A∈RC×H×W as input and outputs a spatial attention map. Fatt is defined as:
Fatt(A)=∑ch=1C∣Ach∣2, where Ach∈RH×W.
Ach is the ch-th feature map of activation tensor A. The discriminator is optimized via:
LadvD=DmaxEz∗∼Z∗[logD(z∗)]+Ez∼Z[log(1−D(z))]
where Z∗ and Z are latent representations from the old and new feature extractors respectively.
The feature extractor F is updated by playing a minimax game with the discriminator D via:
LadvF=Fmin−Ez∼Z[logD(z)].
The paper aligns high-level semantic features using Maximum Mean Discrepancy (MMD). MMD is expressed as the distance between the means of two data distributions P and Q after mapping to a reproducing kernel Hilbert space (RKHS):
MMD2(P,Q)=∥Ep∼P[ϕ(p)]−Eq∼Q[ϕ(q)]∥2
where ϕ(⋅) denotes the mapping to RKHS.
An unbiased estimator of MMD is given by:
Lmmd(P,Q)=Ep,q∼P,Q[k(p,p)+k(q,q)−2k(p,q)]
where k(p,q)=⟨ϕ(p),ϕ(q)⟩ is the kernel function.
The overall loss function is a weighted sum:
L=Lcls+λ1Ldist+λ2LadvF+λ3Lfc
where Lcls is the cross-entropy loss, Ldist is the distillation loss, LadvF is the adversarial loss for feature alignment, and Lfc is the MMD loss for high-level feature alignment.
The paper details experiments in incremental task scenarios, including two-task (starting from ImageNet or Oxford Flowers datasets) and five-task settings (Scenes, Birds, Flowers, Aircraft, and Cars datasets), comparing AFA to joint training, finetuning, Learning without Forgetting (LwF), Encoder-Based Lifelong Learning (EBLL), Elastic Weights Consolidation (EWC), Synaptic Intelligence (SI), and Memory Aware Synapses (MAS).
Key findings include:
- AFA generally suffers the least performance drop on old tasks while achieving high accuracy on new tasks.
- AFA and LwF outperform joint training when new tasks involve smaller datasets than the initial dataset (ImageNet), which prevents overfitting.
- Parameter regularization strategies (EWC, SI, MAS) can struggle when tasks have different output domains or start from small datasets.
Ablation studies demonstrate the individual contributions of adversarial attention alignment and MMD-based high-level feature alignment. The paper also explores alternative constraints for visual and fully connected features, such as L2 regularization, and provides implementation details, including network architecture, training parameters, and hyperparameter selection.