flux_desugar/resolver/
refinement_resolver.rs

1use std::ops::ControlFlow;
2
3use flux_common::index::IndexGen;
4use flux_errors::Errors;
5use flux_middle::{
6    ResolverOutput,
7    fhir::{self, Res},
8};
9use flux_syntax::{
10    surface::{self, Ident, NodeId, visit::Visitor as _},
11    symbols::sym,
12    walk_list,
13};
14use rustc_data_structures::{
15    fx::{FxIndexMap, FxIndexSet, IndexEntry},
16    unord::UnordMap,
17};
18use rustc_hash::FxHashMap;
19use rustc_hir::def::{
20    DefKind,
21    Namespace::{TypeNS, ValueNS},
22};
23use rustc_middle::ty::TyCtxt;
24use rustc_span::{ErrorGuaranteed, Symbol};
25
26use super::{CrateResolver, Segment};
27
28type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
29
30#[derive(Clone, Copy, PartialEq, Eq, Debug)]
31pub(crate) enum ScopeKind {
32    FnInput,
33    FnOutput,
34    Variant,
35    Misc,
36    FnTraitInput,
37}
38
39impl ScopeKind {
40    fn is_barrier(self) -> bool {
41        matches!(self, ScopeKind::FnInput | ScopeKind::Variant)
42    }
43}
44
45/// Parameters used during gathering.
46#[derive(Debug, Clone, Copy)]
47struct ParamRes(fhir::ParamKind, NodeId);
48
49impl ParamRes {
50    fn kind(self) -> fhir::ParamKind {
51        self.0
52    }
53
54    fn param_id(self) -> NodeId {
55        self.1
56    }
57}
58
59pub(crate) trait ScopedVisitor: Sized {
60    fn is_box(&self, segment: &surface::PathSegment) -> bool;
61    fn enter_scope(&mut self, kind: ScopeKind) -> ControlFlow<()>;
62    fn exit_scope(&mut self) {}
63
64    fn wrap(self) -> ScopedVisitorWrapper<Self> {
65        ScopedVisitorWrapper(self)
66    }
67
68    fn on_implicit_param(&mut self, _ident: Ident, _kind: fhir::ParamKind, _node_id: NodeId) {}
69    fn on_generic_param(&mut self, _param: &surface::GenericParam) {}
70    fn on_refine_param(&mut self, _param: &surface::RefineParam) {}
71    fn on_enum_variant(&mut self, _variant: &surface::VariantDef) {}
72    fn on_fn_trait_input(&mut self, _in_arg: &surface::GenericArg, _node_id: NodeId) {}
73    fn on_fn_sig(&mut self, _fn_sig: &surface::FnSig) {}
74    fn on_fn_output(&mut self, _output: &surface::FnOutput) {}
75    fn on_loc(&mut self, _loc: Ident, _node_id: NodeId) {}
76    fn on_path(&mut self, _path: &surface::ExprPath) {}
77    fn on_base_sort(&mut self, _sort: &surface::BaseSort) {}
78}
79
80pub(crate) struct ScopedVisitorWrapper<V>(V);
81
82impl<V: ScopedVisitor> ScopedVisitorWrapper<V> {
83    fn with_scope(&mut self, kind: ScopeKind, f: impl FnOnce(&mut Self)) {
84        let scope = self.0.enter_scope(kind);
85        if let ControlFlow::Continue(_) = scope {
86            f(self);
87            self.0.exit_scope();
88        }
89    }
90}
91
92impl<V> std::ops::Deref for ScopedVisitorWrapper<V> {
93    type Target = V;
94
95    fn deref(&self) -> &Self::Target {
96        &self.0
97    }
98}
99impl<V> std::ops::DerefMut for ScopedVisitorWrapper<V> {
100    fn deref_mut(&mut self) -> &mut Self::Target {
101        &mut self.0
102    }
103}
104
105impl<V: ScopedVisitor> surface::visit::Visitor for ScopedVisitorWrapper<V> {
106    fn visit_trait_assoc_reft(&mut self, assoc_reft: &surface::TraitAssocReft) {
107        self.with_scope(ScopeKind::Misc, |this| {
108            surface::visit::walk_trait_assoc_reft(this, assoc_reft);
109        });
110    }
111
112    fn visit_impl_assoc_reft(&mut self, assoc_reft: &surface::ImplAssocReft) {
113        self.with_scope(ScopeKind::Misc, |this| {
114            surface::visit::walk_impl_assoc_reft(this, assoc_reft);
115        });
116    }
117
118    fn visit_qualifier(&mut self, qualifier: &surface::Qualifier) {
119        self.with_scope(ScopeKind::Misc, |this| {
120            surface::visit::walk_qualifier(this, qualifier);
121        });
122    }
123
124    fn visit_defn(&mut self, defn: &surface::SpecFunc) {
125        self.with_scope(ScopeKind::Misc, |this| {
126            surface::visit::walk_defn(this, defn);
127        });
128    }
129
130    fn visit_primop_prop(&mut self, prop: &surface::PrimOpProp) {
131        self.with_scope(ScopeKind::Misc, |this| {
132            surface::visit::walk_primop_prop(this, prop);
133        });
134    }
135
136    fn visit_generic_param(&mut self, param: &surface::GenericParam) {
137        self.on_generic_param(param);
138        surface::visit::walk_generic_param(self, param);
139    }
140
141    fn visit_refine_param(&mut self, param: &surface::RefineParam) {
142        self.on_refine_param(param);
143        surface::visit::walk_refine_param(self, param);
144    }
145
146    fn visit_ty_alias(&mut self, ty_alias: &surface::TyAlias) {
147        self.with_scope(ScopeKind::Misc, |this| {
148            surface::visit::walk_ty_alias(this, ty_alias);
149        });
150    }
151
152    fn visit_struct_def(&mut self, struct_def: &surface::StructDef) {
153        self.with_scope(ScopeKind::Misc, |this| {
154            surface::visit::walk_struct_def(this, struct_def);
155        });
156    }
157
158    fn visit_enum_def(&mut self, enum_def: &surface::EnumDef) {
159        self.with_scope(ScopeKind::Misc, |this| {
160            surface::visit::walk_enum_def(this, enum_def);
161        });
162    }
163
164    fn visit_variant(&mut self, variant: &surface::VariantDef) {
165        self.with_scope(ScopeKind::Variant, |this| {
166            this.on_enum_variant(variant);
167            surface::visit::walk_variant(this, variant);
168        });
169    }
170
171    fn visit_trait_ref(&mut self, trait_ref: &surface::TraitRef) {
172        match trait_ref.as_fn_trait_ref() {
173            Some((in_arg, out_arg)) => {
174                self.with_scope(ScopeKind::FnTraitInput, |this| {
175                    this.on_fn_trait_input(in_arg, trait_ref.node_id);
176                    surface::visit::walk_generic_arg(this, in_arg);
177                    this.with_scope(ScopeKind::Misc, |this| {
178                        surface::visit::walk_generic_arg(this, out_arg);
179                    });
180                });
181            }
182            None => {
183                self.with_scope(ScopeKind::Misc, |this| {
184                    surface::visit::walk_trait_ref(this, trait_ref);
185                });
186            }
187        }
188    }
189
190    fn visit_variant_ret(&mut self, ret: &surface::VariantRet) {
191        self.with_scope(ScopeKind::Misc, |this| {
192            surface::visit::walk_variant_ret(this, ret);
193        });
194    }
195
196    fn visit_generics(&mut self, generics: &surface::Generics) {
197        self.with_scope(ScopeKind::Misc, |this| {
198            surface::visit::walk_generics(this, generics);
199        });
200    }
201
202    fn visit_fn_sig(&mut self, fn_sig: &surface::FnSig) {
203        self.with_scope(ScopeKind::FnInput, |this| {
204            this.on_fn_sig(fn_sig);
205            surface::visit::walk_fn_sig(this, fn_sig);
206        });
207    }
208
209    fn visit_fn_output(&mut self, output: &surface::FnOutput) {
210        self.with_scope(ScopeKind::FnOutput, |this| {
211            this.on_fn_output(output);
212            surface::visit::walk_fn_output(this, output);
213        });
214    }
215
216    fn visit_fn_input(&mut self, arg: &surface::FnInput) {
217        match arg {
218            surface::FnInput::Constr(bind, _, _, node_id) => {
219                self.on_implicit_param(*bind, fhir::ParamKind::Colon, *node_id);
220            }
221            surface::FnInput::StrgRef(loc, _, node_id) => {
222                self.on_implicit_param(*loc, fhir::ParamKind::Loc, *node_id);
223            }
224            surface::FnInput::Ty(bind, ty, node_id) => {
225                if let &Some(bind) = bind {
226                    let param_kind = if let surface::TyKind::Base(_) = &ty.kind {
227                        fhir::ParamKind::Colon
228                    } else {
229                        fhir::ParamKind::Error
230                    };
231                    self.on_implicit_param(bind, param_kind, *node_id);
232                }
233            }
234        }
235        surface::visit::walk_fn_input(self, arg);
236    }
237
238    fn visit_ensures(&mut self, constraint: &surface::Ensures) {
239        if let surface::Ensures::Type(loc, _, node_id) = constraint {
240            self.on_loc(*loc, *node_id);
241        }
242        surface::visit::walk_ensures(self, constraint);
243    }
244
245    fn visit_refine_arg(&mut self, arg: &surface::RefineArg) {
246        match arg {
247            surface::RefineArg::Bind(ident, kind, _, node_id) => {
248                let kind = match kind {
249                    surface::BindKind::At => fhir::ParamKind::At,
250                    surface::BindKind::Pound => fhir::ParamKind::Pound,
251                };
252                self.on_implicit_param(*ident, kind, *node_id);
253            }
254            surface::RefineArg::Abs(..) => {
255                self.with_scope(ScopeKind::Misc, |this| {
256                    surface::visit::walk_refine_arg(this, arg);
257                });
258            }
259            surface::RefineArg::Expr(expr) => self.visit_expr(expr),
260        }
261    }
262
263    fn visit_path(&mut self, path: &surface::Path) {
264        for arg in &path.refine {
265            self.with_scope(ScopeKind::Misc, |this| this.visit_refine_arg(arg));
266        }
267        walk_list!(self, visit_path_segment, &path.segments);
268    }
269
270    fn visit_path_segment(&mut self, segment: &surface::PathSegment) {
271        let is_box = self.is_box(segment);
272        for (i, arg) in segment.args.iter().enumerate() {
273            if is_box && i == 0 {
274                self.visit_generic_arg(arg);
275            } else {
276                self.with_scope(ScopeKind::Misc, |this| this.visit_generic_arg(arg));
277            }
278        }
279    }
280
281    fn visit_ty(&mut self, ty: &surface::Ty) {
282        let node_id = ty.node_id;
283        match &ty.kind {
284            surface::TyKind::Exists { bind, .. } => {
285                self.with_scope(ScopeKind::Misc, |this| {
286                    let param = surface::RefineParam {
287                        ident: *bind,
288                        mode: None,
289                        sort: surface::Sort::Infer,
290                        node_id,
291                        span: bind.span,
292                    };
293                    this.on_refine_param(&param);
294                    surface::visit::walk_ty(this, ty);
295                });
296            }
297            surface::TyKind::GeneralExists { .. } => {
298                self.with_scope(ScopeKind::Misc, |this| {
299                    surface::visit::walk_ty(this, ty);
300                });
301            }
302            surface::TyKind::Array(..) => {
303                self.with_scope(ScopeKind::Misc, |this| {
304                    surface::visit::walk_ty(this, ty);
305                });
306            }
307            _ => surface::visit::walk_ty(self, ty),
308        }
309    }
310
311    fn visit_bty(&mut self, bty: &surface::BaseTy) {
312        match &bty.kind {
313            surface::BaseTyKind::Slice(_) => {
314                self.with_scope(ScopeKind::Misc, |this| {
315                    surface::visit::walk_bty(this, bty);
316                });
317            }
318            surface::BaseTyKind::Path(..) => {
319                surface::visit::walk_bty(self, bty);
320            }
321        }
322    }
323
324    fn visit_path_expr(&mut self, path: &surface::ExprPath) {
325        self.on_path(path);
326    }
327
328    fn visit_base_sort(&mut self, bsort: &surface::BaseSort) {
329        self.on_base_sort(bsort);
330        surface::visit::walk_base_sort(self, bsort);
331    }
332}
333
334struct ImplicitParamCollector<'a, 'tcx> {
335    tcx: TyCtxt<'tcx>,
336    path_res_map: &'a UnordMap<surface::NodeId, fhir::PartialRes>,
337    kind: ScopeKind,
338    params: Vec<(Ident, fhir::ParamKind, NodeId)>,
339}
340
341impl<'a, 'tcx> ImplicitParamCollector<'a, 'tcx> {
342    fn new(
343        tcx: TyCtxt<'tcx>,
344        path_res_map: &'a UnordMap<surface::NodeId, fhir::PartialRes>,
345        kind: ScopeKind,
346    ) -> Self {
347        Self { tcx, path_res_map, kind, params: vec![] }
348    }
349
350    fn run(
351        self,
352        f: impl FnOnce(&mut ScopedVisitorWrapper<Self>),
353    ) -> Vec<(Ident, fhir::ParamKind, NodeId)> {
354        let mut wrapped = self.wrap();
355        f(&mut wrapped);
356        wrapped.0.params
357    }
358}
359
360impl ScopedVisitor for ImplicitParamCollector<'_, '_> {
361    fn is_box(&self, segment: &surface::PathSegment) -> bool {
362        self.path_res_map
363            .get(&segment.node_id)
364            .map(|r| r.is_box(self.tcx))
365            .unwrap_or(false)
366    }
367
368    fn enter_scope(&mut self, kind: ScopeKind) -> ControlFlow<()> {
369        if self.kind == kind { ControlFlow::Continue(()) } else { ControlFlow::Break(()) }
370    }
371
372    fn on_implicit_param(&mut self, ident: Ident, param: fhir::ParamKind, node_id: NodeId) {
373        self.params.push((ident, param, node_id));
374    }
375}
376
377struct Scope {
378    kind: ScopeKind,
379    bindings: FxIndexMap<Ident, ParamRes>,
380}
381
382impl Scope {
383    fn new(kind: ScopeKind) -> Self {
384        Self { kind, bindings: Default::default() }
385    }
386}
387
388#[derive(Clone, Copy)]
389struct ParamDef {
390    ident: Ident,
391    kind: fhir::ParamKind,
392    scope: Option<NodeId>,
393}
394
395pub(crate) struct RefinementResolver<'a, 'genv, 'tcx> {
396    scopes: Vec<Scope>,
397    sort_params: FxIndexSet<Symbol>,
398    param_defs: FxIndexMap<NodeId, ParamDef>,
399    resolver: &'a mut CrateResolver<'genv, 'tcx>,
400    path_res_map: FxHashMap<NodeId, Res<NodeId>>,
401    errors: Errors<'genv>,
402}
403
404impl<'a, 'genv, 'tcx> RefinementResolver<'a, 'genv, 'tcx> {
405    pub(crate) fn for_flux_item(
406        resolver: &'a mut CrateResolver<'genv, 'tcx>,
407        sort_params: &[Ident],
408    ) -> Self {
409        Self::new(resolver, sort_params.iter().map(|ident| ident.name).collect())
410    }
411
412    pub(crate) fn for_rust_item(resolver: &'a mut CrateResolver<'genv, 'tcx>) -> Self {
413        Self::new(resolver, Default::default())
414    }
415
416    pub(crate) fn resolve_qualifier(
417        resolver: &'a mut CrateResolver<'genv, 'tcx>,
418        qualifier: &surface::Qualifier,
419    ) -> Result {
420        Self::for_flux_item(resolver, &[]).run(|r| r.visit_qualifier(qualifier))
421    }
422
423    pub(crate) fn resolve_defn(
424        resolver: &'a mut CrateResolver<'genv, 'tcx>,
425        defn: &surface::SpecFunc,
426    ) -> Result {
427        Self::for_flux_item(resolver, &defn.sort_vars).run(|r| r.visit_defn(defn))
428    }
429
430    pub(crate) fn resolve_primop_prop(
431        resolver: &'a mut CrateResolver<'genv, 'tcx>,
432        prop: &surface::PrimOpProp,
433    ) -> Result {
434        Self::for_flux_item(resolver, &[]).run(|r| r.visit_primop_prop(prop))
435    }
436
437    pub(crate) fn resolve_fn_sig(
438        resolver: &'a mut CrateResolver<'genv, 'tcx>,
439        fn_sig: &surface::FnSig,
440    ) -> Result {
441        IllegalBinderVisitor::new(resolver).run(|vis| vis.visit_fn_sig(fn_sig))?;
442        Self::for_rust_item(resolver).run(|vis| vis.visit_fn_sig(fn_sig))
443    }
444
445    pub(crate) fn resolve_struct_def(
446        resolver: &'a mut CrateResolver<'genv, 'tcx>,
447        struct_def: &surface::StructDef,
448    ) -> Result {
449        IllegalBinderVisitor::new(resolver).run(|vis| vis.visit_struct_def(struct_def))?;
450        Self::for_rust_item(resolver).run(|vis| vis.visit_struct_def(struct_def))
451    }
452
453    pub(crate) fn resolve_enum_def(
454        resolver: &'a mut CrateResolver<'genv, 'tcx>,
455        enum_def: &surface::EnumDef,
456    ) -> Result {
457        IllegalBinderVisitor::new(resolver).run(|vis| vis.visit_enum_def(enum_def))?;
458        Self::for_rust_item(resolver).run(|vis| vis.visit_enum_def(enum_def))
459    }
460
461    pub(crate) fn resolve_constant(
462        resolver: &'a mut CrateResolver<'genv, 'tcx>,
463        constant_info: &surface::ConstantInfo,
464    ) -> Result {
465        IllegalBinderVisitor::new(resolver).run(|vis| vis.visit_constant(constant_info))?;
466        Self::for_rust_item(resolver).run(|vis| vis.visit_constant(constant_info))
467    }
468
469    pub(crate) fn resolve_ty_alias(
470        resolver: &'a mut CrateResolver<'genv, 'tcx>,
471        ty_alias: &surface::TyAlias,
472    ) -> Result {
473        IllegalBinderVisitor::new(resolver).run(|vis| vis.visit_ty_alias(ty_alias))?;
474        Self::for_rust_item(resolver).run(|vis| vis.visit_ty_alias(ty_alias))
475    }
476
477    pub(crate) fn resolve_impl(
478        resolver: &'a mut CrateResolver<'genv, 'tcx>,
479        impl_: &surface::Impl,
480    ) -> Result {
481        IllegalBinderVisitor::new(resolver).run(|vis| vis.visit_impl(impl_))?;
482        Self::for_rust_item(resolver).run(|vis| vis.visit_impl(impl_))
483    }
484
485    pub(crate) fn resolve_trait(
486        resolver: &'a mut CrateResolver<'genv, 'tcx>,
487        trait_: &surface::Trait,
488    ) -> Result {
489        IllegalBinderVisitor::new(resolver).run(|vis| vis.visit_trait(trait_))?;
490        Self::for_rust_item(resolver).run(|vis| vis.visit_trait(trait_))
491    }
492
493    fn new(resolver: &'a mut CrateResolver<'genv, 'tcx>, sort_params: FxIndexSet<Symbol>) -> Self {
494        let errors = Errors::new(resolver.genv.sess());
495        Self {
496            resolver,
497            sort_params,
498            param_defs: Default::default(),
499            scopes: Default::default(),
500            path_res_map: Default::default(),
501            errors,
502        }
503    }
504
505    fn run(self, f: impl FnOnce(&mut ScopedVisitorWrapper<Self>)) -> Result {
506        let mut wrapper = self.wrap();
507        f(&mut wrapper);
508        wrapper.0.finish()
509    }
510
511    fn define_param(
512        &mut self,
513        ident: Ident,
514        kind: fhir::ParamKind,
515        param_id: NodeId,
516        scope: Option<NodeId>,
517    ) {
518        self.param_defs
519            .insert(param_id, ParamDef { ident, kind, scope });
520
521        let scope = self.scopes.last_mut().unwrap();
522        match scope.bindings.entry(ident) {
523            IndexEntry::Occupied(entry) => {
524                let param_def = self.param_defs[&entry.get().param_id()];
525                self.errors
526                    .emit(errors::DuplicateParam::new(param_def.ident, ident));
527            }
528            IndexEntry::Vacant(entry) => {
529                entry.insert(ParamRes(kind, param_id));
530            }
531        }
532    }
533
534    fn find(&mut self, ident: Ident) -> Option<ParamRes> {
535        for scope in self.scopes.iter().rev() {
536            if let Some(res) = scope.bindings.get(&ident) {
537                return Some(*res);
538            }
539
540            if scope.kind.is_barrier() {
541                return None;
542            }
543        }
544        None
545    }
546
547    fn resolve_path(&mut self, path: &surface::ExprPath) {
548        if let [segment] = &path.segments[..]
549            && let Some(res) = self.try_resolve_param(segment.ident)
550        {
551            self.path_res_map.insert(path.node_id, res);
552            return;
553        }
554        if let Some(res) = self.try_resolve_expr_with_ribs(&path.segments) {
555            self.path_res_map.insert(path.node_id, res);
556            return;
557        }
558        // TODO(nilehmann) move this to resolve_with_ribs
559        if let [typ, name] = &path.segments[..]
560            && let Some(res) = resolve_num_const(typ.ident, name.ident)
561        {
562            self.path_res_map.insert(path.node_id, res);
563            return;
564        }
565        if let [segment] = &path.segments[..]
566            && let Some(res) = self.try_resolve_global_func(segment.ident)
567        {
568            self.path_res_map.insert(path.node_id, res);
569            return;
570        }
571
572        self.errors.emit(errors::UnresolvedVar::from_path(path));
573    }
574
575    fn resolve_ident(&mut self, ident: Ident, node_id: NodeId) {
576        if let Some(res) = self.try_resolve_param(ident) {
577            self.path_res_map.insert(node_id, res);
578            return;
579        }
580        if let Some(res) = self.try_resolve_expr_with_ribs(&[ident]) {
581            self.path_res_map.insert(node_id, res);
582            return;
583        }
584        if let Some(res) = self.try_resolve_global_func(ident) {
585            self.path_res_map.insert(node_id, res);
586            return;
587        }
588        self.errors.emit(errors::UnresolvedVar::from_ident(ident));
589    }
590
591    fn try_resolve_expr_with_ribs<S: Segment>(&mut self, segments: &[S]) -> Option<Res<NodeId>> {
592        if let Some(partial_res) = self.resolver.resolve_path_with_ribs(segments, ValueNS) {
593            return partial_res.full_res().map(|r| r.map_param_id(|p| p));
594        }
595
596        self.resolver
597            .resolve_path_with_ribs(segments, TypeNS)?
598            .full_res()
599            .map(|r| r.map_param_id(|p| p))
600    }
601
602    fn try_resolve_param(&mut self, ident: Ident) -> Option<Res<NodeId>> {
603        let res = self.find(ident)?;
604
605        if let fhir::ParamKind::Error = res.kind() {
606            self.errors.emit(errors::InvalidUnrefinedParam::new(ident));
607        }
608        Some(Res::Param(res.kind(), res.param_id()))
609    }
610
611    fn try_resolve_global_func(&mut self, ident: Ident) -> Option<Res<NodeId>> {
612        let kind = self.resolver.func_decls.get(&ident.name)?;
613        Some(Res::GlobalFunc(*kind))
614    }
615
616    fn resolve_sort_path(&mut self, path: &surface::SortPath) {
617        let res = self
618            .try_resolve_sort_param(path)
619            .or_else(|| self.try_resolve_sort_with_ribs(path))
620            .or_else(|| self.try_resolve_user_sort(path))
621            .or_else(|| self.try_resolve_prim_sort(path));
622
623        if let Some(res) = res {
624            self.resolver
625                .output
626                .sort_path_res_map
627                .insert(path.node_id, res);
628        } else {
629            self.errors.emit(errors::UnresolvedSort::new(path));
630        }
631    }
632
633    fn try_resolve_sort_param(&self, path: &surface::SortPath) -> Option<fhir::SortRes> {
634        let [segment] = &path.segments[..] else { return None };
635        self.sort_params
636            .get_index_of(&segment.name)
637            .map(fhir::SortRes::SortParam)
638    }
639
640    fn try_resolve_sort_with_ribs(&mut self, path: &surface::SortPath) -> Option<fhir::SortRes> {
641        let partial_res = self
642            .resolver
643            .resolve_path_with_ribs(&path.segments, TypeNS)?;
644        match (partial_res.base_res(), partial_res.unresolved_segments()) {
645            (fhir::Res::Def(DefKind::Struct | DefKind::Enum, def_id), 0) => {
646                Some(fhir::SortRes::Adt(def_id))
647            }
648            (fhir::Res::Def(DefKind::TyParam, def_id), 0) => Some(fhir::SortRes::TyParam(def_id)),
649            (fhir::Res::SelfTyParam { trait_ }, 0) => {
650                Some(fhir::SortRes::SelfParam { trait_id: trait_ })
651            }
652            (fhir::Res::SelfTyParam { trait_ }, 1) => {
653                let ident = *path.segments.last().unwrap();
654                Some(fhir::SortRes::SelfParamAssoc { trait_id: trait_, ident })
655            }
656            (fhir::Res::SelfTyAlias { alias_to, .. }, 0) => {
657                Some(fhir::SortRes::SelfAlias { alias_to })
658            }
659            _ => None,
660        }
661    }
662
663    fn try_resolve_user_sort(&self, path: &surface::SortPath) -> Option<fhir::SortRes> {
664        let [segment] = &path.segments[..] else { return None };
665        self.resolver
666            .sort_decls
667            .get(&segment.name)
668            .map(|decl| fhir::SortRes::User { name: decl.name })
669    }
670
671    fn try_resolve_prim_sort(&self, path: &surface::SortPath) -> Option<fhir::SortRes> {
672        let [segment] = &path.segments[..] else { return None };
673        if segment.name == sym::int {
674            Some(fhir::SortRes::PrimSort(fhir::PrimSort::Int))
675        } else if segment.name == sym::bool {
676            Some(fhir::SortRes::PrimSort(fhir::PrimSort::Bool))
677        } else if segment.name == sym::char {
678            Some(fhir::SortRes::PrimSort(fhir::PrimSort::Char))
679        } else if segment.name == sym::real {
680            Some(fhir::SortRes::PrimSort(fhir::PrimSort::Real))
681        } else if segment.name == sym::Set {
682            Some(fhir::SortRes::PrimSort(fhir::PrimSort::Set))
683        } else if segment.name == sym::Map {
684            Some(fhir::SortRes::PrimSort(fhir::PrimSort::Map))
685        } else if segment.name == sym::str {
686            Some(fhir::SortRes::PrimSort(fhir::PrimSort::Str))
687        } else {
688            None
689        }
690    }
691
692    pub(crate) fn finish(self) -> Result {
693        let param_id_gen = IndexGen::new();
694        let mut params = FxIndexMap::default();
695
696        // Create an `fhir::ParamId` for all parameters used in a path before iterating over
697        // `param_defs` such that we can skip `fhir::ParamKind::Colon` if the param wasn't used
698        for (node_id, res) in self.path_res_map {
699            let res = res.map_param_id(|param_id| {
700                *params
701                    .entry(param_id)
702                    .or_insert_with(|| param_id_gen.fresh())
703            });
704            self.resolver.output.expr_path_res_map.insert(node_id, res);
705        }
706
707        // At this point, the `params` map contains all parameters that were used in an expression,
708        // so we can safely skip `ParamKind::Colon` if there's no entry for it.
709        for (param_id, param_def) in self.param_defs {
710            let name = match param_def.kind {
711                fhir::ParamKind::Colon => {
712                    let Some(name) = params.get(&param_id) else { continue };
713                    *name
714                }
715                fhir::ParamKind::Error => continue,
716                _ => {
717                    params
718                        .get(&param_id)
719                        .copied()
720                        .unwrap_or_else(|| param_id_gen.fresh())
721                }
722            };
723            let output = &mut self.resolver.output;
724            output
725                .param_res_map
726                .insert(param_id, (name, param_def.kind));
727
728            if let Some(scope) = param_def.scope {
729                output
730                    .implicit_params
731                    .entry(scope)
732                    .or_default()
733                    .push((param_def.ident, param_id));
734            }
735        }
736        self.errors.into_result()
737    }
738
739    fn resolver_output(&self) -> &ResolverOutput {
740        &self.resolver.output
741    }
742}
743
744impl ScopedVisitor for RefinementResolver<'_, '_, '_> {
745    fn is_box(&self, segment: &surface::PathSegment) -> bool {
746        self.resolver_output()
747            .path_res_map
748            .get(&segment.node_id)
749            .map(|r| r.is_box(self.resolver.genv.tcx()))
750            .unwrap_or(false)
751    }
752
753    fn enter_scope(&mut self, kind: ScopeKind) -> ControlFlow<()> {
754        self.scopes.push(Scope::new(kind));
755        ControlFlow::Continue(())
756    }
757
758    fn exit_scope(&mut self) {
759        self.scopes.pop();
760    }
761
762    fn on_fn_trait_input(&mut self, in_arg: &surface::GenericArg, trait_node_id: NodeId) {
763        let params = ImplicitParamCollector::new(
764            self.resolver.genv.tcx(),
765            &self.resolver.output.path_res_map,
766            ScopeKind::FnTraitInput,
767        )
768        .run(|vis| vis.visit_generic_arg(in_arg));
769        for (ident, kind, node_id) in params {
770            self.define_param(ident, kind, node_id, Some(trait_node_id));
771        }
772    }
773
774    fn on_enum_variant(&mut self, variant: &surface::VariantDef) {
775        let params = ImplicitParamCollector::new(
776            self.resolver.genv.tcx(),
777            &self.resolver.output.path_res_map,
778            ScopeKind::Variant,
779        )
780        .run(|vis| vis.visit_variant(variant));
781        for (ident, kind, node_id) in params {
782            self.define_param(ident, kind, node_id, Some(variant.node_id));
783        }
784    }
785
786    fn on_fn_sig(&mut self, fn_sig: &surface::FnSig) {
787        let params = ImplicitParamCollector::new(
788            self.resolver.genv.tcx(),
789            &self.resolver.output.path_res_map,
790            ScopeKind::FnInput,
791        )
792        .run(|vis| vis.visit_fn_sig(fn_sig));
793        for (ident, kind, param_id) in params {
794            self.define_param(ident, kind, param_id, Some(fn_sig.node_id));
795        }
796    }
797
798    fn on_fn_output(&mut self, output: &surface::FnOutput) {
799        let params = ImplicitParamCollector::new(
800            self.resolver.genv.tcx(),
801            &self.resolver.output.path_res_map,
802            ScopeKind::FnOutput,
803        )
804        .run(|vis| vis.visit_fn_output(output));
805        for (ident, kind, param_id) in params {
806            self.define_param(ident, kind, param_id, Some(output.node_id));
807        }
808    }
809
810    fn on_refine_param(&mut self, param: &surface::RefineParam) {
811        self.define_param(param.ident, fhir::ParamKind::Explicit(param.mode), param.node_id, None);
812    }
813
814    fn on_loc(&mut self, loc: Ident, node_id: NodeId) {
815        self.resolve_ident(loc, node_id);
816    }
817
818    fn on_path(&mut self, path: &surface::ExprPath) {
819        self.resolve_path(path);
820    }
821
822    fn on_base_sort(&mut self, sort: &surface::BaseSort) {
823        match sort {
824            surface::BaseSort::Path(path) => {
825                self.resolve_sort_path(path);
826            }
827            surface::BaseSort::BitVec(_) => {}
828            surface::BaseSort::SortOf(..) => {}
829        }
830    }
831}
832
833macro_rules! define_resolve_num_const {
834    ($($typ:ident),*) => {
835        fn resolve_num_const(typ: surface::Ident, name: surface::Ident) -> Option<Res<NodeId>> {
836            match typ.name.as_str() {
837                $(
838                    stringify!($typ) => {
839                        match name.name.as_str() {
840                            "MAX" => Some(Res::NumConst($typ::MAX.try_into().unwrap())),
841                            "MIN" => Some(Res::NumConst($typ::MIN.try_into().unwrap())),
842                            _ => None,
843                        }
844                    },
845                )*
846                _ => None
847            }
848        }
849    };
850}
851
852define_resolve_num_const!(i8, i16, i32, i64, isize, u8, u16, u32, u64, usize);
853
854struct IllegalBinderVisitor<'a, 'genv, 'tcx> {
855    scopes: Vec<ScopeKind>,
856    resolver: &'a CrateResolver<'genv, 'tcx>,
857    errors: Errors<'genv>,
858}
859
860impl<'a, 'genv, 'tcx> IllegalBinderVisitor<'a, 'genv, 'tcx> {
861    fn new(resolver: &'a mut CrateResolver<'genv, 'tcx>) -> Self {
862        let errors = Errors::new(resolver.genv.sess());
863        Self { scopes: vec![], resolver, errors }
864    }
865
866    fn run(self, f: impl FnOnce(&mut ScopedVisitorWrapper<Self>)) -> Result {
867        let mut vis = self.wrap();
868        f(&mut vis);
869        vis.0.errors.into_result()
870    }
871}
872
873impl ScopedVisitor for IllegalBinderVisitor<'_, '_, '_> {
874    fn is_box(&self, segment: &surface::PathSegment) -> bool {
875        self.resolver
876            .output
877            .path_res_map
878            .get(&segment.node_id)
879            .map(|r| r.is_box(self.resolver.genv.tcx()))
880            .unwrap_or(false)
881    }
882
883    fn enter_scope(&mut self, kind: ScopeKind) -> ControlFlow<()> {
884        self.scopes.push(kind);
885        ControlFlow::Continue(())
886    }
887
888    fn exit_scope(&mut self) {
889        self.scopes.pop();
890    }
891
892    fn on_implicit_param(&mut self, ident: Ident, param_kind: fhir::ParamKind, _: NodeId) {
893        let Some(scope_kind) = self.scopes.last() else { return };
894        let (allowed, bind_kind) = match param_kind {
895            fhir::ParamKind::At => {
896                (
897                    matches!(
898                        scope_kind,
899                        ScopeKind::FnInput | ScopeKind::FnTraitInput | ScopeKind::Variant
900                    ),
901                    surface::BindKind::At,
902                )
903            }
904            fhir::ParamKind::Pound => {
905                (matches!(scope_kind, ScopeKind::FnOutput), surface::BindKind::Pound)
906            }
907            fhir::ParamKind::Colon
908            | fhir::ParamKind::Loc
909            | fhir::ParamKind::Error
910            | fhir::ParamKind::Explicit(..) => return,
911        };
912        if !allowed {
913            self.errors
914                .emit(errors::IllegalBinder::new(ident.span, bind_kind));
915        }
916    }
917}
918
919mod errors {
920    use flux_errors::E0999;
921    use flux_macros::Diagnostic;
922    use flux_syntax::surface;
923    use itertools::Itertools;
924    use rustc_span::{Span, Symbol, symbol::Ident};
925
926    #[derive(Diagnostic)]
927    #[diag(desugar_duplicate_param, code = E0999)]
928    pub(super) struct DuplicateParam {
929        #[primary_span]
930        #[label]
931        span: Span,
932        name: Symbol,
933        #[label(desugar_first_use)]
934        first_use: Span,
935    }
936
937    impl DuplicateParam {
938        pub(super) fn new(old_ident: Ident, new_ident: Ident) -> Self {
939            debug_assert_eq!(old_ident.name, new_ident.name);
940            Self { span: new_ident.span, name: new_ident.name, first_use: old_ident.span }
941        }
942    }
943
944    #[derive(Diagnostic)]
945    #[diag(desugar_unresolved_sort, code = E0999)]
946    pub(super) struct UnresolvedSort {
947        #[primary_span]
948        #[label]
949        span: Span,
950        name: String,
951    }
952
953    impl UnresolvedSort {
954        pub(super) fn new(path: &surface::SortPath) -> Self {
955            Self {
956                span: path
957                    .segments
958                    .iter()
959                    .map(|ident| ident.span)
960                    .reduce(Span::to)
961                    .unwrap_or_default(),
962                name: format!("{}", path.segments.iter().format("::")),
963            }
964        }
965    }
966
967    #[derive(Diagnostic)]
968    #[diag(desugar_unresolved_var, code = E0999)]
969    pub(super) struct UnresolvedVar {
970        #[primary_span]
971        #[label]
972        span: Span,
973        var: String,
974    }
975
976    impl UnresolvedVar {
977        pub(super) fn from_path(path: &surface::ExprPath) -> Self {
978            Self {
979                span: path.span,
980                var: format!(
981                    "{}",
982                    path.segments
983                        .iter()
984                        .format_with("::", |s, f| f(&s.ident.name))
985                ),
986            }
987        }
988
989        pub(super) fn from_ident(ident: Ident) -> Self {
990            Self { span: ident.span, var: format!("{ident}") }
991        }
992    }
993
994    #[derive(Diagnostic)]
995    #[diag(desugar_invalid_unrefined_param, code = E0999)]
996    pub(super) struct InvalidUnrefinedParam {
997        #[primary_span]
998        #[label]
999        span: Span,
1000        var: Ident,
1001    }
1002
1003    impl InvalidUnrefinedParam {
1004        pub(super) fn new(var: Ident) -> Self {
1005            Self { var, span: var.span }
1006        }
1007    }
1008
1009    #[derive(Diagnostic)]
1010    #[diag(desugar_illegal_binder, code = E0999)]
1011    pub(super) struct IllegalBinder {
1012        #[primary_span]
1013        #[label]
1014        span: Span,
1015        kind: &'static str,
1016    }
1017
1018    impl IllegalBinder {
1019        pub(super) fn new(span: Span, kind: surface::BindKind) -> Self {
1020            Self { span, kind: kind.token_str() }
1021        }
1022    }
1023}