mirror of
https://github.com/gfx-rs/wgpu.git
synced 2024-11-22 14:55:05 +00:00
[glsl-out] Convert modulo operator on float to SPIR-V OpFRem equivalent function (#1452)
This commit is contained in:
parent
2e7d629aef
commit
943235cd5e
@ -311,6 +311,16 @@ pub enum Error {
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
/// Binary operation with a different logic on the GLSL side
|
||||
enum BinaryOperation {
|
||||
/// Vector comparison should use the function like `greaterThan()`, etc.
|
||||
VectorCompare,
|
||||
/// GLSL `%` is SPIR-V `OpUMod/OpSMod` and `mod()` is `OpFMod`, but [`BinaryOperator::Modulo`](crate::BinaryOperator::Modulo) is `OpFRem`
|
||||
Modulo,
|
||||
/// Any plain operation. No additional logic required
|
||||
Other,
|
||||
}
|
||||
|
||||
/// Main structure of the glsl backend responsible for all code generation
|
||||
pub struct Writer<'a, W> {
|
||||
// Inputs
|
||||
@ -2214,36 +2224,81 @@ impl<'a, W: Write> Writer<'a, W> {
|
||||
let right_inner = ctx.info[right].ty.inner_with(&self.module.types);
|
||||
|
||||
let function = match (left_inner, right_inner) {
|
||||
(&Ti::Vector { .. }, &Ti::Vector { .. }) => match op {
|
||||
Bo::Less => Some("lessThan"),
|
||||
Bo::LessEqual => Some("lessThanEqual"),
|
||||
Bo::Greater => Some("greaterThan"),
|
||||
Bo::GreaterEqual => Some("greaterThanEqual"),
|
||||
Bo::Equal => Some("equal"),
|
||||
Bo::NotEqual => Some("notEqual"),
|
||||
_ => None,
|
||||
(
|
||||
&Ti::Vector {
|
||||
kind: left_kind, ..
|
||||
},
|
||||
&Ti::Vector {
|
||||
kind: right_kind, ..
|
||||
},
|
||||
) => match op {
|
||||
Bo::Less
|
||||
| Bo::LessEqual
|
||||
| Bo::Greater
|
||||
| Bo::GreaterEqual
|
||||
| Bo::Equal
|
||||
| Bo::NotEqual => BinaryOperation::VectorCompare,
|
||||
Bo::Modulo => match (left_kind, right_kind) {
|
||||
(Sk::Float, _) | (_, Sk::Float) => match op {
|
||||
Bo::Modulo => BinaryOperation::Modulo,
|
||||
_ => BinaryOperation::Other,
|
||||
},
|
||||
_ => BinaryOperation::Other,
|
||||
},
|
||||
_ => BinaryOperation::Other,
|
||||
},
|
||||
_ => match (left_inner.scalar_kind(), right_inner.scalar_kind()) {
|
||||
(Some(Sk::Float), _) | (_, Some(Sk::Float)) => match op {
|
||||
Bo::Modulo => Some("mod"),
|
||||
_ => None,
|
||||
Bo::Modulo => BinaryOperation::Modulo,
|
||||
_ => BinaryOperation::Other,
|
||||
},
|
||||
_ => None,
|
||||
_ => BinaryOperation::Other,
|
||||
},
|
||||
};
|
||||
|
||||
write!(self.out, "{}(", function.unwrap_or(""))?;
|
||||
self.write_expr(left, ctx)?;
|
||||
match function {
|
||||
BinaryOperation::VectorCompare => {
|
||||
let op_str = match op {
|
||||
Bo::Less => "lessThan(",
|
||||
Bo::LessEqual => "lessThanEqual(",
|
||||
Bo::Greater => "greaterThan(",
|
||||
Bo::GreaterEqual => "greaterThanEqual(",
|
||||
Bo::Equal => "equal(",
|
||||
Bo::NotEqual => "notEqual(",
|
||||
_ => unreachable!(),
|
||||
};
|
||||
write!(self.out, "{}", op_str)?;
|
||||
self.write_expr(left, ctx)?;
|
||||
write!(self.out, ", ")?;
|
||||
self.write_expr(right, ctx)?;
|
||||
write!(self.out, ")")?;
|
||||
}
|
||||
BinaryOperation::Modulo => {
|
||||
write!(self.out, "(")?;
|
||||
|
||||
if function.is_some() {
|
||||
write!(self.out, ",")?
|
||||
} else {
|
||||
write!(self.out, " {} ", super::binary_operation_str(op))?;
|
||||
// write `e1 - e2 * trunc(e1 / e2)`
|
||||
self.write_expr(left, ctx)?;
|
||||
write!(self.out, " - ")?;
|
||||
self.write_expr(right, ctx)?;
|
||||
write!(self.out, " * ")?;
|
||||
write!(self.out, "trunc(")?;
|
||||
self.write_expr(left, ctx)?;
|
||||
write!(self.out, " / ")?;
|
||||
self.write_expr(right, ctx)?;
|
||||
write!(self.out, ")")?;
|
||||
|
||||
write!(self.out, ")")?;
|
||||
}
|
||||
BinaryOperation::Other => {
|
||||
write!(self.out, "(")?;
|
||||
|
||||
self.write_expr(left, ctx)?;
|
||||
write!(self.out, " {} ", super::binary_operation_str(op))?;
|
||||
self.write_expr(right, ctx)?;
|
||||
|
||||
write!(self.out, ")")?;
|
||||
}
|
||||
}
|
||||
|
||||
self.write_expr(right, ctx)?;
|
||||
|
||||
write!(self.out, ")")?
|
||||
}
|
||||
// `Select` is written as `condition ? accept : reject`
|
||||
// We wrap everything in parentheses to avoid precedence issues
|
||||
|
@ -813,6 +813,7 @@ pub enum BinaryOperator {
|
||||
Subtract,
|
||||
Multiply,
|
||||
Divide,
|
||||
/// Equivalent of the WGSL's `%` operator or SPIR-V's `OpFRem`
|
||||
Modulo,
|
||||
Equal,
|
||||
NotEqual,
|
||||
|
@ -44,10 +44,19 @@ fn constructors() -> f32 {
|
||||
return foo.a.x;
|
||||
}
|
||||
|
||||
fn modulo() {
|
||||
// Modulo operator on float scalar or vector must be converted to mod function for GLSL
|
||||
let a = 1 % 1;
|
||||
let b = 1.0 % 1.0;
|
||||
let c = vec3<i32>(1) % vec3<i32>(1);
|
||||
let d = vec3<f32>(1.0) % vec3<f32>(1.0);
|
||||
}
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
let a = builtins();
|
||||
let b = splat();
|
||||
let c = unary();
|
||||
let d = constructors();
|
||||
modulo();
|
||||
}
|
||||
|
@ -44,11 +44,19 @@ float constructors() {
|
||||
return _e11;
|
||||
}
|
||||
|
||||
void modulo() {
|
||||
int a1 = (1 % 1);
|
||||
float b1 = (1.0 - 1.0 * trunc(1.0 / 1.0));
|
||||
ivec3 c = (ivec3(1) % ivec3(1));
|
||||
vec3 d = (vec3(1.0) - vec3(1.0) * trunc(vec3(1.0) / vec3(1.0)));
|
||||
}
|
||||
|
||||
void main() {
|
||||
vec4 _e4 = builtins();
|
||||
vec4 _e5 = splat();
|
||||
int _e6 = unary();
|
||||
float _e7 = constructors();
|
||||
modulo();
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -53,6 +53,14 @@ float constructors()
|
||||
return _expr11;
|
||||
}
|
||||
|
||||
void modulo()
|
||||
{
|
||||
int a1 = (1 % 1);
|
||||
float b1 = (1.0 % 1.0);
|
||||
int3 c = (int3(1.xxx) % int3(1.xxx));
|
||||
float3 d = (float3(1.0.xxx) % float3(1.0.xxx));
|
||||
}
|
||||
|
||||
[numthreads(1, 1, 1)]
|
||||
void main()
|
||||
{
|
||||
@ -60,5 +68,6 @@ void main()
|
||||
const float4 _e5 = splat();
|
||||
const int _e6 = unary();
|
||||
const float _e7 = constructors();
|
||||
modulo();
|
||||
return;
|
||||
}
|
||||
|
@ -48,11 +48,20 @@ float constructors(
|
||||
return _e11;
|
||||
}
|
||||
|
||||
void modulo(
|
||||
) {
|
||||
int a1 = 1 % 1;
|
||||
float b1 = metal::fmod(1.0, 1.0);
|
||||
metal::int3 c = metal::int3(1) % metal::int3(1);
|
||||
metal::float3 d = metal::fmod(metal::float3(1.0), metal::float3(1.0));
|
||||
}
|
||||
|
||||
kernel void main1(
|
||||
) {
|
||||
metal::float4 _e4 = builtins();
|
||||
metal::float4 _e5 = splat();
|
||||
int _e6 = unary();
|
||||
float _e7 = constructors();
|
||||
modulo();
|
||||
return;
|
||||
}
|
||||
|
@ -1,12 +1,12 @@
|
||||
; SPIR-V
|
||||
; Version: 1.1
|
||||
; Generator: rspirv
|
||||
; Bound: 101
|
||||
; Bound: 115
|
||||
OpCapability Shader
|
||||
%1 = OpExtInstImport "GLSL.std.450"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %94 "main"
|
||||
OpExecutionMode %94 LocalSize 1 1 1
|
||||
OpEntryPoint GLCompute %108 "main"
|
||||
OpExecutionMode %108 LocalSize 1 1 1
|
||||
OpMemberDecorate %22 0 Offset 0
|
||||
OpMemberDecorate %22 1 Offset 16
|
||||
%2 = OpTypeVoid
|
||||
@ -45,6 +45,8 @@ OpMemberDecorate %22 1 Offset 16
|
||||
%90 = OpTypeInt 32 0
|
||||
%89 = OpConstant %90 0
|
||||
%95 = OpTypeFunction %2
|
||||
%99 = OpTypeVector %8 3
|
||||
%103 = OpTypeVector %4 3
|
||||
%28 = OpFunction %19 None %29
|
||||
%27 = OpLabel
|
||||
OpBranch %30
|
||||
@ -122,9 +124,24 @@ OpFunctionEnd
|
||||
%93 = OpLabel
|
||||
OpBranch %96
|
||||
%96 = OpLabel
|
||||
%97 = OpFunctionCall %19 %28
|
||||
%98 = OpFunctionCall %19 %53
|
||||
%99 = OpFunctionCall %8 %70
|
||||
%100 = OpFunctionCall %4 %82
|
||||
%97 = OpSMod %8 %7 %7
|
||||
%98 = OpFMod %4 %3 %3
|
||||
%100 = OpCompositeConstruct %99 %7 %7 %7
|
||||
%101 = OpCompositeConstruct %99 %7 %7 %7
|
||||
%102 = OpSMod %99 %100 %101
|
||||
%104 = OpCompositeConstruct %103 %3 %3 %3
|
||||
%105 = OpCompositeConstruct %103 %3 %3 %3
|
||||
%106 = OpFMod %103 %104 %105
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
%108 = OpFunction %2 None %95
|
||||
%107 = OpLabel
|
||||
OpBranch %109
|
||||
%109 = OpLabel
|
||||
%110 = OpFunctionCall %19 %28
|
||||
%111 = OpFunctionCall %19 %53
|
||||
%112 = OpFunctionCall %8 %70
|
||||
%113 = OpFunctionCall %4 %82
|
||||
%114 = OpFunctionCall %2 %94
|
||||
OpReturn
|
||||
OpFunctionEnd
|
@ -41,11 +41,19 @@ fn constructors() -> f32 {
|
||||
return e11;
|
||||
}
|
||||
|
||||
fn modulo() {
|
||||
let a1: i32 = (1 % 1);
|
||||
let b1: f32 = (1.0 % 1.0);
|
||||
let c: vec3<i32> = (vec3<i32>(1) % vec3<i32>(1));
|
||||
let d: vec3<f32> = (vec3<f32>(1.0) % vec3<f32>(1.0));
|
||||
}
|
||||
|
||||
[[stage(compute), workgroup_size(1, 1, 1)]]
|
||||
fn main() {
|
||||
let e4: vec4<f32> = builtins();
|
||||
let e5: vec4<f32> = splat();
|
||||
let e6: i32 = unary();
|
||||
let e7: f32 = constructors();
|
||||
modulo();
|
||||
return;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user