Attention please

[논문 리뷰] VQ-VAE: Neural Discrete Representation Learning (2018) 본문

논문 리뷰/Image generation

[논문 리뷰] VQ-VAE: Neural Discrete Representation Learning (2018)

Seongmin.C 2025. 3. 24. 17:51
728x90
반응형

이번에 리뷰할 논문은 Neural Discrete Representation Learning 입니다.

https://arxiv.org/abs/1711.00937

 

Neural Discrete Representation Learning

Learning useful representations without supervision remains a key challenge in machine learning. In this paper, we propose a simple yet powerful generative model that learns such discrete representations. Our model, the Vector Quantised-Variational AutoEnc

arxiv.org

 

 

 

 

 

 

Posterior Collapse

일반적인 VAE 모델은 latent variable을 학습하여 데이터의 중요한 feature들을 압축하여 표현하도록 설계되었습니다. VAE의 Encoder로부터 latent variable인 $z$를 만든 후 Decoder가 이를 기반으로 reconstruction을 수행하게 됩니다. Transformer와 같은 autoregressive 모델들이 Decoder로서 사용되기 시작하여 Decoder의 성능이 크게 향상되었지만, Decoder의 성능이 너무 좋아지게 되면서 Encoder의 latent variable $z$를 거의 사용하지 않고 데이터를 복원할 수 있게 되었습니다. 즉, 학습 과정에서 $z$ 가 무시되는 문제가 생기며, latent variable이 오히려 유용한 정보를 담지 못하는 문제가 발생하게 됩니다. 이런 문제를 Posterior Collapse라 부릅니다. 

 

실제로 Variational Lossy Autoencoder 연구에 따르면, log likelihood로 측정한 가장 좋은 생성 모델은 latent없이 PixelCNN과 같은 강력한 Decoder만을 사용하는 모델일 수도 있다고 합니다. 본 논문에서는 데이터의 중요한 특징들을 Latent space 내에 보존하면서, 동시에 최대우도를 최적하하는 모델을 설계하고자 하였으며, 저자는 discrete하고 유용한 latent variable을 학습할 수 있다고 주장합니다. 

 

지금까지는 continuous feature를 이용한 representation learning이 많은 연구의 중심이었지만, 다음과 같은 이유로 저자는 우리가 관심 갖는 많은 modality에서 discrete representation이 더 자연스러운 형태라고 말합니다. 

  • 언어는 본질적으로 discrete 이다. 
  • 음성 또한 sequence of symbols 로 나타내는 것이 일반적이다. 
  • 이미지는 종종 언어로 간결하게 설명될 수 있다. 

또한, discrete representation은 다음과 같은 복잡한 인지 문제에도 적합합니다. 

  • 추론 (reasoning)
  • 계획 (planning)
  • 예측 학습 (predictive learning)

본 논문에서는, 관측값으로부터 discrete latent variable 의 posterior distribution을 새롭게 파라미터화하여, VAE 프레임워크와 discrete latent represenation을 성공적으로 결합한 새로운 생성 모델 계열을 제안합니다. 해당 모델은 Vector Quantization (VQ)에 기반하여 학습이 간단하고, 큰 variance issue도 없으며, 많은 기존 VAE에서 문제가 되었던 posterior collapse 문제도 발생하지 않습니다. 

 

 

 

 

 

 

Main Contribution

1. VQ-VAE 모델 소개 : 단순하고, discrete latent space를 사용하여, posterior collapse 문제가 없고, variance issue도 없음.

 

2. VQ-VAE이 log-likelihood 측면에서 continueous model과 유사한 성능을 가짐을 보여줌.

 

3. 강력한 prior과 결합하면, 음성, 비디오 등 다양한 응용에서 고품질 샘플 생성 가능.

 

4. 완전한 unsupervised 방식으로, 음성을 통해 언어를 학습하고, unsupervised speaker conversion 가능함을 보임.

 

 

 

 

 

 

VQ-VAE

