flux_infer/
projections.rs

1use std::iter;
2
3use flux_common::{bug, iter::IterExt, tracked_span_bug};
4use flux_middle::{
5    global_env::GlobalEnv,
6    queries::{QueryErr, QueryResult},
7    query_bug,
8    rty::{
9        self, AliasKind, AliasReft, AliasTy, BaseTy, Binder, Clause, ClauseKind, Const, ConstKind,
10        EarlyBinder, Expr, ExprKind, GenericArg, List, ProjectionPredicate, RefineArgs, Region,
11        Sort, SubsetTy, SubsetTyCtor, Ty, TyKind,
12        fold::{FallibleTypeFolder, TypeFoldable, TypeSuperFoldable, TypeVisitable},
13        refining::Refiner,
14        subst::{GenericsSubstDelegate, GenericsSubstFolder},
15    },
16};
17use flux_rustc_bridge::{ToRustc, lowering::Lower};
18use itertools::izip;
19use rustc_hir::def_id::DefId;
20use rustc_infer::traits::{ImplSourceUserDefinedData, Obligation, PredicateObligation};
21use rustc_middle::{
22    traits::{ImplSource, ObligationCause},
23    ty::{TyCtxt, Variance},
24};
25use rustc_trait_selection::{
26    solve::deeply_normalize,
27    traits::{FulfillmentError, SelectionContext},
28};
29use rustc_type_ir::TypeVisitableExt;
30
31use crate::{
32    fixpoint_encoding::KVarEncoding,
33    infer::{InferCtxtAt, InferResult},
34    refine_tree::Scope,
35};
36
37pub trait NormalizeExt: TypeFoldable {
38    fn normalize_projections(&self, infcx: &mut InferCtxtAt) -> QueryResult<Self>;
39
40    /// Normalize projections but only inside sorts
41    fn normalize_sorts<'tcx>(
42        &self,
43        def_id: DefId,
44        genv: GlobalEnv<'_, 'tcx>,
45        infcx: &rustc_infer::infer::InferCtxt<'tcx>,
46    ) -> QueryResult<Self>;
47}
48
49impl<T: TypeFoldable> NormalizeExt for T {
50    fn normalize_projections(&self, infcx: &mut InferCtxtAt) -> QueryResult<Self> {
51        let span = infcx.span;
52        let infcx_orig = &mut infcx.infcx;
53        let mut infcx = infcx_orig.branch();
54        let infcx = infcx.at(span);
55        let mut normalizer = Normalizer::new(infcx)?;
56        self.erase_regions().try_fold_with(&mut normalizer)
57    }
58
59    fn normalize_sorts<'tcx>(
60        &self,
61        def_id: DefId,
62        genv: GlobalEnv<'_, 'tcx>,
63        infcx: &rustc_infer::infer::InferCtxt<'tcx>,
64    ) -> QueryResult<Self> {
65        let mut normalizer = SortNormalizer::new(def_id, genv, infcx);
66        self.erase_regions().try_fold_with(&mut normalizer)
67    }
68}
69
70struct Normalizer<'a, 'infcx, 'genv, 'tcx> {
71    infcx: InferCtxtAt<'a, 'infcx, 'genv, 'tcx>,
72    selcx: SelectionContext<'infcx, 'tcx>,
73    param_env: List<Clause>,
74    scope: Scope,
75}
76
77impl<'a, 'infcx, 'genv, 'tcx> Normalizer<'a, 'infcx, 'genv, 'tcx> {
78    fn new(infcx: InferCtxtAt<'a, 'infcx, 'genv, 'tcx>) -> QueryResult<Self> {
79        let predicates = infcx.genv.predicates_of(infcx.def_id)?;
80        let param_env = predicates.instantiate_identity().predicates.clone();
81        let selcx = SelectionContext::new(infcx.region_infcx);
82        let scope = infcx.cursor().marker().scope().unwrap();
83        Ok(Normalizer { infcx, selcx, param_env, scope })
84    }
85
86    fn normalize_projection_ty(
87        &mut self,
88        obligation: &AliasTy,
89    ) -> QueryResult<(bool, SubsetTyCtor)> {
90        let mut candidates = vec![];
91        self.assemble_candidates_from_param_env(obligation, &mut candidates);
92        self.assemble_candidates_from_trait_def(obligation, &mut candidates)
93            .unwrap_or_else(|err| tracked_span_bug!("{err:?}"));
94        self.assemble_candidates_from_impls(obligation, &mut candidates)?;
95        if candidates.is_empty() {
96            // TODO: This is a temporary hack that uses rustc's trait selection when FLUX fails;
97            //       The correct thing, e.g for `trait09.rs` is to make sure FLUX's param_env mirrors RUSTC,
98            //       by suitably chasing down the super-trait predicates,
99            //       see https://github.com/flux-rs/flux/issues/737
100            let (changed, ty_ctor) = normalize_projection_ty_with_rustc(
101                self.genv(),
102                self.def_id(),
103                self.infcx.region_infcx,
104                obligation,
105            )?;
106            return Ok((changed, ty_ctor));
107        }
108        if candidates.len() > 1 {
109            bug!("ambiguity when resolving `{obligation:?}` in {:?}", self.def_id());
110        }
111        let ctor = self.confirm_candidate(candidates.pop().unwrap(), obligation)?;
112        Ok((true, ctor))
113    }
114
115    fn find_resolved_predicates(
116        &self,
117        subst: &mut TVarSubst,
118        preds: Vec<EarlyBinder<ProjectionPredicate>>,
119    ) -> (Vec<ProjectionPredicate>, Vec<EarlyBinder<ProjectionPredicate>>) {
120        let mut resolved = vec![];
121        let mut unresolved = vec![];
122        for pred in preds {
123            let term = pred.clone().skip_binder().term;
124            let alias_ty = pred.clone().map(|p| p.projection_ty);
125            match subst.instantiate_partial(alias_ty) {
126                Some(projection_ty) => {
127                    let pred = ProjectionPredicate { projection_ty, term };
128                    resolved.push(pred);
129                }
130                None => unresolved.push(pred.clone()),
131            }
132        }
133        (resolved, unresolved)
134    }
135
136    // See issue-829*.rs for an example of what this function is for.
137    fn resolve_projection_predicates(
138        &mut self,
139        subst: &mut TVarSubst,
140        impl_def_id: DefId,
141    ) -> QueryResult {
142        let mut projection_preds: Vec<_> = self
143            .genv()
144            .predicates_of(impl_def_id)?
145            .skip_binder()
146            .predicates
147            .iter()
148            .filter_map(|pred| {
149                if let ClauseKind::Projection(pred) = pred.kind_skipping_binder() {
150                    Some(EarlyBinder(pred.clone()))
151                } else {
152                    None
153                }
154            })
155            .collect();
156
157        while !projection_preds.is_empty() {
158            let (resolved, unresolved) = self.find_resolved_predicates(subst, projection_preds);
159
160            if resolved.is_empty() {
161                break; // failed: there is some unresolved projection pred!
162            }
163            for p in resolved {
164                let obligation = &p.projection_ty;
165                let (_, ctor) = self.normalize_projection_ty(obligation)?;
166                subst.subset_tys(&p.term, &ctor);
167            }
168            projection_preds = unresolved;
169        }
170        Ok(())
171    }
172
173    fn confirm_candidate(
174        &mut self,
175        candidate: Candidate,
176        obligation: &AliasTy,
177    ) -> QueryResult<SubsetTyCtor> {
178        let tcx = self.tcx();
179        match candidate {
180            Candidate::ParamEnv(pred) | Candidate::TraitDef(pred) => {
181                let rustc_obligation = obligation.to_rustc(tcx);
182                let parent_id = rustc_obligation.trait_ref(tcx).def_id;
183                // Do fn-subtyping if the candidate was a fn-trait
184                if tcx.is_fn_trait(parent_id) {
185                    let res = self
186                        .fn_subtype_projection_ty(pred, obligation)
187                        .unwrap_or_else(|err| tracked_span_bug!("{err:?}"));
188                    Ok(res)
189                } else {
190                    Ok(pred.skip_binder().term)
191                }
192            }
193            Candidate::UserDefinedImpl(impl_def_id) => {
194                // Given a projection obligation
195                //     <IntoIter<{v. i32[v] | v > 0}, Global> as Iterator>::Item
196                // and the id of a rust impl block
197                //     impl<T, A: Allocator> Iterator for IntoIter<T, A>
198
199                // 1. MATCH the self type of the rust impl block and the flux self type of the obligation
200                //    to infer a substitution
201                //        IntoIter<{v. i32[v] | v > 0}, Global> MATCH IntoIter<T, A>
202                //            => {T -> {v. i32[v] | v > 0}, A -> Global}
203
204                let impl_trait_ref = self
205                    .genv()
206                    .impl_trait_ref(impl_def_id)?
207                    .unwrap()
208                    .skip_binder();
209
210                let generics = self.tcx().generics_of(impl_def_id);
211
212                let mut subst = TVarSubst::new(generics);
213                for (a, b) in iter::zip(&impl_trait_ref.args, &obligation.args) {
214                    subst.generic_args(a, b);
215                }
216
217                // 2. Gather the ProjectionPredicates and solve them see issue-808.rs
218                self.resolve_projection_predicates(&mut subst, impl_def_id)?;
219
220                let args = subst.finish(self.tcx(), generics)?;
221
222                // 3. Get the associated type in the impl block and apply the substitution to it
223                let assoc_type_id = tcx
224                    .associated_items(impl_def_id)
225                    .in_definition_order()
226                    .find(|item| item.trait_item_def_id == Some(obligation.def_id))
227                    .map(|item| item.def_id)
228                    .ok_or_else(|| {
229                        query_bug!("no associated type for {obligation:?} in impl {impl_def_id:?}")
230                    })?;
231                Ok(self
232                    .genv()
233                    .type_of(assoc_type_id)?
234                    .instantiate(tcx, &args, &[])
235                    .expect_subset_ty_ctor())
236            }
237        }
238    }
239
240    fn fn_subtype_projection_ty(
241        &mut self,
242        actual: Binder<ProjectionPredicate>,
243        oblig: &AliasTy,
244    ) -> InferResult<SubsetTyCtor> {
245        // Step 1: bs <- unpack(b1...)
246        let obligs: Vec<_> = oblig
247            .args
248            .iter()
249            .map(|arg| {
250                match arg {
251                    GenericArg::Ty(ty) => GenericArg::Ty(self.infcx.unpack(ty)),
252                    GenericArg::Base(ctor) => GenericArg::Ty(self.infcx.unpack(&ctor.to_ty())),
253                    _ => arg.clone(),
254                }
255            })
256            .collect();
257
258        let span = self.infcx.span;
259        let mut infcx = self.infcx.at(span);
260
261        let actual = infcx.ensure_resolved_evars(|infcx| {
262            // Step 2: as <- fresh(a1...)
263            let actual = actual
264                .replace_bound_vars(
265                    |_| rty::ReErased,
266                    |sort, mode| infcx.fresh_infer_var(sort, mode),
267                )
268                .normalize_projections(infcx)?;
269
270            let actuals = actual.projection_ty.args.iter().map(|arg| {
271                match arg {
272                    GenericArg::Base(ctor) => GenericArg::Ty(ctor.to_ty()),
273                    _ => arg.clone(),
274                }
275            });
276
277            // Step 3: bs <: as
278            for (a, b) in izip!(actuals.skip(1), obligs.iter().skip(1)) {
279                infcx.subtyping_generic_args(
280                    Variance::Contravariant,
281                    &a,
282                    b,
283                    crate::infer::ConstrReason::Predicate,
284                )?;
285            }
286            Ok(actual)
287        })?;
288        // Step 4: check all evars are solved, plug back into ProjectionPredicate
289        let actual = infcx.fully_resolve_evars(&actual);
290
291        // Step 5: generate "fresh" type for actual.term,
292        let oblig_term = actual.term.with_holes().replace_holes(|binders, kind| {
293            assert!(kind == rty::HoleKind::Pred);
294            let scope = &self.scope;
295            infcx.fresh_kvar_in_scope(binders, scope, KVarEncoding::Conj)
296        });
297
298        // Step 6: subtyping obligation on output
299        infcx.subtyping(
300            &actual.term.to_ty(),
301            &oblig_term.to_ty(),
302            crate::infer::ConstrReason::Predicate,
303        )?;
304        // Ok(ProjectionPredicate { projection_ty: actual.projection_ty, term: oblig_term })
305        Ok(oblig_term)
306    }
307
308    fn assemble_candidates_from_predicates(
309        &mut self,
310        predicates: &List<Clause>,
311        obligation: &AliasTy,
312        ctor: fn(Binder<ProjectionPredicate>) -> Candidate,
313        candidates: &mut Vec<Candidate>,
314    ) {
315        let tcx = self.tcx();
316        let rustc_obligation = obligation.to_rustc(tcx);
317
318        for predicate in predicates {
319            if let Some(pred) = predicate.as_projection_clause()
320                && pred.skip_binder_ref().projection_ty.to_rustc(tcx) == rustc_obligation
321            {
322                candidates.push(ctor(pred));
323            }
324        }
325    }
326
327    fn assemble_candidates_from_param_env(
328        &mut self,
329        obligation: &AliasTy,
330        candidates: &mut Vec<Candidate>,
331    ) {
332        let predicates = self.param_env.clone();
333        self.assemble_candidates_from_predicates(
334            &predicates,
335            obligation,
336            Candidate::ParamEnv,
337            candidates,
338        );
339    }
340
341    fn assemble_candidates_from_trait_def(
342        &mut self,
343        obligation: &AliasTy,
344        candidates: &mut Vec<Candidate>,
345    ) -> InferResult {
346        if let GenericArg::Base(ctor) = &obligation.args[0]
347            && let BaseTy::Alias(AliasKind::Opaque, alias_ty) = ctor.as_bty_skipping_binder()
348        {
349            debug_assert!(!alias_ty.has_escaping_bvars());
350            let bounds = self.genv().item_bounds(alias_ty.def_id)?.instantiate(
351                self.tcx(),
352                &alias_ty.args,
353                &alias_ty.refine_args,
354            );
355            self.assemble_candidates_from_predicates(
356                &bounds,
357                obligation,
358                Candidate::TraitDef,
359                candidates,
360            );
361        }
362        Ok(())
363    }
364
365    fn assemble_candidates_from_impls(
366        &mut self,
367        obligation: &AliasTy,
368        candidates: &mut Vec<Candidate>,
369    ) -> QueryResult {
370        let trait_ref = obligation.to_rustc(self.tcx()).trait_ref(self.tcx());
371        let trait_ref = self.tcx().erase_regions(trait_ref);
372        let trait_pred = Obligation::new(
373            self.tcx(),
374            ObligationCause::dummy(),
375            self.rustc_param_env(),
376            trait_ref,
377        );
378        // FIXME(nilehmann) This is a patch to not panic inside rustc so we are
379        // able to catch the bug
380        if trait_pred.has_escaping_bound_vars() {
381            tracked_span_bug!();
382        }
383        match self.selcx.select(&trait_pred) {
384            Ok(Some(ImplSource::UserDefined(impl_data))) => {
385                candidates.push(Candidate::UserDefinedImpl(impl_data.impl_def_id));
386            }
387            Ok(_) => (),
388            Err(e) => bug!("error selecting {trait_pred:?}: {e:?}"),
389        }
390        Ok(())
391    }
392
393    fn def_id(&self) -> DefId {
394        self.infcx.def_id
395    }
396
397    fn genv(&self) -> GlobalEnv<'genv, 'tcx> {
398        self.infcx.genv
399    }
400
401    fn tcx(&self) -> TyCtxt<'tcx> {
402        self.selcx.tcx()
403    }
404
405    fn rustc_param_env(&self) -> rustc_middle::ty::ParamEnv<'tcx> {
406        self.selcx.tcx().param_env(self.def_id())
407    }
408}
409
410impl FallibleTypeFolder for Normalizer<'_, '_, '_, '_> {
411    type Error = QueryErr;
412
413    fn try_fold_sort(&mut self, sort: &Sort) -> Result<Sort, Self::Error> {
414        match sort {
415            Sort::Alias(AliasKind::Free, alias_ty) => {
416                self.genv()
417                    .normalize_free_alias_sort(alias_ty)?
418                    .try_fold_with(self)
419            }
420            Sort::Alias(AliasKind::Projection, alias_ty) => {
421                let (changed, ctor) = self.normalize_projection_ty(alias_ty)?;
422                let sort = ctor.sort();
423                if changed { sort.try_fold_with(self) } else { Ok(sort) }
424            }
425            _ => sort.try_super_fold_with(self),
426        }
427    }
428
429    // As shown in https://github.com/flux-rs/flux/issues/711 one round of `normalize_projections`
430    // can replace one projection e.g. `<Rev<Iter<[i32]> as Iterator>::Item` with another e.g.
431    // `<Iter<[i32]> as Iterator>::Item` We want to compute a "fixpoint" i.e. keep going until no
432    // change, so that e.g. the above is normalized all the way to `i32`, which is what the `changed`
433    // is for.
434    fn try_fold_ty(&mut self, ty: &Ty) -> Result<Ty, Self::Error> {
435        match ty.kind() {
436            TyKind::Indexed(BaseTy::Alias(AliasKind::Free, alias_ty), idx) => {
437                Ok(self
438                    .genv()
439                    .type_of(alias_ty.def_id)?
440                    .instantiate(self.tcx(), &alias_ty.args, &alias_ty.refine_args)
441                    .expect_ctor()
442                    .replace_bound_reft(idx))
443            }
444            TyKind::Indexed(BaseTy::Alias(AliasKind::Projection, alias_ty), idx) => {
445                let (changed, ctor) = self.normalize_projection_ty(alias_ty)?;
446                let ty = ctor.replace_bound_reft(idx).to_ty();
447                if changed { ty.try_fold_with(self) } else { Ok(ty) }
448            }
449            _ => ty.try_super_fold_with(self),
450        }
451    }
452
453    fn try_fold_subset_ty(&mut self, sty: &SubsetTy) -> Result<SubsetTy, Self::Error> {
454        match &sty.bty {
455            BaseTy::Alias(AliasKind::Free, _alias_ty) => {
456                // Weak aliases are always expanded during conversion. We could in theory normalize
457                // them here but we don't guaranatee that type aliases expand to a subset ty. If we
458                // ever stop expanding aliases during conv we would need to guarantee that aliases
459                // used as a generic base expand to a subset type.
460                tracked_span_bug!()
461            }
462            BaseTy::Alias(AliasKind::Projection, alias_ty) => {
463                let (changed, ctor) = self.normalize_projection_ty(alias_ty)?;
464                let ty = ctor.replace_bound_reft(&sty.idx).strengthen(&sty.pred);
465                if changed { ty.try_fold_with(self) } else { Ok(ty) }
466            }
467            _ => sty.try_super_fold_with(self),
468        }
469    }
470
471    fn try_fold_expr(&mut self, expr: &Expr) -> Result<Expr, Self::Error> {
472        if let ExprKind::Alias(alias_pred, refine_args) = expr.kind() {
473            let (changed, e) = normalize_alias_reft(
474                self.genv(),
475                self.def_id(),
476                self.selcx.infcx,
477                alias_pred,
478                refine_args,
479            )?;
480            if changed { e.try_fold_with(self) } else { Ok(e) }
481        } else {
482            expr.try_super_fold_with(self)
483        }
484    }
485
486    fn try_fold_const(&mut self, c: &Const) -> Result<Const, Self::Error> {
487        let c = c.to_rustc(self.tcx());
488        rustc_trait_selection::traits::evaluate_const(self.selcx.infcx, c, self.rustc_param_env())
489            .lower(self.tcx())
490            .map_err(|e| QueryErr::unsupported(self.def_id(), e.into_err()))
491    }
492}
493
494#[derive(Debug)]
495pub enum Candidate {
496    UserDefinedImpl(DefId),
497    ParamEnv(Binder<ProjectionPredicate>),
498    TraitDef(Binder<ProjectionPredicate>),
499}
500
501#[derive(Debug)]
502struct TVarSubst {
503    args: Vec<Option<GenericArg>>,
504}
505
506impl GenericsSubstDelegate for &TVarSubst {
507    type Error = ();
508
509    fn ty_for_param(&mut self, param_ty: rustc_middle::ty::ParamTy) -> Result<Ty, Self::Error> {
510        match self.args.get(param_ty.index as usize) {
511            Some(Some(GenericArg::Ty(ty))) => Ok(ty.clone()),
512            Some(None) => Err(()),
513            arg => tracked_span_bug!("expected type for generic parameter, found `{arg:?}`"),
514        }
515    }
516
517    fn sort_for_param(&mut self, param_ty: rustc_middle::ty::ParamTy) -> Result<Sort, Self::Error> {
518        match self.args.get(param_ty.index as usize) {
519            Some(Some(GenericArg::Base(ctor))) => Ok(ctor.sort()),
520            Some(None) => Err(()),
521            arg => tracked_span_bug!("expected type for generic parameter, found `{arg:?}`"),
522        }
523    }
524
525    fn ctor_for_param(
526        &mut self,
527        param_ty: rustc_middle::ty::ParamTy,
528    ) -> Result<SubsetTyCtor, Self::Error> {
529        match self.args.get(param_ty.index as usize) {
530            Some(Some(GenericArg::Base(ctor))) => Ok(ctor.clone()),
531            Some(None) => Err(()),
532            arg => tracked_span_bug!("expected type for generic parameter, found `{arg:?}`"),
533        }
534    }
535
536    fn region_for_param(&mut self, _ebr: rustc_middle::ty::EarlyParamRegion) -> Region {
537        tracked_span_bug!()
538    }
539
540    fn expr_for_param_const(&self, _param_const: rustc_middle::ty::ParamConst) -> Expr {
541        tracked_span_bug!()
542    }
543
544    fn const_for_param(&mut self, _param: &Const) -> Const {
545        tracked_span_bug!()
546    }
547}
548
549struct SortNormalizer<'infcx, 'genv, 'tcx> {
550    def_id: DefId,
551    infcx: &'infcx rustc_infer::infer::InferCtxt<'tcx>,
552    genv: GlobalEnv<'genv, 'tcx>,
553}
554impl<'infcx, 'genv, 'tcx> SortNormalizer<'infcx, 'genv, 'tcx> {
555    fn new(
556        def_id: DefId,
557        genv: GlobalEnv<'genv, 'tcx>,
558        infcx: &'infcx rustc_infer::infer::InferCtxt<'tcx>,
559    ) -> Self {
560        Self { def_id, infcx, genv }
561    }
562}
563
564impl FallibleTypeFolder for SortNormalizer<'_, '_, '_> {
565    type Error = QueryErr;
566    fn try_fold_sort(&mut self, sort: &Sort) -> Result<Sort, Self::Error> {
567        match sort {
568            Sort::Alias(AliasKind::Free, alias_ty) => {
569                self.genv
570                    .normalize_free_alias_sort(alias_ty)?
571                    .try_fold_with(self)
572            }
573            Sort::Alias(AliasKind::Projection, alias_ty) => {
574                let (changed, ctor) = normalize_projection_ty_with_rustc(
575                    self.genv,
576                    self.def_id,
577                    self.infcx,
578                    alias_ty,
579                )?;
580                let sort = ctor.sort();
581                if changed { sort.try_fold_with(self) } else { Ok(sort) }
582            }
583            _ => sort.try_super_fold_with(self),
584        }
585    }
586}
587
588impl TVarSubst {
589    fn new(generics: &rustc_middle::ty::Generics) -> Self {
590        Self { args: vec![None; generics.count()] }
591    }
592
593    fn instantiate_partial<T: TypeFoldable>(&mut self, pred: EarlyBinder<T>) -> Option<T> {
594        let mut folder = GenericsSubstFolder::new(&*self, &[]);
595        pred.skip_binder().try_fold_with(&mut folder).ok()
596    }
597
598    fn finish<'tcx>(
599        self,
600        tcx: TyCtxt<'tcx>,
601        generics: &'tcx rustc_middle::ty::Generics,
602    ) -> QueryResult<Vec<GenericArg>> {
603        self.args
604            .into_iter()
605            .enumerate()
606            .map(|(idx, arg)| {
607                if let Some(arg) = arg {
608                    Ok(arg)
609                } else {
610                    let param = generics.param_at(idx, tcx);
611                    Err(QueryErr::bug(
612                        None,
613                        format!("cannot infer substitution for {param:?} at index {idx}"),
614                    ))
615                }
616            })
617            .try_collect_vec()
618    }
619
620    fn generic_args(&mut self, a: &GenericArg, b: &GenericArg) {
621        match (a, b) {
622            (GenericArg::Ty(a), GenericArg::Ty(b)) => self.tys(a, b),
623            (GenericArg::Lifetime(a), GenericArg::Lifetime(b)) => self.regions(*a, *b),
624            (GenericArg::Base(a), GenericArg::Base(b)) => {
625                self.subset_tys(a, b);
626            }
627            (GenericArg::Const(a), GenericArg::Const(b)) => self.consts(a, b),
628            _ => {}
629        }
630    }
631
632    fn tys(&mut self, a: &Ty, b: &Ty) {
633        if let TyKind::Param(param_ty) = a.kind() {
634            if !b.has_escaping_bvars() {
635                self.insert_generic_arg(param_ty.index, GenericArg::Ty(b.clone()));
636            }
637        } else {
638            let Some(a_bty) = a.as_bty_skipping_existentials() else { return };
639            let Some(b_bty) = b.as_bty_skipping_existentials() else { return };
640            self.btys(a_bty, b_bty);
641        }
642    }
643
644    fn subset_tys(&mut self, a: &SubsetTyCtor, b: &SubsetTyCtor) {
645        let bty_a = a.as_bty_skipping_binder();
646        let bty_b = b.as_bty_skipping_binder();
647        if let BaseTy::Param(param_ty) = bty_a {
648            if !b.has_escaping_bvars() {
649                self.insert_generic_arg(param_ty.index, GenericArg::Base(b.clone()));
650            }
651        } else {
652            self.btys(bty_a, bty_b);
653        }
654    }
655
656    fn btys(&mut self, a: &BaseTy, b: &BaseTy) {
657        match (a, b) {
658            (BaseTy::Param(param_ty), _) => {
659                if !b.has_escaping_bvars() {
660                    let sort = b.sort();
661                    let ctor =
662                        Binder::bind_with_sort(SubsetTy::trivial(b.clone(), Expr::nu()), sort);
663                    self.insert_generic_arg(param_ty.index, GenericArg::Base(ctor));
664                }
665            }
666            (BaseTy::Adt(_, a_args), BaseTy::Adt(_, b_args)) => {
667                debug_assert_eq!(a_args.len(), b_args.len());
668                for (a_arg, b_arg) in iter::zip(a_args, b_args) {
669                    self.generic_args(a_arg, b_arg);
670                }
671            }
672            (BaseTy::Array(a_ty, a_n), BaseTy::Array(b_ty, b_n)) => {
673                self.tys(a_ty, b_ty);
674                self.consts(a_n, b_n);
675            }
676            (BaseTy::Tuple(a_tys), BaseTy::Tuple(b_tys)) => {
677                debug_assert_eq!(a_tys.len(), b_tys.len());
678                for (a_ty, b_ty) in iter::zip(a_tys, b_tys) {
679                    self.tys(a_ty, b_ty);
680                }
681            }
682            (BaseTy::Ref(a_re, a_ty, _), BaseTy::Ref(b_re, b_ty, _)) => {
683                self.regions(*a_re, *b_re);
684                self.tys(a_ty, b_ty);
685            }
686            (BaseTy::Slice(a_ty), BaseTy::Slice(b_ty)) => {
687                self.tys(a_ty, b_ty);
688            }
689            _ => {}
690        }
691    }
692
693    fn regions(&mut self, a: Region, b: Region) {
694        if let Region::ReEarlyParam(ebr) = a {
695            self.insert_generic_arg(ebr.index, GenericArg::Lifetime(b));
696        }
697    }
698
699    fn consts(&mut self, a: &Const, b: &Const) {
700        if let ConstKind::Param(param_const) = a.kind {
701            self.insert_generic_arg(param_const.index, GenericArg::Const(b.clone()));
702        }
703    }
704
705    fn insert_generic_arg(&mut self, idx: u32, arg: GenericArg) {
706        if let Some(old) = &self.args[idx as usize]
707            && old != &arg
708        {
709            tracked_span_bug!("ambiguous substitution: old=`{old:?}`, new: `{arg:?}`");
710        }
711        self.args[idx as usize].replace(arg);
712    }
713}
714
715/// Normalize an [`rty::AliasTy`] by converting it to rustc, normalizing it using rustc api, and
716/// then mapping the result back to `rty`. This will lose refinements and it should only be used
717/// to normalize sorts because they should only contain unrefined types. However, we are also using
718/// it as a hack to normalize types in cases where we fail to collect a candidate, this is unsound
719/// and should be removed.
720///
721/// [`rty::AliasTy`]: AliasTy
722fn normalize_projection_ty_with_rustc<'tcx>(
723    genv: GlobalEnv<'_, 'tcx>,
724    def_id: DefId,
725    infcx: &rustc_infer::infer::InferCtxt<'tcx>,
726    obligation: &AliasTy,
727) -> QueryResult<(bool, SubsetTyCtor)> {
728    let tcx = genv.tcx();
729    let projection_ty = obligation.to_rustc(tcx);
730    let projection_ty = tcx.erase_regions(projection_ty);
731    let cause = ObligationCause::dummy();
732    let param_env = tcx.param_env(def_id);
733
734    let pre_ty = projection_ty.to_ty(tcx);
735    let at = infcx.at(&cause, param_env);
736    let ty = deeply_normalize::<rustc_middle::ty::Ty<'tcx>, FulfillmentError>(at, pre_ty)
737        .map_err(|err| query_bug!("{err:?}"))?;
738
739    let changed = pre_ty != ty;
740    let rustc_ty = ty.lower(tcx).map_err(|reason| query_bug!("{reason:?}"))?;
741
742    Ok((
743        changed,
744        Refiner::default_for_item(genv, def_id)?
745            .refine_ty_or_base(&rustc_ty)?
746            .expect_base(),
747    ))
748}
749
750/// Do one step of normalization, unfolding associated refinements if they are concrete.
751///
752/// Use this if you are about to match structurally on an [`ExprKind`] and you need associated
753/// refinements to be normalized.
754pub fn structurally_normalize_expr(
755    genv: GlobalEnv,
756    def_id: DefId,
757    infcx: &rustc_infer::infer::InferCtxt,
758    expr: &Expr,
759) -> QueryResult<Expr> {
760    if let ExprKind::Alias(alias_pred, refine_args) = expr.kind() {
761        let (_, e) = normalize_alias_reft(genv, def_id, infcx, alias_pred, refine_args)?;
762        Ok(e)
763    } else {
764        Ok(expr.clone())
765    }
766}
767
768/// Normalizes an [`AliasReft`]. This uses the trait solver to find the [`ImplSourceUserDefinedData`]
769/// and uses the `args` there, which we map back to Flux via refining. This loses refinements,
770/// but that's fine because [`AliasReft`] should not rely on refinements for trait solving.
771fn normalize_alias_reft(
772    genv: GlobalEnv,
773    def_id: DefId,
774    infcx: &rustc_infer::infer::InferCtxt,
775    alias_reft: &AliasReft,
776    refine_args: &RefineArgs,
777) -> QueryResult<(bool, Expr)> {
778    let is_final = genv.assoc_refinement(alias_reft.assoc_id)?.final_;
779    if is_final {
780        let e = genv
781            .default_assoc_refinement_body(alias_reft.assoc_id)?
782            .unwrap_or_else(|| {
783                bug!("final associated refinement without body - should be caught in desugar")
784            })
785            .instantiate(genv.tcx(), &alias_reft.args, &[])
786            .apply(refine_args);
787        Ok((true, e))
788    } else if let Some(impl_data) = get_impl_data_for_alias_reft(infcx, def_id, alias_reft)? {
789        let tcx = infcx.tcx;
790        let impl_def_id = impl_data.impl_def_id;
791        let args = Refiner::default_for_item(genv, def_id)?.refine_generic_args(
792            impl_def_id,
793            &impl_data
794                .args
795                .lower(tcx)
796                .map_err(|reason| query_bug!("{reason:?}"))?,
797        )?;
798        let e = genv
799            .assoc_refinement_body_for_impl(alias_reft.assoc_id, impl_def_id)?
800            .instantiate(tcx, &args, &[])
801            .apply(refine_args);
802        Ok((true, e))
803    } else {
804        Ok((false, Expr::alias(alias_reft.clone(), refine_args.clone())))
805    }
806}
807
808fn get_impl_data_for_alias_reft<'tcx>(
809    infcx: &rustc_infer::infer::InferCtxt<'tcx>,
810    def_id: DefId,
811    alias_reft: &AliasReft,
812) -> QueryResult<Option<ImplSourceUserDefinedData<'tcx, PredicateObligation<'tcx>>>> {
813    let tcx = infcx.tcx;
814    let mut selcx = SelectionContext::new(infcx);
815    let trait_ref = alias_reft.to_rustc_trait_ref(tcx);
816    let trait_ref = tcx.erase_regions(trait_ref);
817    let trait_pred =
818        Obligation::new(tcx, ObligationCause::dummy(), tcx.param_env(def_id), trait_ref);
819    match selcx.select(&trait_pred) {
820        Ok(Some(ImplSource::UserDefined(impl_data))) => Ok(Some(impl_data)),
821        Ok(_) => Ok(None),
822        Err(e) => bug!("error selecting {trait_pred:?}: {e:?}"),
823    }
824}