[naga spv-out] Expand LocalType to permit pointers to matrices.

In `back::spv`:

- Factor out the numeric variants of `LocalType` into a
  new enum, `NumericType`.

- Split the `Value` variant into `Numeric` and `LocalPointer`
  variants, and let `LocalPointer` point to any numeric type,
  including matrices.

In subsequent commits, we'll need to spill matrices out into temporary
local variables. This means we'll need to generate SPIR-V
pointer-to-matrix types, so `LocalType` needs to be able to represent
that.
This commit is contained in:
Jim Blandy 2024-10-08 14:50:24 -07:00
parent 0392cb783d
commit 908e8353a8
6 changed files with 257 additions and 355 deletions

View File

@ -4,7 +4,7 @@ Implementations for `BlockContext` methods.
use super::{
helpers, index::BoundsCheckResult, selection::Selection, Block, BlockContext, Dimension, Error,
Instruction, LocalType, LookupType, ResultMember, Writer, WriterFlags,
Instruction, LocalType, LookupType, NumericType, ResultMember, Writer, WriterFlags,
};
use crate::{arena::Handle, proc::TypeResolution, Statement};
use spirv::Word;
@ -105,10 +105,9 @@ impl Writer {
position_id: Word,
body: &mut Vec<Instruction>,
) -> Result<(), Error> {
let float_ptr_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
scalar: crate::Scalar::F32,
pointer_space: Some(spirv::StorageClass::Output),
let float_ptr_type_id = self.get_type_id(LookupType::Local(LocalType::LocalPointer {
base: NumericType::Scalar(crate::Scalar::F32),
class: spirv::StorageClass::Output,
}));
let index_y_id = self.get_index_constant(1);
let access_id = self.id_gen.next();
@ -119,11 +118,9 @@ impl Writer {
&[index_y_id],
));
let float_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
scalar: crate::Scalar::F32,
pointer_space: None,
}));
let float_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(crate::Scalar::F32),
)));
let load_id = self.id_gen.next();
body.push(Instruction::load(float_type_id, load_id, access_id, None));
@ -145,11 +142,9 @@ impl Writer {
frag_depth_id: Word,
body: &mut Vec<Instruction>,
) -> Result<(), Error> {
let float_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
scalar: crate::Scalar::F32,
pointer_space: None,
}));
let float_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(crate::Scalar::F32),
)));
let zero_scalar_id = self.get_constant_scalar(crate::Literal::F32(0.0));
let one_scalar_id = self.get_constant_scalar(crate::Literal::F32(1.0));
@ -830,12 +825,8 @@ impl<'w> BlockContext<'w> {
let mut arg2_id = self.writer.get_constant_scalar_with(1, scalar)?;
if let Some(size) = maybe_size {
let ty = LocalType::Value {
vector_size: Some(size),
scalar,
pointer_space: None,
}
.into();
let ty =
LocalType::Numeric(NumericType::Vector { size, scalar }).into();
self.temp_list.clear();
self.temp_list.resize(size as _, arg1_id);
@ -950,12 +941,9 @@ impl<'w> BlockContext<'w> {
&crate::TypeInner::Vector { size, .. },
&crate::TypeInner::Scalar(scalar),
) => {
let selector_type_id =
self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: Some(size),
scalar,
pointer_space: None,
}));
let selector_type_id = self.get_type_id(LookupType::Local(
LocalType::Numeric(NumericType::Vector { size, scalar }),
));
self.temp_list.clear();
self.temp_list.resize(size as usize, arg2_id);
@ -998,12 +986,8 @@ impl<'w> BlockContext<'w> {
Mf::CountTrailingZeros => {
let uint_id = match *arg_ty {
crate::TypeInner::Vector { size, scalar } => {
let ty = LocalType::Value {
vector_size: Some(size),
scalar,
pointer_space: None,
}
.into();
let ty =
LocalType::Numeric(NumericType::Vector { size, scalar }).into();
self.temp_list.clear();
self.temp_list.resize(
@ -1040,12 +1024,8 @@ impl<'w> BlockContext<'w> {
Mf::CountLeadingZeros => {
let (int_type_id, int_id, width) = match *arg_ty {
crate::TypeInner::Vector { size, scalar } => {
let ty = LocalType::Value {
vector_size: Some(size),
scalar,
pointer_space: None,
}
.into();
let ty =
LocalType::Numeric(NumericType::Vector { size, scalar }).into();
self.temp_list.clear();
self.temp_list.resize(
@ -1061,11 +1041,9 @@ impl<'w> BlockContext<'w> {
)
}
crate::TypeInner::Scalar(scalar) => (
self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
scalar,
pointer_space: None,
})),
self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(scalar),
))),
self.writer
.get_constant_scalar_with(scalar.width * 8 - 1, scalar)?,
scalar.width,
@ -1130,14 +1108,9 @@ impl<'w> BlockContext<'w> {
.writer
.get_constant_scalar(crate::Literal::U32(bit_width as u32));
let u32_type = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
scalar: crate::Scalar {
kind: crate::ScalarKind::Uint,
width: 4,
},
pointer_space: None,
}));
let u32_type = self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(crate::Scalar::U32),
)));
// o = min(offset, w)
let offset_id = self.gen_id();
@ -1186,14 +1159,9 @@ impl<'w> BlockContext<'w> {
.writer
.get_constant_scalar(crate::Literal::U32(bit_width as u32));
let u32_type = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
scalar: crate::Scalar {
kind: crate::ScalarKind::Uint,
width: 4,
},
pointer_space: None,
}));
let u32_type = self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(crate::Scalar::U32),
)));
// o = min(offset, w)
let offset_id = self.gen_id();
@ -1259,23 +1227,16 @@ impl<'w> BlockContext<'w> {
Mf::Pack4xU8 => (crate::ScalarKind::Uint, false),
_ => unreachable!(),
};
let uint_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
scalar: crate::Scalar {
kind: crate::ScalarKind::Uint,
width: 4,
},
pointer_space: None,
}));
let uint_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(crate::Scalar::U32),
)));
let int_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
scalar: crate::Scalar {
let int_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(crate::Scalar {
kind: int_type,
width: 4,
},
pointer_space: None,
}));
}),
)));
let mut last_instruction = Instruction::new(spirv::Op::Nop);
@ -1352,24 +1313,17 @@ impl<'w> BlockContext<'w> {
_ => unreachable!(),
};
let sint_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
scalar: crate::Scalar {
kind: crate::ScalarKind::Sint,
width: 4,
},
pointer_space: None,
}));
let sint_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(crate::Scalar::I32),
)));
let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
let int_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
scalar: crate::Scalar {
let int_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(crate::Scalar {
kind: int_type,
width: 4,
},
pointer_space: None,
}));
}),
)));
block
.body
.reserve(usize::from(VEC_LENGTH) * 2 + usize::from(is_signed));
@ -1533,11 +1487,10 @@ impl<'w> BlockContext<'w> {
self.writer.get_constant_scalar_with(0, src_scalar)?;
let zero_id = match src_size {
Some(size) => {
let ty = LocalType::Value {
vector_size: Some(size),
let ty = LocalType::Numeric(NumericType::Vector {
size,
scalar: src_scalar,
pointer_space: None,
}
})
.into();
self.temp_list.clear();
@ -1562,11 +1515,10 @@ impl<'w> BlockContext<'w> {
self.writer.get_constant_scalar_with(1, dst_scalar)?;
let (accept_id, reject_id) = match src_size {
Some(size) => {
let ty = LocalType::Value {
vector_size: Some(size),
let ty = LocalType::Numeric(NumericType::Vector {
size,
scalar: dst_scalar,
pointer_space: None,
}
})
.into();
self.temp_list.clear();
@ -1704,12 +1656,12 @@ impl<'w> BlockContext<'w> {
self.temp_list.clear();
self.temp_list.resize(size as usize, condition_id);
let bool_vector_type_id =
self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: Some(size),
let bool_vector_type_id = self.get_type_id(LookupType::Local(
LocalType::Numeric(NumericType::Vector {
size,
scalar: condition_scalar,
pointer_space: None,
}));
}),
));
let id = self.gen_id();
block.body.push(Instruction::composite_construct(
@ -2031,11 +1983,11 @@ impl<'w> BlockContext<'w> {
) {
self.temp_list.clear();
let vector_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: Some(rows),
scalar: crate::Scalar::float(width),
pointer_space: None,
}));
let vector_type_id =
self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
size: rows,
scalar: crate::Scalar::float(width),
})));
for index in 0..columns as u32 {
let column_id_left = self.gen_id();
@ -2737,20 +2689,15 @@ impl<'w> BlockContext<'w> {
crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
let scalar_type_id = match *value_inner {
crate::TypeInner::Scalar(scalar) => {
self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
scalar,
pointer_space: None,
}))
self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(scalar),
)))
}
_ => unimplemented!(),
};
let bool_type_id =
self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
scalar: crate::Scalar::BOOL,
pointer_space: None,
}));
let bool_type_id = self.get_type_id(LookupType::Local(
LocalType::Numeric(NumericType::Scalar(crate::Scalar::BOOL)),
));
let cas_result_id = self.gen_id();
let equality_result_id = self.gen_id();

