[wgsl-in] Handle all(bool) and any(bool) (#2445)

Fixes #1911.
This commit is contained in:
Fredrik Fornwall 2023-08-29 21:34:55 +02:00 committed by GitHub
parent 1192588486
commit 3bd2834b4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 120 additions and 61 deletions

View File

@ -1696,7 +1696,26 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let argument = self.expression(args.next()?, ctx.reborrow())?;
args.finish()?;
crate::Expression::Relational { fun, argument }
// Check for no-op all(bool) and any(bool):
let argument_unmodified = matches!(
fun,
crate::RelationalFunction::All | crate::RelationalFunction::Any
) && {
ctx.grow_types(argument)?;
matches!(
ctx.resolved_inner(argument),
&crate::TypeInner::Scalar {
kind: crate::ScalarKind::Bool,
..
}
)
};
if argument_unmodified {
return Ok(Some(argument));
} else {
crate::Expression::Relational { fun, argument }
}
} else if let Some((axis, ctrl)) = conv::map_derivative(function.name) {
let mut args = ctx.prepare_args(arguments, 1, span);
let expr = self.expression(args.next()?, ctx.reborrow())?;

View File

@ -1,5 +1,11 @@
// Standard functions.
fn test_any_and_all_for_bool() -> bool {
let a = any(true);
return all(a);
}
@fragment
fn derivatives(@builtin(position) foo: vec4<f32>) -> @location(0) vec4<f32> {
var x = dpdxCoarse(foo);
@ -14,5 +20,7 @@ fn derivatives(@builtin(position) foo: vec4<f32>) -> @location(0) vec4<f32> {
y = dpdy(foo);
z = fwidth(foo);
let a = test_any_and_all_for_bool();
return (x + y) * z;
}

View File

@ -5,6 +5,10 @@ precision highp int;
layout(location = 0) out vec4 _fs2p_location0;
bool test_any_and_all_for_bool() {
return true;
}
void main() {
vec4 foo = gl_FragCoord;
vec4 x = vec4(0.0);
@ -28,10 +32,11 @@ void main() {
y = _e11;
vec4 _e12 = fwidth(foo);
z = _e12;
vec4 _e13 = x;
vec4 _e14 = y;
vec4 _e16 = z;
_fs2p_location0 = ((_e13 + _e14) * _e16);
bool _e13 = test_any_and_all_for_bool();
vec4 _e14 = x;
vec4 _e15 = y;
vec4 _e17 = z;
_fs2p_location0 = ((_e14 + _e15) * _e17);
return;
}

View File

@ -2,6 +2,11 @@ struct FragmentInput_derivatives {
float4 foo_1 : SV_Position;
};
bool test_any_and_all_for_bool()
{
return true;
}
float4 derivatives(FragmentInput_derivatives fragmentinput_derivatives) : SV_Target0
{
float4 foo = fragmentinput_derivatives.foo_1;
@ -27,8 +32,9 @@ float4 derivatives(FragmentInput_derivatives fragmentinput_derivatives) : SV_Tar
y = _expr11;
float4 _expr12 = fwidth(foo);
z = _expr12;
float4 _expr13 = x;
float4 _expr14 = y;
float4 _expr16 = z;
return ((_expr13 + _expr14) * _expr16);
const bool _e13 = test_any_and_all_for_bool();
float4 _expr14 = x;
float4 _expr15 = y;
float4 _expr17 = z;
return ((_expr14 + _expr15) * _expr17);
}

View File

@ -5,6 +5,11 @@
using metal::uint;
bool test_any_and_all_for_bool(
) {
return true;
}
struct derivativesInput {
};
struct derivativesOutput {
@ -34,8 +39,9 @@ fragment derivativesOutput derivatives(
y = _e11;
metal::float4 _e12 = metal::fwidth(foo);
z = _e12;
metal::float4 _e13 = x;
metal::float4 _e14 = y;
metal::float4 _e16 = z;
return derivativesOutput { (_e13 + _e14) * _e16 };
bool _e13 = test_any_and_all_for_bool();
metal::float4 _e14 = x;
metal::float4 _e15 = y;
metal::float4 _e17 = z;
return derivativesOutput { (_e14 + _e15) * _e17 };
}

View File

@ -1,56 +1,66 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 33
; Bound: 40
OpCapability Shader
OpCapability DerivativeControl
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %16 "derivatives" %11 %14
OpExecutionMode %16 OriginUpperLeft
OpDecorate %11 BuiltIn FragCoord
OpDecorate %14 Location 0
OpEntryPoint Fragment %22 "derivatives" %17 %20
OpExecutionMode %22 OriginUpperLeft
OpDecorate %17 BuiltIn FragCoord
OpDecorate %20 Location 0
%2 = OpTypeVoid
%4 = OpTypeFloat 32
%3 = OpTypeVector %4 4
%6 = OpTypePointer Function %3
%7 = OpConstantNull %3
%12 = OpTypePointer Input %3
%11 = OpVariable %12 Input
%15 = OpTypePointer Output %3
%14 = OpVariable %15 Output
%17 = OpTypeFunction %2
%16 = OpFunction %2 None %17
%3 = OpTypeBool
%5 = OpTypeFloat 32
%4 = OpTypeVector %5 4
%8 = OpTypeFunction %3
%9 = OpConstantTrue %3
%12 = OpTypePointer Function %4
%13 = OpConstantNull %4
%18 = OpTypePointer Input %4
%17 = OpVariable %18 Input
%21 = OpTypePointer Output %4
%20 = OpVariable %21 Output
%23 = OpTypeFunction %2
%7 = OpFunction %3 None %8
%6 = OpLabel
OpBranch %10
%10 = OpLabel
%5 = OpVariable %6 Function %7
%8 = OpVariable %6 Function %7
%9 = OpVariable %6 Function %7
%13 = OpLoad %3 %11
OpBranch %18
%18 = OpLabel
%19 = OpDPdxCoarse %3 %13
OpStore %5 %19
%20 = OpDPdyCoarse %3 %13
OpStore %8 %20
%21 = OpFwidthCoarse %3 %13
OpStore %9 %21
%22 = OpDPdxFine %3 %13
OpStore %5 %22
%23 = OpDPdyFine %3 %13
OpStore %8 %23
%24 = OpFwidthFine %3 %13
OpStore %9 %24
%25 = OpDPdx %3 %13
OpStore %5 %25
%26 = OpDPdy %3 %13
OpStore %8 %26
%27 = OpFwidth %3 %13
OpStore %9 %27
%28 = OpLoad %3 %5
%29 = OpLoad %3 %8
%30 = OpFAdd %3 %28 %29
%31 = OpLoad %3 %9
%32 = OpFMul %3 %30 %31
OpReturnValue %9
OpFunctionEnd
%22 = OpFunction %2 None %23
%16 = OpLabel
%11 = OpVariable %12 Function %13
%14 = OpVariable %12 Function %13
%15 = OpVariable %12 Function %13
%19 = OpLoad %4 %17
OpBranch %24
%24 = OpLabel
%25 = OpDPdxCoarse %4 %19
OpStore %11 %25
%26 = OpDPdyCoarse %4 %19
OpStore %14 %26
%27 = OpFwidthCoarse %4 %19
OpStore %15 %27
%28 = OpDPdxFine %4 %19
OpStore %11 %28
%29 = OpDPdyFine %4 %19
OpStore %14 %29
%30 = OpFwidthFine %4 %19
OpStore %15 %30
%31 = OpDPdx %4 %19
OpStore %11 %31
%32 = OpDPdy %4 %19
OpStore %14 %32
%33 = OpFwidth %4 %19
OpStore %15 %33
%34 = OpFunctionCall %3 %7
%35 = OpLoad %4 %11
%36 = OpLoad %4 %14
%37 = OpFAdd %4 %35 %36
%38 = OpLoad %4 %15
%39 = OpFMul %4 %37 %38
OpStore %20 %39
OpReturn
OpFunctionEnd

View File

@ -1,3 +1,7 @@
fn test_any_and_all_for_bool() -> bool {
return true;
}
@fragment
fn derivatives(@builtin(position) foo: vec4<f32>) -> @location(0) vec4<f32> {
var x: vec4<f32>;
@ -22,8 +26,9 @@ fn derivatives(@builtin(position) foo: vec4<f32>) -> @location(0) vec4<f32> {
y = _e11;
let _e12 = fwidth(foo);
z = _e12;
let _e13 = x;
let _e14 = y;
let _e16 = z;
return ((_e13 + _e14) * _e16);
let _e13 = test_any_and_all_for_bool();
let _e14 = x;
let _e15 = y;
let _e17 = z;
return ((_e14 + _e15) * _e17);
}