diff --git a/bel_semantics.ipynb b/bel_semantics.ipynb index f0924cc..4083ea5 100644 --- a/bel_semantics.ipynb +++ b/bel_semantics.ipynb @@ -10,6 +10,7 @@ "\n", "import torch.nn as nn\n", "import torch.optim as optim\n", + "import torch.nn.functional as F\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", @@ -74,50 +75,47 @@ "metadata": {}, "outputs": [], "source": [ - "# Define the MLP\n", - "class MLP(nn.Module):\n", - " def __init__(self):\n", - " super(MLP, self).__init__()\n", - " self.input_layer = nn.Linear(1024, 512)\n", - " self.h1_layer = nn.Linear(512, 64)\n", - " self.h2_layer = nn.Linear(64, 62)\n", - " self.relu = nn.ReLU()\n", + "class SemanticsMLP(nn.Module):\n", + " def __init__(self, input_size=1024, hidden_sizes=[512, 256, 128], num_classes=62):\n", + " super(SemanticsMLP, self).__init__()\n", + " self.input_size = input_size\n", + " self.hidden_sizes = hidden_sizes\n", + " self.num_classes = num_classes\n", "\n", - " def forward(self, x):\n", - " x = self.relu(self.input_layer(x))\n", - " x = self.h1_layer(x)\n", - " x = self.h2_layer(x)\n", + " # Encoder (feature extractor)\n", + " self.encoder_layers = nn.ModuleList()\n", + " prev_size = input_size\n", + " for hidden_size in hidden_sizes:\n", + " self.encoder_layers.append(nn.Linear(prev_size, hidden_size))\n", + " prev_size = hidden_size\n", + "\n", + " # Classifier\n", + " self.classifier = nn.Linear(hidden_sizes[-1], num_classes)\n", + "\n", + " # Decoder\n", + " self.decoder_layers = nn.ModuleList()\n", + " reversed_hidden_sizes = list(reversed(hidden_sizes))\n", + " prev_size = hidden_sizes[-1]\n", + " for hidden_size in reversed_hidden_sizes[1:] + [input_size]:\n", + " self.decoder_layers.append(nn.Linear(prev_size, hidden_size))\n", + " prev_size = hidden_size\n", + "\n", + " def encode(self, x):\n", + " for layer in self.encoder_layers:\n", + " x = F.relu(layer(x))\n", " return x\n", "\n", - "# Define the Decoder\n", - "class Decoder(nn.Module):\n", - " def __init__(self):\n", - " super(Decoder, self).__init__()\n", - " self.h2_h1 = nn.Linear(64, 512)\n", - " self.h1_input = nn.Linear(512, 1024)\n", - " self.relu = nn.ReLU()\n", - "\n", - " def forward(self, x):\n", - " x = self.relu(self.h2_h1(x))\n", - " x = self.h1_input(x)\n", + " def decode(self, x):\n", + " for layer in self.decoder_layers[:-1]:\n", + " x = F.relu(layer(x))\n", + " x = self.decoder_layers[-1](x) # No activation on the final layer\n", " return x\n", "\n", - "class MLPWithDecoder(nn.Module):\n", - " def __init__(self):\n", - " super(MLPWithDecoder, self).__init__()\n", - " self.mlp = MLP()\n", - " self.decoder = Decoder()\n", - "\n", " def forward(self, x):\n", - " # MLP forward pass\n", - " h1 = self.mlp.relu(self.mlp.input_layer(x))\n", - " h2 = self.mlp.relu(self.mlp.h1_layer(h1))\n", - " output = self.mlp.h2_layer(h2)\n", - " \n", - " # Reconstruction\n", - " reconstruction = self.decoder(h2)\n", - " \n", - " return output, reconstruction" + " encoded = self.encode(x)\n", + " logits = self.classifier(encoded)\n", + " reconstructed = self.decode(encoded)\n", + " return logits, reconstructed" ] }, { @@ -125,20 +123,6 @@ "execution_count": 6, "metadata": {}, "outputs": [], - "source": [ - "# Function to reconstruct an image\n", - "def reconstruct_image(model, image):\n", - " model.eval()\n", - " with torch.no_grad():\n", - " _, reconstruction = model(image.unsqueeze(0))\n", - " return reconstruction.squeeze(0)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], "source": [ "def show_image_comparison(original, reconstructed, label, prediction):\n", " \"\"\"\n", @@ -176,19 +160,14 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ - "# Initialize the model, loss function, and optimizer\n", - "model = MLPWithDecoder()\n", + "model = SemanticsMLP(input_size=1024, hidden_sizes=[10], num_classes=62).to(device)\n", "criterion = nn.CrossEntropyLoss()\n", "reconstruction_criterion = nn.MSELoss()\n", - "optimizer = optim.Adam(model.parameters())\n", - "\n", - "model = model.to(device)\n", - "criterion = criterion.to(device)\n", - "reconstruction_criterion = reconstruction_criterion.to(device)" + "optimizer = optim.Adam(model.parameters(), lr=0.001)" ] }, { @@ -199,10 +178,9 @@ "source": [ "num_epochs = 250\n", "for epoch in range(num_epochs):\n", - " model.train() # Set the model to training mode\n", + " model.train()\n", " running_loss = 0.0\n", " \n", - " # Use tqdm for a progress bar\n", " with tqdm(train_loader, unit=\"batch\") as tepoch:\n", " for images, labels in tepoch:\n", " tepoch.set_description(f\"Epoch {epoch+1}\")\n", @@ -211,10 +189,10 @@ " \n", " optimizer.zero_grad()\n", " \n", - " outputs, reconstructions = model(images)\n", + " logits, reconstructed = model(images)\n", " \n", - " classification_loss = criterion(outputs, labels)\n", - " reconstruction_loss = reconstruction_criterion(reconstructions, images)\n", + " classification_loss = criterion(logits, labels)\n", + " reconstruction_loss = reconstruction_criterion(reconstructed, images)\n", " total_loss = classification_loss + reconstruction_loss\n", " \n", " total_loss.backward()\n", @@ -225,7 +203,6 @@ " tepoch.set_postfix(loss=total_loss.item())\n", " \n", " epoch_loss = running_loss / len(train_loader)\n", - " \n", " # print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')" ] }, @@ -235,40 +212,28 @@ "metadata": {}, "outputs": [], "source": [ - "model.eval() # Set the model to evaluation mode\n", + "model.eval()\n", "with torch.no_grad():\n", - " try:\n", - " # Get a batch of test data\n", - " images, labels = next(iter(test_loader))\n", - " \n", - " # Move data to the same device as the model\n", - " images = images.to(device)\n", - " labels = labels.to(device)\n", - " \n", - " # Forward pass through the model\n", - " outputs, reconstructions = model(images)\n", - " \n", - " # Get predicted labels\n", - " _, predicted = torch.max(outputs.data, 1)\n", - " \n", - " # Display the first few images in the batch\n", - " num_images_to_show = min(5, len(images))\n", - " for i in range(num_images_to_show):\n", - " show_image_comparison(\n", - " images[i], \n", - " reconstructions[i], \n", - " labels[i].item(), \n", - " predicted[i].item()\n", - " )\n", - " \n", - " # Calculate and print accuracy\n", - " correct = (predicted == labels).sum().item()\n", - " total = labels.size(0)\n", - " accuracy = 100 * correct / total\n", - " print(f'Test Accuracy: {accuracy:.2f}%')\n", - " \n", - " except Exception as e:\n", - " print(f\"An error occurred during evaluation: {str(e)}\")" + " images, labels = next(iter(test_loader))\n", + " images, labels = images.to(device), labels.to(device)\n", + " \n", + " logits, reconstructed = model(images)\n", + " \n", + " _, predicted = torch.max(logits.data, 1)\n", + " \n", + " num_images_to_show = min(5, len(images))\n", + " for i in range(num_images_to_show):\n", + " show_image_comparison(\n", + " images[i], \n", + " reconstructed[i], \n", + " labels[i].item(), \n", + " predicted[i].item()\n", + " )\n", + " \n", + " correct = (predicted == labels).sum().item()\n", + " total = labels.size(0)\n", + " accuracy = 100 * correct / total\n", + " print(f'Test Accuracy: {accuracy:.2f}%')" ] }, {