[naga] Handle comparison operands in pipeline constant evaluation.

Properly adjust `AtomicFunction::Exchange::compare` after pipeline
constant evaluation.
This commit is contained in:
Jim Blandy 2024-03-25 18:19:17 -07:00 committed by Teodor Tanasoaia
parent a7d8ee999d
commit 8a2bc07f11
7 changed files with 344 additions and 1 deletions

View File

@ -571,11 +571,26 @@ fn adjust_stmt(new_pos: &[Handle<Expression>], stmt: &mut Statement) {
ref mut pointer,
ref mut value,
ref mut result,
..
ref mut fun,
} => {
adjust(pointer);
adjust(value);
adjust(result);
match *fun {
crate::AtomicFunction::Exchange {
compare: Some(ref mut compare),
} => {
adjust(compare);
}
crate::AtomicFunction::Add
| crate::AtomicFunction::Subtract
| crate::AtomicFunction::And
| crate::AtomicFunction::ExclusiveOr
| crate::AtomicFunction::InclusiveOr
| crate::AtomicFunction::Min
| crate::AtomicFunction::Max
| crate::AtomicFunction::Exchange { compare: None } => {}
}
}
Statement::WorkGroupUniformLoad {
ref mut pointer,

View File

@ -0,0 +1,9 @@
(
spv: (
version: (1, 0),
separate_entry_points: true,
),
pipeline_constants: {
"o": 2.0
}
)

View File

@ -0,0 +1,7 @@
override o: i32;
var<workgroup> a: atomic<u32>;
@compute @workgroup_size(1)
fn f() {
atomicCompareExchangeWeak(&a, u32(o), 1u);
}

View File

@ -0,0 +1,128 @@
(
types: [
(
name: None,
inner: Scalar((
kind: Sint,
width: 4,
)),
),
(
name: None,
inner: Atomic((
kind: Uint,
width: 4,
)),
),
(
name: None,
inner: Scalar((
kind: Uint,
width: 4,
)),
),
(
name: None,
inner: Scalar((
kind: Bool,
width: 1,
)),
),
(
name: Some("__atomic_compare_exchange_result<Uint,4>"),
inner: Struct(
members: [
(
name: Some("old_value"),
ty: 3,
binding: None,
offset: 0,
),
(
name: Some("exchanged"),
ty: 4,
binding: None,
offset: 4,
),
],
span: 8,
),
),
],
special_types: (
ray_desc: None,
ray_intersection: None,
predeclared_types: {
AtomicCompareExchangeWeakResult((
kind: Uint,
width: 4,
)): 5,
},
),
constants: [],
overrides: [
(
name: Some("o"),
id: None,
ty: 1,
init: None,
),
],
global_variables: [
(
name: Some("a"),
space: WorkGroup,
binding: None,
ty: 2,
init: None,
),
],
global_expressions: [],
functions: [],
entry_points: [
(
name: "f",
stage: Compute,
early_depth_test: None,
workgroup_size: (1, 1, 1),
function: (
name: Some("f"),
arguments: [],
result: None,
local_variables: [],
expressions: [
GlobalVariable(1),
Override(1),
As(
expr: 2,
kind: Uint,
convert: Some(4),
),
Literal(U32(1)),
AtomicResult(
ty: 5,
comparison: true,
),
],
named_expressions: {},
body: [
Emit((
start: 2,
end: 3,
)),
Atomic(
pointer: 1,
fun: Exchange(
compare: Some(3),
),
value: 4,
result: 5,
),
Return(
value: None,
),
],
),
),
],
)

View File

@ -0,0 +1,128 @@
(
types: [
(
name: None,
inner: Scalar((
kind: Sint,
width: 4,
)),
),
(
name: None,
inner: Atomic((
kind: Uint,
width: 4,
)),
),
(
name: None,
inner: Scalar((
kind: Uint,
width: 4,
)),
),
(
name: None,
inner: Scalar((
kind: Bool,
width: 1,
)),
),
(
name: Some("__atomic_compare_exchange_result<Uint,4>"),
inner: Struct(
members: [
(
name: Some("old_value"),
ty: 3,
binding: None,
offset: 0,
),
(
name: Some("exchanged"),
ty: 4,
binding: None,
offset: 4,
),
],
span: 8,
),
),
],
special_types: (
ray_desc: None,
ray_intersection: None,
predeclared_types: {
AtomicCompareExchangeWeakResult((
kind: Uint,
width: 4,
)): 5,
},
),
constants: [],
overrides: [
(
name: Some("o"),
id: None,
ty: 1,
init: None,
),
],
global_variables: [
(
name: Some("a"),
space: WorkGroup,
binding: None,
ty: 2,
init: None,
),
],
global_expressions: [],
functions: [],
entry_points: [
(
name: "f",
stage: Compute,
early_depth_test: None,
workgroup_size: (1, 1, 1),
function: (
name: Some("f"),
arguments: [],
result: None,
local_variables: [],
expressions: [
GlobalVariable(1),
Override(1),
As(
expr: 2,
kind: Uint,
convert: Some(4),
),
Literal(U32(1)),
AtomicResult(
ty: 5,
comparison: true,
),
],
named_expressions: {},
body: [
Emit((
start: 2,
end: 3,
)),
Atomic(
pointer: 1,
fun: Exchange(
compare: Some(3),
),
value: 4,
result: 5,
),
Return(
value: None,
),
],
),
),
],
)

View File

@ -0,0 +1,52 @@
; SPIR-V
; Version: 1.0
; Generator: rspirv
; Bound: 33
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %11 "f" %18
OpExecutionMode %11 LocalSize 1 1 1
OpMemberDecorate %6 0 Offset 0
OpMemberDecorate %6 1 Offset 4
OpDecorate %18 BuiltIn LocalInvocationId
%2 = OpTypeVoid
%3 = OpTypeInt 32 1
%4 = OpTypeInt 32 0
%5 = OpTypeBool
%6 = OpTypeStruct %4 %5
%7 = OpConstant %3 2
%9 = OpTypePointer Workgroup %4
%8 = OpVariable %9 Workgroup
%12 = OpTypeFunction %2
%13 = OpConstant %4 2
%14 = OpConstant %4 1
%16 = OpConstantNull %4
%17 = OpTypeVector %4 3
%19 = OpTypePointer Input %17
%18 = OpVariable %19 Input
%21 = OpConstantNull %17
%22 = OpTypeVector %5 3
%27 = OpConstant %4 264
%30 = OpConstant %4 256
%11 = OpFunction %2 None %12
%10 = OpLabel
OpBranch %15
%15 = OpLabel
%20 = OpLoad %17 %18
%23 = OpIEqual %22 %20 %21
%24 = OpAll %5 %23
OpSelectionMerge %25 None
OpBranchConditional %24 %26 %25
%26 = OpLabel
OpStore %8 %16
OpBranch %25
%25 = OpLabel
OpControlBarrier %13 %13 %27
OpBranch %28
%28 = OpLabel
%31 = OpAtomicCompareExchange %4 %8 %7 %30 %30 %14 %13
%32 = OpIEqual %5 %31 %13
%29 = OpCompositeConstruct %6 %31 %32
OpReturn
OpFunctionEnd

View File

@ -853,6 +853,10 @@ fn convert_wgsl() {
"overrides",
Targets::IR | Targets::ANALYSIS | Targets::SPIRV | Targets::METAL | Targets::HLSL,
),
(
"overrides-atomicCompareExchangeWeak",
Targets::IR | Targets::SPIRV,
),
];
for &(name, targets) in inputs.iter() {