View File

@ -4,7 +4,7 @@ Generating SPIR-V for image operations.
use super::{
selection::{MergeTuple, Selection},
Block, BlockContext, Error, IdGenerator, Instruction, LocalType, LookupType,
Block, BlockContext, Error, IdGenerator, Instruction, LocalType, LookupType, NumericType,
};
use crate::arena::Handle;
use spirv::Word;
@ -126,11 +126,10 @@ impl Load {
// the right SPIR-V type for the access instruction here.
let type_id = match image_class {
crate::ImageClass::Depth { .. } => {
ctx.get_type_id(LookupType::Local(LocalType::Value {
vector_size: Some(crate::VectorSize::Quad),
ctx.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
size: crate::VectorSize::Quad,
scalar: crate::Scalar::F32,
pointer_space: None,
}))
})))
}
_ => result_type_id,
};
@ -292,15 +291,15 @@ impl<'w> BlockContext<'w> {
// Find the component type of `coordinates`, and figure out the size the
// combined coordinate vector will have.
let (component_scalar, size) = match *inner_ty {
Ti::Scalar(scalar @ crate::Scalar { width: 4, .. }) => (scalar, Some(Vs::Bi)),
Ti::Scalar(scalar @ crate::Scalar { width: 4, .. }) => (scalar, Vs::Bi),
Ti::Vector {
scalar: scalar @ crate::Scalar { width: 4, .. },
size: Vs::Bi,
} => (scalar, Some(Vs::Tri)),
} => (scalar, Vs::Tri),
Ti::Vector {
scalar: scalar @ crate::Scalar { width: 4, .. },
size: Vs::Tri,
} => (scalar, Some(Vs::Quad)),
} => (scalar, Vs::Quad),
Ti::Vector { size: Vs::Quad, .. } => {
return Err(Error::Validation("extending vec4 coordinate"));
}
@ -340,11 +339,9 @@ impl<'w> BlockContext<'w> {
}
};
let reconciled_array_index_id = if let Some(cast) = cast {
let component_ty_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
scalar: component_scalar,
pointer_space: None,
}));
let component_ty_id = self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(component_scalar),
)));
let reconciled_id = self.gen_id();
block.body.push(Instruction::unary(
cast,
@ -358,11 +355,11 @@ impl<'w> BlockContext<'w> {
};
// Find the SPIR-V type for the combined coordinates/index vector.
let type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: size,
scalar: component_scalar,
pointer_space: None,
}));
let type_id =
self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
size,
scalar: component_scalar,
})));
// Schmear the coordinates and index together.
let value_id = self.gen_id();
@ -374,7 +371,7 @@ impl<'w> BlockContext<'w> {
Ok(ImageCoordinates {
value_id,
type_id,
size,
size: Some(size),
})
}
@ -529,11 +526,9 @@ impl<'w> BlockContext<'w> {
&[spirv::Capability::ImageQuery],
)?;
let i32_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
scalar: crate::Scalar::I32,
pointer_space: None,
}));
let i32_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(crate::Scalar::I32),
)));
// If `level` is `Some`, clamp it to fall within bounds. This must
// happen first, because we'll use it to query the image size for
@ -616,11 +611,9 @@ impl<'w> BlockContext<'w> {
)?;
let bool_type_id = self.writer.get_bool_type_id();
let i32_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
scalar: crate::Scalar::I32,
pointer_space: None,
}));
let i32_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(crate::Scalar::I32),
)));
let null_id = access.out_of_bounds_value(self);
@ -683,11 +676,15 @@ impl<'w> BlockContext<'w> {
);
// Compare the coordinates against the bounds.
let coords_bool_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: coordinates.size,
scalar: crate::Scalar::BOOL,
pointer_space: None,
}));
let coords_numeric_type = match coordinates.size {
Some(size) => NumericType::Vector {
size,
scalar: crate::Scalar::BOOL,
},
None => NumericType::Scalar(crate::Scalar::BOOL),
};
let coords_bool_type_id =
self.get_type_id(LookupType::Local(LocalType::Numeric(coords_numeric_type)));
let coords_conds_id = self.gen_id();
selection.block().body.push(Instruction::binary(
spirv::Op::ULessThan,
@ -838,11 +835,10 @@ impl<'w> BlockContext<'w> {
_ => false,
};
let sample_result_type_id = if needs_sub_access {
self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: Some(crate::VectorSize::Quad),
self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
size: crate::VectorSize::Quad,
scalar: crate::Scalar::F32,
pointer_space: None,
}))
})))
} else {
result_type_id
};
@ -1038,11 +1034,16 @@ impl<'w> BlockContext<'w> {
4 => Some(crate::VectorSize::Quad),
_ => None,
};
let extended_size_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size,
scalar: crate::Scalar::U32,
pointer_space: None,
}));
let vector_numeric_type = match vector_size {
Some(size) => NumericType::Vector {
size,
scalar: crate::Scalar::U32,
},
None => NumericType::Scalar(crate::Scalar::U32),
};
let extended_size_type_id =
self.get_type_id(LookupType::Local(LocalType::Numeric(vector_numeric_type)));
let (query_op, level_id) = match class {
Ic::Sampled { multi: true, .. }
@ -1108,11 +1109,11 @@ impl<'w> BlockContext<'w> {
Id::D2 | Id::Cube => crate::VectorSize::Tri,
Id::D3 => crate::VectorSize::Quad,
};
let extended_size_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: Some(vec_size),
scalar: crate::Scalar::U32,
pointer_space: None,
}));
let extended_size_type_id =
self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
size: vec_size,
scalar: crate::Scalar::U32,
})));
let id_extended = self.gen_id();
let mut inst = Instruction::image_query(
spirv::Op::ImageQuerySizeLod,

View File

@ -231,6 +231,21 @@ impl LocalImageType {
}
}
/// A numeric type, for use in [`LocalType`].
#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)]
enum NumericType {
Scalar(crate::Scalar),
Vector {
size: crate::VectorSize,
scalar: crate::Scalar,
},
Matrix {
columns: crate::VectorSize,
rows: crate::VectorSize,
scalar: crate::Scalar,
},
}
/// A SPIR-V type constructed during code generation.
///
/// This is the variant of [`LookupType`] used to represent types that might not
@ -276,19 +291,11 @@ impl LocalImageType {
/// [`TypeInner`]: crate::TypeInner
#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)]
enum LocalType {
/// A scalar, vector, or pointer to one of those.
Value {
/// If `None`, this represents a scalar type. If `Some`, this represents
/// a vector type of the given size.
vector_size: Option<crate::VectorSize>,
scalar: crate::Scalar,
pointer_space: Option<spirv::StorageClass>,
},
/// A matrix of floating-point values.
Matrix {
columns: crate::VectorSize,
rows: crate::VectorSize,
width: crate::Bytes,
/// A numeric type.
Numeric(NumericType),
LocalPointer {
base: NumericType,
class: spirv::StorageClass,
},
Pointer {
base: Handle<crate::Type>,
@ -361,38 +368,39 @@ impl LocalType {
fn from_inner(inner: &crate::TypeInner) -> Option<Self> {
Some(match *inner {
crate::TypeInner::Scalar(scalar) | crate::TypeInner::Atomic(scalar) => {
LocalType::Value {
vector_size: None,
scalar,
pointer_space: None,
}
LocalType::Numeric(NumericType::Scalar(scalar))
}
crate::TypeInner::Vector { size, scalar } => {
LocalType::Numeric(NumericType::Vector { size, scalar })
}
crate::TypeInner::Vector { size, scalar } => LocalType::Value {
vector_size: Some(size),
scalar,
pointer_space: None,
},
crate::TypeInner::Matrix {
columns,
rows,
scalar,
} => LocalType::Matrix {
} => LocalType::Numeric(NumericType::Matrix {
columns,
rows,
width: scalar.width,
},
scalar,
}),
crate::TypeInner::Pointer { base, space } => LocalType::Pointer {
base,
class: helpers::map_storage_class(space),
},
crate::TypeInner::ValuePointer {
size,
size: Some(size),
scalar,
space,
} => LocalType::Value {
vector_size: size,
} => LocalType::LocalPointer {
base: NumericType::Vector { size, scalar },
class: helpers::map_storage_class(space),
},
crate::TypeInner::ValuePointer {
size: None,
scalar,
pointer_space: Some(helpers::map_storage_class(space)),
space,
} => LocalType::LocalPointer {
base: NumericType::Scalar(scalar),
class: helpers::map_storage_class(space),
},
crate::TypeInner::Image {
dim,

View File

@ -2,7 +2,7 @@
Generating SPIR-V for ray query operations.
*/
use super::{Block, BlockContext, Instruction, LocalType, LookupType};
use super::{Block, BlockContext, Instruction, LocalType, LookupType, NumericType};
use crate::arena::Handle;
impl<'w> BlockContext<'w> {
@ -22,11 +22,9 @@ impl<'w> BlockContext<'w> {
let desc_id = self.cached[descriptor];
let acc_struct_id = self.get_handle_id(acceleration_structure);
let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
scalar: crate::Scalar::U32,
pointer_space: None,
}));
let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(crate::Scalar::U32),
)));
let ray_flags_id = self.gen_id();
block.body.push(Instruction::composite_extract(
flag_type_id,
@ -42,11 +40,9 @@ impl<'w> BlockContext<'w> {
&[1],
));
let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
scalar: crate::Scalar::F32,
pointer_space: None,
}));
let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(crate::Scalar::F32),
)));
let tmin_id = self.gen_id();
block.body.push(Instruction::composite_extract(
scalar_type_id,
@ -62,11 +58,11 @@ impl<'w> BlockContext<'w> {
&[3],
));
let vector_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: Some(crate::VectorSize::Tri),
scalar: crate::Scalar::F32,
pointer_space: None,
}));
let vector_type_id =
self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
size: crate::VectorSize::Tri,
scalar: crate::Scalar::F32,
})));
let ray_origin_id = self.gen_id();
block.body.push(Instruction::composite_extract(
vector_type_id,
@ -116,11 +112,9 @@ impl<'w> BlockContext<'w> {
spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _,
));
let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
scalar: crate::Scalar::U32,
pointer_space: None,
}));
let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(crate::Scalar::U32),
)));
let kind_id = self.gen_id();
block.body.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionTypeKHR,
@ -170,11 +164,9 @@ impl<'w> BlockContext<'w> {
intersection_id,
));
let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
scalar: crate::Scalar::F32,
pointer_space: None,
}));
let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(crate::Scalar::F32),
)));
let t_id = self.gen_id();
block.body.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionTKHR,
@ -184,11 +176,11 @@ impl<'w> BlockContext<'w> {
intersection_id,
));
let barycentrics_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: Some(crate::VectorSize::Bi),
scalar: crate::Scalar::F32,
pointer_space: None,
}));
let barycentrics_type_id =
self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
size: crate::VectorSize::Bi,
scalar: crate::Scalar::F32,
})));
let barycentrics_id = self.gen_id();
block.body.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionBarycentricsKHR,
@ -198,11 +190,9 @@ impl<'w> BlockContext<'w> {
intersection_id,
));
let bool_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
scalar: crate::Scalar::BOOL,
pointer_space: None,
}));
let bool_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(crate::Scalar::BOOL),
)));
let front_face_id = self.gen_id();
block.body.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionFrontFaceKHR,
@ -212,11 +202,12 @@ impl<'w> BlockContext<'w> {
intersection_id,
));
let transform_type_id = self.get_type_id(LookupType::Local(LocalType::Matrix {
columns: crate::VectorSize::Quad,
rows: crate::VectorSize::Tri,
width: 4,
}));
let transform_type_id =
self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Matrix {
columns: crate::VectorSize::Quad,
rows: crate::VectorSize::Tri,
scalar: crate::Scalar::F32,
})));
let object_to_world_id = self.gen_id();
block.body.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionObjectToWorldKHR,

