Initial commit
This commit is contained in:
commit
1995df58ce
21 changed files with 6708 additions and 0 deletions
185
bel_semantics.ipynb
Normal file
185
bel_semantics.ipynb
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import net.modules\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"from net.transcoder import Transcoder"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"filepath = 'data/bel_data_test.csv'\n",
|
||||
"train_loader, test_loader, input_size = load_and_prepare_data(file_path=filepath)\n",
|
||||
"\n",
|
||||
"print(\"X_train shape:\", input_size.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# input_size = X_train.shape[0]\n",
|
||||
"# hidden_size = 128\n",
|
||||
"# output_size = 61\n",
|
||||
"\n",
|
||||
"architecture = [input_size, [128], 61]\n",
|
||||
"activations = ['leaky_relu','softmax']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Initialize transcoder"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# bl_transcoder = Transcoder(input_size, hidden_size, output_size, 'leaky_relu', 'softmax')\n",
|
||||
"bl_transcoder = Transcoder(architecture, hidden_activation='relu', output_activation='softmax')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Train Encoders and save weights\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# # Train the encoder if need\n",
|
||||
"\n",
|
||||
"bl_transcoder.train_model(train_loader, test_loader, learning_rate=0.001, epochs=1000)\n",
|
||||
"# bl_transcoder.train_with_validation(X_train, Y_train, alpha=0.1, iterations=1000)\n",
|
||||
"bl_transcoder.save_results('bt_1h128n')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Load weights"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"bl_transcoder.load_weights('weights/bt_1h128n_leaky_relu_weights.pth')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Analysis"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Plot learning curves\n",
|
||||
"bl_transcoder.plot_learning_curves()\n",
|
||||
"\n",
|
||||
"# Visualize encoded space\n",
|
||||
"bl_transcoder.plot_encoded_space(X_test, Y_test)\n",
|
||||
"\n",
|
||||
"print(X_test.shape)\n",
|
||||
"print(X_train.shape)\n",
|
||||
"# Check reconstructions\n",
|
||||
"bl_transcoder.plot_reconstructions(X_test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Transcode images"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"num_images = 2\n",
|
||||
"indices = np.random.choice(X_test.shape[1], num_images, replace=False)\n",
|
||||
"\n",
|
||||
"for idx in indices:\n",
|
||||
" original_image = X_test[:, idx]\n",
|
||||
" \n",
|
||||
" # Encode the image\n",
|
||||
" encoded = bl_transcoder.encode_image(original_image.reshape(-1, 1))\n",
|
||||
" \n",
|
||||
" # Decode the image\n",
|
||||
" decoded = bl_transcoder.decode_image(encoded)\n",
|
||||
"\n",
|
||||
" # Visualize original, encoded, and decoded images\n",
|
||||
" visualize_transcoding(original_image, encoded, decoded, idx)\n",
|
||||
"\n",
|
||||
" print(f\"Image {idx}:\")\n",
|
||||
" print(\"Original shape:\", original_image.shape)\n",
|
||||
" print(\"Encoded shape:\", encoded.shape)\n",
|
||||
" print(\"Decoded shape:\", decoded.shape)\n",
|
||||
" print(\"Encoded vector:\", encoded.flatten()) # Print flattened encoded vector\n",
|
||||
" print(\"\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "semantics",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue