{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "from net.modules import *\n", "from net.decoder import *\n", "from net.encoder import *" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Set up parameters (modify these as needed)\n", "file_path = \"data/bel_data_test.csv\" # Replace with your actual file path\n", "encoder_type = \"pca\" # Choose \"regular\" or \"pca\"\n", "load_weights = False # Set to True if you want to load pre-trained weights\n", "weight_file = \"weights/bel_weights.npz\" # Only used if load_weights is True\n", "dataset=\"bel\"\n", "\n", "# Load and prepare the data\n", "X_train, Y_train, X_test, Y_test = load_and_prepare_data(file_path)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Set hyperparameters\n", "input_size = X_train.shape[0]\n", "hidden_size = 128\n", "output_size = 61 # Number of classes\n", "pca_components = 50\n", "\n", "alpha = 0.01\n", "iterations = 2000\n", "num_trials = 5" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Choose encoder type\n", "if encoder_type == \"regular\":\n", " EncoderClass = Encoder\n", " encoder_name = \"Regular Encoder\"\n", "elif encoder_type == \"pca\":\n", " EncoderClass = PCAEncoder\n", " encoder_name = \"PCAEncoder\"\n", "else:\n", " raise ValueError(\"Invalid encoder type selected\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training PCAEncoder\n", "Trial 1/5\n" ] }, { "ename": "ValueError", "evalue": "X has 50 features, but StandardScaler is expecting 1024 features as input.", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[5], line 9\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTraining \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mencoder_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 9\u001b[0m best_weights, best_accuracy, all_accuracies \u001b[38;5;241m=\u001b[39m \u001b[43mbest_params\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[43mEncoderClass\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mY_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhidden_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 11\u001b[0m \u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43miterations\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_trials\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpca_components\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpca_components\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mEncoderClass\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mPCAEncoder\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\n\u001b[1;32m 12\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBest accuracy achieved with \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mencoder_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mbest_accuracy\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 15\u001b[0m \u001b[38;5;66;03m# Create encoder with best weights\u001b[39;00m\n", "File \u001b[0;32m~/Projects/School/Sem_Imp/net/modules.py:15\u001b[0m, in \u001b[0;36mbest_params\u001b[0;34m(EncoderClass, X, Y, input_size, hidden_size, output_size, alpha, iterations, num_trials, **kwargs)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTrial \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrial\u001b[38;5;250m \u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_trials\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 14\u001b[0m encoder \u001b[38;5;241m=\u001b[39m EncoderClass(input_size, hidden_size, output_size, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m---> 15\u001b[0m accuracies \u001b[38;5;241m=\u001b[39m \u001b[43mencoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mY\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43miterations\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 16\u001b[0m final_accuracy \u001b[38;5;241m=\u001b[39m accuracies[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m 17\u001b[0m all_accuracies\u001b[38;5;241m.\u001b[39mappend(accuracies)\n", "File \u001b[0;32m~/Projects/School/Sem_Imp/net/encoder.py:112\u001b[0m, in \u001b[0;36mPCAEncoder.train\u001b[0;34m(self, X, Y, iterations, alpha)\u001b[0m\n\u001b[1;32m 110\u001b[0m X_scaled \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscaler\u001b[38;5;241m.\u001b[39mtransform(X\u001b[38;5;241m.\u001b[39mT)\u001b[38;5;241m.\u001b[39mT\n\u001b[1;32m 111\u001b[0m X_pca \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpca\u001b[38;5;241m.\u001b[39mtransform(X_scaled\u001b[38;5;241m.\u001b[39mT)\u001b[38;5;241m.\u001b[39mT\n\u001b[0;32m--> 112\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_pca\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mY\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43miterations\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/Projects/School/Sem_Imp/net/encoder.py:64\u001b[0m, in \u001b[0;36mEncoder.train\u001b[0;34m(self, X, Y, iterations, alpha)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mupdate_params(dW1, db1, dW2, db2, alpha)\n\u001b[1;32m 63\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m%\u001b[39m \u001b[38;5;241m100\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m---> 64\u001b[0m accuracy \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_accuracy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mY\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 65\u001b[0m accuracies\u001b[38;5;241m.\u001b[39mappend(accuracy)\n\u001b[1;32m 66\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIteration \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Accuracy: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00maccuracy\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", "File \u001b[0;32m~/Projects/School/Sem_Imp/net/encoder.py:115\u001b[0m, in \u001b[0;36mPCAEncoder.get_accuracy\u001b[0;34m(self, X, Y)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_accuracy\u001b[39m(\u001b[38;5;28mself\u001b[39m, X, Y):\n\u001b[0;32m--> 115\u001b[0m X_scaled \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscaler\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtransform\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mT\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mT\n\u001b[1;32m 116\u001b[0m X_pca \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpca\u001b[38;5;241m.\u001b[39mtransform(X_scaled\u001b[38;5;241m.\u001b[39mT)\u001b[38;5;241m.\u001b[39mT\n\u001b[1;32m 117\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39mget_accuracy(X_pca, Y)\n", "File \u001b[0;32m~/.pyenv/versions/3.11.6/envs/tf/lib/python3.11/site-packages/sklearn/utils/_set_output.py:313\u001b[0m, in \u001b[0;36m_wrap_method_output..wrapped\u001b[0;34m(self, X, *args, **kwargs)\u001b[0m\n\u001b[1;32m 311\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(f)\n\u001b[1;32m 312\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapped\u001b[39m(\u001b[38;5;28mself\u001b[39m, X, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 313\u001b[0m data_to_wrap \u001b[38;5;241m=\u001b[39m \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 314\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data_to_wrap, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m 315\u001b[0m \u001b[38;5;66;03m# only wrap the first output for cross decomposition\u001b[39;00m\n\u001b[1;32m 316\u001b[0m return_tuple \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 317\u001b[0m _wrap_data_with_container(method, data_to_wrap[\u001b[38;5;241m0\u001b[39m], X, \u001b[38;5;28mself\u001b[39m),\n\u001b[1;32m 318\u001b[0m \u001b[38;5;241m*\u001b[39mdata_to_wrap[\u001b[38;5;241m1\u001b[39m:],\n\u001b[1;32m 319\u001b[0m )\n", "File \u001b[0;32m~/.pyenv/versions/3.11.6/envs/tf/lib/python3.11/site-packages/sklearn/preprocessing/_data.py:1045\u001b[0m, in \u001b[0;36mStandardScaler.transform\u001b[0;34m(self, X, copy)\u001b[0m\n\u001b[1;32m 1042\u001b[0m check_is_fitted(\u001b[38;5;28mself\u001b[39m)\n\u001b[1;32m 1044\u001b[0m copy \u001b[38;5;241m=\u001b[39m copy \u001b[38;5;28;01mif\u001b[39;00m copy \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcopy\n\u001b[0;32m-> 1045\u001b[0m X \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_validate_data\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1046\u001b[0m \u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1047\u001b[0m \u001b[43m \u001b[49m\u001b[43mreset\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1048\u001b[0m \u001b[43m \u001b[49m\u001b[43maccept_sparse\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcsr\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1049\u001b[0m \u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1050\u001b[0m \u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mFLOAT_DTYPES\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1051\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_writeable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1052\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_all_finite\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mallow-nan\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1053\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1055\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m sparse\u001b[38;5;241m.\u001b[39missparse(X):\n\u001b[1;32m 1056\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwith_mean:\n", "File \u001b[0;32m~/.pyenv/versions/3.11.6/envs/tf/lib/python3.11/site-packages/sklearn/base.py:654\u001b[0m, in \u001b[0;36mBaseEstimator._validate_data\u001b[0;34m(self, X, y, reset, validate_separately, cast_to_ndarray, **check_params)\u001b[0m\n\u001b[1;32m 651\u001b[0m out \u001b[38;5;241m=\u001b[39m X, y\n\u001b[1;32m 653\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m no_val_X \u001b[38;5;129;01mand\u001b[39;00m check_params\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mensure_2d\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mTrue\u001b[39;00m):\n\u001b[0;32m--> 654\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_check_n_features\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreset\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreset\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 656\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out\n", "File \u001b[0;32m~/.pyenv/versions/3.11.6/envs/tf/lib/python3.11/site-packages/sklearn/base.py:443\u001b[0m, in \u001b[0;36mBaseEstimator._check_n_features\u001b[0;34m(self, X, reset)\u001b[0m\n\u001b[1;32m 440\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 442\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m n_features \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_features_in_:\n\u001b[0;32m--> 443\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 444\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX has \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mn_features\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m features, but \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 445\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mis expecting \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_features_in_\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m features as input.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 446\u001b[0m )\n", "\u001b[0;31mValueError\u001b[0m: X has 50 features, but StandardScaler is expecting 1024 features as input." ] } ], "source": [ "# Load pre-trained weights or train a new model\n", "if load_weights:\n", " encoder = EncoderClass(input_size, hidden_size, output_size, pca_components=pca_components)\n", " encoder.load_weights(weight_file)\n", " print(f\"Weights loaded for {encoder_name}\")\n", "else:\n", " print(f\"Training {encoder_name}\")\n", " best_weights, best_accuracy, all_accuracies = best_params(\n", " EncoderClass, X_train, Y_train, input_size, hidden_size, output_size, \n", " alpha, iterations, num_trials, pca_components=pca_components\n", " )\n", " print(f\"Best accuracy achieved with {encoder_name}: {best_accuracy}\")\n", "\n", " # Create encoder with best weights\n", " encoder = EncoderClass(input_size, hidden_size, output_size, pca_components=pca_components)\n", " encoder.W1, encoder.b1, encoder.W2, encoder.b2 = best_weights\n", "\n", " # Save the best weights\n", " weight_file = f\"{dataset}_{encoder_name.lower().replace(' ', '_')}_weights.npz\"\n", " encoder.save_weights(weight_file)\n", " print(f\"Best weights saved to {weight_file}\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Initialize the appropriate Decoder\n", "decoder = PCADecoder(encoder) if EncoderClass == PCAEncoder else Decoder(encoder)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Test the encoding-decoding process\n", "num_test_images = 5\n", "for i in range(num_test_images):\n", " original_image = X_test[:, i].reshape(input_size, 1)\n", " encoded = encoder.encode(original_image)\n", " reconstructed = decoder.decode(encoded)\n", "\n", " # Visualize the results\n", " img_dim = int(np.sqrt(input_size))\n", " plt.figure(figsize=(15, 5))\n", " \n", " plt.subplot(1, 3, 1)\n", " plt.imshow(original_image.reshape(img_dim, img_dim), cmap='gray')\n", " plt.title(f\"Original Image {i+1}\")\n", " \n", " plt.subplot(1, 3, 2)\n", " plt.imshow(encoded.reshape(output_size, 1), cmap='viridis', aspect='auto')\n", " plt.title(f\"Encoded Image {i+1}\")\n", " \n", " plt.subplot(1, 3, 3)\n", " plt.imshow(reconstructed.reshape(img_dim, img_dim), cmap='gray')\n", " plt.title(f\"Reconstructed Image {i+1}\")\n", " \n", " plt.tight_layout()\n", " plt.show()\n", "\n", " # Calculate and print the mean squared error\n", " mse = np.mean((original_image - reconstructed) ** 2)\n", " print(f\"Mean Squared Error for Image {i+1}: {mse}\")\n", "\n", "print(\"Encoding-decoding process completed.\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "tf", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 2 }