Move Decoder, MLP and MLPWithDecoder to one class SemanticsMLP
This commit is contained in:
parent
0fb1f69b1f
commit
fa15717c84
1 changed files with 64 additions and 99 deletions
|
|
@ -10,6 +10,7 @@
|
|||
"\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.optim as optim\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
|
|
@ -74,50 +75,47 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Define the MLP\n",
|
||||
"class MLP(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(MLP, self).__init__()\n",
|
||||
" self.input_layer = nn.Linear(1024, 512)\n",
|
||||
" self.h1_layer = nn.Linear(512, 64)\n",
|
||||
" self.h2_layer = nn.Linear(64, 62)\n",
|
||||
" self.relu = nn.ReLU()\n",
|
||||
"class SemanticsMLP(nn.Module):\n",
|
||||
" def __init__(self, input_size=1024, hidden_sizes=[512, 256, 128], num_classes=62):\n",
|
||||
" super(SemanticsMLP, self).__init__()\n",
|
||||
" self.input_size = input_size\n",
|
||||
" self.hidden_sizes = hidden_sizes\n",
|
||||
" self.num_classes = num_classes\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = self.relu(self.input_layer(x))\n",
|
||||
" x = self.h1_layer(x)\n",
|
||||
" x = self.h2_layer(x)\n",
|
||||
" # Encoder (feature extractor)\n",
|
||||
" self.encoder_layers = nn.ModuleList()\n",
|
||||
" prev_size = input_size\n",
|
||||
" for hidden_size in hidden_sizes:\n",
|
||||
" self.encoder_layers.append(nn.Linear(prev_size, hidden_size))\n",
|
||||
" prev_size = hidden_size\n",
|
||||
"\n",
|
||||
" # Classifier\n",
|
||||
" self.classifier = nn.Linear(hidden_sizes[-1], num_classes)\n",
|
||||
"\n",
|
||||
" # Decoder\n",
|
||||
" self.decoder_layers = nn.ModuleList()\n",
|
||||
" reversed_hidden_sizes = list(reversed(hidden_sizes))\n",
|
||||
" prev_size = hidden_sizes[-1]\n",
|
||||
" for hidden_size in reversed_hidden_sizes[1:] + [input_size]:\n",
|
||||
" self.decoder_layers.append(nn.Linear(prev_size, hidden_size))\n",
|
||||
" prev_size = hidden_size\n",
|
||||
"\n",
|
||||
" def encode(self, x):\n",
|
||||
" for layer in self.encoder_layers:\n",
|
||||
" x = F.relu(layer(x))\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"# Define the Decoder\n",
|
||||
"class Decoder(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(Decoder, self).__init__()\n",
|
||||
" self.h2_h1 = nn.Linear(64, 512)\n",
|
||||
" self.h1_input = nn.Linear(512, 1024)\n",
|
||||
" self.relu = nn.ReLU()\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = self.relu(self.h2_h1(x))\n",
|
||||
" x = self.h1_input(x)\n",
|
||||
" def decode(self, x):\n",
|
||||
" for layer in self.decoder_layers[:-1]:\n",
|
||||
" x = F.relu(layer(x))\n",
|
||||
" x = self.decoder_layers[-1](x) # No activation on the final layer\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"class MLPWithDecoder(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(MLPWithDecoder, self).__init__()\n",
|
||||
" self.mlp = MLP()\n",
|
||||
" self.decoder = Decoder()\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" # MLP forward pass\n",
|
||||
" h1 = self.mlp.relu(self.mlp.input_layer(x))\n",
|
||||
" h2 = self.mlp.relu(self.mlp.h1_layer(h1))\n",
|
||||
" output = self.mlp.h2_layer(h2)\n",
|
||||
" \n",
|
||||
" # Reconstruction\n",
|
||||
" reconstruction = self.decoder(h2)\n",
|
||||
" \n",
|
||||
" return output, reconstruction"
|
||||
" encoded = self.encode(x)\n",
|
||||
" logits = self.classifier(encoded)\n",
|
||||
" reconstructed = self.decode(encoded)\n",
|
||||
" return logits, reconstructed"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -125,20 +123,6 @@
|
|||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Function to reconstruct an image\n",
|
||||
"def reconstruct_image(model, image):\n",
|
||||
" model.eval()\n",
|
||||
" with torch.no_grad():\n",
|
||||
" _, reconstruction = model(image.unsqueeze(0))\n",
|
||||
" return reconstruction.squeeze(0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def show_image_comparison(original, reconstructed, label, prediction):\n",
|
||||
" \"\"\"\n",
|
||||
|
|
@ -176,19 +160,14 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Initialize the model, loss function, and optimizer\n",
|
||||
"model = MLPWithDecoder()\n",
|
||||
"model = SemanticsMLP(input_size=1024, hidden_sizes=[10], num_classes=62).to(device)\n",
|
||||
"criterion = nn.CrossEntropyLoss()\n",
|
||||
"reconstruction_criterion = nn.MSELoss()\n",
|
||||
"optimizer = optim.Adam(model.parameters())\n",
|
||||
"\n",
|
||||
"model = model.to(device)\n",
|
||||
"criterion = criterion.to(device)\n",
|
||||
"reconstruction_criterion = reconstruction_criterion.to(device)"
|
||||
"optimizer = optim.Adam(model.parameters(), lr=0.001)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -199,10 +178,9 @@
|
|||
"source": [
|
||||
"num_epochs = 250\n",
|
||||
"for epoch in range(num_epochs):\n",
|
||||
" model.train() # Set the model to training mode\n",
|
||||
" model.train()\n",
|
||||
" running_loss = 0.0\n",
|
||||
" \n",
|
||||
" # Use tqdm for a progress bar\n",
|
||||
" with tqdm(train_loader, unit=\"batch\") as tepoch:\n",
|
||||
" for images, labels in tepoch:\n",
|
||||
" tepoch.set_description(f\"Epoch {epoch+1}\")\n",
|
||||
|
|
@ -211,10 +189,10 @@
|
|||
" \n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" \n",
|
||||
" outputs, reconstructions = model(images)\n",
|
||||
" logits, reconstructed = model(images)\n",
|
||||
" \n",
|
||||
" classification_loss = criterion(outputs, labels)\n",
|
||||
" reconstruction_loss = reconstruction_criterion(reconstructions, images)\n",
|
||||
" classification_loss = criterion(logits, labels)\n",
|
||||
" reconstruction_loss = reconstruction_criterion(reconstructed, images)\n",
|
||||
" total_loss = classification_loss + reconstruction_loss\n",
|
||||
" \n",
|
||||
" total_loss.backward()\n",
|
||||
|
|
@ -225,7 +203,6 @@
|
|||
" tepoch.set_postfix(loss=total_loss.item())\n",
|
||||
" \n",
|
||||
" epoch_loss = running_loss / len(train_loader)\n",
|
||||
" \n",
|
||||
" # print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')"
|
||||
]
|
||||
},
|
||||
|
|
@ -235,40 +212,28 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.eval() # Set the model to evaluation mode\n",
|
||||
"model.eval()\n",
|
||||
"with torch.no_grad():\n",
|
||||
" try:\n",
|
||||
" # Get a batch of test data\n",
|
||||
" images, labels = next(iter(test_loader))\n",
|
||||
" \n",
|
||||
" # Move data to the same device as the model\n",
|
||||
" images = images.to(device)\n",
|
||||
" labels = labels.to(device)\n",
|
||||
" \n",
|
||||
" # Forward pass through the model\n",
|
||||
" outputs, reconstructions = model(images)\n",
|
||||
" \n",
|
||||
" # Get predicted labels\n",
|
||||
" _, predicted = torch.max(outputs.data, 1)\n",
|
||||
" \n",
|
||||
" # Display the first few images in the batch\n",
|
||||
" num_images_to_show = min(5, len(images))\n",
|
||||
" for i in range(num_images_to_show):\n",
|
||||
" show_image_comparison(\n",
|
||||
" images[i], \n",
|
||||
" reconstructions[i], \n",
|
||||
" labels[i].item(), \n",
|
||||
" predicted[i].item()\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" # Calculate and print accuracy\n",
|
||||
" correct = (predicted == labels).sum().item()\n",
|
||||
" total = labels.size(0)\n",
|
||||
" accuracy = 100 * correct / total\n",
|
||||
" print(f'Test Accuracy: {accuracy:.2f}%')\n",
|
||||
" \n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"An error occurred during evaluation: {str(e)}\")"
|
||||
" images, labels = next(iter(test_loader))\n",
|
||||
" images, labels = images.to(device), labels.to(device)\n",
|
||||
" \n",
|
||||
" logits, reconstructed = model(images)\n",
|
||||
" \n",
|
||||
" _, predicted = torch.max(logits.data, 1)\n",
|
||||
" \n",
|
||||
" num_images_to_show = min(5, len(images))\n",
|
||||
" for i in range(num_images_to_show):\n",
|
||||
" show_image_comparison(\n",
|
||||
" images[i], \n",
|
||||
" reconstructed[i], \n",
|
||||
" labels[i].item(), \n",
|
||||
" predicted[i].item()\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" correct = (predicted == labels).sum().item()\n",
|
||||
" total = labels.size(0)\n",
|
||||
" accuracy = 100 * correct / total\n",
|
||||
" print(f'Test Accuracy: {accuracy:.2f}%')"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue