Domain Generalization via Adversarially Learned Novel Domains

This study focuses on the domain generalization task, which aims to learn a model that generalizes to unseen domains by utilizing multiple training domains. More specifically, we follow the idea of adversarial data augmentation, which aims to synthesize and augment training data with “hard” domains to improve the model’s domain generalization ability. However, previous studies augmented training data only with samples similar to the training data, resulting in limited generalization ability. To alleviate this issue, we propose a novel adversarial data augmentation method, termed GADA (generative adversarial domain augmentation), which employs an image-to-image translation model to obtain a distribution of novel domains that are semantically different from the training domains, and, at the same time, hard to classify. Evaluation and further analysis suggest that GADA fits our expectation; adversarial data augmentation with semantically different samples leads to better domain generalization performance.

are available [6], [7], [8], [9]. Although samples from the 30 target domain cannot be obtained, zero-shot domain adapta-31 tion methods generally assume that part of the information 32 from the target domain can be obtained. They then utilize 33 such information to simulate the target domain and focus 34 on training a model that can achieve good performance on a 35 certain target domain. Assume we have an autonomous driv-36 ing model trained on sunny days, and we want this model to 37 perform well in different weather conditions. In this situation, 38 the zero-shot domain adaptation algorithm can improve the 39 model performance on the target domain by knowing only 40 the weather of the target domain, such as rainy weather [9]. 41 However, a certain aspect of the target domain may not 42 be easily obtained in reality. For example, an autonomous 43 driving model will encounter various types of road conditions 44 in practice. It is often difficult to specify the road conditions 45 that will be faced in the future. In this case, the application of 46 zero-shot domain adaptation is difficult. 47 The domain generalization (DG) task is designed to solve 48 the above issue. Specifically, the goal of DG is to improve 49 the model's generalization ability to data drawn from unseen 50 VOLUME 10,2022 This work is licensed under a Creative Commons Attribution 4.0 License. For more information, see https://creativecommons.org/licenses/by/4.0/ generative model. They report that their proposal, termed 107 L2A-OT, can improve the model's generalization ability. [5] 108 proposed a generative model named the domain flow gen-109 erator, termed DLOW, which can synthesize intermediate 110 domains among semantically different domains. They report 111 that augmenting training data with samples uniformly drawn 112 from the synthesized domains is beneficial to improving the 113 classification model's domain generalization ability. In prin-114 ciple, these augmentation strategies are designed to augment 115 training samples with samples that widely cover the mix-116 ture of given multiple domains. This strategy is expected to 117 improve the generalization ability even with unseen domains 118 that are semantically different from the original training 119 domain. However, these works have never considered data 120 augmentation with the worst-case distribution using genera-121 tive models.

122
A recent work [12] proposed a framework called MBDG. 123 This work considered the same training objective as the 124 adversarial data augmentation; they want to improve the clas-125 sification model's performance over the worst-case domain. 126 However, unlike [10], which attempted to solve a minimax 127 optimization problem, [12] made several assumptions on 128 the data generation process and the labeling mechanism to 129 transform the minimax optimization problem into a relaxed 130 problem that removes the inner maximization. Then, they 131 proposed an algorithm to solve the relaxed problem by using a 132 generative model that simulates the assumed data generation 133 process. This work is related to both adversarial data aug-134 mentation and generative model-based augmentation. It uti-135 lizes samples generated by a generative model to improve 136 the model's performance over the worst-case distribution. 137 Empirically, they report that their algorithm improves the 138 classification model's domain generalization ability on sev-139 eral datasets. However, their experimental results also show 140 that MBDG only leads to good domain generalization ability 141 in certain domains, such as the sketch domain in the PACS 142 dataset, while in certain domains, this method does not per-143 form well. We speculate that this is because MBDG makes 144 strong assumptions about the data generation process and 145 the main classifier. If the target domain and the classifier 146 satisfy their assumptions, then the improvement of domain 147 generalization ability is significant; otherwise, the improve-148 ment is marginal. Moreover, the training objective of the main 149 classifier needs to satisfy several special assumptions. Hence, 150 this method cannot be directly combined with other DG algo-151 rithms. In contrast, our work is a flexible data augmentation 152 method that can be combined with most DG algorithms, and 153 our method does not require strong assumptions on the target 154 domain and the main classifier. This paper is an extended version of a preliminary confer-157 ence report [13]. In this paper, we present a novel data aug-158 mentation method, termed GADA (Generative Adversarial 159 Domain Augmentation). As already discussed, adversarial 160 data augmentation aims to improve the model's performance 161  the domain-code distribution; the second is using a two-layer 196 neural network, termed domain-code NN, which takes Gaus-197 sian noise as input, and its output represents the domain-code 198 distribution. Then, we optimize the parameters of the Dirich-199 let distribution or the parameters of the domain-code NN so 200 that its output represents a distribution that gives the lowest 201 predictive performance with the current classification model. 202 Additionally, we train the classification model with training 203 data augmented with samples drawn from the worst-case dis-204 tribution represented by the learnable Dirichlet distribution 205 or the domain-code NN. We denote our framework with a 206 learnable Dirichlet distribution as GADA-D and another as 207 GADA-NN. 208 Experimental results show that GADA-D and GADA-NN 209 surpass the current data augmentation methods in three DG 210 tasks with multiple training domains in most cases. Further-211 more, further ablation studies and visualization results show 212 that the proposed adversarial training strategy can learn a 213 'hard' domain-code distribution, which helps achieve good 214 domain generalization ability.

