RNN 筆記 - 損失函數與反向傳遞演算法
RNN 筆記 - 損失函數與反向傳遞演算法
Recurrent Neural Network (RNN)
Ref 吳尚鴻教授上課影片 from Youtube: https://www.youtube.com/watch?v=2btuy_-Fw3c&list=PLlPcwHqLqJDkVO0zHMqswX1jA9Xw7OSOK
Ref 吳尚鴻教授上課影片 from Youtube: https://www.youtube.com/watch?v=2btuy_-Fw3c&list=PLlPcwHqLqJDkVO0zHMqswX1jA9Xw7OSOK
Ref LaTeX Math Symbols: http://web.ift.uib.no/Teori/KURS/WRK/TeX/symALL.html
本篇討論 RNN 是怎麼 train 的!
Cost Function of Vanilla RNNs
- Maximum likelihood:
- depends only on
C 代表 Cost 作為神經網路的 Loss function
這個 Loss function 除了看不同的 data point “n” 以外,也要看時間維度上不同的 “t” 來計算
這個 Loss function 除了看不同的 data point “n” 以外,也要看時間維度上不同的 “t” 來計算
P 代表 Probability 是一個 on given Weight 的條件下的條件機率
Cost function 跟一般的 NN 使用的差異不大,只是多了一個時間維度 t
Backpropagation Through Time (BPTT)
SGD: 先隨機猜參數 “Theta” 然後去計算 partial sum of Losses 乘以 Learning Rate “eta” 再去更新參數
為求簡化 notation of Loss 的表示法,改以 c 來表示如下:
目標是去計算
and
複習一下參數的定義:
U 影響的是時間維度的 gradient
W 影響的是神經網路"層與層之間"空間維度的 gradient
U 影響的是時間維度的 gradient
W 影響的是神經網路"層與層之間"空間維度的 gradient

W 部分跟一般的神經網路相同,先不贅述
這次我們要特別來探討 U 上面的 error signal
在 forward pass 的階段:
除了原本就有的網路架構上 forward pass 之外,多了 “Forward Pass Through Time”

除了原本就有的網路架構上 forward pass 之外,多了 “Forward Pass Through Time”

We can get all second term starting from the most shallow layer and earliest time
同理
Backward Pass 的階段也多了 “Backward Pass Through Time”
Backward Pass 的階段也多了 “Backward Pass Through Time”

上圖中
- 藍色部分: 目前層的 a 會影響到深一層的所有 z
會影響到 - 紅色部分: 目前層的 a 會影響到同一層,但在下一個時間的所有 z
會影響到
化簡式子:

結論:

結論:
- 把所有深層的 W 跟 error signal 算好,然後從深層走回淺層
- 把所有最新時間點的 U 跟 error signal 算好,然後從最新時間點往回走
把以上兩者結果相加後,經過 activation function 微分後的 gradient 值,即為該神經元更新參數之依據
實務上,對於同一個 sequence n 來講
forward pass 可以 shared

forward pass: 從左下走到右上
對於每一個時間點的 Loss 都可以分別算 error signal
時間點 1 的 error signal 沿著綠色走
時間點 2 的 error signal 沿著藍色走
時間點 3 的 error signal 沿著紅色走
backward pass 全部算出來後:
對於 W(1) 來講,把三個 error signal 全部加起來,乘上 forward pass 算出的結果,得到對於整個 sequence 需更新的參數組合
forward pass 可以 shared

forward pass: 從左下走到右上
對於每一個時間點的 Loss 都可以分別算 error signal
時間點 1 的 error signal 沿著綠色走
時間點 2 的 error signal 沿著藍色走
時間點 3 的 error signal 沿著紅色走
backward pass 全部算出來後:
對於 W(1) 來講,把三個 error signal 全部加起來,乘上 forward pass 算出的結果,得到對於整個 sequence 需更新的參數組合
留言
張貼留言