View File

@ -1,4 +1,4 @@
use super::{Block, BlockContext, Error, Instruction};
use super::{Block, BlockContext, Error, Instruction, NumericType};
use crate::{
arena::Handle,
back::spv::{LocalType, LookupType},
@ -16,11 +16,11 @@ impl<'w> BlockContext<'w> {
"GroupNonUniformBallot",
&[spirv::Capability::GroupNonUniformBallot],
)?;
let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: Some(crate::VectorSize::Quad),
scalar: crate::Scalar::U32,
pointer_space: None,
}));
let vec4_u32_type_id =
self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
size: crate::VectorSize::Quad,
scalar: crate::Scalar::U32,
})));
let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
let predicate = if let Some(predicate) = *predicate {
self.cached[predicate]

View File

@ -3,8 +3,8 @@ use super::{
helpers::{contains_builtin, global_needs_wrapper, map_storage_class},
Block, BlockContext, CachedConstant, CachedExpressions, DebugInfo, EntryPointContext, Error,
Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction, LocalType, LocalVariable,
LogicalLayout, LookupFunctionType, LookupType, Options, PhysicalLayout, PipelineOptions,
ResultMember, Writer, WriterFlags, BITS_PER_BYTE,
LogicalLayout, LookupFunctionType, LookupType, NumericType, Options, PhysicalLayout,
PipelineOptions, ResultMember, Writer, WriterFlags, BITS_PER_BYTE,
};
use crate::{
arena::{Handle, HandleVec, UniqueArena},
@ -291,86 +291,55 @@ impl Writer {
}
pub(super) fn get_uint_type_id(&mut self) -> Word {
let local_type = LocalType::Value {
vector_size: None,
scalar: crate::Scalar::U32,
pointer_space: None,
};
let local_type = LocalType::Numeric(NumericType::Scalar(crate::Scalar::U32));
self.get_type_id(local_type.into())
}
pub(super) fn get_float_type_id(&mut self) -> Word {
let local_type = LocalType::Value {
vector_size: None,
scalar: crate::Scalar::F32,
pointer_space: None,
};
let local_type = LocalType::Numeric(NumericType::Scalar(crate::Scalar::F32));
self.get_type_id(local_type.into())
}
pub(super) fn get_uint3_type_id(&mut self) -> Word {
let local_type = LocalType::Value {
vector_size: Some(crate::VectorSize::Tri),
let local_type = LocalType::Numeric(NumericType::Vector {
size: crate::VectorSize::Tri,
scalar: crate::Scalar::U32,
pointer_space: None,
};
});
self.get_type_id(local_type.into())
}
pub(super) fn get_float_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
let lookup_type = LookupType::Local(LocalType::Value {
vector_size: None,
scalar: crate::Scalar::F32,
pointer_space: Some(class),
});
if let Some(&id) = self.lookup_type.get(&lookup_type) {
id
} else {
let id = self.id_gen.next();
let ty_id = self.get_float_type_id();
let instruction = Instruction::type_pointer(id, class, ty_id);
instruction.to_words(&mut self.logical_layout.declarations);
self.lookup_type.insert(lookup_type, id);
id
}
}
pub(super) fn get_uint3_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
let lookup_type = LookupType::Local(LocalType::Value {
vector_size: Some(crate::VectorSize::Tri),
scalar: crate::Scalar::U32,
pointer_space: Some(class),
});
if let Some(&id) = self.lookup_type.get(&lookup_type) {
id
} else {
let id = self.id_gen.next();
let ty_id = self.get_uint3_type_id();
let instruction = Instruction::type_pointer(id, class, ty_id);
instruction.to_words(&mut self.logical_layout.declarations);
self.lookup_type.insert(lookup_type, id);
id
}
}
pub(super) fn get_bool_type_id(&mut self) -> Word {
let local_type = LocalType::Value {
vector_size: None,
scalar: crate::Scalar::BOOL,
pointer_space: None,
let local_type = LocalType::LocalPointer {
base: NumericType::Scalar(crate::Scalar::F32),
class,
};
self.get_type_id(local_type.into())
}
pub(super) fn get_bool3_type_id(&mut self) -> Word {
let local_type = LocalType::Value {
vector_size: Some(crate::VectorSize::Tri),
scalar: crate::Scalar::BOOL,
pointer_space: None,
pub(super) fn get_uint3_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
let local_type = LocalType::LocalPointer {
base: NumericType::Vector {
size: crate::VectorSize::Tri,
scalar: crate::Scalar::U32,
},
class,
};
self.get_type_id(local_type.into())
}
pub(super) fn get_bool_type_id(&mut self) -> Word {
let local_type = LocalType::Numeric(NumericType::Scalar(crate::Scalar::BOOL));
self.get_type_id(local_type.into())
}
pub(super) fn get_bool3_type_id(&mut self) -> Word {
let local_type = LocalType::Numeric(NumericType::Vector {
size: crate::VectorSize::Tri,
scalar: crate::Scalar::BOOL,
});
self.get_type_id(local_type.into())
}
pub(super) fn decorate(&mut self, id: Word, decoration: spirv::Decoration, operands: &[Word]) {
self.annotations
.push(Instruction::decorate(id, decoration, operands));
@ -935,62 +904,50 @@ impl Writer {
Ok(())
}
fn write_numeric_type_declaration_local(&mut self, id: Word, numeric: NumericType) {
let instruction =
match numeric {
NumericType::Scalar(scalar) => self.make_scalar(id, scalar),
NumericType::Vector { size, scalar } => {
let scalar_id = self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(scalar),
)));
Instruction::type_vector(id, scalar_id, size)
}
NumericType::Matrix {
columns,
rows,
scalar,
} => {
let column_id = self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Vector { size: rows, scalar },
)));
Instruction::type_matrix(id, column_id, columns)
}
};
instruction.to_words(&mut self.logical_layout.declarations);
}
fn write_type_declaration_local(&mut self, id: Word, local_ty: LocalType) {
let instruction = match local_ty {
LocalType::Value {
vector_size: None,
scalar,
pointer_space: None,
} => self.make_scalar(id, scalar),
LocalType::Value {
vector_size: Some(size),
scalar,
pointer_space: None,
} => {
let scalar_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
scalar,
pointer_space: None,
}));
Instruction::type_vector(id, scalar_id, size)
LocalType::Numeric(numeric) => {
self.write_numeric_type_declaration_local(id, numeric);
return;
}
LocalType::Matrix {
columns,
rows,
width,
} => {
let vector_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: Some(rows),
scalar: crate::Scalar::float(width),
pointer_space: None,
}));
Instruction::type_matrix(id, vector_id, columns)
LocalType::LocalPointer { base, class } => {
let base_id = self.get_type_id(LookupType::Local(LocalType::Numeric(base)));
Instruction::type_pointer(id, class, base_id)
}
LocalType::Pointer { base, class } => {
let type_id = self.get_type_id(LookupType::Handle(base));
Instruction::type_pointer(id, class, type_id)
}
LocalType::Value {
vector_size,
scalar,
pointer_space: Some(class),
} => {
let type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size,
scalar,
pointer_space: None,
}));
Instruction::type_pointer(id, class, type_id)
}
LocalType::Image(image) => {
let local_type = LocalType::Value {
vector_size: None,
scalar: crate::Scalar {
kind: image.sampled_type,
width: 4,
},
pointer_space: None,
};
let local_type = LocalType::Numeric(NumericType::Scalar(crate::Scalar {
kind: image.sampled_type,
width: 4,
}));
let type_id = self.get_type_id(LookupType::Local(local_type));
Instruction::type_image(id, type_id, image.dim, image.flags, image.image_format)
}
@ -1224,11 +1181,9 @@ impl Writer {
self.debugs.push(Instruction::name(id, name));
}
}
let type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
scalar: value.scalar(),
pointer_space: None,
}));
let type_id = self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Scalar(
value.scalar(),
))));
let instruction = match *value {
crate::Literal::F64(value) => {
let bits = value.to_bits();