semantics/net/modules.py
2024-09-26 17:23:23 -04:00

77 lines
No EOL
2.2 KiB
Python

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
def load_data(file_path):
data = pd.read_csv(file_path)
data = np.array(data)
m, n = data.shape
data_train = data[1000:m].T
Y_train = data_train[0].astype(int)
X_train = data_train[1:n]
data_test = data[0:1000].T
Y_test = data_test[0].astype(int)
X_test = data_test[1:n]
return X_train, Y_train, X_test, Y_test
def plot_accuracy(acc_store, save_path=None):
"""
Plot training and validation accuracy over iterations.
Parameters:
acc_store (list of tuples): Each tuple contains (training_accuracy, validation_accuracy).
save_path (str, optional): If provided, saves the plot to the specified path.
"""
# Unzip the accuracy data
training_accuracy, validation_accuracy = zip(*acc_store)
# Plot
plt.figure(figsize=(10, 6))
plt.plot(training_accuracy, label='Training Accuracy')
plt.plot(validation_accuracy, label='Validation Accuracy')
plt.title('Training and Validation Accuracy Over Iterations')
plt.xlabel('Iterations (in steps of 10)')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
# Save the plot if a path is provided
if save_path:
plt.savefig(save_path)
print(f"Accuracy plot saved to {save_path}")
# Show the plot
plt.show()
def plot_loss(loss_store, save_path=None):
"""
Plot training and validation loss over iterations.
Parameters:
loss_store (list of tuples): Each tuple contains (training_loss, validation_loss).
save_path (str, optional): If provided, saves the plot to the specified path.
"""
# Unzip the loss data
training_loss, validation_loss = zip(*loss_store)
# Plot
plt.figure(figsize=(10, 6))
plt.plot(training_loss, label='Training Loss')
plt.plot(validation_loss, label='Validation Loss')
plt.title('Training and Validation Loss Over Iterations')
plt.xlabel('Iterations (in steps of 10)')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# Save the plot if a path is provided
if save_path:
plt.savefig(save_path)
print(f"Loss plot saved to {save_path}")
# Show the plot
plt.show()