Fixes for fma function (#1580)

* [hlsl-out] Write `mad` intrinsic for `fma` function

- This should be enough because we only support f32 for now.
- Adds a new test for WGSL functions, in the spirit of operators.wgsl.
- Closes #1579

* Add FMA feature to glsl backend

- I think this is right. Just iterate all known expressions in all
  functions and entry points to locate any `fma` function call.
  Should not need to walk the statement DAG.

* Transform GLSL fma function into an airthmetic expression when necessary

* Add tests for GLSL fma function tranformation

* Remove the hazard comment from the webgl test input

* Add helper method for fma function support checks

* Address review comment
This commit is contained in:
Jay Oster 2021-12-22 06:41:07 -08:00 committed by GitHub
parent 54178dede2
commit 924ab17b62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 222 additions and 4 deletions

View File

@ -1,7 +1,7 @@
use super::{BackendResult, Error, Version, Writer};
use crate::{
Binding, Bytes, Handle, ImageClass, ImageDimension, Interpolation, Sampling, ScalarKind,
ShaderStage, StorageClass, StorageFormat, Type, TypeInner,
Binding, Bytes, Expression, Handle, ImageClass, ImageDimension, Interpolation, MathFunction,
Sampling, ScalarKind, ShaderStage, StorageClass, StorageFormat, Type, TypeInner,
};
use std::fmt::Write;
@ -33,6 +33,8 @@ bitflags::bitflags! {
/// Arrays with a dynamic length
const DYNAMIC_ARRAY_SIZE = 1 << 16;
const MULTI_VIEW = 1 << 17;
/// Adds support for fused multiply-add
const FMA = 1 << 18;
}
}
@ -98,6 +100,7 @@ impl FeaturesManager {
check_feature!(SAMPLE_VARIABLES, 400, 300);
check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310);
check_feature!(MULTI_VIEW, 140, 310);
check_feature!(FMA, 400, 310);
// Return an error if there are missing features
if missing.is_empty() {
@ -201,6 +204,11 @@ impl FeaturesManager {
writeln!(out, "#extension GL_EXT_multiview : require")?;
}
if self.0.contains(Features::FMA) && version.is_es() {
// https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_gpu_shader5.txt
writeln!(out, "#extension GL_EXT_gpu_shader5 : require")?;
}
Ok(())
}
}
@ -347,6 +355,27 @@ impl<'a, W> Writer<'a, W> {
}
}
if self.options.version.supports_fma_function() {
let has_fma = self
.module
.functions
.iter()
.flat_map(|(_, f)| f.expressions.iter())
.chain(
self.module
.entry_points
.iter()
.flat_map(|e| e.function.expressions.iter()),
)
.any(|(_, e)| match *e {
Expression::Math { fun, .. } if fun == MathFunction::Fma => true,
_ => false,
});
if has_fma {
self.features.request(Features::FMA);
}
}
self.features.check_availability(self.options.version)
}

View File

