109 lines
No EOL
4.5 KiB
Python
109 lines
No EOL
4.5 KiB
Python
import numpy as np
|
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
from net.mlp import MLP
|
|
from net.modules import calculate_loss, calculate_accuracy, plot_learning_curves, plot_encoded_space, plot_reconstructions
|
|
|
|
class Transcoder(MLP):
|
|
def __init__(self, input_size, hidden_size, output_size, hidden_activation='leaky_relu', output_activation='softmax', alpha=0.01):
|
|
super().__init__(input_size, hidden_size, output_size, hidden_activation, output_activation, alpha)
|
|
self.train_losses = []
|
|
self.val_losses = []
|
|
self.train_accuracies = []
|
|
self.val_accuracies = []
|
|
self.image_shape = self.determine_image_shape(input_size)
|
|
|
|
@staticmethod
|
|
def determine_image_shape(input_size):
|
|
sqrt = int(np.sqrt(input_size))
|
|
if sqrt ** 2 == input_size:
|
|
return (sqrt, sqrt)
|
|
else:
|
|
return (input_size, 1) # Default to column vector if not square
|
|
|
|
def encode_image(self, X):
|
|
_, _, _, A2 = self.forward_prop(X)
|
|
# print(f"Debug - Encoded image shape: {A2.shape}") #Debugging
|
|
return A2
|
|
|
|
def decode_image(self, A2):
|
|
# Start decoding from the encoded representation (A2)
|
|
# print(f"Debug - A2 image shape: {A2.shape}") #Debugging
|
|
|
|
# Step 1: Reverse the output_activation function to get Z2
|
|
Z2 = self.inverse_output_activation(A2)
|
|
# print(f"Debug - Z2 image shape: {Z2.shape}") #Debugging
|
|
|
|
# Step 2: Reverse the second linear transformation to get A1
|
|
A1 = np.linalg.pinv(self.W2).dot(Z2 - self.b2)
|
|
# print(f"Debug - A1 image shape: {A1.shape}") #Debugging
|
|
|
|
# Step 3: Reverse the hidden_activation function to get Z1
|
|
Z1 = self.inverse_hidden_activation(A1, self.alpha)
|
|
# print(f"Debug - Z1 image shape: {Z1.shape}") #Debugging
|
|
|
|
# Step 4: Reverse the first linear transformation to get X (flattened 1D array)
|
|
X_flat = np.linalg.pinv(self.W1).dot(Z1 - self.b1)
|
|
# print(f"Debug - X_Flat image shape: {X_flat.shape}") #Debugging
|
|
|
|
# Step 5: If X_flat has shape (1024, n_samples), reshape it for each sample
|
|
if X_flat.ndim > 1:
|
|
X_flat = X_flat[:, 0] # Extract the first sample or reshape for batch processing
|
|
|
|
# Reshape to original image dimensions (32x32)
|
|
X_image = X_flat.reshape(self.image_shape)
|
|
|
|
return X_image
|
|
|
|
def transcode(self, X):
|
|
print(f"Debug - Input X shape: {X.shape}")
|
|
encoded = self.encode_image(X)
|
|
decoded = self.decode_image(encoded)
|
|
return encoded, decoded
|
|
|
|
def train_with_validation(self, X, Y, alpha, iterations, val_split=0.2):
|
|
# Ensure X is of shape (n_features, n_samples)
|
|
if X.shape[0] != self.input_size:
|
|
X = X.T
|
|
|
|
# Ensure Y is a 1D array
|
|
if Y.ndim > 1:
|
|
Y = Y.ravel()
|
|
|
|
X_train, X_val, Y_train, Y_val = train_test_split(X.T, Y, test_size=val_split, random_state=42)
|
|
X_train, X_val = X_train.T, X_val.T # Transpose back to (n_features, n_samples)
|
|
|
|
for i in range(iterations):
|
|
# Train step
|
|
Z1, A1, Z2, A2 = self.forward_prop(X_train)
|
|
dW1, db1, dW2, db2 = self.backward_prop(Z1, A1, Z2, A2, X_train, Y_train)
|
|
self.update_params(dW1, db1, dW2, db2, alpha)
|
|
|
|
# Calculate and store losses and accuracies
|
|
train_loss = calculate_loss(self, X_train, Y_train)
|
|
val_loss = calculate_loss(self, X_val, Y_val)
|
|
train_accuracy = calculate_accuracy(self, X_train, Y_train)
|
|
val_accuracy = calculate_accuracy(self, X_val, Y_val)
|
|
|
|
self.train_losses.append(train_loss)
|
|
self.val_losses.append(val_loss)
|
|
self.train_accuracies.append(train_accuracy)
|
|
self.val_accuracies.append(val_accuracy)
|
|
|
|
if i % 100 == 0:
|
|
print(f"Iteration {i}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}, "
|
|
f"Train Accuracy = {train_accuracy:.4f}, Val Accuracy = {val_accuracy:.4f}")
|
|
|
|
def plot_learning_curves(self):
|
|
plot_learning_curves(self.train_losses, self.val_losses, self.train_accuracies, self.val_accuracies)
|
|
|
|
def plot_encoded_space(self, X, Y):
|
|
if X.shape[0] != self.input_size:
|
|
X = X.T
|
|
plot_encoded_space(self, X, Y)
|
|
|
|
def plot_reconstructions(self, X, num_images=5):
|
|
if X.shape[0] != self.input_size:
|
|
X = X.T
|
|
plot_reconstructions(self, X, num_images) |