From 8a0cb6ae7d3fd7ec3e3fd9986a5f90b91d03f5a4 Mon Sep 17 00:00:00 2001
From: lcnr <rust@lcnr.de>
Date: Thu, 2 Jun 2022 12:02:30 +0200
Subject: [PATCH] `BoundVarReplacer` remove `Option`

---
 compiler/rustc_middle/src/ty/fold.rs | 81 +++++++++++++---------------
 1 file changed, 36 insertions(+), 45 deletions(-)

diff --git a/compiler/rustc_middle/src/ty/fold.rs b/compiler/rustc_middle/src/ty/fold.rs
index 5469aeb4c2c..794c0ddf97d 100644
--- a/compiler/rustc_middle/src/ty/fold.rs
+++ b/compiler/rustc_middle/src/ty/fold.rs
@@ -656,17 +656,17 @@ struct BoundVarReplacer<'a, 'tcx> {
     /// the ones we have visited.
     current_index: ty::DebruijnIndex,
 
-    fld_r: Option<&'a mut (dyn FnMut(ty::BoundRegion) -> ty::Region<'tcx> + 'a)>,
-    fld_t: Option<&'a mut (dyn FnMut(ty::BoundTy) -> Ty<'tcx> + 'a)>,
-    fld_c: Option<&'a mut (dyn FnMut(ty::BoundVar, Ty<'tcx>) -> ty::Const<'tcx> + 'a)>,
+    fld_r: &'a mut (dyn FnMut(ty::BoundRegion) -> ty::Region<'tcx> + 'a),
+    fld_t: &'a mut (dyn FnMut(ty::BoundTy) -> Ty<'tcx> + 'a),
+    fld_c: &'a mut (dyn FnMut(ty::BoundVar, Ty<'tcx>) -> ty::Const<'tcx> + 'a),
 }
 
 impl<'a, 'tcx> BoundVarReplacer<'a, 'tcx> {
     fn new(
         tcx: TyCtxt<'tcx>,
-        fld_r: Option<&'a mut (dyn FnMut(ty::BoundRegion) -> ty::Region<'tcx> + 'a)>,
-        fld_t: Option<&'a mut (dyn FnMut(ty::BoundTy) -> Ty<'tcx> + 'a)>,
-        fld_c: Option<&'a mut (dyn FnMut(ty::BoundVar, Ty<'tcx>) -> ty::Const<'tcx> + 'a)>,
+        fld_r: &'a mut (dyn FnMut(ty::BoundRegion) -> ty::Region<'tcx> + 'a),
+        fld_t: &'a mut (dyn FnMut(ty::BoundTy) -> Ty<'tcx> + 'a),
+        fld_c: &'a mut (dyn FnMut(ty::BoundVar, Ty<'tcx>) -> ty::Const<'tcx> + 'a),
     ) -> Self {
         BoundVarReplacer { tcx, current_index: ty::INNERMOST, fld_r, fld_t, fld_c }
     }
