<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="4.3.4">Jekyll</generator><link href="https://princeemensah.github.io/feed.xml" rel="self" type="application/atom+xml" /><link href="https://princeemensah.github.io/" rel="alternate" type="text/html" /><updated>2026-05-29T08:31:36+00:00</updated><id>https://princeemensah.github.io/feed.xml</id><title type="html">Prince Mensah</title><subtitle>Passionate about developing end-to-end AI solutions and solving real-world problems through intelligent automation.
</subtitle><entry><title type="html">Implementing Neural Network from scratch-Part 2 (Softmax Classification)</title><link href="https://princeemensah.github.io/blog/2024/08/neural-net2.html" rel="alternate" type="text/html" title="Implementing Neural Network from scratch-Part 2 (Softmax Classification)" /><published>2024-08-16T09:46:13+00:00</published><updated>2024-08-16T09:46:13+00:00</updated><id>https://princeemensah.github.io/blog/2024/08/neural-net2</id><content type="html" xml:base="https://princeemensah.github.io/blog/2024/08/neural-net2.html"><![CDATA[<h2 id="introduction">Introduction</h2>

<p>In a <a href="https://princeemensah.github.io/blog/2024/neural-net/">previous post on binary classification</a>, we explored how to build a neural network from scratch using the MNIST dataset, focusing on distinguishing between two digits. If you followed that guide, you should now be familiar with key concepts such as forward and backward propagation, as well as the use of the sigmoid activation function for binary outputs.</p>

<p>In this tutorial, we’ll expand on that foundation by modifying our neural network to handle multi-class classification. While binary classification involves only two possible outcomes, multi-class classification requires our model to choose from multiple classes—in this case, the digits 0 through 9. To achieve this, we’ll replace the sigmoid activation in the output layer with the softmax function, which will allow our network to output a probability distribution across all classes.</p>

<p>If you’re new to this series, I recommend checking out the <a href="https://princeemensah.github.io/blog/2024/neural-net/">previous tutorial on binary classification</a> to get a solid understanding of the basics before diving into multi-class classification. For those who are already familiar, let’s jump right into extending our neural network to handle multiple classes!</p>

<h2 id="data-preprocessing">Data Preprocessing</h2>

<p>Before we can train our neural network on the MNIST dataset, we need to preprocess the data to ensure it’s in the right format. This involves flattening the images, normalizing the pixel values, and converting the labels into a one-hot encoded format.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">pre_process_data</span><span class="p">(</span><span class="n">train_x</span><span class="p">,</span> <span class="n">train_y</span><span class="p">,</span> <span class="n">test_x</span><span class="p">,</span> <span class="n">test_y</span><span class="p">):</span>
    <span class="c1"># Flatten the input images
</span>    <span class="n">train_x</span> <span class="o">=</span> <span class="n">train_x</span><span class="p">.</span><span class="nf">reshape</span><span class="p">(</span><span class="n">train_x</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mf">255.</span>  <span class="c1"># Flatten and normalize
</span>    <span class="n">test_x</span> <span class="o">=</span> <span class="n">test_x</span><span class="p">.</span><span class="nf">reshape</span><span class="p">(</span><span class="n">test_x</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mf">255.</span>  <span class="c1"># Flatten and normalize
</span>
    <span class="n">enc</span> <span class="o">=</span> <span class="nc">OneHotEncoder</span><span class="p">(</span><span class="n">sparse</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">categories</span><span class="o">=</span><span class="sh">'</span><span class="s">auto</span><span class="sh">'</span><span class="p">)</span>
    <span class="n">train_y</span> <span class="o">=</span> <span class="n">enc</span><span class="p">.</span><span class="nf">fit_transform</span><span class="p">(</span><span class="n">train_y</span><span class="p">.</span><span class="nf">reshape</span><span class="p">(</span><span class="nf">len</span><span class="p">(</span><span class="n">train_y</span><span class="p">),</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span>
    <span class="n">test_y</span> <span class="o">=</span> <span class="n">enc</span><span class="p">.</span><span class="nf">transform</span><span class="p">(</span><span class="n">test_y</span><span class="p">.</span><span class="nf">reshape</span><span class="p">(</span><span class="nf">len</span><span class="p">(</span><span class="n">test_y</span><span class="p">),</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span>

    <span class="k">return</span> <span class="n">train_x</span><span class="p">,</span> <span class="n">train_y</span><span class="p">,</span> <span class="n">test_x</span><span class="p">,</span> <span class="n">test_y</span>
</pre></td></tr></tbody></table></code></pre></div></div>
<p><strong>Checking the Data Shape</strong></p>

<p>Next, we print the shapes of the preprocessed training and test datasets to confirm that the preprocessing steps were applied correctly. This helps ensure that the data is in the expected format before we proceed with training the neural network.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
</pre></td><td class="rouge-code"><pre><span class="p">(</span><span class="n">train_x</span><span class="p">,</span> <span class="n">train_y</span><span class="p">),</span> <span class="p">(</span><span class="n">test_x</span><span class="p">,</span> <span class="n">test_y</span><span class="p">)</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">datasets</span><span class="p">.</span><span class="n">mnist</span><span class="p">.</span><span class="nf">load_data</span><span class="p">()</span>
<span class="n">train_x</span><span class="p">,</span> <span class="n">train_y</span><span class="p">,</span> <span class="n">test_x</span><span class="p">,</span> <span class="n">test_y</span> <span class="o">=</span> <span class="nf">pre_process_data</span><span class="p">(</span><span class="n">train_x</span><span class="p">,</span> <span class="n">train_y</span><span class="p">,</span> <span class="n">test_x</span><span class="p">,</span> <span class="n">test_y</span><span class="p">)</span>

<span class="nf">print</span><span class="p">(</span><span class="sh">"</span><span class="s">train_x</span><span class="sh">'</span><span class="s">s shape: </span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">train_x</span><span class="p">.</span><span class="n">shape</span><span class="p">))</span>
<span class="nf">print</span><span class="p">(</span><span class="sh">"</span><span class="s">test_x</span><span class="sh">'</span><span class="s">s shape: </span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">test_x</span><span class="p">.</span><span class="n">shape</span><span class="p">))</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<h2 id="defining-the-neural-network">Defining the Neural Network</h2>

<p>With our data preprocessed and ready, the next step is to define the architecture of our neural network. We’ll do this by creating a <code class="language-plaintext highlighter-rouge">NeuralNetwork</code> class that will handle everything from parameter initialization to training and prediction.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
</pre></td><td class="rouge-code"><pre><span class="k">class</span> <span class="nc">NeuralNetwork</span><span class="p">:</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">layers_size</span><span class="p">):</span>
        <span class="n">self</span><span class="p">.</span><span class="n">layers_size</span> <span class="o">=</span> <span class="n">layers_size</span>
        <span class="n">self</span><span class="p">.</span><span class="n">parameters</span> <span class="o">=</span> <span class="p">{}</span>
        <span class="n">self</span><span class="p">.</span><span class="n">length</span> <span class="o">=</span> <span class="nf">len</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">layers_size</span><span class="p">)</span>
        <span class="n">self</span><span class="p">.</span><span class="n">n</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="n">self</span><span class="p">.</span><span class="n">costs</span> <span class="o">=</span> <span class="p">[]</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>The setup we have implemented above is the foundation upon which the rest of the neural network operations—such as forward propagation, backpropagation, and parameter updates—will be built.</p>

<h2 id="activation-functions">Activation Functions</h2>

<p>Activation functions are very impotant since they introduce non-linearity into model, helping to learn more complex patterns. for introducing non-linearity into the model, allowing it to learn complex patterns in the data. Here, we will use two different activation functions: <code class="language-plaintext highlighter-rouge">sigmoid</code> for the hidden layers and <code class="language-plaintext highlighter-rouge">softmax</code> for the output layer.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">sigmoid</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">Z</span><span class="p">):</span>
    <span class="k">return</span> <span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="nf">exp</span><span class="p">(</span><span class="o">-</span><span class="n">Z</span><span class="p">))</span>

<span class="k">def</span> <span class="nf">sigmoid_derivative</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">Z</span><span class="p">):</span>
    <span class="n">s</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="nf">exp</span><span class="p">(</span><span class="o">-</span><span class="n">Z</span><span class="p">))</span>
    <span class="k">return</span> <span class="n">s</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">s</span><span class="p">)</span>
    
<span class="k">def</span> <span class="nf">softmax</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">Z</span><span class="p">):</span>
    <span class="n">expZ</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">exp</span><span class="p">(</span><span class="n">Z</span> <span class="o">-</span> <span class="n">np</span><span class="p">.</span><span class="nf">max</span><span class="p">(</span><span class="n">Z</span><span class="p">))</span>
    <span class="k">return</span> <span class="n">expZ</span> <span class="o">/</span> <span class="n">expZ</span><span class="p">.</span><span class="nf">sum</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>The <strong>softmax function</strong> transforms the output of the network into a form that can be interpreted as probabilities, making it ideal for multi-class classification tasks like the MNIST dataset which has 10 different classes.</p>

<div class="row mt-3 justify-content-center">
    <div class="col-sm-8 mt-3 mt-md-0">
        <picture>
            <img src="/images/blog/softmax.png" width="100%" alt="Softmax activation function plot" />
        </picture>
    </div>
</div>

<h2 id="forward-pass">Forward Pass</h2>

<p>With our activation functions defined, we can now implement the forward propagation process, where the input data is passed through the network layer by layer to produce the final output. This step involves calculating the weighted sums of the inputs, applying activation functions, and saving the necessary values for backpropagation.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">X</span><span class="p">):</span>
    <span class="n">save</span> <span class="o">=</span> <span class="p">{}</span>
    <span class="n">A</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">T</span>  <span class="c1"># X is already flattened, so no further reshaping needed
</span>    <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span>
        <span class="n">Z</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)].</span><span class="nf">dot</span><span class="p">(</span><span class="n">A</span><span class="p">)</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">b</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span>
        <span class="n">A</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">sigmoid</span><span class="p">(</span><span class="n">Z</span><span class="p">)</span>
        <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">A</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span> <span class="o">=</span> <span class="n">A</span>
        <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span>
        <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">Z</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span> <span class="o">=</span> <span class="n">Z</span>

    <span class="n">Z</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)].</span><span class="nf">dot</span><span class="p">(</span><span class="n">A</span><span class="p">)</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">b</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span>
    <span class="n">A</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">softmax</span><span class="p">(</span><span class="n">Z</span><span class="p">)</span>
    <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">A</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span> <span class="o">=</span> <span class="n">A</span>
    <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span>
    <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">Z</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span> <span class="o">=</span> <span class="n">Z</span>

    <span class="k">return</span> <span class="n">A</span><span class="p">,</span> <span class="n">save</span>
</pre></td></tr></tbody></table></code></pre></div></div>
<p>By passing the input data through each layer, the network transforms the raw input into a meaningful output—probabilities that represent the likelihood of each class.</p>

<h2 id="backward-pass">Backward Pass</h2>

<p>After completing the forward propagation and obtaining the network’s output, the next step is backward pass (backpropagation). This is where we calculate the gradients of the cost function with respect to each parameter (weights and biases) and use these gradients to update the parameters, minimizing the error in predictions.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">save</span><span class="p">):</span>
    
    <span class="n">gradients</span> <span class="o">=</span> <span class="p">{}</span>
    
    <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">A0</span><span class="sh">"</span><span class="p">]</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">T</span>
    
    <span class="n">A</span> <span class="o">=</span> <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">A</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span>
    <span class="n">dZ</span> <span class="o">=</span> <span class="n">A</span> <span class="o">-</span> <span class="n">Y</span><span class="p">.</span><span class="n">T</span>
    
    <span class="n">dW</span> <span class="o">=</span> <span class="n">dZ</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">A</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)].</span><span class="n">T</span><span class="p">)</span> <span class="o">/</span> <span class="n">self</span><span class="p">.</span><span class="n">n</span>
    <span class="n">db</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">sum</span><span class="p">(</span><span class="n">dZ</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> <span class="o">/</span> <span class="n">self</span><span class="p">.</span><span class="n">n</span>
    <span class="n">dAPrev</span> <span class="o">=</span> <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)].</span><span class="n">T</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">dZ</span><span class="p">)</span>
    
    <span class="n">gradients</span><span class="p">[</span><span class="sh">"</span><span class="s">dW</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span> <span class="o">=</span> <span class="n">dW</span>
    <span class="n">gradients</span><span class="p">[</span><span class="sh">"</span><span class="s">db</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span> <span class="o">=</span> <span class="n">db</span>
    
    <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">):</span>
        <span class="n">dZ</span> <span class="o">=</span> <span class="n">dAPrev</span> <span class="o">*</span> <span class="n">self</span><span class="p">.</span><span class="nf">sigmoid_derivative</span><span class="p">(</span><span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">Z</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)])</span>
        <span class="n">dW</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">/</span> <span class="n">self</span><span class="p">.</span><span class="n">n</span> <span class="o">*</span> <span class="n">dZ</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">A</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)].</span><span class="n">T</span><span class="p">)</span>
        <span class="n">db</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">/</span> <span class="n">self</span><span class="p">.</span><span class="n">n</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="nf">sum</span><span class="p">(</span><span class="n">dZ</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">layer</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
            <span class="n">dAPrev</span> <span class="o">=</span> <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)].</span><span class="n">T</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">dZ</span><span class="p">)</span>
    
        <span class="n">gradients</span><span class="p">[</span><span class="sh">"</span><span class="s">dW</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">=</span> <span class="n">dW</span>
        <span class="n">gradients</span><span class="p">[</span><span class="sh">"</span><span class="s">db</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">=</span> <span class="n">db</span>
    
    <span class="k">return</span> <span class="n">gradients</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>The Backpropagation we’ve implemented above is the core mechanism that allows a neural network to learn from data. By calculating how much each parameter (weight and bias) contributes to the overall error, the network can adjust these parameters to minimize the error.</p>

<h2 id="training-the-neural-network">Training the Neural Network</h2>

<p>Once we’ve set up the forward and backward propagation methods, the next step is to train the neural network. Training involves repeatedly passing the training data through the network, calculating the error, and then adjusting the network’s parameters to reduce this error.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">fit</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span> <span class="n">n_iterations</span><span class="o">=</span><span class="mi">2500</span><span class="p">):</span>
    <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">seed</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
    
    <span class="n">self</span><span class="p">.</span><span class="n">n</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    
    <span class="n">self</span><span class="p">.</span><span class="n">layers_size</span><span class="p">.</span><span class="nf">insert</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
    
    <span class="n">self</span><span class="p">.</span><span class="nf">initialize_parameters</span><span class="p">()</span>
    <span class="k">for</span> <span class="n">loop</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="n">n_iterations</span><span class="p">):</span>
        <span class="n">A</span><span class="p">,</span> <span class="n">save</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">forward</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
        <span class="n">cost</span> <span class="o">=</span> <span class="o">-</span><span class="n">np</span><span class="p">.</span><span class="nf">mean</span><span class="p">(</span><span class="n">Y</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="nf">log</span><span class="p">(</span><span class="n">A</span><span class="p">.</span><span class="n">T</span> <span class="o">+</span> <span class="mf">1e-8</span><span class="p">))</span>
        <span class="n">gradients</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">backward</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">save</span><span class="p">)</span>
    
        <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">self</span><span class="p">.</span><span class="n">length</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span>
            <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">-</span> <span class="n">learning_rate</span> <span class="o">*</span> <span class="n">gradients</span><span class="p">[</span><span class="sh">"</span><span class="s">dW</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span>
            <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">b</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">b</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">-</span> <span class="n">learning_rate</span> <span class="o">*</span> <span class="n">gradients</span><span class="p">[</span><span class="sh">"</span><span class="s">db</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span>
    
        <span class="k">if</span> <span class="n">loop</span> <span class="o">%</span> <span class="mi">10</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
            <span class="nf">print</span><span class="p">(</span><span class="sh">"</span><span class="s">Cost: </span><span class="sh">"</span><span class="p">,</span> <span class="n">cost</span><span class="p">,</span> <span class="sh">"</span><span class="s">Train Accuracy:</span><span class="sh">"</span><span class="p">,</span> <span class="n">self</span><span class="p">.</span><span class="nf">predict</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">))</span>
    
        <span class="k">if</span> <span class="n">loop</span> <span class="o">%</span> <span class="mi">1</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
            <span class="n">self</span><span class="p">.</span><span class="n">costs</span><span class="p">.</span><span class="nf">append</span><span class="p">(</span><span class="n">cost</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>By repeating this process over many iterations, the network gradually learns to minimize the error, improving its ability to make accurate predictions.</p>

<h2 id="evaluating-the-model">Evaluating the Model</h2>

<p>After training the neural network, the next step is to evaluate its performance on both the training and test datasets Let’s implement two methods; the <code class="language-plaintext highlighter-rouge">predict</code> method which is used to make predictions and calculate the accuracy of the model, and the <code class="language-plaintext highlighter-rouge">plot_cost</code> method which allows us to visualize the cost function over the course of the training process.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">):</span>
    <span class="n">A</span><span class="p">,</span> <span class="n">cache</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">forward</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
    <span class="n">y_hat</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">argmax</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
    <span class="n">Y</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">argmax</span><span class="p">(</span><span class="n">Y</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">accuracy</span> <span class="o">=</span> <span class="p">(</span><span class="n">y_hat</span> <span class="o">==</span> <span class="n">Y</span><span class="p">).</span><span class="nf">mean</span><span class="p">()</span>
    <span class="k">return</span> <span class="n">accuracy</span> <span class="o">*</span> <span class="mi">100</span>

<span class="k">def</span> <span class="nf">plot_cost</span><span class="p">(</span><span class="n">self</span><span class="p">):</span>
    <span class="n">plt</span><span class="p">.</span><span class="nf">figure</span><span class="p">()</span>
    <span class="n">plt</span><span class="p">.</span><span class="nf">plot</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nf">arange</span><span class="p">(</span><span class="nf">len</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">costs</span><span class="p">)),</span> <span class="n">self</span><span class="p">.</span><span class="n">costs</span><span class="p">)</span>
    <span class="n">plt</span><span class="p">.</span><span class="nf">xlabel</span><span class="p">(</span><span class="sh">"</span><span class="s">epochs</span><span class="sh">"</span><span class="p">)</span>
    <span class="n">plt</span><span class="p">.</span><span class="nf">ylabel</span><span class="p">(</span><span class="sh">"</span><span class="s">cost</span><span class="sh">"</span><span class="p">)</span>
    <span class="n">plt</span><span class="p">.</span><span class="nf">show</span><span class="p">()</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>By calculating the accuracy of the model on the training and test datasets, we can assess how well the network has learned and how effectively it can generalize to new data.</p>

<h2 id="full-code-implementation">Full Code Implementation</h2>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
</pre></td><td class="rouge-code"><pre><span class="kn">import</span> <span class="n">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="n">tensorflow</span> <span class="k">as</span> <span class="n">tf</span> <span class="c1"># Use to download the data 
</span><span class="kn">import</span> <span class="n">matplotlib.pylab</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">from</span> <span class="n">sklearn.preprocessing</span> <span class="kn">import</span> <span class="n">OneHotEncoder</span>


<span class="k">class</span> <span class="nc">NeuralNetwork</span><span class="p">:</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">layers_size</span><span class="p">):</span>
        <span class="n">self</span><span class="p">.</span><span class="n">layers_size</span> <span class="o">=</span> <span class="n">layers_size</span>
        <span class="n">self</span><span class="p">.</span><span class="n">parameters</span> <span class="o">=</span> <span class="p">{}</span>
        <span class="n">self</span><span class="p">.</span><span class="n">length</span> <span class="o">=</span> <span class="nf">len</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">layers_size</span><span class="p">)</span>
        <span class="n">self</span><span class="p">.</span><span class="n">n</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="n">self</span><span class="p">.</span><span class="n">costs</span> <span class="o">=</span> <span class="p">[]</span>

    <span class="k">def</span> <span class="nf">sigmoid</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">Z</span><span class="p">):</span>
        <span class="k">return</span> <span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="nf">exp</span><span class="p">(</span><span class="o">-</span><span class="n">Z</span><span class="p">))</span>
    
    <span class="k">def</span> <span class="nf">softmax</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">Z</span><span class="p">):</span>
        <span class="n">expZ</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">exp</span><span class="p">(</span><span class="n">Z</span> <span class="o">-</span> <span class="n">np</span><span class="p">.</span><span class="nf">max</span><span class="p">(</span><span class="n">Z</span><span class="p">))</span>
        <span class="k">return</span> <span class="n">expZ</span> <span class="o">/</span> <span class="n">expZ</span><span class="p">.</span><span class="nf">sum</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
    
    <span class="k">def</span> <span class="nf">initialize_parameters</span><span class="p">(</span><span class="n">self</span><span class="p">):</span>
        <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">seed</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
    
        <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nf">len</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">layers_size</span><span class="p">)):</span>
            <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">randn</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">layers_size</span><span class="p">[</span><span class="n">layer</span><span class="p">],</span> <span class="n">self</span><span class="p">.</span><span class="n">layers_size</span><span class="p">[</span><span class="n">layer</span> <span class="o">-</span> <span class="mi">1</span><span class="p">])</span> <span class="o">/</span> <span class="n">np</span><span class="p">.</span><span class="nf">sqrt</span><span class="p">(</span>
                <span class="n">self</span><span class="p">.</span><span class="n">layers_size</span><span class="p">[</span><span class="n">layer</span> <span class="o">-</span> <span class="mi">1</span><span class="p">])</span>
            <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">b</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">zeros</span><span class="p">((</span><span class="n">self</span><span class="p">.</span><span class="n">layers_size</span><span class="p">[</span><span class="n">layer</span><span class="p">],</span> <span class="mi">1</span><span class="p">))</span>
    
    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">X</span><span class="p">):</span>
        <span class="n">save</span> <span class="o">=</span> <span class="p">{}</span>
        <span class="n">A</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">T</span>  <span class="c1"># X is already flattened, so no further reshaping needed
</span>        <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span>
            <span class="n">Z</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)].</span><span class="nf">dot</span><span class="p">(</span><span class="n">A</span><span class="p">)</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">b</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span>
            <span class="n">A</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">sigmoid</span><span class="p">(</span><span class="n">Z</span><span class="p">)</span>
            <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">A</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span> <span class="o">=</span> <span class="n">A</span>
            <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span>
            <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">Z</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span> <span class="o">=</span> <span class="n">Z</span>

        <span class="n">Z</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)].</span><span class="nf">dot</span><span class="p">(</span><span class="n">A</span><span class="p">)</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">b</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span>
        <span class="n">A</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">softmax</span><span class="p">(</span><span class="n">Z</span><span class="p">)</span>
        <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">A</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span> <span class="o">=</span> <span class="n">A</span>
        <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span>
        <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">Z</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span> <span class="o">=</span> <span class="n">Z</span>

        <span class="k">return</span> <span class="n">A</span><span class="p">,</span> <span class="n">save</span>

    
    <span class="k">def</span> <span class="nf">sigmoid_derivative</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">Z</span><span class="p">):</span>
        <span class="n">s</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="nf">exp</span><span class="p">(</span><span class="o">-</span><span class="n">Z</span><span class="p">))</span>
        <span class="k">return</span> <span class="n">s</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">s</span><span class="p">)</span>
    
    <span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">save</span><span class="p">):</span>
    
        <span class="n">gradients</span> <span class="o">=</span> <span class="p">{}</span>
    
        <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">A0</span><span class="sh">"</span><span class="p">]</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">T</span>
    
        <span class="n">A</span> <span class="o">=</span> <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">A</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span>
        <span class="n">dZ</span> <span class="o">=</span> <span class="n">A</span> <span class="o">-</span> <span class="n">Y</span><span class="p">.</span><span class="n">T</span>
    
        <span class="n">dW</span> <span class="o">=</span> <span class="n">dZ</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">A</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)].</span><span class="n">T</span><span class="p">)</span> <span class="o">/</span> <span class="n">self</span><span class="p">.</span><span class="n">n</span>
        <span class="n">db</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">sum</span><span class="p">(</span><span class="n">dZ</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> <span class="o">/</span> <span class="n">self</span><span class="p">.</span><span class="n">n</span>
        <span class="n">dAPrev</span> <span class="o">=</span> <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)].</span><span class="n">T</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">dZ</span><span class="p">)</span>
    
        <span class="n">gradients</span><span class="p">[</span><span class="sh">"</span><span class="s">dW</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span> <span class="o">=</span> <span class="n">dW</span>
        <span class="n">gradients</span><span class="p">[</span><span class="sh">"</span><span class="s">db</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span> <span class="o">=</span> <span class="n">db</span>
    
        <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">):</span>
            <span class="n">dZ</span> <span class="o">=</span> <span class="n">dAPrev</span> <span class="o">*</span> <span class="n">self</span><span class="p">.</span><span class="nf">sigmoid_derivative</span><span class="p">(</span><span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">Z</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)])</span>
            <span class="n">dW</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">/</span> <span class="n">self</span><span class="p">.</span><span class="n">n</span> <span class="o">*</span> <span class="n">dZ</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">A</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)].</span><span class="n">T</span><span class="p">)</span>
            <span class="n">db</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">/</span> <span class="n">self</span><span class="p">.</span><span class="n">n</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="nf">sum</span><span class="p">(</span><span class="n">dZ</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
            <span class="k">if</span> <span class="n">layer</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
                <span class="n">dAPrev</span> <span class="o">=</span> <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)].</span><span class="n">T</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">dZ</span><span class="p">)</span>
    
            <span class="n">gradients</span><span class="p">[</span><span class="sh">"</span><span class="s">dW</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">=</span> <span class="n">dW</span>
            <span class="n">gradients</span><span class="p">[</span><span class="sh">"</span><span class="s">db</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">=</span> <span class="n">db</span>
    
        <span class="k">return</span> <span class="n">gradients</span>
    
    <span class="k">def</span> <span class="nf">fit</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span> <span class="n">n_iterations</span><span class="o">=</span><span class="mi">2500</span><span class="p">):</span>
        <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">seed</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
    
        <span class="n">self</span><span class="p">.</span><span class="n">n</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    
        <span class="n">self</span><span class="p">.</span><span class="n">layers_size</span><span class="p">.</span><span class="nf">insert</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
    
        <span class="n">self</span><span class="p">.</span><span class="nf">initialize_parameters</span><span class="p">()</span>
        <span class="k">for</span> <span class="n">loop</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="n">n_iterations</span><span class="p">):</span>
            <span class="n">A</span><span class="p">,</span> <span class="n">save</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">forward</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
            <span class="n">cost</span> <span class="o">=</span> <span class="o">-</span><span class="n">np</span><span class="p">.</span><span class="nf">mean</span><span class="p">(</span><span class="n">Y</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="nf">log</span><span class="p">(</span><span class="n">A</span><span class="p">.</span><span class="n">T</span><span class="o">+</span> <span class="mf">1e-8</span><span class="p">))</span>
            <span class="n">gradients</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">backward</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">save</span><span class="p">)</span>
    
            <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">self</span><span class="p">.</span><span class="n">length</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span>
                <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">-</span> <span class="n">learning_rate</span> <span class="o">*</span> <span class="n">gradients</span><span class="p">[</span>
                    <span class="sh">"</span><span class="s">dW</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span>
                <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">b</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">b</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">-</span> <span class="n">learning_rate</span> <span class="o">*</span> <span class="n">gradients</span><span class="p">[</span>
                    <span class="sh">"</span><span class="s">db</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span>
    
            <span class="k">if</span> <span class="n">loop</span> <span class="o">%</span> <span class="mi">10</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
                <span class="nf">print</span><span class="p">(</span><span class="sh">"</span><span class="s">Cost: </span><span class="sh">"</span><span class="p">,</span> <span class="n">cost</span><span class="p">,</span> <span class="sh">"</span><span class="s">Train Accuracy:</span><span class="sh">"</span><span class="p">,</span> <span class="n">self</span><span class="p">.</span><span class="nf">predict</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">))</span>
    
            <span class="k">if</span> <span class="n">loop</span> <span class="o">%</span> <span class="mi">1</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
                <span class="n">self</span><span class="p">.</span><span class="n">costs</span><span class="p">.</span><span class="nf">append</span><span class="p">(</span><span class="n">cost</span><span class="p">)</span>
    
    <span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">):</span>
        <span class="n">A</span><span class="p">,</span> <span class="n">cache</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">forward</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
        <span class="n">y_hat</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">argmax</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
        <span class="n">Y</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">argmax</span><span class="p">(</span><span class="n">Y</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
        <span class="n">accuracy</span> <span class="o">=</span> <span class="p">(</span><span class="n">y_hat</span> <span class="o">==</span> <span class="n">Y</span><span class="p">).</span><span class="nf">mean</span><span class="p">()</span>
        <span class="k">return</span> <span class="n">accuracy</span> <span class="o">*</span> <span class="mi">100</span>
    
    <span class="k">def</span> <span class="nf">plot_cost</span><span class="p">(</span><span class="n">self</span><span class="p">):</span>
        <span class="n">plt</span><span class="p">.</span><span class="nf">figure</span><span class="p">()</span>
        <span class="n">plt</span><span class="p">.</span><span class="nf">plot</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nf">arange</span><span class="p">(</span><span class="nf">len</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">costs</span><span class="p">)),</span> <span class="n">self</span><span class="p">.</span><span class="n">costs</span><span class="p">)</span>
        <span class="n">plt</span><span class="p">.</span><span class="nf">xlabel</span><span class="p">(</span><span class="sh">"</span><span class="s">epochs</span><span class="sh">"</span><span class="p">)</span>
        <span class="n">plt</span><span class="p">.</span><span class="nf">ylabel</span><span class="p">(</span><span class="sh">"</span><span class="s">cost</span><span class="sh">"</span><span class="p">)</span>
        <span class="n">plt</span><span class="p">.</span><span class="nf">show</span><span class="p">()</span>


<span class="k">def</span> <span class="nf">pre_process_data</span><span class="p">(</span><span class="n">train_x</span><span class="p">,</span> <span class="n">train_y</span><span class="p">,</span> <span class="n">test_x</span><span class="p">,</span> <span class="n">test_y</span><span class="p">):</span>
    <span class="c1"># Flatten the input images
</span>    <span class="n">train_x</span> <span class="o">=</span> <span class="n">train_x</span><span class="p">.</span><span class="nf">reshape</span><span class="p">(</span><span class="n">train_x</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mf">255.</span>  <span class="c1"># Flatten and normalize
</span>    <span class="n">test_x</span> <span class="o">=</span> <span class="n">test_x</span><span class="p">.</span><span class="nf">reshape</span><span class="p">(</span><span class="n">test_x</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mf">255.</span>  <span class="c1"># Flatten and normalize
</span>
    <span class="n">enc</span> <span class="o">=</span> <span class="nc">OneHotEncoder</span><span class="p">(</span><span class="n">sparse</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">categories</span><span class="o">=</span><span class="sh">'</span><span class="s">auto</span><span class="sh">'</span><span class="p">)</span>
    <span class="n">train_y</span> <span class="o">=</span> <span class="n">enc</span><span class="p">.</span><span class="nf">fit_transform</span><span class="p">(</span><span class="n">train_y</span><span class="p">.</span><span class="nf">reshape</span><span class="p">(</span><span class="nf">len</span><span class="p">(</span><span class="n">train_y</span><span class="p">),</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span>
    <span class="n">test_y</span> <span class="o">=</span> <span class="n">enc</span><span class="p">.</span><span class="nf">transform</span><span class="p">(</span><span class="n">test_y</span><span class="p">.</span><span class="nf">reshape</span><span class="p">(</span><span class="nf">len</span><span class="p">(</span><span class="n">test_y</span><span class="p">),</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span>

    <span class="k">return</span> <span class="n">train_x</span><span class="p">,</span> <span class="n">train_y</span><span class="p">,</span> <span class="n">test_x</span><span class="p">,</span> <span class="n">test_y</span>



<span class="k">if</span> <span class="n">__name__</span> <span class="o">==</span> <span class="sh">'</span><span class="s">__main__</span><span class="sh">'</span><span class="p">:</span>
    <span class="p">(</span><span class="n">train_x</span><span class="p">,</span> <span class="n">train_y</span><span class="p">),</span> <span class="p">(</span><span class="n">test_x</span><span class="p">,</span> <span class="n">test_y</span><span class="p">)</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">datasets</span><span class="p">.</span><span class="n">mnist</span><span class="p">.</span><span class="nf">load_data</span><span class="p">()</span>

    <span class="n">train_x</span><span class="p">,</span> <span class="n">train_y</span><span class="p">,</span> <span class="n">test_x</span><span class="p">,</span> <span class="n">test_y</span> <span class="o">=</span> <span class="nf">pre_process_data</span><span class="p">(</span><span class="n">train_x</span><span class="p">,</span> <span class="n">train_y</span><span class="p">,</span> <span class="n">test_x</span><span class="p">,</span> <span class="n">test_y</span><span class="p">)</span>
    
    <span class="nf">print</span><span class="p">(</span><span class="sh">"</span><span class="s">train_x</span><span class="sh">'</span><span class="s">s shape: </span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">train_x</span><span class="p">.</span><span class="n">shape</span><span class="p">))</span>
    <span class="nf">print</span><span class="p">(</span><span class="sh">"</span><span class="s">test_x</span><span class="sh">'</span><span class="s">s shape: </span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">test_x</span><span class="p">.</span><span class="n">shape</span><span class="p">))</span>
    
    <span class="n">dims_of_layer</span> <span class="o">=</span> <span class="p">[</span><span class="mi">50</span><span class="p">,</span> <span class="mi">10</span><span class="p">]</span>
    
    <span class="n">model</span> <span class="o">=</span> <span class="nc">NeuralNetwork</span><span class="p">(</span><span class="n">dims_of_layer</span><span class="p">)</span>
    <span class="n">model</span><span class="p">.</span><span class="nf">fit</span><span class="p">(</span><span class="n">train_x</span><span class="p">,</span> <span class="n">train_y</span><span class="p">,</span> <span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">n_iterations</span><span class="o">=</span><span class="mi">100</span><span class="p">)</span>
    <span class="nf">print</span><span class="p">(</span><span class="sh">"</span><span class="s">Train Accuracy:</span><span class="sh">"</span><span class="p">,</span> <span class="n">model</span><span class="p">.</span><span class="nf">predict</span><span class="p">(</span><span class="n">train_x</span><span class="p">,</span> <span class="n">train_y</span><span class="p">))</span>
    <span class="nf">print</span><span class="p">(</span><span class="sh">"</span><span class="s">Test Accuracy:</span><span class="sh">"</span><span class="p">,</span> <span class="n">model</span><span class="p">.</span><span class="nf">predict</span><span class="p">(</span><span class="n">test_x</span><span class="p">,</span> <span class="n">test_y</span><span class="p">))</span>
    <span class="n">model</span><span class="p">.</span><span class="nf">plot_cost</span><span class="p">()</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<h2 id="conclusion">Conclusion</h2>

<p>In this post, we explored the process of building a neural network from scratch to perform multi-class classification on the MNIST dataset. We started by preprocessing the data, defining the network architecture, and implementing key components such as forward and backward propagation. By training the network, we minimized the error and improved its ability to classify handwritten digits accurately.</p>

<p>We also implemented methods to evaluate the model’s performance and visualize the cost function, providing insights into the network’s learning process. Understanding these foundational concepts equips you with the tools to tackle more complex problems and refine your models for better accuracy and efficiency. If you have any questions, feel free to leave them in the comment section.</p>

<div class="row mt-3 justify-content-center">
    <div class="col-sm-8 mt-3 mt-md-0">
        <picture>
            <img src="/images/blog/multiclass_loss.png" width="100%" alt="Categorical cross entropy loss plot" />
        </picture>
    </div>
</div>

<div class="row mt-3 justify-content-center">
    <div class="col-sm-8 mt-3 mt-md-0">
        <picture>
            <img src="/images/blog/multiclass_accuracy.png" width="100%" alt="Training and validation accuracy plot" />
        </picture>
    </div>
</div>]]></content><author><name></name></author><category term="neural-network" /><category term="backpropagation" /><category term="multiclass-classification" /><category term="forward-pass" /><category term="softmax" /><summary type="html"><![CDATA[In this post, we implemented a neural network from scratch to perform multi-class classification on the MNIST dataset. We started by preprocessing the data, defining the network architecture, and implementing key components such as forward and backward propagation. By training the network, we minimized the error and improved its ability to classify handwritten digits accurately.]]></summary></entry><entry><title type="html">Implementing Neural Network from scratch-Part 1 (Binary Classification)</title><link href="https://princeemensah.github.io/blog/2024/08/neural-net1.html" rel="alternate" type="text/html" title="Implementing Neural Network from scratch-Part 1 (Binary Classification)" /><published>2024-08-14T09:46:13+00:00</published><updated>2024-08-14T09:46:13+00:00</updated><id>https://princeemensah.github.io/blog/2024/08/neural-net1</id><content type="html" xml:base="https://princeemensah.github.io/blog/2024/08/neural-net1.html"><![CDATA[<h2 id="introduction">Introduction</h2>

<p>Neural networks have become a powerful tool these days, forming the backbone of modern deep learning and powering almost everything from computer vison, natural language processing etc. In as much as it’s quite simpler to use pre-built libraries like Pytorch or TensorFlow to build and train neural networks, I think it’s quite important for us to know how these models fundamentally works. In this blog post, we will build a very simple neural network from scratch using on Numpy and perfom a binary classification using MNIST dataset.</p>

<p>We’ll focus on classifying between two distinct digits: <code class="language-plaintext highlighter-rouge">1</code> and <code class="language-plaintext highlighter-rouge">2</code>. Before we dive into building the model, let’s start by downloading the MNIST dataset and perfom some preprocessing that will necessary for training the model.</p>

<h2 id="data-loading-and-preprocessing">Data Loading and Preprocessing</h2>

<p>We’ll begin by loading the MNIST dataset using TensorFlow, which provides a convenient method to download and load the data. The MNIST dataset is a collection of 70,000 images of handwritten digits, each 28x28 pixels in size. After loading the data, we’ll filter it to only include the classes <code class="language-plaintext highlighter-rouge">1</code> and <code class="language-plaintext highlighter-rouge">2</code>.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
</pre></td><td class="rouge-code"><pre><span class="kn">import</span> <span class="n">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="n">tensorflow</span> <span class="k">as</span> <span class="n">tf</span> <span class="c1"># Use to download the data 
</span><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span> <span class="c1"># Reproducibility.
</span></pre></td></tr></tbody></table></code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">dataset</span><span class="p">():</span>
    <span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">datasets</span><span class="p">.</span><span class="n">mnist</span><span class="p">.</span><span class="nf">load_data</span><span class="p">()</span>

    <span class="c1"># Filter training data for classes 1 and 2
</span>    <span class="n">index_1</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">where</span><span class="p">(</span><span class="n">y_train</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)</span>
    <span class="n">index_2</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">where</span><span class="p">(</span><span class="n">y_train</span> <span class="o">==</span> <span class="mi">2</span><span class="p">)</span>

    <span class="n">index</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">concatenate</span><span class="p">([</span><span class="n">index_1</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">index_2</span><span class="p">[</span><span class="mi">0</span><span class="p">]])</span>
    <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">shuffle</span><span class="p">(</span><span class="n">index</span><span class="p">)</span>

    <span class="n">train_x</span> <span class="o">=</span> <span class="n">x_train</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
    <span class="n">train_y</span> <span class="o">=</span> <span class="n">y_train</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>

    <span class="n">train_y</span><span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="nf">where</span><span class="p">(</span><span class="n">train_y</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)]</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="n">train_y</span><span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="nf">where</span><span class="p">(</span><span class="n">train_y</span> <span class="o">==</span> <span class="mi">2</span><span class="p">)]</span> <span class="o">=</span> <span class="mi">1</span>
    
    <span class="c1"># Filter test data for classes 1 and 2
</span>    <span class="n">index_1</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">where</span><span class="p">(</span><span class="n">y_test</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)</span>
    <span class="n">index_2</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">where</span><span class="p">(</span><span class="n">y_test</span> <span class="o">==</span> <span class="mi">2</span><span class="p">)</span>

    <span class="n">index</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">concatenate</span><span class="p">([</span><span class="n">index_1</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">index_2</span><span class="p">[</span><span class="mi">0</span><span class="p">]])</span>
    <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">shuffle</span><span class="p">(</span><span class="n">index</span><span class="p">)</span>

    <span class="n">test_y</span> <span class="o">=</span> <span class="n">y_test</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
    <span class="n">test_x</span> <span class="o">=</span> <span class="n">x_test</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>

    <span class="n">test_y</span><span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="nf">where</span><span class="p">(</span><span class="n">test_y</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)]</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="n">test_y</span><span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="nf">where</span><span class="p">(</span><span class="n">test_y</span> <span class="o">==</span> <span class="mi">2</span><span class="p">)]</span> <span class="o">=</span> <span class="mi">1</span>

    <span class="k">return</span> <span class="n">train_x</span><span class="p">,</span> <span class="n">train_y</span><span class="p">,</span> <span class="n">test_x</span><span class="p">,</span> <span class="n">test_y</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>In the above code, we loaded the dataset and then use NumPy to filter the images based on their labels Finally, we relabeled the data so that <code class="language-plaintext highlighter-rouge">1</code> becomes <code class="language-plaintext highlighter-rouge">0</code> and <code class="language-plaintext highlighter-rouge">2</code> becomes <code class="language-plaintext highlighter-rouge">1</code>, making this a binary classification problem.</p>

