From 1f4e31fe70c2071d374dd27952f2a9bfb763ac07 Mon Sep 17 00:00:00 2001 From: Murtadha Date: Thu, 26 Sep 2024 20:01:37 -0400 Subject: [PATCH] Change early stopping mechanism to be based on val accuracy --- bel_NN_dynamic.ipynb | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/bel_NN_dynamic.ipynb b/bel_NN_dynamic.ipynb index 03041c0..6a2066a 100644 --- a/bel_NN_dynamic.ipynb +++ b/bel_NN_dynamic.ipynb @@ -156,10 +156,9 @@ "metadata": {}, "outputs": [], "source": [ - "def gradient_descent(X_train, Y_train, X_val, Y_val, layer_dims, alpha, iterations, early_stop_patience=10):\n", + "def gradient_descent(X_train, Y_train, X_val, Y_val, layer_dims, alpha, iterations, accuracy_threshold=0.85):\n", " params = init_params(layer_dims)\n", " best_val_accuracy = 0\n", - " patience_counter = 0\n", " acc_store = []\n", " \n", " for i in range(iterations):\n", @@ -181,12 +180,10 @@ " if val_accuracy > best_val_accuracy:\n", " best_val_accuracy = val_accuracy\n", " best_params = params.copy()\n", - " patience_counter = 0\n", - " else:\n", - " patience_counter += 1\n", - " \n", - " if patience_counter >= early_stop_patience:\n", - " print(\"Early stopping triggered.\")\n", + " \n", + " # Early stopping condition based on validation accuracy threshold\n", + " if val_accuracy >= accuracy_threshold:\n", + " print(f\"Validation accuracy threshold of {accuracy_threshold:.2f} reached. Stopping training.\")\n", " break\n", "\n", " return best_params, best_val_accuracy, acc_store" @@ -198,13 +195,13 @@ "metadata": {}, "outputs": [], "source": [ - "def grid_search(X_train, Y_train, X_val, Y_val, layer_configs, alpha, iterations):\n", + "def grid_search(X_train, Y_train, X_val, Y_val, layer_configs, alpha, iterations, accuracy_threshold=0.85):\n", " results = []\n", " \n", " for layer_config in layer_configs:\n", " layer_dims = [input_size] + list(layer_config) + [output_size]\n", " print(f\"Training architecture: {layer_dims}\")\n", - " best_params, accuracy, acc_store = gradient_descent(X_train, Y_train, X_val, Y_val, layer_dims, alpha, iterations)\n", + " best_params, accuracy, acc_store = gradient_descent(X_train, Y_train, X_val, Y_val, layer_dims, alpha, iterations, accuracy_threshold)\n", " results.append((layer_config, accuracy, best_params, acc_store))\n", " print(f\"Architecture {layer_dims}: Best Validation Accuracy: {accuracy:.4f}\\n\")\n", " \n",