Subgroup Operations (#5301)

Co-authored-by: Jacob Hughes <j@distanthills.org>
Co-authored-by: Connor Fitzgerald <connorwadefitzgerald@gmail.com>
Co-authored-by: atlas dostal <rodol@rivalrebels.com>
This commit is contained in:
Alexander Meißner 2024-04-17 21:25:52 +02:00 committed by GitHub
parent 0dc9dd6bec
commit ea77d5674d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
64 changed files with 3328 additions and 70 deletions

View File

@ -138,6 +138,7 @@ Bottom level categories:
### Bug Fixes
#### General
- Add `SUBGROUP, SUBGROUP_VERTEX, SUBGROUP_BARRIER` features. By @exrook and @lichtso in [#5301](https://github.com/gfx-rs/wgpu/pull/5301)
- Fix `serde` feature not compiling for `wgpu-types`. By @KirmesBude in [#5149](https://github.com/gfx-rs/wgpu/pull/5149)
- Fix the validation of vertex and index ranges. By @nical in [#5144](https://github.com/gfx-rs/wgpu/pull/5144) and [#5156](https://github.com/gfx-rs/wgpu/pull/5156)
- Fix panic when creating a surface while no backend is available. By @wumpf [#5166](https://github.com/gfx-rs/wgpu/pull/5166)

View File

@ -424,6 +424,8 @@ fn run() -> Result<(), Box<dyn std::error::Error>> {
// Validate the IR before compaction.
let info = match naga::valid::Validator::new(params.validation_flags, validation_caps)
.subgroup_stages(naga::valid::ShaderStages::all())
.subgroup_operations(naga::valid::SubgroupOperationSet::all())
.validate(&module)
{
Ok(info) => Some(info),
@ -760,6 +762,8 @@ fn bulk_validate(args: Args, params: &Parameters) -> Result<(), Box<dyn std::err
let mut validator =
naga::valid::Validator::new(params.validation_flags, naga::valid::Capabilities::all());
validator.subgroup_stages(naga::valid::ShaderStages::all());
validator.subgroup_operations(naga::valid::SubgroupOperationSet::all());
if let Err(error) = validator.validate(&module) {
invalid.push(input_path.clone());

View File

@ -279,6 +279,94 @@ impl StatementGraph {
crate::RayQueryFunction::Terminate => "RayQueryTerminate",
}
}
S::SubgroupBallot { result, predicate } => {
if let Some(predicate) = predicate {
self.dependencies.push((id, predicate, "predicate"));
}
self.emits.push((id, result));
"SubgroupBallot"
}
S::SubgroupCollectiveOperation {
op,
collective_op,
argument,
result,
} => {
self.dependencies.push((id, argument, "arg"));
self.emits.push((id, result));
match (collective_op, op) {
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
"SubgroupAll"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
"SubgroupAny"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
"SubgroupAdd"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
"SubgroupMul"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
"SubgroupMax"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
"SubgroupMin"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
"SubgroupAnd"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
"SubgroupOr"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
"SubgroupXor"
}
(
crate::CollectiveOperation::ExclusiveScan,
crate::SubgroupOperation::Add,
) => "SubgroupExclusiveAdd",
(
crate::CollectiveOperation::ExclusiveScan,
crate::SubgroupOperation::Mul,
) => "SubgroupExclusiveMul",
(
crate::CollectiveOperation::InclusiveScan,
crate::SubgroupOperation::Add,
) => "SubgroupInclusiveAdd",
(
crate::CollectiveOperation::InclusiveScan,
crate::SubgroupOperation::Mul,
) => "SubgroupInclusiveMul",
_ => unimplemented!(),
}
}
S::SubgroupGather {
mode,
argument,
result,
} => {
match mode {
crate::GatherMode::BroadcastFirst => {}
crate::GatherMode::Broadcast(index)
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
self.dependencies.push((id, index, "index"))
}
}
self.dependencies.push((id, argument, "arg"));
self.emits.push((id, result));
match mode {
crate::GatherMode::BroadcastFirst => "SubgroupBroadcastFirst",
crate::GatherMode::Broadcast(_) => "SubgroupBroadcast",
crate::GatherMode::Shuffle(_) => "SubgroupShuffle",
crate::GatherMode::ShuffleDown(_) => "SubgroupShuffleDown",
crate::GatherMode::ShuffleUp(_) => "SubgroupShuffleUp",
crate::GatherMode::ShuffleXor(_) => "SubgroupShuffleXor",
}
}
};
// Set the last node to the merge node
last_node = merge_id;
@ -587,6 +675,8 @@ fn write_function_expressions(
let ty = if committed { "Committed" } else { "Candidate" };
(format!("rayQueryGet{}Intersection", ty).into(), 4)
}
E::SubgroupBallotResult => ("SubgroupBallotResult".into(), 4),
E::SubgroupOperationResult { .. } => ("SubgroupOperationResult".into(), 4),
};
// give uniform expressions an outline

View File

@ -50,6 +50,8 @@ bitflags::bitflags! {
const INSTANCE_INDEX = 1 << 22;
/// Sample specific LODs of cube / array shadow textures
const TEXTURE_SHADOW_LOD = 1 << 23;
/// Subgroup operations
const SUBGROUP_OPERATIONS = 1 << 24;
}
}
@ -117,6 +119,7 @@ impl FeaturesManager {
check_feature!(SAMPLE_VARIABLES, 400, 300);
check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310);
check_feature!(DUAL_SOURCE_BLENDING, 330, 300 /* with extension */);
check_feature!(SUBGROUP_OPERATIONS, 430, 310);
match version {
Version::Embedded { is_webgl: true, .. } => check_feature!(MULTI_VIEW, 140, 300),
_ => check_feature!(MULTI_VIEW, 140, 310),
@ -259,6 +262,22 @@ impl FeaturesManager {
writeln!(out, "#extension GL_EXT_texture_shadow_lod : require")?;
}
if self.0.contains(Features::SUBGROUP_OPERATIONS) {
// https://registry.khronos.org/OpenGL/extensions/KHR/KHR_shader_subgroup.txt
writeln!(out, "#extension GL_KHR_shader_subgroup_basic : require")?;
writeln!(out, "#extension GL_KHR_shader_subgroup_vote : require")?;
writeln!(
out,
"#extension GL_KHR_shader_subgroup_arithmetic : require"
)?;
writeln!(out, "#extension GL_KHR_shader_subgroup_ballot : require")?;
writeln!(out, "#extension GL_KHR_shader_subgroup_shuffle : require")?;
writeln!(
out,
"#extension GL_KHR_shader_subgroup_shuffle_relative : require"
)?;
}
Ok(())
}
}
@ -518,6 +537,10 @@ impl<'a, W> Writer<'a, W> {
}
}
}
Expression::SubgroupBallotResult |
Expression::SubgroupOperationResult { .. } => {
features.request(Features::SUBGROUP_OPERATIONS)
}
_ => {}
}
}

View File

@ -2390,6 +2390,125 @@ impl<'a, W: Write> Writer<'a, W> {
writeln!(self.out, ");")?;
}
Statement::RayQuery { .. } => unreachable!(),
Statement::SubgroupBallot { result, predicate } => {
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
self.write_value_type(res_ty)?;
write!(self.out, " {res_name} = ")?;
self.named_expressions.insert(result, res_name);
write!(self.out, "subgroupBallot(")?;
match predicate {
Some(predicate) => self.write_expr(predicate, ctx)?,
None => write!(self.out, "true")?,
}
writeln!(self.out, ");")?;
}
Statement::SubgroupCollectiveOperation {
op,
collective_op,
argument,
result,
} => {
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
self.write_value_type(res_ty)?;
write!(self.out, " {res_name} = ")?;
self.named_expressions.insert(result, res_name);
match (collective_op, op) {
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
write!(self.out, "subgroupAll(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
write!(self.out, "subgroupAny(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
write!(self.out, "subgroupAdd(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
write!(self.out, "subgroupMul(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
write!(self.out, "subgroupMax(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
write!(self.out, "subgroupMin(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
write!(self.out, "subgroupAnd(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
write!(self.out, "subgroupOr(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
write!(self.out, "subgroupXor(")?
}
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
write!(self.out, "subgroupExclusiveAdd(")?
}
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
write!(self.out, "subgroupExclusiveMul(")?
}
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
write!(self.out, "subgroupInclusiveAdd(")?
}
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
write!(self.out, "subgroupInclusiveMul(")?
}
_ => unimplemented!(),
}
self.write_expr(argument, ctx)?;
writeln!(self.out, ");")?;
}
Statement::SubgroupGather {
mode,
argument,
result,
} => {
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
self.write_value_type(res_ty)?;
write!(self.out, " {res_name} = ")?;
self.named_expressions.insert(result, res_name);
match mode {
crate::GatherMode::BroadcastFirst => {
write!(self.out, "subgroupBroadcastFirst(")?;
}
crate::GatherMode::Broadcast(_) => {
write!(self.out, "subgroupBroadcast(")?;
}
crate::GatherMode::Shuffle(_) => {
write!(self.out, "subgroupShuffle(")?;
}
crate::GatherMode::ShuffleDown(_) => {
write!(self.out, "subgroupShuffleDown(")?;
}
crate::GatherMode::ShuffleUp(_) => {
write!(self.out, "subgroupShuffleUp(")?;
}
crate::GatherMode::ShuffleXor(_) => {
write!(self.out, "subgroupShuffleXor(")?;
}
}
self.write_expr(argument, ctx)?;
match mode {
crate::GatherMode::BroadcastFirst => {}
crate::GatherMode::Broadcast(index)
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
write!(self.out, ", ")?;
self.write_expr(index, ctx)?;
}
}
writeln!(self.out, ");")?;
}
}
Ok(())
@ -3658,7 +3777,9 @@ impl<'a, W: Write> Writer<'a, W> {
Expression::CallResult(_)
| Expression::AtomicResult { .. }
| Expression::RayQueryProceedResult
| Expression::WorkGroupUniformLoadResult { .. } => unreachable!(),
| Expression::WorkGroupUniformLoadResult { .. }
| Expression::SubgroupOperationResult { .. }
| Expression::SubgroupBallotResult => unreachable!(),
// `ArrayLength` is written as `expr.length()` and we convert it to a uint
Expression::ArrayLength(expr) => {
write!(self.out, "uint(")?;
@ -4227,6 +4348,9 @@ impl<'a, W: Write> Writer<'a, W> {
if flags.contains(crate::Barrier::WORK_GROUP) {
writeln!(self.out, "{level}memoryBarrierShared();")?;
}
if flags.contains(crate::Barrier::SUB_GROUP) {
writeln!(self.out, "{level}subgroupMemoryBarrier();")?;
}
writeln!(self.out, "{level}barrier();")?;
Ok(())
}
@ -4496,6 +4620,11 @@ const fn glsl_built_in(built_in: crate::BuiltIn, options: VaryingOptions) -> &'s
Bi::WorkGroupId => "gl_WorkGroupID",
Bi::WorkGroupSize => "gl_WorkGroupSize",
Bi::NumWorkGroups => "gl_NumWorkGroups",
// subgroup
Bi::NumSubgroups => "gl_NumSubgroups",
Bi::SubgroupId => "gl_SubgroupID",
Bi::SubgroupSize => "gl_SubgroupSize",
Bi::SubgroupInvocationId => "gl_SubgroupInvocationID",
}
}

View File

@ -179,6 +179,11 @@ impl crate::BuiltIn {
// to this field will get replaced with references to `SPECIAL_CBUF_VAR`
// in `Writer::write_expr`.
Self::NumWorkGroups => "SV_GroupID",
// These builtins map to functions
Self::SubgroupSize
| Self::SubgroupInvocationId
| Self::NumSubgroups
| Self::SubgroupId => unreachable!(),
Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => {
return Err(Error::Unimplemented(format!("builtin {self:?}")))
}

View File

@ -77,6 +77,19 @@ enum Io {
Output,
}
const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
return false;
};
matches!(
builtin,
crate::BuiltIn::SubgroupSize
| crate::BuiltIn::SubgroupInvocationId
| crate::BuiltIn::NumSubgroups
| crate::BuiltIn::SubgroupId
)
}
impl<'a, W: fmt::Write> super::Writer<'a, W> {
pub fn new(out: W, options: &'a Options) -> Self {
Self {
@ -161,6 +174,19 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}
}
for statement in func.body.iter() {
match *statement {
crate::Statement::SubgroupCollectiveOperation {
op: _,
collective_op: crate::CollectiveOperation::InclusiveScan,
argument,
result: _,
} => {
self.need_bake_expressions.insert(argument);
}
_ => {}
}
}
}
pub fn write(
@ -401,31 +427,32 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// if they are struct, so that the `stage` argument here could be omitted.
fn write_semantic(
&mut self,
binding: &crate::Binding,
binding: &Option<crate::Binding>,
stage: Option<(ShaderStage, Io)>,
) -> BackendResult {
match *binding {
crate::Binding::BuiltIn(builtin) => {
Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
let builtin_str = builtin.to_hlsl_str()?;
write!(self.out, " : {builtin_str}")?;
}
crate::Binding::Location {
Some(crate::Binding::Location {
second_blend_source: true,
..
} => {
}) => {
write!(self.out, " : SV_Target1")?;
}
crate::Binding::Location {
Some(crate::Binding::Location {
location,
second_blend_source: false,
..
} => {
}) => {
if stage == Some((crate::ShaderStage::Fragment, Io::Output)) {
write!(self.out, " : SV_Target{location}")?;
} else {
write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
}
}
_ => {}
}
Ok(())
@ -446,17 +473,30 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(self.out, "struct {struct_name}")?;
writeln!(self.out, " {{")?;
for m in members.iter() {
if is_subgroup_builtin_binding(&m.binding) {
continue;
}
write!(self.out, "{}", back::INDENT)?;
if let Some(ref binding) = m.binding {
self.write_modifier(binding)?;
}
self.write_type(module, m.ty)?;
write!(self.out, " {}", &m.name)?;
if let Some(ref binding) = m.binding {
self.write_semantic(binding, Some(shader_stage))?;
}
self.write_semantic(&m.binding, Some(shader_stage))?;
writeln!(self.out, ";")?;
}
if members.iter().any(|arg| {
matches!(
arg.binding,
Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId))
)
}) {
writeln!(
self.out,
"{}uint __local_invocation_index : SV_GroupIndex;",
back::INDENT
)?;
}
writeln!(self.out, "}};")?;
writeln!(self.out)?;
@ -557,8 +597,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
/// Writes special interface structures for an entry point. The special structures have
/// all the fields flattened into them and sorted by binding. They are only needed for
/// VS outputs and FS inputs, so that these interfaces match.
/// all the fields flattened into them and sorted by binding. They are needed to emulate
/// subgroup built-ins and to make the interfaces between VS outputs and FS inputs match.
fn write_ep_interface(
&mut self,
module: &Module,
@ -567,7 +607,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
ep_name: &str,
) -> Result<EntryPointInterface, Error> {
Ok(EntryPointInterface {
input: if !func.arguments.is_empty() && stage == ShaderStage::Fragment {
input: if !func.arguments.is_empty()
&& (stage == ShaderStage::Fragment
|| func
.arguments
.iter()
.any(|arg| is_subgroup_builtin_binding(&arg.binding)))
{
Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
} else {
None
@ -581,6 +627,38 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
})
}
fn write_ep_argument_initialization(
&mut self,
ep: &crate::EntryPoint,
ep_input: &EntryPointBinding,
fake_member: &EpStructMember,
) -> BackendResult {
match fake_member.binding {
Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
write!(self.out, "WaveGetLaneCount()")?
}
Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
write!(self.out, "WaveGetLaneIndex()")?
}
Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
self.out,
"({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
)?,
Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
write!(
self.out,
"{}.__local_invocation_index / WaveGetLaneCount()",
ep_input.arg_name
)?;
}
_ => {
write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
}
}
Ok(())
}
/// Write an entry point preface that initializes the arguments as specified in IR.
fn write_ep_arguments_initialization(
&mut self,
@ -588,6 +666,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
func: &crate::Function,
ep_index: u16,
) -> BackendResult {
let ep = &module.entry_points[ep_index as usize];
let ep_input = match self.entry_point_io[ep_index as usize].input.take() {
Some(ep_input) => ep_input,
None => return Ok(()),
@ -601,8 +680,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
match module.types[arg.ty].inner {
TypeInner::Array { base, size, .. } => {
self.write_array_size(module, base, size)?;
let fake_member = fake_iter.next().unwrap();
writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?;
write!(self.out, " = ")?;
self.write_ep_argument_initialization(
ep,
&ep_input,
fake_iter.next().unwrap(),
)?;
writeln!(self.out, ";")?;
}
TypeInner::Struct { ref members, .. } => {
write!(self.out, " = {{ ")?;
@ -610,14 +694,22 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
if index != 0 {
write!(self.out, ", ")?;
}
let fake_member = fake_iter.next().unwrap();
write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
self.write_ep_argument_initialization(
ep,
&ep_input,
fake_iter.next().unwrap(),
)?;
}
writeln!(self.out, " }};")?;
}
_ => {
let fake_member = fake_iter.next().unwrap();
writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?;
write!(self.out, " = ")?;
self.write_ep_argument_initialization(
ep,
&ep_input,
fake_iter.next().unwrap(),
)?;
writeln!(self.out, ";")?;
}
}
}
@ -932,9 +1024,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}
if let Some(ref binding) = member.binding {
self.write_semantic(binding, shader_stage)?;
};
self.write_semantic(&member.binding, shader_stage)?;
writeln!(self.out, ";")?;
}
@ -1147,7 +1237,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
back::FunctionType::EntryPoint(ep_index) => {
if let Some(ref ep_input) = self.entry_point_io[ep_index as usize].input {
write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name,)?;
write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
} else {
let stage = module.entry_points[ep_index as usize].stage;
for (index, arg) in func.arguments.iter().enumerate() {
@ -1164,17 +1254,16 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_array_size(module, base, size)?;
}
if let Some(ref binding) = arg.binding {
self.write_semantic(binding, Some((stage, Io::Input)))?;
}
self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
}
if need_workgroup_variables_initialization {
if !func.arguments.is_empty() {
write!(self.out, ", ")?;
}
write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
}
if need_workgroup_variables_initialization {
if self.entry_point_io[ep_index as usize].input.is_some()
|| !func.arguments.is_empty()
{
write!(self.out, ", ")?;
}
write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
}
}
}
@ -1184,11 +1273,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// Write semantic if it present
if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
let stage = module.entry_points[index as usize].stage;
if let Some(crate::FunctionResult {
binding: Some(ref binding),
..
}) = func.result
{
if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
self.write_semantic(binding, Some((stage, Io::Output)))?;
}
}
@ -1988,6 +2073,129 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out, "{level}}}")?
}
Statement::RayQuery { .. } => unreachable!(),
Statement::SubgroupBallot { result, predicate } => {
write!(self.out, "{level}")?;
let name = format!("{}{}", back::BAKE_PREFIX, result.index());
write!(self.out, "const uint4 {name} = ")?;
self.named_expressions.insert(result, name);
write!(self.out, "WaveActiveBallot(")?;
match predicate {
Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
None => write!(self.out, "true")?,
}
writeln!(self.out, ");")?;
}
Statement::SubgroupCollectiveOperation {
op,
collective_op,
argument,
result,
} => {
write!(self.out, "{level}")?;
write!(self.out, "const ")?;
let name = format!("{}{}", back::BAKE_PREFIX, result.index());
match func_ctx.info[result].ty {
proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
proc::TypeResolution::Value(ref value) => {
self.write_value_type(module, value)?
}
};
write!(self.out, " {name} = ")?;
self.named_expressions.insert(result, name);
match (collective_op, op) {
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
write!(self.out, "WaveActiveAllTrue(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
write!(self.out, "WaveActiveAnyTrue(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
write!(self.out, "WaveActiveSum(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
write!(self.out, "WaveActiveProduct(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
write!(self.out, "WaveActiveMax(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
write!(self.out, "WaveActiveMin(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
write!(self.out, "WaveActiveBitAnd(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
write!(self.out, "WaveActiveBitOr(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
write!(self.out, "WaveActiveBitXor(")?
}
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
write!(self.out, "WavePrefixSum(")?
}
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
write!(self.out, "WavePrefixProduct(")?
}
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
self.write_expr(module, argument, func_ctx)?;
write!(self.out, " + WavePrefixSum(")?;
}
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
self.write_expr(module, argument, func_ctx)?;
write!(self.out, " * WavePrefixProduct(")?;
}
_ => unimplemented!(),
}
self.write_expr(module, argument, func_ctx)?;
writeln!(self.out, ");")?;
}
Statement::SubgroupGather {
mode,
argument,
result,
} => {
write!(self.out, "{level}")?;
write!(self.out, "const ")?;
let name = format!("{}{}", back::BAKE_PREFIX, result.index());
match func_ctx.info[result].ty {
proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
proc::TypeResolution::Value(ref value) => {
self.write_value_type(module, value)?
}
};
write!(self.out, " {name} = ")?;
self.named_expressions.insert(result, name);
if matches!(mode, crate::GatherMode::BroadcastFirst) {
write!(self.out, "WaveReadLaneFirst(")?;
self.write_expr(module, argument, func_ctx)?;
} else {
write!(self.out, "WaveReadLaneAt(")?;
self.write_expr(module, argument, func_ctx)?;
write!(self.out, ", ")?;
match mode {
crate::GatherMode::BroadcastFirst => unreachable!(),
crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) => {
self.write_expr(module, index, func_ctx)?;
}
crate::GatherMode::ShuffleDown(index) => {
write!(self.out, "WaveGetLaneIndex() + ")?;
self.write_expr(module, index, func_ctx)?;
}
crate::GatherMode::ShuffleUp(index) => {
write!(self.out, "WaveGetLaneIndex() - ")?;
self.write_expr(module, index, func_ctx)?;
}
crate::GatherMode::ShuffleXor(index) => {
write!(self.out, "WaveGetLaneIndex() ^ ")?;
self.write_expr(module, index, func_ctx)?;
}
}
}
writeln!(self.out, ");")?;
}
}
Ok(())
@ -3134,7 +3342,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Expression::CallResult(_)
| Expression::AtomicResult { .. }
| Expression::WorkGroupUniformLoadResult { .. }
| Expression::RayQueryProceedResult => {}
| Expression::RayQueryProceedResult
| Expression::SubgroupBallotResult
| Expression::SubgroupOperationResult { .. } => {}
}
if !closing_bracket.is_empty() {
@ -3201,6 +3411,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
if barrier.contains(crate::Barrier::WORK_GROUP) {
writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
}
if barrier.contains(crate::Barrier::SUB_GROUP) {
// Does not exist in DirectX
}
Ok(())
}
}

