Change early stopping mechanism to be based on val accuracy
This commit is contained in:
parent
99640d88e8
commit
1f4e31fe70
1 changed files with 7 additions and 10 deletions
|
|
@ -156,10 +156,9 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def gradient_descent(X_train, Y_train, X_val, Y_val, layer_dims, alpha, iterations, early_stop_patience=10):\n",
|
||||
"def gradient_descent(X_train, Y_train, X_val, Y_val, layer_dims, alpha, iterations, accuracy_threshold=0.85):\n",
|
||||
" params = init_params(layer_dims)\n",
|
||||
" best_val_accuracy = 0\n",
|
||||
" patience_counter = 0\n",
|
||||
" acc_store = []\n",
|
||||
" \n",
|
||||
" for i in range(iterations):\n",
|
||||
|
|
@ -181,12 +180,10 @@
|
|||
" if val_accuracy > best_val_accuracy:\n",
|
||||
" best_val_accuracy = val_accuracy\n",
|
||||
" best_params = params.copy()\n",
|
||||
" patience_counter = 0\n",
|
||||
" else:\n",
|
||||
" patience_counter += 1\n",
|
||||
" \n",
|
||||
" if patience_counter >= early_stop_patience:\n",
|
||||
" print(\"Early stopping triggered.\")\n",
|
||||
" \n",
|
||||
" # Early stopping condition based on validation accuracy threshold\n",
|
||||
" if val_accuracy >= accuracy_threshold:\n",
|
||||
" print(f\"Validation accuracy threshold of {accuracy_threshold:.2f} reached. Stopping training.\")\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
" return best_params, best_val_accuracy, acc_store"
|
||||
|
|
@ -198,13 +195,13 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def grid_search(X_train, Y_train, X_val, Y_val, layer_configs, alpha, iterations):\n",
|
||||
"def grid_search(X_train, Y_train, X_val, Y_val, layer_configs, alpha, iterations, accuracy_threshold=0.85):\n",
|
||||
" results = []\n",
|
||||
" \n",
|
||||
" for layer_config in layer_configs:\n",
|
||||
" layer_dims = [input_size] + list(layer_config) + [output_size]\n",
|
||||
" print(f\"Training architecture: {layer_dims}\")\n",
|
||||
" best_params, accuracy, acc_store = gradient_descent(X_train, Y_train, X_val, Y_val, layer_dims, alpha, iterations)\n",
|
||||
" best_params, accuracy, acc_store = gradient_descent(X_train, Y_train, X_val, Y_val, layer_dims, alpha, iterations, accuracy_threshold)\n",
|
||||
" results.append((layer_config, accuracy, best_params, acc_store))\n",
|
||||
" print(f\"Architecture {layer_dims}: Best Validation Accuracy: {accuracy:.4f}\\n\")\n",
|
||||
" \n",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue