From e6a95e852c8fd4761939993e58e6b35ca82adf3b Mon Sep 17 00:00:00 2001
From: Iago Bonnici <iago.bonnici@umontpellier.fr>
Date: Mon, 11 Mar 2024 16:25:23 +0100
Subject: [PATCH] Switch to AdamW optimizer, expose search parms.

---
 src/bin/aphid/main.rs     |  2 +-
 src/config.rs             |  2 +-
 src/config/check.rs       | 17 +++++++++++++++++
 src/config/deserialize.rs | 32 ++++++++++++++++++++++++++++++--
 src/model.rs              | 10 +++++++---
 5 files changed, 56 insertions(+), 7 deletions(-)

diff --git a/src/bin/aphid/main.rs b/src/bin/aphid/main.rs
index 0416b47..b786a5f 100644
--- a/src/bin/aphid/main.rs
+++ b/src/bin/aphid/main.rs
@@ -490,7 +490,7 @@ fn run() -> Result<(), AphidError> {
     display_tree_lnl_detail(&parms, lnl);
 
     eprintln!("\nOptimizing ln-likelihood:");
-    let (opt, opt_lnl) = optimize_likelihood(&triplets, &parms);
+    let (opt, opt_lnl) = optimize_likelihood(&triplets, &parms, &config.search);
 
     eprintln!("\nOptimized ln-likelihood:");
     display_tree_lnl_detail(&opt, opt_lnl);
diff --git a/src/config.rs b/src/config.rs
index 7f83674..09301a7 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -9,4 +9,4 @@
 mod check;
 mod deserialize;
 
-pub use check::{Config, Filters, SpeciesTriplet, Taxa};
+pub use check::{Config, Filters, SpeciesTriplet, Taxa, Search};
diff --git a/src/config/check.rs b/src/config/check.rs
index 0a27096..34e6682 100644
--- a/src/config/check.rs
+++ b/src/config/check.rs
@@ -49,6 +49,9 @@ pub struct Config {
     // For this reason, *enforce* that it be small.
     // The value is given in *mutation* units, so branch length × sequence length.
     pub unresolved_length: Option<MeanNbBases>,
+
+    // Parametrize the search.
+    pub search: Search,
 }
 const MAX_UNRESOLVED_LENGTH: f64 = 0.5;
 
@@ -100,6 +103,11 @@ pub struct SpeciesTriplet {
     pub c: SpeciesSymbol,
 }
 
+pub struct Search {
+    pub(crate) learning_rate: f64,
+    pub(crate) max_iterations: u64,
+}
+
 macro_rules! err {
     ($($message:tt)*) => {{
         return Err(Error::InputConsistency(format!$($message)*));
@@ -130,6 +138,11 @@ impl Config {
             }
         }
 
+        let &lr = &raw.search.learning_rate;
+        if lr <= 0.0 {
+            err!(("Learning rate needs to be positive, not {lr}."));
+        }
+
         // Most checks implemented within `TryFrom` trait.
         Ok(Config {
             taxa: raw.taxa._try_into(interner)?,
@@ -141,6 +154,10 @@ impl Config {
             } else {
                 DEFAULT_FILTERS
             },
+            search: Search {
+                learning_rate: lr,
+                max_iterations: raw.search.max_iterations,
+            },
         })
     }
 
diff --git a/src/config/deserialize.rs b/src/config/deserialize.rs
index 43dab1e..c535adf 100644
--- a/src/config/deserialize.rs
+++ b/src/config/deserialize.rs
@@ -21,11 +21,13 @@ pub(super) type RawGeneFlowTimes = Vec<f64>;
 pub(super) struct RawConfig {
     pub(super) filters: Option<RawFilters>,
     pub(super) taxa: RawTaxa,
-    #[serde(rename = "init", default = "default_initial_parameters")]
-    pub(super) initial_parameters: InitialParameters,
     #[serde(default = "default_gf_times")]
     pub(super) gf_times: RawGeneFlowTimes,
     pub(super) unresolved_length: Option<f64>,
+    #[serde(rename = "init", default = "default_initial_parameters")]
+    pub(super) initial_parameters: InitialParameters,
+    #[serde(default = "default_search")]
+    pub(super) search: RawSearch,
 }
 
 impl RawConfig {
@@ -132,6 +134,32 @@ fn default_gf_times() -> Vec<f64> {
     vec![1.]
 }
 
+//--------------------------------------------------------------------------------------------------
+// Search configuration.
+#[derive(Deserialize)]
+#[serde(deny_unknown_fields)]
+pub(super) struct RawSearch {
+    #[serde(default = "default_learning_rate")]
+    pub(super) learning_rate: f64,
+    #[serde(default = "default_max_iter")]
+    pub(super) max_iterations: u64,
+}
+
+fn default_search() -> RawSearch {
+    RawSearch {
+        learning_rate: default_learning_rate(),
+        max_iterations: default_max_iter(),
+    }
+}
+
+fn default_learning_rate() -> f64 {
+    1e-1
+}
+
+fn default_max_iter() -> u64 {
+    1_000
+}
+
 //--------------------------------------------------------------------------------------------------
 // Parsing utils.
 
diff --git a/src/model.rs b/src/model.rs
index 424a984..9874d75 100644
--- a/src/model.rs
+++ b/src/model.rs
@@ -2,7 +2,7 @@
 
 use std::io::{self, Write};
 
-use crate::GeneTriplet;
+use crate::{config::Search, GeneTriplet};
 
 mod likelihood;
 pub(crate) mod parameters;
@@ -37,6 +37,7 @@ pub fn optimize_likelihood(
     // Receive in this format for better locality.
     triplets: &[GeneTriplet],
     start: &Parameters<f64>,
+    search: &Search,
 ) -> (Parameters<f64>, f64) {
     // Choose in-memory data location.
     let (kind, device) = (Kind::Double, Device::Cpu);
@@ -48,8 +49,11 @@ pub fn optimize_likelihood(
     let input = Tensor::zeros([], (kind, device));
 
     // Pick an optimizer and associated learning rate.
-    let mut opt = nn::Sgd::default().build(&vars, 1e-3).unwrap();
-    let n_steps = 1_000;
+    let mut opt = nn::AdamW::default()
+        .amsgrad(true)
+        .build(&vars, search.learning_rate)
+        .unwrap();
+    let n_steps = search.max_iterations;
 
     // Optimize.
     let display_step = |i, lnl: &Tensor| {
-- 
GitLab