Change early stopping mechanism to be based on val accuracy

This commit is contained in:
Murtadha 2024-09-26 20:01:37 -04:00
parent 99640d88e8
commit 1f4e31fe70

View file

@ -156,10 +156,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"def gradient_descent(X_train, Y_train, X_val, Y_val, layer_dims, alpha, iterations, early_stop_patience=10):\n", "def gradient_descent(X_train, Y_train, X_val, Y_val, layer_dims, alpha, iterations, accuracy_threshold=0.85):\n",
" params = init_params(layer_dims)\n", " params = init_params(layer_dims)\n",
" best_val_accuracy = 0\n", " best_val_accuracy = 0\n",
" patience_counter = 0\n",
" acc_store = []\n", " acc_store = []\n",
" \n", " \n",
" for i in range(iterations):\n", " for i in range(iterations):\n",
@ -181,12 +180,10 @@
" if val_accuracy > best_val_accuracy:\n", " if val_accuracy > best_val_accuracy:\n",
" best_val_accuracy = val_accuracy\n", " best_val_accuracy = val_accuracy\n",
" best_params = params.copy()\n", " best_params = params.copy()\n",
" patience_counter = 0\n",
" else:\n",
" patience_counter += 1\n",
" \n", " \n",
" if patience_counter >= early_stop_patience:\n", " # Early stopping condition based on validation accuracy threshold\n",
" print(\"Early stopping triggered.\")\n", " if val_accuracy >= accuracy_threshold:\n",
" print(f\"Validation accuracy threshold of {accuracy_threshold:.2f} reached. Stopping training.\")\n",
" break\n", " break\n",
"\n", "\n",
" return best_params, best_val_accuracy, acc_store" " return best_params, best_val_accuracy, acc_store"
@ -198,13 +195,13 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"def grid_search(X_train, Y_train, X_val, Y_val, layer_configs, alpha, iterations):\n", "def grid_search(X_train, Y_train, X_val, Y_val, layer_configs, alpha, iterations, accuracy_threshold=0.85):\n",
" results = []\n", " results = []\n",
" \n", " \n",
" for layer_config in layer_configs:\n", " for layer_config in layer_configs:\n",
" layer_dims = [input_size] + list(layer_config) + [output_size]\n", " layer_dims = [input_size] + list(layer_config) + [output_size]\n",
" print(f\"Training architecture: {layer_dims}\")\n", " print(f\"Training architecture: {layer_dims}\")\n",
" best_params, accuracy, acc_store = gradient_descent(X_train, Y_train, X_val, Y_val, layer_dims, alpha, iterations)\n", " best_params, accuracy, acc_store = gradient_descent(X_train, Y_train, X_val, Y_val, layer_dims, alpha, iterations, accuracy_threshold)\n",
" results.append((layer_config, accuracy, best_params, acc_store))\n", " results.append((layer_config, accuracy, best_params, acc_store))\n",
" print(f\"Architecture {layer_dims}: Best Validation Accuracy: {accuracy:.4f}\\n\")\n", " print(f\"Architecture {layer_dims}: Best Validation Accuracy: {accuracy:.4f}\\n\")\n",
" \n", " \n",