<h3 id="preprocessing-the-data">Preprocessing the Data</h3>

<p>The next thing we’ll do it to normalize the data, which means that the pixel values of the mnist data which ranges from 0 to 255 will now be scaled to a range between 0 and 1. And yes, since our neural network will be a fully connected (dense) network, we need to flatten each 28x28 image into a 784-dimensional vector.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">data_preprocessing</span><span class="p">(</span><span class="n">train_x</span><span class="p">,</span> <span class="n">test_x</span><span class="p">):</span>
    <span class="c1"># Normalize the pixel values to [0, 1]
</span>    <span class="n">train_x</span> <span class="o">=</span> <span class="n">train_x</span> <span class="o">/</span> <span class="mf">255.</span>
    <span class="n">test_x</span> <span class="o">=</span> <span class="n">test_x</span> <span class="o">/</span> <span class="mf">255.</span>

    <span class="c1"># Flatten the images from 28x28 to 784
</span>    <span class="n">train_x</span> <span class="o">=</span> <span class="n">train_x</span><span class="p">.</span><span class="nf">reshape</span><span class="p">(</span><span class="n">train_x</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">test_x</span> <span class="o">=</span> <span class="n">test_x</span><span class="p">.</span><span class="nf">reshape</span><span class="p">(</span><span class="n">test_x</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">train_x</span><span class="p">,</span> <span class="n">test_x</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
</pre></td><td class="rouge-code"><pre><span class="nf">print</span><span class="p">(</span><span class="sh">"</span><span class="s">train_x</span><span class="sh">'</span><span class="s">s shape: </span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">train_x</span><span class="p">.</span><span class="n">shape</span><span class="p">))</span>
<span class="nf">print</span><span class="p">(</span><span class="sh">"</span><span class="s">test_x</span><span class="sh">'</span><span class="s">s shape: </span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">test_x</span><span class="p">.</span><span class="n">shape</span><span class="p">))</span> 
</pre></td></tr></tbody></table></code></pre></div></div>

<p>Output</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
</pre></td><td class="rouge-code"><pre><span class="n">train_x</span><span class="sh">'</span><span class="s">s shape: (12700, 784)
test_x</span><span class="sh">'</span><span class="n">s</span> <span class="n">shape</span><span class="p">:</span> <span class="p">(</span><span class="mi">2167</span><span class="p">,</span> <span class="mi">784</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<h2 id="implementing-the-neural-network">Implementing The Neural Network</h2>

<p>Now, let’s dive into the core of this project starting with initializing the network and moving through the forward pass, backward pass, training, and prediction phases.</p>

<h3 id="initializing-the-neural-network">Initializing the Neural Network</h3>

<p>The first step in building our neural network is to define its structure and initialize some key components. This is done in the <code class="language-plaintext highlighter-rouge">__init__</code> method of the neural network class.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
</pre></td><td class="rouge-code"><pre><span class="k">class</span> <span class="nc">NeuralNet</span><span class="p">:</span>
  <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">size_of_layers</span><span class="p">):</span>
    <span class="n">self</span><span class="p">.</span><span class="n">size_of_layers</span> <span class="o">=</span> <span class="n">size_of_layers</span>
    <span class="n">self</span><span class="p">.</span><span class="n">parameters</span> <span class="o">=</span> <span class="p">{}</span>
    <span class="n">self</span><span class="p">.</span><span class="n">length</span> <span class="o">=</span> <span class="nf">len</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">size_of_layers</span><span class="p">)</span> <span class="c1"># number of layers
</span>    <span class="n">self</span><span class="p">.</span><span class="n">n</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># number of traing examples
</span>    <span class="n">self</span><span class="p">.</span><span class="n">costs</span> <span class="o">=</span> <span class="p">[]</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>With this initialization, we’ve set up the basic structure of our neural network. In the next steps, we’ll define how the network initializes its weights, performs forward passes, and updates its parameters during training.</p>

<h3 id="initializing-the-network-parameters">Initializing the Network Parameters</h3>

<p>Once we have defined the structure of our neural the next step is to initialize the parameters, specifically the weights and biases—for each layer.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">initialize_parameters</span><span class="p">(</span><span class="n">self</span><span class="p">):</span>
  <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
  <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nf">len</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">size_of_layers</span><span class="p">)):</span>
    <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">randn</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">size_of_layers</span><span class="p">[</span><span class="n">layer</span><span class="p">],</span> <span class="n">self</span><span class="p">.</span><span class="n">size_of_layers</span><span class="p">[</span><span class="n">layer</span> <span class="o">-</span> <span class="mi">1</span><span class="p">])</span> <span class="o">/</span> <span class="n">np</span><span class="p">.</span><span class="nf">sqrt</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">size_of_layers</span><span class="p">[</span><span class="n">layer</span> <span class="o">-</span> <span class="mi">1</span><span class="p">])</span>
    <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">b</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">zeros</span><span class="p">((</span><span class="n">self</span><span class="p">.</span><span class="n">size_of_layers</span><span class="p">[</span><span class="n">layer</span><span class="p">],</span> <span class="mi">1</span><span class="p">))</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>We initialize a weight matrix <code class="language-plaintext highlighter-rouge">W</code> using a Gaussian distribution where the dimensions of this matrix are determined by the number of neurons in the current layer and the previous layer. The weights are scaled by the inverse square root of the number of neurons in the previous layer. This technique is sometimes called He or Xavier initialization. The biases <code class="language-plaintext highlighter-rouge">b</code> for each layer are initialized to zeros.</p>

<h3 id="forward-pass-feeding-data-through-the-network">Forward Pass: Feeding Data Through the Network</h3>

<p>After initializing the parameters of our neural network, the next step is to define the forward pass. This is where we pass our preprocessed data through the network to generate predictions. In this step, the input data is transformed layer by layer until we reach the final output.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">forward_pass</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">X</span><span class="p">):</span>
    <span class="n">save</span> <span class="o">=</span> <span class="p">{}</span>

    <span class="n">A</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">T</span>
    <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span>
        <span class="n">Z</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)].</span><span class="nf">dot</span><span class="p">(</span><span class="n">A</span><span class="p">)</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">b</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span>
        <span class="n">A</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">sigmoid</span><span class="p">(</span><span class="n">Z</span><span class="p">)</span>
        <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">A</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span> <span class="o">=</span> <span class="n">A</span>
        <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span>
        <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">Z</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span> <span class="o">=</span> <span class="n">Z</span>

    <span class="n">Z</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)].</span><span class="nf">dot</span><span class="p">(</span><span class="n">A</span><span class="p">)</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">b</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span>
    <span class="n">A</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">sigmoid</span><span class="p">(</span><span class="n">Z</span><span class="p">)</span>
    <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">A</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span> <span class="o">=</span> <span class="n">A</span>
    <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span>
    <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">Z</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span> <span class="o">=</span> <span class="n">Z</span>

    <span class="k">return</span> <span class="n">A</span><span class="p">,</span> <span class="n">save</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>The forward pass we have just implemented is where the neural network processes the input data, transforms it through each layer, and produces an output prediction. And by storing intermediate results, the network prepares itself for the backward pass, where it will adjust its parameters to minimize the prediction error.</p>

<h3 id="backward-pass-updating-parameters-through-backpropagation">Backward Pass: Updating Parameters through Backpropagation</h3>

<p>After implementing the forward pass and making predictions, the next important step is the backward pass, also known as backpropagation. This is where the neural network calculates the gradients of the loss function with respect to each parameter (weights and biases) and adjusts them to minimize the error in predictions.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">backward_pass</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">save</span><span class="p">):</span>
    <span class="n">save_gradients</span> <span class="o">=</span> <span class="p">{}</span> 
    <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">A0</span><span class="sh">"</span><span class="p">]</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">T</span>

    <span class="n">A</span> <span class="o">=</span> <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">A</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span>
    <span class="n">dA</span> <span class="o">=</span> <span class="o">-</span><span class="n">np</span><span class="p">.</span><span class="nf">divide</span><span class="p">(</span><span class="n">Y</span><span class="p">,</span> <span class="n">A</span><span class="p">)</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="nf">divide</span><span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">Y</span><span class="p">,</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">A</span><span class="p">)</span>

    <span class="n">dZ</span> <span class="o">=</span> <span class="n">dA</span> <span class="o">*</span> <span class="n">self</span><span class="p">.</span><span class="nf">sigmoid_derivative</span><span class="p">(</span><span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">Z</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)])</span>
    <span class="n">dW</span> <span class="o">=</span> <span class="n">dZ</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">A</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)].</span><span class="n">T</span><span class="p">)</span> <span class="o">/</span> <span class="n">self</span><span class="p">.</span><span class="n">n</span>
    <span class="n">db</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">sum</span><span class="p">(</span><span class="n">dZ</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> <span class="o">/</span> <span class="n">self</span><span class="p">.</span><span class="n">n</span>
    <span class="n">dAPrev</span> <span class="o">=</span> <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)].</span><span class="n">T</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">dZ</span><span class="p">)</span>

    <span class="n">save_gradients</span><span class="p">[</span><span class="sh">"</span><span class="s">dW</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span> <span class="o">=</span> <span class="n">dW</span>
    <span class="n">save_gradients</span><span class="p">[</span><span class="sh">"</span><span class="s">db</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span> <span class="o">=</span> <span class="n">db</span>

    <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">):</span>
        <span class="n">dZ</span> <span class="o">=</span> <span class="n">dAPrev</span> <span class="o">*</span> <span class="n">self</span><span class="p">.</span><span class="nf">sigmoid_derivative</span><span class="p">(</span><span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">Z</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)])</span>
        <span class="n">dW</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">/</span> <span class="n">self</span><span class="p">.</span><span class="n">n</span> <span class="o">*</span> <span class="n">dZ</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">A</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)].</span><span class="n">T</span><span class="p">)</span>
        <span class="n">db</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">/</span> <span class="n">self</span><span class="p">.</span><span class="n">n</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="nf">sum</span><span class="p">(</span><span class="n">dZ</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">layer</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
            <span class="n">dAPrev</span> <span class="o">=</span> <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)].</span><span class="n">T</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">dZ</span><span class="p">)</span>

        <span class="n">save_gradients</span><span class="p">[</span><span class="sh">"</span><span class="s">dW</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">=</span> <span class="n">dW</span>
        <span class="n">save_gradients</span><span class="p">[</span><span class="sh">"</span><span class="s">db</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">=</span> <span class="n">db</span>

    <span class="k">return</span> <span class="n">save_gradients</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>The backpropagation we have implemented above is an important part of the neural network. By calculating how much each parameter (weight and bias) contributes to the overall error, the network can adjust these parameters to minimize the error. This process is repeated over many iterations, gradually improving the network’s ability to make accurate predictions.</p>

<h3 id="training-the-neural-network">Training the Neural Network</h3>

<p>Let’s now start training the neural network. The training process involves iteratively updating the network’s parameters (weights and biases) to minimize the prediction error.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">fit</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span> <span class="n">n_iterations</span><span class="o">=</span><span class="mi">3000</span><span class="p">):</span>
    <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
    <span class="n">self</span><span class="p">.</span><span class="n">n</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    <span class="n">self</span><span class="p">.</span><span class="n">size_of_layers</span><span class="p">.</span><span class="nf">insert</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>

    <span class="n">self</span><span class="p">.</span><span class="nf">initialize_parameters</span><span class="p">()</span>
    <span class="k">for</span> <span class="n">loop</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="n">n_iterations</span><span class="p">):</span>
        <span class="n">A</span><span class="p">,</span> <span class="n">save</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">forward_pass</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
        <span class="n">cost</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">squeeze</span><span class="p">(</span><span class="o">-</span><span class="p">(</span><span class="n">Y</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nf">log</span><span class="p">(</span><span class="n">A</span><span class="p">.</span><span class="n">T</span><span class="p">))</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">Y</span><span class="p">).</span><span class="nf">dot</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nf">log</span><span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">A</span><span class="p">.</span><span class="n">T</span><span class="p">)))</span> <span class="o">/</span> <span class="n">self</span><span class="p">.</span><span class="n">n</span><span class="p">)</span>
        <span class="n">gradients</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">backward_pass</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">save</span><span class="p">)</span>

        <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">self</span><span class="p">.</span><span class="n">length</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span>
            <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">-</span> <span class="n">learning_rate</span> <span class="o">*</span> <span class="n">gradients</span><span class="p">[</span><span class="sh">"</span><span class="s">dW</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span>
            <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">b</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">b</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">-</span> <span class="n">learning_rate</span> <span class="o">*</span> <span class="n">gradients</span><span class="p">[</span><span class="sh">"</span><span class="s">db</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span>

        <span class="k">if</span> <span class="n">loop</span> <span class="o">%</span> <span class="mi">10</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
            <span class="nf">print</span><span class="p">(</span><span class="n">cost</span><span class="p">)</span>
            <span class="n">self</span><span class="p">.</span><span class="n">costs</span><span class="p">.</span><span class="nf">append</span><span class="p">(</span><span class="n">cost</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>The fit method we have implemented above is simply the training process which repeatedly adjusts the network’s parameters based on the outputs from the cost function. By the end of the training process, the network should have learned a set of parameters that minimize the error on the training data, allowing it to make accurate predictions.</p>

<h3 id="making-predictions">Making Predictions</h3>

<p>After training the neural network, the next step is to use it to make predictions on new data. The <code class="language-plaintext highlighter-rouge">predict</code> method handles this task, taking input data and using the trained model to predict the output labels. Additionally, it calculates the accuracy of the predictions compared to the actual labels.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">):</span>
    <span class="n">A</span><span class="p">,</span> <span class="n">cache</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">forward_pass</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
    <span class="n">n</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    <span class="n">pred</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="n">n</span><span class="p">))</span>

    <span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">A</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]):</span> 
        <span class="k">if</span> <span class="n">A</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">idx</span><span class="p">]</span> <span class="o">&gt;</span> <span class="mf">0.5</span><span class="p">:</span>
            <span class="n">pred</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">pred</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>

    <span class="nf">print</span><span class="p">(</span><span class="sh">"</span><span class="s">Accuracy: </span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nf">sum</span><span class="p">((</span><span class="n">pred</span> <span class="o">==</span> <span class="n">Y</span><span class="p">)</span> <span class="o">/</span> <span class="n">n</span><span class="p">)))</span>

<span class="k">def</span> <span class="nf">plot_cost</span><span class="p">(</span><span class="n">self</span><span class="p">):</span>
    <span class="n">plt</span><span class="p">.</span><span class="nf">figure</span><span class="p">()</span>
    <span class="n">plt</span><span class="p">.</span><span class="nf">plot</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nf">arange</span><span class="p">(</span><span class="nf">len</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">costs</span><span class="p">)),</span> <span class="n">self</span><span class="p">.</span><span class="n">costs</span><span class="p">)</span>
    <span class="n">plt</span><span class="p">.</span><span class="nf">xlabel</span><span class="p">(</span><span class="sh">"</span><span class="s">epochs</span><span class="sh">"</span><span class="p">)</span>
    <span class="n">plt</span><span class="p">.</span><span class="nf">ylabel</span><span class="p">(</span><span class="sh">"</span><span class="s">cost</span><span class="sh">"</span><span class="p">)</span>
    <span class="n">plt</span><span class="p">.</span><span class="nf">show</span><span class="p">()</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>The <code class="language-plaintext highlighter-rouge">predict</code> method we have implemented above allows us to evaluate how well our trained model performs on new, unseen data. This method is import for testing the generalizability of the neural network and ensuring that it can make accurate predictions outside of the training data. Lastly, we generate a plot of the cost function over the iterations, allowing us to visualize how well the model is learning over time.</p>

<h3 id="putting-it-all-together">Putting It All Together</h3>

<p>With the neural network class fully implemented, we can now put everything together to train the model, make predictions, and evaluate its performance.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
</pre></td><td class="rouge-code"><pre><span class="n">size_of_layers</span> <span class="o">=</span> <span class="p">[</span><span class="mi">196</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span>

<span class="n">model</span> <span class="o">=</span> <span class="nc">NeuralNet</span><span class="p">(</span><span class="n">size_of_layers</span><span class="p">)</span>
<span class="n">model</span><span class="p">.</span><span class="nf">fit</span><span class="p">(</span><span class="n">train_x</span><span class="p">,</span> <span class="n">train_y</span><span class="p">,</span> <span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">n_iterations</span><span class="o">=</span><span class="mi">100</span><span class="p">)</span>
<span class="n">model</span><span class="p">.</span><span class="nf">predict</span><span class="p">(</span><span class="n">train_x</span><span class="p">,</span> <span class="n">train_y</span><span class="p">)</span>
<span class="n">model</span><span class="p">.</span><span class="nf">predict</span><span class="p">(</span><span class="n">test_x</span><span class="p">,</span> <span class="n">test_y</span><span class="p">)</span>
<span class="n">model</span><span class="p">.</span><span class="nf">plot_cost</span><span class="p">()</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>The above implementation is the final step, which define the structure of our neural network, train it on the training data, and then test its accuracy on both the training and test datasets.</p>

<h3 id="full-code-implementation">Full Code Implementation</h3>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
</pre></td><td class="rouge-code"><pre><span class="kn">import</span> <span class="n">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="n">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">import</span> <span class="n">tensorflow</span> <span class="k">as</span> <span class="n">tf</span> <span class="c1"># Use to download the data 
</span><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span> <span class="c1">#reproducibility.
</span>
<span class="k">class</span> <span class="nc">NeuralNet</span><span class="p">:</span>
  <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">size_of_layers</span><span class="p">):</span>
    <span class="n">self</span><span class="p">.</span><span class="n">size_of_layers</span> <span class="o">=</span> <span class="n">size_of_layers</span>
    <span class="n">self</span><span class="p">.</span><span class="n">parameters</span> <span class="o">=</span> <span class="p">{}</span>
    <span class="n">self</span><span class="p">.</span><span class="n">length</span> <span class="o">=</span> <span class="nf">len</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">size_of_layers</span><span class="p">)</span>
    <span class="n">self</span><span class="p">.</span><span class="n">n</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="n">self</span><span class="p">.</span><span class="n">costs</span> <span class="o">=</span> <span class="p">[]</span>


  <span class="k">def</span> <span class="nf">sigmoid</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">z</span><span class="p">):</span>
    <span class="k">return</span> <span class="mi">1</span><span class="o">/</span><span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="nf">exp</span><span class="p">(</span><span class="o">-</span><span class="n">z</span><span class="p">))</span>
    

  <span class="k">def</span> <span class="nf">sigmoid_derivative</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">z</span><span class="p">):</span>
    <span class="n">sigma</span> <span class="o">=</span> <span class="mi">1</span><span class="o">/</span><span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="nf">exp</span><span class="p">(</span><span class="o">-</span><span class="n">z</span><span class="p">))</span>
    <span class="k">return</span> <span class="n">sigma</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">sigma</span><span class="p">)</span>
    

  <span class="k">def</span> <span class="nf">initialize_parameters</span><span class="p">(</span><span class="n">self</span><span class="p">):</span>
    <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span> <span class="c1"># reproducibility
</span>    <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nf">len</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">size_of_layers</span><span class="p">)):</span>
      <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">randn</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">size_of_layers</span><span class="p">[</span><span class="n">layer</span><span class="p">],</span> <span class="n">self</span><span class="p">.</span><span class="n">size_of_layers</span><span class="p">[</span><span class="n">layer</span> <span class="o">-</span> <span class="mi">1</span><span class="p">])</span><span class="o">/</span><span class="n">np</span><span class="p">.</span><span class="nf">sqrt</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">size_of_layers</span><span class="p">[</span><span class="n">layer</span> <span class="o">-</span> <span class="mi">1</span><span class="p">])</span>
      <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">b</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">zeros</span><span class="p">((</span><span class="n">self</span><span class="p">.</span><span class="n">size_of_layers</span><span class="p">[</span><span class="n">layer</span><span class="p">],</span> <span class="mi">1</span><span class="p">))</span>

  <span class="c1"># forward pass