@ -160,6 +160,10 @@ impl Version {
fn supports_std430_layout(&self) -> bool {
*self >= Version::Desktop(430) || *self >= Version::Embedded(310)
}
fn supports_fma_function(&self) -> bool {
*self >= Version::Desktop(400) || *self >= Version::Embedded(310)
}
}
impl PartialOrd for Version {
@ -2471,7 +2475,30 @@ impl<'a, W: Write> Writer<'a, W> {
Mf::Refract => "refract",
// computational
Mf::Sign => "sign",
Mf::Fma => "fma",
Mf::Fma => {
if self.options.version.supports_fma_function() {
// Use the fma function when available
"fma"
} else {
// No fma support. Transform the function call into an arithmetic expression
write!(self.out, "(")?;
self.write_expr(arg, ctx)?;
write!(self.out, " * ")?;
let arg1 =
arg1.ok_or_else(|| Error::Custom("Missing fma arg1".to_owned()))?;
self.write_expr(arg1, ctx)?;
write!(self.out, " + ")?;
let arg2 =
arg2.ok_or_else(|| Error::Custom("Missing fma arg2".to_owned()))?;
self.write_expr(arg2, ctx)?;
write!(self.out, ")")?;
return Ok(());
}
}
Mf::Mix => "mix",
Mf::Step => "step",
Mf::SmoothStep => "smoothstep",

View File

@ -1862,7 +1862,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Mf::Refract => Function::Regular("refract"),
// computational
Mf::Sign => Function::Regular("sign"),
Mf::Fma => Function::Regular("fma"),
Mf::Fma => Function::Regular("mad"),
Mf::Mix => Function::Regular("lerp"),
Mf::Step => Function::Regular("step"),
Mf::SmoothStep => Function::Regular("smoothstep"),

View File

@ -0,0 +1,7 @@
(
glsl: (
version: Embedded(300),
writer_flags: (bits: 0),
binding_map: {},
),
)

View File

@ -0,0 +1,13 @@
fn test_fma() -> vec2<f32> {
let a = vec2<f32>(2.0, 2.0);
let b = vec2<f32>(0.5, 0.5);
let c = vec2<f32>(0.5, 0.5);
return fma(a, b, c);
}
[[stage(vertex)]]
fn main() {
let a = test_fma();
}

View File

@ -0,0 +1,2 @@
(
)

15
tests/in/functions.wgsl Normal file
View File

@ -0,0 +1,15 @@
fn test_fma() -> vec2<f32> {
let a = vec2<f32>(2.0, 2.0);
let b = vec2<f32>(0.5, 0.5);
let c = vec2<f32>(0.5, 0.5);
// Hazard: HLSL needs a different intrinsic function for f32 and f64
// See: https://github.com/gfx-rs/naga/issues/1579
return fma(a, b, c);
}
[[stage(compute), workgroup_size(1)]]
fn main() {
let a = test_fma();
}

View File

@ -0,0 +1,18 @@
#version 300 es
precision highp float;
precision highp int;
vec2 test_fma() {
vec2 a = vec2(2.0, 2.0);
vec2 b = vec2(0.5, 0.5);
vec2 c = vec2(0.5, 0.5);
return (a * b + c);
}
void main() {
vec2 _e0 = test_fma();
return;
}

View File

@ -0,0 +1,21 @@
#version 310 es
#extension GL_EXT_gpu_shader5 : require
precision highp float;
precision highp int;
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
vec2 test_fma() {
vec2 a = vec2(2.0, 2.0);
vec2 b = vec2(0.5, 0.5);
vec2 c = vec2(0.5, 0.5);
return fma(a, b, c);
}
void main() {
vec2 _e0 = test_fma();
return;
}

View File

@ -0,0 +1,15 @@
float2 test_fma()
{
float2 a = float2(2.0, 2.0);
float2 b = float2(0.5, 0.5);
float2 c = float2(0.5, 0.5);
return mad(a, b, c);
}
[numthreads(1, 1, 1)]
void main()
{
const float2 _e0 = test_fma();
return;
}

View File

@ -0,0 +1,3 @@
vertex=()
fragment=()
compute=(main:cs_5_1 )

View File

@ -0,0 +1,18 @@
// language: metal1.1
#include <metal_stdlib>
#include <simd/simd.h>
metal::float2 test_fma(
) {
metal::float2 a = metal::float2(2.0, 2.0);
metal::float2 b = metal::float2(0.5, 0.5);
metal::float2 c = metal::float2(0.5, 0.5);
return metal::fma(a, b, c);
}
kernel void main_(
) {
metal::float2 _e0 = test_fma();
return;
}

View File

@ -0,0 +1,33 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 20
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %16 "main"
OpExecutionMode %16 LocalSize 1 1 1
%2 = OpTypeVoid
%4 = OpTypeFloat 32
%3 = OpConstant %4 2.0
%5 = OpConstant %4 0.5
%6 = OpTypeVector %4 2
%9 = OpTypeFunction %6
%17 = OpTypeFunction %2
%8 = OpFunction %6 None %9
%7 = OpLabel
OpBranch %10
%10 = OpLabel
%11 = OpCompositeConstruct %6 %3 %3
%12 = OpCompositeConstruct %6 %5 %5
%13 = OpCompositeConstruct %6 %5 %5
%14 = OpExtInst %6 %1 Fma %11 %12 %13
OpReturnValue %14
OpFunctionEnd
%16 = OpFunction %2 None %17
%15 = OpLabel
OpBranch %18
%18 = OpLabel
%19 = OpFunctionCall %6 %8
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,12 @@
fn test_fma() -> vec2<f32> {
let a = vec2<f32>(2.0, 2.0);
let b = vec2<f32>(0.5, 0.5);
let c = vec2<f32>(0.5, 0.5);
return fma(a, b, c);
}
[[stage(compute), workgroup_size(1, 1, 1)]]
fn main() {
let _e0 = test_fma();
return;
}

View File

@ -443,6 +443,11 @@ fn convert_wgsl() {
"operators",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
),
(
"functions",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
),
("functions-webgl", Targets::GLSL),
(
"interpolate",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,