@@ -690,55 +690,42 @@ impl<'a, 'tcx> TypeFolder<'tcx> for BoundVarReplacer<'a, 'tcx> {
     fn fold_ty(&mut self, t: Ty<'tcx>) -> Ty<'tcx> {
         match *t.kind() {
             ty::Bound(debruijn, bound_ty) if debruijn == self.current_index => {
-                if let Some(fld_t) = self.fld_t.as_mut() {
-                    let ty = fld_t(bound_ty);
-                    return ty::fold::shift_vars(self.tcx, ty, self.current_index.as_u32());
-                }
+                let ty = (self.fld_t)(bound_ty);
+                ty::fold::shift_vars(self.tcx, ty, self.current_index.as_u32())
             }
-            _ if t.has_vars_bound_at_or_above(self.current_index) => {
-                return t.super_fold_with(self);
-            }
-            _ => {}
+            _ if t.has_vars_bound_at_or_above(self.current_index) => t.super_fold_with(self),
+            _ => t,
         }
-        t
     }
 
     fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
         match *r {
             ty::ReLateBound(debruijn, br) if debruijn == self.current_index => {
-                if let Some(fld_r) = self.fld_r.as_mut() {
-                    let region = fld_r(br);
-                    return if let ty::ReLateBound(debruijn1, br) = *region {
-                        // If the callback returns a late-bound region,
-                        // that region should always use the INNERMOST
-                        // debruijn index. Then we adjust it to the
-                        // correct depth.
-                        assert_eq!(debruijn1, ty::INNERMOST);
-                        self.tcx.mk_region(ty::ReLateBound(debruijn, br))
-                    } else {
-                        region
-                    };
+                let region = (self.fld_r)(br);
+                if let ty::ReLateBound(debruijn1, br) = *region {
+                    // If the callback returns a late-bound region,
+                    // that region should always use the INNERMOST
+                    // debruijn index. Then we adjust it to the
+                    // correct depth.
+                    assert_eq!(debruijn1, ty::INNERMOST);
+                    self.tcx.mk_region(ty::ReLateBound(debruijn, br))
+                } else {
+                    region
                 }
             }
-            _ => {}
+            _ => r,
         }
-        r
     }
 
     fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
         match ct.val() {
             ty::ConstKind::Bound(debruijn, bound_const) if debruijn == self.current_index => {
-                if let Some(fld_c) = self.fld_c.as_mut() {
-                    let ct = fld_c(bound_const, ct.ty());
-                    return ty::fold::shift_vars(self.tcx, ct, self.current_index.as_u32());
-                }
+                let ct = (self.fld_c)(bound_const, ct.ty());
+                ty::fold::shift_vars(self.tcx, ct, self.current_index.as_u32())
             }
-            _ if ct.has_vars_bound_at_or_above(self.current_index) => {
-                return ct.super_fold_with(self);
-            }
-            _ => {}
+            _ if ct.has_vars_bound_at_or_above(self.current_index) => ct.super_fold_with(self),
+            _ => ct,
         }
-        ct
     }
 }
 
@@ -752,8 +739,10 @@ impl<'tcx> TyCtxt<'tcx> {
     /// returned at the end with each bound region and the free region
     /// that replaced it.
     ///
-    /// This method only replaces late bound regions and the result may still
-    /// contain escaping bound types.
+    /// # Panics
+    ///
+    /// This method only replaces late bound regions. Any types or
+    /// constants bound by `value` will cause an ICE.
     pub fn replace_late_bound_regions<T, F>(
         self,
         value: Binder<'tcx, T>,
@@ -766,11 +755,14 @@ impl<'tcx> TyCtxt<'tcx> {
         let mut region_map = BTreeMap::new();
         let mut real_fld_r =
             |br: ty::BoundRegion| *region_map.entry(br).or_insert_with(|| fld_r(br));
+        let mut fld_t = |b| bug!("unexpected bound ty in binder: {b:?}");
+        let mut fld_c = |b, ty| bug!("unexpected bound ct in binder: {b:?} {ty}");
+
         let value = value.skip_binder();
         let value = if !value.has_escaping_bound_vars() {
             value
         } else {
-            let mut replacer = BoundVarReplacer::new(self, Some(&mut real_fld_r), None, None);
+            let mut replacer = BoundVarReplacer::new(self, &mut real_fld_r, &mut fld_t, &mut fld_c);
             value.fold_with(&mut replacer)
         };
         (value, region_map)
@@ -795,15 +787,14 @@ impl<'tcx> TyCtxt<'tcx> {
         if !value.has_escaping_bound_vars() {
             value
         } else {
-            let mut replacer =
-                BoundVarReplacer::new(self, Some(&mut fld_r), Some(&mut fld_t), Some(&mut fld_c));
+            let mut replacer = BoundVarReplacer::new(self, &mut fld_r, &mut fld_t, &mut fld_c);
             value.fold_with(&mut replacer)
         }
     }
 
     /// Replaces all types or regions bound by the given `Binder`. The `fld_r`
-    /// closure replaces bound regions while the `fld_t` closure replaces bound
-    /// types.
+    /// closure replaces bound regions, the `fld_t` closure replaces bound
+    /// types, and `fld_c` replaces bound constants.
     pub fn replace_bound_vars<T, F, G, H>(
         self,
         value: Binder<'tcx, T>,