From 6ee51426a9714a9f989a2a9c549d52b3dfc18c49 Mon Sep 17 00:00:00 2001
From: Nadrieril <nadrieril+git@gmail.com>
Date: Tue, 31 Oct 2023 00:40:41 +0100
Subject: [PATCH] Respect `split` invariants for `Opaque`s

---
 .../src/thir/pattern/deconstruct_pat.rs       | 84 +++++++++++++------
 1 file changed, 59 insertions(+), 25 deletions(-)

diff --git a/compiler/rustc_mir_build/src/thir/pattern/deconstruct_pat.rs b/compiler/rustc_mir_build/src/thir/pattern/deconstruct_pat.rs
index 479f6c0a3ca..cf9a0af33ee 100644
--- a/compiler/rustc_mir_build/src/thir/pattern/deconstruct_pat.rs
+++ b/compiler/rustc_mir_build/src/thir/pattern/deconstruct_pat.rs
@@ -41,6 +41,13 @@
 //! or-patterns; instead we just try the alternatives one-by-one. For details on splitting
 //! wildcards, see [`Constructor::split`]; for integer ranges, see
 //! [`IntRange::split`]; for slices, see [`Slice::split`].
+//!
+//! ## Opaque patterns
+//!
+//! Some patterns, such as TODO, cannot be inspected, which we handle with `Constructor::Opaque`.
+//! Since we know nothing of these patterns, we assume they never cover each other. In order to
+//! respect the invariants of [`SplitConstructorSet`], we give each `Opaque` constructor a unique id
+//! so we can recognize it.
 
 use std::cell::Cell;
 use std::cmp::{self, max, min, Ordering};
@@ -617,6 +624,18 @@ impl Slice {
     }
 }
 
+/// A globally unique id to distinguish `Opaque` patterns.
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub(super) struct OpaqueId(u32);
+
+impl OpaqueId {
+    fn new() -> Self {
+        use std::sync::atomic::{AtomicU32, Ordering};
+        static OPAQUE_ID: AtomicU32 = AtomicU32::new(0);
+        OpaqueId(OPAQUE_ID.fetch_add(1, Ordering::SeqCst))
+    }
+}
+
 /// A value can be decomposed into a constructor applied to some fields. This struct represents
 /// the constructor. See also `Fields`.
 ///
