[naga msl-out] Defeat the MSL compiler's infinite loop analysis.

See the comments in the code for details.

This patch emits the definition of the macro only when the first loop
is encountered. This does make that first loop's code look a bit odd:
it would be more natural to define the macro at the top of the
file. (See the modified files in `naga/tests/out/msl`.)

Rejected alternatives:

- We could emit the macro definition unconditionally at the top of the
  file. But this changes every MSL snapshot output file, whereas only
  eight of them actually contain loops.

- We could have the validator flag modules that contain loops. But the
  changes end up being not small, and spread across the validator, so
  this seems disproportionate. If we had other consumers of this
  information, it might make sense.

- We could change the MSL backend to allow text to be generated out of
  order, so that we can decide whether to define the macro after we've
  generated all the function bodies. But at the moment this seems like
  unnecessary complexity, although it might be worth doing in the
  future if we had additional uses for it - say, to conditionally emit
  helper function definitions.

Fixes #4972.
This commit is contained in:
Jim Blandy 2024-09-16 12:56:58 -07:00 committed by Erich Gubler
parent c3ab12aa29
commit 3fda684eb9
9 changed files with 162 additions and 19 deletions

View File

@ -376,6 +376,11 @@ pub struct Writer<W> {
/// Set of (struct type, struct field index) denoting which fields require /// Set of (struct type, struct field index) denoting which fields require
/// padding inserted **before** them (i.e. between fields at index - 1 and index) /// padding inserted **before** them (i.e. between fields at index - 1 and index)
struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>, struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
/// Name of the loop reachability macro.
///
/// See `emit_loop_reachable_macro` for details.
loop_reachable_macro_name: String,
} }
impl crate::Scalar { impl crate::Scalar {
@ -665,6 +670,7 @@ impl<W: Write> Writer<W> {
#[cfg(test)] #[cfg(test)]
put_block_stack_pointers: Default::default(), put_block_stack_pointers: Default::default(),
struct_member_pads: FastHashSet::default(), struct_member_pads: FastHashSet::default(),
loop_reachable_macro_name: String::default(),
} }
} }
@ -675,6 +681,125 @@ impl<W: Write> Writer<W> {
self.out self.out
} }
/// Define a macro to invoke before loops, to defeat MSL infinite loop
/// reasoning.
///
/// If we haven't done so already, emit the definition of a preprocessor
/// macro to be invoked before each loop in the generated MSL, to ensure
/// that the MSL compiler's optimizations do not remove bounds checks.
///
/// Only the first call to this function for a given module actually causes
/// the macro definition to be written. Subsequent loops can simply use the
/// prior macro definition, since macros aren't block-scoped.
///
/// # What is this trying to solve?
///
/// In Metal Shading Language, an infinite loop has undefined behavior.
/// (This rule is inherited from C++14.) This means that, if the MSL
/// compiler determines that a given loop will never exit, it may assume
/// that it is never reached. It may thus assume that any conditions
/// sufficient to cause the loop to be reached must be false. Like many
/// optimizing compilers, MSL uses this kind of analysis to establish limits
/// on the range of values variables involved in those conditions might
/// hold.
///
/// For example, suppose the MSL compiler sees the code:
///
/// ```ignore
/// if (i >= 10) {
/// while (true) { }
/// }
/// ```
///
/// It will recognize that the `while` loop will never terminate, conclude
/// that it must be unreachable, and thus infer that, if this code is
/// reached, then `i < 10` at that point.
///
/// Now suppose that, at some point where `i` has the same value as above,
/// the compiler sees the code:
///
/// ```ignore
/// if (i < 10) {
/// a[i] = 1;
/// }
/// ```
///
/// Because the compiler is confident that `i < 10`, it will make the
/// assignment to `a[i]` unconditional, rewriting this code as, simply:
///
/// ```ignore
/// a[i] = 1;
/// ```
///
/// If that `if` condition was injected by Naga to implement a bounds check,
/// the MSL compiler's optimizations could allow out-of-bounds array
/// accesses to occur.
///
/// Naga cannot feasibly anticipate whether the MSL compiler will determine
/// that a loop is infinite, so an attacker could craft a Naga module
/// containing an infinite loop protected by conditions that cause the Metal
/// compiler to remove bounds checks that Naga injected elsewhere in the
/// function.
///
/// This rewrite could occur even if the conditional assignment appears
/// *before* the `while` loop, as long as `i < 10` by the time the loop is
/// reached. This would allow the attacker to save the results of
/// unauthorized reads somewhere accessible before entering the infinite
/// loop. But even worse, the MSL compiler has been observed to simply
/// delete the infinite loop entirely, so that even code dominated by the
/// loop becomes reachable. This would make the attack even more flexible,
/// since shaders that would appear to never terminate would actually exit
/// nicely, after having stolen data from elsewhere in the GPU address
/// space.
///
/// Ideally, Naga would prevent UB entirely via some means that persuades
/// the MSL compiler that no loop Naga generates is infinite. One approach
/// would be to add inline assembly to each loop that is annotated as
/// potentially branching out of the loop, but which in fact generates no
/// instructions. Unfortunately, inline assembly is not handled correctly by
/// some Metal device drivers. Further experimentation hasn't produced a
/// satisfactory approach.
///
/// Instead, we accept that the MSL compiler may determine that some loops
/// are infinite, and focus instead on preventing the range analysis from
/// being affected. We transform *every* loop into something like this:
///
/// ```ignore
/// if (volatile bool unpredictable = true; unpredictable)
/// while (true) { }
/// ```
///
/// Since the `volatile` qualifier prevents the compiler from assuming that
/// the `if` condition is true, it cannot be sure the infinite loop is
/// reached, and thus it cannot assume the entire structure is unreachable.
/// This prevents the range analysis impact described above.
///
/// Unfortunately, what makes this a kludge, not a hack, is that this
/// solution leaves the GPU executing a pointless conditional branch, at
/// runtime, before each loop. There's no part of the system that has a
/// global enough view to be sure that `unpredictable` is true, and remove
/// it from the code.
///
/// To make our output a bit more legible, we pull the condition out into a
/// preprocessor macro defined at the top of the module.
fn emit_loop_reachable_macro(&mut self) -> BackendResult {
if !self.loop_reachable_macro_name.is_empty() {
return Ok(());
}
self.loop_reachable_macro_name = self.namer.call("LOOP_IS_REACHABLE");
let loop_reachable_volatile_name = self.namer.call("unpredictable_jump_over_loop");
writeln!(
self.out,
"#define {} if (volatile bool {} = true; {})",
self.loop_reachable_macro_name,
loop_reachable_volatile_name,
loop_reachable_volatile_name,
)?;
Ok(())
}
fn put_call_parameters( fn put_call_parameters(
&mut self, &mut self,
parameters: impl Iterator<Item = Handle<crate::Expression>>, parameters: impl Iterator<Item = Handle<crate::Expression>>,
@ -2924,10 +3049,15 @@ impl<W: Write> Writer<W> {
ref continuing, ref continuing,
break_if, break_if,
} => { } => {
self.emit_loop_reachable_macro()?;
if !continuing.is_empty() || break_if.is_some() { if !continuing.is_empty() || break_if.is_some() {
let gate_name = self.namer.call("loop_init"); let gate_name = self.namer.call("loop_init");
writeln!(self.out, "{level}bool {gate_name} = true;")?; writeln!(self.out, "{level}bool {gate_name} = true;")?;
writeln!(self.out, "{level}while(true) {{")?; writeln!(
self.out,
"{level}{} while(true) {{",
self.loop_reachable_macro_name,
)?;
let lif = level.next(); let lif = level.next();
let lcontinuing = lif.next(); let lcontinuing = lif.next();
writeln!(self.out, "{lif}if (!{gate_name}) {{")?; writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
@ -2942,7 +3072,11 @@ impl<W: Write> Writer<W> {
writeln!(self.out, "{lif}}}")?; writeln!(self.out, "{lif}}}")?;
writeln!(self.out, "{lif}{gate_name} = false;")?; writeln!(self.out, "{lif}{gate_name} = false;")?;
} else { } else {
writeln!(self.out, "{level}while(true) {{")?; writeln!(
self.out,
"{level}{} while(true) {{",
self.loop_reachable_macro_name,
)?;
} }
self.put_block(level.next(), body, context)?; self.put_block(level.next(), body, context)?;
writeln!(self.out, "{level}}}")?; writeln!(self.out, "{level}}}")?;
@ -3379,6 +3513,7 @@ impl<W: Write> Writer<W> {
&[CLAMPED_LOD_LOAD_PREFIX], &[CLAMPED_LOD_LOAD_PREFIX],
&mut self.names, &mut self.names,
); );
self.loop_reachable_macro_name.clear();
self.struct_member_pads.clear(); self.struct_member_pads.clear();
writeln!( writeln!(

View File

@ -55,8 +55,9 @@ kernel void main_(
vPos = _e8; vPos = _e8;
metal::float2 _e14 = particlesSrc.particles[index].vel; metal::float2 _e14 = particlesSrc.particles[index].vel;
vVel = _e14; vVel = _e14;
#define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
bool loop_init = true; bool loop_init = true;
while(true) { LOOP_IS_REACHABLE while(true) {
if (!loop_init) { if (!loop_init) {
uint _e91 = i; uint _e91 = i;
i = _e91 + 1u; i = _e91 + 1u;

View File

@ -7,8 +7,9 @@ using metal::uint;
void breakIfEmpty( void breakIfEmpty(
) { ) {
#define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
bool loop_init = true; bool loop_init = true;
while(true) { LOOP_IS_REACHABLE while(true) {
if (!loop_init) { if (!loop_init) {
if (true) { if (true) {
break; break;
@ -25,7 +26,7 @@ void breakIfEmptyBody(
bool b = {}; bool b = {};
bool c = {}; bool c = {};
bool loop_init_1 = true; bool loop_init_1 = true;
while(true) { LOOP_IS_REACHABLE while(true) {
if (!loop_init_1) { if (!loop_init_1) {
b = a; b = a;
bool _e2 = b; bool _e2 = b;
@ -46,7 +47,7 @@ void breakIf(
bool d = {}; bool d = {};
bool e = {}; bool e = {};
bool loop_init_2 = true; bool loop_init_2 = true;
while(true) { LOOP_IS_REACHABLE while(true) {
if (!loop_init_2) { if (!loop_init_2) {
bool _e5 = e; bool _e5 = e;
if (a_1 == e) { if (a_1 == e) {
@ -65,7 +66,7 @@ void breakIfSeparateVariable(
) { ) {
uint counter = 0u; uint counter = 0u;
bool loop_init_3 = true; bool loop_init_3 = true;
while(true) { LOOP_IS_REACHABLE while(true) {
if (!loop_init_3) { if (!loop_init_3) {
uint _e5 = counter; uint _e5 = counter;
if (counter == 5u) { if (counter == 5u) {

View File

@ -19,7 +19,8 @@ uint collatz_iterations(
uint n = {}; uint n = {};
uint i = 0u; uint i = 0u;
n = n_base; n = n_base;
while(true) { #define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
LOOP_IS_REACHABLE while(true) {
uint _e4 = n; uint _e4 = n;
if (_e4 > 1u) { if (_e4 > 1u) {
} else { } else {

View File

@ -31,7 +31,8 @@ void switch_case_break(
void loop_switch_continue( void loop_switch_continue(
int x int x
) { ) {
while(true) { #define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
LOOP_IS_REACHABLE while(true) {
switch(x) { switch(x) {
case 1: { case 1: {
continue; continue;
@ -49,7 +50,7 @@ void loop_switch_continue_nesting(
int y, int y,
int z int z
) { ) {
while(true) { LOOP_IS_REACHABLE while(true) {
switch(x_1) { switch(x_1) {
case 1: { case 1: {
continue; continue;
@ -60,7 +61,7 @@ void loop_switch_continue_nesting(
continue; continue;
} }
default: { default: {
while(true) { LOOP_IS_REACHABLE while(true) {
switch(z) { switch(z) {
case 1: { case 1: {
continue; continue;
@ -85,7 +86,7 @@ void loop_switch_continue_nesting(
} }
} }
} }
while(true) { LOOP_IS_REACHABLE while(true) {
switch(y) { switch(y) {
case 1: case 1:
default: { default: {
@ -108,7 +109,7 @@ void loop_switch_omit_continue_variable_checks(
int w int w
) { ) {
int pos_1 = 0; int pos_1 = 0;
while(true) { LOOP_IS_REACHABLE while(true) {
switch(x_2) { switch(x_2) {
case 1: { case 1: {
pos_1 = 1; pos_1 = 1;
@ -119,7 +120,7 @@ void loop_switch_omit_continue_variable_checks(
} }
} }
} }
while(true) { LOOP_IS_REACHABLE while(true) {
switch(x_2) { switch(x_2) {
case 1: { case 1: {
break; break;

View File

@ -8,8 +8,9 @@ using metal::uint;
void fb1_( void fb1_(
thread bool& cond thread bool& cond
) { ) {
#define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
bool loop_init = true; bool loop_init = true;
while(true) { LOOP_IS_REACHABLE while(true) {
if (!loop_init) { if (!loop_init) {
bool _e1 = cond; bool _e1 = cond;
if (!(cond)) { if (!(cond)) {

View File

@ -33,7 +33,8 @@ kernel void main_(
rq.intersector.force_opacity((desc.flags & 1) != 0 ? metal::raytracing::forced_opacity::opaque : (desc.flags & 2) != 0 ? metal::raytracing::forced_opacity::non_opaque : metal::raytracing::forced_opacity::none); rq.intersector.force_opacity((desc.flags & 1) != 0 ? metal::raytracing::forced_opacity::opaque : (desc.flags & 2) != 0 ? metal::raytracing::forced_opacity::non_opaque : metal::raytracing::forced_opacity::none);
rq.intersector.accept_any_intersection((desc.flags & 4) != 0); rq.intersector.accept_any_intersection((desc.flags & 4) != 0);
rq.intersection = rq.intersector.intersect(metal::raytracing::ray(desc.origin, desc.dir, desc.tmin, desc.tmax), acc_struct, desc.cull_mask); rq.ready = true; rq.intersection = rq.intersector.intersect(metal::raytracing::ray(desc.origin, desc.dir, desc.tmin, desc.tmax), acc_struct, desc.cull_mask); rq.ready = true;
while(true) { #define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
LOOP_IS_REACHABLE while(true) {
bool _e31 = rq.ready; bool _e31 = rq.ready;
rq.ready = false; rq.ready = false;
if (_e31) { if (_e31) {

View File

@ -53,7 +53,8 @@ RayIntersection query_loop(
rq.intersector.force_opacity((_e8.flags & 1) != 0 ? metal::raytracing::forced_opacity::opaque : (_e8.flags & 2) != 0 ? metal::raytracing::forced_opacity::non_opaque : metal::raytracing::forced_opacity::none); rq.intersector.force_opacity((_e8.flags & 1) != 0 ? metal::raytracing::forced_opacity::opaque : (_e8.flags & 2) != 0 ? metal::raytracing::forced_opacity::non_opaque : metal::raytracing::forced_opacity::none);
rq.intersector.accept_any_intersection((_e8.flags & 4) != 0); rq.intersector.accept_any_intersection((_e8.flags & 4) != 0);
rq.intersection = rq.intersector.intersect(metal::raytracing::ray(_e8.origin, _e8.dir, _e8.tmin, _e8.tmax), acs, _e8.cull_mask); rq.ready = true; rq.intersection = rq.intersector.intersect(metal::raytracing::ray(_e8.origin, _e8.dir, _e8.tmin, _e8.tmax), acs, _e8.cull_mask); rq.ready = true;
while(true) { #define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
LOOP_IS_REACHABLE while(true) {
bool _e9 = rq.ready; bool _e9 = rq.ready;
rq.ready = false; rq.ready = false;
if (_e9) { if (_e9) {

View File

@ -100,8 +100,9 @@ fragment fs_mainOutput fs_main(
metal::float3 color = c_ambient; metal::float3 color = c_ambient;
uint i = 0u; uint i = 0u;
metal::float3 normal_1 = metal::normalize(in.world_normal); metal::float3 normal_1 = metal::normalize(in.world_normal);
#define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
bool loop_init = true; bool loop_init = true;
while(true) { LOOP_IS_REACHABLE while(true) {
if (!loop_init) { if (!loop_init) {
uint _e40 = i; uint _e40 = i;
i = _e40 + 1u; i = _e40 + 1u;
@ -151,7 +152,7 @@ fragment fs_main_without_storageOutput fs_main_without_storage(
uint i_1 = 0u; uint i_1 = 0u;
metal::float3 normal_2 = metal::normalize(in_1.world_normal); metal::float3 normal_2 = metal::normalize(in_1.world_normal);
bool loop_init_1 = true; bool loop_init_1 = true;
while(true) { LOOP_IS_REACHABLE while(true) {
if (!loop_init_1) { if (!loop_init_1) {
uint _e40 = i_1; uint _e40 = i_1;
i_1 = _e40 + 1u; i_1 = _e40 + 1u;