Multi-Head Attention Graph Network for Few Shot Learning

: The majority of existing graph-network-based few-shot models focus on a node-similarity update mode. The lack of adequate information intensifies the risk of overtraining. In this paper, we propose a novel Multi-head Attention Graph Network to excavate discriminative relation and fulfill effective information propagation. For edge update, the node-level attention is used to evaluate the similarities between the two nodes and the distribution-level attention extracts more in-deep global relation. The cooperation between those two parts provides a discriminative and comprehensive expression for edge feature. For node update, we embrace the label-level attention to soften the noise of irrelevant nodes and optimize the update direction. Our proposed model is verified through extensive experiments on two few-shot benchmark MiniImageNet and CIFAR-FS dataset. The results suggest that our method has a strong capability of noise immunity and quick convergence. The classification accuracy outperforms most state-of-the-art approaches.

similarity between labeled and unlabeled data [12,13]; learn-to-model methods generate and update parameters through collaborating with proven networks [14,15]; learn-to-optimize methods suggest to ne-tune a base learner for fast adaptation [16]. Despite its diversity and ef cacy, mainstream meta-learning models mostly pay attention to generalize to the unseen task with transferable knowledge, but few explore inherent structured relation and regularity [17].
To remedy the drawback above, another line of work has focused on Graph Network, which adopted structural representation to support relational reasoning for few-shot learning [17]. The early work constructed a complete graph to represent each task, where label information was propagated by updating node features from neighborhood aggregation [18]. Thereafter, more and more graph methods have been devoted to few-shot learning. Such as edge-labeling framework EGNN [19], transductive inference methods TPN [20], distribution propagation methods DPGN [21], etc. With various features involved in the graph update, limited label information has been converted to multiple forms, and then double-counting and aggregation, entailing many otherwise unnecessary costs [22]. Consequently, how to nd the discriminable information and realize effective propagation is a problem that desperately needs to be settled.

Figure 1:
The overall framework of the MAGN model. In this gure, we present a 3-way 1-shot problem as an example. After Feature Embedding Module f emb (details in Section 4.2.1), samples and their relations generate the initial graph. There are L generations in the GNN module (we show one of them for simplicity). Each generation consists of node feature update and edge feature update, with cooperation among the node-attention, distribution-attention and labelattention. The solid circle represents support samples and the hollow circle represents the query samples. The square indicates the edge feature and the darkness of color denotes the value. The darker the color, the larger the value. The detailed process is described in Section 3 In this paper, we propose a novel multi-head attention Graph Network (MAGN) to address the problem stated above, which is shown in Fig. 1. In the process of updating the graph network, different weights are assigned to different neighbor nodes. Compared to the node-similarity based weight of existing methods, we provide new insights into multi-level fusion similarity mechanism with distribution feature and label information to improve discriminative performance. More speci cally, for node update, we treat the label information as an initial adjacency matrix to soften the noise of irrelevant nodes, thereby providing a constraint for the update direction. For edge update, we excavate the distribution feature by calculating the edge-level similarity of overall samples, as a feedback of global information, it reveals more in-depth relations. Collocating with the regular node-level attention, more valuable and discriminable relations would be involved in the process of knowledge transfer. Furthermore, we verify the effectiveness of our methods through extensive experiments on the MiniImagenet and CIFAR-FS datasets. The results show that MAGN exceeds comparable performance in quick convergence, robustness at the same time keeps the property of accuracy.

Meta-Learning
Meta-learning, also known as "learn to learn," plays an essential role in addressing the issue of few-shot learning. According to the different content in the learning systems, it can be divided into three categories: learn-to-measure methods, which based on metric learning, employs an attention nearest neighbor classi er with the similarity between labeled and unlabeled data. Matching networks adopts a cosine similarity [15], Prototypical network [12] establishes a prototype for each class and utilize Euclidean distance as a metric. Differ from above, Relation Net [13] devises a CNN-based relation metric network. Learn-to-optimize methods suggest to ne-tuning a base learner for fast adaptation. MAML [16] is a typical approach that learns a good initialization parameter for rapid generalization. Thereafter, various models derived from MAML, such as rst-order gradients methods Reptile [23], task-agnostic method TAML [24], Bayes based method BMAML [25], etc. Learn-to-model methods generate and update parameters on the basis of the proven networks. Meta-LSTM [26] embraces the LSTM network to update the meta-learner parameters. VERSA [27] builds a probabilistic amortization network to obtain softmax layer weights. In order to predict weights, MetaOpt Net [28] advocates SVM, R2-D2 adopts ridge regression layer [29], while Dynamic Net [30] uses a memory module.

