学习合成数据:解决语义分割中的域偏移问题
会议:CVPR2018
这篇文章的研究动机是什么?
本文旨在减少源域和目标域之间的域间差异,使得基于源域图片和标签训练的分割模型能够应用于目标域的分割。Domain Adaptation Segmentation方法的研究任务是什么?
对于游戏场景的街道图片数据集(包含图片和标签)和真实街道的图片数据集(只有图片没有标签),由于游戏场景图片和真实图片风格不同,Domain Adaptation Segmentation的目标是利用已有数据训练出能够应用于真实场景图片分割的模型。
什么是域间差异(Domain Shift)?
不同的数据集具有不同的数据分布,通常训练的模型只能应用于与训练数据集分布相似的数据集。当用于与训练数据集分布不同的数据集时,会产生显著的性能差异。
如下图所示,如果仅使用游戏场景的数据(即Synthetic Data)训练模型,该模型在游戏场景图片分割上表现良好(Fr),但在真实场景图片上的表现(Fs)会相对较差。Fr和Fs之间的差距实际上是由两个数据集之间的Domain Shift造成的。
本文的目标便是为了减小这个Domain Shift,从而训练出能够应用于真实街道场景图片分割的模型,并期望获得更好的性能,如图中的Fours所示。
本文如何降低Domain Shift?
网络整体架构:
对于Source image,通过F Network提取特征F(xs)。F(xs)通过两个通道进行处理:1)F(xs)通过G network重构生成Fake Source image,将Fake Source image和Real Source image送入判别器D network,以判别image的真实性并构建辅助分割网络进行分割;2)F(xs)通过C network预测每个像素的类别标签作为分割结果。训练过程中的Loss已标注。
对于Real Target image,由于没有label,因此网络中关于预测分割类别的地方均不使用。Real Target image通过F Network提取特征F(xr),F(xr)通过G network重构生成Fake Target image,将Fake Target image和Real Target image送入判别器D network,以判别image的真实性。
上述两图中的F network、G network以及D network是共享参数的。
网络的训练损失函数有哪些?
- GAN Loss
- 分割预测损失函数Lseg和辅助分割预测损失函数Laux,这两个都是像素级交叉熵函数(pixel-wise cross entropy loss)
- 图像重构损失函数Lrec,本质上是输入图片和重构图片之间的L1 Loss
如何利用这些损失函数优化网络?
- 判别器D的优化函数:旨在提高判别器的判别能力。
- 生成器G的优化函数:旨在生成更逼真的图像以欺骗判别器。
- 特征提取器F的优化函数:旨在提取具有域不变性的特征。
实验结果:
使用Source Domain生成场景数据集SYNTHIA和真实街道场景数据集CITYSCAPES为训练数据,在CITYSCAPES数据集上的测试分割结果。同样地,使用Source Domain生成场景数据集GTA5和真实街道场景数据集CITYSCAPES为训练数据,在CITYSCAPES数据集上的测试分割结果也显示了该方法的有效性。
为什么使用Domain Adaptation Segmentation进行分割,而不是直接利用真实数据和标签进行训练?
这是由于数据集标签获取困难。众所周知,分割数据集的标签是像素级的,每个像素都进行了分类。而数据集的获取方式目前仍依赖人工收集,因此使用全监督方式训练分割网络成本巨大。因此,无监督或弱监督语义分割方法应运而生。这篇论文已开源,有兴趣的读者可以进行实践验证。