Diffusion Models for Adversarial Purification
1. Introduction
-
Adversarial attack์ ๋ง๋ ๋ฐฉ๋ฒ์๋ ํฌ๊ฒ ๋๊ฐ์ง๊ฐ ์๋ค.
- Adversarial training: Adversarial sample์ ํ์ต์์ผ ํด๋น sample์ ํน์ง์ ๋ชจ๋ธ์ด ํ์ต ํ๋๋ก ๋ง๋๋ ๋ฐฉ๋ฒ์ด๋ค.
- ์ฅ์ : ๊ฑฐ์ ๋๋ถ๋ถ์ ๊ฒฝ์ฐ์ SOTA๋ฅผ ๋ณด์ด๊ณ ์๋ ๋ชจ๋ธ์ด๋ค
- ์ฅ์ : ํ์ต์ ์ฌ์ฉ๋ adversarial example๋ง ๋ง์ ์ ์๋ค. ๋ํ ๋ณ๋์ ํ๋ จ๊ณผ์ ์ด ํ์ํ๋ฏ๋ก computationlly expensive ํ๋ค.
- Adversarial purification: Adversarial example์ ์์ฑ ๋ชจ๋ธ์ด ํ๋ฒ purify ํด์ฃผ๊ณ purified๋ sample์ ๋ถ๋ฅํ๋ค.
- ์ฅ์ : plug-and-play manner์ผ๋ก ์๋ํ๋ฏ๋ก adversarial training๊ณผ๋ ๋ฌ๋ฆฌ ๋ณ๋์ ํ๋ จ๊ณผ์ ์ ์๊ตฌํ์ง ์๋๋ค.
- ๋จ์ : ์์ฑ๋ชจ๋ธ์ด ๊ฐ์ง๊ณ ์๋ ์ฌ๋ฌ ๋ฌธ์ (GAN์ ๊ฒฝ์ฐ์๋ mode collapse) ๋๋ฌธ์ adversarial training๋ณด๋ค ์ฑ๋ฅ์ด ์ข์ง ์๋ค.
- Adversarial training: Adversarial sample์ ํ์ต์์ผ ํด๋น sample์ ํน์ง์ ๋ชจ๋ธ์ด ํ์ต ํ๋๋ก ๋ง๋๋ ๋ฐฉ๋ฒ์ด๋ค.
-
Diffusion model(์ดํ DM)์ likelihood based model๋ก, GAN๊ณผ ๋น๊ตํ์ฌ์ ์ฌ๋ฌ ์ฅ์ ์ ๊ฐ์ง๊ณ ์๋ค. ์ด๋ฌํ ์ฅ์ ๋๋ฌธ์ purification ์ฑ๋ฅ์ด ๋ ์ข์ ๊ฒ์์ด ๊ธฐ๋๋๋ค.
-
ํด๋น ๋ ผ๋ฌธ์์๋ ๋ค์๊ณผ ๊ฐ์ contribution์ ํ๊ณ ์๋ค.
- pretrained๋ DM์ forward and reverse process๋ฅผ ํตํด adversarial purification์ ์ํํ๋ DiffPure์ ์ ์ํ๋ค.
- DM์ forward and reverse process๊ฐ ์ด๋ป๊ฒ ํจ๊ณผ์ ์ผ๋ก adversarial perturbation์ ์ ๊ฑฐํ๋ ๋์์ label semantics๋ฅผ ์ ์งํ๋์ง ์์๋ณธ๋ค.
- Adaptive attack์ ๋ง๊ธฐ ์ํด reverse process์์ ํจ๊ณผ์ ์ผ๋ก full gradient๋ฅผ ๊ตฌํ๋ ๋ฐฉ๋ฒ์ ์ ์ํ๋ค.
- Adaptive attack ๋ฒค์น๋งํฌ์์ SOTA๋ฅผ ๋ณด์ด๊ณ ์์์ ์ฌ๋ฌ ์คํ์ ํตํด ์ฆ๋ช ํ๋ค.
2. Background
ํด๋น ์น์ ์์๋ โScore-based generative modeling through stochastic differential equations.โ ๋ ผ๋ฌธ์์ ์ ์ํ๊ณ ์๋ continuous-time diffusion model์ ๋ํด์ ๊ฐ๋จํ๊ฒ ๋ฆฌ๋ทฐํ๊ณ ์๋ค. ๋จผ์ foward diffusion process๋ ๋ค์๊ณผ ๊ฐ์ด ๋ํ๋ธ๋ค.
DM์๋ ํฌ๊ฒ ๋๊ฐ์ง ์ข ๋ฅ๊ฐ ์๋๋ฐ, VE-SDE์ VP-SDE ์ด๋ค. ์ด ๋ ผ๋ฌธ์์๋ VP-SDE๋ฅผ purification model๋ก ์ฌ์ฉํ๋ค.
๋ค์์ผ๋ก sample generation์ ๋ค์๊ณผ ๊ฐ์ด reverse-time SDE๋ฅผ ํตํ์ฌ ์ํํ๋ค.
DM์ ํ๋ จ์ ๋ค์๊ณผ ๊ฐ์ด score function์ธ $s_\theta (\tilde{x}, t)$์ transition probability์ ์ต๋ํ ๊ฐ๊น์ ์ง๋๋ก ๋ง๋๋ denosing score matching ๋ฐฉ๋ฒ์ ์ด์ฉํ๋ค.
3. Method
3.1 Diffusion purification
ํด๋น ์น์ ์์๋ adversarial example์ธ $x_a$์ foward process๋ฅผ ํตํด ๋ ธ์ด์ฆ๋ฅผ ์ถ๊ฐํ๊ณ ๋ค์ reverse process๋ฅผ ํตํด ๋ ธ์ด์ฆ๋ฅผ ์ ๊ฑฐํ๋ฉด example์ด ๊ฐ์ง๊ณ ์๋ adversarial feature๋ค์ด ์ฌ๋ผ์ง๊ณ clean image๋ก ๋ณ๊ฒฝ๋๋ค๊ณ ์ฃผ์ฅํ๊ณ ์๋ค. ์ฆ, ๋ค์๊ณผ ๊ฐ์ ์ ๋ฆฌ๋ฅผ ๋ด์ธ์ ๋ค. (์ฌ๊ธฐ์ ์ ๋ฆฌ์ ๋ํ ์ฆ๋ช ์ ๊ธธ์ด์ ์๋ตํ๋๋ก ํ๊ฒ ๋ค.)
${x(t)}_{t\in[0,1]}$๊ฐ foward SDE, $p_t$๊ฐ clean data ๋ถํฌ, $q_t$๊ฐ adversarial sample ๋ถํฌ์ผ ๋ ์ฆ๋ช
์ ํตํด ๋ค์๊ณผ ๊ฐ์ด forward SDE๊ฐ ์ผ์ด๋ ์๋ก $p_t$์ $q_t$๋ ์๋ก ๊ฐ๊น์์ง๋ค.
DM์ forward, reverse process๋ฅผ ํตํด adversarial example์ด clean image์ ๊ฐ๊น์์ง๋ค๋ ๊ฒ์ ์ฆ๋ช ํ์ผ๋ฏ๋ก forward, reverse process๊ฐ ์ผ์ด๋๋ ๊ณผ์ ์ ๋ํด์ ์ดํด๋ณด๋ฉด ๊ฐ๊ฐ ๋ค์๊ณผ ๊ฐ๋ค.
- Forward process: ์ด ๋ ผ๋ฌธ์์๋ DM์ผ๋ก VP-SDE๋ฅผ ์ฌ์ฉํ๋ค. VP-SDE๋ ๋ค์๊ณผ ๊ฐ์ ๊ณต์์ผ๋ก timestep $t^{\ast}$์ ๋ ธ์ด์ฆ๊ฐ ์ถ๊ฐ๋ ์ด๋ฏธ์ง๋ฅผ ๊ตฌํ ์ ์๋ค.
- Reverse process: Euler-Maruyama solver๋ก reverse-time SDE๋ฅผ ๋ค์๊ณผ ๊ฐ์ด ํ ์ ์๋ค.
์ ์์์ sdeint์ ์ ๋ ฅ๊ฐ์ ์ฐจ๋ก๋๋ก initial value, drift coefficient, diffusion coefficient, Wiener process, initial time, ์ end time ์ด๋ค. $f_{rev}$์ $g_{rev}$๋ ๋ค์๊ณผ ๊ฐ๋ค.
์ฌ๊ธฐ์ ํต์ฌ์ diffusion์ ํ๋ timestep์ธ $t^{\ast}$๋ฅผ ์ ํ๋ ๊ฒ์ด๋ค. ์ต์ ์ $t^{\ast}$์ ๋ค์๊ณผ ๊ฐ์ ์ ๋ฆฌ๋ฅผ ํตํด์ ๊ตฌํ ์ ์๋ค.
๋ง์ฝ score function์ด $\lVert s_\theta (x,t) \rVert \leq \frac{1}{2}C_s$๋ฅผ ๋ง์กฑํ๋ค๊ณ ๊ฐ์ ํ๋ฉด clean data $x$์ purified data $\hat{x}(0)$์์ L2 ๊ฑฐ๋ฆฌ์ upper bound๋ ๋ค์๊ณผ ๊ฐ์ด ๊ตฌํ ์ ์๋ค.
purified image๊ฐ ์ต๋ํ clean image์ ์ ์ฌํ label semantics๋ฅผ ๊ฐ์ ธ์ผ ํ๋ฏ๋ก $\lVert \tilde{x}(0)-x \rVert$๋ฅผ ์๊ฒ ํ๋ $t^{\ast}$๋ฅผ ์ฐพ์์ผ ํ๋ค. ๊ทธ๋ฌ๊ธฐ ์ํด์๋ $\gamma(t^{\ast})$๊ฐ ์์์ผ ํ๊ณ , ์ด๋ ๊ณง $t^{\ast}$๊ฐ ์์์ผ ํ๋ค๋ ์๋ฏธ์ด๋ฏ๋ก $t^{\ast}$๋ฅผ ์ต๋ํ ์๊ฒ ์ค์ ํด์ผ ํ๋ค.
์ฒซ๋ฒ์งธ์ ๋๋ฒ์งธ ์ ๋ฆฌ๋ฅผ ํตํด $t^{\ast}$๊ฐ์ ๋ฐ๋ผ adversarial sample์์ perturbation์ purifyํ๋ ์ ๋์ purified image๊ฐ ๊ฐ์ง๋ label semantics๊ฐ์๋ trade-off๊ฐ ์๋ค๋ ๊ฒ์ ์ ์ ์๋ค. ๋ฐ๋ผ์ ์ต์ ์ $t^{\ast}$๋ฅผ ์ ์ ํํ๋ ๊ฒ์ด ์ค์ํ๋ค. ์ด ๋ ผ๋ฌธ์์๋ perturbation์ด ์์ ๊ฒฝ์ฐ๊ฐ ๋ง๋ค๊ณ ํ์๊ธฐ ๋๋ฌธ์ $t^{\ast}$๊ฐ์ ์๊ฒ ์ค์ ํ์๋ค.
3.2 Adaptive attack to diffusion purification
๊ฐํ adaptive attack ๊ฐ์ ๊ฒฝ์ฐ๋ SDE solver์ full gradient๋ฅผ ๊ณ์ฐํ๋ค. ๊ทธ๋ฌ๋ ์ด๋ฌํ ๊ณผ์ ์ ๋งค์ฐ computationally expensiveํ๋ฏ๋ก ํด๋น ๋ ผ๋ฌธ์์๋ adjoint method๋ฅผ ์ด์ฉํ์ฌ $O(1)$์ ๊ณต๊ฐ ๋ณต์ก๋๋ก ์ฐ์ฐ์ ์ํํ์๋ค. ํด๋น ๋ฐฉ๋ฒ์ ์์ธํ ๋ด์ฉ์ ๋ค์๊ณผ ๊ฐ๋ค.
Adjoint method์ ๋ํด์๋ โscalable gradients for stochastic differential equationsโ๋ ผ๋ฌธ์ ์์ธํ๊ฒ ๋์์๋ค.
4. Experiments
๋ชจ๋ธ์ robustness๋ฅผ ํ๊ฐํ๋ ๋ํ์ ์ธ ๋ฒค์น๋งํฌ์ธ RobustBench์ SOTA๋ชจ๋ธ๋ค๊ณผ ๋น๊ต๋ฅผ ํ์๋ค.
Experimental settings
- Dataset์ CIFAR-10, CelebA-HQ, ImageNet์ ์ฌ์ฉํ์๋ค.
- Classifier์ ResNet, WideResNet, ViT๋ฅผ ์ฌ์ฉํ์๋ค.
- Adversarial attack ๋ฐฉ๋ฒ์๋ AutoAttack์ $l_\infty $, $l_2$์ StAdv, BPDA+EOT๋ฅผ ์ฌ์ฉํ์๋ค.
- Evaluation metric์ผ๋ก๋ clean data์ ๋ํ ํ๊ฐ์ธ standard accuracy, purified data์ ๋ํ ํ๊ฐ์ธ robust accuracy๋ฅผ ์ฌ์ฉํ์๋ค.
Comparision with the sate-of-the-art
๋จผ์ AutoAttack $l_\infty (\epsilon = 8/255) $๋ฅผ CIFAR-10์ ์ ์ฉํ์ ๋ ๊ฒฐ๊ณผ์ด๋ค. DiffPure์ด SOTA๋ฅผ ๋ชจ๋ ํ๊ฐ์งํ์์ ๋ฌ์ฑํ๊ณ ์๋ค๋ ์ฌ์ค์ ์ ์ ์๋ค.
๋ค์์ CIFAR-10์ $l_2 (\epsilon = 0.5)$ AutoAttack์ ๊ฐ๊ธฐ ๋ค๋ฅธ classifier์ ์ ์ฉํ์ ๋ ๊ฒฐ๊ณผ์ด๋ค. ์ถ๊ฐ์ ์ธ ๋ฐ์ดํฐ์ ํ์ต์ด ์์ด๋ SOTA๋ชจ๋ธ๊ณผ ์ฑ๋ฅ์ด ๋น์ทํ๊ฑฐ๋ ๋ ์ข๋ค๋ ๊ฒ์ ์ ์ ์๋ค.
๋ง์ง๋ง์ AutoAttack $l_\infty (\epsilon = 4/255) $์ ImageNet์ ์ ์ฉํ์ ๋ ๊ฒฐ๊ณผ์ด๋ค. ๋ง์ฐฌ๊ฐ์ง๋ก ๋ชจ๋ classifier์์ SOTA์์ ํ์ธํ ์ ์๋ค.
Defense against unseen threats
Adversarial training์ ์ฃผ์ ๋จ์ ์ด๋ผ๊ณ ํ๋ค๋ฉด ํ๋ จ๋์ง ์์ ๊ณต๊ฒฉ๋ฐฉ๋ฒ์๋ ์ทจ์ฝํ๋ค๋ ๊ฒ์ด๋ค. ์๋ ํ๋ CIFAR-10์ adversarial trainingํ ๋ชจ๋ธ๊ณผ์ ์ฑ๋ฅ์ ๋น๊ตํ ํ์ธ๋ฐ, ํ์๊ธ์๋ ๋ชจ๋ธ์ด ํ๋ จ๋ ๊ณต๊ฒฉ ๋ฐฉ๋ฒ์ด๋ค. ๋น๊ต ๊ฒฐ๊ณผ ๋ชจ๋ ๊ณต๊ฒฉ๋ฐฉ๋ฒ์ ๋ํ์ฌ diffpure์ด ๋์ ์ ํ๋๋ฅผ ๋ณด์ด๊ณ ์์์ ์ ์ ์๋ค.
Comparison with other purfication model
Diffpure์ GAN์ ์ฌ์ฉํ ๋ค๋ฅธ purification ๋ฐฉ๋ฒ๊ณผ์ ๋น๊ต ๋ํ ์ํํ์๋ค. ์ฌ๊ธฐ์ AutoAttack๊ณผ ๊ฐ์ white-box adaptive attack์ ๊ฐํ๋ฉด ์ต์ ํ ๋๋ sampling loop ๋ฌธ์ ๊ฐ ๋ํ๋ฌ๊ธฐ ๋๋ฌธ์ BPDA+EOT ๊ณต๊ฒฉ ๋ฐฉ๋ฒ์ ์ฌ์ฉํ์๋ค. ๋ฐ์ดํฐ์ ์ CelebA-HQ์ CIFAR-10์ ์ฌ์ฉํ์๊ณ , ๋น๊ต์ ๋ํ ๊ฒฐ๊ณผ๋ ์๋ํ์ ๋์์๋ค.
์ฌ๊ธฐ์ ENC์ OPT๋ GAN์ optimization-based, encoder-based inversion์ ์ง์นญํ๋ค. ๋ ๋ฐ์ดํฐ์ ์์ Robust accuracy๊ฐ ๋ชจ๋ SOTA ์ฑ๋ฅ์ ๋ณด์ด๋ ๊ฒ์ ์ ์ ์๋ค.
Conclusion
-
Adversarial example์ ํ๋ฒ ๋ถ๋ฅ๊ธฐ๋ก ๋ถ๋ฅํ๊ธฐ ์ ์ ์ ํ์ํค๋ Diffpure ๋ฐฉ๋ฒ์ ์๋กญ๊ฒ ์ ์ํ์๋ค.
-
White-box adaptive attack์ ๋ํ robustness๋ฅผ ํ๊ฐํ๊ธฐ ์ํด์๋ full gradient๋ฅผ ๊ตฌํ๋ ์์ ์ด ํ์ํ๋ฐ, ์ด๋ฌํ ์์ ์ adjoint method๋ฅผ ์ฌ์ฉํ์๋ค.
-
Diffpure์ ๋ค๋ฅธ SOTA ๋ฐฉ๋ฒ๋ค๊ณผ ๋น๊ตํ๊ธฐ ์ํด ์คํ์ ์งํํ์๋ค. ๋ฐ์ดํฐ์ ์ CIFAR-10, ImageNet, CelebA-HQ๋ฅผ ์ฌ์ฉํ์๊ณ classifier๋ ResNet, WideResNet, ViT ๋ฅผ ์ฌ์ฉํ์๋ค. ์คํ ๊ฒฐ๊ณผ robust accuracy์์ ์ด๋ค์ ๋ชจ๋ ๋ฐ์ด๋์๋ค.
-
๊ทธ๋ฌ๋ ๋ค์๊ณผ ๊ฐ์ ๋ ๊ฐ์ง ํ๊ณ์ ์ ๋ดํฌํ๊ณ ์๋ค.
- Purification์ ํ๋ ์๊ฐ์ด ๋๋ฌด ์ค๋ ์์๋๋ค.
- Diffusion model์ ์ด๋ฏธ์ง์ ์์ ๊ต์ฅํ ๋ฏผ๊ฐํ๋ฏ๋ก ์์ ๋ํ ๊ณต๊ฒฉ์ ํจ๊ณผ์ ์ผ๋ก ๋ฐฉ์ดํ์ง ๋ชปํ๋ค.
Comments