Graph Attention Network
The attention mechanism is essential for a wide range of technologies, such as sequence learning, feature extraction, signal enhancement and so on [31]. The core objective is to select the information that is more critical to the current task objective from the numerous information. The early GCN works have been limited by the Fourier transform derivation, which was challenging to deal with a directed graph with indiscriminate equal weight [32]. Given that, Yoshua Bengio equipped the graph network with a masked self-attention mechanism [33]. During information propagation, it assigns different weights to each node according to the neighbor distribution. Bene ted from this strategy, GAT can lter noise neighbor and improve the performance of the graph Framework. Such an idea was adopted and enhanced by GAAN [34]. It combined these two mechanisms, the multi-head attention to extract various information, likewise the self-attention to gather them.

Model
In this section, we rst summarize the preliminaries of few-shot classi cation following previous work and then describe our method in more technical detail.

Preliminaries
Few-shot learning: The goal of FSL is to train a reliable model with the capability of learning and generalizing from few samples. A common setting is N-way K-shot classi cation task. Clearly, each task T consists of support set S and query set Q. There are N * K labeled samples in the support set, where N is the number of class and K is the number of samples in each class. Samples in the query set are unlabeled, but they belong to the N class of support set. The learning algorithm aims to produce a mapping function from query samples to the label.

Meta-Learning:
One of the main obstacles in the FSL is over tting caused by limited labeled data. Meta-learning adopts episodic training strategy to make up for this, which increase generalization ability through extensive training on similar tasks. Given train date set D train and test date set D test , D train ∩ D test = ∅. Each task T is randomly sampled from a task distribution P (T ). It can be expressed as In the training stage, there are plenty of N-way K-shot classi cation tasks which samples from D train . Through amounts of training episodic on these tasks, we can propose a feasible classi er. And in the testing stage, samples of each task stem from D test . Since tasks in D train and D test follow the same distribution P (T ). Such classi er can generalize well on the task which samples from D test .

Initialized GNN
Graph Neural Networks: In this section, we describe the overall framework of our proposed GNN, as shown in Fig. 1. Firstly, we utilize an embedding module to extract feature (detail in Section 4.2.1), after that each task is expressed as a fully-connected graph. Through L layers Graph Update, the GNN realizes information transfer and relational reasoning. Speci cally, the task T is formed as the graph G = (V, E), where each node v i ∈ V denotes the embedding sample x i in task T , and each edge e i,j ∈ E corresponds to the relationship of two connected nodes v j and v i , where i, j = 1, 2 · · · F, F is the numbers of all samples in the T , F = N × K + T. Initial graph feature: In the graph G = (V, E), node features are initialized as the output of feature embedding module Where θ emb is the parameter set of the embedding module f emb . Edge features are used to indicate the degree of correlation between the two connected nodes, e i,j ∈ [0, 1]. Given the label information, we set the edge features of labeled samples to reach the two extremes of intra-class and inter-class relations, while the edge features of unlabeled samples share the same relation to others. Therefore, the edge features are initialized as Eq. (1):

Multi-Head Attention
The majority of existing few-shot graph-models focus on a node-attention update mode, which adopts the node similarity to control neighborhood aggregation. This mode ignores the inherent relationships between the samples, which may lead to the risk of overtraining. Therefore, we propose a multi-head attention mechanism with distribution feature and label information to enhance the model capability.

Node-Level Attention
Like some existing methods as EGNN and DPGN, the node-level attention is based on the similarity between the two nodes. Since each node has a different neighborhood, we use normalization operation for nodes in the same neighborhood to get more discriminative and comparable results. We employ node-level attention with node-similarity de ned as follows: In detail, given nodes v k i and v k j from the k-th layer, Att is a metric network with four Conv-BN-ReLU blocks to calculate the primary similarity of the two nodes. In Eq. (3), N (i) denotes the neighbor set of the node v i . Then we apply a local normalization operation by softmax and get the nal node-similarity n k i,j .