</span>  <span class="k">def</span> <span class="nf">forward_pass</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">X</span><span class="p">):</span>
    <span class="n">save</span> <span class="o">=</span> <span class="p">{}</span>

    <span class="n">A</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">T</span>
    <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span>
      <span class="n">Z</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)].</span><span class="nf">dot</span><span class="p">(</span><span class="n">A</span><span class="p">)</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">b</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span>
      <span class="n">A</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">sigmoid</span><span class="p">(</span><span class="n">Z</span><span class="p">)</span>
      <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">A</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span> <span class="o">=</span> <span class="n">A</span>
      <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span>
      <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">Z</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)]</span> <span class="o">=</span> <span class="n">Z</span>

    <span class="n">Z</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)].</span><span class="nf">dot</span><span class="p">(</span><span class="n">A</span><span class="p">)</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">b</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span>
    <span class="n">A</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">sigmoid</span><span class="p">(</span><span class="n">Z</span><span class="p">)</span>
    <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">A</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span> <span class="o">=</span> <span class="n">A</span>
    <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span>
    <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">Z</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span> <span class="o">=</span> <span class="n">Z</span>

    <span class="k">return</span> <span class="n">A</span><span class="p">,</span> <span class="n">save</span>

  <span class="c1"># backward pass
</span>  <span class="k">def</span> <span class="nf">backward_pass</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">save</span><span class="p">):</span>
      <span class="n">save_gradients</span> <span class="o">=</span> <span class="p">{}</span> 
      <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">A0</span><span class="sh">"</span><span class="p">]</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">T</span>

      <span class="n">A</span> <span class="o">=</span> <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">A</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span>
      <span class="n">dA</span> <span class="o">=</span> <span class="o">-</span><span class="n">np</span><span class="p">.</span><span class="nf">divide</span><span class="p">(</span><span class="n">Y</span><span class="p">,</span> <span class="n">A</span><span class="p">)</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="nf">divide</span><span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">Y</span><span class="p">,</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">A</span><span class="p">)</span>

      <span class="n">dZ</span> <span class="o">=</span> <span class="n">dA</span> <span class="o">*</span> <span class="n">self</span><span class="p">.</span><span class="nf">sigmoid_derivative</span><span class="p">(</span><span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">Z</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)])</span>
      <span class="n">dW</span> <span class="o">=</span> <span class="n">dZ</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">A</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)].</span><span class="n">T</span><span class="p">)</span> <span class="o">/</span> <span class="n">self</span><span class="p">.</span><span class="n">n</span>
      <span class="n">db</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">sum</span><span class="p">(</span><span class="n">dZ</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> <span class="o">/</span> <span class="n">self</span><span class="p">.</span><span class="n">n</span>
      <span class="n">dAPrev</span> <span class="o">=</span> <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)].</span><span class="n">T</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">dZ</span><span class="p">)</span>

      <span class="n">save_gradients</span><span class="p">[</span><span class="sh">"</span><span class="s">dW</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span> <span class="o">=</span> <span class="n">dW</span>
      <span class="n">save_gradients</span><span class="p">[</span><span class="sh">"</span><span class="s">db</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span><span class="p">)]</span> <span class="o">=</span> <span class="n">db</span>

      <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">):</span>
          <span class="n">dZ</span> <span class="o">=</span> <span class="n">dAPrev</span> <span class="o">*</span> <span class="n">self</span><span class="p">.</span><span class="nf">sigmoid_derivative</span><span class="p">(</span><span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">Z</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)])</span>
          <span class="n">dW</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">/</span> <span class="n">self</span><span class="p">.</span><span class="n">n</span> <span class="o">*</span> <span class="n">dZ</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">A</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)].</span><span class="n">T</span><span class="p">)</span>
          <span class="n">db</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">/</span> <span class="n">self</span><span class="p">.</span><span class="n">n</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="nf">sum</span><span class="p">(</span><span class="n">dZ</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
          <span class="k">if</span> <span class="n">layer</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
              <span class="n">dAPrev</span> <span class="o">=</span> <span class="n">save</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)].</span><span class="n">T</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">dZ</span><span class="p">)</span>

          <span class="n">save_gradients</span><span class="p">[</span><span class="sh">"</span><span class="s">dW</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">=</span> <span class="n">dW</span>
          <span class="n">save_gradients</span><span class="p">[</span><span class="sh">"</span><span class="s">db</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">=</span> <span class="n">db</span>

      <span class="k">return</span> <span class="n">save_gradients</span>

  <span class="k">def</span> <span class="nf">fit</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span> <span class="n">n_iterations</span><span class="o">=</span><span class="mi">3000</span><span class="p">):</span>
      <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span> <span class="c1"># reproducibility
</span>      <span class="n">self</span><span class="p">.</span><span class="n">n</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
      <span class="n">self</span><span class="p">.</span><span class="n">size_of_layers</span><span class="p">.</span><span class="nf">insert</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>

      <span class="n">self</span><span class="p">.</span><span class="nf">initialize_parameters</span><span class="p">()</span>
      <span class="k">for</span> <span class="n">loop</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="n">n_iterations</span><span class="p">):</span>
          <span class="n">A</span><span class="p">,</span> <span class="n">save</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">forward_pass</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
          <span class="n">cost</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">squeeze</span><span class="p">(</span><span class="o">-</span><span class="p">(</span><span class="n">Y</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nf">log</span><span class="p">(</span><span class="n">A</span><span class="p">.</span><span class="n">T</span><span class="p">))</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">Y</span><span class="p">).</span><span class="nf">dot</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nf">log</span><span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">A</span><span class="p">.</span><span class="n">T</span><span class="p">)))</span> <span class="o">/</span> <span class="n">self</span><span class="p">.</span><span class="n">n</span><span class="p">)</span>
          <span class="n">gradients</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">backward_pass</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">,</span> <span class="n">save</span><span class="p">)</span>

          <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">self</span><span class="p">.</span><span class="n">length</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span>
              <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">W</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">-</span> <span class="n">learning_rate</span> <span class="o">*</span> <span class="n">gradients</span><span class="p">[</span>
                  <span class="sh">"</span><span class="s">dW</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span>
              <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">b</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">[</span><span class="sh">"</span><span class="s">b</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span> <span class="o">-</span> <span class="n">learning_rate</span> <span class="o">*</span> <span class="n">gradients</span><span class="p">[</span>
                  <span class="sh">"</span><span class="s">db</span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">layer</span><span class="p">)]</span>

          <span class="k">if</span> <span class="n">loop</span> <span class="o">%</span> <span class="mi">100</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
              <span class="nf">print</span><span class="p">(</span><span class="n">cost</span><span class="p">)</span>
              <span class="n">self</span><span class="p">.</span><span class="n">costs</span><span class="p">.</span><span class="nf">append</span><span class="p">(</span><span class="n">cost</span><span class="p">)</span>

  <span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">Y</span><span class="p">):</span>
      <span class="n">A</span><span class="p">,</span> <span class="n">cache</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">forward_pass</span><span class="p">(</span><span class="n">X</span><span class="p">)</span>
      <span class="n">n</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
      <span class="n">pred</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="n">n</span><span class="p">))</span>

      <span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">A</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]):</span> 
          <span class="k">if</span> <span class="n">A</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">idx</span><span class="p">]</span> <span class="o">&gt;</span> <span class="mf">0.5</span><span class="p">:</span>
              <span class="n">pred</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
          <span class="k">else</span><span class="p">:</span>
              <span class="n">pred</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>

      <span class="nf">print</span><span class="p">(</span><span class="sh">"</span><span class="s">Accuracy: </span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nf">sum</span><span class="p">((</span><span class="n">pred</span> <span class="o">==</span> <span class="n">Y</span><span class="p">)</span> <span class="o">/</span> <span class="n">n</span><span class="p">)))</span>

  <span class="k">def</span> <span class="nf">plot_cost</span><span class="p">(</span><span class="n">self</span><span class="p">):</span>
      <span class="n">plt</span><span class="p">.</span><span class="nf">figure</span><span class="p">()</span>
      <span class="n">plt</span><span class="p">.</span><span class="nf">plot</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nf">arange</span><span class="p">(</span><span class="nf">len</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">costs</span><span class="p">)),</span> <span class="n">self</span><span class="p">.</span><span class="n">costs</span><span class="p">)</span>
      <span class="n">plt</span><span class="p">.</span><span class="nf">xlabel</span><span class="p">(</span><span class="sh">"</span><span class="s">epochs</span><span class="sh">"</span><span class="p">)</span>
      <span class="n">plt</span><span class="p">.</span><span class="nf">ylabel</span><span class="p">(</span><span class="sh">"</span><span class="s">cost</span><span class="sh">"</span><span class="p">)</span>
      <span class="n">plt</span><span class="p">.</span><span class="nf">show</span><span class="p">()</span>

<span class="k">def</span> <span class="nf">dataset</span><span class="p">():</span>
  <span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">datasets</span><span class="p">.</span><span class="n">mnist</span><span class="p">.</span><span class="nf">load_data</span><span class="p">()</span>

  <span class="n">index_1</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">where</span><span class="p">(</span><span class="n">y_train</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)</span>
  <span class="n">index_2</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">where</span><span class="p">(</span><span class="n">y_train</span> <span class="o">==</span> <span class="mi">2</span><span class="p">)</span>

  <span class="n">index</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">concatenate</span><span class="p">([</span><span class="n">index_1</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">index_2</span><span class="p">[</span><span class="mi">0</span><span class="p">]])</span>
  <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
  <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">shuffle</span><span class="p">(</span><span class="n">index</span><span class="p">)</span>

  <span class="n">train_x</span> <span class="o">=</span> <span class="n">x_train</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
  <span class="n">train_y</span> <span class="o">=</span> <span class="n">y_train</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>

  <span class="n">train_y</span><span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="nf">where</span><span class="p">(</span><span class="n">train_y</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)]</span> <span class="o">=</span> <span class="mi">0</span>
  <span class="n">train_y</span><span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="nf">where</span><span class="p">(</span><span class="n">train_y</span> <span class="o">==</span> <span class="mi">2</span><span class="p">)]</span> <span class="o">=</span> <span class="mi">1</span>
  
  <span class="n">index_1</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">where</span><span class="p">(</span><span class="n">y_test</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)</span>
  <span class="n">index_2</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">where</span><span class="p">(</span><span class="n">y_test</span> <span class="o">==</span> <span class="mi">2</span><span class="p">)</span>

  <span class="n">index</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">concatenate</span><span class="p">([</span><span class="n">index_1</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">index_2</span><span class="p">[</span><span class="mi">0</span><span class="p">]])</span>
  <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">shuffle</span><span class="p">(</span><span class="n">index</span><span class="p">)</span>

  <span class="n">index</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">concatenate</span><span class="p">([</span><span class="n">index_1</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">index_2</span><span class="p">[</span><span class="mi">0</span><span class="p">]])</span>
  <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">shuffle</span><span class="p">(</span><span class="n">index</span><span class="p">)</span>

  <span class="n">test_y</span> <span class="o">=</span> <span class="n">y_test</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
  <span class="n">test_x</span> <span class="o">=</span> <span class="n">x_test</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>

  <span class="n">test_y</span><span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="nf">where</span><span class="p">(</span><span class="n">test_y</span> <span class="o">==</span> <span class="mi">1</span><span class="p">)]</span> <span class="o">=</span> <span class="mi">0</span>
  <span class="n">test_y</span><span class="p">[</span><span class="n">np</span><span class="p">.</span><span class="nf">where</span><span class="p">(</span><span class="n">test_y</span> <span class="o">==</span> <span class="mi">2</span><span class="p">)]</span> <span class="o">=</span> <span class="mi">1</span>

  <span class="k">return</span> <span class="n">train_x</span><span class="p">,</span> <span class="n">train_y</span><span class="p">,</span> <span class="n">test_x</span><span class="p">,</span> <span class="n">test_y</span>

<span class="k">def</span> <span class="nf">data_preprocessing</span><span class="p">(</span><span class="n">train_x</span><span class="p">,</span> <span class="n">test_x</span><span class="p">):</span>
    <span class="c1"># Normalize
</span>    <span class="n">train_x</span> <span class="o">=</span> <span class="n">train_x</span> <span class="o">/</span> <span class="mf">255.</span>
    <span class="n">test_x</span> <span class="o">=</span> <span class="n">test_x</span> <span class="o">/</span> <span class="mf">255.</span>

    <span class="c1"># Flatten the images
</span>    <span class="n">train_x</span> <span class="o">=</span> <span class="n">train_x</span><span class="p">.</span><span class="nf">reshape</span><span class="p">(</span><span class="n">train_x</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">test_x</span> <span class="o">=</span> <span class="n">test_x</span><span class="p">.</span><span class="nf">reshape</span><span class="p">(</span><span class="n">test_x</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">train_x</span><span class="p">,</span> <span class="n">test_x</span>

<span class="n">train_x</span><span class="p">,</span> <span class="n">train_y</span><span class="p">,</span> <span class="n">test_x</span><span class="p">,</span> <span class="n">test_y</span> <span class="o">=</span> <span class="nf">dataset</span><span class="p">()</span>
<span class="n">train_x</span><span class="p">,</span> <span class="n">test_x</span> <span class="o">=</span> <span class="nf">data_preprocessing</span><span class="p">(</span><span class="n">train_x</span><span class="p">,</span> <span class="n">test_x</span><span class="p">)</span>

<span class="nf">print</span><span class="p">(</span><span class="sh">"</span><span class="s">train_x</span><span class="sh">'</span><span class="s">s shape: </span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">train_x</span><span class="p">.</span><span class="n">shape</span><span class="p">))</span>
<span class="nf">print</span><span class="p">(</span><span class="sh">"</span><span class="s">test_x</span><span class="sh">'</span><span class="s">s shape: </span><span class="sh">"</span> <span class="o">+</span> <span class="nf">str</span><span class="p">(</span><span class="n">test_x</span><span class="p">.</span><span class="n">shape</span><span class="p">))</span> 

<span class="n">size_of_layers</span> <span class="o">=</span> <span class="p">[</span><span class="mi">196</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span>

<span class="n">model</span> <span class="o">=</span> <span class="nc">NeuralNet</span><span class="p">(</span><span class="n">size_of_layers</span><span class="p">)</span>
<span class="n">model</span><span class="p">.</span><span class="nf">fit</span><span class="p">(</span><span class="n">train_x</span><span class="p">,</span> <span class="n">train_y</span><span class="p">,</span> <span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">n_iterations</span><span class="o">=</span><span class="mi">1000</span><span class="p">)</span>
<span class="n">model</span><span class="p">.</span><span class="nf">predict</span><span class="p">(</span><span class="n">train_x</span><span class="p">,</span> <span class="n">train_y</span><span class="p">)</span>
<span class="n">model</span><span class="p">.</span><span class="nf">predict</span><span class="p">(</span><span class="n">test_x</span><span class="p">,</span> <span class="n">test_y</span><span class="p">)</span>
<span class="n">model</span><span class="p">.</span><span class="nf">plot_cost</span><span class="p">()</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<h2 id="conclusion">Conclusion</h2>

<p>I hope this tutorial provides a detailed approach of the process of building a neural network from scratch. Understanding the core components like forward and backward propagation is crucial since they form the backbone of any neural network. From here, we can explore various optimizations to improve accuracy, speed up computation, and enhance performance. In the next steps, we’ll look at how to implement similar neural networks using popular frameworks like TensorFlow and PyTorch, which offer powerful tools for more advanced applications.</p>

<div class="row mt-3 justify-content-center">
    <div class="col-sm-8 mt-3 mt-md-0">
        <picture>
            <img src="/images/blog/sigmoid.png" width="100%" alt="Sigmoid activation function plot" />
        </picture>
    </div>
</div>

<div class="row mt-3 justify-content-center">
    <div class="col-sm-8 mt-3 mt-md-0">
        <picture>
            <img src="/images/blog/binary_loss.png" width="100%" alt="Binary cross entropy loss plot" />
        </picture>
    </div>
</div>

<div class="row mt-3 justify-content-center">
    <div class="col-sm-8 mt-3 mt-md-0">
        <picture>
            <img src="/images/blog/binary_accuracy.png" width="100%" alt="Training and validation accuracy plot" />
        </picture>
    </div>
</div>]]></content><author><name></name></author><category term="neural-network" /><category term="backpropagation" /><category term="binary-classification" /><category term="forward-pass" /><category term="sigmoid" /><summary type="html"><![CDATA[In this blog, we explored the process of building a neural network from scratch using Python and the MNIST dataset. By focusing on binary classification, we covered the essential components of neural networks, including data preprocessing, parameter initialization, forward pass, backpropagation, and training the network.]]></summary></entry><entry><title type="html">Implementing Stochastic Gradient Descent and variants from scratch.</title><link href="https://princeemensah.github.io/blog/2024/08/stochastic-gradient-descent.html" rel="alternate" type="text/html" title="Implementing Stochastic Gradient Descent and variants from scratch." /><published>2024-08-09T10:33:13+00:00</published><updated>2024-08-09T10:33:13+00:00</updated><id>https://princeemensah.github.io/blog/2024/08/stochastic-gradient-descent</id><content type="html" xml:base="https://princeemensah.github.io/blog/2024/08/stochastic-gradient-descent.html"><![CDATA[<p>Welcome to the implementation of an important optimization techniques in machine learning! In this post, we’ll look at Gradient Descent (GD) and Stochastic Gradient Descent (SGD) which are two essential methods for training machine learning models. Whether you’re new to these concepts or looking to refine your understanding, this post is designed to make these methods comprehensive and practical.</p>

<p>We’ll walk through various SGD variants like constant and shrinking step sizes, momentum, and averaging, comparing how each one impacts the speed and accuracy of the model’s convergence. Along the way, we’ll discuss when to use each technique, what makes them effective, and how to balance computational cost with performance.</p>

<p>Let’s dive in together and discover the best method for training your machine learning model!</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
</pre></td><td class="rouge-code"><pre><span class="c1"># The following libraries will be essential for our implemetation.
</span><span class="kn">import</span> <span class="n">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">from</span> <span class="n">numpy</span> <span class="kn">import</span> <span class="n">linalg</span> <span class="k">as</span> <span class="n">la</span>
<span class="kn">from</span> <span class="n">scipy.linalg</span> <span class="kn">import</span> <span class="n">norm</span>
<span class="kn">import</span> <span class="n">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">from</span> <span class="n">numba</span> <span class="kn">import</span> <span class="n">njit</span><span class="p">,</span> <span class="n">jit</span>  <span class="c1"># A just in time compiler to speed things up!
</span><span class="o">%</span><span class="n">matplotlib</span> <span class="n">inline</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<h2 id="linear-regression-with-ridge-penalization">Linear Regression with Ridge Penalization</h2>

<p>In our linear regression model with Ridge penalization, the goal is to find the weight vector \(w\) that minimizes the following objective function:</p>

<p>\begin{equation}
\label{eq:linear-regression}
f(w) = \frac{1}{2n} |Xw - y|^2 + \frac{\lambda}{2} ||w||^2,
\end{equation}</p>

<p>where \(X\) is our feature matrix, \(y\) is the vector of true values, and \(\lambda\) is the regularization parameter that controls the strength of the penalty on the size of the weights.</p>

<p>To optimize this objective function using gradient-based methods, we need to compute the gradient, which tells us the direction in which the function decreases most rapidly. The gradient of the objective function \(f(w)\)is:</p>

<p>\begin{equation}
\label{eq:gradient} 
\nabla f(w) = \frac{1}{n} X^T(Xw - y) + \lambda w, 
\end{equation}</p>

<p>where the first term \(\frac{1}{n} X^T(Xw - y)\) represents the gradient of the least-squares loss, while the second term  \(\lambda w\) accounts for the regularization.</p>

<p>For stochastic gradient descent (SGD), we often update the weights using the gradient calculated from a single data point rather than the entire dataset. The gradient for a single data point \(i\) is given by:</p>

<p>\begin{equation}
\label{eq:sgd}
\nabla f_i(w) = (X_i w - y_i) X_i + \lambda w
\end{equation}</p>

<p>We will implement this as well which will allow us to perform efficient updates in each iteration of SGD.”</p>

<p>To ensure stable and efficient updates in gradient-based methods, it’s important to set an appropriate step size. The Lipschitz constant \(L\) provides an upper bound on the gradient’s rate of change and helps in choosing this step size:</p>