View File

@ -436,6 +436,11 @@ impl ResolvedBinding {
Bi::WorkGroupId => "threadgroup_position_in_grid",
Bi::WorkGroupSize => "dispatch_threads_per_threadgroup",
Bi::NumWorkGroups => "threadgroups_per_grid",
// subgroup
Bi::NumSubgroups => "simdgroups_per_threadgroup",
Bi::SubgroupId => "simdgroup_index_in_threadgroup",
Bi::SubgroupSize => "threads_per_simdgroup",
Bi::SubgroupInvocationId => "thread_index_in_simdgroup",
Bi::CullDistance | Bi::ViewIndex => {
return Err(Error::UnsupportedBuiltIn(built_in))
}

View File

@ -2042,6 +2042,8 @@ impl<W: Write> Writer<W> {
crate::Expression::CallResult(_)
| crate::Expression::AtomicResult { .. }
| crate::Expression::WorkGroupUniformLoadResult { .. }
| crate::Expression::SubgroupBallotResult
| crate::Expression::SubgroupOperationResult { .. }
| crate::Expression::RayQueryProceedResult => {
unreachable!()
}
@ -3145,6 +3147,121 @@ impl<W: Write> Writer<W> {
}
}
}
crate::Statement::SubgroupBallot { result, predicate } => {
write!(self.out, "{level}")?;
let name = self.namer.call("");
self.start_baking_expression(result, &context.expression, &name)?;
self.named_expressions.insert(result, name);
write!(self.out, "uint4((uint64_t){NAMESPACE}::simd_ballot(")?;
if let Some(predicate) = predicate {
self.put_expression(predicate, &context.expression, true)?;
} else {
write!(self.out, "true")?;
}
writeln!(self.out, "), 0, 0, 0);")?;
}
crate::Statement::SubgroupCollectiveOperation {
op,
collective_op,
argument,
result,
} => {
write!(self.out, "{level}")?;
let name = self.namer.call("");
self.start_baking_expression(result, &context.expression, &name)?;
self.named_expressions.insert(result, name);
match (collective_op, op) {
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
write!(self.out, "{NAMESPACE}::simd_all(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
write!(self.out, "{NAMESPACE}::simd_any(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
write!(self.out, "{NAMESPACE}::simd_sum(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
write!(self.out, "{NAMESPACE}::simd_product(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
write!(self.out, "{NAMESPACE}::simd_max(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
write!(self.out, "{NAMESPACE}::simd_min(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
write!(self.out, "{NAMESPACE}::simd_and(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
write!(self.out, "{NAMESPACE}::simd_or(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
write!(self.out, "{NAMESPACE}::simd_xor(")?
}
(
crate::CollectiveOperation::ExclusiveScan,
crate::SubgroupOperation::Add,
) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
(
crate::CollectiveOperation::ExclusiveScan,
crate::SubgroupOperation::Mul,
) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
(
crate::CollectiveOperation::InclusiveScan,
crate::SubgroupOperation::Add,
) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
(
crate::CollectiveOperation::InclusiveScan,
crate::SubgroupOperation::Mul,
) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
_ => unimplemented!(),
}
self.put_expression(argument, &context.expression, true)?;
writeln!(self.out, ");")?;
}
crate::Statement::SubgroupGather {
mode,
argument,
result,
} => {
write!(self.out, "{level}")?;
let name = self.namer.call("");
self.start_baking_expression(result, &context.expression, &name)?;
self.named_expressions.insert(result, name);
match mode {
crate::GatherMode::BroadcastFirst => {
write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
}
crate::GatherMode::Broadcast(_) => {
write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
}
crate::GatherMode::Shuffle(_) => {
write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
}
crate::GatherMode::ShuffleDown(_) => {
write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
}
crate::GatherMode::ShuffleUp(_) => {
write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
}
crate::GatherMode::ShuffleXor(_) => {
write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
}
}
self.put_expression(argument, &context.expression, true)?;
match mode {
crate::GatherMode::BroadcastFirst => {}
crate::GatherMode::Broadcast(index)
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
write!(self.out, ", ")?;
self.put_expression(index, &context.expression, true)?;
}
}
writeln!(self.out, ");")?;
}
}
}
@ -4492,6 +4609,12 @@ impl<W: Write> Writer<W> {
"{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
)?;
}
if flags.contains(crate::Barrier::SUB_GROUP) {
writeln!(
self.out,
"{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
)?;
}
Ok(())
}
}
@ -4762,8 +4885,8 @@ fn test_stack_size() {
}
let stack_size = addresses_end - addresses_start;
// check the size (in debug only)
// last observed macOS value: 19152 (CI)
if !(9000..=20000).contains(&stack_size) {
// last observed macOS value: 22256 (CI)
if !(15000..=25000).contains(&stack_size) {
panic!("`put_block` stack size {stack_size} has changed!");
}
}

View File

@ -522,7 +522,9 @@ fn adjust_expr(new_pos: &[Handle<Expression>], expr: &mut Expression) {
ty: _,
comparison: _,
}
| Expression::WorkGroupUniformLoadResult { ty: _ } => {}
| Expression::WorkGroupUniformLoadResult { ty: _ }
| Expression::SubgroupBallotResult
| Expression::SubgroupOperationResult { .. } => {}
}
}
@ -637,6 +639,41 @@ fn adjust_stmt(new_pos: &[Handle<Expression>], stmt: &mut Statement) {
adjust(pointer);
adjust(result);
}
Statement::SubgroupBallot {
ref mut result,
ref mut predicate,
} => {
if let Some(ref mut predicate) = *predicate {
adjust(predicate);
}
adjust(result);
}
Statement::SubgroupCollectiveOperation {
ref mut argument,
ref mut result,
..
} => {
adjust(argument);
adjust(result);
}
Statement::SubgroupGather {
ref mut mode,
ref mut argument,
ref mut result,
} => {
match *mode {
crate::GatherMode::BroadcastFirst => {}
crate::GatherMode::Broadcast(ref mut index)
| crate::GatherMode::Shuffle(ref mut index)
| crate::GatherMode::ShuffleDown(ref mut index)
| crate::GatherMode::ShuffleUp(ref mut index)
| crate::GatherMode::ShuffleXor(ref mut index) => {
adjust(index);
}
}
adjust(argument);
adjust(result)
}
Statement::Call {
ref mut arguments,
ref mut result,

View File

@ -1279,7 +1279,9 @@ impl<'w> BlockContext<'w> {
crate::Expression::CallResult(_)
| crate::Expression::AtomicResult { .. }
| crate::Expression::WorkGroupUniformLoadResult { .. }
| crate::Expression::RayQueryProceedResult => self.cached[expr_handle],
| crate::Expression::RayQueryProceedResult
| crate::Expression::SubgroupBallotResult
| crate::Expression::SubgroupOperationResult { .. } => self.cached[expr_handle],
crate::Expression::As {
expr,
kind,
@ -2490,6 +2492,27 @@ impl<'w> BlockContext<'w> {
crate::Statement::RayQuery { query, ref fun } => {
self.write_ray_query_function(query, fun, &mut block);
}
crate::Statement::SubgroupBallot {
result,
ref predicate,
} => {
self.write_subgroup_ballot(predicate, result, &mut block)?;
}
crate::Statement::SubgroupCollectiveOperation {
ref op,
ref collective_op,
argument,
result,
} => {
self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?;
}
crate::Statement::SubgroupGather {
ref mode,
argument,
result,
} => {
self.write_subgroup_gather(mode, argument, result, &mut block)?;
}
}
}

View File

@ -1073,6 +1073,73 @@ impl super::Instruction {
instruction.add_operand(semantics_id);
instruction
}
// Group Instructions
pub(super) fn group_non_uniform_ballot(
result_type_id: Word,
id: Word,
exec_scope_id: Word,
predicate: Word,
) -> Self {
let mut instruction = Self::new(Op::GroupNonUniformBallot);
instruction.set_type(result_type_id);
instruction.set_result(id);
instruction.add_operand(exec_scope_id);
instruction.add_operand(predicate);
instruction
}
pub(super) fn group_non_uniform_broadcast_first(
result_type_id: Word,
id: Word,
exec_scope_id: Word,
value: Word,
) -> Self {
let mut instruction = Self::new(Op::GroupNonUniformBroadcastFirst);
instruction.set_type(result_type_id);
instruction.set_result(id);
instruction.add_operand(exec_scope_id);
instruction.add_operand(value);
instruction
}
pub(super) fn group_non_uniform_gather(
op: Op,
result_type_id: Word,
id: Word,
exec_scope_id: Word,
value: Word,
index: Word,
) -> Self {
let mut instruction = Self::new(op);
instruction.set_type(result_type_id);
instruction.set_result(id);
instruction.add_operand(exec_scope_id);
instruction.add_operand(value);
instruction.add_operand(index);
instruction
}
pub(super) fn group_non_uniform_arithmetic(
op: Op,
result_type_id: Word,
id: Word,
exec_scope_id: Word,
group_op: Option<spirv::GroupOperation>,
value: Word,
) -> Self {
let mut instruction = Self::new(op);
instruction.set_type(result_type_id);
instruction.set_result(id);
instruction.add_operand(exec_scope_id);
if let Some(group_op) = group_op {
instruction.add_operand(group_op as u32);
}
instruction.add_operand(value);
instruction
}
}
impl From<crate::StorageFormat> for spirv::ImageFormat {

View File

@ -13,6 +13,7 @@ mod layout;
mod ray;
mod recyclable;
mod selection;
mod subgroup;
mod writer;
pub use spirv::Capability;

View File

@ -0,0 +1,207 @@
use super::{Block, BlockContext, Error, Instruction};
use crate::{
arena::Handle,
back::spv::{LocalType, LookupType},
TypeInner,
};
impl<'w> BlockContext<'w> {
pub(super) fn write_subgroup_ballot(
&mut self,
predicate: &Option<Handle<crate::Expression>>,
result: Handle<crate::Expression>,
block: &mut Block,
) -> Result<(), Error> {
self.writer.require_any(
"GroupNonUniformBallot",
&[spirv::Capability::GroupNonUniformBallot],
)?;
let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: Some(crate::VectorSize::Quad),
scalar: crate::Scalar::U32,
pointer_space: None,
}));
let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
let predicate = if let Some(predicate) = *predicate {
self.cached[predicate]
} else {
self.writer.get_constant_scalar(crate::Literal::Bool(true))
};
let id = self.gen_id();
block.body.push(Instruction::group_non_uniform_ballot(
vec4_u32_type_id,
id,
exec_scope_id,
predicate,
));
self.cached[result] = id;
Ok(())
}
pub(super) fn write_subgroup_operation(
&mut self,
op: &crate::SubgroupOperation,
collective_op: &crate::CollectiveOperation,
argument: Handle<crate::Expression>,
result: Handle<crate::Expression>,
block: &mut Block,
) -> Result<(), Error> {
use crate::SubgroupOperation as sg;
match *op {
sg::All | sg::Any => {
self.writer.require_any(
"GroupNonUniformVote",
&[spirv::Capability::GroupNonUniformVote],
)?;
}
_ => {
self.writer.require_any(
"GroupNonUniformArithmetic",
&[spirv::Capability::GroupNonUniformArithmetic],
)?;
}
}
let id = self.gen_id();
let result_ty = &self.fun_info[result].ty;
let result_type_id = self.get_expression_type_id(result_ty);
let result_ty_inner = result_ty.inner_with(&self.ir_module.types);
let (is_scalar, scalar) = match *result_ty_inner {
TypeInner::Scalar(kind) => (true, kind),
TypeInner::Vector { scalar: kind, .. } => (false, kind),
_ => unimplemented!(),
};
use crate::ScalarKind as sk;
let spirv_op = match (scalar.kind, *op) {
(sk::Bool, sg::All) if is_scalar => spirv::Op::GroupNonUniformAll,
(sk::Bool, sg::Any) if is_scalar => spirv::Op::GroupNonUniformAny,
(_, sg::All | sg::Any) => unimplemented!(),
(sk::Sint | sk::Uint, sg::Add) => spirv::Op::GroupNonUniformIAdd,
(sk::Float, sg::Add) => spirv::Op::GroupNonUniformFAdd,
(sk::Sint | sk::Uint, sg::Mul) => spirv::Op::GroupNonUniformIMul,
(sk::Float, sg::Mul) => spirv::Op::GroupNonUniformFMul,
(sk::Sint, sg::Max) => spirv::Op::GroupNonUniformSMax,
(sk::Uint, sg::Max) => spirv::Op::GroupNonUniformUMax,
(sk::Float, sg::Max) => spirv::Op::GroupNonUniformFMax,
(sk::Sint, sg::Min) => spirv::Op::GroupNonUniformSMin,
(sk::Uint, sg::Min) => spirv::Op::GroupNonUniformUMin,
(sk::Float, sg::Min) => spirv::Op::GroupNonUniformFMin,
(_, sg::Add | sg::Mul | sg::Min | sg::Max) => unimplemented!(),
(sk::Sint | sk::Uint, sg::And) => spirv::Op::GroupNonUniformBitwiseAnd,
(sk::Sint | sk::Uint, sg::Or) => spirv::Op::GroupNonUniformBitwiseOr,
(sk::Sint | sk::Uint, sg::Xor) => spirv::Op::GroupNonUniformBitwiseXor,
(sk::Bool, sg::And) => spirv::Op::GroupNonUniformLogicalAnd,
(sk::Bool, sg::Or) => spirv::Op::GroupNonUniformLogicalOr,
(sk::Bool, sg::Xor) => spirv::Op::GroupNonUniformLogicalXor,
(_, sg::And | sg::Or | sg::Xor) => unimplemented!(),
};
let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
use crate::CollectiveOperation as c;
let group_op = match *op {
sg::All | sg::Any => None,
_ => Some(match *collective_op {
c::Reduce => spirv::GroupOperation::Reduce,
c::InclusiveScan => spirv::GroupOperation::InclusiveScan,
c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan,
}),
};
let arg_id = self.cached[argument];
block.body.push(Instruction::group_non_uniform_arithmetic(
spirv_op,
result_type_id,
id,
exec_scope_id,
group_op,
arg_id,
));
self.cached[result] = id;
Ok(())
}
pub(super) fn write_subgroup_gather(
&mut self,
mode: &crate::GatherMode,
argument: Handle<crate::Expression>,
result: Handle<crate::Expression>,
block: &mut Block,
) -> Result<(), Error> {
self.writer.require_any(
"GroupNonUniformBallot",
&[spirv::Capability::GroupNonUniformBallot],
)?;
match *mode {
crate::GatherMode::BroadcastFirst | crate::GatherMode::Broadcast(_) => {
self.writer.require_any(
"GroupNonUniformBallot",
&[spirv::Capability::GroupNonUniformBallot],
)?;
}
crate::GatherMode::Shuffle(_) | crate::GatherMode::ShuffleXor(_) => {
self.writer.require_any(
"GroupNonUniformShuffle",
&[spirv::Capability::GroupNonUniformShuffle],
)?;
}
crate::GatherMode::ShuffleDown(_) | crate::GatherMode::ShuffleUp(_) => {
self.writer.require_any(
"GroupNonUniformShuffleRelative",
&[spirv::Capability::GroupNonUniformShuffleRelative],
)?;
}
}
let id = self.gen_id();
let result_ty = &self.fun_info[result].ty;
let result_type_id = self.get_expression_type_id(result_ty);
let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
let arg_id = self.cached[argument];
match *mode {
crate::GatherMode::BroadcastFirst => {
block
.body
.push(Instruction::group_non_uniform_broadcast_first(
result_type_id,
id,
exec_scope_id,
arg_id,
));
}
crate::GatherMode::Broadcast(index)
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
let index_id = self.cached[index];
let op = match *mode {
crate::GatherMode::BroadcastFirst => unreachable!(),
// Use shuffle to emit broadcast to allow the index to
// be dynamically uniform on Vulkan 1.1. The argument to
// OpGroupNonUniformBroadcast must be a constant pre SPIR-V
// 1.5 (vulkan 1.2)
crate::GatherMode::Broadcast(_) => spirv::Op::GroupNonUniformShuffle,
crate::GatherMode::Shuffle(_) => spirv::Op::GroupNonUniformShuffle,
crate::GatherMode::ShuffleDown(_) => spirv::Op::GroupNonUniformShuffleDown,
crate::GatherMode::ShuffleUp(_) => spirv::Op::GroupNonUniformShuffleUp,
crate::GatherMode::ShuffleXor(_) => spirv::Op::GroupNonUniformShuffleXor,
};
block.body.push(Instruction::group_non_uniform_gather(
op,
result_type_id,
id,
exec_scope_id,
arg_id,
index_id,
));
}
}
self.cached[result] = id;
Ok(())
}
}

