[glsl-out] Convert modulo operator on float to SPIR-V OpFRem equivalent function (#1452)

This commit is contained in:
Igor Shaposhnik 2021-10-07 23:59:39 +03:00 committed by GitHub
parent 2e7d629aef
commit 943235cd5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 144 additions and 28 deletions

View File

@ -311,6 +311,16 @@ pub enum Error {
Custom(String),
}
/// Binary operation with a different logic on the GLSL side
enum BinaryOperation {
/// Vector comparison should use the function like `greaterThan()`, etc.
VectorCompare,
/// GLSL `%` is SPIR-V `OpUMod/OpSMod` and `mod()` is `OpFMod`, but [`BinaryOperator::Modulo`](crate::BinaryOperator::Modulo) is `OpFRem`
Modulo,
/// Any plain operation. No additional logic required
Other,
}
/// Main structure of the glsl backend responsible for all code generation
pub struct Writer<'a, W> {
// Inputs
@ -2214,36 +2224,81 @@ impl<'a, W: Write> Writer<'a, W> {
let right_inner = ctx.info[right].ty.inner_with(&self.module.types);
let function = match (left_inner, right_inner) {
(&Ti::Vector { .. }, &Ti::Vector { .. }) => match op {
Bo::Less => Some("lessThan"),
Bo::LessEqual => Some("lessThanEqual"),
Bo::Greater => Some("greaterThan"),
Bo::GreaterEqual => Some("greaterThanEqual"),
Bo::Equal => Some("equal"),
Bo::NotEqual => Some("notEqual"),
_ => None,
(
&Ti::Vector {
kind: left_kind, ..
},
&Ti::Vector {
kind: right_kind, ..
},
) => match op {
Bo::Less
| Bo::LessEqual
| Bo::Greater
| Bo::GreaterEqual
| Bo::Equal
| Bo::NotEqual => BinaryOperation::VectorCompare,
Bo::Modulo => match (left_kind, right_kind) {
(Sk::Float, _) | (_, Sk::Float) => match op {
Bo::Modulo => BinaryOperation::Modulo,
_ => BinaryOperation::Other,
},
_ => BinaryOperation::Other,
},
_ => BinaryOperation::Other,
},
_ => match (left_inner.scalar_kind(), right_inner.scalar_kind()) {
(Some(Sk::Float), _) | (_, Some(Sk::Float)) => match op {
Bo::Modulo => Some("mod"),
_ => None,
Bo::Modulo => BinaryOperation::Modulo,
_ => BinaryOperation::Other,
},
_ => None,
_ => BinaryOperation::Other,
},
};
write!(self.out, "{}(", function.unwrap_or(""))?;
self.write_expr(left, ctx)?;
match function {
BinaryOperation::VectorCompare => {
let op_str = match op {
Bo::Less => "lessThan(",
Bo::LessEqual => "lessThanEqual(",
Bo::Greater => "greaterThan(",
Bo::GreaterEqual => "greaterThanEqual(",
Bo::Equal => "equal(",
Bo::NotEqual => "notEqual(",
_ => unreachable!(),
};
write!(self.out, "{}", op_str)?;
self.write_expr(left, ctx)?;
write!(self.out, ", ")?;
self.write_expr(right, ctx)?;
write!(self.out, ")")?;
}
BinaryOperation::Modulo => {
write!(self.out, "(")?;
if function.is_some() {
write!(self.out, ",")?
} else {
write!(self.out, " {} ", super::binary_operation_str(op))?;
// write `e1 - e2 * trunc(e1 / e2)`
self.write_expr(left, ctx)?;
write!(self.out, " - ")?;
self.write_expr(right, ctx)?;
write!(self.out, " * ")?;
write!(self.out, "trunc(")?;
self.write_expr(left, ctx)?;
write!(self.out, " / ")?;
self.write_expr(right, ctx)?;
write!(self.out, ")")?;
write!(self.out, ")")?;
}
BinaryOperation::Other => {
write!(self.out, "(")?;
self.write_expr(left, ctx)?;
write!(self.out, " {} ", super::binary_operation_str(op))?;
self.write_expr(right, ctx)?;
write!(self.out, ")")?;
}
}
self.write_expr(right, ctx)?;
write!(self.out, ")")?
}
// `Select` is written as `condition ? accept : reject`
// We wrap everything in parentheses to avoid precedence issues

View File

@ -813,6 +813,7 @@ pub enum BinaryOperator {
Subtract,
Multiply,
Divide,
/// Equivalent of the WGSL's `%` operator or SPIR-V's `OpFRem`
Modulo,
Equal,
NotEqual,

View File

@ -44,10 +44,19 @@ fn constructors() -> f32 {
return foo.a.x;
}
fn modulo() {
// Modulo operator on float scalar or vector must be converted to mod function for GLSL
let a = 1 % 1;
let b = 1.0 % 1.0;
let c = vec3<i32>(1) % vec3<i32>(1);
let d = vec3<f32>(1.0) % vec3<f32>(1.0);
}
[[stage(compute), workgroup_size(1)]]
fn main() {
let a = builtins();
let b = splat();
let c = unary();
let d = constructors();
modulo();
}

View File

@ -44,11 +44,19 @@ float constructors() {
return _e11;
}
void modulo() {
int a1 = (1 % 1);
float b1 = (1.0 - 1.0 * trunc(1.0 / 1.0));
ivec3 c = (ivec3(1) % ivec3(1));
vec3 d = (vec3(1.0) - vec3(1.0) * trunc(vec3(1.0) / vec3(1.0)));
}
void main() {
vec4 _e4 = builtins();
vec4 _e5 = splat();
int _e6 = unary();
float _e7 = constructors();
modulo();
return;
}

View File

@ -53,6 +53,14 @@ float constructors()
return _expr11;
}
void modulo()
{
int a1 = (1 % 1);
float b1 = (1.0 % 1.0);
int3 c = (int3(1.xxx) % int3(1.xxx));
float3 d = (float3(1.0.xxx) % float3(1.0.xxx));
}
[numthreads(1, 1, 1)]
void main()
{
@ -60,5 +68,6 @@ void main()
const float4 _e5 = splat();
const int _e6 = unary();
const float _e7 = constructors();
modulo();
return;
}

View File

@ -48,11 +48,20 @@ float constructors(
return _e11;
}
void modulo(
) {
int a1 = 1 % 1;
float b1 = metal::fmod(1.0, 1.0);
metal::int3 c = metal::int3(1) % metal::int3(1);
metal::float3 d = metal::fmod(metal::float3(1.0), metal::float3(1.0));
}
kernel void main1(
) {
metal::float4 _e4 = builtins();
metal::float4 _e5 = splat();
int _e6 = unary();
float _e7 = constructors();
modulo();
return;
}

View File

@ -1,12 +1,12 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 101
; Bound: 115
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %94 "main"
OpExecutionMode %94 LocalSize 1 1 1
OpEntryPoint GLCompute %108 "main"
OpExecutionMode %108 LocalSize 1 1 1
OpMemberDecorate %22 0 Offset 0
OpMemberDecorate %22 1 Offset 16
%2 = OpTypeVoid
@ -45,6 +45,8 @@ OpMemberDecorate %22 1 Offset 16
%90 = OpTypeInt 32 0
%89 = OpConstant %90 0
%95 = OpTypeFunction %2
%99 = OpTypeVector %8 3
%103 = OpTypeVector %4 3
%28 = OpFunction %19 None %29
%27 = OpLabel
OpBranch %30
@ -122,9 +124,24 @@ OpFunctionEnd
%93 = OpLabel
OpBranch %96
%96 = OpLabel
%97 = OpFunctionCall %19 %28
%98 = OpFunctionCall %19 %53
%99 = OpFunctionCall %8 %70
%100 = OpFunctionCall %4 %82
%97 = OpSMod %8 %7 %7
%98 = OpFMod %4 %3 %3
%100 = OpCompositeConstruct %99 %7 %7 %7
%101 = OpCompositeConstruct %99 %7 %7 %7
%102 = OpSMod %99 %100 %101
%104 = OpCompositeConstruct %103 %3 %3 %3
%105 = OpCompositeConstruct %103 %3 %3 %3
%106 = OpFMod %103 %104 %105
OpReturn
OpFunctionEnd
%108 = OpFunction %2 None %95
%107 = OpLabel
OpBranch %109
%109 = OpLabel
%110 = OpFunctionCall %19 %28
%111 = OpFunctionCall %19 %53
%112 = OpFunctionCall %8 %70
%113 = OpFunctionCall %4 %82
%114 = OpFunctionCall %2 %94
OpReturn
OpFunctionEnd

View File

@ -41,11 +41,19 @@ fn constructors() -> f32 {
return e11;
}
fn modulo() {
let a1: i32 = (1 % 1);
let b1: f32 = (1.0 % 1.0);
let c: vec3<i32> = (vec3<i32>(1) % vec3<i32>(1));
let d: vec3<f32> = (vec3<f32>(1.0) % vec3<f32>(1.0));
}
[[stage(compute), workgroup_size(1, 1, 1)]]
fn main() {
let e4: vec4<f32> = builtins();
let e5: vec4<f32> = splat();
let e6: i32 = unary();
let e7: f32 = constructors();
modulo();
return;
}