Thursday, August 5, 2021

Beginning Deep Learning - Understanding Back Propagation

My gosh! The last time spent as much time learning a particular topic, was when I wanted to learn more about buffer overflows. Learning about back propagation which is a critical component of deep learning, really required me to dig deep (pun intended). This post reflects my understanding of back propagation. My hope is with this understanding, I can now move forward with what is hopefully easier learning (pun intended again ;-) )  as I learn more about deep learning.

Here is my topology:


The network is being fed 2 & 9 as input and are expecting a value of 92 as output. This 92 can be made into percent such as 0.92 to keep it simple.

Feeding forward!
Here is the topology with inputs, weights, bias, etc. As shown below, our target is 0.92 or 92%. Our network predicted 0.6445 or 64%. Let's see how we got this prediction.


Note: Later in the diagrams I accidentally dropped the bias, but they are b0=0.5, b1=0.5, b2=0.5, b3-02.

The first part of understanding back propagation, is to understand the feed forward process. I believe this is relatively easy. Let's see this in action. Let's first find z0, then O1. This will be followed by z1 then O2, then I then wrap up the hidden layer by doing z2 and O2. Note that zx represent the weighted sum, while Ox represents the output after the activation function has been applied. 

Hidden Layer Node 0:
z0  = (x0 * w0) + (x1 * w1) + b
z0  = (2 * 0.15) + (9 * 0.23) + 0.5
    = 0.3 + 2.07 + 0.5
    = 2.87

With z0 value found, time to compute O0 by applying an activation function. Throughout this post, I am using the Sigmoid activation function. Check out the Wikipedia or casio or  redcrab-software in the reference if you are not sure how to compute the sigmoid.

O0 = sigmoid(z0)
O0 = sigmoid(2.87) = 0.9463

Perform a similar process for Hidden Layer Node 1:

z1  = (x0 * w2) + (x1 * w3) + b
z1  = (2 * 0.5) + (9 * 0.8) + 0.5
    = 2 + 7.2 + 0.5
    = 8.7

O1 = sigmoid(z1)
O1 = sigmoid(8.7) = 0.9998

Completing the hidden layer Hidden Layer Node 2:

z2  = (x0 * w4) + (x1 * w5) + b
z2  = (2 * 0.05) + (9 * -0.05) + 0.5
    = 0.1 + (-0.45) + 0.5
    = 0.15

O2 = sigmoid(z2)
O2 = sigmoid(0.15) = 0.5374

Moving on to the output layer.

Output Layer Node 0:

z3  = (00 * w6) + (00 * w7) + (00 * w8)  + b
z3  = (0.9463 * 0.9) + (0.9998 * -0.5) + (0.5374 * 0.08) + 0.2
    = 0.8517 + (-0.4999) + (0.0430) + 0.2
    = 0.5948

O3 = sigmoid(z3)
O3 = sigmoid(0.5948) = 0.6445

Ok. After the first past this network produced 0.6445 or 64%. I needed it to produce 0.92 or 92% 

Setting up the Cost function, using Mean Squared Error (MSE)
Cost = (actual - predicted)**2

another way of looking at this and what I will use is.

Cost = (y_true - y_predicted)**2

Using this formula, I now plug our values in to get the loss:

Cost = (0.92 - 0.6445) ** 2
    = (0.2755) ** 2
    = 0.0759

At this point, we have a cost of 0.0759. Moving on ...

The real headache (for me) Back Propagation!
I know there are many out there, who would say this is easy stuff to learn. I am not mad at you. However, this took me some time and for many people, they may not even bother to invest that time. I know you might also be saying but at least you (me) were able to figure it out and thus it might not have been as hard as you thought. I guess the end does justify the means.

Let's get down to the meat of the matter. What we need to figure out, is how the weights (w0, w1, w2, w3, w4, w5, w6, w7, w8) impact the cost. To do that, I need to take advantage of the chain rule and computation of partial derivatives. 

I will use 'd' to represent derivatives. Therefore dCost/dO3 means, the derivative of the cost as it relates to the derivative of output 3 (O3). In other words, how does a change in output 3 impact the cost. Let's dig in.