Distribution-Level Attention
The node-level attention relies on the local relationships of node similarity, while the global relationship has not yet been fully investigated. To mine more discriminative information, we extract the global distribution feature by aggregating the edge features of overall samples and then evaluate the similarity of distribution feature, with de nitions as Eqs. (4) and (5).
where D k i is the distribution feature of node v k i from the k-th layer, it consists of all the edge features of v k i . Similarly, we can get the distribution feature of node v k j as D k j . Then both of them would be sent to the Att network to assess the distribution similarity. The same softmax operation aims at simplifying the computations.

Label-Level Attention
In the previous work, though the aggregation scope is the neighborhood of each node, it extends beyond the same class. Furthermore, the update of graph network is a process of information interaction and fusion, therefore increasing the noise of nodes from diverse classes. We set an adjacency matrix to lter irrelevant information and constraint update direction as shown in Eq. (6).
where A k is the adjacency matrix at the k-th layer. A is the label adjacency matrix, the element a i,j is equal to one when v i and v j have the same label and zero otherwise. E k is the matrix of edge feature. It combines long-term label information with short-term updated edge features in a Recurrent Neural Network. Such operation prunes useless information from inter-class samples and distills useful intra-class samples.

Feature Update
Information transmission has been facilitated through the alternate update of node features and edge features. In particular, the update of node feature depends on neighborhood aggregation, where edge features cooperate with label information to control the relation transformation. While the edge features of MAGN subject to node-similarity and neighborhood distribution.
Based on the above update rule, the edge features at the (k + 1)-th layer can be formulated as follows: where conca/ave represents the connection between the two attention mechanisms, conca means cascade connection, ave denotes mean reversion. n k i,j represents the node-similarity as shown in Eq. (3), d k i,j represents the distribution-similarity as shown in Eq. (5). The node vectors at the (k + 1)-th layer can be formulated as Eq. (8): where MLP v is the node update network with two Conv-BN-ReLU blocks, a k+1 i,j is the adjacency status of v j and v i at the (k + 1)-th layer. It aggregates the node features of neighbor set with multi-head attention mechanism shown in Fig. 2.

Prediction
Over L layers update of node and edge feature, the classi cation results of node x i can be obtained from a prediction probability of corresponding edge feature at the nal layer e L i,j by softmax function: In Eq. (9), δ y j = n is the Kronecker function that outputs one if y j = n and zero otherwise. P ŷ i = n|v i stands for the prediction probability where v i is in the n-th category.

Training
During the episodic training, the parameters of proposed GNN are trained in an end-to-end manner. The nal objective is to minimize the total loss function computed in all layers as shown in Eq. (10): where λ k is the weight of k-th layer, L E represents the cross-entropy loss function, P ŷ i |v k i is the probability predictions of sample x i at the k-th layer and y i is the ground-truth label.

Experiments
For a fair comparison, we conduct our method on two standard few-shot learning datasets following the proposed experimental settings of EGNN and make contrast experiments with stateof-the-art approaches.

Datasets
MiniImageNet is a typical benchmark few-shot dataset. As a subset of the ImageNet, it is composed of 60,000 images uniformly distributed over 100 classes. All of the images are RGB colored, the size is 84 * 84 * 3. Following the setting provided by [26], we randomly select 64 classes for training, 16 classes for validation, and 20 classes for testing.
CIFAR-FS is derived from CIFAR-100 dataset. The same as MiniImageNet, it is formed of 100 classes and each class contains 600 images, which splits 64, 16, 20 for training, validation, and testing. In particular, the main obstacles of low resolution (32 * 32) and high inter-class similarity make classi cation task technically challenging.
Before training, both datasets have been endured data augmentation with transformation as horizontal ip, random crop, and color jitter (brightness, contrast, and saturation).

Embedding Network
We adopt ConvNet and ResNet12 for the backbone embedding module. Following the same setting used in [19,23], the ConvNet architecture contains four convolutional blocks, each block is composed of 3 * 3 convolutions, a batch normalization, a 2 * 2 max-pooling and a LeakyReLU activation. Similar to ConvNet, ResNet12 also has four blocks, one of which is replaced by a residual block.

