mirror of
https://github.com/gfx-rs/wgpu.git
synced 2024-11-27 01:03:41 +00:00
Fixes for fma
function (#1580)
* [hlsl-out] Write `mad` intrinsic for `fma` function - This should be enough because we only support f32 for now. - Adds a new test for WGSL functions, in the spirit of operators.wgsl. - Closes #1579 * Add FMA feature to glsl backend - I think this is right. Just iterate all known expressions in all functions and entry points to locate any `fma` function call. Should not need to walk the statement DAG. * Transform GLSL fma function into an airthmetic expression when necessary * Add tests for GLSL fma function tranformation * Remove the hazard comment from the webgl test input * Add helper method for fma function support checks * Address review comment
This commit is contained in:
parent
54178dede2
commit
924ab17b62
@ -1,7 +1,7 @@
|
||||
use super::{BackendResult, Error, Version, Writer};
|
||||
use crate::{
|
||||
Binding, Bytes, Handle, ImageClass, ImageDimension, Interpolation, Sampling, ScalarKind,
|
||||
ShaderStage, StorageClass, StorageFormat, Type, TypeInner,
|
||||
Binding, Bytes, Expression, Handle, ImageClass, ImageDimension, Interpolation, MathFunction,
|
||||
Sampling, ScalarKind, ShaderStage, StorageClass, StorageFormat, Type, TypeInner,
|
||||
};
|
||||
use std::fmt::Write;
|
||||
|
||||
@ -33,6 +33,8 @@ bitflags::bitflags! {
|
||||
/// Arrays with a dynamic length
|
||||
const DYNAMIC_ARRAY_SIZE = 1 << 16;
|
||||
const MULTI_VIEW = 1 << 17;
|
||||
/// Adds support for fused multiply-add
|
||||
const FMA = 1 << 18;
|
||||
}
|
||||
}
|
||||
|
||||
@ -98,6 +100,7 @@ impl FeaturesManager {
|
||||
check_feature!(SAMPLE_VARIABLES, 400, 300);
|
||||
check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310);
|
||||
check_feature!(MULTI_VIEW, 140, 310);
|
||||
check_feature!(FMA, 400, 310);
|
||||
|
||||
// Return an error if there are missing features
|
||||
if missing.is_empty() {
|
||||
@ -201,6 +204,11 @@ impl FeaturesManager {
|
||||
writeln!(out, "#extension GL_EXT_multiview : require")?;
|
||||
}
|
||||
|
||||
if self.0.contains(Features::FMA) && version.is_es() {
|
||||
// https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_gpu_shader5.txt
|
||||
writeln!(out, "#extension GL_EXT_gpu_shader5 : require")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -347,6 +355,27 @@ impl<'a, W> Writer<'a, W> {
|
||||
}
|
||||
}
|
||||
|
||||
if self.options.version.supports_fma_function() {
|
||||
let has_fma = self
|
||||
.module
|
||||
.functions
|
||||
.iter()
|
||||
.flat_map(|(_, f)| f.expressions.iter())
|
||||
.chain(
|
||||
self.module
|
||||
.entry_points
|
||||
.iter()
|
||||
.flat_map(|e| e.function.expressions.iter()),
|
||||
)
|
||||
.any(|(_, e)| match *e {
|
||||
Expression::Math { fun, .. } if fun == MathFunction::Fma => true,
|
||||
_ => false,
|
||||
});
|
||||
if has_fma {
|
||||
self.features.request(Features::FMA);
|
||||
}
|
||||
}
|
||||
|
||||
self.features.check_availability(self.options.version)
|
||||
}
|
||||
|
||||
|
@ -160,6 +160,10 @@ impl Version {
|
||||
fn supports_std430_layout(&self) -> bool {
|
||||
*self >= Version::Desktop(430) || *self >= Version::Embedded(310)
|
||||
}
|
||||
|
||||
fn supports_fma_function(&self) -> bool {
|
||||
*self >= Version::Desktop(400) || *self >= Version::Embedded(310)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for Version {
|
||||
@ -2471,7 +2475,30 @@ impl<'a, W: Write> Writer<'a, W> {
|
||||
Mf::Refract => "refract",
|
||||
// computational
|
||||
Mf::Sign => "sign",
|
||||
Mf::Fma => "fma",
|
||||
Mf::Fma => {
|
||||
if self.options.version.supports_fma_function() {
|
||||
// Use the fma function when available
|
||||
"fma"
|
||||
} else {
|
||||
// No fma support. Transform the function call into an arithmetic expression
|
||||
write!(self.out, "(")?;
|
||||
|
||||
self.write_expr(arg, ctx)?;
|
||||
write!(self.out, " * ")?;
|
||||
|
||||
let arg1 =
|
||||
arg1.ok_or_else(|| Error::Custom("Missing fma arg1".to_owned()))?;
|
||||
self.write_expr(arg1, ctx)?;
|
||||
write!(self.out, " + ")?;
|
||||
|
||||
let arg2 =
|
||||
arg2.ok_or_else(|| Error::Custom("Missing fma arg2".to_owned()))?;
|
||||
self.write_expr(arg2, ctx)?;
|
||||
write!(self.out, ")")?;
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
Mf::Mix => "mix",
|
||||
Mf::Step => "step",
|
||||
Mf::SmoothStep => "smoothstep",
|
||||
|
@ -1862,7 +1862,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
Mf::Refract => Function::Regular("refract"),
|
||||
// computational
|
||||
Mf::Sign => Function::Regular("sign"),
|
||||
Mf::Fma => Function::Regular("fma"),
|
||||
Mf::Fma => Function::Regular("mad"),
|
||||
Mf::Mix => Function::Regular("lerp"),
|
||||
Mf::Step => Function::Regular("step"),
|
||||
Mf::SmoothStep => Function::Regular("smoothstep"),
|
||||
|
7
tests/in/functions-webgl.param.ron
Normal file
7
tests/in/functions-webgl.param.ron
Normal file
@ -0,0 +1,7 @@
|
||||
(
|
||||
glsl: (
|
||||
version: Embedded(300),
|
||||
writer_flags: (bits: 0),
|
||||
binding_map: {},
|
||||
),
|
||||
)
|
13
tests/in/functions-webgl.wgsl
Normal file
13
tests/in/functions-webgl.wgsl
Normal file
@ -0,0 +1,13 @@
|
||||
fn test_fma() -> vec2<f32> {
|
||||
let a = vec2<f32>(2.0, 2.0);
|
||||
let b = vec2<f32>(0.5, 0.5);
|
||||
let c = vec2<f32>(0.5, 0.5);
|
||||
|
||||
return fma(a, b, c);
|
||||
}
|
||||
|
||||
|
||||
[[stage(vertex)]]
|
||||
fn main() {
|
||||
let a = test_fma();
|
||||
}
|
2
tests/in/functions.param.ron
Normal file
2
tests/in/functions.param.ron
Normal file
@ -0,0 +1,2 @@
|
||||
(
|
||||
)
|
15
tests/in/functions.wgsl
Normal file
15
tests/in/functions.wgsl
Normal file
@ -0,0 +1,15 @@
|
||||
fn test_fma() -> vec2<f32> {
|
||||
let a = vec2<f32>(2.0, 2.0);
|
||||
let b = vec2<f32>(0.5, 0.5);
|
||||
let c = vec2<f32>(0.5, 0.5);
|
||||
|
||||
// Hazard: HLSL needs a different intrinsic function for f32 and f64
|
||||
// See: https://github.com/gfx-rs/naga/issues/1579
|
||||
return fma(a, b, c);
|
||||
}
|
||||
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
let a = test_fma();
|
||||
}
|
18
tests/out/glsl/functions-webgl.main.Vertex.glsl
Normal file
18
tests/out/glsl/functions-webgl.main.Vertex.glsl
Normal file
@ -0,0 +1,18 @@
|
||||
#version 300 es
|
||||
|
||||
precision highp float;
|
||||
precision highp int;
|
||||
|
||||
|
||||
vec2 test_fma() {
|
||||
vec2 a = vec2(2.0, 2.0);
|
||||
vec2 b = vec2(0.5, 0.5);
|
||||
vec2 c = vec2(0.5, 0.5);
|
||||
return (a * b + c);
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec2 _e0 = test_fma();
|
||||
return;
|
||||
}
|
||||
|
21
tests/out/glsl/functions.main.Compute.glsl
Normal file
21
tests/out/glsl/functions.main.Compute.glsl
Normal file
@ -0,0 +1,21 @@
|
||||
#version 310 es
|
||||
#extension GL_EXT_gpu_shader5 : require
|
||||
|
||||
precision highp float;
|
||||
precision highp int;
|
||||
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
|
||||
vec2 test_fma() {
|
||||
vec2 a = vec2(2.0, 2.0);
|
||||
vec2 b = vec2(0.5, 0.5);
|
||||
vec2 c = vec2(0.5, 0.5);
|
||||
return fma(a, b, c);
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec2 _e0 = test_fma();
|
||||
return;
|
||||
}
|
||||
|
15
tests/out/hlsl/functions.hlsl
Normal file
15
tests/out/hlsl/functions.hlsl
Normal file
@ -0,0 +1,15 @@
|
||||
|
||||
float2 test_fma()
|
||||
{
|
||||
float2 a = float2(2.0, 2.0);
|
||||
float2 b = float2(0.5, 0.5);
|
||||
float2 c = float2(0.5, 0.5);
|
||||
return mad(a, b, c);
|
||||
}
|
||||
|
||||
[numthreads(1, 1, 1)]
|
||||
void main()
|
||||
{
|
||||
const float2 _e0 = test_fma();
|
||||
return;
|
||||
}
|
3
tests/out/hlsl/functions.hlsl.config
Normal file
3
tests/out/hlsl/functions.hlsl.config
Normal file
@ -0,0 +1,3 @@
|
||||
vertex=()
|
||||
fragment=()
|
||||
compute=(main:cs_5_1 )
|
18
tests/out/msl/functions.msl
Normal file
18
tests/out/msl/functions.msl
Normal file
@ -0,0 +1,18 @@
|
||||
// language: metal1.1
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
|
||||
|
||||
metal::float2 test_fma(
|
||||
) {
|
||||
metal::float2 a = metal::float2(2.0, 2.0);
|
||||
metal::float2 b = metal::float2(0.5, 0.5);
|
||||
metal::float2 c = metal::float2(0.5, 0.5);
|
||||
return metal::fma(a, b, c);
|
||||
}
|
||||
|
||||
kernel void main_(
|
||||
) {
|
||||
metal::float2 _e0 = test_fma();
|
||||
return;
|
||||
}
|
33
tests/out/spv/functions.spvasm
Normal file
33
tests/out/spv/functions.spvasm
Normal file
@ -0,0 +1,33 @@
|
||||
; SPIR-V
|
||||
; Version: 1.1
|
||||
; Generator: rspirv
|
||||
; Bound: 20
|
||||
OpCapability Shader
|
||||
%1 = OpExtInstImport "GLSL.std.450"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %16 "main"
|
||||
OpExecutionMode %16 LocalSize 1 1 1
|
||||
%2 = OpTypeVoid
|
||||
%4 = OpTypeFloat 32
|
||||
%3 = OpConstant %4 2.0
|
||||
%5 = OpConstant %4 0.5
|
||||
%6 = OpTypeVector %4 2
|
||||
%9 = OpTypeFunction %6
|
||||
%17 = OpTypeFunction %2
|
||||
%8 = OpFunction %6 None %9
|
||||
%7 = OpLabel
|
||||
OpBranch %10
|
||||
%10 = OpLabel
|
||||
%11 = OpCompositeConstruct %6 %3 %3
|
||||
%12 = OpCompositeConstruct %6 %5 %5
|
||||
%13 = OpCompositeConstruct %6 %5 %5
|
||||
%14 = OpExtInst %6 %1 Fma %11 %12 %13
|
||||
OpReturnValue %14
|
||||
OpFunctionEnd
|
||||
%16 = OpFunction %2 None %17
|
||||
%15 = OpLabel
|
||||
OpBranch %18
|
||||
%18 = OpLabel
|
||||
%19 = OpFunctionCall %6 %8
|
||||
OpReturn
|
||||
OpFunctionEnd
|
12
tests/out/wgsl/functions.wgsl
Normal file
12
tests/out/wgsl/functions.wgsl
Normal file
@ -0,0 +1,12 @@
|
||||
fn test_fma() -> vec2<f32> {
|
||||
let a = vec2<f32>(2.0, 2.0);
|
||||
let b = vec2<f32>(0.5, 0.5);
|
||||
let c = vec2<f32>(0.5, 0.5);
|
||||
return fma(a, b, c);
|
||||
}
|
||||
|
||||
[[stage(compute), workgroup_size(1, 1, 1)]]
|
||||
fn main() {
|
||||
let _e0 = test_fma();
|
||||
return;
|
||||
}
|
@ -443,6 +443,11 @@ fn convert_wgsl() {
|
||||
"operators",
|
||||
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
|
||||
),
|
||||
(
|
||||
"functions",
|
||||
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
|
||||
),
|
||||
("functions-webgl", Targets::GLSL),
|
||||
(
|
||||
"interpolate",
|
||||
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
|
||||
|
Loading…
Reference in New Issue
Block a user