215
The novelty of this study is summarized as follows. 216 1) We propose a flexible framework that utilizes the 217 worst-case data generating distribution represented by 218 a generative model. More specifically, we propose uti-219 lizing domain codes to control the behavior of the 220 generative model. The data generating distribution is 221 consequently adapted by optimizing the domain-code 222 distribution, so that it helps to improve the domain 223 generalization ability. 224 2) Our proposal serves as a data augmentation-based 225 method. Due to the flexibility of the data augmenta-226 tion methodology, our framework can be easily com-227 bined with other domain generalization algorithms. 228 We remark that, to the best of our knowledge, our 229 proposal is the first data augmentation-based method 230 that exploits the worst-case data generating distribution 231 characterized by generative models that can represent 232 semantically different domains.

233
3) The experimental results suggest that the pro-234 posed method shows superior domain generalization 235 ability. proposed augmenting training data with features synthesized 296 by linear combinations of features learned by the classifier. 297 We will compare our proposal with this method in the exper-298 imental section.

299
Several works attempt to augment training data without 300 distinguishable features [17], [18]. [17]  There are several works that follow generative model-based 315 data augmentation. In addition to those discussed above, [19], 316 [20] proposed jointly training the generative model for data 317 augmentation and the classification model for improving the 318 classification model's domain generalization ability. During 319 the training process, [19] attempted to train the generator 320 so that the mutual information between the synthesized 321 images and the source images is maximized. Reference [20] 322 attempted to train the generator so that divergence between 323 features extracted from synthesized images and original 324 training data is maximized. Although these works considered 325 different strategies to train the generator, they both aimed to 326 improve the diversity of the generated images. We also make 327 certain modifications to the generative model to encourage 328 the diversity of synthesized images, but we also aim to aug-329 ment training data with the worst-case distribution presented 330 by the generative model.

331
Our previous work [13] (GADA) combined the advantages 332 of adversarial data augmentation and generative model-based 333 data augmentation. GADA attempted to utilize a neural net-334 work, termed domain-code NN, to control the data generat-335 ing distribution of a generative model. Then, by optimizing 336 domain-code NN, the generative model provided hard sam-337 ples to the main classifier. This paper is an expanded version 338 of [13]. In this work, we consider an additional strategy, 339 which utilizes a trainable Dirichlet distribution to control the 340 data generating distribution of a generative model. This new 341 strategy enables more efficient optimization. More details are 342 shown in Section IV-B. In terms of experiments, we compare 343 our proposal with more advanced DG methods [12], [19], 344 [20] in this paper. For better comparison, we also report the 345 performance of our proposal under an additional experimen-346 tal setting, which assumes that we have enough computation 347 resources. More details are shown in Section V.

348
Another common approach to tackling the DG problem 349 is domain-invariance feature learning [4], [23], [24], [25], 350 [26]. The intuition of these works is that invariant features 351 learned from different domains can be generalized to unseen 352 domains. For instance, a popular technique in this line is 353 domain adversarial learning [4] and its variant [23], [24], 354 which attempts to learn a representation that cannot be distin-355 guished among multiple domains. In addition, a method that 356 has recently gained attention is invariant risk minimization 357 (IRM) [25], [26], which aims to learn a representation such 358 that the optimal linear predictor on top of this representation 359 Each 415 element z k represents the degree of relevance to the target 416 domain T k . By specifying z at generation time, DLOW can 417 convert an image in the source domain into an image with a 418 style specified by z, which realizes interpolation of all training 419 domains. One illustration of the domain codes is shown in 420 Figure 1.