일반적으로 VAE는 다음과 같은 구성 요소들로 이루어져 있습니다.

  • 입력 데이터 $x$가 주어졌을 때, discrete latent random variable $z$에 대한 posterior distribution $q(z|x)$ 를 파라미터화하는 Encoder Network 
  • Prior distribution $p(z)$
  • Latent variable $z$를 조건으로 하는 입력 데이터 분포 $p(x|z)$를 가지는 Decoder

본 논문에서는 vector quantisation (VQ)에서 영감을 받아 discrete latent variables를 사용하는 VQ-VAE를 제안합니다. 이 모델에서는 posterior과 prior 모두가 번주형 분포 (categorical distribution)이며, 이 분포에서 샘플링된 결과는 embedding table의 index로 사용됩니다. 이렇게 선택된 Embedding vector들은 decoder의 입력으로 사용되는 것이죠. 

 

 

 

Discrete Latent Variables

우선 잠재 임베딩 공간(Latent Embedding Space)는 실제로 Decoder의 입력으로 들어가는 Embedding vector들의 dictionary라 생각할 수 있으며, $e \in \mathbb{R}^{K \times D}$로 정의됩니다. 

  • $K$ : 이산 잠재 공간의 크기 (K-클래스 범주형 분포)
  • $D$ : 각 embedding vector $e_{i}$의 차원 수 

즉, Embedding Space에는 $K$개의 Embedding vector $e_{i} \in \mathbb{R}^{D}$가 존재하며, $i \in \{1,2,...,K\}$ 입니다.

 

 

위 그림과 같이, 모델은 입력 $x$를 받아 Encoder를 통과시켜 출력 $z_{e}(x)$를 생성합니다. 그 후, 이산 잠재 변수 $z$는 공유된 임베딩 공간 $e$를 이용해 가장 가까운 임베딩 벡터를 찾는 방식으로 계산됩니다. 이후, Decoder는 앞서 선택된 가장 가까운 임베딩 벡터 $e_{k}$를 입력으로 받아 이미지를 생성하게 됩니다. 

 

$$q(z = k \mid x) = \begin{cases} 1 & \text{if } k = \arg\min_j \| z_e(x) - e_j \|^2 \\ 0 & \text{otherwise} \end{cases}$$

 

위 식과 같이, $z_{e}(x)$는 Encoder 네트워크의 출력값이며, 이 모델의 posterior categorical distribution인 $q(z|x)$는 가장 가까운 embedding vector의 index인 $k$가 1이며, 나머지는 모두 0의 값을 가지는 one-hot vector로 정의됩니다. 

 

이 모델을 VAE의 한 형태로 간주하고, ELBO를 통해 $log\ p(x)$의 하한을 구할 수 있습니다. 또한 사후 분포 $q(z = k|x)$는 deterministic이며, $z$에 대해 단순한 균등 분포임을 통해 KL divergence가 $log\ K$와 같은 상수가 된다고 합니다.

 

VQ-VAE 의 ELBO 중 KL divergence가 상수 취급되는 과정은 다음과 같이 설명 가능합니다. 

 

[VAE - ELBO]

 

우선 log-likelihood 의 하한은 다음과 같습니다.

 

$$
\log p(x) \geq \mathbb{E}_{q(z \mid x)} [\log p(x \mid z)] - D_{\text{KL}}(q(z \mid x) \parallel p(z))
$$

 

고로 loss는 다음과 같이 정의됩니다. 

 

$$
\mathcal{L}(x) = -\mathbb{E}_{q(z \mid x)} [\log p(x \mid z)] + D_{\text{KL}}(q(z \mid x) \parallel p(z))
$$

 

위 손실함수의 1번째 항은 reconstruction loss로 MSE혹은 Cross Entropy로 계산되며, 2번째 항은 KL-divergence로 latent space를 regularization하는 용도로 사용됩니다. 

 

[VQ-VAE]

 

VQ-VAE의 역시 $ D_{\text{KL}}(q(z \mid x) \parallel p(z)) $ 형태의 KL-divergence를 가지며, 각 분포는 다음과 같이 정의됩니다.

  • $q(z|x)$ : 결정적 one-hot distribution (항상 1개만 선택)
  • $p(z)$ : 균등 분포 (모든 $z$가 같은 확률을 가짐)