View File

@ -1310,7 +1310,11 @@ impl Writer {
spirv::MemorySemantics::WORKGROUP_MEMORY,
flags.contains(crate::Barrier::WORK_GROUP),
);
let exec_scope_id = self.get_index_constant(spirv::Scope::Workgroup as u32);
let exec_scope_id = if flags.contains(crate::Barrier::SUB_GROUP) {
self.get_index_constant(spirv::Scope::Subgroup as u32)
} else {
self.get_index_constant(spirv::Scope::Workgroup as u32)
};
let mem_scope_id = self.get_index_constant(memory_scope as u32);
let semantics_id = self.get_index_constant(semantics.bits());
block.body.push(Instruction::control_barrier(
@ -1585,6 +1589,41 @@ impl Writer {
Bi::WorkGroupId => BuiltIn::WorkgroupId,
Bi::WorkGroupSize => BuiltIn::WorkgroupSize,
Bi::NumWorkGroups => BuiltIn::NumWorkgroups,
// Subgroup
Bi::NumSubgroups => {
self.require_any(
"`num_subgroups` built-in",
&[spirv::Capability::GroupNonUniform],
)?;
BuiltIn::NumSubgroups
}
Bi::SubgroupId => {
self.require_any(
"`subgroup_id` built-in",
&[spirv::Capability::GroupNonUniform],
)?;
BuiltIn::SubgroupId
}
Bi::SubgroupSize => {
self.require_any(
"`subgroup_size` built-in",
&[
spirv::Capability::GroupNonUniform,
spirv::Capability::SubgroupBallotKHR,
],
)?;
BuiltIn::SubgroupSize
}
Bi::SubgroupInvocationId => {
self.require_any(
"`subgroup_invocation_id` built-in",
&[
spirv::Capability::GroupNonUniform,
spirv::Capability::SubgroupBallotKHR,
],
)?;
BuiltIn::SubgroupLocalInvocationId
}
};
self.decorate(id, Decoration::BuiltIn, &[built_in as u32]);

View File

@ -924,8 +924,124 @@ impl<W: Write> Writer<W> {
if barrier.contains(crate::Barrier::WORK_GROUP) {
writeln!(self.out, "{level}workgroupBarrier();")?;
}
if barrier.contains(crate::Barrier::SUB_GROUP) {
writeln!(self.out, "{level}subgroupBarrier();")?;
}
}
Statement::RayQuery { .. } => unreachable!(),
Statement::SubgroupBallot { result, predicate } => {
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
self.start_named_expr(module, result, func_ctx, &res_name)?;
self.named_expressions.insert(result, res_name);
write!(self.out, "subgroupBallot(")?;
if let Some(predicate) = predicate {
self.write_expr(module, predicate, func_ctx)?;
}
writeln!(self.out, ");")?;
}
Statement::SubgroupCollectiveOperation {
op,
collective_op,
argument,
result,
} => {
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
self.start_named_expr(module, result, func_ctx, &res_name)?;
self.named_expressions.insert(result, res_name);
match (collective_op, op) {
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
write!(self.out, "subgroupAll(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
write!(self.out, "subgroupAny(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
write!(self.out, "subgroupAdd(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
write!(self.out, "subgroupMul(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
write!(self.out, "subgroupMax(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
write!(self.out, "subgroupMin(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
write!(self.out, "subgroupAnd(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
write!(self.out, "subgroupOr(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
write!(self.out, "subgroupXor(")?
}
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
write!(self.out, "subgroupExclusiveAdd(")?
}
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
write!(self.out, "subgroupExclusiveMul(")?
}
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
write!(self.out, "subgroupInclusiveAdd(")?
}
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
write!(self.out, "subgroupInclusiveMul(")?
}
_ => unimplemented!(),
}
self.write_expr(module, argument, func_ctx)?;
writeln!(self.out, ");")?;
}
Statement::SubgroupGather {
mode,
argument,
result,
} => {
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
self.start_named_expr(module, result, func_ctx, &res_name)?;
self.named_expressions.insert(result, res_name);
match mode {
crate::GatherMode::BroadcastFirst => {
write!(self.out, "subgroupBroadcastFirst(")?;
}
crate::GatherMode::Broadcast(_) => {
write!(self.out, "subgroupBroadcast(")?;
}
crate::GatherMode::Shuffle(_) => {
write!(self.out, "subgroupShuffle(")?;
}
crate::GatherMode::ShuffleDown(_) => {
write!(self.out, "subgroupShuffleDown(")?;
}
crate::GatherMode::ShuffleUp(_) => {
write!(self.out, "subgroupShuffleUp(")?;
}
crate::GatherMode::ShuffleXor(_) => {
write!(self.out, "subgroupShuffleXor(")?;
}
}
self.write_expr(module, argument, func_ctx)?;
match mode {
crate::GatherMode::BroadcastFirst => {}
crate::GatherMode::Broadcast(index)
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
write!(self.out, ", ")?;
self.write_expr(module, index, func_ctx)?;
}
}
writeln!(self.out, ");")?;
}
}
Ok(())
@ -1698,6 +1814,8 @@ impl<W: Write> Writer<W> {
Expression::CallResult(_)
| Expression::AtomicResult { .. }
| Expression::RayQueryProceedResult
| Expression::SubgroupBallotResult
| Expression::SubgroupOperationResult { .. }
| Expression::WorkGroupUniformLoadResult { .. } => {}
}
@ -1799,6 +1917,10 @@ fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> {
Bi::SampleMask => "sample_mask",
Bi::PrimitiveIndex => "primitive_index",
Bi::ViewIndex => "view_index",
Bi::NumSubgroups => "num_subgroups",
Bi::SubgroupId => "subgroup_id",
Bi::SubgroupSize => "subgroup_size",
Bi::SubgroupInvocationId => "subgroup_invocation_id",
Bi::BaseInstance
| Bi::BaseVertex
| Bi::ClipDistance

View File

@ -72,6 +72,7 @@ impl<'tracer> ExpressionTracer<'tracer> {
| Ex::GlobalVariable(_)
| Ex::LocalVariable(_)
| Ex::CallResult(_)
| Ex::SubgroupBallotResult
| Ex::RayQueryProceedResult => {}
Ex::Constant(handle) => {
@ -192,6 +193,7 @@ impl<'tracer> ExpressionTracer<'tracer> {
Ex::AtomicResult { ty, comparison: _ } => self.types_used.insert(ty),
Ex::WorkGroupUniformLoadResult { ty } => self.types_used.insert(ty),
Ex::ArrayLength(expr) => self.expressions_used.insert(expr),
Ex::SubgroupOperationResult { ty } => self.types_used.insert(ty),
Ex::RayQueryGetIntersection {
query,
committed: _,
@ -223,6 +225,7 @@ impl ModuleMap {
| Ex::GlobalVariable(_)
| Ex::LocalVariable(_)
| Ex::CallResult(_)
| Ex::SubgroupBallotResult
| Ex::RayQueryProceedResult => {}
// All overrides are retained, so their handles never change.
@ -353,6 +356,7 @@ impl ModuleMap {
comparison: _,
} => self.types.adjust(ty),
Ex::WorkGroupUniformLoadResult { ref mut ty } => self.types.adjust(ty),
Ex::SubgroupOperationResult { ref mut ty } => self.types.adjust(ty),
Ex::ArrayLength(ref mut expr) => adjust(expr),
Ex::RayQueryGetIntersection {
ref mut query,

View File

@ -97,6 +97,39 @@ impl FunctionTracer<'_> {
self.expressions_used.insert(query);
self.trace_ray_query_function(fun);
}
St::SubgroupBallot { result, predicate } => {
if let Some(predicate) = predicate {
self.expressions_used.insert(predicate)
}
self.expressions_used.insert(result)
}
St::SubgroupCollectiveOperation {
op: _,
collective_op: _,
argument,
result,
} => {
self.expressions_used.insert(argument);
self.expressions_used.insert(result)
}
St::SubgroupGather {
mode,
argument,
result,
} => {
match mode {
crate::GatherMode::BroadcastFirst => {}
crate::GatherMode::Broadcast(index)
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
self.expressions_used.insert(index)
}
}
self.expressions_used.insert(argument);
self.expressions_used.insert(result)
}
// Trivial statements.
St::Break
@ -250,6 +283,40 @@ impl FunctionMap {
adjust(query);
self.adjust_ray_query_function(fun);
}
St::SubgroupBallot {
ref mut result,
ref mut predicate,
} => {
if let Some(ref mut predicate) = *predicate {
adjust(predicate);
}
adjust(result);
}
St::SubgroupCollectiveOperation {
op: _,
collective_op: _,
ref mut argument,
ref mut result,
} => {
adjust(argument);
adjust(result);
}
St::SubgroupGather {
ref mut mode,
ref mut argument,
ref mut result,
} => {
match *mode {
crate::GatherMode::BroadcastFirst => {}
crate::GatherMode::Broadcast(ref mut index)
| crate::GatherMode::Shuffle(ref mut index)
| crate::GatherMode::ShuffleDown(ref mut index)
| crate::GatherMode::ShuffleUp(ref mut index)
| crate::GatherMode::ShuffleXor(ref mut index) => adjust(index),
}
adjust(argument);
adjust(result);
}
// Trivial statements.
St::Break

View File

@ -153,6 +153,11 @@ pub(super) fn map_builtin(word: spirv::Word, invariant: bool) -> Result<crate::B
Some(Bi::WorkgroupId) => crate::BuiltIn::WorkGroupId,
Some(Bi::WorkgroupSize) => crate::BuiltIn::WorkGroupSize,
Some(Bi::NumWorkgroups) => crate::BuiltIn::NumWorkGroups,
// subgroup
Some(Bi::NumSubgroups) => crate::BuiltIn::NumSubgroups,
Some(Bi::SubgroupId) => crate::BuiltIn::SubgroupId,
Some(Bi::SubgroupSize) => crate::BuiltIn::SubgroupSize,
Some(Bi::SubgroupLocalInvocationId) => crate::BuiltIn::SubgroupInvocationId,
_ => return Err(Error::UnsupportedBuiltIn(word)),
})
}

View File

@ -58,6 +58,8 @@ pub enum Error {
UnknownBinaryOperator(spirv::Op),
#[error("unknown relational function {0:?}")]
UnknownRelationalFunction(spirv::Op),
#[error("unsupported group operation %{0}")]
UnsupportedGroupOperation(spirv::Word),
#[error("invalid parameter {0:?}")]
InvalidParameter(spirv::Op),
#[error("invalid operand count {1} for {0:?}")]

View File

@ -3700,6 +3700,254 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
},
);
}
Op::GroupNonUniformBallot => {
inst.expect(5)?;
block.extend(emitter.finish(ctx.expressions));
let result_type_id = self.next()?;
let result_id = self.next()?;
let exec_scope_id = self.next()?;
let predicate_id = self.next()?;
let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner)
.filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32)
.ok_or(Error::InvalidBarrierScope(exec_scope_id))?;
let predicate = if self
.lookup_constant
.lookup(predicate_id)
.ok()
.filter(|predicate_const| match predicate_const.inner {
Constant::Constant(constant) => matches!(
ctx.gctx().global_expressions[ctx.gctx().constants[constant].init],
crate::Expression::Literal(crate::Literal::Bool(true)),
),
Constant::Override(_) => false,
})
.is_some()
{
None
} else {
let predicate_lookup = self.lookup_expression.lookup(predicate_id)?;
let predicate_handle = get_expr_handle!(predicate_id, predicate_lookup);
Some(predicate_handle)
};
let result_handle = ctx
.expressions
.append(crate::Expression::SubgroupBallotResult, span);
self.lookup_expression.insert(
result_id,
LookupExpression {
handle: result_handle,
type_id: result_type_id,
block_id,
},
);
block.push(
crate::Statement::SubgroupBallot {
result: result_handle,
predicate,
},
span,
);
emitter.start(ctx.expressions);
}
spirv::Op::GroupNonUniformAll
| spirv::Op::GroupNonUniformAny
| spirv::Op::GroupNonUniformIAdd
| spirv::Op::GroupNonUniformFAdd
| spirv::Op::GroupNonUniformIMul
| spirv::Op::GroupNonUniformFMul
| spirv::Op::GroupNonUniformSMax
| spirv::Op::GroupNonUniformUMax
| spirv::Op::GroupNonUniformFMax
| spirv::Op::GroupNonUniformSMin
| spirv::Op::GroupNonUniformUMin
| spirv::Op::GroupNonUniformFMin
| spirv::Op::GroupNonUniformBitwiseAnd
| spirv::Op::GroupNonUniformBitwiseOr
| spirv::Op::GroupNonUniformBitwiseXor
| spirv::Op::GroupNonUniformLogicalAnd
| spirv::Op::GroupNonUniformLogicalOr
| spirv::Op::GroupNonUniformLogicalXor => {
block.extend(emitter.finish(ctx.expressions));
inst.expect(
if matches!(
inst.op,
spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny
) {
5
} else {
6
},
)?;
let result_type_id = self.next()?;
let result_id = self.next()?;
let exec_scope_id = self.next()?;
let collective_op_id = match inst.op {
spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny => {
crate::CollectiveOperation::Reduce
}
_ => {
let group_op_id = self.next()?;
match spirv::GroupOperation::from_u32(group_op_id) {
Some(spirv::GroupOperation::Reduce) => {
crate::CollectiveOperation::Reduce
}
Some(spirv::GroupOperation::InclusiveScan) => {
crate::CollectiveOperation::InclusiveScan
}
Some(spirv::GroupOperation::ExclusiveScan) => {
crate::CollectiveOperation::ExclusiveScan
}
_ => return Err(Error::UnsupportedGroupOperation(group_op_id)),
}
}
};
let argument_id = self.next()?;
let argument_lookup = self.lookup_expression.lookup(argument_id)?;
let argument_handle = get_expr_handle!(argument_id, argument_lookup);
let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner)
.filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32)
.ok_or(Error::InvalidBarrierScope(exec_scope_id))?;
let op_id = match inst.op {
spirv::Op::GroupNonUniformAll => crate::SubgroupOperation::All,
spirv::Op::GroupNonUniformAny => crate::SubgroupOperation::Any,
spirv::Op::GroupNonUniformIAdd | spirv::Op::GroupNonUniformFAdd => {
crate::SubgroupOperation::Add
}
spirv::Op::GroupNonUniformIMul | spirv::Op::GroupNonUniformFMul => {
crate::SubgroupOperation::Mul
}
spirv::Op::GroupNonUniformSMax
| spirv::Op::GroupNonUniformUMax
| spirv::Op::GroupNonUniformFMax => crate::SubgroupOperation::Max,
spirv::Op::GroupNonUniformSMin
| spirv::Op::GroupNonUniformUMin
| spirv::Op::GroupNonUniformFMin => crate::SubgroupOperation::Min,
spirv::Op::GroupNonUniformBitwiseAnd
| spirv::Op::GroupNonUniformLogicalAnd => crate::SubgroupOperation::And,
spirv::Op::GroupNonUniformBitwiseOr
| spirv::Op::GroupNonUniformLogicalOr => crate::SubgroupOperation::Or,
spirv::Op::GroupNonUniformBitwiseXor
| spirv::Op::GroupNonUniformLogicalXor => crate::SubgroupOperation::Xor,
_ => unreachable!(),
};
let result_type = self.lookup_type.lookup(result_type_id)?;
let result_handle = ctx.expressions.append(
crate::Expression::SubgroupOperationResult {
ty: result_type.handle,
},
span,
);
self.lookup_expression.insert(
result_id,
LookupExpression {
handle: result_handle,
type_id: result_type_id,
block_id,
},
);
block.push(
crate::Statement::SubgroupCollectiveOperation {
result: result_handle,
op: op_id,
collective_op: collective_op_id,
argument: argument_handle,
},
span,
);
emitter.start(ctx.expressions);
}
Op::GroupNonUniformBroadcastFirst
| Op::GroupNonUniformBroadcast
| Op::GroupNonUniformShuffle
| Op::GroupNonUniformShuffleDown
| Op::GroupNonUniformShuffleUp
| Op::GroupNonUniformShuffleXor => {
inst.expect(
if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) {
5
} else {
6
},
)?;
block.extend(emitter.finish(ctx.expressions));
let result_type_id = self.next()?;
let result_id = self.next()?;
let exec_scope_id = self.next()?;
let argument_id = self.next()?;
let argument_lookup = self.lookup_expression.lookup(argument_id)?;
let argument_handle = get_expr_handle!(argument_id, argument_lookup);
let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner)
.filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32)
.ok_or(Error::InvalidBarrierScope(exec_scope_id))?;
let mode = if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) {
crate::GatherMode::BroadcastFirst
} else {
let index_id = self.next()?;
let index_lookup = self.lookup_expression.lookup(index_id)?;
let index_handle = get_expr_handle!(index_id, index_lookup);
match inst.op {
spirv::Op::GroupNonUniformBroadcast => {
crate::GatherMode::Broadcast(index_handle)
}
spirv::Op::GroupNonUniformShuffle => {
crate::GatherMode::Shuffle(index_handle)
}
spirv::Op::GroupNonUniformShuffleDown => {
crate::GatherMode::ShuffleDown(index_handle)
}
spirv::Op::GroupNonUniformShuffleUp => {
crate::GatherMode::ShuffleUp(index_handle)
}
spirv::Op::GroupNonUniformShuffleXor => {
crate::GatherMode::ShuffleXor(index_handle)
}
_ => unreachable!(),
}
};
let result_type = self.lookup_type.lookup(result_type_id)?;
let result_handle = ctx.expressions.append(
crate::Expression::SubgroupOperationResult {
ty: result_type.handle,
},
span,
);
self.lookup_expression.insert(
result_id,
LookupExpression {
handle: result_handle,
type_id: result_type_id,
block_id,
},
);
block.push(
crate::Statement::SubgroupGather {
result: result_handle,
mode,
argument: argument_handle,
},
span,
);
emitter.start(ctx.expressions);
}
_ => return Err(Error::UnsupportedInstruction(self.state, inst.op)),
}
};
@ -3824,7 +4072,10 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
| S::Store { .. }
| S::ImageStore { .. }
| S::Atomic { .. }
| S::RayQuery { .. } => {}
| S::RayQuery { .. }
| S::SubgroupBallot { .. }
| S::SubgroupCollectiveOperation { .. }
| S::SubgroupGather { .. } => {}
S::Call {
function: ref mut callee,
ref arguments,

View File

@ -874,6 +874,29 @@ impl Texture {
}
}
enum SubgroupGather {
BroadcastFirst,
Broadcast,
Shuffle,
ShuffleDown,
ShuffleUp,
ShuffleXor,
}
impl SubgroupGather {
pub fn map(word: &str) -> Option<Self> {
Some(match word {
"subgroupBroadcastFirst" => Self::BroadcastFirst,
"subgroupBroadcast" => Self::Broadcast,
"subgroupShuffle" => Self::Shuffle,
"subgroupShuffleDown" => Self::ShuffleDown,
"subgroupShuffleUp" => Self::ShuffleUp,
"subgroupShuffleXor" => Self::ShuffleXor,
_ => return None,
})
}
}
pub struct Lowerer<'source, 'temp> {
index: &'temp Index<'source>,
layouter: Layouter,
@ -2054,6 +2077,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
} else if let Some(fun) = Texture::map(function.name) {
self.texture_sample_helper(fun, arguments, span, ctx)?
} else if let Some((op, cop)) = conv::map_subgroup_operation(function.name) {
return Ok(Some(
self.subgroup_operation_helper(span, op, cop, arguments, ctx)?,
));
} else if let Some(mode) = SubgroupGather::map(function.name) {
return Ok(Some(
self.subgroup_gather_helper(span, mode, arguments, ctx)?,
));
} else {
match function.name {
"select" => {
@ -2221,6 +2252,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.push(crate::Statement::Barrier(crate::Barrier::WORK_GROUP), span);
return Ok(None);
}
"subgroupBarrier" => {
ctx.prepare_args(arguments, 0, span).finish()?;
let rctx = ctx.runtime_expression_ctx(span)?;
rctx.block
.push(crate::Statement::Barrier(crate::Barrier::SUB_GROUP), span);
return Ok(None);
}
"workgroupUniformLoad" => {
let mut args = ctx.prepare_args(arguments, 1, span);
let expr = args.next()?;
@ -2428,6 +2467,22 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
)?;
return Ok(Some(handle));
}
"subgroupBallot" => {
let mut args = ctx.prepare_args(arguments, 0, span);
let predicate = if arguments.len() == 1 {
Some(self.expression(args.next()?, ctx)?)
} else {
None
};
args.finish()?;
let result = ctx
.interrupt_emitter(crate::Expression::SubgroupBallotResult, span)?;
let rctx = ctx.runtime_expression_ctx(span)?;
rctx.block
.push(crate::Statement::SubgroupBallot { result, predicate }, span);
return Ok(Some(result));
}
_ => return Err(Error::UnknownIdent(function.span, function.name)),
}
};
@ -2619,6 +2674,80 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
})
}
fn subgroup_operation_helper(
&mut self,
span: Span,
op: crate::SubgroupOperation,
collective_op: crate::CollectiveOperation,
arguments: &[Handle<ast::Expression<'source>>],
ctx: &mut ExpressionContext<'source, '_, '_>,
) -> Result<Handle<crate::Expression>, Error<'source>> {
let mut args = ctx.prepare_args(arguments, 1, span);
let argument = self.expression(args.next()?, ctx)?;
args.finish()?;
let ty = ctx.register_type(argument)?;
let result =
ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?;
let rctx = ctx.runtime_expression_ctx(span)?;
rctx.block.push(
crate::Statement::SubgroupCollectiveOperation {
op,
collective_op,
argument,
result,
},
span,
);
Ok(result)
}
fn subgroup_gather_helper(
&mut self,
span: Span,
mode: SubgroupGather,
arguments: &[Handle<ast::Expression<'source>>],
ctx: &mut ExpressionContext<'source, '_, '_>,
) -> Result<Handle<crate::Expression>, Error<'source>> {
let mut args = ctx.prepare_args(arguments, 2, span);
let argument = self.expression(args.next()?, ctx)?;
use SubgroupGather as Sg;
let mode = if let Sg::BroadcastFirst = mode {
crate::GatherMode::BroadcastFirst
} else {
let index = self.expression(args.next()?, ctx)?;
match mode {
Sg::Broadcast => crate::GatherMode::Broadcast(index),
Sg::Shuffle => crate::GatherMode::Shuffle(index),
Sg::ShuffleDown => crate::GatherMode::ShuffleDown(index),
Sg::ShuffleUp => crate::GatherMode::ShuffleUp(index),
Sg::ShuffleXor => crate::GatherMode::ShuffleXor(index),
Sg::BroadcastFirst => unreachable!(),
}
};
args.finish()?;
let ty = ctx.register_type(argument)?;
let result =
ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?;
let rctx = ctx.runtime_expression_ctx(span)?;
rctx.block.push(
crate::Statement::SubgroupGather {
mode,
argument,
result,
},
span,
);
Ok(result)
}
fn r#struct(
&mut self,
s: &ast::Struct<'source>,

