423 lines
10 KiB
Text
423 lines
10 KiB
Text
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "407c9473",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "14cccaae-d3b6-4ae5-a28a-5fed4b998783",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def init_params():\n",
|
|
" W1 = np.random.rand(10,1024) - 0.5\n",
|
|
" b1 = np.random.rand(10,1) - 0.5\n",
|
|
" W2 = np.random.rand(61,10) - 0.5\n",
|
|
" b2 = np.random.rand(61,1) - 0.5\n",
|
|
" return W1, b1 , W2, b2\n",
|
|
"def ReLU(Z):\n",
|
|
" return np.maximum(Z,0)\n",
|
|
"def softmax(Z):\n",
|
|
" A = np.exp(Z) / sum(np.exp(Z))\n",
|
|
" return A\n",
|
|
"def forward_prop(W1, b1, W2, b2, X):\n",
|
|
" Z1 = W1.dot(X) + b1\n",
|
|
" A1 = ReLU(Z1)\n",
|
|
" Z2 = W2.dot(A1) + b2\n",
|
|
" A2 = softmax(Z2)\n",
|
|
" return Z1, A1, Z2, A2\n",
|
|
"def ReLU_deriv(Z):\n",
|
|
" return Z > 0\n",
|
|
"def one_hot(Y):\n",
|
|
" one_hot_Y = np.zeros((Y.size, Y.max() + 1))\n",
|
|
" one_hot_Y[np.arange(Y.size), Y] = 1\n",
|
|
" one_hot_Y = one_hot_Y.T\n",
|
|
" return one_hot_Y\n",
|
|
"def backward_prop(Z1, A1, Z2, A2, W1, W2, X, Y):\n",
|
|
" one_hot_Y = one_hot(Y)\n",
|
|
" dZ2 = A2 - one_hot_Y\n",
|
|
" dW2 = 1 / m * dZ2.dot(A1.T)\n",
|
|
" db2 = 1 / m * np.sum(dZ2)\n",
|
|
" dZ1 = W2.T.dot(dZ2) * ReLU_deriv(Z1)\n",
|
|
" dW1 = 1 / m * dZ1.dot(X.T)\n",
|
|
" db1 = 1 / m * np.sum(dZ1)\n",
|
|
" return dW1, db1, dW2, db2\n",
|
|
"def update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha):\n",
|
|
" W1 = W1 - alpha * dW1\n",
|
|
" b1 = b1 - alpha * db1 \n",
|
|
" W2 = W2 - alpha * dW2 \n",
|
|
" b2 = b2 - alpha * db2 \n",
|
|
" return W1, b1, W2, b2\n",
|
|
"def get_predictions(A2):\n",
|
|
" return np.argmax(A2, 0)\n",
|
|
"def get_accuracy(predictions, Y):\n",
|
|
" #print(predictions, Y)\n",
|
|
" return np.sum(predictions == Y) / Y.size"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "ec251927-46fc-413d-abb8-34fafb5a429d",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"#import data_ready as dr\n",
|
|
"import os\n",
|
|
"import struct\n",
|
|
"import numpy as np\n",
|
|
"from matplotlib import pyplot as plt \n",
|
|
"\n",
|
|
"'''\n",
|
|
"npz = np.load(\"weights.npz\")\n",
|
|
"W1 = np.array(npz['arr_0'])\n",
|
|
"b1 = np.array(npz['arr_1'])\n",
|
|
"W2 = np.array(npz['arr_2'])\n",
|
|
"b2 = np.array(npz['arr_3'])\n",
|
|
"'''\n",
|
|
"def encode_image(X,W1,b1,W2,b2):\n",
|
|
" current_image = X\n",
|
|
" _, _, _, A2 = forward_prop(W1,b1,W2,b2,current_image)\n",
|
|
" return A2\n",
|
|
" #print(A2)\n",
|
|
" #np.save('pred', A2)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ff4758bc-7e5d-47ba-a3d4-aa4afde6165f",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Load in The Weights"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "d323d295-5f29-4233-975c-1d5eab88a830",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"bt_npz = np.load(\"bt_weights.npz\")\n",
|
|
"cr_npz = np.load(\"cr_weights.npz\")\n",
|
|
"gt_npz = np.load(\"gt_weights.npz\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "62160a9e-f166-47d1-88af-cc767385d09f",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"W1_bt = np.array(bt_npz['arr_0'])\n",
|
|
"b1_bt = np.array(bt_npz['arr_1'])\n",
|
|
"W2_bt = np.array(bt_npz['arr_2'])\n",
|
|
"b2_bt = np.array(bt_npz['arr_3'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "daa2d8e0-8f76-436d-9e3d-dfb711808e43",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"W1_cr = np.array(cr_npz['arr_0'])\n",
|
|
"b1_cr = np.array(cr_npz['arr_1'])\n",
|
|
"W2_cr = np.array(cr_npz['arr_2'])\n",
|
|
"b2_cr = np.array(cr_npz['arr_3'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "17e3cd24-21b1-440c-8e38-f91597368771",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"W1_gt = np.array(gt_npz['arr_0'])\n",
|
|
"b1_gt = np.array(gt_npz['arr_1'])\n",
|
|
"W2_gt = np.array(gt_npz['arr_2'])\n",
|
|
"b2_gt = np.array(gt_npz['arr_3'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "4ae55872-8e28-419d-85af-1ef185a254ce",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Load in the Dataset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "0f230135-264d-4b89-b0a6-cab229ed0047",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"datac = pd.read_csv('cro_data_test.csv')\n",
|
|
"datac = np.array(datac)\n",
|
|
"\n",
|
|
"m,n = datac.shape\n",
|
|
"data_trainc = datac[1000:m].T\n",
|
|
"\n",
|
|
"Y_trainc = data_trainc[0].astype(int)\n",
|
|
"X_trainc = data_trainc[1:n]\n",
|
|
"\n",
|
|
"current_image_c = X_trainc[:,1,None]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "1f2a757a-4c29-4271-a895-03415765b105",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"datag = pd.read_csv('gtsrb_data_test.csv')\n",
|
|
"datag = np.array(datag)\n",
|
|
"\n",
|
|
"m,n = datag.shape\n",
|
|
"data_traing = datag[1000:m].T\n",
|
|
"\n",
|
|
"Y_traing = data_traing[0].astype(int)\n",
|
|
"X_traing = data_traing[1:n]\n",
|
|
"\n",
|
|
"current_image_g = X_traing[:,1,None]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "7026b28c-ef29-4706-9e9b-ed2417ab06eb",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"datab = pd.read_csv('bel_data_test.csv')\n",
|
|
"datab = np.array(datab)\n",
|
|
"\n",
|
|
"m,n = datab.shape\n",
|
|
"data_trainb = datab[1000:m].T\n",
|
|
"\n",
|
|
"Y_trainb = data_trainb[0].astype(int)\n",
|
|
"X_trainb = data_trainb[1:n]\n",
|
|
"\n",
|
|
"current_image_b = X_trainc[:,1,None]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c0552a80-ab8f-4ac0-ba4b-bc3ab26a1944",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Encoding 1 Image"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "3daa565f-64f0-4294-9c36-bb87d1904ad0",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"c1 = encode_image(current_image_c,W1_cr,b1_cr,W2_cr,b2_cr)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "cd9d9223-d626-43e9-86d7-ef9bd24e0bc7",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"b1 = encode_image(current_image_b,W1_bt,b1_bt,W2_bt,b2_bt)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "22443477-852f-4978-ae65-6883ed9d1e6b",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"g1 = encode_image(current_image_g,W1_gt,b1_gt,W2_gt,b2_gt)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "b33a151f-d6b1-4386-8fa8-7e88da8951c8",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"np.save('data/Single_Encoding/pred_c', c1)\n",
|
|
"np.save('data/Single_Encoding/pred_b', b1)\n",
|
|
"np.save('data/Single_Encoding/pred_g', g1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a0df95db-6f10-4f6e-a043-0a82b6e3763b",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Encoding 900 Images"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 44,
|
|
"id": "002bd5b5-aead-413b-8051-64326ebef595",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"for x in range(0,900):\n",
|
|
" current_image_g = X_traing[:,x,None]\n",
|
|
" g1 = encode_image(current_image_g,W1_gt,b1_gt,W2_gt,b2_gt)\n",
|
|
" np.save('data/9_G/pred_g' + str(x), g1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 45,
|
|
"id": "6513e7f4-ef60-4d96-8576-540d4ef91c7b",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"for x in range(0,900):\n",
|
|
" current_image_b = X_trainb[:,x,None]\n",
|
|
" b1 = encode_image(current_image_b,W1_bt,b1_bt,W2_bt,b2_bt)\n",
|
|
" np.save('data/9_B/pred_b' + str(x), b1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 46,
|
|
"id": "fc9144a0-b415-463d-bcc7-5ebaed0467ce",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"for x in range(0,900):\n",
|
|
" current_image_c = X_trainc[:,x,None]\n",
|
|
" c1 = encode_image(current_image_c,W1_cr,b1_cr,W2_cr,b2_cr)\n",
|
|
" np.save('data/9_C/pred_c' + str(x), c1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "4fbd5222-e2b9-4c1d-82a0-6294136a2a71",
|
|
"metadata": {},
|
|
"source": [
|
|
"Encoding 1800 Images"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 47,
|
|
"id": "00eba2e4-b5a1-4905-a20f-42d71eebaeff",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"for x in range(0,1800):\n",
|
|
" current_image_g = X_traing[:,x,None]\n",
|
|
" g1 = encode_image(current_image_g,W1_gt,b1_gt,W2_gt,b2_gt)\n",
|
|
" np.save('data/18_G/pred_g' + str(x), g1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 48,
|
|
"id": "ee1bd756-09c1-401f-8301-95357ae52446",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"for x in range(0,1800):\n",
|
|
" current_image_b = X_trainb[:,x,None]\n",
|
|
" b1 = encode_image(current_image_b,W1_bt,b1_bt,W2_bt,b2_bt)\n",
|
|
" np.save('data/18_B/pred_b' + str(x), b1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 49,
|
|
"id": "65dfac08-626e-47ac-9d35-15b794998f8d",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"for x in range(0,1800):\n",
|
|
" current_image_c = X_trainc[:,x,None]\n",
|
|
" c1 = encode_image(current_image_c,W1_cr,b1_cr,W2_cr,b2_cr)\n",
|
|
" np.save('data/18_C/pred_c' + str(x), c1)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.16"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|