Initial commit
This commit is contained in:
commit
1995df58ce
21 changed files with 6708 additions and 0 deletions
77
net/modules.py
Normal file
77
net/modules.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def load_data(file_path):
|
||||
data = pd.read_csv(file_path)
|
||||
data = np.array(data)
|
||||
m, n = data.shape
|
||||
|
||||
data_train = data[1000:m].T
|
||||
Y_train = data_train[0].astype(int)
|
||||
X_train = data_train[1:n]
|
||||
|
||||
data_test = data[0:1000].T
|
||||
Y_test = data_test[0].astype(int)
|
||||
X_test = data_test[1:n]
|
||||
|
||||
return X_train, Y_train, X_test, Y_test
|
||||
|
||||
def plot_accuracy(acc_store, save_path=None):
|
||||
"""
|
||||
Plot training and validation accuracy over iterations.
|
||||
|
||||
Parameters:
|
||||
acc_store (list of tuples): Each tuple contains (training_accuracy, validation_accuracy).
|
||||
save_path (str, optional): If provided, saves the plot to the specified path.
|
||||
"""
|
||||
# Unzip the accuracy data
|
||||
training_accuracy, validation_accuracy = zip(*acc_store)
|
||||
|
||||
# Plot
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(training_accuracy, label='Training Accuracy')
|
||||
plt.plot(validation_accuracy, label='Validation Accuracy')
|
||||
plt.title('Training and Validation Accuracy Over Iterations')
|
||||
plt.xlabel('Iterations (in steps of 10)')
|
||||
plt.ylabel('Accuracy')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
# Save the plot if a path is provided
|
||||
if save_path:
|
||||
plt.savefig(save_path)
|
||||
print(f"Accuracy plot saved to {save_path}")
|
||||
|
||||
# Show the plot
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_loss(loss_store, save_path=None):
|
||||
"""
|
||||
Plot training and validation loss over iterations.
|
||||
|
||||
Parameters:
|
||||
loss_store (list of tuples): Each tuple contains (training_loss, validation_loss).
|
||||
save_path (str, optional): If provided, saves the plot to the specified path.
|
||||
"""
|
||||
# Unzip the loss data
|
||||
training_loss, validation_loss = zip(*loss_store)
|
||||
|
||||
# Plot
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(training_loss, label='Training Loss')
|
||||
plt.plot(validation_loss, label='Validation Loss')
|
||||
plt.title('Training and Validation Loss Over Iterations')
|
||||
plt.xlabel('Iterations (in steps of 10)')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
# Save the plot if a path is provided
|
||||
if save_path:
|
||||
plt.savefig(save_path)
|
||||
print(f"Loss plot saved to {save_path}")
|
||||
|
||||
# Show the plot
|
||||
plt.show()
|
||||
Loading…
Add table
Add a link
Reference in a new issue