flux_infer/
projections.rs

1use std::iter;
2
3use flux_common::{bug, iter::IterExt, tracked_span_bug};
4use flux_middle::{
5    global_env::GlobalEnv,
6    queries::{QueryErr, QueryResult},
7    query_bug,
8    rty::{
9        self, AliasKind, AliasReft, AliasTy, BaseTy, Binder, Clause, ClauseKind, Const, ConstKind,
10        EarlyBinder, Expr, ExprKind, GenericArg, List, ProjectionPredicate, RefineArgs, Region,
11        Sort, SubsetTy, SubsetTyCtor, Ty, TyKind, TyOrBase,
12        fold::{FallibleTypeFolder, TypeFoldable, TypeSuperFoldable, TypeVisitable},
13        refining::Refiner,
14        subst::{GenericsSubstDelegate, GenericsSubstFolder},
15    },
16};
17use flux_rustc_bridge::{ToRustc, lowering::Lower};
18use itertools::izip;
19use rustc_hir::def_id::DefId;
20use rustc_infer::traits::{BuiltinImplSource, Obligation};
21use rustc_middle::{
22    traits::{ImplSource, ObligationCause},
23    ty::{TyCtxt, Variance},
24};
25use rustc_trait_selection::{
26    solve::deeply_normalize,
27    traits::{FulfillmentError, SelectionContext},
28};
29use rustc_type_ir::TypeVisitableExt;
30
31use crate::{
32    fixpoint_encoding::KVarEncoding,
33    infer::{InferCtxtAt, InferResult},
34    refine_tree::Scope,
35};
36
37pub trait NormalizeExt: TypeFoldable {
38    fn deeply_normalize(&self, infcx: &mut InferCtxtAt) -> QueryResult<Self>;
39
40    /// Deeply normalize projections but only inside sorts
41    fn deeply_normalize_sorts<'tcx>(
42        &self,
43        def_id: DefId,
44        genv: GlobalEnv<'_, 'tcx>,
45        infcx: &rustc_infer::infer::InferCtxt<'tcx>,
46    ) -> QueryResult<Self>;
47}
48
49impl<T: TypeFoldable> NormalizeExt for T {
50    fn deeply_normalize(&self, infcx: &mut InferCtxtAt) -> QueryResult<Self> {
51        let span = infcx.span;
52        let infcx_orig = &mut infcx.infcx;
53        let mut infcx = infcx_orig.branch();
54        let infcx = infcx.at(span);
55        let mut normalizer = Normalizer::new(infcx)?;
56        self.erase_regions().try_fold_with(&mut normalizer)
57    }
58
59    fn deeply_normalize_sorts<'tcx>(
60        &self,
61        def_id: DefId,
62        genv: GlobalEnv<'_, 'tcx>,
63        infcx: &rustc_infer::infer::InferCtxt<'tcx>,
64    ) -> QueryResult<Self> {
65        let mut normalizer = SortNormalizer::new(def_id, genv, infcx);
66        self.erase_regions().try_fold_with(&mut normalizer)
67    }
68}
69
70struct Normalizer<'a, 'infcx, 'genv, 'tcx> {
71    infcx: InferCtxtAt<'a, 'infcx, 'genv, 'tcx>,
72    selcx: SelectionContext<'infcx, 'tcx>,
73    param_env: List<Clause>,
74    scope: Scope,
75}
76
77impl<'a, 'infcx, 'genv, 'tcx> Normalizer<'a, 'infcx, 'genv, 'tcx> {
78    fn new(infcx: InferCtxtAt<'a, 'infcx, 'genv, 'tcx>) -> QueryResult<Self> {
79        let predicates = infcx.genv.predicates_of(infcx.def_id)?;
80        let param_env = predicates.instantiate_identity().predicates.clone();
81        let selcx = SelectionContext::new(infcx.region_infcx);
82        let scope = infcx.cursor().marker().scope().unwrap();
83        Ok(Normalizer { infcx, selcx, param_env, scope })
84    }
85
86    fn normalize_projection_ty(
87        &mut self,
88        obligation: &AliasTy,
89    ) -> QueryResult<(bool, SubsetTyCtor)> {
90        // First we must recursively (i.e., deeply) normalize projection types before proceeding.
91        // For example, in `issue-1449.rs` when normalizing `<<MyChoice as Choice>::Session as FromState>::Role`
92        // we first recursively normalize to get `<End<B> as FromState>::Role`
93        let obligation = &obligation.try_fold_with(self)?;
94
95        let mut candidates = vec![];
96        self.assemble_candidates_from_param_env(obligation, &mut candidates);
97        self.assemble_candidates_from_trait_def(obligation, &mut candidates)
98            .unwrap_or_else(|err| tracked_span_bug!("{err:?}"));
99        self.assemble_candidates_from_impls(obligation, &mut candidates)?;
100        if candidates.is_empty() {
101            // TODO: This is a temporary hack that uses rustc's trait selection when FLUX fails;
102            //       The correct thing, e.g for `trait09.rs` is to make sure FLUX's param_env mirrors RUSTC,
103            //       by suitably chasing down the super-trait predicates,
104            //       see https://github.com/flux-rs/flux/issues/737
105            let (changed, ty_ctor) = normalize_projection_ty_with_rustc(
106                self.genv(),
107                self.def_id(),
108                self.infcx.region_infcx,
109                obligation,
110            )?;
111            return Ok((changed, ty_ctor));
112        }
113        if candidates.len() > 1 {
114            bug!("ambiguity when resolving `{obligation:?}` in {:?}", self.def_id());
115        }
116        let ctor = self.confirm_candidate(candidates.pop().unwrap(), obligation)?;
117        Ok((true, ctor))
118    }
119
120    fn find_resolved_predicates(
121        &self,
122        subst: &mut TVarSubst,
123        preds: Vec<EarlyBinder<ProjectionPredicate>>,
124    ) -> (Vec<ProjectionPredicate>, Vec<EarlyBinder<ProjectionPredicate>>) {
125        let mut resolved = vec![];
126        let mut unresolved = vec![];
127        for pred in preds {
128            let term = pred.clone().skip_binder().term;
129            let alias_ty = pred.clone().map(|p| p.projection_ty);
130            match subst.instantiate_partial(alias_ty) {
131                Some(projection_ty) => {
132                    let pred = ProjectionPredicate { projection_ty, term };
133                    resolved.push(pred);
134                }
135                None => unresolved.push(pred.clone()),
136            }
137        }
138        (resolved, unresolved)
139    }
140
141    // See issue-829*.rs for an example of what this function is for.
142    fn resolve_projection_predicates(
143        &mut self,
144        subst: &mut TVarSubst,
145        impl_def_id: DefId,
146    ) -> QueryResult {
147        let mut projection_preds: Vec<_> = self
148            .genv()
149            .predicates_of(impl_def_id)?
150            .skip_binder()
151            .predicates
152            .iter()
153            .filter_map(|pred| {
154                if let ClauseKind::Projection(pred) = pred.kind_skipping_binder() {
155                    Some(EarlyBinder(pred.clone()))
156                } else {
157                    None
158                }
159            })
160            .collect();
161
162        while !projection_preds.is_empty() {
163            let (resolved, unresolved) = self.find_resolved_predicates(subst, projection_preds);
164
165            if resolved.is_empty() {
166                break; // failed: there is some unresolved projection pred!
167            }
168            for p in resolved {
169                let obligation = &p.projection_ty;
170                let (_, ctor) = self.normalize_projection_ty(obligation)?;
171                subst.subset_tys(&p.term, &ctor);
172            }
173            projection_preds = unresolved;
174        }
175        Ok(())
176    }
177
178    fn confirm_candidate(
179        &mut self,
180        candidate: Candidate,
181        obligation: &AliasTy,
182    ) -> QueryResult<SubsetTyCtor> {
183        let tcx = self.tcx();
184        match candidate {
185            Candidate::ParamEnv(pred) | Candidate::TraitDef(pred) => {
186                let rustc_obligation = obligation.to_rustc(tcx);
187                let parent_id = rustc_obligation.trait_ref(tcx).def_id;
188                // Do fn-subtyping if the candidate was a fn-trait
189                if tcx.is_fn_trait(parent_id) {
190                    let res = self
191                        .fn_subtype_projection_ty(pred, obligation)
192                        .unwrap_or_else(|err| tracked_span_bug!("{err:?}"));
193                    Ok(res)
194                } else {
195                    Ok(pred.skip_binder().term)
196                }
197            }
198            Candidate::UserDefinedImpl(impl_def_id) => {
199                // Given a projection obligation
200                //     <IntoIter<{v. i32[v] | v > 0}, Global> as Iterator>::Item
201                // and the id of a rust impl block
202                //     impl<T, A: Allocator> Iterator for IntoIter<T, A>
203
204                // 1. MATCH the self type of the rust impl block and the flux self type of the obligation
205                //    to infer a substitution
206                //        IntoIter<{v. i32[v] | v > 0}, Global> MATCH IntoIter<T, A>
207                //            => {T -> {v. i32[v] | v > 0}, A -> Global}
208
209                let impl_trait_ref = self.genv().impl_trait_ref(impl_def_id)?.skip_binder();
210
211                let generics = self.tcx().generics_of(impl_def_id);
212
213                let mut subst = TVarSubst::new(generics);
214                for (a, b) in iter::zip(&impl_trait_ref.args, &obligation.args) {
215                    subst.generic_args(a, b);
216                }
217
218                // 2. Gather the ProjectionPredicates and solve them see issue-808.rs
219                self.resolve_projection_predicates(&mut subst, impl_def_id)?;
220
221                let args = subst.finish(self.tcx(), generics)?;
222
223                // 3. Get the associated type in the impl block and apply the substitution to it
224                let assoc_type_id = tcx
225                    .associated_items(impl_def_id)
226                    .in_definition_order()
227                    .find(|item| item.trait_item_def_id() == Some(obligation.def_id))
228                    .map(|item| item.def_id)
229                    .ok_or_else(|| {
230                        query_bug!("no associated type for {obligation:?} in impl {impl_def_id:?}")
231                    })?;
232                Ok(self
233                    .genv()
234                    .type_of(assoc_type_id)?
235                    .instantiate(tcx, &args, &[])
236                    .expect_subset_ty_ctor())
237            }
238        }
239    }
240
241    fn fn_subtype_projection_ty(
242        &mut self,
243        actual: Binder<ProjectionPredicate>,
244        oblig: &AliasTy,
245    ) -> InferResult<SubsetTyCtor> {
246        // Step 1: bs <- unpack(b1...)
247        let obligs: Vec<_> = oblig
248            .args
249            .iter()
250            .map(|arg| {
251                match arg {
252                    GenericArg::Ty(ty) => GenericArg::Ty(self.infcx.unpack(ty)),
253                    GenericArg::Base(ctor) => GenericArg::Ty(self.infcx.unpack(&ctor.to_ty())),
254                    _ => arg.clone(),
255                }
256            })
257            .collect();
258
259        let span = self.infcx.span;
260        let mut infcx = self.infcx.at(span);
261
262        let actual = infcx.ensure_resolved_evars(|infcx| {
263            // Step 2: as <- fresh(a1...)
264            let actual = actual
265                .replace_bound_vars(
266                    |_| rty::ReErased,
267                    |sort, mode, _| infcx.fresh_infer_var(sort, mode),
268                )
269                .deeply_normalize(infcx)?;
270
271            let actuals = actual.projection_ty.args.iter().map(|arg| {
272                match arg {
273                    GenericArg::Base(ctor) => GenericArg::Ty(ctor.to_ty()),
274                    _ => arg.clone(),
275                }
276            });
277
278            // Step 3: bs <: as
279            for (a, b) in izip!(actuals.skip(1), obligs.iter().skip(1)) {
280                infcx.subtyping_generic_args(
281                    Variance::Contravariant,
282                    &a,
283                    b,
284                    crate::infer::ConstrReason::Predicate,
285                )?;
286            }
287            Ok(actual)
288        })?;
289        // Step 4: check all evars are solved, plug back into ProjectionPredicate
290        let actual = infcx.fully_resolve_evars(&actual);
291
292        // Step 5: generate "fresh" type for actual.term,
293        let oblig_term = actual.term.with_holes().replace_holes(|binders, kind| {
294            assert!(kind == rty::HoleKind::Pred);
295            let scope = &self.scope;
296            infcx.fresh_kvar_in_scope(binders, scope, KVarEncoding::Conj)
297        });
298
299        // Step 6: subtyping obligation on output
300        infcx.subtyping(
301            &actual.term.to_ty(),
302            &oblig_term.to_ty(),
303            crate::infer::ConstrReason::Predicate,
304        )?;
305        // Ok(ProjectionPredicate { projection_ty: actual.projection_ty, term: oblig_term })
306        Ok(oblig_term)
307    }
308
309    fn assemble_candidates_from_predicates(
310        &mut self,
311        predicates: &List<Clause>,
312        obligation: &AliasTy,
313        ctor: fn(Binder<ProjectionPredicate>) -> Candidate,
314        candidates: &mut Vec<Candidate>,
315    ) {
316        let tcx = self.tcx();
317        let rustc_obligation = obligation.to_rustc(tcx);
318
319        for predicate in predicates {
320            if let Some(pred) = predicate.as_projection_clause()
321                && pred.skip_binder_ref().projection_ty.to_rustc(tcx) == rustc_obligation
322            {
323                candidates.push(ctor(pred));
324            }
325        }
326    }
327
328    fn assemble_candidates_from_param_env(
329        &mut self,
330        obligation: &AliasTy,
331        candidates: &mut Vec<Candidate>,
332    ) {
333        let predicates = self.param_env.clone();
334        self.assemble_candidates_from_predicates(
335            &predicates,
336            obligation,
337            Candidate::ParamEnv,
338            candidates,
339        );
340    }
341
342    fn assemble_candidates_from_trait_def(
343        &mut self,
344        obligation: &AliasTy,
345        candidates: &mut Vec<Candidate>,
346    ) -> InferResult {
347        if let GenericArg::Base(ctor) = &obligation.args[0]
348            && let BaseTy::Alias(AliasKind::Opaque, alias_ty) = ctor.as_bty_skipping_binder()
349        {
350            debug_assert!(!alias_ty.has_escaping_bvars());
351            let bounds = self.genv().item_bounds(alias_ty.def_id)?.instantiate(
352                self.tcx(),
353                &alias_ty.args,
354                &alias_ty.refine_args,
355            );
356            self.assemble_candidates_from_predicates(
357                &bounds,
358                obligation,
359                Candidate::TraitDef,
360                candidates,
361            );
362        }
363        Ok(())
364    }
365
366    fn assemble_candidates_from_impls(
367        &mut self,
368        obligation: &AliasTy,
369        candidates: &mut Vec<Candidate>,
370    ) -> QueryResult {
371        let trait_ref = obligation.to_rustc(self.tcx()).trait_ref(self.tcx());
372        let trait_ref = self.tcx().erase_and_anonymize_regions(trait_ref);
373        let trait_pred = Obligation::new(
374            self.tcx(),
375            ObligationCause::dummy(),
376            self.rustc_param_env(),
377            trait_ref,
378        );
379        // FIXME(nilehmann) This is a patch to not panic inside rustc so we are
380        // able to catch the bug
381        if trait_pred.has_escaping_bound_vars() {
382            tracked_span_bug!();
383        }
384        match self.selcx.select(&trait_pred) {
385            Ok(Some(ImplSource::UserDefined(impl_data))) => {
386                candidates.push(Candidate::UserDefinedImpl(impl_data.impl_def_id));
387            }
388            Ok(_) => (),
389            Err(e) => bug!("error selecting {trait_pred:?}: {e:?}"),
390        }
391        Ok(())
392    }
393
394    fn def_id(&self) -> DefId {
395        self.infcx.def_id
396    }
397
398    fn genv(&self) -> GlobalEnv<'genv, 'tcx> {
399        self.infcx.genv
400    }
401
402    fn tcx(&self) -> TyCtxt<'tcx> {
403        self.selcx.tcx()
404    }
405
406    fn rustc_param_env(&self) -> rustc_middle::ty::ParamEnv<'tcx> {
407        self.selcx.tcx().param_env(self.def_id())
408    }
409}
410
411impl FallibleTypeFolder for Normalizer<'_, '_, '_, '_> {
412    type Error = QueryErr;
413
414    fn try_fold_sort(&mut self, sort: &Sort) -> Result<Sort, Self::Error> {
415        match sort {
416            Sort::Alias(AliasKind::Free, alias_ty) => {
417                self.genv()
418                    .normalize_free_alias_sort(alias_ty)?
419                    .try_fold_with(self)
420            }
421            Sort::Alias(AliasKind::Projection, alias_ty) => {
422                let (changed, ctor) = self.normalize_projection_ty(alias_ty)?;
423                let sort = ctor.sort();
424                if changed { sort.try_fold_with(self) } else { Ok(sort) }
425            }
426            _ => sort.try_super_fold_with(self),
427        }
428    }
429
430    // As shown in https://github.com/flux-rs/flux/issues/711 one round of `normalize_projections`
431    // can replace one projection e.g. `<Rev<Iter<[i32]> as Iterator>::Item` with another e.g.
432    // `<Iter<[i32]> as Iterator>::Item` We want to compute a "fixpoint" i.e. keep going until no
433    // change, so that e.g. the above is normalized all the way to `i32`, which is what the `changed`
434    // is for.
435    fn try_fold_ty(&mut self, ty: &Ty) -> Result<Ty, Self::Error> {
436        match ty.kind() {
437            TyKind::Indexed(BaseTy::Alias(AliasKind::Free, alias_ty), idx) => {
438                Ok(self
439                    .genv()
440                    .type_of(alias_ty.def_id)?
441                    .instantiate(self.tcx(), &alias_ty.args, &alias_ty.refine_args)
442                    .expect_ctor()
443                    .replace_bound_reft(idx))
444            }
445            TyKind::Indexed(BaseTy::Alias(AliasKind::Projection, alias_ty), idx) => {
446                let (changed, ctor) = self.normalize_projection_ty(alias_ty)?;
447                let ty = ctor.replace_bound_reft(idx).to_ty();
448                if changed { ty.try_fold_with(self) } else { Ok(ty) }
449            }
450            _ => ty.try_super_fold_with(self),
451        }
452    }
453
454    fn try_fold_subset_ty(&mut self, sty: &SubsetTy) -> Result<SubsetTy, Self::Error> {
455        match &sty.bty {
456            BaseTy::Alias(AliasKind::Free, _alias_ty) => {
457                // Weak aliases are always expanded during conversion. We could in theory normalize
458                // them here but we don't guaranatee that type aliases expand to a subset ty. If we
459                // ever stop expanding aliases during conv we would need to guarantee that aliases
460                // used as a generic base expand to a subset type.
461                tracked_span_bug!()
462            }
463            BaseTy::Alias(AliasKind::Projection, alias_ty) => {
464                let (changed, ctor) = self.normalize_projection_ty(alias_ty)?;
465                let ty = ctor.replace_bound_reft(&sty.idx).strengthen(&sty.pred);
466                if changed { ty.try_fold_with(self) } else { Ok(ty) }
467            }
468            _ => sty.try_super_fold_with(self),
469        }
470    }
471
472    fn try_fold_expr(&mut self, expr: &Expr) -> Result<Expr, Self::Error> {
473        if let ExprKind::Alias(alias_pred, refine_args) = expr.kind() {
474            let (changed, e) = normalize_alias_reft(
475                self.genv(),
476                self.def_id(),
477                self.selcx.infcx,
478                alias_pred,
479                refine_args,
480            )?;
481            if changed { e.try_fold_with(self) } else { Ok(e) }
482        } else {
483            expr.try_super_fold_with(self)
484        }
485    }
486
487    fn try_fold_const(&mut self, c: &Const) -> Result<Const, Self::Error> {
488        let c = c.to_rustc(self.tcx());
489        rustc_trait_selection::traits::evaluate_const(self.selcx.infcx, c, self.rustc_param_env())
490            .lower(self.tcx())
491            .map_err(|e| QueryErr::unsupported(self.def_id(), e.into_err()))
492    }
493}
494
495#[derive(Debug)]
496pub enum Candidate {
497    UserDefinedImpl(DefId),
498    ParamEnv(Binder<ProjectionPredicate>),
499    TraitDef(Binder<ProjectionPredicate>),
500}
501
502#[derive(Debug)]
503struct TVarSubst {
504    args: Vec<Option<GenericArg>>,
505}
506
507impl GenericsSubstDelegate for &TVarSubst {
508    type Error = ();
509
510    fn ty_for_param(&mut self, param_ty: rustc_middle::ty::ParamTy) -> Result<Ty, Self::Error> {
511        match self.args.get(param_ty.index as usize) {
512            Some(Some(GenericArg::Ty(ty))) => Ok(ty.clone()),
513            Some(None) => Err(()),
514            arg => tracked_span_bug!("expected type for generic parameter, found `{arg:?}`"),
515        }
516    }
517
518    fn sort_for_param(&mut self, param_ty: rustc_middle::ty::ParamTy) -> Result<Sort, Self::Error> {
519        match self.args.get(param_ty.index as usize) {
520            Some(Some(GenericArg::Base(ctor))) => Ok(ctor.sort()),
521            Some(None) => Err(()),
522            arg => tracked_span_bug!("expected type for generic parameter, found `{arg:?}`"),
523        }
524    }
525
526    fn ctor_for_param(
527        &mut self,
528        param_ty: rustc_middle::ty::ParamTy,
529    ) -> Result<SubsetTyCtor, Self::Error> {
530        match self.args.get(param_ty.index as usize) {
531            Some(Some(GenericArg::Base(ctor))) => Ok(ctor.clone()),
532            Some(None) => Err(()),
533            arg => tracked_span_bug!("expected type for generic parameter, found `{arg:?}`"),
534        }
535    }
536
537    fn region_for_param(
538        &mut self,
539        ebr: rustc_middle::ty::EarlyParamRegion,
540    ) -> Result<Region, Self::Error> {
541        match self.args.get(ebr.index as usize) {
542            Some(Some(GenericArg::Lifetime(region))) => Ok(*region),
543            Some(None) => Err(()),
544            arg => tracked_span_bug!("expected region for generic parameter, found `{arg:?}`"),
545        }
546    }
547
548    fn expr_for_param_const(&self, _param_const: rustc_middle::ty::ParamConst) -> Expr {
549        tracked_span_bug!()
550    }
551
552    fn const_for_param(&mut self, _param: &Const) -> Const {
553        tracked_span_bug!()
554    }
555}
556
557struct SortNormalizer<'infcx, 'genv, 'tcx> {
558    def_id: DefId,
559    infcx: &'infcx rustc_infer::infer::InferCtxt<'tcx>,
560    genv: GlobalEnv<'genv, 'tcx>,
561}
562
563impl<'infcx, 'genv, 'tcx> SortNormalizer<'infcx, 'genv, 'tcx> {
564    fn new(
565        def_id: DefId,
566        genv: GlobalEnv<'genv, 'tcx>,
567        infcx: &'infcx rustc_infer::infer::InferCtxt<'tcx>,
568    ) -> Self {
569        Self { def_id, infcx, genv }
570    }
571}
572
573impl FallibleTypeFolder for SortNormalizer<'_, '_, '_> {
574    type Error = QueryErr;
575    fn try_fold_sort(&mut self, sort: &Sort) -> Result<Sort, Self::Error> {
576        match sort {
577            Sort::Alias(AliasKind::Free, alias_ty) => {
578                self.genv
579                    .normalize_free_alias_sort(alias_ty)?
580                    .try_fold_with(self)
581            }
582            Sort::Alias(AliasKind::Projection, alias_ty) => {
583                let (changed, ctor) = normalize_projection_ty_with_rustc(
584                    self.genv,
585                    self.def_id,
586                    self.infcx,
587                    alias_ty,
588                )?;
589                let sort = ctor.sort();
590                if changed { sort.try_fold_with(self) } else { Ok(sort) }
591            }
592            _ => sort.try_super_fold_with(self),
593        }
594    }
595}
596
597impl TVarSubst {
598    fn new(generics: &rustc_middle::ty::Generics) -> Self {
599        Self { args: vec![None; generics.count()] }
600    }
601
602    fn instantiate_partial<T: TypeFoldable>(&mut self, pred: EarlyBinder<T>) -> Option<T> {
603        let mut folder = GenericsSubstFolder::new(&*self, &[]);
604        pred.skip_binder().try_fold_with(&mut folder).ok()
605    }
606
607    fn finish<'tcx>(
608        self,
609        tcx: TyCtxt<'tcx>,
610        generics: &'tcx rustc_middle::ty::Generics,
611    ) -> QueryResult<Vec<GenericArg>> {
612        self.args
613            .into_iter()
614            .enumerate()
615            .map(|(idx, arg)| {
616                if let Some(arg) = arg {
617                    Ok(arg)
618                } else {
619                    let param = generics.param_at(idx, tcx);
620                    Err(QueryErr::bug(
621                        None,
622                        format!("cannot infer substitution for {param:?} at index {idx}"),
623                    ))
624                }
625            })
626            .try_collect_vec()
627    }
628
629    fn generic_args(&mut self, a: &GenericArg, b: &GenericArg) {
630        match (a, b) {
631            (GenericArg::Ty(a), GenericArg::Ty(b)) => self.tys(a, b),
632            (GenericArg::Lifetime(a), GenericArg::Lifetime(b)) => self.regions(*a, *b),
633            (GenericArg::Base(a), GenericArg::Base(b)) => {
634                self.subset_tys(a, b);
635            }
636            (GenericArg::Const(a), GenericArg::Const(b)) => self.consts(a, b),
637            _ => {}
638        }
639    }
640
641    fn tys(&mut self, a: &Ty, b: &Ty) {
642        if let TyKind::Param(param_ty) = a.kind() {
643            if !b.has_escaping_bvars() {
644                self.insert_generic_arg(param_ty.index, GenericArg::Ty(b.clone()));
645            }
646        } else {
647            let a = a.shallow_canonicalize().as_ty_or_base();
648            let b = b.shallow_canonicalize().as_ty_or_base();
649            if let (TyOrBase::Base(a_ctor), TyOrBase::Base(b_ctor)) = (a, b) {
650                self.subset_tys(&a_ctor, &b_ctor);
651            }
652        }
653    }
654
655    fn subset_tys(&mut self, a: &SubsetTyCtor, b: &SubsetTyCtor) {
656        let bty_a = a.as_bty_skipping_binder();
657        let bty_b = b.as_bty_skipping_binder();
658        if let BaseTy::Param(param_ty) = bty_a {
659            if !b.has_escaping_bvars() {
660                self.insert_generic_arg(param_ty.index, GenericArg::Base(b.clone()));
661            }
662        } else {
663            self.btys(bty_a, bty_b);
664        }
665    }
666
667    fn btys(&mut self, a: &BaseTy, b: &BaseTy) {
668        match (a, b) {
669            (BaseTy::Param(param_ty), _) => {
670                if !b.has_escaping_bvars() {
671                    let sort = b.sort();
672                    let ctor =
673                        Binder::bind_with_sort(SubsetTy::trivial(b.clone(), Expr::nu()), sort);
674                    self.insert_generic_arg(param_ty.index, GenericArg::Base(ctor));
675                }
676            }
677            (BaseTy::Adt(_, a_args), BaseTy::Adt(_, b_args)) => {
678                debug_assert_eq!(a_args.len(), b_args.len());
679                for (a_arg, b_arg) in iter::zip(a_args, b_args) {
680                    self.generic_args(a_arg, b_arg);
681                }
682            }
683            (BaseTy::Array(a_ty, a_n), BaseTy::Array(b_ty, b_n)) => {
684                self.tys(a_ty, b_ty);
685                self.consts(a_n, b_n);
686            }
687            (BaseTy::Tuple(a_tys), BaseTy::Tuple(b_tys)) => {
688                debug_assert_eq!(a_tys.len(), b_tys.len());
689                for (a_ty, b_ty) in iter::zip(a_tys, b_tys) {
690                    self.tys(a_ty, b_ty);
691                }
692            }
693            (BaseTy::Ref(a_re, a_ty, _), BaseTy::Ref(b_re, b_ty, _)) => {
694                self.regions(*a_re, *b_re);
695                self.tys(a_ty, b_ty);
696            }
697            (BaseTy::Slice(a_ty), BaseTy::Slice(b_ty)) => {
698                self.tys(a_ty, b_ty);
699            }
700            _ => {}
701        }
702    }
703
704    fn regions(&mut self, a: Region, b: Region) {
705        if let Region::ReEarlyParam(ebr) = a {
706            self.insert_generic_arg(ebr.index, GenericArg::Lifetime(b));
707        }
708    }
709
710    fn consts(&mut self, a: &Const, b: &Const) {
711        if let ConstKind::Param(param_const) = a.kind {
712            self.insert_generic_arg(param_const.index, GenericArg::Const(b.clone()));
713        }
714    }
715
716    fn insert_generic_arg(&mut self, idx: u32, arg: GenericArg) {
717        if let Some(old) = &self.args[idx as usize]
718            && old != &arg
719        {
720            tracked_span_bug!("ambiguous substitution: old=`{old:?}`, new: `{arg:?}`");
721        }
722        self.args[idx as usize].replace(arg);
723    }
724}
725
726/// Normalize an [`rty::AliasTy`] by converting it to rustc, normalizing it using rustc api, and
727/// then mapping the result back to `rty`. This will lose refinements and it should only be used
728/// to normalize sorts because they should only contain unrefined types. However, we are also using
729/// it as a hack to normalize types in cases where we fail to collect a candidate, this is unsound
730/// and should be removed.
731///
732/// [`rty::AliasTy`]: AliasTy
733fn normalize_projection_ty_with_rustc<'tcx>(
734    genv: GlobalEnv<'_, 'tcx>,
735    def_id: DefId,
736    infcx: &rustc_infer::infer::InferCtxt<'tcx>,
737    obligation: &AliasTy,
738) -> QueryResult<(bool, SubsetTyCtor)> {
739    let tcx = genv.tcx();
740    let projection_ty = obligation.to_rustc(tcx);
741    let projection_ty = tcx.erase_and_anonymize_regions(projection_ty);
742    let cause = ObligationCause::dummy();
743    let param_env = tcx.param_env(def_id);
744
745    let pre_ty = projection_ty.to_ty(tcx);
746    let at = infcx.at(&cause, param_env);
747    let ty = deeply_normalize::<rustc_middle::ty::Ty<'tcx>, FulfillmentError>(at, pre_ty)
748        .map_err(|err| query_bug!("{err:?}"))?;
749
750    let changed = pre_ty != ty;
751    let rustc_ty = ty.lower(tcx).map_err(|reason| query_bug!("{reason:?}"))?;
752
753    Ok((
754        changed,
755        Refiner::default_for_item(genv, def_id)?
756            .refine_ty_or_base(&rustc_ty)?
757            .expect_base(),
758    ))
759}
760
761/// Do one step of normalization, unfolding associated refinements if they are concrete.
762///
763/// Use this if you are about to match structurally on an [`ExprKind`] and you need associated
764/// refinements to be normalized.
765pub fn structurally_normalize_expr<'tcx>(
766    genv: GlobalEnv<'_, 'tcx>,
767    def_id: DefId,
768    infcx: &rustc_infer::infer::InferCtxt<'tcx>,
769    expr: &Expr,
770) -> QueryResult<Expr> {
771    if let ExprKind::Alias(alias_pred, refine_args) = expr.kind() {
772        let (_, e) = normalize_alias_reft(genv, def_id, infcx, alias_pred, refine_args)?;
773        Ok(e)
774    } else {
775        Ok(expr.clone())
776    }
777}
778
779/// Normalizes an [`AliasReft`]. This uses the trait solver to find the [`ImplSourceUserDefinedData`]
780/// and uses the `args` there, which we map back to Flux via refining. This loses refinements,
781/// but that's fine because [`AliasReft`] should not rely on refinements for trait solving.
782fn normalize_alias_reft<'tcx>(
783    genv: GlobalEnv<'_, 'tcx>,
784    def_id: DefId,
785    infcx: &rustc_infer::infer::InferCtxt<'tcx>,
786    alias_reft: &AliasReft,
787    refine_args: &RefineArgs,
788) -> QueryResult<(bool, Expr)> {
789    let tcx = genv.tcx();
790
791    let is_final = genv.assoc_refinement(alias_reft.assoc_id)?.final_;
792    if is_final {
793        let e = genv
794            .default_assoc_refinement_body(alias_reft.assoc_id)?
795            .unwrap_or_else(|| {
796                bug!("final associated refinement without body - should be caught in desugar")
797            })
798            .instantiate(genv.tcx(), &alias_reft.args, &[])
799            .apply(refine_args);
800        return Ok((true, e));
801    }
802
803    // Get impl source
804    let mut selcx = SelectionContext::new(infcx);
805    let param_env = tcx.param_env(def_id);
806    let trait_ref = alias_reft.to_rustc_trait_ref(tcx);
807    let trait_ref = tcx.erase_and_anonymize_regions(trait_ref);
808    let trait_pred = Obligation::new(tcx, ObligationCause::dummy(), param_env, trait_ref);
809
810    let impl_source = selcx
811        .select(&trait_pred)
812        .map_err(|e| query_bug!("error selecting {trait_pred:?}: {e:?}"))?;
813
814    match impl_source {
815        Some(ImplSource::UserDefined(impl_data)) => {
816            let impl_def_id = impl_data.impl_def_id;
817            let args = Refiner::default_for_item(genv, def_id)?.refine_generic_args(
818                impl_def_id,
819                &impl_data
820                    .args
821                    .lower(tcx)
822                    .map_err(|reason| query_bug!("{reason:?}"))?,
823            )?;
824            let e = genv
825                .assoc_refinement_body_for_impl(alias_reft.assoc_id, impl_def_id)?
826                .instantiate(tcx, &args, &[])
827                .apply(refine_args);
828            Ok((true, e))
829        }
830        Some(ImplSource::Builtin(BuiltinImplSource::Misc | BuiltinImplSource::Trivial, _)) => {
831            let e = genv
832                .builtin_assoc_reft_body(infcx.typing_env(param_env), alias_reft)
833                .apply(refine_args);
834            Ok((true, e))
835        }
836        _ => Ok((false, Expr::alias(alias_reft.clone(), refine_args.clone()))),
837    }
838}