View File

@ -35,6 +35,11 @@ pub fn map_built_in(word: &str, span: Span) -> Result<crate::BuiltIn, Error<'_>>
"local_invocation_index" => crate::BuiltIn::LocalInvocationIndex,
"workgroup_id" => crate::BuiltIn::WorkGroupId,
"num_workgroups" => crate::BuiltIn::NumWorkGroups,
// subgroup
"num_subgroups" => crate::BuiltIn::NumSubgroups,
"subgroup_id" => crate::BuiltIn::SubgroupId,
"subgroup_size" => crate::BuiltIn::SubgroupSize,
"subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId,
_ => return Err(Error::UnknownBuiltin(span)),
})
}
@ -260,3 +265,26 @@ pub fn map_conservative_depth(
_ => Err(Error::UnknownConservativeDepth(span)),
}
}
pub fn map_subgroup_operation(
word: &str,
) -> Option<(crate::SubgroupOperation, crate::CollectiveOperation)> {
use crate::CollectiveOperation as co;
use crate::SubgroupOperation as sg;
Some(match word {
"subgroupAll" => (sg::All, co::Reduce),
"subgroupAny" => (sg::Any, co::Reduce),
"subgroupAdd" => (sg::Add, co::Reduce),
"subgroupMul" => (sg::Mul, co::Reduce),
"subgroupMin" => (sg::Min, co::Reduce),
"subgroupMax" => (sg::Max, co::Reduce),
"subgroupAnd" => (sg::And, co::Reduce),
"subgroupOr" => (sg::Or, co::Reduce),
"subgroupXor" => (sg::Xor, co::Reduce),
"subgroupExclusiveAdd" => (sg::Add, co::ExclusiveScan),
"subgroupExclusiveMul" => (sg::Mul, co::ExclusiveScan),
"subgroupInclusiveAdd" => (sg::Add, co::InclusiveScan),
"subgroupInclusiveMul" => (sg::Mul, co::InclusiveScan),
_ => return None,
})
}

View File

@ -431,6 +431,11 @@ pub enum BuiltIn {
WorkGroupId,
WorkGroupSize,
NumWorkGroups,
// subgroup
NumSubgroups,
SubgroupId,
SubgroupSize,
SubgroupInvocationId,
}
/// Number of bytes per scalar.
@ -1277,6 +1282,51 @@ pub enum SwizzleComponent {
W = 3,
}
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
pub enum GatherMode {
/// All gather from the active lane with the smallest index
BroadcastFirst,
/// All gather from the same lane at the index given by the expression
Broadcast(Handle<Expression>),
/// Each gathers from a different lane at the index given by the expression
Shuffle(Handle<Expression>),
/// Each gathers from their lane plus the shift given by the expression
ShuffleDown(Handle<Expression>),
/// Each gathers from their lane minus the shift given by the expression
ShuffleUp(Handle<Expression>),
/// Each gathers from their lane xored with the given by the expression
ShuffleXor(Handle<Expression>),
}
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
pub enum SubgroupOperation {
All = 0,
Any = 1,
Add = 2,
Mul = 3,
Min = 4,
Max = 5,
And = 6,
Or = 7,
Xor = 8,
}
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
pub enum CollectiveOperation {
Reduce = 0,
InclusiveScan = 1,
ExclusiveScan = 2,
}
bitflags::bitflags! {
/// Memory barrier flags.
#[cfg_attr(feature = "serialize", derive(Serialize))]
@ -1285,9 +1335,11 @@ bitflags::bitflags! {
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct Barrier: u32 {
/// Barrier affects all `AddressSpace::Storage` accesses.
const STORAGE = 0x1;
const STORAGE = 1 << 0;
/// Barrier affects all `AddressSpace::WorkGroup` accesses.
const WORK_GROUP = 0x2;
const WORK_GROUP = 1 << 1;
/// Barrier synchronizes execution across all invocations within a subgroup that exectue this instruction.
const SUB_GROUP = 1 << 2;
}
}
@ -1588,6 +1640,15 @@ pub enum Expression {
query: Handle<Expression>,
committed: bool,
},
/// Result of a [`SubgroupBallot`] statement.
///
/// [`SubgroupBallot`]: Statement::SubgroupBallot
SubgroupBallotResult,
/// Result of a [`SubgroupCollectiveOperation`] or [`SubgroupGather`] statement.
///
/// [`SubgroupCollectiveOperation`]: Statement::SubgroupCollectiveOperation
/// [`SubgroupGather`]: Statement::SubgroupGather
SubgroupOperationResult { ty: Handle<Type> },
}
pub use block::Block;
@ -1872,6 +1933,39 @@ pub enum Statement {
/// The specific operation we're performing on `query`.
fun: RayQueryFunction,
},
/// Calculate a bitmask using a boolean from each active thread in the subgroup
SubgroupBallot {
/// The [`SubgroupBallotResult`] expression representing this load's result.
///
/// [`SubgroupBallotResult`]: Expression::SubgroupBallotResult
result: Handle<Expression>,
/// The value from this thread to store in the ballot
predicate: Option<Handle<Expression>>,
},
/// Gather a value from another active thread in the subgroup
SubgroupGather {
/// Specifies which thread to gather from
mode: GatherMode,
/// The value to broadcast over
argument: Handle<Expression>,
/// The [`SubgroupOperationResult`] expression representing this load's result.
///
/// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult
result: Handle<Expression>,
},
/// Compute a collective operation across all active threads in the subgroup
SubgroupCollectiveOperation {
/// What operation to compute
op: SubgroupOperation,
/// How to combine the results
collective_op: CollectiveOperation,
/// The value to compute over
argument: Handle<Expression>,
/// The [`SubgroupOperationResult`] expression representing this load's result.
///
/// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult
result: Handle<Expression>,
},
}
/// A function argument.

View File

@ -476,6 +476,8 @@ pub enum ConstantEvaluatorError {
ImageExpression,
#[error("Constants don't support ray query expressions")]
RayQueryExpression,
#[error("Constants don't support subgroup expressions")]
SubgroupExpression,
#[error("Cannot access the type")]
InvalidAccessBase,
#[error("Cannot access at the index")]
@ -884,6 +886,12 @@ impl<'a> ConstantEvaluator<'a> {
Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => {
Err(ConstantEvaluatorError::RayQueryExpression)
}
Expression::SubgroupBallotResult { .. } => {
Err(ConstantEvaluatorError::SubgroupExpression)
}
Expression::SubgroupOperationResult { .. } => {
Err(ConstantEvaluatorError::SubgroupExpression)
}
}
}

View File

@ -37,6 +37,9 @@ pub fn ensure_block_returns(block: &mut crate::Block) {
| S::RayQuery { .. }
| S::Atomic { .. }
| S::WorkGroupUniformLoad { .. }
| S::SubgroupBallot { .. }
| S::SubgroupCollectiveOperation { .. }
| S::SubgroupGather { .. }
| S::Barrier(_)),
)
| None => block.push(S::Return { value: None }, Default::default()),

View File

@ -598,6 +598,7 @@ impl<'a> ResolveContext<'a> {
| crate::BinaryOperator::ShiftRight => past(left)?.clone(),
},
crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty),
crate::Expression::SubgroupOperationResult { ty } => TypeResolution::Handle(ty),
crate::Expression::WorkGroupUniformLoadResult { ty } => TypeResolution::Handle(ty),
crate::Expression::Select { accept, .. } => past(accept)?.clone(),
crate::Expression::Derivative { expr, .. } => past(expr)?.clone(),
@ -885,6 +886,10 @@ impl<'a> ResolveContext<'a> {
.ok_or(ResolveError::MissingSpecialType)?;
TypeResolution::Handle(result)
}
crate::Expression::SubgroupBallotResult => TypeResolution::Value(Ti::Vector {
scalar: crate::Scalar::U32,
size: crate::VectorSize::Quad,
}),
})
}
}

View File

@ -787,6 +787,14 @@ impl FunctionInfo {
non_uniform_result: self.add_ref(query),
requirements: UniformityRequirements::empty(),
},
E::SubgroupBallotResult => Uniformity {
non_uniform_result: Some(handle),
requirements: UniformityRequirements::empty(),
},
E::SubgroupOperationResult { .. } => Uniformity {
non_uniform_result: Some(handle),
requirements: UniformityRequirements::empty(),
},
};
let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
@ -1029,6 +1037,42 @@ impl FunctionInfo {
}
FunctionUniformity::new()
}
S::SubgroupBallot {
result: _,
predicate,
} => {
if let Some(predicate) = predicate {
let _ = self.add_ref(predicate);
}
FunctionUniformity::new()
}
S::SubgroupCollectiveOperation {
op: _,
collective_op: _,
argument,
result: _,
} => {
let _ = self.add_ref(argument);
FunctionUniformity::new()
}
S::SubgroupGather {
mode,
argument,
result: _,
} => {
let _ = self.add_ref(argument);
match mode {
crate::GatherMode::BroadcastFirst => {}
crate::GatherMode::Broadcast(index)
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
let _ = self.add_ref(index);
}
}
FunctionUniformity::new()
}
};
disruptor = disruptor.or(uniformity.exit_disruptor());

View File

@ -1641,6 +1641,7 @@ impl super::Validator {
return Err(ExpressionError::InvalidRayQueryType(query));
}
},
E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages,
};
Ok(stages)
}

View File

