{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "25c0d153-288c-4ee8-a968-915f853b8157", "metadata": { "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd \n", "from matplotlib import pyplot as plt \n", "\n", "data = pd.read_csv('cro_data_test.csv')" ] }, { "cell_type": "code", "execution_count": 2, "id": "962cacc2-c818-4c5b-bdab-2ee46c6de511", "metadata": { "tags": [] }, "outputs": [], "source": [ "data = np.array(data)\n", "\n", "m,n = data.shape\n", "data_train = data[1000:m].T\n", "\n", "Y_train = data_train[0].astype(int)\n", "\n", "X_train = data_train[1:n]" ] }, { "cell_type": "code", "execution_count": 3, "id": "e863fe3b-3ee6-42f3-b716-4fcda6a850af", "metadata": { "tags": [] }, "outputs": [], "source": [ "def init_params():\n", " W1 = np.random.rand(10,1024) - 0.5\n", " b1 = np.random.rand(10,1) - 0.5\n", " W2 = np.random.rand(5,10) - 0.5\n", " b2 = np.random.rand(5,1) - 0.5\n", " return W1, b1 , W2, b2" ] }, { "cell_type": "code", "execution_count": 4, "id": "64dd0fba-a49e-4f13-b534-e074350b5f42", "metadata": { "tags": [] }, "outputs": [], "source": [ "def ReLU(Z):\n", " return np.maximum(Z,0)\n", "def softmax(Z):\n", " A = np.exp(Z) / sum(np.exp(Z))\n", " return A\n", "def forward_prop(W1, b1, W2, b2, X):\n", " Z1 = W1.dot(X) + b1\n", " A1 = ReLU(Z1)\n", " Z2 = W2.dot(A1) + b2\n", " A2 = softmax(Z2)\n", " return Z1, A1, Z2, A2\n", "def ReLU_deriv(Z):\n", " return Z > 0\n", "def one_hot(Y):\n", " one_hot_Y = np.zeros((Y.size, Y.max() + 1))\n", " one_hot_Y[np.arange(Y.size), Y] = 1\n", " one_hot_Y = one_hot_Y.T\n", " return one_hot_Y\n", "def backward_prop(Z1, A1, Z2, A2, W1, W2, X, Y):\n", " one_hot_Y = one_hot(Y)\n", " dZ2 = A2 - one_hot_Y\n", " dW2 = 1 / m * dZ2.dot(A1.T)\n", " db2 = 1 / m * np.sum(dZ2)\n", " dZ1 = W2.T.dot(dZ2) * ReLU_deriv(Z1)\n", " dW1 = 1 / m * dZ1.dot(X.T)\n", " db1 = 1 / m * np.sum(dZ1)\n", " return dW1, db1, dW2, db2\n", "def update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha):\n", " W1 = W1 - alpha * dW1\n", " b1 = b1 - alpha * db1 \n", " W2 = W2 - alpha * dW2 \n", " b2 = b2 - alpha * db2 \n", " return W1, b1, W2, b2\n", "def get_predictions(A2):\n", " return np.argmax(A2, 0)\n", "def get_accuracy(predictions, Y):\n", " #print(predictions, Y)\n", " return np.sum(predictions == Y) / Y.size" ] }, { "cell_type": "code", "execution_count": 5, "id": "e7ef6234-254e-47f6-ac29-ddd92d363e9e", "metadata": { "tags": [] }, "outputs": [], "source": [ "acc_store = [] " ] }, { "cell_type": "code", "execution_count": 6, "id": "d24bdd4d-1d57-40b1-a95b-3cc33e02312d", "metadata": { "tags": [] }, "outputs": [], "source": [ "def gradient_descent(X, Y, alpha, iterations):\n", " W1, b1, W2, b2 = init_params()\n", " for i in range(iterations):\n", " Z1, A1, Z2, A2 = forward_prop(W1, b1, W2, b2, X)\n", " dW1, db1, dW2, db2 = backward_prop(Z1, A1, Z2, A2, W1, W2, X, Y)\n", " W1, b1, W2, b2 = update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha)\n", " if i % 10 == 0:\n", " print(\"Iteration: \", i)\n", " predictions = get_predictions(A2)\n", " pred = get_accuracy(predictions, Y)\n", " print(pred)\n", " acc_store.append(pred)\n", " return W1, b1, W2, b2" ] }, { "cell_type": "code", "execution_count": 7, "id": "d266b8d3-8f15-4d89-a896-a728215b048d", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration: 0\n", "0.2016308376575241\n", "Iteration: 10\n", "0.3700889547813195\n", "Iteration: 20\n", "0.45978502594514453\n", "Iteration: 30\n", "0.519644180874722\n", "Iteration: 40\n", "0.4792438843587843\n", "Iteration: 50\n", "0.49833209785025945\n", "Iteration: 60\n", "0.544477390659748\n", "Iteration: 70\n", "0.5804299481097109\n", "Iteration: 80\n", "0.6250926612305412\n", "Iteration: 90\n", "0.6653076352853966\n", "Iteration: 100\n", "0.6955151964418087\n", "Iteration: 110\n", "0.7168272794662713\n", "Iteration: 120\n", "0.7294292068198666\n", "Iteration: 130\n", "0.7366567828020756\n", "Iteration: 140\n", "0.7462935507783544\n", "Iteration: 150\n", "0.7544477390659748\n", "Iteration: 160\n", "0.758524833209785\n", "Iteration: 170\n", "0.7626019273535952\n", "Iteration: 180\n", "0.7677909562638991\n", "Iteration: 190\n", "0.7729799851742031\n", "Iteration: 200\n", "0.7755744996293551\n", "Iteration: 210\n", "0.7792809488510007\n", "Iteration: 220\n", "0.7824314306893996\n", "Iteration: 230\n", "0.7857672349888807\n", "Iteration: 240\n", "0.7889177168272795\n", "Iteration: 250\n", "0.7909562638991846\n", "Iteration: 260\n", "0.7935507783543365\n", "Iteration: 270\n", "0.7968865826538176\n", "Iteration: 280\n", "0.7996664195700519\n", "Iteration: 290\n", "0.8011489992587102\n", "Iteration: 300\n", "0.8026315789473685\n", "Iteration: 310\n", "0.8052260934025204\n", "Iteration: 320\n", "0.8081912527798369\n", "Iteration: 330\n", "0.8096738324684952\n", "Iteration: 340\n", "0.8118977020014826\n", "Iteration: 350\n", "0.8139362490733877\n", "Iteration: 360\n", "0.8148628613787992\n", "Iteration: 370\n", "0.8169014084507042\n", "Iteration: 380\n", "0.8181986656782803\n", "Iteration: 390\n", "0.8191252779836916\n", "Iteration: 400\n", "0.8204225352112676\n", "Iteration: 410\n", "0.8215344699777613\n", "Iteration: 420\n", "0.8224610822831727\n", "Iteration: 430\n", "0.825240919199407\n", "Iteration: 440\n", "0.8259822090437361\n", "Iteration: 450\n", "0.8280207561156412\n", "Iteration: 460\n", "0.8293180133432172\n", "Iteration: 470\n", "0.8306152705707932\n", "Iteration: 480\n", "0.8317272053372868\n", "Iteration: 490\n", "0.8313565604151223\n", "Iteration: 500\n", "0.8317272053372868\n", "Iteration: 510\n", "0.8320978502594515\n", "Iteration: 520\n", "0.8335804299481097\n", "Iteration: 530\n", "0.8348776871756857\n", "Iteration: 540\n", "0.8356189770200149\n", "Iteration: 550\n", "0.836360266864344\n", "Iteration: 560\n", "0.8367309117865085\n", "Iteration: 570\n", "0.8372868791697554\n", "Iteration: 580\n", "0.8395107487027428\n", "Iteration: 590\n", "0.8400667160859896\n", "Iteration: 600\n", "0.8404373610081541\n", "Iteration: 610\n", "0.8393254262416605\n", "Iteration: 620\n", "0.8400667160859896\n", "Iteration: 630\n", "0.8408080059303188\n", "Iteration: 640\n", "0.8408080059303188\n", "Iteration: 650\n", "0.8409933283914011\n", "Iteration: 660\n", "0.8421052631578947\n", "Iteration: 670\n", "0.8432171979243884\n", "Iteration: 680\n", "0.8432171979243884\n", "Iteration: 690\n", "0.843587842846553\n", "Iteration: 700\n", "0.8441438102297999\n", "Iteration: 710\n", "0.8445144551519644\n", "Iteration: 720\n", "0.8445144551519644\n", "Iteration: 730\n", "0.8445144551519644\n", "Iteration: 740\n", "0.844885100074129\n", "Iteration: 750\n", "0.844885100074129\n", "Iteration: 760\n", "0.8452557449962935\n", "Iteration: 770\n", "0.8459970348406227\n", "Iteration: 780\n", "0.8465530022238695\n", "Iteration: 790\n", "0.8467383246849518\n", "Iteration: 800\n", "0.8471089696071163\n", "Iteration: 810\n", "0.8476649369903633\n", "Iteration: 820\n", "0.8491475166790216\n", "Iteration: 830\n", "0.8502594514455152\n", "Iteration: 840\n", "0.8502594514455152\n", "Iteration: 850\n", "0.8511860637509266\n", "Iteration: 860\n", "0.8517420311341735\n", "Iteration: 870\n", "0.8517420311341735\n", "Iteration: 880\n", "0.8519273535952557\n", "Iteration: 890\n", "0.8519273535952557\n", "Iteration: 900\n", "0.8532246108228317\n", "Iteration: 910\n", "0.8539659006671608\n", "Iteration: 920\n", "0.85470719051149\n", "Iteration: 930\n", "0.85470719051149\n", "Iteration: 940\n", "0.8548925129725723\n", "Iteration: 950\n", "0.8556338028169014\n", "Iteration: 960\n", "0.8563750926612306\n", "Iteration: 970\n", "0.8565604151223128\n", "Iteration: 980\n", "0.8571163825055597\n", "Iteration: 990\n", "0.8567457375833951\n", "Iteration: 1000\n", "0.8578576723498889\n", "Iteration: 1010\n", "0.8580429948109711\n", "Iteration: 1020\n", "0.8582283172720534\n", "Iteration: 1030\n", "0.8585989621942179\n", "Iteration: 1040\n", "0.8587842846553002\n", "Iteration: 1050\n", "0.8591549295774648\n", "Iteration: 1060\n", "0.8595255744996294\n", "Iteration: 1070\n", "0.8595255744996294\n", "Iteration: 1080\n", "0.8597108969607117\n", "Iteration: 1090\n", "0.860637509266123\n", "Iteration: 1100\n", "0.8626760563380281\n", "Iteration: 1110\n", "0.8628613787991104\n", "Iteration: 1120\n", "0.8630467012601928\n", "Iteration: 1130\n", "0.863232023721275\n", "Iteration: 1140\n", "0.8630467012601928\n", "Iteration: 1150\n", "0.863232023721275\n", "Iteration: 1160\n", "0.8641586360266864\n", "Iteration: 1170\n", "0.8648999258710156\n", "Iteration: 1180\n", "0.8647146034099333\n", "Iteration: 1190\n", "0.8654558932542624\n", "Iteration: 1200\n", "0.8650852483320979\n", "Iteration: 1210\n", "0.8652705707931801\n", "Iteration: 1220\n", "0.8661971830985915\n", "Iteration: 1230\n", "0.8667531504818384\n", "Iteration: 1240\n", "0.8669384729429207\n", "Iteration: 1250\n", "0.8665678280207562\n", "Iteration: 1260\n", "0.8665678280207562\n", "Iteration: 1270\n", "0.8663825055596739\n", "Iteration: 1280\n", "0.865826538176427\n", "Iteration: 1290\n", "0.8665678280207562\n", "Iteration: 1300\n", "0.8682357301704967\n", "Iteration: 1310\n", "0.8700889547813195\n", "Iteration: 1320\n", "0.8721275018532246\n", "Iteration: 1330\n", "0.8723128243143069\n", "Iteration: 1340\n", "0.8723128243143069\n", "Iteration: 1350\n", "0.8702742772424018\n", "Iteration: 1360\n", "0.8699036323202373\n", "Iteration: 1370\n", "0.8680504077094143\n", "Iteration: 1380\n", "0.8680504077094143\n", "Iteration: 1390\n", "0.8691623424759081\n", "Iteration: 1400\n", "0.8713862120088954\n", "Iteration: 1410\n", "0.873054114158636\n", "Iteration: 1420\n", "0.874351371386212\n", "Iteration: 1430\n", "0.8758339510748703\n", "Iteration: 1440\n", "0.8763899184581171\n", "Iteration: 1450\n", "0.8763899184581171\n", "Iteration: 1460\n", "0.8762045959970348\n", "Iteration: 1470\n", "0.8745366938472943\n", "Iteration: 1480\n", "0.8724981467753892\n", "Iteration: 1490\n", "0.8702742772424018\n", "Iteration: 1500\n", "0.8710155670867309\n", "Iteration: 1510\n", "0.873054114158636\n", "Iteration: 1520\n", "0.8736100815418829\n", "Iteration: 1530\n", "0.8739807264640475\n", "Iteration: 1540\n", "0.8747220163083765\n", "Iteration: 1550\n", "0.8750926612305412\n", "Iteration: 1560\n", "0.8752779836916235\n", "Iteration: 1570\n", "0.8752779836916235\n", "Iteration: 1580\n", "0.8750926612305412\n", "Iteration: 1590\n", "0.8750926612305412\n", "Iteration: 1600\n", "0.8760192735359525\n", "Iteration: 1610\n", "0.876945885841364\n", "Iteration: 1620\n", "0.8775018532246108\n", "Iteration: 1630\n", "0.8778724981467754\n", "Iteration: 1640\n", "0.8784284655300222\n", "Iteration: 1650\n", "0.8782431430689399\n", "Iteration: 1660\n", "0.876945885841364\n", "Iteration: 1670\n", "0.8765752409191994\n", "Iteration: 1680\n", "0.8773165307635286\n", "Iteration: 1690\n", "0.8778724981467754\n", "Iteration: 1700\n", "0.8793550778354337\n", "Iteration: 1710\n", "0.8797257227575982\n", "Iteration: 1720\n", "0.8808376575240919\n", "Iteration: 1730\n", "0.8810229799851742\n", "Iteration: 1740\n", "0.8812083024462565\n", "Iteration: 1750\n", "0.8810229799851742\n", "Iteration: 1760\n", "0.8821349147516679\n", "Iteration: 1770\n", "0.8825055596738325\n", "Iteration: 1780\n", "0.8826908821349148\n", "Iteration: 1790\n", "0.882876204595997\n", "Iteration: 1800\n", "0.8830615270570793\n", "Iteration: 1810\n", "0.8832468495181616\n", "Iteration: 1820\n", "0.8834321719792438\n", "Iteration: 1830\n", "0.8825055596738325\n", "Iteration: 1840\n", "0.8821349147516679\n", "Iteration: 1850\n", "0.8817642698295033\n", "Iteration: 1860\n", "0.8826908821349148\n", "Iteration: 1870\n", "0.8843587842846553\n", "Iteration: 1880\n", "0.8851000741289844\n", "Iteration: 1890\n", "0.8852853965900667\n", "Iteration: 1900\n", "0.8856560415122313\n", "Iteration: 1910\n", "0.8856560415122313\n", "Iteration: 1920\n", "0.8858413639733136\n", "Iteration: 1930\n", "0.8865826538176427\n", "Iteration: 1940\n", "0.8871386212008896\n", "Iteration: 1950\n", "0.8869532987398072\n", "Iteration: 1960\n", "0.8873239436619719\n", "Iteration: 1970\n", "0.888065233506301\n", "Iteration: 1980\n", "0.8882505559673832\n", "Iteration: 1990\n", "0.8871386212008896\n", "Iteration: 2000\n", "0.885470719051149\n", "Iteration: 2010\n", "0.8865826538176427\n", "Iteration: 2020\n", "0.8878799110452187\n", "Iteration: 2030\n", "0.8888065233506302\n", "Iteration: 2040\n", "0.8888065233506302\n", "Iteration: 2050\n", "0.889362490733877\n", "Iteration: 2060\n", "0.8891771682727947\n", "Iteration: 2070\n", "0.888065233506301\n", "Iteration: 2080\n", "0.8886212008895478\n", "Iteration: 2090\n", "0.8891771682727947\n", "Iteration: 2100\n", "0.8904744255003706\n", "Iteration: 2110\n", "0.8910303928836175\n", "Iteration: 2120\n", "0.8908450704225352\n", "Iteration: 2130\n", "0.8910303928836175\n", "Iteration: 2140\n", "0.8914010378057821\n", "Iteration: 2150\n", "0.8917716827279466\n", "Iteration: 2160\n", "0.8915863602668643\n", "Iteration: 2170\n", "0.890659747961453\n", "Iteration: 2180\n", "0.8908450704225352\n", "Iteration: 2190\n", "0.8914010378057821\n", "Iteration: 2200\n", "0.8919570051890289\n", "Iteration: 2210\n", "0.8921423276501111\n", "Iteration: 2220\n", "0.8926982950333581\n", "Iteration: 2230\n", "0.8932542624166049\n", "Iteration: 2240\n", "0.8932542624166049\n", "Iteration: 2250\n", "0.8938102297998517\n", "Iteration: 2260\n", "0.8939955522609341\n", "Iteration: 2270\n", "0.8938102297998517\n", "Iteration: 2280\n", "0.8926982950333581\n", "Iteration: 2290\n", "0.8928836174944403\n", "Iteration: 2300\n", "0.8934395848776872\n", "Iteration: 2310\n", "0.8943661971830986\n", "Iteration: 2320\n", "0.8956634544106745\n", "Iteration: 2330\n", "0.89529280948851\n", "Iteration: 2340\n", "0.8954781319495922\n", "Iteration: 2350\n", "0.8954781319495922\n", "Iteration: 2360\n", "0.8960340993328392\n", "Iteration: 2370\n", "0.8964047442550037\n", "Iteration: 2380\n", "0.8964047442550037\n", "Iteration: 2390\n", "0.8964047442550037\n", "Iteration: 2400\n", "0.8964047442550037\n", "Iteration: 2410\n", "0.896590066716086\n", "Iteration: 2420\n", "0.896590066716086\n", "Iteration: 2430\n", "0.8969607116382505\n", "Iteration: 2440\n", "0.8969607116382505\n", "Iteration: 2450\n", "0.8971460340993328\n", "Iteration: 2460\n", "0.8973313565604151\n", "Iteration: 2470\n", "0.8977020014825797\n", "Iteration: 2480\n", "0.8980726464047443\n", "Iteration: 2490\n", "0.8982579688658265\n", "Iteration: 2500\n", "0.8986286137879911\n", "Iteration: 2510\n", "0.899184581171238\n", "Iteration: 2520\n", "0.899184581171238\n", "Iteration: 2530\n", "0.8989992587101556\n", "Iteration: 2540\n", "0.8989992587101556\n", "Iteration: 2550\n", "0.899184581171238\n", "Iteration: 2560\n", "0.8995552260934025\n", "Iteration: 2570\n", "0.8997405485544848\n", "Iteration: 2580\n", "0.8997405485544848\n", "Iteration: 2590\n", "0.9004818383988139\n", "Iteration: 2600\n", "0.9006671608598962\n", "Iteration: 2610\n", "0.9010378057820608\n", "Iteration: 2620\n", "0.9010378057820608\n", "Iteration: 2630\n", "0.9008524833209784\n", "Iteration: 2640\n", "0.9004818383988139\n", "Iteration: 2650\n", "0.9008524833209784\n", "Iteration: 2660\n", "0.9012231282431431\n", "Iteration: 2670\n", "0.9017790956263899\n", "Iteration: 2680\n", "0.9023350630096367\n", "Iteration: 2690\n", "0.902520385470719\n", "Iteration: 2700\n", "0.9027057079318014\n", "Iteration: 2710\n", "0.9028910303928837\n", "Iteration: 2720\n", "0.9030763528539659\n", "Iteration: 2730\n", "0.9030763528539659\n", "Iteration: 2740\n", "0.9030763528539659\n", "Iteration: 2750\n", "0.9027057079318014\n", "Iteration: 2760\n", "0.9015937731653076\n", "Iteration: 2770\n", "0.9015937731653076\n", "Iteration: 2780\n", "0.9014084507042254\n", "Iteration: 2790\n", "0.9015937731653076\n", "Iteration: 2800\n", "0.9030763528539659\n", "Iteration: 2810\n", "0.903817642698295\n", "Iteration: 2820\n", "0.9047442550037065\n", "Iteration: 2830\n", "0.9056708673091178\n", "Iteration: 2840\n", "0.906412157153447\n", "Iteration: 2850\n", "0.906412157153447\n", "Iteration: 2860\n", "0.9047442550037065\n", "Iteration: 2870\n", "0.9040029651593773\n", "Iteration: 2880\n", "0.9030763528539659\n", "Iteration: 2890\n", "0.903817642698295\n", "Iteration: 2900\n", "0.9049295774647887\n", "Iteration: 2910\n", "0.9047442550037065\n", "Iteration: 2920\n", "0.9053002223869533\n", "Iteration: 2930\n", "0.9058561897702001\n", "Iteration: 2940\n", "0.9065974796145293\n", "Iteration: 2950\n", "0.9073387694588584\n", "Iteration: 2960\n", "0.9086360266864344\n", "Iteration: 2970\n", "0.9091919940696812\n", "Iteration: 2980\n", "0.9091919940696812\n", "Iteration: 2990\n", "0.9091919940696812\n", "Iteration: 3000\n", "0.9095626389918459\n", "Iteration: 3010\n", "0.9073387694588584\n", "Iteration: 3020\n", "0.9027057079318014\n", "Iteration: 3030\n", "0.9023350630096367\n", "Iteration: 3040\n", "0.9049295774647887\n", "Iteration: 3050\n", "0.9062268346923648\n", "Iteration: 3060\n", "0.9080800593031876\n", "Iteration: 3070\n", "0.9080800593031876\n", "Iteration: 3080\n", "0.9082653817642699\n", "Iteration: 3090\n", "0.9082653817642699\n", "Iteration: 3100\n", "0.9088213491475167\n", "Iteration: 3110\n", "0.9097479614529281\n", "Iteration: 3120\n", "0.911045218680504\n", "Iteration: 3130\n", "0.9119718309859155\n", "Iteration: 3140\n", "0.91234247590808\n", "Iteration: 3150\n", "0.9125277983691623\n", "Iteration: 3160\n", "0.9117865085248332\n", "Iteration: 3170\n", "0.9086360266864344\n", "Iteration: 3180\n", "0.905114899925871\n", "Iteration: 3190\n", "0.9032616753150482\n", "Iteration: 3200\n", "0.9054855448480356\n", "Iteration: 3210\n", "0.9075240919199407\n", "Iteration: 3220\n", "0.9106745737583395\n", "Iteration: 3230\n", "0.9121571534469978\n", "Iteration: 3240\n", "0.9134544106745738\n", "Iteration: 3250\n", "0.9141957005189029\n", "Iteration: 3260\n", "0.9140103780578206\n", "Iteration: 3270\n", "0.9141957005189029\n", "Iteration: 3280\n", "0.9132690882134915\n", "Iteration: 3290\n", "0.9093773165307635\n", "Iteration: 3300\n", "0.9071534469977761\n", "Iteration: 3310\n", "0.9090066716085989\n", "Iteration: 3320\n", "0.9106745737583395\n", "Iteration: 3330\n", "0.911045218680504\n", "Iteration: 3340\n", "0.9127131208302446\n", "Iteration: 3350\n", "0.9141957005189029\n", "Iteration: 3360\n", "0.9145663454410674\n", "Iteration: 3370\n", "0.9147516679021498\n", "Iteration: 3380\n", "0.9149369903632321\n", "Iteration: 3390\n", "0.913639733135656\n", "Iteration: 3400\n", "0.9117865085248332\n", "Iteration: 3410\n", "0.9106745737583395\n", "Iteration: 3420\n", "0.9106745737583395\n", "Iteration: 3430\n", "0.911601186063751\n", "Iteration: 3440\n", "0.9141957005189029\n", "Iteration: 3450\n", "0.9153076352853966\n", "Iteration: 3460\n", "0.9169755374351372\n", "Iteration: 3470\n", "0.9171608598962194\n", "Iteration: 3480\n", "0.9173461823573017\n", "Iteration: 3490\n", "0.9153076352853966\n", "Iteration: 3500\n", "0.9132690882134915\n", "Iteration: 3510\n", "0.9119718309859155\n", "Iteration: 3520\n", "0.9121571534469978\n", "Iteration: 3530\n", "0.9140103780578206\n", "Iteration: 3540\n", "0.9149369903632321\n", "Iteration: 3550\n", "0.9167902149740549\n", "Iteration: 3560\n", "0.917531504818384\n", "Iteration: 3570\n", "0.9191994069681245\n", "Iteration: 3580\n", "0.9195700518902891\n", "Iteration: 3590\n", "0.9195700518902891\n", "Iteration: 3600\n", "0.9191994069681245\n", "Iteration: 3610\n", "0.9186434395848777\n", "Iteration: 3620\n", "0.9164195700518903\n", "Iteration: 3630\n", "0.911045218680504\n", "Iteration: 3640\n", "0.9106745737583395\n", "Iteration: 3650\n", "0.9128984432913269\n", "Iteration: 3660\n", "0.9160489251297257\n", "Iteration: 3670\n", "0.9184581171237954\n", "Iteration: 3680\n", "0.9190140845070423\n", "Iteration: 3690\n", "0.9204966641957005\n", "Iteration: 3700\n", "0.9206819866567828\n", "Iteration: 3710\n", "0.9203113417346183\n", "Iteration: 3720\n", "0.9203113417346183\n", "Iteration: 3730\n", "0.9197553743513713\n", "Iteration: 3740\n", "0.9195700518902891\n", "Iteration: 3750\n", "0.9156782802075611\n", "Iteration: 3760\n", "0.91234247590808\n", "Iteration: 3770\n", "0.9127131208302446\n", "Iteration: 3780\n", "0.9158636026686434\n", "Iteration: 3790\n", "0.9179021497405485\n", "Iteration: 3800\n", "0.9190140845070423\n", "Iteration: 3810\n", "0.920126019273536\n", "Iteration: 3820\n", "0.9206819866567828\n", "Iteration: 3830\n", "0.9214232765011119\n", "Iteration: 3840\n", "0.9212379540400296\n", "Iteration: 3850\n", "0.9212379540400296\n", "Iteration: 3860\n", "0.9210526315789473\n", "Iteration: 3870\n", "0.9193847294292068\n", "Iteration: 3880\n", "0.9169755374351372\n", "Iteration: 3890\n", "0.9154929577464789\n", "Iteration: 3900\n", "0.9164195700518903\n", "Iteration: 3910\n", "0.9180874722016308\n", "Iteration: 3920\n", "0.920126019273536\n", "Iteration: 3930\n", "0.9219792438843588\n", "Iteration: 3940\n", "0.9225352112676056\n", "Iteration: 3950\n", "0.9225352112676056\n", "Iteration: 3960\n", "0.9223498888065234\n", "Iteration: 3970\n", "0.9208673091178651\n", "Iteration: 3980\n", "0.9195700518902891\n", "Iteration: 3990\n", "0.9186434395848777\n" ] } ], "source": [ "W1, b1, W2, b2 = gradient_descent(X_train, Y_train, 0.10, 4000)\n", "df = pd.DataFrame(acc_store)\n", "df.to_csv('cr_acc.csv', index=False)\n", "np.savez(\"cr_weights\", W1, b1, W2, b2)" ] }, { "cell_type": "code", "execution_count": null, "id": "11203f4e-4adf-4a47-a6e2-a8847f27f0cc", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.17" } }, "nbformat": 4, "nbformat_minor": 5 }