flux_desugar/resolver/
refinement_resolver.rs

1use std::ops::ControlFlow;
2
3use flux_common::index::IndexGen;
4use flux_errors::Errors;
5use flux_middle::{
6    ResolverOutput,
7    fhir::{self, ExprRes},
8};
9use flux_syntax::{
10    surface::{self, Ident, NodeId, visit::Visitor as _},
11    walk_list,
12};
13use rustc_data_structures::{
14    fx::{FxIndexMap, FxIndexSet, IndexEntry},
15    unord::UnordMap,
16};
17use rustc_hash::FxHashMap;
18use rustc_hir::def::{
19    DefKind,
20    Namespace::{TypeNS, ValueNS},
21};
22use rustc_middle::ty::TyCtxt;
23use rustc_span::{ErrorGuaranteed, Symbol, sym};
24
25use super::{CrateResolver, Segment};
26
27type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
28
29#[derive(Clone, Copy, PartialEq, Eq, Debug)]
30pub(crate) enum ScopeKind {
31    FnInput,
32    FnOutput,
33    Variant,
34    Misc,
35    FnTraitInput,
36}
37
38impl ScopeKind {
39    fn is_barrier(self) -> bool {
40        matches!(self, ScopeKind::FnInput | ScopeKind::Variant)
41    }
42}
43
44/// Parameters used during gathering.
45#[derive(Debug, Clone, Copy)]
46struct ParamRes(fhir::ParamKind, NodeId);
47
48impl ParamRes {
49    fn kind(self) -> fhir::ParamKind {
50        self.0
51    }
52
53    fn param_id(self) -> NodeId {
54        self.1
55    }
56}
57
58pub(crate) trait ScopedVisitor: Sized {
59    fn is_box(&self, segment: &surface::PathSegment) -> bool;
60    fn enter_scope(&mut self, kind: ScopeKind) -> ControlFlow<()>;
61    fn exit_scope(&mut self) {}
62
63    fn wrap(self) -> ScopedVisitorWrapper<Self> {
64        ScopedVisitorWrapper(self)
65    }
66
67    fn on_implicit_param(&mut self, _ident: Ident, _kind: fhir::ParamKind, _node_id: NodeId) {}
68    fn on_generic_param(&mut self, _param: &surface::GenericParam) {}
69    fn on_refine_param(&mut self, _param: &surface::RefineParam) {}
70    fn on_enum_variant(&mut self, _variant: &surface::VariantDef) {}
71    fn on_fn_trait_input(&mut self, _in_arg: &surface::GenericArg, _node_id: NodeId) {}
72    fn on_fn_sig(&mut self, _fn_sig: &surface::FnSig) {}
73    fn on_fn_output(&mut self, _output: &surface::FnOutput) {}
74    fn on_loc(&mut self, _loc: Ident, _node_id: NodeId) {}
75    fn on_path(&mut self, _path: &surface::ExprPath) {}
76    fn on_base_sort(&mut self, _sort: &surface::BaseSort) {}
77}
78
79pub(crate) struct ScopedVisitorWrapper<V>(V);
80
81impl<V: ScopedVisitor> ScopedVisitorWrapper<V> {
82    fn with_scope(&mut self, kind: ScopeKind, f: impl FnOnce(&mut Self)) {
83        let scope = self.0.enter_scope(kind);
84        if let ControlFlow::Continue(_) = scope {
85            f(self);
86            self.0.exit_scope();
87        }
88    }
89}
90
91impl<V> std::ops::Deref for ScopedVisitorWrapper<V> {
92    type Target = V;
93
94    fn deref(&self) -> &Self::Target {
95        &self.0
96    }
97}
98impl<V> std::ops::DerefMut for ScopedVisitorWrapper<V> {
99    fn deref_mut(&mut self) -> &mut Self::Target {
100        &mut self.0
101    }
102}
103
104impl<V: ScopedVisitor> surface::visit::Visitor for ScopedVisitorWrapper<V> {
105    fn visit_trait_assoc_reft(&mut self, assoc_reft: &surface::TraitAssocReft) {
106        self.with_scope(ScopeKind::Misc, |this| {
107            surface::visit::walk_trait_assoc_reft(this, assoc_reft);
108        });
109    }
110
111    fn visit_impl_assoc_reft(&mut self, assoc_reft: &surface::ImplAssocReft) {
112        self.with_scope(ScopeKind::Misc, |this| {
113            surface::visit::walk_impl_assoc_reft(this, assoc_reft);
114        });
115    }
116
117    fn visit_qualifier(&mut self, qualifier: &surface::Qualifier) {
118        self.with_scope(ScopeKind::Misc, |this| {
119            surface::visit::walk_qualifier(this, qualifier);
120        });
121    }
122
123    fn visit_defn(&mut self, defn: &surface::SpecFunc) {
124        self.with_scope(ScopeKind::Misc, |this| {
125            surface::visit::walk_defn(this, defn);
126        });
127    }
128
129    fn visit_prim_prop(&mut self, prop: &surface::PrimOpProp) {
130        self.with_scope(ScopeKind::Misc, |this| {
131            surface::visit::walk_prim_prop(this, prop);
132        });
133    }
134
135    fn visit_generic_param(&mut self, param: &surface::GenericParam) {
136        self.on_generic_param(param);
137        surface::visit::walk_generic_param(self, param);
138    }
139
140    fn visit_refine_param(&mut self, param: &surface::RefineParam) {
141        self.on_refine_param(param);
142        surface::visit::walk_refine_param(self, param);
143    }
144
145    fn visit_ty_alias(&mut self, ty_alias: &surface::TyAlias) {
146        self.with_scope(ScopeKind::Misc, |this| {
147            surface::visit::walk_ty_alias(this, ty_alias);
148        });
149    }
150
151    fn visit_struct_def(&mut self, struct_def: &surface::StructDef) {
152        self.with_scope(ScopeKind::Misc, |this| {
153            surface::visit::walk_struct_def(this, struct_def);
154        });
155    }
156
157    fn visit_enum_def(&mut self, enum_def: &surface::EnumDef) {
158        self.with_scope(ScopeKind::Misc, |this| {
159            surface::visit::walk_enum_def(this, enum_def);
160        });
161    }
162
163    fn visit_variant(&mut self, variant: &surface::VariantDef) {
164        self.with_scope(ScopeKind::Variant, |this| {
165            this.on_enum_variant(variant);
166            surface::visit::walk_variant(this, variant);
167        });
168    }
169
170    fn visit_trait_ref(&mut self, trait_ref: &surface::TraitRef) {
171        match trait_ref.as_fn_trait_ref() {
172            Some((in_arg, out_arg)) => {
173                self.with_scope(ScopeKind::FnTraitInput, |this| {
174                    this.on_fn_trait_input(in_arg, trait_ref.node_id);
175                    surface::visit::walk_generic_arg(this, in_arg);
176                    this.with_scope(ScopeKind::Misc, |this| {
177                        surface::visit::walk_generic_arg(this, out_arg);
178                    });
179                });
180            }
181            None => {
182                self.with_scope(ScopeKind::Misc, |this| {
183                    surface::visit::walk_trait_ref(this, trait_ref);
184                });
185            }
186        }
187    }
188
189    fn visit_variant_ret(&mut self, ret: &surface::VariantRet) {
190        self.with_scope(ScopeKind::Misc, |this| {
191            surface::visit::walk_variant_ret(this, ret);
192        });
193    }
194
195    fn visit_generics(&mut self, generics: &surface::Generics) {
196        self.with_scope(ScopeKind::Misc, |this| {
197            surface::visit::walk_generics(this, generics);
198        });
199    }
200
201    fn visit_fn_sig(&mut self, fn_sig: &surface::FnSig) {
202        self.with_scope(ScopeKind::FnInput, |this| {
203            this.on_fn_sig(fn_sig);
204            surface::visit::walk_fn_sig(this, fn_sig);
205        });
206    }
207
208    fn visit_fn_output(&mut self, output: &surface::FnOutput) {
209        self.with_scope(ScopeKind::FnOutput, |this| {
210            this.on_fn_output(output);
211            surface::visit::walk_fn_output(this, output);
212        });
213    }
214
215    fn visit_fn_input(&mut self, arg: &surface::FnInput) {
216        match arg {
217            surface::FnInput::Constr(bind, _, _, node_id) => {
218                self.on_implicit_param(*bind, fhir::ParamKind::Colon, *node_id);
219            }
220            surface::FnInput::StrgRef(loc, _, node_id) => {
221                self.on_implicit_param(*loc, fhir::ParamKind::Loc, *node_id);
222            }
223            surface::FnInput::Ty(bind, ty, node_id) => {
224                if let &Some(bind) = bind {
225                    let param_kind = if let surface::TyKind::Base(_) = &ty.kind {
226                        fhir::ParamKind::Colon
227                    } else {
228                        fhir::ParamKind::Error
229                    };
230                    self.on_implicit_param(bind, param_kind, *node_id);
231                }
232            }
233        }
234        surface::visit::walk_fn_input(self, arg);
235    }
236
237    fn visit_ensures(&mut self, constraint: &surface::Ensures) {
238        if let surface::Ensures::Type(loc, _, node_id) = constraint {
239            self.on_loc(*loc, *node_id);
240        }
241        surface::visit::walk_ensures(self, constraint);
242    }
243
244    fn visit_refine_arg(&mut self, arg: &surface::RefineArg) {
245        match arg {
246            surface::RefineArg::Bind(ident, kind, _, node_id) => {
247                let kind = match kind {
248                    surface::BindKind::At => fhir::ParamKind::At,
249                    surface::BindKind::Pound => fhir::ParamKind::Pound,
250                };
251                self.on_implicit_param(*ident, kind, *node_id);
252            }
253            surface::RefineArg::Abs(..) => {
254                self.with_scope(ScopeKind::Misc, |this| {
255                    surface::visit::walk_refine_arg(this, arg);
256                });
257            }
258            surface::RefineArg::Expr(expr) => self.visit_expr(expr),
259        }
260    }
261
262    fn visit_path(&mut self, path: &surface::Path) {
263        for arg in &path.refine {
264            self.with_scope(ScopeKind::Misc, |this| this.visit_refine_arg(arg));
265        }
266        walk_list!(self, visit_path_segment, &path.segments);
267    }
268
269    fn visit_path_segment(&mut self, segment: &surface::PathSegment) {
270        let is_box = self.is_box(segment);
271        for (i, arg) in segment.args.iter().enumerate() {
272            if is_box && i == 0 {
273                self.visit_generic_arg(arg);
274            } else {
275                self.with_scope(ScopeKind::Misc, |this| this.visit_generic_arg(arg));
276            }
277        }
278    }
279
280    fn visit_ty(&mut self, ty: &surface::Ty) {
281        let node_id = ty.node_id;
282        match &ty.kind {
283            surface::TyKind::Exists { bind, .. } => {
284                self.with_scope(ScopeKind::Misc, |this| {
285                    let param = surface::RefineParam {
286                        ident: *bind,
287                        mode: None,
288                        sort: surface::Sort::Infer,
289                        node_id,
290                        span: bind.span,
291                    };
292                    this.on_refine_param(&param);
293                    surface::visit::walk_ty(this, ty);
294                });
295            }
296            surface::TyKind::GeneralExists { .. } => {
297                self.with_scope(ScopeKind::Misc, |this| {
298                    surface::visit::walk_ty(this, ty);
299                });
300            }
301            surface::TyKind::Array(..) => {
302                self.with_scope(ScopeKind::Misc, |this| {
303                    surface::visit::walk_ty(this, ty);
304                });
305            }
306            _ => surface::visit::walk_ty(self, ty),
307        }
308    }
309
310    fn visit_bty(&mut self, bty: &surface::BaseTy) {
311        match &bty.kind {
312            surface::BaseTyKind::Slice(_) => {
313                self.with_scope(ScopeKind::Misc, |this| {
314                    surface::visit::walk_bty(this, bty);
315                });
316            }
317            surface::BaseTyKind::Path(..) => {
318                surface::visit::walk_bty(self, bty);
319            }
320        }
321    }
322
323    fn visit_path_expr(&mut self, path: &surface::ExprPath) {
324        self.on_path(path);
325    }
326
327    fn visit_base_sort(&mut self, bsort: &surface::BaseSort) {
328        self.on_base_sort(bsort);
329        surface::visit::walk_base_sort(self, bsort);
330    }
331}
332
333struct ImplicitParamCollector<'a, 'tcx> {
334    tcx: TyCtxt<'tcx>,
335    path_res_map: &'a UnordMap<surface::NodeId, fhir::PartialRes>,
336    kind: ScopeKind,
337    params: Vec<(Ident, fhir::ParamKind, NodeId)>,
338}
339
340impl<'a, 'tcx> ImplicitParamCollector<'a, 'tcx> {
341    fn new(
342        tcx: TyCtxt<'tcx>,
343        path_res_map: &'a UnordMap<surface::NodeId, fhir::PartialRes>,
344        kind: ScopeKind,
345    ) -> Self {
346        Self { tcx, path_res_map, kind, params: vec![] }
347    }
348
349    fn run(
350        self,
351        f: impl FnOnce(&mut ScopedVisitorWrapper<Self>),
352    ) -> Vec<(Ident, fhir::ParamKind, NodeId)> {
353        let mut wrapped = self.wrap();
354        f(&mut wrapped);
355        wrapped.0.params
356    }
357}
358
359impl ScopedVisitor for ImplicitParamCollector<'_, '_> {
360    fn is_box(&self, segment: &surface::PathSegment) -> bool {
361        self.path_res_map
362            .get(&segment.node_id)
363            .map(|r| r.is_box(self.tcx))
364            .unwrap_or(false)
365    }
366
367    fn enter_scope(&mut self, kind: ScopeKind) -> ControlFlow<()> {
368        if self.kind == kind { ControlFlow::Continue(()) } else { ControlFlow::Break(()) }
369    }
370
371    fn on_implicit_param(&mut self, ident: Ident, param: fhir::ParamKind, node_id: NodeId) {
372        self.params.push((ident, param, node_id));
373    }
374}
375
376struct Scope {
377    kind: ScopeKind,
378    bindings: FxIndexMap<Ident, ParamRes>,
379}
380
381impl Scope {
382    fn new(kind: ScopeKind) -> Self {
383        Self { kind, bindings: Default::default() }
384    }
385}
386
387#[derive(Clone, Copy)]
388struct ParamDef {
389    ident: Ident,
390    kind: fhir::ParamKind,
391    scope: Option<NodeId>,
392}
393
394pub(crate) struct RefinementResolver<'a, 'genv, 'tcx> {
395    scopes: Vec<Scope>,
396    sort_params: FxIndexSet<Symbol>,
397    param_defs: FxIndexMap<NodeId, ParamDef>,
398    resolver: &'a mut CrateResolver<'genv, 'tcx>,
399    path_res_map: FxHashMap<NodeId, ExprRes<NodeId>>,
400    errors: Errors<'genv>,
401}
402
403impl<'a, 'genv, 'tcx> RefinementResolver<'a, 'genv, 'tcx> {
404    pub(crate) fn for_flux_item(
405        resolver: &'a mut CrateResolver<'genv, 'tcx>,
406        sort_params: &[Ident],
407    ) -> Self {
408        Self::new(resolver, sort_params.iter().map(|ident| ident.name).collect())
409    }
410
411    pub(crate) fn for_rust_item(resolver: &'a mut CrateResolver<'genv, 'tcx>) -> Self {
412        Self::new(resolver, Default::default())
413    }
414
415    pub(crate) fn resolve_qualifier(
416        resolver: &'a mut CrateResolver<'genv, 'tcx>,
417        qualifier: &surface::Qualifier,
418    ) -> Result {
419        Self::for_flux_item(resolver, &[]).run(|r| r.visit_qualifier(qualifier))
420    }
421
422    pub(crate) fn resolve_defn(
423        resolver: &'a mut CrateResolver<'genv, 'tcx>,
424        defn: &surface::SpecFunc,
425    ) -> Result {
426        Self::for_flux_item(resolver, &defn.sort_vars).run(|r| r.visit_defn(defn))
427    }
428
429    pub(crate) fn resolve_prim_prop(
430        resolver: &'a mut CrateResolver<'genv, 'tcx>,
431        prop: &surface::PrimOpProp,
432    ) -> Result {
433        Self::for_flux_item(resolver, &[]).run(|r| r.visit_prim_prop(prop))
434    }
435
436    pub(crate) fn resolve_fn_sig(
437        resolver: &'a mut CrateResolver<'genv, 'tcx>,
438        fn_sig: &surface::FnSig,
439    ) -> Result {
440        IllegalBinderVisitor::new(resolver).run(|vis| vis.visit_fn_sig(fn_sig))?;
441        Self::for_rust_item(resolver).run(|vis| vis.visit_fn_sig(fn_sig))
442    }
443
444    pub(crate) fn resolve_struct_def(
445        resolver: &'a mut CrateResolver<'genv, 'tcx>,
446        struct_def: &surface::StructDef,
447    ) -> Result {
448        IllegalBinderVisitor::new(resolver).run(|vis| vis.visit_struct_def(struct_def))?;
449        Self::for_rust_item(resolver).run(|vis| vis.visit_struct_def(struct_def))
450    }
451
452    pub(crate) fn resolve_enum_def(
453        resolver: &'a mut CrateResolver<'genv, 'tcx>,
454        enum_def: &surface::EnumDef,
455    ) -> Result {
456        IllegalBinderVisitor::new(resolver).run(|vis| vis.visit_enum_def(enum_def))?;
457        Self::for_rust_item(resolver).run(|vis| vis.visit_enum_def(enum_def))
458    }
459
460    pub(crate) fn resolve_constant(
461        resolver: &'a mut CrateResolver<'genv, 'tcx>,
462        constant_info: &surface::ConstantInfo,
463    ) -> Result {
464        IllegalBinderVisitor::new(resolver).run(|vis| vis.visit_constant(constant_info))?;
465        Self::for_rust_item(resolver).run(|vis| vis.visit_constant(constant_info))
466    }
467
468    pub(crate) fn resolve_ty_alias(
469        resolver: &'a mut CrateResolver<'genv, 'tcx>,
470        ty_alias: &surface::TyAlias,
471    ) -> Result {
472        IllegalBinderVisitor::new(resolver).run(|vis| vis.visit_ty_alias(ty_alias))?;
473        Self::for_rust_item(resolver).run(|vis| vis.visit_ty_alias(ty_alias))
474    }
475
476    pub(crate) fn resolve_impl(
477        resolver: &'a mut CrateResolver<'genv, 'tcx>,
478        impl_: &surface::Impl,
479    ) -> Result {
480        IllegalBinderVisitor::new(resolver).run(|vis| vis.visit_impl(impl_))?;
481        Self::for_rust_item(resolver).run(|vis| vis.visit_impl(impl_))
482    }
483
484    pub(crate) fn resolve_trait(
485        resolver: &'a mut CrateResolver<'genv, 'tcx>,
486        trait_: &surface::Trait,
487    ) -> Result {
488        IllegalBinderVisitor::new(resolver).run(|vis| vis.visit_trait(trait_))?;
489        Self::for_rust_item(resolver).run(|vis| vis.visit_trait(trait_))
490    }
491
492    fn new(resolver: &'a mut CrateResolver<'genv, 'tcx>, sort_params: FxIndexSet<Symbol>) -> Self {
493        let errors = Errors::new(resolver.genv.sess());
494        Self {
495            resolver,
496            sort_params,
497            param_defs: Default::default(),
498            scopes: Default::default(),
499            path_res_map: Default::default(),
500            errors,
501        }
502    }
503
504    fn run(self, f: impl FnOnce(&mut ScopedVisitorWrapper<Self>)) -> Result {
505        let mut wrapper = self.wrap();
506        f(&mut wrapper);
507        wrapper.0.finish()
508    }
509
510    fn define_param(
511        &mut self,
512        ident: Ident,
513        kind: fhir::ParamKind,
514        param_id: NodeId,
515        scope: Option<NodeId>,
516    ) {
517        self.param_defs
518            .insert(param_id, ParamDef { ident, kind, scope });
519
520        let scope = self.scopes.last_mut().unwrap();
521        match scope.bindings.entry(ident) {
522            IndexEntry::Occupied(entry) => {
523                let param_def = self.param_defs[&entry.get().param_id()];
524                self.errors
525                    .emit(errors::DuplicateParam::new(param_def.ident, ident));
526            }
527            IndexEntry::Vacant(entry) => {
528                entry.insert(ParamRes(kind, param_id));
529            }
530        }
531    }
532
533    fn find(&mut self, ident: Ident) -> Option<ParamRes> {
534        for scope in self.scopes.iter().rev() {
535            if let Some(res) = scope.bindings.get(&ident) {
536                return Some(*res);
537            }
538
539            if scope.kind.is_barrier() {
540                return None;
541            }
542        }
543        None
544    }
545
546    fn try_resolve_enum_variant(&mut self, typ: Ident, variant: Ident) -> Option<ExprRes<NodeId>> {
547        if let fhir::Res::Def(_, enum_def_id) =
548            self.resolver.resolve_ident_with_ribs(typ, TypeNS)?
549        {
550            let enum_variants = self.resolver.enum_variants.get(&enum_def_id)?;
551            let variant_def_id = enum_variants.variants.get(&variant.name)?;
552            return Some(ExprRes::Variant(*variant_def_id));
553        }
554        None
555    }
556
557    fn resolve_path(&mut self, path: &surface::ExprPath) {
558        if let [segment] = &path.segments[..]
559            && let Some(res) = self.try_resolve_param(segment.ident)
560        {
561            self.path_res_map.insert(path.node_id, res);
562            return;
563        }
564        if let Some(res) = self.try_resolve_expr_with_ribs(&path.segments) {
565            self.path_res_map.insert(path.node_id, res);
566            return;
567        }
568        // TODO(nilehmann) move this to resolve_with_ribs
569        if let [typ, name] = &path.segments[..]
570            && let Some(res) = resolve_num_const(typ.ident, name.ident)
571        {
572            self.path_res_map.insert(path.node_id, res);
573            return;
574        }
575        if let [typ, name] = &path.segments[..]
576            && let Some(res) = self.try_resolve_enum_variant(typ.ident, name.ident)
577        {
578            self.path_res_map.insert(path.node_id, res);
579            return;
580        }
581        if let [segment] = &path.segments[..]
582            && let Some(res) = self.try_resolve_global_func(segment.ident)
583        {
584            self.path_res_map.insert(path.node_id, res);
585            return;
586        }
587
588        self.errors.emit(errors::UnresolvedVar::from_path(path));
589    }
590
591    fn resolve_ident(&mut self, ident: Ident, node_id: NodeId) {
592        if let Some(res) = self.try_resolve_param(ident) {
593            self.path_res_map.insert(node_id, res);
594            return;
595        }
596        if let Some(res) = self.try_resolve_expr_with_ribs(&[ident]) {
597            self.path_res_map.insert(node_id, res);
598            return;
599        }
600        if let Some(res) = self.try_resolve_global_func(ident) {
601            self.path_res_map.insert(node_id, res);
602            return;
603        }
604        self.errors.emit(errors::UnresolvedVar::from_ident(ident));
605    }
606
607    fn try_resolve_expr_with_ribs<S: Segment>(
608        &mut self,
609        segments: &[S],
610    ) -> Option<ExprRes<NodeId>> {
611        let path = self.resolver.resolve_path_with_ribs(segments, ValueNS);
612
613        let res = match path {
614            Some(r) => r.full_res()?,
615            _ => {
616                self.resolver
617                    .resolve_path_with_ribs(segments, TypeNS)?
618                    .full_res()?
619            }
620        };
621        match res {
622            fhir::Res::Def(DefKind::ConstParam, def_id) => Some(ExprRes::ConstGeneric(def_id)),
623            fhir::Res::Def(DefKind::Const, def_id) => Some(ExprRes::Const(def_id)),
624            fhir::Res::Def(DefKind::Struct | DefKind::Enum, def_id) => Some(ExprRes::Ctor(def_id)),
625            fhir::Res::Def(DefKind::Variant, def_id) => Some(ExprRes::Variant(def_id)),
626            _ => None,
627        }
628    }
629
630    fn try_resolve_param(&mut self, ident: Ident) -> Option<ExprRes<NodeId>> {
631        let res = self.find(ident)?;
632
633        if let fhir::ParamKind::Error = res.kind() {
634            self.errors.emit(errors::InvalidUnrefinedParam::new(ident));
635        }
636        Some(ExprRes::Param(res.kind(), res.param_id()))
637    }
638
639    fn try_resolve_global_func(&mut self, ident: Ident) -> Option<ExprRes<NodeId>> {
640        let kind = self.resolver.func_decls.get(&ident.name)?;
641        Some(ExprRes::GlobalFunc(*kind))
642    }
643
644    fn resolve_sort_path(&mut self, path: &surface::SortPath) {
645        let res = self
646            .try_resolve_sort_param(path)
647            .or_else(|| self.try_resolve_sort_with_ribs(path))
648            .or_else(|| self.try_resolve_user_sort(path))
649            .or_else(|| self.try_resolve_prim_sort(path));
650
651        if let Some(res) = res {
652            self.resolver
653                .output
654                .sort_path_res_map
655                .insert(path.node_id, res);
656        } else {
657            self.errors.emit(errors::UnresolvedSort::new(path));
658        }
659    }
660
661    fn try_resolve_sort_param(&self, path: &surface::SortPath) -> Option<fhir::SortRes> {
662        let [segment] = &path.segments[..] else { return None };
663        self.sort_params
664            .get_index_of(&segment.name)
665            .map(fhir::SortRes::SortParam)
666    }
667
668    fn try_resolve_sort_with_ribs(&mut self, path: &surface::SortPath) -> Option<fhir::SortRes> {
669        let partial_res = self
670            .resolver
671            .resolve_path_with_ribs(&path.segments, TypeNS)?;
672        match (partial_res.base_res(), partial_res.unresolved_segments()) {
673            (fhir::Res::Def(DefKind::Struct | DefKind::Enum, def_id), 0) => {
674                Some(fhir::SortRes::Adt(def_id))
675            }
676            (fhir::Res::Def(DefKind::TyParam, def_id), 0) => Some(fhir::SortRes::TyParam(def_id)),
677            (fhir::Res::SelfTyParam { trait_ }, 0) => {
678                Some(fhir::SortRes::SelfParam { trait_id: trait_ })
679            }
680            (fhir::Res::SelfTyParam { trait_ }, 1) => {
681                let ident = *path.segments.last().unwrap();
682                Some(fhir::SortRes::SelfParamAssoc { trait_id: trait_, ident })
683            }
684            (fhir::Res::SelfTyAlias { alias_to, .. }, 0) => {
685                Some(fhir::SortRes::SelfAlias { alias_to })
686            }
687            _ => None,
688        }
689    }
690
691    fn try_resolve_user_sort(&self, path: &surface::SortPath) -> Option<fhir::SortRes> {
692        let [segment] = &path.segments[..] else { return None };
693        self.resolver
694            .sort_decls
695            .get(&segment.name)
696            .map(|decl| fhir::SortRes::User { name: decl.name })
697    }
698
699    fn try_resolve_prim_sort(&self, path: &surface::SortPath) -> Option<fhir::SortRes> {
700        let [segment] = &path.segments[..] else { return None };
701        if segment.name == SORTS.int {
702            Some(fhir::SortRes::PrimSort(fhir::PrimSort::Int))
703        } else if segment.name == sym::bool {
704            Some(fhir::SortRes::PrimSort(fhir::PrimSort::Bool))
705        } else if segment.name == sym::char {
706            Some(fhir::SortRes::PrimSort(fhir::PrimSort::Char))
707        } else if segment.name == SORTS.real {
708            Some(fhir::SortRes::PrimSort(fhir::PrimSort::Real))
709        } else if segment.name == SORTS.set {
710            Some(fhir::SortRes::PrimSort(fhir::PrimSort::Set))
711        } else if segment.name == SORTS.map {
712            Some(fhir::SortRes::PrimSort(fhir::PrimSort::Map))
713        } else {
714            None
715        }
716    }
717
718    pub(crate) fn finish(self) -> Result {
719        let param_id_gen = IndexGen::new();
720        let mut params = FxIndexMap::default();
721
722        // Create an `fhir::ParamId` for all parameters used in a path before iterating over
723        // `param_defs` such that we can skip `fhir::ParamKind::Colon` if the param wasn't used
724        for (node_id, res) in self.path_res_map {
725            let res = res.map_param_id(|param_id| {
726                *params
727                    .entry(param_id)
728                    .or_insert_with(|| param_id_gen.fresh())
729            });
730            self.resolver.output.expr_path_res_map.insert(node_id, res);
731        }
732
733        // At this point, the `params` map contains all parameters that were used in an expression,
734        // so we can safely skip `ParamKind::Colon` if there's no entry for it.
735        for (param_id, param_def) in self.param_defs {
736            let name = match param_def.kind {
737                fhir::ParamKind::Colon => {
738                    let Some(name) = params.get(&param_id) else { continue };
739                    *name
740                }
741                fhir::ParamKind::Error => continue,
742                _ => {
743                    params
744                        .get(&param_id)
745                        .copied()
746                        .unwrap_or_else(|| param_id_gen.fresh())
747                }
748            };
749            let output = &mut self.resolver.output;
750            output
751                .param_res_map
752                .insert(param_id, (name, param_def.kind));
753
754            if let Some(scope) = param_def.scope {
755                output
756                    .implicit_params
757                    .entry(scope)
758                    .or_default()
759                    .push((param_def.ident, param_id));
760            }
761        }
762        self.errors.into_result()
763    }
764
765    fn resolver_output(&self) -> &ResolverOutput {
766        &self.resolver.output
767    }
768}
769
770impl ScopedVisitor for RefinementResolver<'_, '_, '_> {
771    fn is_box(&self, segment: &surface::PathSegment) -> bool {
772        self.resolver_output()
773            .path_res_map
774            .get(&segment.node_id)
775            .map(|r| r.is_box(self.resolver.genv.tcx()))
776            .unwrap_or(false)
777    }
778
779    fn enter_scope(&mut self, kind: ScopeKind) -> ControlFlow<()> {
780        self.scopes.push(Scope::new(kind));
781        ControlFlow::Continue(())
782    }
783
784    fn exit_scope(&mut self) {
785        self.scopes.pop();
786    }
787
788    fn on_fn_trait_input(&mut self, in_arg: &surface::GenericArg, trait_node_id: NodeId) {
789        let params = ImplicitParamCollector::new(
790            self.resolver.genv.tcx(),
791            &self.resolver.output.path_res_map,
792            ScopeKind::FnTraitInput,
793        )
794        .run(|vis| vis.visit_generic_arg(in_arg));
795        for (ident, kind, node_id) in params {
796            self.define_param(ident, kind, node_id, Some(trait_node_id));
797        }
798    }
799
800    fn on_enum_variant(&mut self, variant: &surface::VariantDef) {
801        let params = ImplicitParamCollector::new(
802            self.resolver.genv.tcx(),
803            &self.resolver.output.path_res_map,
804            ScopeKind::Variant,
805        )
806        .run(|vis| vis.visit_variant(variant));
807        for (ident, kind, node_id) in params {
808            self.define_param(ident, kind, node_id, Some(variant.node_id));
809        }
810    }
811
812    fn on_fn_sig(&mut self, fn_sig: &surface::FnSig) {
813        let params = ImplicitParamCollector::new(
814            self.resolver.genv.tcx(),
815            &self.resolver.output.path_res_map,
816            ScopeKind::FnInput,
817        )
818        .run(|vis| vis.visit_fn_sig(fn_sig));
819        for (ident, kind, param_id) in params {
820            self.define_param(ident, kind, param_id, Some(fn_sig.node_id));
821        }
822    }
823
824    fn on_fn_output(&mut self, output: &surface::FnOutput) {
825        let params = ImplicitParamCollector::new(
826            self.resolver.genv.tcx(),
827            &self.resolver.output.path_res_map,
828            ScopeKind::FnOutput,
829        )
830        .run(|vis| vis.visit_fn_output(output));
831        for (ident, kind, param_id) in params {
832            self.define_param(ident, kind, param_id, Some(output.node_id));
833        }
834    }
835
836    fn on_refine_param(&mut self, param: &surface::RefineParam) {
837        self.define_param(param.ident, fhir::ParamKind::Explicit(param.mode), param.node_id, None);
838    }
839
840    fn on_loc(&mut self, loc: Ident, node_id: NodeId) {
841        self.resolve_ident(loc, node_id);
842    }
843
844    fn on_path(&mut self, path: &surface::ExprPath) {
845        self.resolve_path(path);
846    }
847
848    fn on_base_sort(&mut self, sort: &surface::BaseSort) {
849        match sort {
850            surface::BaseSort::Path(path) => {
851                self.resolve_sort_path(path);
852            }
853            surface::BaseSort::BitVec(_) => {}
854            surface::BaseSort::SortOf(..) => {}
855        }
856    }
857}
858
859macro_rules! define_resolve_num_const {
860    ($($typ:ident),*) => {
861        fn resolve_num_const(typ: surface::Ident, name: surface::Ident) -> Option<ExprRes<NodeId>> {
862            match typ.name.as_str() {
863                $(
864                    stringify!($typ) => {
865                        match name.name.as_str() {
866                            "MAX" => Some(ExprRes::NumConst($typ::MAX.try_into().unwrap())),
867                            "MIN" => Some(ExprRes::NumConst($typ::MIN.try_into().unwrap())),
868                            _ => None,
869                        }
870                    },
871                )*
872                _ => None
873            }
874        }
875    };
876}
877
878define_resolve_num_const!(i8, i16, i32, i64, isize, u8, u16, u32, u64, usize);
879
880pub(crate) struct Sorts {
881    pub int: Symbol,
882    pub real: Symbol,
883    pub set: Symbol,
884    pub map: Symbol,
885}
886
887pub(crate) static SORTS: std::sync::LazyLock<Sorts> = std::sync::LazyLock::new(|| {
888    Sorts {
889        int: Symbol::intern("int"),
890        real: Symbol::intern("real"),
891        set: Symbol::intern("Set"),
892        map: Symbol::intern("Map"),
893    }
894});
895
896struct IllegalBinderVisitor<'a, 'genv, 'tcx> {
897    scopes: Vec<ScopeKind>,
898    resolver: &'a CrateResolver<'genv, 'tcx>,
899    errors: Errors<'genv>,
900}
901
902impl<'a, 'genv, 'tcx> IllegalBinderVisitor<'a, 'genv, 'tcx> {
903    fn new(resolver: &'a mut CrateResolver<'genv, 'tcx>) -> Self {
904        let errors = Errors::new(resolver.genv.sess());
905        Self { scopes: vec![], resolver, errors }
906    }
907
908    fn run(self, f: impl FnOnce(&mut ScopedVisitorWrapper<Self>)) -> Result {
909        let mut vis = self.wrap();
910        f(&mut vis);
911        vis.0.errors.into_result()
912    }
913}
914
915impl ScopedVisitor for IllegalBinderVisitor<'_, '_, '_> {
916    fn is_box(&self, segment: &surface::PathSegment) -> bool {
917        self.resolver
918            .output
919            .path_res_map
920            .get(&segment.node_id)
921            .map(|r| r.is_box(self.resolver.genv.tcx()))
922            .unwrap_or(false)
923    }
924
925    fn enter_scope(&mut self, kind: ScopeKind) -> ControlFlow<()> {
926        self.scopes.push(kind);
927        ControlFlow::Continue(())
928    }
929
930    fn exit_scope(&mut self) {
931        self.scopes.pop();
932    }
933
934    fn on_implicit_param(&mut self, ident: Ident, param_kind: fhir::ParamKind, _: NodeId) {
935        let Some(scope_kind) = self.scopes.last() else { return };
936        let (allowed, bind_kind) = match param_kind {
937            fhir::ParamKind::At => {
938                (
939                    matches!(
940                        scope_kind,
941                        ScopeKind::FnInput | ScopeKind::FnTraitInput | ScopeKind::Variant
942                    ),
943                    surface::BindKind::At,
944                )
945            }
946            fhir::ParamKind::Pound => {
947                (matches!(scope_kind, ScopeKind::FnOutput), surface::BindKind::Pound)
948            }
949            fhir::ParamKind::Colon
950            | fhir::ParamKind::Loc
951            | fhir::ParamKind::Error
952            | fhir::ParamKind::Explicit(..) => return,
953        };
954        if !allowed {
955            self.errors
956                .emit(errors::IllegalBinder::new(ident.span, bind_kind));
957        }
958    }
959}
960
961mod errors {
962    use flux_errors::E0999;
963    use flux_macros::Diagnostic;
964    use flux_syntax::surface;
965    use itertools::Itertools;
966    use rustc_span::{Span, Symbol, symbol::Ident};
967
968    #[derive(Diagnostic)]
969    #[diag(desugar_duplicate_param, code = E0999)]
970    pub(super) struct DuplicateParam {
971        #[primary_span]
972        #[label]
973        span: Span,
974        name: Symbol,
975        #[label(desugar_first_use)]
976        first_use: Span,
977    }
978
979    impl DuplicateParam {
980        pub(super) fn new(old_ident: Ident, new_ident: Ident) -> Self {
981            debug_assert_eq!(old_ident.name, new_ident.name);
982            Self { span: new_ident.span, name: new_ident.name, first_use: old_ident.span }
983        }
984    }
985
986    #[derive(Diagnostic)]
987    #[diag(desugar_unresolved_sort, code = E0999)]
988    pub(super) struct UnresolvedSort {
989        #[primary_span]
990        #[label]
991        span: Span,
992        name: String,
993    }
994
995    impl UnresolvedSort {
996        pub(super) fn new(path: &surface::SortPath) -> Self {
997            Self {
998                span: path
999                    .segments
1000                    .iter()
1001                    .map(|ident| ident.span)
1002                    .reduce(Span::to)
1003                    .unwrap_or_default(),
1004                name: format!("{}", path.segments.iter().format("::")),
1005            }
1006        }
1007    }
1008
1009    #[derive(Diagnostic)]
1010    #[diag(desugar_unresolved_var, code = E0999)]
1011    pub(super) struct UnresolvedVar {
1012        #[primary_span]
1013        #[label]
1014        span: Span,
1015        var: String,
1016    }
1017
1018    impl UnresolvedVar {
1019        pub(super) fn from_path(path: &surface::ExprPath) -> Self {
1020            Self {
1021                span: path.span,
1022                var: format!(
1023                    "{}",
1024                    path.segments
1025                        .iter()
1026                        .format_with("::", |s, f| f(&s.ident.name))
1027                ),
1028            }
1029        }
1030
1031        pub(super) fn from_ident(ident: Ident) -> Self {
1032            Self { span: ident.span, var: format!("{ident}") }
1033        }
1034    }
1035
1036    #[derive(Diagnostic)]
1037    #[diag(desugar_invalid_unrefined_param, code = E0999)]
1038    pub(super) struct InvalidUnrefinedParam {
1039        #[primary_span]
1040        #[label]
1041        span: Span,
1042        var: Ident,
1043    }
1044
1045    impl InvalidUnrefinedParam {
1046        pub(super) fn new(var: Ident) -> Self {
1047            Self { var, span: var.span }
1048        }
1049    }
1050
1051    #[derive(Diagnostic)]
1052    #[diag(desugar_illegal_binder, code = E0999)]
1053    pub(super) struct IllegalBinder {
1054        #[primary_span]
1055        #[label]
1056        span: Span,
1057        kind: &'static str,
1058    }
1059
1060    impl IllegalBinder {
1061        pub(super) fn new(span: Span, kind: surface::BindKind) -> Self {
1062            Self { span, kind: kind.token_str() }
1063        }
1064    }
1065}