1use std::{
2    fs,
3    io::{self, Write as _},
4    sync::{Mutex, atomic::AtomicU32},
5    time::{Duration, Instant},
6};
7
8use flux_config as config;
9use itertools::Itertools;
10use rustc_hir::def_id::{DefId, LOCAL_CRATE, LocalDefId};
11use rustc_middle::ty::TyCtxt;
12use serde::Serialize;
13
14use crate::FixpointQueryKind;
15
16const BOLD: anstyle::Style = anstyle::Style::new().bold();
17const GREY: anstyle::Style = anstyle::AnsiColor::BrightBlack.on_default();
18
19pub fn print_summary(total_time: Duration) -> io::Result<()> {
20    let mut stderr = anstream::Stderr::always(std::io::stderr());
21    writeln!(
22        &mut stderr,
23        "{BOLD}summary.{BOLD:#} {} functions processed: {} checked; {} trusted; {} ignored; {} cached; {} trivial. {} constraints generated: {} errors. Finished in {}{GREY:#}",
24        METRICS.get(Metric::FnTotal),
25        METRICS.get(Metric::FnChecked),
26        METRICS.get(Metric::FnTrusted),
27        METRICS.get(Metric::FnIgnored),
28        METRICS.get(Metric::FnCached),
29        METRICS.get(Metric::FnTrivial),
30        METRICS.get(Metric::CsTotal),
31        METRICS.get(Metric::CsError),
32        fmt_duration(total_time),
33    )
34}
35
36static METRICS: Metrics = Metrics::new();
37
38#[repr(u8)]
39pub enum Metric {
40    FnTotal,
42    FnTrusted,
44    FnIgnored,
46    FnChecked,
48    FnCached,
50    FnTrivial,
52    CsTotal,
54    CsError,
56}
57
58struct Metrics {
59    counts: [AtomicU32; 8],
60}
61
62impl Metrics {
63    const fn new() -> Self {
64        Self {
65            counts: [
66                AtomicU32::new(0),
67                AtomicU32::new(0),
68                AtomicU32::new(0),
69                AtomicU32::new(0),
70                AtomicU32::new(0),
71                AtomicU32::new(0),
72                AtomicU32::new(0),
73                AtomicU32::new(0),
74            ],
75        }
76    }
77
78    fn incr(&self, metric: Metric, val: u32) {
79        self.counts[metric as usize].fetch_add(val, std::sync::atomic::Ordering::Relaxed);
80    }
81
82    fn get(&self, metric: Metric) -> u32 {
83        self.counts[metric as usize].load(std::sync::atomic::Ordering::Relaxed)
84    }
85}
86
87pub fn incr_metric(metric: Metric, val: u32) {
88    METRICS.incr(metric, val);
89}
90
91pub fn incr_metric_if(cond: bool, metric: Metric) {
92    if cond {
93        incr_metric(metric, 1);
94    }
95}
96
97static TIMINGS: Mutex<Vec<Entry>> = Mutex::new(Vec::new());
98
99pub enum TimingKind {
100    Total,
102    CheckFn(LocalDefId),
104    FixpointQuery(DefId, FixpointQueryKind),
106}
107
108#[derive(Serialize)]
109struct TimingsDump {
110    total: ms,
112    functions: Vec<FuncTiming>,
114    queries: Vec<QueryTiming>,
116}
117
118#[derive(Serialize)]
119struct FuncTiming {
120    def_path: String,
121    time_ms: ms,
122}
123
124#[derive(Serialize)]
125struct QueryTiming {
126    task_key: String,
127    time_ms: ms,
128}
129
130fn snd<A, B: Copy>(&(_, b): &(A, B)) -> B {
131    b
132}
133
134pub fn print_and_dump_timings(tcx: TyCtxt) -> io::Result<()> {
135    if !config::timings() {
136        return Ok(());
137    }
138
139    let timings = std::mem::take(&mut *TIMINGS.lock().unwrap());
140    let mut functions = vec![];
141    let mut queries = vec![];
142    let mut total = Duration::from_secs(0);
143    for timing in timings {
144        match timing.kind {
145            TimingKind::CheckFn(local_def_id) => {
146                let def_path = tcx.def_path_str(local_def_id);
147                functions.push((def_path, timing.duration));
148            }
149            TimingKind::FixpointQuery(def_id, kind) => {
150                let key = kind.task_key(tcx, def_id);
151                queries.push((key, timing.duration));
152            }
153            TimingKind::Total => {
154                total = timing.duration;
156            }
157        }
158    }
159    functions.sort_by_key(snd);
160    functions.reverse();
161
162    queries.sort_by_key(snd);
163    queries.reverse();
164
165    print_report(&functions, total);
166    dump_timings(
167        tcx,
168        TimingsDump {
169            total: ms(total),
170            functions: functions
171                .into_iter()
172                .map(|(def_path, time)| FuncTiming { def_path, time_ms: ms(time) })
173                .collect(),
174            queries: queries
175                .into_iter()
176                .map(|(task_key, time)| QueryTiming { task_key, time_ms: ms(time) })
177                .collect(),
178        },
179    )
180}
181
182fn print_report(functions: &[(String, Duration)], total: Duration) {
183    let stats = stats(&functions.iter().map(snd).collect_vec());
184    eprintln!();
185    eprintln!("───────────────────── Timing Report ────────────────────────");
186    eprintln!("Total running time: {:>40}", fmt_duration(total));
187    eprintln!("Functions checked:  {:>40}", stats.count);
188    eprintln!("Min:                {:>40}", fmt_duration(stats.min));
189    eprintln!("Max:                {:>40}", fmt_duration(stats.max));
190    eprintln!("Mean:               {:>40}", fmt_duration(stats.mean));
191    eprintln!("Std. Dev.:          {:>40}", fmt_duration(stats.standard_deviation));
192
193    let top5 = functions.iter().take(5).cloned().collect_vec();
194    if !top5.is_empty() {
195        eprintln!("────────────────────────────────────────────────────────────");
196        eprintln!("Top 5 Functions ");
197        for (def_path, duration) in top5 {
198            let len = def_path.len();
199            if len > 46 {
200                eprintln!(
201                    "• …{} {:>width$}",
202                    &def_path[len - 46..],
203                    fmt_duration(duration),
204                    width = 10
205                );
206            } else {
207                eprintln!(
208                    "• {def_path} {:>width$}",
209                    fmt_duration(duration),
210                    width = 60 - def_path.len() - 3
211                );
212            }
213        }
214    }
215    eprintln!("────────────────────────────────────────────────────────────");
216}
217
218fn dump_timings(tcx: TyCtxt, timings: TimingsDump) -> io::Result<()> {
219    let crate_name = tcx.crate_name(LOCAL_CRATE);
220    fs::create_dir_all(config::log_dir())?;
221    let path = config::log_dir().join(format!("{crate_name}-timings.json"));
222    let mut file = fs::File::create(path)?;
223    serde_json::to_writer(&mut file, &timings)?;
224    Ok(())
225}
226
227pub fn time_it<R>(kind: TimingKind, f: impl FnOnce() -> R) -> R {
228    if !config::timings() {
229        return f();
230    }
231    let start = Instant::now();
232    let r = f();
233    TIMINGS
234        .lock()
235        .unwrap()
236        .push(Entry { duration: start.elapsed(), kind });
237    r
238}
239
240fn stats(durations: &[Duration]) -> TimingStats {
241    let count = durations.len() as u32;
242    if count == 0 {
243        return TimingStats::default();
244    }
245    let sum: Duration = durations.iter().sum();
246    let mean = sum / count;
247
248    let meanf = mean.as_millis() as f64;
249    let mut sum_of_squares = 0.0;
250    let mut max = Duration::ZERO;
251    let mut min = Duration::MAX;
252    for duration in durations {
253        let diff = duration.as_millis() as f64 - meanf;
254        sum_of_squares += diff * diff;
255        max = max.max(*duration);
256        min = min.min(*duration);
257    }
258    let standard_deviation = Duration::from_millis((sum_of_squares / count as f64).sqrt() as u64);
259
260    TimingStats { count, max, min, mean, standard_deviation }
261}
262
263#[derive(Default)]
264struct TimingStats {
265    count: u32,
266    max: Duration,
267    min: Duration,
268    mean: Duration,
269    standard_deviation: Duration,
270}
271
272struct Entry {
273    duration: Duration,
274    kind: TimingKind,
275}
276
277#[allow(non_camel_case_types)]
278#[derive(Clone, Copy, Serialize)]
279#[serde(into = "u128")]
280struct ms(Duration);
281
282impl From<ms> for u128 {
283    fn from(value: ms) -> Self {
284        value.0.as_millis()
285    }
286}
287
288fn fmt_duration(duration: Duration) -> String {
289    let nanos = duration.as_nanos();
290
291    if nanos < 1_000 {
292        format!("{nanos}ns")
293    } else if nanos < 1_000_000 {
294        format!("{:.2}µs", nanos as f64 / 1_000.0)
295    } else if nanos < 1_000_000_000 {
296        format!("{:.2}ms", nanos as f64 / 1_000_000.0)
297    } else if nanos < 60_000_000_000 {
298        format!("{:.2}s", nanos as f64 / 1_000_000_000.0)
299    } else {
300        let seconds = duration.as_secs();
301        let minutes = seconds / 60;
302        let seconds_remainder = seconds % 60;
303
304        if minutes < 60 {
305            format!("{minutes}m {seconds_remainder}s")
306        } else {
307            let hours = minutes / 60;
308            let minutes_remainder = minutes % 60;
309
310            if hours < 24 {
311                format!("{hours}h {minutes_remainder}m {seconds_remainder}s")
312            } else {
313                let days = hours / 24;
314                let hours_remainder = hours % 24;
315                format!("{days}d {hours_remainder}h {minutes_remainder}m {seconds_remainder}s",)
316            }
317        }
318    }
319}