semantics/Bel_NN_C.ipynb
2024-09-26 17:23:23 -04:00

1000 lines
27 KiB
Text

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "c4460aee-ec58-454e-9304-82deb89942b4",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd \n",
"from matplotlib import pyplot as plt \n",
"\n",
"data = pd.read_csv('data/bel_data_test.csv')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f1ec44a2-7bb0-44a1-a417-69ab70e1b6f8",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"data = np.array(data)\n",
"\n",
"m,n = data.shape\n",
"data_train = data[1000:m].T\n",
"\n",
"Y_train = data_train[0].astype(int)\n",
"\n",
"X_train = data_train[1:n]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9223239b-35b9-43bc-acf9-dac6a19587c5",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def init_params():\n",
" W1 = np.random.rand(10,1024) - 0.5\n",
" b1 = np.random.rand(10,1) - 0.5\n",
" W2 = np.random.rand(61,10) - 0.5\n",
" b2 = np.random.rand(61,1) - 0.5\n",
" return W1, b1 , W2, b2"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "610cd614-b3c9-445e-8f3d-491596bc773c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def ReLU(Z):\n",
" return np.maximum(Z,0)\n",
"\n",
"def softmax(Z):\n",
" A = np.exp(Z) / sum(np.exp(Z))\n",
" return A\n",
"\n",
"def forward_prop(W1, b1, W2, b2, X):\n",
" Z1 = W1.dot(X) + b1\n",
" A1 = ReLU(Z1)\n",
" Z2 = W2.dot(A1) + b2\n",
" A2 = softmax(Z2)\n",
" return Z1, A1, Z2, A2\n",
"\n",
"def ReLU_deriv(Z):\n",
" return Z > 0\n",
"\n",
"def one_hot(Y):\n",
" one_hot_Y = np.zeros((Y.size, Y.max() + 1))\n",
" one_hot_Y[np.arange(Y.size), Y] = 1\n",
" one_hot_Y = one_hot_Y.T\n",
" return one_hot_Y\n",
"\n",
"def backward_prop(Z1, A1, Z2, A2, W1, W2, X, Y):\n",
" one_hot_Y = one_hot(Y)\n",
" dZ2 = A2 - one_hot_Y\n",
" dW2 = 1 / m * dZ2.dot(A1.T)\n",
" db2 = 1 / m * np.sum(dZ2)\n",
" dZ1 = W2.T.dot(dZ2) * ReLU_deriv(Z1)\n",
" dW1 = 1 / m * dZ1.dot(X.T)\n",
" db1 = 1 / m * np.sum(dZ1)\n",
" return dW1, db1, dW2, db2\n",
"\n",
"def update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha):\n",
" W1 = W1 - alpha * dW1\n",
" b1 = b1 - alpha * db1 \n",
" W2 = W2 - alpha * dW2 \n",
" b2 = b2 - alpha * db2 \n",
" return W1, b1, W2, b2\n",
"\n",
"def get_predictions(A2):\n",
" return np.argmax(A2, 0)\n",
"\n",
"def get_accuracy(predictions, Y):\n",
" #print(predictions, Y)\n",
" return np.sum(predictions == Y) / Y.size"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "287c3525-e1a0-41c4-a1ae-9c24c805c7c5",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"acc_store = [] "
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "539820c9-dc35-4f24-bbff-7860f7442c19",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def gradient_descent(X, Y, alpha, iterations):\n",
" W1, b1, W2, b2 = init_params()\n",
" for i in range(iterations):\n",
" Z1, A1, Z2, A2 = forward_prop(W1, b1, W2, b2, X)\n",
" dW1, db1, dW2, db2 = backward_prop(Z1, A1, Z2, A2, W1, W2, X, Y)\n",
" W1, b1, W2, b2 = update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha)\n",
" if i % 10 == 0:\n",
" print(\"Iteration: \", i)\n",
" predictions = get_predictions(A2)\n",
" pred = get_accuracy(predictions, Y)\n",
" print(pred)\n",
" acc_store.append(pred)\n",
" return W1, b1, W2, b2"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "06a5f975-472b-4b98-99e2-c1fff036ddc7",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iteration: 0\n",
"0.014825174825174826\n",
"Iteration: 10\n",
"0.06657342657342658\n",
"Iteration: 20\n",
"0.12055944055944055\n",
"Iteration: 30\n",
"0.1476923076923077\n",
"Iteration: 40\n",
"0.16895104895104895\n",
"Iteration: 50\n",
"0.19552447552447552\n",
"Iteration: 60\n",
"0.20251748251748253\n",
"Iteration: 70\n",
"0.2095104895104895\n",
"Iteration: 80\n",
"0.21818181818181817\n",
"Iteration: 90\n",
"0.2274125874125874\n",
"Iteration: 100\n",
"0.23608391608391607\n",
"Iteration: 110\n",
"0.24503496503496502\n",
"Iteration: 120\n",
"0.26153846153846155\n",
"Iteration: 130\n",
"0.2732867132867133\n",
"Iteration: 140\n",
"0.28895104895104895\n",
"Iteration: 150\n",
"0.3026573426573427\n",
"Iteration: 160\n",
"0.3144055944055944\n",
"Iteration: 170\n",
"0.3225174825174825\n",
"Iteration: 180\n",
"0.33146853146853145\n",
"Iteration: 190\n",
"0.34965034965034963\n",
"Iteration: 200\n",
"0.3655944055944056\n",
"Iteration: 210\n",
"0.41426573426573426\n",
"Iteration: 220\n",
"0.4483916083916084\n",
"Iteration: 230\n",
"0.4662937062937063\n",
"Iteration: 240\n",
"0.47944055944055947\n",
"Iteration: 250\n",
"0.4967832167832168\n",
"Iteration: 260\n",
"0.5096503496503496\n",
"Iteration: 270\n",
"0.5216783216783217\n",
"Iteration: 280\n",
"0.5342657342657343\n",
"Iteration: 290\n",
"0.544055944055944\n",
"Iteration: 300\n",
"0.5546853146853147\n",
"Iteration: 310\n",
"0.5613986013986014\n",
"Iteration: 320\n",
"0.5667132867132867\n",
"Iteration: 330\n",
"0.5728671328671329\n",
"Iteration: 340\n",
"0.579020979020979\n",
"Iteration: 350\n",
"0.5865734265734266\n",
"Iteration: 360\n",
"0.593006993006993\n",
"Iteration: 370\n",
"0.6008391608391609\n",
"Iteration: 380\n",
"0.6044755244755244\n",
"Iteration: 390\n",
"0.6092307692307692\n",
"Iteration: 400\n",
"0.6156643356643356\n",
"Iteration: 410\n",
"0.6223776223776224\n",
"Iteration: 420\n",
"0.6310489510489511\n",
"Iteration: 430\n",
"0.6391608391608392\n",
"Iteration: 440\n",
"0.6486713286713287\n",
"Iteration: 450\n",
"0.6573426573426573\n",
"Iteration: 460\n",
"0.6618181818181819\n",
"Iteration: 470\n",
"0.6665734265734266\n",
"Iteration: 480\n",
"0.6707692307692308\n",
"Iteration: 490\n",
"0.6738461538461539\n",
"Iteration: 500\n",
"0.676923076923077\n",
"Iteration: 510\n",
"0.6836363636363636\n",
"Iteration: 520\n",
"0.6900699300699301\n",
"Iteration: 530\n",
"0.6942657342657342\n",
"Iteration: 540\n",
"0.6987412587412587\n",
"Iteration: 550\n",
"0.7026573426573427\n",
"Iteration: 560\n",
"0.7048951048951049\n",
"Iteration: 570\n",
"0.7090909090909091\n",
"Iteration: 580\n",
"0.7113286713286713\n",
"Iteration: 590\n",
"0.7163636363636363\n",
"Iteration: 600\n",
"0.7230769230769231\n",
"Iteration: 610\n",
"0.7264335664335664\n",
"Iteration: 620\n",
"0.7297902097902098\n",
"Iteration: 630\n",
"0.7337062937062937\n",
"Iteration: 640\n",
"0.7367832167832168\n",
"Iteration: 650\n",
"0.7412587412587412\n",
"Iteration: 660\n",
"0.7446153846153846\n",
"Iteration: 670\n",
"0.7468531468531469\n",
"Iteration: 680\n",
"0.7510489510489511\n",
"Iteration: 690\n",
"0.7546853146853146\n",
"Iteration: 700\n",
"0.7560839160839161\n",
"Iteration: 710\n",
"0.76\n",
"Iteration: 720\n",
"0.7622377622377622\n",
"Iteration: 730\n",
"0.7647552447552447\n",
"Iteration: 740\n",
"0.7664335664335664\n",
"Iteration: 750\n",
"0.7695104895104895\n",
"Iteration: 760\n",
"0.7711888111888112\n",
"Iteration: 770\n",
"0.7731468531468532\n",
"Iteration: 780\n",
"0.777062937062937\n",
"Iteration: 790\n",
"0.7784615384615384\n",
"Iteration: 800\n",
"0.7801398601398601\n",
"Iteration: 810\n",
"0.7812587412587413\n",
"Iteration: 820\n",
"0.784055944055944\n",
"Iteration: 830\n",
"0.7865734265734265\n",
"Iteration: 840\n",
"0.7885314685314685\n",
"Iteration: 850\n",
"0.7904895104895104\n",
"Iteration: 860\n",
"0.7927272727272727\n",
"Iteration: 870\n",
"0.796083916083916\n",
"Iteration: 880\n",
"0.7974825174825175\n",
"Iteration: 890\n",
"0.7986013986013986\n",
"Iteration: 900\n",
"0.7994405594405595\n",
"Iteration: 910\n",
"0.8008391608391608\n",
"Iteration: 920\n",
"0.8022377622377622\n",
"Iteration: 930\n",
"0.8036363636363636\n",
"Iteration: 940\n",
"0.8041958041958042\n",
"Iteration: 950\n",
"0.8055944055944056\n",
"Iteration: 960\n",
"0.8072727272727273\n",
"Iteration: 970\n",
"0.8095104895104895\n",
"Iteration: 980\n",
"0.8109090909090909\n",
"Iteration: 990\n",
"0.8131468531468532\n",
"Iteration: 1000\n",
"0.813986013986014\n",
"Iteration: 1010\n",
"0.8145454545454546\n",
"Iteration: 1020\n",
"0.8156643356643357\n",
"Iteration: 1030\n",
"0.8170629370629371\n",
"Iteration: 1040\n",
"0.8198601398601398\n",
"Iteration: 1050\n",
"0.8204195804195804\n",
"Iteration: 1060\n",
"0.8215384615384616\n",
"Iteration: 1070\n",
"0.8226573426573427\n",
"Iteration: 1080\n",
"0.8234965034965035\n",
"Iteration: 1090\n",
"0.8237762237762237\n",
"Iteration: 1100\n",
"0.8251748251748252\n",
"Iteration: 1110\n",
"0.8254545454545454\n",
"Iteration: 1120\n",
"0.8265734265734266\n",
"Iteration: 1130\n",
"0.8274125874125874\n",
"Iteration: 1140\n",
"0.8282517482517483\n",
"Iteration: 1150\n",
"0.8285314685314685\n",
"Iteration: 1160\n",
"0.8293706293706293\n",
"Iteration: 1170\n",
"0.8296503496503497\n",
"Iteration: 1180\n",
"0.8302097902097902\n",
"Iteration: 1190\n",
"0.831048951048951\n",
"Iteration: 1200\n",
"0.8318881118881118\n",
"Iteration: 1210\n",
"0.8327272727272728\n",
"Iteration: 1220\n",
"0.833006993006993\n",
"Iteration: 1230\n",
"0.8346853146853147\n",
"Iteration: 1240\n",
"0.8355244755244755\n",
"Iteration: 1250\n",
"0.8366433566433567\n",
"Iteration: 1260\n",
"0.8374825174825175\n",
"Iteration: 1270\n",
"0.838041958041958\n",
"Iteration: 1280\n",
"0.8397202797202797\n",
"Iteration: 1290\n",
"0.8411188811188811\n",
"Iteration: 1300\n",
"0.8416783216783217\n",
"Iteration: 1310\n",
"0.8436363636363636\n",
"Iteration: 1320\n",
"0.845034965034965\n",
"Iteration: 1330\n",
"0.8458741258741259\n",
"Iteration: 1340\n",
"0.8467132867132867\n",
"Iteration: 1350\n",
"0.8483916083916084\n",
"Iteration: 1360\n",
"0.848951048951049\n",
"Iteration: 1370\n",
"0.8497902097902098\n",
"Iteration: 1380\n",
"0.850909090909091\n",
"Iteration: 1390\n",
"0.8514685314685315\n",
"Iteration: 1400\n",
"0.852027972027972\n",
"Iteration: 1410\n",
"0.8528671328671329\n",
"Iteration: 1420\n",
"0.8531468531468531\n",
"Iteration: 1430\n",
"0.8537062937062937\n",
"Iteration: 1440\n",
"0.8542657342657343\n",
"Iteration: 1450\n",
"0.8553846153846154\n",
"Iteration: 1460\n",
"0.8565034965034966\n",
"Iteration: 1470\n",
"0.8573426573426574\n",
"Iteration: 1480\n",
"0.8576223776223776\n",
"Iteration: 1490\n",
"0.8587412587412587\n",
"Iteration: 1500\n",
"0.8593006993006993\n",
"Iteration: 1510\n",
"0.8601398601398601\n",
"Iteration: 1520\n",
"0.8618181818181818\n",
"Iteration: 1530\n",
"0.8623776223776224\n",
"Iteration: 1540\n",
"0.8632167832167832\n",
"Iteration: 1550\n",
"0.8646153846153846\n",
"Iteration: 1560\n",
"0.8654545454545455\n",
"Iteration: 1570\n",
"0.866013986013986\n",
"Iteration: 1580\n",
"0.8665734265734266\n",
"Iteration: 1590\n",
"0.8676923076923077\n",
"Iteration: 1600\n",
"0.8676923076923077\n",
"Iteration: 1610\n",
"0.8682517482517482\n",
"Iteration: 1620\n",
"0.8690909090909091\n",
"Iteration: 1630\n",
"0.8693706293706294\n",
"Iteration: 1640\n",
"0.8699300699300699\n",
"Iteration: 1650\n",
"0.8713286713286713\n",
"Iteration: 1660\n",
"0.8721678321678321\n",
"Iteration: 1670\n",
"0.8724475524475525\n",
"Iteration: 1680\n",
"0.8732867132867133\n",
"Iteration: 1690\n",
"0.8735664335664336\n",
"Iteration: 1700\n",
"0.8738461538461538\n",
"Iteration: 1710\n",
"0.8755244755244755\n",
"Iteration: 1720\n",
"0.8763636363636363\n",
"Iteration: 1730\n",
"0.8772027972027973\n",
"Iteration: 1740\n",
"0.8777622377622377\n",
"Iteration: 1750\n",
"0.8780419580419581\n",
"Iteration: 1760\n",
"0.8786013986013986\n",
"Iteration: 1770\n",
"0.8791608391608392\n",
"Iteration: 1780\n",
"0.8797202797202798\n",
"Iteration: 1790\n",
"0.8802797202797202\n",
"Iteration: 1800\n",
"0.8805594405594406\n",
"Iteration: 1810\n",
"0.8811188811188811\n",
"Iteration: 1820\n",
"0.881958041958042\n",
"Iteration: 1830\n",
"0.8822377622377623\n",
"Iteration: 1840\n",
"0.8830769230769231\n",
"Iteration: 1850\n",
"0.8839160839160839\n",
"Iteration: 1860\n",
"0.8841958041958042\n",
"Iteration: 1870\n",
"0.8839160839160839\n",
"Iteration: 1880\n",
"0.8844755244755245\n",
"Iteration: 1890\n",
"0.8847552447552448\n",
"Iteration: 1900\n",
"0.885034965034965\n",
"Iteration: 1910\n",
"0.885034965034965\n",
"Iteration: 1920\n",
"0.8853146853146853\n",
"Iteration: 1930\n",
"0.8861538461538462\n",
"Iteration: 1940\n",
"0.886993006993007\n",
"Iteration: 1950\n",
"0.8875524475524476\n",
"Iteration: 1960\n",
"0.8881118881118881\n",
"Iteration: 1970\n",
"0.8886713286713287\n",
"Iteration: 1980\n",
"0.8897902097902098\n",
"Iteration: 1990\n",
"0.8900699300699301\n",
"Iteration: 2000\n",
"0.8900699300699301\n",
"Iteration: 2010\n",
"0.8909090909090909\n",
"Iteration: 2020\n",
"0.8911888111888112\n",
"Iteration: 2030\n",
"0.8914685314685314\n",
"Iteration: 2040\n",
"0.8917482517482518\n",
"Iteration: 2050\n",
"0.8925874125874126\n",
"Iteration: 2060\n",
"0.8928671328671328\n",
"Iteration: 2070\n",
"0.8931468531468532\n",
"Iteration: 2080\n",
"0.8934265734265734\n",
"Iteration: 2090\n",
"0.893986013986014\n",
"Iteration: 2100\n",
"0.8942657342657343\n",
"Iteration: 2110\n",
"0.8948251748251749\n",
"Iteration: 2120\n",
"0.8953846153846153\n",
"Iteration: 2130\n",
"0.8956643356643357\n",
"Iteration: 2140\n",
"0.8962237762237762\n",
"Iteration: 2150\n",
"0.8967832167832168\n",
"Iteration: 2160\n",
"0.897062937062937\n",
"Iteration: 2170\n",
"0.8973426573426574\n",
"Iteration: 2180\n",
"0.8976223776223776\n",
"Iteration: 2190\n",
"0.8979020979020979\n",
"Iteration: 2200\n",
"0.8984615384615384\n",
"Iteration: 2210\n",
"0.8993006993006993\n",
"Iteration: 2220\n",
"0.8995804195804196\n",
"Iteration: 2230\n",
"0.8998601398601399\n",
"Iteration: 2240\n",
"0.9001398601398601\n",
"Iteration: 2250\n",
"0.9001398601398601\n",
"Iteration: 2260\n",
"0.9004195804195804\n",
"Iteration: 2270\n",
"0.9012587412587413\n",
"Iteration: 2280\n",
"0.9015384615384615\n",
"Iteration: 2290\n",
"0.9020979020979021\n",
"Iteration: 2300\n",
"0.9020979020979021\n",
"Iteration: 2310\n",
"0.9020979020979021\n",
"Iteration: 2320\n",
"0.9023776223776224\n",
"Iteration: 2330\n",
"0.9032167832167832\n",
"Iteration: 2340\n",
"0.904055944055944\n",
"Iteration: 2350\n",
"0.9043356643356644\n",
"Iteration: 2360\n",
"0.9046153846153846\n",
"Iteration: 2370\n",
"0.9046153846153846\n",
"Iteration: 2380\n",
"0.9054545454545454\n",
"Iteration: 2390\n",
"0.906013986013986\n",
"Iteration: 2400\n",
"0.9071328671328671\n",
"Iteration: 2410\n",
"0.9074125874125875\n",
"Iteration: 2420\n",
"0.9079720279720279\n",
"Iteration: 2430\n",
"0.9082517482517483\n",
"Iteration: 2440\n",
"0.9085314685314685\n",
"Iteration: 2450\n",
"0.9085314685314685\n",
"Iteration: 2460\n",
"0.9085314685314685\n",
"Iteration: 2470\n",
"0.9082517482517483\n",
"Iteration: 2480\n",
"0.9093706293706294\n",
"Iteration: 2490\n",
"0.9093706293706294\n",
"Iteration: 2500\n",
"0.9096503496503496\n",
"Iteration: 2510\n",
"0.9096503496503496\n",
"Iteration: 2520\n",
"0.9102097902097902\n",
"Iteration: 2530\n",
"0.9102097902097902\n",
"Iteration: 2540\n",
"0.9102097902097902\n",
"Iteration: 2550\n",
"0.9107692307692308\n",
"Iteration: 2560\n",
"0.911048951048951\n",
"Iteration: 2570\n",
"0.9113286713286713\n",
"Iteration: 2580\n",
"0.9116083916083916\n",
"Iteration: 2590\n",
"0.9118881118881119\n",
"Iteration: 2600\n",
"0.9127272727272727\n",
"Iteration: 2610\n",
"0.9132867132867133\n",
"Iteration: 2620\n",
"0.9135664335664335\n",
"Iteration: 2630\n",
"0.9135664335664335\n",
"Iteration: 2640\n",
"0.9144055944055944\n",
"Iteration: 2650\n",
"0.9146853146853147\n",
"Iteration: 2660\n",
"0.9146853146853147\n",
"Iteration: 2670\n",
"0.9146853146853147\n",
"Iteration: 2680\n",
"0.9155244755244756\n",
"Iteration: 2690\n",
"0.9155244755244756\n",
"Iteration: 2700\n",
"0.9158041958041958\n",
"Iteration: 2710\n",
"0.9166433566433566\n",
"Iteration: 2720\n",
"0.916923076923077\n",
"Iteration: 2730\n",
"0.9172027972027972\n",
"Iteration: 2740\n",
"0.9177622377622378\n",
"Iteration: 2750\n",
"0.9177622377622378\n",
"Iteration: 2760\n",
"0.9183216783216783\n",
"Iteration: 2770\n",
"0.9186013986013986\n",
"Iteration: 2780\n",
"0.9188811188811189\n",
"Iteration: 2790\n",
"0.9194405594405595\n",
"Iteration: 2800\n",
"0.9194405594405595\n",
"Iteration: 2810\n",
"0.9197202797202797\n",
"Iteration: 2820\n",
"0.92\n",
"Iteration: 2830\n",
"0.9202797202797203\n",
"Iteration: 2840\n",
"0.9202797202797203\n",
"Iteration: 2850\n",
"0.9202797202797203\n",
"Iteration: 2860\n",
"0.9202797202797203\n",
"Iteration: 2870\n",
"0.9202797202797203\n",
"Iteration: 2880\n",
"0.9205594405594406\n",
"Iteration: 2890\n",
"0.9213986013986014\n",
"Iteration: 2900\n",
"0.9216783216783216\n",
"Iteration: 2910\n",
"0.9216783216783216\n",
"Iteration: 2920\n",
"0.9216783216783216\n",
"Iteration: 2930\n",
"0.9216783216783216\n",
"Iteration: 2940\n",
"0.9216783216783216\n",
"Iteration: 2950\n",
"0.921958041958042\n",
"Iteration: 2960\n",
"0.921958041958042\n",
"Iteration: 2970\n",
"0.9225174825174826\n",
"Iteration: 2980\n",
"0.9225174825174826\n",
"Iteration: 2990\n",
"0.9227972027972028\n",
"Iteration: 3000\n",
"0.9225174825174826\n",
"Iteration: 3010\n",
"0.9230769230769231\n",
"Iteration: 3020\n",
"0.9233566433566434\n",
"Iteration: 3030\n",
"0.9236363636363636\n",
"Iteration: 3040\n",
"0.9236363636363636\n",
"Iteration: 3050\n",
"0.9236363636363636\n",
"Iteration: 3060\n",
"0.9236363636363636\n",
"Iteration: 3070\n",
"0.9236363636363636\n",
"Iteration: 3080\n",
"0.9239160839160839\n",
"Iteration: 3090\n",
"0.9241958041958042\n",
"Iteration: 3100\n",
"0.9250349650349651\n",
"Iteration: 3110\n",
"0.9250349650349651\n",
"Iteration: 3120\n",
"0.9253146853146853\n",
"Iteration: 3130\n",
"0.9255944055944056\n",
"Iteration: 3140\n",
"0.9258741258741259\n",
"Iteration: 3150\n",
"0.9258741258741259\n",
"Iteration: 3160\n",
"0.9258741258741259\n",
"Iteration: 3170\n",
"0.9258741258741259\n",
"Iteration: 3180\n",
"0.9255944055944056\n",
"Iteration: 3190\n",
"0.9258741258741259\n",
"Iteration: 3200\n",
"0.9258741258741259\n",
"Iteration: 3210\n",
"0.9261538461538461\n",
"Iteration: 3220\n",
"0.9261538461538461\n",
"Iteration: 3230\n",
"0.9267132867132867\n",
"Iteration: 3240\n",
"0.9275524475524476\n",
"Iteration: 3250\n",
"0.9275524475524476\n",
"Iteration: 3260\n",
"0.9275524475524476\n",
"Iteration: 3270\n",
"0.9275524475524476\n",
"Iteration: 3280\n",
"0.9275524475524476\n",
"Iteration: 3290\n",
"0.9278321678321678\n",
"Iteration: 3300\n",
"0.9281118881118882\n",
"Iteration: 3310\n",
"0.9283916083916084\n",
"Iteration: 3320\n",
"0.9283916083916084\n",
"Iteration: 3330\n",
"0.9283916083916084\n",
"Iteration: 3340\n",
"0.9283916083916084\n",
"Iteration: 3350\n",
"0.9286713286713286\n",
"Iteration: 3360\n",
"0.9286713286713286\n",
"Iteration: 3370\n",
"0.928951048951049\n",
"Iteration: 3380\n",
"0.928951048951049\n",
"Iteration: 3390\n",
"0.928951048951049\n",
"Iteration: 3400\n",
"0.928951048951049\n",
"Iteration: 3410\n",
"0.928951048951049\n",
"Iteration: 3420\n",
"0.928951048951049\n",
"Iteration: 3430\n",
"0.9295104895104895\n",
"Iteration: 3440\n",
"0.9297902097902098\n",
"Iteration: 3450\n",
"0.9295104895104895\n",
"Iteration: 3460\n",
"0.9295104895104895\n",
"Iteration: 3470\n",
"0.9297902097902098\n",
"Iteration: 3480\n",
"0.9300699300699301\n",
"Iteration: 3490\n",
"0.9303496503496503\n",
"Iteration: 3500\n",
"0.9303496503496503\n",
"Iteration: 3510\n",
"0.9306293706293707\n",
"Iteration: 3520\n",
"0.9306293706293707\n",
"Iteration: 3530\n",
"0.9309090909090909\n",
"Iteration: 3540\n",
"0.9311888111888111\n",
"Iteration: 3550\n",
"0.9314685314685315\n",
"Iteration: 3560\n",
"0.9317482517482517\n",
"Iteration: 3570\n",
"0.9317482517482517\n",
"Iteration: 3580\n",
"0.932027972027972\n",
"Iteration: 3590\n",
"0.932027972027972\n",
"Iteration: 3600\n",
"0.9325874125874126\n",
"Iteration: 3610\n",
"0.9328671328671329\n",
"Iteration: 3620\n",
"0.9334265734265734\n",
"Iteration: 3630\n",
"0.9334265734265734\n",
"Iteration: 3640\n",
"0.933986013986014\n",
"Iteration: 3650\n",
"0.933986013986014\n",
"Iteration: 3660\n",
"0.933986013986014\n",
"Iteration: 3670\n",
"0.9342657342657342\n",
"Iteration: 3680\n",
"0.9342657342657342\n",
"Iteration: 3690\n",
"0.9345454545454546\n",
"Iteration: 3700\n",
"0.9345454545454546\n",
"Iteration: 3710\n",
"0.9345454545454546\n",
"Iteration: 3720\n",
"0.9353846153846154\n",
"Iteration: 3730\n",
"0.9353846153846154\n",
"Iteration: 3740\n",
"0.9356643356643357\n",
"Iteration: 3750\n",
"0.9365034965034965\n",
"Iteration: 3760\n",
"0.9370629370629371\n",
"Iteration: 3770\n",
"0.9376223776223777\n",
"Iteration: 3780\n",
"0.9379020979020979\n",
"Iteration: 3790\n",
"0.9381818181818182\n",
"Iteration: 3800\n",
"0.9381818181818182\n",
"Iteration: 3810\n",
"0.9381818181818182\n",
"Iteration: 3820\n",
"0.9381818181818182\n",
"Iteration: 3830\n",
"0.9381818181818182\n",
"Iteration: 3840\n",
"0.9381818181818182\n",
"Iteration: 3850\n",
"0.9381818181818182\n",
"Iteration: 3860\n",
"0.9384615384615385\n",
"Iteration: 3870\n",
"0.9384615384615385\n",
"Iteration: 3880\n",
"0.9384615384615385\n",
"Iteration: 3890\n",
"0.9387412587412587\n",
"Iteration: 3900\n",
"0.9387412587412587\n",
"Iteration: 3910\n",
"0.939020979020979\n",
"Iteration: 3920\n",
"0.939020979020979\n",
"Iteration: 3930\n",
"0.9387412587412587\n",
"Iteration: 3940\n",
"0.939020979020979\n",
"Iteration: 3950\n",
"0.9393006993006993\n",
"Iteration: 3960\n",
"0.9393006993006993\n",
"Iteration: 3970\n",
"0.9393006993006993\n",
"Iteration: 3980\n",
"0.9393006993006993\n",
"Iteration: 3990\n",
"0.9393006993006993\n"
]
}
],
"source": [
"W1, b1, W2, b2 = gradient_descent(X_train, Y_train, 0.10, 4000)\n",
"df = pd.DataFrame(acc_store)\n",
"df.to_csv('bt_acc.csv', index=False)\n",
"np.savez(\"bt_weights\", W1, b1, W2, b2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d3370366-9827-44db-85b9-512e39574ee7",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}