<p>\begin{equation}
\label{eq:step-size}
L = \frac{|X|_2^2}{n} + \lambda,
\end{equation}</p>

<p>which guides us in selecting a step size that prevents overshooting during optimization.</p>

<p>In stochastic gradient descent, where updates are made based on individual data points, the step size can be adapted to the specific characteristics of each data point.</p>

<p>\begin{equation}
\label{eq:lmax}
L_{\text{max}} = \max\left(\sum X_i^2\right) + \lambda
\end{equation}</p>

<p>This constant ensures that the step size is appropriately scaled, even for the most ‘difficult’ data points, preventing instability in the updates.</p>

<p>Lastly, when dealing with strongly convex functions, the strong convexity constant \(\mu\) provides a lower bound on the curvature of the objective function.</p>

<p>\begin{equation}
\label{eq:muconstant}
\mu = \frac{\min(\text{eigenvalues}(X^TX))}{n} + \lambda
\end{equation}</p>

<p>The strong convexity constant helps in determining how aggressively we can update our weights without risking divergence.</p>

<p>Lets now put all the pieces we’ve discussed above into a <code class="language-plaintext highlighter-rouge">LinReg</code> class which will be important for our optimization tasks.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
</pre></td><td class="rouge-code"><pre><span class="kn">from</span> <span class="n">scipy.linalg</span> <span class="kn">import</span> <span class="n">svd</span>

<span class="k">class</span> <span class="nc">LinearRegression</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">lbda</span><span class="p">):</span>
        <span class="n">self</span><span class="p">.</span><span class="n">X</span> <span class="o">=</span> <span class="n">X</span>
        <span class="n">self</span><span class="p">.</span><span class="n">y</span> <span class="o">=</span> <span class="n">y</span>
        <span class="n">self</span><span class="p">.</span><span class="n">n</span><span class="p">,</span> <span class="n">self</span><span class="p">.</span><span class="n">d</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span>
        <span class="n">self</span><span class="p">.</span><span class="n">lbda</span> <span class="o">=</span> <span class="n">lbda</span>  
    <span class="k">def</span> <span class="nf">grad</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">w</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">self</span><span class="p">.</span><span class="n">X</span><span class="p">.</span><span class="n">T</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">X</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">w</span><span class="p">)</span> <span class="o">-</span> <span class="n">self</span><span class="p">.</span><span class="n">y</span><span class="p">)</span> <span class="o">/</span> <span class="n">self</span><span class="p">.</span><span class="n">n</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="n">lbda</span> <span class="o">*</span> <span class="n">w</span>
    
    <span class="k">def</span> <span class="nf">f_i</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">w</span><span class="p">):</span>
        <span class="k">return</span> <span class="nf">norm</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">X</span><span class="p">[</span><span class="n">i</span><span class="p">].</span><span class="nf">dot</span><span class="p">(</span><span class="n">w</span><span class="p">)</span> <span class="o">-</span> <span class="n">self</span><span class="p">.</span><span class="n">y</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">/</span> <span class="p">(</span><span class="mf">2.</span><span class="p">)</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="n">lbda</span> <span class="o">*</span> <span class="nf">norm</span><span class="p">(</span><span class="n">w</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">/</span> <span class="mf">2.</span>  
    
    <span class="k">def</span> <span class="nf">f</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">w</span><span class="p">):</span>
        <span class="k">return</span> <span class="nf">norm</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">X</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">w</span><span class="p">)</span> <span class="o">-</span> <span class="n">self</span><span class="p">.</span><span class="n">y</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">/</span> <span class="p">(</span><span class="mf">2.</span> <span class="o">*</span> <span class="n">self</span><span class="p">.</span><span class="n">n</span><span class="p">)</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="n">lbda</span> <span class="o">*</span> <span class="nf">norm</span><span class="p">(</span><span class="n">w</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">/</span> <span class="mf">2.</span>

    <span class="k">def</span> <span class="nf">grad_i</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">w</span><span class="p">):</span>
        <span class="n">x_i</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">X</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
        <span class="nf">return </span><span class="p">(</span><span class="n">x_i</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">w</span><span class="p">)</span> <span class="o">-</span> <span class="n">self</span><span class="p">.</span><span class="n">y</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="o">*</span> <span class="n">x_i</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="n">lbda</span> <span class="o">*</span> <span class="n">w</span>

    <span class="k">def</span> <span class="nf">lipschitz_constant</span><span class="p">(</span><span class="n">self</span><span class="p">):</span>
        <span class="n">L</span> <span class="o">=</span> <span class="nf">norm</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">X</span><span class="p">,</span> <span class="nb">ord</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">/</span> <span class="n">self</span><span class="p">.</span><span class="n">n</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="n">lbda</span>
        <span class="k">return</span> <span class="n">L</span>
    
    <span class="k">def</span> <span class="nf">L_max_constant</span><span class="p">(</span><span class="n">self</span><span class="p">):</span>
        <span class="n">L_max</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">max</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nf">sum</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">X</span> <span class="o">**</span> <span class="mi">2</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="n">lbda</span>
        <span class="k">return</span> <span class="n">L_max</span> 
    
    <span class="k">def</span> <span class="nf">mu_constant</span><span class="p">(</span><span class="n">self</span><span class="p">):</span>
        <span class="n">mu</span> <span class="o">=</span>  <span class="nf">min</span><span class="p">(</span><span class="nf">abs</span><span class="p">(</span><span class="n">la</span><span class="p">.</span><span class="nf">eigvals</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">X</span><span class="p">.</span><span class="n">T</span><span class="p">,</span><span class="n">self</span><span class="p">.</span><span class="n">X</span><span class="p">))))</span> <span class="o">/</span> <span class="n">self</span><span class="p">.</span><span class="n">n</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="n">lbda</span>
        <span class="k">return</span> <span class="n">mu</span>     
</pre></td></tr></tbody></table></code></pre></div></div>

<p>Whether you’re using full-batch gradient descent or stochastic methods, this class forms the backbone of our optimization experiments, enabling us to test and compare different techniques effectively.</p>

<h2 id="logistic-regression-with-ridge-penalization">Logistic Regression with Ridge Penalization</h2>

<p>Similarly, in logistic regression, our goal is to find the weight vector \(w\) that minimizes the following objective function, which includes both the logistic loss and an L2 regularization term:</p>

<p>\begin{equation}
\label{eq:logistic-regression}
f(w) = \frac{1}{n} \sum_{i=1}^{n} \log\left(1 + \exp(-y_i \cdot X_i w)\right) + \frac{\lambda}{2} ||w||^2,
\end{equation}</p>

<p>where, \(X\) is the feature matrix, \(y\) is the vector of binary labels, and \(\lambda\) is the regularization parameter that controls the penalty on the magnitude of the weights.</p>

<p>To minimize this objective function using gradient-based methods, we need to compute its gradient, which tells us the direction in which the function decreases most rapidly. The gradient of \(f(w)\) is:</p>

<p>\begin{equation}
\label{eq:log_grad}
\nabla f(w) = -\frac{1}{n} X^T \left(\frac{y}{1 + \exp(y \cdot Xw)}\right) + \lambda w .
\end{equation}</p>

<p>The first term represents the gradient of the logistic loss, and the second term \(\lambda w\) is the gradient of the L2 regularization.</p>

<p>For stochastic gradient descent, where we update the weights based on one data point at a time, we use the gradient calculated from that individual data point. The gradient for a single data point $i$ is:</p>

<p>\begin{equation}
\label{eq:log_sgdgrad}
\nabla f_i(w) = -\frac{y_i \cdot X_i}{1 + \exp(y_i \cdot X_i w)} + \lambda w .
\end{equation}</p>

<p>This allow us to perform efficient updates during each iteration of SGD.</p>

<p>To ensure that our gradient-based methods converge efficiently, we need to carefully choose the step size. The Lipschitz constant \(L\) gives us an upper bound on how much the gradient can change, helping us set a stable step size:</p>

<p>\begin{equation}
\label{eq:log_L}
L = \frac{||X||_2^2}{4n} + \lambda .
\end{equation}</p>

<p>And this help us in selecting a step size that prevents overshooting during optimization.</p>

<p>When using stochastic gradient descent, it’s often beneficial to adapt the step size to the characteristics of each data point.</p>

<p>\begin{equation}
\label{eq:log_Lmax}
L_{\text{max}} = \frac{\max(\sum X_i^2)}{4} + \lambda
\end{equation}</p>

<p>This constant ensures that our step sizes are appropriately scaled, even for the most challenging data points.</p>

<p>In strongly convex optimization problems, the strong convexity constant \(\mu\) plays an important role in accelerating convergence. For our logistic regression problem, the strong convexity constant is given by:</p>

<p>\begin{equation}
\label{eq:log_mu}
\mu = \lambda
\end{equation}</p>

<p>This constant reflects the curvature of our loss function, helping us fine-tune our optimization algorithms for faster convergence.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
</pre></td><td class="rouge-code"><pre><span class="k">class</span> <span class="nc">LogisticRegression</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">lbda</span><span class="p">):</span>
        <span class="n">self</span><span class="p">.</span><span class="n">X</span> <span class="o">=</span> <span class="n">X</span>
        <span class="n">self</span><span class="p">.</span><span class="n">y</span> <span class="o">=</span> <span class="n">y</span>
        <span class="n">self</span><span class="p">.</span><span class="n">n</span><span class="p">,</span> <span class="n">self</span><span class="p">.</span><span class="n">d</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span>
        <span class="n">self</span><span class="p">.</span><span class="n">lbda</span> <span class="o">=</span> <span class="n">lbda</span>
 
    <span class="k">def</span> <span class="nf">grad</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">w</span><span class="p">):</span>
        <span class="n">bAx</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">y</span> <span class="o">*</span> <span class="n">self</span><span class="p">.</span><span class="n">X</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">w</span><span class="p">)</span>
        <span class="n">temp</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">/</span> <span class="p">(</span><span class="mf">1.</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="nf">exp</span><span class="p">(</span><span class="n">bAx</span><span class="p">))</span>
        <span class="n">grad</span> <span class="o">=</span> <span class="o">-</span> <span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">X</span><span class="p">.</span><span class="n">T</span><span class="p">).</span><span class="nf">dot</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">y</span> <span class="o">*</span> <span class="n">temp</span><span class="p">)</span> <span class="o">/</span> <span class="n">self</span><span class="p">.</span><span class="n">n</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="n">lbda</span> <span class="o">*</span> <span class="n">w</span>
        <span class="k">return</span> <span class="n">grad</span>
    
    <span class="k">def</span> <span class="nf">f_i</span><span class="p">(</span><span class="n">self</span><span class="p">,</span><span class="n">i</span><span class="p">,</span> <span class="n">w</span><span class="p">):</span>
        <span class="n">bAx_i</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">y</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">X</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">w</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="nf">log</span><span class="p">(</span><span class="mf">1.</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="nf">exp</span><span class="p">(</span><span class="o">-</span> <span class="n">bAx_i</span><span class="p">))</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="n">lbda</span> <span class="o">*</span> <span class="nf">norm</span><span class="p">(</span><span class="n">w</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">/</span> <span class="mf">2.</span>
    
    <span class="k">def</span> <span class="nf">f</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">w</span><span class="p">):</span>
        <span class="n">bAx</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">y</span> <span class="o">*</span> <span class="n">self</span><span class="p">.</span><span class="n">X</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">w</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="nf">mean</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nf">log</span><span class="p">(</span><span class="mf">1.</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="nf">exp</span><span class="p">(</span><span class="o">-</span> <span class="n">bAx</span><span class="p">)))</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="n">lbda</span> <span class="o">*</span> <span class="nf">norm</span><span class="p">(</span><span class="n">w</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">/</span> <span class="mf">2.</span>

    <span class="k">def</span> <span class="nf">grad_i</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">w</span><span class="p">):</span>
        <span class="n">grad</span> <span class="o">=</span> <span class="o">-</span> <span class="n">self</span><span class="p">.</span><span class="n">X</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="n">self</span><span class="p">.</span><span class="n">y</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">/</span> <span class="p">(</span><span class="mf">1.</span> <span class="o">+</span> <span class="n">np</span><span class="p">.</span><span class="nf">exp</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">y</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> 
                                                      <span class="o">*</span> <span class="n">self</span><span class="p">.</span><span class="n">X</span><span class="p">[</span><span class="n">i</span><span class="p">].</span><span class="nf">dot</span><span class="p">(</span><span class="n">w</span><span class="p">)))</span>
        <span class="n">grad</span> <span class="o">+=</span> <span class="n">self</span><span class="p">.</span><span class="n">lbda</span> <span class="o">*</span> <span class="n">w</span>
        <span class="k">return</span> <span class="n">grad</span>

    <span class="k">def</span> <span class="nf">lipschitz_constant</span><span class="p">(</span><span class="n">self</span><span class="p">):</span>
        <span class="n">L</span> <span class="o">=</span> <span class="nf">norm</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">X</span><span class="p">,</span> <span class="nb">ord</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span>  <span class="o">/</span> <span class="p">(</span><span class="mf">4.</span> <span class="o">*</span> <span class="n">self</span><span class="p">.</span><span class="n">n</span><span class="p">)</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="n">lbda</span>
        <span class="k">return</span> <span class="n">L</span>
    <span class="k">def</span> <span class="nf">L_max_constant</span><span class="p">(</span><span class="n">self</span><span class="p">):</span>
        <span class="n">L_max</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">max</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nf">sum</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">X</span> <span class="o">**</span> <span class="mi">2</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span><span class="o">/</span><span class="mi">4</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="n">lbda</span>
        <span class="k">return</span> <span class="n">L_max</span> 
    
    <span class="k">def</span> <span class="nf">mu_constant</span><span class="p">(</span><span class="n">self</span><span class="p">):</span>
        <span class="n">mu</span> <span class="o">=</span>  <span class="n">self</span><span class="p">.</span><span class="n">lbda</span>
        <span class="k">return</span> <span class="n">mu</span>    
</pre></td></tr></tbody></table></code></pre></div></div>
<p>Whether you’re using full-batch gradient descent, stochastic gradient descent, momentum or averaging, this class gives us the tools we need to achieve stable and efficient convergence.</p>

<h2 id="data-functions">Data Functions</h2>

<p>To test and compare our optimization methods, we first need to create a dataset that simulates a real-world least-squares and logistic regression task. The code block below defines a function called simu_linreg, which generates such a dataset for the linear regressioin model.</p>

<h3 id="data-simulation-for-linear-regression">Data simulation for linear regression</h3>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
</pre></td><td class="rouge-code"><pre><span class="kn">from</span> <span class="n">numpy.random</span> <span class="kn">import</span> <span class="n">multivariate_normal</span><span class="p">,</span> <span class="n">randn</span>
<span class="kn">from</span> <span class="n">scipy.linalg.special_matrices</span> <span class="kn">import</span> <span class="n">toeplitz</span>

    
<span class="k">def</span> <span class="nf">simulate_linreg</span><span class="p">(</span><span class="n">w</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">1.</span><span class="p">,</span> <span class="n">corr</span><span class="o">=</span><span class="mf">0.5</span><span class="p">):</span>
    <span class="sh">"""</span><span class="s">
    Simulation of the least-squares problem
    
    Parameters
    ----------
    x : np.ndarray, shape=(d,)
        The coefficients of the model
    
    n : int
        Sample size
    
    std : float, default=1.
        Standard-deviation of the noise

    corr : float, default=0.5
        Correlation of the features matrix
    </span><span class="sh">"""</span>    
    <span class="n">d</span> <span class="o">=</span> <span class="n">w</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
    <span class="n">cov</span> <span class="o">=</span> <span class="nf">toeplitz</span><span class="p">(</span><span class="n">corr</span> <span class="o">**</span> <span class="n">np</span><span class="p">.</span><span class="nf">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">d</span><span class="p">))</span>
    <span class="n">X</span> <span class="o">=</span> <span class="nf">multivariate_normal</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nf">zeros</span><span class="p">(</span><span class="n">d</span><span class="p">),</span> <span class="n">cov</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">n</span><span class="p">)</span>
    <span class="n">noise</span> <span class="o">=</span> <span class="n">std</span> <span class="o">*</span> <span class="nf">randn</span><span class="p">(</span><span class="n">n</span><span class="p">)</span>
    <span class="n">y</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">w</span><span class="p">)</span> <span class="o">+</span> <span class="n">noise</span>
    <span class="k">return</span> <span class="n">X</span><span class="p">,</span> <span class="n">y</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<h3 id="data-simulation-for-linear-regression-1">Data simulation for linear regression</h3>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">simulate_logreg</span><span class="p">(</span><span class="n">w</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">1.</span><span class="p">,</span> <span class="n">corr</span><span class="o">=</span><span class="mf">0.5</span><span class="p">):</span>
    <span class="sh">"""</span><span class="s">
    Simulation of the logistic regression problem
    
    Parameters
    ----------
    x : np.ndarray, shape=(d,)
        The coefficients of the model
    
    n : int
        Sample size
    
    std : float, default=1.
        Standard-deviation of the noise

    corr : float, default=0.5
        Correlation of the features matrix
    </span><span class="sh">"""</span>    
    <span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="nf">simulate_linreg</span><span class="p">(</span><span class="n">w</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">1.</span><span class="p">,</span> <span class="n">corr</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">X</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="nf">sign</span><span class="p">(</span><span class="n">y</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>Both functions are essential because they allow us to create controlled datasets, making it easier to evaluate how well our models perform under different conditions.</p>

<h3 id="generating-the-dataset">Generating the Dataset</h3>

<p>In this step, we create the dataset that will be used to test our linear and logistic regression model.</p>

<p><strong>Define Dimensions</strong></p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
</pre></td><td class="rouge-code"><pre><span class="n">d</span> <span class="o">=</span> <span class="mi">50</span>
<span class="n">n</span> <span class="o">=</span> <span class="mi">1000</span>
</pre></td></tr></tbody></table></code></pre></div></div>
<p>We set the number of features \(d = 50\) and the number of data points \(n = 1000.\) This means our dataset will have 50 features per data point, and we’ll generate \(1000\) such data points.</p>

<p><strong>Setting Up Ground Truth Coefficients</strong></p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
</pre></td><td class="rouge-code"><pre><span class="n">idx</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">arange</span><span class="p">(</span><span class="n">d</span><span class="p">)</span>
<span class="n">w_model_truth</span> <span class="o">=</span> <span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">**</span><span class="n">idx</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="nf">exp</span><span class="p">(</span><span class="o">-</span><span class="n">idx</span> <span class="o">/</span> <span class="mf">10.</span><span class="p">)</span>

<span class="n">plt</span><span class="p">.</span><span class="nf">stem</span><span class="p">(</span><span class="n">w_model_truth</span><span class="p">);</span> 
</pre></td></tr></tbody></table></code></pre></div></div>
<p>We create the true coefficients \(w_{\text{model_truth}}\) that the model will try to learn. These coefficients are generated using an exponential decay function, alternating signs with each feature.</p>

<p><strong>Generate the Dataset</strong></p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
</pre></td><td class="rouge-code"><pre><span class="c1">#X, y = simulate_linreg(w_model_truth, n, std=1., corr=0.1)
</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="nf">simulate_logreg</span><span class="p">(</span><span class="n">w_model_truth</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">1.</span><span class="p">,</span> <span class="n">corr</span><span class="o">=</span><span class="mf">0.7</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>
<p>Using the <code class="language-plaintext highlighter-rouge">simulate_linreg</code> function, we generate the feature matrix \(X\) and the target labels \(y\). The dataset is created with a moderate noise level (<code class="language-plaintext highlighter-rouge">std=1.0</code>) and a correlation of (<code class="language-plaintext highlighter-rouge">corr=0.1</code>) between features.</p>

<p>This dataset simulates a realistic logistic regression problem, providing the data we need to test and refine our optimization algorithms. <strong><em>Please not that we will not be using the the logistic regression model for this task and that explains why I commented it out.</em></strong></p>

<h3 id="selecting-the-model">Selecting the Model</h3>

<p>In this step, we choose the model that will be used for our optimization experiments. Here’s what the code does:</p>

<p><strong>Set the Regularization Parameter</strong></p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
</pre></td><td class="rouge-code"><pre><span class="n">lbda</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">/</span> <span class="n">n</span> <span class="o">**</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>
<p>We define the regularization parameter \(\lambda\) as \(1 / \sqrt{n}\), where \(n\) is the number of data points. This setting helps balance the model complexity and prevents overfitting.</p>

<p><strong>Choose the Model</strong></p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
</pre></td><td class="rouge-code"><pre><span class="c1">#model = LinearRegression(X, y, lbda)
</span><span class="n">model</span> <span class="o">=</span> <span class="nc">LogisticRegression</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">lbda</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>
<p>Again I chose the logistic regression model with L2 regularization as the preferred model for this task. However, you can choose to use the linear regression model with Ridge penalization as your preferred model.</p>

<p>This choice determines whether you’ll be performing regression (with <code class="language-plaintext highlighter-rouge">LinearRegression</code>) or classification (with <code class="language-plaintext highlighter-rouge">LogisticRegression</code>). Depending on the dataset you’ve generated (<code class="language-plaintext highlighter-rouge">X</code>, <code class="language-plaintext highlighter-rouge">y</code>), you’ll select the appropriate model for the task.</p>

<h3 id="gradient-verification">Gradient Verification</h3>
<p>What we want to ensue is that the analytical gradient \(\nabla f_i(w)\) calculated by the model matches the numerical gradient derived from the objective function \(f_i(w)\).</p>

<p>We compute the numerical gradient as follows:
\begin{equation}
\label{eq:num-grad}
\text{numerical_grad} = \frac{f_i(w + \epsilon \cdot \text{vec}) - f_i(w)}{\epsilon}
\end{equation}</p>

<p>And we compute the analytical gradient and checkt the difference.
\begin{equation}
\label{eq:ana-grad}
\text{grad_error} = \text{numerical_grad} - \text{analytical_grad}
\end{equation}</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
</pre></td><td class="rouge-code"><pre><span class="n">grad_error</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
    <span class="n">ind</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">choice</span><span class="p">(</span><span class="n">n</span><span class="p">,</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">w</span> <span class="o">=</span>  <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">randn</span><span class="p">(</span><span class="n">d</span><span class="p">)</span>
    <span class="n">vec</span> <span class="o">=</span>  <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">randn</span><span class="p">(</span><span class="n">d</span><span class="p">)</span>
    <span class="n">eps</span> <span class="o">=</span> <span class="nf">pow</span><span class="p">(</span><span class="mf">10.0</span><span class="p">,</span> <span class="o">-</span><span class="mf">7.0</span><span class="p">)</span>
    <span class="n">model</span><span class="p">.</span><span class="nf">f_i</span><span class="p">(</span><span class="n">ind</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="n">w</span><span class="p">)</span>
    <span class="n">grad_error</span><span class="p">.</span><span class="nf">append</span><span class="p">((</span><span class="n">model</span><span class="p">.</span><span class="nf">f_i</span><span class="p">(</span> <span class="n">ind</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">w</span><span class="o">+</span><span class="n">eps</span><span class="o">*</span><span class="n">vec</span><span class="p">)</span> <span class="o">-</span> <span class="n">model</span><span class="p">.</span><span class="nf">f_i</span><span class="p">(</span> <span class="n">ind</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">w</span><span class="p">))</span><span class="o">/</span><span class="n">eps</span> <span class="o">-</span> <span class="n">np</span><span class="p">.</span><span class="nf">dot</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="nf">grad_i</span><span class="p">(</span><span class="n">ind</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span><span class="n">w</span><span class="p">),</span><span class="n">vec</span><span class="p">))</span>
<span class="nf">print</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nf">mean</span><span class="p">(</span><span class="n">grad_error</span><span class="p">))</span>
</pre></td></tr></tbody></table></code></pre></div></div>
<p>Output:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
</pre></td><td class="rouge-code"><pre><span class="mf">2.7469189607901637e-06</span>
</pre></td></tr></tbody></table></code></pre></div></div>
<p>The small value of <code class="language-plaintext highlighter-rouge">2.7469189607901637e-06</code> indicates that the gradients computed by the model are highly accurate and closely match the numerical gradients. This low error confirms that our gradient implementation is correct, ensuring that our optimization algorithms will perform correctly, as they rely on accurate gradient calculations to update the model weights</p>

<p>Alternatively, we can also use the <code class="language-plaintext highlighter-rouge">check_grad</code> function from the <code class="language-plaintext highlighter-rouge">scipy.optimize</code> module to verify the accuracy of the gradient calculations in our <code class="language-plaintext highlighter-rouge">LinearRegression</code> and <code class="language-plaintext highlighter-rouge">LogisticRegression</code> models.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
</pre></td><td class="rouge-code"><pre><span class="kn">from</span> <span class="n">scipy.optimize</span> <span class="kn">import</span> <span class="n">check_grad</span>
<span class="n">modellin</span> <span class="o">=</span> <span class="nc">LinearRegression</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">lbda</span><span class="p">)</span>
<span class="nf">check_grad</span><span class="p">(</span><span class="n">modellin</span><span class="p">.</span><span class="n">f</span><span class="p">,</span> <span class="n">modellin</span><span class="p">.</span><span class="n">grad</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">randn</span><span class="p">(</span><span class="n">d</span><span class="p">))</span>
</pre></td></tr></tbody></table></code></pre></div></div>
<p>Output:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
</pre></td><td class="rouge-code"><pre><span class="mf">1.2288105629057588e-06</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
</pre></td><td class="rouge-code"><pre><span class="n">modellog</span> <span class="o">=</span> <span class="nc">LogReg</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">lbda</span><span class="p">)</span>
<span class="nf">check_grad</span><span class="p">(</span><span class="n">modellog</span><span class="p">.</span><span class="n">f</span><span class="p">,</span> <span class="n">modellog</span><span class="p">.</span><span class="n">grad</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">randn</span><span class="p">(</span><span class="n">d</span><span class="p">))</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>Output</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
</pre></td><td class="rouge-code"><pre><span class="mf">1.8667365426265916e-07</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>What we want to do now is to use the L-BFGS (Limited-memory Broyden–Fletcher–Goldfarb–Shanno) algorithm to find a highly accurate solution which will serve as a benchmark for evaluating the performance of the SGD method we’re going to implement.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
</pre></td><td class="rouge-code"><pre><span class="kn">from</span> <span class="n">scipy.optimize</span> <span class="kn">import</span> <span class="n">fmin_l_bfgs_b</span>
<span class="n">w_init</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">zeros</span><span class="p">(</span><span class="n">d</span><span class="p">)</span>
<span class="n">w_min</span><span class="p">,</span> <span class="n">obj_min</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="nf">fmin_l_bfgs_b</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">f</span><span class="p">,</span> <span class="n">w_init</span><span class="p">,</span> <span class="n">model</span><span class="p">.</span><span class="n">grad</span><span class="p">,</span> <span class="n">args</span><span class="o">=</span><span class="p">(),</span> <span class="n">pgtol</span><span class="o">=</span><span class="mf">1e-30</span><span class="p">,</span> <span class="n">factr</span> <span class="o">=</span><span class="mf">1e-30</span><span class="p">)</span>

<span class="nf">print</span><span class="p">(</span><span class="n">obj_min</span><span class="p">)</span>
<span class="nf">print</span><span class="p">(</span><span class="nf">norm</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="nf">grad</span><span class="p">(</span><span class="n">w_min</span><span class="p">)))</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>Output:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
</pre></td><td class="rouge-code"><pre><span class="mf">0.2736626885606007</span>
<span class="mf">7.144141131678549e-09</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>From the output <code class="language-plaintext highlighter-rouge">obj_min = 0.2736626885606007</code> is the value of the objective function at the found minimum and <code class="language-plaintext highlighter-rouge">norm(model.grad(w_min)) = 7.144141131678549e-09</code> indicates that the algorithm has converged to a point where the gradient is nearly zero, meaning the solution is highly accurate.</p>

<h2 id="implementing-stochastic-gradient-descent">Implementing Stochastic Gradient descent</h2>

<p>Unlike gradient descent method, which updates the model parameters using the entire dataset, SGD performs updates using a randomly selected data point at each iteration.</p>

<p>The update rule for SGD is:
\begin{equation}
\label{eq:sgd-update}
w^{(t+1)} = w^{(t)} - \gamma^{(t)} \nabla f_{i_t}(w^{(t)})
\end{equation}</p>

<p>where \(\gamma^{(t)}\) is the learning rate at iteration \(t\), and \(\nabla f_{i_t}(w^{(t)})\) is the gradient with respect to the randomly chosen data point \(i_t\).</p>

<p>To further enhance this, we can add a momentum term that helps accelerate convergence:
\begin{equation}
\label{eq:momentum}
w^{t+1} = w^t - \gamma^t \nabla f_i(w^t) + \text{momentum} \times (w^t - w^{t-1}), 
\end{equation}
where, \(\text{momentum}\) is a hyperparameter that controls the influence of the previous step.</p>

<p>Additionally, we can use <strong>iterative averaging</strong> to improve the stability and convergence of the algorithm. After a certain number of iterations, we start averaging the iterates:
\begin{equation}
\label{eq:sgd_averaging}
w_{\text{avg}}^{(t+1)} = \frac{1}{t - t_0 + 1} \sum_{j=t_0}^t w^{(j)}
\end{equation}
where \(t_0\) is the iteration at which we begin averaging. Averaging can be particularly useful in the later stages of optimization to smooth out the noise introduced by stochastic updates.</p>

<p>Now, lets implement the above SGD with option for momentum, averaging and step sizes.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">sgd</span><span class="p">(</span><span class="n">w0</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">steps</span><span class="p">,</span> <span class="n">w_min</span><span class="p">,</span> <span class="n">n_iter</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">averaging_on</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">momentum</span> <span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">start_late_averaging</span> <span class="o">=</span> <span class="mi">0</span><span class="p">):</span>
    <span class="n">w</span> <span class="o">=</span> <span class="n">w0</span><span class="p">.</span><span class="nf">copy</span><span class="p">()</span>
    <span class="n">w_new</span> <span class="o">=</span> <span class="n">w0</span><span class="p">.</span><span class="nf">copy</span><span class="p">()</span>
    <span class="n">n_samples</span><span class="p">,</span> <span class="n">n_features</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span>
    <span class="n">w_average</span> <span class="o">=</span> <span class="n">w0</span><span class="p">.</span><span class="nf">copy</span><span class="p">()</span>
    <span class="n">w_test</span> <span class="o">=</span> <span class="n">w0</span><span class="p">.</span><span class="nf">copy</span><span class="p">()</span>
    <span class="n">w_old</span> <span class="o">=</span> <span class="n">w0</span><span class="p">.</span><span class="nf">copy</span><span class="p">()</span>
    <span class="n">errors</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="n">err</span> <span class="o">=</span> <span class="mf">1.0</span>
    <span class="n">objectives</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="c1"># Current estimation error
</span>    <span class="k">if</span> <span class="n">np</span><span class="p">.</span><span class="nf">any</span><span class="p">(</span><span class="n">w_min</span><span class="p">):</span>
        <span class="n">err</span> <span class="o">=</span> <span class="nf">norm</span><span class="p">(</span><span class="n">w</span> <span class="o">-</span> <span class="n">w_min</span><span class="p">)</span> <span class="o">/</span> <span class="nf">norm</span><span class="p">(</span><span class="n">w_min</span><span class="p">)</span>
        <span class="n">errors</span><span class="p">.</span><span class="nf">append</span><span class="p">(</span><span class="n">err</span><span class="p">)</span>
    <span class="c1"># Current objective
</span>    <span class="n">obj</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="nf">f</span><span class="p">(</span><span class="n">w</span><span class="p">)</span> 
    <span class="n">objectives</span><span class="p">.</span><span class="nf">append</span><span class="p">(</span><span class="n">obj</span><span class="p">)</span>
    <span class="k">if</span> <span class="n">verbose</span><span class="p">:</span>
        <span class="nf">print</span><span class="p">(</span><span class="sh">"</span><span class="s">Lauching SGD solver...</span><span class="sh">"</span><span class="p">)</span>
        <span class="nf">print</span><span class="p">(</span><span class="sh">'</span><span class="s"> | </span><span class="sh">'</span><span class="p">.</span><span class="nf">join</span><span class="p">([</span><span class="n">name</span><span class="p">.</span><span class="nf">center</span><span class="p">(</span><span class="mi">8</span><span class="p">)</span> <span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="p">[</span><span class="sh">"</span><span class="s">it</span><span class="sh">"</span><span class="p">,</span> <span class="sh">"</span><span class="s">obj</span><span class="sh">"</span><span class="p">,</span> <span class="sh">"</span><span class="s">err</span><span class="sh">"</span><span class="p">]]))</span>
    <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="n">n_iter</span><span class="p">):</span>
        <span class="n">w_new</span><span class="p">[:]</span> <span class="o">=</span> <span class="n">w</span> <span class="o">-</span> <span class="n">steps</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="nf">grad_i</span><span class="p">(</span><span class="n">indices</span><span class="p">[</span><span class="n">k</span><span class="p">],</span><span class="n">w</span><span class="p">)</span> <span class="o">+</span> <span class="n">momentum</span><span class="o">*</span><span class="p">(</span><span class="n">w</span> <span class="o">-</span> <span class="n">w_old</span><span class="p">))</span>
        <span class="n">w_old</span><span class="p">[:]</span> <span class="o">=</span> <span class="n">w</span>
        <span class="n">w</span><span class="p">[:]</span> <span class="o">=</span> <span class="n">w_new</span>
        <span class="k">if</span> <span class="n">k</span> <span class="o">&lt;</span> <span class="n">start_late_averaging</span><span class="p">:</span>
            <span class="n">w_average</span><span class="p">[:]</span> <span class="o">=</span> <span class="n">w</span>
        <span class="k">else</span><span class="p">:</span>    
            <span class="n">k_new</span> <span class="o">=</span> <span class="n">k</span><span class="o">-</span><span class="n">start_late_averaging</span>
            <span class="n">w_average</span><span class="p">[:]</span> <span class="o">=</span> <span class="n">k_new</span> <span class="o">/</span> <span class="p">(</span><span class="n">k_new</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">w_average</span> <span class="o">+</span> <span class="n">w</span> <span class="o">/</span> <span class="p">(</span><span class="n">k_new</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span>
            
        <span class="k">if</span> <span class="n">averaging_on</span><span class="p">:</span>
            <span class="n">w_test</span><span class="p">[:]</span> <span class="o">=</span> <span class="n">w_average</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">w_test</span><span class="p">[:]</span> <span class="o">=</span> <span class="n">w</span>
        <span class="n">obj</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="nf">f</span><span class="p">(</span><span class="n">w_test</span><span class="p">)</span> 
        <span class="k">if</span> <span class="n">np</span><span class="p">.</span><span class="nf">any</span><span class="p">(</span><span class="n">w_min</span><span class="p">):</span>
            <span class="n">err</span> <span class="o">=</span> <span class="nf">norm</span><span class="p">(</span><span class="n">w_test</span> <span class="o">-</span> <span class="n">w_min</span><span class="p">)</span> <span class="o">/</span> <span class="nf">norm</span><span class="p">(</span><span class="n">w_min</span><span class="p">)</span>
            <span class="n">errors</span><span class="p">.</span><span class="nf">append</span><span class="p">(</span><span class="n">err</span><span class="p">)</span>
        <span class="n">objectives</span><span class="p">.</span><span class="nf">append</span><span class="p">(</span><span class="n">obj</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">k</span> <span class="o">%</span> <span class="n">n_samples</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">verbose</span><span class="p">:</span>
            <span class="nf">if</span><span class="p">(</span><span class="nf">sum</span><span class="p">(</span><span class="n">w_min</span><span class="p">)):</span>
                <span class="nf">print</span><span class="p">(</span><span class="sh">'</span><span class="s"> | </span><span class="sh">'</span><span class="p">.</span><span class="nf">join</span><span class="p">([(</span><span class="sh">"</span><span class="s">%d</span><span class="sh">"</span> <span class="o">%</span> <span class="n">k</span><span class="p">).</span><span class="nf">rjust</span><span class="p">(</span><span class="mi">8</span><span class="p">),</span> 
                              <span class="p">(</span><span class="sh">"</span><span class="s">%.2e</span><span class="sh">"</span> <span class="o">%</span> <span class="n">obj</span><span class="p">).</span><span class="nf">rjust</span><span class="p">(</span><span class="mi">8</span><span class="p">),</span> 
                              <span class="p">(</span><span class="sh">"</span><span class="s">%.2e</span><span class="sh">"</span> <span class="o">%</span> <span class="n">err</span><span class="p">).</span><span class="nf">rjust</span><span class="p">(</span><span class="mi">8</span><span class="p">)]))</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="nf">print</span><span class="p">(</span><span class="sh">'</span><span class="s"> | </span><span class="sh">'</span><span class="p">.</span><span class="nf">join</span><span class="p">([(</span><span class="sh">"</span><span class="s">%d</span><span class="sh">"</span> <span class="o">%</span> <span class="n">k</span><span class="p">).</span><span class="nf">rjust</span><span class="p">(</span><span class="mi">8</span><span class="p">),</span> 
                              <span class="p">(</span><span class="sh">"</span><span class="s">%.2e</span><span class="sh">"</span> <span class="o">%</span> <span class="n">obj</span><span class="p">).</span><span class="nf">rjust</span><span class="p">(</span><span class="mi">8</span><span class="p">)]))</span>
    <span class="k">if</span> <span class="n">averaging_on</span><span class="p">:</span>
        <span class="n">w_output</span> <span class="o">=</span> <span class="n">w_average</span><span class="p">.</span><span class="nf">copy</span><span class="p">()</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="n">w_output</span> <span class="o">=</span> <span class="n">w</span><span class="p">.</span><span class="nf">copy</span><span class="p">()</span>    
    <span class="k">return</span> <span class="n">w_output</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="nf">array</span><span class="p">(</span><span class="n">objectives</span><span class="p">),</span> <span class="n">np</span><span class="p">.</span><span class="nf">array</span><span class="p">(</span><span class="n">errors</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>
<p>This function provides a flexible framework for testing with different variants of SGD, allowing us to test the effects of momentum, averaging, and various step size schedules.</p>

<h2 id="constant-and-shrinking-step-sizes-with-replacement">Constant and Shrinking Step Sizes (With Replacement)</h2>
<p>Now that we’ve implemented our SGD function, it’s time to show how different step sizes impact the optimization process. Specifically, we’ll implement and compare <strong>SGD with a constant step size</strong> and <strong>SGD with a shrinking step size</strong>, both using sampling <strong>with replacement</strong>.</p>

<p>First, lets set up the number of iterations:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
</pre></td><td class="rouge-code"><pre><span class="n">datapasses</span> <span class="o">=</span> <span class="mi">30</span> 
<span class="n">n_iter</span> <span class="o">=</span> <span class="nf">int</span><span class="p">(</span><span class="n">datapasses</span> <span class="o">*</span> <span class="n">n</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p><code class="language-plaintext highlighter-rouge">datapasses</code> refers to the number of complete passes over the dataset. The total number of iterations, <code class="language-plaintext highlighter-rouge">n_iter</code>, is calculated by multiplying the number of data points \(n\) by the number of passes. This ensures that each data point is updated multiple times during the training process.</p>

<p><strong>Constant Stepsizes Step Size (With Replacement)</strong></p>

<p>In our first approach, we’ll use a constant step size throughout the optimization:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
</pre></td><td class="rouge-code"><pre><span class="n">Lmax</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="nc">L_max_constant</span><span class="p">()</span>

<span class="n">indices</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">choice</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">n_iter</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">replace</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">steps</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">ones</span><span class="p">(</span><span class="n">n_iter</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="mi">2</span><span class="o">*</span><span class="n">Lmax</span><span class="p">)</span>
<span class="n">w0</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">zeros</span><span class="p">(</span><span class="n">d</span><span class="p">)</span>
<span class="n">w_sgdcr</span><span class="p">,</span> <span class="n">obj_sgdcr</span><span class="p">,</span> <span class="n">err_sgdcr</span> <span class="o">=</span> <span class="nf">sgd</span><span class="p">(</span><span class="n">w0</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">steps</span><span class="p">,</span> <span class="n">w_min</span><span class="p">,</span> <span class="n">n_iter</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p><strong>Shrinking Stepsizes Step Size (With Replacement)</strong></p>

<p>Next, we’ll implement SGD using a shrinking step size:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
</pre></td><td class="rouge-code"><pre><span class="n">Lmax</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="nc">L_max_constant</span><span class="p">()</span>

<span class="n">indices</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">choice</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">n_iter</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">replace</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">steps</span> <span class="o">=</span>  <span class="mi">2</span><span class="o">/</span><span class="p">(</span><span class="n">Lmax</span><span class="o">*</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nf">sqrt</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nf">arange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">n_iter</span> <span class="o">+</span> <span class="mi">2</span><span class="p">))))</span>
<span class="n">w_sgdsr</span><span class="p">,</span> <span class="n">obj_sgdsr</span><span class="p">,</span> <span class="n">err_sgdsr</span> <span class="o">=</span> <span class="nf">sgd</span><span class="p">(</span><span class="n">w0</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">steps</span><span class="p">,</span> <span class="n">w_min</span><span class="p">,</span> <span class="n">n_iter</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>
<h3 id="comparing-sgd-with-constant-and-shrinking-step-sizes">Comparing SGD with Constant and Shrinking Step Sizes</h3>

<p>Let’s now compare the difference between SGD with constant step size and shrinking step size and observe their rate of convergence.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
</pre></td><td class="rouge-code"><pre><span class="c1"># Error of objective on a logarithmic scale
</span><span class="n">plt</span><span class="p">.</span><span class="nf">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">7</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">obj_sgdcr</span> <span class="o">-</span> <span class="n">obj_min</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGD const</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">obj_sgdsr</span> <span class="o">-</span> <span class="n">obj_min</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGD shrink</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">title</span><span class="p">(</span><span class="sh">"</span><span class="s">Convergence plot</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">16</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">xlabel</span><span class="p">(</span><span class="sh">"</span><span class="s">#iterations</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">ylabel</span><span class="p">(</span><span class="sh">"</span><span class="s">Error of objective</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">legend</span><span class="p">()</span>
<span class="c1"># Distance to the minimizer on a logarithmic scale
</span><span class="n">plt</span><span class="p">.</span><span class="nf">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">7</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">yscale</span><span class="p">(</span><span class="sh">"</span><span class="s">log</span><span class="sh">"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">err_sgdcr</span> <span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGD const</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">err_sgdsr</span> <span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGD shrink</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">title</span><span class="p">(</span><span class="sh">"</span><span class="s">Convergence plot</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">16</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">xlabel</span><span class="p">(</span><span class="sh">"</span><span class="s">#iterations</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">ylabel</span><span class="p">(</span><span class="sh">"</span><span class="s">Distance to the minimum</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">legend</span><span class="p">()</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<div class="row mt-3 justify-content-center">
    <div class="col-sm-8 mt-3 mt-md-0">
        <picture>
            <img src="/images/blog/const_shrink1.png" width="100%" alt="Convergence plot comparing constant and shrinking step sizes" />
        </picture>
    </div>
</div>
<div class="row mt-3 justify-content-center">
    <div class="col-sm-8 mt-3 mt-md-0">
        <picture>
            <img src="/images/blog/const_shrink2.png" width="100%" alt="Distance to minimum plot comparing constant and shrinking step sizes" />
        </picture>
    </div>
</div>
<div class="caption text-center">
    A plot showing the difference between constant step size and shrinking step size in terms of convergence.
</div>

<p>By comparing these two methods, we can see that while constant step sizes may be faster initially it tends to oscillate around the minimum as the iteration increases, shrinking step sizes provide a more reliable path to convergence, making them a preferred choice in scenarios where stability and accuracy are critical.</p>

<h2 id="sgd-with-switching-to-shrinking-step-sizes">SGD with Switching to Shrinking Step Sizes</h2>

<p>It’s often beneficial to start with a larger, constant step size for faster convergence early on, and then transition to smaller, shrinking step sizes to fine-tune the solution.</p>

<p><strong>Constant Step Size (Early Iterations)</strong></p>

<p>For the first \(t^*\) iterations, we use a constant step size:
\begin{equation}
\label{eq:const_to_switch}
\gamma_t = \frac{1}{2L_{\max}}
\end{equation}
This ensures rapid progress toward minimizing the objective function.</p>

<p><strong>Switching to Shrinking Step Sizes (Later Iterations)</strong></p>

<p>After \(t^*\), we switch to a shrinking step size:
\begin{equation}
\gamma_t = \frac{2t + 1}{(t + 1)^2 \mu},
\end{equation}
where, \(\mu\) is the strong convexity constant of the function, and the shrinking step size ensures that the updates become more conservative as the algorithm nears the optimal solution, which helps to reduce oscillations and improving stability.</p>

<p><strong>Switch Point</strong></p>

<p>The switch occurs at the iteration index \(t^*\), which is determined by the condition:
\begin{equation}
t^* = 4 \times \lceil \kappa \rceil,
\end{equation}
where \(\kappa = \frac{L_{\max}}{\mu}\) is the condition number of the problem. This point is chosen to balance between fast initial convergence and the need for more precision as we get closer to the solution.</p>

<p>Let’s now implement the above.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
</pre></td><td class="rouge-code"><pre><span class="n">mu</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="nf">mu_constant</span><span class="p">()</span>
<span class="n">Kappa</span> <span class="o">=</span> <span class="n">Lmax</span><span class="o">/</span><span class="n">mu</span>
<span class="n">tstar</span> <span class="o">=</span> <span class="mi">4</span> <span class="o">*</span> <span class="nf">int</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nf">ceil</span><span class="p">(</span><span class="n">Kappa</span><span class="p">))</span>

<span class="n">steps_switch</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nf">zeros</span><span class="p">(</span><span class="n">n_iter</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="n">n_iter</span><span class="p">):</span>
    <span class="k">if</span> <span class="n">i</span> <span class="o">&lt;=</span> <span class="n">tstar</span><span class="p">:</span>
        <span class="n">steps_switch</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">Lmax</span><span class="p">)</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="n">steps_switch</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="p">((</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">mu</span><span class="p">)</span>

<span class="n">indices</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">choice</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">n_iter</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">replace</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">np</span><span class="p">.</span><span class="nf">size</span><span class="p">(</span><span class="n">indices</span><span class="p">)</span>
<span class="n">w_sgdss</span><span class="p">,</span> <span class="n">obj_sgdss</span><span class="p">,</span> <span class="n">err_sgdss</span> <span class="o">=</span> <span class="nf">sgd</span><span class="p">(</span><span class="n">w0</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">steps_switch</span><span class="p">,</span> <span class="n">w_min</span><span class="p">,</span> <span class="n">n_iter</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>This switching approach effectively combines the advantages of both constant and shrinking step sizes as the constant step size in the early iterations allows for quick progress toward reducing the objective function and as we approach the minimum, the gradients become smaller, and the shrinking step sizes help to ensure that the updates do not overshoot the minimum.</p>

<h3 id="comparing-sgd-with-constant-to-switching-step-sizes">Comparing SGD with Constant to Switching Step Sizes</h3>

<p>Let’s now compare the difference between SGD with constant step size and shrinking step size and observe their rate of convergence.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
</pre></td><td class="rouge-code"><pre><span class="c1"># Plotting to compare with constant stepsize
</span><span class="n">plt</span><span class="p">.</span><span class="nf">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">7</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">obj_sgdcr</span> <span class="o">-</span> <span class="n">obj_min</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGD const</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">obj_sgdss</span> <span class="o">-</span> <span class="n">obj_min</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGD switch</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">title</span><span class="p">(</span><span class="sh">"</span><span class="s">Convergence plot</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">16</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">xlabel</span><span class="p">(</span><span class="sh">"</span><span class="s">#iterations</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">ylabel</span><span class="p">(</span><span class="sh">"</span><span class="s">Error of objective</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">axvline</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">tstar</span><span class="p">,</span> <span class="n">color</span> <span class="o">=</span> <span class="sh">"</span><span class="s">orange</span><span class="sh">"</span><span class="p">,</span> <span class="n">linestyle</span><span class="o">=</span><span class="sh">'</span><span class="s">dashed</span><span class="sh">'</span><span class="p">)</span>

<span class="c1"># Distance to the minimizer on a logarithmic scale
</span><span class="n">plt</span><span class="p">.</span><span class="nf">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">7</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">yscale</span><span class="p">(</span><span class="sh">"</span><span class="s">log</span><span class="sh">"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">err_sgdcr</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGD const</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">err_sgdss</span> <span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGD switch</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">title</span><span class="p">(</span><span class="sh">"</span><span class="s">Convergence plot</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">16</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">xlabel</span><span class="p">(</span><span class="sh">"</span><span class="s">#iterations</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">ylabel</span><span class="p">(</span><span class="sh">"</span><span class="s">Distance to the minimum</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">axvline</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">tstar</span><span class="p">,</span>  <span class="n">color</span> <span class="o">=</span> <span class="sh">"</span><span class="s">orange</span><span class="sh">"</span><span class="p">,</span> <span class="n">linestyle</span><span class="o">=</span><span class="sh">'</span><span class="s">dashed</span><span class="sh">'</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<div class="row mt-3 justify-content-center">
    <div class="col-sm-8 mt-3 mt-md-0">
        <picture>
            <img src="/images/blog/switch1.png" width="100%" alt="Convergence plot comparing constant and switching step sizes" />
        </picture>
    </div>
</div>
<div class="row mt-3 justify-content-center">
    <div class="col-sm-8 mt-3 mt-md-0">
        <picture>
            <img src="/images/blog/switch2.png" width="100%" alt="Distance to minimum plot comparing constant and switching step sizes" />
        </picture>
    </div>
</div>
<div class="caption text-center">
    A plot showing the difference between constant step size and switching step size in terms of convergence.
</div>

<p>The plot demonstrates that the switch to shrinking stepsizes strategy outperforms the constant stepsize approach by reducing the oscillations and providing a smoother convergence towards the minimum.</p>

<h2 id="sgd-with-averaging">SGD With Averaging</h2>

<p>One powerful technique that can enhance the performance of SGD is averaging. Averaging works by calculating the mean of the iterates towards the end of the optimization process</p>

<p>Here, start averaging the iterates only in the last quarter of the total iterations. This allows the algorithm to have more information to average on. Let’s implement it.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
</pre></td><td class="rouge-code"><pre><span class="c1"># Implementing averaging with SGD
</span><span class="n">indices</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">choice</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">n_iter</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">replace</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">start_late_averaging</span> <span class="o">=</span> <span class="mi">3</span><span class="o">*</span><span class="n">n_iter</span><span class="o">/</span><span class="mi">4</span>
<span class="n">averaging_on</span> <span class="o">=</span> <span class="bp">True</span> 

<span class="n">w_sgdar</span><span class="p">,</span> <span class="n">obj_sgdar</span><span class="p">,</span> <span class="n">err_sgdar</span> <span class="o">=</span> <span class="nf">sgd</span><span class="p">(</span><span class="n">w0</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">steps_switch</span><span class="p">,</span> <span class="n">w_min</span><span class="p">,</span> <span class="n">n_iter</span><span class="p">,</span> <span class="n">averaging_on</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="bp">True</span><span class="p">,</span> <span class="n">start_late_averaging</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<h3 id="comparing-the-results">Comparing the Results.</h3>

<p>Let’s now compare the difference between SGD with constant, switching and averaging step size and observe their rate of convergence.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
</pre></td><td class="rouge-code"><pre><span class="c1"># Plotting to compare constant stepsize, switchting, switching + averaging
</span><span class="n">plt</span><span class="p">.</span><span class="nf">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">7</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">obj_sgdcr</span> <span class="o">-</span> <span class="n">obj_min</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGD const</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">obj_sgdss</span> <span class="o">-</span> <span class="n">obj_min</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGD switch</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">obj_sgdar</span> <span class="o">-</span> <span class="n">obj_min</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGD average end</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">title</span><span class="p">(</span><span class="sh">"</span><span class="s">Convergence plot</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">16</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">xlabel</span><span class="p">(</span><span class="sh">"</span><span class="s">#iterations</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">ylabel</span><span class="p">(</span><span class="sh">"</span><span class="s">Loss function</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">axvline</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">tstar</span><span class="p">,</span> <span class="n">color</span> <span class="o">=</span> <span class="sh">"</span><span class="s">orange</span><span class="sh">"</span><span class="p">,</span> <span class="n">linestyle</span><span class="o">=</span><span class="sh">'</span><span class="s">dashed</span><span class="sh">'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">axvline</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">start_late_averaging</span><span class="p">,</span> <span class="n">color</span> <span class="o">=</span> <span class="sh">"</span><span class="s">green</span><span class="sh">"</span><span class="p">,</span> <span class="n">linestyle</span><span class="o">=</span><span class="sh">'</span><span class="s">dashed</span><span class="sh">'</span><span class="p">)</span>

<span class="c1"># Distance to the minimizer on a logarithmic scale
</span><span class="n">plt</span><span class="p">.</span><span class="nf">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">7</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">err_sgdcr</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGD const</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">err_sgdss</span> <span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGD switch</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">err_sgdar</span> <span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGD average end</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">title</span><span class="p">(</span><span class="sh">"</span><span class="s">Convergence plot</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">16</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">xlabel</span><span class="p">(</span><span class="sh">"</span><span class="s">#iterations</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">ylabel</span><span class="p">(</span><span class="sh">"</span><span class="s">Distance to the minimum</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">axvline</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">tstar</span><span class="p">,</span> <span class="n">color</span> <span class="o">=</span> <span class="sh">"</span><span class="s">orange</span><span class="sh">"</span><span class="p">,</span> <span class="n">linestyle</span><span class="o">=</span><span class="sh">'</span><span class="s">dashed</span><span class="sh">'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">axvline</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">start_late_averaging</span><span class="p">,</span> <span class="n">color</span> <span class="o">=</span> <span class="sh">"</span><span class="s">green</span><span class="sh">"</span><span class="p">,</span> <span class="n">linestyle</span><span class="o">=</span><span class="sh">'</span><span class="s">dashed</span><span class="sh">'</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<div class="row mt-3 justify-content-center">
    <div class="col-sm-8 mt-3 mt-md-0">
        <picture>
            <img src="/images/blog/averaging1.png" width="100%" alt="Convergence plot comparing constant, switching and averaging step sizes" />
        </picture>
    </div>
</div>
<div class="row mt-3 justify-content-center">
    <div class="col-sm-8 mt-3 mt-md-0">
        <picture>
            <img src="/images/blog/averaging2.png" width="100%" alt="Distance to minimum plot comparing constant, switching and averaging step sizes" />
        </picture>
    </div>
</div>
<div class="caption text-center">
    A plot showing the difference between constant, switch and averaging step size.
</div>

<p>We can see that the averaging technique (green line) helps to stabilize the objective function, especially towards the end of the optimization process. This method is particularly useful when we want to ensure that the algorithm converges to a solution that generalizes well, as it mitigates the risk of overfitting due to fluctuations in the later stages.</p>

<h2 id="sgd-with-momentum">SGD with Momentum</h2>

<p>Momentum is a technique used to accelerate convergence, especially in scenarios where gradients oscillate. By adding a fraction of the previous update to the current update, this method potentially lead to faster convergence. <strong><em>Please note that I have already given and explained the updare rule for SGD with momentum in \(\eqref{eq:momentum}\).</em></strong></p>

<p>Now let’s implement SGD with momentum:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
</pre></td><td class="rouge-code"><pre><span class="n">indices</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">choice</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">n_iter</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">replace</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">averaging_on</span> <span class="o">=</span> <span class="bp">True</span>
<span class="n">start_late_averaging</span> <span class="o">=</span> <span class="mf">0.0</span>
<span class="n">momentum</span> <span class="o">=</span> <span class="mf">1.0</span>
<span class="n">w_sgdm</span><span class="p">,</span> <span class="n">obj_sgdm</span><span class="p">,</span> <span class="n">err_sgdm</span> <span class="o">=</span> <span class="nf">sgd</span><span class="p">(</span><span class="n">w0</span><span class="p">,</span><span class="n">model</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">steps_switch</span><span class="p">,</span> <span class="n">w_min</span><span class="p">,</span> <span class="n">n_iter</span><span class="p">,</span> <span class="n">averaging_on</span><span class="p">,</span> <span class="n">momentum</span><span class="p">,</span> <span class="bp">True</span><span class="p">,</span> <span class="n">start_late_averaging</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>For simplicity, we have set the momentum parameter to \(1\). However, you can work with different values of momentum to check which one works best.</p>

<h2 id="comparing-the-results-1">Comparing the Results</h2>

<p>Let’s now compare the performance of SGD with constant step size, switching step size, switching step size with averaging, and SGD with momentum.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
</pre></td><td class="rouge-code"><pre><span class="c1"># Plotting to compare constant stepsize, switchting, switching + averaging
</span><span class="n">plt</span><span class="p">.</span><span class="nf">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">7</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">obj_sgdcr</span> <span class="o">-</span> <span class="n">obj_min</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGD const</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">obj_sgdss</span> <span class="o">-</span> <span class="n">obj_min</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGD switch</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">obj_sgdar</span> <span class="o">-</span> <span class="n">obj_min</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGD average end</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">obj_sgdm</span> <span class="o">-</span> <span class="n">obj_min</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGDm</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">title</span><span class="p">(</span><span class="sh">"</span><span class="s">Convergence plot</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">16</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">xlabel</span><span class="p">(</span><span class="sh">"</span><span class="s">#iterations</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">ylabel</span><span class="p">(</span><span class="sh">"</span><span class="s">Loss function</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">axvline</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">tstar</span><span class="p">,</span> <span class="n">color</span> <span class="o">=</span> <span class="sh">"</span><span class="s">orange</span><span class="sh">"</span><span class="p">,</span> <span class="n">linestyle</span><span class="o">=</span><span class="sh">'</span><span class="s">dashed</span><span class="sh">'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">axvline</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">start_late_averaging</span><span class="p">,</span> <span class="n">color</span> <span class="o">=</span> <span class="sh">"</span><span class="s">purple</span><span class="sh">"</span><span class="p">,</span> <span class="n">linestyle</span><span class="o">=</span><span class="sh">'</span><span class="s">dashed</span><span class="sh">'</span><span class="p">)</span>

<span class="c1"># Distance to the minimizer on a logarithmic scale
</span><span class="n">plt</span><span class="p">.</span><span class="nf">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">7</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">err_sgdcr</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGD const</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">err_sgdss</span> <span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGD switch</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">err_sgdar</span> <span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGD average end</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">err_sgdm</span> <span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGDm</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">title</span><span class="p">(</span><span class="sh">"</span><span class="s">Convergence plot</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">16</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">xlabel</span><span class="p">(</span><span class="sh">"</span><span class="s">#iterations</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">ylabel</span><span class="p">(</span><span class="sh">"</span><span class="s">Distance to the minimum</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">legend</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">axvline</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">tstar</span><span class="p">,</span> <span class="n">color</span> <span class="o">=</span> <span class="sh">"</span><span class="s">orange</span><span class="sh">"</span><span class="p">,</span> <span class="n">linestyle</span><span class="o">=</span><span class="sh">'</span><span class="s">dashed</span><span class="sh">'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">axvline</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">start_late_averaging</span><span class="p">,</span> <span class="n">color</span> <span class="o">=</span> <span class="sh">"</span><span class="s">purple</span><span class="sh">"</span><span class="p">,</span> <span class="n">linestyle</span><span class="o">=</span><span class="sh">'</span><span class="s">dashed</span><span class="sh">'</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<div class="row mt-3 justify-content-center">
    <div class="col-sm-8 mt-3 mt-md-0">
        <picture>
            <img src="/images/blog/momentum1.png" width="100%" alt="Convergence plot comparing constant, switch and momentum step sizes" />
        </picture>
    </div>
</div>
<div class="row mt-3 justify-content-center">
    <div class="col-sm-8 mt-3 mt-md-0">
        <picture>
            <img src="/images/blog/momentum2.png" width="100%" alt="Distance to minimum plot comparing constant, switch and momentum step sizes" />
        </picture>
    </div>
</div>
<div class="caption text-center">
    A plot showing the difference between constant, switch and momentum step sizes.
</div>

<p>We can observe that SGD with momentum (red curve) shows the fastest convergence, outperforming other methods in both loss reduction and distance to the minimum. The vertical dashed lines indicate the point at which the step size switching occurs and where the late averaging begins.</p>

<h2 id="sgd-without-replacement">SGD without Replacement</h2>

<p>SGD without replacement selects each data point exactly once per epoch, ensuring that the model sees the entire dataset in each pass without replacement.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
</pre></td><td class="rouge-code"><pre><span class="kn">import</span> <span class="n">numpy.matlib</span>
<span class="n">niters</span> <span class="o">=</span> <span class="nf">int</span><span class="p">(</span><span class="n">datapasses</span> <span class="o">*</span> <span class="n">n</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
<span class="n">indices</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">matlib</span><span class="p">.</span><span class="nf">repmat</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">choice</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">replace</span> <span class="o">=</span> <span class="bp">False</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span> <span class="n">datapasses</span><span class="p">)</span>
<span class="n">indices</span> <span class="o">=</span> <span class="n">indices</span><span class="p">.</span><span class="nf">flatten</span><span class="p">()</span>
<span class="n">w_sgdsw</span><span class="p">,</span> <span class="n">obj_sgdsw</span><span class="p">,</span> <span class="n">err_sgdsw</span> <span class="o">=</span> <span class="nf">sgd</span><span class="p">(</span><span class="n">w0</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">steps_switch</span><span class="p">,</span> <span class="n">w_min</span><span class="p">,</span> <span class="n">niters</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<h3 id="compare-result">Compare Result</h3>
<p>Let’s now compare the performance of SGD with replacement and without replacement.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
</pre></td><td class="rouge-code"><pre><span class="c1"># Error of objective on a logarithmic scale
</span><span class="n">plt</span><span class="p">.</span><span class="nf">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">7</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">yscale</span><span class="p">(</span><span class="sh">"</span><span class="s">log</span><span class="sh">"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">plot</span><span class="p">(</span><span class="n">obj_sgdss</span> <span class="o">-</span> <span class="n">obj_min</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGD with replacement</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">plot</span><span class="p">(</span><span class="n">obj_sgdsw</span> <span class="o">-</span> <span class="n">obj_min</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGD without replacement</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">title</span><span class="p">(</span><span class="sh">"</span><span class="s">Convergence plot</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">16</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">xlabel</span><span class="p">(</span><span class="sh">"</span><span class="s">#iterations</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">ylabel</span><span class="p">(</span><span class="sh">"</span><span class="s">Distance to the minimum</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">legend</span><span class="p">()</span>

<span class="c1"># Distance to the minimizer on a logarithmic scale
</span><span class="n">plt</span><span class="p">.</span><span class="nf">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">7</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">yscale</span><span class="p">(</span><span class="sh">"</span><span class="s">log</span><span class="sh">"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">plot</span><span class="p">(</span><span class="n">err_sgdss</span> <span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGD replacement</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">plot</span><span class="p">(</span><span class="n">err_sgdsw</span> <span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGD without replacement</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">title</span><span class="p">(</span><span class="sh">"</span><span class="s">Convergence plot</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">16</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">xlabel</span><span class="p">(</span><span class="sh">"</span><span class="s">#iterations</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">ylabel</span><span class="p">(</span><span class="sh">"</span><span class="s">Distance to the minimum</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">legend</span><span class="p">()</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<div class="row mt-3 justify-content-center">
    <div class="col-sm-8 mt-3 mt-md-0">
        <picture>
            <img src="/images/blog/without_replace1.png" width="400" alt="Convergence plot comparing SGD with and without replacement" />
        </picture>
    </div>
    <div class="col-sm mt-3 mt-md-0">
        <picture>
            <img src="/images/blog/without_replace2.png" width="400" alt="Distance to minimum plot comparing SGD with and without replacement" />
        </picture>
    </div>
</div>
<div class="caption">
    A plot showing the comparison between SGD with replacement and without replacement.
</div>

<p>SGD without replacement demonstrates a better convergence to the minimum, likely due to the efficiency of utilizing the dataset without replacement. This method is generally more efficient because it avoids redundant updates and thus lead to faster convergence.</p>

<h2 id="comparing-gradient-descent-with-stochastic-gradient-descent">Comparing Gradient Descent with Stochastic Gradient Descent</h2>

<p>After looking at various forms of Stochastic Gradient Descent (SGD), it’s important to compare these results with the traditional <strong>Gradient Descent (GD)</strong> method.</p>

<h3 id="gradient-descent-implementation">Gradient Descent Implementation</h3>
<p>In each iteration of the gradient descent algorithm, the gradient is computed using the entire dataset, and the model’s weights are updated accordingly.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
</pre></td><td class="rouge-code"><pre><span class="k">def</span> <span class="nf">gd</span><span class="p">(</span><span class="n">w0</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span> <span class="n">w_min</span> <span class="o">=</span><span class="p">[],</span> <span class="n">n_iter</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="bp">True</span><span class="p">):</span>
    <span class="sh">"""</span><span class="s">Gradient descent algorithm
    </span><span class="sh">"""</span>
    <span class="n">w</span> <span class="o">=</span> <span class="n">w0</span><span class="p">.</span><span class="nf">copy</span><span class="p">()</span>
    <span class="n">w_new</span> <span class="o">=</span> <span class="n">w0</span><span class="p">.</span><span class="nf">copy</span><span class="p">()</span>
    <span class="n">n_samples</span><span class="p">,</span> <span class="n">n_features</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span>
    <span class="c1"># estimation error history
</span>    <span class="n">errors</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="n">err</span> <span class="o">=</span> <span class="mf">1.</span>
    <span class="c1"># objective history
</span>    <span class="n">objectives</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="c1"># Current estimation error
</span>    <span class="k">if</span> <span class="n">np</span><span class="p">.</span><span class="nf">any</span><span class="p">(</span><span class="n">w_min</span><span class="p">):</span>
        <span class="n">err</span> <span class="o">=</span> <span class="nf">norm</span><span class="p">(</span><span class="n">w</span> <span class="o">-</span> <span class="n">w_min</span><span class="p">)</span> <span class="o">/</span> <span class="nf">norm</span><span class="p">(</span><span class="n">w_min</span><span class="p">)</span>
        <span class="n">errors</span><span class="p">.</span><span class="nf">append</span><span class="p">(</span><span class="n">err</span><span class="p">)</span>
    <span class="c1"># Current objective
</span>    <span class="n">obj</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="nf">f</span><span class="p">(</span><span class="n">w</span><span class="p">)</span>
    <span class="n">objectives</span><span class="p">.</span><span class="nf">append</span><span class="p">(</span><span class="n">obj</span><span class="p">)</span>
    <span class="k">if</span> <span class="n">verbose</span><span class="p">:</span>
        <span class="nf">print</span><span class="p">(</span><span class="sh">"</span><span class="s">Lauching GD solver...</span><span class="sh">"</span><span class="p">)</span>
        <span class="nf">print</span><span class="p">(</span><span class="sh">'</span><span class="s"> | </span><span class="sh">'</span><span class="p">.</span><span class="nf">join</span><span class="p">([</span><span class="n">name</span><span class="p">.</span><span class="nf">center</span><span class="p">(</span><span class="mi">8</span><span class="p">)</span> <span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="p">[</span><span class="sh">"</span><span class="s">it</span><span class="sh">"</span><span class="p">,</span> <span class="sh">"</span><span class="s">obj</span><span class="sh">"</span><span class="p">,</span> <span class="sh">"</span><span class="s">err</span><span class="sh">"</span><span class="p">]]))</span>
    <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="n">n_iter</span> <span class="p">):</span>
        <span class="n">w</span><span class="p">[:]</span> <span class="o">=</span> <span class="n">w</span> <span class="o">-</span> <span class="n">step</span> <span class="o">*</span> <span class="n">model</span><span class="p">.</span><span class="nf">grad</span><span class="p">(</span><span class="n">w</span><span class="p">)</span>
        <span class="n">obj</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="nf">f</span><span class="p">(</span><span class="n">w</span><span class="p">)</span>
        <span class="nf">if </span><span class="p">(</span><span class="nf">sum</span><span class="p">(</span><span class="n">w_min</span><span class="p">)):</span>
            <span class="n">err</span> <span class="o">=</span> <span class="nf">norm</span><span class="p">(</span><span class="n">w</span> <span class="o">-</span> <span class="n">w_min</span><span class="p">)</span> <span class="o">/</span> <span class="nf">norm</span><span class="p">(</span><span class="n">w_min</span><span class="p">)</span>
            <span class="n">errors</span><span class="p">.</span><span class="nf">append</span><span class="p">(</span><span class="n">err</span><span class="p">)</span>
        <span class="n">objectives</span><span class="p">.</span><span class="nf">append</span><span class="p">(</span><span class="n">obj</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">verbose</span><span class="p">:</span>
            <span class="nf">print</span><span class="p">(</span><span class="sh">'</span><span class="s"> | </span><span class="sh">'</span><span class="p">.</span><span class="nf">join</span><span class="p">([(</span><span class="sh">"</span><span class="s">%d</span><span class="sh">"</span> <span class="o">%</span> <span class="n">k</span><span class="p">).</span><span class="nf">rjust</span><span class="p">(</span><span class="mi">8</span><span class="p">),</span>
                              <span class="p">(</span><span class="sh">"</span><span class="s">%.2e</span><span class="sh">"</span> <span class="o">%</span> <span class="n">obj</span><span class="p">).</span><span class="nf">rjust</span><span class="p">(</span><span class="mi">8</span><span class="p">),</span>
                              <span class="p">(</span><span class="sh">"</span><span class="s">%.2e</span><span class="sh">"</span> <span class="o">%</span> <span class="n">err</span><span class="p">).</span><span class="nf">rjust</span><span class="p">(</span><span class="mi">8</span><span class="p">)]))</span>
    <span class="k">return</span> <span class="n">w</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="nf">array</span><span class="p">(</span><span class="n">objectives</span><span class="p">),</span> <span class="n">np</span><span class="p">.</span><span class="nf">array</span><span class="p">(</span><span class="n">errors</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>To ensure stable convergence in Gradient Descent, we select the step size (<code class="language-plaintext highlighter-rouge">step</code>) as the inverse of the Lipschitz constant of the gradient:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
</pre></td><td class="rouge-code"><pre><span class="n">step</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">/</span> <span class="n">model</span><span class="p">.</span><span class="nf">lipschitz_constant</span><span class="p">()</span>
<span class="n">w_gd</span><span class="p">,</span> <span class="n">obj_gd</span><span class="p">,</span> <span class="n">err_gd</span> <span class="o">=</span> <span class="nf">gd</span><span class="p">(</span><span class="n">w0</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span> <span class="n">w_min</span><span class="p">,</span> <span class="n">datapasses</span><span class="p">)</span>
<span class="nf">print</span><span class="p">(</span><span class="n">obj_gd</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<p>To fairly compare GD with SGD, we calculate the computational complexity of GD. Since each step of GD requires a full pass over the dataset, the total computational effort can be represented as:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
</pre></td><td class="rouge-code"><pre><span class="n">complexityofGD</span> <span class="o">=</span> <span class="n">n</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="nf">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">datapasses</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<h3 id="compare-results">Compare Results</h3>

<p>Let’s now compare the performance of SGD with GD.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
</pre></td><td class="rouge-code"><pre><span class="c1"># Error of objective on a logarithmic scale
</span><span class="n">plt</span><span class="p">.</span><span class="nf">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">7</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">complexityofGD</span><span class="p">,</span> <span class="n">obj_gd</span> <span class="o">-</span> <span class="n">obj_min</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">gd</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">obj_sgdss</span> <span class="o">-</span> <span class="n">obj_min</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">sgd switch</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">obj_sgdm</span> <span class="o">-</span> <span class="n">obj_min</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">sgdm</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">title</span><span class="p">(</span><span class="sh">"</span><span class="s">Convergence plot</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">16</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">xlabel</span><span class="p">(</span><span class="sh">"</span><span class="s"># SGD iterations</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">ylabel</span><span class="p">(</span><span class="sh">"</span><span class="s">Loss function</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">legend</span><span class="p">()</span>

<span class="c1"># Distance to the minimum on a logarithmic scale
</span><span class="n">plt</span><span class="p">.</span><span class="nf">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">7</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">complexityofGD</span><span class="p">,</span> <span class="n">err_gd</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">gd</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">err_sgdss</span> <span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">sgd switch</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">err_sgdm</span> <span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">sgd switch</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">title</span><span class="p">(</span><span class="sh">"</span><span class="s">Convergence plot</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">16</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">xlabel</span><span class="p">(</span><span class="sh">"</span><span class="s"># SGD iterations</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">ylabel</span><span class="p">(</span><span class="sh">"</span><span class="s">Distance to the minimum</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">legend</span><span class="p">()</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<div class="row mt-3">
    <div class="col-sm mt-3 mt-md-0">
        <picture>
            <img src="/images/blog/gd1.png" width="400" alt="Convergence plot comparing SGD and GD" />
        </picture>
    </div>
    <div class="col-sm mt-3 mt-md-0">
        <picture>
            <img src="/images/blog/gd2.png" width="400" alt="Distance to minimum plot comparing SGD and GD" />
        </picture>
    </div>
</div>
<div class="caption">
    A plot showing the comparison between SGD and GD.
</div>

<p>From our comparison, SGD variants are more computationally efficient compared to GD. They make faster progress in the initial stages, which is crucial in large-scale datasets. GD provides more stable convergence but at a higher computational cost.</p>

<h2 id="comparing-test-error-gradient-descent-vs-sgd-with-momentum">Comparing Test Error: Gradient Descent vs. SGD with Momentum</h2>

<p>In this final comparison, we focus on the test error, which is important for understanding how well our models generalize to unseen data.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
</pre></td><td class="rouge-code"><pre><span class="n">datapasses</span> <span class="o">=</span> <span class="mi">30</span><span class="p">;</span>
<span class="n">n_iters</span> <span class="o">=</span> <span class="nf">int</span><span class="p">(</span><span class="n">datapasses</span> <span class="o">*</span> <span class="n">n</span><span class="p">)</span>
<span class="c1"># With replacement
</span><span class="n">indices</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">matlib</span><span class="p">.</span><span class="nf">repmat</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">choice</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">replace</span> <span class="o">=</span> <span class="bp">False</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span> <span class="n">datapasses</span><span class="p">)</span>
<span class="n">indices</span> <span class="o">=</span> <span class="n">indices</span><span class="p">.</span><span class="nf">flatten</span><span class="p">()</span>
<span class="c1">##
</span><span class="n">steps</span> <span class="o">=</span> <span class="mf">0.25</span> <span class="o">/</span> <span class="n">np</span><span class="p">.</span><span class="nf">sqrt</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nf">arange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">niters</span> <span class="o">+</span> <span class="mi">2</span><span class="p">))</span>

<span class="n">indices</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="nf">choice</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">n_iter</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">replace</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">w_sgdar</span><span class="p">,</span> <span class="n">obj_sgdar</span><span class="p">,</span> <span class="n">err_sgdart</span>    <span class="o">=</span> <span class="nf">sgd</span><span class="p">(</span><span class="n">w0</span><span class="p">,</span><span class="n">model</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">steps_switch</span><span class="p">,</span> <span class="n">w_model_truth</span><span class="p">,</span> <span class="n">n_iter</span><span class="p">,</span> <span class="bp">True</span><span class="p">,</span> <span class="bp">False</span><span class="p">,</span> <span class="mi">3</span><span class="o">*</span><span class="n">n_iter</span><span class="o">/</span><span class="mi">4</span><span class="p">)</span> <span class="c1"># (datapasses-5)*n
</span>
<span class="n">w_sgdsw</span><span class="p">,</span> <span class="n">obj_sgdsw</span><span class="p">,</span> <span class="n">err_sgdswt</span> <span class="o">=</span> <span class="nf">sgd</span><span class="p">(</span><span class="n">w0</span><span class="p">,</span><span class="n">model</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">steps</span><span class="p">,</span> <span class="n">w_model_truth</span><span class="p">,</span> <span class="n">n_iter</span><span class="p">,</span> <span class="n">verbose</span> <span class="o">=</span> <span class="bp">False</span><span class="p">);</span>
<span class="c1">## GD
</span><span class="n">step</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">/</span> <span class="n">model</span><span class="p">.</span><span class="nf">lipschitz_constant</span><span class="p">()</span>
<span class="n">w_gd</span><span class="p">,</span> <span class="n">obj_gd</span><span class="p">,</span> <span class="n">err_gd</span> <span class="o">=</span> <span class="nf">gd</span><span class="p">(</span><span class="n">w0</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span> <span class="n">w_model_truth</span><span class="p">,</span> <span class="n">datapasses</span><span class="p">,</span> <span class="n">verbose</span> <span class="o">=</span> <span class="bp">False</span><span class="p">)</span>
<span class="n">complexityofGD</span> <span class="o">=</span> <span class="n">n</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="nf">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">datapasses</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>

<span class="c1">## SGD with momentum
</span><span class="n">averaging_on</span> <span class="o">=</span> <span class="bp">True</span>
<span class="n">start_late_averaging</span> <span class="o">=</span> <span class="mf">0.0</span>
<span class="n">momentum</span> <span class="o">=</span> <span class="mf">1.0</span>
<span class="n">w_sgdm</span><span class="p">,</span> <span class="n">obj_sgdm</span><span class="p">,</span> <span class="n">err_sgdmt</span> <span class="o">=</span> <span class="nf">sgd</span><span class="p">(</span><span class="n">w0</span><span class="p">,</span><span class="n">model</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">steps_switch</span><span class="p">,</span> <span class="n">w_model_truth</span><span class="p">,</span> <span class="n">n_iter</span><span class="p">,</span> <span class="n">averaging_on</span><span class="p">,</span> <span class="n">momentum</span><span class="p">,</span> <span class="bp">True</span><span class="p">,</span> <span class="n">start_late_averaging</span><span class="p">)</span> <span class="c1"># (datapasses-5)*n
</span>
<span class="c1">## GD
</span><span class="n">step</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">/</span> <span class="n">model</span><span class="p">.</span><span class="nf">lipschitz_constant</span><span class="p">()</span>
<span class="n">w_gd</span><span class="p">,</span> <span class="n">obj_gd</span><span class="p">,</span> <span class="n">err_gdt</span> <span class="o">=</span> <span class="nf">gd</span><span class="p">(</span><span class="n">w0</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">step</span><span class="p">,</span> <span class="n">w_model_truth</span><span class="p">,</span> <span class="n">datapasses</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<h3 id="compare-result-1">Compare Result</h3>

<p>Let’s compares the test error convergence for Gradient Descent (GD) and Stochastic Gradient Descent with Momentum (SGDm).</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><table class="rouge-table"><tbody><tr><td class="rouge-gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
</pre></td><td class="rouge-code"><pre><span class="c1"># Distance to the minimizer on a logarithmic scale
</span><span class="n">plt</span><span class="p">.</span><span class="nf">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">7</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">yscale</span><span class="p">(</span><span class="sh">"</span><span class="s">log</span><span class="sh">"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">complexityofGD</span><span class="p">,</span> <span class="n">err_gdt</span> <span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">GD</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="c1"># plt.semilogy(err_sgdswt, label="SGD without replacement", lw=2)
# plt.semilogy(err_sgdart , label="SGD averaging end", lw=2)
</span><span class="n">plt</span><span class="p">.</span><span class="nf">semilogy</span><span class="p">(</span><span class="n">err_sgdmt</span><span class="p">,</span>  <span class="n">label</span><span class="o">=</span><span class="sh">"</span><span class="s">SGDm</span><span class="sh">"</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">title</span><span class="p">(</span><span class="sh">"</span><span class="s">Convergence plot</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">16</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">xlabel</span><span class="p">(</span><span class="sh">"</span><span class="s">#iterations</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">ylabel</span><span class="p">(</span><span class="sh">"</span><span class="s">Test error</span><span class="sh">"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="nf">legend</span><span class="p">()</span>
</pre></td></tr></tbody></table></code></pre></div></div>

<div class="row mt-3">
    <div class="col-sm mt-3 mt-md-0">
        <picture>
            <img src="/images/blog/sgd_gd.png" width="400" alt="Test error comparison between SGD and GD" />
        </picture>
    </div>
</div>
<div class="caption">
    A plot showing the comparison between SGD and GD.
</div>

<p>From the plot, SGDm not only converges faster but also achieves a lower final test error compared to GD. This indicates better generalization, making SGDm more suitable for real-world applications where test performance is critical.</p>

<h2 id="conclusion">Conclusion</h2>

<p>By comparing these methods with Gradient Descent, we’ve highlighted the practical advantages of SGD, particularly in handling large-scale datasets where computational efficiency is key. Our final comparison of test error revealed that SGD with momentum not only accelerates convergence but also leads to superior model performance, making it a powerful method.</p>]]></content><author><name></name></author><category term="optimization" /><category term="stochastic-gradient" /><category term="gradient-descent" /><category term="momentum" /><summary type="html"><![CDATA[In this post, we implemented stochastic gradient descent in python which is one of the efficient method for training ML models. The implementation encompases various SGD variants like constant and shrinking step sizes, momentum, and averaging, comparing how each one impacts the speed and accuracy of the model's convergence.]]></summary></entry></feed>