@@ -642,10 +661,12 @@ pub(super) enum Constructor<'tcx> {
     Str(mir::Const<'tcx>),
     /// Array and slice patterns.
     Slice(Slice),
-    /// Constants that must not be matched structurally. They are treated as black
-    /// boxes for the purposes of exhaustiveness: we must not inspect them, and they
-    /// don't count towards making a match exhaustive.
-    Opaque,
+    /// Constants that must not be matched structurally. They are treated as black boxes for the
+    /// purposes of exhaustiveness: we must not inspect them, and they don't count towards making a
+    /// match exhaustive.
+    /// Carries an id that must be unique within a match. We need this to ensure the invariants of
+    /// [`SplitConstructorSet`].
+    Opaque(OpaqueId),
     /// Or-pattern.
     Or,
     /// Wildcard pattern.
@@ -663,6 +684,9 @@ pub(super) enum Constructor<'tcx> {
 }
 
 impl<'tcx> Constructor<'tcx> {
+    pub(super) fn is_wildcard(&self) -> bool {
+        matches!(self, Wildcard)
+    }
     pub(super) fn is_non_exhaustive(&self) -> bool {
         matches!(self, NonExhaustive)
     }
@@ -728,7 +752,7 @@ impl<'tcx> Constructor<'tcx> {
             | F32Range(..)
             | F64Range(..)
             | Str(..)
-            | Opaque
+            | Opaque(..)
             | NonExhaustive
             | Hidden
             | Missing { .. }
@@ -869,8 +893,10 @@ impl<'tcx> Constructor<'tcx> {
             }
             (Slice(self_slice), Slice(other_slice)) => self_slice.is_covered_by(*other_slice),
 
-            // We are trying to inspect an opaque constant. Thus we skip the row.
-            (Opaque, _) | (_, Opaque) => false,
+            // Opaque constructors don't interact with anything unless they come from the
+            // syntactically identical pattern.
+            (Opaque(self_id), Opaque(other_id)) => self_id == other_id,
+            (Opaque(..), _) | (_, Opaque(..)) => false,
 
             _ => span_bug!(
                 pcx.span,
@@ -1083,18 +1109,26 @@ impl ConstructorSet {
     {
         let mut present: SmallVec<[_; 1]> = SmallVec::new();
         let mut missing = Vec::new();
-        // Constructors in `ctors`, except wildcards.
-        let mut seen = ctors.filter(|c| !(matches!(c, Opaque | Wildcard)));
+        // Constructors in `ctors`, except wildcards and opaques.
+        let mut seen = Vec::new();
+        for ctor in ctors.cloned() {
+            if let Constructor::Opaque(..) = ctor {
+                present.push(ctor);
+            } else if !ctor.is_wildcard() {
+                seen.push(ctor);
+            }
+        }
+
         match self {
             ConstructorSet::Single => {
-                if seen.next().is_none() {
+                if seen.is_empty() {
                     missing.push(Single);
                 } else {
                     present.push(Single);
                 }
             }
             ConstructorSet::Variants { visible_variants, hidden_variants, non_exhaustive } => {
-                let seen_set: FxHashSet<_> = seen.map(|c| c.as_variant().unwrap()).collect();
+                let seen_set: FxHashSet<_> = seen.iter().map(|c| c.as_variant().unwrap()).collect();
                 let mut skipped_a_hidden_variant = false;
 
                 for variant in visible_variants {
@@ -1125,7 +1159,7 @@ impl ConstructorSet {
             ConstructorSet::Bool => {
                 let mut seen_false = false;
                 let mut seen_true = false;
-                for b in seen.map(|ctor| ctor.as_bool().unwrap()) {
+                for b in seen.iter().map(|ctor| ctor.as_bool().unwrap()) {
                     if b {
                         seen_true = true;
                     } else {
@@ -1145,7 +1179,7 @@ impl ConstructorSet {
             }
             ConstructorSet::Integers { range_1, range_2 } => {
                 let seen_ranges: Vec<_> =
-                    seen.map(|ctor| ctor.as_int_range().unwrap().clone()).collect();
+                    seen.iter().map(|ctor| ctor.as_int_range().unwrap().clone()).collect();
                 for (seen, splitted_range) in range_1.split(seen_ranges.iter().cloned()) {
                     match seen {
                         Presence::Unseen => missing.push(IntRange(splitted_range)),
@@ -1162,7 +1196,7 @@ impl ConstructorSet {
                 }
             }
             &ConstructorSet::Slice(array_len) => {
-                let seen_slices = seen.map(|c| c.as_slice().unwrap());
+                let seen_slices = seen.iter().map(|c| c.as_slice().unwrap());
                 let base_slice = Slice::new(array_len, VarLen(0, 0));
                 for (seen, splitted_slice) in base_slice.split(seen_slices) {
                     let ctor = Slice(splitted_slice);
@@ -1178,7 +1212,7 @@ impl ConstructorSet {
                 // unreachable if length != 0.
                 // We still gather the seen constructors in `present`, but the only slice that can
                 // go in `missing` is `[]`.
-                let seen_slices = seen.map(|c| c.as_slice().unwrap());
+                let seen_slices = seen.iter().map(|c| c.as_slice().unwrap());
                 let base_slice = Slice::new(None, VarLen(0, 0));
                 for (seen, splitted_slice) in base_slice.split(seen_slices) {
                     let ctor = Slice(splitted_slice);
@@ -1194,7 +1228,7 @@ impl ConstructorSet {
             ConstructorSet::Unlistable => {
                 // Since we can't list constructors, we take the ones in the column. This might list
                 // some constructors several times but there's not much we can do.
-                present.extend(seen.cloned());
+                present.extend(seen);
                 missing.push(NonExhaustive);
             }
             // If `exhaustive_patterns` is disabled and our scrutinee is an empty type, we cannot
@@ -1339,7 +1373,7 @@ impl<'p, 'tcx> Fields<'p, 'tcx> {
             | F32Range(..)
             | F64Range(..)
             | Str(..)
-            | Opaque
+            | Opaque(..)
             | NonExhaustive
             | Hidden
             | Missing { .. }
@@ -1470,14 +1504,14 @@ impl<'p, 'tcx> DeconstructedPat<'p, 'tcx> {
                     ty::Bool => {
                         ctor = match value.try_eval_bool(cx.tcx, cx.param_env) {
                             Some(b) => Bool(b),
-                            None => Opaque,
+                            None => Opaque(OpaqueId::new()),
                         };
                         fields = Fields::empty();
                     }
                     ty::Char | ty::Int(_) | ty::Uint(_) => {
                         ctor = match value.try_eval_bits(cx.tcx, cx.param_env) {
                             Some(bits) => IntRange(IntRange::from_bits(cx.tcx, pat.ty, bits)),
-                            None => Opaque,
+                            None => Opaque(OpaqueId::new()),
                         };
                         fields = Fields::empty();
                     }
@@ -1488,7 +1522,7 @@ impl<'p, 'tcx> DeconstructedPat<'p, 'tcx> {
                                 let value = rustc_apfloat::ieee::Single::from_bits(bits);
                                 F32Range(value, value, RangeEnd::Included)
                             }
-                            None => Opaque,
+                            None => Opaque(OpaqueId::new()),
                         };
                         fields = Fields::empty();
                     }
@@ -1499,7 +1533,7 @@ impl<'p, 'tcx> DeconstructedPat<'p, 'tcx> {
                                 let value = rustc_apfloat::ieee::Double::from_bits(bits);
                                 F64Range(value, value, RangeEnd::Included)
                             }
-                            None => Opaque,
+                            None => Opaque(OpaqueId::new()),
                         };
                         fields = Fields::empty();
                     }
@@ -1520,7 +1554,7 @@ impl<'p, 'tcx> DeconstructedPat<'p, 'tcx> {
                     // into the corresponding `Pat`s by `const_to_pat`. Constants that remain are
                     // opaque.
                     _ => {
-                        ctor = Opaque;
+                        ctor = Opaque(OpaqueId::new());
                         fields = Fields::empty();
                     }
                 }
@@ -1581,7 +1615,7 @@ impl<'p, 'tcx> DeconstructedPat<'p, 'tcx> {
                 fields = Fields::from_iter(cx, pats.into_iter().map(mkpat));
             }
             PatKind::Error(_) => {
-                ctor = Opaque;
+                ctor = Opaque(OpaqueId::new());
                 fields = Fields::empty();
             }
         }
@@ -1768,7 +1802,7 @@ impl<'p, 'tcx> fmt::Debug for DeconstructedPat<'p, 'tcx> {
             F32Range(lo, hi, end) => write!(f, "{lo}{end}{hi}"),
             F64Range(lo, hi, end) => write!(f, "{lo}{end}{hi}"),
             Str(value) => write!(f, "{value}"),
-            Opaque => write!(f, "<constant pattern>"),
+            Opaque(..) => write!(f, "<constant pattern>"),
             Or => {
                 for pat in self.iter_fields() {
                     write!(f, "{}{:?}", start_or_continue(" | "), pat)?;
@@ -1898,7 +1932,7 @@ impl<'tcx> WitnessPat<'tcx> {
                 "trying to convert a `Missing` constructor into a `Pat`; this is probably a bug,
                 `Missing` should have been processed in `apply_constructors`"
             ),
-            F32Range(..) | F64Range(..) | Opaque | Or => {
+            F32Range(..) | F64Range(..) | Opaque(..) | Or => {
                 bug!("can't convert to pattern: {:?}", self)
             }
         };