[hlsl-out] Fix accesses on zero value expressions (#5587)

This commit is contained in:
Imbris 2024-04-24 04:40:08 -04:00 committed by GitHub
parent edf1a86148
commit 82fa580152
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 220 additions and 29 deletions

View File

@ -164,6 +164,8 @@ Bottom level categories:
- GLSL 410 does not support layout(binding = ...), enable only for GLSL 420. By @bes in [#5357](https://github.com/gfx-rs/wgpu/pull/5357)
- In spv-out, check for acceleration and ray-query types when enabling ray-query extension to prevent validation error. By @Vecvec in [#5463](https://github.com/gfx-rs/wgpu/pull/5463)
- Add a limit for curly brace nesting in WGSL parsing, plus a note about stack size requirements. By @ErichDonGubler in [#5447](https://github.com/gfx-rs/wgpu/pull/5447).
- In hlsl-out, parenthesize output for `Expression::ZeroValue` (e.g. `(float4)0` -> `((float)0)`). This allows subsequent member access to parse correctly. By @Imberflur in [#5587](https://github.com/gfx-rs/wgpu/pull/5587).
#### Tests

View File

@ -70,6 +70,11 @@ pub(super) struct WrappedMath {
pub(super) components: Option<u32>,
}
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
pub(super) struct WrappedZeroValue {
pub(super) ty: Handle<crate::Type>,
}
/// HLSL backend requires its own `ImageQuery` enum.
///
/// It is used inside `WrappedImageQuery` and should be unique per ImageQuery function.
@ -359,7 +364,7 @@ impl<'a, W: Write> super::Writer<'a, W> {
}
/// Helper function that write wrapped function for `Expression::Compose` for structures.
pub(super) fn write_wrapped_constructor_function(
fn write_wrapped_constructor_function(
&mut self,
module: &crate::Module,
constructor: WrappedConstructor,
@ -862,6 +867,25 @@ impl<'a, W: Write> super::Writer<'a, W> {
Ok(())
}
// TODO: we could merge this with iteration in write_wrapped_compose_functions...
//
/// Helper function that writes zero value wrapped functions
pub(super) fn write_wrapped_zero_value_functions(
&mut self,
module: &crate::Module,
expressions: &crate::Arena<crate::Expression>,
) -> BackendResult {
for (handle, _) in expressions.iter() {
if let crate::Expression::ZeroValue(ty) = expressions[handle] {
let zero_value = WrappedZeroValue { ty };
if self.wrapped.zero_values.insert(zero_value) {
self.write_wrapped_zero_value_function(module, zero_value)?;
}
}
}
Ok(())
}
pub(super) fn write_wrapped_math_functions(
&mut self,
module: &crate::Module,
@ -1006,6 +1030,7 @@ impl<'a, W: Write> super::Writer<'a, W> {
) -> BackendResult {
self.write_wrapped_math_functions(module, func_ctx)?;
self.write_wrapped_compose_functions(module, func_ctx.expressions)?;
self.write_wrapped_zero_value_functions(module, func_ctx.expressions)?;
for (handle, _) in func_ctx.expressions.iter() {
match func_ctx.expressions[handle] {
@ -1283,4 +1308,71 @@ impl<'a, W: Write> super::Writer<'a, W> {
Ok(())
}
pub(super) fn write_wrapped_zero_value_function_name(
&mut self,
module: &crate::Module,
zero_value: WrappedZeroValue,
) -> BackendResult {
let name = crate::TypeInner::hlsl_type_id(zero_value.ty, module.to_ctx(), &self.names)?;
write!(self.out, "ZeroValue{name}")?;
Ok(())
}
/// Helper function that write wrapped function for `Expression::ZeroValue`
///
/// This is necessary since we might have a member access after the zero value expression, e.g.
/// `.y` (in practice this can come up when consuming SPIRV that's been produced by glslc).
///
/// So we can't just write `(float4)0` since `(float4)0.y` won't parse correctly.
///
/// Parenthesizing the expression like `((float4)0).y` would work... except DXC can't handle
/// cases like:
///
/// ```ignore
/// tests\out\hlsl\access.hlsl:183:41: error: cannot compile this l-value expression yet
/// t_1.am = (__mat4x2[2])((float4x2[2])0);
/// ^
/// ```
fn write_wrapped_zero_value_function(
&mut self,
module: &crate::Module,
zero_value: WrappedZeroValue,
) -> BackendResult {
use crate::back::INDENT;
const RETURN_VARIABLE_NAME: &str = "ret";
// Write function return type and name
if let crate::TypeInner::Array { base, size, .. } = module.types[zero_value.ty].inner {
write!(self.out, "typedef ")?;
self.write_type(module, zero_value.ty)?;
write!(self.out, " ret_")?;
self.write_wrapped_zero_value_function_name(module, zero_value)?;
self.write_array_size(module, base, size)?;
writeln!(self.out, ";")?;
write!(self.out, "ret_")?;
self.write_wrapped_zero_value_function_name(module, zero_value)?;
} else {
self.write_type(module, zero_value.ty)?;
}
write!(self.out, " ")?;
self.write_wrapped_zero_value_function_name(module, zero_value)?;
// Write function parameters (none) and start function body
writeln!(self.out, "() {{")?;
// Write `ZeroValue` function.
write!(self.out, "{INDENT}return ")?;
self.write_default_init(module, zero_value.ty)?;
writeln!(self.out, ";")?;
// End of function body
writeln!(self.out, "}}")?;
// Write extra new line
writeln!(self.out)?;
Ok(())
}
}

View File

@ -267,6 +267,7 @@ pub enum Error {
#[derive(Default)]
struct Wrapped {
zero_values: crate::FastHashSet<help::WrappedZeroValue>,
array_lengths: crate::FastHashSet<help::WrappedArrayLength>,
image_queries: crate::FastHashSet<help::WrappedImageQuery>,
constructors: crate::FastHashSet<help::WrappedConstructor>,

View File

@ -1,5 +1,8 @@
use super::{
help::{WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess},
help::{
WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
WrappedZeroValue,
},
storage::StoreValue,
BackendResult, Error, Options,
};
@ -264,6 +267,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_special_functions(module)?;
self.write_wrapped_compose_functions(module, &module.global_expressions)?;
self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;
// Write all named constants
let mut constants = module
@ -2251,7 +2255,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_const_expression(module, constant.init)?;
}
}
Expression::ZeroValue(ty) => self.write_default_init(module, ty)?,
Expression::ZeroValue(ty) => {
self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?;
write!(self.out, "()")?;
}
Expression::Compose { ty, ref components } => {
match module.types[ty].inner {
TypeInner::Struct { .. } | TypeInner::Array { .. } => {
@ -3394,7 +3401,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
/// Helper function that write default zero initialization
fn write_default_init(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
pub(super) fn write_default_init(
&mut self,
module: &Module,
ty: Handle<crate::Type>,
) -> BackendResult {
write!(self.out, "(")?;
self.write_type(module, ty)?;
if let TypeInner::Array { base, size, .. } = module.types[ty].inner {

View File

@ -158,10 +158,15 @@ MatCx2InArray ConstructMatCx2InArray(float4x2 arg0[2]) {
return ret;
}
typedef float4x2 ret_ZeroValuearray2_float4x2_[2];
ret_ZeroValuearray2_float4x2_ ZeroValuearray2_float4x2_() {
return (float4x2[2])0;
}
void test_matrix_within_array_within_struct_accesses()
{
int idx_1 = 1;
MatCx2InArray t_1 = ConstructMatCx2InArray((float4x2[2])0);
MatCx2InArray t_1 = ConstructMatCx2InArray(ZeroValuearray2_float4x2_());
int _expr3 = idx_1;
idx_1 = (_expr3 - 1);
@ -180,7 +185,7 @@ void test_matrix_within_array_within_struct_accesses()
float l7_ = __get_col_of_mat4x2(nested_mat_cx2_.am[0], _expr46)[_expr48];
int _expr55 = idx_1;
idx_1 = (_expr55 + 1);
t_1.am = (__mat4x2[2])(float4x2[2])0;
t_1.am = (__mat4x2[2])ZeroValuearray2_float4x2_();
t_1.am[0] = (__mat4x2)float4x2((8.0).xx, (7.0).xx, (6.0).xx, (5.0).xx);
t_1.am[0]._0 = (9.0).xx;
int _expr77 = idx_1;
@ -231,6 +236,11 @@ ret_Constructarray5_int_ Constructarray5_int_(int arg0, int arg1, int arg2, int
return ret;
}
typedef float ret_ZeroValuearray5_array10_float__[5][10];
ret_ZeroValuearray5_array10_float__ ZeroValuearray5_array10_float__() {
return (float[5][10])0;
}
typedef uint2 ret_Constructarray2_uint2_[2];
ret_Constructarray2_uint2_ Constructarray2_uint2_(uint2 arg0, uint2 arg1) {
uint2 ret[2] = { arg0, arg1 };
@ -262,10 +272,14 @@ float4 foo_vert(uint vi : SV_VertexID) : SV_Position
c2_ = Constructarray5_int_(a_1, int(b), 3, 4, 5);
c2_[(vi + 1u)] = 42;
int value = c2_[vi];
const float _e47 = test_arr_as_arg((float[5][10])0);
const float _e47 = test_arr_as_arg(ZeroValuearray5_array10_float__());
return float4(mul(float4((value).xxxx), _matrix), 2.0);
}
int2 ZeroValueint2() {
return (int2)0;
}
float4 foo_frag() : SV_Target0
{
bar.Store(8+16+0, asuint(1.0));
@ -282,7 +296,7 @@ float4 foo_frag() : SV_Target0
bar.Store2(144+8, asuint(_value2[1]));
}
bar.Store(0+8+160, asuint(1));
qux.Store2(0, asuint((int2)0));
qux.Store2(0, asuint(ZeroValueint2()));
return (0.0).xxxx;
}

View File

@ -18,17 +18,50 @@ ret_Constructarray4_int_ Constructarray4_int_(int arg0, int arg1, int arg2, int
return ret;
}
bool ZeroValuebool() {
return (bool)0;
}
int ZeroValueint() {
return (int)0;
}
uint ZeroValueuint() {
return (uint)0;
}
float ZeroValuefloat() {
return (float)0;
}
uint2 ZeroValueuint2() {
return (uint2)0;
}
float2x2 ZeroValuefloat2x2() {
return (float2x2)0;
}
typedef Foo ret_ZeroValuearray3_Foo_[3];
ret_ZeroValuearray3_Foo_ ZeroValuearray3_Foo_() {
return (Foo[3])0;
}
Foo ZeroValueFoo() {
return (Foo)0;
}
static const float3 const2_ = float3(0.0, 1.0, 2.0);
static const float2x2 const3_ = float2x2(float2(0.0, 1.0), float2(2.0, 3.0));
static const float2x2 const4_[1] = Constructarray1_float2x2_(float2x2(float2(0.0, 1.0), float2(2.0, 3.0)));
static const bool cz0_ = (bool)0;
static const int cz1_ = (int)0;
static const uint cz2_ = (uint)0;
static const float cz3_ = (float)0;
static const uint2 cz4_ = (uint2)0;
static const float2x2 cz5_ = (float2x2)0;
static const Foo cz6_[3] = (Foo[3])0;
static const Foo cz7_ = (Foo)0;
static const bool cz0_ = ZeroValuebool();
static const int cz1_ = ZeroValueint();
static const uint cz2_ = ZeroValueuint();
static const float cz3_ = ZeroValuefloat();
static const uint2 cz4_ = ZeroValueuint2();
static const float2x2 cz5_ = ZeroValuefloat2x2();
static const Foo cz6_[3] = ZeroValuearray3_Foo_();
static const Foo cz7_ = ZeroValueFoo();
static const int cp3_[4] = Constructarray4_int_(0, 1, 2, 3);
Foo ConstructFoo(float4 arg0, int arg1) {
@ -38,6 +71,10 @@ Foo ConstructFoo(float4 arg0, int arg1) {
return ret;
}
float2x3 ZeroValuefloat2x3() {
return (float2x3)0;
}
[numthreads(1, 1, 1)]
void main()
{

View File

@ -71,6 +71,10 @@ void test_msl_packed_vec3_as_arg(float3 arg)
return;
}
float3x3 ZeroValuefloat3x3() {
return (float3x3)0;
}
FooStruct ConstructFooStruct(float3 arg0, float arg1) {
FooStruct ret = (FooStruct)0;
ret.v3_ = arg0;
@ -91,8 +95,8 @@ void test_msl_packed_vec3_()
float3 l0_ = data.v3_;
float2 l1_ = data.v3_.zx;
test_msl_packed_vec3_as_arg(data.v3_);
float3 mvm0_ = mul((float3x3)0, data.v3_);
float3 mvm1_ = mul(data.v3_, (float3x3)0);
float3 mvm0_ = mul(ZeroValuefloat3x3(), data.v3_);
float3 mvm1_ = mul(data.v3_, ZeroValuefloat3x3());
float3 svm0_ = (data.v3_ * 2.0);
float3 svm1_ = (2.0 * data.v3_);
}

View File

@ -63,6 +63,10 @@ _frexp_result_vec4_f32_ naga_frexp(float4 arg) {
return result;
}
int2 ZeroValueint2() {
return (int2)0;
}
void main()
{
float4 v = (0.0).xxxx;
@ -74,7 +78,7 @@ void main()
float4 g = refract(v, v, 1.0);
int4 sign_b = int4(-1, -1, -1, -1);
float4 sign_d = float4(-1.0, -1.0, -1.0, -1.0);
int const_dot = dot((int2)0, (int2)0);
int const_dot = dot(ZeroValueint2(), ZeroValueint2());
uint first_leading_bit_abs = firstbithigh(0u);
int flb_a = asint(firstbithigh(-1));
int2 flb_b = asint(firstbithigh((-1).xx));

View File

@ -55,6 +55,18 @@ void logical()
bool4 bitwise_and1_ = ((true).xxxx & (false).xxxx);
}
float3x3 ZeroValuefloat3x3() {
return (float3x3)0;
}
float4x3 ZeroValuefloat4x3() {
return (float4x3)0;
}
float3x4 ZeroValuefloat3x4() {
return (float3x4)0;
}
void arithmetic()
{
float neg0_1 = -(1.0);
@ -122,13 +134,13 @@ void arithmetic()
float2 rem4_1 = fmod((2.0).xx, (1.0).xx);
float2 rem5_1 = fmod((2.0).xx, (1.0).xx);
}
float3x3 add = ((float3x3)0 + (float3x3)0);
float3x3 sub = ((float3x3)0 - (float3x3)0);
float3x3 mul_scalar0_ = mul(1.0, (float3x3)0);
float3x3 mul_scalar1_ = mul((float3x3)0, 2.0);
float3 mul_vector0_ = mul((1.0).xxxx, (float4x3)0);
float4 mul_vector1_ = mul((float4x3)0, (2.0).xxx);
float3x3 mul_ = mul((float3x4)0, (float4x3)0);
float3x3 add = (ZeroValuefloat3x3() + ZeroValuefloat3x3());
float3x3 sub = (ZeroValuefloat3x3() - ZeroValuefloat3x3());
float3x3 mul_scalar0_ = mul(1.0, ZeroValuefloat3x3());
float3x3 mul_scalar1_ = mul(ZeroValuefloat3x3(), 2.0);
float3 mul_vector0_ = mul((1.0).xxxx, ZeroValuefloat4x3());
float4 mul_vector1_ = mul(ZeroValuefloat4x3(), (2.0).xxx);
float3x3 mul_ = mul(ZeroValuefloat3x4(), ZeroValuefloat4x3());
}
void bit()
@ -199,10 +211,14 @@ void comparison()
bool4 gte5_ = ((2.0).xxxx >= (1.0).xxxx);
}
int3 ZeroValueint3() {
return (int3)0;
}
void assignment()
{
int a_1 = (int)0;
int3 vec0_ = (int3)0;
int3 vec0_ = ZeroValueint3();
a_1 = 1;
int _expr5 = a_1;

View File

@ -20,9 +20,14 @@ gl_PerVertex Constructgl_PerVertex(float4 arg0, float arg1, float arg2[1], float
return ret;
}
typedef float ret_ZeroValuearray1_float_[1];
ret_ZeroValuearray1_float_ ZeroValuearray1_float_() {
return (float[1])0;
}
static float2 v_uv = (float2)0;
static float2 a_uv_1 = (float2)0;
static gl_PerVertex unnamed = Constructgl_PerVertex(float4(0.0, 0.0, 0.0, 1.0), 1.0, (float[1])0, (float[1])0);
static gl_PerVertex unnamed = Constructgl_PerVertex(float4(0.0, 0.0, 0.0, 1.0), 1.0, ZeroValuearray1_float_(), ZeroValuearray1_float_());
static float2 a_pos_1 = (float2)0;
struct VertexOutput_main {

View File

@ -15,7 +15,12 @@ type_4 Constructtype_4(float4 arg0, float arg1, float arg2[1], float arg3[1]) {
return ret;
}
static type_4 global = Constructtype_4(float4(0.0, 0.0, 0.0, 1.0), 1.0, (float[1])0, (float[1])0);
typedef float ret_ZeroValuearray1_float_[1];
ret_ZeroValuearray1_float_ ZeroValuearray1_float_() {
return (float[1])0;
}
static type_4 global = Constructtype_4(float4(0.0, 0.0, 0.0, 1.0), 1.0, ZeroValuearray1_float_(), ZeroValuearray1_float_());
static int global_1 = (int)0;
void function()