flux_desugar/resolver/
refinement_resolver.rs

1use std::ops::ControlFlow;
2
3use flux_common::index::IndexGen;
4use flux_errors::Errors;
5use flux_middle::{
6    ResolverOutput,
7    fhir::{self, ExprRes},
8};
9use flux_syntax::{
10    surface::{self, Ident, NodeId, visit::Visitor as _},
11    walk_list,
12};
13use rustc_data_structures::{
14    fx::{FxIndexMap, FxIndexSet, IndexEntry},
15    unord::UnordMap,
16};
17use rustc_hash::FxHashMap;
18use rustc_hir::def::{
19    DefKind,
20    Namespace::{TypeNS, ValueNS},
21};
22use rustc_middle::ty::TyCtxt;
23use rustc_span::{ErrorGuaranteed, Symbol, sym};
24
25use super::{CrateResolver, Segment};
26
27type Result<T = ()> = std::result::Result<T, ErrorGuaranteed>;
28
29#[derive(Clone, Copy, PartialEq, Eq, Debug)]
30pub(crate) enum ScopeKind {
31    FnInput,
32    FnOutput,
33    Variant,
34    Misc,
35    FnTraitInput,
36}
37
38impl ScopeKind {
39    fn is_barrier(self) -> bool {
40        matches!(self, ScopeKind::FnInput | ScopeKind::Variant)
41    }
42}
43
44/// Parameters used during gathering.
45#[derive(Debug, Clone, Copy)]
46struct ParamRes(fhir::ParamKind, NodeId);
47
48impl ParamRes {
49    fn kind(self) -> fhir::ParamKind {
50        self.0
51    }
52
53    fn param_id(self) -> NodeId {
54        self.1
55    }
56}
57
58pub(crate) trait ScopedVisitor: Sized {
59    fn is_box(&self, segment: &surface::PathSegment) -> bool;
60    fn enter_scope(&mut self, kind: ScopeKind) -> ControlFlow<()>;
61    fn exit_scope(&mut self) {}
62
63    fn wrap(self) -> ScopedVisitorWrapper<Self> {
64        ScopedVisitorWrapper(self)
65    }
66
67    fn on_implicit_param(&mut self, _ident: Ident, _kind: fhir::ParamKind, _node_id: NodeId) {}
68    fn on_generic_param(&mut self, _param: &surface::GenericParam) {}
69    fn on_refine_param(&mut self, _param: &surface::RefineParam) {}
70    fn on_enum_variant(&mut self, _variant: &surface::VariantDef) {}
71    fn on_fn_trait_input(&mut self, _in_arg: &surface::GenericArg, _node_id: NodeId) {}
72    fn on_fn_sig(&mut self, _fn_sig: &surface::FnSig) {}
73    fn on_fn_output(&mut self, _output: &surface::FnOutput) {}
74    fn on_loc(&mut self, _loc: Ident, _node_id: NodeId) {}
75    fn on_path(&mut self, _path: &surface::ExprPath) {}
76    fn on_base_sort(&mut self, _sort: &surface::BaseSort) {}
77}
78
79pub(crate) struct ScopedVisitorWrapper<V>(V);
80
81impl<V: ScopedVisitor> ScopedVisitorWrapper<V> {
82    fn with_scope(&mut self, kind: ScopeKind, f: impl FnOnce(&mut Self)) {
83        let scope = self.0.enter_scope(kind);
84        if let ControlFlow::Continue(_) = scope {
85            f(self);
86            self.0.exit_scope();
87        }
88    }
89}
90
91impl<V> std::ops::Deref for ScopedVisitorWrapper<V> {
92    type Target = V;
93
94    fn deref(&self) -> &Self::Target {
95        &self.0
96    }
97}
98impl<V> std::ops::DerefMut for ScopedVisitorWrapper<V> {
99    fn deref_mut(&mut self) -> &mut Self::Target {
100        &mut self.0
101    }
102}
103
104impl<V: ScopedVisitor> surface::visit::Visitor for ScopedVisitorWrapper<V> {
105    fn visit_trait_assoc_reft(&mut self, assoc_reft: &surface::TraitAssocReft) {
106        self.with_scope(ScopeKind::Misc, |this| {
107            surface::visit::walk_trait_assoc_reft(this, assoc_reft);
108        });
109    }
110
111    fn visit_impl_assoc_reft(&mut self, assoc_reft: &surface::ImplAssocReft) {
112        self.with_scope(ScopeKind::Misc, |this| {
113            surface::visit::walk_impl_assoc_reft(this, assoc_reft);
114        });
115    }
116
117    fn visit_qualifier(&mut self, qualifier: &surface::Qualifier) {
118        self.with_scope(ScopeKind::Misc, |this| {
119            surface::visit::walk_qualifier(this, qualifier);
120        });
121    }
122
123    fn visit_defn(&mut self, defn: &surface::SpecFunc) {
124        self.with_scope(ScopeKind::Misc, |this| {
125            surface::visit::walk_defn(this, defn);
126        });
127    }
128
129    fn visit_generic_param(&mut self, param: &surface::GenericParam) {
130        self.on_generic_param(param);
131        surface::visit::walk_generic_param(self, param);
132    }
133
134    fn visit_refine_param(&mut self, param: &surface::RefineParam) {
135        self.on_refine_param(param);
136        surface::visit::walk_refine_param(self, param);
137    }
138
139    fn visit_ty_alias(&mut self, ty_alias: &surface::TyAlias) {
140        self.with_scope(ScopeKind::Misc, |this| {
141            surface::visit::walk_ty_alias(this, ty_alias);
142        });
143    }
144
145    fn visit_struct_def(&mut self, struct_def: &surface::StructDef) {
146        self.with_scope(ScopeKind::Misc, |this| {
147            surface::visit::walk_struct_def(this, struct_def);
148        });
149    }
150
151    fn visit_enum_def(&mut self, enum_def: &surface::EnumDef) {
152        self.with_scope(ScopeKind::Misc, |this| {
153            surface::visit::walk_enum_def(this, enum_def);
154        });
155    }
156
157    fn visit_variant(&mut self, variant: &surface::VariantDef) {
158        self.with_scope(ScopeKind::Variant, |this| {
159            this.on_enum_variant(variant);
160            surface::visit::walk_variant(this, variant);
161        });
162    }
163
164    fn visit_trait_ref(&mut self, trait_ref: &surface::TraitRef) {
165        match trait_ref.as_fn_trait_ref() {
166            Some((in_arg, out_arg)) => {
167                self.with_scope(ScopeKind::FnTraitInput, |this| {
168                    this.on_fn_trait_input(in_arg, trait_ref.node_id);
169                    surface::visit::walk_generic_arg(this, in_arg);
170                    this.with_scope(ScopeKind::Misc, |this| {
171                        surface::visit::walk_generic_arg(this, out_arg);
172                    });
173                });
174            }
175            None => {
176                self.with_scope(ScopeKind::Misc, |this| {
177                    surface::visit::walk_trait_ref(this, trait_ref);
178                });
179            }
180        }
181    }
182
183    fn visit_variant_ret(&mut self, ret: &surface::VariantRet) {
184        self.with_scope(ScopeKind::Misc, |this| {
185            surface::visit::walk_variant_ret(this, ret);
186        });
187    }
188
189    fn visit_generics(&mut self, generics: &surface::Generics) {
190        self.with_scope(ScopeKind::Misc, |this| {
191            surface::visit::walk_generics(this, generics);
192        });
193    }
194
195    fn visit_fn_sig(&mut self, fn_sig: &surface::FnSig) {
196        self.with_scope(ScopeKind::FnInput, |this| {
197            this.on_fn_sig(fn_sig);
198            surface::visit::walk_fn_sig(this, fn_sig);
199        });
200    }
201
202    fn visit_fn_output(&mut self, output: &surface::FnOutput) {
203        self.with_scope(ScopeKind::FnOutput, |this| {
204            this.on_fn_output(output);
205            surface::visit::walk_fn_output(this, output);
206        });
207    }
208
209    fn visit_fn_input(&mut self, arg: &surface::FnInput) {
210        match arg {
211            surface::FnInput::Constr(bind, _, _, node_id) => {
212                self.on_implicit_param(*bind, fhir::ParamKind::Colon, *node_id);
213            }
214            surface::FnInput::StrgRef(loc, _, node_id) => {
215                self.on_implicit_param(*loc, fhir::ParamKind::Loc, *node_id);
216            }
217            surface::FnInput::Ty(bind, ty, node_id) => {
218                if let &Some(bind) = bind {
219                    let param_kind = if let surface::TyKind::Base(_) = &ty.kind {
220                        fhir::ParamKind::Colon
221                    } else {
222                        fhir::ParamKind::Error
223                    };
224                    self.on_implicit_param(bind, param_kind, *node_id);
225                }
226            }
227        }
228        surface::visit::walk_fn_input(self, arg);
229    }
230
231    fn visit_ensures(&mut self, constraint: &surface::Ensures) {
232        if let surface::Ensures::Type(loc, _, node_id) = constraint {
233            self.on_loc(*loc, *node_id);
234        }
235        surface::visit::walk_ensures(self, constraint);
236    }
237
238    fn visit_refine_arg(&mut self, arg: &surface::RefineArg) {
239        match arg {
240            surface::RefineArg::Bind(ident, kind, _, node_id) => {
241                let kind = match kind {
242                    surface::BindKind::At => fhir::ParamKind::At,
243                    surface::BindKind::Pound => fhir::ParamKind::Pound,
244                };
245                self.on_implicit_param(*ident, kind, *node_id);
246            }
247            surface::RefineArg::Abs(..) => {
248                self.with_scope(ScopeKind::Misc, |this| {
249                    surface::visit::walk_refine_arg(this, arg);
250                });
251            }
252            surface::RefineArg::Expr(expr) => self.visit_expr(expr),
253        }
254    }
255
256    fn visit_path(&mut self, path: &surface::Path) {
257        for arg in &path.refine {
258            self.with_scope(ScopeKind::Misc, |this| this.visit_refine_arg(arg));
259        }
260        walk_list!(self, visit_path_segment, &path.segments);
261    }
262
263    fn visit_path_segment(&mut self, segment: &surface::PathSegment) {
264        let is_box = self.is_box(segment);
265        for (i, arg) in segment.args.iter().enumerate() {
266            if is_box && i == 0 {
267                self.visit_generic_arg(arg);
268            } else {
269                self.with_scope(ScopeKind::Misc, |this| this.visit_generic_arg(arg));
270            }
271        }
272    }
273
274    fn visit_ty(&mut self, ty: &surface::Ty) {
275        let node_id = ty.node_id;
276        match &ty.kind {
277            surface::TyKind::Exists { bind, .. } => {
278                self.with_scope(ScopeKind::Misc, |this| {
279                    let param = surface::RefineParam {
280                        ident: *bind,
281                        mode: None,
282                        sort: surface::Sort::Infer,
283                        node_id,
284                        span: bind.span,
285                    };
286                    this.on_refine_param(&param);
287                    surface::visit::walk_ty(this, ty);
288                });
289            }
290            surface::TyKind::GeneralExists { .. } => {
291                self.with_scope(ScopeKind::Misc, |this| {
292                    surface::visit::walk_ty(this, ty);
293                });
294            }
295            surface::TyKind::Array(..) => {
296                self.with_scope(ScopeKind::Misc, |this| {
297                    surface::visit::walk_ty(this, ty);
298                });
299            }
300            _ => surface::visit::walk_ty(self, ty),
301        }
302    }
303
304    fn visit_bty(&mut self, bty: &surface::BaseTy) {
305        match &bty.kind {
306            surface::BaseTyKind::Slice(_) => {
307                self.with_scope(ScopeKind::Misc, |this| {
308                    surface::visit::walk_bty(this, bty);
309                });
310            }
311            surface::BaseTyKind::Path(..) => {
312                surface::visit::walk_bty(self, bty);
313            }
314        }
315    }
316
317    fn visit_path_expr(&mut self, path: &surface::ExprPath) {
318        self.on_path(path);
319    }
320
321    fn visit_base_sort(&mut self, bsort: &surface::BaseSort) {
322        self.on_base_sort(bsort);
323        surface::visit::walk_base_sort(self, bsort);
324    }
325}
326
327struct ImplicitParamCollector<'a, 'tcx> {
328    tcx: TyCtxt<'tcx>,
329    path_res_map: &'a UnordMap<surface::NodeId, fhir::PartialRes>,
330    kind: ScopeKind,
331    params: Vec<(Ident, fhir::ParamKind, NodeId)>,
332}
333
334impl<'a, 'tcx> ImplicitParamCollector<'a, 'tcx> {
335    fn new(
336        tcx: TyCtxt<'tcx>,
337        path_res_map: &'a UnordMap<surface::NodeId, fhir::PartialRes>,
338        kind: ScopeKind,
339    ) -> Self {
340        Self { tcx, path_res_map, kind, params: vec![] }
341    }
342
343    fn run(
344        self,
345        f: impl FnOnce(&mut ScopedVisitorWrapper<Self>),
346    ) -> Vec<(Ident, fhir::ParamKind, NodeId)> {
347        let mut wrapped = self.wrap();
348        f(&mut wrapped);
349        wrapped.0.params
350    }
351}
352
353impl ScopedVisitor for ImplicitParamCollector<'_, '_> {
354    fn is_box(&self, segment: &surface::PathSegment) -> bool {
355        self.path_res_map
356            .get(&segment.node_id)
357            .map(|r| r.is_box(self.tcx))
358            .unwrap_or(false)
359    }
360
361    fn enter_scope(&mut self, kind: ScopeKind) -> ControlFlow<()> {
362        if self.kind == kind { ControlFlow::Continue(()) } else { ControlFlow::Break(()) }
363    }
364
365    fn on_implicit_param(&mut self, ident: Ident, param: fhir::ParamKind, node_id: NodeId) {
366        self.params.push((ident, param, node_id));
367    }
368}
369
370struct Scope {
371    kind: ScopeKind,
372    bindings: FxIndexMap<Ident, ParamRes>,
373}
374
375impl Scope {
376    fn new(kind: ScopeKind) -> Self {
377        Self { kind, bindings: Default::default() }
378    }
379}
380
381#[derive(Clone, Copy)]
382struct ParamDef {
383    ident: Ident,
384    kind: fhir::ParamKind,
385    scope: Option<NodeId>,
386}
387
388pub(crate) struct RefinementResolver<'a, 'genv, 'tcx> {
389    scopes: Vec<Scope>,
390    sort_params: FxIndexSet<Symbol>,
391    param_defs: FxIndexMap<NodeId, ParamDef>,
392    resolver: &'a mut CrateResolver<'genv, 'tcx>,
393    path_res_map: FxHashMap<NodeId, ExprRes<NodeId>>,
394    errors: Errors<'genv>,
395}
396
397impl<'a, 'genv, 'tcx> RefinementResolver<'a, 'genv, 'tcx> {
398    pub(crate) fn for_flux_item(
399        resolver: &'a mut CrateResolver<'genv, 'tcx>,
400        sort_params: &[Ident],
401    ) -> Self {
402        Self::new(resolver, sort_params.iter().map(|ident| ident.name).collect())
403    }
404
405    pub(crate) fn for_rust_item(resolver: &'a mut CrateResolver<'genv, 'tcx>) -> Self {
406        Self::new(resolver, Default::default())
407    }
408
409    pub(crate) fn resolve_qualifier(
410        resolver: &'a mut CrateResolver<'genv, 'tcx>,
411        qualifier: &surface::Qualifier,
412    ) -> Result {
413        Self::for_flux_item(resolver, &[]).run(|r| r.visit_qualifier(qualifier))
414    }
415
416    pub(crate) fn resolve_defn(
417        resolver: &'a mut CrateResolver<'genv, 'tcx>,
418        defn: &surface::SpecFunc,
419    ) -> Result {
420        Self::for_flux_item(resolver, &defn.sort_vars).run(|r| r.visit_defn(defn))
421    }
422
423    pub(crate) fn resolve_fn_sig(
424        resolver: &'a mut CrateResolver<'genv, 'tcx>,
425        fn_sig: &surface::FnSig,
426    ) -> Result {
427        IllegalBinderVisitor::new(resolver).run(|vis| vis.visit_fn_sig(fn_sig))?;
428        Self::for_rust_item(resolver).run(|vis| vis.visit_fn_sig(fn_sig))
429    }
430
431    pub(crate) fn resolve_struct_def(
432        resolver: &'a mut CrateResolver<'genv, 'tcx>,
433        struct_def: &surface::StructDef,
434    ) -> Result {
435        IllegalBinderVisitor::new(resolver).run(|vis| vis.visit_struct_def(struct_def))?;
436        Self::for_rust_item(resolver).run(|vis| vis.visit_struct_def(struct_def))
437    }
438
439    pub(crate) fn resolve_enum_def(
440        resolver: &'a mut CrateResolver<'genv, 'tcx>,
441        enum_def: &surface::EnumDef,
442    ) -> Result {
443        IllegalBinderVisitor::new(resolver).run(|vis| vis.visit_enum_def(enum_def))?;
444        Self::for_rust_item(resolver).run(|vis| vis.visit_enum_def(enum_def))
445    }
446
447    pub(crate) fn resolve_constant(
448        resolver: &'a mut CrateResolver<'genv, 'tcx>,
449        constant_info: &surface::ConstantInfo,
450    ) -> Result {
451        IllegalBinderVisitor::new(resolver).run(|vis| vis.visit_constant(constant_info))?;
452        Self::for_rust_item(resolver).run(|vis| vis.visit_constant(constant_info))
453    }
454
455    pub(crate) fn resolve_ty_alias(
456        resolver: &'a mut CrateResolver<'genv, 'tcx>,
457        ty_alias: &surface::TyAlias,
458    ) -> Result {
459        IllegalBinderVisitor::new(resolver).run(|vis| vis.visit_ty_alias(ty_alias))?;
460        Self::for_rust_item(resolver).run(|vis| vis.visit_ty_alias(ty_alias))
461    }
462
463    pub(crate) fn resolve_impl(
464        resolver: &'a mut CrateResolver<'genv, 'tcx>,
465        impl_: &surface::Impl,
466    ) -> Result {
467        IllegalBinderVisitor::new(resolver).run(|vis| vis.visit_impl(impl_))?;
468        Self::for_rust_item(resolver).run(|vis| vis.visit_impl(impl_))
469    }
470
471    pub(crate) fn resolve_trait(
472        resolver: &'a mut CrateResolver<'genv, 'tcx>,
473        trait_: &surface::Trait,
474    ) -> Result {
475        IllegalBinderVisitor::new(resolver).run(|vis| vis.visit_trait(trait_))?;
476        Self::for_rust_item(resolver).run(|vis| vis.visit_trait(trait_))
477    }
478
479    fn new(resolver: &'a mut CrateResolver<'genv, 'tcx>, sort_params: FxIndexSet<Symbol>) -> Self {
480        let errors = Errors::new(resolver.genv.sess());
481        Self {
482            resolver,
483            sort_params,
484            param_defs: Default::default(),
485            scopes: Default::default(),
486            path_res_map: Default::default(),
487            errors,
488        }
489    }
490
491    fn run(self, f: impl FnOnce(&mut ScopedVisitorWrapper<Self>)) -> Result {
492        let mut wrapper = self.wrap();
493        f(&mut wrapper);
494        wrapper.0.finish()
495    }
496
497    fn define_param(
498        &mut self,
499        ident: Ident,
500        kind: fhir::ParamKind,
501        param_id: NodeId,
502        scope: Option<NodeId>,
503    ) {
504        self.param_defs
505            .insert(param_id, ParamDef { ident, kind, scope });
506
507        let scope = self.scopes.last_mut().unwrap();
508        match scope.bindings.entry(ident) {
509            IndexEntry::Occupied(entry) => {
510                let param_def = self.param_defs[&entry.get().param_id()];
511                self.errors
512                    .emit(errors::DuplicateParam::new(param_def.ident, ident));
513            }
514            IndexEntry::Vacant(entry) => {
515                entry.insert(ParamRes(kind, param_id));
516            }
517        }
518    }
519
520    fn find(&mut self, ident: Ident) -> Option<ParamRes> {
521        for scope in self.scopes.iter().rev() {
522            if let Some(res) = scope.bindings.get(&ident) {
523                return Some(*res);
524            }
525
526            if scope.kind.is_barrier() {
527                return None;
528            }
529        }
530        None
531    }
532
533    fn try_resolve_enum_variant(&mut self, typ: Ident, variant: Ident) -> Option<ExprRes<NodeId>> {
534        if let fhir::Res::Def(_, enum_def_id) =
535            self.resolver.resolve_ident_with_ribs(typ, TypeNS)?
536        {
537            let enum_variants = self.resolver.enum_variants.get(&enum_def_id)?;
538            let variant_def_id = enum_variants.variants.get(&variant.name)?;
539            return Some(ExprRes::Variant(*variant_def_id));
540        }
541        None
542    }
543
544    fn resolve_path(&mut self, path: &surface::ExprPath) {
545        if let [segment] = &path.segments[..]
546            && let Some(res) = self.try_resolve_param(segment.ident)
547        {
548            self.path_res_map.insert(path.node_id, res);
549            return;
550        }
551        if let Some(res) = self.try_resolve_expr_with_ribs(&path.segments) {
552            self.path_res_map.insert(path.node_id, res);
553            return;
554        }
555        // TODO(nilehmann) move this to resolve_with_ribs
556        if let [typ, name] = &path.segments[..]
557            && let Some(res) = resolve_num_const(typ.ident, name.ident)
558        {
559            self.path_res_map.insert(path.node_id, res);
560            return;
561        }
562        if let [typ, name] = &path.segments[..]
563            && let Some(res) = self.try_resolve_enum_variant(typ.ident, name.ident)
564        {
565            self.path_res_map.insert(path.node_id, res);
566            return;
567        }
568        if let [segment] = &path.segments[..]
569            && let Some(res) = self.try_resolve_global_func(segment.ident)
570        {
571            self.path_res_map.insert(path.node_id, res);
572            return;
573        }
574
575        self.errors.emit(errors::UnresolvedVar::from_path(path));
576    }
577
578    fn resolve_ident(&mut self, ident: Ident, node_id: NodeId) {
579        if let Some(res) = self.try_resolve_param(ident) {
580            self.path_res_map.insert(node_id, res);
581            return;
582        }
583        if let Some(res) = self.try_resolve_expr_with_ribs(&[ident]) {
584            self.path_res_map.insert(node_id, res);
585            return;
586        }
587        if let Some(res) = self.try_resolve_global_func(ident) {
588            self.path_res_map.insert(node_id, res);
589            return;
590        }
591        self.errors.emit(errors::UnresolvedVar::from_ident(ident));
592    }
593
594    fn try_resolve_expr_with_ribs<S: Segment>(
595        &mut self,
596        segments: &[S],
597    ) -> Option<ExprRes<NodeId>> {
598        let path = self.resolver.resolve_path_with_ribs(segments, ValueNS);
599
600        let res = match path {
601            Some(r) => r.full_res()?,
602            _ => {
603                self.resolver
604                    .resolve_path_with_ribs(segments, TypeNS)?
605                    .full_res()?
606            }
607        };
608        match res {
609            fhir::Res::Def(DefKind::ConstParam, def_id) => Some(ExprRes::ConstGeneric(def_id)),
610            fhir::Res::Def(DefKind::Const, def_id) => Some(ExprRes::Const(def_id)),
611            fhir::Res::Def(DefKind::Struct | DefKind::Enum, def_id) => Some(ExprRes::Ctor(def_id)),
612            fhir::Res::Def(DefKind::Variant, def_id) => Some(ExprRes::Variant(def_id)),
613            _ => None,
614        }
615    }
616
617    fn try_resolve_param(&mut self, ident: Ident) -> Option<ExprRes<NodeId>> {
618        let res = self.find(ident)?;
619
620        if let fhir::ParamKind::Error = res.kind() {
621            self.errors.emit(errors::InvalidUnrefinedParam::new(ident));
622        }
623        Some(ExprRes::Param(res.kind(), res.param_id()))
624    }
625
626    fn try_resolve_global_func(&mut self, ident: Ident) -> Option<ExprRes<NodeId>> {
627        let kind = self.resolver.func_decls.get(&ident.name)?;
628        Some(ExprRes::GlobalFunc(*kind))
629    }
630
631    fn resolve_sort_path(&mut self, path: &surface::SortPath) {
632        let res = self
633            .try_resolve_sort_param(path)
634            .or_else(|| self.try_resolve_sort_with_ribs(path))
635            .or_else(|| self.try_resolve_user_sort(path))
636            .or_else(|| self.try_resolve_prim_sort(path));
637
638        if let Some(res) = res {
639            self.resolver
640                .output
641                .sort_path_res_map
642                .insert(path.node_id, res);
643        } else {
644            self.errors.emit(errors::UnresolvedSort::new(path));
645        }
646    }
647
648    fn try_resolve_sort_param(&self, path: &surface::SortPath) -> Option<fhir::SortRes> {
649        let [segment] = &path.segments[..] else { return None };
650        self.sort_params
651            .get_index_of(&segment.name)
652            .map(fhir::SortRes::SortParam)
653    }
654
655    fn try_resolve_sort_with_ribs(&mut self, path: &surface::SortPath) -> Option<fhir::SortRes> {
656        let partial_res = self
657            .resolver
658            .resolve_path_with_ribs(&path.segments, TypeNS)?;
659        match (partial_res.base_res(), partial_res.unresolved_segments()) {
660            (fhir::Res::Def(DefKind::Struct | DefKind::Enum, def_id), 0) => {
661                Some(fhir::SortRes::Adt(def_id))
662            }
663            (fhir::Res::Def(DefKind::TyParam, def_id), 0) => Some(fhir::SortRes::TyParam(def_id)),
664            (fhir::Res::SelfTyParam { trait_ }, 0) => {
665                Some(fhir::SortRes::SelfParam { trait_id: trait_ })
666            }
667            (fhir::Res::SelfTyParam { trait_ }, 1) => {
668                let ident = *path.segments.last().unwrap();
669                Some(fhir::SortRes::SelfParamAssoc { trait_id: trait_, ident })
670            }
671            (fhir::Res::SelfTyAlias { alias_to, .. }, 0) => {
672                Some(fhir::SortRes::SelfAlias { alias_to })
673            }
674            _ => None,
675        }
676    }
677
678    fn try_resolve_user_sort(&self, path: &surface::SortPath) -> Option<fhir::SortRes> {
679        let [segment] = &path.segments[..] else { return None };
680        self.resolver
681            .sort_decls
682            .get(&segment.name)
683            .map(|decl| fhir::SortRes::User { name: decl.name })
684    }
685
686    fn try_resolve_prim_sort(&self, path: &surface::SortPath) -> Option<fhir::SortRes> {
687        let [segment] = &path.segments[..] else { return None };
688        if segment.name == SORTS.int {
689            Some(fhir::SortRes::PrimSort(fhir::PrimSort::Int))
690        } else if segment.name == sym::bool {
691            Some(fhir::SortRes::PrimSort(fhir::PrimSort::Bool))
692        } else if segment.name == SORTS.real {
693            Some(fhir::SortRes::PrimSort(fhir::PrimSort::Real))
694        } else if segment.name == SORTS.set {
695            Some(fhir::SortRes::PrimSort(fhir::PrimSort::Set))
696        } else if segment.name == SORTS.map {
697            Some(fhir::SortRes::PrimSort(fhir::PrimSort::Map))
698        } else {
699            None
700        }
701    }
702
703    pub(crate) fn finish(self) -> Result {
704        let param_id_gen = IndexGen::new();
705        let mut params = FxIndexMap::default();
706
707        // Create an `fhir::ParamId` for all parameters used in a path before iterating over
708        // `param_defs` such that we can skip `fhir::ParamKind::Colon` if the param wasn't used
709        for (node_id, res) in self.path_res_map {
710            let res = res.map_param_id(|param_id| {
711                *params
712                    .entry(param_id)
713                    .or_insert_with(|| param_id_gen.fresh())
714            });
715            self.resolver.output.expr_path_res_map.insert(node_id, res);
716        }
717
718        // At this point, the `params` map contains all parameters that were used in an expression,
719        // so we can safely skip `ParamKind::Colon` if there's no entry for it.
720        for (param_id, param_def) in self.param_defs {
721            let name = match param_def.kind {
722                fhir::ParamKind::Colon => {
723                    let Some(name) = params.get(&param_id) else { continue };
724                    *name
725                }
726                fhir::ParamKind::Error => continue,
727                _ => {
728                    params
729                        .get(&param_id)
730                        .copied()
731                        .unwrap_or_else(|| param_id_gen.fresh())
732                }
733            };
734            let output = &mut self.resolver.output;
735            output
736                .param_res_map
737                .insert(param_id, (name, param_def.kind));
738
739            if let Some(scope) = param_def.scope {
740                output
741                    .implicit_params
742                    .entry(scope)
743                    .or_default()
744                    .push((param_def.ident, param_id));
745            }
746        }
747        self.errors.into_result()
748    }
749
750    fn resolver_output(&self) -> &ResolverOutput {
751        &self.resolver.output
752    }
753}
754
755impl ScopedVisitor for RefinementResolver<'_, '_, '_> {
756    fn is_box(&self, segment: &surface::PathSegment) -> bool {
757        self.resolver_output()
758            .path_res_map
759            .get(&segment.node_id)
760            .map(|r| r.is_box(self.resolver.genv.tcx()))
761            .unwrap_or(false)
762    }
763
764    fn enter_scope(&mut self, kind: ScopeKind) -> ControlFlow<()> {
765        self.scopes.push(Scope::new(kind));
766        ControlFlow::Continue(())
767    }
768
769    fn exit_scope(&mut self) {
770        self.scopes.pop();
771    }
772
773    fn on_fn_trait_input(&mut self, in_arg: &surface::GenericArg, trait_node_id: NodeId) {
774        let params = ImplicitParamCollector::new(
775            self.resolver.genv.tcx(),
776            &self.resolver.output.path_res_map,
777            ScopeKind::FnTraitInput,
778        )
779        .run(|vis| vis.visit_generic_arg(in_arg));
780        for (ident, kind, node_id) in params {
781            self.define_param(ident, kind, node_id, Some(trait_node_id));
782        }
783    }
784
785    fn on_enum_variant(&mut self, variant: &surface::VariantDef) {
786        let params = ImplicitParamCollector::new(
787            self.resolver.genv.tcx(),
788            &self.resolver.output.path_res_map,
789            ScopeKind::Variant,
790        )
791        .run(|vis| vis.visit_variant(variant));
792        for (ident, kind, node_id) in params {
793            self.define_param(ident, kind, node_id, Some(variant.node_id));
794        }
795    }
796
797    fn on_fn_sig(&mut self, fn_sig: &surface::FnSig) {
798        let params = ImplicitParamCollector::new(
799            self.resolver.genv.tcx(),
800            &self.resolver.output.path_res_map,
801            ScopeKind::FnInput,
802        )
803        .run(|vis| vis.visit_fn_sig(fn_sig));
804        for (ident, kind, param_id) in params {
805            self.define_param(ident, kind, param_id, Some(fn_sig.node_id));
806        }
807    }
808
809    fn on_fn_output(&mut self, output: &surface::FnOutput) {
810        let params = ImplicitParamCollector::new(
811            self.resolver.genv.tcx(),
812            &self.resolver.output.path_res_map,
813            ScopeKind::FnOutput,
814        )
815        .run(|vis| vis.visit_fn_output(output));
816        for (ident, kind, param_id) in params {
817            self.define_param(ident, kind, param_id, Some(output.node_id));
818        }
819    }
820
821    fn on_refine_param(&mut self, param: &surface::RefineParam) {
822        self.define_param(param.ident, fhir::ParamKind::Explicit(param.mode), param.node_id, None);
823    }
824
825    fn on_loc(&mut self, loc: Ident, node_id: NodeId) {
826        self.resolve_ident(loc, node_id);
827    }
828
829    fn on_path(&mut self, path: &surface::ExprPath) {
830        self.resolve_path(path);
831    }
832
833    fn on_base_sort(&mut self, sort: &surface::BaseSort) {
834        match sort {
835            surface::BaseSort::Path(path) => {
836                self.resolve_sort_path(path);
837            }
838            surface::BaseSort::BitVec(_) => {}
839            surface::BaseSort::SortOf(..) => {}
840        }
841    }
842}
843
844macro_rules! define_resolve_num_const {
845    ($($typ:ident),*) => {
846        fn resolve_num_const(typ: surface::Ident, name: surface::Ident) -> Option<ExprRes<NodeId>> {
847            match typ.name.as_str() {
848                $(
849                    stringify!($typ) => {
850                        match name.name.as_str() {
851                            "MAX" => Some(ExprRes::NumConst($typ::MAX.try_into().unwrap())),
852                            "MIN" => Some(ExprRes::NumConst($typ::MIN.try_into().unwrap())),
853                            _ => None,
854                        }
855                    },
856                )*
857                _ => None
858            }
859        }
860    };
861}
862
863define_resolve_num_const!(i8, i16, i32, i64, isize, u8, u16, u32, u64, usize);
864
865pub(crate) struct Sorts {
866    pub int: Symbol,
867    pub real: Symbol,
868    pub set: Symbol,
869    pub map: Symbol,
870}
871
872pub(crate) static SORTS: std::sync::LazyLock<Sorts> = std::sync::LazyLock::new(|| {
873    Sorts {
874        int: Symbol::intern("int"),
875        real: Symbol::intern("real"),
876        set: Symbol::intern("Set"),
877        map: Symbol::intern("Map"),
878    }
879});
880
881struct IllegalBinderVisitor<'a, 'genv, 'tcx> {
882    scopes: Vec<ScopeKind>,
883    resolver: &'a CrateResolver<'genv, 'tcx>,
884    errors: Errors<'genv>,
885}
886
887impl<'a, 'genv, 'tcx> IllegalBinderVisitor<'a, 'genv, 'tcx> {
888    fn new(resolver: &'a mut CrateResolver<'genv, 'tcx>) -> Self {
889        let errors = Errors::new(resolver.genv.sess());
890        Self { scopes: vec![], resolver, errors }
891    }
892
893    fn run(self, f: impl FnOnce(&mut ScopedVisitorWrapper<Self>)) -> Result {
894        let mut vis = self.wrap();
895        f(&mut vis);
896        vis.0.errors.into_result()
897    }
898}
899
900impl ScopedVisitor for IllegalBinderVisitor<'_, '_, '_> {
901    fn is_box(&self, segment: &surface::PathSegment) -> bool {
902        self.resolver
903            .output
904            .path_res_map
905            .get(&segment.node_id)
906            .map(|r| r.is_box(self.resolver.genv.tcx()))
907            .unwrap_or(false)
908    }
909
910    fn enter_scope(&mut self, kind: ScopeKind) -> ControlFlow<()> {
911        self.scopes.push(kind);
912        ControlFlow::Continue(())
913    }
914
915    fn exit_scope(&mut self) {
916        self.scopes.pop();
917    }
918
919    fn on_implicit_param(&mut self, ident: Ident, param_kind: fhir::ParamKind, _: NodeId) {
920        let Some(scope_kind) = self.scopes.last() else { return };
921        let (allowed, bind_kind) = match param_kind {
922            fhir::ParamKind::At => {
923                (
924                    matches!(
925                        scope_kind,
926                        ScopeKind::FnInput | ScopeKind::FnTraitInput | ScopeKind::Variant
927                    ),
928                    surface::BindKind::At,
929                )
930            }
931            fhir::ParamKind::Pound => {
932                (matches!(scope_kind, ScopeKind::FnOutput), surface::BindKind::Pound)
933            }
934            fhir::ParamKind::Colon
935            | fhir::ParamKind::Loc
936            | fhir::ParamKind::Error
937            | fhir::ParamKind::Explicit(..) => return,
938        };
939        if !allowed {
940            self.errors
941                .emit(errors::IllegalBinder::new(ident.span, bind_kind));
942        }
943    }
944}
945
946mod errors {
947    use flux_errors::E0999;
948    use flux_macros::Diagnostic;
949    use flux_syntax::surface;
950    use itertools::Itertools;
951    use rustc_span::{Span, Symbol, symbol::Ident};
952
953    #[derive(Diagnostic)]
954    #[diag(desugar_duplicate_param, code = E0999)]
955    pub(super) struct DuplicateParam {
956        #[primary_span]
957        #[label]
958        span: Span,
959        name: Symbol,
960        #[label(desugar_first_use)]
961        first_use: Span,
962    }
963
964    impl DuplicateParam {
965        pub(super) fn new(old_ident: Ident, new_ident: Ident) -> Self {
966            debug_assert_eq!(old_ident.name, new_ident.name);
967            Self { span: new_ident.span, name: new_ident.name, first_use: old_ident.span }
968        }
969    }
970
971    #[derive(Diagnostic)]
972    #[diag(desugar_unresolved_sort, code = E0999)]
973    pub(super) struct UnresolvedSort {
974        #[primary_span]
975        #[label]
976        span: Span,
977        name: String,
978    }
979
980    impl UnresolvedSort {
981        pub(super) fn new(path: &surface::SortPath) -> Self {
982            Self {
983                span: path
984                    .segments
985                    .iter()
986                    .map(|ident| ident.span)
987                    .reduce(Span::to)
988                    .unwrap_or_default(),
989                name: format!("{}", path.segments.iter().format("::")),
990            }
991        }
992    }
993
994    #[derive(Diagnostic)]
995    #[diag(desugar_unresolved_var, code = E0999)]
996    pub(super) struct UnresolvedVar {
997        #[primary_span]
998        #[label]
999        span: Span,
1000        var: String,
1001    }
1002
1003    impl UnresolvedVar {
1004        pub(super) fn from_path(path: &surface::ExprPath) -> Self {
1005            Self {
1006                span: path.span,
1007                var: format!(
1008                    "{}",
1009                    path.segments
1010                        .iter()
1011                        .format_with("::", |s, f| f(&s.ident.name))
1012                ),
1013            }
1014        }
1015
1016        pub(super) fn from_ident(ident: Ident) -> Self {
1017            Self { span: ident.span, var: format!("{ident}") }
1018        }
1019    }
1020
1021    #[derive(Diagnostic)]
1022    #[diag(desugar_invalid_unrefined_param, code = E0999)]
1023    pub(super) struct InvalidUnrefinedParam {
1024        #[primary_span]
1025        #[label]
1026        span: Span,
1027        var: Ident,
1028    }
1029
1030    impl InvalidUnrefinedParam {
1031        pub(super) fn new(var: Ident) -> Self {
1032            Self { var, span: var.span }
1033        }
1034    }
1035
1036    #[derive(Diagnostic)]
1037    #[diag(desugar_illegal_binder, code = E0999)]
1038    pub(super) struct IllegalBinder {
1039        #[primary_span]
1040        #[label]
1041        span: Span,
1042        kind: &'static str,
1043    }
1044
1045    impl IllegalBinder {
1046        pub(super) fn new(span: Span, kind: surface::BindKind) -> Self {
1047            Self { span, kind: kind.token_str() }
1048        }
1049    }
1050}