diff --git a/.rustfmt.toml b/.rustfmt.toml index 876a42ac475cdb1c8940c0ea85b60c20a287ac79..f5325cde975428545c0984e25e408c48a923334e 100644 --- a/.rustfmt.toml +++ b/.rustfmt.toml @@ -1,6 +1,6 @@ -single_line_if_else_max_width = 100 - unstable_features = true - +single_line_if_else_max_width = 100 imports_granularity = "Crate" group_imports = "StdExternalCrate" +merge_imports = true +struct_lit_width = 50 diff --git a/Cargo.toml b/Cargo.toml index c71be9acccc4955add6e9b15466925e2b8eba114..28e1e30f913a09680010a7e0d852972bff807ae2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,6 @@ itertools = "0.12.1" lexical-sort = "0.3.1" num-traits = "0.2.18" paste = "1.0.14" -rand = "0.8.5" regex = "1.10.3" serde = { version = "1.0.196", features = ["derive"] } serde_with = "3.6.1" @@ -27,6 +26,9 @@ tch = { git = "https://github.com/LaurentMazare/tch-rs", rev = "d068b18", versio toml = "0.8.10" unicode-width = "0.1.11" float_eq = { version = "1.0.1", features = ["derive"] } +snafu = "0.8.2" [dev-dependencies] -itertools = "0.12.1" +rand = "0.8.5" +rand_distr = "0.4.3" + diff --git a/plot_traces.py b/plot_traces.py new file mode 100644 index 0000000000000000000000000000000000000000..245ffad0be3e87c804d064cbb51f489e421cb319 --- /dev/null +++ b/plot_traces.py @@ -0,0 +1,243 @@ +"""Debug plotting of optimisation traces. +""" + +from matplotlib import colormaps +import matplotlib.pyplot as plt +import numpy as np +import polars as ps +import scipy as sc + +bfgs = "./target/bfgs.csv" +wolfe = "./target/bfgs_linsearch.csv" +grid = "./target/grid.csv" +# Theoretical min. +optx, opty = (0, -1) + +# Rough'n brutal 'interactive animation' procedure: +# dismiss any data after this step. +MAX_STEP = 4 +MAX_STEP = float("+inf") + +bfgs = ps.read_csv(bfgs) +wolfe = ps.read_csv(wolfe) +try: + grid = ps.read_csv( + grid, + has_header=False, + new_columns=["x", "y", "z"], + separator="\t", + ) +except FileNotFoundError: + grid = None + + +def group_colnames(cols: list[str]) -> list[str | list[str]]: + """Extract columns into clusters by name. + >>> group_colnames(["a", "b_1", "b_2", "c", "d_1", "e_1", "e_2"]) + ['a', ['b_1', 'b_2'], 'c', ['d_1'], ['e_1', 'e_2']] + """ + res = [] + previous_name = None + for c in cols: + try: + name, number = c.rsplit("_", 1) + number = int(number) + except: + name, number = c, None + if number is None: + res.append(c) + else: + if previous_name is None or previous_name != name: + res.append([c]) + else: + res[-1].append(c) + previous_name = name + return res + + +def group_columns(d: ps.DataFrame): + """Second step to extract values from columns.""" + res = [] + for c in group_colnames(d.columns): + if type(c) is list: + res.append([d[name] for name in c]) + else: + res.append(d[c]) + return res + + +bfgs = bfgs.filter(bfgs["step"] <= MAX_STEP) +wolfe = wolfe.filter(wolfe["id"] <= MAX_STEP) + +# --------------------------------------------------------------------------------------- +# BFGS plots: +# First panel is the loss and step size. +# Next panels are the non-scalar values, +# with a color for every optimized input variable. +step, loss, step_size, vars, grad, direction = group_columns(bfgs) + +fig, [ax, *vec_axes] = plt.subplots(4) +fig.subplots_adjust( + right=0.75, + left=0.05, + top=0.95, + bottom=0.01, +) + +# Basic loss and step size. +loss_ax, size_ax = [ax, ax.twinx()] +loss_c, size_c, grad_c = ("blue", "red", "green") +loss_ax.plot(step, loss, c=loss_c, marker=".") +loss_ax.set_ylabel("Loss", color=loss_c) +size_ax.plot(step, step_size, c=size_c, marker=".") +size_ax.set_ylabel("Step size", color=size_c) +size_ax.set_yscale("log") +loss_ax.set_yscale("log") +loss_ax.set_title("BFGS") + +# Non-scalar plots. +cm = colormaps["viridis"] +palette = [cm(i) for i in np.linspace(0, 1, len(vars))] +for vec, ax, name in zip( + (vars, grad, direction), vec_axes, ("variables", "gradient", "direction") +): + ax.set_title(name) + ax.get_yaxis().set_visible(False) + for i, (y, col) in enumerate(zip(vec, palette)): + twin = ax.twinx() + pos = 1 + i / 20 + twin.spines["right"].set_position(("axes", pos)) + twin.get_yaxis().get_offset_text().set_position((pos, 1.1)) + twin.set_frame_on(True) + twin.patch.set_visible(False) + twin.set_ylabel(i, rotation=0) + twin.tick_params(axis="y", colors=col) + twin.plot(step, y, color=col, label=i, marker=".") + twin.set_yscale("symlog") + +bfgs_vars = vars # Useful in the next + +# --------------------------------------------------------------------------------------- +# If a grid is available, draw the 2D trajectory. +if grid is not None: + # Convert to image map. + n = round(np.sqrt(len(grid))) + x, y, z = grid["x"], grid["y"], grid["z"] + minx, miny, maxx, maxy = (f(i) for f in (min, max) for i in (x, y)) + assert max(abs(x.unique().to_numpy() - np.linspace(minx, maxx, n))) < 1e-14 + assert max(abs(y.unique().to_numpy() - np.linspace(miny, maxy, n))) < 1e-14 + x = x[0:-1:n] + y = y[0:n] + z = z.to_numpy().reshape((n, n)).transpose() + # Find minima within the matrix. + shifts = [ + (range(1, n), range(n - 1, n)), + (range(0, 0), range(n)), + (range(0, 1), range(0, n - 1)), + ] + cat = np.concatenate + minima = np.ones((n, n), bool) + for ia, ib in shifts: + a, b = z[:, ia], z[:, ib] + for jc, jd in shifts: + ac, ad = a[jc], a[jd] + bc, bd = b[jc], b[jd] + shifted = cat([cat([ac, ad]), cat([bc, bd])], 1) + minima &= z <= shifted + argminy, argminx = sc.signal.argrelmax(minima) + low_x, low_y = x[argminx], y[argminy] + az = int(np.argmin(z)) + globmin_x, globmin_y = x[az % n], y[az // n] + # Plot. + fig, ax = plt.subplots() + ax.pcolormesh(x, y, z, norm="symlog", vmin=0) + # Display minima. + ax.scatter(low_x, low_y, c="red", zorder=2, s=0.1) + ax.scatter(globmin_x, globmin_y, c="pink", zorder=2) + [ + f(v, color="black", alpha=0.5, zorder=2, lw=0.5) + for f, v in [(ax.axvline, optx), (ax.axhline, opty)] + ] + # Superimpose trajectory. + x, y = vars + ax.plot(x, y, c="black", linewidth=1) + ax.scatter(x, y, c=range(len(x)), zorder=2) + ax2d = ax # Remember to later add linear search. + +( + id, + alpha, + loss, + phigrad, + vars, + grad, +) = group_columns(wolfe) + +# --------------------------------------------------------------------------------------- +# Wolfe searches plots. +# Vertical bars delimit successive searches, +# major x ticks show the BFGS steps identifiers, +# minor x ticks show the search duration. +# Plot the loss progress during every search. +# Also plot the step size search, wrapped by its supposedly shrinking search range. +search_sizes = ps.DataFrame(id).group_by("id", maintain_order=True).len()["len"] +steps = search_sizes.cum_sum() - 1 +fig, (axa, axl, axg) = plt.subplots(3) +axa.axhline(1, color="black", lw=1) +for x in (axa, axg): + x.set_xticks(steps, labels=range(len(steps))) + size_ticks = steps - search_sizes / 2 + x.set_xticks(size_ticks, search_sizes, minor=True, color="gray") +[x.set_yscale("log") for x in (axa, axl, axg)] +axa.set_ylabel("Step size", c=size_c) +axl.plot(loss, color=loss_c) +axl.scatter(range(len(loss)), loss, color=loss_c, s=5) +axl.set_ylabel("Loss", color=loss_c) +axg.set_ylabel("Step derivative", c=grad_c) +prev = 0 +for i_step, s in enumerate(steps): + x = range(prev, s + 1) + axa.plot(x, alpha[x], color=size_c) + axa.scatter(x, alpha[x], color=size_c, s=5, zorder=10) + axg.plot(x, phigrad[x], color=grad_c) + axg.scatter(x, phigrad[x], color=grad_c, s=5) + [x.axvline(s, color="black", lw=1, alpha=0.5) for x in (axa, axl, axg)] + prev = x.stop + if grid is not None: + # Append linear search steps. + x, y = (i[x] for i in vars) + startx, starty = [v[i_step] for v in bfgs_vars] + x, y = (np.concatenate([[s], v]) for s, v in [(startx, x), (starty, y)]) + (ax2d, minx, miny, maxx, maxy) = (ax2d, minx, miny, maxx, maxy) # type: ignore + opacity = 0.2 if i_step < len(steps) - 1 else 1 + ax2d.plot(x, y, c="blue", linewidth=0.5, zorder=1, alpha=opacity) + ax2d.scatter( + x, + y, + c=range(len(x)), + cmap=colormaps["plasma"], + s=2, + zorder=3, + alpha=opacity, + ) + dx, dy = np.diff(x), np.diff(y) + ax2d.quiver( + x[:-1], + y[:-1], + 0.1 * dx, + 0.1 * dy, + range(len(x) - 1), + cmap=colormaps["Reds"], + alpha=opacity, + angles="xy", + scale_units="xy", + scale=1, + units="dots", + width=2, + zorder=2, + ) + ax2d.scatter(x[-1], y[-1], marker="x", c="red", alpha=opacity) + # [pseudo-animation]: Only extend the limits if the "current step" is off. + ax2d.set_xlim(min(minx, min(x)), max(maxx, max(x))) + ax2d.set_ylim(min(miny, min(y)), max(maxy, max(y))) +plt.show() diff --git a/src/bin/aphid/main.rs b/src/bin/aphid/main.rs index b786a5fd5e5a97f570cb29f2e125bd8eed967cd7..5603aefa5a392d97d79076156f3245ba80c358dc 100644 --- a/src/bin/aphid/main.rs +++ b/src/bin/aphid/main.rs @@ -8,7 +8,7 @@ use aphid::{ it_mean::SummedMean, ln_likelihood, optimize_likelihood, topological_analysis::{SectionAnalysis, SectionLca}, - Config, Error as AphidError, GeneTree, GeneTriplet, GenesForest, TopologicalAnalysis, VERSION, + Config, GeneTree, GeneTriplet, GenesForest, TopologicalAnalysis, VERSION, }; use arrayvec::ArrayVec; use clap::Parser; @@ -20,6 +20,7 @@ mod justify; use display_tree_analysis::display_topological_tree_analysis; use float_eq::float_eq; use justify::{justify, terminal_width}; +use snafu::{ensure, Snafu}; use crate::display_tree_analysis::display_geometrical_tree_analysis; @@ -40,14 +41,13 @@ fn main() { } Err(e) => { eprintln!("{} {e}", "🗙".bold().red()); - process::exit(e.code()); + process::exit(1); } } } #[allow(clippy::too_many_lines)] // TODO: factorize later: ongoing experiment. -fn run() -> Result<(), AphidError> { - use AphidError as E; +fn run() -> Result<(), Error> { // Parse command line arguments. let args = Args::parse(); @@ -81,16 +81,19 @@ fn run() -> Result<(), AphidError> { not_found.remove(&sp); } } - if !not_found.is_empty() { - return Err(E::InputConsistency(format!( - "Species {1:?} cannot be found in input gene trees. {0} mispelled?", - if not_found.len() == 1 { "Is it" } else { "Are they" }, - not_found - .iter() - .map(|&sp| ResolvedSymbol::new(sp, interner)) - .collect::<Vec<_>>(), - ))); - } + ensure!( + not_found.is_empty(), + InputConsistencyErr { + mess: format!( + "Species {1:?} cannot be found in input gene trees. {0} mispelled?", + if not_found.len() == 1 { "Is it" } else { "Are they" }, + not_found + .iter() + .map(|&sp| ResolvedSymbol::new(sp, interner)) + .collect::<Vec<_>>(), + ) + } + ); eprintln!("Echo input:"); eprintln!("{:#?}", config.for_display(interner)); @@ -461,10 +464,7 @@ fn run() -> Result<(), AphidError> { let &relative_mutation_rate = mutation_rates[i].as_ref().unwrap(); // Upgrade local triplets to extended versions, aware of the whole dataset. - let triplet = GeneTriplet { - local, - relative_mutation_rate, - }; + let triplet = GeneTriplet { local, relative_mutation_rate }; kept_ids.push(ids[i]); triplets.push(triplet); @@ -472,11 +472,12 @@ fn run() -> Result<(), AphidError> { } drop(passed_triplets); // Has been consumed into. - eprintln!("Starting point: {:#?}", config.init_parms); + eprintln!("Starting point: {:#?}", config.search.init_parms); - let parms = config.init_parms; + let parms = &config.search.init_parms; eprintln!("\nInitial ln-likelihood:"); - let display_tree_lnl_detail = |parms, lnl| { + #[cfg(test)] + let display_tree_lnl_detail = |parms| { for (&id, trip) in kept_ids.iter().zip(triplets.iter()) { eprintln!( " {}\t{}", @@ -484,16 +485,22 @@ fn run() -> Result<(), AphidError> { interner.resolve(id).unwrap(), ); } - eprintln!("total: {lnl}"); }; - let lnl = ln_likelihood(&triplets, &parms); - display_tree_lnl_detail(&parms, lnl); + let lnl = ln_likelihood(&triplets, parms); + #[cfg(test)] + display_tree_lnl_detail(parms); + eprintln!("total: {lnl}"); eprintln!("\nOptimizing ln-likelihood:"); - let (opt, opt_lnl) = optimize_likelihood(&triplets, &parms, &config.search); + let (opt, opt_lnl) = + optimize_likelihood(&triplets, parms, &config.search).unwrap_or_else(|e| { + panic!("TODO: once learning works, remove all error types in favour of snafu.\n{e}") + }); eprintln!("\nOptimized ln-likelihood:"); - display_tree_lnl_detail(&opt, opt_lnl); + #[cfg(test)] + display_tree_lnl_detail(&opt); + eprintln!("total: {opt_lnl}"); eprintln!("\nEnding point: {opt:#?}\n"); @@ -520,3 +527,16 @@ impl ToString for ColoredString { format!("{self}") } } + +#[derive(Debug, Snafu)] +#[snafu(context(suffix(Err)))] +enum Error { + #[snafu(transparent)] + Check { source: aphid::config::check::Error }, + #[snafu(transparent)] + ForestParse { + source: aphid::genes_forest::parse::Error, + }, + #[snafu(display("Inconsistent input:\n{mess}"))] + InputConsistency { mess: String }, +} diff --git a/src/config.rs b/src/config.rs index 09301a777dbebe743c46148fb8455c47a57dc03a..5caa52d96ba6d849947bd67b86c2fe7f93efc439 100644 --- a/src/config.rs +++ b/src/config.rs @@ -6,7 +6,108 @@ // Once parsed, it is checked for consistency // and transformed into the actual, useful config data. -mod check; +use std::path::PathBuf; + +use crate::{ + gene_tree::MeanNbBases, + interner::SpeciesSymbol, + optim::{bfgs::Config as BfgsConfig, gd::Config as GdConfig}, + BranchLength, Parameters, +}; + +pub mod check; +pub(crate) mod defaults; mod deserialize; -pub use check::{Config, Filters, SpeciesTriplet, Taxa, Search}; +// The final value handed out to user. +pub struct Config { + // Path to the gene trees data file. + pub trees: PathBuf, + + // Names of the species of interest. + pub taxa: Taxa, + + // Parametrize the filters. + pub filters: Filters, + + // If set, when the internal triplet branch length "d" + // is less or equal to this value, + // consider that its topology is not resolved enough + // to exclude discordant scenarios. + // In this situation, + // every scenario contributes to the likelihood + // instead of only the concordant ones. + // The possible internal branch discordance + // between 'actual' and 'expected' length + // is neglected because the the actual length is small. + // For this reason, *enforce* that it be small. + // The value is given in *mutation* units, so branch length × sequence length. + pub unresolved_length: Option<MeanNbBases>, + + // Parametrize the search. + pub search: Search, +} +const MAX_UNRESOLVED_LENGTH: f64 = 0.5; + +#[derive(Debug)] +pub struct Filters { + // When raised, filter out trees if some species in 'other' + // branche between LCA(outgroup) and the root, + // because they may should have been considered 'outgroup' instead. + pub triplet_other_monophyly: bool, + + // Let `mt` be the mean branch length for gene tree triplet species, + // and `mr` ("rest") be the mean branch length for outgroup/other species. + // The "shape" of the gene tree is the ratio `q = mt/mr`. + // The "global shape" or "mean shape" of all gene trees is `Q = smt/smr`, + // where `smt` is the sum of all `mt` and `smr` of all `mr` + // for all trees *with accepted topology*. + // If set, this parameter reject trees whose shape differ from the mean shape + // by more than a factor, say `2`, or anything specified with this parameter. + // The motivation is that the hypothesis that mutation rate is constant + // accross the tree(s) is too audacious when, eg. `q / Q ⩾ max_clock_ratio`. + // Values need to be greater than 1 + // because they quantify "absolute" ratio or "imbalance": `max(q/Q, Q/q)`. + pub max_clock_ratio: Option<BranchLength>, +} + +// Contains only non-duplicated interned strings. +// Every section contains distinct species. +pub struct Taxa { + pub triplet: SpeciesTriplet, + pub(crate) outgroup: Vec<SpeciesSymbol>, + pub(crate) other: Vec<SpeciesSymbol>, +} + +// The triplet is importantly *oriented*: ((A, B), C). +// Accept several input forms provided they are unambiguous: +// triplet = [["U", "V"], "W"] # A: U, B: V, C: W +// triplet = ["U", ["V", "W"]] # A: V, B: W, C: U +// Or, shorter without special chars in species names: +// triplet = "((U, V), W)" +// triplet = "(U, (V, W))" +pub struct SpeciesTriplet { + pub a: SpeciesSymbol, + pub b: SpeciesSymbol, + pub c: SpeciesSymbol, +} + +#[derive(Debug)] +pub struct Search { + // Starting point in the search space. + pub init_parms: Parameters<f64>, + + // Dataset reduction factor + // when searching for non-degenerated starting point. + pub(crate) init_data_reduction_factor: f64, // Above 1. + + // Gradient search configuration + // when learning from a dataset reduction. + pub(crate) init_descent: GdConfig, + + // Main learning phase configuration. + pub(crate) bfgs: BfgsConfig, + + // Possible final gradient descent. + pub(crate) final_descent: Option<GdConfig>, +} diff --git a/src/config/check.rs b/src/config/check.rs index 34e668220e2edb4dbc4fed28e075e6d22d2d5f47..14e1abcb7b93ad03e7bd21c1665ef250b58d6a26 100644 --- a/src/config/check.rs +++ b/src/config/check.rs @@ -1,117 +1,37 @@ // Verify raw config consistency as we construct actual config from it. // Take this opportunity to intern all species names and replace them with their symbols. -use core::fmt; use std::{ cmp::Ordering, collections::HashSet, + fmt, path::{Path, PathBuf}, }; use arrayvec::ArrayVec; use regex::Regex; +use snafu::{ensure, ResultExt, Snafu}; -use super::deserialize::{ - InitialParameters, NodeInput, RawConfig, RawFilters, RawGeneFlowTimes, RawTaxa, RawTriplet, -}; use crate::{ - gene_tree::{BranchLength, MeanNbBases}, + config::{ + defaults, + deserialize::{self as raw, NodeInput}, + Search, SpeciesTriplet, Taxa, MAX_UNRESOLVED_LENGTH, + }, interner::{Interner, ResolvedSymbol, SpeciesSymbol}, - io::read_file, + io::{self, read_file}, model::parameters::{GeneFlowTimes, MAX_N_GF_TIMES}, - Error as E, Error, Parameters, -}; - -// The final value handed out to user. -pub struct Config { - // Path to the gene trees data file. - pub trees: PathBuf, - - // Names of the species of interest. - pub taxa: Taxa, - - // Starting point in the search space. - pub init_parms: Parameters<f64>, - - // Parametrize the filters. - pub filters: Filters, - - // If set, when the internal triplet branch length "d" - // is less or equal to this value, - // consider that its topology is not resolved enough - // to exclude discordant scenarios. - // In this situation, - // every scenario contributes to the likelihood - // instead of only the concordant ones. - // The possible internal branch discordance - // between 'actual' and 'expected' length - // is neglected because the the actual length is small. - // For this reason, *enforce* that it be small. - // The value is given in *mutation* units, so branch length × sequence length. - pub unresolved_length: Option<MeanNbBases>, - - // Parametrize the search. - pub search: Search, -} -const MAX_UNRESOLVED_LENGTH: f64 = 0.5; - -#[derive(Debug)] -pub struct Filters { - // When raised, filter out trees if some species in 'other' - // branche between LCA(outgroup) and the root, - // because they may should have been considered 'outgroup' instead. - pub triplet_other_monophyly: bool, - - // Let `mt` be the mean branch length for gene tree triplet species, - // and `mr` ("rest") be the mean branch length for outgroup/other species. - // The "shape" of the gene tree is the ratio `q = mt/mr`. - // The "global shape" or "mean shape" of all gene trees is `Q = smt/smr`, - // where `smt` is the sum of all `mt` and `smr` of all `mr` - // for all trees *with accepted topology*. - // If set, this parameter reject trees whose shape differ from the mean shape - // by more than a factor, say `2`, or anything specified with this parameter. - // The motivation is that the hypothesis that mutation rate is constant - // accross the tree(s) is too audacious when, eg. `q / Q ⩾ max_clock_ratio`. - // Values need to be greater than 1 - // because they quantify "absolute" ratio or "imbalance": `max(q/Q, Q/q)`. - pub max_clock_ratio: Option<BranchLength>, -} - -const DEFAULT_FILTERS: Filters = Filters { - max_clock_ratio: None, - triplet_other_monophyly: false, + optim::{ + self, bfgs::Config as BfgsConfig, gd::Config as GdConfig, + wolfe_search::Config as WolfeConfig, SlopeTrackingConfig, + }, + Config, Filters, Parameters, }; -// Contains only non-duplicated interned strings. -// Every section contains distinct species. -pub struct Taxa { - pub triplet: SpeciesTriplet, - pub(crate) outgroup: Vec<SpeciesSymbol>, - pub(crate) other: Vec<SpeciesSymbol>, -} - -// The triplet is importantly *oriented*: ((A, B), C). -// Accept several input forms provided they are unambiguous: -// triplet = [["U", "V"], "W"] # A: U, B: V, C: W -// triplet = ["U", ["V", "W"]] # A: V, B: W, C: U -// Or, shorter without special chars in species names: -// triplet = "((U, V), W)" -// triplet = "(U, (V, W))" -pub struct SpeciesTriplet { - pub a: SpeciesSymbol, - pub b: SpeciesSymbol, - pub c: SpeciesSymbol, -} - -pub struct Search { - pub(crate) learning_rate: f64, - pub(crate) max_iterations: u64, -} - macro_rules! err { - ($($message:tt)*) => {{ - return Err(Error::InputConsistency(format!$($message)*)); - }}; + ($fmt:tt) => { + ConfigErr { mess: format!$fmt } + }; } impl Config { @@ -119,55 +39,44 @@ impl Config { pub fn from_file(path: &Path, interner: &mut Interner) -> Result<Self, Error> { // Parse raw TOML file. let input = read_file(path)?; - let raw = RawConfig::parse(&input)?; + let raw = raw::Config::parse(&input)?; if let Some(ul) = raw.unresolved_length { - if ul < 0. { + ensure!( + ul >= 0., err!( ("The minimum internal branch length for tree considered resolved \ cannot be negative. Received 'unresolved_length = {ul}'.") ) - } - if ul > MAX_UNRESOLVED_LENGTH { + ); + ensure!( + ul <= MAX_UNRESOLVED_LENGTH, err!( ("The minimum internal branch length for tree considered resolved \ is assumed to be small, \ so the program forbids that it be superior to {MAX_UNRESOLVED_LENGTH}. \ Received 'unresolved_length = {ul}'.") ) - } - } - - let &lr = &raw.search.learning_rate; - if lr <= 0.0 { - err!(("Learning rate needs to be positive, not {lr}.")); + ); } // Most checks implemented within `TryFrom` trait. Ok(Config { + search: (&raw).try_into()?, taxa: raw.taxa._try_into(interner)?, trees: raw.taxa.trees, - init_parms: Parameters::try_from(&raw.initial_parameters, &raw.gf_times)?, unresolved_length: raw.unresolved_length, filters: if let Some(ref raw) = raw.filters { raw.try_into()? } else { - DEFAULT_FILTERS - }, - search: Search { - learning_rate: lr, - max_iterations: raw.search.max_iterations, + defaults::filters() }, }) } // The union of all sections constitutes the set of "relevant" species. pub fn relevant_species(&self) -> impl Iterator<Item = SpeciesSymbol> + '_ { - let Taxa { - triplet, - outgroup, - other, - } = &self.taxa; + let Taxa { triplet, outgroup, other } = &self.taxa; triplet .iter() .chain(outgroup.iter().copied()) @@ -177,23 +86,191 @@ impl Config { //-------------------------------------------------------------------------------------------------- // Check filter parameters. -impl TryFrom<&RawFilters> for Filters { +impl TryFrom<&raw::Filters> for Filters { type Error = Error; - fn try_from(raw: &RawFilters) -> Result<Self, Self::Error> { + fn try_from(raw: &raw::Filters) -> Result<Self, Self::Error> { if let Some(mcr) = raw.max_clock_ratio { - if mcr < 1. { + ensure!( + mcr >= 1., err!( ("The maximum branch length ratio between triplet and outgroup \ cannot be lower than 1. Received 'max_clock_ratio = {mcr}'.") ) - } + ); } Ok(Filters { max_clock_ratio: raw.max_clock_ratio, triplet_other_monophyly: raw .triplet_other_monophyly - .unwrap_or(DEFAULT_FILTERS.triplet_other_monophyly), + .unwrap_or(defaults::filters().triplet_other_monophyly), + }) + } +} + +//-------------------------------------------------------------------------------------------------- +// Check search configuration. + +impl TryFrom<&'_ raw::Config> for Search { + type Error = Error; + fn try_from(raw: &'_ raw::Config) -> Result<Self, Error> { + let init_data_reduction_factor = raw.search.init_data_reduction_factor; + if init_data_reduction_factor <= 1.0 { + err!( + ("The data reduction factor to search for non-degenerated starting point \ + must be greater than 1. Received: {init_data_reduction_factor}.") + ); + } + Ok(Self { + init_parms: Parameters::try_from(&raw.init, &raw.gf_times)?, + init_data_reduction_factor, + init_descent: (&raw.search.init_descent).try_into()?, + bfgs: (&raw.search.bfgs).try_into()?, + final_descent: raw + .search + .final_descent + .as_ref() + .map(TryInto::try_into) + .transpose()?, + }) + } +} + +impl TryFrom<&'_ raw::GdConfig> for GdConfig { + type Error = Error; + fn try_from(raw: &raw::GdConfig) -> Result<Self, Self::Error> { + let &raw::GdConfig { max_iter, learning_rate, ref slope_tracking } = raw; + Ok(Self { + max_iter, + step_size: learning_rate, + slope_tracking: slope_tracking + .as_ref() + .map(|sl| SlopeTrackingConfig::new(sl.sample_size, sl.threshold, sl.grain)) + .transpose()?, + }) + } +} + +impl TryFrom<&'_ raw::BfgsConfig> for BfgsConfig { + type Error = Error; + fn try_from(raw: &raw::BfgsConfig) -> Result<Self, Self::Error> { + let &raw::BfgsConfig { + max_iter, + ref wolfe_search, + step_size_threshold, + ref slope_tracking, + ref main_trace_path, + ref linsearch_trace_path, + } = raw; + ensure!( + 0. <= step_size_threshold, + err!( + ("The threshold for defining small search steps \ + must be null or positive. Received: {step_size_threshold}.") + ) + ); + let check_path = |path: &Option<PathBuf>| -> Result<_, Self::Error> { + let Some(path) = path else { + return Ok(None); + }; + let path = path + .canonicalize() + .with_context(|_| CanonicalizeErr { path })?; + if path.is_file() { + eprintln!("Will override {:?}.", path.to_string_lossy()); + }; + if let Some(parent) = path.parent() { + ensure!(parent.exists(), NoSuchFolderErr { path }); + } + Ok(Some(path)) + }; + Ok(Self { + max_iter, + slope_tracking: slope_tracking + .as_ref() + .map(|sl| SlopeTrackingConfig::new(sl.sample_size, sl.threshold, sl.grain)) + .transpose()?, + wolfe: wolfe_search.try_into()?, + small_step: step_size_threshold, + // TODO: expose. + main_trace_path: check_path(main_trace_path)?, + linsearch_trace_path: check_path(linsearch_trace_path)?, + }) + } +} + +impl TryFrom<&raw::WolfeSearchConfig> for WolfeConfig { + type Error = Error; + fn try_from(raw: &raw::WolfeSearchConfig) -> Result<Self, Self::Error> { + let &raw::WolfeSearchConfig { + c1, + c2, + init_step_size, + step_decrease_factor, + step_increase_factor, + flat_gradient, + bisection_threshold, + } = raw; + for (c, which, name) in [(c1, "upper", "c1"), (c2, "lower", "c2")] { + ensure!( + 0. < c && c < 1., + err!( + ("The Wolfe constant for checking {which} bound \ + must stand between 0 and 1 (excluded). \ + Received {name}={c1}.") + ) + ); + } + ensure!( + c1 < c2, + err!( + ("The Wolfe constant c1 and c2 must verify c1 < c2. \ + Received: c1 = {c1} ⩾ {c2} = c2.") + ) + ); + ensure!( + init_step_size > 0., + err!( + ("Initial step size for Wolfe search must be positive. \ + Received {init_step_size}.") + ) + ); + ensure!( + 0. < step_decrease_factor && step_decrease_factor < 1., + err!( + ("Decrease factor for Wolfe search must stand between 0 and 1 (excluded). \ + Received: {step_decrease_factor}.") + ) + ); + ensure!( + 1. < step_increase_factor, + err!( + ("Increase factor for Wolfe search must stand above 1. \ + Received: {step_increase_factor}.") + ) + ); + ensure!( + 0. <= flat_gradient, + err!( + ("Flat gradient threshold for Wolfe search must be positive or null. \ + Received: {flat_gradient}.") + ) + ); + ensure!( + (0. ..1.).contains(&bisection_threshold), + err!( + ("The zoom bisection threshold for Wolfe search must stand between 0 and 1/2. \ + Received: {bisection_threshold}.") + ) + ); + Ok(Self { + c1, + c2, + init_step: init_step_size, + step_decrease: step_decrease_factor, + step_increase: step_increase_factor, + flat_gradient, + bisection_threshold, }) } } @@ -202,9 +279,12 @@ impl TryFrom<&RawFilters> for Filters { // Check initial parameters. impl Parameters<f64> { - fn try_from(init: &InitialParameters, raw_gft: &RawGeneFlowTimes) -> Result<Self, Error> { + fn try_from( + init: &raw::InitialParameters, + raw_gft: &raw::GeneFlowTimes, + ) -> Result<Self, Error> { use Ordering as O; - let &InitialParameters { + let &raw::InitialParameters { theta, tau_1, tau_2, @@ -217,23 +297,24 @@ impl Parameters<f64> { macro_rules! check_positive { ($($name:ident),+) => {$({ let value = init.$name; - if value < 0. { + ensure!(value >= 0., err!(( "Model parameter '{}' must be positive, not {value}.", stringify!($name) )) - } + ) })+}; } macro_rules! check_prob { ($($name:ident),+) => {$({ let value = init.$name; - if !(0. ..=1.).contains(&value) { + ensure!((0. ..=1.).contains(&value), err!(( - "Model parameter '{}' must be a probability (between 0 and 1), not {value}.", + "Model parameter '{}' must be a probability (between 0 and 1), \ + not {value}.", stringify!($name) )) - } + ) })+}; } @@ -242,27 +323,30 @@ impl Parameters<f64> { match tau_2.total_cmp(&tau_1) { O::Less => { - err!( + return err!( ("The second coalescence time must be older than the first: \ - maybe tau_1 ({tau_1}) and tau_2 ({tau_2}) were meant to be reversed?") + maybe tau_1 ({tau_1}) and tau_2 ({tau_2}) were meant to be reversed?") ) + .fail() } O::Equal => { - err!( + return err!( ("The two coalescence times cannot be identical: \ - here tau_1 == tau_2 == {tau_1}.") + here tau_1 == tau_2 == {tau_1}.") ) + .fail() } O::Greater => {} }; let s = p_ab + p_bc + p_ac; - if 1. < s { + ensure!( + s <= 1.0, err!( ("The sum of gene flow transfer probabilities must not be larger than 1. \ p_ab + p_bc + p_ac = {p_ab} + {p_bc} + {p_ac} = {s} > 1.") ) - } + ); Ok(Parameters { theta, @@ -280,9 +364,9 @@ impl Parameters<f64> { //-------------------------------------------------------------------------------------------------- // Check gene flow times. -impl TryFrom<&RawGeneFlowTimes> for GeneFlowTimes<f64> { +impl TryFrom<&raw::GeneFlowTimes> for GeneFlowTimes<f64> { type Error = Error; - fn try_from(raw: &RawGeneFlowTimes) -> Result<Self, Self::Error> { + fn try_from(raw: &raw::GeneFlowTimes) -> Result<Self, Self::Error> { if raw.is_empty() { err!(("No gene flow time specified.")); } @@ -311,7 +395,7 @@ impl TryFrom<&RawGeneFlowTimes> for GeneFlowTimes<f64> { .0 .binary_search_by(|&t| t.total_cmp(&time).reverse()) { - Ok(_) => err!(("Gene flow time '{time}' specified twice.")), + Ok(_) => return err!(("Gene flow time '{time}' specified twice.")).fail(), Err(i) => gf_times.0.insert(i, time), } } @@ -320,29 +404,27 @@ impl TryFrom<&RawGeneFlowTimes> for GeneFlowTimes<f64> { } //-------------------------------------------------------------------------------------------------- -impl RawTaxa { +impl raw::Taxa { fn _try_into(&self, interner: &mut Interner) -> Result<Taxa, Error> { // Check taxa consistency. - let Self { - triplet: raw_triplet, - outgroup, - other, - .. - } = self; - let triplet: [&str; 3] = raw_triplet.try_into().map_err(E::InputConsistency)?; + let Self { triplet: raw_triplet, outgroup, other, .. } = self; + let triplet: [&str; 3] = raw_triplet + .try_into() + .map_err(|mess| Error::Config { mess })?; if outgroup.is_empty() { err!(("No outgroup species specified.")); } // No duplicates should be found among the taxa explicitly given. let mut relevant_species = HashSet::new(); - let mut new_species = move |name, section| { - if relevant_species.contains(&name) { + let mut new_species = move |name, section| -> Result<_, Error> { + ensure!( + !relevant_species.contains(&name), err!( ("Species {name:?} found in '{section}' section \ is given twice within the [taxa] table.") ) - } + ); let symbol = interner.get_or_intern(name); relevant_species.insert(name); // Also record into the interner. @@ -351,11 +433,7 @@ impl RawTaxa { let [a, b, c] = triplet.map(|s| new_species(s, "triplet")); Ok(Taxa { - triplet: SpeciesTriplet { - a: a?, - b: b?, - c: c?, - }, + triplet: SpeciesTriplet { a: a?, b: b?, c: c? }, outgroup: outgroup .iter() .map(|s| new_species(s, "outgroup")) @@ -371,15 +449,15 @@ impl RawTaxa { //-------------------------------------------------------------------------------------------------- // Check triplet and extract strings in canonical order. -impl<'t> TryFrom<&'t RawTriplet> for [&'t str; 3] { +impl<'t> TryFrom<&'t raw::Triplet> for [&'t str; 3] { type Error = String; // Work towards canonicalized ((a, b), c), // (x, y) refer to input nodes still undetermined. #[allow(clippy::many_single_char_names)] - fn try_from(input: &'t RawTriplet) -> Result<Self, Self::Error> { + fn try_from(input: &'t raw::Triplet) -> Result<Self, Self::Error> { + use raw::Triplet as T; use NodeInput as N; - use RawTriplet as T; macro_rules! err { ($($message:tt)+) => {{ return Err(format!$($message)+); @@ -468,7 +546,7 @@ impl SpeciesTriplet { } } -//-------------------------------------------------------------------------------------------------- +//================================================================================================== // Display. // We cannot simply derive(Debug) for the config @@ -494,20 +572,16 @@ impl<'i> fmt::Debug for ConfigDisplay<'i> { f.debug_struct("Config") .field("trees", &self.config.trees) .field("taxa", &self.config.taxa.for_display(self.interner)) - .field("init_parms", &self.config.init_parms) .field("filters", &self.config.filters) .field("unresolved_length", &self.config.unresolved_length) + .field("search", &self.config.search) .finish() } } impl<'i> fmt::Debug for TaxaDisplay<'i> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let Taxa { - triplet, - outgroup, - other, - } = self.taxa; + let Taxa { triplet, outgroup, other } = self.taxa; let resolve = |slice: &[_]| { slice .iter() @@ -536,43 +610,60 @@ impl<'i> fmt::Debug for TripletDisplay<'i> { impl Config { pub fn for_display<'d>(&'d self, interner: &'d Interner) -> ConfigDisplay<'d> { - ConfigDisplay { - config: self, - interner, - } + ConfigDisplay { config: self, interner } } } impl Taxa { fn for_display<'d>(&'d self, interner: &'d Interner) -> TaxaDisplay<'d> { - TaxaDisplay { - taxa: self, - interner, - } + TaxaDisplay { taxa: self, interner } } } impl SpeciesTriplet { fn for_display<'d>(&'d self, interner: &'d Interner) -> TripletDisplay<'d> { - TripletDisplay { - triplet: self, - interner, - } + TripletDisplay { triplet: self, interner } } } -//------------------------------------------------------------------------------------------------- +//================================================================================================== +// Errors. + +#[derive(Debug, Snafu)] +#[snafu(context(suffix(Err)))] +pub enum Error { + #[snafu(transparent)] + Io { source: io::Error }, + #[snafu(transparent)] + Parse { source: raw::Error }, + #[snafu(display("Configuration error: {mess}"))] + Config { mess: String }, + #[snafu(transparent)] + Optim { source: optim::Error }, + #[snafu(display("Could not canonicalize path: {:?}.", path.to_string_lossy()))] + Canonicalize { + source: std::io::Error, + path: PathBuf, + }, + #[snafu(display("Could not find folder: {:?}", path.to_string_lossy()))] + NoSuchFolder { path: PathBuf }, +} + +//================================================================================================== +// Test. + #[cfg(test)] mod tests { use serde::Deserialize; use super::*; + #[test] #[allow(clippy::too_many_lines)] fn triplet_parsing() { #[derive(Deserialize)] struct Single { - raw: RawTriplet, + raw: raw::Triplet, } let success = |input, expected| { println!("Checking success for: {input}."); @@ -707,7 +798,7 @@ mod tests { r#" |\n"#, r#"1 | raw = [["U", ["V", "W"]]]\n"#, r#" | ^^^^^^^^^^^^^^^^^^^\n"#, - r#"data did not match any variant of untagged enum RawTriplet\n"#, + r#"data did not match any variant of untagged enum Triplet\n"#, ) .replace(r"\n", "\n"), ); diff --git a/src/config/defaults.rs b/src/config/defaults.rs new file mode 100644 index 0000000000000000000000000000000000000000..dbe968142e8dc1dca4fc2f066dad8b09bd072bdd --- /dev/null +++ b/src/config/defaults.rs @@ -0,0 +1,87 @@ +// Decisions for anything not user-provided. + +use crate::{ + config::deserialize::{ + BfgsConfig, GdConfig, InitialParameters, Search, SlopeTrackingConfig, WolfeSearchConfig, + }, + Filters, +}; + +pub(crate) fn initial_parameters() -> InitialParameters { + InitialParameters { + theta: 0.005, + tau_1: 0.002, + tau_2: 0.005, + p_ab: 0.3, + p_ac: 0.3, + p_bc: 0.3, + p_ancient_gf: 0.7, + } +} + +pub(crate) fn gf_times() -> Vec<f64> { + vec![1.] +} + +pub(crate) fn filters() -> Filters { + Filters { + max_clock_ratio: None, + triplet_other_monophyly: false, + } +} + +pub(crate) fn search() -> Search { + Search { + init_data_reduction_factor: init_data_reduction_factor(), + init_descent: init_descent(), + bfgs: bfgs(), + final_descent: None, + } +} + +pub(crate) fn init_data_reduction_factor() -> f64 { + 2.0 +} + +pub(crate) fn init_descent() -> GdConfig { + GdConfig { + max_iter: 1_000, + learning_rate: 1e-2, + slope_tracking: Some(SlopeTrackingConfig { sample_size: 10, threshold: 1e-1, grain: 5 }), + } +} + +pub(crate) fn bfgs() -> BfgsConfig { + BfgsConfig { + max_iter: 1_000, + wolfe_search: wolfe_search(), + step_size_threshold: 1e-9, + slope_tracking: Some(SlopeTrackingConfig { sample_size: 20, threshold: 1e-3, grain: 5 }), + main_trace_path: None, + linsearch_trace_path: None, + } +} + +pub(crate) fn wolfe_search() -> WolfeSearchConfig { + // Taken from Nocedal and Wright 2006. + let c1 = 1e-4; + let c2 = 0.1; + WolfeSearchConfig { + c1, + c2, + init_step_size: 1.0, + step_decrease_factor: c2, + step_increase_factor: 10., // (custom addition) + flat_gradient: 1e-20, + bisection_threshold: 1e-1, + } +} + +#[allow(dead_code)] // Not the default anymore. +pub(crate) fn final_descent() -> GdConfig { + GdConfig { + max_iter: 1_000, + learning_rate: 1e-5, + slope_tracking: Some(SlopeTrackingConfig { sample_size: 100, threshold: 1e-3, grain: 50 }), + } +} diff --git a/src/config/deserialize.rs b/src/config/deserialize.rs index c535adf6d01a4808b567a7c118c65478fabb8d92..d5d51d17e3957f9f941c9c75f8a081055c710c73 100644 --- a/src/config/deserialize.rs +++ b/src/config/deserialize.rs @@ -8,33 +8,34 @@ use std::{array, fmt, path::PathBuf}; use serde::Deserialize; use serde_with::{serde_as, FromInto}; +use snafu::{ResultExt, Snafu}; -use crate::Error; +use crate::config::defaults; //-------------------------------------------------------------------------------------------------- // Global config structure and defaults. -pub(super) type RawGeneFlowTimes = Vec<f64>; +pub(crate) type GeneFlowTimes = Vec<f64>; #[derive(Deserialize)] #[serde(deny_unknown_fields)] -pub(super) struct RawConfig { - pub(super) filters: Option<RawFilters>, - pub(super) taxa: RawTaxa, - #[serde(default = "default_gf_times")] - pub(super) gf_times: RawGeneFlowTimes, - pub(super) unresolved_length: Option<f64>, - #[serde(rename = "init", default = "default_initial_parameters")] - pub(super) initial_parameters: InitialParameters, - #[serde(default = "default_search")] - pub(super) search: RawSearch, -} - -impl RawConfig { +pub(crate) struct Config { + pub(crate) filters: Option<Filters>, + pub(crate) taxa: Taxa, + #[serde(default = "defaults::gf_times")] + pub(crate) gf_times: GeneFlowTimes, + pub(crate) unresolved_length: Option<f64>, + #[serde(default = "defaults::initial_parameters")] + pub(crate) init: InitialParameters, + #[serde(default = "defaults::search")] + pub(crate) search: Search, +} + +impl Config { // Parse config file successfully // or abort the program with a useful error. - pub(super) fn parse(input: &str) -> Result<Self, Error> { - Ok(toml::from_str(input)?) + pub(crate) fn parse(input: &str) -> Result<Self, Error> { + toml::from_str(input).context(DeserializeErr {}) } } @@ -44,9 +45,9 @@ impl RawConfig { #[serde_as] #[derive(Deserialize)] #[serde(deny_unknown_fields)] -pub(super) struct RawFilters { - pub(super) max_clock_ratio: Option<f64>, - pub(super) triplet_other_monophyly: Option<bool>, +pub(crate) struct Filters { + pub(crate) max_clock_ratio: Option<f64>, + pub(crate) triplet_other_monophyly: Option<bool>, } //-------------------------------------------------------------------------------------------------- @@ -55,37 +56,37 @@ pub(super) struct RawFilters { #[serde_as] #[derive(Deserialize)] #[serde(deny_unknown_fields)] -pub(super) struct RawTaxa { +pub(crate) struct Taxa { // Where to find the gene trees. - pub(super) trees: PathBuf, + pub(crate) trees: PathBuf, - pub(super) triplet: RawTriplet, + pub(crate) triplet: Triplet, #[serde_as(as = "FromInto<ListOfStrings>")] - pub(super) outgroup: Vec<String>, + pub(crate) outgroup: Vec<String>, #[serde_as(as = "FromInto<ListOfStrings>")] - pub(super) other: Vec<String>, + pub(crate) other: Vec<String>, } // Guard against invalid inputs. #[derive(Deserialize)] #[serde(untagged)] -pub(super) enum RawTriplet { +pub(crate) enum Triplet { Arrays(Vec<NodeInput>), Parenthesized(String), } #[derive(Deserialize)] #[serde(untagged)] -pub(super) enum NodeInput { +pub(crate) enum NodeInput { Leaf(String), Internal(Vec<String>), } -impl fmt::Debug for RawTriplet { +impl fmt::Debug for Triplet { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - use RawTriplet as T; + use Triplet as T; match self { T::Arrays(ars) => write!(f, "{ars:?}"), T::Parenthesized(s) => write!(f, "{s:?}"), @@ -108,59 +109,72 @@ impl fmt::Debug for NodeInput { #[derive(Deserialize)] #[serde(deny_unknown_fields)] -pub(super) struct InitialParameters { - pub(super) theta: f64, - pub(super) tau_1: f64, - pub(super) tau_2: f64, - pub(super) p_ab: f64, - pub(super) p_ac: f64, - pub(super) p_bc: f64, - pub(super) p_ancient_gf: f64, -} - -fn default_initial_parameters() -> InitialParameters { - InitialParameters { - theta: 0.005, - tau_1: 0.002, - tau_2: 0.005, - p_ab: 0.3, - p_ac: 0.3, - p_bc: 0.3, - p_ancient_gf: 0.7, - } -} - -fn default_gf_times() -> Vec<f64> { - vec![1.] +pub(crate) struct InitialParameters { + pub(crate) theta: f64, + pub(crate) tau_1: f64, + pub(crate) tau_2: f64, + pub(crate) p_ab: f64, + pub(crate) p_ac: f64, + pub(crate) p_bc: f64, + pub(crate) p_ancient_gf: f64, } //-------------------------------------------------------------------------------------------------- // Search configuration. + #[derive(Deserialize)] #[serde(deny_unknown_fields)] -pub(super) struct RawSearch { - #[serde(default = "default_learning_rate")] - pub(super) learning_rate: f64, - #[serde(default = "default_max_iter")] - pub(super) max_iterations: u64, +pub(crate) struct Search { + #[serde(default = "defaults::init_data_reduction_factor")] + pub(crate) init_data_reduction_factor: f64, + #[serde(default = "defaults::init_descent")] + pub(crate) init_descent: GdConfig, + #[serde(default = "defaults::bfgs")] + pub(crate) bfgs: BfgsConfig, + pub(crate) final_descent: Option<GdConfig>, } -fn default_search() -> RawSearch { - RawSearch { - learning_rate: default_learning_rate(), - max_iterations: default_max_iter(), - } +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +pub(crate) struct BfgsConfig { + pub(crate) max_iter: u64, + #[serde(default = "defaults::wolfe_search")] + pub(crate) wolfe_search: WolfeSearchConfig, + pub(crate) step_size_threshold: f64, + pub(crate) slope_tracking: Option<SlopeTrackingConfig>, + pub(crate) main_trace_path: Option<PathBuf>, + pub(crate) linsearch_trace_path: Option<PathBuf>, +} + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +pub(crate) struct WolfeSearchConfig { + pub(crate) c1: f64, + pub(crate) c2: f64, + pub(crate) init_step_size: f64, + pub(crate) step_decrease_factor: f64, + pub(crate) step_increase_factor: f64, + pub(crate) flat_gradient: f64, + pub(crate) bisection_threshold: f64, } -fn default_learning_rate() -> f64 { - 1e-1 +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +pub(crate) struct GdConfig { + pub(crate) max_iter: u64, + pub(crate) learning_rate: f64, + pub(crate) slope_tracking: Option<SlopeTrackingConfig>, } -fn default_max_iter() -> u64 { - 1_000 +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +pub(crate) struct SlopeTrackingConfig { + pub(crate) sample_size: usize, + pub(crate) threshold: f64, + pub(crate) grain: u64, } -//-------------------------------------------------------------------------------------------------- +//================================================================================================== // Parsing utils. // Lists of strings can either be specified as regular arrays like in: @@ -206,3 +220,38 @@ impl<const N: usize> TryFrom<ListOfStrings> for [String; N] { } } } + +//================================================================================================== +// Errors. + +#[derive(Debug, Snafu)] +#[snafu(context(suffix(Err)))] +pub enum Error { + #[snafu(display( + "Error while reading config file:\n{}", + complete_deserialize_error(source) + ))] + Deserialize { source: toml::de::Error }, +} + +// Best-attempt to improve over https://github.com/serde-rs/serde/issues/773. +fn complete_deserialize_error(e: &toml::de::Error) -> String { + let mut message = format!("{e}"); + let pattern = "data did not match any variant of untagged enum "; + if message.contains(pattern) { + let name = message.trim().rsplit(' ').next().unwrap(); + let accepted_types = match name { + // TODO: generate these with procedural macros? + "ListOfStrings" => vec!["Array", "String"], + _ => panic!("Error in source code: enum {name} has not been handled."), + }; + message = message.replace( + &format!("{pattern}{name}"), + &format!( + "invalid type: expected one of [{}]", + accepted_types.join(", ") + ), + ); + }; + message +} diff --git a/src/errors.rs b/src/errors.rs deleted file mode 100644 index 6075e105dcfc48e02ced7ed9b9a5e23939a819e1..0000000000000000000000000000000000000000 --- a/src/errors.rs +++ /dev/null @@ -1,115 +0,0 @@ -// Coordinate toplevel failure modes. - -use std::{ - fmt::{self, Display}, - io::Error as IOError, - string::FromUtf8Error, -}; - -use colored::Colorize; -use toml::de::Error as TomlError; - -use crate::lexer::Error as LexerError; - -// Toplevel errors. -#[cfg_attr(test, derive(Debug))] -pub enum Error { - // Errors while reading the text in input files. - InputIo(IOError), - Utf8(FromUtf8Error), - // Error when parsing the main config file. - ConfigParsing(TomlError), - // Error when checking input consistency. - InputConsistency(String), - // Error while reading gene trees, located within the file. - GeneTree { line: usize, error: LexerError }, - // Fallback/ general purpose error. - Any(String), -} - -// Delegate display to contained errors. -impl Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let prefix = format!("{}:", self.prefix()); - write!(f, "{} {}", prefix.bold().red(), self.message()) - } -} - -impl Error { - // Pick an error prefix for each. - fn prefix(&self) -> &str { - use Error as E; - match self { - E::InputIo(_) => "Input error", - E::Utf8(_) => "Encoding error", - E::ConfigParsing(_) => "Config parsing error", - E::InputConsistency(_) => "Inconsistent input", - E::GeneTree { .. } => "Gene tree error", - E::Any { .. } => "Error", - } - } - - // Pick an error code for each. - pub fn code(&self) -> i32 { - use Error as E; - match self { - E::InputIo(_) => 1, - E::Utf8(_) => 2, - E::ConfigParsing(_) => 3, - E::InputConsistency(_) => 4, - E::GeneTree { .. } => 5, - E::Any { .. } => 6, - } - } - - // Extract owned message string from whatever underlying error. - fn message(&self) -> String { - use Error as E; - #[allow(clippy::match_same_arms)] - match self { - E::InputIo(e) => format!("{e}"), - E::Utf8(e) => format!("{e}"), - E::InputConsistency(e) => e.to_string(), - E::Any(e) => e.to_string(), - E::GeneTree { line, error: e } => format!("Line {line}: {e}"), - E::ConfigParsing(e) => { - // Best-attempt to improve over https://github.com/serde-rs/serde/issues/773. - let mut message = format!("{e}"); - let pattern = "data did not match any variant of untagged enum "; - if message.contains(pattern) { - let name = message.trim().rsplit(' ').next().unwrap(); - let accepted_types = match name { - // TODO: generate these with procedural macros? - "ListOfStrings" => vec!["Array", "String"], - _ => panic!("Error in source code: enum {name} has not been handled."), - }; - message = message.replace( - &format!("{pattern}{name}"), - &format!( - "invalid type: expected one of [{}]", - accepted_types.join(", ") - ), - ); - }; - message - } - } - } -} - -// Trivial conversions. -macro_rules! from { - ($($From:ident -> $To:ident)*) => {$( - impl From<$From> for Error { - fn from(e: $From) -> Self { - Error::$To(e) - } - } - )*}; -} - -from! { - IOError -> InputIo - FromUtf8Error -> Utf8 - TomlError -> ConfigParsing -} diff --git a/src/gene_tree.rs b/src/gene_tree.rs index 1d20af7b228199069f191a95e498f6ffe4fb0296..58acfa95b429184ccf08236eae1890ee38dab760 100644 --- a/src/gene_tree.rs +++ b/src/gene_tree.rs @@ -88,11 +88,7 @@ impl GeneTree { tname }, ); - GeneTree { - tree: pruned, - index, - ..*self - } + GeneTree { tree: pruned, index, ..*self } } // Iterate over terminal nodes names diff --git a/src/gene_tree/parse.rs b/src/gene_tree/parse.rs index cb4f62fd07c9f6ebade139ac3c1e1a370fbd249c..a53f4672c4f529ba5672c06ef0a123126b2e0e2c 100644 --- a/src/gene_tree/parse.rs +++ b/src/gene_tree/parse.rs @@ -98,26 +98,13 @@ impl<'i> Lexer<'i> { let sequence_length = read .parse::<NbBases>() .map_err(|e| lexerr!("Invalid sequence length: {read:?}: {e}"))?; - Ok(GeneTree { - tree, - index: reader.index, - sequence_length, - }) - } -} - -impl GeneTree { - pub(crate) fn from_input(input: &str, interner: &mut Interner) -> Result<Self, LexerError> { - Lexer::new(input).read_gene_tree(interner) + Ok(GeneTree { tree, index: reader.index, sequence_length }) } } impl<'n> GeneTreeDataReader<'n> { pub fn new(interner: &'n mut Interner) -> Self { - Self { - index: HashMap::new(), - interner, - } + Self { index: HashMap::new(), interner } } } @@ -125,6 +112,12 @@ impl<'n> GeneTreeDataReader<'n> { mod tests { use super::*; + impl GeneTree { + pub(crate) fn from_input(input: &str, interner: &mut Interner) -> Result<Self, LexerError> { + Lexer::new(input).read_gene_tree(interner) + } + } + #[test] fn parse_lengths() { let mut interner = Interner::new(); @@ -190,7 +183,7 @@ mod tests { println!("Check parse failure of:\n {input}"); let input = format!("{input} 0"); // Append dummy sequence length. match GeneTree::from_input(&input, &mut interner) { - Err(e) => assert_eq!(message, e.ref_message()), + Err(e) => assert_eq!(message, format!("{e}")), Ok(_) => panic!("Unexpected successful parse."), } }; diff --git a/src/gene_tree/topological_analysis.rs b/src/gene_tree/topological_analysis.rs index 62643f26d5a295ea07324cee2e4e28be105a1b84..722764722d1bce2bec43451334e592988b9c6415 100644 --- a/src/gene_tree/topological_analysis.rs +++ b/src/gene_tree/topological_analysis.rs @@ -66,12 +66,8 @@ impl GeneTree { // Check that they coalesce at the root. let top = if let ( - Some(&SectionLca { - id: lca_triplet, .. - }), - Some(&SectionLca { - id: lca_outgroup, .. - }), + Some(&SectionLca { id: lca_triplet, .. }), + Some(&SectionLca { id: lca_outgroup, .. }), ) = (triplet.lca(), outgroup.lca()) { let (lca, is_triplet_first) = self.tree.common_parent(lca_triplet, lca_outgroup); @@ -95,26 +91,15 @@ impl GeneTree { for o in self.leaves_excluding(lca_outgroup, outgroup_side) { species.push(o); } - InternalSpecies { - species, - first_outgroup, - } + InternalSpecies { species, first_outgroup } }); - Some(TreeTop { - lca, - external, - internal, - }) + Some(TreeTop { lca, external, internal }) } else { None // No treetop analysis if triplet or outgroup is empty. }; // Analysis completed. - TopologicalAnalysis { - triplet, - outgroup, - top, - } + TopologicalAnalysis { triplet, outgroup, top } } fn section_analysis(&self, species: impl Iterator<Item = SpeciesSymbol>) -> SectionAnalysis { @@ -141,10 +126,7 @@ impl GeneTree { (!required.contains(&sp)).then_some(sp) }) .collect::<Vec<_>>(); - let lca = SectionLca { - id: lca, - paraphyletic, - }; + let lca = SectionLca { id: lca, paraphyletic }; if missing.is_empty() { S::AllFound(lca) } else { diff --git a/src/gene_tree/triplet.rs b/src/gene_tree/triplet.rs index 8565b4ceea2e3a4f07055bd09d100d72f3a1296b..b6f4026fbc5b0fc92e161c70e5cc1247daea2857 100644 --- a/src/gene_tree/triplet.rs +++ b/src/gene_tree/triplet.rs @@ -3,6 +3,7 @@ use std::fmt::{self, Display}; +use super::{MeanNbBases, NbBases}; use crate::{ config::SpeciesTriplet, gene_tree::nb_mutations, @@ -11,10 +12,8 @@ use crate::{ GeneTree, TopologicalAnalysis, }; -use super::{MeanNbBases, NbBases}; - // Matches the gene triplet topology onto the expected species topology. -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[allow(clippy::upper_case_acronyms)] pub enum Topology { ABC, // ((A,B),C) -> Concordant triplets. @@ -161,11 +160,7 @@ mod tests { let mut g = |c| interner.get_or_intern(c); let taxa = Taxa { - triplet: SpeciesTriplet { - a: g("T"), - b: g("U"), - c: g("V"), - }, + triplet: SpeciesTriplet { a: g("T"), b: g("U"), c: g("V") }, outgroup: vec![g("O"), g("P"), g("Q")], other: vec![g("X"), g("Y"), g("Z")], }; diff --git a/src/genes_forest.rs b/src/genes_forest.rs index 94c8afd706de484a6f5f3cee3359b7b9546310c9..0411afc203968d4299a2579087217ccdf049bab3 100644 --- a/src/genes_forest.rs +++ b/src/genes_forest.rs @@ -12,13 +12,13 @@ use crate::{interner::TreeSymbol, GeneTree}; // Move special methods implementations in dedicated modules. pub(crate) mod filter; -pub(crate) mod parse; +pub mod parse; // Bound to the input file lifetime. pub struct GenesForest { trees: Vec<GeneTree>, ids: Vec<TreeSymbol>, - index: HashMap<TreeSymbol, usize>, + _index: HashMap<TreeSymbol, usize>, } impl GenesForest { diff --git a/src/genes_forest/filter.rs b/src/genes_forest/filter.rs index d12e840da14cae096927224a1ce1ef62a782e20c..475c37d6072dca76768b8bce6d3dd1868b74e599 100644 --- a/src/genes_forest/filter.rs +++ b/src/genes_forest/filter.rs @@ -107,6 +107,7 @@ pub fn imbalance(a: f64, b: f64) -> (bool, f64) { #[cfg(test)] mod tests { + use float_eq::float_eq; use itertools::{EitherOrBoth, Itertools}; use crate::{ @@ -116,8 +117,6 @@ mod tests { GeneTree, TopologicalAnalysis, }; - use float_eq::float_eq; - #[test] #[allow(clippy::too_many_lines)] // It does take vertical space to test many trees. fn analysis_then_filter() { @@ -163,11 +162,7 @@ mod tests { // Test analysis on a few trees: let taxa = Taxa { - triplet: SpeciesTriplet { - a: g("T"), - b: g("U"), - c: g("V"), - }, + triplet: SpeciesTriplet { a: g("T"), b: g("U"), c: g("V") }, outgroup: vec![g("O"), g("P"), g("Q")], other: vec![g("X"), g("Y"), g("Z")], }; @@ -624,11 +619,7 @@ mod tests { let mut interner = Interner::new(); let mut g = |c| interner.get_or_intern(c); let taxa = Taxa { - triplet: SpeciesTriplet { - a: g("T"), - b: g("U"), - c: g("V"), - }, + triplet: SpeciesTriplet { a: g("T"), b: g("U"), c: g("V") }, outgroup: vec![g("O"), g("P"), g("Q")], other: vec![g("X"), g("Y"), g("Z")], }; diff --git a/src/genes_forest/parse.rs b/src/genes_forest/parse.rs index 4090db2584bdf4cd990b61aac34d261eb1c248a5..94171495cab6cdbf35c82b44457e19bfd66c29a6 100644 --- a/src/genes_forest/parse.rs +++ b/src/genes_forest/parse.rs @@ -2,12 +2,13 @@ use std::{collections::HashMap, path::Path}; +use snafu::{ResultExt, Snafu}; + use super::GenesForest; use crate::{ - errors::Error, interner::Interner, - io::read_file, - lexer::{lexerr, Error as LexerError, Lexer}, + io::{self, read_file}, + lexer::{Error as LexerError, Lexer}, }; impl GenesForest { @@ -23,12 +24,14 @@ impl GenesForest { // Count lines to contextualize potential errors. let mut l = 1; while !lexer.trim_start().consumed() { - let gtree = lexer.read_gene_tree(interner).map_err(with_line(l))?; - let name = lexer.trim_start_on_line().read_block().ok_or_else(|| { - with_line(l)(lexerr!( - "Unexpected end of line while reading gene tree name." - )) - })?; + let gtree = lexer.read_gene_tree(interner).context(LexErr { line: l })?; + let name = lexer + .trim_start_on_line() + .read_block() + .ok_or_else(|| Error::Parse { + line: l, + mess: "Unexpected end of line while reading gene tree name.".into(), + })?; lexer .trim_start_on_line() .strip_char(['\n']) @@ -42,32 +45,34 @@ impl GenesForest { trees.push(gtree); l += 1; } - Ok(GenesForest { trees, ids, index }) + Ok(GenesForest { trees, ids, _index: index }) } } -// Error conversions that require additional context. - -// For use with .map_err(). -fn with_line(line: usize) -> impl FnOnce(LexerError) -> Error { - move |error| Error::GeneTree { line, error } -} - -// For use from scratch. fn duplicated_tree_name(line: usize, name: &str) -> Error { - Error::GeneTree { + Error::Parse { line, - error: LexerError::new(format!( - "Tree name {name:?} already given earlier in the file." - )), + mess: format!("Tree name {name:?} already given earlier in the file."), } } + fn unexpected_data_on_line(line: usize, lexer: &mut Lexer) -> Error { - Error::GeneTree { + Error::Parse { line, - error: LexerError::new(format!( + mess: format!( "Unexpected trailing data: {}.", lexer.read_until(['\n']).unwrap().0 - )), + ), } } + +#[derive(Debug, Snafu)] +#[snafu(context(suffix(Err)))] +pub enum Error { + #[snafu(display("Lexing error while parsing gene forest, line {line}:\n{source}"))] + Lex { line: usize, source: LexerError }, + #[snafu(display("Error while parsing gene forest:\n{mess}"))] + Parse { line: usize, mess: String }, + #[snafu(transparent)] + Read { source: io::Error }, +} diff --git a/src/io.rs b/src/io.rs index 20e8286a20dc0c353311ac1912900ee4cd1e7fbe..955fe8b4d8feac9bd7710030ea6ccc59c1eb27e0 100644 --- a/src/io.rs +++ b/src/io.rs @@ -1,13 +1,32 @@ // Reading from / writing to files on disk. -use std::{fs::File, io::prelude::*, path::Path}; +use std::{ + fs::File, + io::{self, Read}, + path::{Path, PathBuf}, + string::FromUtf8Error, +}; -use crate::Error; +use snafu::{ResultExt, Snafu}; // Read into owned data, emiting dedicated crate errors on failure. pub(crate) fn read_file(path: &Path) -> Result<String, Error> { - let mut file = File::open(path)?; + let tobuf = || -> PathBuf { path.into() }; + let mut file = File::open(path).context(IoErr { path: tobuf() })?; let mut buf = Vec::new(); - file.read_to_end(&mut buf)?; - Ok(String::from_utf8(buf)?) + file.read_to_end(&mut buf) + .context(IoErr { path: tobuf() })?; + String::from_utf8(buf).context(Utf8Err { path: tobuf() }) +} + +#[derive(Debug, Snafu)] +#[snafu(context(suffix(Err)))] +pub enum Error { + #[snafu(display("IO error while reading file {}:\n{source}", path.display()))] + Io { path: PathBuf, source: io::Error }, + #[snafu(display("UTF-8 encoding error in file {}:\n{source}", path.display()))] + Utf8 { + path: PathBuf, + source: FromUtf8Error, + }, } diff --git a/src/learn.rs b/src/learn.rs new file mode 100644 index 0000000000000000000000000000000000000000..58c6dbfb283e50f164632ad00694a698301e1a4e --- /dev/null +++ b/src/learn.rs @@ -0,0 +1,248 @@ +// Use homebrew gradient descent and BFGS search +// to find scores values yielding maximum log-likelihood. +// Model log-likelihood as -F(X, P) +// with 'X' the 'data', constant throughout the optimisation, +// and 'P' the 'parameters' to optimize, +// and that we wish to derive F with respect to. + +use snafu::{ensure, Snafu}; +use tch::{Device, Tensor}; + +use crate::{ + config::Search, + model::{ + likelihood::{data_tensors, ln_likelihood_tensors}, + parameters::GeneFlowTimes, + scenarios::Scenario, + scores::Scores, + }, + optim::{Error as OptimError, Optim, OptimResult, OptimTensor}, + GeneTriplet, Parameters, +}; + +// Return both optimized inputs and output. +#[allow(clippy::similar_names)] // p_ab/p_ac is okay. +#[allow(clippy::too_many_lines)] // TODO: Fix once stabilized. +pub fn optimize_likelihood( + // Receive in this format for better locality. + triplets: &[GeneTriplet], + start: &Parameters<f64>, + search: &Search, +) -> Result<(Parameters<f64>, f64), Error> { + // Tensors are small and control flow depends a lot on their values, + // so roundtrips to a Gpu are not exactly interesting + // in terms of performances. + let device = Device::Cpu; + let n_triplets = triplets.len(); + + // Extract parameters tensors 'P', tracked for the gradient. + // Some few "fixed parameters", part of X, can be retrieved within the scores. + let (mut p, scores) = leaves_tensors(start, device); + + // Extract all remaining data 'X' from the triplets. + let ngf = start.n_gf_times(); + let scenarios = Scenario::iter(ngf).collect::<Vec<_>>(); + let n_scenarios = scenarios.len(); + let x = data_tensors(triplets, n_scenarios, &scenarios, device); + + let mut n_eval = 0; + let mut n_diff = 0; + + // Check starting point. + let first_lnl = ln_likelihood_tensors(&x, &scores(&p)).to_double(); + if !first_lnl.is_finite() { + // When there are a lot of data, + // the starting point parameters 'P0' may lead + // to numerically non-finite likelihood values. + // If this happens, perform the optimization + // on a smaller batch of the data first until we get finite likelihood. + let sc: Scores<f64> = (&scores(&p)).into(); + eprintln!( + "The chosen starting point yields non-finite log-likelihood ({first_lnl}) \ + on the whole dataset ({n_triplets} triplets):\n{sc}\n{start:?}", + ); + + // Numerical conversion shenanigans to divide dataset by a floating factor. + let reduce = |n: usize| { + let n: u64 = n as u64; + #[allow(clippy::cast_precision_loss)] + let n: f64 = n as f64; + let r = n / search.init_data_reduction_factor; + #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)] + let r: usize = r as usize; + r + }; + + // Loop to search a starting point that works for all data. + 'p: loop { + let mut n_samples = reduce(triplets.len()); + let (mut sub_triplets, mut sub_x); + // Loop to search a data subset that works for the current starting point. + 'x: loop { + ensure!( + n_samples > 0, + NonFiniteLikelihoodErr { lnl: first_lnl, parms: start.clone() } + ); + eprintln!( + "-- Try obtaining finite log-likelihood \ + with the {n_samples} first triplets." + ); + sub_triplets = &triplets[0..n_samples]; + sub_x = data_tensors(sub_triplets, n_scenarios, &scenarios, device); + let lnl = ln_likelihood_tensors(&sub_x, &scores(&p)).to_double(); + n_eval += 1; + if !lnl.is_finite() { + eprintln!("---- Failure: obtained log-likelihood: {lnl}."); + // Reduce the data again. + n_samples = reduce(n_samples); + continue 'x; + } + // With this working (sub_X, P), perform a learning pass + // to produce a (hopefully better) candidate P. + eprintln!( + "---- Success: obtained log-likelihood: {lnl}.\n\ + -- Learn on this subsample with simple gradient descent." + ); + let f = |p: &Tensor| -ln_likelihood_tensors(&sub_x, &scores(p)); + match search.init_descent.minimize(f, &p, 0) { + Err(e) => { + eprintln!("---- Learning failed:\n{e}"); + n_samples /= 2; + continue 'x; + } + Ok(opt) => { + n_eval += opt.n_eval(); + n_diff += opt.n_diff(); + p = opt.best_vars().copy(); + let sc: Scores<f64> = (&scores(&p)).into(); + let pm: Parameters<f64> = (&sc).into(); + eprintln!( + "---- Success: parameters learned on the subsample:\n{sc}\n{pm:?}" + ); + break 'x; + } + } + } + eprintln!("-- Try with this new starting point."); + let lnl = ln_likelihood_tensors(&x, &scores(&p)).to_double(); + n_eval += 1; + if !lnl.is_finite() { + eprintln!( + "---- Failure: obtained non-finite log-likelihood again ({lnl}). \ + Try subsampling again from the new starting point." + ); + continue 'p; + } + eprintln!( + "---- Success: obtained finite log-likelihood on the whole dataset ({lnl}).\n\ + -- Start learning from this new starting point." + ); + break 'p; + } + } + + // Learn over the whole dataset. + let f = |p: &Tensor| -ln_likelihood_tensors(&x, &scores(p)); + + println!("-- BFGS learning."); + let mut opt: Box<dyn OptimResult>; + opt = Box::new(search.bfgs.minimize(f, &p, 0)?); + let p = opt.best_vars(); + n_eval += opt.n_eval(); + n_diff += opt.n_diff(); + + // Then simple gradient descent to refine. + if let Some(gd) = &search.final_descent { + println!("-- Refine with simple gradient descent."); + opt = Box::new(gd.minimize(f, p, 0)?); + } + n_eval += opt.n_eval(); + n_diff += opt.n_diff(); + + // Final value. + let p = opt.best_vars().set_requires_grad(true); + let loss = f(&p); + loss.backward(); + let grad = p.grad(); + let scores = scores(&p); + let grad = scores.with_grads(&grad); + println!("-- Terminate heuristic after {n_eval} evaluations and {n_diff} differentiations."); + println!("-- Final gradient:\n{grad:#}"); + + let parms: Parameters<f64> = (&scores).into(); + let opt_lnl = -opt.best_loss(); + Ok((parms, opt_lnl)) +} + +// Construct all gradient-tracked parameters, +// gathered into a full 'P' tensor. +// Also construct a closure +// to extract them into a `Scores` value +// used as input to the likelihood formula. +// The Scores gather (aliased- = non-leaves-)gradient-tracked parameters +// and the untracked parameters, +// which are then conceptually part of the data 'X'. +fn leaves_tensors( + start: &Parameters<f64>, + dev: Device, +) -> (Tensor, impl Fn(&Tensor) -> Scores<Tensor>) { + // Calculate free-ranging scores from the desired starting point. + let s: Scores<_> = start.into(); + let Scores { + theta, + tau_1, + delta_tau, + gf, + ac, + bc, + ancient, + gf_times, + } = s; + + // TODO: make it configurable which ones are tracked. + let tracked = + Tensor::from_slice(&[theta, tau_1, delta_tau, gf, ac, bc, ancient]).to_device(dev); + // Scores `gf_times` are not tracked yet. + + ( + tracked, + // Specify how to extract typed, mixed '(P, X)' scores from the pure 'P' tensor. + move |p| { + let mut it = (0..).map(|i| p.get(i)); + let mut next = || it.next().unwrap(); + + Scores { + // /!\ Order here must match order within the tensor: match it! + theta: next(), + tau_1: next(), + delta_tau: next(), + gf: next(), + ac: next(), + bc: next(), + ancient: next(), + gf_times: GeneFlowTimes( + gf_times + .0 + .iter() + .map(|&t| Tensor::from(t).to_device(dev)) + .collect(), + ), + } + }, + ) +} + +//================================================================================================== +// Errors. + +#[derive(Debug, Snafu)] +#[snafu(context(suffix(Err)))] +pub enum Error { + #[snafu(display( + "The following initial parameters yielded non-finite log-likelihood ({lnl}) \ + even when (naively) subsampling the data:\n{parms:#?}" + ))] + NonFiniteLikelihood { lnl: f64, parms: Parameters<f64> }, + #[snafu(transparent)] + Optim { source: OptimError }, +} diff --git a/src/lexer.rs b/src/lexer.rs index a8bec8446ba1dff002faae4cb699f4c2cfb3a34f..bbeae18b833afeeb8ced18505c6db991ad23b199 100644 --- a/src/lexer.rs +++ b/src/lexer.rs @@ -4,17 +4,10 @@ // so that the "meaningful bits" can be anything needed, // even sophisticated structures and "Newick tree"s in particular. -use std::fmt::{self, Display}; - pub(crate) struct Lexer<'i> { input: &'i str, } -// Error out on parsing failure. -#[cfg_attr(test, derive(Debug))] -pub struct Error { - message: String, -} macro_rules! lexerr { ($($message:tt)*) => { crate::lexer::Error::new(format!($($message)*)) @@ -27,6 +20,7 @@ macro_rules! errout { } pub(crate) use errout; pub(crate) use lexerr; +use snafu::Snafu; // https://doc.rust-lang.org/reference/whitespace.html const WHITESPACE: [char; 11] = [ @@ -50,6 +44,7 @@ impl<'i> Lexer<'i> { } // Have a peek what's next to consume. + #[cfg(test)] pub(crate) fn rest(&self) -> &str { self.input } @@ -138,22 +133,16 @@ impl<'i> Lexer<'i> { } } -// Boilerplate. +#[derive(Debug, Snafu)] +#[snafu(display("{message}"))] +pub struct Error { + message: String, +} + impl Error { pub(crate) fn new(message: String) -> Self { Self { message } } - fn message(&self) -> String { - self.message.to_string() - } - pub(crate) fn ref_message(&self) -> &str { - &self.message - } -} -impl Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.message) - } } //================================================================================================== diff --git a/src/lib.rs b/src/lib.rs index e10cfca274b54147ec70ec10269460c4d9cb2f9d..2c6c5dabfffbd4a7bed5e00b53afaefbebb624e7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,16 +1,16 @@ pub mod config; -mod errors; mod gene_tree; -mod genes_forest; +pub mod genes_forest; pub mod interner; mod io; pub mod it_mean; +mod learn; mod lexer; mod model; +mod optim; mod tree; pub use config::{Config, Filters}; -pub use errors::Error; pub use gene_tree::{ extract_local_triplet, topological_analysis::{self, TopologicalAnalysis}, @@ -18,6 +18,7 @@ pub use gene_tree::{ BranchLength, GeneTree, }; pub use genes_forest::{filter::imbalance, GenesForest}; -pub use model::{ln_likelihood, optimize_likelihood, Parameters}; +pub use learn::optimize_likelihood; +pub use model::{likelihood::ln_likelihood, Parameters}; pub const VERSION: &str = env!("CARGO_PKG_VERSION"); diff --git a/src/model.rs b/src/model.rs index 9874d75f8ec4bf214ef4908b690a454f30b0655e..e34c0482ce34084d0386123724d8824be33afbfa 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,79 +1,8 @@ // The GLS/ILF model, associated parameters and likelihood calculations. -use std::io::{self, Write}; - -use crate::{config::Search, GeneTriplet}; - -mod likelihood; +pub(crate) mod likelihood; pub(crate) mod parameters; -mod scenarios; -mod scores; +pub(crate) mod scenarios; +pub(crate) mod scores; pub use parameters::Parameters; -use tch::{ - nn::{self, Module, OptimizerConfig}, - Device, Kind, Tensor, -}; - -use self::likelihood::likelihood_tensors; - -// Log-probability that all triplets in the data occur. -pub fn ln_likelihood( - // Receive in this format for better locality. - triplets: &[GeneTriplet], - p: &Parameters<f64>, -) -> f64 { - let mut ln_prob = 0.; - for gtrip in triplets { - ln_prob += gtrip.likelihood(p).ln(); - } - ln_prob -} - -// Different implementation with torch tensors. -// Return both optimized input and output. -#[allow(clippy::similar_names)] // p_ab/p_ac is okay. -pub fn optimize_likelihood( - // Receive in this format for better locality. - triplets: &[GeneTriplet], - start: &Parameters<f64>, - search: &Search, -) -> (Parameters<f64>, f64) { - // Choose in-memory data location. - let (kind, device) = (Kind::Double, Device::Cpu); - let vars = nn::VarStore::new(device); - - // Construct the likelihood function from fixed data and initial trainable parameters. - let (scores, ln_likelihood) = likelihood_tensors(&vars.root(), triplets, start, (kind, device)); - // Dummy input for this function: the data does not change during optimisation. - let input = Tensor::zeros([], (kind, device)); - - // Pick an optimizer and associated learning rate. - let mut opt = nn::AdamW::default() - .amsgrad(true) - .build(&vars, search.learning_rate) - .unwrap(); - let n_steps = search.max_iterations; - - // Optimize. - let display_step = |i, lnl: &Tensor| { - eprint!("step {i:03}: {}\r", lnl.double_value(&[])); - io::stdout().flush().unwrap(); - }; - for i in 0..n_steps { - opt.zero_grad(); - let lnl = ln_likelihood.forward(&input); - if i % 97 == 0 || i == n_steps { - display_step(i, &lnl); - } - opt.backward_step(&-lnl); - } - - // Final value. - let lnl = ln_likelihood.forward(&input); - display_step(n_steps, &lnl); - eprintln!(); - - let parms: Parameters<f64> = (&scores).into(); - (parms, lnl.double_value(&[])) -} diff --git a/src/model/likelihood.rs b/src/model/likelihood.rs index cafab336fcbe0f2c397e7282d83639c06f942713..1a352cb24a78e730b69261b7639d5d7ad52a69a2 100644 --- a/src/model/likelihood.rs +++ b/src/model/likelihood.rs @@ -1,22 +1,58 @@ // Calculate chances that the given tree happens given parameters values. +// Two implementations of likelihood are intertwined here: +// - One straighforward implementation with rust vectors and floats. +// - One tensor-based implementation with torch, +// which can be automatically differentiated. use std::iter; use itertools::izip; use paste::paste; use statrs::distribution::{Discrete, Poisson}; -use tch::{nn, Device, IndexOp, Kind, Tensor}; - -use super::{ - parameters::GeneFlowTimes, - scenarios::{GeneFlowScenario, GeneFlowTopology, Scenario, N_ILS_SCENARIOS}, - scores::Scores, - Parameters, +use tch::{Device, IndexOp, Tensor}; + +use crate::{ + gene_tree::{nb_mutations, triplet::GeneTriplet, TripletTopology}, + model::{ + scenarios::{GeneFlowScenario, GeneFlowTopology, Scenario, N_ILS_SCENARIOS}, + scores::Scores, + Parameters, + }, }; -use crate::gene_tree::{nb_mutations, triplet::GeneTriplet, TripletTopology}; -//-------------------------------------------------------------------------------------------------- +//================================================================================================== +// Basic vectors + floats implementation. + +// Entrypoint: log-probability that all triplets in the data occur. +#[allow(clippy::module_name_repetitions)] +pub fn ln_likelihood( + // Receive in this format for better locality. + triplets: &[GeneTriplet], + p: &Parameters<f64>, +) -> f64 { + let mut ln_prob = 0.; + for gtrip in triplets { + ln_prob += gtrip.likelihood(p).ln(); + } + ln_prob +} + +// Probability that the given gene triplet occurs, +// integrating over all possible scenarios. impl GeneTriplet { + pub fn likelihood(&self, p: &Parameters<f64>) -> f64 { + let mut prob = 0.; + for scenario in Scenario::iter(p.gf_times.0.len()) { + if self.contributes(&scenario) { + let prob_scenario = scenario.probability(p); + let expected_lengths = scenario.branches_lengths(p); + let prob_branches = self.branches_lengths_likelihood(expected_lengths); + prob += prob_scenario * prob_branches; + } + } + prob + } + // Decide whether contribution to the overall likelihood is relevant. // Only include concordant scenarios, // unless the triplet is not resolved enough. @@ -27,36 +63,38 @@ impl GeneTriplet { } //-------------------------------------------------------------------------------------------------- +// Expressions for probability of having ILS and of not having gene flow. + +// Use macro to ease abstraction over both floats and tensors. macro_rules! probabilities_ils_nongf { ($parameters:expr, $Ty:ty) => {{ - let Parameters { - theta, - tau_1, - tau_2, - p_ab, - p_ac, - p_bc, - .. - } = $parameters; + let Parameters { theta, tau_1, tau_2, p_ab, p_ac, p_bc, .. } = $parameters; let p_ils = ((2.0 * (tau_1 - tau_2) / theta) as $Ty).exp(); let p_ngf = 1. - p_ab - p_ac - p_bc; [p_ils, p_ngf] }}; } + +// (not actually used in this implementation) impl Parameters<f64> { - fn probabilities_ils_nongf(&self) -> [f64; 2] { + fn _probabilities_ils_nongf(&self) -> [f64; 2] { probabilities_ils_nongf!(self, f64) } } impl Parameters<Tensor> { - fn probabilities_ils_nongf(&self) -> [Tensor; 2] { + fn _probabilities_ils_nongf(&self) -> [Tensor; 2] { probabilities_ils_nongf!(self, Tensor) } } +//-------------------------------------------------------------------------------------------------- +// Prior probability that the scenario occurs, +// with respect to other possible scenarios. + macro_rules! scenario_probability { ($scenario:expr, $parameters:expr - $(, $Ty:ident)? // Weak hack to ease factorization in this very particular (f64, Tensor) abstraction. + // Weak hack to ease factorization in this very particular (f64, Tensor) abstraction. + $(, $Ty:ident)? ) => { paste! {{ let p = $parameters; @@ -84,8 +122,6 @@ macro_rules! scenario_probability { }; } -// Prior probability that the scenario occurs, -// with respect to other possible scenarios. impl Scenario { pub(crate) fn probability(&self, p: &Parameters<f64>) -> f64 { scenario_probability!(self, p) @@ -95,16 +131,13 @@ impl Scenario { } } +//-------------------------------------------------------------------------------------------------- +// Calculate gene flow probability depending on topology. + macro_rules! geneflow_probability { ($geneflow:expr, $parameters:expr) => {{ let Self { topology, time } = $geneflow; - let Parameters { - p_ab, - p_ac, - p_bc, - p_ancient, - .. - } = $parameters; + let Parameters { p_ab, p_ac, p_bc, p_ancient, .. } = $parameters; let ngf = $parameters.n_gf_times(); { use GeneFlowTopology as GT; @@ -122,6 +155,7 @@ macro_rules! geneflow_probability { } }}; } + impl GeneFlowScenario { pub(crate) fn probability(&self, p: &Parameters<f64>) -> f64 { geneflow_probability!(self, p) @@ -132,6 +166,8 @@ impl GeneFlowScenario { } //-------------------------------------------------------------------------------------------------- +// Calculate expected branches lengths. + // Expected branches lengths for the triplet, according to the given scenario: // Return only pair of short/long branches lengths [S, L]. // @@ -143,23 +179,14 @@ impl GeneFlowScenario { macro_rules! compact_branches_lengths { ($scenario:expr, $parameters:expr) => {{ use Scenario as S; - let Parameters { - theta, - tau_1, - tau_2, - gf_times, - .. - } = $parameters; + let Parameters { theta, tau_1, tau_2, gf_times, .. } = $parameters; let gf_times = |i| &gf_times.0[i]; match $scenario { S::NoEvent => [tau_1 + 0.5 * theta, tau_2 + 0.5 * theta], S::IncompleteLineageSorting(_) => { [tau_2 + (1. / 6.) * theta, tau_2 + (2. / 3.) * theta] } - &S::GeneFlow(GeneFlowScenario { - time: i, - ref topology, - }) => { + &S::GeneFlow(GeneFlowScenario { time: i, ref topology }) => { use GeneFlowTopology as G; let tau_g = tau_1 * gf_times(i); [ @@ -202,14 +229,14 @@ impl Scenario { pub(crate) fn branches_lengths_tensor(&self, parms: &Parameters<Tensor>) -> Tensor { let sl = compact_branches_lengths!(self, parms); let ([a, b, c], ref d) = expand_lengths!(self, sl); - Tensor::concatenate(&[a, b, c, d], 0) + Tensor::stack(&[a, b, c, d], 0) } } +//-------------------------------------------------------------------------------------------------- // Probability that the given gene triplet occurs // given the expected branches lengths. impl GeneTriplet { - #[allow(clippy::many_single_char_names)] fn branches_lengths_likelihood(&self, expected_lengths: [f64; 4]) -> f64 { // (locally match the paper notations) let alpha_i = self.relative_mutation_rate; @@ -234,154 +261,30 @@ fn ln_poisson_density(n: u64, lambda: f64) -> f64 { Poisson::new(lambda).unwrap().ln_pmf(n) } -//-------------------------------------------------------------------------------------------------- -// Probability that the given gene triplet occurs, -// integrating over all possible scenarios. -impl GeneTriplet { - pub fn likelihood(&self, p: &Parameters<f64>) -> f64 { - let mut prob = 0.; - for scenario in Scenario::iter(p.gf_times.0.len()) { - if self.contributes(&scenario) { - let prob_scenario = scenario.probability(p); - let expected_lengths = scenario.branches_lengths(p); - let prob_branches = self.branches_lengths_likelihood(expected_lengths); - prob += prob_scenario * prob_branches; - } - } - prob - } -} - //================================================================================================== -// Alternate implementation as a torch module, for autograd + optimisation. +// Alternate implementation with torch tensors, +// useful for autograd + optimisation. +// Represent log-likelihood by lnl = F(X, P) +// with 'X' fixed data +// and 'P' user parameters to optimize. -pub(crate) fn likelihood_tensors( - p: &nn::Path, +// Construct all non(-potentially)-gradient-tracked tensors, +// referring to the immutable input data 'X', +// and used to calculate the eventual likelihood. +// Some of it, the untracked scores, +// is actually standing within the Scores struct. +pub(crate) fn data_tensors<'s>( triplets: &[GeneTriplet], - start: &Parameters<f64>, - (kind, device): (Kind, Device), -) -> (Scores<Tensor>, impl nn::Module) { + n_scenarios: usize, + scenarios: &'s [Scenario], + dev: Device, +) -> DataTensors<'s> { let n_trees = triplets.len(); - let ngf = start.n_gf_times(); - let scenarios = Scenario::iter(ngf).collect::<Vec<_>>(); - let n_scenarios = scenarios.len(); - let n_trees_i: i64 = n_trees .try_into() .unwrap_or_else(|_| panic!("Too many trees ({n_trees}) to be indexed with i64 by torch.")); - // Model log-likelihood as F(X, P) with X the 'data', constant throughout the optimisation, - // and 'P' the 'parameters' to optimize, and that we wish to derive F with respect to. - - // Prepare all parameters 'P' from the desired initial value. - let scores = leaves_tensors(p, start, device); - - // Extract all data 'X' from the triplets. - let DataTensors { - alphaseq_per_scenario, - n_per_scenario, - minus_ln_factorial_n_per_scenario, - columns_indices, - } = data_tensors(triplets, n_scenarios, &scenarios, device); - - // Calculate likelihood. - ( - scores.alias(), - nn::func(move |_| { - // Calculate parameters for the current scores values. - let parms: Parameters<Tensor> = (&scores).into(); - - let mut columns = Vec::with_capacity(n_scenarios); - - for (scenario, alphaseq, n, minus_ln_factorial_n) in izip!( - &scenarios, - &alphaseq_per_scenario, - &n_per_scenario, - &minus_ln_factorial_n_per_scenario, - ) { - // Scenario prior probability. - let prior = scenario.probability_tensors(&parms); - - // Expected branch lengths. - let expected_abcd = scenario.branches_lengths_tensor(&parms); - - // Log-densities of poisson distribution. - let lambda = Tensor::outer(alphaseq, &expected_abcd); - let ln_poisson = minus_ln_factorial_n + n * lambda.log() - lambda; - - // Sum over the branches and exponentiate - // to get the product of poisson distributions. - let poisson_product = ln_poisson.sum_dim_intlist(1, false, kind).exp(); - - // Down to the final contribution of every tree - // to the likelihood of this scenario. - let contributions = prior * poisson_product; - - columns.push(contributions); - } - - // Sum over all scenarios to get the total probability of every tree. - // This is a sum over the rows of the conceptual sparse matrix. - let columns = Tensor::cat(&columns, 0); - let trees_prob = - Tensor::zeros(n_trees_i, (kind, device)).scatter_add(0, &columns_indices, &columns); - - // And finally integrate this all into the total likelihood. - trees_prob.log().sum(None) - }), - ) -} - -// Construct all (potentially) gradient-tracked parameters, -// sources of the calculation graph towards the eventual likelihood. -fn leaves_tensors(p: &nn::Path<'_>, start: &Parameters<f64>, dev: Device) -> Scores<Tensor> { - let single = |v, track| { - Tensor::from_slice(&[v]) - .set_requires_grad(track) - .to_device(dev) - }; - let var = |v, name: &str, track| { - if track { - p.var_copy(name, &single(v, true)) - } else { - single(v, false) - } - }; - - // Calculate free-ranging scores from the desired starting point. - let s: Scores<_> = start.into(); - let scores = Scores::<Tensor> { - theta: var(s.theta, "theta", true), - tau_1: var(s.tau_1, "tau_1", true), - delta_tau: var(s.delta_tau, "delta_tau", true), - gf: var(s.gf, "gf", true), - ac: var(s.ac, "ac", true), - bc: var(s.bc, "bc", true), - ancient: var(s.ancient, "ancient", true), - gf_times: GeneFlowTimes( - s.gf_times - .0 - .iter() - .enumerate() - .map(|(i, &v)| var(v, &format!("gf_time_{i:02}"), false)) - .collect(), - ), - }; - scores -} - -// Construct all non(-potentially)-gradient-tracked tensors, -// referring to the immutable input data, -// and used to calculate the eventual likelihood. -fn data_tensors( - triplets: &[GeneTriplet], - n_scenarios: usize, - scenarios: &[Scenario], - dev: Device, -) -> DataTensors { - let mut n = Vec::new(); - // /!\ - // Generic name for branch lengths a/b/*d*/c. + let mut n = Vec::new(); // Generic name for branch lengths a/b/c/d. let mut alpha = Vec::new(); let mut seqlen = Vec::new(); for trip in triplets { @@ -432,17 +335,80 @@ fn data_tensors( // For every scenario, collect contribution of every tree to its likelihood, // in a sparse fashion corresponding to the columns of a conceptual scenario × gene trees matrix. // This matrix remains conceptual and is never actually reified into a tensor. - let columns_indices = Tensor::cat(&scenario_trees_index, 0); + let columns_indices = Tensor::cat(&scenario_trees_index, 0).to_device(dev); + DataTensors { + n_trees_i, + scenarios, + alphaseq_per_scenario, + n_per_scenario, + minus_ln_factorial_n_per_scenario, + columns_indices, + } +} + +// Actual implementation of lnl = F(X, P). +pub(crate) fn ln_likelihood_tensors( + data: &DataTensors, // 'X' + scores: &Scores<Tensor>, // 'P' (some may be *fixed* by user so actually part of 'X') +) -> Tensor { + let DataTensors { + n_trees_i, + scenarios, alphaseq_per_scenario, n_per_scenario, minus_ln_factorial_n_per_scenario, columns_indices, + } = data; + let (kind, device) = scores.kind_device(); + + // Calculate parameters for the current scores values. + let parms: Parameters<Tensor> = scores.into(); + + let mut columns = Vec::with_capacity(scenarios.len()); + + for (scenario, alphaseq, n, minus_ln_factorial_n) in izip!( + *scenarios, + alphaseq_per_scenario, + n_per_scenario, + minus_ln_factorial_n_per_scenario, + ) { + // Scenario prior probability. + let prior = scenario.probability_tensors(&parms); + + // Expected branch lengths. + let expected_abcd = scenario.branches_lengths_tensor(&parms); + + // Log-densities of poisson distribution. + let lambda = Tensor::outer(alphaseq, &expected_abcd); + let ln_poisson = minus_ln_factorial_n + n * lambda.log() - lambda; + + // Sum over the branches and exponentiate + // to get the product of poisson distributions. + let poisson_product = ln_poisson.sum_dim_intlist(1, false, kind).exp(); + + // Down to the final contribution of every tree + // to the likelihood of this scenario. + let contributions = prior * poisson_product; + + columns.push(contributions); } + + // Sum over all scenarios to get the total probability of every tree. + // This is a sum over the rows of the conceptual sparse matrix. + let columns = Tensor::cat(&columns, 0); + let trees_prob = + Tensor::zeros(*n_trees_i, (kind, device)).scatter_add(0, columns_indices, &columns); + + // And finally integrate this all into the total likelihood. + trees_prob.log().sum(None) } -// Return value of the above. -struct DataTensors { +// Constant tensors +// supposed not to be changed during the whole optimisation procedure. +pub(crate) struct DataTensors<'s> { + n_trees_i: i64, + scenarios: &'s [Scenario], alphaseq_per_scenario: Vec<Tensor>, n_per_scenario: Vec<Tensor>, minus_ln_factorial_n_per_scenario: Vec<Tensor>, diff --git a/src/model/parameters.rs b/src/model/parameters.rs index aee7379e9414fce701f5b6476043e83b9dc6bfa9..bf16c12276a93cf331fc8d62d88ef10b16cb7e92 100644 --- a/src/model/parameters.rs +++ b/src/model/parameters.rs @@ -6,15 +6,24 @@ use std::fmt::{self, Display}; use arrayvec::ArrayVec; use tch::Tensor; -use super::scenarios::Scenario; +// The structure is generic among float and tensors +// to allow both likelihood implementations. +// NB: the required bounds to abstract over both in this program +// is very complicated: (f64 - Num, Num / f64, Num.exp(), etc.) +// and it needs to be repeated on every method implementation. +// As a consequence, either duplicate the code or use macros, +// and keep genericity just for factorizing over the parameters type. +pub trait Num: Sized {} +impl Num for f64 {} +impl Num for Tensor {} +impl Num for ValGrad {} // Constraints to uphold: // - all positive // - tau_1 <= tau_2 // - p_* <= 1 // - p_ab + p_ac + p_bc <= 1 -#[allow(clippy::module_name_repetitions)] -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Parameters<F: Num> { // Population size. pub(crate) theta: F, @@ -32,37 +41,17 @@ pub struct Parameters<F: Num> { // Parameters are stack-allocated, so don't overuse possible GF times. pub(crate) const MAX_N_GF_TIMES: usize = 3; -#[derive(Debug)] -#[cfg_attr(test, derive(PartialEq))] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct GeneFlowTimes<F: Num>(pub(crate) ArrayVec<F, MAX_N_GF_TIMES>); -// Abstract over both f64 and Tensor. -// NB: the required bounds to abstract over them in this program -// is very complicated: (f64 - Num, Num / f64, Num.exp(), etc.) -// and it needs to be repeated on every method implementation. -// As a consequence, either duplicate the code or use macros, -// and keep genericity just for factorizing over the above structure. -pub trait Num: Sized {} -impl Num for f64 {} -impl Num for Tensor {} -impl Num for ValGrad {} - // Convenience bundle. pub(crate) struct ValGrad { pub(crate) value: f64, pub(crate) gradient: Option<f64>, } -//================================================================================================== - -impl<F: Num> Parameters<F> { - pub(crate) fn n_gf_times(&self) -> usize { - self.gf_times.0.len() - } - pub(crate) fn n_scenarios(&self) -> usize { - Scenario::len(self.n_gf_times()) - } -} +//-------------------------------------------------------------------------------------------------- +// Minor methods. impl<F: Num> GeneFlowTimes<F> { pub(crate) fn new() -> Self { @@ -70,7 +59,11 @@ impl<F: Num> GeneFlowTimes<F> { } } -//================================================================================================== +impl<F: Num> Parameters<F> { + pub(crate) fn n_gf_times(&self) -> usize { + self.gf_times.0.len() + } +} // Drop gradient information impl From<&ValGrad> for f64 { @@ -79,6 +72,9 @@ impl From<&ValGrad> for f64 { } } +//================================================================================================== +// Display. + impl<F: Num + Display> Display for GeneFlowTimes<F> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_list() diff --git a/src/model/scenarios.rs b/src/model/scenarios.rs index eae254f1bba38dce51a0e1ff6ef260b49ea49d68..cdb9e5c1719289a080f158f5827ca8a925f324b2 100644 --- a/src/model/scenarios.rs +++ b/src/model/scenarios.rs @@ -4,34 +4,30 @@ use std::iter; use crate::gene_tree::TripletTopology; -#[derive(Debug)] -#[cfg_attr(test, derive(PartialEq))] -pub(super) struct IncompleteLineageSortingScenario(TripletTopology); -pub(super) const N_ILS_SCENARIOS: usize = 3; +#[derive(Debug, PartialEq, Eq)] +pub(crate) struct IncompleteLineageSortingScenario(TripletTopology); +pub(crate) const N_ILS_SCENARIOS: usize = 3; -#[derive(Debug)] -#[cfg_attr(test, derive(PartialEq))] +#[derive(Debug, PartialEq, Eq)] #[allow(clippy::upper_case_acronyms)] -pub(super) enum GeneFlowTopology { +pub(crate) enum GeneFlowTopology { ABC, // ((A <> B), C) // (undistinguishable directions) ACB, // ((A => C), B) CAB, // ((C => A), B) BCA, // ((B => C), A) CBA, // ((C => B), A) } -pub(super) const N_GF_TOPOLOGIES: usize = 5; +#[cfg(test)] +pub(crate) const N_GF_TOPOLOGIES: usize = 5; -#[derive(Debug)] -#[cfg_attr(test, derive(PartialEq))] -pub(super) struct GeneFlowScenario { - pub(super) topology: GeneFlowTopology, - pub(super) time: usize, // Index into the vector of gene flow times. +#[derive(Debug, PartialEq, Eq)] +pub(crate) struct GeneFlowScenario { + pub(crate) topology: GeneFlowTopology, + pub(crate) time: usize, // Index into the vector of gene flow times. } -#[derive(Debug)] -#[cfg_attr(test, derive(PartialEq))] -#[allow(clippy::upper_case_acronyms)] -pub(super) enum Scenario { +#[derive(Debug, PartialEq, Eq)] +pub(crate) enum Scenario { NoEvent, IncompleteLineageSorting(IncompleteLineageSortingScenario), GeneFlow(GeneFlowScenario), @@ -74,10 +70,12 @@ impl GeneFlowTopology { } } +#[cfg(test)] +pub(crate) fn n_scenarios(n_gf_times: usize) -> usize { + 1 + N_ILS_SCENARIOS + n_gf_times * N_GF_TOPOLOGIES +} + impl Scenario { - pub(crate) fn len(n_gf_times: usize) -> usize { - 1 + N_ILS_SCENARIOS + n_gf_times * N_GF_TOPOLOGIES - } pub(crate) fn iter(n_gf_times: usize) -> impl Iterator<Item = Self> { use Scenario as S; @@ -93,7 +91,8 @@ impl Scenario { } } -//-------------------------------------------------------------------------------------------------- +//================================================================================================== + #[cfg(test)] mod tests { use super::*; @@ -108,7 +107,7 @@ mod tests { let check = |n, scenarios: &[_]| { assert_eq!(Scenario::iter(n).collect::<Vec<_>>(), scenarios); - assert_eq!(Scenario::len(n), scenarios.len()); + assert_eq!(n_scenarios(n), scenarios.len()); }; // With no gene flow scenarios. diff --git a/src/model/scores.rs b/src/model/scores.rs index 8d4873d077634703878f2c45475eb13601df42c5..2ccebc9bce348619c032b37ede0f001addb64725 100644 --- a/src/model/scores.rs +++ b/src/model/scores.rs @@ -7,23 +7,24 @@ use std::fmt::{self, Display}; use arrayvec::ArrayVec; -use tch::Tensor; +use tch::{Device, Kind, Tensor}; -use crate::Parameters; +use crate::{ + model::parameters::{GeneFlowTimes, Num, ValGrad}, + Parameters, +}; -use super::parameters::{GeneFlowTimes, Num, ValGrad}; - -#[cfg_attr(test, derive(PartialEq))] +#[derive(PartialEq, Eq)] pub(crate) struct Scores<F: Num> { pub(crate) theta: F, // Encodes theta. pub(crate) tau_1: F, // Encodes tau_1. pub(crate) delta_tau: F, // Encodes positive difference between tau_2 and tau_1. pub(crate) gf: F, // Encodes p_ab + p_ac + p_bc. - // (implicit) ab = 1: weigths p_ab. + // (implicit) ab = 1, // Weigths p_ab. pub(crate) ac: F, // Weights p_ac. pub(crate) bc: F, // Weights p_bc. pub(crate) ancient: F, // Encodes p_ancient. - // (implicit) max_gft = 1; + // (implicit) max_gft = 1, pub(crate) gf_times: GeneFlowTimes<F>, // Encode fraction of max_gft or previous greater gft. } // Implicit constants. @@ -32,6 +33,7 @@ const MAX_GFT: f64 = 1.0; //-------------------------------------------------------------------------------------------------- // Check whether all constraints are enforced. + #[cfg(test)] impl Parameters<f64> { fn valid(&self) -> bool { @@ -84,7 +86,9 @@ impl Sigmoid for &Tensor { } } +//-------------------------------------------------------------------------------------------------- // Calculate parameters from the scores. + macro_rules! scores_to_parms { ($score:expr, $Ty:ident) => {{ let s = $score; @@ -127,11 +131,13 @@ macro_rules! scores_to_parms { } }}; } + impl From<&Scores<f64>> for Parameters<f64> { fn from(s: &Scores<f64>) -> Self { scores_to_parms!(s, f64) } } + impl From<&Scores<Tensor>> for Parameters<Tensor> { #[allow(clippy::similar_names)] fn from(s: &Scores<Tensor>) -> Self { @@ -139,19 +145,13 @@ impl From<&Scores<Tensor>> for Parameters<Tensor> { } } +//-------------------------------------------------------------------------------------------------- // Calculate scores from parameters. + macro_rules! parms_to_score { ($parms:expr, $ln:ident) => {{ let p = $parms; - let Parameters { - p_ab, - p_ac, - p_bc, - p_ancient, - tau_1, - tau_2, - .. - } = p; + let Parameters { p_ab, p_ac, p_bc, p_ancient, tau_1, tau_2, .. } = p; let theta = p.theta.$ln(); let ancient = p_ancient.rev_sigmoid(); @@ -197,6 +197,7 @@ impl From<&Parameters<f64>> for Scores<f64> { parms_to_score!(p, ln) } } + impl From<&Parameters<Tensor>> for Scores<Tensor> { fn from(p: &Parameters<Tensor>) -> Self { parms_to_score!(p, log) @@ -204,8 +205,8 @@ impl From<&Parameters<Tensor>> for Scores<Tensor> { } //-------------------------------------------------------------------------------------------------- +// Extract float values and/or gradient information from tensors. -// Extract float values from tensors. macro_rules! into_scores { ($scores:expr, $get:ident) => {{ let Scores { @@ -238,14 +239,24 @@ impl From<&Scores<Tensor>> for Scores<f64> { } } -// Extract gradient info from tensors. -impl From<&Scores<Tensor>> for Scores<ValGrad> { - fn from(scores: &Scores<Tensor>) -> Self { - let getgrad = |t: &Tensor| ValGrad { +impl Scores<Tensor> { + pub(crate) fn with_grads(&self, grad: &Tensor) -> Scores<ValGrad> { + // TODO: this will require user config info + // when some parameters tracking become opt-out. + // For now, this assumes that only the *last* values (gf_times) are not tracked. + let mut it = (0..grad.size1().unwrap()).map(|i| grad.get(i)); + let mut getgrad = |t: &Tensor| ValGrad { value: t.double_value(&[]), - gradient: t.requires_grad().then(|| t.grad().double_value(&[])), + gradient: it.next().map(|g| g.double_value(&[])), }; - into_scores!(scores, getgrad) + into_scores!(self, getgrad) + } + + // Assuming all tensors have same kind and device, + // pick any to return which they are. + pub(crate) fn kind_device(&self) -> (Kind, Device) { + let p = &self.theta; + (p.kind(), p.device()) } } @@ -275,32 +286,6 @@ impl From<&Scores<ValGrad>> for Scores<f64> { } } -impl Scores<Tensor> { - // Duplicate without duplicating underlying tensors. - pub(crate) fn alias(&self) -> Self { - let Scores { - theta, - tau_1, - delta_tau, - gf, - ac, - bc, - ancient, - gf_times, - } = self; - Scores { - theta: theta.alias(), - tau_1: tau_1.alias(), - delta_tau: delta_tau.alias(), - gf: gf.alias(), - ac: ac.alias(), - bc: bc.alias(), - ancient: ancient.alias(), - gf_times: GeneFlowTimes(gf_times.0.iter().map(Tensor::alias).collect()), - } - } -} - macro_rules! score_to_parms { ($F:ident) => { impl From<&Scores<$F>> for Parameters<f64> { @@ -315,6 +300,8 @@ score_to_parms!(ValGrad); score_to_parms!(Tensor); //================================================================================================== +// Display. + impl<F: Num + Display + fmt::Debug> fmt::Display for Scores<F> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let Scores { @@ -341,15 +328,18 @@ impl<F: Num + Display + fmt::Debug> fmt::Display for Scores<F> { } //================================================================================================== +// Test. + #[cfg(test)] mod tests { use std::iter; - use super::*; use float_eq::float_eq; use rand::prelude::*; use tch::{Device, Kind, Tensor}; + use super::*; + #[test] fn parameters_scores_random_roundtrips() { let n_tests = 1000; diff --git a/src/optim.rs b/src/optim.rs new file mode 100644 index 0000000000000000000000000000000000000000..b4f9503a5ba7b2813d8904924f5f888fa9285f90 --- /dev/null +++ b/src/optim.rs @@ -0,0 +1,589 @@ +// Disatisfied with tch-rs optimisers, +// here is a collection of homebrew optimisers procedures +// built on tch-rs Tensors and autograd. + +pub(crate) mod bfgs; +pub(crate) mod gd; +pub(crate) mod tensor; +pub(crate) mod wolfe_search; + +use std::num::NonZeroUsize; + +use paste::paste; +use snafu::{ensure, Snafu}; +use tch::Tensor; +use tensor::Loggable; +pub(crate) use tensor::OptimTensor; + +// Abstract trait shared by optimisers. +pub(crate) trait Optim { + type Error: Into<Error>; + type Result: OptimResult; + // Minimize the given scalar function y = f(x). + // assuming that the tensor 'x' + // contains all and only the gradient-tracking leaves of it. + // The 'log' parameters sets the number of steps to wait before displaying one. + // Set to 0 to not log anything. + fn minimize( + &self, + f: impl Fn(&Tensor) -> Tensor, + init: &Tensor, + log: u64, + ) -> Result<Self::Result, Error>; +} + +// Abstact over the possible return values of the minimisation function. +pub(crate) trait OptimResult { + fn best_vars(&self) -> &Tensor; + fn best_loss(&self) -> f64; + fn n_eval(&self) -> u64; + fn n_diff(&self) -> u64; +} + +// Ease implementation of the above for struct with explicit field names. +macro_rules! simple_optim_result_impl { + ($Name:ident) => { + impl crate::optim::OptimResult for $Name { + fn best_vars(&self) -> &Tensor { + &self.vars + } + fn best_loss(&self) -> f64 { + self.loss + } + fn n_eval(&self) -> u64 { + self.n_eval + } + fn n_diff(&self) -> u64 { + self.n_diff + } + } + }; +} +use simple_optim_result_impl; + +//================================================================================================== +// Keep track of the best candidate found. + +pub(crate) struct Best { + loss: f64, + vars: Tensor, +} + +impl Best { + fn new(loss: f64, vars: &Tensor) -> Self { + Self { loss, vars: vars.detach().copy() } + } + + fn update(&mut self, loss: f64, vars: &Tensor) { + if loss <= self.loss { + self.loss = loss; + self.vars = vars.detach().copy(); + } + } +} + +//================================================================================================== +// Keep track of latest search history. + +struct History { + slope: Box<dyn Fn(&Tensor) -> f64>, // Calculate slope from ordered trace data. + // Rotating log. + trace: Vec<f64>, + size: usize, + next_up: usize, // Rotates around the trace. + full: bool, // Raise on first wrapping. +} + +impl History { + fn new(size: NonZeroUsize, (kind, device): (tch::Kind, tch::Device)) -> Self { + let size: usize = size.into(); + let n: i64 = size.try_into().unwrap_or_else(|e| { + panic!( + "Too much history required \ + to fit in tensor length integer type: {size}:\n{e}" + ) + }); + let x = Tensor::linspace(0, n - 1, n, (kind, device)); + #[allow(clippy::cast_precision_loss)] + let mean_x = 0.5 * (n as f64 - 1.0); + let dev_x = x - mean_x; + let dev_x_square = dev_x.square().sum(kind).to_double(); + Self { + slope: Box::new(move |y| { + let mean_y = y.mean(kind); + let dev_y = y - mean_y; + let slope = dev_x.dot(&dev_y) / dev_x_square; + slope.to_double() + }), + trace: vec![0.0; size], + size, + next_up: 0, + full: false, + } + } + + fn update(&mut self, loss: f64) { + self.trace[self.next_up] = loss; + self.next_up += 1; + if self.next_up == self.size { + self.next_up = 0; + self.full = true; + } + } + + // Calculate mean slope over the history. + // Don't return if history hasn't been filled yet. + fn slope(&self) -> Option<f64> { + self.full.then(|| { + // Reform a new ordered tensor from the rotated one. + let trace = &self.trace; + let i = self.next_up; + let (latest, oldest) = trace.split_at(i); + let [latest, oldest] = [latest, oldest].map(Tensor::from_slice); + let y = Tensor::cat(&[oldest, latest], 0); + + (*self.slope)(&y) + }) + } +} + +// Useful to optimizers using history +// to periodically check slope and stop when it reaches some threshold. +#[derive(Debug)] +pub(crate) struct SlopeTrackingConfig { + history_size: NonZeroUsize, + threshold: f64, // Positive. + grain: u64, // Wait this number of steps before checking slope (0 to never check) +} + +impl SlopeTrackingConfig { + pub(crate) fn new(history_size: usize, threshold: f64, grain: u64) -> Result<Self, Error> { + ensure!( + history_size > 1, + cerr!( + ("Two history points at least are required to calculate slope. \ + Received {history_size}.") + ) + ); + let history_size = NonZeroUsize::new(history_size).unwrap(); + ensure!( + threshold >= 0.0, + cerr!(("Slope threshold must be positive. Received: {threshold}.")), + ); + Ok(Self { history_size, threshold, grain }) + } +} + +struct SlopeTracker<'c> { + config: &'c SlopeTrackingConfig, + history: History, +} + +impl<'c> SlopeTracker<'c> { + fn new(config: &'c SlopeTrackingConfig, kindev: (tch::Kind, tch::Device)) -> Self { + Self { + history: History::new(config.history_size, kindev), + config, + } + } + + // Return slope value if its magnitude is low enough. + fn low_slope(&mut self, loss: f64, n_iter: u64) -> Option<f64> { + self.history.update(loss); + let sg = self.config.grain; + if sg > 0 && n_iter % sg == 0 { + if let Some(slope) = self.history.slope() { + if slope.abs() < self.config.threshold { + return Some(slope); + } + } + } + None + } +} + +//================================================================================================== +// Errors. + +// Not exactly sure how to best handle errors polymorphism among optimizers? +// Here, an explicit list of implementors is required. +#[derive(Debug, Snafu)] +#[snafu(context(suffix(Err)))] +pub enum Error { + #[snafu(display("Configuration error:\n{mess}"))] + Config { mess: String }, + #[snafu(display("Gradient descent failure:\n{source}"))] + Gd { source: gd::Error }, + #[snafu(display("Failure of BFGS algorithm:\n{source}"))] + Bfgs { source: bfgs::Error }, +} + +macro_rules! cerr { + ($fmt:tt) => { + crate::optim::ConfigErr { mess: format!$fmt } + }; +} +use cerr; + +macro_rules! error_convert { + ($($mod:ident),+$(,)?) => {$( + paste! { + impl From<$mod::Error> for Error { + fn from(e: $mod::Error) -> Self { + Error::[< $mod:camel >] { source: e } + } + } + } + )+}; +} +error_convert! { + gd, + bfgs, +} + +//================================================================================================== +// Tests. + +#[cfg(test)] +mod tests { + use std::{fs::File, io::Write, path::PathBuf}; + + use float_eq::float_eq; + use rand::{ + distributions::{Distribution, Uniform}, + rngs::StdRng, + SeedableRng, + }; + use rand_distr::StandardNormal; + use tch::{Device, Kind, Tensor}; + + use crate::{ + config::defaults, + optim::{ + bfgs::Config as BfgsConfig, gd::Config as GdConfig, History, Optim, OptimResult, + OptimTensor, SlopeTrackingConfig, + }, + }; + + // Generate files to debug simple optimisation tests. + const SAMPLE_GRID: bool = false; + fn export_2d_loss_landscape( + loss: impl Fn(&Tensor) -> Tensor, + [a_min, a_max]: [f64; 2], + [b_min, b_max]: [f64; 2], + samples: i64, + filename: &str, + ) { + if !SAMPLE_GRID { + return; + } + println!("Dense evaluation on 2D grid.."); + let (kind, device) = (Kind::Double, Device::Cpu); + + // Export loss function grid. + let a_grid = Tensor::linspace(a_min, a_max, samples, (kind, device)); + let b_grid = Tensor::linspace(b_min, b_max, samples, (kind, device)); + let mut file = File::create(filename).unwrap(); + for a in a_grid.iter::<f64>().unwrap() { + for b in b_grid.iter::<f64>().unwrap() { + let y = loss(&Tensor::from_slice(&[a, b])).to_double(); + writeln!(file, "{a:?}\t{b:?}\t{y:?}").unwrap(); + } + } + } + + fn bfgs_config() -> BfgsConfig { + let mut bfgs: BfgsConfig = (&defaults::bfgs()).try_into().unwrap(); + bfgs.main_trace_path = Some(PathBuf::from("./target/bfgs.csv")); + bfgs.linsearch_trace_path = Some(PathBuf::from("./target/bfgs_linsearch.csv")); + bfgs + } + + #[test] + fn linear_regression() { + // Test bfgs on an easy linear regression case. + let mut rng = StdRng::seed_from_u64(12); + let norm = StandardNormal; + let (kind, device) = (Kind::Double, Device::Cpu); + + // Generate noisy data. + let (a, b) = (5.0, 8.0); + let n_data = 10; + let x = Tensor::linspace(0, 100, n_data, (kind, device)); + let y = a * &x + b; + let err = (0..n_data) + .map(|_| norm.sample(&mut rng)) + .collect::<Vec<f64>>(); + let err = 0.05 * Tensor::from_slice(&err); + let y = y + err; + + // Start from naive estimate. + let init = Tensor::from_slice(&[0., 0.]); + + // Wrap model and data into a single 'loss' function to minimize. + let loss = |p: &Tensor| { + let (a, b) = (p.get(0), p.get(1)); + let prediction = a * &x + b; + // Least squares error. + let errors = &y - prediction; + errors.square().sum(None) + }; + export_2d_loss_landscape(loss, [-10., 10.], [-10., 10.], 1024, "./target/grid.csv"); + + // Fit to data. + let bfgs = bfgs_config(); + let bfgs = bfgs.minimize(loss, &init, 1).unwrap(); + + // Check that we get close enough. + let [oa, ob] = [0, 1].map(|i| bfgs.best_vars().get(i).to_double()); + println!("Found with BFGS: ({oa}, {ob})."); + assert!( + float_eq!(oa, a, abs <= 5e-2), + "{oa} != {a} (𝛥 = {:e})", + a - oa + ); + assert!( + float_eq!(ob, b, abs <= 5e-2), + "{ob} != {b} (𝛥 = {:e})", + b - ob + ); + + // Seek similar results with simple gradient descent. + // Use less naive initial point. + let init = bfgs.best_vars() + 1; + let sl = SlopeTrackingConfig::new(100, 1e-5, 50).unwrap(); + let gd = GdConfig { + max_iter: 20_000, + step_size: 2e-5, + slope_tracking: Some(sl), + }; + let gd = gd.minimize(loss, &init, 1).unwrap(); + let [oa, ob] = [0, 1].map(|i| gd.best_vars().get(i).to_double()); + println!("Found with gradient descent: ({oa}, {ob})."); + assert!( + float_eq!(oa, a, abs <= 1e-2), // (les strict) + "{oa} != {a} (𝛥 = {:e})", + a - oa + ); + assert!( + float_eq!(ob, b, abs <= 2e-1), // (even less strict) + "{ob} != {b} (𝛥 = {:e})", + b - ob + ); + + // Hard-test evaluation counts just to trigger attention whenever the algorithm changes. + assert_eq!((bfgs.n_eval(), bfgs.n_diff()), (25, 25)); // With BFGS instead of DFP. + assert_eq!((gd.n_eval(), gd.n_diff()), (18102, 18101)); // One-off from weak slope stop. + assert!(bfgs.best_loss() < gd.best_loss()); + } + + // A few test cases taken from: + // https://en.wikipedia.org/wiki/Test_functions_for_optimization + fn sphere(kind: Kind) -> impl Fn(&Tensor) -> Tensor { + move |x: &Tensor| (x.square()).sum(kind) + } + + fn rosenbrock(kind: Kind) -> impl Fn(&Tensor) -> Tensor { + move |x: &Tensor| { + let n = x.size1().unwrap(); + let base = x.slice(0, 0, n - 1, 1); + let shift = x.slice(0, 1, n, 1); + let lhs: Tensor = 100. * (shift - base.square()); + let rhs: Tensor = 1. - base; + (lhs.square() + rhs.square()).sum(kind) + } + } + + fn beale() -> impl Fn(&Tensor) -> Tensor { + move |x: &Tensor| { + let (x, y) = (&x.get(0), &x.get(1)); + let aa: Tensor = 1.5 - x + x * y; + let bb: Tensor = 2.25 - x + x * y.square(); + let cc: Tensor = 2.625 - x + x * y * y.square(); + aa.square() + bb.square() + cc.square() + } + } + + fn goldstein_price() -> impl Fn(&Tensor) -> Tensor { + move |x: &Tensor| { + let (x, y) = (&x.get(0), &x.get(1)); + let xs = &x.square(); + let ys = &y.square(); + let xy = &(x * y); + let left: Tensor = 1. + + (x + y + 1.).square() * (19. - 14. * x + 3. * xs - 14. * y + 6. * xy + 3. * ys); + let right = 30. + + (x * 2. - y * 3.).square() + * (18. - 32. * x + 12. * xs + 48. * y - 36. * xy + 27. * ys); + left * right + } + } + + fn levi() -> impl Fn(&Tensor) -> Tensor { + move |x: &Tensor| { + let (x, y) = (&x.get(0), &x.get(1)); + let pi = std::f64::consts::PI; + (3. * pi * x).sin().square() + + (x - 1.).square() * (1. + (3. * pi * y).sin().square()) + + (y - 1.).square() * (1. + (2. * pi * y).sin().square()) + } + } + + fn ackley() -> impl Fn(&Tensor) -> Tensor { + move |x: &Tensor| { + let (x, y) = (&x.get(0), &x.get(1)); + let tau = std::f64::consts::TAU; + let e = std::f64::consts::E; + (((x.square() + y.square()) * 0.5).sqrt() * -0.2).exp() * -20. + - (((tau * x).cos() + (tau * y).cos()) * 0.5).exp() + + e + + 20. + } + } + + fn test_2d( + loss: impl Fn(&Tensor) -> Tensor, + (opt_x, opt_y): ([f64; 2], f64), + init: [f64; 2], + [x_range, y_range]: [[f64; 2]; 2], + [n_eval, n_diff]: [u64; 2], + ) -> impl OptimResult { + // Possibly export landscape to debug. + export_2d_loss_landscape(&loss, x_range, y_range, 1024, "./target/grid.csv"); + + let init = Tensor::from_slice(&init); + let bfgs = bfgs_config(); + let bfgs = bfgs.minimize(&loss, &init, 1).unwrap(); + + // Extract optimum found. + for (actual, expected) in (0..=1) + .map(|i| bfgs.best_vars().get(i).to_double()) + .zip(opt_x) + { + assert!( + float_eq!(expected, actual, abs <= 1e-6), + "{actual} != {expected} (𝛥 = {:e})", + actual - expected + ); + } + let bl = bfgs.best_loss(); + assert!(float_eq!(opt_y, bl, abs <= 1e-6), "{bl:e} != {opt_y}"); + println!( + "Optimisation success: {} evals and {} diffs.", + bfgs.n_eval(), + bfgs.n_diff() + ); + // Also hard-test these numbers to trigger attention on algorithm change. + assert_eq!( + [bfgs.n_eval(), bfgs.n_diff()], + [n_eval, n_diff], + "New number of evaluations/diffs?" + ); + bfgs + } + + #[test] + fn sphere_2d() { + test_2d( + sphere(Kind::Double), + ([0., 0.], 0.), + [1., 1.5], + [[-2., 2.], [-2., 2.]], + [4, 4], + ); + } + + #[test] + fn rosenbrock_2d() { + test_2d( + rosenbrock(Kind::Double), + ([1., 1.], 0.), + [-1., 2.5], + [[-2., 2.], [-1., 3.]], + [330, 330], + ); + } + + #[test] + fn beale_2d() { + test_2d( + beale(), + ([3., 0.5], 0.), + [-1., -2.], + [[-4.5, 4.5], [-4.5, 4.5]], + [62, 62], + ); + } + + #[test] + fn goldstein_price_2d() { + test_2d( + goldstein_price(), + ([0., -1.], 3.), + [-1., -2.], + [[-2., 2.], [-3., 1.]], + [99, 99], + ); + } + + #[test] + fn levi_2d() { + let tau = std::f64::consts::TAU; + test_2d( + levi(), + ([1., 1.], 0.), + [-4., -3.], + [[-tau, tau], [-tau, tau]], + [13, 13], + ); + } + + #[test] + fn ackley_2d() { + let tau = std::f64::consts::TAU; + test_2d( + ackley(), + ([0., 0.], 0.), + [-4.5, -3.], + [[-tau, tau], [-tau, tau]], + [107, 107], + ); + } + + #[test] + fn rosenbrock_20d() { + let mut rng = StdRng::seed_from_u64(12); + let unif = Uniform::new(-5., 5.); + let kind = Kind::Double; + let loss = rosenbrock(kind); + let init = (0..20).map(|_| unif.sample(&mut rng)).collect::<Vec<_>>(); + let init = Tensor::from_slice(&init); + let bfgs = bfgs_config(); + let bfgs = bfgs.minimize(&loss, &init, 1).unwrap(); + for i in 0..20 { + let oi = bfgs.best_vars().get(i).to_double(); + assert!( + float_eq!(oi, 1., abs <= 1e-3), + "variable {i}: {oi} != 1 (𝛥 = {:e})", + 1. - oi + ); + } + assert!(float_eq!(bfgs.best_loss(), 0., abs <= 1e-6)); + } + + #[test] + fn slope_calculation() { + let kindev = (Kind::Double, Device::Cpu); + let (a, b) = (5., 8.); + let n_data: usize = 100; + #[allow(clippy::cast_precision_loss)] + let x = Tensor::linspace(0., n_data as f64 - 1., n_data.try_into().unwrap(), kindev); + let y: Tensor = a * x + b; + let mut history = History::new(n_data.try_into().unwrap(), kindev); + for value in y.iter::<f64>().unwrap() { + history.update(value); + } + let estimate = history.slope().unwrap(); + assert!(float_eq!(estimate, a, ulps <= 1)); + } +} diff --git a/src/optim/bfgs.rs b/src/optim/bfgs.rs new file mode 100644 index 0000000000000000000000000000000000000000..01550d7dcd01c36f5e4415c843bdf67f99d8c251 --- /dev/null +++ b/src/optim/bfgs.rs @@ -0,0 +1,345 @@ +// Implementation of BFGS algorithm with tch-rs, +// taken from Nocedal & Wright (2006). +// Stops on its own when the mean slope of the history +// becomes flat, provided the history is full. + +use std::{ + fs::File, + io::{self, Write}, + path::PathBuf, +}; + +use snafu::{ensure, ResultExt, Snafu}; +use tch::Tensor; + +use crate::optim::{ + simple_optim_result_impl, + wolfe_search::{ + self, Config as WolfeSearchConfig, Error as WolfeSearchError, LastStep, Location, + Summary as WolfeSearchSummary, + }, + Best, Error as OptimError, Optim, OptimTensor, SlopeTracker, SlopeTrackingConfig, +}; + +// The exposed config. +#[derive(Debug)] +pub(crate) struct Config { + pub(crate) max_iter: u64, + pub(crate) wolfe: WolfeSearchConfig, + // If required. + pub(crate) slope_tracking: Option<SlopeTrackingConfig>, + // Logging. + pub(crate) main_trace_path: Option<PathBuf>, + pub(crate) linsearch_trace_path: Option<PathBuf>, + // When gradient from one step to the next happen to be equal, + // the hessian approximation breaks with a division by zero. + // In this situation, consider that the search is over + // if the step norm was shorter than this threshold value. + pub(crate) small_step: f64, + // Paths for recording optimisation details. +} + +#[derive(Debug)] +pub(crate) struct BfgsResult { + loss: f64, + vars: Tensor, + n_eval: u64, + n_diff: u64, +} +simple_optim_result_impl!(BfgsResult); + +// The actual search state. +struct BfgsSearch<'c, F: Fn(&Tensor) -> Tensor> { + // Search configuration. + cf: &'c Config, + + // Objective function, + fun: F, + n_vars: usize, // Cached: expensive to query. + + // Current variables values. + vars: Tensor, + // Current loss. + loss: f64, + // Current grad. + grad: Tensor, + + // Running estimate of the *inverse* hessian matrix. + hess: Tensor, + // (initial value + recycled constant) + eye: Tensor, + + // Current search direction. + dir: Tensor, + + // Iteration number. + n_steps: u64, + // Keep track of total number of function evaluation and differentiation. + n_eval: u64, + n_diff: u64, + + // Keep track of the best point found so far. + best: Best, + + // If desired. + slope_tracker: Option<SlopeTracker<'c>>, + + // Log files to write to. + main_file: Option<File>, // Main search steps. + lin_file: Option<File>, // Linear search on every step. +} + +impl Optim for Config { + type Error = Error; + type Result = BfgsResult; + fn minimize( + &self, + fun: impl Fn(&Tensor) -> Tensor, + init: &Tensor, + log: u64, + ) -> Result<Self::Result, OptimError> { + let n_parms = init.size1().unwrap_or_else(|e| { + panic!( + "Parameters must be single-dimensional. Received dimensions {:?} instead:\n{e}", + init.size(), + ) + }); + let (kind, device) = (init.kind(), init.device()); + + // Evaluate initial gradient value. + let mut vars = init.set_requires_grad(true); + vars.zero_grad(); + let loss_t = fun(&vars); + let loss = loss_t.to_double(); + ensure!(loss.is_finite(), NonFiniteLossErr { loss, vars }); + + loss_t.backward(); + let grad = vars.grad().copy(); + ensure!(grad.is_all_finite(), NonFiniteGradErr { grad, vars, loss }); + + let best = Best::new(loss, &vars); + + // Initial hessian estimate is plain identity. + let eye = Tensor::eye(n_parms, (kind, device)); + let hess = eye.copy(); + + // First direction to search. + let dir = (-&hess).mv(&grad); + + // Will only be differentiated as `x + alpha * p` during linear searches. + vars = vars.detach(); + + // Initialize slope tracking if desired. + let slope_tracker = self.slope_tracking.as_ref().map(|s| { + let mut t = SlopeTracker::new(s, (kind, device)); + assert!(t.low_slope(loss, 0).is_none()); + t + }); + + let create_file = |path: &Option<PathBuf>| -> Result<_, Error> { + Ok(if let Some(path) = path { + Some(File::create(path).context(TraceErr { path })?) + } else { + ensure!(log == 0, NoTracePathErr { log }); + None + }) + }; + + // Spin actual search steps. + let mut search = BfgsSearch { + fun, + n_vars: vars.size1().unwrap().try_into().unwrap(), + vars, + loss, + grad, + hess, + eye, + dir, + n_steps: 0, + n_eval: 1, + n_diff: 1, + main_file: create_file(&self.main_trace_path)?, + lin_file: create_file(&self.linsearch_trace_path)?, + cf: self, + slope_tracker, + best, + }; + + Ok(search.run_search(log)?) + } +} + +impl<F: Fn(&Tensor) -> Tensor> BfgsSearch<'_, F> { + fn run_search(&mut self, log: u64) -> Result<BfgsResult, Error> { + let cf = self.cf; + if cf.max_iter == 0 { + println!("No optimisation step asked for."); + return Ok(self.result()); + } + + if log > 0 { + self.write_header()?; + wolfe_search::write_header(self.lin_file.as_mut().unwrap(), self.n_vars)?; + self.log(0.)?; // Fake initial zero step. + } + + loop { + let log_this_one = log > 0 && self.n_steps % log == 0; + + // Pick a search direction. + self.dir = (-&self.hess).mv(&self.grad); + + // Linear-search for an acceptable step size. + let WolfeSearchSummary { last_step: res, n_eval, n_diff } = wolfe_search::search( + Location::new(&self.fun, &self.dir, &self.vars, self.loss, &self.grad), + &self.cf.wolfe, + &mut self.best, + ( + if log_this_one { log } else { 0 }, + self.lin_file.as_mut(), + self.n_steps, + ), + )?; + self.n_eval += n_eval; + self.n_diff += n_diff; + let &mut Self { n_steps, n_eval, n_diff, .. } = self; + let Some(LastStep { + step_size, + step, + vars_after_step, + loss_after_step, + grad_after_step, + }) = res + else { + println!( + "Reached null step size after {n_steps} steps: \ + {n_eval} evaluations and {n_diff} differentiations." + ); + break Ok(self.result()); + }; + + // Check step norm. + let norm = step.dot(&step).to_double(); + if norm <= 0.0 { + println!( + "Reach silent step after {n_steps} steps: \ + {n_eval} evaluations and {n_diff} differentiations." + ); + break Ok(self.result()); + } + if norm < self.cf.small_step && self.grad == grad_after_step { + println!( + "Reach silent grad step after {n_steps} steps: \ + {n_eval} evaluations and {n_diff} differentiations." + ); + break Ok(self.result()); + } + + // Accept step. + self.vars = vars_after_step; + self.n_steps += 1; + let n_steps = self.n_steps; + + // Check slope. + if let Some(ref mut tracker) = self.slope_tracker { + if let Some(lowslope) = tracker.low_slope(self.loss, n_steps) { + println!("Weak loss slope ({lowslope:e}) on iteration {n_steps}: stopping."); + break Ok(self.result()); + } + } + + // Update hessian approximation (eq. 6.17). + let Self { ref grad, ref hess, .. } = *self; + let diff = &grad_after_step - grad; + let rho = 1. / &diff.dot(&step); + let left = &self.eye - &rho * step.outer(&diff); + let right = &self.eye - &rho * diff.outer(&step); + self.hess = left.mm(&hess.mm(&right)) + rho * step.outer(&step); + + self.loss = loss_after_step; + self.grad = grad_after_step; + + if log_this_one { + self.log(step_size)?; + } + + if n_steps >= cf.max_iter { + println!("Max iteration reached: {n_steps}."); + break Ok(self.result()); + } + } + } + + fn result(&self) -> BfgsResult { + let &Self { ref best, n_eval, n_diff, .. } = self; + BfgsResult { + vars: best.vars.detach().copy(), + loss: best.loss, + n_eval, + n_diff, + } + } + + // Traces. + fn write_header(&mut self) -> Result<(), Error> { + // Main headers. + let mut header = ["step", "loss", "step_size"] + .into_iter() + .map(ToString::to_string) + .collect::<Vec<_>>(); + for vec in ["vars", "grad", "dir"] { + for i in 0..self.n_vars { + header.push(format!("{vec}_{i}")); + } + } + let file = self.main_file.as_mut().unwrap(); + let path = self.cf.main_trace_path.as_ref().unwrap(); + writeln!(file, "{}", header.join(",")).with_context(|_| TraceErr { path })?; + + Ok(()) + } + + fn log(&mut self, step_size: f64) -> Result<(), Error> { + macro_rules! w { + ($fmt:literal $(, $val:expr)?) => {{ + let file = self.main_file.as_mut().unwrap(); + let path = self.cf.main_trace_path.as_ref().unwrap(); + write!(file, $fmt $(, $val)?).with_context(|_| TraceErr{path}) + }}; + } + w!("{}", self.n_steps)?; + w!(",{}", self.loss)?; + w!(",{}", step_size)?; + for t in [&self.vars, &self.grad, &self.dir] { + for f in t.iter::<f64>().unwrap() { + w!(",{f}")?; + } + } + w!("\n") + } +} + +#[derive(Debug, Snafu)] +#[snafu(context(suffix(Err)))] +pub enum Error { + #[snafu(display("Obtained non-finite loss ({loss}) with these variables:\n{vars:?}"))] + NonFiniteLoss { loss: f64, vars: Tensor }, + #[snafu(display( + "Obtained non-finite gradient for loss ({loss}):\ngradient: {grad:?}\nvariables: {vars:?}" + ))] + NonFiniteGrad { + grad: Tensor, + vars: Tensor, + loss: f64, + }, + #[snafu(transparent)] + WolfeSearch { source: WolfeSearchError }, + #[snafu(display("Error writing to trace file {:?}:\n{source}", + path.as_os_str().to_string_lossy()))] + Trace { path: PathBuf, source: io::Error }, + #[snafu(display( + "A log level of {log} is required, \ + but no file path has been choosen to record traces." + ))] + NoTracePath { log: u64 }, +} diff --git a/src/optim/gd.rs b/src/optim/gd.rs new file mode 100644 index 0000000000000000000000000000000000000000..9c03a6ce7cd68ea2dcdfc14f81f0a6076c0917c9 --- /dev/null +++ b/src/optim/gd.rs @@ -0,0 +1,172 @@ +// Implementation of simple, naive gradient descent. +// Keeps a history of the last few loss values visited +// to return only the best one found so far (latest in case of ex-aequo). +// Stops on its own when the mean slope of the history +// becomes flat, provided the history is full. + +use snafu::Snafu; +use tch::Tensor; + +use crate::optim::{ + simple_optim_result_impl, Best, Error as OptimError, Loggable, Optim, OptimTensor, + SlopeTracker, SlopeTrackingConfig, +}; + +#[derive(Debug)] +pub(crate) struct Config { + pub(crate) max_iter: u64, + pub(crate) step_size: f64, // Above 0. + pub(crate) slope_tracking: Option<SlopeTrackingConfig>, +} + +#[derive(Debug)] +pub(crate) struct GdResult { + loss: f64, + vars: Tensor, + n_eval: u64, + n_diff: u64, +} +simple_optim_result_impl!(GdResult); + +impl Optim for Config { + type Error = Error; + type Result = GdResult; + fn minimize( + &self, + fun: impl Fn(&Tensor) -> Tensor, + init: &Tensor, + log: u64, + ) -> Result<Self::Result, OptimError> { + let &Self { max_iter, step_size, ref slope_tracking } = self; + let kindev = (init.kind(), init.device()); + let mut n_eval = 0; + let mut n_diff = 0; + + // Factorize common checks. + macro_rules! check_finite_loss { + ($loss:ident, $vars:ident, $step:expr) => { + snafu::ensure!( + $loss.is_finite(), + NonFiniteLossErr { + loss: $loss, + vars: $vars.detach().copy(), + step: $step + } + ); + }; + } + macro_rules! check_finite_grad { + ($grad:ident, $loss:ident, $vars:ident, $step:expr) => { + snafu::ensure!( + $grad.is_all_finite(), + NonFiniteGradErr { + grad: $grad.copy(), + loss: $loss, + vars: $vars.detach().copy(), + step: $step, + } + ); + }; + } + + // Evaluate initial loss value. + let mut x = init.set_requires_grad(true); + x.zero_grad(); + let mut y_t = fun(&x); + let mut y = y_t.to_double(); + n_eval += 1; + check_finite_loss!(y, x, 0u64); + + // Keep track of the best (loss, parameters) pair found so far. + let mut best = Best::new(y_t.to_double(), init); + + // Initialize slope tracking if desired. + let mut slope_tracker = slope_tracking.as_ref().map(|s| { + let mut t = SlopeTracker::new(s, kindev); + assert!(t.low_slope(y, 0).is_none()); + t + }); + + // Evaluate intial gradient value. + y_t.backward(); + let mut grad = x.grad().copy(); + n_diff += 1; + check_finite_grad!(grad, y, x, 0u64); + + if max_iter == 0 { + println!("No optimisation step asked for."); + return Ok(GdResult { vars: best.vars, loss: best.loss, n_eval, n_diff }); + } + + let mut n_steps = 0; + loop { + // Update. + let step = -step_size * &grad; + tch::no_grad(|| { + let _ = x.g_add_(&step); + }); + + // Re-evaluate loss in new location. + x.zero_grad(); + y_t = fun(&x); + y = y_t.to_double(); + n_eval += 1; + check_finite_loss!(y, x, n_steps); + best.update(y, &x); + + // Check slope. + if let Some(ref mut tracker) = slope_tracker { + if let Some(lowslope) = tracker.low_slope(y, n_steps) { + println!("Weak loss slope ({lowslope:e}) on iteration {n_steps}: stopping."); + break Ok(GdResult { loss: best.loss, vars: best.vars, n_eval, n_diff }); + } + } + + // Re-evaluate gradient in new location. + y_t.backward(); + grad = x.grad().copy(); + n_diff += 1; + check_finite_grad!(grad, y, x, n_steps); + + if log > 0 && n_steps % log == 0 { + println!("step {n_steps:<3} y: {} x: {}", y.format(), x.format()); + println!(" {} x': {}", None.format(), grad.format()); + } + + // Move on to next step. + n_steps += 1; + if n_steps >= max_iter { + if let Some(slope_threshold) = slope_tracking.as_ref().map(|s| s.threshold) { + println!( + "Max iteration reached ({n_steps}) \ + without finding a loss slope magnitude \ + lower than {slope_threshold:e}." + ); + } else { + println!("Max iteration reached ({n_steps})."); + } + break Ok(GdResult { loss: best.loss, vars: best.vars, n_eval, n_diff }); + } + } + } +} + +#[derive(Debug, Snafu)] +#[snafu(context(suffix(Err)))] +pub enum Error { + #[snafu(display( + "Obtained non-finite loss ({loss}) \ + with these variables on step {step}:\n{vars:?}" + ))] + NonFiniteLoss { loss: f64, vars: Tensor, step: u64 }, + #[snafu(display( + "Obtained non-finite gradient for loss ({loss}) on step {step}:\ + \ngradient: {grad:?}\nvariables: {vars:?}" + ))] + NonFiniteGrad { + grad: Tensor, + vars: Tensor, + loss: f64, + step: u64, + }, +} diff --git a/src/optim/tensor.rs b/src/optim/tensor.rs new file mode 100644 index 0000000000000000000000000000000000000000..3da983d3d70f90a0c0b5f8ff3f0e4ae25237a4c1 --- /dev/null +++ b/src/optim/tensor.rs @@ -0,0 +1,80 @@ +// Convenience trait for tensors, as used in this optimisation module. + +use tch::Tensor; + +//-------------------------------------------------------------------------------------------------- +// Pulls value back from device. + +pub(crate) trait OptimTensor { + fn to_double(&self) -> f64; + fn is_all_finite(&self) -> bool; +} + +impl OptimTensor for Tensor { + fn to_double(&self) -> f64 { + self.double_value(&[]) + } + fn is_all_finite(&self) -> bool { + self.isfinite().all().to_double() > 0.5 + } +} + +//-------------------------------------------------------------------------------------------------- +// Debug logging features, formatting values so they (mostly) fit into columns. + +pub(super) trait Loggable { + fn format(&self) -> String; +} + +impl Loggable for f64 { + fn format(&self) -> String { + let prec = 16; // Large to get fine precision displayed. + let mut width = 23; // Large to have logs correctly align. + let leading_space = if *self >= 0. { + " " + } else { + width += 1; + "" + }; + format!("{leading_space}{self:<width$.prec$e}") + } +} + +impl Loggable for u64 { + fn format(&self) -> String { + let width = 23; // For consistency with floats. + format!(" {self:<width$}") + } +} + +impl Loggable for Option<f64> { + fn format(&self) -> String { + let width = 23; + if let Some(v) = self { + v.format() + } else { + format!("{:<width$}", "") + } + } +} + +impl Loggable for Tensor { + fn format(&self) -> String { + let mut res = String::new(); + res.push('['); + let mut it = self.iter::<f64>().unwrap(); + if let Some(mut i) = it.next() { + loop { + res.push_str(&i.format()); + if let Some(n) = it.next() { + res.push_str(", "); + i = n; + } else { + break; + } + } + } + res.push(']'); + res + } +} diff --git a/src/optim/wolfe_search.rs b/src/optim/wolfe_search.rs new file mode 100644 index 0000000000000000000000000000000000000000..902356021c98e04643451ecb954445580af5b086 --- /dev/null +++ b/src/optim/wolfe_search.rs @@ -0,0 +1,624 @@ +// Given a local evaluation (y, grad) = (f(x), ∇f(x)) and a search direction p, +// perform a linear search to figure +// an acceptable positive step size 'alpha' in this direction, +// with respect to the strong Wolfe criteria. +// The algorithm implemented here is taken from Nocedal & Wright (2006), +// and the code uses their notations. +// +// There is no predefined value for alpha_max, +// but we ASSUME that all values of phi(alpha) are finite and non-NaN +// up to some horizon value h = alpha_max, +// and that acceptable step values do exist within this horizon. +// +// NaN and/or infinite values +// v v v +// 0 h +// a = alpha |----------|🗙-🗙-🗙-> ∞ +// ^^ ^^^ +// acceptable steps +// +// If the tested value for alpha is non-finite, +// it is reduced by a constant factor until we find a < h. +// If the bracketing phase requires increasing alpha again, +// we perform a binary search to find a < a_next < h. + +use std::{ + fs::File, + io::{self, Write}, + ops::ControlFlow, +}; + +use snafu::{ensure, Snafu}; +use tch::Tensor; + +use crate::optim::{tensor::OptimTensor, Best}; + +#[derive(Debug)] +pub(crate) struct Config { + // 0 < c1 < c2 < 1 + pub(crate) c1: f64, + pub(crate) c2: f64, + + // Starting point. + pub(crate) init_step: f64, // Above 0. + + // Scaling factor to reduce alpha in search for a < h. + pub(crate) step_decrease: f64, // (0 < · < 1) + + // Scaling factor to increase alpha during the bracketing phase, + // unless some h < a was already found and we use binary search instead. + pub(crate) step_increase: f64, // (1 < ·) + + // When f and f' evaluate to the same values + // at both ends of the zooming interval + // or in the sampled step size within it, + // stop the research if the gradient is under this threshold. + pub(crate) flat_gradient: f64, // (0 ⩽ ·) + + // If the next zoom candidate issued from cubic interpolation + // is too close from the zoom interval + // in terms of this fraction of the interval, + // then use trivial bisection instead. + pub(crate) bisection_threshold: f64, // (0 ⩽ · ⩽ 1/2) +} + +pub(crate) struct Search<'c, 's, F: Fn(&Tensor) -> Tensor> { + cf: &'c Config, + + // Borrow situation from the calling search step. + loc: Location<'s, F>, + + // Log level. + n_log: u64, + + // Evolvable search state. + best: &'s mut Best, // Keep track of the best loss found so far. + + // Keep track of number of evaluations + differentiations performed. + n_eval: u64, + n_diff: u64, + + // Log search trace, given a search identifier as first column. + trace: Option<&'s mut File>, + id: u64, + n_vars: usize, +} + +// Starting point of the linear search, supposed immutable in this context. +pub(crate) struct Location<'s, F: Fn(&Tensor) -> Tensor> { + pub(crate) f: &'s F, // Objective function. + pub(crate) p: &'s Tensor, // Search direction. + pub(crate) x: &'s Tensor, // Search starting point. + pub(crate) y: f64, // Search starting loss value = phi(0). + pub(crate) phigrad: f64, // Derivative = phi'(0) = p . ∇f(x). +} + +// Pack search variables in various types depending on the search phase. +// Field names: +// alpha = current tested step size. +// step = alpha * p. +// x_stepped = xs = x + step. +// y_stepped = ys = phis = f(xs) = phi(alpha). +// grad_stepped = gs = ∇f(xs). +// phigrad_stepped = phigs = phi'(alpha) = p . ∇f(xs). + +// Initial phase: evaluate phi(alpha), lowering alpha untill a finite value is found. +struct SearchFinite { + alpha: f64, // To be evaluated. + lowest_nonfinite: Option<f64>, // Update to the lowest 'h < a' found, if any. +} +// First iteration into the bracketing loop. +struct Bracketing { + alpha: f64, + sample: Sample, + previous_step: Step, // Initially set for alpha=0. + lowest_nonfinite: Option<f64>, // Inherited/updated from previous state. +} +// The 'zoom' phase. +struct Zoom { + lo: Step, + hi: Step, + last: LastStep, // Useful when terminating. +} + +// Bundle the values obtained after 'sampling' one value of 'alpha'. +struct Sample { + step: Tensor, // Intermediate computation = alpha * p. + x: Tensor, // The one gradient-tracked = x + step. + y: Tensor, // The one to invoke backpropagation on = f(xs) = phi(alpha). + phi: f64, // Local copy of the above. +} + +// Bundle useful values from a 'previous' or 'current' search step. +struct Step { + alpha: f64, // alpha + phi: f64, // phi(alpha) + grad: f64, // phi'(alpha) +} + +impl<'s, F: Fn(&Tensor) -> Tensor> Location<'s, F> { + pub(crate) fn new(f: &'s F, p: &'s Tensor, x: &'s Tensor, y: f64, grad: &'s Tensor) -> Self { + Self { f, p, x, y, phigrad: p.dot(grad).to_double() } + } +} + +#[derive(Debug, PartialEq, Clone, Copy)] +enum BinaryStepDirection { + Up, + Down, +} + +impl<'c, 'binsearch, F: Fn(&Tensor) -> Tensor> Search<'c, 'binsearch, F> { + // Specify Wolfe criteria, assuming loss is finite. + + // Armijo's rule: check that the current step candidate + // would "decrease f sufficiently". + fn sufficient_decrease(&self, alpha: f64, phi: f64) -> bool { + let &Self { + loc: Location { y: phi_zero, phigrad: phigrad_zero, .. }, + cf: Config { c1, .. }, + .. + } = self; + phi <= phi_zero + c1 * alpha * phigrad_zero + } + + // Strong curvature condition: check that the current step candidate + // would "reduce slope sufficiently". + fn weak_slope(&self, phigrad: f64) -> bool { + let Self { + loc: Location { phigrad: phigrad_zero, .. }, + cf: Config { c2, .. }, + .. + } = self; + phigrad.abs() <= c2 * phigrad_zero.abs() + } + + fn run_search(mut self) -> Result<Summary, Error> { + use ControlFlow as C; + + let bracket = self.search_first_finite(&SearchFinite { + alpha: self.cf.init_step, + lowest_nonfinite: None, + })?; + + match self.bracket(bracket)? { + C::Break(summary) => Ok(summary), + C::Continue(zoom) => self.zoom(zoom), + } + } + + fn search_first_finite(&mut self, search: &SearchFinite) -> Result<Bracketing, Error> { + use Error as E; + let SearchFinite { mut alpha, mut lowest_nonfinite } = search; + loop { + match self.sample(alpha) { + Err(E::NonFiniteLoss { loss, vars }) => { + // Try again with a lower value for alpha. + self.log(alpha, loss, None, &vars, None)?; + let lower = alpha * self.cf.step_decrease; + ensure!( + lower > 0., + NoStepSizeYieldingFiniteLossErr { + x: self.loc.x.detach().copy(), + p: self.loc.p.detach().copy(), + } + ); + lowest_nonfinite = Some(alpha); + alpha = lower; + continue; + } + Ok(sample) => { + return Ok(Bracketing { + alpha, + sample, + previous_step: Step { + alpha: 0.0, + phi: self.loc.y, + grad: self.loc.phigrad, + }, + lowest_nonfinite, + }) + } + Err(e) => return Err(e), + } + } + } + + fn bracket(&mut self, mut bracket: Bracketing) -> Result<ControlFlow<Summary, Zoom>, Error> { + use BinaryStepDirection as D; + use ControlFlow as C; + use Error as E; + let mut first = true; + 'bracketing: loop { + let Bracketing { + alpha, + sample: Sample { step, x, y, phi }, + previous_step, + lowest_nonfinite, + } = bracket; + let (grad, phigrad) = self.differentiate(&x, &y, phi)?; + let this_step = || Step { alpha, phi, grad: phigrad }; + self.log(alpha, phi, Some(phigrad), &x, Some(&grad))?; + let last = || Self::last(alpha, step, &x, phi, &grad); + + if !self.sufficient_decrease(alpha, phi) || (!first && phi >= previous_step.phi) { + return Ok(C::Continue(Zoom { + lo: previous_step, + hi: this_step(), + last: last(), + })); + } + if self.weak_slope(phigrad) { + return Ok(C::Break(self.summary(last()))); + } + if phigrad >= 0.0 { + return Ok(C::Continue(Zoom { + lo: this_step(), + hi: previous_step, + last: last(), + })); + } + // Unfruitful: pick a larger candidate. + let over_horizon = if let Some(ln) = lowest_nonfinite { + ln + } else { + // No non-finite loss has been observed yet: increase candidate naively. + let new = alpha * self.cf.step_increase; + match self.sample(new) { + Ok(sample) => { + // Success: keep bracketing. + first = false; + bracket = Bracketing { + alpha: new, + sample, + previous_step: this_step(), + lowest_nonfinite, + }; + continue 'bracketing; + } + // Or a new candidate above horizon has been found. + Err(E::NonFiniteLoss { loss, vars }) => { + self.log(new, loss, None, &vars, None)?; + new + } + Err(e) => return Err(e), + } + }; + // Watch the horizon to not pick a candidate yielding non-finite loss. + // Bisect up first, then bisect down again if we overshoot. + let lo = alpha; // The largest finite candidate found so far. + let mut hi = over_horizon; // Reduce as we search. + let mut dir = D::Up; // Flip if we overshoot. + let mut current = alpha; + 'binsearch: loop { + let new = lo + 0.5 * (hi - lo); + // The binary search dies against floating point precision + // if no progress has been made. + ensure!( + if dir == D::Up { + current < new && new < hi + } else { + lo < new && new < current + }, + DeadBinarySearchErr { alpha: new, summary: self.summary(last()) } + ); + match self.sample(new) { + Ok(sample) => { + // Increase succeeded. + first = false; + bracket = Bracketing { + alpha: new, + sample, + previous_step, + lowest_nonfinite: Some(hi), + }; + continue 'bracketing; + } + Err(E::NonFiniteLoss { loss, vars }) => { + // This candidate was an overshoot, bisect down from now on. + dir = D::Down; + hi = new; + current = new; + self.log(new, loss, None, &vars, None)?; + continue 'binsearch; + } + Err(e) => return Err(e), + } + } + } + } + + fn zoom(&mut self, zoom: Zoom) -> Result<Summary, Error> { + let Zoom { mut lo, mut hi, mut last } = zoom; + loop { + // Cubic interpolation from eq (3.59). + // The minimizer of the cubic is proven to exist within the interval + // containing lo and hi, or it is either endpoint. + let d1 = lo.grad + hi.grad - 3.0 * (lo.phi - hi.phi) / (lo.alpha - hi.alpha); + let d2 = (hi.alpha - lo.alpha).signum() * (d1 * d1 - lo.grad * hi.grad).sqrt(); + let mut a = hi.alpha + - (hi.alpha - lo.alpha) * (hi.grad + d2 - d1) / (hi.grad - lo.grad + 2.0 * d2); + + // Check that this did fall within the range. + let (l, h) = (lo.alpha, hi.alpha); + let (l, h) = if l < h { (l, h) } else { (h, l) }; + ensure!(l <= a && a <= h, BrokenZoomErr { alpha: a, lo: l, hi: h }); + + // If it falls too close from a boundary, bisect instead. + let d = f64::min(a - l, h - a); + if d / (h - l) <= self.cf.bisection_threshold { + a = l + 0.5 * (h - l); + } + ensure!( + l < a && a < h, + DeadZoomErr { alpha: a, summary: self.summary(last) } + ); + + // Evaluate. + let alpha = a; + let Sample { step, x, y, phi } = self.sample(alpha)?; + let (ygrad, grad) = self.differentiate(&x, &y, phi)?; + self.log(alpha, phi, Some(grad), &x, Some(&ygrad))?; + last = Self::last(alpha, step, &x, phi, &ygrad); + + // Catch degeneration. + #[allow(clippy::float_cmp)] + if grad <= self.cf.flat_gradient + && grad == hi.grad + && grad == lo.grad + && hi.phi == lo.phi + && lo.phi == phi + { + return FlatZoomErr { alpha, grad, summary: self.summary(last) }.fail(); + } + + // Decide where to zoom next if necessary. + if phi >= lo.phi || !self.sufficient_decrease(alpha, phi) { + hi = Step { alpha, phi, grad }; + continue; + } + if self.weak_slope(grad) { + return Ok(self.summary(last)); + } + if grad * (hi.alpha - lo.alpha) >= 0.0 { + hi = Step { ..lo }; + } + lo = Step { alpha, phi, grad }; + } + } + + // Sample the objective function at distance 'alpha' in direction 'p'. + // Return all relevant tensors. + // And in particular, the tensor on which to query the backprop + // to possibly get 'phigrad_stepped' later. + // Errors if the sampled loss is non-finite. + fn sample(&mut self, alpha: f64) -> Result<Sample, Error> { + let &mut Self { loc: Location { f, x: x_zero, p, .. }, .. } = self; + + let step = alpha * p; + let x = (x_zero + &step).set_requires_grad(true); + let y = f(&x); + self.n_eval += 1; + let phi = y.to_double(); + + ensure!( + phi.is_finite(), + NonFiniteLossErr { loss: phi, vars: x.detach().copy() }, + ); + + self.best.update(phi, &x); + + Ok(Sample { step, x, y, phi }) + } + + // Differentiate the objective function at distande 'alpha' in direction 'p', + // assuming the object function has just been evaluated on x_stepped. + // Return 'grad_stepped' and 'phigrad_stepped'. + // Error if the gradient is non-finite. + fn differentiate( + &mut self, + x: &Tensor, + y: &Tensor, + phi: f64, // Just to avoid extracting it again from y. + ) -> Result<(Tensor, f64), Error> { + let Self { loc: Location { p, .. }, .. } = self; + y.backward(); + self.n_diff += 1; + let grad = x.grad(); + ensure!( + grad.is_all_finite(), + NonFiniteGradErr { + grad: grad.detach().copy(), + vars: x.detach().copy(), + loss: phi, + } + ); + let phigrad = p.dot(&grad).to_double(); // ( phi'(alpha) = p.∇f ) + Ok((grad, phigrad)) + } + + fn last(alpha: f64, step: Tensor, x: &Tensor, phi: f64, grad: &Tensor) -> LastStep { + LastStep { + step_size: alpha, + step, + vars_after_step: x.detach(), // Detach to not use it again in subsequent grad calc. + loss_after_step: phi, + grad_after_step: grad.detach(), + } + } + + fn summary(&self, last: LastStep) -> Summary { + Summary { + last_step: Some(last), + n_eval: self.n_eval, + n_diff: self.n_diff, + } + } + + // Log one line on every candidate evaluated, + // typically before getting ready to either change it or yield it. + fn log( + &mut self, + alpha: f64, + phi: f64, // = phi(alpha) + phigs: Option<f64>, // = phi'(alpha) + x: &Tensor, // short for xs = x + alpha * p + grad: Option<&Tensor>, // = ∇f(xs) + ) -> Result<(), Error> { + if self.n_log == 0 || self.id % self.n_log > 0 { + return Ok(()); + } + let file = self.trace.as_mut().unwrap(); + macro_rules! w { + ($fmt:literal, $value:expr) => { + write!(file, $fmt, $value)?; + }; + ($fmt:literal if $opt:expr; $alt:literal) => { + if let Some(value) = $opt { + write!(file, $fmt, value)?; + } else { + write!(file, $alt)?; + } + }; + } + w!("{}", self.id); + w!(",{}", alpha); + w!(",{}", phi); + w!(",{}" if phigs; ","); + for t in [Some(x), grad] { + if let Some(t) = t { + for f in t.iter::<f64>().unwrap() { + w!(",{}", f); + } + } else { + for _ in 0..self.n_vars { + write!(file, ",")?; + } + } + } + writeln!(file)?; + Ok(()) + } +} + +pub(crate) fn write_header(file: &mut File, n_vars: usize) -> Result<(), Error> { + let mut header = ["id", "alpha", "loss", "phigrad"] + .into_iter() + .map(ToString::to_string) + .collect::<Vec<_>>(); + for vec in ["vars", "grad"] { + for i in 0..n_vars { + header.push(format!("{vec}_{i}")); + } + } + writeln!(file, "{}", header.join(","))?; + Ok(()) +} + +// Perform the seach. +// Return step size and step vector meeting the criteria +// along with the evaluation + gradient value at the point stepped to +// and the number of evaluations/differentiations used to find the result. +// Don't return values at the point stepped to +// if the best step size found is null. +pub(crate) fn search<'s, F: Fn(&Tensor) -> Tensor>( + loc: Location<'s, F>, + cf: &Config, + best: &'s mut Best, + (n_log, trace, id): (u64, Option<&'s mut File>, u64), +) -> Result<Summary, Error> { + use Error as E; + // Initialize. + let search = Search { + n_vars: loc.p.size1().unwrap().try_into().unwrap(), + loc, + n_log, + best, + cf, + n_eval: 0, + n_diff: 0, + // Not yet calculated. + trace, + id, + }; + match search.run_search() { + Err(E::DeadBinarySearch { alpha, summary } | E::DeadZoom { alpha, summary }) => { + if alpha <= 0. { + println!("Null step size reached: {alpha}."); + Ok(summary) + } else { + println!("(step {id}) Wolfe range too small for ulps around {alpha}?"); + Ok(summary) + } + } + Err(e @ E::FlatZoom { .. }) => { + println!("{e}"); + let E::FlatZoom { summary, .. } = e else { + unreachable!() + }; + Ok(summary) + } + res => res, + } +} + +#[derive(Debug)] +pub struct Summary { + pub(crate) last_step: Option<LastStep>, + pub(crate) n_eval: u64, + pub(crate) n_diff: u64, +} + +#[derive(Debug)] +pub(crate) struct LastStep { + pub(crate) step_size: f64, + pub(crate) step: Tensor, + pub(crate) vars_after_step: Tensor, + pub(crate) loss_after_step: f64, + pub(crate) grad_after_step: Tensor, +} + +#[derive(Debug, Snafu)] +#[snafu(context(suffix(Err)))] +pub enum Error { + #[snafu(display( + "Could not find a step size yielding a finite loss value, \ + starting from point:\n{x}\nin direction:\n{p}" + ))] + NoStepSizeYieldingFiniteLoss { x: Tensor, p: Tensor }, + #[snafu(display( + "Could not find a step size meeting Wolfe lower bound, \ + starting from point:\n{x}\nin direction:\n{p}" + ))] + NoStepLargeEnough { x: Tensor, p: Tensor }, + #[snafu(display("Obtained non-finite loss ({loss}) with these variables:\n{vars:?}"))] + NonFiniteLoss { loss: f64, vars: Tensor }, + #[snafu(display( + "Obtained non-finite gradient for loss ({loss}):\ngradient: {grad:?}\nvariables: {vars:?}" + ))] + NonFiniteGrad { + grad: Tensor, + vars: Tensor, + loss: f64, + }, + #[snafu(display("Binary search could not reduce the search range further around {alpha}."))] + DeadBinarySearch { alpha: f64, summary: Summary }, + #[snafu(display( + "Cubic interpolation yielded a minimizer outside the zooming range:\ + (lo): {lo} :: {alpha} :: {hi} (hi)" + ))] + BrokenZoom { alpha: f64, hi: f64, lo: f64 }, + #[snafu(display("Zoom phase could not reduce the search range further around {alpha}."))] + DeadZoom { alpha: f64, summary: Summary }, + #[snafu(display( + "The zoom step length interval has possibly become \ + numerically flat around {alpha} (gradient: {grad:e})." + ))] + FlatZoom { + alpha: f64, + grad: f64, + summary: Summary, + }, + #[snafu(transparent)] + Trace { source: io::Error }, +} diff --git a/src/tree.rs b/src/tree.rs index 76db8b30616601ca9096489d453b3d25229a78e4..4eea031bae3cd17d41ad42dae37ca4b9569b6336 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -95,16 +95,6 @@ impl<B, IN, TN> Node<B, IN, TN> { N::Terminal(n) => &n.branch, } } - fn is_leaf(&self) -> bool { - use Node as N; - match self { - N::Internal(_) => false, - N::Terminal(_) => true, - } - } - fn is_internal(&self) -> bool { - !self.is_leaf() - } pub(crate) fn children(&self) -> Option<&[usize; 2]> { use Node as N; match self { diff --git a/src/tree/filter_node_to_leaves.rs b/src/tree/filter_node_to_leaves.rs index f35591f51e12cf0e96016a341b21a77d15ebedcf..2f916199c4b9e389274d62397af62438566f1b90 100644 --- a/src/tree/filter_node_to_leaves.rs +++ b/src/tree/filter_node_to_leaves.rs @@ -18,6 +18,7 @@ impl<B, IN, TN> Tree<B, IN, TN> { FilterNodeToLeaves::new(self, i_start, init, accumulator, terminator) } + #[cfg(test)] pub(crate) fn filter_root_to_leaves<'n, A, T, I, V>( &'n self, init: I, diff --git a/src/tree/node_to_leaves.rs b/src/tree/node_to_leaves.rs index 484bdaf9056c8b56f3799ddacd953414f5bed52d..96a29faf47c27887dcc5ff4aab164d5e5c369237 100644 --- a/src/tree/node_to_leaves.rs +++ b/src/tree/node_to_leaves.rs @@ -34,6 +34,7 @@ impl<B, IN, TN> Tree<B, IN, TN> { } // Start from the root. + #[cfg(test)] pub(crate) fn root_to_leaves<'t, A, T, I, V>( &'t self, init: I, @@ -47,6 +48,7 @@ impl<B, IN, TN> Tree<B, IN, TN> { { self.node_to_leaves(self.root(), init, accumulator, terminator) } + } // Accumulating iterator from root to leaves. diff --git a/src/tree/parse.rs b/src/tree/parse.rs index 856bc2620630261460554f6eb196b2da8632d875..1df2668c412bbada2404c9562e446528b6fd716b 100644 --- a/src/tree/parse.rs +++ b/src/tree/parse.rs @@ -97,10 +97,7 @@ impl<'i> Lexer<'i> { branch: B::default(), // Resolved later after the closing parenthesis. parent: i_parent, })); - unresolved.push(Unresolved { - i_parent: i_new_node, - i_child: 0, - }); + unresolved.push(Unresolved { i_parent: i_new_node, i_child: 0 }); continue 'n; } @@ -251,21 +248,22 @@ impl<'i> Lexer<'i> { } } -impl<B, IN, TN> Tree<B, IN, TN> -where - B: Default, - IN: Default, -{ - pub(crate) fn from_input<'i>( - input: &'i str, - reader: &mut impl TreeDataReader<'i, B, IN, TN>, - ) -> Result<Self, LexerError> { - Lexer::new(input).read_tree(reader) - } -} - #[cfg(test)] pub(crate) mod tests { + + impl<B, IN, TN> Tree<B, IN, TN> + where + B: Default, + IN: Default, + { + pub(crate) fn from_input<'i>( + input: &'i str, + reader: &mut impl TreeDataReader<'i, B, IN, TN>, + ) -> Result<Self, LexerError> { + Lexer::new(input).read_tree(reader) + } + } + use std::{fmt::Display, str::FromStr}; use super::*; @@ -328,7 +326,7 @@ pub(crate) mod tests { panic!("Unexpected successful parsing while expecting the following error: \ \n {}", $error); } - Err(e) => { assert_eq!($error, e.ref_message()); } + Err(e) => { assert_eq!($error, &format!("{e}")); } } };