본문 바로가기

인공지능

PyTorch로 GAN을 만들 때 detach()를 사용하는 이유

GAN을 구현할 때 netD(fake.detach())를 산다. 왜 이렇게 구현하는 것일까?

 

## 가짜 데이터들로 학습을 합니다
        # 생성자에 사용할 잠재공간 벡터를 생성합니다
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # G를 이용해 가짜 이미지를 생성합니다
        fake = netG(noise)
        label.fill_(fake_label)
        # D를 이용해 데이터의 진위를 판별합니다
        output = netD(fake.detach()).view(-1)
        # D의 손실값을 계산합니다
        errD_fake = criterion(output, label)
        # 역전파를 통해 변화도를 계산합니다. 이때 앞서 구한 변화도에 더합니다(accumulate)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # 가짜 이미지와 진짜 이미지 모두에서 구한 손실값들을 더합니다
        # 이때 errD는 역전파에서 사용되지 않고, 이후 학습 상태를 리포팅(reporting)할 때 사용합니다
        errD = errD_real + errD_fake
        # D를 업데이트 합니다
        optimizerD.step()

        ############################
        # (2) G 신경망을 업데이트 합니다: log(D(G(z)))를 최대화 합니다
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # 생성자의 손실값을 구하기 위해 진짜 라벨을 이용할 겁니다
        # 우리는 방금 D를 업데이트했기 때문에, D에 다시 가짜 데이터를 통과시킵니다.
        # 이때 G는 업데이트되지 않았지만, D가 업데이트 되었기 때문에 앞선 손실값가 다른 값이 나오게 됩니다
        output = netD(fake).view(-1)
        # G의 손실값을 구합니다
        errG = criterion(output, label)
        # G의 변화도를 계산합니다
        errG.backward()
        D_G_z2 = output.mean().item()
        # G를 업데이트 합니다
        optimizerG.step()

 

fake.detach()는 netD의 weights의 requires_grad에 아무런 영향을 주지 않는다.

errD_fake.backward()가 실행되어도 fake의 grad는 여전히 None이고, detach()에 의해 grad_fn도 None으로 바뀌었기 때문에 netG의 weights도 역전파할 수 없다.

(errD_fake는 fake를 netD에 넣고 얻은 결과이다)

 

학습하고자 하는 것은 오직 weights이기 때문에 weight Tensor의 requires_grad만 True이면 된다.

 

Output : errD_fake, F : Fake

 

위의 사진에서 fake.detach()를 하면 F의 grad_fn인 $\frac{\partial F}{\partial W}$는 다 None이 되고, errD_fake.backward()를 하면 원래 $\frac{\partial Out}{\partial F}$가 chain rule에 의해 계산이 되어야 하는데 F의 requires_grad가 False이니까 None이 된다. 즉, netG에 역전파가 안된다.

 

순전파 단계에서 computation graph를 만들고 각 텐서의 grand_fn(주황색 박스)을 저장했다가 역전파 단계에서 저장했던 grad_fn으로 각 텐서의 gradient를 chain rule로 계산한다.

Inference 단계에서 with torch.no_grad()나 detach()를 해서 메모리를 아낀다는 말은 Tensor의 grad_fn을 저장하지 않아 아낀다는 뜻이다. Training 단계와는 달리 역전파를 안 하니 grad_fn이 필요가 없다.

 

optimizerD.step()은 결국 netD의 파라미터만 업데이트 할 텐데 굳이 fake.detach()를 쓴 이유는?

netD를 학습하기 위해 쓴 fake를 netG를 학습하는 단계에서 재활용하기 때문이다. netG에서 만들어진 computation graph를 보존해야 한다.

 

 

참고

https://redstarhong.tistory.com/64

 

netD(fake.detach())의 이유

GAN을 공부하다가 tutorial 코드에 netD(fake.detach())가 어떤 원리인 지 이해가 안 갔다. 의도야 설명에 나온대로 netG에 backpropagation이 안되도록, 즉 첫번째 스텝에서는 netD만 학습하려는 것이라는 건 알

redstarhong.tistory.com

https://tutorials.pytorch.kr/beginner/dcgan_faces_tutorial.html

 

DCGAN 튜토리얼

저자: Nathan Inkawhich, 번역: 조민성,. 개요: 본 튜토리얼에서는 예제를 통해 DCGAN을 알아보겠습니다. 우리는 실제 유명인들의 사진들로 적대적 생성 신경망(GAN)을 학습시켜, 새로운 유명인의 사진을

tutorials.pytorch.kr