semantics/German_NN_C.ipynb
2024-09-26 17:23:23 -04:00

1098 lines
29 KiB
Text

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "ea3e9a09-3257-4b1a-9245-d42bbf88d06b",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd \n",
"from matplotlib import pyplot as plt \n",
"\n",
"data = pd.read_csv('gtsrb_data_test.csv')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "fe7852fc-5353-478d-9c7e-ad5fd251d963",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"data = np.array(data)\n",
"\n",
"m,n = data.shape\n",
"data_train = data[1000:m].T\n",
"\n",
"Y_train = data_train[0].astype(int)\n",
"\n",
"X_train = data_train[1:n]"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "f3d23a27-d544-44ea-9291-5697fac2b8c2",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.5568628\n"
]
}
],
"source": []
},
{
"cell_type": "code",
"execution_count": 3,
"id": "42435b4b-6390-4c20-b04a-cec2730879fd",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def init_params():\n",
" W1 = np.random.rand(10,1024) - 0.5\n",
" b1 = np.random.rand(10,1) - 0.5\n",
" W2 = np.random.rand(43,10) - 0.5\n",
" b2 = np.random.rand(43,1) - 0.5\n",
" return W1, b1 , W2, b2"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f9da376f-6b2a-4c7c-b392-c34f7dc7dee4",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def ReLU(Z):\n",
" return np.maximum(Z,0)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6cd98c98-9f06-4362-a3cd-a0a9b168ea97",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def softmax(Z):\n",
" A = np.exp(Z) / sum(np.exp(Z))\n",
" return A"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "039f4e5d-9d7c-4f59-81d6-f64310cc5b1f",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def forward_prop(W1, b1, W2, b2, X):\n",
" Z1 = W1.dot(X) + b1\n",
" A1 = ReLU(Z1)\n",
" Z2 = W2.dot(A1) + b2\n",
" A2 = softmax(Z2)\n",
" return Z1, A1, Z2, A2"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "24588282-2ec1-4c64-9b9d-4e7d395449fb",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def ReLU_deriv(Z):\n",
" return Z > 0"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "d4977d9c-38c9-4758-9262-dd03b8bbd015",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def one_hot(Y):\n",
" one_hot_Y = np.zeros((Y.size, Y.max() + 1))\n",
" one_hot_Y[np.arange(Y.size), Y] = 1\n",
" one_hot_Y = one_hot_Y.T\n",
" return one_hot_Y"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "941b2ec2-980f-4577-abe3-d13ac252e86b",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def backward_prop(Z1, A1, Z2, A2, W1, W2, X, Y):\n",
" one_hot_Y = one_hot(Y)\n",
" dZ2 = A2 - one_hot_Y\n",
" dW2 = 1 / m * dZ2.dot(A1.T)\n",
" db2 = 1 / m * np.sum(dZ2)\n",
" dZ1 = W2.T.dot(dZ2) * ReLU_deriv(Z1)\n",
" dW1 = 1 / m * dZ1.dot(X.T)\n",
" db1 = 1 / m * np.sum(dZ1)\n",
" return dW1, db1, dW2, db2"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "20ab4235-86ac-4708-8b16-8d0169ecd97f",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha):\n",
" W1 = W1 - alpha * dW1\n",
" b1 = b1 - alpha * db1 \n",
" W2 = W2 - alpha * dW2 \n",
" b2 = b2 - alpha * db2 \n",
" return W1, b1, W2, b2"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "b82641d2-7ebb-4adb-9c40-3a6a28ccb49b",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def get_predictions(A2):\n",
" return np.argmax(A2, 0)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "2f38d2cc-316a-48f3-a275-57253249f132",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def get_accuracy(predictions, Y):\n",
" #print(predictions, Y)\n",
" return np.sum(predictions == Y) / Y.size"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "c46ae58a-9362-4f7d-8ce4-d668230b1647",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"acc_store = [] "
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "54095b38-ca0a-41e6-9e33-c7b484e73c86",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def gradient_descent(X, Y, alpha, iterations):\n",
" W1, b1, W2, b2 = init_params()\n",
" for i in range(iterations):\n",
" Z1, A1, Z2, A2 = forward_prop(W1, b1, W2, b2, X)\n",
" dW1, db1, dW2, db2 = backward_prop(Z1, A1, Z2, A2, W1, W2, X, Y)\n",
" W1, b1, W2, b2 = update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha)\n",
" if i % 10 == 0:\n",
" print(\"Iteration: \", i)\n",
" predictions = get_predictions(A2)\n",
" pred = get_accuracy(predictions, Y)\n",
" print(pred)\n",
" acc_store.append(pred)\n",
" return W1, b1, W2, b2"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "bfa77e8c-04e4-4d1e-85f0-ea6cdd71e924",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iteration: 0\n",
"0.025193465176268273\n",
"Iteration: 10\n",
"0.04428202923473775\n",
"Iteration: 20\n",
"0.04883920894239037\n",
"Iteration: 30\n",
"0.060791057609630265\n",
"Iteration: 40\n",
"0.07919174548581255\n",
"Iteration: 50\n",
"0.09501289767841789\n",
"Iteration: 60\n",
"0.10378331900257953\n",
"Iteration: 70\n",
"0.10902837489251935\n",
"Iteration: 80\n",
"0.1124677558039553\n",
"Iteration: 90\n",
"0.11401547721410146\n",
"Iteration: 100\n",
"0.11685296646603612\n",
"Iteration: 110\n",
"0.12158211521926053\n",
"Iteration: 120\n",
"0.12510748065348237\n",
"Iteration: 130\n",
"0.12768701633705934\n",
"Iteration: 140\n",
"0.1294926913155632\n",
"Iteration: 150\n",
"0.13181427343078245\n",
"Iteration: 160\n",
"0.13284608770421324\n",
"Iteration: 170\n",
"0.13404987102321583\n",
"Iteration: 180\n",
"0.13473774720550302\n",
"Iteration: 190\n",
"0.13680137575236456\n",
"Iteration: 200\n",
"0.13843508168529664\n",
"Iteration: 210\n",
"0.13938091143594153\n",
"Iteration: 220\n",
"0.13989681857265693\n",
"Iteration: 230\n",
"0.13989681857265693\n",
"Iteration: 240\n",
"0.1402407566638005\n",
"Iteration: 250\n",
"0.1404987102321582\n",
"Iteration: 260\n",
"0.1410146173688736\n",
"Iteration: 270\n",
"0.1419604471195185\n",
"Iteration: 280\n",
"0.14282029234737748\n",
"Iteration: 290\n",
"0.14462596732588134\n",
"Iteration: 300\n",
"0.1469475494411006\n",
"Iteration: 310\n",
"0.1476354256233878\n",
"Iteration: 320\n",
"0.14944110060189167\n",
"Iteration: 330\n",
"0.15150472914875324\n",
"Iteration: 340\n",
"0.1529664660361135\n",
"Iteration: 350\n",
"0.15657781599312123\n",
"Iteration: 360\n",
"0.15950128976784178\n",
"Iteration: 370\n",
"0.16362854686156492\n",
"Iteration: 380\n",
"0.16603611349957006\n",
"Iteration: 390\n",
"0.16913155631986243\n",
"Iteration: 400\n",
"0.17205503009458298\n",
"Iteration: 410\n",
"0.1761822871883061\n",
"Iteration: 420\n",
"0.17815993121238177\n",
"Iteration: 430\n",
"0.18185726569217542\n",
"Iteration: 440\n",
"0.18426483233018057\n",
"Iteration: 450\n",
"0.18538263112639725\n",
"Iteration: 460\n",
"0.18512467755803955\n",
"Iteration: 470\n",
"0.18512467755803955\n",
"Iteration: 480\n",
"0.1874462596732588\n",
"Iteration: 490\n",
"0.1935511607910576\n",
"Iteration: 500\n",
"0.19733447979363714\n",
"Iteration: 510\n",
"0.2006018916595013\n",
"Iteration: 520\n",
"0.20429922613929494\n",
"Iteration: 530\n",
"0.20834049871023216\n",
"Iteration: 540\n",
"0.21298366294067067\n",
"Iteration: 550\n",
"0.21496130696474636\n",
"Iteration: 560\n",
"0.21908856405846946\n",
"Iteration: 570\n",
"0.2238177128116939\n",
"Iteration: 580\n",
"0.2291487532244196\n",
"Iteration: 590\n",
"0.23293207222699913\n",
"Iteration: 600\n",
"0.23439380911435942\n",
"Iteration: 610\n",
"0.23843508168529665\n",
"Iteration: 620\n",
"0.24316423043852106\n",
"Iteration: 630\n",
"0.2465176268271711\n",
"Iteration: 640\n",
"0.2524505588993981\n",
"Iteration: 650\n",
"0.2578675838349097\n",
"Iteration: 660\n",
"0.2623387790197764\n",
"Iteration: 670\n",
"0.2658641444539983\n",
"Iteration: 680\n",
"0.27136715391229577\n",
"Iteration: 690\n",
"0.2776440240756664\n",
"Iteration: 700\n",
"0.2808254514187446\n",
"Iteration: 710\n",
"0.2844368013757524\n",
"Iteration: 720\n",
"0.2868443680137575\n",
"Iteration: 730\n",
"0.29191745485812554\n",
"Iteration: 740\n",
"0.294067067927773\n",
"Iteration: 750\n",
"0.2976784178847807\n",
"Iteration: 760\n",
"0.3004299226139295\n",
"Iteration: 770\n",
"0.30472914875322443\n",
"Iteration: 780\n",
"0.30902837489251933\n",
"Iteration: 790\n",
"0.3113499570077386\n",
"Iteration: 800\n",
"0.31478933791917457\n",
"Iteration: 810\n",
"0.3183147033533964\n",
"Iteration: 820\n",
"0.32089423903697334\n",
"Iteration: 830\n",
"0.3221840068787618\n",
"Iteration: 840\n",
"0.32493551160791057\n",
"Iteration: 850\n",
"0.32639724849527085\n",
"Iteration: 860\n",
"0.3266552020636285\n",
"Iteration: 870\n",
"0.32751504729148756\n",
"Iteration: 880\n",
"0.3289767841788478\n",
"Iteration: 890\n",
"0.3315563198624248\n",
"Iteration: 900\n",
"0.3339638865004299\n",
"Iteration: 910\n",
"0.33568357695614787\n",
"Iteration: 920\n",
"0.3374032674118659\n",
"Iteration: 930\n",
"0.33920894239036975\n",
"Iteration: 940\n",
"0.3423043852106621\n",
"Iteration: 950\n",
"0.3467755803955288\n",
"Iteration: 960\n",
"0.3480653482373173\n",
"Iteration: 970\n",
"0.35081685296646603\n",
"Iteration: 980\n",
"0.35537403267411866\n",
"Iteration: 990\n",
"0.358469475494411\n",
"Iteration: 1000\n",
"0.3609630266552021\n",
"Iteration: 1010\n",
"0.36328460877042135\n",
"Iteration: 1020\n",
"0.3650042992261393\n",
"Iteration: 1030\n",
"0.36723989681857266\n",
"Iteration: 1040\n",
"0.3699914015477214\n",
"Iteration: 1050\n",
"0.37265692175408427\n",
"Iteration: 1060\n",
"0.37463456577815996\n",
"Iteration: 1070\n",
"0.3761822871883061\n",
"Iteration: 1080\n",
"0.37850386930352536\n",
"Iteration: 1090\n",
"0.37962166809974207\n",
"Iteration: 1100\n",
"0.3816852966466036\n",
"Iteration: 1110\n",
"0.3840068787618229\n",
"Iteration: 1120\n",
"0.38564058469475493\n",
"Iteration: 1130\n",
"0.3885640584694755\n",
"Iteration: 1140\n",
"0.39165950128976784\n",
"Iteration: 1150\n",
"0.394926913155632\n",
"Iteration: 1160\n",
"0.39604471195184865\n",
"Iteration: 1170\n",
"0.39819432502149615\n",
"Iteration: 1180\n",
"0.40103181427343076\n",
"Iteration: 1190\n",
"0.40386930352536543\n",
"Iteration: 1200\n",
"0.404213241616509\n",
"Iteration: 1210\n",
"0.40739466895958726\n",
"Iteration: 1220\n",
"0.40937231298366294\n",
"Iteration: 1230\n",
"0.4111779879621668\n",
"Iteration: 1240\n",
"0.41496130696474637\n",
"Iteration: 1250\n",
"0.4168529664660361\n",
"Iteration: 1260\n",
"0.4199484092863285\n",
"Iteration: 1270\n",
"0.42115219260533104\n",
"Iteration: 1280\n",
"0.423731728288908\n",
"Iteration: 1290\n",
"0.42613929492691316\n",
"Iteration: 1300\n",
"0.4278589853826311\n",
"Iteration: 1310\n",
"0.4311263972484953\n",
"Iteration: 1320\n",
"0.4322441960447119\n",
"Iteration: 1330\n",
"0.4345657781599312\n",
"Iteration: 1340\n",
"0.435597592433362\n",
"Iteration: 1350\n",
"0.43697334479793637\n",
"Iteration: 1360\n",
"0.4391229578675838\n",
"Iteration: 1370\n",
"0.4410146173688736\n",
"Iteration: 1380\n",
"0.44273430782459156\n",
"Iteration: 1390\n",
"0.44428202923473775\n",
"Iteration: 1400\n",
"0.4460017196904557\n",
"Iteration: 1410\n",
"0.44823731728288907\n",
"Iteration: 1420\n",
"0.4504729148753224\n",
"Iteration: 1430\n",
"0.4522785898538263\n",
"Iteration: 1440\n",
"0.4543422184006879\n",
"Iteration: 1450\n",
"0.45649183147033534\n",
"Iteration: 1460\n",
"0.4585554600171969\n",
"Iteration: 1470\n",
"0.460189165950129\n",
"Iteration: 1480\n",
"0.46182287188306104\n",
"Iteration: 1490\n",
"0.46388650042992263\n",
"Iteration: 1500\n",
"0.466122098022356\n",
"Iteration: 1510\n",
"0.46732588134135855\n",
"Iteration: 1520\n",
"0.46844368013757526\n",
"Iteration: 1530\n",
"0.4696474634565778\n",
"Iteration: 1540\n",
"0.4706792777300086\n",
"Iteration: 1550\n",
"0.47231298366294067\n",
"Iteration: 1560\n",
"0.47386070507308686\n",
"Iteration: 1570\n",
"0.4760963026655202\n",
"Iteration: 1580\n",
"0.47807394668959585\n",
"Iteration: 1590\n",
"0.47910576096302665\n",
"Iteration: 1600\n",
"0.48134135855546\n",
"Iteration: 1610\n",
"0.4824591573516767\n",
"Iteration: 1620\n",
"0.4831470335339639\n",
"Iteration: 1630\n",
"0.48460877042132416\n",
"Iteration: 1640\n",
"0.4866723989681857\n",
"Iteration: 1650\n",
"0.4877042132416165\n",
"Iteration: 1660\n",
"0.48796216680997423\n",
"Iteration: 1670\n",
"0.48950988822012037\n",
"Iteration: 1680\n",
"0.4885640584694755\n",
"Iteration: 1690\n",
"0.458469475494411\n",
"Iteration: 1700\n",
"0.4817712811693895\n",
"Iteration: 1710\n",
"0.5031814273430782\n",
"Iteration: 1720\n",
"0.5036973344797936\n",
"Iteration: 1730\n",
"0.4990541702493551\n",
"Iteration: 1740\n",
"0.48658641444539985\n",
"Iteration: 1750\n",
"0.4834049871023216\n",
"Iteration: 1760\n",
"0.49234737747205504\n",
"Iteration: 1770\n",
"0.5004299226139295\n",
"Iteration: 1780\n",
"0.504213241616509\n",
"Iteration: 1790\n",
"0.5081685296646603\n",
"Iteration: 1800\n",
"0.5116938950988822\n",
"Iteration: 1810\n",
"0.5149613069647463\n",
"Iteration: 1820\n",
"0.5170249355116079\n",
"Iteration: 1830\n",
"0.5184866723989682\n",
"Iteration: 1840\n",
"0.5189165950128977\n",
"Iteration: 1850\n",
"0.5210662080825451\n",
"Iteration: 1860\n",
"0.5173688736027515\n",
"Iteration: 1870\n",
"0.5042992261392949\n",
"Iteration: 1880\n",
"0.4950988822012038\n",
"Iteration: 1890\n",
"0.48538263112639723\n",
"Iteration: 1900\n",
"0.4998280309544282\n",
"Iteration: 1910\n",
"0.5120378331900258\n",
"Iteration: 1920\n",
"0.5133276010318143\n",
"Iteration: 1930\n",
"0.5149613069647463\n",
"Iteration: 1940\n",
"0.5182287188306105\n",
"Iteration: 1950\n",
"0.5208942390369733\n",
"Iteration: 1960\n",
"0.5234737747205503\n",
"Iteration: 1970\n",
"0.5251074806534823\n",
"Iteration: 1980\n",
"0.5267411865864144\n",
"Iteration: 1990\n",
"0.5280309544282029\n",
"Iteration: 2000\n",
"0.530438521066208\n",
"Iteration: 2010\n",
"0.5311263972484953\n",
"Iteration: 2020\n",
"0.529664660361135\n",
"Iteration: 2030\n",
"0.525451418744626\n",
"Iteration: 2040\n",
"0.5149613069647463\n",
"Iteration: 2050\n",
"0.49140154772141015\n",
"Iteration: 2060\n",
"0.5261392949269131\n",
"Iteration: 2070\n",
"0.5496130696474635\n",
"Iteration: 2080\n",
"0.5528804815133276\n",
"Iteration: 2090\n",
"0.548237317282889\n",
"Iteration: 2100\n",
"0.5408426483233018\n",
"Iteration: 2110\n",
"0.5248495270851247\n",
"Iteration: 2120\n",
"0.5248495270851247\n",
"Iteration: 2130\n",
"0.5340498710232158\n",
"Iteration: 2140\n",
"0.5439380911435941\n",
"Iteration: 2150\n",
"0.550816852966466\n",
"Iteration: 2160\n",
"0.5544282029234737\n",
"Iteration: 2170\n",
"0.5510748065348238\n",
"Iteration: 2180\n",
"0.5514187446259673\n",
"Iteration: 2190\n",
"0.5385210662080826\n",
"Iteration: 2200\n",
"0.542304385210662\n",
"Iteration: 2210\n",
"0.5588134135855546\n",
"Iteration: 2220\n",
"0.5655202063628547\n",
"Iteration: 2230\n",
"0.5571797076526225\n",
"Iteration: 2240\n",
"0.5384350816852966\n",
"Iteration: 2250\n",
"0.5430782459157352\n",
"Iteration: 2260\n",
"0.5548581255374032\n",
"Iteration: 2270\n",
"0.5577815993121238\n",
"Iteration: 2280\n",
"0.55932932072227\n",
"Iteration: 2290\n",
"0.5602751504729149\n",
"Iteration: 2300\n",
"0.561049011177988\n",
"Iteration: 2310\n",
"0.5584694754944111\n",
"Iteration: 2320\n",
"0.5538263112639725\n",
"Iteration: 2330\n",
"0.5439380911435941\n",
"Iteration: 2340\n",
"0.5379191745485813\n",
"Iteration: 2350\n",
"0.5466895958727429\n",
"Iteration: 2360\n",
"0.5522785898538263\n",
"Iteration: 2370\n",
"0.5636285468615649\n",
"Iteration: 2380\n",
"0.574548581255374\n",
"Iteration: 2390\n",
"0.5803955288048152\n",
"Iteration: 2400\n",
"0.5831470335339639\n",
"Iteration: 2410\n",
"0.5859845227858985\n",
"Iteration: 2420\n",
"0.5879621668099742\n",
"Iteration: 2430\n",
"0.5848667239896819\n",
"Iteration: 2440\n",
"0.5764402407566638\n",
"Iteration: 2450\n",
"0.572828890799656\n",
"Iteration: 2460\n",
"0.5832330180567498\n",
"Iteration: 2470\n",
"0.5879621668099742\n",
"Iteration: 2480\n",
"0.576268271711092\n",
"Iteration: 2490\n",
"0.5756663800515908\n",
"Iteration: 2500\n",
"0.582201203783319\n",
"Iteration: 2510\n",
"0.5882201203783319\n",
"Iteration: 2520\n",
"0.5885640584694755\n",
"Iteration: 2530\n",
"0.5852106620808255\n",
"Iteration: 2540\n",
"0.5783319002579536\n",
"Iteration: 2550\n",
"0.5766122098022356\n",
"Iteration: 2560\n",
"0.5803095442820292\n",
"Iteration: 2570\n",
"0.585640584694755\n",
"Iteration: 2580\n",
"0.5918314703353397\n",
"Iteration: 2590\n",
"0.585640584694755\n",
"Iteration: 2600\n",
"0.5929492691315563\n",
"Iteration: 2610\n",
"0.605159071367154\n",
"Iteration: 2620\n",
"0.6088564058469476\n",
"Iteration: 2630\n",
"0.6080825451418744\n",
"Iteration: 2640\n",
"0.6078245915735168\n",
"Iteration: 2650\n",
"0.6065348237317283\n",
"Iteration: 2660\n",
"0.6094582975064489\n",
"Iteration: 2670\n",
"0.6084264832330181\n",
"Iteration: 2680\n",
"0.6133276010318143\n",
"Iteration: 2690\n",
"0.6077386070507309\n",
"Iteration: 2700\n",
"0.5948409286328461\n",
"Iteration: 2710\n",
"0.6030954428202924\n",
"Iteration: 2720\n",
"0.6112639724849527\n",
"Iteration: 2730\n",
"0.6117798796216681\n",
"Iteration: 2740\n",
"0.6086844368013757\n",
"Iteration: 2750\n",
"0.6091143594153052\n",
"Iteration: 2760\n",
"0.6121238177128117\n",
"Iteration: 2770\n",
"0.6153912295786759\n",
"Iteration: 2780\n",
"0.6151332760103182\n",
"Iteration: 2790\n",
"0.616938950988822\n",
"Iteration: 2800\n",
"0.6122957867583835\n",
"Iteration: 2810\n",
"0.6045571797076527\n",
"Iteration: 2820\n",
"0.6171109200343938\n",
"Iteration: 2830\n",
"0.6277730008598452\n",
"Iteration: 2840\n",
"0.6288907996560619\n",
"Iteration: 2850\n",
"0.6271711092003439\n",
"Iteration: 2860\n",
"0.6308684436801376\n",
"Iteration: 2870\n",
"0.6312983662940671\n",
"Iteration: 2880\n",
"0.6307824591573516\n",
"Iteration: 2890\n",
"0.6348237317282889\n",
"Iteration: 2900\n",
"0.6386070507308684\n",
"Iteration: 2910\n",
"0.6353396388650043\n",
"Iteration: 2920\n",
"0.6213241616509029\n",
"Iteration: 2930\n",
"0.6240756663800516\n",
"Iteration: 2940\n",
"0.6345657781599312\n",
"Iteration: 2950\n",
"0.6372312983662941\n",
"Iteration: 2960\n",
"0.6330180567497851\n",
"Iteration: 2970\n",
"0.6343938091143594\n",
"Iteration: 2980\n",
"0.6374892519346518\n",
"Iteration: 2990\n",
"0.6381771281169389\n",
"Iteration: 3000\n",
"0.6404987102321582\n",
"Iteration: 3010\n",
"0.6439380911435941\n",
"Iteration: 3020\n",
"0.6418744625967326\n",
"Iteration: 3030\n",
"0.6346517626827171\n",
"Iteration: 3040\n",
"0.6258813413585554\n",
"Iteration: 3050\n",
"0.6392949269131556\n",
"Iteration: 3060\n",
"0.6469475494411006\n",
"Iteration: 3070\n",
"0.64737747205503\n",
"Iteration: 3080\n",
"0.6445399828030954\n",
"Iteration: 3090\n",
"0.6468615649183147\n",
"Iteration: 3100\n",
"0.6502149613069648\n",
"Iteration: 3110\n",
"0.6527944969905417\n",
"Iteration: 3120\n",
"0.6552880481513328\n",
"Iteration: 3130\n",
"0.6583834909716251\n",
"Iteration: 3140\n",
"0.660103181427343\n",
"Iteration: 3150\n",
"0.6502149613069648\n",
"Iteration: 3160\n",
"0.6441960447119518\n",
"Iteration: 3170\n",
"0.6499570077386071\n",
"Iteration: 3180\n",
"0.6577815993121238\n",
"Iteration: 3190\n",
"0.654170249355116\n",
"Iteration: 3200\n",
"0.6540842648323302\n",
"Iteration: 3210\n",
"0.6580395528804815\n",
"Iteration: 3220\n",
"0.6615649183147033\n",
"Iteration: 3230\n",
"0.6634565778159931\n",
"Iteration: 3240\n",
"0.6682717110920035\n",
"Iteration: 3250\n",
"0.6688736027515048\n",
"Iteration: 3260\n",
"0.6673258813413585\n",
"Iteration: 3270\n",
"0.6582115219260533\n",
"Iteration: 3280\n",
"0.6464316423043852\n",
"Iteration: 3290\n",
"0.6527944969905417\n",
"Iteration: 3300\n",
"0.6649183147033534\n",
"Iteration: 3310\n",
"0.6658641444539983\n",
"Iteration: 3320\n",
"0.6637145313843508\n",
"Iteration: 3330\n",
"0.6682717110920035\n",
"Iteration: 3340\n",
"0.6712811693895099\n",
"Iteration: 3350\n",
"0.671969045571797\n",
"Iteration: 3360\n",
"0.6722269991401548\n",
"Iteration: 3370\n",
"0.6730008598452278\n",
"Iteration: 3380\n",
"0.6782459157351677\n",
"Iteration: 3390\n",
"0.6794496990541703\n",
"Iteration: 3400\n",
"0.6723129836629407\n",
"Iteration: 3410\n",
"0.6655202063628547\n",
"Iteration: 3420\n",
"0.6648323301805675\n",
"Iteration: 3430\n",
"0.677128116938951\n",
"Iteration: 3440\n",
"0.6760963026655202\n",
"Iteration: 3450\n",
"0.6699054170249356\n",
"Iteration: 3460\n",
"0.6776440240756664\n",
"Iteration: 3470\n",
"0.6831470335339639\n",
"Iteration: 3480\n",
"0.6845227858985382\n",
"Iteration: 3490\n",
"0.6829750644883921\n",
"Iteration: 3500\n",
"0.6809114359415305\n",
"Iteration: 3510\n",
"0.683061049011178\n",
"Iteration: 3520\n",
"0.6860705073086845\n",
"Iteration: 3530\n",
"0.6862424763542563\n",
"Iteration: 3540\n",
"0.6808254514187446\n",
"Iteration: 3550\n",
"0.6705932932072227\n",
"Iteration: 3560\n",
"0.6765262252794497\n",
"Iteration: 3570\n",
"0.6884780739466896\n",
"Iteration: 3580\n",
"0.6907996560619089\n",
"Iteration: 3590\n",
"0.6785898538263113\n",
"Iteration: 3600\n",
"0.6792777300085985\n",
"Iteration: 3610\n",
"0.6947549441100602\n",
"Iteration: 3620\n",
"0.7036113499570077\n",
"Iteration: 3630\n",
"0.7007738607050731\n",
"Iteration: 3640\n",
"0.6918314703353396\n",
"Iteration: 3650\n",
"0.6798796216680998\n",
"Iteration: 3660\n",
"0.68469475494411\n",
"Iteration: 3670\n",
"0.6921754084264833\n",
"Iteration: 3680\n",
"0.697506448839209\n",
"Iteration: 3690\n",
"0.6960447119518487\n",
"Iteration: 3700\n",
"0.6905417024935512\n",
"Iteration: 3710\n",
"0.6876182287188306\n",
"Iteration: 3720\n",
"0.697506448839209\n",
"Iteration: 3730\n",
"0.7049011177987962\n",
"Iteration: 3740\n",
"0.7035253654342218\n",
"Iteration: 3750\n",
"0.6846087704213242\n",
"Iteration: 3760\n",
"0.6799656061908856\n",
"Iteration: 3770\n",
"0.6952708512467756\n",
"Iteration: 3780\n",
"0.7100601891659502\n",
"Iteration: 3790\n",
"0.7192605331040413\n",
"Iteration: 3800\n",
"0.7217540842648323\n",
"Iteration: 3810\n",
"0.7165950128976785\n",
"Iteration: 3820\n",
"0.698280309544282\n",
"Iteration: 3830\n",
"0.6799656061908856\n",
"Iteration: 3840\n",
"0.678761822871883\n",
"Iteration: 3850\n",
"0.693207222699914\n",
"Iteration: 3860\n",
"0.7019776440240757\n",
"Iteration: 3870\n",
"0.7073946689595872\n",
"Iteration: 3880\n",
"0.7070507308684437\n",
"Iteration: 3890\n",
"0.7057609630266553\n",
"Iteration: 3900\n",
"0.7104041272570937\n",
"Iteration: 3910\n",
"0.7171109200343938\n",
"Iteration: 3920\n",
"0.7189165950128977\n",
"Iteration: 3930\n",
"0.7104041272570937\n",
"Iteration: 3940\n",
"0.691487532244196\n",
"Iteration: 3950\n",
"0.6880481513327601\n",
"Iteration: 3960\n",
"0.7095442820292347\n",
"Iteration: 3970\n",
"0.722957867583835\n",
"Iteration: 3980\n",
"0.7315563198624248\n",
"Iteration: 3990\n",
"0.7358555460017197\n"
]
}
],
"source": [
"W1, b1, W2, b2 = gradient_descent(X_train, Y_train, 0.10, 4000)\n",
"df = pd.DataFrame(acc_store)\n",
"df.to_csv('gt_acc.csv', index=False)\n",
"np.savez(\"gt_weights\", W1, b1, W2, b2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c8c9bb8d-0379-4b46-941b-410bf0da29ef",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.17"
}
},
"nbformat": 4,
"nbformat_minor": 5
}