diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index e169ed1d0..b50238b97 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -16,6 +16,8 @@ use std::ops; pub type NonUniformResult = Option>; +const DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE: bool = true; + bitflags::bitflags! { /// Kinds of expressions that require uniform control flow. #[cfg_attr(feature = "serialize", derive(serde::Serialize))] @@ -23,8 +25,8 @@ bitflags::bitflags! { #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub struct UniformityRequirements: u8 { const WORK_GROUP_BARRIER = 0x1; - const DERIVATIVE = 0x2; - const IMPLICIT_LEVEL = 0x4; + const DERIVATIVE = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x2 }; + const IMPLICIT_LEVEL = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x4 }; } } @@ -1350,52 +1352,56 @@ fn uniform_control_flow() { &expressions, &Arena::new(), ); - assert_eq!( - block_info, - Err(FunctionError::NonUniformControlFlow( - UniformityRequirements::DERIVATIVE, - derivative_expr, - UniformityDisruptor::Expression(non_uniform_global_expr) - ) - .with_span()), - ); - assert_eq!(info[derivative_expr].ref_count, 1); + if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { + assert_eq!(info[derivative_expr].ref_count, 2); + } else { + assert_eq!( + block_info, + Err(FunctionError::NonUniformControlFlow( + UniformityRequirements::DERIVATIVE, + derivative_expr, + UniformityDisruptor::Expression(non_uniform_global_expr) + ) + .with_span()), + ); + assert_eq!(info[derivative_expr].ref_count, 1); - // Test that the same thing passes when we disable the `derivative_uniformity` - let mut diagnostic_filters = Arena::new(); - let diagnostic_filter_leaf = diagnostic_filters.append( - DiagnosticFilterNode { - inner: crate::diagnostic_filter::DiagnosticFilter { - new_severity: crate::diagnostic_filter::Severity::Off, - triggering_rule: FilterableTriggeringRule::DerivativeUniformity, + // Test that the same thing passes when we disable the `derivative_uniformity` + let mut diagnostic_filters = Arena::new(); + let diagnostic_filter_leaf = diagnostic_filters.append( + DiagnosticFilterNode { + inner: crate::diagnostic_filter::DiagnosticFilter { + new_severity: crate::diagnostic_filter::Severity::Off, + triggering_rule: FilterableTriggeringRule::DerivativeUniformity, + }, + parent: None, }, - parent: None, - }, - crate::Span::default(), - ); - let mut info = FunctionInfo { - diagnostic_filter_leaf: Some(diagnostic_filter_leaf), - ..info.clone() - }; + crate::Span::default(), + ); + let mut info = FunctionInfo { + diagnostic_filter_leaf: Some(diagnostic_filter_leaf), + ..info.clone() + }; - let block_info = info.process_block( - &vec![stmt_emit2, stmt_if_non_uniform].into(), - &[], - None, - &expressions, - &diagnostic_filters, - ); - assert_eq!( - block_info, - Ok(FunctionUniformity { - result: Uniformity { - non_uniform_result: None, - requirements: UniformityRequirements::DERIVATIVE, - }, - exit: ExitFlags::empty() - }), - ); - assert_eq!(info[derivative_expr].ref_count, 2); + let block_info = info.process_block( + &vec![stmt_emit2, stmt_if_non_uniform].into(), + &[], + None, + &expressions, + &diagnostic_filters, + ); + assert_eq!( + block_info, + Ok(FunctionUniformity { + result: Uniformity { + non_uniform_result: None, + requirements: UniformityRequirements::DERIVATIVE, + }, + exit: ExitFlags::empty() + }), + ); + assert_eq!(info[derivative_expr].ref_count, 2); + } } assert_eq!(info[non_uniform_global], GlobalUse::READ);