{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "c4460aee-ec58-454e-9304-82deb89942b4", "metadata": { "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd \n", "from matplotlib import pyplot as plt \n", "\n", "data = pd.read_csv('data/bel_data_test.csv')" ] }, { "cell_type": "code", "execution_count": 2, "id": "f1ec44a2-7bb0-44a1-a417-69ab70e1b6f8", "metadata": { "tags": [] }, "outputs": [], "source": [ "data = np.array(data)\n", "\n", "m,n = data.shape\n", "data_train = data[1000:m].T\n", "\n", "Y_train = data_train[0].astype(int)\n", "\n", "X_train = data_train[1:n]" ] }, { "cell_type": "code", "execution_count": 3, "id": "9223239b-35b9-43bc-acf9-dac6a19587c5", "metadata": { "tags": [] }, "outputs": [], "source": [ "def init_params():\n", " W1 = np.random.rand(10,1024) - 0.5\n", " b1 = np.random.rand(10,1) - 0.5\n", " W2 = np.random.rand(61,10) - 0.5\n", " b2 = np.random.rand(61,1) - 0.5\n", " return W1, b1 , W2, b2" ] }, { "cell_type": "code", "execution_count": 4, "id": "610cd614-b3c9-445e-8f3d-491596bc773c", "metadata": { "tags": [] }, "outputs": [], "source": [ "def ReLU(Z):\n", " return np.maximum(Z,0)\n", "\n", "def softmax(Z):\n", " A = np.exp(Z) / sum(np.exp(Z))\n", " return A\n", "\n", "def forward_prop(W1, b1, W2, b2, X):\n", " Z1 = W1.dot(X) + b1\n", " A1 = ReLU(Z1)\n", " Z2 = W2.dot(A1) + b2\n", " A2 = softmax(Z2)\n", " return Z1, A1, Z2, A2\n", "\n", "def ReLU_deriv(Z):\n", " return Z > 0\n", "\n", "def one_hot(Y):\n", " one_hot_Y = np.zeros((Y.size, Y.max() + 1))\n", " one_hot_Y[np.arange(Y.size), Y] = 1\n", " one_hot_Y = one_hot_Y.T\n", " return one_hot_Y\n", "\n", "def backward_prop(Z1, A1, Z2, A2, W1, W2, X, Y):\n", " one_hot_Y = one_hot(Y)\n", " dZ2 = A2 - one_hot_Y\n", " dW2 = 1 / m * dZ2.dot(A1.T)\n", " db2 = 1 / m * np.sum(dZ2)\n", " dZ1 = W2.T.dot(dZ2) * ReLU_deriv(Z1)\n", " dW1 = 1 / m * dZ1.dot(X.T)\n", " db1 = 1 / m * np.sum(dZ1)\n", " return dW1, db1, dW2, db2\n", "\n", "def update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha):\n", " W1 = W1 - alpha * dW1\n", " b1 = b1 - alpha * db1 \n", " W2 = W2 - alpha * dW2 \n", " b2 = b2 - alpha * db2 \n", " return W1, b1, W2, b2\n", "\n", "def get_predictions(A2):\n", " return np.argmax(A2, 0)\n", "\n", "def get_accuracy(predictions, Y):\n", " #print(predictions, Y)\n", " return np.sum(predictions == Y) / Y.size" ] }, { "cell_type": "code", "execution_count": 5, "id": "287c3525-e1a0-41c4-a1ae-9c24c805c7c5", "metadata": { "tags": [] }, "outputs": [], "source": [ "acc_store = [] " ] }, { "cell_type": "code", "execution_count": 6, "id": "539820c9-dc35-4f24-bbff-7860f7442c19", "metadata": { "tags": [] }, "outputs": [], "source": [ "def gradient_descent(X, Y, alpha, iterations):\n", " W1, b1, W2, b2 = init_params()\n", " for i in range(iterations):\n", " Z1, A1, Z2, A2 = forward_prop(W1, b1, W2, b2, X)\n", " dW1, db1, dW2, db2 = backward_prop(Z1, A1, Z2, A2, W1, W2, X, Y)\n", " W1, b1, W2, b2 = update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha)\n", " if i % 10 == 0:\n", " print(\"Iteration: \", i)\n", " predictions = get_predictions(A2)\n", " pred = get_accuracy(predictions, Y)\n", " print(pred)\n", " acc_store.append(pred)\n", " return W1, b1, W2, b2" ] }, { "cell_type": "code", "execution_count": 7, "id": "06a5f975-472b-4b98-99e2-c1fff036ddc7", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration: 0\n", "0.014825174825174826\n", "Iteration: 10\n", "0.06657342657342658\n", "Iteration: 20\n", "0.12055944055944055\n", "Iteration: 30\n", "0.1476923076923077\n", "Iteration: 40\n", "0.16895104895104895\n", "Iteration: 50\n", "0.19552447552447552\n", "Iteration: 60\n", "0.20251748251748253\n", "Iteration: 70\n", "0.2095104895104895\n", "Iteration: 80\n", "0.21818181818181817\n", "Iteration: 90\n", "0.2274125874125874\n", "Iteration: 100\n", "0.23608391608391607\n", "Iteration: 110\n", "0.24503496503496502\n", "Iteration: 120\n", "0.26153846153846155\n", "Iteration: 130\n", "0.2732867132867133\n", "Iteration: 140\n", "0.28895104895104895\n", "Iteration: 150\n", "0.3026573426573427\n", "Iteration: 160\n", "0.3144055944055944\n", "Iteration: 170\n", "0.3225174825174825\n", "Iteration: 180\n", "0.33146853146853145\n", "Iteration: 190\n", "0.34965034965034963\n", "Iteration: 200\n", "0.3655944055944056\n", "Iteration: 210\n", "0.41426573426573426\n", "Iteration: 220\n", "0.4483916083916084\n", "Iteration: 230\n", "0.4662937062937063\n", "Iteration: 240\n", "0.47944055944055947\n", "Iteration: 250\n", "0.4967832167832168\n", "Iteration: 260\n", "0.5096503496503496\n", "Iteration: 270\n", "0.5216783216783217\n", "Iteration: 280\n", "0.5342657342657343\n", "Iteration: 290\n", "0.544055944055944\n", "Iteration: 300\n", "0.5546853146853147\n", "Iteration: 310\n", "0.5613986013986014\n", "Iteration: 320\n", "0.5667132867132867\n", "Iteration: 330\n", "0.5728671328671329\n", "Iteration: 340\n", "0.579020979020979\n", "Iteration: 350\n", "0.5865734265734266\n", "Iteration: 360\n", "0.593006993006993\n", "Iteration: 370\n", "0.6008391608391609\n", "Iteration: 380\n", "0.6044755244755244\n", "Iteration: 390\n", "0.6092307692307692\n", "Iteration: 400\n", "0.6156643356643356\n", "Iteration: 410\n", "0.6223776223776224\n", "Iteration: 420\n", "0.6310489510489511\n", "Iteration: 430\n", "0.6391608391608392\n", "Iteration: 440\n", "0.6486713286713287\n", "Iteration: 450\n", "0.6573426573426573\n", "Iteration: 460\n", "0.6618181818181819\n", "Iteration: 470\n", "0.6665734265734266\n", "Iteration: 480\n", "0.6707692307692308\n", "Iteration: 490\n", "0.6738461538461539\n", "Iteration: 500\n", "0.676923076923077\n", "Iteration: 510\n", "0.6836363636363636\n", "Iteration: 520\n", "0.6900699300699301\n", "Iteration: 530\n", "0.6942657342657342\n", "Iteration: 540\n", "0.6987412587412587\n", "Iteration: 550\n", "0.7026573426573427\n", "Iteration: 560\n", "0.7048951048951049\n", "Iteration: 570\n", "0.7090909090909091\n", "Iteration: 580\n", "0.7113286713286713\n", "Iteration: 590\n", "0.7163636363636363\n", "Iteration: 600\n", "0.7230769230769231\n", "Iteration: 610\n", "0.7264335664335664\n", "Iteration: 620\n", "0.7297902097902098\n", "Iteration: 630\n", "0.7337062937062937\n", "Iteration: 640\n", "0.7367832167832168\n", "Iteration: 650\n", "0.7412587412587412\n", "Iteration: 660\n", "0.7446153846153846\n", "Iteration: 670\n", "0.7468531468531469\n", "Iteration: 680\n", "0.7510489510489511\n", "Iteration: 690\n", "0.7546853146853146\n", "Iteration: 700\n", "0.7560839160839161\n", "Iteration: 710\n", "0.76\n", "Iteration: 720\n", "0.7622377622377622\n", "Iteration: 730\n", "0.7647552447552447\n", "Iteration: 740\n", "0.7664335664335664\n", "Iteration: 750\n", "0.7695104895104895\n", "Iteration: 760\n", "0.7711888111888112\n", "Iteration: 770\n", "0.7731468531468532\n", "Iteration: 780\n", "0.777062937062937\n", "Iteration: 790\n", "0.7784615384615384\n", "Iteration: 800\n", "0.7801398601398601\n", "Iteration: 810\n", "0.7812587412587413\n", "Iteration: 820\n", "0.784055944055944\n", "Iteration: 830\n", "0.7865734265734265\n", "Iteration: 840\n", "0.7885314685314685\n", "Iteration: 850\n", "0.7904895104895104\n", "Iteration: 860\n", "0.7927272727272727\n", "Iteration: 870\n", "0.796083916083916\n", "Iteration: 880\n", "0.7974825174825175\n", "Iteration: 890\n", "0.7986013986013986\n", "Iteration: 900\n", "0.7994405594405595\n", "Iteration: 910\n", "0.8008391608391608\n", "Iteration: 920\n", "0.8022377622377622\n", "Iteration: 930\n", "0.8036363636363636\n", "Iteration: 940\n", "0.8041958041958042\n", "Iteration: 950\n", "0.8055944055944056\n", "Iteration: 960\n", "0.8072727272727273\n", "Iteration: 970\n", "0.8095104895104895\n", "Iteration: 980\n", "0.8109090909090909\n", "Iteration: 990\n", "0.8131468531468532\n", "Iteration: 1000\n", "0.813986013986014\n", "Iteration: 1010\n", "0.8145454545454546\n", "Iteration: 1020\n", "0.8156643356643357\n", "Iteration: 1030\n", "0.8170629370629371\n", "Iteration: 1040\n", "0.8198601398601398\n", "Iteration: 1050\n", "0.8204195804195804\n", "Iteration: 1060\n", "0.8215384615384616\n", "Iteration: 1070\n", "0.8226573426573427\n", "Iteration: 1080\n", "0.8234965034965035\n", "Iteration: 1090\n", "0.8237762237762237\n", "Iteration: 1100\n", "0.8251748251748252\n", "Iteration: 1110\n", "0.8254545454545454\n", "Iteration: 1120\n", "0.8265734265734266\n", "Iteration: 1130\n", "0.8274125874125874\n", "Iteration: 1140\n", "0.8282517482517483\n", "Iteration: 1150\n", "0.8285314685314685\n", "Iteration: 1160\n", "0.8293706293706293\n", "Iteration: 1170\n", "0.8296503496503497\n", "Iteration: 1180\n", "0.8302097902097902\n", "Iteration: 1190\n", "0.831048951048951\n", "Iteration: 1200\n", "0.8318881118881118\n", "Iteration: 1210\n", "0.8327272727272728\n", "Iteration: 1220\n", "0.833006993006993\n", "Iteration: 1230\n", "0.8346853146853147\n", "Iteration: 1240\n", "0.8355244755244755\n", "Iteration: 1250\n", "0.8366433566433567\n", "Iteration: 1260\n", "0.8374825174825175\n", "Iteration: 1270\n", "0.838041958041958\n", "Iteration: 1280\n", "0.8397202797202797\n", "Iteration: 1290\n", "0.8411188811188811\n", "Iteration: 1300\n", "0.8416783216783217\n", "Iteration: 1310\n", "0.8436363636363636\n", "Iteration: 1320\n", "0.845034965034965\n", "Iteration: 1330\n", "0.8458741258741259\n", "Iteration: 1340\n", "0.8467132867132867\n", "Iteration: 1350\n", "0.8483916083916084\n", "Iteration: 1360\n", "0.848951048951049\n", "Iteration: 1370\n", "0.8497902097902098\n", "Iteration: 1380\n", "0.850909090909091\n", "Iteration: 1390\n", "0.8514685314685315\n", "Iteration: 1400\n", "0.852027972027972\n", "Iteration: 1410\n", "0.8528671328671329\n", "Iteration: 1420\n", "0.8531468531468531\n", "Iteration: 1430\n", "0.8537062937062937\n", "Iteration: 1440\n", "0.8542657342657343\n", "Iteration: 1450\n", "0.8553846153846154\n", "Iteration: 1460\n", "0.8565034965034966\n", "Iteration: 1470\n", "0.8573426573426574\n", "Iteration: 1480\n", "0.8576223776223776\n", "Iteration: 1490\n", "0.8587412587412587\n", "Iteration: 1500\n", "0.8593006993006993\n", "Iteration: 1510\n", "0.8601398601398601\n", "Iteration: 1520\n", "0.8618181818181818\n", "Iteration: 1530\n", "0.8623776223776224\n", "Iteration: 1540\n", "0.8632167832167832\n", "Iteration: 1550\n", "0.8646153846153846\n", "Iteration: 1560\n", "0.8654545454545455\n", "Iteration: 1570\n", "0.866013986013986\n", "Iteration: 1580\n", "0.8665734265734266\n", "Iteration: 1590\n", "0.8676923076923077\n", "Iteration: 1600\n", "0.8676923076923077\n", "Iteration: 1610\n", "0.8682517482517482\n", "Iteration: 1620\n", "0.8690909090909091\n", "Iteration: 1630\n", "0.8693706293706294\n", "Iteration: 1640\n", "0.8699300699300699\n", "Iteration: 1650\n", "0.8713286713286713\n", "Iteration: 1660\n", "0.8721678321678321\n", "Iteration: 1670\n", "0.8724475524475525\n", "Iteration: 1680\n", "0.8732867132867133\n", "Iteration: 1690\n", "0.8735664335664336\n", "Iteration: 1700\n", "0.8738461538461538\n", "Iteration: 1710\n", "0.8755244755244755\n", "Iteration: 1720\n", "0.8763636363636363\n", "Iteration: 1730\n", "0.8772027972027973\n", "Iteration: 1740\n", "0.8777622377622377\n", "Iteration: 1750\n", "0.8780419580419581\n", "Iteration: 1760\n", "0.8786013986013986\n", "Iteration: 1770\n", "0.8791608391608392\n", "Iteration: 1780\n", "0.8797202797202798\n", "Iteration: 1790\n", "0.8802797202797202\n", "Iteration: 1800\n", "0.8805594405594406\n", "Iteration: 1810\n", "0.8811188811188811\n", "Iteration: 1820\n", "0.881958041958042\n", "Iteration: 1830\n", "0.8822377622377623\n", "Iteration: 1840\n", "0.8830769230769231\n", "Iteration: 1850\n", "0.8839160839160839\n", "Iteration: 1860\n", "0.8841958041958042\n", "Iteration: 1870\n", "0.8839160839160839\n", "Iteration: 1880\n", "0.8844755244755245\n", "Iteration: 1890\n", "0.8847552447552448\n", "Iteration: 1900\n", "0.885034965034965\n", "Iteration: 1910\n", "0.885034965034965\n", "Iteration: 1920\n", "0.8853146853146853\n", "Iteration: 1930\n", "0.8861538461538462\n", "Iteration: 1940\n", "0.886993006993007\n", "Iteration: 1950\n", "0.8875524475524476\n", "Iteration: 1960\n", "0.8881118881118881\n", "Iteration: 1970\n", "0.8886713286713287\n", "Iteration: 1980\n", "0.8897902097902098\n", "Iteration: 1990\n", "0.8900699300699301\n", "Iteration: 2000\n", "0.8900699300699301\n", "Iteration: 2010\n", "0.8909090909090909\n", "Iteration: 2020\n", "0.8911888111888112\n", "Iteration: 2030\n", "0.8914685314685314\n", "Iteration: 2040\n", "0.8917482517482518\n", "Iteration: 2050\n", "0.8925874125874126\n", "Iteration: 2060\n", "0.8928671328671328\n", "Iteration: 2070\n", "0.8931468531468532\n", "Iteration: 2080\n", "0.8934265734265734\n", "Iteration: 2090\n", "0.893986013986014\n", "Iteration: 2100\n", "0.8942657342657343\n", "Iteration: 2110\n", "0.8948251748251749\n", "Iteration: 2120\n", "0.8953846153846153\n", "Iteration: 2130\n", "0.8956643356643357\n", "Iteration: 2140\n", "0.8962237762237762\n", "Iteration: 2150\n", "0.8967832167832168\n", "Iteration: 2160\n", "0.897062937062937\n", "Iteration: 2170\n", "0.8973426573426574\n", "Iteration: 2180\n", "0.8976223776223776\n", "Iteration: 2190\n", "0.8979020979020979\n", "Iteration: 2200\n", "0.8984615384615384\n", "Iteration: 2210\n", "0.8993006993006993\n", "Iteration: 2220\n", "0.8995804195804196\n", "Iteration: 2230\n", "0.8998601398601399\n", "Iteration: 2240\n", "0.9001398601398601\n", "Iteration: 2250\n", "0.9001398601398601\n", "Iteration: 2260\n", "0.9004195804195804\n", "Iteration: 2270\n", "0.9012587412587413\n", "Iteration: 2280\n", "0.9015384615384615\n", "Iteration: 2290\n", "0.9020979020979021\n", "Iteration: 2300\n", "0.9020979020979021\n", "Iteration: 2310\n", "0.9020979020979021\n", "Iteration: 2320\n", "0.9023776223776224\n", "Iteration: 2330\n", "0.9032167832167832\n", "Iteration: 2340\n", "0.904055944055944\n", "Iteration: 2350\n", "0.9043356643356644\n", "Iteration: 2360\n", "0.9046153846153846\n", "Iteration: 2370\n", "0.9046153846153846\n", "Iteration: 2380\n", "0.9054545454545454\n", "Iteration: 2390\n", "0.906013986013986\n", "Iteration: 2400\n", "0.9071328671328671\n", "Iteration: 2410\n", "0.9074125874125875\n", "Iteration: 2420\n", "0.9079720279720279\n", "Iteration: 2430\n", "0.9082517482517483\n", "Iteration: 2440\n", "0.9085314685314685\n", "Iteration: 2450\n", "0.9085314685314685\n", "Iteration: 2460\n", "0.9085314685314685\n", "Iteration: 2470\n", "0.9082517482517483\n", "Iteration: 2480\n", "0.9093706293706294\n", "Iteration: 2490\n", "0.9093706293706294\n", "Iteration: 2500\n", "0.9096503496503496\n", "Iteration: 2510\n", "0.9096503496503496\n", "Iteration: 2520\n", "0.9102097902097902\n", "Iteration: 2530\n", "0.9102097902097902\n", "Iteration: 2540\n", "0.9102097902097902\n", "Iteration: 2550\n", "0.9107692307692308\n", "Iteration: 2560\n", "0.911048951048951\n", "Iteration: 2570\n", "0.9113286713286713\n", "Iteration: 2580\n", "0.9116083916083916\n", "Iteration: 2590\n", "0.9118881118881119\n", "Iteration: 2600\n", "0.9127272727272727\n", "Iteration: 2610\n", "0.9132867132867133\n", "Iteration: 2620\n", "0.9135664335664335\n", "Iteration: 2630\n", "0.9135664335664335\n", "Iteration: 2640\n", "0.9144055944055944\n", "Iteration: 2650\n", "0.9146853146853147\n", "Iteration: 2660\n", "0.9146853146853147\n", "Iteration: 2670\n", "0.9146853146853147\n", "Iteration: 2680\n", "0.9155244755244756\n", "Iteration: 2690\n", "0.9155244755244756\n", "Iteration: 2700\n", "0.9158041958041958\n", "Iteration: 2710\n", "0.9166433566433566\n", "Iteration: 2720\n", "0.916923076923077\n", "Iteration: 2730\n", "0.9172027972027972\n", "Iteration: 2740\n", "0.9177622377622378\n", "Iteration: 2750\n", "0.9177622377622378\n", "Iteration: 2760\n", "0.9183216783216783\n", "Iteration: 2770\n", "0.9186013986013986\n", "Iteration: 2780\n", "0.9188811188811189\n", "Iteration: 2790\n", "0.9194405594405595\n", "Iteration: 2800\n", "0.9194405594405595\n", "Iteration: 2810\n", "0.9197202797202797\n", "Iteration: 2820\n", "0.92\n", "Iteration: 2830\n", "0.9202797202797203\n", "Iteration: 2840\n", "0.9202797202797203\n", "Iteration: 2850\n", "0.9202797202797203\n", "Iteration: 2860\n", "0.9202797202797203\n", "Iteration: 2870\n", "0.9202797202797203\n", "Iteration: 2880\n", "0.9205594405594406\n", "Iteration: 2890\n", "0.9213986013986014\n", "Iteration: 2900\n", "0.9216783216783216\n", "Iteration: 2910\n", "0.9216783216783216\n", "Iteration: 2920\n", "0.9216783216783216\n", "Iteration: 2930\n", "0.9216783216783216\n", "Iteration: 2940\n", "0.9216783216783216\n", "Iteration: 2950\n", "0.921958041958042\n", "Iteration: 2960\n", "0.921958041958042\n", "Iteration: 2970\n", "0.9225174825174826\n", "Iteration: 2980\n", "0.9225174825174826\n", "Iteration: 2990\n", "0.9227972027972028\n", "Iteration: 3000\n", "0.9225174825174826\n", "Iteration: 3010\n", "0.9230769230769231\n", "Iteration: 3020\n", "0.9233566433566434\n", "Iteration: 3030\n", "0.9236363636363636\n", "Iteration: 3040\n", "0.9236363636363636\n", "Iteration: 3050\n", "0.9236363636363636\n", "Iteration: 3060\n", "0.9236363636363636\n", "Iteration: 3070\n", "0.9236363636363636\n", "Iteration: 3080\n", "0.9239160839160839\n", "Iteration: 3090\n", "0.9241958041958042\n", "Iteration: 3100\n", "0.9250349650349651\n", "Iteration: 3110\n", "0.9250349650349651\n", "Iteration: 3120\n", "0.9253146853146853\n", "Iteration: 3130\n", "0.9255944055944056\n", "Iteration: 3140\n", "0.9258741258741259\n", "Iteration: 3150\n", "0.9258741258741259\n", "Iteration: 3160\n", "0.9258741258741259\n", "Iteration: 3170\n", "0.9258741258741259\n", "Iteration: 3180\n", "0.9255944055944056\n", "Iteration: 3190\n", "0.9258741258741259\n", "Iteration: 3200\n", "0.9258741258741259\n", "Iteration: 3210\n", "0.9261538461538461\n", "Iteration: 3220\n", "0.9261538461538461\n", "Iteration: 3230\n", "0.9267132867132867\n", "Iteration: 3240\n", "0.9275524475524476\n", "Iteration: 3250\n", "0.9275524475524476\n", "Iteration: 3260\n", "0.9275524475524476\n", "Iteration: 3270\n", "0.9275524475524476\n", "Iteration: 3280\n", "0.9275524475524476\n", "Iteration: 3290\n", "0.9278321678321678\n", "Iteration: 3300\n", "0.9281118881118882\n", "Iteration: 3310\n", "0.9283916083916084\n", "Iteration: 3320\n", "0.9283916083916084\n", "Iteration: 3330\n", "0.9283916083916084\n", "Iteration: 3340\n", "0.9283916083916084\n", "Iteration: 3350\n", "0.9286713286713286\n", "Iteration: 3360\n", "0.9286713286713286\n", "Iteration: 3370\n", "0.928951048951049\n", "Iteration: 3380\n", "0.928951048951049\n", "Iteration: 3390\n", "0.928951048951049\n", "Iteration: 3400\n", "0.928951048951049\n", "Iteration: 3410\n", "0.928951048951049\n", "Iteration: 3420\n", "0.928951048951049\n", "Iteration: 3430\n", "0.9295104895104895\n", "Iteration: 3440\n", "0.9297902097902098\n", "Iteration: 3450\n", "0.9295104895104895\n", "Iteration: 3460\n", "0.9295104895104895\n", "Iteration: 3470\n", "0.9297902097902098\n", "Iteration: 3480\n", "0.9300699300699301\n", "Iteration: 3490\n", "0.9303496503496503\n", "Iteration: 3500\n", "0.9303496503496503\n", "Iteration: 3510\n", "0.9306293706293707\n", "Iteration: 3520\n", "0.9306293706293707\n", "Iteration: 3530\n", "0.9309090909090909\n", "Iteration: 3540\n", "0.9311888111888111\n", "Iteration: 3550\n", "0.9314685314685315\n", "Iteration: 3560\n", "0.9317482517482517\n", "Iteration: 3570\n", "0.9317482517482517\n", "Iteration: 3580\n", "0.932027972027972\n", "Iteration: 3590\n", "0.932027972027972\n", "Iteration: 3600\n", "0.9325874125874126\n", "Iteration: 3610\n", "0.9328671328671329\n", "Iteration: 3620\n", "0.9334265734265734\n", "Iteration: 3630\n", "0.9334265734265734\n", "Iteration: 3640\n", "0.933986013986014\n", "Iteration: 3650\n", "0.933986013986014\n", "Iteration: 3660\n", "0.933986013986014\n", "Iteration: 3670\n", "0.9342657342657342\n", "Iteration: 3680\n", "0.9342657342657342\n", "Iteration: 3690\n", "0.9345454545454546\n", "Iteration: 3700\n", "0.9345454545454546\n", "Iteration: 3710\n", "0.9345454545454546\n", "Iteration: 3720\n", "0.9353846153846154\n", "Iteration: 3730\n", "0.9353846153846154\n", "Iteration: 3740\n", "0.9356643356643357\n", "Iteration: 3750\n", "0.9365034965034965\n", "Iteration: 3760\n", "0.9370629370629371\n", "Iteration: 3770\n", "0.9376223776223777\n", "Iteration: 3780\n", "0.9379020979020979\n", "Iteration: 3790\n", "0.9381818181818182\n", "Iteration: 3800\n", "0.9381818181818182\n", "Iteration: 3810\n", "0.9381818181818182\n", "Iteration: 3820\n", "0.9381818181818182\n", "Iteration: 3830\n", "0.9381818181818182\n", "Iteration: 3840\n", "0.9381818181818182\n", "Iteration: 3850\n", "0.9381818181818182\n", "Iteration: 3860\n", "0.9384615384615385\n", "Iteration: 3870\n", "0.9384615384615385\n", "Iteration: 3880\n", "0.9384615384615385\n", "Iteration: 3890\n", "0.9387412587412587\n", "Iteration: 3900\n", "0.9387412587412587\n", "Iteration: 3910\n", "0.939020979020979\n", "Iteration: 3920\n", "0.939020979020979\n", "Iteration: 3930\n", "0.9387412587412587\n", "Iteration: 3940\n", "0.939020979020979\n", "Iteration: 3950\n", "0.9393006993006993\n", "Iteration: 3960\n", "0.9393006993006993\n", "Iteration: 3970\n", "0.9393006993006993\n", "Iteration: 3980\n", "0.9393006993006993\n", "Iteration: 3990\n", "0.9393006993006993\n" ] } ], "source": [ "W1, b1, W2, b2 = gradient_descent(X_train, Y_train, 0.10, 4000)\n", "df = pd.DataFrame(acc_store)\n", "df.to_csv('bt_acc.csv', index=False)\n", "np.savez(\"bt_weights\", W1, b1, W2, b2)" ] }, { "cell_type": "code", "execution_count": null, "id": "d3370366-9827-44db-85b9-512e39574ee7", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 5 }