185 lines
4.2 KiB
Text
185 lines
4.2 KiB
Text
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import net.modules\n",
|
|
"\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"from net.transcoder import Transcoder"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"filepath = 'data/bel_data_test.csv'\n",
|
|
"train_loader, test_loader, input_size = load_and_prepare_data(file_path=filepath)\n",
|
|
"\n",
|
|
"print(\"X_train shape:\", input_size.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# input_size = X_train.shape[0]\n",
|
|
"# hidden_size = 128\n",
|
|
"# output_size = 61\n",
|
|
"\n",
|
|
"architecture = [input_size, [128], 61]\n",
|
|
"activations = ['leaky_relu','softmax']"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Initialize transcoder"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# bl_transcoder = Transcoder(input_size, hidden_size, output_size, 'leaky_relu', 'softmax')\n",
|
|
"bl_transcoder = Transcoder(architecture, hidden_activation='relu', output_activation='softmax')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Train Encoders and save weights\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# # Train the encoder if need\n",
|
|
"\n",
|
|
"bl_transcoder.train_model(train_loader, test_loader, learning_rate=0.001, epochs=1000)\n",
|
|
"# bl_transcoder.train_with_validation(X_train, Y_train, alpha=0.1, iterations=1000)\n",
|
|
"bl_transcoder.save_results('bt_1h128n')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Load weights"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"bl_transcoder.load_weights('weights/bt_1h128n_leaky_relu_weights.pth')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Analysis"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Plot learning curves\n",
|
|
"bl_transcoder.plot_learning_curves()\n",
|
|
"\n",
|
|
"# Visualize encoded space\n",
|
|
"bl_transcoder.plot_encoded_space(X_test, Y_test)\n",
|
|
"\n",
|
|
"print(X_test.shape)\n",
|
|
"print(X_train.shape)\n",
|
|
"# Check reconstructions\n",
|
|
"bl_transcoder.plot_reconstructions(X_test)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Transcode images"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"num_images = 2\n",
|
|
"indices = np.random.choice(X_test.shape[1], num_images, replace=False)\n",
|
|
"\n",
|
|
"for idx in indices:\n",
|
|
" original_image = X_test[:, idx]\n",
|
|
" \n",
|
|
" # Encode the image\n",
|
|
" encoded = bl_transcoder.encode_image(original_image.reshape(-1, 1))\n",
|
|
" \n",
|
|
" # Decode the image\n",
|
|
" decoded = bl_transcoder.decode_image(encoded)\n",
|
|
"\n",
|
|
" # Visualize original, encoded, and decoded images\n",
|
|
" visualize_transcoding(original_image, encoded, decoded, idx)\n",
|
|
"\n",
|
|
" print(f\"Image {idx}:\")\n",
|
|
" print(\"Original shape:\", original_image.shape)\n",
|
|
" print(\"Encoded shape:\", encoded.shape)\n",
|
|
" print(\"Decoded shape:\", decoded.shape)\n",
|
|
" print(\"Encoded vector:\", encoded.flatten()) # Print flattened encoded vector\n",
|
|
" print(\"\\n\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "semantics",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.1"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|