421
The objective function of the DLOW model contains the 422 adversarial loss term and cycle consistency loss, and DLOW 423 results in two maps: G ST , which learns a mapping from source 424 domain S to target domains T 1 , . . . , T K , and G TS , which 425 learns a mapping from T 1 , . . . , T K to S. G TS is used to assist 426 the training of G ST for persevering categorical semantics, 427 which is reflected in the cycle consistency loss. We only 428 utilize G ST in our proposal, so we take G ST as an example 429 to explain DLOW. For more details, please refer to [5]. 430 For K target domains, the DLOW model introduces K dis-431 criminators, D T 1 , . . . , D T K , that distinguish generated images 432 and images from domain T 1 , . . . , T K , respectively. Addition-433 ally, the discriminator D S for source domain S distinguishes 434 generated images and images from the source domain. Then, 435 for each k, the adversarial loss between S and T k can be 436 formulated as: The full adversarial loss is defined as: The cycle consistency loss between S and T k is defined as: Then the full objective L is defined as: where λ is a hyperparameter used to balance the two losses

479
The overview of GADA with domain-code NN is shown in 480 Figure 2. As discussed above, we propose controlling the domain-code 483 distribution to control the behavior of the generative model so 484 that we can find the worst-case distribution. Hence, we need 485 a generative model that is conditioned by domain codes. The 486 DLOW model already satisfies our requirements. Empiri-487 cally, however, we found that the synthesized images of the 488 DLOW lacked diversity. The style of the generated images 489 does not necessarily change with the domain code change. 490 [28] showed that the cycle consistency loss limits the 491 diversity of the generated images. The cycle consistency loss 492 forces generated images to be recovered to the original image, 493 so the generated images usually look like original images. 494 Hence, [28] proposed a weak version cycle consistency loss. 495 Instead of forcing generated images can be recovered, they 496 enforce that features extracted by the discriminator can be 497 recovered. Inspired by the weak version of cycle consistency 498 loss, we replace the cycle consistency loss of the original 499 DLOW model with the weak cycle consistency loss. We call 500 this new model DLOW-B (domain flow generator with better 501 cycle). The weak cycle consistency loss is shown in the 502 following: (4) 506 Here, f D S () is the feature extractor of the source domain's 507 discriminator. γ is a hyperparameter used to balance the 508 pixel-level cycle consistency loss and feature-level cycle con-509 sistency loss.

510
The comparison of the image quality of DLOW-B and 511 DLOW is shown in Figure 3  In this study, we aim to train the classifier with a min-max 515 optimization objective, (5). More specifically, during the 516 training process, we aim to optimize P G , that is, the data 517 generating distribution of the generative model G so that the 518 loss function is maximized. Additionally, we optimize the 519 parameters of the classification model F so that the loss of 520 the classification task is minimized. We can interpret this 521 optimization as distributionally robust optimization [29]:

523
Here, θ F are the parameters of the classification model F.  Furthermore, considering that the Dirichlet distribution 555 may not be able to represent a complex distribution over the 556 simplex, we provide another strategy. The success of deep 557 learning in recent years has shown that neural networks have 558 great representation ability. In particular, GAN [30] showed 559 that a neural network can learn a complex distribution from 560 finite samples. Hence, we propose to use a neural network g, 561 termed domain-code NN, to represent the distribution of 562 domain codes. Inspired by [30], we introduce a Gaussian 563 prior c ∼ N (0, 1) to introduce randomness into the neural 564 network. We design the last layer of the domain-code NN as 565 a softmax layer so that it outputs a probability vector. There-566 fore, by choosing c randomly, g(c) works as a distribution in 567 the domain-code space. Although the number of parameters 568 of the domain-code NN will be large, the domain-code NN 569 have better representation ability than a Dirichlet distribution. 570 When the multiple training domains are complex, this strat-571 egy can better represent the domain-code distribution. We call 572 our framework with this strategy GADA-NN.