Quick simple note on chain rule. The chain rule says, dCost/dO3 is the same as dCost/dw0 * dw0/dO3.

Breaking it down further, let's say dCost is 3 and dO3 is 2. This is equal to:

dCost/dO3
= 3/2
= 1.5

Now let say dw0 is 10. Note, any number chosen should produce the same results. 

This means.

dCost/dO3 = dCost/dw0 * dw0/dO3
1.5 = 3/10 * 10/2
1.5 = 30/20
1.5 = 1.5

As you can see above, we got 1.5 on both sides of the equal.

With that basic understanding, time to build further.

Here are the first 6 steps, where ultimately, I finding for w6, w7 and w8. Looking at it from a different perspective dCost/dw6, dCost/dw7 and dCost/dw8.



Let's find dCost/dz3 = dCost/dO3 * dO3/dz3

First, step 1:
dCost/dO3 = O3 - y_true
    = 0.6445 - 0.92 
    = -0.2755


Step 2:
dO3/dz3 = O3 * (1 - O3)
    = 0.6445 * (1 - 0.6445)
    = 0.6445 * 0.3555
    = 0.2291

With the above calculated, finding for dCost/dz3 is simply to multiply dCost/dO3 * dO3/dz3.

Step 3:
dCost/dz3 = dCost/dO3 * dO3/dz3
dCost/dz3 = -0.2755 * 0.2291
    = -0.0631

This value also represents dCost/db3
dCost/dBias = -0.0631

With those out of the way, time to compute how w6, w7 and w8 impact the cost. 

Next, step 4. Find dCost/dw6

Weight 6:
dCost/dw6 = dCost/dz3 * dz3/dw6
    = dCost/dz3 * O0
    = -0.0631 * 0.9463
    = -0.0597


Step 5. Find dCost/dw7

Weight 7:
dCost/dw7 = dCost/dz3 * dz3/dw7
    = dCost/dz3 * O1
    = -0.0631 * 0.9998
    = -0.0631


Step 6. Find dCost/dw8

Weight 8:
dCost/dw8 = dCost/dz3 * dz3/dw8
    = dCost/dz3 * O2
    = -0.0631 * 0.5374
    = -0.0339

Next up, how does the w0, w1, w2, w3, w4 and w5 impact the cost.

Step 7 : Finding dCost/dO0
dCost/dO0 = dCost/dz3 * dz3/dO0
    =  dCost/dz3 * w6
    = -0.0631 * 0.9
    = -0.0568


Step 8 : Finding dO0/dz0
    dO0/dz0 = O0 * (1 - O0)
        = 0.9463 * (1 - 0.9463)
        = 0.9463 * 0.0537
        = 0.0508


Step 9 : Finding dCost/dz0
    dCost/dz0 =  -0.0568 * 0.0508
        = -0.0029


Step 10 : Finding dCost/dw0
    dCost/dw0 = dCost/dz0 * dz0/dw0
        = dCost/dz0 * Input0
        = -0.0029 * 2
        = -0.0058


Step 11 : Finding dCost/dw1
    dCost/dw1 = dCost/dw1 = dCost/dz0 * dz0/dw1
        = dCost/dz0 * Input1
        = -0.0029 * 9
        = -0.0261




Step 12 : Finding dCost/dO1
dCost/dO1 = dCost/dz3 * dz3/dO1
    =  dCost/dz3 * w7
    = -0.0631 * -0.5
    = 0.0316



Step 13 : Finding dO0/dz0
    dO0/dz1 = O1 * (1 - O1)
        = 0.9998 * (1 - 0.9998)
        = 0.9998 * 0.0002
        = 0.0002


Step 14 : Finding dCost/dz0
    dCost/dz1 = dCost/dO1 * dO1/dz1
     =  -0.0316 * 0.0002
        = -0.0000


This value is also dCost/db1
dCost/db1 = -0.0000


Step 15 : Finding dCost/dw2
    dCost/dw2 = dCost/dz1 * dz0/dw2
        = dCost/dz1 * Input0
        = -0.0000 * 2
        = -0.0000