고로 VQ-VAE의 KL-divergence는 다음과 같이 정리됩니다.

 

$$ D_{\text{KL}}(q(z \mid x) \parallel p(z)) $$

$$ = \log \frac{q(z|x)}{p(z)} $$

$$ = \log q(z|x) - \log p(z) $$

$$ = \log k - \log 1 $$

$$ = \log k $$

 

위와 같이 $\log k$는 항상 상수 취급되며, 학습에서 무시되는 것이 가능합니다.

 

$$
z_q(x) = e_k, \quad \text{where } k = \arg\min_j \| z_e(x) - e_j \|^2
$$

 

또한 위 식은 앞서 말했던 것처럼 Encoder의 출력인 $z_{e}(x)$이 양자화 병목(discretisation bottlenect)을 거쳐 embedding space인 $e$에서 가장 가까운 벡터로 mapping되는 것을 의미합니다. 

 

 

 

Learning

앞서 봤던 수식을 다시 보면,

 

$$
z_q(x) = e_k, \quad \text{where } k = \arg\min_j \| z_e(x) - e_j \|^2
$$

 

에서는 정의된 real gradient는 존재하지 않습니다. 하지만 VQ-VAE에서는 straight-through estimator과 유사한 방식으로 gradient를 근사시킵니다. 아주 간단하게 이를 해결하는데, 단순히 Decoder의 입력인 $z_{q}(x)$의 gradient를 Encoder의 출력인 $z_{e}(x)$에 그대로 복사하는 것으로 진행됩니다.

 

다음으로, forward pass에서는 가장 가까운 embedding $z_{q}(x)$가 Decoder에 전달되고, backward pass에서는 gradient $\triangledown_{2} L$가 변형 없이 Encoder로 전달됩니다. 이는 Encoder의 출력 표현과 Decoder의 입력이 동일한 $D$차원의 공간을 공유하기 때문에, 이 gradient는 Encoder가 Reconstruction loss를 줄이기 위해 출력을 어떻게 변경해야 하는지에 대한 유용한 정보를 담고 있습니다. 

 

 

위 그림을 보시면, gradient는 Encoder의 출력을 밀어내어, 다음 forward pass 에서 다른 embedding으로 양자화될 수 있도록 만드려는 것을 확인할 수 있습니다. 즉, 

 

$$q(z = k \mid x) = \begin{cases} 1 & \text{if } k = \arg\min_j \| z_e(x) - e_j \|^2 \\ 0 & \text{otherwise} \end{cases}$$

 

에서의 Embedding 선택이 바뀔 수 있음을 의미합니다. 

 

 

 

Loss Function

다음으로는 VQ-VAE의 전체 손실 함수에 대해 설명드리겠습니다.

 

$$
\mathcal{L} = \log p(x \mid z_q(x)) + \| \text{sg}[z_e(x)] - e_k \|_2^2 + \beta \| z_e(x) - \text{sg}[e] \|_2^2
$$

 

3가지의 항으로 구성되어있으며, 각각

  • 1 term : Reconstruction Loss
  • 2 term : $l_{2}$ - 거리 기반 손실 (k-means)
  • 3 term : Commitment Loss

로 사용됩니다. 

 

위 손실함수의 각 항에 대해 하나하나 뜯어 보도록 합시다.

 

<1 term : Reconstruction Loss>

 

$$\log p(x \mid z_q(x))$$

 

위 수식은 형태 그대로 log-likelihood 즉, reconstruction을 위한 loss입니다. 다만, 한가지 유의깊게 봐야할 점은 Encoder의 출력은 $z_{e}(x)$ 였지만, straight through gradient를 통해 $z_{e}(x) \rightarrow z_{q}(x)$로의 mapping이 수행되었기 때문에, embedding $e_{i}$는 reconstruction loss인 $ \log p(x \mid z_q(x)) $ 로부터 gradient를 받지 않습니다. 

 

 

 

<2 term : $l_{2}$ - 거리 기반 손실 (k-means)>

 

$$ \| \text{sg}[z_e(x)] - e_k \|_2^2 $$

 