C. GENERATIVE ADVERSARIAL DOMAIN AUGMENTATION 574
In this section, we propose an algorithm to realize the 575 min-max optimization problem (5) with the DLOW-B model 576 G and the idea of utilizing domain codes. We denote a Dirichlet distribution as B(), and its parameters 579 as α. Here, α is defined as a K −dimensional vector, α = 580 [α 1 , . . . , α K ] with α k ≥ 0. K is the number of given training 581 domains. Then, by utilizing the distribution of domain codes 582 to represent the data generating distribution, the objective of 583 (5) becomes: where (x, y) is an input-label pair drawn from source domain 586 D source (DLOW-B is designed to take samples from one train-587 ing domain as input, and we denote that the training domain 588 as the source domain which is randomly selected among all 589 training domains), F is the main classifier, c ∼ N (0, 1), and 590 is the cross-entropy loss.

591
In the inner maximization of (6), we directly optimize 592 the parameters of the Dirichlet distribution α to represent 593 the worst-case distribution of domain codes. To optimize the 594 parameters of the Dirichlet distribution, we utilize the implicit 595 reparameterization technique [31]. With outer minimization, 596 F is trained to minimize the classification loss of the hardest 597 samples. cally, we find that this leads to better domain generalization 632 ability, as in [10]. Therefore, the actual objective function is 633 given by: , g(c)), y))] Here, λ is used to balance the weight of augmented data and  We evaluate our proposed method on three benchmarks.

646
The first benchmark is digits-DG, which contains four dig- Update θ F by descending its stochastic gradient: For evaluation, we followed the leave-one-domain-out prin-661 ciple in [22], which chooses one domain as the test domain 662 while the others are used as training domains. For a fair com-663 parison, the architecture of classification models followed 664 previous works [11]. We also followed [11]'s strategy to set 665 the hyperparameters by assuming we only have limited com-666 putational resources. All general hyperparameters such as the 667 type of optimizer, the learning rate, and the total number of 668   iterations are consistent for all methods. For the only hyper-669 parameter that is needed in our method, λ, we tuned it with 670 limited trials (less than 5). The experiments in Table 2    with a learning rate of 0.0001. We set the batch size as 16 and 681 the total number of iterations as 10,000. The value of λ is 682 set based on the performance on training validation set, and 683 the search space is λ ∈ {0.5, 1, 1.5}. Moreover, [22] showed 684 that hyperparameter selection is important for domain gen-685 eralization algorithms. We also show the performance of 686 our algorithm with sufficient hyperparameters selection in 687 Section V-D.  Table 2, Table 3 and Table 4 show the prediction accuracy  and office-home dataset. Therefore, the results in Table 3 and 734 Table 4 suggest that compared with Yang2021, our proposal 735 is better at handling difficult domains. This is reasonable 736 since our method is designed to improve the model's perfor-737 mance in the most difficult domain, whereas Yang2021 did 738 not have such characteristics. In contrast, our proposal only 739 gives limited domain generalization ability on relatively easy 740 domains, such as the photo domain in the PACS task and the 741 real-world domain in the office-home task. In these relatively 742 simple domains, the performance of our method is worse 743 than most competitors. We speculate that this phenomenon 744 is because we only focus on the most challenging domain. 745 We ignore the model's performance in the relatively simple 746 domain. If we can find a good balance between relatively 747 easy domains and difficult domains, our proposal may lead 748 to better domain generalization ability. We attempt to give a 749 preliminary solution in Section V-D. 750 Moreover, we can observe that MBDG achieves the highest 751 accuracy in the sketch domain for the PACS task. MBDG also 752 aims to improve the model's domain generalization ability 753 over the worst-case domain. Hence it shows excellent per-754 formance in the difficult domain. However, their algorithm 755 made a strong assumption about the classifier. They require 756 that the main classifier can give a consistent prediction for 757 samples in different training domains. However, it is difficult 758 to train such a classifier, especially when training domains 759 have significant differences. In the PACS dataset, the sketch 760 domain has a very different visual appearance from other 761 domains. When MBDG takes the sketch domain as one of 762 the training domains, i.e., when the test domains are cartoon, 763 photo, and art domains, we can observe that MBDG is not as 764 powerful as solving the sketch domain. In contrast, we did 765 not make substantial requirements on the target domain; 766 hence compared with MBDG, our method works on more 767 test domains, and our method outperforms MBDG in the 768 cartoon and art domains. Moreover, we would like to note 769 that the experimental results of MBDG in Table 3 cannot 770 compare fairly with other results since we obtained the result 771 of MBDG in [12], they spent more computational resources to 772 select their hyperparameters. We will make a fair comparison 773 between our proposal and MBDG in Section V-D.