@ -47,6 +47,19 @@ pub enum AtomicError {
ResultTypeMismatch(Handle<crate::Expression>),
}
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum SubgroupError {
#[error("Operand {0:?} has invalid type.")]
InvalidOperand(Handle<crate::Expression>),
#[error("Result type for {0:?} doesn't match the statement")]
ResultTypeMismatch(Handle<crate::Expression>),
#[error("Support for subgroup operation {0:?} is required")]
UnsupportedOperation(super::SubgroupOperationSet),
#[error("Unknown operation")]
UnknownOperation,
}
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum LocalVariableError {
@ -135,6 +148,8 @@ pub enum FunctionError {
InvalidRayDescriptor(Handle<crate::Expression>),
#[error("Ray Query {0:?} does not have a matching type")]
InvalidRayQueryType(Handle<crate::Type>),
#[error("Shader requires capability {0:?}")]
MissingCapability(super::Capabilities),
#[error(
"Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
)]
@ -155,6 +170,8 @@ pub enum FunctionError {
WorkgroupUniformLoadExpressionMismatch(Handle<crate::Expression>),
#[error("The expression {0:?} is not valid as a WorkGroupUniformLoad argument. It should be a Pointer in Workgroup address space")]
WorkgroupUniformLoadInvalidPointer(Handle<crate::Expression>),
#[error("Subgroup operation is invalid")]
InvalidSubgroup(#[from] SubgroupError),
}
bitflags::bitflags! {
@ -399,6 +416,127 @@ impl super::Validator {
}
Ok(())
}
fn validate_subgroup_operation(
&mut self,
op: &crate::SubgroupOperation,
collective_op: &crate::CollectiveOperation,
argument: Handle<crate::Expression>,
result: Handle<crate::Expression>,
context: &BlockContext,
) -> Result<(), WithSpan<FunctionError>> {
let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?;
let (is_scalar, scalar) = match *argument_inner {
crate::TypeInner::Scalar(scalar) => (true, scalar),
crate::TypeInner::Vector { scalar, .. } => (false, scalar),
_ => {
log::error!("Subgroup operand type {:?}", argument_inner);
return Err(SubgroupError::InvalidOperand(argument)
.with_span_handle(argument, context.expressions)
.into_other());
}
};
use crate::ScalarKind as sk;
use crate::SubgroupOperation as sg;
match (scalar.kind, *op) {
(sk::Bool, sg::All | sg::Any) if is_scalar => {}
(sk::Sint | sk::Uint | sk::Float, sg::Add | sg::Mul | sg::Min | sg::Max) => {}
(sk::Sint | sk::Uint, sg::And | sg::Or | sg::Xor) => {}
(_, _) => {
log::error!("Subgroup operand type {:?}", argument_inner);
return Err(SubgroupError::InvalidOperand(argument)
.with_span_handle(argument, context.expressions)
.into_other());
}
};
use crate::CollectiveOperation as co;
match (*collective_op, *op) {
(
co::Reduce,
sg::All
| sg::Any
| sg::Add
| sg::Mul
| sg::Min
| sg::Max
| sg::And
| sg::Or
| sg::Xor,
) => {}
(co::InclusiveScan | co::ExclusiveScan, sg::Add | sg::Mul) => {}
(_, _) => {
return Err(SubgroupError::UnknownOperation.with_span().into_other());
}
};
self.emit_expression(result, context)?;
match context.expressions[result] {
crate::Expression::SubgroupOperationResult { ty }
if { &context.types[ty].inner == argument_inner } => {}
_ => {
return Err(SubgroupError::ResultTypeMismatch(result)
.with_span_handle(result, context.expressions)
.into_other())
}
}
Ok(())
}
fn validate_subgroup_gather(
&mut self,
mode: &crate::GatherMode,
argument: Handle<crate::Expression>,
result: Handle<crate::Expression>,
context: &BlockContext,
) -> Result<(), WithSpan<FunctionError>> {
match *mode {
crate::GatherMode::BroadcastFirst => {}
crate::GatherMode::Broadcast(index)
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
let index_ty = context.resolve_type(index, &self.valid_expression_set)?;
match *index_ty {
crate::TypeInner::Scalar(crate::Scalar::U32) => {}
_ => {
log::error!(
"Subgroup gather index type {:?}, expected unsigned int",
index_ty
);
return Err(SubgroupError::InvalidOperand(argument)
.with_span_handle(index, context.expressions)
.into_other());
}
}
}
}
let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?;
if !matches!(*argument_inner,
crate::TypeInner::Scalar ( scalar, .. ) | crate::TypeInner::Vector { scalar, .. }
if matches!(scalar.kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float)
) {
log::error!("Subgroup gather operand type {:?}", argument_inner);
return Err(SubgroupError::InvalidOperand(argument)
.with_span_handle(argument, context.expressions)
.into_other());
}
self.emit_expression(result, context)?;
match context.expressions[result] {
crate::Expression::SubgroupOperationResult { ty }
if { &context.types[ty].inner == argument_inner } => {}
_ => {
return Err(SubgroupError::ResultTypeMismatch(result)
.with_span_handle(result, context.expressions)
.into_other())
}
}
Ok(())
}
fn validate_block_impl(
&mut self,
@ -613,8 +751,30 @@ impl super::Validator {
stages &= super::ShaderStages::FRAGMENT;
finished = true;
}
S::Barrier(_) => {
S::Barrier(barrier) => {
stages &= super::ShaderStages::COMPUTE;
if barrier.contains(crate::Barrier::SUB_GROUP) {
if !self.capabilities.contains(
super::Capabilities::SUBGROUP | super::Capabilities::SUBGROUP_BARRIER,
) {
return Err(FunctionError::MissingCapability(
super::Capabilities::SUBGROUP
| super::Capabilities::SUBGROUP_BARRIER,
)
.with_span_static(span, "missing capability for this operation"));
}
if !self
.subgroup_operations
.contains(super::SubgroupOperationSet::BASIC)
{
return Err(FunctionError::InvalidSubgroup(
SubgroupError::UnsupportedOperation(
super::SubgroupOperationSet::BASIC,
),
)
.with_span_static(span, "support for this operation is not present"));
}
}
}
S::Store { pointer, value } => {
let mut current = pointer;
@ -904,6 +1064,86 @@ impl super::Validator {
crate::RayQueryFunction::Terminate => {}
}
}
S::SubgroupBallot { result, predicate } => {
stages &= self.subgroup_stages;
if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
return Err(FunctionError::MissingCapability(
super::Capabilities::SUBGROUP,
)
.with_span_static(span, "missing capability for this operation"));
}
if !self
.subgroup_operations
.contains(super::SubgroupOperationSet::BALLOT)
{
return Err(FunctionError::InvalidSubgroup(
SubgroupError::UnsupportedOperation(
super::SubgroupOperationSet::BALLOT,
),
)
.with_span_static(span, "support for this operation is not present"));
}
if let Some(predicate) = predicate {
let predicate_inner =
context.resolve_type(predicate, &self.valid_expression_set)?;
if !matches!(
*predicate_inner,
crate::TypeInner::Scalar(crate::Scalar::BOOL,)
) {
log::error!(
"Subgroup ballot predicate type {:?} expected bool",
predicate_inner
);
return Err(SubgroupError::InvalidOperand(predicate)
.with_span_handle(predicate, context.expressions)
.into_other());
}
}
self.emit_expression(result, context)?;
}
S::SubgroupCollectiveOperation {
ref op,
ref collective_op,
argument,
result,
} => {
stages &= self.subgroup_stages;
if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
return Err(FunctionError::MissingCapability(
super::Capabilities::SUBGROUP,
)
.with_span_static(span, "missing capability for this operation"));
}
let operation = op.required_operations();
if !self.subgroup_operations.contains(operation) {
return Err(FunctionError::InvalidSubgroup(
SubgroupError::UnsupportedOperation(operation),
)
.with_span_static(span, "support for this operation is not present"));
}
self.validate_subgroup_operation(op, collective_op, argument, result, context)?;
}
S::SubgroupGather {
ref mode,
argument,
result,
} => {
stages &= self.subgroup_stages;
if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
return Err(FunctionError::MissingCapability(
super::Capabilities::SUBGROUP,
)
.with_span_static(span, "missing capability for this operation"));
}
let operation = mode.required_operations();
if !self.subgroup_operations.contains(operation) {
return Err(FunctionError::InvalidSubgroup(
SubgroupError::UnsupportedOperation(operation),
)
.with_span_static(span, "support for this operation is not present"));
}
self.validate_subgroup_gather(mode, argument, result, context)?;
}
}
}
Ok(BlockInfo { stages, finished })

View File

@ -420,6 +420,8 @@ impl super::Validator {
}
crate::Expression::AtomicResult { .. }
| crate::Expression::RayQueryProceedResult
| crate::Expression::SubgroupBallotResult
| crate::Expression::SubgroupOperationResult { .. }
| crate::Expression::WorkGroupUniformLoadResult { .. } => (),
crate::Expression::ArrayLength(array) => {
handle.check_dep(array)?;
@ -565,6 +567,38 @@ impl super::Validator {
}
Ok(())
}
crate::Statement::SubgroupBallot { result, predicate } => {
validate_expr_opt(predicate)?;
validate_expr(result)?;
Ok(())
}
crate::Statement::SubgroupCollectiveOperation {
op: _,
collective_op: _,
argument,
result,
} => {
validate_expr(argument)?;
validate_expr(result)?;
Ok(())
}
crate::Statement::SubgroupGather {
mode,
argument,
result,
} => {
validate_expr(argument)?;
match mode {
crate::GatherMode::BroadcastFirst => {}
crate::GatherMode::Broadcast(index)
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => validate_expr(index)?,
}
validate_expr(result)?;
Ok(())
}
crate::Statement::Break
| crate::Statement::Continue
| crate::Statement::Kill

View File

