77 lines
No EOL
2.2 KiB
Python
77 lines
No EOL
2.2 KiB
Python
import numpy as np
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
|
|
def load_data(file_path):
|
|
data = pd.read_csv(file_path)
|
|
data = np.array(data)
|
|
m, n = data.shape
|
|
|
|
data_train = data[1000:m].T
|
|
Y_train = data_train[0].astype(int)
|
|
X_train = data_train[1:n]
|
|
|
|
data_test = data[0:1000].T
|
|
Y_test = data_test[0].astype(int)
|
|
X_test = data_test[1:n]
|
|
|
|
return X_train, Y_train, X_test, Y_test
|
|
|
|
def plot_accuracy(acc_store, save_path=None):
|
|
"""
|
|
Plot training and validation accuracy over iterations.
|
|
|
|
Parameters:
|
|
acc_store (list of tuples): Each tuple contains (training_accuracy, validation_accuracy).
|
|
save_path (str, optional): If provided, saves the plot to the specified path.
|
|
"""
|
|
# Unzip the accuracy data
|
|
training_accuracy, validation_accuracy = zip(*acc_store)
|
|
|
|
# Plot
|
|
plt.figure(figsize=(10, 6))
|
|
plt.plot(training_accuracy, label='Training Accuracy')
|
|
plt.plot(validation_accuracy, label='Validation Accuracy')
|
|
plt.title('Training and Validation Accuracy Over Iterations')
|
|
plt.xlabel('Iterations (in steps of 10)')
|
|
plt.ylabel('Accuracy')
|
|
plt.legend()
|
|
plt.grid(True)
|
|
|
|
# Save the plot if a path is provided
|
|
if save_path:
|
|
plt.savefig(save_path)
|
|
print(f"Accuracy plot saved to {save_path}")
|
|
|
|
# Show the plot
|
|
plt.show()
|
|
|
|
|
|
def plot_loss(loss_store, save_path=None):
|
|
"""
|
|
Plot training and validation loss over iterations.
|
|
|
|
Parameters:
|
|
loss_store (list of tuples): Each tuple contains (training_loss, validation_loss).
|
|
save_path (str, optional): If provided, saves the plot to the specified path.
|
|
"""
|
|
# Unzip the loss data
|
|
training_loss, validation_loss = zip(*loss_store)
|
|
|
|
# Plot
|
|
plt.figure(figsize=(10, 6))
|
|
plt.plot(training_loss, label='Training Loss')
|
|
plt.plot(validation_loss, label='Validation Loss')
|
|
plt.title('Training and Validation Loss Over Iterations')
|
|
plt.xlabel('Iterations (in steps of 10)')
|
|
plt.ylabel('Loss')
|
|
plt.legend()
|
|
plt.grid(True)
|
|
|
|
# Save the plot if a path is provided
|
|
if save_path:
|
|
plt.savefig(save_path)
|
|
print(f"Loss plot saved to {save_path}")
|
|
|
|
# Show the plot
|
|
plt.show() |