Change early stopping mechanism to be based on val accuracy
This commit is contained in:
parent
99640d88e8
commit
1f4e31fe70
1 changed files with 7 additions and 10 deletions
|
|
@ -156,10 +156,9 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def gradient_descent(X_train, Y_train, X_val, Y_val, layer_dims, alpha, iterations, early_stop_patience=10):\n",
|
"def gradient_descent(X_train, Y_train, X_val, Y_val, layer_dims, alpha, iterations, accuracy_threshold=0.85):\n",
|
||||||
" params = init_params(layer_dims)\n",
|
" params = init_params(layer_dims)\n",
|
||||||
" best_val_accuracy = 0\n",
|
" best_val_accuracy = 0\n",
|
||||||
" patience_counter = 0\n",
|
|
||||||
" acc_store = []\n",
|
" acc_store = []\n",
|
||||||
" \n",
|
" \n",
|
||||||
" for i in range(iterations):\n",
|
" for i in range(iterations):\n",
|
||||||
|
|
@ -181,12 +180,10 @@
|
||||||
" if val_accuracy > best_val_accuracy:\n",
|
" if val_accuracy > best_val_accuracy:\n",
|
||||||
" best_val_accuracy = val_accuracy\n",
|
" best_val_accuracy = val_accuracy\n",
|
||||||
" best_params = params.copy()\n",
|
" best_params = params.copy()\n",
|
||||||
" patience_counter = 0\n",
|
" \n",
|
||||||
" else:\n",
|
" # Early stopping condition based on validation accuracy threshold\n",
|
||||||
" patience_counter += 1\n",
|
" if val_accuracy >= accuracy_threshold:\n",
|
||||||
" \n",
|
" print(f\"Validation accuracy threshold of {accuracy_threshold:.2f} reached. Stopping training.\")\n",
|
||||||
" if patience_counter >= early_stop_patience:\n",
|
|
||||||
" print(\"Early stopping triggered.\")\n",
|
|
||||||
" break\n",
|
" break\n",
|
||||||
"\n",
|
"\n",
|
||||||
" return best_params, best_val_accuracy, acc_store"
|
" return best_params, best_val_accuracy, acc_store"
|
||||||
|
|
@ -198,13 +195,13 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def grid_search(X_train, Y_train, X_val, Y_val, layer_configs, alpha, iterations):\n",
|
"def grid_search(X_train, Y_train, X_val, Y_val, layer_configs, alpha, iterations, accuracy_threshold=0.85):\n",
|
||||||
" results = []\n",
|
" results = []\n",
|
||||||
" \n",
|
" \n",
|
||||||
" for layer_config in layer_configs:\n",
|
" for layer_config in layer_configs:\n",
|
||||||
" layer_dims = [input_size] + list(layer_config) + [output_size]\n",
|
" layer_dims = [input_size] + list(layer_config) + [output_size]\n",
|
||||||
" print(f\"Training architecture: {layer_dims}\")\n",
|
" print(f\"Training architecture: {layer_dims}\")\n",
|
||||||
" best_params, accuracy, acc_store = gradient_descent(X_train, Y_train, X_val, Y_val, layer_dims, alpha, iterations)\n",
|
" best_params, accuracy, acc_store = gradient_descent(X_train, Y_train, X_val, Y_val, layer_dims, alpha, iterations, accuracy_threshold)\n",
|
||||||
" results.append((layer_config, accuracy, best_params, acc_store))\n",
|
" results.append((layer_config, accuracy, best_params, acc_store))\n",
|
||||||
" print(f\"Architecture {layer_dims}: Best Validation Accuracy: {accuracy:.4f}\\n\")\n",
|
" print(f\"Architecture {layer_dims}: Best Validation Accuracy: {accuracy:.4f}\\n\")\n",
|
||||||
" \n",
|
" \n",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue