Move Decoder, MLP and MLPWithDecoder to one class SemanticsMLP

This commit is contained in:
Murtadha 2024-09-27 11:28:58 -04:00
parent 0fb1f69b1f
commit fa15717c84

View file

@ -10,6 +10,7 @@
"\n", "\n",
"import torch.nn as nn\n", "import torch.nn as nn\n",
"import torch.optim as optim\n", "import torch.optim as optim\n",
"import torch.nn.functional as F\n",
"import numpy as np\n", "import numpy as np\n",
"import pandas as pd\n", "import pandas as pd\n",
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
@ -74,50 +75,47 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Define the MLP\n", "class SemanticsMLP(nn.Module):\n",
"class MLP(nn.Module):\n", " def __init__(self, input_size=1024, hidden_sizes=[512, 256, 128], num_classes=62):\n",
" def __init__(self):\n", " super(SemanticsMLP, self).__init__()\n",
" super(MLP, self).__init__()\n", " self.input_size = input_size\n",
" self.input_layer = nn.Linear(1024, 512)\n", " self.hidden_sizes = hidden_sizes\n",
" self.h1_layer = nn.Linear(512, 64)\n", " self.num_classes = num_classes\n",
" self.h2_layer = nn.Linear(64, 62)\n",
" self.relu = nn.ReLU()\n",
"\n", "\n",
" def forward(self, x):\n", " # Encoder (feature extractor)\n",
" x = self.relu(self.input_layer(x))\n", " self.encoder_layers = nn.ModuleList()\n",
" x = self.h1_layer(x)\n", " prev_size = input_size\n",
" x = self.h2_layer(x)\n", " for hidden_size in hidden_sizes:\n",
" self.encoder_layers.append(nn.Linear(prev_size, hidden_size))\n",
" prev_size = hidden_size\n",
"\n",
" # Classifier\n",
" self.classifier = nn.Linear(hidden_sizes[-1], num_classes)\n",
"\n",
" # Decoder\n",
" self.decoder_layers = nn.ModuleList()\n",
" reversed_hidden_sizes = list(reversed(hidden_sizes))\n",
" prev_size = hidden_sizes[-1]\n",
" for hidden_size in reversed_hidden_sizes[1:] + [input_size]:\n",
" self.decoder_layers.append(nn.Linear(prev_size, hidden_size))\n",
" prev_size = hidden_size\n",
"\n",
" def encode(self, x):\n",
" for layer in self.encoder_layers:\n",
" x = F.relu(layer(x))\n",
" return x\n", " return x\n",
"\n", "\n",
"# Define the Decoder\n", " def decode(self, x):\n",
"class Decoder(nn.Module):\n", " for layer in self.decoder_layers[:-1]:\n",
" def __init__(self):\n", " x = F.relu(layer(x))\n",
" super(Decoder, self).__init__()\n", " x = self.decoder_layers[-1](x) # No activation on the final layer\n",
" self.h2_h1 = nn.Linear(64, 512)\n",
" self.h1_input = nn.Linear(512, 1024)\n",
" self.relu = nn.ReLU()\n",
"\n",
" def forward(self, x):\n",
" x = self.relu(self.h2_h1(x))\n",
" x = self.h1_input(x)\n",
" return x\n", " return x\n",
"\n", "\n",
"class MLPWithDecoder(nn.Module):\n",
" def __init__(self):\n",
" super(MLPWithDecoder, self).__init__()\n",
" self.mlp = MLP()\n",
" self.decoder = Decoder()\n",
"\n",
" def forward(self, x):\n", " def forward(self, x):\n",
" # MLP forward pass\n", " encoded = self.encode(x)\n",
" h1 = self.mlp.relu(self.mlp.input_layer(x))\n", " logits = self.classifier(encoded)\n",
" h2 = self.mlp.relu(self.mlp.h1_layer(h1))\n", " reconstructed = self.decode(encoded)\n",
" output = self.mlp.h2_layer(h2)\n", " return logits, reconstructed"
" \n",
" # Reconstruction\n",
" reconstruction = self.decoder(h2)\n",
" \n",
" return output, reconstruction"
] ]
}, },
{ {
@ -125,20 +123,6 @@
"execution_count": 6, "execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [
"# Function to reconstruct an image\n",
"def reconstruct_image(model, image):\n",
" model.eval()\n",
" with torch.no_grad():\n",
" _, reconstruction = model(image.unsqueeze(0))\n",
" return reconstruction.squeeze(0)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [ "source": [
"def show_image_comparison(original, reconstructed, label, prediction):\n", "def show_image_comparison(original, reconstructed, label, prediction):\n",
" \"\"\"\n", " \"\"\"\n",
@ -176,19 +160,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 7,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Initialize the model, loss function, and optimizer\n", "model = SemanticsMLP(input_size=1024, hidden_sizes=[10], num_classes=62).to(device)\n",
"model = MLPWithDecoder()\n",
"criterion = nn.CrossEntropyLoss()\n", "criterion = nn.CrossEntropyLoss()\n",
"reconstruction_criterion = nn.MSELoss()\n", "reconstruction_criterion = nn.MSELoss()\n",
"optimizer = optim.Adam(model.parameters())\n", "optimizer = optim.Adam(model.parameters(), lr=0.001)"
"\n",
"model = model.to(device)\n",
"criterion = criterion.to(device)\n",
"reconstruction_criterion = reconstruction_criterion.to(device)"
] ]
}, },
{ {
@ -199,10 +178,9 @@
"source": [ "source": [
"num_epochs = 250\n", "num_epochs = 250\n",
"for epoch in range(num_epochs):\n", "for epoch in range(num_epochs):\n",
" model.train() # Set the model to training mode\n", " model.train()\n",
" running_loss = 0.0\n", " running_loss = 0.0\n",
" \n", " \n",
" # Use tqdm for a progress bar\n",
" with tqdm(train_loader, unit=\"batch\") as tepoch:\n", " with tqdm(train_loader, unit=\"batch\") as tepoch:\n",
" for images, labels in tepoch:\n", " for images, labels in tepoch:\n",
" tepoch.set_description(f\"Epoch {epoch+1}\")\n", " tepoch.set_description(f\"Epoch {epoch+1}\")\n",
@ -211,10 +189,10 @@
" \n", " \n",
" optimizer.zero_grad()\n", " optimizer.zero_grad()\n",
" \n", " \n",
" outputs, reconstructions = model(images)\n", " logits, reconstructed = model(images)\n",
" \n", " \n",
" classification_loss = criterion(outputs, labels)\n", " classification_loss = criterion(logits, labels)\n",
" reconstruction_loss = reconstruction_criterion(reconstructions, images)\n", " reconstruction_loss = reconstruction_criterion(reconstructed, images)\n",
" total_loss = classification_loss + reconstruction_loss\n", " total_loss = classification_loss + reconstruction_loss\n",
" \n", " \n",
" total_loss.backward()\n", " total_loss.backward()\n",
@ -225,7 +203,6 @@
" tepoch.set_postfix(loss=total_loss.item())\n", " tepoch.set_postfix(loss=total_loss.item())\n",
" \n", " \n",
" epoch_loss = running_loss / len(train_loader)\n", " epoch_loss = running_loss / len(train_loader)\n",
" \n",
" # print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')" " # print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')"
] ]
}, },
@ -235,40 +212,28 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"model.eval() # Set the model to evaluation mode\n", "model.eval()\n",
"with torch.no_grad():\n", "with torch.no_grad():\n",
" try:\n", " images, labels = next(iter(test_loader))\n",
" # Get a batch of test data\n", " images, labels = images.to(device), labels.to(device)\n",
" images, labels = next(iter(test_loader))\n", " \n",
" \n", " logits, reconstructed = model(images)\n",
" # Move data to the same device as the model\n", " \n",
" images = images.to(device)\n", " _, predicted = torch.max(logits.data, 1)\n",
" labels = labels.to(device)\n", " \n",
" \n", " num_images_to_show = min(5, len(images))\n",
" # Forward pass through the model\n", " for i in range(num_images_to_show):\n",
" outputs, reconstructions = model(images)\n", " show_image_comparison(\n",
" \n", " images[i], \n",
" # Get predicted labels\n", " reconstructed[i], \n",
" _, predicted = torch.max(outputs.data, 1)\n", " labels[i].item(), \n",
" \n", " predicted[i].item()\n",
" # Display the first few images in the batch\n", " )\n",
" num_images_to_show = min(5, len(images))\n", " \n",
" for i in range(num_images_to_show):\n", " correct = (predicted == labels).sum().item()\n",
" show_image_comparison(\n", " total = labels.size(0)\n",
" images[i], \n", " accuracy = 100 * correct / total\n",
" reconstructions[i], \n", " print(f'Test Accuracy: {accuracy:.2f}%')"
" labels[i].item(), \n",
" predicted[i].item()\n",
" )\n",
" \n",
" # Calculate and print accuracy\n",
" correct = (predicted == labels).sum().item()\n",
" total = labels.size(0)\n",
" accuracy = 100 * correct / total\n",
" print(f'Test Accuracy: {accuracy:.2f}%')\n",
" \n",
" except Exception as e:\n",
" print(f\"An error occurred during evaluation: {str(e)}\")"
] ]
}, },
{ {