flux_infer/
projections.rs

1use std::iter;
2
3use flux_common::{bug, iter::IterExt, tracked_span_bug};
4use flux_middle::{
5    global_env::GlobalEnv,
6    queries::{QueryErr, QueryResult},
7    query_bug,
8    rty::{
9        self, AliasKind, AliasReft, AliasTy, BaseTy, Binder, Clause, ClauseKind, Const, ConstKind,
10        EarlyBinder, Expr, ExprKind, GenericArg, List, ProjectionPredicate, RefineArgs, Region,
11        Sort, SubsetTy, SubsetTyCtor, Ty, TyKind,
12        fold::{FallibleTypeFolder, TypeFoldable, TypeSuperFoldable, TypeVisitable},
13        refining::Refiner,
14        subst::{GenericsSubstDelegate, GenericsSubstFolder},
15    },
16};
17use flux_rustc_bridge::{ToRustc, lowering::Lower};
18use itertools::izip;
19use rustc_hir::def_id::DefId;
20use rustc_infer::traits::Obligation;
21use rustc_middle::{
22    traits::{ImplSource, ObligationCause},
23    ty::{TyCtxt, Variance},
24};
25use rustc_trait_selection::{
26    solve::deeply_normalize,
27    traits::{FulfillmentError, SelectionContext},
28};
29
30use crate::{
31    fixpoint_encoding::KVarEncoding,
32    infer::{InferCtxtAt, InferResult},
33    refine_tree::Scope,
34};
35
36pub trait NormalizeExt: TypeFoldable {
37    fn normalize_projections(&self, infcx: &mut InferCtxtAt) -> QueryResult<Self>;
38
39    /// Normalize projections but only inside sorts
40    fn normalize_sorts<'tcx>(
41        &self,
42        def_id: DefId,
43        genv: GlobalEnv<'_, 'tcx>,
44        infcx: &rustc_infer::infer::InferCtxt<'tcx>,
45    ) -> QueryResult<Self>;
46}
47
48impl<T: TypeFoldable> NormalizeExt for T {
49    fn normalize_projections(&self, infcx: &mut InferCtxtAt) -> QueryResult<Self> {
50        let span = infcx.span;
51        let infcx_orig = &mut infcx.infcx;
52        let mut infcx = infcx_orig.branch();
53        let infcx = infcx.at(span);
54        let mut normalizer = Normalizer::new(infcx)?;
55        self.erase_regions().try_fold_with(&mut normalizer)
56    }
57
58    fn normalize_sorts<'tcx>(
59        &self,
60        def_id: DefId,
61        genv: GlobalEnv<'_, 'tcx>,
62        infcx: &rustc_infer::infer::InferCtxt<'tcx>,
63    ) -> QueryResult<Self> {
64        let mut normalizer = SortNormalizer::new(def_id, genv, infcx);
65        self.erase_regions().try_fold_with(&mut normalizer)
66    }
67}
68
69struct Normalizer<'a, 'infcx, 'genv, 'tcx> {
70    infcx: InferCtxtAt<'a, 'infcx, 'genv, 'tcx>,
71    selcx: SelectionContext<'infcx, 'tcx>,
72    param_env: List<Clause>,
73    scope: Scope,
74}
75
76impl<'a, 'infcx, 'genv, 'tcx> Normalizer<'a, 'infcx, 'genv, 'tcx> {
77    fn new(infcx: InferCtxtAt<'a, 'infcx, 'genv, 'tcx>) -> QueryResult<Self> {
78        let predicates = infcx.genv.predicates_of(infcx.def_id)?;
79        let param_env = predicates.instantiate_identity().predicates.clone();
80        let selcx = SelectionContext::new(infcx.region_infcx);
81        let scope = infcx.cursor().marker().scope().unwrap();
82        Ok(Normalizer { infcx, selcx, param_env, scope })
83    }
84
85    fn get_impl_id_of_alias_reft(&mut self, alias_reft: &AliasReft) -> QueryResult<Option<DefId>> {
86        let tcx = self.tcx();
87        let def_id = self.def_id();
88        let selcx = &mut self.selcx;
89
90        let trait_pred = Obligation::new(
91            tcx,
92            ObligationCause::dummy(),
93            tcx.param_env(def_id),
94            alias_reft.to_rustc_trait_ref(tcx),
95        );
96        match selcx.select(&trait_pred) {
97            Ok(Some(ImplSource::UserDefined(impl_data))) => Ok(Some(impl_data.impl_def_id)),
98            Ok(_) => Ok(None),
99            Err(e) => bug!("error selecting {trait_pred:?}: {e:?}"),
100        }
101    }
102
103    fn normalize_alias_reft(
104        &mut self,
105        alias_reft: &AliasReft,
106        refine_args: &RefineArgs,
107    ) -> QueryResult<Expr> {
108        if let Some(impl_def_id) = self.get_impl_id_of_alias_reft(alias_reft)? {
109            let impl_trait_ref = self
110                .genv()
111                .impl_trait_ref(impl_def_id)?
112                .unwrap()
113                .skip_binder();
114            let generics = self.tcx().generics_of(impl_def_id);
115            let mut subst = TVarSubst::new(generics);
116            let alias_reft_args = alias_reft.args.try_fold_with(self)?;
117            // TODO(RJ): The code below duplicates the logic in `confirm_candidates` ...
118            for (a, b) in iter::zip(&impl_trait_ref.args, &alias_reft_args) {
119                subst.generic_args(a, b);
120            }
121            self.resolve_projection_predicates(&mut subst, impl_def_id)?;
122
123            let args = subst.finish(self.tcx(), generics)?;
124            self.genv()
125                .assoc_refinement_body_for_impl(alias_reft.assoc_id, impl_def_id)?
126                .instantiate(self.genv().tcx(), &args, &[])
127                .apply(refine_args)
128                .try_fold_with(self)
129        } else {
130            Ok(Expr::alias(alias_reft.clone(), refine_args.clone()))
131        }
132    }
133
134    fn normalize_projection_ty(
135        &mut self,
136        obligation: &AliasTy,
137    ) -> QueryResult<(bool, SubsetTyCtor)> {
138        let mut candidates = vec![];
139        self.assemble_candidates_from_param_env(obligation, &mut candidates);
140        self.assemble_candidates_from_trait_def(obligation, &mut candidates)
141            .unwrap_or_else(|err| tracked_span_bug!("{err:?}"));
142        self.assemble_candidates_from_impls(obligation, &mut candidates)?;
143        if candidates.is_empty() {
144            // TODO: This is a temporary hack that uses rustc's trait selection when FLUX fails;
145            //       The correct thing, e.g for `trait09.rs` is to make sure FLUX's param_env mirrors RUSTC,
146            //       by suitably chasing down the super-trait predicates,
147            //       see https://github.com/flux-rs/flux/issues/737
148            let (changed, ty_ctor) = normalize_projection_ty_with_rustc(
149                self.genv(),
150                self.def_id(),
151                self.infcx.region_infcx,
152                obligation,
153            )?;
154            return Ok((changed, ty_ctor));
155        }
156        if candidates.len() > 1 {
157            bug!("ambiguity when resolving `{obligation:?}` in {:?}", self.def_id());
158        }
159        let ctor = self.confirm_candidate(candidates.pop().unwrap(), obligation)?;
160        Ok((true, ctor))
161    }
162
163    fn find_resolved_predicates(
164        &self,
165        subst: &mut TVarSubst,
166        preds: Vec<EarlyBinder<ProjectionPredicate>>,
167    ) -> (Vec<ProjectionPredicate>, Vec<EarlyBinder<ProjectionPredicate>>) {
168        let mut resolved = vec![];
169        let mut unresolved = vec![];
170        for pred in preds {
171            let term = pred.clone().skip_binder().term;
172            let alias_ty = pred.clone().map(|p| p.projection_ty);
173            match subst.instantiate_partial(alias_ty) {
174                Some(projection_ty) => {
175                    let pred = ProjectionPredicate { projection_ty, term };
176                    resolved.push(pred);
177                }
178                None => unresolved.push(pred.clone()),
179            }
180        }
181        (resolved, unresolved)
182    }
183
184    // See issue-829*.rs for an example of what this function is for.
185    fn resolve_projection_predicates(
186        &mut self,
187        subst: &mut TVarSubst,
188        impl_def_id: DefId,
189    ) -> QueryResult {
190        let mut projection_preds: Vec<_> = self
191            .genv()
192            .predicates_of(impl_def_id)?
193            .skip_binder()
194            .predicates
195            .iter()
196            .filter_map(|pred| {
197                if let ClauseKind::Projection(pred) = pred.kind_skipping_binder() {
198                    Some(EarlyBinder(pred.clone()))
199                } else {
200                    None
201                }
202            })
203            .collect();
204
205        while !projection_preds.is_empty() {
206            let (resolved, unresolved) = self.find_resolved_predicates(subst, projection_preds);
207
208            if resolved.is_empty() {
209                break; // failed: there is some unresolved projection pred!
210            }
211            for p in resolved {
212                let obligation = &p.projection_ty;
213                let (_, ctor) = self.normalize_projection_ty(obligation)?;
214                subst.subset_tys(&p.term, &ctor);
215            }
216            projection_preds = unresolved;
217        }
218        Ok(())
219    }
220
221    fn confirm_candidate(
222        &mut self,
223        candidate: Candidate,
224        obligation: &AliasTy,
225    ) -> QueryResult<SubsetTyCtor> {
226        let tcx = self.tcx();
227        match candidate {
228            Candidate::ParamEnv(pred) | Candidate::TraitDef(pred) => {
229                let rustc_obligation = obligation.to_rustc(tcx);
230                let parent_id = rustc_obligation.trait_ref(tcx).def_id;
231                // Do fn-subtyping if the candidate was a fn-trait
232                if tcx.is_fn_trait(parent_id) {
233                    let res = self
234                        .fn_subtype_projection_ty(pred, obligation)
235                        .unwrap_or_else(|err| tracked_span_bug!("{err:?}"));
236                    Ok(res)
237                } else {
238                    Ok(pred.skip_binder().term)
239                }
240            }
241            Candidate::UserDefinedImpl(impl_def_id) => {
242                // Given a projection obligation
243                //     <IntoIter<{v. i32[v] | v > 0}, Global> as Iterator>::Item
244                // and the id of a rust impl block
245                //     impl<T, A: Allocator> Iterator for IntoIter<T, A>
246
247                // 1. MATCH the self type of the rust impl block and the flux self type of the obligation
248                //    to infer a substitution
249                //        IntoIter<{v. i32[v] | v > 0}, Global> MATCH IntoIter<T, A>
250                //            => {T -> {v. i32[v] | v > 0}, A -> Global}
251
252                let impl_trait_ref = self
253                    .genv()
254                    .impl_trait_ref(impl_def_id)?
255                    .unwrap()
256                    .skip_binder();
257
258                let generics = self.tcx().generics_of(impl_def_id);
259
260                let mut subst = TVarSubst::new(generics);
261                for (a, b) in iter::zip(&impl_trait_ref.args, &obligation.args) {
262                    subst.generic_args(a, b);
263                }
264
265                // 2. Gather the ProjectionPredicates and solve them see issue-808.rs
266                self.resolve_projection_predicates(&mut subst, impl_def_id)?;
267
268                let args = subst.finish(self.tcx(), generics)?;
269
270                // 3. Get the associated type in the impl block and apply the substitution to it
271                let assoc_type_id = tcx
272                    .associated_items(impl_def_id)
273                    .in_definition_order()
274                    .find(|item| item.trait_item_def_id == Some(obligation.def_id))
275                    .map(|item| item.def_id)
276                    .unwrap();
277
278                Ok(self
279                    .genv()
280                    .type_of(assoc_type_id)?
281                    .instantiate(tcx, &args, &[])
282                    .expect_subset_ty_ctor())
283            }
284        }
285    }
286
287    fn fn_subtype_projection_ty(
288        &mut self,
289        actual: Binder<ProjectionPredicate>,
290        oblig: &AliasTy,
291    ) -> InferResult<SubsetTyCtor> {
292        // Step 1: bs <- unpack(b1...)
293        let obligs: Vec<_> = oblig
294            .args
295            .iter()
296            .map(|arg| {
297                match arg {
298                    GenericArg::Ty(ty) => GenericArg::Ty(self.infcx.unpack(ty)),
299                    GenericArg::Base(ctor) => GenericArg::Ty(self.infcx.unpack(&ctor.to_ty())),
300                    _ => arg.clone(),
301                }
302            })
303            .collect();
304
305        let span = self.infcx.span;
306        let mut infcx = self.infcx.at(span);
307
308        let actual = infcx.ensure_resolved_evars(|infcx| {
309            // Step 2: as <- fresh(a1...)
310            let actual = actual
311                .replace_bound_vars(
312                    |_| rty::ReErased,
313                    |sort, mode| infcx.fresh_infer_var(sort, mode),
314                )
315                .normalize_projections(infcx)?;
316
317            let actuals = actual.projection_ty.args.iter().map(|arg| {
318                match arg {
319                    GenericArg::Base(ctor) => GenericArg::Ty(ctor.to_ty()),
320                    _ => arg.clone(),
321                }
322            });
323
324            // Step 3: bs <: as
325            for (a, b) in izip!(actuals.skip(1), obligs.iter().skip(1)) {
326                infcx.subtyping_generic_args(
327                    Variance::Contravariant,
328                    &a,
329                    b,
330                    crate::infer::ConstrReason::Predicate,
331                )?;
332            }
333            Ok(actual)
334        })?;
335        // Step 4: check all evars are solved, plug back into ProjectionPredicate
336        let actual = infcx.fully_resolve_evars(&actual);
337
338        // Step 5: generate "fresh" type for actual.term,
339        let oblig_term = actual.term.with_holes().replace_holes(|binders, kind| {
340            assert!(kind == rty::HoleKind::Pred);
341            let scope = &self.scope;
342            infcx.fresh_kvar_in_scope(binders, scope, KVarEncoding::Conj)
343        });
344
345        // Step 6: subtyping obligation on output
346        infcx.subtyping(
347            &actual.term.to_ty(),
348            &oblig_term.to_ty(),
349            crate::infer::ConstrReason::Predicate,
350        )?;
351        // Ok(ProjectionPredicate { projection_ty: actual.projection_ty, term: oblig_term })
352        Ok(oblig_term)
353    }
354
355    fn assemble_candidates_from_predicates(
356        &mut self,
357        predicates: &List<Clause>,
358        obligation: &AliasTy,
359        ctor: fn(Binder<ProjectionPredicate>) -> Candidate,
360        candidates: &mut Vec<Candidate>,
361    ) {
362        let tcx = self.tcx();
363        let rustc_obligation = obligation.to_rustc(tcx);
364
365        for predicate in predicates {
366            if let Some(pred) = predicate.as_projection_clause()
367                && pred.skip_binder_ref().projection_ty.to_rustc(tcx) == rustc_obligation
368            {
369                candidates.push(ctor(pred));
370            }
371        }
372    }
373
374    fn assemble_candidates_from_param_env(
375        &mut self,
376        obligation: &AliasTy,
377        candidates: &mut Vec<Candidate>,
378    ) {
379        let predicates = self.param_env.clone();
380        self.assemble_candidates_from_predicates(
381            &predicates,
382            obligation,
383            Candidate::ParamEnv,
384            candidates,
385        );
386    }
387
388    fn assemble_candidates_from_trait_def(
389        &mut self,
390        obligation: &AliasTy,
391        candidates: &mut Vec<Candidate>,
392    ) -> InferResult {
393        if let GenericArg::Base(ctor) = &obligation.args[0]
394            && let BaseTy::Alias(AliasKind::Opaque, alias_ty) = ctor.as_bty_skipping_binder()
395        {
396            debug_assert!(!alias_ty.has_escaping_bvars());
397            let bounds = self.genv().item_bounds(alias_ty.def_id)?.instantiate(
398                self.tcx(),
399                &alias_ty.args,
400                &alias_ty.refine_args,
401            );
402            self.assemble_candidates_from_predicates(
403                &bounds,
404                obligation,
405                Candidate::TraitDef,
406                candidates,
407            );
408        }
409        Ok(())
410    }
411
412    fn assemble_candidates_from_impls(
413        &mut self,
414        obligation: &AliasTy,
415        candidates: &mut Vec<Candidate>,
416    ) -> QueryResult {
417        let trait_pred = Obligation::new(
418            self.tcx(),
419            ObligationCause::dummy(),
420            self.rustc_param_env(),
421            obligation.to_rustc(self.tcx()).trait_ref(self.tcx()),
422        );
423        match self.selcx.select(&trait_pred) {
424            Ok(Some(ImplSource::UserDefined(impl_data))) => {
425                candidates.push(Candidate::UserDefinedImpl(impl_data.impl_def_id));
426            }
427            Ok(_) => (),
428            Err(e) => bug!("error selecting {trait_pred:?}: {e:?}"),
429        }
430        Ok(())
431    }
432
433    fn def_id(&self) -> DefId {
434        self.infcx.def_id
435    }
436
437    fn genv(&self) -> GlobalEnv<'genv, 'tcx> {
438        self.infcx.genv
439    }
440
441    fn tcx(&self) -> TyCtxt<'tcx> {
442        self.selcx.tcx()
443    }
444
445    fn rustc_param_env(&self) -> rustc_middle::ty::ParamEnv<'tcx> {
446        self.selcx.tcx().param_env(self.def_id())
447    }
448}
449
450impl FallibleTypeFolder for Normalizer<'_, '_, '_, '_> {
451    type Error = QueryErr;
452
453    fn try_fold_sort(&mut self, sort: &Sort) -> Result<Sort, Self::Error> {
454        match sort {
455            Sort::Alias(AliasKind::Weak, alias_ty) => {
456                self.genv()
457                    .normalize_weak_alias_sort(alias_ty)?
458                    .try_fold_with(self)
459            }
460            Sort::Alias(AliasKind::Projection, alias_ty) => {
461                let (changed, ctor) = self.normalize_projection_ty(alias_ty)?;
462                let sort = ctor.sort();
463                if changed { sort.try_fold_with(self) } else { Ok(sort) }
464            }
465            _ => sort.try_super_fold_with(self),
466        }
467    }
468
469    // As shown in https://github.com/flux-rs/flux/issues/711 one round of `normalize_projections`
470    // can replace one projection e.g. `<Rev<Iter<[i32]> as Iterator>::Item` with another e.g.
471    // `<Iter<[i32]> as Iterator>::Item` We want to compute a "fixpoint" i.e. keep going until no
472    // change, so that e.g. the above is normalized all the way to `i32`, which is what the `changed`
473    // is for.
474    fn try_fold_ty(&mut self, ty: &Ty) -> Result<Ty, Self::Error> {
475        match ty.kind() {
476            TyKind::Indexed(BaseTy::Alias(AliasKind::Weak, alias_ty), idx) => {
477                Ok(self
478                    .genv()
479                    .type_of(alias_ty.def_id)?
480                    .instantiate(self.tcx(), &alias_ty.args, &alias_ty.refine_args)
481                    .expect_ctor()
482                    .replace_bound_reft(idx))
483            }
484            TyKind::Indexed(BaseTy::Alias(AliasKind::Projection, alias_ty), idx) => {
485                let (changed, ctor) = self.normalize_projection_ty(alias_ty)?;
486                let ty = ctor.replace_bound_reft(idx).to_ty();
487                if changed { ty.try_fold_with(self) } else { Ok(ty) }
488            }
489            _ => ty.try_super_fold_with(self),
490        }
491    }
492
493    fn try_fold_subset_ty(&mut self, sty: &SubsetTy) -> Result<SubsetTy, Self::Error> {
494        match &sty.bty {
495            BaseTy::Alias(AliasKind::Weak, _alias_ty) => {
496                // Weak aliases are always expanded during conversion. We could in theory normalize
497                // them here but we don't guaranatee that type aliases expand to a subset ty. If we
498                // ever stop expanding aliases during conv we would need to guarantee that aliases
499                // used as a generic base expand to a subset type.
500                tracked_span_bug!()
501            }
502            BaseTy::Alias(AliasKind::Projection, alias_ty) => {
503                let (changed, ctor) = self.normalize_projection_ty(alias_ty)?;
504                let ty = ctor.replace_bound_reft(&sty.idx).strengthen(&sty.pred);
505                if changed { ty.try_fold_with(self) } else { Ok(ty) }
506            }
507            _ => sty.try_super_fold_with(self),
508        }
509    }
510
511    fn try_fold_expr(&mut self, expr: &Expr) -> Result<Expr, Self::Error> {
512        if let ExprKind::Alias(alias_pred, refine_args) = expr.kind() {
513            self.normalize_alias_reft(alias_pred, refine_args)
514        } else {
515            expr.try_super_fold_with(self)
516        }
517    }
518
519    fn try_fold_const(&mut self, c: &Const) -> Result<Const, Self::Error> {
520        let c = c.to_rustc(self.tcx());
521        rustc_trait_selection::traits::evaluate_const(self.selcx.infcx, c, self.rustc_param_env())
522            .lower(self.tcx())
523            .map_err(|e| QueryErr::unsupported(self.def_id(), e.into_err()))
524    }
525}
526
527#[derive(Debug)]
528pub enum Candidate {
529    UserDefinedImpl(DefId),
530    ParamEnv(Binder<ProjectionPredicate>),
531    TraitDef(Binder<ProjectionPredicate>),
532}
533
534#[derive(Debug)]
535struct TVarSubst {
536    args: Vec<Option<GenericArg>>,
537}
538
539impl GenericsSubstDelegate for &TVarSubst {
540    type Error = ();
541
542    fn ty_for_param(&mut self, param_ty: rustc_middle::ty::ParamTy) -> Result<Ty, Self::Error> {
543        match self.args.get(param_ty.index as usize) {
544            Some(Some(GenericArg::Ty(ty))) => Ok(ty.clone()),
545            Some(None) => Err(()),
546            arg => tracked_span_bug!("expected type for generic parameter, found `{arg:?}`"),
547        }
548    }
549
550    fn sort_for_param(&mut self, param_ty: rustc_middle::ty::ParamTy) -> Result<Sort, Self::Error> {
551        match self.args.get(param_ty.index as usize) {
552            Some(Some(GenericArg::Base(ctor))) => Ok(ctor.sort()),
553            Some(None) => Err(()),
554            arg => tracked_span_bug!("expected type for generic parameter, found `{arg:?}`"),
555        }
556    }
557
558    fn ctor_for_param(
559        &mut self,
560        param_ty: rustc_middle::ty::ParamTy,
561    ) -> Result<SubsetTyCtor, Self::Error> {
562        match self.args.get(param_ty.index as usize) {
563            Some(Some(GenericArg::Base(ctor))) => Ok(ctor.clone()),
564            Some(None) => Err(()),
565            arg => tracked_span_bug!("expected type for generic parameter, found `{arg:?}`"),
566        }
567    }
568
569    fn region_for_param(&mut self, _ebr: rustc_middle::ty::EarlyParamRegion) -> Region {
570        tracked_span_bug!()
571    }
572
573    fn expr_for_param_const(&self, _param_const: rustc_middle::ty::ParamConst) -> Expr {
574        tracked_span_bug!()
575    }
576
577    fn const_for_param(&mut self, _param: &Const) -> Const {
578        tracked_span_bug!()
579    }
580}
581
582struct SortNormalizer<'infcx, 'genv, 'tcx> {
583    def_id: DefId,
584    infcx: &'infcx rustc_infer::infer::InferCtxt<'tcx>,
585    genv: GlobalEnv<'genv, 'tcx>,
586}
587impl<'infcx, 'genv, 'tcx> SortNormalizer<'infcx, 'genv, 'tcx> {
588    fn new(
589        def_id: DefId,
590        genv: GlobalEnv<'genv, 'tcx>,
591        infcx: &'infcx rustc_infer::infer::InferCtxt<'tcx>,
592    ) -> Self {
593        Self { def_id, infcx, genv }
594    }
595}
596
597impl FallibleTypeFolder for SortNormalizer<'_, '_, '_> {
598    type Error = QueryErr;
599    fn try_fold_sort(&mut self, sort: &Sort) -> Result<Sort, Self::Error> {
600        match sort {
601            Sort::Alias(AliasKind::Weak, alias_ty) => {
602                self.genv
603                    .normalize_weak_alias_sort(alias_ty)?
604                    .try_fold_with(self)
605            }
606            Sort::Alias(AliasKind::Projection, alias_ty) => {
607                let (changed, ctor) = normalize_projection_ty_with_rustc(
608                    self.genv,
609                    self.def_id,
610                    self.infcx,
611                    alias_ty,
612                )?;
613                let sort = ctor.sort();
614                if changed { sort.try_fold_with(self) } else { Ok(sort) }
615            }
616            _ => sort.try_super_fold_with(self),
617        }
618    }
619}
620
621impl TVarSubst {
622    fn new(generics: &rustc_middle::ty::Generics) -> Self {
623        Self { args: vec![None; generics.count()] }
624    }
625
626    fn instantiate_partial<T: TypeFoldable>(&mut self, pred: EarlyBinder<T>) -> Option<T> {
627        let mut folder = GenericsSubstFolder::new(&*self, &[]);
628        pred.skip_binder().try_fold_with(&mut folder).ok()
629    }
630
631    fn finish<'tcx>(
632        self,
633        tcx: TyCtxt<'tcx>,
634        generics: &'tcx rustc_middle::ty::Generics,
635    ) -> QueryResult<Vec<GenericArg>> {
636        self.args
637            .into_iter()
638            .enumerate()
639            .map(|(idx, arg)| {
640                if let Some(arg) = arg {
641                    Ok(arg)
642                } else {
643                    let param = generics.param_at(idx, tcx);
644                    Err(QueryErr::bug(
645                        None,
646                        format!("cannot infer substitution for {param:?} at index {idx}"),
647                    ))
648                }
649            })
650            .try_collect_vec()
651    }
652
653    fn generic_args(&mut self, a: &GenericArg, b: &GenericArg) {
654        match (a, b) {
655            (GenericArg::Ty(a), GenericArg::Ty(b)) => self.tys(a, b),
656            (GenericArg::Lifetime(a), GenericArg::Lifetime(b)) => self.regions(*a, *b),
657            (GenericArg::Base(a), GenericArg::Base(b)) => {
658                self.subset_tys(a, b);
659            }
660            (GenericArg::Const(a), GenericArg::Const(b)) => self.consts(a, b),
661            _ => {}
662        }
663    }
664
665    fn tys(&mut self, a: &Ty, b: &Ty) {
666        if let TyKind::Param(param_ty) = a.kind() {
667            if !b.has_escaping_bvars() {
668                self.insert_generic_arg(param_ty.index, GenericArg::Ty(b.clone()));
669            }
670        } else {
671            let Some(a_bty) = a.as_bty_skipping_existentials() else { return };
672            let Some(b_bty) = b.as_bty_skipping_existentials() else { return };
673            self.btys(a_bty, b_bty);
674        }
675    }
676
677    fn subset_tys(&mut self, a: &SubsetTyCtor, b: &SubsetTyCtor) {
678        let bty_a = a.as_bty_skipping_binder();
679        let bty_b = b.as_bty_skipping_binder();
680        if let BaseTy::Param(param_ty) = bty_a {
681            if !b.has_escaping_bvars() {
682                self.insert_generic_arg(param_ty.index, GenericArg::Base(b.clone()));
683            }
684        } else {
685            self.btys(bty_a, bty_b);
686        }
687    }
688
689    fn btys(&mut self, a: &BaseTy, b: &BaseTy) {
690        match (a, b) {
691            (BaseTy::Param(param_ty), _) => {
692                if !b.has_escaping_bvars() {
693                    let sort = b.sort();
694                    let ctor =
695                        Binder::bind_with_sort(SubsetTy::trivial(b.clone(), Expr::nu()), sort);
696                    self.insert_generic_arg(param_ty.index, GenericArg::Base(ctor));
697                }
698            }
699            (BaseTy::Adt(_, a_args), BaseTy::Adt(_, b_args)) => {
700                debug_assert_eq!(a_args.len(), b_args.len());
701                for (a_arg, b_arg) in iter::zip(a_args, b_args) {
702                    self.generic_args(a_arg, b_arg);
703                }
704            }
705            (BaseTy::Array(a_ty, a_n), BaseTy::Array(b_ty, b_n)) => {
706                self.tys(a_ty, b_ty);
707                self.consts(a_n, b_n);
708            }
709            (BaseTy::Tuple(a_tys), BaseTy::Tuple(b_tys)) => {
710                debug_assert_eq!(a_tys.len(), b_tys.len());
711                for (a_ty, b_ty) in iter::zip(a_tys, b_tys) {
712                    self.tys(a_ty, b_ty);
713                }
714            }
715            (BaseTy::Ref(a_re, a_ty, _), BaseTy::Ref(b_re, b_ty, _)) => {
716                self.regions(*a_re, *b_re);
717                self.tys(a_ty, b_ty);
718            }
719            (BaseTy::Slice(a_ty), BaseTy::Slice(b_ty)) => {
720                self.tys(a_ty, b_ty);
721            }
722            _ => {}
723        }
724    }
725
726    fn regions(&mut self, a: Region, b: Region) {
727        if let Region::ReEarlyParam(ebr) = a {
728            self.insert_generic_arg(ebr.index, GenericArg::Lifetime(b));
729        }
730    }
731
732    fn consts(&mut self, a: &Const, b: &Const) {
733        if let ConstKind::Param(param_const) = a.kind {
734            self.insert_generic_arg(param_const.index, GenericArg::Const(b.clone()));
735        }
736    }
737
738    fn insert_generic_arg(&mut self, idx: u32, arg: GenericArg) {
739        if let Some(old) = &self.args[idx as usize]
740            && old != &arg
741        {
742            tracked_span_bug!("ambiguous substitution: old=`{old:?}`, new: `{arg:?}`");
743        }
744        self.args[idx as usize].replace(arg);
745    }
746}
747
748/// Normalize an [`rty::AliasTy`] by converting it to rustc, normalizing it using rustc api, and
749/// then mapping the result back to `rty`. This will lose refinements and it should only be used
750/// to normalize sorts because they should only contain unrefined types. However, we are also using
751/// it as a hack to normalize types in cases where we fail to collect a candidate, this is unsound
752/// and should be removed.
753///
754/// [`rty::AliasTy`]: AliasTy
755fn normalize_projection_ty_with_rustc<'tcx>(
756    genv: GlobalEnv<'_, 'tcx>,
757    def_id: DefId,
758    infcx: &rustc_infer::infer::InferCtxt<'tcx>,
759    obligation: &AliasTy,
760) -> QueryResult<(bool, SubsetTyCtor)> {
761    let tcx = genv.tcx();
762    let projection_ty = obligation.to_rustc(tcx);
763    let cause = ObligationCause::dummy();
764    let param_env = tcx.param_env(def_id);
765
766    let pre_ty = projection_ty.to_ty(tcx);
767    let at = infcx.at(&cause, param_env);
768    let ty = deeply_normalize::<rustc_middle::ty::Ty<'tcx>, FulfillmentError>(at, pre_ty)
769        .map_err(|err| query_bug!("{err:?}"))?;
770
771    let changed = pre_ty != ty;
772    let rustc_ty = ty.lower(tcx).map_err(|reason| query_bug!("{reason:?}"))?;
773
774    Ok((
775        changed,
776        Refiner::default_for_item(genv, def_id)?
777            .refine_ty_or_base(&rustc_ty)?
778            .expect_base(),
779    ))
780}