@ -77,6 +77,8 @@ pub enum VaryingError {
location: u32,
attribute: &'static str,
},
#[error("Workgroup size is multi dimensional, @builtin(subgroup_id) and @builtin(subgroup_invocation_id) are not supported.")]
InvalidMultiDimensionalSubgroupBuiltIn,
}
#[derive(Clone, Debug, thiserror::Error)]
@ -140,6 +142,7 @@ struct VaryingContext<'a> {
impl VaryingContext<'_> {
fn validate_impl(
&mut self,
ep: &crate::EntryPoint,
ty: Handle<crate::Type>,
binding: &crate::Binding,
) -> Result<(), VaryingError> {
@ -167,12 +170,24 @@ impl VaryingContext<'_> {
Bi::PrimitiveIndex => Capabilities::PRIMITIVE_INDEX,
Bi::ViewIndex => Capabilities::MULTIVIEW,
Bi::SampleIndex => Capabilities::MULTISAMPLED_SHADING,
Bi::NumSubgroups
| Bi::SubgroupId
| Bi::SubgroupSize
| Bi::SubgroupInvocationId => Capabilities::SUBGROUP,
_ => Capabilities::empty(),
};
if !self.capabilities.contains(required) {
return Err(VaryingError::UnsupportedCapability(required));
}
if matches!(
built_in,
crate::BuiltIn::SubgroupId | crate::BuiltIn::SubgroupInvocationId
) && ep.workgroup_size[1..].iter().any(|&s| s > 1)
{
return Err(VaryingError::InvalidMultiDimensionalSubgroupBuiltIn);
}
let (visible, type_good) = match built_in {
Bi::BaseInstance | Bi::BaseVertex | Bi::InstanceIndex | Bi::VertexIndex => (
self.stage == St::Vertex && !self.output,
@ -254,6 +269,17 @@ impl VaryingContext<'_> {
scalar: crate::Scalar::U32,
},
),
Bi::NumSubgroups | Bi::SubgroupId => (
self.stage == St::Compute && !self.output,
*ty_inner == Ti::Scalar(crate::Scalar::U32),
),
Bi::SubgroupSize | Bi::SubgroupInvocationId => (
match self.stage {
St::Compute | St::Fragment => !self.output,
St::Vertex => false,
},
*ty_inner == Ti::Scalar(crate::Scalar::U32),
),
};
if !visible {
@ -354,13 +380,14 @@ impl VaryingContext<'_> {
fn validate(
&mut self,
ep: &crate::EntryPoint,
ty: Handle<crate::Type>,
binding: Option<&crate::Binding>,
) -> Result<(), WithSpan<VaryingError>> {
let span_context = self.types.get_span_context(ty);
match binding {
Some(binding) => self
.validate_impl(ty, binding)
.validate_impl(ep, ty, binding)
.map_err(|e| e.with_span_context(span_context)),
None => {
match self.types[ty].inner {
@ -377,7 +404,7 @@ impl VaryingContext<'_> {
}
}
Some(ref binding) => self
.validate_impl(member.ty, binding)
.validate_impl(ep, member.ty, binding)
.map_err(|e| e.with_span_context(span_context))?,
}
}
@ -609,7 +636,7 @@ impl super::Validator {
capabilities: self.capabilities,
flags: self.flags,
};
ctx.validate(fa.ty, fa.binding.as_ref())
ctx.validate(ep, fa.ty, fa.binding.as_ref())
.map_err_inner(|e| EntryPointError::Argument(index as u32, e).with_span())?;
}
@ -627,7 +654,7 @@ impl super::Validator {
capabilities: self.capabilities,
flags: self.flags,
};
ctx.validate(fr.ty, fr.binding.as_ref())
ctx.validate(ep, fr.ty, fr.binding.as_ref())
.map_err_inner(|e| EntryPointError::Result(e).with_span())?;
if ctx.second_blend_source {
// Only the first location may be used when dual source blending

View File

@ -77,7 +77,7 @@ bitflags::bitflags! {
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct Capabilities: u16 {
pub struct Capabilities: u32 {
/// Support for [`AddressSpace:PushConstant`].
const PUSH_CONSTANT = 0x1;
/// Float values with width = 8.
@ -110,6 +110,10 @@ bitflags::bitflags! {
const CUBE_ARRAY_TEXTURES = 0x4000;
/// Support for 64-bit signed and unsigned integers.
const SHADER_INT64 = 0x8000;
/// Support for subgroup operations.
const SUBGROUP = 0x10000;
/// Support for subgroup barriers.
const SUBGROUP_BARRIER = 0x20000;
}
}
@ -119,6 +123,57 @@ impl Default for Capabilities {
}
}
bitflags::bitflags! {
/// Supported subgroup operations
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct SubgroupOperationSet: u8 {
/// Elect, Barrier
const BASIC = 1 << 0;
/// Any, All
const VOTE = 1 << 1;
/// reductions, scans
const ARITHMETIC = 1 << 2;
/// ballot, broadcast
const BALLOT = 1 << 3;
/// shuffle, shuffle xor
const SHUFFLE = 1 << 4;
/// shuffle up, down
const SHUFFLE_RELATIVE = 1 << 5;
// We don't support these operations yet
// /// Clustered
// const CLUSTERED = 1 << 6;
// /// Quad supported
// const QUAD_FRAGMENT_COMPUTE = 1 << 7;
// /// Quad supported in all stages
// const QUAD_ALL_STAGES = 1 << 8;
}
}
impl super::SubgroupOperation {
const fn required_operations(&self) -> SubgroupOperationSet {
use SubgroupOperationSet as S;
match *self {
Self::All | Self::Any => S::VOTE,
Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => {
S::ARITHMETIC
}
}
}
}
impl super::GatherMode {
const fn required_operations(&self) -> SubgroupOperationSet {
use SubgroupOperationSet as S;
match *self {
Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT,
Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE,
Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE,
}
}
}
bitflags::bitflags! {
/// Validation flags.
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
@ -166,6 +221,8 @@ impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
pub struct Validator {
flags: ValidationFlags,
capabilities: Capabilities,
subgroup_stages: ShaderStages,
subgroup_operations: SubgroupOperationSet,
types: Vec<r#type::TypeInfo>,
layouter: Layouter,
location_mask: BitSet,
@ -317,6 +374,8 @@ impl Validator {
Validator {
flags,
capabilities,
subgroup_stages: ShaderStages::empty(),
subgroup_operations: SubgroupOperationSet::empty(),
types: Vec::new(),
layouter: Layouter::default(),
location_mask: BitSet::new(),
@ -329,6 +388,16 @@ impl Validator {
}
}
pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self {
self.subgroup_stages = stages;
self
}
pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self {
self.subgroup_operations = operations;
self
}
/// Reset the validator internals
pub fn reset(&mut self) {
self.types.clear();

View File

@ -0,0 +1,27 @@
(
god_mode: true,
spv: (
version: (1, 3),
),
msl: (
lang_version: (2, 4),
per_entry_point_map: {},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: false,
zero_initialize_workgroup_memory: true,
),
glsl: (
version: Desktop(430),
writer_flags: (""),
binding_map: { },
zero_initialize_workgroup_memory: true,
),
hlsl: (
shader_model: V6_0,
binding_map: {},
fake_missing_bindings: true,
special_constants_binding: None,
zero_initialize_workgroup_memory: true,
),
)

Binary file not shown.

View File

@ -0,0 +1,75 @@
; SPIR-V
; Version: 1.3
; Generator: rspirv
; Bound: 54
OpCapability Shader
OpCapability GroupNonUniform
OpCapability GroupNonUniformBallot
OpCapability GroupNonUniformVote
OpCapability GroupNonUniformArithmetic
OpCapability GroupNonUniformShuffle
OpCapability GroupNonUniformShuffleRelative
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %15 "main" %6 %9 %11 %13
OpExecutionMode %15 LocalSize 1 1 1
OpDecorate %6 BuiltIn NumSubgroups
OpDecorate %9 BuiltIn SubgroupId
OpDecorate %11 BuiltIn SubgroupSize
OpDecorate %13 BuiltIn SubgroupLocalInvocationId
%2 = OpTypeVoid
%3 = OpTypeInt 32 0
%4 = OpTypeBool
%7 = OpTypePointer Input %3
%6 = OpVariable %7 Input
%9 = OpVariable %7 Input
%11 = OpVariable %7 Input
%13 = OpVariable %7 Input
%16 = OpTypeFunction %2
%17 = OpConstant %3 1
%18 = OpConstant %3 0
%19 = OpConstant %3 4
%21 = OpConstant %3 3
%22 = OpConstant %3 2
%23 = OpConstant %3 8
%26 = OpTypeVector %3 4
%28 = OpConstantTrue %4
%15 = OpFunction %2 None %16
%5 = OpLabel
%8 = OpLoad %3 %6
%10 = OpLoad %3 %9
%12 = OpLoad %3 %11
%14 = OpLoad %3 %13
OpBranch %20
%20 = OpLabel
OpControlBarrier %21 %22 %23
%24 = OpBitwiseAnd %3 %14 %17
%25 = OpIEqual %4 %24 %17
%27 = OpGroupNonUniformBallot %26 %21 %25
%29 = OpGroupNonUniformBallot %26 %21 %28
%30 = OpINotEqual %4 %14 %18
%31 = OpGroupNonUniformAll %4 %21 %30
%32 = OpIEqual %4 %14 %18
%33 = OpGroupNonUniformAny %4 %21 %32
%34 = OpGroupNonUniformIAdd %3 %21 Reduce %14
%35 = OpGroupNonUniformIMul %3 %21 Reduce %14
%36 = OpGroupNonUniformUMin %3 %21 Reduce %14
%37 = OpGroupNonUniformUMax %3 %21 Reduce %14
%38 = OpGroupNonUniformBitwiseAnd %3 %21 Reduce %14
%39 = OpGroupNonUniformBitwiseOr %3 %21 Reduce %14
%40 = OpGroupNonUniformBitwiseXor %3 %21 Reduce %14
%41 = OpGroupNonUniformIAdd %3 %21 ExclusiveScan %14
%42 = OpGroupNonUniformIMul %3 %21 ExclusiveScan %14
%43 = OpGroupNonUniformIAdd %3 %21 InclusiveScan %14
%44 = OpGroupNonUniformIMul %3 %21 InclusiveScan %14
%45 = OpGroupNonUniformBroadcastFirst %3 %21 %14
%46 = OpGroupNonUniformBroadcast %3 %21 %14 %19
%47 = OpISub %3 %12 %17
%48 = OpISub %3 %47 %14
%49 = OpGroupNonUniformShuffle %3 %21 %14 %48
%50 = OpGroupNonUniformShuffleDown %3 %21 %14 %17
%51 = OpGroupNonUniformShuffleUp %3 %21 %14 %17
%52 = OpISub %3 %12 %17
%53 = OpGroupNonUniformShuffleXor %3 %21 %14 %52
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,27 @@
(
god_mode: true,
spv: (
version: (1, 3),
),
msl: (
lang_version: (2, 4),
per_entry_point_map: {},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: false,
zero_initialize_workgroup_memory: true,
),
glsl: (
version: Desktop(430),
writer_flags: (""),
binding_map: { },
zero_initialize_workgroup_memory: true,
),
hlsl: (
shader_model: V6_0,
binding_map: {},
fake_missing_bindings: true,
special_constants_binding: None,
zero_initialize_workgroup_memory: true,
),
)

View File

@ -0,0 +1,37 @@
struct Structure {
@builtin(num_subgroups) num_subgroups: u32,
@builtin(subgroup_size) subgroup_size: u32,
};
@compute @workgroup_size(1)
fn main(
sizes: Structure,
@builtin(subgroup_id) subgroup_id: u32,
@builtin(subgroup_invocation_id) subgroup_invocation_id: u32,
) {
subgroupBarrier();
subgroupBallot((subgroup_invocation_id & 1u) == 1u);
subgroupBallot();
subgroupAll(subgroup_invocation_id != 0u);
subgroupAny(subgroup_invocation_id == 0u);
subgroupAdd(subgroup_invocation_id);
subgroupMul(subgroup_invocation_id);
subgroupMin(subgroup_invocation_id);
subgroupMax(subgroup_invocation_id);
subgroupAnd(subgroup_invocation_id);
subgroupOr(subgroup_invocation_id);
subgroupXor(subgroup_invocation_id);
subgroupExclusiveAdd(subgroup_invocation_id);
subgroupExclusiveMul(subgroup_invocation_id);
subgroupInclusiveAdd(subgroup_invocation_id);
subgroupInclusiveMul(subgroup_invocation_id);
subgroupBroadcastFirst(subgroup_invocation_id);
subgroupBroadcast(subgroup_invocation_id, 4u);
subgroupShuffle(subgroup_invocation_id, sizes.subgroup_size - 1u - subgroup_invocation_id);
subgroupShuffleDown(subgroup_invocation_id, 1u);
subgroupShuffleUp(subgroup_invocation_id, 1u);
subgroupShuffleXor(subgroup_invocation_id, sizes.subgroup_size - 1u);
}

View File

@ -0,0 +1,58 @@
#version 430 core
#extension GL_ARB_compute_shader : require
#extension GL_KHR_shader_subgroup_basic : require
#extension GL_KHR_shader_subgroup_vote : require
#extension GL_KHR_shader_subgroup_arithmetic : require
#extension GL_KHR_shader_subgroup_ballot : require
#extension GL_KHR_shader_subgroup_shuffle : require
#extension GL_KHR_shader_subgroup_shuffle_relative : require
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
uint num_subgroups_1 = 0u;
uint subgroup_id_1 = 0u;
uint subgroup_size_1 = 0u;
uint subgroup_invocation_id_1 = 0u;
void main_1() {
uint _e5 = subgroup_size_1;
uint _e6 = subgroup_invocation_id_1;
uvec4 _e9 = subgroupBallot(((_e6 & 1u) == 1u));
uvec4 _e10 = subgroupBallot(true);
bool _e12 = subgroupAll((_e6 != 0u));
bool _e14 = subgroupAny((_e6 == 0u));
uint _e15 = subgroupAdd(_e6);
uint _e16 = subgroupMul(_e6);
uint _e17 = subgroupMin(_e6);
uint _e18 = subgroupMax(_e6);
uint _e19 = subgroupAnd(_e6);
uint _e20 = subgroupOr(_e6);
uint _e21 = subgroupXor(_e6);
uint _e22 = subgroupExclusiveAdd(_e6);
uint _e23 = subgroupExclusiveMul(_e6);
uint _e24 = subgroupInclusiveAdd(_e6);
uint _e25 = subgroupInclusiveMul(_e6);
uint _e26 = subgroupBroadcastFirst(_e6);
uint _e27 = subgroupBroadcast(_e6, 4u);
uint _e30 = subgroupShuffle(_e6, ((_e5 - 1u) - _e6));
uint _e31 = subgroupShuffleDown(_e6, 1u);
uint _e32 = subgroupShuffleUp(_e6, 1u);
uint _e34 = subgroupShuffleXor(_e6, (_e5 - 1u));
return;
}
void main() {
uint num_subgroups = gl_NumSubgroups;
uint subgroup_id = gl_SubgroupID;
uint subgroup_size = gl_SubgroupSize;
uint subgroup_invocation_id = gl_SubgroupInvocationID;
num_subgroups_1 = num_subgroups;
subgroup_id_1 = subgroup_id;
subgroup_size_1 = subgroup_size;
subgroup_invocation_id_1 = subgroup_invocation_id;
main_1();
}

View File

@ -0,0 +1,45 @@
#version 430 core
#extension GL_ARB_compute_shader : require
#extension GL_KHR_shader_subgroup_basic : require
#extension GL_KHR_shader_subgroup_vote : require
#extension GL_KHR_shader_subgroup_arithmetic : require
#extension GL_KHR_shader_subgroup_ballot : require
#extension GL_KHR_shader_subgroup_shuffle : require
#extension GL_KHR_shader_subgroup_shuffle_relative : require
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
struct Structure {
uint num_subgroups;
uint subgroup_size;
};
void main() {
Structure sizes = Structure(gl_NumSubgroups, gl_SubgroupSize);
uint subgroup_id = gl_SubgroupID;
uint subgroup_invocation_id = gl_SubgroupInvocationID;
subgroupMemoryBarrier();
barrier();
uvec4 _e7 = subgroupBallot(((subgroup_invocation_id & 1u) == 1u));
uvec4 _e8 = subgroupBallot(true);
bool _e11 = subgroupAll((subgroup_invocation_id != 0u));
bool _e14 = subgroupAny((subgroup_invocation_id == 0u));
uint _e15 = subgroupAdd(subgroup_invocation_id);
uint _e16 = subgroupMul(subgroup_invocation_id);
uint _e17 = subgroupMin(subgroup_invocation_id);
uint _e18 = subgroupMax(subgroup_invocation_id);
uint _e19 = subgroupAnd(subgroup_invocation_id);
uint _e20 = subgroupOr(subgroup_invocation_id);
uint _e21 = subgroupXor(subgroup_invocation_id);
uint _e22 = subgroupExclusiveAdd(subgroup_invocation_id);
uint _e23 = subgroupExclusiveMul(subgroup_invocation_id);
uint _e24 = subgroupInclusiveAdd(subgroup_invocation_id);
uint _e25 = subgroupInclusiveMul(subgroup_invocation_id);
uint _e26 = subgroupBroadcastFirst(subgroup_invocation_id);
uint _e28 = subgroupBroadcast(subgroup_invocation_id, 4u);
uint _e33 = subgroupShuffle(subgroup_invocation_id, ((sizes.subgroup_size - 1u) - subgroup_invocation_id));
uint _e35 = subgroupShuffleDown(subgroup_invocation_id, 1u);
uint _e37 = subgroupShuffleUp(subgroup_invocation_id, 1u);
uint _e41 = subgroupShuffleXor(subgroup_invocation_id, (sizes.subgroup_size - 1u));
return;
}

View File

@ -0,0 +1,50 @@
static uint num_subgroups_1 = (uint)0;
static uint subgroup_id_1 = (uint)0;
static uint subgroup_size_1 = (uint)0;
static uint subgroup_invocation_id_1 = (uint)0;
struct ComputeInput_main {
uint __local_invocation_index : SV_GroupIndex;
};
void main_1()
{
uint _expr5 = subgroup_size_1;
uint _expr6 = subgroup_invocation_id_1;
const uint4 _e9 = WaveActiveBallot(((_expr6 & 1u) == 1u));
const uint4 _e10 = WaveActiveBallot(true);
const bool _e12 = WaveActiveAllTrue((_expr6 != 0u));
const bool _e14 = WaveActiveAnyTrue((_expr6 == 0u));
const uint _e15 = WaveActiveSum(_expr6);
const uint _e16 = WaveActiveProduct(_expr6);
const uint _e17 = WaveActiveMin(_expr6);
const uint _e18 = WaveActiveMax(_expr6);
const uint _e19 = WaveActiveBitAnd(_expr6);
const uint _e20 = WaveActiveBitOr(_expr6);
const uint _e21 = WaveActiveBitXor(_expr6);
const uint _e22 = WavePrefixSum(_expr6);
const uint _e23 = WavePrefixProduct(_expr6);
const uint _e24 = _expr6 + WavePrefixSum(_expr6);
const uint _e25 = _expr6 * WavePrefixProduct(_expr6);
const uint _e26 = WaveReadLaneFirst(_expr6);
const uint _e27 = WaveReadLaneAt(_expr6, 4u);
const uint _e30 = WaveReadLaneAt(_expr6, ((_expr5 - 1u) - _expr6));
const uint _e31 = WaveReadLaneAt(_expr6, WaveGetLaneIndex() + 1u);
const uint _e32 = WaveReadLaneAt(_expr6, WaveGetLaneIndex() - 1u);
const uint _e34 = WaveReadLaneAt(_expr6, WaveGetLaneIndex() ^ (_expr5 - 1u));
return;
}
[numthreads(1, 1, 1)]
void main(ComputeInput_main computeinput_main)
{
uint num_subgroups = (1u + WaveGetLaneCount() - 1u) / WaveGetLaneCount();
uint subgroup_id = computeinput_main.__local_invocation_index / WaveGetLaneCount();
uint subgroup_size = WaveGetLaneCount();
uint subgroup_invocation_id = WaveGetLaneIndex();
num_subgroups_1 = num_subgroups;
subgroup_id_1 = subgroup_id;
subgroup_size_1 = subgroup_size;
subgroup_invocation_id_1 = subgroup_invocation_id;
main_1();
}

View File

@ -0,0 +1,12 @@
(
vertex:[
],
fragment:[
],
compute:[
(
entry_point:"main",
target_profile:"cs_6_0",
),
],
)

View File

@ -0,0 +1,38 @@
struct Structure {
uint num_subgroups;
uint subgroup_size;
};
struct ComputeInput_main {
uint __local_invocation_index : SV_GroupIndex;
};
[numthreads(1, 1, 1)]
void main(ComputeInput_main computeinput_main)
{
Structure sizes = { (1u + WaveGetLaneCount() - 1u) / WaveGetLaneCount(), WaveGetLaneCount() };
uint subgroup_id = computeinput_main.__local_invocation_index / WaveGetLaneCount();
uint subgroup_invocation_id = WaveGetLaneIndex();
const uint4 _e7 = WaveActiveBallot(((subgroup_invocation_id & 1u) == 1u));
const uint4 _e8 = WaveActiveBallot(true);
const bool _e11 = WaveActiveAllTrue((subgroup_invocation_id != 0u));
const bool _e14 = WaveActiveAnyTrue((subgroup_invocation_id == 0u));
const uint _e15 = WaveActiveSum(subgroup_invocation_id);
const uint _e16 = WaveActiveProduct(subgroup_invocation_id);
const uint _e17 = WaveActiveMin(subgroup_invocation_id);
const uint _e18 = WaveActiveMax(subgroup_invocation_id);
const uint _e19 = WaveActiveBitAnd(subgroup_invocation_id);
const uint _e20 = WaveActiveBitOr(subgroup_invocation_id);
const uint _e21 = WaveActiveBitXor(subgroup_invocation_id);
const uint _e22 = WavePrefixSum(subgroup_invocation_id);
const uint _e23 = WavePrefixProduct(subgroup_invocation_id);
const uint _e24 = subgroup_invocation_id + WavePrefixSum(subgroup_invocation_id);
const uint _e25 = subgroup_invocation_id * WavePrefixProduct(subgroup_invocation_id);
const uint _e26 = WaveReadLaneFirst(subgroup_invocation_id);
const uint _e28 = WaveReadLaneAt(subgroup_invocation_id, 4u);
const uint _e33 = WaveReadLaneAt(subgroup_invocation_id, ((sizes.subgroup_size - 1u) - subgroup_invocation_id));
const uint _e35 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() + 1u);
const uint _e37 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() - 1u);
const uint _e41 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() ^ (sizes.subgroup_size - 1u));
return;
}

View File

@ -0,0 +1,12 @@
(
vertex:[
],
fragment:[
],
compute:[
(
entry_point:"main",
target_profile:"cs_6_0",
),
],
)

View File

@ -0,0 +1,55 @@
// language: metal2.4
#include <metal_stdlib>
#include <simd/simd.h>
using metal::uint;
void main_1(
thread uint& subgroup_size_1,
thread uint& subgroup_invocation_id_1
) {
uint _e5 = subgroup_size_1;
uint _e6 = subgroup_invocation_id_1;
metal::uint4 unnamed = uint4((uint64_t)metal::simd_ballot((_e6 & 1u) == 1u), 0, 0, 0);
metal::uint4 unnamed_1 = uint4((uint64_t)metal::simd_ballot(true), 0, 0, 0);
bool unnamed_2 = metal::simd_all(_e6 != 0u);
bool unnamed_3 = metal::simd_any(_e6 == 0u);
uint unnamed_4 = metal::simd_sum(_e6);
uint unnamed_5 = metal::simd_product(_e6);
uint unnamed_6 = metal::simd_min(_e6);
uint unnamed_7 = metal::simd_max(_e6);
uint unnamed_8 = metal::simd_and(_e6);
uint unnamed_9 = metal::simd_or(_e6);
uint unnamed_10 = metal::simd_xor(_e6);
uint unnamed_11 = metal::simd_prefix_exclusive_sum(_e6);
uint unnamed_12 = metal::simd_prefix_exclusive_product(_e6);
uint unnamed_13 = metal::simd_prefix_inclusive_sum(_e6);
uint unnamed_14 = metal::simd_prefix_inclusive_product(_e6);
uint unnamed_15 = metal::simd_broadcast_first(_e6);
uint unnamed_16 = metal::simd_broadcast(_e6, 4u);
uint unnamed_17 = metal::simd_shuffle(_e6, (_e5 - 1u) - _e6);
uint unnamed_18 = metal::simd_shuffle_down(_e6, 1u);
uint unnamed_19 = metal::simd_shuffle_up(_e6, 1u);
uint unnamed_20 = metal::simd_shuffle_xor(_e6, _e5 - 1u);
return;
}
struct main_Input {
};
kernel void main_(
uint num_subgroups [[simdgroups_per_threadgroup]]
, uint subgroup_id [[simdgroup_index_in_threadgroup]]
, uint subgroup_size [[threads_per_simdgroup]]
, uint subgroup_invocation_id [[thread_index_in_simdgroup]]
) {
uint num_subgroups_1 = {};
uint subgroup_id_1 = {};
uint subgroup_size_1 = {};
uint subgroup_invocation_id_1 = {};
num_subgroups_1 = num_subgroups;
subgroup_id_1 = subgroup_id;
subgroup_size_1 = subgroup_size;
subgroup_invocation_id_1 = subgroup_invocation_id;
main_1(subgroup_size_1, subgroup_invocation_id_1);
}

View File

@ -0,0 +1,44 @@
// language: metal2.4
#include <metal_stdlib>
#include <simd/simd.h>
using metal::uint;
struct Structure {
uint num_subgroups;
uint subgroup_size;
};
struct main_Input {
};
kernel void main_(
uint num_subgroups [[simdgroups_per_threadgroup]]
, uint subgroup_size [[threads_per_simdgroup]]
, uint subgroup_id [[simdgroup_index_in_threadgroup]]
, uint subgroup_invocation_id [[thread_index_in_simdgroup]]
) {
const Structure sizes = { num_subgroups, subgroup_size };
metal::simdgroup_barrier(metal::mem_flags::mem_threadgroup);
metal::uint4 unnamed = uint4((uint64_t)metal::simd_ballot((subgroup_invocation_id & 1u) == 1u), 0, 0, 0);
metal::uint4 unnamed_1 = uint4((uint64_t)metal::simd_ballot(true), 0, 0, 0);
bool unnamed_2 = metal::simd_all(subgroup_invocation_id != 0u);
bool unnamed_3 = metal::simd_any(subgroup_invocation_id == 0u);
uint unnamed_4 = metal::simd_sum(subgroup_invocation_id);
uint unnamed_5 = metal::simd_product(subgroup_invocation_id);
uint unnamed_6 = metal::simd_min(subgroup_invocation_id);
uint unnamed_7 = metal::simd_max(subgroup_invocation_id);
uint unnamed_8 = metal::simd_and(subgroup_invocation_id);
uint unnamed_9 = metal::simd_or(subgroup_invocation_id);
uint unnamed_10 = metal::simd_xor(subgroup_invocation_id);
uint unnamed_11 = metal::simd_prefix_exclusive_sum(subgroup_invocation_id);
uint unnamed_12 = metal::simd_prefix_exclusive_product(subgroup_invocation_id);
uint unnamed_13 = metal::simd_prefix_inclusive_sum(subgroup_invocation_id);
uint unnamed_14 = metal::simd_prefix_inclusive_product(subgroup_invocation_id);
uint unnamed_15 = metal::simd_broadcast_first(subgroup_invocation_id);
uint unnamed_16 = metal::simd_broadcast(subgroup_invocation_id, 4u);
uint unnamed_17 = metal::simd_shuffle(subgroup_invocation_id, (sizes.subgroup_size - 1u) - subgroup_invocation_id);
uint unnamed_18 = metal::simd_shuffle_down(subgroup_invocation_id, 1u);
uint unnamed_19 = metal::simd_shuffle_up(subgroup_invocation_id, 1u);
uint unnamed_20 = metal::simd_shuffle_xor(subgroup_invocation_id, sizes.subgroup_size - 1u);
return;
}

View File

@ -0,0 +1,81 @@
; SPIR-V
; Version: 1.3
; Generator: rspirv
; Bound: 58
OpCapability Shader
OpCapability GroupNonUniform
OpCapability GroupNonUniformBallot
OpCapability GroupNonUniformVote
OpCapability GroupNonUniformArithmetic
OpCapability GroupNonUniformShuffle
OpCapability GroupNonUniformShuffleRelative
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %17 "main" %8 %11 %13 %15
OpExecutionMode %17 LocalSize 1 1 1
OpMemberDecorate %4 0 Offset 0
OpMemberDecorate %4 1 Offset 4
OpDecorate %8 BuiltIn NumSubgroups
OpDecorate %11 BuiltIn SubgroupSize
OpDecorate %13 BuiltIn SubgroupId
OpDecorate %15 BuiltIn SubgroupLocalInvocationId
%2 = OpTypeVoid
%3 = OpTypeInt 32 0
%4 = OpTypeStruct %3 %3
%5 = OpTypeBool
%9 = OpTypePointer Input %3
%8 = OpVariable %9 Input
%11 = OpVariable %9 Input
%13 = OpVariable %9 Input
%15 = OpVariable %9 Input
%18 = OpTypeFunction %2
%19 = OpConstant %3 1
%20 = OpConstant %3 0
%21 = OpConstant %3 4
%23 = OpConstant %3 3
%24 = OpConstant %3 2
%25 = OpConstant %3 8
%28 = OpTypeVector %3 4
%30 = OpConstantTrue %5
%17 = OpFunction %2 None %18
%6 = OpLabel
%10 = OpLoad %3 %8
%12 = OpLoad %3 %11
%7 = OpCompositeConstruct %4 %10 %12
%14 = OpLoad %3 %13
%16 = OpLoad %3 %15
OpBranch %22
%22 = OpLabel
OpControlBarrier %23 %24 %25
%26 = OpBitwiseAnd %3 %16 %19
%27 = OpIEqual %5 %26 %19
%29 = OpGroupNonUniformBallot %28 %23 %27
%31 = OpGroupNonUniformBallot %28 %23 %30
%32 = OpINotEqual %5 %16 %20
%33 = OpGroupNonUniformAll %5 %23 %32
%34 = OpIEqual %5 %16 %20
%35 = OpGroupNonUniformAny %5 %23 %34
%36 = OpGroupNonUniformIAdd %3 %23 Reduce %16
%37 = OpGroupNonUniformIMul %3 %23 Reduce %16
%38 = OpGroupNonUniformUMin %3 %23 Reduce %16
%39 = OpGroupNonUniformUMax %3 %23 Reduce %16
%40 = OpGroupNonUniformBitwiseAnd %3 %23 Reduce %16
%41 = OpGroupNonUniformBitwiseOr %3 %23 Reduce %16
%42 = OpGroupNonUniformBitwiseXor %3 %23 Reduce %16
%43 = OpGroupNonUniformIAdd %3 %23 ExclusiveScan %16
%44 = OpGroupNonUniformIMul %3 %23 ExclusiveScan %16
%45 = OpGroupNonUniformIAdd %3 %23 InclusiveScan %16
%46 = OpGroupNonUniformIMul %3 %23 InclusiveScan %16
%47 = OpGroupNonUniformBroadcastFirst %3 %23 %16
%48 = OpGroupNonUniformShuffle %3 %23 %16 %21
%49 = OpCompositeExtract %3 %7 1
%50 = OpISub %3 %49 %19
%51 = OpISub %3 %50 %16
%52 = OpGroupNonUniformShuffle %3 %23 %16 %51
%53 = OpGroupNonUniformShuffleDown %3 %23 %16 %19
%54 = OpGroupNonUniformShuffleUp %3 %23 %16 %19
%55 = OpCompositeExtract %3 %7 1
%56 = OpISub %3 %55 %19
%57 = OpGroupNonUniformShuffleXor %3 %23 %16 %56
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,40 @@
var<private> num_subgroups_1: u32;
var<private> subgroup_id_1: u32;
var<private> subgroup_size_1: u32;
var<private> subgroup_invocation_id_1: u32;
fn main_1() {
let _e5 = subgroup_size_1;
let _e6 = subgroup_invocation_id_1;
let _e9 = subgroupBallot(((_e6 & 1u) == 1u));
let _e10 = subgroupBallot();
let _e12 = subgroupAll((_e6 != 0u));
let _e14 = subgroupAny((_e6 == 0u));
let _e15 = subgroupAdd(_e6);
let _e16 = subgroupMul(_e6);
let _e17 = subgroupMin(_e6);
let _e18 = subgroupMax(_e6);
let _e19 = subgroupAnd(_e6);
let _e20 = subgroupOr(_e6);
let _e21 = subgroupXor(_e6);
let _e22 = subgroupExclusiveAdd(_e6);
let _e23 = subgroupExclusiveMul(_e6);
let _e24 = subgroupInclusiveAdd(_e6);
let _e25 = subgroupInclusiveMul(_e6);
let _e26 = subgroupBroadcastFirst(_e6);
let _e27 = subgroupBroadcast(_e6, 4u);
let _e30 = subgroupShuffle(_e6, ((_e5 - 1u) - _e6));
let _e31 = subgroupShuffleDown(_e6, 1u);
let _e32 = subgroupShuffleUp(_e6, 1u);
let _e34 = subgroupShuffleXor(_e6, (_e5 - 1u));
return;
}
@compute @workgroup_size(1, 1, 1)
fn main(@builtin(num_subgroups) num_subgroups: u32, @builtin(subgroup_id) subgroup_id: u32, @builtin(subgroup_size) subgroup_size: u32, @builtin(subgroup_invocation_id) subgroup_invocation_id: u32) {
num_subgroups_1 = num_subgroups;
subgroup_id_1 = subgroup_id;
subgroup_size_1 = subgroup_size;
subgroup_invocation_id_1 = subgroup_invocation_id;
main_1();
}

View File

@ -0,0 +1,31 @@
struct Structure {
@builtin(num_subgroups) num_subgroups: u32,
@builtin(subgroup_size) subgroup_size: u32,
}
@compute @workgroup_size(1, 1, 1)
fn main(sizes: Structure, @builtin(subgroup_id) subgroup_id: u32, @builtin(subgroup_invocation_id) subgroup_invocation_id: u32) {
subgroupBarrier();
let _e7 = subgroupBallot(((subgroup_invocation_id & 1u) == 1u));
let _e8 = subgroupBallot();
let _e11 = subgroupAll((subgroup_invocation_id != 0u));
let _e14 = subgroupAny((subgroup_invocation_id == 0u));
let _e15 = subgroupAdd(subgroup_invocation_id);
let _e16 = subgroupMul(subgroup_invocation_id);
let _e17 = subgroupMin(subgroup_invocation_id);
let _e18 = subgroupMax(subgroup_invocation_id);
let _e19 = subgroupAnd(subgroup_invocation_id);
let _e20 = subgroupOr(subgroup_invocation_id);
let _e21 = subgroupXor(subgroup_invocation_id);
let _e22 = subgroupExclusiveAdd(subgroup_invocation_id);
let _e23 = subgroupExclusiveMul(subgroup_invocation_id);
let _e24 = subgroupInclusiveAdd(subgroup_invocation_id);
let _e25 = subgroupInclusiveMul(subgroup_invocation_id);
let _e26 = subgroupBroadcastFirst(subgroup_invocation_id);
let _e28 = subgroupBroadcast(subgroup_invocation_id, 4u);
let _e33 = subgroupShuffle(subgroup_invocation_id, ((sizes.subgroup_size - 1u) - subgroup_invocation_id));
let _e35 = subgroupShuffleDown(subgroup_invocation_id, 1u);
let _e37 = subgroupShuffleUp(subgroup_invocation_id, 1u);
let _e41 = subgroupShuffleXor(subgroup_invocation_id, (sizes.subgroup_size - 1u));
return;
}

View File

@ -269,10 +269,18 @@ fn check_targets(
let params = input.read_parameters();
let name = &input.file_name;
let capabilities = if params.god_mode {
naga::valid::Capabilities::all()
let (capabilities, subgroup_stages, subgroup_operations) = if params.god_mode {
(
naga::valid::Capabilities::all(),
naga::valid::ShaderStages::all(),
naga::valid::SubgroupOperationSet::all(),
)
} else {
naga::valid::Capabilities::default()
(
naga::valid::Capabilities::default(),
naga::valid::ShaderStages::empty(),
naga::valid::SubgroupOperationSet::empty(),
)
};
#[cfg(feature = "serialize")]
@ -285,6 +293,8 @@ fn check_targets(
}
let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities)
.subgroup_stages(subgroup_stages)
.subgroup_operations(subgroup_operations)
.validate(module)
.unwrap_or_else(|err| {
panic!(
@ -308,6 +318,8 @@ fn check_targets(
}
naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities)
.subgroup_stages(subgroup_stages)
.subgroup_operations(subgroup_operations)
.validate(module)
.unwrap_or_else(|err| {
panic!(
@ -850,6 +862,10 @@ fn convert_wgsl() {
"int64",
Targets::SPIRV | Targets::HLSL | Targets::WGSL | Targets::METAL,
),
(
"subgroup-operations",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
),
(
"overrides",
Targets::IR
@ -957,6 +973,11 @@ fn convert_spv_all() {
);
convert_spv("builtin-accessed-outside-entrypoint", true, Targets::WGSL);
convert_spv("spec-constants", true, Targets::IR);
convert_spv(
"subgroup-operations-s",
false,
Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
);
}
#[cfg(feature = "glsl-in")]

View File

@ -33,6 +33,7 @@ mod scissor_tests;
mod shader;
mod shader_primitive_index;
mod shader_view_format;
mod subgroup_operations;
mod texture_bounds;
mod texture_view_creation;
mod transfer;

View File

@ -0,0 +1,126 @@
use std::{borrow::Cow, collections::HashMap, num::NonZeroU64};
use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters};
const THREAD_COUNT: u64 = 128;
const TEST_COUNT: u32 = 32;
#[gpu_test]
static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.features(wgpu::Features::SUBGROUP)
.limits(wgpu::Limits::downlevel_defaults())
.expect_fail(wgpu_test::FailureCase::molten_vk())
.expect_fail(
// Expect metal to fail on tests involving operations in divergent control flow
wgpu_test::FailureCase::backend(wgpu::Backends::METAL)
.panic("thread 0 failed tests: 27, 29,\nthread 1 failed tests: 27, 28, 29,\n"),
),
)
.run_sync(|ctx| {
let device = &ctx.device;
let storage_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: THREAD_COUNT * std::mem::size_of::<u32>() as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("bind group layout"),
entries: &[wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: NonZeroU64::new(
THREAD_COUNT * std::mem::size_of::<u32>() as u64,
),
},
count: None,
}],
});
let cs_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("shader.wgsl"))),
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("main"),
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: Some(&pipeline_layout),
module: &cs_module,
entry_point: "main",
constants: &HashMap::default(),
});
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: storage_buffer.as_entire_binding(),
}],
layout: &bind_group_layout,
label: Some("bind group"),
});
let mut encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: None,
timestamp_writes: None,
});
cpass.set_pipeline(&compute_pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
cpass.dispatch_workgroups(1, 1, 1);
}
ctx.queue.submit(Some(encoder.finish()));
wgpu::util::DownloadBuffer::read_buffer(
device,
&ctx.queue,
&storage_buffer.slice(..),
|mapping_buffer_view| {
let mapping_buffer_view = mapping_buffer_view.unwrap();
let result: &[u32; THREAD_COUNT as usize] =
bytemuck::from_bytes(&mapping_buffer_view);
let expected_mask = (1u64 << (TEST_COUNT)) - 1; // generate full mask
let expected_array = [expected_mask as u32; THREAD_COUNT as usize];
if result != &expected_array {
use std::fmt::Write;
let mut msg = String::new();
writeln!(
&mut msg,
"Got from GPU:\n{:x?}\n expected:\n{:x?}",
result, &expected_array,
)
.unwrap();
for (thread, (result, expected)) in result
.iter()
.zip(expected_array)
.enumerate()
.filter(|(_, (r, e))| *r != e)
{
write!(&mut msg, "thread {} failed tests:", thread).unwrap();
let difference = result ^ expected;
for i in (0..u32::BITS).filter(|i| (difference & (1 << i)) != 0) {
write!(&mut msg, " {},", i).unwrap();
}
writeln!(&mut msg).unwrap();
}
panic!("{}", msg);
}
},
);
});

