[glsl-in] Convert bool -> scalar cast to Select

This commit is contained in:
Jasper St. Pierre 2021-07-01 19:57:51 -07:00 committed by Dzmitry Malyshau
parent a92f7689f2
commit 78e1304d42
4 changed files with 122 additions and 9 deletions

View File

@ -1,13 +1,35 @@
use crate::{
proc::ensure_block_returns, Arena, BinaryOperator, Block, EntryPoint, Expression, Function,
proc::ensure_block_returns, Arena, BinaryOperator, Block, Constant, ConstantInner, EntryPoint, Expression, Function,
FunctionArgument, FunctionResult, Handle, ImageQuery, LocalVariable, MathFunction,
RelationalFunction, SampleLevel, ScalarKind, Statement, StructMember, SwizzleComponent, Type,
RelationalFunction, SampleLevel, ScalarKind, ScalarValue, Statement, StructMember, SwizzleComponent, Type,
TypeInner, VectorSize,
};
use super::{ast::*, error::ErrorKind, SourceMetadata};
impl Program<'_> {
fn add_constant_value(
&mut self,
scalar_kind: ScalarKind,
value: u64,
) -> Handle<Constant> {
let value = match scalar_kind {
ScalarKind::Uint => ScalarValue::Uint(value),
ScalarKind::Sint => ScalarValue::Sint(value as i64),
ScalarKind::Float => ScalarValue::Float(value as f64),
_ => unreachable!(),
};
self.module.constants.fetch_or_append(Constant {
name: None,
specialization: None,
inner: ConstantInner::Scalar {
width: 4,
value,
},
})
}
pub fn function_call(
&mut self,
ctx: &mut Context,
@ -24,13 +46,40 @@ impl Program<'_> {
match fc {
FunctionCallKind::TypeConstructor(ty) => {
let h = if args.len() == 1 {
let is_vec = match *self.resolve_type(ctx, args[0].0, args[0].1)? {
TypeInner::Vector { .. } => true,
_ => false,
let expr_type = self.resolve_type(ctx, args[0].0, args[0].1)?;
let vector_size = match *expr_type {
TypeInner::Vector{ size, .. } => Some(size),
_ => None,
};
// Special case: if casting from a bool, we need to use Select and not As.
match self.module.types[ty].inner.scalar_kind() {
Some(result_scalar_kind) if expr_type.scalar_kind() == Some(ScalarKind::Bool) && result_scalar_kind != ScalarKind::Bool => {
let c0 = self.add_constant_value(result_scalar_kind, 0u64);
let c1 = self.add_constant_value(result_scalar_kind, 1u64);
let mut reject = ctx.add_expression(Expression::Constant(c0), body);
let mut accept = ctx.add_expression(Expression::Constant(c1), body);
ctx.implicit_splat(self, &mut reject, meta, vector_size)?;
ctx.implicit_splat(self, &mut accept, meta, vector_size)?;
let h = ctx.add_expression(
Expression::Select {
accept,
reject,
condition: args[0].0,
},
body,
);
return Ok(Some(h));
}
_ => {}
}
match self.module.types[ty].inner {
TypeInner::Vector { size, kind, .. } if !is_vec => {
TypeInner::Vector { size, kind, .. } if vector_size.is_none() => {
let (mut value, meta) = args[0];
ctx.implicit_conversion(self, &mut value, meta, kind)?;

View File

@ -1159,13 +1159,15 @@ impl super::Validator {
.resolve(expr)?
.scalar_kind()
.ok_or(ExpressionError::InvalidCastArgument)?;
if prev_kind == Sk::Bool || kind == Sk::Bool {
return Err(ExpressionError::InvalidCastArgument);
}
match convert {
Some(width) if !self.check_width(kind, width) => {
return Err(ExpressionError::InvalidCastArgument)
}
None if prev_kind == Sk::Bool || kind == Sk::Bool => {
return Err(ExpressionError::InvalidCastArgument)
}
_ => {}
}
ShaderStages::all()

View File

@ -0,0 +1,17 @@
#version 440 core
precision highp float;
layout(location = 0) out vec4 o_color;
float TevPerCompGT(float a, float b) {
return float(a > b);
}
vec3 TevPerCompGT(vec3 a, vec3 b) {
return vec3(greaterThan(a, b));
}
void main() {
o_color.rgb = TevPerCompGT(vec3(3.0), vec3(5.0));
o_color.a = TevPerCompGT(3.0, 5.0);
}

View File

@ -0,0 +1,45 @@
struct FragmentOutput {
[[location(0), interpolate(perspective)]] o_color: vec4<f32>;
};
var<private> o_color: vec4<f32>;
fn TevPerCompGT(a: f32, b: f32) -> f32 {
var a1: f32;
var b1: f32;
a1 = a;
b1 = b;
let _e5: f32 = a1;
let _e6: f32 = b1;
return select(1.0, 0.0, (_e5 > _e6));
}
fn TevPerCompGT1(a2: vec3<f32>, b2: vec3<f32>) -> vec3<f32> {
var a3: vec3<f32>;
var b3: vec3<f32>;
a3 = a2;
b3 = b2;
let _e5: vec3<f32> = a3;
let _e6: vec3<f32> = b3;
return select(vec3<f32>(1.0), vec3<f32>(0.0), (_e5 > _e6));
}
fn main1() {
let _e1: vec4<f32> = o_color;
let _e11: vec3<f32> = TevPerCompGT1(vec3<f32>(3.0), vec3<f32>(5.0));
o_color.x = _e11.x;
o_color.y = _e11.y;
o_color.z = _e11.z;
let _e23: f32 = TevPerCompGT(3.0, 5.0);
o_color.w = _e23;
return;
}
[[stage(fragment)]]
fn main() -> FragmentOutput {
main1();
let _e1: vec4<f32> = o_color;
return FragmentOutput(_e1);
}