774
When we further compare GADA-D and GADA-NN, 775 we find that GADA-D performs better than GADA-NN on the 776 digits task, and these two methods achieve close performance 777 in the PACS and office-home tasks. Empirically, we find that 778 GADA-D is easier to optimize since the number of parame-779 ters of the Dirichlet distribution is much less than the neural 780 network parameters. We speculate that this is why GADA-D 781 achieves better performance than GADA-NN in some cases. 782 We provide more evidence in Section V-B. GADA-NN also 783 demonstrated clear advantages over GADA-D in certain 784 domains, such as the sketch domain in the PACS task. As we 785 discussed in Section IV-C, a neural network can represent 786 a more complex distribution than a Dirichlet distribution. 787 Considering that the sketch domain is the most challenging 788 in the PACS task and that images in this domain show a clear 789 difference from the other three domains, we expect the sketch 790 domain to be the most complex in the PACS task. In such a 791 situation, we expect better presentation ability of the neural 792    808 we can see that generated domain codes concentrate around 809 the center of the simplex in the beginning. As the learn-810 ing progresses, the distribution of domain codes gradually 811 shifts away from the lower right corner (MNIST domain). 812 This tendency is also reflected in the generated images; we 813 find that the generated images become increasingly color-814 ful. Intuitively, considering that the MNIST domain is the 815 easiest domain to classify, the observation that the style of 816 the generated samples is less and less like MNIST indi-817 cates that our adversarial search of worst-case distribution 818 works successfully. For GADA-D, since we set the initial 819 parameters of the Dirichlet distribution as (1.0,1.0,1.0), the 820 distribution of domain codes is equivalent to a uniform 821 distribution over the two-dimensional simplex. We can see 822 that the generated domain codes shift away from the right 823 VOLUME 10, 2022 corner during the training, and the generated images become   We also have other methods for generating augmented data.

842
For example, [10] utilizes a pixel-level adversarial attack to 843 generate augmented samples. We compare our proposal with 844 this method and the results shown in the third row in Table 5 845 and Table 6. The results confirm that using a generative 846 model can achieve better accuracy than using a pixel-level  Table 6) and digits datasets (the fourth and fifth 863 rows in We refer to this objective as a nondistributional method. 875 The nondistributional method can be considered a simple 876 combination of adversarial data augmentation methods and 877 generative model-based data augmentation methods. Com-878 pared with the nondistributional method, our method samples 879 domain codes from a learned distribution, which improves 880 the randomness of the domain codes. Hence, the diversity of 881 augmented samples also increases. Experimental results on 882 the PACS (the sixth row in Table 6) and digits datasets (the 883 sixth row in Table 5) show that using a domain-code NN or 884 learnable Dirichlet distribution to represent the distribution 885 of domain codes performs better than the nondistributional 886 method.

889
The experimental results in Table 2, Table 3 and 4 sug-890 gest that our method is not good at handling relatively easy 891 domains. We speculate that this is because of insufficient 892 hyperparameters tuning. In our training objective, (9) and 893 (8), we only use one hyperparameter λ to balance the trade-894 off between the worst-case domain and the given training 895 domains, and this hyperparameter is selected from a small 896 space, λ ∈ {0.5, 1, 1.5}, hence, it is possible that in previous 897 experiments, the value of λ is not properly set. Moreover, 898 considering that different training domains have different 899 characteristics, giving different training domains different 900 weights is more appropriate. Hence, we modify our training 901 objective, and (9)  λ k E (x k ,y k )∼D k ( (θ F ; (x k , y k ))). (12) 907 In (11) and (12), we introduce more hyperparameters 908 for finding a better balance among all training domains. 909 To find the appropriate value for these hyperparameters, 910 we follow [38]'s strategy. We first randomly set the value 911 of all hyperparameters from a uniform distribution, and then 912 we train a model by our proposal. Finally, we repeat these 913 processes 20 times, evaluate these 20 models on the train-914 ing evaluation set, and choose hyperparameters that can 915 achieve the best validation accuracy as our final hyper-916 parameters. We denote modified GADA-D and GADA-NN 917 as GADA-D-F and GADA-NN-F (F refers to fine-tuned). 918 We would like to note that MBDG also followed this strategy 919 to tune their hyperparameters. Table 7 reports the performance of GADA-NN-F and  Table 7 are not comparable to most results in Table 2   932 and Table 3