268 lines
8.6 KiB
Text
268 lines
8.6 KiB
Text
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"\n",
|
|
"import torch.nn as nn\n",
|
|
"import torch.optim as optim\n",
|
|
"import torch.nn.functional as F\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"from torch.utils.data import DataLoader, TensorDataset\n",
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"from tqdm import tqdm"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Check if CUDA is available\n",
|
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|
"print(f\"Using device: {device}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"data = pd.read_csv('data/bel_data_test.csv')\n",
|
|
"# Load the data\n",
|
|
"data = np.array(data)\n",
|
|
"\n",
|
|
"# Split features and labels\n",
|
|
"X = data[:, 1:] # All columns except the first one\n",
|
|
"y = data[:, 0].astype(int) # First column as labels\n",
|
|
"\n",
|
|
"# Split the data into training and testing sets\n",
|
|
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Convert to PyTorch tensors\n",
|
|
"X_train_tensor = torch.FloatTensor(X_train)\n",
|
|
"y_train_tensor = torch.LongTensor(y_train)\n",
|
|
"X_test_tensor = torch.FloatTensor(X_test)\n",
|
|
"y_test_tensor = torch.LongTensor(y_test)\n",
|
|
"\n",
|
|
"# Create DataLoader objects\n",
|
|
"train_dataset = TensorDataset(X_train_tensor, y_train_tensor)\n",
|
|
"test_dataset = TensorDataset(X_test_tensor, y_test_tensor)\n",
|
|
"\n",
|
|
"train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n",
|
|
"test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class SemanticsMLP(nn.Module):\n",
|
|
" def __init__(self, input_size=1024, hidden_sizes=[512, 256, 128], num_classes=62):\n",
|
|
" super(SemanticsMLP, self).__init__()\n",
|
|
" self.input_size = input_size\n",
|
|
" self.hidden_sizes = hidden_sizes\n",
|
|
" self.num_classes = num_classes\n",
|
|
"\n",
|
|
" # Encoder (feature extractor)\n",
|
|
" self.encoder_layers = nn.ModuleList()\n",
|
|
" prev_size = input_size\n",
|
|
" for hidden_size in hidden_sizes:\n",
|
|
" self.encoder_layers.append(nn.Linear(prev_size, hidden_size))\n",
|
|
" prev_size = hidden_size\n",
|
|
"\n",
|
|
" # Classifier\n",
|
|
" self.classifier = nn.Linear(hidden_sizes[-1], num_classes)\n",
|
|
"\n",
|
|
" # Decoder\n",
|
|
" self.decoder_layers = nn.ModuleList()\n",
|
|
" reversed_hidden_sizes = list(reversed(hidden_sizes))\n",
|
|
" prev_size = hidden_sizes[-1]\n",
|
|
" for hidden_size in reversed_hidden_sizes[1:] + [input_size]:\n",
|
|
" self.decoder_layers.append(nn.Linear(prev_size, hidden_size))\n",
|
|
" prev_size = hidden_size\n",
|
|
"\n",
|
|
" def encode(self, x):\n",
|
|
" for layer in self.encoder_layers:\n",
|
|
" x = F.relu(layer(x))\n",
|
|
" return x\n",
|
|
"\n",
|
|
" def decode(self, x):\n",
|
|
" for layer in self.decoder_layers[:-1]:\n",
|
|
" x = F.relu(layer(x))\n",
|
|
" x = self.decoder_layers[-1](x) # No activation on the final layer\n",
|
|
" return x\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" encoded = self.encode(x)\n",
|
|
" logits = self.classifier(encoded)\n",
|
|
" reconstructed = self.decode(encoded)\n",
|
|
" return logits, reconstructed"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def show_image_comparison(original, reconstructed, label, prediction):\n",
|
|
" \"\"\"\n",
|
|
" Display the original and reconstructed images side by side.\n",
|
|
" \n",
|
|
" :param original: Original image (1D tensor of 1024 elements)\n",
|
|
" :param reconstructed: Reconstructed image (1D tensor of 1024 elements)\n",
|
|
" :param label: True label of the image\n",
|
|
" :param prediction: Predicted label of the image\n",
|
|
" \"\"\"\n",
|
|
" # Convert to numpy arrays and move to CPU if they're on GPU\n",
|
|
" original = original.cpu().numpy()\n",
|
|
" reconstructed = reconstructed.cpu().numpy()\n",
|
|
" \n",
|
|
" # Reshape the 1D arrays to 32x32 images\n",
|
|
" original_img = original.reshape(32, 32)\n",
|
|
" reconstructed_img = reconstructed.reshape(32, 32)\n",
|
|
" \n",
|
|
" # Create a figure with two subplots side by side\n",
|
|
" fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))\n",
|
|
" \n",
|
|
" # Show original image\n",
|
|
" ax1.imshow(original_img, cmap='gray')\n",
|
|
" ax1.set_title(f'Original (Label: {label})')\n",
|
|
" ax1.axis('off')\n",
|
|
" \n",
|
|
" # Show reconstructed image\n",
|
|
" ax2.imshow(reconstructed_img, cmap='gray')\n",
|
|
" ax2.set_title(f'Reconstructed (Predicted: {prediction})')\n",
|
|
" ax2.axis('off')\n",
|
|
" \n",
|
|
" plt.tight_layout()\n",
|
|
" plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model = SemanticsMLP(input_size=1024, hidden_sizes=[10], num_classes=62).to(device)\n",
|
|
"criterion = nn.CrossEntropyLoss()\n",
|
|
"reconstruction_criterion = nn.MSELoss()\n",
|
|
"optimizer = optim.Adam(model.parameters(), lr=0.001)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"num_epochs = 250\n",
|
|
"for epoch in range(num_epochs):\n",
|
|
" model.train()\n",
|
|
" running_loss = 0.0\n",
|
|
" \n",
|
|
" with tqdm(train_loader, unit=\"batch\") as tepoch:\n",
|
|
" for images, labels in tepoch:\n",
|
|
" tepoch.set_description(f\"Epoch {epoch+1}\")\n",
|
|
" \n",
|
|
" images, labels = images.to(device), labels.to(device)\n",
|
|
" \n",
|
|
" optimizer.zero_grad()\n",
|
|
" \n",
|
|
" logits, reconstructed = model(images)\n",
|
|
" \n",
|
|
" classification_loss = criterion(logits, labels)\n",
|
|
" reconstruction_loss = reconstruction_criterion(reconstructed, images)\n",
|
|
" total_loss = classification_loss + reconstruction_loss\n",
|
|
" \n",
|
|
" total_loss.backward()\n",
|
|
" optimizer.step()\n",
|
|
" \n",
|
|
" running_loss += total_loss.item()\n",
|
|
" \n",
|
|
" tepoch.set_postfix(loss=total_loss.item())\n",
|
|
" \n",
|
|
" epoch_loss = running_loss / len(train_loader)\n",
|
|
" # print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model.eval()\n",
|
|
"with torch.no_grad():\n",
|
|
" images, labels = next(iter(test_loader))\n",
|
|
" images, labels = images.to(device), labels.to(device)\n",
|
|
" \n",
|
|
" logits, reconstructed = model(images)\n",
|
|
" \n",
|
|
" _, predicted = torch.max(logits.data, 1)\n",
|
|
" \n",
|
|
" num_images_to_show = min(5, len(images))\n",
|
|
" for i in range(num_images_to_show):\n",
|
|
" show_image_comparison(\n",
|
|
" images[i], \n",
|
|
" reconstructed[i], \n",
|
|
" labels[i].item(), \n",
|
|
" predicted[i].item()\n",
|
|
" )\n",
|
|
" \n",
|
|
" correct = (predicted == labels).sum().item()\n",
|
|
" total = labels.size(0)\n",
|
|
" accuracy = 100 * correct / total\n",
|
|
" print(f'Test Accuracy: {accuracy:.2f}%')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "semantics",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.1"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|