{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "ea3e9a09-3257-4b1a-9245-d42bbf88d06b", "metadata": { "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd \n", "from matplotlib import pyplot as plt \n", "\n", "data = pd.read_csv('gtsrb_data_test.csv')" ] }, { "cell_type": "code", "execution_count": 2, "id": "fe7852fc-5353-478d-9c7e-ad5fd251d963", "metadata": { "tags": [] }, "outputs": [], "source": [ "data = np.array(data)\n", "\n", "m,n = data.shape\n", "data_train = data[1000:m].T\n", "\n", "Y_train = data_train[0].astype(int)\n", "\n", "X_train = data_train[1:n]" ] }, { "cell_type": "code", "execution_count": 13, "id": "f3d23a27-d544-44ea-9291-5697fac2b8c2", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.5568628\n" ] } ], "source": [] }, { "cell_type": "code", "execution_count": 3, "id": "42435b4b-6390-4c20-b04a-cec2730879fd", "metadata": { "tags": [] }, "outputs": [], "source": [ "def init_params():\n", " W1 = np.random.rand(10,1024) - 0.5\n", " b1 = np.random.rand(10,1) - 0.5\n", " W2 = np.random.rand(43,10) - 0.5\n", " b2 = np.random.rand(43,1) - 0.5\n", " return W1, b1 , W2, b2" ] }, { "cell_type": "code", "execution_count": 4, "id": "f9da376f-6b2a-4c7c-b392-c34f7dc7dee4", "metadata": { "tags": [] }, "outputs": [], "source": [ "def ReLU(Z):\n", " return np.maximum(Z,0)" ] }, { "cell_type": "code", "execution_count": 5, "id": "6cd98c98-9f06-4362-a3cd-a0a9b168ea97", "metadata": { "tags": [] }, "outputs": [], "source": [ "def softmax(Z):\n", " A = np.exp(Z) / sum(np.exp(Z))\n", " return A" ] }, { "cell_type": "code", "execution_count": 6, "id": "039f4e5d-9d7c-4f59-81d6-f64310cc5b1f", "metadata": { "tags": [] }, "outputs": [], "source": [ "def forward_prop(W1, b1, W2, b2, X):\n", " Z1 = W1.dot(X) + b1\n", " A1 = ReLU(Z1)\n", " Z2 = W2.dot(A1) + b2\n", " A2 = softmax(Z2)\n", " return Z1, A1, Z2, A2" ] }, { "cell_type": "code", "execution_count": 7, "id": "24588282-2ec1-4c64-9b9d-4e7d395449fb", "metadata": { "tags": [] }, "outputs": [], "source": [ "def ReLU_deriv(Z):\n", " return Z > 0" ] }, { "cell_type": "code", "execution_count": 8, "id": "d4977d9c-38c9-4758-9262-dd03b8bbd015", "metadata": { "tags": [] }, "outputs": [], "source": [ "def one_hot(Y):\n", " one_hot_Y = np.zeros((Y.size, Y.max() + 1))\n", " one_hot_Y[np.arange(Y.size), Y] = 1\n", " one_hot_Y = one_hot_Y.T\n", " return one_hot_Y" ] }, { "cell_type": "code", "execution_count": 9, "id": "941b2ec2-980f-4577-abe3-d13ac252e86b", "metadata": { "tags": [] }, "outputs": [], "source": [ "def backward_prop(Z1, A1, Z2, A2, W1, W2, X, Y):\n", " one_hot_Y = one_hot(Y)\n", " dZ2 = A2 - one_hot_Y\n", " dW2 = 1 / m * dZ2.dot(A1.T)\n", " db2 = 1 / m * np.sum(dZ2)\n", " dZ1 = W2.T.dot(dZ2) * ReLU_deriv(Z1)\n", " dW1 = 1 / m * dZ1.dot(X.T)\n", " db1 = 1 / m * np.sum(dZ1)\n", " return dW1, db1, dW2, db2" ] }, { "cell_type": "code", "execution_count": 10, "id": "20ab4235-86ac-4708-8b16-8d0169ecd97f", "metadata": { "tags": [] }, "outputs": [], "source": [ "def update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha):\n", " W1 = W1 - alpha * dW1\n", " b1 = b1 - alpha * db1 \n", " W2 = W2 - alpha * dW2 \n", " b2 = b2 - alpha * db2 \n", " return W1, b1, W2, b2" ] }, { "cell_type": "code", "execution_count": 11, "id": "b82641d2-7ebb-4adb-9c40-3a6a28ccb49b", "metadata": { "tags": [] }, "outputs": [], "source": [ "def get_predictions(A2):\n", " return np.argmax(A2, 0)" ] }, { "cell_type": "code", "execution_count": 12, "id": "2f38d2cc-316a-48f3-a275-57253249f132", "metadata": { "tags": [] }, "outputs": [], "source": [ "def get_accuracy(predictions, Y):\n", " #print(predictions, Y)\n", " return np.sum(predictions == Y) / Y.size" ] }, { "cell_type": "code", "execution_count": 13, "id": "c46ae58a-9362-4f7d-8ce4-d668230b1647", "metadata": { "tags": [] }, "outputs": [], "source": [ "acc_store = [] " ] }, { "cell_type": "code", "execution_count": 14, "id": "54095b38-ca0a-41e6-9e33-c7b484e73c86", "metadata": { "tags": [] }, "outputs": [], "source": [ "def gradient_descent(X, Y, alpha, iterations):\n", " W1, b1, W2, b2 = init_params()\n", " for i in range(iterations):\n", " Z1, A1, Z2, A2 = forward_prop(W1, b1, W2, b2, X)\n", " dW1, db1, dW2, db2 = backward_prop(Z1, A1, Z2, A2, W1, W2, X, Y)\n", " W1, b1, W2, b2 = update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha)\n", " if i % 10 == 0:\n", " print(\"Iteration: \", i)\n", " predictions = get_predictions(A2)\n", " pred = get_accuracy(predictions, Y)\n", " print(pred)\n", " acc_store.append(pred)\n", " return W1, b1, W2, b2" ] }, { "cell_type": "code", "execution_count": 15, "id": "bfa77e8c-04e4-4d1e-85f0-ea6cdd71e924", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration: 0\n", "0.025193465176268273\n", "Iteration: 10\n", "0.04428202923473775\n", "Iteration: 20\n", "0.04883920894239037\n", "Iteration: 30\n", "0.060791057609630265\n", "Iteration: 40\n", "0.07919174548581255\n", "Iteration: 50\n", "0.09501289767841789\n", "Iteration: 60\n", "0.10378331900257953\n", "Iteration: 70\n", "0.10902837489251935\n", "Iteration: 80\n", "0.1124677558039553\n", "Iteration: 90\n", "0.11401547721410146\n", "Iteration: 100\n", "0.11685296646603612\n", "Iteration: 110\n", "0.12158211521926053\n", "Iteration: 120\n", "0.12510748065348237\n", "Iteration: 130\n", "0.12768701633705934\n", "Iteration: 140\n", "0.1294926913155632\n", "Iteration: 150\n", "0.13181427343078245\n", "Iteration: 160\n", "0.13284608770421324\n", "Iteration: 170\n", "0.13404987102321583\n", "Iteration: 180\n", "0.13473774720550302\n", "Iteration: 190\n", "0.13680137575236456\n", "Iteration: 200\n", "0.13843508168529664\n", "Iteration: 210\n", "0.13938091143594153\n", "Iteration: 220\n", "0.13989681857265693\n", "Iteration: 230\n", "0.13989681857265693\n", "Iteration: 240\n", "0.1402407566638005\n", "Iteration: 250\n", "0.1404987102321582\n", "Iteration: 260\n", "0.1410146173688736\n", "Iteration: 270\n", "0.1419604471195185\n", "Iteration: 280\n", "0.14282029234737748\n", "Iteration: 290\n", "0.14462596732588134\n", "Iteration: 300\n", "0.1469475494411006\n", "Iteration: 310\n", "0.1476354256233878\n", "Iteration: 320\n", "0.14944110060189167\n", "Iteration: 330\n", "0.15150472914875324\n", "Iteration: 340\n", "0.1529664660361135\n", "Iteration: 350\n", "0.15657781599312123\n", "Iteration: 360\n", "0.15950128976784178\n", "Iteration: 370\n", "0.16362854686156492\n", "Iteration: 380\n", "0.16603611349957006\n", "Iteration: 390\n", "0.16913155631986243\n", "Iteration: 400\n", "0.17205503009458298\n", "Iteration: 410\n", "0.1761822871883061\n", "Iteration: 420\n", "0.17815993121238177\n", "Iteration: 430\n", "0.18185726569217542\n", "Iteration: 440\n", "0.18426483233018057\n", "Iteration: 450\n", "0.18538263112639725\n", "Iteration: 460\n", "0.18512467755803955\n", "Iteration: 470\n", "0.18512467755803955\n", "Iteration: 480\n", "0.1874462596732588\n", "Iteration: 490\n", "0.1935511607910576\n", "Iteration: 500\n", "0.19733447979363714\n", "Iteration: 510\n", "0.2006018916595013\n", "Iteration: 520\n", "0.20429922613929494\n", "Iteration: 530\n", "0.20834049871023216\n", "Iteration: 540\n", "0.21298366294067067\n", "Iteration: 550\n", "0.21496130696474636\n", "Iteration: 560\n", "0.21908856405846946\n", "Iteration: 570\n", "0.2238177128116939\n", "Iteration: 580\n", "0.2291487532244196\n", "Iteration: 590\n", "0.23293207222699913\n", "Iteration: 600\n", "0.23439380911435942\n", "Iteration: 610\n", "0.23843508168529665\n", "Iteration: 620\n", "0.24316423043852106\n", "Iteration: 630\n", "0.2465176268271711\n", "Iteration: 640\n", "0.2524505588993981\n", "Iteration: 650\n", "0.2578675838349097\n", "Iteration: 660\n", "0.2623387790197764\n", "Iteration: 670\n", "0.2658641444539983\n", "Iteration: 680\n", "0.27136715391229577\n", "Iteration: 690\n", "0.2776440240756664\n", "Iteration: 700\n", "0.2808254514187446\n", "Iteration: 710\n", "0.2844368013757524\n", "Iteration: 720\n", "0.2868443680137575\n", "Iteration: 730\n", "0.29191745485812554\n", "Iteration: 740\n", "0.294067067927773\n", "Iteration: 750\n", "0.2976784178847807\n", "Iteration: 760\n", "0.3004299226139295\n", "Iteration: 770\n", "0.30472914875322443\n", "Iteration: 780\n", "0.30902837489251933\n", "Iteration: 790\n", "0.3113499570077386\n", "Iteration: 800\n", "0.31478933791917457\n", "Iteration: 810\n", "0.3183147033533964\n", "Iteration: 820\n", "0.32089423903697334\n", "Iteration: 830\n", "0.3221840068787618\n", "Iteration: 840\n", "0.32493551160791057\n", "Iteration: 850\n", "0.32639724849527085\n", "Iteration: 860\n", "0.3266552020636285\n", "Iteration: 870\n", "0.32751504729148756\n", "Iteration: 880\n", "0.3289767841788478\n", "Iteration: 890\n", "0.3315563198624248\n", "Iteration: 900\n", "0.3339638865004299\n", "Iteration: 910\n", "0.33568357695614787\n", "Iteration: 920\n", "0.3374032674118659\n", "Iteration: 930\n", "0.33920894239036975\n", "Iteration: 940\n", "0.3423043852106621\n", "Iteration: 950\n", "0.3467755803955288\n", "Iteration: 960\n", "0.3480653482373173\n", "Iteration: 970\n", "0.35081685296646603\n", "Iteration: 980\n", "0.35537403267411866\n", "Iteration: 990\n", "0.358469475494411\n", "Iteration: 1000\n", "0.3609630266552021\n", "Iteration: 1010\n", "0.36328460877042135\n", "Iteration: 1020\n", "0.3650042992261393\n", "Iteration: 1030\n", "0.36723989681857266\n", "Iteration: 1040\n", "0.3699914015477214\n", "Iteration: 1050\n", "0.37265692175408427\n", "Iteration: 1060\n", "0.37463456577815996\n", "Iteration: 1070\n", "0.3761822871883061\n", "Iteration: 1080\n", "0.37850386930352536\n", "Iteration: 1090\n", "0.37962166809974207\n", "Iteration: 1100\n", "0.3816852966466036\n", "Iteration: 1110\n", "0.3840068787618229\n", "Iteration: 1120\n", "0.38564058469475493\n", "Iteration: 1130\n", "0.3885640584694755\n", "Iteration: 1140\n", "0.39165950128976784\n", "Iteration: 1150\n", "0.394926913155632\n", "Iteration: 1160\n", "0.39604471195184865\n", "Iteration: 1170\n", "0.39819432502149615\n", "Iteration: 1180\n", "0.40103181427343076\n", "Iteration: 1190\n", "0.40386930352536543\n", "Iteration: 1200\n", "0.404213241616509\n", "Iteration: 1210\n", "0.40739466895958726\n", "Iteration: 1220\n", "0.40937231298366294\n", "Iteration: 1230\n", "0.4111779879621668\n", "Iteration: 1240\n", "0.41496130696474637\n", "Iteration: 1250\n", "0.4168529664660361\n", "Iteration: 1260\n", "0.4199484092863285\n", "Iteration: 1270\n", "0.42115219260533104\n", "Iteration: 1280\n", "0.423731728288908\n", "Iteration: 1290\n", "0.42613929492691316\n", "Iteration: 1300\n", "0.4278589853826311\n", "Iteration: 1310\n", "0.4311263972484953\n", "Iteration: 1320\n", "0.4322441960447119\n", "Iteration: 1330\n", "0.4345657781599312\n", "Iteration: 1340\n", "0.435597592433362\n", "Iteration: 1350\n", "0.43697334479793637\n", "Iteration: 1360\n", "0.4391229578675838\n", "Iteration: 1370\n", "0.4410146173688736\n", "Iteration: 1380\n", "0.44273430782459156\n", "Iteration: 1390\n", "0.44428202923473775\n", "Iteration: 1400\n", "0.4460017196904557\n", "Iteration: 1410\n", "0.44823731728288907\n", "Iteration: 1420\n", "0.4504729148753224\n", "Iteration: 1430\n", "0.4522785898538263\n", "Iteration: 1440\n", "0.4543422184006879\n", "Iteration: 1450\n", "0.45649183147033534\n", "Iteration: 1460\n", "0.4585554600171969\n", "Iteration: 1470\n", "0.460189165950129\n", "Iteration: 1480\n", "0.46182287188306104\n", "Iteration: 1490\n", "0.46388650042992263\n", "Iteration: 1500\n", "0.466122098022356\n", "Iteration: 1510\n", "0.46732588134135855\n", "Iteration: 1520\n", "0.46844368013757526\n", "Iteration: 1530\n", "0.4696474634565778\n", "Iteration: 1540\n", "0.4706792777300086\n", "Iteration: 1550\n", "0.47231298366294067\n", "Iteration: 1560\n", "0.47386070507308686\n", "Iteration: 1570\n", "0.4760963026655202\n", "Iteration: 1580\n", "0.47807394668959585\n", "Iteration: 1590\n", "0.47910576096302665\n", "Iteration: 1600\n", "0.48134135855546\n", "Iteration: 1610\n", "0.4824591573516767\n", "Iteration: 1620\n", "0.4831470335339639\n", "Iteration: 1630\n", "0.48460877042132416\n", "Iteration: 1640\n", "0.4866723989681857\n", "Iteration: 1650\n", "0.4877042132416165\n", "Iteration: 1660\n", "0.48796216680997423\n", "Iteration: 1670\n", "0.48950988822012037\n", "Iteration: 1680\n", "0.4885640584694755\n", "Iteration: 1690\n", "0.458469475494411\n", "Iteration: 1700\n", "0.4817712811693895\n", "Iteration: 1710\n", "0.5031814273430782\n", "Iteration: 1720\n", "0.5036973344797936\n", "Iteration: 1730\n", "0.4990541702493551\n", "Iteration: 1740\n", "0.48658641444539985\n", "Iteration: 1750\n", "0.4834049871023216\n", "Iteration: 1760\n", "0.49234737747205504\n", "Iteration: 1770\n", "0.5004299226139295\n", "Iteration: 1780\n", "0.504213241616509\n", "Iteration: 1790\n", "0.5081685296646603\n", "Iteration: 1800\n", "0.5116938950988822\n", "Iteration: 1810\n", "0.5149613069647463\n", "Iteration: 1820\n", "0.5170249355116079\n", "Iteration: 1830\n", "0.5184866723989682\n", "Iteration: 1840\n", "0.5189165950128977\n", "Iteration: 1850\n", "0.5210662080825451\n", "Iteration: 1860\n", "0.5173688736027515\n", "Iteration: 1870\n", "0.5042992261392949\n", "Iteration: 1880\n", "0.4950988822012038\n", "Iteration: 1890\n", "0.48538263112639723\n", "Iteration: 1900\n", "0.4998280309544282\n", "Iteration: 1910\n", "0.5120378331900258\n", "Iteration: 1920\n", "0.5133276010318143\n", "Iteration: 1930\n", "0.5149613069647463\n", "Iteration: 1940\n", "0.5182287188306105\n", "Iteration: 1950\n", "0.5208942390369733\n", "Iteration: 1960\n", "0.5234737747205503\n", "Iteration: 1970\n", "0.5251074806534823\n", "Iteration: 1980\n", "0.5267411865864144\n", "Iteration: 1990\n", "0.5280309544282029\n", "Iteration: 2000\n", "0.530438521066208\n", "Iteration: 2010\n", "0.5311263972484953\n", "Iteration: 2020\n", "0.529664660361135\n", "Iteration: 2030\n", "0.525451418744626\n", "Iteration: 2040\n", "0.5149613069647463\n", "Iteration: 2050\n", "0.49140154772141015\n", "Iteration: 2060\n", "0.5261392949269131\n", "Iteration: 2070\n", "0.5496130696474635\n", "Iteration: 2080\n", "0.5528804815133276\n", "Iteration: 2090\n", "0.548237317282889\n", "Iteration: 2100\n", "0.5408426483233018\n", "Iteration: 2110\n", "0.5248495270851247\n", "Iteration: 2120\n", "0.5248495270851247\n", "Iteration: 2130\n", "0.5340498710232158\n", "Iteration: 2140\n", "0.5439380911435941\n", "Iteration: 2150\n", "0.550816852966466\n", "Iteration: 2160\n", "0.5544282029234737\n", "Iteration: 2170\n", "0.5510748065348238\n", "Iteration: 2180\n", "0.5514187446259673\n", "Iteration: 2190\n", "0.5385210662080826\n", "Iteration: 2200\n", "0.542304385210662\n", "Iteration: 2210\n", "0.5588134135855546\n", "Iteration: 2220\n", "0.5655202063628547\n", "Iteration: 2230\n", "0.5571797076526225\n", "Iteration: 2240\n", "0.5384350816852966\n", "Iteration: 2250\n", "0.5430782459157352\n", "Iteration: 2260\n", "0.5548581255374032\n", "Iteration: 2270\n", "0.5577815993121238\n", "Iteration: 2280\n", "0.55932932072227\n", "Iteration: 2290\n", "0.5602751504729149\n", "Iteration: 2300\n", "0.561049011177988\n", "Iteration: 2310\n", "0.5584694754944111\n", "Iteration: 2320\n", "0.5538263112639725\n", "Iteration: 2330\n", "0.5439380911435941\n", "Iteration: 2340\n", "0.5379191745485813\n", "Iteration: 2350\n", "0.5466895958727429\n", "Iteration: 2360\n", "0.5522785898538263\n", "Iteration: 2370\n", "0.5636285468615649\n", "Iteration: 2380\n", "0.574548581255374\n", "Iteration: 2390\n", "0.5803955288048152\n", "Iteration: 2400\n", "0.5831470335339639\n", "Iteration: 2410\n", "0.5859845227858985\n", "Iteration: 2420\n", "0.5879621668099742\n", "Iteration: 2430\n", "0.5848667239896819\n", "Iteration: 2440\n", "0.5764402407566638\n", "Iteration: 2450\n", "0.572828890799656\n", "Iteration: 2460\n", "0.5832330180567498\n", "Iteration: 2470\n", "0.5879621668099742\n", "Iteration: 2480\n", "0.576268271711092\n", "Iteration: 2490\n", "0.5756663800515908\n", "Iteration: 2500\n", "0.582201203783319\n", "Iteration: 2510\n", "0.5882201203783319\n", "Iteration: 2520\n", "0.5885640584694755\n", "Iteration: 2530\n", "0.5852106620808255\n", "Iteration: 2540\n", "0.5783319002579536\n", "Iteration: 2550\n", "0.5766122098022356\n", "Iteration: 2560\n", "0.5803095442820292\n", "Iteration: 2570\n", "0.585640584694755\n", "Iteration: 2580\n", "0.5918314703353397\n", "Iteration: 2590\n", "0.585640584694755\n", "Iteration: 2600\n", "0.5929492691315563\n", "Iteration: 2610\n", "0.605159071367154\n", "Iteration: 2620\n", "0.6088564058469476\n", "Iteration: 2630\n", "0.6080825451418744\n", "Iteration: 2640\n", "0.6078245915735168\n", "Iteration: 2650\n", "0.6065348237317283\n", "Iteration: 2660\n", "0.6094582975064489\n", "Iteration: 2670\n", "0.6084264832330181\n", "Iteration: 2680\n", "0.6133276010318143\n", "Iteration: 2690\n", "0.6077386070507309\n", "Iteration: 2700\n", "0.5948409286328461\n", "Iteration: 2710\n", "0.6030954428202924\n", "Iteration: 2720\n", "0.6112639724849527\n", "Iteration: 2730\n", "0.6117798796216681\n", "Iteration: 2740\n", "0.6086844368013757\n", "Iteration: 2750\n", "0.6091143594153052\n", "Iteration: 2760\n", "0.6121238177128117\n", "Iteration: 2770\n", "0.6153912295786759\n", "Iteration: 2780\n", "0.6151332760103182\n", "Iteration: 2790\n", "0.616938950988822\n", "Iteration: 2800\n", "0.6122957867583835\n", "Iteration: 2810\n", "0.6045571797076527\n", "Iteration: 2820\n", "0.6171109200343938\n", "Iteration: 2830\n", "0.6277730008598452\n", "Iteration: 2840\n", "0.6288907996560619\n", "Iteration: 2850\n", "0.6271711092003439\n", "Iteration: 2860\n", "0.6308684436801376\n", "Iteration: 2870\n", "0.6312983662940671\n", "Iteration: 2880\n", "0.6307824591573516\n", "Iteration: 2890\n", "0.6348237317282889\n", "Iteration: 2900\n", "0.6386070507308684\n", "Iteration: 2910\n", "0.6353396388650043\n", "Iteration: 2920\n", "0.6213241616509029\n", "Iteration: 2930\n", "0.6240756663800516\n", "Iteration: 2940\n", "0.6345657781599312\n", "Iteration: 2950\n", "0.6372312983662941\n", "Iteration: 2960\n", "0.6330180567497851\n", "Iteration: 2970\n", "0.6343938091143594\n", "Iteration: 2980\n", "0.6374892519346518\n", "Iteration: 2990\n", "0.6381771281169389\n", "Iteration: 3000\n", "0.6404987102321582\n", "Iteration: 3010\n", "0.6439380911435941\n", "Iteration: 3020\n", "0.6418744625967326\n", "Iteration: 3030\n", "0.6346517626827171\n", "Iteration: 3040\n", "0.6258813413585554\n", "Iteration: 3050\n", "0.6392949269131556\n", "Iteration: 3060\n", "0.6469475494411006\n", "Iteration: 3070\n", "0.64737747205503\n", "Iteration: 3080\n", "0.6445399828030954\n", "Iteration: 3090\n", "0.6468615649183147\n", "Iteration: 3100\n", "0.6502149613069648\n", "Iteration: 3110\n", "0.6527944969905417\n", "Iteration: 3120\n", "0.6552880481513328\n", "Iteration: 3130\n", "0.6583834909716251\n", "Iteration: 3140\n", "0.660103181427343\n", "Iteration: 3150\n", "0.6502149613069648\n", "Iteration: 3160\n", "0.6441960447119518\n", "Iteration: 3170\n", "0.6499570077386071\n", "Iteration: 3180\n", "0.6577815993121238\n", "Iteration: 3190\n", "0.654170249355116\n", "Iteration: 3200\n", "0.6540842648323302\n", "Iteration: 3210\n", "0.6580395528804815\n", "Iteration: 3220\n", "0.6615649183147033\n", "Iteration: 3230\n", "0.6634565778159931\n", "Iteration: 3240\n", "0.6682717110920035\n", "Iteration: 3250\n", "0.6688736027515048\n", "Iteration: 3260\n", "0.6673258813413585\n", "Iteration: 3270\n", "0.6582115219260533\n", "Iteration: 3280\n", "0.6464316423043852\n", "Iteration: 3290\n", "0.6527944969905417\n", "Iteration: 3300\n", "0.6649183147033534\n", "Iteration: 3310\n", "0.6658641444539983\n", "Iteration: 3320\n", "0.6637145313843508\n", "Iteration: 3330\n", "0.6682717110920035\n", "Iteration: 3340\n", "0.6712811693895099\n", "Iteration: 3350\n", "0.671969045571797\n", "Iteration: 3360\n", "0.6722269991401548\n", "Iteration: 3370\n", "0.6730008598452278\n", "Iteration: 3380\n", "0.6782459157351677\n", "Iteration: 3390\n", "0.6794496990541703\n", "Iteration: 3400\n", "0.6723129836629407\n", "Iteration: 3410\n", "0.6655202063628547\n", "Iteration: 3420\n", "0.6648323301805675\n", "Iteration: 3430\n", "0.677128116938951\n", "Iteration: 3440\n", "0.6760963026655202\n", "Iteration: 3450\n", "0.6699054170249356\n", "Iteration: 3460\n", "0.6776440240756664\n", "Iteration: 3470\n", "0.6831470335339639\n", "Iteration: 3480\n", "0.6845227858985382\n", "Iteration: 3490\n", "0.6829750644883921\n", "Iteration: 3500\n", "0.6809114359415305\n", "Iteration: 3510\n", "0.683061049011178\n", "Iteration: 3520\n", "0.6860705073086845\n", "Iteration: 3530\n", "0.6862424763542563\n", "Iteration: 3540\n", "0.6808254514187446\n", "Iteration: 3550\n", "0.6705932932072227\n", "Iteration: 3560\n", "0.6765262252794497\n", "Iteration: 3570\n", "0.6884780739466896\n", "Iteration: 3580\n", "0.6907996560619089\n", "Iteration: 3590\n", "0.6785898538263113\n", "Iteration: 3600\n", "0.6792777300085985\n", "Iteration: 3610\n", "0.6947549441100602\n", "Iteration: 3620\n", "0.7036113499570077\n", "Iteration: 3630\n", "0.7007738607050731\n", "Iteration: 3640\n", "0.6918314703353396\n", "Iteration: 3650\n", "0.6798796216680998\n", "Iteration: 3660\n", "0.68469475494411\n", "Iteration: 3670\n", "0.6921754084264833\n", "Iteration: 3680\n", "0.697506448839209\n", "Iteration: 3690\n", "0.6960447119518487\n", "Iteration: 3700\n", "0.6905417024935512\n", "Iteration: 3710\n", "0.6876182287188306\n", "Iteration: 3720\n", "0.697506448839209\n", "Iteration: 3730\n", "0.7049011177987962\n", "Iteration: 3740\n", "0.7035253654342218\n", "Iteration: 3750\n", "0.6846087704213242\n", "Iteration: 3760\n", "0.6799656061908856\n", "Iteration: 3770\n", "0.6952708512467756\n", "Iteration: 3780\n", "0.7100601891659502\n", "Iteration: 3790\n", "0.7192605331040413\n", "Iteration: 3800\n", "0.7217540842648323\n", "Iteration: 3810\n", "0.7165950128976785\n", "Iteration: 3820\n", "0.698280309544282\n", "Iteration: 3830\n", "0.6799656061908856\n", "Iteration: 3840\n", "0.678761822871883\n", "Iteration: 3850\n", "0.693207222699914\n", "Iteration: 3860\n", "0.7019776440240757\n", "Iteration: 3870\n", "0.7073946689595872\n", "Iteration: 3880\n", "0.7070507308684437\n", "Iteration: 3890\n", "0.7057609630266553\n", "Iteration: 3900\n", "0.7104041272570937\n", "Iteration: 3910\n", "0.7171109200343938\n", "Iteration: 3920\n", "0.7189165950128977\n", "Iteration: 3930\n", "0.7104041272570937\n", "Iteration: 3940\n", "0.691487532244196\n", "Iteration: 3950\n", "0.6880481513327601\n", "Iteration: 3960\n", "0.7095442820292347\n", "Iteration: 3970\n", "0.722957867583835\n", "Iteration: 3980\n", "0.7315563198624248\n", "Iteration: 3990\n", "0.7358555460017197\n" ] } ], "source": [ "W1, b1, W2, b2 = gradient_descent(X_train, Y_train, 0.10, 4000)\n", "df = pd.DataFrame(acc_store)\n", "df.to_csv('gt_acc.csv', index=False)\n", "np.savez(\"gt_weights\", W1, b1, W2, b2)" ] }, { "cell_type": "code", "execution_count": null, "id": "c8c9bb8d-0379-4b46-941b-410bf0da29ef", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.17" } }, "nbformat": 4, "nbformat_minor": 5 }