semantics/Cro_NN_C.ipynb
2024-09-26 17:23:23 -04:00

992 lines
27 KiB
Text

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "25c0d153-288c-4ee8-a968-915f853b8157",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd \n",
"from matplotlib import pyplot as plt \n",
"\n",
"data = pd.read_csv('cro_data_test.csv')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "962cacc2-c818-4c5b-bdab-2ee46c6de511",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"data = np.array(data)\n",
"\n",
"m,n = data.shape\n",
"data_train = data[1000:m].T\n",
"\n",
"Y_train = data_train[0].astype(int)\n",
"\n",
"X_train = data_train[1:n]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e863fe3b-3ee6-42f3-b716-4fcda6a850af",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def init_params():\n",
" W1 = np.random.rand(10,1024) - 0.5\n",
" b1 = np.random.rand(10,1) - 0.5\n",
" W2 = np.random.rand(5,10) - 0.5\n",
" b2 = np.random.rand(5,1) - 0.5\n",
" return W1, b1 , W2, b2"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "64dd0fba-a49e-4f13-b534-e074350b5f42",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def ReLU(Z):\n",
" return np.maximum(Z,0)\n",
"def softmax(Z):\n",
" A = np.exp(Z) / sum(np.exp(Z))\n",
" return A\n",
"def forward_prop(W1, b1, W2, b2, X):\n",
" Z1 = W1.dot(X) + b1\n",
" A1 = ReLU(Z1)\n",
" Z2 = W2.dot(A1) + b2\n",
" A2 = softmax(Z2)\n",
" return Z1, A1, Z2, A2\n",
"def ReLU_deriv(Z):\n",
" return Z > 0\n",
"def one_hot(Y):\n",
" one_hot_Y = np.zeros((Y.size, Y.max() + 1))\n",
" one_hot_Y[np.arange(Y.size), Y] = 1\n",
" one_hot_Y = one_hot_Y.T\n",
" return one_hot_Y\n",
"def backward_prop(Z1, A1, Z2, A2, W1, W2, X, Y):\n",
" one_hot_Y = one_hot(Y)\n",
" dZ2 = A2 - one_hot_Y\n",
" dW2 = 1 / m * dZ2.dot(A1.T)\n",
" db2 = 1 / m * np.sum(dZ2)\n",
" dZ1 = W2.T.dot(dZ2) * ReLU_deriv(Z1)\n",
" dW1 = 1 / m * dZ1.dot(X.T)\n",
" db1 = 1 / m * np.sum(dZ1)\n",
" return dW1, db1, dW2, db2\n",
"def update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha):\n",
" W1 = W1 - alpha * dW1\n",
" b1 = b1 - alpha * db1 \n",
" W2 = W2 - alpha * dW2 \n",
" b2 = b2 - alpha * db2 \n",
" return W1, b1, W2, b2\n",
"def get_predictions(A2):\n",
" return np.argmax(A2, 0)\n",
"def get_accuracy(predictions, Y):\n",
" #print(predictions, Y)\n",
" return np.sum(predictions == Y) / Y.size"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e7ef6234-254e-47f6-ac29-ddd92d363e9e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"acc_store = [] "
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "d24bdd4d-1d57-40b1-a95b-3cc33e02312d",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def gradient_descent(X, Y, alpha, iterations):\n",
" W1, b1, W2, b2 = init_params()\n",
" for i in range(iterations):\n",
" Z1, A1, Z2, A2 = forward_prop(W1, b1, W2, b2, X)\n",
" dW1, db1, dW2, db2 = backward_prop(Z1, A1, Z2, A2, W1, W2, X, Y)\n",
" W1, b1, W2, b2 = update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha)\n",
" if i % 10 == 0:\n",
" print(\"Iteration: \", i)\n",
" predictions = get_predictions(A2)\n",
" pred = get_accuracy(predictions, Y)\n",
" print(pred)\n",
" acc_store.append(pred)\n",
" return W1, b1, W2, b2"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "d266b8d3-8f15-4d89-a896-a728215b048d",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iteration: 0\n",
"0.2016308376575241\n",
"Iteration: 10\n",
"0.3700889547813195\n",
"Iteration: 20\n",
"0.45978502594514453\n",
"Iteration: 30\n",
"0.519644180874722\n",
"Iteration: 40\n",
"0.4792438843587843\n",
"Iteration: 50\n",
"0.49833209785025945\n",
"Iteration: 60\n",
"0.544477390659748\n",
"Iteration: 70\n",
"0.5804299481097109\n",
"Iteration: 80\n",
"0.6250926612305412\n",
"Iteration: 90\n",
"0.6653076352853966\n",
"Iteration: 100\n",
"0.6955151964418087\n",
"Iteration: 110\n",
"0.7168272794662713\n",
"Iteration: 120\n",
"0.7294292068198666\n",
"Iteration: 130\n",
"0.7366567828020756\n",
"Iteration: 140\n",
"0.7462935507783544\n",
"Iteration: 150\n",
"0.7544477390659748\n",
"Iteration: 160\n",
"0.758524833209785\n",
"Iteration: 170\n",
"0.7626019273535952\n",
"Iteration: 180\n",
"0.7677909562638991\n",
"Iteration: 190\n",
"0.7729799851742031\n",
"Iteration: 200\n",
"0.7755744996293551\n",
"Iteration: 210\n",
"0.7792809488510007\n",
"Iteration: 220\n",
"0.7824314306893996\n",
"Iteration: 230\n",
"0.7857672349888807\n",
"Iteration: 240\n",
"0.7889177168272795\n",
"Iteration: 250\n",
"0.7909562638991846\n",
"Iteration: 260\n",
"0.7935507783543365\n",
"Iteration: 270\n",
"0.7968865826538176\n",
"Iteration: 280\n",
"0.7996664195700519\n",
"Iteration: 290\n",
"0.8011489992587102\n",
"Iteration: 300\n",
"0.8026315789473685\n",
"Iteration: 310\n",
"0.8052260934025204\n",
"Iteration: 320\n",
"0.8081912527798369\n",
"Iteration: 330\n",
"0.8096738324684952\n",
"Iteration: 340\n",
"0.8118977020014826\n",
"Iteration: 350\n",
"0.8139362490733877\n",
"Iteration: 360\n",
"0.8148628613787992\n",
"Iteration: 370\n",
"0.8169014084507042\n",
"Iteration: 380\n",
"0.8181986656782803\n",
"Iteration: 390\n",
"0.8191252779836916\n",
"Iteration: 400\n",
"0.8204225352112676\n",
"Iteration: 410\n",
"0.8215344699777613\n",
"Iteration: 420\n",
"0.8224610822831727\n",
"Iteration: 430\n",
"0.825240919199407\n",
"Iteration: 440\n",
"0.8259822090437361\n",
"Iteration: 450\n",
"0.8280207561156412\n",
"Iteration: 460\n",
"0.8293180133432172\n",
"Iteration: 470\n",
"0.8306152705707932\n",
"Iteration: 480\n",
"0.8317272053372868\n",
"Iteration: 490\n",
"0.8313565604151223\n",
"Iteration: 500\n",
"0.8317272053372868\n",
"Iteration: 510\n",
"0.8320978502594515\n",
"Iteration: 520\n",
"0.8335804299481097\n",
"Iteration: 530\n",
"0.8348776871756857\n",
"Iteration: 540\n",
"0.8356189770200149\n",
"Iteration: 550\n",
"0.836360266864344\n",
"Iteration: 560\n",
"0.8367309117865085\n",
"Iteration: 570\n",
"0.8372868791697554\n",
"Iteration: 580\n",
"0.8395107487027428\n",
"Iteration: 590\n",
"0.8400667160859896\n",
"Iteration: 600\n",
"0.8404373610081541\n",
"Iteration: 610\n",
"0.8393254262416605\n",
"Iteration: 620\n",
"0.8400667160859896\n",
"Iteration: 630\n",
"0.8408080059303188\n",
"Iteration: 640\n",
"0.8408080059303188\n",
"Iteration: 650\n",
"0.8409933283914011\n",
"Iteration: 660\n",
"0.8421052631578947\n",
"Iteration: 670\n",
"0.8432171979243884\n",
"Iteration: 680\n",
"0.8432171979243884\n",
"Iteration: 690\n",
"0.843587842846553\n",
"Iteration: 700\n",
"0.8441438102297999\n",
"Iteration: 710\n",
"0.8445144551519644\n",
"Iteration: 720\n",
"0.8445144551519644\n",
"Iteration: 730\n",
"0.8445144551519644\n",
"Iteration: 740\n",
"0.844885100074129\n",
"Iteration: 750\n",
"0.844885100074129\n",
"Iteration: 760\n",
"0.8452557449962935\n",
"Iteration: 770\n",
"0.8459970348406227\n",
"Iteration: 780\n",
"0.8465530022238695\n",
"Iteration: 790\n",
"0.8467383246849518\n",
"Iteration: 800\n",
"0.8471089696071163\n",
"Iteration: 810\n",
"0.8476649369903633\n",
"Iteration: 820\n",
"0.8491475166790216\n",
"Iteration: 830\n",
"0.8502594514455152\n",
"Iteration: 840\n",
"0.8502594514455152\n",
"Iteration: 850\n",
"0.8511860637509266\n",
"Iteration: 860\n",
"0.8517420311341735\n",
"Iteration: 870\n",
"0.8517420311341735\n",
"Iteration: 880\n",
"0.8519273535952557\n",
"Iteration: 890\n",
"0.8519273535952557\n",
"Iteration: 900\n",
"0.8532246108228317\n",
"Iteration: 910\n",
"0.8539659006671608\n",
"Iteration: 920\n",
"0.85470719051149\n",
"Iteration: 930\n",
"0.85470719051149\n",
"Iteration: 940\n",
"0.8548925129725723\n",
"Iteration: 950\n",
"0.8556338028169014\n",
"Iteration: 960\n",
"0.8563750926612306\n",
"Iteration: 970\n",
"0.8565604151223128\n",
"Iteration: 980\n",
"0.8571163825055597\n",
"Iteration: 990\n",
"0.8567457375833951\n",
"Iteration: 1000\n",
"0.8578576723498889\n",
"Iteration: 1010\n",
"0.8580429948109711\n",
"Iteration: 1020\n",
"0.8582283172720534\n",
"Iteration: 1030\n",
"0.8585989621942179\n",
"Iteration: 1040\n",
"0.8587842846553002\n",
"Iteration: 1050\n",
"0.8591549295774648\n",
"Iteration: 1060\n",
"0.8595255744996294\n",
"Iteration: 1070\n",
"0.8595255744996294\n",
"Iteration: 1080\n",
"0.8597108969607117\n",
"Iteration: 1090\n",
"0.860637509266123\n",
"Iteration: 1100\n",
"0.8626760563380281\n",
"Iteration: 1110\n",
"0.8628613787991104\n",
"Iteration: 1120\n",
"0.8630467012601928\n",
"Iteration: 1130\n",
"0.863232023721275\n",
"Iteration: 1140\n",
"0.8630467012601928\n",
"Iteration: 1150\n",
"0.863232023721275\n",
"Iteration: 1160\n",
"0.8641586360266864\n",
"Iteration: 1170\n",
"0.8648999258710156\n",
"Iteration: 1180\n",
"0.8647146034099333\n",
"Iteration: 1190\n",
"0.8654558932542624\n",
"Iteration: 1200\n",
"0.8650852483320979\n",
"Iteration: 1210\n",
"0.8652705707931801\n",
"Iteration: 1220\n",
"0.8661971830985915\n",
"Iteration: 1230\n",
"0.8667531504818384\n",
"Iteration: 1240\n",
"0.8669384729429207\n",
"Iteration: 1250\n",
"0.8665678280207562\n",
"Iteration: 1260\n",
"0.8665678280207562\n",
"Iteration: 1270\n",
"0.8663825055596739\n",
"Iteration: 1280\n",
"0.865826538176427\n",
"Iteration: 1290\n",
"0.8665678280207562\n",
"Iteration: 1300\n",
"0.8682357301704967\n",
"Iteration: 1310\n",
"0.8700889547813195\n",
"Iteration: 1320\n",
"0.8721275018532246\n",
"Iteration: 1330\n",
"0.8723128243143069\n",
"Iteration: 1340\n",
"0.8723128243143069\n",
"Iteration: 1350\n",
"0.8702742772424018\n",
"Iteration: 1360\n",
"0.8699036323202373\n",
"Iteration: 1370\n",
"0.8680504077094143\n",
"Iteration: 1380\n",
"0.8680504077094143\n",
"Iteration: 1390\n",
"0.8691623424759081\n",
"Iteration: 1400\n",
"0.8713862120088954\n",
"Iteration: 1410\n",
"0.873054114158636\n",
"Iteration: 1420\n",
"0.874351371386212\n",
"Iteration: 1430\n",
"0.8758339510748703\n",
"Iteration: 1440\n",
"0.8763899184581171\n",
"Iteration: 1450\n",
"0.8763899184581171\n",
"Iteration: 1460\n",
"0.8762045959970348\n",
"Iteration: 1470\n",
"0.8745366938472943\n",
"Iteration: 1480\n",
"0.8724981467753892\n",
"Iteration: 1490\n",
"0.8702742772424018\n",
"Iteration: 1500\n",
"0.8710155670867309\n",
"Iteration: 1510\n",
"0.873054114158636\n",
"Iteration: 1520\n",
"0.8736100815418829\n",
"Iteration: 1530\n",
"0.8739807264640475\n",
"Iteration: 1540\n",
"0.8747220163083765\n",
"Iteration: 1550\n",
"0.8750926612305412\n",
"Iteration: 1560\n",
"0.8752779836916235\n",
"Iteration: 1570\n",
"0.8752779836916235\n",
"Iteration: 1580\n",
"0.8750926612305412\n",
"Iteration: 1590\n",
"0.8750926612305412\n",
"Iteration: 1600\n",
"0.8760192735359525\n",
"Iteration: 1610\n",
"0.876945885841364\n",
"Iteration: 1620\n",
"0.8775018532246108\n",
"Iteration: 1630\n",
"0.8778724981467754\n",
"Iteration: 1640\n",
"0.8784284655300222\n",
"Iteration: 1650\n",
"0.8782431430689399\n",
"Iteration: 1660\n",
"0.876945885841364\n",
"Iteration: 1670\n",
"0.8765752409191994\n",
"Iteration: 1680\n",
"0.8773165307635286\n",
"Iteration: 1690\n",
"0.8778724981467754\n",
"Iteration: 1700\n",
"0.8793550778354337\n",
"Iteration: 1710\n",
"0.8797257227575982\n",
"Iteration: 1720\n",
"0.8808376575240919\n",
"Iteration: 1730\n",
"0.8810229799851742\n",
"Iteration: 1740\n",
"0.8812083024462565\n",
"Iteration: 1750\n",
"0.8810229799851742\n",
"Iteration: 1760\n",
"0.8821349147516679\n",
"Iteration: 1770\n",
"0.8825055596738325\n",
"Iteration: 1780\n",
"0.8826908821349148\n",
"Iteration: 1790\n",
"0.882876204595997\n",
"Iteration: 1800\n",
"0.8830615270570793\n",
"Iteration: 1810\n",
"0.8832468495181616\n",
"Iteration: 1820\n",
"0.8834321719792438\n",
"Iteration: 1830\n",
"0.8825055596738325\n",
"Iteration: 1840\n",
"0.8821349147516679\n",
"Iteration: 1850\n",
"0.8817642698295033\n",
"Iteration: 1860\n",
"0.8826908821349148\n",
"Iteration: 1870\n",
"0.8843587842846553\n",
"Iteration: 1880\n",
"0.8851000741289844\n",
"Iteration: 1890\n",
"0.8852853965900667\n",
"Iteration: 1900\n",
"0.8856560415122313\n",
"Iteration: 1910\n",
"0.8856560415122313\n",
"Iteration: 1920\n",
"0.8858413639733136\n",
"Iteration: 1930\n",
"0.8865826538176427\n",
"Iteration: 1940\n",
"0.8871386212008896\n",
"Iteration: 1950\n",
"0.8869532987398072\n",
"Iteration: 1960\n",
"0.8873239436619719\n",
"Iteration: 1970\n",
"0.888065233506301\n",
"Iteration: 1980\n",
"0.8882505559673832\n",
"Iteration: 1990\n",
"0.8871386212008896\n",
"Iteration: 2000\n",
"0.885470719051149\n",
"Iteration: 2010\n",
"0.8865826538176427\n",
"Iteration: 2020\n",
"0.8878799110452187\n",
"Iteration: 2030\n",
"0.8888065233506302\n",
"Iteration: 2040\n",
"0.8888065233506302\n",
"Iteration: 2050\n",
"0.889362490733877\n",
"Iteration: 2060\n",
"0.8891771682727947\n",
"Iteration: 2070\n",
"0.888065233506301\n",
"Iteration: 2080\n",
"0.8886212008895478\n",
"Iteration: 2090\n",
"0.8891771682727947\n",
"Iteration: 2100\n",
"0.8904744255003706\n",
"Iteration: 2110\n",
"0.8910303928836175\n",
"Iteration: 2120\n",
"0.8908450704225352\n",
"Iteration: 2130\n",
"0.8910303928836175\n",
"Iteration: 2140\n",
"0.8914010378057821\n",
"Iteration: 2150\n",
"0.8917716827279466\n",
"Iteration: 2160\n",
"0.8915863602668643\n",
"Iteration: 2170\n",
"0.890659747961453\n",
"Iteration: 2180\n",
"0.8908450704225352\n",
"Iteration: 2190\n",
"0.8914010378057821\n",
"Iteration: 2200\n",
"0.8919570051890289\n",
"Iteration: 2210\n",
"0.8921423276501111\n",
"Iteration: 2220\n",
"0.8926982950333581\n",
"Iteration: 2230\n",
"0.8932542624166049\n",
"Iteration: 2240\n",
"0.8932542624166049\n",
"Iteration: 2250\n",
"0.8938102297998517\n",
"Iteration: 2260\n",
"0.8939955522609341\n",
"Iteration: 2270\n",
"0.8938102297998517\n",
"Iteration: 2280\n",
"0.8926982950333581\n",
"Iteration: 2290\n",
"0.8928836174944403\n",
"Iteration: 2300\n",
"0.8934395848776872\n",
"Iteration: 2310\n",
"0.8943661971830986\n",
"Iteration: 2320\n",
"0.8956634544106745\n",
"Iteration: 2330\n",
"0.89529280948851\n",
"Iteration: 2340\n",
"0.8954781319495922\n",
"Iteration: 2350\n",
"0.8954781319495922\n",
"Iteration: 2360\n",
"0.8960340993328392\n",
"Iteration: 2370\n",
"0.8964047442550037\n",
"Iteration: 2380\n",
"0.8964047442550037\n",
"Iteration: 2390\n",
"0.8964047442550037\n",
"Iteration: 2400\n",
"0.8964047442550037\n",
"Iteration: 2410\n",
"0.896590066716086\n",
"Iteration: 2420\n",
"0.896590066716086\n",
"Iteration: 2430\n",
"0.8969607116382505\n",
"Iteration: 2440\n",
"0.8969607116382505\n",
"Iteration: 2450\n",
"0.8971460340993328\n",
"Iteration: 2460\n",
"0.8973313565604151\n",
"Iteration: 2470\n",
"0.8977020014825797\n",
"Iteration: 2480\n",
"0.8980726464047443\n",
"Iteration: 2490\n",
"0.8982579688658265\n",
"Iteration: 2500\n",
"0.8986286137879911\n",
"Iteration: 2510\n",
"0.899184581171238\n",
"Iteration: 2520\n",
"0.899184581171238\n",
"Iteration: 2530\n",
"0.8989992587101556\n",
"Iteration: 2540\n",
"0.8989992587101556\n",
"Iteration: 2550\n",
"0.899184581171238\n",
"Iteration: 2560\n",
"0.8995552260934025\n",
"Iteration: 2570\n",
"0.8997405485544848\n",
"Iteration: 2580\n",
"0.8997405485544848\n",
"Iteration: 2590\n",
"0.9004818383988139\n",
"Iteration: 2600\n",
"0.9006671608598962\n",
"Iteration: 2610\n",
"0.9010378057820608\n",
"Iteration: 2620\n",
"0.9010378057820608\n",
"Iteration: 2630\n",
"0.9008524833209784\n",
"Iteration: 2640\n",
"0.9004818383988139\n",
"Iteration: 2650\n",
"0.9008524833209784\n",
"Iteration: 2660\n",
"0.9012231282431431\n",
"Iteration: 2670\n",
"0.9017790956263899\n",
"Iteration: 2680\n",
"0.9023350630096367\n",
"Iteration: 2690\n",
"0.902520385470719\n",
"Iteration: 2700\n",
"0.9027057079318014\n",
"Iteration: 2710\n",
"0.9028910303928837\n",
"Iteration: 2720\n",
"0.9030763528539659\n",
"Iteration: 2730\n",
"0.9030763528539659\n",
"Iteration: 2740\n",
"0.9030763528539659\n",
"Iteration: 2750\n",
"0.9027057079318014\n",
"Iteration: 2760\n",
"0.9015937731653076\n",
"Iteration: 2770\n",
"0.9015937731653076\n",
"Iteration: 2780\n",
"0.9014084507042254\n",
"Iteration: 2790\n",
"0.9015937731653076\n",
"Iteration: 2800\n",
"0.9030763528539659\n",
"Iteration: 2810\n",
"0.903817642698295\n",
"Iteration: 2820\n",
"0.9047442550037065\n",
"Iteration: 2830\n",
"0.9056708673091178\n",
"Iteration: 2840\n",
"0.906412157153447\n",
"Iteration: 2850\n",
"0.906412157153447\n",
"Iteration: 2860\n",
"0.9047442550037065\n",
"Iteration: 2870\n",
"0.9040029651593773\n",
"Iteration: 2880\n",
"0.9030763528539659\n",
"Iteration: 2890\n",
"0.903817642698295\n",
"Iteration: 2900\n",
"0.9049295774647887\n",
"Iteration: 2910\n",
"0.9047442550037065\n",
"Iteration: 2920\n",
"0.9053002223869533\n",
"Iteration: 2930\n",
"0.9058561897702001\n",
"Iteration: 2940\n",
"0.9065974796145293\n",
"Iteration: 2950\n",
"0.9073387694588584\n",
"Iteration: 2960\n",
"0.9086360266864344\n",
"Iteration: 2970\n",
"0.9091919940696812\n",
"Iteration: 2980\n",
"0.9091919940696812\n",
"Iteration: 2990\n",
"0.9091919940696812\n",
"Iteration: 3000\n",
"0.9095626389918459\n",
"Iteration: 3010\n",
"0.9073387694588584\n",
"Iteration: 3020\n",
"0.9027057079318014\n",
"Iteration: 3030\n",
"0.9023350630096367\n",
"Iteration: 3040\n",
"0.9049295774647887\n",
"Iteration: 3050\n",
"0.9062268346923648\n",
"Iteration: 3060\n",
"0.9080800593031876\n",
"Iteration: 3070\n",
"0.9080800593031876\n",
"Iteration: 3080\n",
"0.9082653817642699\n",
"Iteration: 3090\n",
"0.9082653817642699\n",
"Iteration: 3100\n",
"0.9088213491475167\n",
"Iteration: 3110\n",
"0.9097479614529281\n",
"Iteration: 3120\n",
"0.911045218680504\n",
"Iteration: 3130\n",
"0.9119718309859155\n",
"Iteration: 3140\n",
"0.91234247590808\n",
"Iteration: 3150\n",
"0.9125277983691623\n",
"Iteration: 3160\n",
"0.9117865085248332\n",
"Iteration: 3170\n",
"0.9086360266864344\n",
"Iteration: 3180\n",
"0.905114899925871\n",
"Iteration: 3190\n",
"0.9032616753150482\n",
"Iteration: 3200\n",
"0.9054855448480356\n",
"Iteration: 3210\n",
"0.9075240919199407\n",
"Iteration: 3220\n",
"0.9106745737583395\n",
"Iteration: 3230\n",
"0.9121571534469978\n",
"Iteration: 3240\n",
"0.9134544106745738\n",
"Iteration: 3250\n",
"0.9141957005189029\n",
"Iteration: 3260\n",
"0.9140103780578206\n",
"Iteration: 3270\n",
"0.9141957005189029\n",
"Iteration: 3280\n",
"0.9132690882134915\n",
"Iteration: 3290\n",
"0.9093773165307635\n",
"Iteration: 3300\n",
"0.9071534469977761\n",
"Iteration: 3310\n",
"0.9090066716085989\n",
"Iteration: 3320\n",
"0.9106745737583395\n",
"Iteration: 3330\n",
"0.911045218680504\n",
"Iteration: 3340\n",
"0.9127131208302446\n",
"Iteration: 3350\n",
"0.9141957005189029\n",
"Iteration: 3360\n",
"0.9145663454410674\n",
"Iteration: 3370\n",
"0.9147516679021498\n",
"Iteration: 3380\n",
"0.9149369903632321\n",
"Iteration: 3390\n",
"0.913639733135656\n",
"Iteration: 3400\n",
"0.9117865085248332\n",
"Iteration: 3410\n",
"0.9106745737583395\n",
"Iteration: 3420\n",
"0.9106745737583395\n",
"Iteration: 3430\n",
"0.911601186063751\n",
"Iteration: 3440\n",
"0.9141957005189029\n",
"Iteration: 3450\n",
"0.9153076352853966\n",
"Iteration: 3460\n",
"0.9169755374351372\n",
"Iteration: 3470\n",
"0.9171608598962194\n",
"Iteration: 3480\n",
"0.9173461823573017\n",
"Iteration: 3490\n",
"0.9153076352853966\n",
"Iteration: 3500\n",
"0.9132690882134915\n",
"Iteration: 3510\n",
"0.9119718309859155\n",
"Iteration: 3520\n",
"0.9121571534469978\n",
"Iteration: 3530\n",
"0.9140103780578206\n",
"Iteration: 3540\n",
"0.9149369903632321\n",
"Iteration: 3550\n",
"0.9167902149740549\n",
"Iteration: 3560\n",
"0.917531504818384\n",
"Iteration: 3570\n",
"0.9191994069681245\n",
"Iteration: 3580\n",
"0.9195700518902891\n",
"Iteration: 3590\n",
"0.9195700518902891\n",
"Iteration: 3600\n",
"0.9191994069681245\n",
"Iteration: 3610\n",
"0.9186434395848777\n",
"Iteration: 3620\n",
"0.9164195700518903\n",
"Iteration: 3630\n",
"0.911045218680504\n",
"Iteration: 3640\n",
"0.9106745737583395\n",
"Iteration: 3650\n",
"0.9128984432913269\n",
"Iteration: 3660\n",
"0.9160489251297257\n",
"Iteration: 3670\n",
"0.9184581171237954\n",
"Iteration: 3680\n",
"0.9190140845070423\n",
"Iteration: 3690\n",
"0.9204966641957005\n",
"Iteration: 3700\n",
"0.9206819866567828\n",
"Iteration: 3710\n",
"0.9203113417346183\n",
"Iteration: 3720\n",
"0.9203113417346183\n",
"Iteration: 3730\n",
"0.9197553743513713\n",
"Iteration: 3740\n",
"0.9195700518902891\n",
"Iteration: 3750\n",
"0.9156782802075611\n",
"Iteration: 3760\n",
"0.91234247590808\n",
"Iteration: 3770\n",
"0.9127131208302446\n",
"Iteration: 3780\n",
"0.9158636026686434\n",
"Iteration: 3790\n",
"0.9179021497405485\n",
"Iteration: 3800\n",
"0.9190140845070423\n",
"Iteration: 3810\n",
"0.920126019273536\n",
"Iteration: 3820\n",
"0.9206819866567828\n",
"Iteration: 3830\n",
"0.9214232765011119\n",
"Iteration: 3840\n",
"0.9212379540400296\n",
"Iteration: 3850\n",
"0.9212379540400296\n",
"Iteration: 3860\n",
"0.9210526315789473\n",
"Iteration: 3870\n",
"0.9193847294292068\n",
"Iteration: 3880\n",
"0.9169755374351372\n",
"Iteration: 3890\n",
"0.9154929577464789\n",
"Iteration: 3900\n",
"0.9164195700518903\n",
"Iteration: 3910\n",
"0.9180874722016308\n",
"Iteration: 3920\n",
"0.920126019273536\n",
"Iteration: 3930\n",
"0.9219792438843588\n",
"Iteration: 3940\n",
"0.9225352112676056\n",
"Iteration: 3950\n",
"0.9225352112676056\n",
"Iteration: 3960\n",
"0.9223498888065234\n",
"Iteration: 3970\n",
"0.9208673091178651\n",
"Iteration: 3980\n",
"0.9195700518902891\n",
"Iteration: 3990\n",
"0.9186434395848777\n"
]
}
],
"source": [
"W1, b1, W2, b2 = gradient_descent(X_train, Y_train, 0.10, 4000)\n",
"df = pd.DataFrame(acc_store)\n",
"df.to_csv('cr_acc.csv', index=False)\n",
"np.savez(\"cr_weights\", W1, b1, W2, b2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11203f4e-4adf-4a47-a6e2-a8847f27f0cc",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.17"
}
},
"nbformat": 4,
"nbformat_minor": 5
}