Embedding table을 학습하기 위한 term이며, embedding vector $e_{i}$를 encoder의 output으로 끌어당기는 $l_{2}$ 거리 기반 손실을 사용합니다. 해당 term은 오직 embedding만을 학습하는데 사용된다는 특징이 있습니다.

 

 

 

<3 term : Commitment Loss>

 

$$ \beta \| z_e(x) - \text{sg}[e] \|_2^2 $$

 

Encoder는 $z_{e}(x)$를 아무 방향으로 막 출력해도 결국 Decoder로는 가장 가까운 $e_{k}$만 입력되게 됩니다. 즉, Encoder가 Embedding table에 commit하지 않아도 잘 학습됨을 의미합니다. 

 

이는 Embedding과 Encoder output이 점점 멀어져 학습을 불안정하게 만들며, latent의 의미를 사라지게 만듭니다. 즉, 해당 term을 통해 Encoder의 output과 Embedding table의 embedding vector간 거리를 좁혀주는 것 즉, commit하도록 만들어주는 것이죠. 

 

다시 말해, "그 Embedding vector 사용할거면, 거기에 맞춰서 나와!" 라고 말하는 것으로 이해하면 직관적입니다.

 

 

정리하자면,

 

1. Decoder는 1 번째 항으로 최적화

2. Encoder는 1, 3 번째 항으로 최적화

3. Embedding은 2 번째 항으로 최적화

 

입니다.

 

추가로, 본 논문에서는 실험을 통해 위 알고리즘이 $\beta$에 대해 robust 하다는 것을 발견하였으며, 실제로 0.1 ~ 2.0 까지 바꾸어도 결과가 크게 변하지 않았다고 합니다. 실험에서는 $\beta = 0.25$ 를 사용하였습니다.

 

또한, 위에서 설명했듯이 $z$에 대해 uniform prior를 가정하기 때문에, ELBO에서 일반적으로 나타나는 KL-divergence 항은 Encoder 파라미터에 대해 상수처리가 되며, 학습 시 무시할 수 있습니다. 

 

실험에서는 $N$개의 이산 latent를 정의합니다. 이미지의 경우 여러개의 패치로 나누어 각각 정의될 수 있죠.

(e.g. ImageNet의 경우 32 x 32 / CIFAR10의 경우 8 x 8 x 10)

 

위 경우에도, 전체 손실 $L$의 구조는 동일하지만, k-means 손실commitment 손실$N$개 위치에 대해 평균을 낸 형태로 계산합니다. 즉, 각 latent마다 하나씩 계산 후 평균값을 취하게 됩니다. reconstruction 손실의 경우 전체 이미지 단위로 있기 때문에 그대로 두고 나머지 두 항에 대해서만 spatial grid 전체에 대해 평균값을 취하는 것입니다.

 

 

 

Log likelihood : $\log p(x)$

전제 모델의 log likelihood $\log p(x)$는 다음과 같이 평가할 수 있습니다.

 

$$
\log p(x) = \log \sum_k p(x \mid z_k) \, p(z_k)
$$

 

하지만 여기서 Decoder는 항상 Encoder가 선택한 embedding인 $z = z_{q}(x)$만을 보고 학습되었습니다. 그 외에 $z$들에 대해서는 Decoder가 확률을 할당하지 않죠. 즉, $z_{q}(x)$외에 모두 $p(x | z) \approx 0$ 이 됩니다.

 

$$
\log p(x) \approx \log \left( p(x \mid z_q(x)) \cdot p(z_q(x)) \right)
$$

 

결론적으로, 위와 같이 근사하는 것이 가능합니다. 즉, 전체 $z$에 대해 sum을 하지 않고, MAP(Most A Posteriori) 근사에 나온 딱 하나의 $z_{q}(x)$만 가지고 계산하게 됩니다. 

 

$$
\log p(x) \geq \log \left( p(x \mid z_q(x)) \cdot p(z_q(x)) \right)
$$

 

또한, 위 식과 같이 Jensen's inequality에 의해 위와 같은 부등식도 성립됩니다. 

 

728x90
반응형
Comments