semantics/enc_dec.ipynb
2024-09-26 17:23:23 -04:00

203 lines
19 KiB
Text

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from net.modules import *\n",
"from net.decoder import *\n",
"from net.encoder import *"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Set up parameters (modify these as needed)\n",
"file_path = \"data/bel_data_test.csv\" # Replace with your actual file path\n",
"encoder_type = \"pca\" # Choose \"regular\" or \"pca\"\n",
"load_weights = False # Set to True if you want to load pre-trained weights\n",
"weight_file = \"weights/bel_weights.npz\" # Only used if load_weights is True\n",
"dataset=\"bel\"\n",
"\n",
"# Load and prepare the data\n",
"X_train, Y_train, X_test, Y_test = load_and_prepare_data(file_path)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Set hyperparameters\n",
"input_size = X_train.shape[0]\n",
"hidden_size = 128\n",
"output_size = 61 # Number of classes\n",
"pca_components = 50\n",
"\n",
"alpha = 0.01\n",
"iterations = 2000\n",
"num_trials = 5"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Choose encoder type\n",
"if encoder_type == \"regular\":\n",
" EncoderClass = Encoder\n",
" encoder_name = \"Regular Encoder\"\n",
"elif encoder_type == \"pca\":\n",
" EncoderClass = PCAEncoder\n",
" encoder_name = \"PCAEncoder\"\n",
"else:\n",
" raise ValueError(\"Invalid encoder type selected\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training PCAEncoder\n",
"Trial 1/5\n"
]
},
{
"ename": "ValueError",
"evalue": "X has 50 features, but StandardScaler is expecting 1024 features as input.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[5], line 9\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTraining \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mencoder_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 9\u001b[0m best_weights, best_accuracy, all_accuracies \u001b[38;5;241m=\u001b[39m \u001b[43mbest_params\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[43mEncoderClass\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mY_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhidden_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 11\u001b[0m \u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43miterations\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_trials\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpca_components\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpca_components\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mEncoderClass\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mPCAEncoder\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\n\u001b[1;32m 12\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBest accuracy achieved with \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mencoder_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mbest_accuracy\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 15\u001b[0m \u001b[38;5;66;03m# Create encoder with best weights\u001b[39;00m\n",
"File \u001b[0;32m~/Projects/School/Sem_Imp/net/modules.py:15\u001b[0m, in \u001b[0;36mbest_params\u001b[0;34m(EncoderClass, X, Y, input_size, hidden_size, output_size, alpha, iterations, num_trials, **kwargs)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTrial \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrial\u001b[38;5;250m \u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_trials\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 14\u001b[0m encoder \u001b[38;5;241m=\u001b[39m EncoderClass(input_size, hidden_size, output_size, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m---> 15\u001b[0m accuracies \u001b[38;5;241m=\u001b[39m \u001b[43mencoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mY\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43miterations\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 16\u001b[0m final_accuracy \u001b[38;5;241m=\u001b[39m accuracies[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m 17\u001b[0m all_accuracies\u001b[38;5;241m.\u001b[39mappend(accuracies)\n",
"File \u001b[0;32m~/Projects/School/Sem_Imp/net/encoder.py:112\u001b[0m, in \u001b[0;36mPCAEncoder.train\u001b[0;34m(self, X, Y, iterations, alpha)\u001b[0m\n\u001b[1;32m 110\u001b[0m X_scaled \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscaler\u001b[38;5;241m.\u001b[39mtransform(X\u001b[38;5;241m.\u001b[39mT)\u001b[38;5;241m.\u001b[39mT\n\u001b[1;32m 111\u001b[0m X_pca \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpca\u001b[38;5;241m.\u001b[39mtransform(X_scaled\u001b[38;5;241m.\u001b[39mT)\u001b[38;5;241m.\u001b[39mT\n\u001b[0;32m--> 112\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_pca\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mY\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43miterations\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Projects/School/Sem_Imp/net/encoder.py:64\u001b[0m, in \u001b[0;36mEncoder.train\u001b[0;34m(self, X, Y, iterations, alpha)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mupdate_params(dW1, db1, dW2, db2, alpha)\n\u001b[1;32m 63\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m%\u001b[39m \u001b[38;5;241m100\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m---> 64\u001b[0m accuracy \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_accuracy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mY\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 65\u001b[0m accuracies\u001b[38;5;241m.\u001b[39mappend(accuracy)\n\u001b[1;32m 66\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIteration \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Accuracy: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00maccuracy\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
"File \u001b[0;32m~/Projects/School/Sem_Imp/net/encoder.py:115\u001b[0m, in \u001b[0;36mPCAEncoder.get_accuracy\u001b[0;34m(self, X, Y)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_accuracy\u001b[39m(\u001b[38;5;28mself\u001b[39m, X, Y):\n\u001b[0;32m--> 115\u001b[0m X_scaled \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscaler\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtransform\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mT\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mT\n\u001b[1;32m 116\u001b[0m X_pca \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpca\u001b[38;5;241m.\u001b[39mtransform(X_scaled\u001b[38;5;241m.\u001b[39mT)\u001b[38;5;241m.\u001b[39mT\n\u001b[1;32m 117\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39mget_accuracy(X_pca, Y)\n",
"File \u001b[0;32m~/.pyenv/versions/3.11.6/envs/tf/lib/python3.11/site-packages/sklearn/utils/_set_output.py:313\u001b[0m, in \u001b[0;36m_wrap_method_output.<locals>.wrapped\u001b[0;34m(self, X, *args, **kwargs)\u001b[0m\n\u001b[1;32m 311\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(f)\n\u001b[1;32m 312\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapped\u001b[39m(\u001b[38;5;28mself\u001b[39m, X, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 313\u001b[0m data_to_wrap \u001b[38;5;241m=\u001b[39m \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 314\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data_to_wrap, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m 315\u001b[0m \u001b[38;5;66;03m# only wrap the first output for cross decomposition\u001b[39;00m\n\u001b[1;32m 316\u001b[0m return_tuple \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 317\u001b[0m _wrap_data_with_container(method, data_to_wrap[\u001b[38;5;241m0\u001b[39m], X, \u001b[38;5;28mself\u001b[39m),\n\u001b[1;32m 318\u001b[0m \u001b[38;5;241m*\u001b[39mdata_to_wrap[\u001b[38;5;241m1\u001b[39m:],\n\u001b[1;32m 319\u001b[0m )\n",
"File \u001b[0;32m~/.pyenv/versions/3.11.6/envs/tf/lib/python3.11/site-packages/sklearn/preprocessing/_data.py:1045\u001b[0m, in \u001b[0;36mStandardScaler.transform\u001b[0;34m(self, X, copy)\u001b[0m\n\u001b[1;32m 1042\u001b[0m check_is_fitted(\u001b[38;5;28mself\u001b[39m)\n\u001b[1;32m 1044\u001b[0m copy \u001b[38;5;241m=\u001b[39m copy \u001b[38;5;28;01mif\u001b[39;00m copy \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcopy\n\u001b[0;32m-> 1045\u001b[0m X \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_validate_data\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1046\u001b[0m \u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1047\u001b[0m \u001b[43m \u001b[49m\u001b[43mreset\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1048\u001b[0m \u001b[43m \u001b[49m\u001b[43maccept_sparse\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcsr\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1049\u001b[0m \u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1050\u001b[0m \u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mFLOAT_DTYPES\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1051\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_writeable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1052\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_all_finite\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mallow-nan\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1053\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1055\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m sparse\u001b[38;5;241m.\u001b[39missparse(X):\n\u001b[1;32m 1056\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwith_mean:\n",
"File \u001b[0;32m~/.pyenv/versions/3.11.6/envs/tf/lib/python3.11/site-packages/sklearn/base.py:654\u001b[0m, in \u001b[0;36mBaseEstimator._validate_data\u001b[0;34m(self, X, y, reset, validate_separately, cast_to_ndarray, **check_params)\u001b[0m\n\u001b[1;32m 651\u001b[0m out \u001b[38;5;241m=\u001b[39m X, y\n\u001b[1;32m 653\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m no_val_X \u001b[38;5;129;01mand\u001b[39;00m check_params\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mensure_2d\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mTrue\u001b[39;00m):\n\u001b[0;32m--> 654\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_check_n_features\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreset\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 656\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out\n",
"File \u001b[0;32m~/.pyenv/versions/3.11.6/envs/tf/lib/python3.11/site-packages/sklearn/base.py:443\u001b[0m, in \u001b[0;36mBaseEstimator._check_n_features\u001b[0;34m(self, X, reset)\u001b[0m\n\u001b[1;32m 440\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 442\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m n_features \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_features_in_:\n\u001b[0;32m--> 443\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 444\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX has \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mn_features\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m features, but \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 445\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mis expecting \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_features_in_\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m features as input.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 446\u001b[0m )\n",
"\u001b[0;31mValueError\u001b[0m: X has 50 features, but StandardScaler is expecting 1024 features as input."
]
}
],
"source": [
"# Load pre-trained weights or train a new model\n",
"if load_weights:\n",
" encoder = EncoderClass(input_size, hidden_size, output_size, pca_components=pca_components)\n",
" encoder.load_weights(weight_file)\n",
" print(f\"Weights loaded for {encoder_name}\")\n",
"else:\n",
" print(f\"Training {encoder_name}\")\n",
" best_weights, best_accuracy, all_accuracies = best_params(\n",
" EncoderClass, X_train, Y_train, input_size, hidden_size, output_size, \n",
" alpha, iterations, num_trials, pca_components=pca_components\n",
" )\n",
" print(f\"Best accuracy achieved with {encoder_name}: {best_accuracy}\")\n",
"\n",
" # Create encoder with best weights\n",
" encoder = EncoderClass(input_size, hidden_size, output_size, pca_components=pca_components)\n",
" encoder.W1, encoder.b1, encoder.W2, encoder.b2 = best_weights\n",
"\n",
" # Save the best weights\n",
" weight_file = f\"{dataset}_{encoder_name.lower().replace(' ', '_')}_weights.npz\"\n",
" encoder.save_weights(weight_file)\n",
" print(f\"Best weights saved to {weight_file}\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize the appropriate Decoder\n",
"decoder = PCADecoder(encoder) if EncoderClass == PCAEncoder else Decoder(encoder)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test the encoding-decoding process\n",
"num_test_images = 5\n",
"for i in range(num_test_images):\n",
" original_image = X_test[:, i].reshape(input_size, 1)\n",
" encoded = encoder.encode(original_image)\n",
" reconstructed = decoder.decode(encoded)\n",
"\n",
" # Visualize the results\n",
" img_dim = int(np.sqrt(input_size))\n",
" plt.figure(figsize=(15, 5))\n",
" \n",
" plt.subplot(1, 3, 1)\n",
" plt.imshow(original_image.reshape(img_dim, img_dim), cmap='gray')\n",
" plt.title(f\"Original Image {i+1}\")\n",
" \n",
" plt.subplot(1, 3, 2)\n",
" plt.imshow(encoded.reshape(output_size, 1), cmap='viridis', aspect='auto')\n",
" plt.title(f\"Encoded Image {i+1}\")\n",
" \n",
" plt.subplot(1, 3, 3)\n",
" plt.imshow(reconstructed.reshape(img_dim, img_dim), cmap='gray')\n",
" plt.title(f\"Reconstructed Image {i+1}\")\n",
" \n",
" plt.tight_layout()\n",
" plt.show()\n",
"\n",
" # Calculate and print the mean squared error\n",
" mse = np.mean((original_image - reconstructed) ** 2)\n",
" print(f\"Mean Squared Error for Image {i+1}: {mse}\")\n",
"\n",
"print(\"Encoding-decoding process completed.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "tf",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}