[naga] Implement quantizeToF16 (#6519)

Implement WGSL frontend and WGSL, SPIR-V, HLSL, MSL, and GLSL
backends. WGSL and SPIR-V backends natively support the instruction.
MSL and HLSL emulate it by casting to f16 and back to f32. GLSL does
similar but must (mis)use (un)pack2x16 to do so.
This commit is contained in:
Jamie Nicol 2024-11-12 11:05:19 +00:00 committed by GitHub
parent 6a60458790
commit cffc7933fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 206 additions and 76 deletions

View File

@ -47,6 +47,7 @@ Bottom level categories:
- Parse `diagnostic(…)` directives, but don't implement any triggering rules yet. By @ErichDonGubler in [#6456](https://github.com/gfx-rs/wgpu/pull/6456).
- Fix an issue where `naga` CLI would incorrectly skip the first positional argument when `--stdin-file-path` was specified. By @ErichDonGubler in [#6480](https://github.com/gfx-rs/wgpu/pull/6480).
- Fix textureNumLevels in the GLSL backend. By @magcius in [#6483](https://github.com/gfx-rs/wgpu/pull/6483).
- Implement `quantizeToF16()` for WGSL frontend, and WGSL, SPIR-V, HLSL, MSL, and GLSL backends. By @jamienicol in [#6519](https://github.com/gfx-rs/wgpu/pull/6519).
#### General

View File

@ -1332,7 +1332,8 @@ impl<'a, W: Write> Writer<'a, W> {
crate::MathFunction::Pack4xI8
| crate::MathFunction::Pack4xU8
| crate::MathFunction::Unpack4xI8
| crate::MathFunction::Unpack4xU8 => {
| crate::MathFunction::Unpack4xU8
| crate::MathFunction::QuantizeToF16 => {
self.need_bake_expressions.insert(arg);
}
crate::MathFunction::ExtractBits => {
@ -3495,6 +3496,48 @@ impl<'a, W: Write> Writer<'a, W> {
Mf::Inverse => "inverse",
Mf::Transpose => "transpose",
Mf::Determinant => "determinant",
Mf::QuantizeToF16 => match *ctx.resolve_type(arg, &self.module.types) {
crate::TypeInner::Scalar { .. } => {
write!(self.out, "unpackHalf2x16(packHalf2x16(vec2(")?;
self.write_expr(arg, ctx)?;
write!(self.out, "))).x")?;
return Ok(());
}
crate::TypeInner::Vector {
size: crate::VectorSize::Bi,
..
} => {
write!(self.out, "unpackHalf2x16(packHalf2x16(")?;
self.write_expr(arg, ctx)?;
write!(self.out, "))")?;
return Ok(());
}
crate::TypeInner::Vector {
size: crate::VectorSize::Tri,
..
} => {
write!(self.out, "vec3(unpackHalf2x16(packHalf2x16(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ".xy)), unpackHalf2x16(packHalf2x16(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ".zz)).x)")?;
return Ok(());
}
crate::TypeInner::Vector {
size: crate::VectorSize::Quad,
..
} => {
write!(self.out, "vec4(unpackHalf2x16(packHalf2x16(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ".xy)), unpackHalf2x16(packHalf2x16(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ".zw)))")?;
return Ok(());
}
_ => unreachable!(
"Correct TypeInner for QuantizeToF16 should be already validated"
),
},
// bits
Mf::CountTrailingZeros => {
match *ctx.resolve_type(arg, &self.module.types) {

View File

@ -3036,6 +3036,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Unpack4x8unorm,
Unpack4xI8,
Unpack4xU8,
QuantizeToF16,
Regular(&'static str),
MissingIntOverload(&'static str),
MissingIntReturnType(&'static str),
@ -3102,6 +3103,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
//Mf::Inverse =>,
Mf::Transpose => Function::Regular("transpose"),
Mf::Determinant => Function::Regular("determinant"),
Mf::QuantizeToF16 => Function::QuantizeToF16,
// bits
Mf::CountTrailingZeros => Function::CountTrailingZeros,
Mf::CountLeadingZeros => Function::CountLeadingZeros,
@ -3303,6 +3305,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " >> 24) << 24 >> 24")?;
}
Function::QuantizeToF16 => {
write!(self.out, "f16tof32(f32tof16(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
}
Function::Regular(fun_name) => {
write!(self.out, "{fun_name}(")?;
self.write_expr(module, arg, func_ctx)?;

View File

@ -1936,6 +1936,7 @@ impl<W: Write> Writer<W> {
Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
Mf::Transpose => "transpose",
Mf::Determinant => "determinant",
Mf::QuantizeToF16 => "",
// bits
Mf::CountTrailingZeros => "ctz",
Mf::CountLeadingZeros => "clz",
@ -2144,6 +2145,22 @@ impl<W: Write> Writer<W> {
self.put_expression(arg, context, true)?;
write!(self.out, " >> 24) << 24 >> 24")?;
}
Mf::QuantizeToF16 => {
match *context.resolve_type(arg) {
crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?,
crate::TypeInner::Vector { size, .. } => write!(
self.out,
"{NAMESPACE}::float{size}({NAMESPACE}::half{size}(",
size = back::vector_size_str(size),
)?,
_ => unreachable!(
"Correct TypeInner for QuantizeToF16 should be already validated"
),
};
self.put_expression(arg, context, true)?;
write!(self.out, "))")?;
}
_ => {
write!(self.out, "{NAMESPACE}::{fun_name}")?;
self.put_call_parameters(

View File

@ -1032,6 +1032,12 @@ impl<'w> BlockContext<'w> {
arg0_id,
)),
Mf::Determinant => MathOp::Ext(spirv::GLOp::Determinant),
Mf::QuantizeToF16 => MathOp::Custom(Instruction::unary(
spirv::Op::QuantizeToF16,
result_type_id,
id,
arg0_id,
)),
Mf::ReverseBits => MathOp::Custom(Instruction::unary(
spirv::Op::BitReverse,
result_type_id,

View File

@ -1723,6 +1723,7 @@ impl<W: Write> Writer<W> {
Mf::InverseSqrt => Function::Regular("inverseSqrt"),
Mf::Transpose => Function::Regular("transpose"),
Mf::Determinant => Function::Regular("determinant"),
Mf::QuantizeToF16 => Function::Regular("quantizeToF16"),
// bits
Mf::CountTrailingZeros => Function::Regular("countTrailingZeros"),
Mf::CountLeadingZeros => Function::Regular("countLeadingZeros"),

View File

@ -230,6 +230,7 @@ pub fn map_standard_fun(word: &str) -> Option<crate::MathFunction> {
"inverseSqrt" => Mf::InverseSqrt,
"transpose" => Mf::Transpose,
"determinant" => Mf::Determinant,
"quantizeToF16" => Mf::QuantizeToF16,
// bits
"countTrailingZeros" => Mf::CountTrailingZeros,
"countLeadingZeros" => Mf::CountLeadingZeros,

View File

@ -1199,6 +1199,7 @@ pub enum MathFunction {
Inverse,
Transpose,
Determinant,
QuantizeToF16,
// bits
CountTrailingZeros,
CountLeadingZeros,

View File

@ -478,6 +478,7 @@ impl super::MathFunction {
Self::Inverse => 1,
Self::Transpose => 1,
Self::Determinant => 1,
Self::QuantizeToF16 => 1,
// bits
Self::CountTrailingZeros => 1,
Self::CountLeadingZeros => 1,

View File

@ -665,7 +665,8 @@ impl<'a> ResolveContext<'a> {
| Mf::Exp2
| Mf::Log
| Mf::Log2
| Mf::Pow => res_arg.clone(),
| Mf::Pow
| Mf::QuantizeToF16 => res_arg.clone(),
Mf::Modf | Mf::Frexp => {
let (size, width) = match res_arg.inner_with(types) {
&Ti::Scalar(crate::Scalar {

View File

@ -1363,6 +1363,26 @@ impl super::Validator {
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
Mf::QuantizeToF16 => {
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
match *arg_ty {
Ti::Scalar(Sc {
kind: Sk::Float,
width: 4,
})
| Ti::Vector {
scalar:
Sc {
kind: Sk::Float,
width: 4,
},
..
} => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
// Remove once fixed https://github.com/gfx-rs/wgpu/issues/5276
Mf::CountLeadingZeros
| Mf::CountTrailingZeros

View File

@ -45,4 +45,8 @@ fn main() {
let frexp_b = frexp(1.5).fract;
let frexp_c: i32 = frexp(1.5).exp;
let frexp_d: i32 = frexp(vec4(1.5, 1.5, 1.5, 1.5)).exp.x;
let quantizeToF16_a: f32 = quantizeToF16(1.0);
let quantizeToF16_b: vec2<f32> = quantizeToF16(vec2(1.0, 1.0));
let quantizeToF16_c: vec3<f32> = quantizeToF16(vec3(1.0, 1.0, 1.0));
let quantizeToF16_d: vec4<f32> = quantizeToF16(vec4(1.0, 1.0, 1.0, 1.0));
}

View File

@ -87,5 +87,12 @@ void main() {
float frexp_b = naga_frexp(1.5).fract_;
int frexp_c = naga_frexp(1.5).exp_;
int frexp_d = naga_frexp(vec4(1.5, 1.5, 1.5, 1.5)).exp_.x;
float quantizeToF16_a = unpackHalf2x16(packHalf2x16(vec2(1.0))).x;
vec2 _e120 = vec2(1.0, 1.0);
vec2 quantizeToF16_b = unpackHalf2x16(packHalf2x16(_e120));
vec3 _e125 = vec3(1.0, 1.0, 1.0);
vec3 quantizeToF16_c = vec3(unpackHalf2x16(packHalf2x16(_e125.xy)), unpackHalf2x16(packHalf2x16(_e125.zz)).x);
vec4 _e131 = vec4(1.0, 1.0, 1.0, 1.0);
vec4 quantizeToF16_d = vec4(unpackHalf2x16(packHalf2x16(_e131.xy)), unpackHalf2x16(packHalf2x16(_e131.zw)));
}

View File

@ -101,4 +101,8 @@ void main()
float frexp_b = naga_frexp(1.5).fract;
int frexp_c = naga_frexp(1.5).exp_;
int frexp_d = naga_frexp(float4(1.5, 1.5, 1.5, 1.5)).exp_.x;
float quantizeToF16_a = f16tof32(f32tof16(1.0));
float2 quantizeToF16_b = f16tof32(f32tof16(float2(1.0, 1.0)));
float3 quantizeToF16_c = f16tof32(f32tof16(float3(1.0, 1.0, 1.0)));
float4 quantizeToF16_d = f16tof32(f32tof16(float4(1.0, 1.0, 1.0, 1.0)));
}

View File

@ -89,4 +89,8 @@ fragment void main_(
float frexp_b = naga_frexp(1.5).fract;
int frexp_c = naga_frexp(1.5).exp;
int frexp_d = naga_frexp(metal::float4(1.5, 1.5, 1.5, 1.5)).exp.x;
float quantizeToF16_a = float(half(1.0));
metal::float2 quantizeToF16_b = metal::float2(metal::half2(metal::float2(1.0, 1.0)));
metal::float3 quantizeToF16_c = metal::float3(metal::half3(metal::float3(1.0, 1.0, 1.0)));
metal::float4 quantizeToF16_d = metal::float4(metal::half4(metal::float4(1.0, 1.0, 1.0, 1.0)));
}

View File

@ -1,12 +1,12 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 87
; Bound: 95
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %17 "main"
OpExecutionMode %17 OriginUpperLeft
OpEntryPoint Fragment %18 "main"
OpExecutionMode %18 OriginUpperLeft
OpMemberDecorate %11 0 Offset 0
OpMemberDecorate %11 1 Offset 4
OpMemberDecorate %12 0 Offset 0
@ -31,77 +31,85 @@ OpMemberDecorate %15 1 Offset 16
%13 = OpTypeStruct %4 %4
%14 = OpTypeStruct %3 %6
%15 = OpTypeStruct %4 %5
%18 = OpTypeFunction %2
%19 = OpConstant %3 1.0
%20 = OpConstant %3 0.0
%21 = OpConstantComposite %4 %20 %20 %20 %20
%22 = OpConstant %6 -1
%23 = OpConstantComposite %5 %22 %22 %22 %22
%24 = OpConstant %3 -1.0
%25 = OpConstantComposite %4 %24 %24 %24 %24
%26 = OpConstantNull %7
%27 = OpConstant %9 4294967295
%28 = OpConstantComposite %7 %22 %22
%29 = OpConstant %9 0
%30 = OpConstantComposite %8 %29 %29
%31 = OpConstant %6 0
%32 = OpConstantComposite %7 %31 %31
%33 = OpConstant %9 32
%34 = OpConstant %6 32
%35 = OpConstantComposite %8 %33 %33
%36 = OpConstantComposite %7 %34 %34
%37 = OpConstant %9 31
%38 = OpConstantComposite %8 %37 %37
%39 = OpConstant %6 2
%40 = OpConstant %3 2.0
%41 = OpConstantComposite %10 %19 %40
%42 = OpConstant %6 3
%43 = OpConstant %6 4
%44 = OpConstantComposite %7 %42 %43
%45 = OpConstant %3 1.5
%46 = OpConstantComposite %10 %45 %45
%47 = OpConstantComposite %4 %45 %45 %45 %45
%54 = OpConstantComposite %4 %19 %19 %19 %19
%57 = OpConstantNull %6
%17 = OpFunction %2 None %18
%16 = OpLabel
OpBranch %48
%48 = OpLabel
%49 = OpExtInst %3 %1 Degrees %19
%50 = OpExtInst %3 %1 Radians %19
%51 = OpExtInst %4 %1 Degrees %21
%52 = OpExtInst %4 %1 Radians %21
%53 = OpExtInst %4 %1 FClamp %21 %21 %54
%55 = OpExtInst %4 %1 Refract %21 %21 %19
%58 = OpCompositeExtract %6 %26 0
%59 = OpCompositeExtract %6 %26 0
%60 = OpIMul %6 %58 %59
%61 = OpIAdd %6 %57 %60
%62 = OpCompositeExtract %6 %26 1
%63 = OpCompositeExtract %6 %26 1
%16 = OpTypeVector %3 3
%19 = OpTypeFunction %2
%20 = OpConstant %3 1.0
%21 = OpConstant %3 0.0
%22 = OpConstantComposite %4 %21 %21 %21 %21
%23 = OpConstant %6 -1
%24 = OpConstantComposite %5 %23 %23 %23 %23
%25 = OpConstant %3 -1.0
%26 = OpConstantComposite %4 %25 %25 %25 %25
%27 = OpConstantNull %7
%28 = OpConstant %9 4294967295
%29 = OpConstantComposite %7 %23 %23
%30 = OpConstant %9 0
%31 = OpConstantComposite %8 %30 %30
%32 = OpConstant %6 0
%33 = OpConstantComposite %7 %32 %32
%34 = OpConstant %9 32
%35 = OpConstant %6 32
%36 = OpConstantComposite %8 %34 %34
%37 = OpConstantComposite %7 %35 %35
%38 = OpConstant %9 31
%39 = OpConstantComposite %8 %38 %38
%40 = OpConstant %6 2
%41 = OpConstant %3 2.0
%42 = OpConstantComposite %10 %20 %41
%43 = OpConstant %6 3
%44 = OpConstant %6 4
%45 = OpConstantComposite %7 %43 %44
%46 = OpConstant %3 1.5
%47 = OpConstantComposite %10 %46 %46
%48 = OpConstantComposite %4 %46 %46 %46 %46
%49 = OpConstantComposite %10 %20 %20
%50 = OpConstantComposite %16 %20 %20 %20
%51 = OpConstantComposite %4 %20 %20 %20 %20
%58 = OpConstantComposite %4 %20 %20 %20 %20
%61 = OpConstantNull %6
%18 = OpFunction %2 None %19
%17 = OpLabel
OpBranch %52
%52 = OpLabel
%53 = OpExtInst %3 %1 Degrees %20
%54 = OpExtInst %3 %1 Radians %20
%55 = OpExtInst %4 %1 Degrees %22
%56 = OpExtInst %4 %1 Radians %22
%57 = OpExtInst %4 %1 FClamp %22 %22 %58
%59 = OpExtInst %4 %1 Refract %22 %22 %20
%62 = OpCompositeExtract %6 %27 0
%63 = OpCompositeExtract %6 %27 0
%64 = OpIMul %6 %62 %63
%56 = OpIAdd %6 %61 %64
%65 = OpExtInst %3 %1 Ldexp %19 %39
%66 = OpExtInst %10 %1 Ldexp %41 %44
%67 = OpExtInst %11 %1 ModfStruct %45
%68 = OpExtInst %11 %1 ModfStruct %45
%69 = OpCompositeExtract %3 %68 0
%70 = OpExtInst %11 %1 ModfStruct %45
%71 = OpCompositeExtract %3 %70 1
%72 = OpExtInst %12 %1 ModfStruct %46
%73 = OpExtInst %13 %1 ModfStruct %47
%74 = OpCompositeExtract %4 %73 1
%75 = OpCompositeExtract %3 %74 0
%76 = OpExtInst %12 %1 ModfStruct %46
%77 = OpCompositeExtract %10 %76 0
%78 = OpCompositeExtract %3 %77 1
%79 = OpExtInst %14 %1 FrexpStruct %45
%80 = OpExtInst %14 %1 FrexpStruct %45
%81 = OpCompositeExtract %3 %80 0
%82 = OpExtInst %14 %1 FrexpStruct %45
%83 = OpCompositeExtract %6 %82 1
%84 = OpExtInst %15 %1 FrexpStruct %47
%85 = OpCompositeExtract %5 %84 1
%86 = OpCompositeExtract %6 %85 0
%65 = OpIAdd %6 %61 %64
%66 = OpCompositeExtract %6 %27 1
%67 = OpCompositeExtract %6 %27 1
%68 = OpIMul %6 %66 %67
%60 = OpIAdd %6 %65 %68
%69 = OpExtInst %3 %1 Ldexp %20 %40
%70 = OpExtInst %10 %1 Ldexp %42 %45
%71 = OpExtInst %11 %1 ModfStruct %46
%72 = OpExtInst %11 %1 ModfStruct %46
%73 = OpCompositeExtract %3 %72 0
%74 = OpExtInst %11 %1 ModfStruct %46
%75 = OpCompositeExtract %3 %74 1
%76 = OpExtInst %12 %1 ModfStruct %47
%77 = OpExtInst %13 %1 ModfStruct %48
%78 = OpCompositeExtract %4 %77 1
%79 = OpCompositeExtract %3 %78 0
%80 = OpExtInst %12 %1 ModfStruct %47
%81 = OpCompositeExtract %10 %80 0
%82 = OpCompositeExtract %3 %81 1
%83 = OpExtInst %14 %1 FrexpStruct %46
%84 = OpExtInst %14 %1 FrexpStruct %46
%85 = OpCompositeExtract %3 %84 0
%86 = OpExtInst %14 %1 FrexpStruct %46
%87 = OpCompositeExtract %6 %86 1
%88 = OpExtInst %15 %1 FrexpStruct %48
%89 = OpCompositeExtract %5 %88 1
%90 = OpCompositeExtract %6 %89 0
%91 = OpQuantizeToF16 %3 %20
%92 = OpQuantizeToF16 %10 %49
%93 = OpQuantizeToF16 %16 %50
%94 = OpQuantizeToF16 %4 %51
OpReturn
OpFunctionEnd

View File

@ -32,4 +32,8 @@ fn main() {
let frexp_b = frexp(1.5f).fract;
let frexp_c = frexp(1.5f).exp;
let frexp_d = frexp(vec4<f32>(1.5f, 1.5f, 1.5f, 1.5f)).exp.x;
let quantizeToF16_a = quantizeToF16(1f);
let quantizeToF16_b = quantizeToF16(vec2<f32>(1f, 1f));
let quantizeToF16_c = quantizeToF16(vec3<f32>(1f, 1f, 1f));
let quantizeToF16_d = quantizeToF16(vec4<f32>(1f, 1f, 1f, 1f));
}