flux_infer/
projections.rs

1use std::iter;
2
3use flux_common::{bug, iter::IterExt, tracked_span_bug};
4use flux_middle::{
5    global_env::GlobalEnv,
6    queries::{QueryErr, QueryResult},
7    query_bug,
8    rty::{
9        self, AliasKind, AliasReft, AliasTy, BaseTy, Binder, Clause, ClauseKind, Const, ConstKind,
10        EarlyBinder, Expr, ExprKind, GenericArg, List, ProjectionPredicate, RefineArgs, Region,
11        Sort, SubsetTy, SubsetTyCtor, Ty, TyKind,
12        fold::{FallibleTypeFolder, TypeFoldable, TypeSuperFoldable, TypeVisitable},
13        refining::Refiner,
14        subst::{GenericsSubstDelegate, GenericsSubstFolder},
15    },
16};
17use flux_rustc_bridge::{ToRustc, lowering::Lower};
18use itertools::izip;
19use rustc_hir::def_id::DefId;
20use rustc_infer::traits::Obligation;
21use rustc_middle::{
22    traits::{ImplSource, ObligationCause},
23    ty::{TyCtxt, Variance},
24};
25use rustc_trait_selection::{
26    solve::deeply_normalize,
27    traits::{FulfillmentError, SelectionContext},
28};
29use rustc_type_ir::TypeVisitableExt;
30
31use crate::{
32    fixpoint_encoding::KVarEncoding,
33    infer::{InferCtxtAt, InferResult},
34    refine_tree::Scope,
35};
36
37pub trait NormalizeExt: TypeFoldable {
38    fn normalize_projections(&self, infcx: &mut InferCtxtAt) -> QueryResult<Self>;
39
40    /// Normalize projections but only inside sorts
41    fn normalize_sorts<'tcx>(
42        &self,
43        def_id: DefId,
44        genv: GlobalEnv<'_, 'tcx>,
45        infcx: &rustc_infer::infer::InferCtxt<'tcx>,
46    ) -> QueryResult<Self>;
47}
48
49impl<T: TypeFoldable> NormalizeExt for T {
50    fn normalize_projections(&self, infcx: &mut InferCtxtAt) -> QueryResult<Self> {
51        let span = infcx.span;
52        let infcx_orig = &mut infcx.infcx;
53        let mut infcx = infcx_orig.branch();
54        let infcx = infcx.at(span);
55        let mut normalizer = Normalizer::new(infcx)?;
56        self.erase_regions().try_fold_with(&mut normalizer)
57    }
58
59    fn normalize_sorts<'tcx>(
60        &self,
61        def_id: DefId,
62        genv: GlobalEnv<'_, 'tcx>,
63        infcx: &rustc_infer::infer::InferCtxt<'tcx>,
64    ) -> QueryResult<Self> {
65        let mut normalizer = SortNormalizer::new(def_id, genv, infcx);
66        self.erase_regions().try_fold_with(&mut normalizer)
67    }
68}
69
70struct Normalizer<'a, 'infcx, 'genv, 'tcx> {
71    infcx: InferCtxtAt<'a, 'infcx, 'genv, 'tcx>,
72    selcx: SelectionContext<'infcx, 'tcx>,
73    param_env: List<Clause>,
74    scope: Scope,
75}
76
77impl<'a, 'infcx, 'genv, 'tcx> Normalizer<'a, 'infcx, 'genv, 'tcx> {
78    fn new(infcx: InferCtxtAt<'a, 'infcx, 'genv, 'tcx>) -> QueryResult<Self> {
79        let predicates = infcx.genv.predicates_of(infcx.def_id)?;
80        let param_env = predicates.instantiate_identity().predicates.clone();
81        let selcx = SelectionContext::new(infcx.region_infcx);
82        let scope = infcx.cursor().marker().scope().unwrap();
83        Ok(Normalizer { infcx, selcx, param_env, scope })
84    }
85
86    fn get_impl_id_of_alias_reft(&mut self, alias_reft: &AliasReft) -> QueryResult<Option<DefId>> {
87        let tcx = self.tcx();
88        let def_id = self.def_id();
89        let selcx = &mut self.selcx;
90        let trait_ref = alias_reft.to_rustc_trait_ref(tcx);
91        let trait_ref = tcx.erase_regions(trait_ref);
92        let trait_pred =
93            Obligation::new(tcx, ObligationCause::dummy(), tcx.param_env(def_id), trait_ref);
94        match selcx.select(&trait_pred) {
95            Ok(Some(ImplSource::UserDefined(impl_data))) => Ok(Some(impl_data.impl_def_id)),
96            Ok(_) => Ok(None),
97            Err(e) => bug!("error selecting {trait_pred:?}: {e:?}"),
98        }
99    }
100
101    fn alias_reft_is_final(&mut self, alias_reft: &AliasReft) -> QueryResult<bool> {
102        let reft_id = alias_reft.assoc_id;
103        let assoc_refinements = self.genv().assoc_refinements_of(reft_id.parent())?;
104        Ok(assoc_refinements.get(reft_id).final_)
105    }
106
107    fn normalize_alias_reft(
108        &mut self,
109        alias_reft: &AliasReft,
110        refine_args: &RefineArgs,
111    ) -> QueryResult<Expr> {
112        if self.alias_reft_is_final(alias_reft)? {
113            self.genv()
114                .default_assoc_refinement_body(alias_reft.assoc_id)?
115                .unwrap_or_else(|| {
116                    bug!("final associated refinement without body - should be caught in desugar")
117                })
118                .instantiate(self.genv().tcx(), &alias_reft.args, &[])
119                .apply(refine_args)
120                .try_fold_with(self)
121        } else if let Some(impl_def_id) = self.get_impl_id_of_alias_reft(alias_reft)? {
122            let impl_trait_ref = self
123                .genv()
124                .impl_trait_ref(impl_def_id)?
125                .unwrap()
126                .skip_binder();
127            let generics = self.tcx().generics_of(impl_def_id);
128            let mut subst = TVarSubst::new(generics);
129            let alias_reft_args = alias_reft.args.try_fold_with(self)?;
130            // TODO(RJ): The code below duplicates the logic in `confirm_candidates` ...
131            for (a, b) in iter::zip(&impl_trait_ref.args, &alias_reft_args) {
132                subst.generic_args(a, b);
133            }
134            self.resolve_projection_predicates(&mut subst, impl_def_id)?;
135
136            let args = subst.finish(self.tcx(), generics)?;
137            self.genv()
138                .assoc_refinement_body_for_impl(alias_reft.assoc_id, impl_def_id)?
139                .instantiate(self.genv().tcx(), &args, &[])
140                .apply(refine_args)
141                .try_fold_with(self)
142        } else {
143            Ok(Expr::alias(alias_reft.clone(), refine_args.clone()))
144        }
145    }
146
147    fn normalize_projection_ty(
148        &mut self,
149        obligation: &AliasTy,
150    ) -> QueryResult<(bool, SubsetTyCtor)> {
151        let mut candidates = vec![];
152        self.assemble_candidates_from_param_env(obligation, &mut candidates);
153        self.assemble_candidates_from_trait_def(obligation, &mut candidates)
154            .unwrap_or_else(|err| tracked_span_bug!("{err:?}"));
155        self.assemble_candidates_from_impls(obligation, &mut candidates)?;
156        if candidates.is_empty() {
157            // TODO: This is a temporary hack that uses rustc's trait selection when FLUX fails;
158            //       The correct thing, e.g for `trait09.rs` is to make sure FLUX's param_env mirrors RUSTC,
159            //       by suitably chasing down the super-trait predicates,
160            //       see https://github.com/flux-rs/flux/issues/737
161            let (changed, ty_ctor) = normalize_projection_ty_with_rustc(
162                self.genv(),
163                self.def_id(),
164                self.infcx.region_infcx,
165                obligation,
166            )?;
167            return Ok((changed, ty_ctor));
168        }
169        if candidates.len() > 1 {
170            bug!("ambiguity when resolving `{obligation:?}` in {:?}", self.def_id());
171        }
172        let ctor = self.confirm_candidate(candidates.pop().unwrap(), obligation)?;
173        Ok((true, ctor))
174    }
175
176    fn find_resolved_predicates(
177        &self,
178        subst: &mut TVarSubst,
179        preds: Vec<EarlyBinder<ProjectionPredicate>>,
180    ) -> (Vec<ProjectionPredicate>, Vec<EarlyBinder<ProjectionPredicate>>) {
181        let mut resolved = vec![];
182        let mut unresolved = vec![];
183        for pred in preds {
184            let term = pred.clone().skip_binder().term;
185            let alias_ty = pred.clone().map(|p| p.projection_ty);
186            match subst.instantiate_partial(alias_ty) {
187                Some(projection_ty) => {
188                    let pred = ProjectionPredicate { projection_ty, term };
189                    resolved.push(pred);
190                }
191                None => unresolved.push(pred.clone()),
192            }
193        }
194        (resolved, unresolved)
195    }
196
197    // See issue-829*.rs for an example of what this function is for.
198    fn resolve_projection_predicates(
199        &mut self,
200        subst: &mut TVarSubst,
201        impl_def_id: DefId,
202    ) -> QueryResult {
203        let mut projection_preds: Vec<_> = self
204            .genv()
205            .predicates_of(impl_def_id)?
206            .skip_binder()
207            .predicates
208            .iter()
209            .filter_map(|pred| {
210                if let ClauseKind::Projection(pred) = pred.kind_skipping_binder() {
211                    Some(EarlyBinder(pred.clone()))
212                } else {
213                    None
214                }
215            })
216            .collect();
217
218        while !projection_preds.is_empty() {
219            let (resolved, unresolved) = self.find_resolved_predicates(subst, projection_preds);
220
221            if resolved.is_empty() {
222                break; // failed: there is some unresolved projection pred!
223            }
224            for p in resolved {
225                let obligation = &p.projection_ty;
226                let (_, ctor) = self.normalize_projection_ty(obligation)?;
227                subst.subset_tys(&p.term, &ctor);
228            }
229            projection_preds = unresolved;
230        }
231        Ok(())
232    }
233
234    fn confirm_candidate(
235        &mut self,
236        candidate: Candidate,
237        obligation: &AliasTy,
238    ) -> QueryResult<SubsetTyCtor> {
239        let tcx = self.tcx();
240        match candidate {
241            Candidate::ParamEnv(pred) | Candidate::TraitDef(pred) => {
242                let rustc_obligation = obligation.to_rustc(tcx);
243                let parent_id = rustc_obligation.trait_ref(tcx).def_id;
244                // Do fn-subtyping if the candidate was a fn-trait
245                if tcx.is_fn_trait(parent_id) {
246                    let res = self
247                        .fn_subtype_projection_ty(pred, obligation)
248                        .unwrap_or_else(|err| tracked_span_bug!("{err:?}"));
249                    Ok(res)
250                } else {
251                    Ok(pred.skip_binder().term)
252                }
253            }
254            Candidate::UserDefinedImpl(impl_def_id) => {
255                // Given a projection obligation
256                //     <IntoIter<{v. i32[v] | v > 0}, Global> as Iterator>::Item
257                // and the id of a rust impl block
258                //     impl<T, A: Allocator> Iterator for IntoIter<T, A>
259
260                // 1. MATCH the self type of the rust impl block and the flux self type of the obligation
261                //    to infer a substitution
262                //        IntoIter<{v. i32[v] | v > 0}, Global> MATCH IntoIter<T, A>
263                //            => {T -> {v. i32[v] | v > 0}, A -> Global}
264
265                let impl_trait_ref = self
266                    .genv()
267                    .impl_trait_ref(impl_def_id)?
268                    .unwrap()
269                    .skip_binder();
270
271                let generics = self.tcx().generics_of(impl_def_id);
272
273                let mut subst = TVarSubst::new(generics);
274                for (a, b) in iter::zip(&impl_trait_ref.args, &obligation.args) {
275                    subst.generic_args(a, b);
276                }
277
278                // 2. Gather the ProjectionPredicates and solve them see issue-808.rs
279                self.resolve_projection_predicates(&mut subst, impl_def_id)?;
280
281                let args = subst.finish(self.tcx(), generics)?;
282
283                // 3. Get the associated type in the impl block and apply the substitution to it
284                let assoc_type_id = tcx
285                    .associated_items(impl_def_id)
286                    .in_definition_order()
287                    .find(|item| item.trait_item_def_id == Some(obligation.def_id))
288                    .map(|item| item.def_id)
289                    .ok_or_else(|| {
290                        query_bug!("no associated type for {obligation:?} in impl {impl_def_id:?}")
291                    })?;
292                Ok(self
293                    .genv()
294                    .type_of(assoc_type_id)?
295                    .instantiate(tcx, &args, &[])
296                    .expect_subset_ty_ctor())
297            }
298        }
299    }
300
301    fn fn_subtype_projection_ty(
302        &mut self,
303        actual: Binder<ProjectionPredicate>,
304        oblig: &AliasTy,
305    ) -> InferResult<SubsetTyCtor> {
306        // Step 1: bs <- unpack(b1...)
307        let obligs: Vec<_> = oblig
308            .args
309            .iter()
310            .map(|arg| {
311                match arg {
312                    GenericArg::Ty(ty) => GenericArg::Ty(self.infcx.unpack(ty)),
313                    GenericArg::Base(ctor) => GenericArg::Ty(self.infcx.unpack(&ctor.to_ty())),
314                    _ => arg.clone(),
315                }
316            })
317            .collect();
318
319        let span = self.infcx.span;
320        let mut infcx = self.infcx.at(span);
321
322        let actual = infcx.ensure_resolved_evars(|infcx| {
323            // Step 2: as <- fresh(a1...)
324            let actual = actual
325                .replace_bound_vars(
326                    |_| rty::ReErased,
327                    |sort, mode| infcx.fresh_infer_var(sort, mode),
328                )
329                .normalize_projections(infcx)?;
330
331            let actuals = actual.projection_ty.args.iter().map(|arg| {
332                match arg {
333                    GenericArg::Base(ctor) => GenericArg::Ty(ctor.to_ty()),
334                    _ => arg.clone(),
335                }
336            });
337
338            // Step 3: bs <: as
339            for (a, b) in izip!(actuals.skip(1), obligs.iter().skip(1)) {
340                infcx.subtyping_generic_args(
341                    Variance::Contravariant,
342                    &a,
343                    b,
344                    crate::infer::ConstrReason::Predicate,
345                )?;
346            }
347            Ok(actual)
348        })?;
349        // Step 4: check all evars are solved, plug back into ProjectionPredicate
350        let actual = infcx.fully_resolve_evars(&actual);
351
352        // Step 5: generate "fresh" type for actual.term,
353        let oblig_term = actual.term.with_holes().replace_holes(|binders, kind| {
354            assert!(kind == rty::HoleKind::Pred);
355            let scope = &self.scope;
356            infcx.fresh_kvar_in_scope(binders, scope, KVarEncoding::Conj)
357        });
358
359        // Step 6: subtyping obligation on output
360        infcx.subtyping(
361            &actual.term.to_ty(),
362            &oblig_term.to_ty(),
363            crate::infer::ConstrReason::Predicate,
364        )?;
365        // Ok(ProjectionPredicate { projection_ty: actual.projection_ty, term: oblig_term })
366        Ok(oblig_term)
367    }
368
369    fn assemble_candidates_from_predicates(
370        &mut self,
371        predicates: &List<Clause>,
372        obligation: &AliasTy,
373        ctor: fn(Binder<ProjectionPredicate>) -> Candidate,
374        candidates: &mut Vec<Candidate>,
375    ) {
376        let tcx = self.tcx();
377        let rustc_obligation = obligation.to_rustc(tcx);
378
379        for predicate in predicates {
380            if let Some(pred) = predicate.as_projection_clause()
381                && pred.skip_binder_ref().projection_ty.to_rustc(tcx) == rustc_obligation
382            {
383                candidates.push(ctor(pred));
384            }
385        }
386    }
387
388    fn assemble_candidates_from_param_env(
389        &mut self,
390        obligation: &AliasTy,
391        candidates: &mut Vec<Candidate>,
392    ) {
393        let predicates = self.param_env.clone();
394        self.assemble_candidates_from_predicates(
395            &predicates,
396            obligation,
397            Candidate::ParamEnv,
398            candidates,
399        );
400    }
401
402    fn assemble_candidates_from_trait_def(
403        &mut self,
404        obligation: &AliasTy,
405        candidates: &mut Vec<Candidate>,
406    ) -> InferResult {
407        if let GenericArg::Base(ctor) = &obligation.args[0]
408            && let BaseTy::Alias(AliasKind::Opaque, alias_ty) = ctor.as_bty_skipping_binder()
409        {
410            debug_assert!(!alias_ty.has_escaping_bvars());
411            let bounds = self.genv().item_bounds(alias_ty.def_id)?.instantiate(
412                self.tcx(),
413                &alias_ty.args,
414                &alias_ty.refine_args,
415            );
416            self.assemble_candidates_from_predicates(
417                &bounds,
418                obligation,
419                Candidate::TraitDef,
420                candidates,
421            );
422        }
423        Ok(())
424    }
425
426    fn assemble_candidates_from_impls(
427        &mut self,
428        obligation: &AliasTy,
429        candidates: &mut Vec<Candidate>,
430    ) -> QueryResult {
431        let trait_ref = obligation.to_rustc(self.tcx()).trait_ref(self.tcx());
432        let trait_ref = self.tcx().erase_regions(trait_ref);
433        let trait_pred = Obligation::new(
434            self.tcx(),
435            ObligationCause::dummy(),
436            self.rustc_param_env(),
437            trait_ref,
438        );
439        // FIXME(nilehmann) This is a patch to not panic inside rustc so we are
440        // able to catch the bug
441        if trait_pred.has_escaping_bound_vars() {
442            tracked_span_bug!();
443        }
444        match self.selcx.select(&trait_pred) {
445            Ok(Some(ImplSource::UserDefined(impl_data))) => {
446                candidates.push(Candidate::UserDefinedImpl(impl_data.impl_def_id));
447            }
448            Ok(_) => (),
449            Err(e) => bug!("error selecting {trait_pred:?}: {e:?}"),
450        }
451        Ok(())
452    }
453
454    fn def_id(&self) -> DefId {
455        self.infcx.def_id
456    }
457
458    fn genv(&self) -> GlobalEnv<'genv, 'tcx> {
459        self.infcx.genv
460    }
461
462    fn tcx(&self) -> TyCtxt<'tcx> {
463        self.selcx.tcx()
464    }
465
466    fn rustc_param_env(&self) -> rustc_middle::ty::ParamEnv<'tcx> {
467        self.selcx.tcx().param_env(self.def_id())
468    }
469}
470
471impl FallibleTypeFolder for Normalizer<'_, '_, '_, '_> {
472    type Error = QueryErr;
473
474    fn try_fold_sort(&mut self, sort: &Sort) -> Result<Sort, Self::Error> {
475        match sort {
476            Sort::Alias(AliasKind::Free, alias_ty) => {
477                self.genv()
478                    .normalize_free_alias_sort(alias_ty)?
479                    .try_fold_with(self)
480            }
481            Sort::Alias(AliasKind::Projection, alias_ty) => {
482                let (changed, ctor) = self.normalize_projection_ty(alias_ty)?;
483                let sort = ctor.sort();
484                if changed { sort.try_fold_with(self) } else { Ok(sort) }
485            }
486            _ => sort.try_super_fold_with(self),
487        }
488    }
489
490    // As shown in https://github.com/flux-rs/flux/issues/711 one round of `normalize_projections`
491    // can replace one projection e.g. `<Rev<Iter<[i32]> as Iterator>::Item` with another e.g.
492    // `<Iter<[i32]> as Iterator>::Item` We want to compute a "fixpoint" i.e. keep going until no
493    // change, so that e.g. the above is normalized all the way to `i32`, which is what the `changed`
494    // is for.
495    fn try_fold_ty(&mut self, ty: &Ty) -> Result<Ty, Self::Error> {
496        match ty.kind() {
497            TyKind::Indexed(BaseTy::Alias(AliasKind::Free, alias_ty), idx) => {
498                Ok(self
499                    .genv()
500                    .type_of(alias_ty.def_id)?
501                    .instantiate(self.tcx(), &alias_ty.args, &alias_ty.refine_args)
502                    .expect_ctor()
503                    .replace_bound_reft(idx))
504            }
505            TyKind::Indexed(BaseTy::Alias(AliasKind::Projection, alias_ty), idx) => {
506                let (changed, ctor) = self.normalize_projection_ty(alias_ty)?;
507                let ty = ctor.replace_bound_reft(idx).to_ty();
508                if changed { ty.try_fold_with(self) } else { Ok(ty) }
509            }
510            _ => ty.try_super_fold_with(self),
511        }
512    }
513
514    fn try_fold_subset_ty(&mut self, sty: &SubsetTy) -> Result<SubsetTy, Self::Error> {
515        match &sty.bty {
516            BaseTy::Alias(AliasKind::Free, _alias_ty) => {
517                // Weak aliases are always expanded during conversion. We could in theory normalize
518                // them here but we don't guaranatee that type aliases expand to a subset ty. If we
519                // ever stop expanding aliases during conv we would need to guarantee that aliases
520                // used as a generic base expand to a subset type.
521                tracked_span_bug!()
522            }
523            BaseTy::Alias(AliasKind::Projection, alias_ty) => {
524                let (changed, ctor) = self.normalize_projection_ty(alias_ty)?;
525                let ty = ctor.replace_bound_reft(&sty.idx).strengthen(&sty.pred);
526                if changed { ty.try_fold_with(self) } else { Ok(ty) }
527            }
528            _ => sty.try_super_fold_with(self),
529        }
530    }
531
532    fn try_fold_expr(&mut self, expr: &Expr) -> Result<Expr, Self::Error> {
533        if let ExprKind::Alias(alias_pred, refine_args) = expr.kind() {
534            self.normalize_alias_reft(alias_pred, refine_args)
535        } else {
536            expr.try_super_fold_with(self)
537        }
538    }
539
540    fn try_fold_const(&mut self, c: &Const) -> Result<Const, Self::Error> {
541        let c = c.to_rustc(self.tcx());
542        rustc_trait_selection::traits::evaluate_const(self.selcx.infcx, c, self.rustc_param_env())
543            .lower(self.tcx())
544            .map_err(|e| QueryErr::unsupported(self.def_id(), e.into_err()))
545    }
546}
547
548#[derive(Debug)]
549pub enum Candidate {
550    UserDefinedImpl(DefId),
551    ParamEnv(Binder<ProjectionPredicate>),
552    TraitDef(Binder<ProjectionPredicate>),
553}
554
555#[derive(Debug)]
556struct TVarSubst {
557    args: Vec<Option<GenericArg>>,
558}
559
560impl GenericsSubstDelegate for &TVarSubst {
561    type Error = ();
562
563    fn ty_for_param(&mut self, param_ty: rustc_middle::ty::ParamTy) -> Result<Ty, Self::Error> {
564        match self.args.get(param_ty.index as usize) {
565            Some(Some(GenericArg::Ty(ty))) => Ok(ty.clone()),
566            Some(None) => Err(()),
567            arg => tracked_span_bug!("expected type for generic parameter, found `{arg:?}`"),
568        }
569    }
570
571    fn sort_for_param(&mut self, param_ty: rustc_middle::ty::ParamTy) -> Result<Sort, Self::Error> {
572        match self.args.get(param_ty.index as usize) {
573            Some(Some(GenericArg::Base(ctor))) => Ok(ctor.sort()),
574            Some(None) => Err(()),
575            arg => tracked_span_bug!("expected type for generic parameter, found `{arg:?}`"),
576        }
577    }
578
579    fn ctor_for_param(
580        &mut self,
581        param_ty: rustc_middle::ty::ParamTy,
582    ) -> Result<SubsetTyCtor, Self::Error> {
583        match self.args.get(param_ty.index as usize) {
584            Some(Some(GenericArg::Base(ctor))) => Ok(ctor.clone()),
585            Some(None) => Err(()),
586            arg => tracked_span_bug!("expected type for generic parameter, found `{arg:?}`"),
587        }
588    }
589
590    fn region_for_param(&mut self, _ebr: rustc_middle::ty::EarlyParamRegion) -> Region {
591        tracked_span_bug!()
592    }
593
594    fn expr_for_param_const(&self, _param_const: rustc_middle::ty::ParamConst) -> Expr {
595        tracked_span_bug!()
596    }
597
598    fn const_for_param(&mut self, _param: &Const) -> Const {
599        tracked_span_bug!()
600    }
601}
602
603struct SortNormalizer<'infcx, 'genv, 'tcx> {
604    def_id: DefId,
605    infcx: &'infcx rustc_infer::infer::InferCtxt<'tcx>,
606    genv: GlobalEnv<'genv, 'tcx>,
607}
608impl<'infcx, 'genv, 'tcx> SortNormalizer<'infcx, 'genv, 'tcx> {
609    fn new(
610        def_id: DefId,
611        genv: GlobalEnv<'genv, 'tcx>,
612        infcx: &'infcx rustc_infer::infer::InferCtxt<'tcx>,
613    ) -> Self {
614        Self { def_id, infcx, genv }
615    }
616}
617
618impl FallibleTypeFolder for SortNormalizer<'_, '_, '_> {
619    type Error = QueryErr;
620    fn try_fold_sort(&mut self, sort: &Sort) -> Result<Sort, Self::Error> {
621        match sort {
622            Sort::Alias(AliasKind::Free, alias_ty) => {
623                self.genv
624                    .normalize_free_alias_sort(alias_ty)?
625                    .try_fold_with(self)
626            }
627            Sort::Alias(AliasKind::Projection, alias_ty) => {
628                let (changed, ctor) = normalize_projection_ty_with_rustc(
629                    self.genv,
630                    self.def_id,
631                    self.infcx,
632                    alias_ty,
633                )?;
634                let sort = ctor.sort();
635                if changed { sort.try_fold_with(self) } else { Ok(sort) }
636            }
637            _ => sort.try_super_fold_with(self),
638        }
639    }
640}
641
642impl TVarSubst {
643    fn new(generics: &rustc_middle::ty::Generics) -> Self {
644        Self { args: vec![None; generics.count()] }
645    }
646
647    fn instantiate_partial<T: TypeFoldable>(&mut self, pred: EarlyBinder<T>) -> Option<T> {
648        let mut folder = GenericsSubstFolder::new(&*self, &[]);
649        pred.skip_binder().try_fold_with(&mut folder).ok()
650    }
651
652    fn finish<'tcx>(
653        self,
654        tcx: TyCtxt<'tcx>,
655        generics: &'tcx rustc_middle::ty::Generics,
656    ) -> QueryResult<Vec<GenericArg>> {
657        self.args
658            .into_iter()
659            .enumerate()
660            .map(|(idx, arg)| {
661                if let Some(arg) = arg {
662                    Ok(arg)
663                } else {
664                    let param = generics.param_at(idx, tcx);
665                    Err(QueryErr::bug(
666                        None,
667                        format!("cannot infer substitution for {param:?} at index {idx}"),
668                    ))
669                }
670            })
671            .try_collect_vec()
672    }
673
674    fn generic_args(&mut self, a: &GenericArg, b: &GenericArg) {
675        match (a, b) {
676            (GenericArg::Ty(a), GenericArg::Ty(b)) => self.tys(a, b),
677            (GenericArg::Lifetime(a), GenericArg::Lifetime(b)) => self.regions(*a, *b),
678            (GenericArg::Base(a), GenericArg::Base(b)) => {
679                self.subset_tys(a, b);
680            }
681            (GenericArg::Const(a), GenericArg::Const(b)) => self.consts(a, b),
682            _ => {}
683        }
684    }
685
686    fn tys(&mut self, a: &Ty, b: &Ty) {
687        if let TyKind::Param(param_ty) = a.kind() {
688            if !b.has_escaping_bvars() {
689                self.insert_generic_arg(param_ty.index, GenericArg::Ty(b.clone()));
690            }
691        } else {
692            let Some(a_bty) = a.as_bty_skipping_existentials() else { return };
693            let Some(b_bty) = b.as_bty_skipping_existentials() else { return };
694            self.btys(a_bty, b_bty);
695        }
696    }
697
698    fn subset_tys(&mut self, a: &SubsetTyCtor, b: &SubsetTyCtor) {
699        let bty_a = a.as_bty_skipping_binder();
700        let bty_b = b.as_bty_skipping_binder();
701        if let BaseTy::Param(param_ty) = bty_a {
702            if !b.has_escaping_bvars() {
703                self.insert_generic_arg(param_ty.index, GenericArg::Base(b.clone()));
704            }
705        } else {
706            self.btys(bty_a, bty_b);
707        }
708    }
709
710    fn btys(&mut self, a: &BaseTy, b: &BaseTy) {
711        match (a, b) {
712            (BaseTy::Param(param_ty), _) => {
713                if !b.has_escaping_bvars() {
714                    let sort = b.sort();
715                    let ctor =
716                        Binder::bind_with_sort(SubsetTy::trivial(b.clone(), Expr::nu()), sort);
717                    self.insert_generic_arg(param_ty.index, GenericArg::Base(ctor));
718                }
719            }
720            (BaseTy::Adt(_, a_args), BaseTy::Adt(_, b_args)) => {
721                debug_assert_eq!(a_args.len(), b_args.len());
722                for (a_arg, b_arg) in iter::zip(a_args, b_args) {
723                    self.generic_args(a_arg, b_arg);
724                }
725            }
726            (BaseTy::Array(a_ty, a_n), BaseTy::Array(b_ty, b_n)) => {
727                self.tys(a_ty, b_ty);
728                self.consts(a_n, b_n);
729            }
730            (BaseTy::Tuple(a_tys), BaseTy::Tuple(b_tys)) => {
731                debug_assert_eq!(a_tys.len(), b_tys.len());
732                for (a_ty, b_ty) in iter::zip(a_tys, b_tys) {
733                    self.tys(a_ty, b_ty);
734                }
735            }
736            (BaseTy::Ref(a_re, a_ty, _), BaseTy::Ref(b_re, b_ty, _)) => {
737                self.regions(*a_re, *b_re);
738                self.tys(a_ty, b_ty);
739            }
740            (BaseTy::Slice(a_ty), BaseTy::Slice(b_ty)) => {
741                self.tys(a_ty, b_ty);
742            }
743            _ => {}
744        }
745    }
746
747    fn regions(&mut self, a: Region, b: Region) {
748        if let Region::ReEarlyParam(ebr) = a {
749            self.insert_generic_arg(ebr.index, GenericArg::Lifetime(b));
750        }
751    }
752
753    fn consts(&mut self, a: &Const, b: &Const) {
754        if let ConstKind::Param(param_const) = a.kind {
755            self.insert_generic_arg(param_const.index, GenericArg::Const(b.clone()));
756        }
757    }
758
759    fn insert_generic_arg(&mut self, idx: u32, arg: GenericArg) {
760        if let Some(old) = &self.args[idx as usize]
761            && old != &arg
762        {
763            tracked_span_bug!("ambiguous substitution: old=`{old:?}`, new: `{arg:?}`");
764        }
765        self.args[idx as usize].replace(arg);
766    }
767}
768
769/// Normalize an [`rty::AliasTy`] by converting it to rustc, normalizing it using rustc api, and
770/// then mapping the result back to `rty`. This will lose refinements and it should only be used
771/// to normalize sorts because they should only contain unrefined types. However, we are also using
772/// it as a hack to normalize types in cases where we fail to collect a candidate, this is unsound
773/// and should be removed.
774///
775/// [`rty::AliasTy`]: AliasTy
776fn normalize_projection_ty_with_rustc<'tcx>(
777    genv: GlobalEnv<'_, 'tcx>,
778    def_id: DefId,
779    infcx: &rustc_infer::infer::InferCtxt<'tcx>,
780    obligation: &AliasTy,
781) -> QueryResult<(bool, SubsetTyCtor)> {
782    let tcx = genv.tcx();
783    let projection_ty = obligation.to_rustc(tcx);
784    let projection_ty = tcx.erase_regions(projection_ty);
785    let cause = ObligationCause::dummy();
786    let param_env = tcx.param_env(def_id);
787
788    let pre_ty = projection_ty.to_ty(tcx);
789    let at = infcx.at(&cause, param_env);
790    let ty = deeply_normalize::<rustc_middle::ty::Ty<'tcx>, FulfillmentError>(at, pre_ty)
791        .map_err(|err| query_bug!("{err:?}"))?;
792
793    let changed = pre_ty != ty;
794    let rustc_ty = ty.lower(tcx).map_err(|reason| query_bug!("{reason:?}"))?;
795
796    Ok((
797        changed,
798        Refiner::default_for_item(genv, def_id)?
799            .refine_ty_or_base(&rustc_ty)?
800            .expect_base(),
801    ))
802}