View File

@ -0,0 +1,158 @@
@group(0)
@binding(0)
var<storage, read_write> storage_buffer: array<u32>;
var<workgroup> workgroup_buffer: u32;
fn add_result_to_mask(mask: ptr<function, u32>, index: u32, value: bool) {
(*mask) |= u32(value) << index;
}
@compute
@workgroup_size(128)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_subgroups) num_subgroups: u32,
@builtin(subgroup_id) subgroup_id: u32,
@builtin(subgroup_size) subgroup_size: u32,
@builtin(subgroup_invocation_id) subgroup_invocation_id: u32,
) {
var passed = 0u;
var expected: u32;
add_result_to_mask(&passed, 0u, num_subgroups == 128u / subgroup_size);
add_result_to_mask(&passed, 1u, subgroup_id == global_id.x / subgroup_size);
add_result_to_mask(&passed, 2u, subgroup_invocation_id == global_id.x % subgroup_size);
var expected_ballot = vec4<u32>(0u);
for(var i = 0u; i < subgroup_size; i += 1u) {
expected_ballot[i / 32u] |= ((global_id.x - subgroup_invocation_id + i) & 1u) << (i % 32u);
}
add_result_to_mask(&passed, 3u, dot(vec4<u32>(1u), vec4<u32>(subgroupBallot((subgroup_invocation_id & 1u) == 1u) == expected_ballot)) == 4u);
add_result_to_mask(&passed, 4u, subgroupAll(true));
add_result_to_mask(&passed, 5u, !subgroupAll(subgroup_invocation_id != 0u));
add_result_to_mask(&passed, 6u, subgroupAny(subgroup_invocation_id == 0u));
add_result_to_mask(&passed, 7u, !subgroupAny(false));
expected = 0u;
for(var i = 0u; i < subgroup_size; i += 1u) {
expected += global_id.x - subgroup_invocation_id + i + 1u;
}
add_result_to_mask(&passed, 8u, subgroupAdd(global_id.x + 1u) == expected);
expected = 1u;
for(var i = 0u; i < subgroup_size; i += 1u) {
expected *= global_id.x - subgroup_invocation_id + i + 1u;
}
add_result_to_mask(&passed, 9u, subgroupMul(global_id.x + 1u) == expected);
expected = 0u;
for(var i = 0u; i < subgroup_size; i += 1u) {
expected = max(expected, global_id.x - subgroup_invocation_id + i + 1u);
}
add_result_to_mask(&passed, 10u, subgroupMax(global_id.x + 1u) == expected);
expected = 0xFFFFFFFFu;
for(var i = 0u; i < subgroup_size; i += 1u) {
expected = min(expected, global_id.x - subgroup_invocation_id + i + 1u);
}
add_result_to_mask(&passed, 11u, subgroupMin(global_id.x + 1u) == expected);
expected = 0xFFFFFFFFu;
for(var i = 0u; i < subgroup_size; i += 1u) {
expected &= global_id.x - subgroup_invocation_id + i + 1u;
}
add_result_to_mask(&passed, 12u, subgroupAnd(global_id.x + 1u) == expected);
expected = 0u;
for(var i = 0u; i < subgroup_size; i += 1u) {
expected |= global_id.x - subgroup_invocation_id + i + 1u;
}
add_result_to_mask(&passed, 13u, subgroupOr(global_id.x + 1u) == expected);
expected = 0u;
for(var i = 0u; i < subgroup_size; i += 1u) {
expected ^= global_id.x - subgroup_invocation_id + i + 1u;
}
add_result_to_mask(&passed, 14u, subgroupXor(global_id.x + 1u) == expected);
expected = 0u;
for(var i = 0u; i < subgroup_invocation_id; i += 1u) {
expected += global_id.x - subgroup_invocation_id + i + 1u;
}
add_result_to_mask(&passed, 15u, subgroupExclusiveAdd(global_id.x + 1u) == expected);
expected = 1u;
for(var i = 0u; i < subgroup_invocation_id; i += 1u) {
expected *= global_id.x - subgroup_invocation_id + i + 1u;
}
add_result_to_mask(&passed, 16u, subgroupExclusiveMul(global_id.x + 1u) == expected);
expected = 0u;
for(var i = 0u; i <= subgroup_invocation_id; i += 1u) {
expected += global_id.x - subgroup_invocation_id + i + 1u;
}
add_result_to_mask(&passed, 17u, subgroupInclusiveAdd(global_id.x + 1u) == expected);
expected = 1u;
for(var i = 0u; i <= subgroup_invocation_id; i += 1u) {
expected *= global_id.x - subgroup_invocation_id + i + 1u;
}
add_result_to_mask(&passed, 18u, subgroupInclusiveMul(global_id.x + 1u) == expected);
add_result_to_mask(&passed, 19u, subgroupBroadcastFirst(u32(subgroup_invocation_id != 0u)) == 0u);
add_result_to_mask(&passed, 20u, subgroupBroadcastFirst(u32(subgroup_invocation_id == 0u)) == 1u);
add_result_to_mask(&passed, 21u, subgroupBroadcast(subgroup_invocation_id, 1u) == 1u);
add_result_to_mask(&passed, 22u, subgroupShuffle(subgroup_invocation_id, subgroup_invocation_id) == subgroup_invocation_id);
add_result_to_mask(&passed, 23u, subgroupShuffle(subgroup_invocation_id, subgroup_size - 1u - subgroup_invocation_id) == subgroup_size - 1u - subgroup_invocation_id);
add_result_to_mask(&passed, 24u, subgroup_invocation_id == subgroup_size - 1u || subgroupShuffleDown(subgroup_invocation_id, 1u) == subgroup_invocation_id + 1u);
add_result_to_mask(&passed, 25u, subgroup_invocation_id == 0u || subgroupShuffleUp(subgroup_invocation_id, 1u) == subgroup_invocation_id - 1u);
add_result_to_mask(&passed, 26u, subgroupShuffleXor(subgroup_invocation_id, subgroup_size - 1u) == (subgroup_invocation_id ^ (subgroup_size - 1u)));
var passed_27 = false;
if subgroup_invocation_id % 2u == 0u {
passed_27 |= subgroupAdd(1u) == (subgroup_size / 2u);
} else {
passed_27 |= subgroupAdd(1u) == (subgroup_size / 2u);
}
add_result_to_mask(&passed, 27u, passed_27);
var passed_28 = false;
switch subgroup_invocation_id % 3u {
case 0u: {
passed_28 = subgroupBroadcastFirst(subgroup_invocation_id) == 0u;
}
case 1u: {
passed_28 = subgroupBroadcastFirst(subgroup_invocation_id) == 1u;
}
case 2u: {
passed_28 = subgroupBroadcastFirst(subgroup_invocation_id) == 2u;
}
default { }
}
add_result_to_mask(&passed, 28u, passed_28);
expected = 0u;
for (var i = subgroup_size; i >= 0u; i -= 1u) {
expected = subgroupAdd(1u);
if i == subgroup_invocation_id {
break;
}
}
add_result_to_mask(&passed, 29u, expected == (subgroup_invocation_id + 1u));
if global_id.x == 0u {
workgroup_buffer = subgroup_size;
}
workgroupBarrier();
add_result_to_mask(&passed, 30u, workgroup_buffer == subgroup_size);
// Keep this test last, verify we are still convergent after running other tests
add_result_to_mask(&passed, 31u, subgroupAdd(1u) == subgroup_size);
// Increment TEST_COUNT in subgroup_operations/mod.rs if adding more tests
storage_buffer[global_id.x] = passed;
}

View File

