diff --git a/compiler/rustc_smir/src/rustc_internal/mod.rs b/compiler/rustc_smir/src/rustc_internal/mod.rs index 1c5a0924f4a..26127c5eb85 100644 --- a/compiler/rustc_smir/src/rustc_internal/mod.rs +++ b/compiler/rustc_smir/src/rustc_internal/mod.rs @@ -196,6 +196,7 @@ where { args: Vec, callback: fn(TyCtxt<'_>) -> T, + after_analysis: Compilation, result: Option, } @@ -205,16 +206,27 @@ where { /// Creates a new `StableMir` instance, with given test_function and arguments. pub fn new(args: Vec, callback: fn(TyCtxt<'_>) -> T) -> Self { - StableMir { args, callback, result: None } + StableMir { args, callback, result: None, after_analysis: Compilation::Stop } + } + + /// Configure object to stop compilation after callback is called. + pub fn stop_compilation(&mut self) -> &mut Self { + self.after_analysis = Compilation::Stop; + self + } + + /// Configure object to continue compilation after callback is called. + pub fn continue_compilation(&mut self) -> &mut Self { + self.after_analysis = Compilation::Continue; + self } /// Runs the compiler against given target and tests it with `test_function` - pub fn run(mut self) -> Result { - let compiler_result = rustc_driver::catch_fatal_errors(|| { - RunCompiler::new(&self.args.clone(), &mut self).run() - }); + pub fn run(&mut self) -> Result { + let compiler_result = + rustc_driver::catch_fatal_errors(|| RunCompiler::new(&self.args.clone(), self).run()); match compiler_result { - Ok(Ok(())) => Ok(self.result.unwrap()), + Ok(Ok(())) => Ok(self.result.take().unwrap()), Ok(Err(_)) => Err(CompilerError::CompilationFailed), Err(_) => Err(CompilerError::ICE), } @@ -238,7 +250,7 @@ where self.result = Some((self.callback)(tcx)); }); }); - // No need to keep going. - Compilation::Stop + // Let users define if they want to stop compilation. + self.after_analysis } }