一种基于Transformer的生成对抗网络方法
未命名
07-27
阅读:111
评论:0
一种基于transformer的生成对抗网络方法技术领域
1.本发明涉及高光谱图像分类方法,特别是涉及一种基于transformer的生成对抗网络(generative adversarial network,gan)方法,属于遥感信息处理技术领域。
背景技术:
2.随着科技的发展,高光谱图像分类(hyperspectral image classification,hic)在许多方面得到了广泛的应用。近年来,深度学习(dl)模型已经应用到hic领域。
3.随着深度学习的发展和模型参数的增加,过拟合问题成为了一个巨大的挑战。为了缓解这个问题,zhang等人致力于开发一个简单的网络。他们提出了一种易于实现、比普通3d卷积更轻的1d胶囊网络。但mou等人认为一维卷积在表示高光谱像素时可能会造成像素信息的丢失,因此他们提出了一种新颖的循环神经网络(recurrent neural network,rnn )结构。然而rnn在处理序列信息时存在效率低的问题。在处理顺序数据时,相对于rnn,具有注意力机制的transformer能够更好地解决处理序列效率低的问题。目前,将transformer与cnn结合起来学习图像特征是一种比较常用的方式。然而,transformer的参数量较大,对于hsi这样的小样本训练时非常容易出现过拟合的现象。缓解过拟合的一个重要方法是增加训练数据。许多研究人员通过增加数据来缓解这种情况。这具体包括数据翻转、裁剪、平移和生成模型。生成式模型是通过生成高质量的样本来缓解这个问题。gan是典型的生成模型,主要由生成器g和判别器d两部分组成,gan可以从根本上解决数据样本少的问题,进而解决过拟合的问题。因此,更多的研究者设计gan来缓解样本不足的问题。zhu等人使用1d gan作为光谱分类器,3d gan作为空间分类器。此外,许多研究人员还将gan与其他技术相结合。然而,gan总是存在训练数据不平衡和模式崩溃的问题。为了解决训练数据不平衡的问题,wang等人将d适应为单个分类器,并提出了自适应dropblock正则化方法来解决模式崩溃问题。
4.gan具有不稳定的缺点,大多数研究人员一直致力于解决这个问题,所以很多人引入各种正则化方法,但很少改变它的网络结构。对于cnn来说,卷积算子有一个局部接受域,所以cnn无法处理远程依赖。然而,hsi有更多的谱序列信息。于是本方法使用了transformer做基础框架,它更适合处理全局信息,也擅长处理序列信息。目前在hic领域,还没有人将transformer引入gan中。因此,本方法结合了transformer和gan的思想,提出了带有残差升级模块的生成对抗网络(transformer with residual upscale gan,trug)。
技术实现要素:
5.本发明将transformer引入gan,并提出了用于hic的基于transformer的带有残差升级模块的生成对抗网络(transformer with residual upscale gan,trug)。trug包含一个生成器g和一个鉴别器d。在g中,我们提出了残差升级模块(residual upscale,ru),ru可以提高生成图像的分辨率。在d中,我们采用规模逐步递减的transformer块,并在第一层中使用网格自注意机制,以便于更好地提取图像特征。此外,gan容易出现训练不稳定的问题,
image; (b) ground truth; (c) svm; (d) cnn; (e) 3d cnn; (f) hybridsn; (g) dprn; (h) transformer; (i) vit;(j)trug。
13.图6为up数据集通过不同方法获得的分类图的可视化比较;(a) false color image; ground truth; (c) svm; (d) cnn; (e) 3d cnn; (f) hybridsn; (g) dprn; (h) transformer; (i) vit;(j)trug。
具体实施方式
14.为了使本技术领域的人员更好地理解本发明方案,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分的实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动的前提下所获得的所有其他实施例,都应当属于本发明保护的范围。
15.图1为本发明trug的框架图。
16.我们选择了两个公开的hsi数据集,分别是indian pines(ip),university of pavia (up),来验证所提方法的有效性。
17.所有的数据集分为两部分,即训练集和测试集。由于gan对小样本非常敏感,我们对每一类样本进行分类,并从每一类样本中选取10%进行训练。实验结果主要有三个评价标准,总体准确率(oa)、平均准确率(aa)、kappa系数(kappa)。此外,为了避免有偏见的估计,在一台配备了英特尔酷睿i5处理器和rtx3090gpu的计算机上,使用pytorch进行了10次独立测试。
18.每次测试的具体步骤如下: s1:将原始数据通过pca进行降维得到x
pca
,并将x
pca
输入到鉴别器d里学习其真实样本的特征;s2:在鉴别器d中将x
pca
分为几个patch, 并对其进行embedding;s3:将embedding之后的数据输入到transformer的block中,学习其特征,并随后对得到的特征进行降采样使其尺寸减小,重复该步骤三次得到最后的辨别特征;s4:向生成器g中输入一维随机噪声z∈r
b*l
和类标签c,通过多层感知器( multi-layer perceptron, mlp)将噪声z重构为分辨率为(h
×
w)的特征图x∈r
b*h*w*c
,并且将到的特征图x输入到transformer block进一步提取特征;s5:将s4得到的特征通过残差升级模块(residual upscale,ru)来提高特征图的分辨率,残差升级模块的具体步骤为:在模块前的特征图x和模块后的特征图x
new
之间做一个kronecker积,生成高分辨率的x
up
,具体公式如下:x
⊗
x
new
=x
up
s6:将s5得到的特征图x
up
输入到swin transformer(st)中进一步提取其不同窗口之间的特征xst,并将得到的特征图xst通过ru模块进一步提高其分辨率得到特征x
stnew
;s7:将x
stnew
的通道维度压缩到与x
pca
的通道维度一致得到假样本fake
data
∈r
b*m*n*c
;s8:将生成的假样本fake
data
与真实样本x
pca
共同输入到鉴别器d中,将s3得到的辨别特征输入到softmax中进行分类以及辨别真假,得到最后的分类结果,同时将辨别真假以
及分类结果的loss回传给生成器使其不断学习生成更高质量的样本。
19.为了检验本发明的有效性,分别进行了烧蚀试验和对比实验。
20.a.烧蚀实验(1)生成样本量和视觉分析生成图像的大小是重要的参数。在实验中,我们用数据集生成了不同大小的图像,分别为16,32。具体实验生成的图像特征的可视化如图2所示。特征分析一直是一个难题,尤其是其直观性分析。对于生成样本的质量评价,最直观的是与真实图像进行对比以进行视觉分析。图2中(a)从上到下显示了数据集pu生成的真假图像从训练早期到中后期的比较。而图2中(b)是in数据集生成的32
×
32张假图像的可视化,按照训练时间的长度从上到下依次显示。从图中可以看出,上述训练前期的真假图像明显差别很大,中后期出现了相似的部分。在训练过程中仍然可以看到图像的学习过程。
21.对于生成样本的大小,我们分别对两个数据集进行了参数实验,如图3所示,发现样本大小越大分类效果越好,但由于硬件条件的限制,我们无法继续增加大小进行实验。从图中可以看出,对于不同的数据集,生成样本的大小对实验结果有很大的影响。对于数据集ip,当样本量为64时,可以获得oa为94.56 %的实验结果,比样本量为16时的o a高14.49%。当样本量为16时,获得了最高的分类准确率96.76%。说明不同的数据集对于不同的规模具有不同的分类精度。因此,在后续的实验中,我们对ip采用64的大小,对pu采用16的大小 。
22.(2) ru的分析:我们选择transgan作为对比实验,该实验使用了传统的提高图像分辨率的方法,即upscaling。从图4可以看出,使用ru模块的 gan分类精度明显比使用传统方法要高得多。实验结果表明,ru模块确实具有一定的效果。
23.b、对比实验我们提供了在ip和up数据集上通过不同方法获得的分类精度。比较方法包括svm,cnn,3d cnn,hybridsn,deep pyramidal residual networks (dprn)以及最近出现的transformer和vit。在实验中,我们使用了训练集的10%。
24.可以从表1看到,trug优于其他所有方法。对于数据集ip,我们提出的方法的oa分别远高于svm、cnn、3d cnn和vit,比transformer、hybridsn和dprn获得的oa高4.43%左右。trug的oa、aa和kappa分别为97.85%、97.67%和97.55%。说明trug对hsi的分类效果较好。此外,我们还可以观察到,对于up数据集,本发明提出的方法获得了最佳的性能(即oa=99.83%,aa=99.67%,kappa=99.77%),比其他基于深度学习的方法提高了约1%,比传统方法提高了4%到5%。不同数据集的分类图如图5和图6所示。本发明所提方法得到的分类图比其他方法得到的分类图更清晰。
25.表1
以上仅是本技术的具体实施方式而已,并非对本技术做任何形式上的限定,凡是依据本技术的技术实质对以上实施方式所做的任意简单修改、等同变化或修饰,均仍属于本技术技术方案的保护范围。
技术特征:
1.一种基于transformer的生成对抗网络(generative adversarial network,gan)方法,其特征在于,包括以下步骤:s1:将原始数据通过pca进行降维得到xpca,并将xpca输入到鉴别器d里学习其真实样本的特征;s2:在鉴别器d中将xpca分为几个patch,并对其进行embedding;s3:将embedding之后的数据输入到transformer的block中,学习其特征,并随后对得到的特征进行降采样使其尺寸减小,重复该步骤s3三次,得到最后的辨别特征;s4:向生成器g中输入一维随机噪声z∈r
b*l
和类标签c,通过多层感知器( multi-layer perceptron, mlp)将噪声z重构为分辨率为(h
×
w)的特征图x∈r
b*h*w*c
,并且将得到的特征图x输入到transformer block进一步提取特征; s5:将s4得到的特征通过残差升级模块(residual upscale,ru)来提高特征图的分辨率,残差升级模块的具体步骤为:在模块前的特征图x和模块后的特征图xnew之间做一个kronecker积,生成高分辨率的xup;s6:将s5得到的特征图xup输入到swintransformer(st)中进一步提取其不同窗口之间的特征xst,并将得到的特征图xst通过ru模块进一步提高其分辨率得到特征xstnew;s7:将xstnew的通道维度压缩到与xpca的通道维度一致得到假样本fake
data
∈r
b*m*n*c
;s8:将生成的假样本fake
data
与真实样本xpca共同输入到鉴别器d中,将s3得到的辨别特征输入到softmax中进行分类以及辨别真假,得到最后的分类结果,同时将辨别真假以及分类结果的loss回传给生成器使其不断学习生成更高质量的样本。
技术总结
本发明针对高光谱图像分类(Hyperspectral Image Classification,HIC)领域,公开了一种基于Transformer的生成对抗网络(Generative Adversarial Network,GAN)方法。该方法将Transformer引入到GAN中,并提出了用于HIC的基于Transformer的带有残差升级模块的生成对抗网络(Transformer with residual upscale GAN,TRUG)。TRUG中包含一个生成器G和一个鉴别器D。在G中,我们提出了残差升级模块(Residual Upscale,RU),RU可以提高生成图像的分辨率。在D中,我们采用规模逐步递减的Transformer Block,并在第一层中使用网格自注意机制,以便于更好地提取图像特征。此外,GAN容易出现训练不稳定的问题,为了解决这个问题,我们改进了归一化算法,增加了相对位置编码。TRUG是第一个应用于HIC的基于Transformer的GAN。Transformer的GAN。Transformer的GAN。
技术研发人员:郝思媛 翟世杰 夏裕凤
受保护的技术使用者:青岛理工大学
技术研发日:2023.04.27
技术公布日:2023/7/25
版权声明
本文仅代表作者观点,不代表航空之家立场。
本文系作者授权航家号发表,未经原创作者书面授权,任何单位或个人不得引用、复制、转载、摘编、链接或以其他任何方式复制发表。任何单位或个人在获得书面授权使用航空之家内容时,须注明作者及来源 “航空之家”。如非法使用航空之家的部分或全部内容的,航空之家将依法追究其法律责任。(航空之家官方QQ:2926969996)
飞行汽车 https://www.autovtol.com/