@ -1537,6 +1537,15 @@ impl<A: HalApi> Device<A> {
.flags
.contains(wgt::DownlevelFlags::CUBE_ARRAY_TEXTURES),
);
caps.set(
Caps::SUBGROUP,
self.features
.intersects(wgt::Features::SUBGROUP | wgt::Features::SUBGROUP_VERTEX),
);
caps.set(
Caps::SUBGROUP_BARRIER,
self.features.intersects(wgt::Features::SUBGROUP_BARRIER),
);
let debug_source =
if self.instance_flags.contains(wgt::InstanceFlags::DEBUG) && !source.is_empty() {
@ -1552,7 +1561,26 @@ impl<A: HalApi> Device<A> {
None
};
let mut subgroup_stages = naga::valid::ShaderStages::empty();
subgroup_stages.set(
naga::valid::ShaderStages::COMPUTE | naga::valid::ShaderStages::FRAGMENT,
self.features.contains(wgt::Features::SUBGROUP),
);
subgroup_stages.set(
naga::valid::ShaderStages::VERTEX,
self.features.contains(wgt::Features::SUBGROUP_VERTEX),
);
let subgroup_operations = if caps.contains(Caps::SUBGROUP) {
use naga::valid::SubgroupOperationSet as S;
S::BASIC | S::VOTE | S::ARITHMETIC | S::BALLOT | S::SHUFFLE | S::SHUFFLE_RELATIVE
} else {
naga::valid::SubgroupOperationSet::empty()
};
let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), caps)
.subgroup_stages(subgroup_stages)
.subgroup_operations(subgroup_operations)
.validate(&module)
.map_err(|inner| {
pipeline::CreateShaderModuleError::Validation(pipeline::ShaderError {

View File

@ -127,6 +127,11 @@ impl super::Adapter {
)
});
// If we don't have dxc, we reduce the max to 5.1
if dxc_container.is_none() {
shader_model_support.HighestShaderModel = d3d12_ty::D3D_SHADER_MODEL_5_1;
}
let mut workarounds = super::Workarounds::default();
let info = wgt::AdapterInfo {
@ -343,21 +348,28 @@ impl super::Adapter {
bgra8unorm_storage_supported,
);
let mut features1: d3d12_ty::D3D12_FEATURE_DATA_D3D12_OPTIONS1 = unsafe { mem::zeroed() };
let hr = unsafe {
device.CheckFeatureSupport(
d3d12_ty::D3D12_FEATURE_D3D12_OPTIONS1,
&mut features1 as *mut _ as *mut _,
mem::size_of::<d3d12_ty::D3D12_FEATURE_DATA_D3D12_OPTIONS1>() as _,
)
};
// we must be using DXC because uint64_t was added with Shader Model 6
// and FXC only supports up to 5.1
let int64_shader_ops_supported = dxc_container.is_some() && {
let mut features1: d3d12_ty::D3D12_FEATURE_DATA_D3D12_OPTIONS1 =
unsafe { mem::zeroed() };
let hr = unsafe {
device.CheckFeatureSupport(
d3d12_ty::D3D12_FEATURE_D3D12_OPTIONS1,
&mut features1 as *mut _ as *mut _,
mem::size_of::<d3d12_ty::D3D12_FEATURE_DATA_D3D12_OPTIONS1>() as _,
)
};
hr == 0 && features1.Int64ShaderOps != 0
};
features.set(wgt::Features::SHADER_INT64, int64_shader_ops_supported);
features.set(
wgt::Features::SHADER_INT64,
dxc_container.is_some() && hr == 0 && features1.Int64ShaderOps != 0,
);
features.set(
wgt::Features::SUBGROUP,
shader_model_support.HighestShaderModel >= d3d12_ty::D3D_SHADER_MODEL_6_0
&& hr == 0
&& features1.WaveOps != 0,
);
// float32-filterable should always be available on d3d12
features.set(wgt::Features::FLOAT32_FILTERABLE, true);
@ -425,6 +437,8 @@ impl super::Adapter {
.min(crate::MAX_VERTEX_BUFFERS as u32),
max_vertex_attributes: d3d12_ty::D3D12_IA_VERTEX_INPUT_RESOURCE_SLOT_COUNT,
max_vertex_buffer_array_stride: d3d12_ty::D3D12_SO_BUFFER_MAX_STRIDE_IN_BYTES,
min_subgroup_size: 4, // Not using `features1.WaveLaneCountMin` as it is unreliable
max_subgroup_size: 128,
// The push constants are part of the root signature which
// has a limit of 64 DWORDS (256 bytes), but other resources
// also share the root signature:

View File

@ -748,6 +748,8 @@ impl super::Adapter {
} else {
!0
},
min_subgroup_size: 0,
max_subgroup_size: 0,
max_push_constant_size: super::MAX_PUSH_CONSTANTS as u32 * 4,
min_uniform_buffer_offset_alignment,
min_storage_buffer_offset_alignment,

View File

@ -813,6 +813,10 @@ impl super::PrivateCapabilities {
None
},
timestamp_query_support,
supports_simd_scoped_operations: family_check
&& (device.supports_family(MTLGPUFamily::Metal3)
|| device.supports_family(MTLGPUFamily::Mac2)
|| device.supports_family(MTLGPUFamily::Apple7)),
}
}
@ -898,6 +902,10 @@ impl super::PrivateCapabilities {
features.set(F::RG11B10UFLOAT_RENDERABLE, self.format_rg11b10_all);
features.set(F::SHADER_UNUSED_VERTEX_OUTPUT, true);
if self.supports_simd_scoped_operations {
features.insert(F::SUBGROUP | F::SUBGROUP_BARRIER);
}
features
}
@ -952,6 +960,8 @@ impl super::PrivateCapabilities {
max_vertex_buffers: self.max_vertex_buffers,
max_vertex_attributes: 31,
max_vertex_buffer_array_stride: base.max_vertex_buffer_array_stride,
min_subgroup_size: 4,
max_subgroup_size: 64,
max_push_constant_size: 0x1000,
min_uniform_buffer_offset_alignment: self.buffer_alignment as u32,
min_storage_buffer_offset_alignment: self.buffer_alignment as u32,

View File

@ -269,6 +269,7 @@ struct PrivateCapabilities {
supports_shader_primitive_index: bool,
has_unified_memory: Option<bool>,
timestamp_query_support: TimestampQuerySupport,
supports_simd_scoped_operations: bool,
}
#[derive(Clone, Debug)]

View File

@ -101,6 +101,9 @@ pub struct PhysicalDeviceFeatures {
/// to Vulkan 1.3.
zero_initialize_workgroup_memory:
Option<vk::PhysicalDeviceZeroInitializeWorkgroupMemoryFeatures>,
/// Features provided by `VK_EXT_subgroup_size_control`, promoted to Vulkan 1.3.
subgroup_size_control: Option<vk::PhysicalDeviceSubgroupSizeControlFeatures>,
}
// This is safe because the structs have `p_next: *mut c_void`, which we null out/never read.
@ -148,6 +151,9 @@ impl PhysicalDeviceFeatures {
if let Some(ref mut feature) = self.ray_query {
info = info.push_next(feature);
}
if let Some(ref mut feature) = self.subgroup_size_control {
info = info.push_next(feature);
}
info
}
@ -434,6 +440,17 @@ impl PhysicalDeviceFeatures {
} else {
None
},
subgroup_size_control: if device_api_version >= vk::API_VERSION_1_3
|| enabled_extensions.contains(&vk::ExtSubgroupSizeControlFn::name())
{
Some(
vk::PhysicalDeviceSubgroupSizeControlFeatures::builder()
.subgroup_size_control(true)
.build(),
)
} else {
None
},
}
}
@ -638,6 +655,34 @@ impl PhysicalDeviceFeatures {
);
}
if let Some(ref subgroup) = caps.subgroup {
if (caps.device_api_version >= vk::API_VERSION_1_3
|| caps.supports_extension(vk::ExtSubgroupSizeControlFn::name()))
&& subgroup.supported_operations.contains(
vk::SubgroupFeatureFlags::BASIC
| vk::SubgroupFeatureFlags::VOTE
| vk::SubgroupFeatureFlags::ARITHMETIC
| vk::SubgroupFeatureFlags::BALLOT
| vk::SubgroupFeatureFlags::SHUFFLE
| vk::SubgroupFeatureFlags::SHUFFLE_RELATIVE,
)
{
features.set(
F::SUBGROUP,
subgroup
.supported_stages
.contains(vk::ShaderStageFlags::COMPUTE | vk::ShaderStageFlags::FRAGMENT),
);
features.set(
F::SUBGROUP_VERTEX,
subgroup
.supported_stages
.contains(vk::ShaderStageFlags::VERTEX),
);
features.insert(F::SUBGROUP_BARRIER);
}
}
let supports_depth_format = |format| {
supports_format(
instance,
@ -773,6 +818,13 @@ pub struct PhysicalDeviceProperties {
/// `VK_KHR_driver_properties` extension, promoted to Vulkan 1.2.
driver: Option<vk::PhysicalDeviceDriverPropertiesKHR>,
/// Additional `vk::PhysicalDevice` properties from Vulkan 1.1.
subgroup: Option<vk::PhysicalDeviceSubgroupProperties>,
/// Additional `vk::PhysicalDevice` properties from the
/// `VK_EXT_subgroup_size_control` extension, promoted to Vulkan 1.3.
subgroup_size_control: Option<vk::PhysicalDeviceSubgroupSizeControlProperties>,
/// The device API version.
///
/// Which is the version of Vulkan supported for device-level functionality.
@ -888,6 +940,11 @@ impl PhysicalDeviceProperties {
if self.supports_extension(vk::ExtImageRobustnessFn::name()) {
extensions.push(vk::ExtImageRobustnessFn::name());
}
// Require `VK_EXT_subgroup_size_control` if the associated feature was requested
if requested_features.contains(wgt::Features::SUBGROUP) {
extensions.push(vk::ExtSubgroupSizeControlFn::name());
}
}
// Optional `VK_KHR_swapchain_mutable_format`
@ -987,6 +1044,14 @@ impl PhysicalDeviceProperties {
.min(crate::MAX_VERTEX_BUFFERS as u32),
max_vertex_attributes: limits.max_vertex_input_attributes,
max_vertex_buffer_array_stride: limits.max_vertex_input_binding_stride,
min_subgroup_size: self
.subgroup_size_control
.map(|subgroup_size| subgroup_size.min_subgroup_size)
.unwrap_or(0),
max_subgroup_size: self
.subgroup_size_control
.map(|subgroup_size| subgroup_size.max_subgroup_size)
.unwrap_or(0),
max_push_constant_size: limits.max_push_constants_size,
min_uniform_buffer_offset_alignment: limits.min_uniform_buffer_offset_alignment as u32,
min_storage_buffer_offset_alignment: limits.min_storage_buffer_offset_alignment as u32,
@ -1042,6 +1107,9 @@ impl super::InstanceShared {
let supports_driver_properties = capabilities.device_api_version
>= vk::API_VERSION_1_2
|| capabilities.supports_extension(vk::KhrDriverPropertiesFn::name());
let supports_subgroup_size_control = capabilities.device_api_version
>= vk::API_VERSION_1_3
|| capabilities.supports_extension(vk::ExtSubgroupSizeControlFn::name());
let supports_acceleration_structure =
capabilities.supports_extension(vk::KhrAccelerationStructureFn::name());
@ -1075,6 +1143,20 @@ impl super::InstanceShared {
builder = builder.push_next(next);
}
if capabilities.device_api_version >= vk::API_VERSION_1_1 {
let next = capabilities
.subgroup
.insert(vk::PhysicalDeviceSubgroupProperties::default());
builder = builder.push_next(next);
}
if supports_subgroup_size_control {
let next = capabilities
.subgroup_size_control
.insert(vk::PhysicalDeviceSubgroupSizeControlProperties::default());
builder = builder.push_next(next);
}
let mut properties2 = builder.build();
unsafe {
get_device_properties.get_physical_device_properties2(phd, &mut properties2);
@ -1190,6 +1272,16 @@ impl super::InstanceShared {
builder = builder.push_next(next);
}
// `VK_EXT_subgroup_size_control` is promoted to 1.3
if capabilities.device_api_version >= vk::API_VERSION_1_3
|| capabilities.supports_extension(vk::ExtSubgroupSizeControlFn::name())
{
let next = features
.subgroup_size_control
.insert(vk::PhysicalDeviceSubgroupSizeControlFeatures::default());
builder = builder.push_next(next);
}
let mut features2 = builder.build();
unsafe {
get_device_properties.get_physical_device_features2(phd, &mut features2);
@ -1382,6 +1474,9 @@ impl super::Instance {
}),
image_format_list: phd_capabilities.device_api_version >= vk::API_VERSION_1_2
|| phd_capabilities.supports_extension(vk::KhrImageFormatListFn::name()),
subgroup_size_control: phd_features
.subgroup_size_control
.map_or(false, |ext| ext.subgroup_size_control == vk::TRUE),
};
let capabilities = crate::Capabilities {
limits: phd_capabilities.to_wgpu_limits(),
@ -1581,6 +1676,15 @@ impl super::Adapter {
capabilities.push(spv::Capability::Geometry);
}
if features.intersects(wgt::Features::SUBGROUP | wgt::Features::SUBGROUP_VERTEX) {
capabilities.push(spv::Capability::GroupNonUniform);
capabilities.push(spv::Capability::GroupNonUniformVote);
capabilities.push(spv::Capability::GroupNonUniformArithmetic);
capabilities.push(spv::Capability::GroupNonUniformBallot);
capabilities.push(spv::Capability::GroupNonUniformShuffle);
capabilities.push(spv::Capability::GroupNonUniformShuffleRelative);
}
if features.intersects(
wgt::Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING
| wgt::Features::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING,
@ -1616,7 +1720,13 @@ impl super::Adapter {
true, // could check `super::Workarounds::SEPARATE_ENTRY_POINTS`
);
spv::Options {
lang_version: (1, 0),
lang_version: if features
.intersects(wgt::Features::SUBGROUP | wgt::Features::SUBGROUP_VERTEX)
{
(1, 3)
} else {
(1, 0)
},
flags,
capabilities: Some(capabilities.iter().cloned().collect()),
bounds_check_policies: naga::proc::BoundsCheckPolicies {

View File

@ -782,8 +782,14 @@ impl super::Device {
}
};
let mut flags = vk::PipelineShaderStageCreateFlags::empty();
if self.shared.private_caps.subgroup_size_control {
flags |= vk::PipelineShaderStageCreateFlags::ALLOW_VARYING_SUBGROUP_SIZE
}
let entry_point = CString::new(stage.entry_point).unwrap();
let create_info = vk::PipelineShaderStageCreateInfo::builder()
.flags(flags)
.stage(conv::map_shader_stage(stage_flags))
.module(vk_module)
.name(&entry_point)

View File

@ -238,6 +238,7 @@ struct PrivateCapabilities {
robust_image_access2: bool,
zero_initialize_workgroup_memory: bool,
image_format_list: bool,
subgroup_size_control: bool,
}
bitflags::bitflags!(

View File

@ -143,6 +143,8 @@ fn print_adapter(output: &mut impl io::Write, report: &AdapterReport, idx: usize
max_vertex_buffers,
max_vertex_attributes,
max_vertex_buffer_array_stride,
min_subgroup_size,
max_subgroup_size,
max_push_constant_size,
min_uniform_buffer_offset_alignment,
min_storage_buffer_offset_alignment,
@ -176,6 +178,8 @@ fn print_adapter(output: &mut impl io::Write, report: &AdapterReport, idx: usize
writeln!(output, "\t\t Max Vertex Buffers: {max_vertex_buffers}")?;
writeln!(output, "\t\t Max Vertex Attributes: {max_vertex_attributes}")?;
writeln!(output, "\t\t Max Vertex Buffer Array Stride: {max_vertex_buffer_array_stride}")?;
writeln!(output, "\t\t Min Subgroup Size: {min_subgroup_size}")?;
writeln!(output, "\t\t Max Subgroup Size: {max_subgroup_size}")?;
writeln!(output, "\t\t Max Push Constant Size: {max_push_constant_size}")?;
writeln!(output, "\t\t Min Uniform Buffer Offset Alignment: {min_uniform_buffer_offset_alignment}")?;
writeln!(output, "\t\t Min Storage Buffer Offset Alignment: {min_storage_buffer_offset_alignment}")?;

View File

@ -890,6 +890,30 @@ bitflags::bitflags! {
///
/// This is a native only feature.
const SHADER_INT64 = 1 << 55;
/// Allows compute and fragment shaders to use the subgroup operation built-ins
///
/// Supported Platforms:
/// - Vulkan
/// - DX12
/// - Metal
///
/// This is a native only feature.
const SUBGROUP = 1 << 56;
/// Allows vertex shaders to use the subgroup operation built-ins
///
/// Supported Platforms:
/// - Vulkan
///
/// This is a native only feature.
const SUBGROUP_VERTEX = 1 << 57;
/// Allows shaders to use the subgroup barrier
///
/// Supported Platforms:
/// - Vulkan
/// - Metal
///
/// This is a native only feature.
const SUBGROUP_BARRIER = 1 << 58;
}
}
@ -1136,6 +1160,11 @@ pub struct Limits {
/// The maximum value for each dimension of a `ComputePass::dispatch(x, y, z)` operation.
/// Defaults to 65535. Higher is "better".
pub max_compute_workgroups_per_dimension: u32,
/// Minimal number of invocations in a subgroup. Higher is "better".
pub min_subgroup_size: u32,
/// Maximal number of invocations in a subgroup. Lower is "better".
pub max_subgroup_size: u32,
/// Amount of storage available for push constants in bytes. Defaults to 0. Higher is "better".
/// Requesting more than 0 during device creation requires [`Features::PUSH_CONSTANTS`] to be enabled.
///
@ -1146,7 +1175,6 @@ pub struct Limits {
/// - OpenGL doesn't natively support push constants, and are emulated with uniforms,
/// so this number is less useful but likely 256.
pub max_push_constant_size: u32,
/// Maximum number of live non-sampler bindings.
///
/// This limit only affects the d3d12 backend. Using a large number will allow the device
@ -1187,6 +1215,8 @@ impl Default for Limits {
max_compute_workgroup_size_y: 256,
max_compute_workgroup_size_z: 64,
max_compute_workgroups_per_dimension: 65535,
min_subgroup_size: 0,
max_subgroup_size: 0,
max_push_constant_size: 0,
max_non_sampler_bindings: 1_000_000,
}
@ -1218,6 +1248,8 @@ impl Limits {
/// max_vertex_buffers: 8,
/// max_vertex_attributes: 16,
/// max_vertex_buffer_array_stride: 2048,
/// min_subgroup_size: 0,
/// max_subgroup_size: 0,
/// max_push_constant_size: 0,
/// min_uniform_buffer_offset_alignment: 256,
/// min_storage_buffer_offset_alignment: 256,
@ -1254,6 +1286,8 @@ impl Limits {
max_vertex_buffers: 8,
max_vertex_attributes: 16,
max_vertex_buffer_array_stride: 2048,
min_subgroup_size: 0,
max_subgroup_size: 0,
max_push_constant_size: 0,
min_uniform_buffer_offset_alignment: 256,
min_storage_buffer_offset_alignment: 256,
@ -1296,6 +1330,8 @@ impl Limits {
/// max_vertex_buffers: 8,
/// max_vertex_attributes: 16,
/// max_vertex_buffer_array_stride: 255, // +
/// min_subgroup_size: 0,
/// max_subgroup_size: 0,
/// max_push_constant_size: 0,
/// min_uniform_buffer_offset_alignment: 256,
/// min_storage_buffer_offset_alignment: 256,
@ -1326,6 +1362,8 @@ impl Limits {
max_compute_workgroup_size_y: 0,
max_compute_workgroup_size_z: 0,
max_compute_workgroups_per_dimension: 0,
min_subgroup_size: 0,
max_subgroup_size: 0,
// Value supported by Intel Celeron B830 on Windows (OpenGL 3.1)
max_inter_stage_shader_components: 31,
@ -1418,6 +1456,10 @@ impl Limits {
compare!(max_vertex_buffers, Less);
compare!(max_vertex_attributes, Less);
compare!(max_vertex_buffer_array_stride, Less);
if self.min_subgroup_size > 0 && self.max_subgroup_size > 0 {
compare!(min_subgroup_size, Greater);
compare!(max_subgroup_size, Less);
}
compare!(max_push_constant_size, Less);
compare!(min_uniform_buffer_offset_alignment, Greater);
compare!(min_storage_buffer_offset_alignment, Greater);

View File

@ -737,6 +737,8 @@ fn map_wgt_limits(limits: webgpu_sys::GpuSupportedLimits) -> wgt::Limits {
max_compute_workgroup_size_z: limits.max_compute_workgroup_size_z(),
max_compute_workgroups_per_dimension: limits.max_compute_workgroups_per_dimension(),
// The following are not part of WebGPU
min_subgroup_size: wgt::Limits::default().min_subgroup_size,
max_subgroup_size: wgt::Limits::default().max_subgroup_size,
max_push_constant_size: wgt::Limits::default().max_push_constant_size,
max_non_sampler_bindings: wgt::Limits::default().max_non_sampler_bindings,
}