Step 16 : Finding dCost/dw3
    dCost/dw3 = dCost/dz1 * dz1/dw3
        = dCost/dz1 * Input1
        = -0.0000 * 9
        = -0.0000




Step 17 : Finding dCost/dO2
dCost/dO2 = dCost/dz3 * dz3/dO2
    =  dCost/dz3 * w8
    = -0.0631 * 0.08
    = -0.0050


Step 18 : Finding dO2/dz2
    dO2/dz2 = O2 * (1 - O2)
        = 0.5374 * (1 - 0.5374)
        = 0.5374 * 0.4626
        = 0.2486


Step 19 : Finding dCost/dz2
    dCost/dz2 = dCost/dO2 * dO2/dz2
     =  -0.0050 * 0.2486
        = -0.0012
        
This value is also dCost/db2
dCost/db2 = -0.0012


Step 20 : Finding dCost/dw4
    dCost/dw4 = dCost/dz2 * dz0/dw4
        = dCost/dz2 * Input0
        = -0.0012 * 2
        = -0.0024


Step 21 : Finding dCost/dw5
    dCost/dw5 = dCost/dz2 * dz2/dw5
        = dCost/dz2 * Input1
        = -0.0012 * 9
        = -0.0108


Here is the final diagram.



Ahhhhhhhhhhhhh! The heavy lifting has been completed. Next up, calculate the new weights and biases.

The formula for the new weights and biases are as follows:
new_weight = old_weight - learning_rate(dCost/dw_n) where dw_n, represents dw0, dw1, dw2,..., dw8. The learning rate is typically a value between 0 and 1. I will use 0.5 Hence the new weights are: 

new_w0 = old_w0 - 0.5(dCost/dw0)
    = 0.15 - 0.5(-0.0058) 
    = 0.15 - (-0.0029)
    = 0.1529

new_w1 = old_w1 - 0.5(dCost/dw1)
    = 0.23 - 0.5(-0.0261) 
    = 0.23 - (-0.01305)
    = 0.2431

new_w2 = old_w2 - 0.5(dCost/dw2)
    = 0.5 - 0.5(-0.000) 
    = 0.5 - (0)
    = 0.5


new_w3 = old_w3 - 0.5(dCost/dw3)
    = 0.8 - 0.5(-0.000) 
    = 0.8 - (0)
    = 0.8


new_w4 = old_w4 - 0.5(dCost/dw4)
    = 0.05 - 0.5(-0.0024) 
    = 0.05 - (-0.0012)
    = 0.0512


new_w5 = old_w5 - 0.5(dCost/dw5)
    = -0.05 - 0.5(-0.0108) 
    = -0.05 - (-0.0054)
    = -0.0446


new_w6 = old_w6 - 0.5(dCost/dw6)
    = 0.9 - 0.5(-0.0597) 
    = 0.9 - (-0.029854)
    = 0.9299


new_w7 = old_w7 - 0.5(dCost/dw7)
    = -0.5 - 0.5(-0.0631) 
    = -0.5 - (-0.03155)
    = -0.4685


new_w8 = old_w8 - 0.5(dCost/dw8)
    = 0.08 - 0.5(-0.0339) 
    = 0.08 - (-0.01695)
    = 0.09695


new_b0 = old_b0 - 0.5(dCost/db0)
    = 0.5 - 0.5(-0.0029) 
    = 0.5 - (-0.00145)
    = 0.50145


new_b1 = old_b1 - 0.5(dCost/db1)
    = 0.5 - 0.5(0.0316) 
    = 0.5 - (0.0158)
    = 0.4842


new_b2 = old_b2 - 0.5(dCost/db2)
    = 0.5 - 0.5(-0.005) 
    = 0.5 - (-0.0025)
    = 0.5025


new_b3 = old_b3 - 0.5(dCost/db3)
    = 0.2 - 0.5(-0.0631) 
    = 0.2 - (-0.03155)
    = 0.23155

Sighhhhhh!!!! This was a tedious process and really took some patience on my part. I'm happy that I was able to learn this. With this completed, I believe my learning process (pun intended) becomes a lot easier.

See you the next post, where I code this just to get a basic understanding.


References

No comments:

Post a Comment