My gosh! The last time spent as much time learning a particular topic, was when I wanted to learn more about buffer overflows. Learning about back propagation which is a critical component of deep learning, really required me to dig deep (pun intended). This post reflects my understanding of back propagation. My hope is with this understanding, I can now move forward with what is hopefully easier learning (pun intended again ;-) ) as I learn more about deep learning.
Here is my topology:
The network is being fed 2 & 9 as input and are expecting a value of 92 as output. This 92 can be made into percent such as 0.92 to keep it simple.
Feeding forward!
Here is the topology with inputs, weights, bias, etc. As shown below, our target is 0.92 or 92%. Our network predicted 0.6445 or 64%. Let's see how we got this prediction.
Note: Later in the diagrams I accidentally dropped the bias, but they are b0=0.5, b1=0.5, b2=0.5, b3-02.
The first part of understanding back propagation, is to understand the feed forward process. I believe this is relatively easy. Let's see this in action. Let's first find z0, then O1. This will be followed by z1 then O2, then I then wrap up the hidden layer by doing z2 and O2. Note that zx represent the weighted sum, while Ox represents the output after the activation function has been applied.
Hidden Layer Node 0: z0 = (x0 * w0) + (x1 * w1) + b z0 = (2 * 0.15) + (9 * 0.23) + 0.5 = 0.3 + 2.07 + 0.5 = 2.87
With z0 value found, time to compute O0 by applying an activation function. Throughout this post, I am using the Sigmoid activation function. Check out the Wikipedia or casio or redcrab-software in the reference if you are not sure how to compute the sigmoid.
O0 = sigmoid(z0) O0 = sigmoid(2.87) = 0.9463
Perform a similar process for Hidden Layer Node 1:
z1 = (x0 * w2) + (x1 * w3) + b z1 = (2 * 0.5) + (9 * 0.8) + 0.5 = 2 + 7.2 + 0.5 = 8.7 O1 = sigmoid(z1) O1 = sigmoid(8.7) = 0.9998
Completing the hidden layer Hidden Layer Node 2:
z2 = (x0 * w4) + (x1 * w5) + b z2 = (2 * 0.05) + (9 * -0.05) + 0.5 = 0.1 + (-0.45) + 0.5 = 0.15 O2 = sigmoid(z2) O2 = sigmoid(0.15) = 0.5374
Moving on to the output layer.
Output Layer Node 0:
z3 = (00 * w6) + (00 * w7) + (00 * w8) + b z3 = (0.9463 * 0.9) + (0.9998 * -0.5) + (0.5374 * 0.08) + 0.2 = 0.8517 + (-0.4999) + (0.0430) + 0.2 = 0.5948 O3 = sigmoid(z3) O3 = sigmoid(0.5948) = 0.6445
Ok. After the first past this network produced 0.6445 or 64%. I needed it to produce 0.92 or 92%
Setting up the Cost function, using Mean Squared Error (MSE)
Cost = (actual - predicted)**2
another way of looking at this and what I will use is.
Cost = (y_true - y_predicted)**2
Using this formula, I now plug our values in to get the loss:
Cost = (0.92 - 0.6445) ** 2 = (0.2755) ** 2 = 0.0759
At this point, we have a cost of 0.0759. Moving on ...
The real headache (for me) Back Propagation!
I know there are many out there, who would say this is easy stuff to learn. I am not mad at you. However, this took me some time and for many people, they may not even bother to invest that time. I know you might also be saying but at least you (me) were able to figure it out and thus it might not have been as hard as you thought. I guess the end does justify the means.
Let's get down to the meat of the matter. What we need to figure out, is how the weights (w0, w1, w2, w3, w4, w5, w6, w7, w8) impact the cost. To do that, I need to take advantage of the chain rule and computation of partial derivatives.
I will use 'd' to represent derivatives. Therefore dCost/dO3 means, the derivative of the cost as it relates to the derivative of output 3 (O3). In other words, how does a change in output 3 impact the cost. Let's dig in.
Breaking it down further, let's say dCost is 3 and dO3 is 2. This is equal to:
dCost/dO3 = 3/2 = 1.5
Now let say dw0 is 10. Note, any number chosen should produce the same results.
This means.
dCost/dO3 = dCost/dw0 * dw0/dO3 1.5 = 3/10 * 10/2 1.5 = 30/20 1.5 = 1.5
As you can see above, we got 1.5 on both sides of the equal.
With that basic understanding, time to build further.
Here are the first 6 steps, where ultimately, I finding for w6, w7 and w8. Looking at it from a different perspective dCost/dw6, dCost/dw7 and dCost/dw8.
Let's find dCost/dz3 = dCost/dO3 * dO3/dz3
First, step 1: dCost/dO3 = O3 - y_true = 0.6445 - 0.92 = -0.2755 Step 2: dO3/dz3 = O3 * (1 - O3) = 0.6445 * (1 - 0.6445) = 0.6445 * 0.3555 = 0.2291
With the above calculated, finding for dCost/dz3 is simply to multiply dCost/dO3 * dO3/dz3.
Step 3: dCost/dz3 = dCost/dO3 * dO3/dz3 dCost/dz3 = -0.2755 * 0.2291 = -0.0631 This value also represents dCost/db3 dCost/dBias = -0.0631
With those out of the way, time to compute how w6, w7 and w8 impact the cost.
Next, step 4. Find dCost/dw6 Weight 6: dCost/dw6 = dCost/dz3 * dz3/dw6 = dCost/dz3 * O0 = -0.0631 * 0.9463 = -0.0597 Step 5. Find dCost/dw7 Weight 7: dCost/dw7 = dCost/dz3 * dz3/dw7 = dCost/dz3 * O1 = -0.0631 * 0.9998 = -0.0631 Step 6. Find dCost/dw8 Weight 8: dCost/dw8 = dCost/dz3 * dz3/dw8 = dCost/dz3 * O2 = -0.0631 * 0.5374 = -0.0339
Step 7 : Finding dCost/dO0 dCost/dO0 = dCost/dz3 * dz3/dO0 = dCost/dz3 * w6 = -0.0631 * 0.9 = -0.0568 Step 8 : Finding dO0/dz0 dO0/dz0 = O0 * (1 - O0) = 0.9463 * (1 - 0.9463) = 0.9463 * 0.0537 = 0.0508 Step 9 : Finding dCost/dz0 dCost/dz0 = -0.0568 * 0.0508 = -0.0029 Step 10 : Finding dCost/dw0 dCost/dw0 = dCost/dz0 * dz0/dw0 = dCost/dz0 * Input0 = -0.0029 * 2 = -0.0058 Step 11 : Finding dCost/dw1 dCost/dw1 = dCost/dw1 = dCost/dz0 * dz0/dw1 = dCost/dz0 * Input1 = -0.0029 * 9 = -0.0261
Step 12 : Finding dCost/dO1 dCost/dO1 = dCost/dz3 * dz3/dO1 = dCost/dz3 * w7 = -0.0631 * -0.5 = 0.0316 Step 13 : Finding dO0/dz0 dO0/dz1 = O1 * (1 - O1) = 0.9998 * (1 - 0.9998) = 0.9998 * 0.0002 = 0.0002 Step 14 : Finding dCost/dz0 dCost/dz1 = dCost/dO1 * dO1/dz1 = -0.0316 * 0.0002 = -0.0000 This value is also dCost/db1 dCost/db1 = -0.0000 Step 15 : Finding dCost/dw2 dCost/dw2 = dCost/dz1 * dz0/dw2 = dCost/dz1 * Input0 = -0.0000 * 2 = -0.0000 Step 16 : Finding dCost/dw3 dCost/dw3 = dCost/dz1 * dz1/dw3 = dCost/dz1 * Input1 = -0.0000 * 9 = -0.0000
Step 17 : Finding dCost/dO2 dCost/dO2 = dCost/dz3 * dz3/dO2 = dCost/dz3 * w8 = -0.0631 * 0.08 = -0.0050 Step 18 : Finding dO2/dz2 dO2/dz2 = O2 * (1 - O2) = 0.5374 * (1 - 0.5374) = 0.5374 * 0.4626 = 0.2486 Step 19 : Finding dCost/dz2 dCost/dz2 = dCost/dO2 * dO2/dz2 = -0.0050 * 0.2486 = -0.0012 This value is also dCost/db2 dCost/db2 = -0.0012 Step 20 : Finding dCost/dw4 dCost/dw4 = dCost/dz2 * dz0/dw4 = dCost/dz2 * Input0 = -0.0012 * 2 = -0.0024 Step 21 : Finding dCost/dw5 dCost/dw5 = dCost/dz2 * dz2/dw5 = dCost/dz2 * Input1 = -0.0012 * 9 = -0.0108
Here is the final diagram.
Ahhhhhhhhhhhhh! The heavy lifting has been completed. Next up, calculate the new weights and biases.
The formula for the new weights and biases are as follows:
new_weight = old_weight - learning_rate(dCost/dw_n) where dw_n, represents dw0, dw1, dw2,..., dw8. The learning rate is typically a value between 0 and 1. I will use 0.5 Hence the new weights are:
new_w0 = old_w0 - 0.5(dCost/dw0) = 0.15 - 0.5(-0.0058) = 0.15 - (-0.0029) = 0.1529 new_w1 = old_w1 - 0.5(dCost/dw1) = 0.23 - 0.5(-0.0261) = 0.23 - (-0.01305) = 0.2431 new_w2 = old_w2 - 0.5(dCost/dw2) = 0.5 - 0.5(-0.000) = 0.5 - (0) = 0.5 new_w3 = old_w3 - 0.5(dCost/dw3) = 0.8 - 0.5(-0.000) = 0.8 - (0) = 0.8 new_w4 = old_w4 - 0.5(dCost/dw4) = 0.05 - 0.5(-0.0024) = 0.05 - (-0.0012) = 0.0512 new_w5 = old_w5 - 0.5(dCost/dw5) = -0.05 - 0.5(-0.0108) = -0.05 - (-0.0054) = -0.0446 new_w6 = old_w6 - 0.5(dCost/dw6) = 0.9 - 0.5(-0.0597) = 0.9 - (-0.029854) = 0.9299 new_w7 = old_w7 - 0.5(dCost/dw7) = -0.5 - 0.5(-0.0631) = -0.5 - (-0.03155) = -0.4685 new_w8 = old_w8 - 0.5(dCost/dw8) = 0.08 - 0.5(-0.0339) = 0.08 - (-0.01695) = 0.09695 new_b0 = old_b0 - 0.5(dCost/db0) = 0.5 - 0.5(-0.0029) = 0.5 - (-0.00145) = 0.50145 new_b1 = old_b1 - 0.5(dCost/db1) = 0.5 - 0.5(0.0316) = 0.5 - (0.0158) = 0.4842 new_b2 = old_b2 - 0.5(dCost/db2) = 0.5 - 0.5(-0.005) = 0.5 - (-0.0025) = 0.5025 new_b3 = old_b3 - 0.5(dCost/db3) = 0.2 - 0.5(-0.0631) = 0.2 - (-0.03155) = 0.23155
Sighhhhhh!!!! This was a tedious process and really took some patience on my part. I'm happy that I was able to learn this. With this completed, I believe my learning process (pun intended) becomes a lot easier.
See you the next post, where I code this just to get a basic understanding.
References
No comments:
Post a Comment