Sangmun

RNN BPTT(Back Propagation Through Time) 본문

네이버 AI 부스트캠프 4기

RNN BPTT(Back Propagation Through Time)

상상2 2022. 9. 25. 20:07

두번째 심화과제인 Rnn BPTT의 구현이다. 현재의 포스팅에서는 수식만 다룰 예정이다.

 

먼저 Rnn BPTT를 이해하는데 도움이 되는 공식이다.

위공식에서 Loss function을 미분하려면 결국은 θ_1에 대하여 미분을 하여야 한다.

그리고 θ_1에 대하여 미분을 하는 경로는 h_2,1,h_2,2두가지 경로가 있는데 θ_1에 대하여 미분을 실시하려면 h_2,1,h_2,2을 미분하여 나온값을 더해줘야 된다는 말이다.

 

1) Many to one

 

먼저 하나의 출력값을 가지는 Rnn의 BPTT를 다뤄보려고 한다. 본 예제에서 타임스탭은 4이다.

Many to one Rnn
타임스탭은 4이다.

 

y_hat은 Rnn의 출력값에서 분류를 위한 layer를 하나더 거친 값을 의미한다.

h_t 와 y_hat 그리고  Loss function

 

θ는 Φ(파이)와 ψ(프사이)로 구성되어있고 Φ(파이)는 Rnn의 출력값에서 분류를 위한 추가적인 Layer의 weight를 의미한다.

ψ(프사이)는 Rnn 네트워크의 weight를 의미한다.

 

Loss Function에 대한 미분공식 Φ(파이)로 미분한다.

 

Rnn 네트워크의 backpropagation ψ(프사이)로 미분한다.

 

2) Many to Many

 

다음은 Many to Many의 케이스에 대한 BPTT이다.

h_t와 y_hat, θ에 대한 내용은 many to one 네트워크와 동일하며 Loss function과 Rnn 네트워크에 대한 BPTT 내용은 다르게 적용된다.

Many to many 네트워크의 Loss function

 

Many to many 네트워크의 backpropagation

 

Comments