GAN을 구현할 때 netD(fake.detach())를 산다. 왜 이렇게 구현하는 것일까?
## 가짜 데이터들로 학습을 합니다
# 생성자에 사용할 잠재공간 벡터를 생성합니다
noise = torch.randn(b_size, nz, 1, 1, device=device)
# G를 이용해 가짜 이미지를 생성합니다
fake = netG(noise)
label.fill_(fake_label)
# D를 이용해 데이터의 진위를 판별합니다
output = netD(fake.detach()).view(-1)
# D의 손실값을 계산합니다
errD_fake = criterion(output, label)
# 역전파를 통해 변화도를 계산합니다. 이때 앞서 구한 변화도에 더합니다(accumulate)
errD_fake.backward()
D_G_z1 = output.mean().item()
# 가짜 이미지와 진짜 이미지 모두에서 구한 손실값들을 더합니다
# 이때 errD는 역전파에서 사용되지 않고, 이후 학습 상태를 리포팅(reporting)할 때 사용합니다
errD = errD_real + errD_fake
# D를 업데이트 합니다
optimizerD.step()
############################
# (2) G 신경망을 업데이트 합니다: log(D(G(z)))를 최대화 합니다
###########################
netG.zero_grad()
label.fill_(real_label) # 생성자의 손실값을 구하기 위해 진짜 라벨을 이용할 겁니다
# 우리는 방금 D를 업데이트했기 때문에, D에 다시 가짜 데이터를 통과시킵니다.
# 이때 G는 업데이트되지 않았지만, D가 업데이트 되었기 때문에 앞선 손실값가 다른 값이 나오게 됩니다
output = netD(fake).view(-1)
# G의 손실값을 구합니다
errG = criterion(output, label)
# G의 변화도를 계산합니다
errG.backward()
D_G_z2 = output.mean().item()
# G를 업데이트 합니다
optimizerG.step()
fake.detach()는 netD의 weights의 requires_grad에 아무런 영향을 주지 않는다.
errD_fake.backward()가 실행되어도 fake의 grad는 여전히 None이고, detach()에 의해 grad_fn도 None으로 바뀌었기 때문에 netG의 weights도 역전파할 수 없다.
(errD_fake는 fake를 netD에 넣고 얻은 결과이다)
학습하고자 하는 것은 오직 weights이기 때문에 weight Tensor의 requires_grad만 True이면 된다.
위의 사진에서 fake.detach()를 하면 F의 grad_fn인 $\frac{\partial F}{\partial W}$는 다 None이 되고, errD_fake.backward()를 하면 원래 $\frac{\partial Out}{\partial F}$가 chain rule에 의해 계산이 되어야 하는데 F의 requires_grad가 False이니까 None이 된다. 즉, netG에 역전파가 안된다.
순전파 단계에서 computation graph를 만들고 각 텐서의 grand_fn(주황색 박스)을 저장했다가 역전파 단계에서 저장했던 grad_fn으로 각 텐서의 gradient를 chain rule로 계산한다.
Inference 단계에서 with torch.no_grad()나 detach()를 해서 메모리를 아낀다는 말은 Tensor의 grad_fn을 저장하지 않아 아낀다는 뜻이다. Training 단계와는 달리 역전파를 안 하니 grad_fn이 필요가 없다.
optimizerD.step()은 결국 netD의 파라미터만 업데이트 할 텐데 굳이 fake.detach()를 쓴 이유는?
netD를 학습하기 위해 쓴 fake를 netG를 학습하는 단계에서 재활용하기 때문이다. netG에서 만들어진 computation graph를 보존해야 한다.
참고
https://redstarhong.tistory.com/64
https://tutorials.pytorch.kr/beginner/dcgan_faces_tutorial.html
'인공지능' 카테고리의 다른 글
Gradient 추적과 그것을 멈춰야 하는 이유 (0) | 2022.10.27 |
---|---|
초거대 AI란 무엇인가? (1) | 2022.10.27 |
Multi-Modal (멀티 모달) AI (1) | 2022.10.26 |
Semantic segmentation에서 입력 데이터 전처리 하는 방법 (1) | 2022.10.25 |
Convolutional Neural Networks 기초 (0) | 2022.10.25 |