203 lines
19 KiB
Text
203 lines
19 KiB
Text
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"\n",
|
|
"from net.modules import *\n",
|
|
"from net.decoder import *\n",
|
|
"from net.encoder import *"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Set up parameters (modify these as needed)\n",
|
|
"file_path = \"data/bel_data_test.csv\" # Replace with your actual file path\n",
|
|
"encoder_type = \"pca\" # Choose \"regular\" or \"pca\"\n",
|
|
"load_weights = False # Set to True if you want to load pre-trained weights\n",
|
|
"weight_file = \"weights/bel_weights.npz\" # Only used if load_weights is True\n",
|
|
"dataset=\"bel\"\n",
|
|
"\n",
|
|
"# Load and prepare the data\n",
|
|
"X_train, Y_train, X_test, Y_test = load_and_prepare_data(file_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Set hyperparameters\n",
|
|
"input_size = X_train.shape[0]\n",
|
|
"hidden_size = 128\n",
|
|
"output_size = 61 # Number of classes\n",
|
|
"pca_components = 50\n",
|
|
"\n",
|
|
"alpha = 0.01\n",
|
|
"iterations = 2000\n",
|
|
"num_trials = 5"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Choose encoder type\n",
|
|
"if encoder_type == \"regular\":\n",
|
|
" EncoderClass = Encoder\n",
|
|
" encoder_name = \"Regular Encoder\"\n",
|
|
"elif encoder_type == \"pca\":\n",
|
|
" EncoderClass = PCAEncoder\n",
|
|
" encoder_name = \"PCAEncoder\"\n",
|
|
"else:\n",
|
|
" raise ValueError(\"Invalid encoder type selected\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Training PCAEncoder\n",
|
|
"Trial 1/5\n"
|
|
]
|
|
},
|
|
{
|
|
"ename": "ValueError",
|
|
"evalue": "X has 50 features, but StandardScaler is expecting 1024 features as input.",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
|
|
"Cell \u001b[0;32mIn[5], line 9\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTraining \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mencoder_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 9\u001b[0m best_weights, best_accuracy, all_accuracies \u001b[38;5;241m=\u001b[39m \u001b[43mbest_params\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[43mEncoderClass\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mY_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhidden_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 11\u001b[0m \u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43miterations\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_trials\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpca_components\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpca_components\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mEncoderClass\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mPCAEncoder\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\n\u001b[1;32m 12\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBest accuracy achieved with \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mencoder_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mbest_accuracy\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 15\u001b[0m \u001b[38;5;66;03m# Create encoder with best weights\u001b[39;00m\n",
|
|
"File \u001b[0;32m~/Projects/School/Sem_Imp/net/modules.py:15\u001b[0m, in \u001b[0;36mbest_params\u001b[0;34m(EncoderClass, X, Y, input_size, hidden_size, output_size, alpha, iterations, num_trials, **kwargs)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTrial \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrial\u001b[38;5;250m \u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_trials\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 14\u001b[0m encoder \u001b[38;5;241m=\u001b[39m EncoderClass(input_size, hidden_size, output_size, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m---> 15\u001b[0m accuracies \u001b[38;5;241m=\u001b[39m \u001b[43mencoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mY\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43miterations\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 16\u001b[0m final_accuracy \u001b[38;5;241m=\u001b[39m accuracies[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m 17\u001b[0m all_accuracies\u001b[38;5;241m.\u001b[39mappend(accuracies)\n",
|
|
"File \u001b[0;32m~/Projects/School/Sem_Imp/net/encoder.py:112\u001b[0m, in \u001b[0;36mPCAEncoder.train\u001b[0;34m(self, X, Y, iterations, alpha)\u001b[0m\n\u001b[1;32m 110\u001b[0m X_scaled \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscaler\u001b[38;5;241m.\u001b[39mtransform(X\u001b[38;5;241m.\u001b[39mT)\u001b[38;5;241m.\u001b[39mT\n\u001b[1;32m 111\u001b[0m X_pca \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpca\u001b[38;5;241m.\u001b[39mtransform(X_scaled\u001b[38;5;241m.\u001b[39mT)\u001b[38;5;241m.\u001b[39mT\n\u001b[0;32m--> 112\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_pca\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mY\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43miterations\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[43m)\u001b[49m\n",
|
|
"File \u001b[0;32m~/Projects/School/Sem_Imp/net/encoder.py:64\u001b[0m, in \u001b[0;36mEncoder.train\u001b[0;34m(self, X, Y, iterations, alpha)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mupdate_params(dW1, db1, dW2, db2, alpha)\n\u001b[1;32m 63\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m%\u001b[39m \u001b[38;5;241m100\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m---> 64\u001b[0m accuracy \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_accuracy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mY\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 65\u001b[0m accuracies\u001b[38;5;241m.\u001b[39mappend(accuracy)\n\u001b[1;32m 66\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIteration \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Accuracy: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00maccuracy\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
|
|
"File \u001b[0;32m~/Projects/School/Sem_Imp/net/encoder.py:115\u001b[0m, in \u001b[0;36mPCAEncoder.get_accuracy\u001b[0;34m(self, X, Y)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_accuracy\u001b[39m(\u001b[38;5;28mself\u001b[39m, X, Y):\n\u001b[0;32m--> 115\u001b[0m X_scaled \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscaler\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtransform\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mT\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mT\n\u001b[1;32m 116\u001b[0m X_pca \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpca\u001b[38;5;241m.\u001b[39mtransform(X_scaled\u001b[38;5;241m.\u001b[39mT)\u001b[38;5;241m.\u001b[39mT\n\u001b[1;32m 117\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39mget_accuracy(X_pca, Y)\n",
|
|
"File \u001b[0;32m~/.pyenv/versions/3.11.6/envs/tf/lib/python3.11/site-packages/sklearn/utils/_set_output.py:313\u001b[0m, in \u001b[0;36m_wrap_method_output.<locals>.wrapped\u001b[0;34m(self, X, *args, **kwargs)\u001b[0m\n\u001b[1;32m 311\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(f)\n\u001b[1;32m 312\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapped\u001b[39m(\u001b[38;5;28mself\u001b[39m, X, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 313\u001b[0m data_to_wrap \u001b[38;5;241m=\u001b[39m \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 314\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data_to_wrap, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m 315\u001b[0m \u001b[38;5;66;03m# only wrap the first output for cross decomposition\u001b[39;00m\n\u001b[1;32m 316\u001b[0m return_tuple \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 317\u001b[0m _wrap_data_with_container(method, data_to_wrap[\u001b[38;5;241m0\u001b[39m], X, \u001b[38;5;28mself\u001b[39m),\n\u001b[1;32m 318\u001b[0m \u001b[38;5;241m*\u001b[39mdata_to_wrap[\u001b[38;5;241m1\u001b[39m:],\n\u001b[1;32m 319\u001b[0m )\n",
|
|
"File \u001b[0;32m~/.pyenv/versions/3.11.6/envs/tf/lib/python3.11/site-packages/sklearn/preprocessing/_data.py:1045\u001b[0m, in \u001b[0;36mStandardScaler.transform\u001b[0;34m(self, X, copy)\u001b[0m\n\u001b[1;32m 1042\u001b[0m check_is_fitted(\u001b[38;5;28mself\u001b[39m)\n\u001b[1;32m 1044\u001b[0m copy \u001b[38;5;241m=\u001b[39m copy \u001b[38;5;28;01mif\u001b[39;00m copy \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcopy\n\u001b[0;32m-> 1045\u001b[0m X \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_validate_data\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1046\u001b[0m \u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1047\u001b[0m \u001b[43m \u001b[49m\u001b[43mreset\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1048\u001b[0m \u001b[43m \u001b[49m\u001b[43maccept_sparse\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcsr\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1049\u001b[0m \u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1050\u001b[0m \u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mFLOAT_DTYPES\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1051\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_writeable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1052\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_all_finite\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mallow-nan\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1053\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1055\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m sparse\u001b[38;5;241m.\u001b[39missparse(X):\n\u001b[1;32m 1056\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwith_mean:\n",
|
|
"File \u001b[0;32m~/.pyenv/versions/3.11.6/envs/tf/lib/python3.11/site-packages/sklearn/base.py:654\u001b[0m, in \u001b[0;36mBaseEstimator._validate_data\u001b[0;34m(self, X, y, reset, validate_separately, cast_to_ndarray, **check_params)\u001b[0m\n\u001b[1;32m 651\u001b[0m out \u001b[38;5;241m=\u001b[39m X, y\n\u001b[1;32m 653\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m no_val_X \u001b[38;5;129;01mand\u001b[39;00m check_params\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mensure_2d\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mTrue\u001b[39;00m):\n\u001b[0;32m--> 654\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_check_n_features\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreset\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 656\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out\n",
|
|
"File \u001b[0;32m~/.pyenv/versions/3.11.6/envs/tf/lib/python3.11/site-packages/sklearn/base.py:443\u001b[0m, in \u001b[0;36mBaseEstimator._check_n_features\u001b[0;34m(self, X, reset)\u001b[0m\n\u001b[1;32m 440\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 442\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m n_features \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_features_in_:\n\u001b[0;32m--> 443\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 444\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX has \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mn_features\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m features, but \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 445\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mis expecting \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_features_in_\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m features as input.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 446\u001b[0m )\n",
|
|
"\u001b[0;31mValueError\u001b[0m: X has 50 features, but StandardScaler is expecting 1024 features as input."
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Load pre-trained weights or train a new model\n",
|
|
"if load_weights:\n",
|
|
" encoder = EncoderClass(input_size, hidden_size, output_size, pca_components=pca_components)\n",
|
|
" encoder.load_weights(weight_file)\n",
|
|
" print(f\"Weights loaded for {encoder_name}\")\n",
|
|
"else:\n",
|
|
" print(f\"Training {encoder_name}\")\n",
|
|
" best_weights, best_accuracy, all_accuracies = best_params(\n",
|
|
" EncoderClass, X_train, Y_train, input_size, hidden_size, output_size, \n",
|
|
" alpha, iterations, num_trials, pca_components=pca_components\n",
|
|
" )\n",
|
|
" print(f\"Best accuracy achieved with {encoder_name}: {best_accuracy}\")\n",
|
|
"\n",
|
|
" # Create encoder with best weights\n",
|
|
" encoder = EncoderClass(input_size, hidden_size, output_size, pca_components=pca_components)\n",
|
|
" encoder.W1, encoder.b1, encoder.W2, encoder.b2 = best_weights\n",
|
|
"\n",
|
|
" # Save the best weights\n",
|
|
" weight_file = f\"{dataset}_{encoder_name.lower().replace(' ', '_')}_weights.npz\"\n",
|
|
" encoder.save_weights(weight_file)\n",
|
|
" print(f\"Best weights saved to {weight_file}\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Initialize the appropriate Decoder\n",
|
|
"decoder = PCADecoder(encoder) if EncoderClass == PCAEncoder else Decoder(encoder)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Test the encoding-decoding process\n",
|
|
"num_test_images = 5\n",
|
|
"for i in range(num_test_images):\n",
|
|
" original_image = X_test[:, i].reshape(input_size, 1)\n",
|
|
" encoded = encoder.encode(original_image)\n",
|
|
" reconstructed = decoder.decode(encoded)\n",
|
|
"\n",
|
|
" # Visualize the results\n",
|
|
" img_dim = int(np.sqrt(input_size))\n",
|
|
" plt.figure(figsize=(15, 5))\n",
|
|
" \n",
|
|
" plt.subplot(1, 3, 1)\n",
|
|
" plt.imshow(original_image.reshape(img_dim, img_dim), cmap='gray')\n",
|
|
" plt.title(f\"Original Image {i+1}\")\n",
|
|
" \n",
|
|
" plt.subplot(1, 3, 2)\n",
|
|
" plt.imshow(encoded.reshape(output_size, 1), cmap='viridis', aspect='auto')\n",
|
|
" plt.title(f\"Encoded Image {i+1}\")\n",
|
|
" \n",
|
|
" plt.subplot(1, 3, 3)\n",
|
|
" plt.imshow(reconstructed.reshape(img_dim, img_dim), cmap='gray')\n",
|
|
" plt.title(f\"Reconstructed Image {i+1}\")\n",
|
|
" \n",
|
|
" plt.tight_layout()\n",
|
|
" plt.show()\n",
|
|
"\n",
|
|
" # Calculate and print the mean squared error\n",
|
|
" mse = np.mean((original_image - reconstructed) ** 2)\n",
|
|
" print(f\"Mean Squared Error for Image {i+1}: {mse}\")\n",
|
|
"\n",
|
|
"print(\"Encoding-decoding process completed.\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "tf",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.6"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|