Change early stopping mechanism to be based on val accuracy

This commit is contained in:
Murtadha 2024-09-26 20:01:37 -04:00
parent 99640d88e8
commit 1f4e31fe70

View file

@ -156,10 +156,9 @@
"metadata": {},
"outputs": [],
"source": [
"def gradient_descent(X_train, Y_train, X_val, Y_val, layer_dims, alpha, iterations, early_stop_patience=10):\n",
"def gradient_descent(X_train, Y_train, X_val, Y_val, layer_dims, alpha, iterations, accuracy_threshold=0.85):\n",
" params = init_params(layer_dims)\n",
" best_val_accuracy = 0\n",
" patience_counter = 0\n",
" acc_store = []\n",
" \n",
" for i in range(iterations):\n",
@ -181,12 +180,10 @@
" if val_accuracy > best_val_accuracy:\n",
" best_val_accuracy = val_accuracy\n",
" best_params = params.copy()\n",
" patience_counter = 0\n",
" else:\n",
" patience_counter += 1\n",
" \n",
" if patience_counter >= early_stop_patience:\n",
" print(\"Early stopping triggered.\")\n",
" # Early stopping condition based on validation accuracy threshold\n",
" if val_accuracy >= accuracy_threshold:\n",
" print(f\"Validation accuracy threshold of {accuracy_threshold:.2f} reached. Stopping training.\")\n",
" break\n",
"\n",
" return best_params, best_val_accuracy, acc_store"
@ -198,13 +195,13 @@
"metadata": {},
"outputs": [],
"source": [
"def grid_search(X_train, Y_train, X_val, Y_val, layer_configs, alpha, iterations):\n",
"def grid_search(X_train, Y_train, X_val, Y_val, layer_configs, alpha, iterations, accuracy_threshold=0.85):\n",
" results = []\n",
" \n",
" for layer_config in layer_configs:\n",
" layer_dims = [input_size] + list(layer_config) + [output_size]\n",
" print(f\"Training architecture: {layer_dims}\")\n",
" best_params, accuracy, acc_store = gradient_descent(X_train, Y_train, X_val, Y_val, layer_dims, alpha, iterations)\n",
" best_params, accuracy, acc_store = gradient_descent(X_train, Y_train, X_val, Y_val, layer_dims, alpha, iterations, accuracy_threshold)\n",
" results.append((layer_config, accuracy, best_params, acc_store))\n",
" print(f\"Architecture {layer_dims}: Best Validation Accuracy: {accuracy:.4f}\\n\")\n",
" \n",