Parameter Settings
We evaluate MAGN in 5-way 1-shot and 5-shot classi cation task on both benchmarks. There are three layers in the proposed GNN model. In the meta-train stage, each batch consists of 60 tasks. While in the meta-test step, each batch obtains ten tasks. During training, we adopt the Adam optimizer with an initial learning rate of 5 * 10 −4 and a weight decay of 10 −6 . The dropout rate is set as 0.3, and the loss coef cient is 1. The results of our proposed model are obtained through 100k iterations on MiniImageNet and CIFAR-FS.

Main Results
We compare our approach with recent state-of-the-art models. The main results are listed in Tabs. 1 and 2. According to diverse embedding architectures, the backbone can be divided into ConvNet, ResNet12, ResNet18, and WRN28. The major difference is the number of residual blocks. In addition, GNN-based methods are listed separately for the sake of intuition. Extensive results show that our MAGN yields better performance on both datasets. For example, among all the Convnet-architecture methods, The MAGN is substantially better than others. Although the results are slightly lower than DPGN, we still obtain the second place with a narrow gap of both backbones. Nevertheless, some common graph network methods like EGNN, DPGN adopt training and testing with labels in a consistent order, such as the label in the 5-way 1-shot task is from support set (0, 1, 2, 3, 4) to the query set (0, 1, 2, 3, 4). The learning system may learn the order of task rather than the relation of samples. To avoid this effect, we disrupt the label order of support set and query set. This setup makes our results less than optimal, but it is more in line with the reality of the scene. The proposed MAGN acquires a robust result that would not be biased by the noise of label order.

Ablation Study
Effect of Data shuf ing mode: There are three ways to scramble data: shuf e the support set, shuf e the query set and shuf e both sets. We conduct a 5-way 1-shot trial with label-node attention in MiniImagenet. The comparative result is shown in the Tab. 3. As we can see, the use of data shuf ing mode has little effect on the accuracy rate, while it makes a difference to the time of convergence. It is consistent with the essence of random selection. To further explore the convergence performance of the model, the default setting is shuf ing the order of both sets.
Effect of Different Attention: The major ablation results of different attention components are shown in Fig. 3. All variants are performed on the 5-way 1-shot classi cation task of MiniIma-geNet. The baseline adopts only node attention ("NodeAtt"). On this basis, the variant "DisNode" adds distribution-level attention to assist edge update. For samples in the same class, their surrounding neighborhood would follow a similar distribution. Thus the "DisNode" model can mine more discriminable relationship between the two nodes and obtain an enhancement in accuracy. Besides, the performance of concatenating aggregation is superior to average aggregation. This advantage extends to the nal state of three attentions with a slight rise from 0.49 ("CatDisNode"-"AveDisNode") to 0.85 ("Cat3Att"-"Ave3Att"). The variant "LabNode" equips node update with label-level attention, leading to a considerable improvement in convergent iteration from 89k to 63k. We attribute this to the ltering capability of label adjacency matrix, which constrains update direction and realizes fast convergence.

Effect of Layers:
In GNN, the depth of the network has some in uence on feature extraction and information transmission. To explore this problem, we perform 5-way 1-shot experiments with different numbers of layers. As shown in Tab. 4, accuracy rate and convergence times are improved steadily with the network deepens. To manage the trade-off between convergence and accuracy, a 3-layers GNN is con gured for our models.

Conclusion
In this paper, we propose a multi-head attention Graph Network for few-shot learning. The multiple attention mechanism including three parts: node-level attention explores the similarities between the two nodes, and distribution-level attention extracts more in-deep global relation. The cooperation between those two parts provides a discriminative expression for edge feature. While the label-level attention, served as a ltration, weakens the noise of some inter-class information during node update and accelerates the convergence process. Furthermore, we scramble the training data of support set and query set to guarantee to transfer order-agnostic knowledge. Extensive experiments on few-shot benchmark datasets validate the accuracy and ef ciency of the proposed method.

Con icts of Interest:
The authors declare that they have no con icts of interest to report regarding the present study.