GAIN: Missing Data Imputation using Generative Adversarial Nets

论文:GAIN: Missing Data Imputation using Generative Adversarial Nets

这篇论文使用生成式对抗网络,利用缺失数据集,学习数据的真实分布,进而得到数据集的缺失值。

问题设定

考虑一个 $d$ 维空间 $\mathcal{X}$, 设数据向量 $\mathbf{X}=\left(X_{1}, \dots, X_{d}\right) \in \mathcal{X}$ 是一个连续的或者离散的随机变量. 令$\mathrm M=(M_1,…,M_D)$ 是一个随机变量, 取值于 ${(0,1)}^{d}$, 称为掩码向量.

对每一个 $i \in\{1, \ldots, d\}$, 定义一个新的空间 $\tilde{\mathcal{X}}_{i}=\mathcal{X}_{i} \cup \{*\}$, ${*}$ 代表缺失值, 令 $\tilde{\mathcal{X}}=\tilde{\mathcal{X}}_{1} \times \ldots \times \tilde{\mathcal{X}}_{d}$, 定义一个新的随机变量 $\tilde{\mathbf{X}}=\left(\tilde{X}_{1}, \ldots, \tilde{X}_{d}\right) \in \tilde{\mathcal{X}}$, 且

$$\tilde{X}_{i}=\left\{\begin{array}{ll}{X_{i},} & {\text { if } M_{i}=1} \\ {*,} & {\text { otherwise }}\end{array}\right.$$

掩码向量 $\mathrm M$ 中元素 0 对应着缺失值的位置.

生成器 $G$ 的输入为 $\tilde{\mathbf{X}}, \mathbf{M}, \mathbf{Z}$, 其中 $\mathbf{Z}$ 是随机的隐变量, 输出为 $\overline{\mathbf{X}}$, 定义随机变量 $\overline{\mathbf{X}}, \hat{\mathbf{X}} \in \mathcal{X}$ 为:

$$\begin{array}{l}{\mathbf{ \bar{X}}=G(\tilde{\mathbf{X}}, \mathbf{M},(\mathbf{1}-\mathbf{M}) \odot \mathbf{Z})} \\ {{\mathbf{\hat{X}}}=\mathbf{M} \odot \tilde{\mathbf{X}}+(\mathbf{1}-\mathbf{M}) \odot \bar{\mathbf{X}}}\end{array}$$

本文还引入了暗示机制(Hint Mechanism), 简单来说, 暗示矩阵 $\mathbf{H}$ 在数据不缺失的位置随机赋值为 $1$ 或 $0.5$, 在数据缺失的位置随机赋值为 $0$ 或 $0.5$, 注意这种随机不是等可能的.

价值函数

$$\begin{aligned} V(D, G)=& \mathbb{E}_{\hat{\mathbf{X}}, \mathrm{M}, \mathbf{H}}\left[\mathbf{M}^{T} \log D(\hat{\mathbf{X}}, \mathbf{H})\right.\\ &+(\mathbf{1}-\mathbf{M})^{T} \log (\mathbf{1}-D(\hat{\mathbf{X}}, \mathbf{H})) ] \end{aligned}$$

就像标准的 GANs 网络一样, GAIN 的优化目标是一个极小化极大值问题

$$\min _{G} \max _{D} V(D, G)$$

定义损失函数 $\mathcal{L} :\{0,1\}^{d} \times[0,1]^{d} \rightarrow \mathbb{R}$ 为

$$\mathcal{L}(\mathbf{a}, \mathbf{b})=\sum_{i=1}^{d}\left[a_{i} \log \left(b_{i}\right)+\left(1-a_{i}\right) \log \left(1-b_{i}\right)\right]$$

记 $\hat{\mathbf{M}}=D(\hat{\mathbf{X}}, \mathbf{H})$, 则优化问题可以改写为

$$\min _{G} \max _{D} \mathbb{E}[\mathcal{L}(\mathbf{M}, \hat{\mathbf{M}})]$$

GAIN 算法

GAIN 算法的训练方法与 Goodfellow 的 GANs 网络的方法相似. 生成器 ${G}$ 和判别器 ${D}$ 都是全连接网络. 首先固定 ${G}$ 训练 ${D}$, 然后利用训练好的 ${G}$ 训练 ${D}$.

定义判别器 ${D}$ 的损失函数 $\mathcal{L}_{D} :\{0,1\}^{d} \times[0,1]^{d} \times\{0,1\}^{d} \rightarrow \mathbb{R}$

$$\begin{aligned} \mathcal{L}_{D}(\mathbf{m}, \hat{\mathbf{m}}, \mathbf{b})=& \sum_{i : b_{i}=0}\left[m_{i} \log \left(\hat{m}_{i}\right)\right.\\ &+\left(1-m_{i}\right) \log \left(1-\hat{m}_{i}\right) \end{aligned}$$

生成器 $G$ 的损失由两部分组成. 不仅要确保 $G$ 输出的插补值能够”骗过”判别器, 还要保证对于非缺失值的生成尽可能和真实值接近.

定义 $\mathcal{L}_{G} :{0,1}^{d} \times[0,1]^{d} \times{0,1}^{d} \rightarrow \mathbb{R}$,

$$\mathcal{L}_{G}(\mathbf{m}, \hat{\mathbf{m}}, \mathbf{b})=-\sum_{i : b_{i}=0}\left(1-m_{i}\right) \log \left(\hat{m}_{i}\right)$$

定义 $\mathcal{L}_{M} : \mathbb{R}^{d} \times \mathbb{R}^{d} \rightarrow \mathbb{R}$,

$$\mathcal{L}_{M}\left(\mathbf{x}, \mathbf{x}^{\prime}\right)=\sum_{i=1}^{d} m_{i} L_{M}\left(x_{i}, x_{i}^{\prime}\right)$$

其中

$$L_{M}\left(x_{i}, x_{i}^{\prime}\right)=\left\{\begin{array}{ll}{\left(x_{i}^{\prime}-x_{i}\right)^{2},} & {\text { if } x_{i} \text { is continuous, }} \\ {-x_{i} \log \left(x_{i}^{\prime}\right),} & {\text { if } x_{i} \text { is binary. }}\end{array}\right.$$