flux_driver/collector/
annot_stats.rs

1use std::{collections::HashMap, fs, io};
2
3use flux_config as config;
4use rustc_ast::{DelimArgs, tokenstream::TokenTree};
5use rustc_hir::{AttrArgs, def_id::LOCAL_CRATE};
6use rustc_middle::ty::TyCtxt;
7use rustc_span::{Span, source_map::SourceMap};
8use serde::Serialize;
9
10#[derive(Default, Serialize)]
11pub struct Stats {
12    /// The number of times an attribute appears, e.g., how many times `flux::trusted` appears.
13    attr_count: HashMap<String, usize>,
14    /// The number of lines taken by each type of attribute, e.g., the sum of all lines taken by
15    /// `flux::sig` annotations.
16    loc_per_attr: HashMap<String, usize>,
17    /// This is the sum over `loc_per_attr`
18    loc: usize,
19}
20
21impl Stats {
22    pub fn save(&self, tcx: TyCtxt) -> io::Result<()> {
23        fs::create_dir_all(config::log_dir())?;
24        let crate_name = tcx.crate_name(LOCAL_CRATE);
25        let path = config::log_dir().join(format!("{crate_name}-annots.json"));
26        let mut file = fs::File::create(path)?;
27        serde_json::to_writer(&mut file, self)?;
28        Ok(())
29    }
30
31    pub fn add(&mut self, tcx: TyCtxt, name: &str, args: &AttrArgs) {
32        let sm = tcx.sess.source_map();
33        self.increase_count(name);
34        match args {
35            AttrArgs::Empty => {
36                self.increase_loc(name, 1);
37            }
38            AttrArgs::Delimited(delim_args) => {
39                self.increase_loc(name, count_lines(sm, delim_args));
40            }
41            AttrArgs::Eq { .. } => {}
42        }
43    }
44
45    fn increase_count(&mut self, name: &str) {
46        self.attr_count
47            .raw_entry_mut()
48            .from_key(name)
49            .and_modify(|_, v| *v += 1)
50            .or_insert_with(|| (name.to_string(), 1));
51    }
52
53    fn increase_loc(&mut self, name: &str, loc: usize) {
54        self.loc_per_attr
55            .raw_entry_mut()
56            .from_key(name)
57            .and_modify(|_, v| *v += loc)
58            .or_insert_with(|| (name.to_string(), loc));
59        self.loc += loc;
60    }
61}
62
63fn count_lines(sm: &SourceMap, delim_args: &DelimArgs) -> usize {
64    fn go<'a>(
65        sm: &SourceMap,
66        line_set: &mut IntervalSet,
67        tokens: impl Iterator<Item = &'a TokenTree>,
68    ) {
69        for t in tokens {
70            match t {
71                TokenTree::Token(token, _) => {
72                    let info = get_lines(sm, token.span);
73                    line_set.insert(info.start_line, info.end_line);
74                }
75                TokenTree::Delimited(delim_span, _, _, token_stream) => {
76                    let open_info = get_lines(sm, delim_span.open);
77                    let close_info = get_lines(sm, delim_span.close);
78                    line_set.insert(open_info.start_line, open_info.end_line);
79                    line_set.insert(close_info.start_line, close_info.end_line);
80                    go(sm, line_set, token_stream.iter());
81                }
82            }
83        }
84    }
85    let mut line_set = IntervalSet::new();
86    go(sm, &mut line_set, delim_args.tokens.iter());
87    let mut lines = 0;
88    for (start, end) in line_set.iter_intervals() {
89        lines += end - start + 1;
90    }
91    lines
92}
93
94#[expect(dead_code)]
95struct LineInfo {
96    start_line: usize,
97    start_col: usize,
98    end_line: usize,
99    end_col: usize,
100}
101
102fn get_lines(sm: &SourceMap, span: Span) -> LineInfo {
103    let lines = sm.span_to_location_info(span);
104    LineInfo { start_line: lines.1, start_col: lines.2, end_line: lines.3, end_col: lines.4 }
105}
106
107struct IntervalSet {
108    map: Vec<(usize, usize)>,
109}
110
111impl IntervalSet {
112    fn new() -> Self {
113        Self { map: vec![] }
114    }
115
116    fn insert(&mut self, start: usize, end: usize) {
117        if start > end {
118            return;
119        }
120
121        // This condition looks a bit weird, but actually makes sense.
122        //
123        // if r.0 == end + 1, then we're actually adjacent, so we want to
124        // continue to the next range. We're looking here for the first
125        // range which starts *non-adjacently* to our end.
126        let next = self.map.partition_point(|r| r.0 <= end + 1);
127
128        if let Some(right) = next.checked_sub(1) {
129            let (prev_start, prev_end) = self.map[right];
130            if prev_end + 1 >= start {
131                // If the start for the inserted range is adjacent to the
132                // end of the previous, we can extend the previous range.
133                if start < prev_start {
134                    // The first range which ends *non-adjacently* to our start.
135                    // And we can ensure that left <= right.
136
137                    let left = self.map.partition_point(|l| l.1 + 1 < start);
138                    let min = std::cmp::min(self.map[left].0, start);
139                    let max = std::cmp::max(prev_end, end);
140                    self.map[right] = (min, max);
141                    if left != right {
142                        self.map.drain(left..right);
143                    }
144                } else {
145                    // We overlap with the previous range, increase it to
146                    // include us.
147                    //
148                    // Make sure we're actually going to *increase* it though --
149                    // it may be that end is just inside the previously existing
150                    // set.
151                    if end > prev_end {
152                        self.map[right].1 = end;
153                    }
154                }
155            } else {
156                // Otherwise, we don't overlap, so just insert
157                self.map.insert(right + 1, (start, end));
158            }
159        } else if self.map.is_empty() {
160            // Quite common in practice, and expensive to call memcpy
161            // with length zero.
162            self.map.push((start, end));
163        } else {
164            self.map.insert(next, (start, end));
165        }
166    }
167
168    fn iter_intervals(&self) -> impl Iterator<Item = (usize, usize)> {
169        self.map.iter().copied()
170    }
171}