From e6a95e852c8fd4761939993e58e6b35ca82adf3b Mon Sep 17 00:00:00 2001 From: Iago Bonnici <iago.bonnici@umontpellier.fr> Date: Mon, 11 Mar 2024 16:25:23 +0100 Subject: [PATCH] Switch to AdamW optimizer, expose search parms. --- src/bin/aphid/main.rs | 2 +- src/config.rs | 2 +- src/config/check.rs | 17 +++++++++++++++++ src/config/deserialize.rs | 32 ++++++++++++++++++++++++++++++-- src/model.rs | 10 +++++++--- 5 files changed, 56 insertions(+), 7 deletions(-) diff --git a/src/bin/aphid/main.rs b/src/bin/aphid/main.rs index 0416b47..b786a5f 100644 --- a/src/bin/aphid/main.rs +++ b/src/bin/aphid/main.rs @@ -490,7 +490,7 @@ fn run() -> Result<(), AphidError> { display_tree_lnl_detail(&parms, lnl); eprintln!("\nOptimizing ln-likelihood:"); - let (opt, opt_lnl) = optimize_likelihood(&triplets, &parms); + let (opt, opt_lnl) = optimize_likelihood(&triplets, &parms, &config.search); eprintln!("\nOptimized ln-likelihood:"); display_tree_lnl_detail(&opt, opt_lnl); diff --git a/src/config.rs b/src/config.rs index 7f83674..09301a7 100644 --- a/src/config.rs +++ b/src/config.rs @@ -9,4 +9,4 @@ mod check; mod deserialize; -pub use check::{Config, Filters, SpeciesTriplet, Taxa}; +pub use check::{Config, Filters, SpeciesTriplet, Taxa, Search}; diff --git a/src/config/check.rs b/src/config/check.rs index 0a27096..34e6682 100644 --- a/src/config/check.rs +++ b/src/config/check.rs @@ -49,6 +49,9 @@ pub struct Config { // For this reason, *enforce* that it be small. // The value is given in *mutation* units, so branch length × sequence length. pub unresolved_length: Option<MeanNbBases>, + + // Parametrize the search. + pub search: Search, } const MAX_UNRESOLVED_LENGTH: f64 = 0.5; @@ -100,6 +103,11 @@ pub struct SpeciesTriplet { pub c: SpeciesSymbol, } +pub struct Search { + pub(crate) learning_rate: f64, + pub(crate) max_iterations: u64, +} + macro_rules! err { ($($message:tt)*) => {{ return Err(Error::InputConsistency(format!$($message)*)); @@ -130,6 +138,11 @@ impl Config { } } + let &lr = &raw.search.learning_rate; + if lr <= 0.0 { + err!(("Learning rate needs to be positive, not {lr}.")); + } + // Most checks implemented within `TryFrom` trait. Ok(Config { taxa: raw.taxa._try_into(interner)?, @@ -141,6 +154,10 @@ impl Config { } else { DEFAULT_FILTERS }, + search: Search { + learning_rate: lr, + max_iterations: raw.search.max_iterations, + }, }) } diff --git a/src/config/deserialize.rs b/src/config/deserialize.rs index 43dab1e..c535adf 100644 --- a/src/config/deserialize.rs +++ b/src/config/deserialize.rs @@ -21,11 +21,13 @@ pub(super) type RawGeneFlowTimes = Vec<f64>; pub(super) struct RawConfig { pub(super) filters: Option<RawFilters>, pub(super) taxa: RawTaxa, - #[serde(rename = "init", default = "default_initial_parameters")] - pub(super) initial_parameters: InitialParameters, #[serde(default = "default_gf_times")] pub(super) gf_times: RawGeneFlowTimes, pub(super) unresolved_length: Option<f64>, + #[serde(rename = "init", default = "default_initial_parameters")] + pub(super) initial_parameters: InitialParameters, + #[serde(default = "default_search")] + pub(super) search: RawSearch, } impl RawConfig { @@ -132,6 +134,32 @@ fn default_gf_times() -> Vec<f64> { vec![1.] } +//-------------------------------------------------------------------------------------------------- +// Search configuration. +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +pub(super) struct RawSearch { + #[serde(default = "default_learning_rate")] + pub(super) learning_rate: f64, + #[serde(default = "default_max_iter")] + pub(super) max_iterations: u64, +} + +fn default_search() -> RawSearch { + RawSearch { + learning_rate: default_learning_rate(), + max_iterations: default_max_iter(), + } +} + +fn default_learning_rate() -> f64 { + 1e-1 +} + +fn default_max_iter() -> u64 { + 1_000 +} + //-------------------------------------------------------------------------------------------------- // Parsing utils. diff --git a/src/model.rs b/src/model.rs index 424a984..9874d75 100644 --- a/src/model.rs +++ b/src/model.rs @@ -2,7 +2,7 @@ use std::io::{self, Write}; -use crate::GeneTriplet; +use crate::{config::Search, GeneTriplet}; mod likelihood; pub(crate) mod parameters; @@ -37,6 +37,7 @@ pub fn optimize_likelihood( // Receive in this format for better locality. triplets: &[GeneTriplet], start: &Parameters<f64>, + search: &Search, ) -> (Parameters<f64>, f64) { // Choose in-memory data location. let (kind, device) = (Kind::Double, Device::Cpu); @@ -48,8 +49,11 @@ pub fn optimize_likelihood( let input = Tensor::zeros([], (kind, device)); // Pick an optimizer and associated learning rate. - let mut opt = nn::Sgd::default().build(&vars, 1e-3).unwrap(); - let n_steps = 1_000; + let mut opt = nn::AdamW::default() + .amsgrad(true) + .build(&vars, search.learning_rate) + .unwrap(); + let n_steps = search.max_iterations; // Optimize. let display_step = |i, lnl: &Tensor| { -- GitLab