mirror of
https://github.com/gfx-rs/wgpu.git
synced 2024-11-24 07:43:49 +00:00
Subgroup Operations (#5301)
Co-authored-by: Jacob Hughes <j@distanthills.org> Co-authored-by: Connor Fitzgerald <connorwadefitzgerald@gmail.com> Co-authored-by: atlas dostal <rodol@rivalrebels.com>
This commit is contained in:
parent
0dc9dd6bec
commit
ea77d5674d
@ -138,6 +138,7 @@ Bottom level categories:
|
||||
### Bug Fixes
|
||||
|
||||
#### General
|
||||
- Add `SUBGROUP, SUBGROUP_VERTEX, SUBGROUP_BARRIER` features. By @exrook and @lichtso in [#5301](https://github.com/gfx-rs/wgpu/pull/5301)
|
||||
- Fix `serde` feature not compiling for `wgpu-types`. By @KirmesBude in [#5149](https://github.com/gfx-rs/wgpu/pull/5149)
|
||||
- Fix the validation of vertex and index ranges. By @nical in [#5144](https://github.com/gfx-rs/wgpu/pull/5144) and [#5156](https://github.com/gfx-rs/wgpu/pull/5156)
|
||||
- Fix panic when creating a surface while no backend is available. By @wumpf [#5166](https://github.com/gfx-rs/wgpu/pull/5166)
|
||||
|
@ -424,6 +424,8 @@ fn run() -> Result<(), Box<dyn std::error::Error>> {
|
||||
|
||||
// Validate the IR before compaction.
|
||||
let info = match naga::valid::Validator::new(params.validation_flags, validation_caps)
|
||||
.subgroup_stages(naga::valid::ShaderStages::all())
|
||||
.subgroup_operations(naga::valid::SubgroupOperationSet::all())
|
||||
.validate(&module)
|
||||
{
|
||||
Ok(info) => Some(info),
|
||||
@ -760,6 +762,8 @@ fn bulk_validate(args: Args, params: &Parameters) -> Result<(), Box<dyn std::err
|
||||
|
||||
let mut validator =
|
||||
naga::valid::Validator::new(params.validation_flags, naga::valid::Capabilities::all());
|
||||
validator.subgroup_stages(naga::valid::ShaderStages::all());
|
||||
validator.subgroup_operations(naga::valid::SubgroupOperationSet::all());
|
||||
|
||||
if let Err(error) = validator.validate(&module) {
|
||||
invalid.push(input_path.clone());
|
||||
|
@ -279,6 +279,94 @@ impl StatementGraph {
|
||||
crate::RayQueryFunction::Terminate => "RayQueryTerminate",
|
||||
}
|
||||
}
|
||||
S::SubgroupBallot { result, predicate } => {
|
||||
if let Some(predicate) = predicate {
|
||||
self.dependencies.push((id, predicate, "predicate"));
|
||||
}
|
||||
self.emits.push((id, result));
|
||||
"SubgroupBallot"
|
||||
}
|
||||
S::SubgroupCollectiveOperation {
|
||||
op,
|
||||
collective_op,
|
||||
argument,
|
||||
result,
|
||||
} => {
|
||||
self.dependencies.push((id, argument, "arg"));
|
||||
self.emits.push((id, result));
|
||||
match (collective_op, op) {
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
|
||||
"SubgroupAll"
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
|
||||
"SubgroupAny"
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
|
||||
"SubgroupAdd"
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
|
||||
"SubgroupMul"
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
|
||||
"SubgroupMax"
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
|
||||
"SubgroupMin"
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
|
||||
"SubgroupAnd"
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
|
||||
"SubgroupOr"
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
|
||||
"SubgroupXor"
|
||||
}
|
||||
(
|
||||
crate::CollectiveOperation::ExclusiveScan,
|
||||
crate::SubgroupOperation::Add,
|
||||
) => "SubgroupExclusiveAdd",
|
||||
(
|
||||
crate::CollectiveOperation::ExclusiveScan,
|
||||
crate::SubgroupOperation::Mul,
|
||||
) => "SubgroupExclusiveMul",
|
||||
(
|
||||
crate::CollectiveOperation::InclusiveScan,
|
||||
crate::SubgroupOperation::Add,
|
||||
) => "SubgroupInclusiveAdd",
|
||||
(
|
||||
crate::CollectiveOperation::InclusiveScan,
|
||||
crate::SubgroupOperation::Mul,
|
||||
) => "SubgroupInclusiveMul",
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
S::SubgroupGather {
|
||||
mode,
|
||||
argument,
|
||||
result,
|
||||
} => {
|
||||
match mode {
|
||||
crate::GatherMode::BroadcastFirst => {}
|
||||
crate::GatherMode::Broadcast(index)
|
||||
| crate::GatherMode::Shuffle(index)
|
||||
| crate::GatherMode::ShuffleDown(index)
|
||||
| crate::GatherMode::ShuffleUp(index)
|
||||
| crate::GatherMode::ShuffleXor(index) => {
|
||||
self.dependencies.push((id, index, "index"))
|
||||
}
|
||||
}
|
||||
self.dependencies.push((id, argument, "arg"));
|
||||
self.emits.push((id, result));
|
||||
match mode {
|
||||
crate::GatherMode::BroadcastFirst => "SubgroupBroadcastFirst",
|
||||
crate::GatherMode::Broadcast(_) => "SubgroupBroadcast",
|
||||
crate::GatherMode::Shuffle(_) => "SubgroupShuffle",
|
||||
crate::GatherMode::ShuffleDown(_) => "SubgroupShuffleDown",
|
||||
crate::GatherMode::ShuffleUp(_) => "SubgroupShuffleUp",
|
||||
crate::GatherMode::ShuffleXor(_) => "SubgroupShuffleXor",
|
||||
}
|
||||
}
|
||||
};
|
||||
// Set the last node to the merge node
|
||||
last_node = merge_id;
|
||||
@ -587,6 +675,8 @@ fn write_function_expressions(
|
||||
let ty = if committed { "Committed" } else { "Candidate" };
|
||||
(format!("rayQueryGet{}Intersection", ty).into(), 4)
|
||||
}
|
||||
E::SubgroupBallotResult => ("SubgroupBallotResult".into(), 4),
|
||||
E::SubgroupOperationResult { .. } => ("SubgroupOperationResult".into(), 4),
|
||||
};
|
||||
|
||||
// give uniform expressions an outline
|
||||
|
@ -50,6 +50,8 @@ bitflags::bitflags! {
|
||||
const INSTANCE_INDEX = 1 << 22;
|
||||
/// Sample specific LODs of cube / array shadow textures
|
||||
const TEXTURE_SHADOW_LOD = 1 << 23;
|
||||
/// Subgroup operations
|
||||
const SUBGROUP_OPERATIONS = 1 << 24;
|
||||
}
|
||||
}
|
||||
|
||||
@ -117,6 +119,7 @@ impl FeaturesManager {
|
||||
check_feature!(SAMPLE_VARIABLES, 400, 300);
|
||||
check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310);
|
||||
check_feature!(DUAL_SOURCE_BLENDING, 330, 300 /* with extension */);
|
||||
check_feature!(SUBGROUP_OPERATIONS, 430, 310);
|
||||
match version {
|
||||
Version::Embedded { is_webgl: true, .. } => check_feature!(MULTI_VIEW, 140, 300),
|
||||
_ => check_feature!(MULTI_VIEW, 140, 310),
|
||||
@ -259,6 +262,22 @@ impl FeaturesManager {
|
||||
writeln!(out, "#extension GL_EXT_texture_shadow_lod : require")?;
|
||||
}
|
||||
|
||||
if self.0.contains(Features::SUBGROUP_OPERATIONS) {
|
||||
// https://registry.khronos.org/OpenGL/extensions/KHR/KHR_shader_subgroup.txt
|
||||
writeln!(out, "#extension GL_KHR_shader_subgroup_basic : require")?;
|
||||
writeln!(out, "#extension GL_KHR_shader_subgroup_vote : require")?;
|
||||
writeln!(
|
||||
out,
|
||||
"#extension GL_KHR_shader_subgroup_arithmetic : require"
|
||||
)?;
|
||||
writeln!(out, "#extension GL_KHR_shader_subgroup_ballot : require")?;
|
||||
writeln!(out, "#extension GL_KHR_shader_subgroup_shuffle : require")?;
|
||||
writeln!(
|
||||
out,
|
||||
"#extension GL_KHR_shader_subgroup_shuffle_relative : require"
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -518,6 +537,10 @@ impl<'a, W> Writer<'a, W> {
|
||||
}
|
||||
}
|
||||
}
|
||||
Expression::SubgroupBallotResult |
|
||||
Expression::SubgroupOperationResult { .. } => {
|
||||
features.request(Features::SUBGROUP_OPERATIONS)
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
@ -2390,6 +2390,125 @@ impl<'a, W: Write> Writer<'a, W> {
|
||||
writeln!(self.out, ");")?;
|
||||
}
|
||||
Statement::RayQuery { .. } => unreachable!(),
|
||||
Statement::SubgroupBallot { result, predicate } => {
|
||||
write!(self.out, "{level}")?;
|
||||
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
|
||||
let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
|
||||
self.write_value_type(res_ty)?;
|
||||
write!(self.out, " {res_name} = ")?;
|
||||
self.named_expressions.insert(result, res_name);
|
||||
|
||||
write!(self.out, "subgroupBallot(")?;
|
||||
match predicate {
|
||||
Some(predicate) => self.write_expr(predicate, ctx)?,
|
||||
None => write!(self.out, "true")?,
|
||||
}
|
||||
writeln!(self.out, ");")?;
|
||||
}
|
||||
Statement::SubgroupCollectiveOperation {
|
||||
op,
|
||||
collective_op,
|
||||
argument,
|
||||
result,
|
||||
} => {
|
||||
write!(self.out, "{level}")?;
|
||||
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
|
||||
let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
|
||||
self.write_value_type(res_ty)?;
|
||||
write!(self.out, " {res_name} = ")?;
|
||||
self.named_expressions.insert(result, res_name);
|
||||
|
||||
match (collective_op, op) {
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
|
||||
write!(self.out, "subgroupAll(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
|
||||
write!(self.out, "subgroupAny(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
|
||||
write!(self.out, "subgroupAdd(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
|
||||
write!(self.out, "subgroupMul(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
|
||||
write!(self.out, "subgroupMax(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
|
||||
write!(self.out, "subgroupMin(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
|
||||
write!(self.out, "subgroupAnd(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
|
||||
write!(self.out, "subgroupOr(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
|
||||
write!(self.out, "subgroupXor(")?
|
||||
}
|
||||
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
|
||||
write!(self.out, "subgroupExclusiveAdd(")?
|
||||
}
|
||||
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
|
||||
write!(self.out, "subgroupExclusiveMul(")?
|
||||
}
|
||||
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
|
||||
write!(self.out, "subgroupInclusiveAdd(")?
|
||||
}
|
||||
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
|
||||
write!(self.out, "subgroupInclusiveMul(")?
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
self.write_expr(argument, ctx)?;
|
||||
writeln!(self.out, ");")?;
|
||||
}
|
||||
Statement::SubgroupGather {
|
||||
mode,
|
||||
argument,
|
||||
result,
|
||||
} => {
|
||||
write!(self.out, "{level}")?;
|
||||
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
|
||||
let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
|
||||
self.write_value_type(res_ty)?;
|
||||
write!(self.out, " {res_name} = ")?;
|
||||
self.named_expressions.insert(result, res_name);
|
||||
|
||||
match mode {
|
||||
crate::GatherMode::BroadcastFirst => {
|
||||
write!(self.out, "subgroupBroadcastFirst(")?;
|
||||
}
|
||||
crate::GatherMode::Broadcast(_) => {
|
||||
write!(self.out, "subgroupBroadcast(")?;
|
||||
}
|
||||
crate::GatherMode::Shuffle(_) => {
|
||||
write!(self.out, "subgroupShuffle(")?;
|
||||
}
|
||||
crate::GatherMode::ShuffleDown(_) => {
|
||||
write!(self.out, "subgroupShuffleDown(")?;
|
||||
}
|
||||
crate::GatherMode::ShuffleUp(_) => {
|
||||
write!(self.out, "subgroupShuffleUp(")?;
|
||||
}
|
||||
crate::GatherMode::ShuffleXor(_) => {
|
||||
write!(self.out, "subgroupShuffleXor(")?;
|
||||
}
|
||||
}
|
||||
self.write_expr(argument, ctx)?;
|
||||
match mode {
|
||||
crate::GatherMode::BroadcastFirst => {}
|
||||
crate::GatherMode::Broadcast(index)
|
||||
| crate::GatherMode::Shuffle(index)
|
||||
| crate::GatherMode::ShuffleDown(index)
|
||||
| crate::GatherMode::ShuffleUp(index)
|
||||
| crate::GatherMode::ShuffleXor(index) => {
|
||||
write!(self.out, ", ")?;
|
||||
self.write_expr(index, ctx)?;
|
||||
}
|
||||
}
|
||||
writeln!(self.out, ");")?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@ -3658,7 +3777,9 @@ impl<'a, W: Write> Writer<'a, W> {
|
||||
Expression::CallResult(_)
|
||||
| Expression::AtomicResult { .. }
|
||||
| Expression::RayQueryProceedResult
|
||||
| Expression::WorkGroupUniformLoadResult { .. } => unreachable!(),
|
||||
| Expression::WorkGroupUniformLoadResult { .. }
|
||||
| Expression::SubgroupOperationResult { .. }
|
||||
| Expression::SubgroupBallotResult => unreachable!(),
|
||||
// `ArrayLength` is written as `expr.length()` and we convert it to a uint
|
||||
Expression::ArrayLength(expr) => {
|
||||
write!(self.out, "uint(")?;
|
||||
@ -4227,6 +4348,9 @@ impl<'a, W: Write> Writer<'a, W> {
|
||||
if flags.contains(crate::Barrier::WORK_GROUP) {
|
||||
writeln!(self.out, "{level}memoryBarrierShared();")?;
|
||||
}
|
||||
if flags.contains(crate::Barrier::SUB_GROUP) {
|
||||
writeln!(self.out, "{level}subgroupMemoryBarrier();")?;
|
||||
}
|
||||
writeln!(self.out, "{level}barrier();")?;
|
||||
Ok(())
|
||||
}
|
||||
@ -4496,6 +4620,11 @@ const fn glsl_built_in(built_in: crate::BuiltIn, options: VaryingOptions) -> &'s
|
||||
Bi::WorkGroupId => "gl_WorkGroupID",
|
||||
Bi::WorkGroupSize => "gl_WorkGroupSize",
|
||||
Bi::NumWorkGroups => "gl_NumWorkGroups",
|
||||
// subgroup
|
||||
Bi::NumSubgroups => "gl_NumSubgroups",
|
||||
Bi::SubgroupId => "gl_SubgroupID",
|
||||
Bi::SubgroupSize => "gl_SubgroupSize",
|
||||
Bi::SubgroupInvocationId => "gl_SubgroupInvocationID",
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -179,6 +179,11 @@ impl crate::BuiltIn {
|
||||
// to this field will get replaced with references to `SPECIAL_CBUF_VAR`
|
||||
// in `Writer::write_expr`.
|
||||
Self::NumWorkGroups => "SV_GroupID",
|
||||
// These builtins map to functions
|
||||
Self::SubgroupSize
|
||||
| Self::SubgroupInvocationId
|
||||
| Self::NumSubgroups
|
||||
| Self::SubgroupId => unreachable!(),
|
||||
Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => {
|
||||
return Err(Error::Unimplemented(format!("builtin {self:?}")))
|
||||
}
|
||||
|
@ -77,6 +77,19 @@ enum Io {
|
||||
Output,
|
||||
}
|
||||
|
||||
const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
|
||||
let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
|
||||
return false;
|
||||
};
|
||||
matches!(
|
||||
builtin,
|
||||
crate::BuiltIn::SubgroupSize
|
||||
| crate::BuiltIn::SubgroupInvocationId
|
||||
| crate::BuiltIn::NumSubgroups
|
||||
| crate::BuiltIn::SubgroupId
|
||||
)
|
||||
}
|
||||
|
||||
impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
pub fn new(out: W, options: &'a Options) -> Self {
|
||||
Self {
|
||||
@ -161,6 +174,19 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
}
|
||||
}
|
||||
}
|
||||
for statement in func.body.iter() {
|
||||
match *statement {
|
||||
crate::Statement::SubgroupCollectiveOperation {
|
||||
op: _,
|
||||
collective_op: crate::CollectiveOperation::InclusiveScan,
|
||||
argument,
|
||||
result: _,
|
||||
} => {
|
||||
self.need_bake_expressions.insert(argument);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write(
|
||||
@ -401,31 +427,32 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
// if they are struct, so that the `stage` argument here could be omitted.
|
||||
fn write_semantic(
|
||||
&mut self,
|
||||
binding: &crate::Binding,
|
||||
binding: &Option<crate::Binding>,
|
||||
stage: Option<(ShaderStage, Io)>,
|
||||
) -> BackendResult {
|
||||
match *binding {
|
||||
crate::Binding::BuiltIn(builtin) => {
|
||||
Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
|
||||
let builtin_str = builtin.to_hlsl_str()?;
|
||||
write!(self.out, " : {builtin_str}")?;
|
||||
}
|
||||
crate::Binding::Location {
|
||||
Some(crate::Binding::Location {
|
||||
second_blend_source: true,
|
||||
..
|
||||
} => {
|
||||
}) => {
|
||||
write!(self.out, " : SV_Target1")?;
|
||||
}
|
||||
crate::Binding::Location {
|
||||
Some(crate::Binding::Location {
|
||||
location,
|
||||
second_blend_source: false,
|
||||
..
|
||||
} => {
|
||||
}) => {
|
||||
if stage == Some((crate::ShaderStage::Fragment, Io::Output)) {
|
||||
write!(self.out, " : SV_Target{location}")?;
|
||||
} else {
|
||||
write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@ -446,17 +473,30 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
write!(self.out, "struct {struct_name}")?;
|
||||
writeln!(self.out, " {{")?;
|
||||
for m in members.iter() {
|
||||
if is_subgroup_builtin_binding(&m.binding) {
|
||||
continue;
|
||||
}
|
||||
write!(self.out, "{}", back::INDENT)?;
|
||||
if let Some(ref binding) = m.binding {
|
||||
self.write_modifier(binding)?;
|
||||
}
|
||||
self.write_type(module, m.ty)?;
|
||||
write!(self.out, " {}", &m.name)?;
|
||||
if let Some(ref binding) = m.binding {
|
||||
self.write_semantic(binding, Some(shader_stage))?;
|
||||
}
|
||||
self.write_semantic(&m.binding, Some(shader_stage))?;
|
||||
writeln!(self.out, ";")?;
|
||||
}
|
||||
if members.iter().any(|arg| {
|
||||
matches!(
|
||||
arg.binding,
|
||||
Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId))
|
||||
)
|
||||
}) {
|
||||
writeln!(
|
||||
self.out,
|
||||
"{}uint __local_invocation_index : SV_GroupIndex;",
|
||||
back::INDENT
|
||||
)?;
|
||||
}
|
||||
writeln!(self.out, "}};")?;
|
||||
writeln!(self.out)?;
|
||||
|
||||
@ -557,8 +597,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
}
|
||||
|
||||
/// Writes special interface structures for an entry point. The special structures have
|
||||
/// all the fields flattened into them and sorted by binding. They are only needed for
|
||||
/// VS outputs and FS inputs, so that these interfaces match.
|
||||
/// all the fields flattened into them and sorted by binding. They are needed to emulate
|
||||
/// subgroup built-ins and to make the interfaces between VS outputs and FS inputs match.
|
||||
fn write_ep_interface(
|
||||
&mut self,
|
||||
module: &Module,
|
||||
@ -567,7 +607,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
ep_name: &str,
|
||||
) -> Result<EntryPointInterface, Error> {
|
||||
Ok(EntryPointInterface {
|
||||
input: if !func.arguments.is_empty() && stage == ShaderStage::Fragment {
|
||||
input: if !func.arguments.is_empty()
|
||||
&& (stage == ShaderStage::Fragment
|
||||
|| func
|
||||
.arguments
|
||||
.iter()
|
||||
.any(|arg| is_subgroup_builtin_binding(&arg.binding)))
|
||||
{
|
||||
Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
|
||||
} else {
|
||||
None
|
||||
@ -581,6 +627,38 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
})
|
||||
}
|
||||
|
||||
fn write_ep_argument_initialization(
|
||||
&mut self,
|
||||
ep: &crate::EntryPoint,
|
||||
ep_input: &EntryPointBinding,
|
||||
fake_member: &EpStructMember,
|
||||
) -> BackendResult {
|
||||
match fake_member.binding {
|
||||
Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
|
||||
write!(self.out, "WaveGetLaneCount()")?
|
||||
}
|
||||
Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
|
||||
write!(self.out, "WaveGetLaneIndex()")?
|
||||
}
|
||||
Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
|
||||
self.out,
|
||||
"({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
|
||||
ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
|
||||
)?,
|
||||
Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
|
||||
write!(
|
||||
self.out,
|
||||
"{}.__local_invocation_index / WaveGetLaneCount()",
|
||||
ep_input.arg_name
|
||||
)?;
|
||||
}
|
||||
_ => {
|
||||
write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Write an entry point preface that initializes the arguments as specified in IR.
|
||||
fn write_ep_arguments_initialization(
|
||||
&mut self,
|
||||
@ -588,6 +666,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
func: &crate::Function,
|
||||
ep_index: u16,
|
||||
) -> BackendResult {
|
||||
let ep = &module.entry_points[ep_index as usize];
|
||||
let ep_input = match self.entry_point_io[ep_index as usize].input.take() {
|
||||
Some(ep_input) => ep_input,
|
||||
None => return Ok(()),
|
||||
@ -601,8 +680,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
match module.types[arg.ty].inner {
|
||||
TypeInner::Array { base, size, .. } => {
|
||||
self.write_array_size(module, base, size)?;
|
||||
let fake_member = fake_iter.next().unwrap();
|
||||
writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?;
|
||||
write!(self.out, " = ")?;
|
||||
self.write_ep_argument_initialization(
|
||||
ep,
|
||||
&ep_input,
|
||||
fake_iter.next().unwrap(),
|
||||
)?;
|
||||
writeln!(self.out, ";")?;
|
||||
}
|
||||
TypeInner::Struct { ref members, .. } => {
|
||||
write!(self.out, " = {{ ")?;
|
||||
@ -610,14 +694,22 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
if index != 0 {
|
||||
write!(self.out, ", ")?;
|
||||
}
|
||||
let fake_member = fake_iter.next().unwrap();
|
||||
write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
|
||||
self.write_ep_argument_initialization(
|
||||
ep,
|
||||
&ep_input,
|
||||
fake_iter.next().unwrap(),
|
||||
)?;
|
||||
}
|
||||
writeln!(self.out, " }};")?;
|
||||
}
|
||||
_ => {
|
||||
let fake_member = fake_iter.next().unwrap();
|
||||
writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?;
|
||||
write!(self.out, " = ")?;
|
||||
self.write_ep_argument_initialization(
|
||||
ep,
|
||||
&ep_input,
|
||||
fake_iter.next().unwrap(),
|
||||
)?;
|
||||
writeln!(self.out, ";")?;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -932,9 +1024,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref binding) = member.binding {
|
||||
self.write_semantic(binding, shader_stage)?;
|
||||
};
|
||||
self.write_semantic(&member.binding, shader_stage)?;
|
||||
writeln!(self.out, ";")?;
|
||||
}
|
||||
|
||||
@ -1147,7 +1237,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
}
|
||||
back::FunctionType::EntryPoint(ep_index) => {
|
||||
if let Some(ref ep_input) = self.entry_point_io[ep_index as usize].input {
|
||||
write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name,)?;
|
||||
write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
|
||||
} else {
|
||||
let stage = module.entry_points[ep_index as usize].stage;
|
||||
for (index, arg) in func.arguments.iter().enumerate() {
|
||||
@ -1164,17 +1254,16 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
self.write_array_size(module, base, size)?;
|
||||
}
|
||||
|
||||
if let Some(ref binding) = arg.binding {
|
||||
self.write_semantic(binding, Some((stage, Io::Input)))?;
|
||||
}
|
||||
self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
|
||||
}
|
||||
|
||||
if need_workgroup_variables_initialization {
|
||||
if !func.arguments.is_empty() {
|
||||
write!(self.out, ", ")?;
|
||||
}
|
||||
write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
|
||||
}
|
||||
if need_workgroup_variables_initialization {
|
||||
if self.entry_point_io[ep_index as usize].input.is_some()
|
||||
|| !func.arguments.is_empty()
|
||||
{
|
||||
write!(self.out, ", ")?;
|
||||
}
|
||||
write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1184,11 +1273,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
// Write semantic if it present
|
||||
if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
|
||||
let stage = module.entry_points[index as usize].stage;
|
||||
if let Some(crate::FunctionResult {
|
||||
binding: Some(ref binding),
|
||||
..
|
||||
}) = func.result
|
||||
{
|
||||
if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
|
||||
self.write_semantic(binding, Some((stage, Io::Output)))?;
|
||||
}
|
||||
}
|
||||
@ -1988,6 +2073,129 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
writeln!(self.out, "{level}}}")?
|
||||
}
|
||||
Statement::RayQuery { .. } => unreachable!(),
|
||||
Statement::SubgroupBallot { result, predicate } => {
|
||||
write!(self.out, "{level}")?;
|
||||
let name = format!("{}{}", back::BAKE_PREFIX, result.index());
|
||||
write!(self.out, "const uint4 {name} = ")?;
|
||||
self.named_expressions.insert(result, name);
|
||||
|
||||
write!(self.out, "WaveActiveBallot(")?;
|
||||
match predicate {
|
||||
Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
|
||||
None => write!(self.out, "true")?,
|
||||
}
|
||||
writeln!(self.out, ");")?;
|
||||
}
|
||||
Statement::SubgroupCollectiveOperation {
|
||||
op,
|
||||
collective_op,
|
||||
argument,
|
||||
result,
|
||||
} => {
|
||||
write!(self.out, "{level}")?;
|
||||
write!(self.out, "const ")?;
|
||||
let name = format!("{}{}", back::BAKE_PREFIX, result.index());
|
||||
match func_ctx.info[result].ty {
|
||||
proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
|
||||
proc::TypeResolution::Value(ref value) => {
|
||||
self.write_value_type(module, value)?
|
||||
}
|
||||
};
|
||||
write!(self.out, " {name} = ")?;
|
||||
self.named_expressions.insert(result, name);
|
||||
|
||||
match (collective_op, op) {
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
|
||||
write!(self.out, "WaveActiveAllTrue(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
|
||||
write!(self.out, "WaveActiveAnyTrue(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
|
||||
write!(self.out, "WaveActiveSum(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
|
||||
write!(self.out, "WaveActiveProduct(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
|
||||
write!(self.out, "WaveActiveMax(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
|
||||
write!(self.out, "WaveActiveMin(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
|
||||
write!(self.out, "WaveActiveBitAnd(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
|
||||
write!(self.out, "WaveActiveBitOr(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
|
||||
write!(self.out, "WaveActiveBitXor(")?
|
||||
}
|
||||
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
|
||||
write!(self.out, "WavePrefixSum(")?
|
||||
}
|
||||
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
|
||||
write!(self.out, "WavePrefixProduct(")?
|
||||
}
|
||||
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
|
||||
self.write_expr(module, argument, func_ctx)?;
|
||||
write!(self.out, " + WavePrefixSum(")?;
|
||||
}
|
||||
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
|
||||
self.write_expr(module, argument, func_ctx)?;
|
||||
write!(self.out, " * WavePrefixProduct(")?;
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
self.write_expr(module, argument, func_ctx)?;
|
||||
writeln!(self.out, ");")?;
|
||||
}
|
||||
Statement::SubgroupGather {
|
||||
mode,
|
||||
argument,
|
||||
result,
|
||||
} => {
|
||||
write!(self.out, "{level}")?;
|
||||
write!(self.out, "const ")?;
|
||||
let name = format!("{}{}", back::BAKE_PREFIX, result.index());
|
||||
match func_ctx.info[result].ty {
|
||||
proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
|
||||
proc::TypeResolution::Value(ref value) => {
|
||||
self.write_value_type(module, value)?
|
||||
}
|
||||
};
|
||||
write!(self.out, " {name} = ")?;
|
||||
self.named_expressions.insert(result, name);
|
||||
|
||||
if matches!(mode, crate::GatherMode::BroadcastFirst) {
|
||||
write!(self.out, "WaveReadLaneFirst(")?;
|
||||
self.write_expr(module, argument, func_ctx)?;
|
||||
} else {
|
||||
write!(self.out, "WaveReadLaneAt(")?;
|
||||
self.write_expr(module, argument, func_ctx)?;
|
||||
write!(self.out, ", ")?;
|
||||
match mode {
|
||||
crate::GatherMode::BroadcastFirst => unreachable!(),
|
||||
crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) => {
|
||||
self.write_expr(module, index, func_ctx)?;
|
||||
}
|
||||
crate::GatherMode::ShuffleDown(index) => {
|
||||
write!(self.out, "WaveGetLaneIndex() + ")?;
|
||||
self.write_expr(module, index, func_ctx)?;
|
||||
}
|
||||
crate::GatherMode::ShuffleUp(index) => {
|
||||
write!(self.out, "WaveGetLaneIndex() - ")?;
|
||||
self.write_expr(module, index, func_ctx)?;
|
||||
}
|
||||
crate::GatherMode::ShuffleXor(index) => {
|
||||
write!(self.out, "WaveGetLaneIndex() ^ ")?;
|
||||
self.write_expr(module, index, func_ctx)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
writeln!(self.out, ");")?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@ -3134,7 +3342,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
Expression::CallResult(_)
|
||||
| Expression::AtomicResult { .. }
|
||||
| Expression::WorkGroupUniformLoadResult { .. }
|
||||
| Expression::RayQueryProceedResult => {}
|
||||
| Expression::RayQueryProceedResult
|
||||
| Expression::SubgroupBallotResult
|
||||
| Expression::SubgroupOperationResult { .. } => {}
|
||||
}
|
||||
|
||||
if !closing_bracket.is_empty() {
|
||||
@ -3201,6 +3411,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
if barrier.contains(crate::Barrier::WORK_GROUP) {
|
||||
writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
|
||||
}
|
||||
if barrier.contains(crate::Barrier::SUB_GROUP) {
|
||||
// Does not exist in DirectX
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -436,6 +436,11 @@ impl ResolvedBinding {
|
||||
Bi::WorkGroupId => "threadgroup_position_in_grid",
|
||||
Bi::WorkGroupSize => "dispatch_threads_per_threadgroup",
|
||||
Bi::NumWorkGroups => "threadgroups_per_grid",
|
||||
// subgroup
|
||||
Bi::NumSubgroups => "simdgroups_per_threadgroup",
|
||||
Bi::SubgroupId => "simdgroup_index_in_threadgroup",
|
||||
Bi::SubgroupSize => "threads_per_simdgroup",
|
||||
Bi::SubgroupInvocationId => "thread_index_in_simdgroup",
|
||||
Bi::CullDistance | Bi::ViewIndex => {
|
||||
return Err(Error::UnsupportedBuiltIn(built_in))
|
||||
}
|
||||
|
@ -2042,6 +2042,8 @@ impl<W: Write> Writer<W> {
|
||||
crate::Expression::CallResult(_)
|
||||
| crate::Expression::AtomicResult { .. }
|
||||
| crate::Expression::WorkGroupUniformLoadResult { .. }
|
||||
| crate::Expression::SubgroupBallotResult
|
||||
| crate::Expression::SubgroupOperationResult { .. }
|
||||
| crate::Expression::RayQueryProceedResult => {
|
||||
unreachable!()
|
||||
}
|
||||
@ -3145,6 +3147,121 @@ impl<W: Write> Writer<W> {
|
||||
}
|
||||
}
|
||||
}
|
||||
crate::Statement::SubgroupBallot { result, predicate } => {
|
||||
write!(self.out, "{level}")?;
|
||||
let name = self.namer.call("");
|
||||
self.start_baking_expression(result, &context.expression, &name)?;
|
||||
self.named_expressions.insert(result, name);
|
||||
write!(self.out, "uint4((uint64_t){NAMESPACE}::simd_ballot(")?;
|
||||
if let Some(predicate) = predicate {
|
||||
self.put_expression(predicate, &context.expression, true)?;
|
||||
} else {
|
||||
write!(self.out, "true")?;
|
||||
}
|
||||
writeln!(self.out, "), 0, 0, 0);")?;
|
||||
}
|
||||
crate::Statement::SubgroupCollectiveOperation {
|
||||
op,
|
||||
collective_op,
|
||||
argument,
|
||||
result,
|
||||
} => {
|
||||
write!(self.out, "{level}")?;
|
||||
let name = self.namer.call("");
|
||||
self.start_baking_expression(result, &context.expression, &name)?;
|
||||
self.named_expressions.insert(result, name);
|
||||
match (collective_op, op) {
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
|
||||
write!(self.out, "{NAMESPACE}::simd_all(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
|
||||
write!(self.out, "{NAMESPACE}::simd_any(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
|
||||
write!(self.out, "{NAMESPACE}::simd_sum(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
|
||||
write!(self.out, "{NAMESPACE}::simd_product(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
|
||||
write!(self.out, "{NAMESPACE}::simd_max(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
|
||||
write!(self.out, "{NAMESPACE}::simd_min(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
|
||||
write!(self.out, "{NAMESPACE}::simd_and(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
|
||||
write!(self.out, "{NAMESPACE}::simd_or(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
|
||||
write!(self.out, "{NAMESPACE}::simd_xor(")?
|
||||
}
|
||||
(
|
||||
crate::CollectiveOperation::ExclusiveScan,
|
||||
crate::SubgroupOperation::Add,
|
||||
) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
|
||||
(
|
||||
crate::CollectiveOperation::ExclusiveScan,
|
||||
crate::SubgroupOperation::Mul,
|
||||
) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
|
||||
(
|
||||
crate::CollectiveOperation::InclusiveScan,
|
||||
crate::SubgroupOperation::Add,
|
||||
) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
|
||||
(
|
||||
crate::CollectiveOperation::InclusiveScan,
|
||||
crate::SubgroupOperation::Mul,
|
||||
) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
self.put_expression(argument, &context.expression, true)?;
|
||||
writeln!(self.out, ");")?;
|
||||
}
|
||||
crate::Statement::SubgroupGather {
|
||||
mode,
|
||||
argument,
|
||||
result,
|
||||
} => {
|
||||
write!(self.out, "{level}")?;
|
||||
let name = self.namer.call("");
|
||||
self.start_baking_expression(result, &context.expression, &name)?;
|
||||
self.named_expressions.insert(result, name);
|
||||
match mode {
|
||||
crate::GatherMode::BroadcastFirst => {
|
||||
write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
|
||||
}
|
||||
crate::GatherMode::Broadcast(_) => {
|
||||
write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
|
||||
}
|
||||
crate::GatherMode::Shuffle(_) => {
|
||||
write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
|
||||
}
|
||||
crate::GatherMode::ShuffleDown(_) => {
|
||||
write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
|
||||
}
|
||||
crate::GatherMode::ShuffleUp(_) => {
|
||||
write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
|
||||
}
|
||||
crate::GatherMode::ShuffleXor(_) => {
|
||||
write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
|
||||
}
|
||||
}
|
||||
self.put_expression(argument, &context.expression, true)?;
|
||||
match mode {
|
||||
crate::GatherMode::BroadcastFirst => {}
|
||||
crate::GatherMode::Broadcast(index)
|
||||
| crate::GatherMode::Shuffle(index)
|
||||
| crate::GatherMode::ShuffleDown(index)
|
||||
| crate::GatherMode::ShuffleUp(index)
|
||||
| crate::GatherMode::ShuffleXor(index) => {
|
||||
write!(self.out, ", ")?;
|
||||
self.put_expression(index, &context.expression, true)?;
|
||||
}
|
||||
}
|
||||
writeln!(self.out, ");")?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -4492,6 +4609,12 @@ impl<W: Write> Writer<W> {
|
||||
"{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
|
||||
)?;
|
||||
}
|
||||
if flags.contains(crate::Barrier::SUB_GROUP) {
|
||||
writeln!(
|
||||
self.out,
|
||||
"{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -4762,8 +4885,8 @@ fn test_stack_size() {
|
||||
}
|
||||
let stack_size = addresses_end - addresses_start;
|
||||
// check the size (in debug only)
|
||||
// last observed macOS value: 19152 (CI)
|
||||
if !(9000..=20000).contains(&stack_size) {
|
||||
// last observed macOS value: 22256 (CI)
|
||||
if !(15000..=25000).contains(&stack_size) {
|
||||
panic!("`put_block` stack size {stack_size} has changed!");
|
||||
}
|
||||
}
|
||||
|
@ -522,7 +522,9 @@ fn adjust_expr(new_pos: &[Handle<Expression>], expr: &mut Expression) {
|
||||
ty: _,
|
||||
comparison: _,
|
||||
}
|
||||
| Expression::WorkGroupUniformLoadResult { ty: _ } => {}
|
||||
| Expression::WorkGroupUniformLoadResult { ty: _ }
|
||||
| Expression::SubgroupBallotResult
|
||||
| Expression::SubgroupOperationResult { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
@ -637,6 +639,41 @@ fn adjust_stmt(new_pos: &[Handle<Expression>], stmt: &mut Statement) {
|
||||
adjust(pointer);
|
||||
adjust(result);
|
||||
}
|
||||
Statement::SubgroupBallot {
|
||||
ref mut result,
|
||||
ref mut predicate,
|
||||
} => {
|
||||
if let Some(ref mut predicate) = *predicate {
|
||||
adjust(predicate);
|
||||
}
|
||||
adjust(result);
|
||||
}
|
||||
Statement::SubgroupCollectiveOperation {
|
||||
ref mut argument,
|
||||
ref mut result,
|
||||
..
|
||||
} => {
|
||||
adjust(argument);
|
||||
adjust(result);
|
||||
}
|
||||
Statement::SubgroupGather {
|
||||
ref mut mode,
|
||||
ref mut argument,
|
||||
ref mut result,
|
||||
} => {
|
||||
match *mode {
|
||||
crate::GatherMode::BroadcastFirst => {}
|
||||
crate::GatherMode::Broadcast(ref mut index)
|
||||
| crate::GatherMode::Shuffle(ref mut index)
|
||||
| crate::GatherMode::ShuffleDown(ref mut index)
|
||||
| crate::GatherMode::ShuffleUp(ref mut index)
|
||||
| crate::GatherMode::ShuffleXor(ref mut index) => {
|
||||
adjust(index);
|
||||
}
|
||||
}
|
||||
adjust(argument);
|
||||
adjust(result)
|
||||
}
|
||||
Statement::Call {
|
||||
ref mut arguments,
|
||||
ref mut result,
|
||||
|
@ -1279,7 +1279,9 @@ impl<'w> BlockContext<'w> {
|
||||
crate::Expression::CallResult(_)
|
||||
| crate::Expression::AtomicResult { .. }
|
||||
| crate::Expression::WorkGroupUniformLoadResult { .. }
|
||||
| crate::Expression::RayQueryProceedResult => self.cached[expr_handle],
|
||||
| crate::Expression::RayQueryProceedResult
|
||||
| crate::Expression::SubgroupBallotResult
|
||||
| crate::Expression::SubgroupOperationResult { .. } => self.cached[expr_handle],
|
||||
crate::Expression::As {
|
||||
expr,
|
||||
kind,
|
||||
@ -2490,6 +2492,27 @@ impl<'w> BlockContext<'w> {
|
||||
crate::Statement::RayQuery { query, ref fun } => {
|
||||
self.write_ray_query_function(query, fun, &mut block);
|
||||
}
|
||||
crate::Statement::SubgroupBallot {
|
||||
result,
|
||||
ref predicate,
|
||||
} => {
|
||||
self.write_subgroup_ballot(predicate, result, &mut block)?;
|
||||
}
|
||||
crate::Statement::SubgroupCollectiveOperation {
|
||||
ref op,
|
||||
ref collective_op,
|
||||
argument,
|
||||
result,
|
||||
} => {
|
||||
self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?;
|
||||
}
|
||||
crate::Statement::SubgroupGather {
|
||||
ref mode,
|
||||
argument,
|
||||
result,
|
||||
} => {
|
||||
self.write_subgroup_gather(mode, argument, result, &mut block)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1073,6 +1073,73 @@ impl super::Instruction {
|
||||
instruction.add_operand(semantics_id);
|
||||
instruction
|
||||
}
|
||||
|
||||
// Group Instructions
|
||||
|
||||
pub(super) fn group_non_uniform_ballot(
|
||||
result_type_id: Word,
|
||||
id: Word,
|
||||
exec_scope_id: Word,
|
||||
predicate: Word,
|
||||
) -> Self {
|
||||
let mut instruction = Self::new(Op::GroupNonUniformBallot);
|
||||
instruction.set_type(result_type_id);
|
||||
instruction.set_result(id);
|
||||
instruction.add_operand(exec_scope_id);
|
||||
instruction.add_operand(predicate);
|
||||
|
||||
instruction
|
||||
}
|
||||
pub(super) fn group_non_uniform_broadcast_first(
|
||||
result_type_id: Word,
|
||||
id: Word,
|
||||
exec_scope_id: Word,
|
||||
value: Word,
|
||||
) -> Self {
|
||||
let mut instruction = Self::new(Op::GroupNonUniformBroadcastFirst);
|
||||
instruction.set_type(result_type_id);
|
||||
instruction.set_result(id);
|
||||
instruction.add_operand(exec_scope_id);
|
||||
instruction.add_operand(value);
|
||||
|
||||
instruction
|
||||
}
|
||||
pub(super) fn group_non_uniform_gather(
|
||||
op: Op,
|
||||
result_type_id: Word,
|
||||
id: Word,
|
||||
exec_scope_id: Word,
|
||||
value: Word,
|
||||
index: Word,
|
||||
) -> Self {
|
||||
let mut instruction = Self::new(op);
|
||||
instruction.set_type(result_type_id);
|
||||
instruction.set_result(id);
|
||||
instruction.add_operand(exec_scope_id);
|
||||
instruction.add_operand(value);
|
||||
instruction.add_operand(index);
|
||||
|
||||
instruction
|
||||
}
|
||||
pub(super) fn group_non_uniform_arithmetic(
|
||||
op: Op,
|
||||
result_type_id: Word,
|
||||
id: Word,
|
||||
exec_scope_id: Word,
|
||||
group_op: Option<spirv::GroupOperation>,
|
||||
value: Word,
|
||||
) -> Self {
|
||||
let mut instruction = Self::new(op);
|
||||
instruction.set_type(result_type_id);
|
||||
instruction.set_result(id);
|
||||
instruction.add_operand(exec_scope_id);
|
||||
if let Some(group_op) = group_op {
|
||||
instruction.add_operand(group_op as u32);
|
||||
}
|
||||
instruction.add_operand(value);
|
||||
|
||||
instruction
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::StorageFormat> for spirv::ImageFormat {
|
||||
|
@ -13,6 +13,7 @@ mod layout;
|
||||
mod ray;
|
||||
mod recyclable;
|
||||
mod selection;
|
||||
mod subgroup;
|
||||
mod writer;
|
||||
|
||||
pub use spirv::Capability;
|
||||
|
207
naga/src/back/spv/subgroup.rs
Normal file
207
naga/src/back/spv/subgroup.rs
Normal file
@ -0,0 +1,207 @@
|
||||
use super::{Block, BlockContext, Error, Instruction};
|
||||
use crate::{
|
||||
arena::Handle,
|
||||
back::spv::{LocalType, LookupType},
|
||||
TypeInner,
|
||||
};
|
||||
|
||||
impl<'w> BlockContext<'w> {
|
||||
pub(super) fn write_subgroup_ballot(
|
||||
&mut self,
|
||||
predicate: &Option<Handle<crate::Expression>>,
|
||||
result: Handle<crate::Expression>,
|
||||
block: &mut Block,
|
||||
) -> Result<(), Error> {
|
||||
self.writer.require_any(
|
||||
"GroupNonUniformBallot",
|
||||
&[spirv::Capability::GroupNonUniformBallot],
|
||||
)?;
|
||||
let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
|
||||
vector_size: Some(crate::VectorSize::Quad),
|
||||
scalar: crate::Scalar::U32,
|
||||
pointer_space: None,
|
||||
}));
|
||||
let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
|
||||
let predicate = if let Some(predicate) = *predicate {
|
||||
self.cached[predicate]
|
||||
} else {
|
||||
self.writer.get_constant_scalar(crate::Literal::Bool(true))
|
||||
};
|
||||
let id = self.gen_id();
|
||||
block.body.push(Instruction::group_non_uniform_ballot(
|
||||
vec4_u32_type_id,
|
||||
id,
|
||||
exec_scope_id,
|
||||
predicate,
|
||||
));
|
||||
self.cached[result] = id;
|
||||
Ok(())
|
||||
}
|
||||
pub(super) fn write_subgroup_operation(
|
||||
&mut self,
|
||||
op: &crate::SubgroupOperation,
|
||||
collective_op: &crate::CollectiveOperation,
|
||||
argument: Handle<crate::Expression>,
|
||||
result: Handle<crate::Expression>,
|
||||
block: &mut Block,
|
||||
) -> Result<(), Error> {
|
||||
use crate::SubgroupOperation as sg;
|
||||
match *op {
|
||||
sg::All | sg::Any => {
|
||||
self.writer.require_any(
|
||||
"GroupNonUniformVote",
|
||||
&[spirv::Capability::GroupNonUniformVote],
|
||||
)?;
|
||||
}
|
||||
_ => {
|
||||
self.writer.require_any(
|
||||
"GroupNonUniformArithmetic",
|
||||
&[spirv::Capability::GroupNonUniformArithmetic],
|
||||
)?;
|
||||
}
|
||||
}
|
||||
|
||||
let id = self.gen_id();
|
||||
let result_ty = &self.fun_info[result].ty;
|
||||
let result_type_id = self.get_expression_type_id(result_ty);
|
||||
let result_ty_inner = result_ty.inner_with(&self.ir_module.types);
|
||||
|
||||
let (is_scalar, scalar) = match *result_ty_inner {
|
||||
TypeInner::Scalar(kind) => (true, kind),
|
||||
TypeInner::Vector { scalar: kind, .. } => (false, kind),
|
||||
_ => unimplemented!(),
|
||||
};
|
||||
|
||||
use crate::ScalarKind as sk;
|
||||
let spirv_op = match (scalar.kind, *op) {
|
||||
(sk::Bool, sg::All) if is_scalar => spirv::Op::GroupNonUniformAll,
|
||||
(sk::Bool, sg::Any) if is_scalar => spirv::Op::GroupNonUniformAny,
|
||||
(_, sg::All | sg::Any) => unimplemented!(),
|
||||
|
||||
(sk::Sint | sk::Uint, sg::Add) => spirv::Op::GroupNonUniformIAdd,
|
||||
(sk::Float, sg::Add) => spirv::Op::GroupNonUniformFAdd,
|
||||
(sk::Sint | sk::Uint, sg::Mul) => spirv::Op::GroupNonUniformIMul,
|
||||
(sk::Float, sg::Mul) => spirv::Op::GroupNonUniformFMul,
|
||||
(sk::Sint, sg::Max) => spirv::Op::GroupNonUniformSMax,
|
||||
(sk::Uint, sg::Max) => spirv::Op::GroupNonUniformUMax,
|
||||
(sk::Float, sg::Max) => spirv::Op::GroupNonUniformFMax,
|
||||
(sk::Sint, sg::Min) => spirv::Op::GroupNonUniformSMin,
|
||||
(sk::Uint, sg::Min) => spirv::Op::GroupNonUniformUMin,
|
||||
(sk::Float, sg::Min) => spirv::Op::GroupNonUniformFMin,
|
||||
(_, sg::Add | sg::Mul | sg::Min | sg::Max) => unimplemented!(),
|
||||
|
||||
(sk::Sint | sk::Uint, sg::And) => spirv::Op::GroupNonUniformBitwiseAnd,
|
||||
(sk::Sint | sk::Uint, sg::Or) => spirv::Op::GroupNonUniformBitwiseOr,
|
||||
(sk::Sint | sk::Uint, sg::Xor) => spirv::Op::GroupNonUniformBitwiseXor,
|
||||
(sk::Bool, sg::And) => spirv::Op::GroupNonUniformLogicalAnd,
|
||||
(sk::Bool, sg::Or) => spirv::Op::GroupNonUniformLogicalOr,
|
||||
(sk::Bool, sg::Xor) => spirv::Op::GroupNonUniformLogicalXor,
|
||||
(_, sg::And | sg::Or | sg::Xor) => unimplemented!(),
|
||||
};
|
||||
|
||||
let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
|
||||
|
||||
use crate::CollectiveOperation as c;
|
||||
let group_op = match *op {
|
||||
sg::All | sg::Any => None,
|
||||
_ => Some(match *collective_op {
|
||||
c::Reduce => spirv::GroupOperation::Reduce,
|
||||
c::InclusiveScan => spirv::GroupOperation::InclusiveScan,
|
||||
c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan,
|
||||
}),
|
||||
};
|
||||
|
||||
let arg_id = self.cached[argument];
|
||||
block.body.push(Instruction::group_non_uniform_arithmetic(
|
||||
spirv_op,
|
||||
result_type_id,
|
||||
id,
|
||||
exec_scope_id,
|
||||
group_op,
|
||||
arg_id,
|
||||
));
|
||||
self.cached[result] = id;
|
||||
Ok(())
|
||||
}
|
||||
pub(super) fn write_subgroup_gather(
|
||||
&mut self,
|
||||
mode: &crate::GatherMode,
|
||||
argument: Handle<crate::Expression>,
|
||||
result: Handle<crate::Expression>,
|
||||
block: &mut Block,
|
||||
) -> Result<(), Error> {
|
||||
self.writer.require_any(
|
||||
"GroupNonUniformBallot",
|
||||
&[spirv::Capability::GroupNonUniformBallot],
|
||||
)?;
|
||||
match *mode {
|
||||
crate::GatherMode::BroadcastFirst | crate::GatherMode::Broadcast(_) => {
|
||||
self.writer.require_any(
|
||||
"GroupNonUniformBallot",
|
||||
&[spirv::Capability::GroupNonUniformBallot],
|
||||
)?;
|
||||
}
|
||||
crate::GatherMode::Shuffle(_) | crate::GatherMode::ShuffleXor(_) => {
|
||||
self.writer.require_any(
|
||||
"GroupNonUniformShuffle",
|
||||
&[spirv::Capability::GroupNonUniformShuffle],
|
||||
)?;
|
||||
}
|
||||
crate::GatherMode::ShuffleDown(_) | crate::GatherMode::ShuffleUp(_) => {
|
||||
self.writer.require_any(
|
||||
"GroupNonUniformShuffleRelative",
|
||||
&[spirv::Capability::GroupNonUniformShuffleRelative],
|
||||
)?;
|
||||
}
|
||||
}
|
||||
|
||||
let id = self.gen_id();
|
||||
let result_ty = &self.fun_info[result].ty;
|
||||
let result_type_id = self.get_expression_type_id(result_ty);
|
||||
|
||||
let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
|
||||
|
||||
let arg_id = self.cached[argument];
|
||||
match *mode {
|
||||
crate::GatherMode::BroadcastFirst => {
|
||||
block
|
||||
.body
|
||||
.push(Instruction::group_non_uniform_broadcast_first(
|
||||
result_type_id,
|
||||
id,
|
||||
exec_scope_id,
|
||||
arg_id,
|
||||
));
|
||||
}
|
||||
crate::GatherMode::Broadcast(index)
|
||||
| crate::GatherMode::Shuffle(index)
|
||||
| crate::GatherMode::ShuffleDown(index)
|
||||
| crate::GatherMode::ShuffleUp(index)
|
||||
| crate::GatherMode::ShuffleXor(index) => {
|
||||
let index_id = self.cached[index];
|
||||
let op = match *mode {
|
||||
crate::GatherMode::BroadcastFirst => unreachable!(),
|
||||
// Use shuffle to emit broadcast to allow the index to
|
||||
// be dynamically uniform on Vulkan 1.1. The argument to
|
||||
// OpGroupNonUniformBroadcast must be a constant pre SPIR-V
|
||||
// 1.5 (vulkan 1.2)
|
||||
crate::GatherMode::Broadcast(_) => spirv::Op::GroupNonUniformShuffle,
|
||||
crate::GatherMode::Shuffle(_) => spirv::Op::GroupNonUniformShuffle,
|
||||
crate::GatherMode::ShuffleDown(_) => spirv::Op::GroupNonUniformShuffleDown,
|
||||
crate::GatherMode::ShuffleUp(_) => spirv::Op::GroupNonUniformShuffleUp,
|
||||
crate::GatherMode::ShuffleXor(_) => spirv::Op::GroupNonUniformShuffleXor,
|
||||
};
|
||||
block.body.push(Instruction::group_non_uniform_gather(
|
||||
op,
|
||||
result_type_id,
|
||||
id,
|
||||
exec_scope_id,
|
||||
arg_id,
|
||||
index_id,
|
||||
));
|
||||
}
|
||||
}
|
||||
self.cached[result] = id;
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -1310,7 +1310,11 @@ impl Writer {
|
||||
spirv::MemorySemantics::WORKGROUP_MEMORY,
|
||||
flags.contains(crate::Barrier::WORK_GROUP),
|
||||
);
|
||||
let exec_scope_id = self.get_index_constant(spirv::Scope::Workgroup as u32);
|
||||
let exec_scope_id = if flags.contains(crate::Barrier::SUB_GROUP) {
|
||||
self.get_index_constant(spirv::Scope::Subgroup as u32)
|
||||
} else {
|
||||
self.get_index_constant(spirv::Scope::Workgroup as u32)
|
||||
};
|
||||
let mem_scope_id = self.get_index_constant(memory_scope as u32);
|
||||
let semantics_id = self.get_index_constant(semantics.bits());
|
||||
block.body.push(Instruction::control_barrier(
|
||||
@ -1585,6 +1589,41 @@ impl Writer {
|
||||
Bi::WorkGroupId => BuiltIn::WorkgroupId,
|
||||
Bi::WorkGroupSize => BuiltIn::WorkgroupSize,
|
||||
Bi::NumWorkGroups => BuiltIn::NumWorkgroups,
|
||||
// Subgroup
|
||||
Bi::NumSubgroups => {
|
||||
self.require_any(
|
||||
"`num_subgroups` built-in",
|
||||
&[spirv::Capability::GroupNonUniform],
|
||||
)?;
|
||||
BuiltIn::NumSubgroups
|
||||
}
|
||||
Bi::SubgroupId => {
|
||||
self.require_any(
|
||||
"`subgroup_id` built-in",
|
||||
&[spirv::Capability::GroupNonUniform],
|
||||
)?;
|
||||
BuiltIn::SubgroupId
|
||||
}
|
||||
Bi::SubgroupSize => {
|
||||
self.require_any(
|
||||
"`subgroup_size` built-in",
|
||||
&[
|
||||
spirv::Capability::GroupNonUniform,
|
||||
spirv::Capability::SubgroupBallotKHR,
|
||||
],
|
||||
)?;
|
||||
BuiltIn::SubgroupSize
|
||||
}
|
||||
Bi::SubgroupInvocationId => {
|
||||
self.require_any(
|
||||
"`subgroup_invocation_id` built-in",
|
||||
&[
|
||||
spirv::Capability::GroupNonUniform,
|
||||
spirv::Capability::SubgroupBallotKHR,
|
||||
],
|
||||
)?;
|
||||
BuiltIn::SubgroupLocalInvocationId
|
||||
}
|
||||
};
|
||||
|
||||
self.decorate(id, Decoration::BuiltIn, &[built_in as u32]);
|
||||
|
@ -924,8 +924,124 @@ impl<W: Write> Writer<W> {
|
||||
if barrier.contains(crate::Barrier::WORK_GROUP) {
|
||||
writeln!(self.out, "{level}workgroupBarrier();")?;
|
||||
}
|
||||
|
||||
if barrier.contains(crate::Barrier::SUB_GROUP) {
|
||||
writeln!(self.out, "{level}subgroupBarrier();")?;
|
||||
}
|
||||
}
|
||||
Statement::RayQuery { .. } => unreachable!(),
|
||||
Statement::SubgroupBallot { result, predicate } => {
|
||||
write!(self.out, "{level}")?;
|
||||
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
|
||||
self.start_named_expr(module, result, func_ctx, &res_name)?;
|
||||
self.named_expressions.insert(result, res_name);
|
||||
|
||||
write!(self.out, "subgroupBallot(")?;
|
||||
if let Some(predicate) = predicate {
|
||||
self.write_expr(module, predicate, func_ctx)?;
|
||||
}
|
||||
writeln!(self.out, ");")?;
|
||||
}
|
||||
Statement::SubgroupCollectiveOperation {
|
||||
op,
|
||||
collective_op,
|
||||
argument,
|
||||
result,
|
||||
} => {
|
||||
write!(self.out, "{level}")?;
|
||||
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
|
||||
self.start_named_expr(module, result, func_ctx, &res_name)?;
|
||||
self.named_expressions.insert(result, res_name);
|
||||
|
||||
match (collective_op, op) {
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
|
||||
write!(self.out, "subgroupAll(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
|
||||
write!(self.out, "subgroupAny(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
|
||||
write!(self.out, "subgroupAdd(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
|
||||
write!(self.out, "subgroupMul(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
|
||||
write!(self.out, "subgroupMax(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
|
||||
write!(self.out, "subgroupMin(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
|
||||
write!(self.out, "subgroupAnd(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
|
||||
write!(self.out, "subgroupOr(")?
|
||||
}
|
||||
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
|
||||
write!(self.out, "subgroupXor(")?
|
||||
}
|
||||
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
|
||||
write!(self.out, "subgroupExclusiveAdd(")?
|
||||
}
|
||||
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
|
||||
write!(self.out, "subgroupExclusiveMul(")?
|
||||
}
|
||||
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
|
||||
write!(self.out, "subgroupInclusiveAdd(")?
|
||||
}
|
||||
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
|
||||
write!(self.out, "subgroupInclusiveMul(")?
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
self.write_expr(module, argument, func_ctx)?;
|
||||
writeln!(self.out, ");")?;
|
||||
}
|
||||
Statement::SubgroupGather {
|
||||
mode,
|
||||
argument,
|
||||
result,
|
||||
} => {
|
||||
write!(self.out, "{level}")?;
|
||||
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
|
||||
self.start_named_expr(module, result, func_ctx, &res_name)?;
|
||||
self.named_expressions.insert(result, res_name);
|
||||
|
||||
match mode {
|
||||
crate::GatherMode::BroadcastFirst => {
|
||||
write!(self.out, "subgroupBroadcastFirst(")?;
|
||||
}
|
||||
crate::GatherMode::Broadcast(_) => {
|
||||
write!(self.out, "subgroupBroadcast(")?;
|
||||
}
|
||||
crate::GatherMode::Shuffle(_) => {
|
||||
write!(self.out, "subgroupShuffle(")?;
|
||||
}
|
||||
crate::GatherMode::ShuffleDown(_) => {
|
||||
write!(self.out, "subgroupShuffleDown(")?;
|
||||
}
|
||||
crate::GatherMode::ShuffleUp(_) => {
|
||||
write!(self.out, "subgroupShuffleUp(")?;
|
||||
}
|
||||
crate::GatherMode::ShuffleXor(_) => {
|
||||
write!(self.out, "subgroupShuffleXor(")?;
|
||||
}
|
||||
}
|
||||
self.write_expr(module, argument, func_ctx)?;
|
||||
match mode {
|
||||
crate::GatherMode::BroadcastFirst => {}
|
||||
crate::GatherMode::Broadcast(index)
|
||||
| crate::GatherMode::Shuffle(index)
|
||||
| crate::GatherMode::ShuffleDown(index)
|
||||
| crate::GatherMode::ShuffleUp(index)
|
||||
| crate::GatherMode::ShuffleXor(index) => {
|
||||
write!(self.out, ", ")?;
|
||||
self.write_expr(module, index, func_ctx)?;
|
||||
}
|
||||
}
|
||||
writeln!(self.out, ");")?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@ -1698,6 +1814,8 @@ impl<W: Write> Writer<W> {
|
||||
Expression::CallResult(_)
|
||||
| Expression::AtomicResult { .. }
|
||||
| Expression::RayQueryProceedResult
|
||||
| Expression::SubgroupBallotResult
|
||||
| Expression::SubgroupOperationResult { .. }
|
||||
| Expression::WorkGroupUniformLoadResult { .. } => {}
|
||||
}
|
||||
|
||||
@ -1799,6 +1917,10 @@ fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> {
|
||||
Bi::SampleMask => "sample_mask",
|
||||
Bi::PrimitiveIndex => "primitive_index",
|
||||
Bi::ViewIndex => "view_index",
|
||||
Bi::NumSubgroups => "num_subgroups",
|
||||
Bi::SubgroupId => "subgroup_id",
|
||||
Bi::SubgroupSize => "subgroup_size",
|
||||
Bi::SubgroupInvocationId => "subgroup_invocation_id",
|
||||
Bi::BaseInstance
|
||||
| Bi::BaseVertex
|
||||
| Bi::ClipDistance
|
||||
|
@ -72,6 +72,7 @@ impl<'tracer> ExpressionTracer<'tracer> {
|
||||
| Ex::GlobalVariable(_)
|
||||
| Ex::LocalVariable(_)
|
||||
| Ex::CallResult(_)
|
||||
| Ex::SubgroupBallotResult
|
||||
| Ex::RayQueryProceedResult => {}
|
||||
|
||||
Ex::Constant(handle) => {
|
||||
@ -192,6 +193,7 @@ impl<'tracer> ExpressionTracer<'tracer> {
|
||||
Ex::AtomicResult { ty, comparison: _ } => self.types_used.insert(ty),
|
||||
Ex::WorkGroupUniformLoadResult { ty } => self.types_used.insert(ty),
|
||||
Ex::ArrayLength(expr) => self.expressions_used.insert(expr),
|
||||
Ex::SubgroupOperationResult { ty } => self.types_used.insert(ty),
|
||||
Ex::RayQueryGetIntersection {
|
||||
query,
|
||||
committed: _,
|
||||
@ -223,6 +225,7 @@ impl ModuleMap {
|
||||
| Ex::GlobalVariable(_)
|
||||
| Ex::LocalVariable(_)
|
||||
| Ex::CallResult(_)
|
||||
| Ex::SubgroupBallotResult
|
||||
| Ex::RayQueryProceedResult => {}
|
||||
|
||||
// All overrides are retained, so their handles never change.
|
||||
@ -353,6 +356,7 @@ impl ModuleMap {
|
||||
comparison: _,
|
||||
} => self.types.adjust(ty),
|
||||
Ex::WorkGroupUniformLoadResult { ref mut ty } => self.types.adjust(ty),
|
||||
Ex::SubgroupOperationResult { ref mut ty } => self.types.adjust(ty),
|
||||
Ex::ArrayLength(ref mut expr) => adjust(expr),
|
||||
Ex::RayQueryGetIntersection {
|
||||
ref mut query,
|
||||
|
@ -97,6 +97,39 @@ impl FunctionTracer<'_> {
|
||||
self.expressions_used.insert(query);
|
||||
self.trace_ray_query_function(fun);
|
||||
}
|
||||
St::SubgroupBallot { result, predicate } => {
|
||||
if let Some(predicate) = predicate {
|
||||
self.expressions_used.insert(predicate)
|
||||
}
|
||||
self.expressions_used.insert(result)
|
||||
}
|
||||
St::SubgroupCollectiveOperation {
|
||||
op: _,
|
||||
collective_op: _,
|
||||
argument,
|
||||
result,
|
||||
} => {
|
||||
self.expressions_used.insert(argument);
|
||||
self.expressions_used.insert(result)
|
||||
}
|
||||
St::SubgroupGather {
|
||||
mode,
|
||||
argument,
|
||||
result,
|
||||
} => {
|
||||
match mode {
|
||||
crate::GatherMode::BroadcastFirst => {}
|
||||
crate::GatherMode::Broadcast(index)
|
||||
| crate::GatherMode::Shuffle(index)
|
||||
| crate::GatherMode::ShuffleDown(index)
|
||||
| crate::GatherMode::ShuffleUp(index)
|
||||
| crate::GatherMode::ShuffleXor(index) => {
|
||||
self.expressions_used.insert(index)
|
||||
}
|
||||
}
|
||||
self.expressions_used.insert(argument);
|
||||
self.expressions_used.insert(result)
|
||||
}
|
||||
|
||||
// Trivial statements.
|
||||
St::Break
|
||||
@ -250,6 +283,40 @@ impl FunctionMap {
|
||||
adjust(query);
|
||||
self.adjust_ray_query_function(fun);
|
||||
}
|
||||
St::SubgroupBallot {
|
||||
ref mut result,
|
||||
ref mut predicate,
|
||||
} => {
|
||||
if let Some(ref mut predicate) = *predicate {
|
||||
adjust(predicate);
|
||||
}
|
||||
adjust(result);
|
||||
}
|
||||
St::SubgroupCollectiveOperation {
|
||||
op: _,
|
||||
collective_op: _,
|
||||
ref mut argument,
|
||||
ref mut result,
|
||||
} => {
|
||||
adjust(argument);
|
||||
adjust(result);
|
||||
}
|
||||
St::SubgroupGather {
|
||||
ref mut mode,
|
||||
ref mut argument,
|
||||
ref mut result,
|
||||
} => {
|
||||
match *mode {
|
||||
crate::GatherMode::BroadcastFirst => {}
|
||||
crate::GatherMode::Broadcast(ref mut index)
|
||||
| crate::GatherMode::Shuffle(ref mut index)
|
||||
| crate::GatherMode::ShuffleDown(ref mut index)
|
||||
| crate::GatherMode::ShuffleUp(ref mut index)
|
||||
| crate::GatherMode::ShuffleXor(ref mut index) => adjust(index),
|
||||
}
|
||||
adjust(argument);
|
||||
adjust(result);
|
||||
}
|
||||
|
||||
// Trivial statements.
|
||||
St::Break
|
||||
|
@ -153,6 +153,11 @@ pub(super) fn map_builtin(word: spirv::Word, invariant: bool) -> Result<crate::B
|
||||
Some(Bi::WorkgroupId) => crate::BuiltIn::WorkGroupId,
|
||||
Some(Bi::WorkgroupSize) => crate::BuiltIn::WorkGroupSize,
|
||||
Some(Bi::NumWorkgroups) => crate::BuiltIn::NumWorkGroups,
|
||||
// subgroup
|
||||
Some(Bi::NumSubgroups) => crate::BuiltIn::NumSubgroups,
|
||||
Some(Bi::SubgroupId) => crate::BuiltIn::SubgroupId,
|
||||
Some(Bi::SubgroupSize) => crate::BuiltIn::SubgroupSize,
|
||||
Some(Bi::SubgroupLocalInvocationId) => crate::BuiltIn::SubgroupInvocationId,
|
||||
_ => return Err(Error::UnsupportedBuiltIn(word)),
|
||||
})
|
||||
}
|
||||
|
@ -58,6 +58,8 @@ pub enum Error {
|
||||
UnknownBinaryOperator(spirv::Op),
|
||||
#[error("unknown relational function {0:?}")]
|
||||
UnknownRelationalFunction(spirv::Op),
|
||||
#[error("unsupported group operation %{0}")]
|
||||
UnsupportedGroupOperation(spirv::Word),
|
||||
#[error("invalid parameter {0:?}")]
|
||||
InvalidParameter(spirv::Op),
|
||||
#[error("invalid operand count {1} for {0:?}")]
|
||||
|
@ -3700,6 +3700,254 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
|
||||
},
|
||||
);
|
||||
}
|
||||
Op::GroupNonUniformBallot => {
|
||||
inst.expect(5)?;
|
||||
block.extend(emitter.finish(ctx.expressions));
|
||||
let result_type_id = self.next()?;
|
||||
let result_id = self.next()?;
|
||||
let exec_scope_id = self.next()?;
|
||||
let predicate_id = self.next()?;
|
||||
|
||||
let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
|
||||
let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner)
|
||||
.filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32)
|
||||
.ok_or(Error::InvalidBarrierScope(exec_scope_id))?;
|
||||
|
||||
let predicate = if self
|
||||
.lookup_constant
|
||||
.lookup(predicate_id)
|
||||
.ok()
|
||||
.filter(|predicate_const| match predicate_const.inner {
|
||||
Constant::Constant(constant) => matches!(
|
||||
ctx.gctx().global_expressions[ctx.gctx().constants[constant].init],
|
||||
crate::Expression::Literal(crate::Literal::Bool(true)),
|
||||
),
|
||||
Constant::Override(_) => false,
|
||||
})
|
||||
.is_some()
|
||||
{
|
||||
None
|
||||
} else {
|
||||
let predicate_lookup = self.lookup_expression.lookup(predicate_id)?;
|
||||
let predicate_handle = get_expr_handle!(predicate_id, predicate_lookup);
|
||||
Some(predicate_handle)
|
||||
};
|
||||
|
||||
let result_handle = ctx
|
||||
.expressions
|
||||
.append(crate::Expression::SubgroupBallotResult, span);
|
||||
self.lookup_expression.insert(
|
||||
result_id,
|
||||
LookupExpression {
|
||||
handle: result_handle,
|
||||
type_id: result_type_id,
|
||||
block_id,
|
||||
},
|
||||
);
|
||||
|
||||
block.push(
|
||||
crate::Statement::SubgroupBallot {
|
||||
result: result_handle,
|
||||
predicate,
|
||||
},
|
||||
span,
|
||||
);
|
||||
emitter.start(ctx.expressions);
|
||||
}
|
||||
spirv::Op::GroupNonUniformAll
|
||||
| spirv::Op::GroupNonUniformAny
|
||||
| spirv::Op::GroupNonUniformIAdd
|
||||
| spirv::Op::GroupNonUniformFAdd
|
||||
| spirv::Op::GroupNonUniformIMul
|
||||
| spirv::Op::GroupNonUniformFMul
|
||||
| spirv::Op::GroupNonUniformSMax
|
||||
| spirv::Op::GroupNonUniformUMax
|
||||
| spirv::Op::GroupNonUniformFMax
|
||||
| spirv::Op::GroupNonUniformSMin
|
||||
| spirv::Op::GroupNonUniformUMin
|
||||
| spirv::Op::GroupNonUniformFMin
|
||||
| spirv::Op::GroupNonUniformBitwiseAnd
|
||||
| spirv::Op::GroupNonUniformBitwiseOr
|
||||
| spirv::Op::GroupNonUniformBitwiseXor
|
||||
| spirv::Op::GroupNonUniformLogicalAnd
|
||||
| spirv::Op::GroupNonUniformLogicalOr
|
||||
| spirv::Op::GroupNonUniformLogicalXor => {
|
||||
block.extend(emitter.finish(ctx.expressions));
|
||||
inst.expect(
|
||||
if matches!(
|
||||
inst.op,
|
||||
spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny
|
||||
) {
|
||||
5
|
||||
} else {
|
||||
6
|
||||
},
|
||||
)?;
|
||||
let result_type_id = self.next()?;
|
||||
let result_id = self.next()?;
|
||||
let exec_scope_id = self.next()?;
|
||||
let collective_op_id = match inst.op {
|
||||
spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny => {
|
||||
crate::CollectiveOperation::Reduce
|
||||
}
|
||||
_ => {
|
||||
let group_op_id = self.next()?;
|
||||
match spirv::GroupOperation::from_u32(group_op_id) {
|
||||
Some(spirv::GroupOperation::Reduce) => {
|
||||
crate::CollectiveOperation::Reduce
|
||||
}
|
||||
Some(spirv::GroupOperation::InclusiveScan) => {
|
||||
crate::CollectiveOperation::InclusiveScan
|
||||
}
|
||||
Some(spirv::GroupOperation::ExclusiveScan) => {
|
||||
crate::CollectiveOperation::ExclusiveScan
|
||||
}
|
||||
_ => return Err(Error::UnsupportedGroupOperation(group_op_id)),
|
||||
}
|
||||
}
|
||||
};
|
||||
let argument_id = self.next()?;
|
||||
|
||||
let argument_lookup = self.lookup_expression.lookup(argument_id)?;
|
||||
let argument_handle = get_expr_handle!(argument_id, argument_lookup);
|
||||
|
||||
let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
|
||||
let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner)
|
||||
.filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32)
|
||||
.ok_or(Error::InvalidBarrierScope(exec_scope_id))?;
|
||||
|
||||
let op_id = match inst.op {
|
||||
spirv::Op::GroupNonUniformAll => crate::SubgroupOperation::All,
|
||||
spirv::Op::GroupNonUniformAny => crate::SubgroupOperation::Any,
|
||||
spirv::Op::GroupNonUniformIAdd | spirv::Op::GroupNonUniformFAdd => {
|
||||
crate::SubgroupOperation::Add
|
||||
}
|
||||
spirv::Op::GroupNonUniformIMul | spirv::Op::GroupNonUniformFMul => {
|
||||
crate::SubgroupOperation::Mul
|
||||
}
|
||||
spirv::Op::GroupNonUniformSMax
|
||||
| spirv::Op::GroupNonUniformUMax
|
||||
| spirv::Op::GroupNonUniformFMax => crate::SubgroupOperation::Max,
|
||||
spirv::Op::GroupNonUniformSMin
|
||||
| spirv::Op::GroupNonUniformUMin
|
||||
| spirv::Op::GroupNonUniformFMin => crate::SubgroupOperation::Min,
|
||||
spirv::Op::GroupNonUniformBitwiseAnd
|
||||
| spirv::Op::GroupNonUniformLogicalAnd => crate::SubgroupOperation::And,
|
||||
spirv::Op::GroupNonUniformBitwiseOr
|
||||
| spirv::Op::GroupNonUniformLogicalOr => crate::SubgroupOperation::Or,
|
||||
spirv::Op::GroupNonUniformBitwiseXor
|
||||
| spirv::Op::GroupNonUniformLogicalXor => crate::SubgroupOperation::Xor,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let result_type = self.lookup_type.lookup(result_type_id)?;
|
||||
|
||||
let result_handle = ctx.expressions.append(
|
||||
crate::Expression::SubgroupOperationResult {
|
||||
ty: result_type.handle,
|
||||
},
|
||||
span,
|
||||
);
|
||||
self.lookup_expression.insert(
|
||||
result_id,
|
||||
LookupExpression {
|
||||
handle: result_handle,
|
||||
type_id: result_type_id,
|
||||
block_id,
|
||||
},
|
||||
);
|
||||
|
||||
block.push(
|
||||
crate::Statement::SubgroupCollectiveOperation {
|
||||
result: result_handle,
|
||||
op: op_id,
|
||||
collective_op: collective_op_id,
|
||||
argument: argument_handle,
|
||||
},
|
||||
span,
|
||||
);
|
||||
emitter.start(ctx.expressions);
|
||||
}
|
||||
Op::GroupNonUniformBroadcastFirst
|
||||
| Op::GroupNonUniformBroadcast
|
||||
| Op::GroupNonUniformShuffle
|
||||
| Op::GroupNonUniformShuffleDown
|
||||
| Op::GroupNonUniformShuffleUp
|
||||
| Op::GroupNonUniformShuffleXor => {
|
||||
inst.expect(
|
||||
if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) {
|
||||
5
|
||||
} else {
|
||||
6
|
||||
},
|
||||
)?;
|
||||
block.extend(emitter.finish(ctx.expressions));
|
||||
let result_type_id = self.next()?;
|
||||
let result_id = self.next()?;
|
||||
let exec_scope_id = self.next()?;
|
||||
let argument_id = self.next()?;
|
||||
|
||||
let argument_lookup = self.lookup_expression.lookup(argument_id)?;
|
||||
let argument_handle = get_expr_handle!(argument_id, argument_lookup);
|
||||
|
||||
let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
|
||||
let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner)
|
||||
.filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32)
|
||||
.ok_or(Error::InvalidBarrierScope(exec_scope_id))?;
|
||||
|
||||
let mode = if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) {
|
||||
crate::GatherMode::BroadcastFirst
|
||||
} else {
|
||||
let index_id = self.next()?;
|
||||
let index_lookup = self.lookup_expression.lookup(index_id)?;
|
||||
let index_handle = get_expr_handle!(index_id, index_lookup);
|
||||
match inst.op {
|
||||
spirv::Op::GroupNonUniformBroadcast => {
|
||||
crate::GatherMode::Broadcast(index_handle)
|
||||
}
|
||||
spirv::Op::GroupNonUniformShuffle => {
|
||||
crate::GatherMode::Shuffle(index_handle)
|
||||
}
|
||||
spirv::Op::GroupNonUniformShuffleDown => {
|
||||
crate::GatherMode::ShuffleDown(index_handle)
|
||||
}
|
||||
spirv::Op::GroupNonUniformShuffleUp => {
|
||||
crate::GatherMode::ShuffleUp(index_handle)
|
||||
}
|
||||
spirv::Op::GroupNonUniformShuffleXor => {
|
||||
crate::GatherMode::ShuffleXor(index_handle)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
};
|
||||
|
||||
let result_type = self.lookup_type.lookup(result_type_id)?;
|
||||
|
||||
let result_handle = ctx.expressions.append(
|
||||
crate::Expression::SubgroupOperationResult {
|
||||
ty: result_type.handle,
|
||||
},
|
||||
span,
|
||||
);
|
||||
self.lookup_expression.insert(
|
||||
result_id,
|
||||
LookupExpression {
|
||||
handle: result_handle,
|
||||
type_id: result_type_id,
|
||||
block_id,
|
||||
},
|
||||
);
|
||||
|
||||
block.push(
|
||||
crate::Statement::SubgroupGather {
|
||||
result: result_handle,
|
||||
mode,
|
||||
argument: argument_handle,
|
||||
},
|
||||
span,
|
||||
);
|
||||
emitter.start(ctx.expressions);
|
||||
}
|
||||
_ => return Err(Error::UnsupportedInstruction(self.state, inst.op)),
|
||||
}
|
||||
};
|
||||
@ -3824,7 +4072,10 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
|
||||
| S::Store { .. }
|
||||
| S::ImageStore { .. }
|
||||
| S::Atomic { .. }
|
||||
| S::RayQuery { .. } => {}
|
||||
| S::RayQuery { .. }
|
||||
| S::SubgroupBallot { .. }
|
||||
| S::SubgroupCollectiveOperation { .. }
|
||||
| S::SubgroupGather { .. } => {}
|
||||
S::Call {
|
||||
function: ref mut callee,
|
||||
ref arguments,
|
||||
|
@ -874,6 +874,29 @@ impl Texture {
|
||||
}
|
||||
}
|
||||
|
||||
enum SubgroupGather {
|
||||
BroadcastFirst,
|
||||
Broadcast,
|
||||
Shuffle,
|
||||
ShuffleDown,
|
||||
ShuffleUp,
|
||||
ShuffleXor,
|
||||
}
|
||||
|
||||
impl SubgroupGather {
|
||||
pub fn map(word: &str) -> Option<Self> {
|
||||
Some(match word {
|
||||
"subgroupBroadcastFirst" => Self::BroadcastFirst,
|
||||
"subgroupBroadcast" => Self::Broadcast,
|
||||
"subgroupShuffle" => Self::Shuffle,
|
||||
"subgroupShuffleDown" => Self::ShuffleDown,
|
||||
"subgroupShuffleUp" => Self::ShuffleUp,
|
||||
"subgroupShuffleXor" => Self::ShuffleXor,
|
||||
_ => return None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Lowerer<'source, 'temp> {
|
||||
index: &'temp Index<'source>,
|
||||
layouter: Layouter,
|
||||
@ -2054,6 +2077,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
|
||||
}
|
||||
} else if let Some(fun) = Texture::map(function.name) {
|
||||
self.texture_sample_helper(fun, arguments, span, ctx)?
|
||||
} else if let Some((op, cop)) = conv::map_subgroup_operation(function.name) {
|
||||
return Ok(Some(
|
||||
self.subgroup_operation_helper(span, op, cop, arguments, ctx)?,
|
||||
));
|
||||
} else if let Some(mode) = SubgroupGather::map(function.name) {
|
||||
return Ok(Some(
|
||||
self.subgroup_gather_helper(span, mode, arguments, ctx)?,
|
||||
));
|
||||
} else {
|
||||
match function.name {
|
||||
"select" => {
|
||||
@ -2221,6 +2252,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
|
||||
.push(crate::Statement::Barrier(crate::Barrier::WORK_GROUP), span);
|
||||
return Ok(None);
|
||||
}
|
||||
"subgroupBarrier" => {
|
||||
ctx.prepare_args(arguments, 0, span).finish()?;
|
||||
|
||||
let rctx = ctx.runtime_expression_ctx(span)?;
|
||||
rctx.block
|
||||
.push(crate::Statement::Barrier(crate::Barrier::SUB_GROUP), span);
|
||||
return Ok(None);
|
||||
}
|
||||
"workgroupUniformLoad" => {
|
||||
let mut args = ctx.prepare_args(arguments, 1, span);
|
||||
let expr = args.next()?;
|
||||
@ -2428,6 +2467,22 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
|
||||
)?;
|
||||
return Ok(Some(handle));
|
||||
}
|
||||
"subgroupBallot" => {
|
||||
let mut args = ctx.prepare_args(arguments, 0, span);
|
||||
let predicate = if arguments.len() == 1 {
|
||||
Some(self.expression(args.next()?, ctx)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
args.finish()?;
|
||||
|
||||
let result = ctx
|
||||
.interrupt_emitter(crate::Expression::SubgroupBallotResult, span)?;
|
||||
let rctx = ctx.runtime_expression_ctx(span)?;
|
||||
rctx.block
|
||||
.push(crate::Statement::SubgroupBallot { result, predicate }, span);
|
||||
return Ok(Some(result));
|
||||
}
|
||||
_ => return Err(Error::UnknownIdent(function.span, function.name)),
|
||||
}
|
||||
};
|
||||
@ -2619,6 +2674,80 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
|
||||
})
|
||||
}
|
||||
|
||||
fn subgroup_operation_helper(
|
||||
&mut self,
|
||||
span: Span,
|
||||
op: crate::SubgroupOperation,
|
||||
collective_op: crate::CollectiveOperation,
|
||||
arguments: &[Handle<ast::Expression<'source>>],
|
||||
ctx: &mut ExpressionContext<'source, '_, '_>,
|
||||
) -> Result<Handle<crate::Expression>, Error<'source>> {
|
||||
let mut args = ctx.prepare_args(arguments, 1, span);
|
||||
|
||||
let argument = self.expression(args.next()?, ctx)?;
|
||||
args.finish()?;
|
||||
|
||||
let ty = ctx.register_type(argument)?;
|
||||
|
||||
let result =
|
||||
ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?;
|
||||
let rctx = ctx.runtime_expression_ctx(span)?;
|
||||
rctx.block.push(
|
||||
crate::Statement::SubgroupCollectiveOperation {
|
||||
op,
|
||||
collective_op,
|
||||
argument,
|
||||
result,
|
||||
},
|
||||
span,
|
||||
);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn subgroup_gather_helper(
|
||||
&mut self,
|
||||
span: Span,
|
||||
mode: SubgroupGather,
|
||||
arguments: &[Handle<ast::Expression<'source>>],
|
||||
ctx: &mut ExpressionContext<'source, '_, '_>,
|
||||
) -> Result<Handle<crate::Expression>, Error<'source>> {
|
||||
let mut args = ctx.prepare_args(arguments, 2, span);
|
||||
|
||||
let argument = self.expression(args.next()?, ctx)?;
|
||||
|
||||
use SubgroupGather as Sg;
|
||||
let mode = if let Sg::BroadcastFirst = mode {
|
||||
crate::GatherMode::BroadcastFirst
|
||||
} else {
|
||||
let index = self.expression(args.next()?, ctx)?;
|
||||
match mode {
|
||||
Sg::Broadcast => crate::GatherMode::Broadcast(index),
|
||||
Sg::Shuffle => crate::GatherMode::Shuffle(index),
|
||||
Sg::ShuffleDown => crate::GatherMode::ShuffleDown(index),
|
||||
Sg::ShuffleUp => crate::GatherMode::ShuffleUp(index),
|
||||
Sg::ShuffleXor => crate::GatherMode::ShuffleXor(index),
|
||||
Sg::BroadcastFirst => unreachable!(),
|
||||
}
|
||||
};
|
||||
|
||||
args.finish()?;
|
||||
|
||||
let ty = ctx.register_type(argument)?;
|
||||
|
||||
let result =
|
||||
ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?;
|
||||
let rctx = ctx.runtime_expression_ctx(span)?;
|
||||
rctx.block.push(
|
||||
crate::Statement::SubgroupGather {
|
||||
mode,
|
||||
argument,
|
||||
result,
|
||||
},
|
||||
span,
|
||||
);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn r#struct(
|
||||
&mut self,
|
||||
s: &ast::Struct<'source>,
|
||||
|
@ -35,6 +35,11 @@ pub fn map_built_in(word: &str, span: Span) -> Result<crate::BuiltIn, Error<'_>>
|
||||
"local_invocation_index" => crate::BuiltIn::LocalInvocationIndex,
|
||||
"workgroup_id" => crate::BuiltIn::WorkGroupId,
|
||||
"num_workgroups" => crate::BuiltIn::NumWorkGroups,
|
||||
// subgroup
|
||||
"num_subgroups" => crate::BuiltIn::NumSubgroups,
|
||||
"subgroup_id" => crate::BuiltIn::SubgroupId,
|
||||
"subgroup_size" => crate::BuiltIn::SubgroupSize,
|
||||
"subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId,
|
||||
_ => return Err(Error::UnknownBuiltin(span)),
|
||||
})
|
||||
}
|
||||
@ -260,3 +265,26 @@ pub fn map_conservative_depth(
|
||||
_ => Err(Error::UnknownConservativeDepth(span)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map_subgroup_operation(
|
||||
word: &str,
|
||||
) -> Option<(crate::SubgroupOperation, crate::CollectiveOperation)> {
|
||||
use crate::CollectiveOperation as co;
|
||||
use crate::SubgroupOperation as sg;
|
||||
Some(match word {
|
||||
"subgroupAll" => (sg::All, co::Reduce),
|
||||
"subgroupAny" => (sg::Any, co::Reduce),
|
||||
"subgroupAdd" => (sg::Add, co::Reduce),
|
||||
"subgroupMul" => (sg::Mul, co::Reduce),
|
||||
"subgroupMin" => (sg::Min, co::Reduce),
|
||||
"subgroupMax" => (sg::Max, co::Reduce),
|
||||
"subgroupAnd" => (sg::And, co::Reduce),
|
||||
"subgroupOr" => (sg::Or, co::Reduce),
|
||||
"subgroupXor" => (sg::Xor, co::Reduce),
|
||||
"subgroupExclusiveAdd" => (sg::Add, co::ExclusiveScan),
|
||||
"subgroupExclusiveMul" => (sg::Mul, co::ExclusiveScan),
|
||||
"subgroupInclusiveAdd" => (sg::Add, co::InclusiveScan),
|
||||
"subgroupInclusiveMul" => (sg::Mul, co::InclusiveScan),
|
||||
_ => return None,
|
||||
})
|
||||
}
|
||||
|
@ -431,6 +431,11 @@ pub enum BuiltIn {
|
||||
WorkGroupId,
|
||||
WorkGroupSize,
|
||||
NumWorkGroups,
|
||||
// subgroup
|
||||
NumSubgroups,
|
||||
SubgroupId,
|
||||
SubgroupSize,
|
||||
SubgroupInvocationId,
|
||||
}
|
||||
|
||||
/// Number of bytes per scalar.
|
||||
@ -1277,6 +1282,51 @@ pub enum SwizzleComponent {
|
||||
W = 3,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
|
||||
pub enum GatherMode {
|
||||
/// All gather from the active lane with the smallest index
|
||||
BroadcastFirst,
|
||||
/// All gather from the same lane at the index given by the expression
|
||||
Broadcast(Handle<Expression>),
|
||||
/// Each gathers from a different lane at the index given by the expression
|
||||
Shuffle(Handle<Expression>),
|
||||
/// Each gathers from their lane plus the shift given by the expression
|
||||
ShuffleDown(Handle<Expression>),
|
||||
/// Each gathers from their lane minus the shift given by the expression
|
||||
ShuffleUp(Handle<Expression>),
|
||||
/// Each gathers from their lane xored with the given by the expression
|
||||
ShuffleXor(Handle<Expression>),
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
|
||||
pub enum SubgroupOperation {
|
||||
All = 0,
|
||||
Any = 1,
|
||||
Add = 2,
|
||||
Mul = 3,
|
||||
Min = 4,
|
||||
Max = 5,
|
||||
And = 6,
|
||||
Or = 7,
|
||||
Xor = 8,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
|
||||
pub enum CollectiveOperation {
|
||||
Reduce = 0,
|
||||
InclusiveScan = 1,
|
||||
ExclusiveScan = 2,
|
||||
}
|
||||
|
||||
bitflags::bitflags! {
|
||||
/// Memory barrier flags.
|
||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||
@ -1285,9 +1335,11 @@ bitflags::bitflags! {
|
||||
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
|
||||
pub struct Barrier: u32 {
|
||||
/// Barrier affects all `AddressSpace::Storage` accesses.
|
||||
const STORAGE = 0x1;
|
||||
const STORAGE = 1 << 0;
|
||||
/// Barrier affects all `AddressSpace::WorkGroup` accesses.
|
||||
const WORK_GROUP = 0x2;
|
||||
const WORK_GROUP = 1 << 1;
|
||||
/// Barrier synchronizes execution across all invocations within a subgroup that exectue this instruction.
|
||||
const SUB_GROUP = 1 << 2;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1588,6 +1640,15 @@ pub enum Expression {
|
||||
query: Handle<Expression>,
|
||||
committed: bool,
|
||||
},
|
||||
/// Result of a [`SubgroupBallot`] statement.
|
||||
///
|
||||
/// [`SubgroupBallot`]: Statement::SubgroupBallot
|
||||
SubgroupBallotResult,
|
||||
/// Result of a [`SubgroupCollectiveOperation`] or [`SubgroupGather`] statement.
|
||||
///
|
||||
/// [`SubgroupCollectiveOperation`]: Statement::SubgroupCollectiveOperation
|
||||
/// [`SubgroupGather`]: Statement::SubgroupGather
|
||||
SubgroupOperationResult { ty: Handle<Type> },
|
||||
}
|
||||
|
||||
pub use block::Block;
|
||||
@ -1872,6 +1933,39 @@ pub enum Statement {
|
||||
/// The specific operation we're performing on `query`.
|
||||
fun: RayQueryFunction,
|
||||
},
|
||||
/// Calculate a bitmask using a boolean from each active thread in the subgroup
|
||||
SubgroupBallot {
|
||||
/// The [`SubgroupBallotResult`] expression representing this load's result.
|
||||
///
|
||||
/// [`SubgroupBallotResult`]: Expression::SubgroupBallotResult
|
||||
result: Handle<Expression>,
|
||||
/// The value from this thread to store in the ballot
|
||||
predicate: Option<Handle<Expression>>,
|
||||
},
|
||||
/// Gather a value from another active thread in the subgroup
|
||||
SubgroupGather {
|
||||
/// Specifies which thread to gather from
|
||||
mode: GatherMode,
|
||||
/// The value to broadcast over
|
||||
argument: Handle<Expression>,
|
||||
/// The [`SubgroupOperationResult`] expression representing this load's result.
|
||||
///
|
||||
/// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult
|
||||
result: Handle<Expression>,
|
||||
},
|
||||
/// Compute a collective operation across all active threads in the subgroup
|
||||
SubgroupCollectiveOperation {
|
||||
/// What operation to compute
|
||||
op: SubgroupOperation,
|
||||
/// How to combine the results
|
||||
collective_op: CollectiveOperation,
|
||||
/// The value to compute over
|
||||
argument: Handle<Expression>,
|
||||
/// The [`SubgroupOperationResult`] expression representing this load's result.
|
||||
///
|
||||
/// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult
|
||||
result: Handle<Expression>,
|
||||
},
|
||||
}
|
||||
|
||||
/// A function argument.
|
||||
|
@ -476,6 +476,8 @@ pub enum ConstantEvaluatorError {
|
||||
ImageExpression,
|
||||
#[error("Constants don't support ray query expressions")]
|
||||
RayQueryExpression,
|
||||
#[error("Constants don't support subgroup expressions")]
|
||||
SubgroupExpression,
|
||||
#[error("Cannot access the type")]
|
||||
InvalidAccessBase,
|
||||
#[error("Cannot access at the index")]
|
||||
@ -884,6 +886,12 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => {
|
||||
Err(ConstantEvaluatorError::RayQueryExpression)
|
||||
}
|
||||
Expression::SubgroupBallotResult { .. } => {
|
||||
Err(ConstantEvaluatorError::SubgroupExpression)
|
||||
}
|
||||
Expression::SubgroupOperationResult { .. } => {
|
||||
Err(ConstantEvaluatorError::SubgroupExpression)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -37,6 +37,9 @@ pub fn ensure_block_returns(block: &mut crate::Block) {
|
||||
| S::RayQuery { .. }
|
||||
| S::Atomic { .. }
|
||||
| S::WorkGroupUniformLoad { .. }
|
||||
| S::SubgroupBallot { .. }
|
||||
| S::SubgroupCollectiveOperation { .. }
|
||||
| S::SubgroupGather { .. }
|
||||
| S::Barrier(_)),
|
||||
)
|
||||
| None => block.push(S::Return { value: None }, Default::default()),
|
||||
|
@ -598,6 +598,7 @@ impl<'a> ResolveContext<'a> {
|
||||
| crate::BinaryOperator::ShiftRight => past(left)?.clone(),
|
||||
},
|
||||
crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty),
|
||||
crate::Expression::SubgroupOperationResult { ty } => TypeResolution::Handle(ty),
|
||||
crate::Expression::WorkGroupUniformLoadResult { ty } => TypeResolution::Handle(ty),
|
||||
crate::Expression::Select { accept, .. } => past(accept)?.clone(),
|
||||
crate::Expression::Derivative { expr, .. } => past(expr)?.clone(),
|
||||
@ -885,6 +886,10 @@ impl<'a> ResolveContext<'a> {
|
||||
.ok_or(ResolveError::MissingSpecialType)?;
|
||||
TypeResolution::Handle(result)
|
||||
}
|
||||
crate::Expression::SubgroupBallotResult => TypeResolution::Value(Ti::Vector {
|
||||
scalar: crate::Scalar::U32,
|
||||
size: crate::VectorSize::Quad,
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -787,6 +787,14 @@ impl FunctionInfo {
|
||||
non_uniform_result: self.add_ref(query),
|
||||
requirements: UniformityRequirements::empty(),
|
||||
},
|
||||
E::SubgroupBallotResult => Uniformity {
|
||||
non_uniform_result: Some(handle),
|
||||
requirements: UniformityRequirements::empty(),
|
||||
},
|
||||
E::SubgroupOperationResult { .. } => Uniformity {
|
||||
non_uniform_result: Some(handle),
|
||||
requirements: UniformityRequirements::empty(),
|
||||
},
|
||||
};
|
||||
|
||||
let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
|
||||
@ -1029,6 +1037,42 @@ impl FunctionInfo {
|
||||
}
|
||||
FunctionUniformity::new()
|
||||
}
|
||||
S::SubgroupBallot {
|
||||
result: _,
|
||||
predicate,
|
||||
} => {
|
||||
if let Some(predicate) = predicate {
|
||||
let _ = self.add_ref(predicate);
|
||||
}
|
||||
FunctionUniformity::new()
|
||||
}
|
||||
S::SubgroupCollectiveOperation {
|
||||
op: _,
|
||||
collective_op: _,
|
||||
argument,
|
||||
result: _,
|
||||
} => {
|
||||
let _ = self.add_ref(argument);
|
||||
FunctionUniformity::new()
|
||||
}
|
||||
S::SubgroupGather {
|
||||
mode,
|
||||
argument,
|
||||
result: _,
|
||||
} => {
|
||||
let _ = self.add_ref(argument);
|
||||
match mode {
|
||||
crate::GatherMode::BroadcastFirst => {}
|
||||
crate::GatherMode::Broadcast(index)
|
||||
| crate::GatherMode::Shuffle(index)
|
||||
| crate::GatherMode::ShuffleDown(index)
|
||||
| crate::GatherMode::ShuffleUp(index)
|
||||
| crate::GatherMode::ShuffleXor(index) => {
|
||||
let _ = self.add_ref(index);
|
||||
}
|
||||
}
|
||||
FunctionUniformity::new()
|
||||
}
|
||||
};
|
||||
|
||||
disruptor = disruptor.or(uniformity.exit_disruptor());
|
||||
|
@ -1641,6 +1641,7 @@ impl super::Validator {
|
||||
return Err(ExpressionError::InvalidRayQueryType(query));
|
||||
}
|
||||
},
|
||||
E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages,
|
||||
};
|
||||
Ok(stages)
|
||||
}
|
||||
|
@ -47,6 +47,19 @@ pub enum AtomicError {
|
||||
ResultTypeMismatch(Handle<crate::Expression>),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, thiserror::Error)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
pub enum SubgroupError {
|
||||
#[error("Operand {0:?} has invalid type.")]
|
||||
InvalidOperand(Handle<crate::Expression>),
|
||||
#[error("Result type for {0:?} doesn't match the statement")]
|
||||
ResultTypeMismatch(Handle<crate::Expression>),
|
||||
#[error("Support for subgroup operation {0:?} is required")]
|
||||
UnsupportedOperation(super::SubgroupOperationSet),
|
||||
#[error("Unknown operation")]
|
||||
UnknownOperation,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, thiserror::Error)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
pub enum LocalVariableError {
|
||||
@ -135,6 +148,8 @@ pub enum FunctionError {
|
||||
InvalidRayDescriptor(Handle<crate::Expression>),
|
||||
#[error("Ray Query {0:?} does not have a matching type")]
|
||||
InvalidRayQueryType(Handle<crate::Type>),
|
||||
#[error("Shader requires capability {0:?}")]
|
||||
MissingCapability(super::Capabilities),
|
||||
#[error(
|
||||
"Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
|
||||
)]
|
||||
@ -155,6 +170,8 @@ pub enum FunctionError {
|
||||
WorkgroupUniformLoadExpressionMismatch(Handle<crate::Expression>),
|
||||
#[error("The expression {0:?} is not valid as a WorkGroupUniformLoad argument. It should be a Pointer in Workgroup address space")]
|
||||
WorkgroupUniformLoadInvalidPointer(Handle<crate::Expression>),
|
||||
#[error("Subgroup operation is invalid")]
|
||||
InvalidSubgroup(#[from] SubgroupError),
|
||||
}
|
||||
|
||||
bitflags::bitflags! {
|
||||
@ -399,6 +416,127 @@ impl super::Validator {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
fn validate_subgroup_operation(
|
||||
&mut self,
|
||||
op: &crate::SubgroupOperation,
|
||||
collective_op: &crate::CollectiveOperation,
|
||||
argument: Handle<crate::Expression>,
|
||||
result: Handle<crate::Expression>,
|
||||
context: &BlockContext,
|
||||
) -> Result<(), WithSpan<FunctionError>> {
|
||||
let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?;
|
||||
|
||||
let (is_scalar, scalar) = match *argument_inner {
|
||||
crate::TypeInner::Scalar(scalar) => (true, scalar),
|
||||
crate::TypeInner::Vector { scalar, .. } => (false, scalar),
|
||||
_ => {
|
||||
log::error!("Subgroup operand type {:?}", argument_inner);
|
||||
return Err(SubgroupError::InvalidOperand(argument)
|
||||
.with_span_handle(argument, context.expressions)
|
||||
.into_other());
|
||||
}
|
||||
};
|
||||
|
||||
use crate::ScalarKind as sk;
|
||||
use crate::SubgroupOperation as sg;
|
||||
match (scalar.kind, *op) {
|
||||
(sk::Bool, sg::All | sg::Any) if is_scalar => {}
|
||||
(sk::Sint | sk::Uint | sk::Float, sg::Add | sg::Mul | sg::Min | sg::Max) => {}
|
||||
(sk::Sint | sk::Uint, sg::And | sg::Or | sg::Xor) => {}
|
||||
|
||||
(_, _) => {
|
||||
log::error!("Subgroup operand type {:?}", argument_inner);
|
||||
return Err(SubgroupError::InvalidOperand(argument)
|
||||
.with_span_handle(argument, context.expressions)
|
||||
.into_other());
|
||||
}
|
||||
};
|
||||
|
||||
use crate::CollectiveOperation as co;
|
||||
match (*collective_op, *op) {
|
||||
(
|
||||
co::Reduce,
|
||||
sg::All
|
||||
| sg::Any
|
||||
| sg::Add
|
||||
| sg::Mul
|
||||
| sg::Min
|
||||
| sg::Max
|
||||
| sg::And
|
||||
| sg::Or
|
||||
| sg::Xor,
|
||||
) => {}
|
||||
(co::InclusiveScan | co::ExclusiveScan, sg::Add | sg::Mul) => {}
|
||||
|
||||
(_, _) => {
|
||||
return Err(SubgroupError::UnknownOperation.with_span().into_other());
|
||||
}
|
||||
};
|
||||
|
||||
self.emit_expression(result, context)?;
|
||||
match context.expressions[result] {
|
||||
crate::Expression::SubgroupOperationResult { ty }
|
||||
if { &context.types[ty].inner == argument_inner } => {}
|
||||
_ => {
|
||||
return Err(SubgroupError::ResultTypeMismatch(result)
|
||||
.with_span_handle(result, context.expressions)
|
||||
.into_other())
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
fn validate_subgroup_gather(
|
||||
&mut self,
|
||||
mode: &crate::GatherMode,
|
||||
argument: Handle<crate::Expression>,
|
||||
result: Handle<crate::Expression>,
|
||||
context: &BlockContext,
|
||||
) -> Result<(), WithSpan<FunctionError>> {
|
||||
match *mode {
|
||||
crate::GatherMode::BroadcastFirst => {}
|
||||
crate::GatherMode::Broadcast(index)
|
||||
| crate::GatherMode::Shuffle(index)
|
||||
| crate::GatherMode::ShuffleDown(index)
|
||||
| crate::GatherMode::ShuffleUp(index)
|
||||
| crate::GatherMode::ShuffleXor(index) => {
|
||||
let index_ty = context.resolve_type(index, &self.valid_expression_set)?;
|
||||
match *index_ty {
|
||||
crate::TypeInner::Scalar(crate::Scalar::U32) => {}
|
||||
_ => {
|
||||
log::error!(
|
||||
"Subgroup gather index type {:?}, expected unsigned int",
|
||||
index_ty
|
||||
);
|
||||
return Err(SubgroupError::InvalidOperand(argument)
|
||||
.with_span_handle(index, context.expressions)
|
||||
.into_other());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?;
|
||||
if !matches!(*argument_inner,
|
||||
crate::TypeInner::Scalar ( scalar, .. ) | crate::TypeInner::Vector { scalar, .. }
|
||||
if matches!(scalar.kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float)
|
||||
) {
|
||||
log::error!("Subgroup gather operand type {:?}", argument_inner);
|
||||
return Err(SubgroupError::InvalidOperand(argument)
|
||||
.with_span_handle(argument, context.expressions)
|
||||
.into_other());
|
||||
}
|
||||
|
||||
self.emit_expression(result, context)?;
|
||||
match context.expressions[result] {
|
||||
crate::Expression::SubgroupOperationResult { ty }
|
||||
if { &context.types[ty].inner == argument_inner } => {}
|
||||
_ => {
|
||||
return Err(SubgroupError::ResultTypeMismatch(result)
|
||||
.with_span_handle(result, context.expressions)
|
||||
.into_other())
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_block_impl(
|
||||
&mut self,
|
||||
@ -613,8 +751,30 @@ impl super::Validator {
|
||||
stages &= super::ShaderStages::FRAGMENT;
|
||||
finished = true;
|
||||
}
|
||||
S::Barrier(_) => {
|
||||
S::Barrier(barrier) => {
|
||||
stages &= super::ShaderStages::COMPUTE;
|
||||
if barrier.contains(crate::Barrier::SUB_GROUP) {
|
||||
if !self.capabilities.contains(
|
||||
super::Capabilities::SUBGROUP | super::Capabilities::SUBGROUP_BARRIER,
|
||||
) {
|
||||
return Err(FunctionError::MissingCapability(
|
||||
super::Capabilities::SUBGROUP
|
||||
| super::Capabilities::SUBGROUP_BARRIER,
|
||||
)
|
||||
.with_span_static(span, "missing capability for this operation"));
|
||||
}
|
||||
if !self
|
||||
.subgroup_operations
|
||||
.contains(super::SubgroupOperationSet::BASIC)
|
||||
{
|
||||
return Err(FunctionError::InvalidSubgroup(
|
||||
SubgroupError::UnsupportedOperation(
|
||||
super::SubgroupOperationSet::BASIC,
|
||||
),
|
||||
)
|
||||
.with_span_static(span, "support for this operation is not present"));
|
||||
}
|
||||
}
|
||||
}
|
||||
S::Store { pointer, value } => {
|
||||
let mut current = pointer;
|
||||
@ -904,6 +1064,86 @@ impl super::Validator {
|
||||
crate::RayQueryFunction::Terminate => {}
|
||||
}
|
||||
}
|
||||
S::SubgroupBallot { result, predicate } => {
|
||||
stages &= self.subgroup_stages;
|
||||
if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
|
||||
return Err(FunctionError::MissingCapability(
|
||||
super::Capabilities::SUBGROUP,
|
||||
)
|
||||
.with_span_static(span, "missing capability for this operation"));
|
||||
}
|
||||
if !self
|
||||
.subgroup_operations
|
||||
.contains(super::SubgroupOperationSet::BALLOT)
|
||||
{
|
||||
return Err(FunctionError::InvalidSubgroup(
|
||||
SubgroupError::UnsupportedOperation(
|
||||
super::SubgroupOperationSet::BALLOT,
|
||||
),
|
||||
)
|
||||
.with_span_static(span, "support for this operation is not present"));
|
||||
}
|
||||
if let Some(predicate) = predicate {
|
||||
let predicate_inner =
|
||||
context.resolve_type(predicate, &self.valid_expression_set)?;
|
||||
if !matches!(
|
||||
*predicate_inner,
|
||||
crate::TypeInner::Scalar(crate::Scalar::BOOL,)
|
||||
) {
|
||||
log::error!(
|
||||
"Subgroup ballot predicate type {:?} expected bool",
|
||||
predicate_inner
|
||||
);
|
||||
return Err(SubgroupError::InvalidOperand(predicate)
|
||||
.with_span_handle(predicate, context.expressions)
|
||||
.into_other());
|
||||
}
|
||||
}
|
||||
self.emit_expression(result, context)?;
|
||||
}
|
||||
S::SubgroupCollectiveOperation {
|
||||
ref op,
|
||||
ref collective_op,
|
||||
argument,
|
||||
result,
|
||||
} => {
|
||||
stages &= self.subgroup_stages;
|
||||
if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
|
||||
return Err(FunctionError::MissingCapability(
|
||||
super::Capabilities::SUBGROUP,
|
||||
)
|
||||
.with_span_static(span, "missing capability for this operation"));
|
||||
}
|
||||
let operation = op.required_operations();
|
||||
if !self.subgroup_operations.contains(operation) {
|
||||
return Err(FunctionError::InvalidSubgroup(
|
||||
SubgroupError::UnsupportedOperation(operation),
|
||||
)
|
||||
.with_span_static(span, "support for this operation is not present"));
|
||||
}
|
||||
self.validate_subgroup_operation(op, collective_op, argument, result, context)?;
|
||||
}
|
||||
S::SubgroupGather {
|
||||
ref mode,
|
||||
argument,
|
||||
result,
|
||||
} => {
|
||||
stages &= self.subgroup_stages;
|
||||
if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
|
||||
return Err(FunctionError::MissingCapability(
|
||||
super::Capabilities::SUBGROUP,
|
||||
)
|
||||
.with_span_static(span, "missing capability for this operation"));
|
||||
}
|
||||
let operation = mode.required_operations();
|
||||
if !self.subgroup_operations.contains(operation) {
|
||||
return Err(FunctionError::InvalidSubgroup(
|
||||
SubgroupError::UnsupportedOperation(operation),
|
||||
)
|
||||
.with_span_static(span, "support for this operation is not present"));
|
||||
}
|
||||
self.validate_subgroup_gather(mode, argument, result, context)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(BlockInfo { stages, finished })
|
||||
|
@ -420,6 +420,8 @@ impl super::Validator {
|
||||
}
|
||||
crate::Expression::AtomicResult { .. }
|
||||
| crate::Expression::RayQueryProceedResult
|
||||
| crate::Expression::SubgroupBallotResult
|
||||
| crate::Expression::SubgroupOperationResult { .. }
|
||||
| crate::Expression::WorkGroupUniformLoadResult { .. } => (),
|
||||
crate::Expression::ArrayLength(array) => {
|
||||
handle.check_dep(array)?;
|
||||
@ -565,6 +567,38 @@ impl super::Validator {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
crate::Statement::SubgroupBallot { result, predicate } => {
|
||||
validate_expr_opt(predicate)?;
|
||||
validate_expr(result)?;
|
||||
Ok(())
|
||||
}
|
||||
crate::Statement::SubgroupCollectiveOperation {
|
||||
op: _,
|
||||
collective_op: _,
|
||||
argument,
|
||||
result,
|
||||
} => {
|
||||
validate_expr(argument)?;
|
||||
validate_expr(result)?;
|
||||
Ok(())
|
||||
}
|
||||
crate::Statement::SubgroupGather {
|
||||
mode,
|
||||
argument,
|
||||
result,
|
||||
} => {
|
||||
validate_expr(argument)?;
|
||||
match mode {
|
||||
crate::GatherMode::BroadcastFirst => {}
|
||||
crate::GatherMode::Broadcast(index)
|
||||
| crate::GatherMode::Shuffle(index)
|
||||
| crate::GatherMode::ShuffleDown(index)
|
||||
| crate::GatherMode::ShuffleUp(index)
|
||||
| crate::GatherMode::ShuffleXor(index) => validate_expr(index)?,
|
||||
}
|
||||
validate_expr(result)?;
|
||||
Ok(())
|
||||
}
|
||||
crate::Statement::Break
|
||||
| crate::Statement::Continue
|
||||
| crate::Statement::Kill
|
||||
|
@ -77,6 +77,8 @@ pub enum VaryingError {
|
||||
location: u32,
|
||||
attribute: &'static str,
|
||||
},
|
||||
#[error("Workgroup size is multi dimensional, @builtin(subgroup_id) and @builtin(subgroup_invocation_id) are not supported.")]
|
||||
InvalidMultiDimensionalSubgroupBuiltIn,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, thiserror::Error)]
|
||||
@ -140,6 +142,7 @@ struct VaryingContext<'a> {
|
||||
impl VaryingContext<'_> {
|
||||
fn validate_impl(
|
||||
&mut self,
|
||||
ep: &crate::EntryPoint,
|
||||
ty: Handle<crate::Type>,
|
||||
binding: &crate::Binding,
|
||||
) -> Result<(), VaryingError> {
|
||||
@ -167,12 +170,24 @@ impl VaryingContext<'_> {
|
||||
Bi::PrimitiveIndex => Capabilities::PRIMITIVE_INDEX,
|
||||
Bi::ViewIndex => Capabilities::MULTIVIEW,
|
||||
Bi::SampleIndex => Capabilities::MULTISAMPLED_SHADING,
|
||||
Bi::NumSubgroups
|
||||
| Bi::SubgroupId
|
||||
| Bi::SubgroupSize
|
||||
| Bi::SubgroupInvocationId => Capabilities::SUBGROUP,
|
||||
_ => Capabilities::empty(),
|
||||
};
|
||||
if !self.capabilities.contains(required) {
|
||||
return Err(VaryingError::UnsupportedCapability(required));
|
||||
}
|
||||
|
||||
if matches!(
|
||||
built_in,
|
||||
crate::BuiltIn::SubgroupId | crate::BuiltIn::SubgroupInvocationId
|
||||
) && ep.workgroup_size[1..].iter().any(|&s| s > 1)
|
||||
{
|
||||
return Err(VaryingError::InvalidMultiDimensionalSubgroupBuiltIn);
|
||||
}
|
||||
|
||||
let (visible, type_good) = match built_in {
|
||||
Bi::BaseInstance | Bi::BaseVertex | Bi::InstanceIndex | Bi::VertexIndex => (
|
||||
self.stage == St::Vertex && !self.output,
|
||||
@ -254,6 +269,17 @@ impl VaryingContext<'_> {
|
||||
scalar: crate::Scalar::U32,
|
||||
},
|
||||
),
|
||||
Bi::NumSubgroups | Bi::SubgroupId => (
|
||||
self.stage == St::Compute && !self.output,
|
||||
*ty_inner == Ti::Scalar(crate::Scalar::U32),
|
||||
),
|
||||
Bi::SubgroupSize | Bi::SubgroupInvocationId => (
|
||||
match self.stage {
|
||||
St::Compute | St::Fragment => !self.output,
|
||||
St::Vertex => false,
|
||||
},
|
||||
*ty_inner == Ti::Scalar(crate::Scalar::U32),
|
||||
),
|
||||
};
|
||||
|
||||
if !visible {
|
||||
@ -354,13 +380,14 @@ impl VaryingContext<'_> {
|
||||
|
||||
fn validate(
|
||||
&mut self,
|
||||
ep: &crate::EntryPoint,
|
||||
ty: Handle<crate::Type>,
|
||||
binding: Option<&crate::Binding>,
|
||||
) -> Result<(), WithSpan<VaryingError>> {
|
||||
let span_context = self.types.get_span_context(ty);
|
||||
match binding {
|
||||
Some(binding) => self
|
||||
.validate_impl(ty, binding)
|
||||
.validate_impl(ep, ty, binding)
|
||||
.map_err(|e| e.with_span_context(span_context)),
|
||||
None => {
|
||||
match self.types[ty].inner {
|
||||
@ -377,7 +404,7 @@ impl VaryingContext<'_> {
|
||||
}
|
||||
}
|
||||
Some(ref binding) => self
|
||||
.validate_impl(member.ty, binding)
|
||||
.validate_impl(ep, member.ty, binding)
|
||||
.map_err(|e| e.with_span_context(span_context))?,
|
||||
}
|
||||
}
|
||||
@ -609,7 +636,7 @@ impl super::Validator {
|
||||
capabilities: self.capabilities,
|
||||
flags: self.flags,
|
||||
};
|
||||
ctx.validate(fa.ty, fa.binding.as_ref())
|
||||
ctx.validate(ep, fa.ty, fa.binding.as_ref())
|
||||
.map_err_inner(|e| EntryPointError::Argument(index as u32, e).with_span())?;
|
||||
}
|
||||
|
||||
@ -627,7 +654,7 @@ impl super::Validator {
|
||||
capabilities: self.capabilities,
|
||||
flags: self.flags,
|
||||
};
|
||||
ctx.validate(fr.ty, fr.binding.as_ref())
|
||||
ctx.validate(ep, fr.ty, fr.binding.as_ref())
|
||||
.map_err_inner(|e| EntryPointError::Result(e).with_span())?;
|
||||
if ctx.second_blend_source {
|
||||
// Only the first location may be used when dual source blending
|
||||
|
@ -77,7 +77,7 @@ bitflags::bitflags! {
|
||||
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
pub struct Capabilities: u16 {
|
||||
pub struct Capabilities: u32 {
|
||||
/// Support for [`AddressSpace:PushConstant`].
|
||||
const PUSH_CONSTANT = 0x1;
|
||||
/// Float values with width = 8.
|
||||
@ -110,6 +110,10 @@ bitflags::bitflags! {
|
||||
const CUBE_ARRAY_TEXTURES = 0x4000;
|
||||
/// Support for 64-bit signed and unsigned integers.
|
||||
const SHADER_INT64 = 0x8000;
|
||||
/// Support for subgroup operations.
|
||||
const SUBGROUP = 0x10000;
|
||||
/// Support for subgroup barriers.
|
||||
const SUBGROUP_BARRIER = 0x20000;
|
||||
}
|
||||
}
|
||||
|
||||
@ -119,6 +123,57 @@ impl Default for Capabilities {
|
||||
}
|
||||
}
|
||||
|
||||
bitflags::bitflags! {
|
||||
/// Supported subgroup operations
|
||||
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
|
||||
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
|
||||
pub struct SubgroupOperationSet: u8 {
|
||||
/// Elect, Barrier
|
||||
const BASIC = 1 << 0;
|
||||
/// Any, All
|
||||
const VOTE = 1 << 1;
|
||||
/// reductions, scans
|
||||
const ARITHMETIC = 1 << 2;
|
||||
/// ballot, broadcast
|
||||
const BALLOT = 1 << 3;
|
||||
/// shuffle, shuffle xor
|
||||
const SHUFFLE = 1 << 4;
|
||||
/// shuffle up, down
|
||||
const SHUFFLE_RELATIVE = 1 << 5;
|
||||
// We don't support these operations yet
|
||||
// /// Clustered
|
||||
// const CLUSTERED = 1 << 6;
|
||||
// /// Quad supported
|
||||
// const QUAD_FRAGMENT_COMPUTE = 1 << 7;
|
||||
// /// Quad supported in all stages
|
||||
// const QUAD_ALL_STAGES = 1 << 8;
|
||||
}
|
||||
}
|
||||
|
||||
impl super::SubgroupOperation {
|
||||
const fn required_operations(&self) -> SubgroupOperationSet {
|
||||
use SubgroupOperationSet as S;
|
||||
match *self {
|
||||
Self::All | Self::Any => S::VOTE,
|
||||
Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => {
|
||||
S::ARITHMETIC
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl super::GatherMode {
|
||||
const fn required_operations(&self) -> SubgroupOperationSet {
|
||||
use SubgroupOperationSet as S;
|
||||
match *self {
|
||||
Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT,
|
||||
Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE,
|
||||
Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bitflags::bitflags! {
|
||||
/// Validation flags.
|
||||
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
|
||||
@ -166,6 +221,8 @@ impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
|
||||
pub struct Validator {
|
||||
flags: ValidationFlags,
|
||||
capabilities: Capabilities,
|
||||
subgroup_stages: ShaderStages,
|
||||
subgroup_operations: SubgroupOperationSet,
|
||||
types: Vec<r#type::TypeInfo>,
|
||||
layouter: Layouter,
|
||||
location_mask: BitSet,
|
||||
@ -317,6 +374,8 @@ impl Validator {
|
||||
Validator {
|
||||
flags,
|
||||
capabilities,
|
||||
subgroup_stages: ShaderStages::empty(),
|
||||
subgroup_operations: SubgroupOperationSet::empty(),
|
||||
types: Vec::new(),
|
||||
layouter: Layouter::default(),
|
||||
location_mask: BitSet::new(),
|
||||
@ -329,6 +388,16 @@ impl Validator {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self {
|
||||
self.subgroup_stages = stages;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self {
|
||||
self.subgroup_operations = operations;
|
||||
self
|
||||
}
|
||||
|
||||
/// Reset the validator internals
|
||||
pub fn reset(&mut self) {
|
||||
self.types.clear();
|
||||
|
27
naga/tests/in/spv/subgroup-operations-s.param.ron
Normal file
27
naga/tests/in/spv/subgroup-operations-s.param.ron
Normal file
@ -0,0 +1,27 @@
|
||||
(
|
||||
god_mode: true,
|
||||
spv: (
|
||||
version: (1, 3),
|
||||
),
|
||||
msl: (
|
||||
lang_version: (2, 4),
|
||||
per_entry_point_map: {},
|
||||
inline_samplers: [],
|
||||
spirv_cross_compatibility: false,
|
||||
fake_missing_bindings: false,
|
||||
zero_initialize_workgroup_memory: true,
|
||||
),
|
||||
glsl: (
|
||||
version: Desktop(430),
|
||||
writer_flags: (""),
|
||||
binding_map: { },
|
||||
zero_initialize_workgroup_memory: true,
|
||||
),
|
||||
hlsl: (
|
||||
shader_model: V6_0,
|
||||
binding_map: {},
|
||||
fake_missing_bindings: true,
|
||||
special_constants_binding: None,
|
||||
zero_initialize_workgroup_memory: true,
|
||||
),
|
||||
)
|
BIN
naga/tests/in/spv/subgroup-operations-s.spv
Normal file
BIN
naga/tests/in/spv/subgroup-operations-s.spv
Normal file
Binary file not shown.
75
naga/tests/in/spv/subgroup-operations-s.spvasm
Normal file
75
naga/tests/in/spv/subgroup-operations-s.spvasm
Normal file
@ -0,0 +1,75 @@
|
||||
; SPIR-V
|
||||
; Version: 1.3
|
||||
; Generator: rspirv
|
||||
; Bound: 54
|
||||
OpCapability Shader
|
||||
OpCapability GroupNonUniform
|
||||
OpCapability GroupNonUniformBallot
|
||||
OpCapability GroupNonUniformVote
|
||||
OpCapability GroupNonUniformArithmetic
|
||||
OpCapability GroupNonUniformShuffle
|
||||
OpCapability GroupNonUniformShuffleRelative
|
||||
%1 = OpExtInstImport "GLSL.std.450"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %15 "main" %6 %9 %11 %13
|
||||
OpExecutionMode %15 LocalSize 1 1 1
|
||||
OpDecorate %6 BuiltIn NumSubgroups
|
||||
OpDecorate %9 BuiltIn SubgroupId
|
||||
OpDecorate %11 BuiltIn SubgroupSize
|
||||
OpDecorate %13 BuiltIn SubgroupLocalInvocationId
|
||||
%2 = OpTypeVoid
|
||||
%3 = OpTypeInt 32 0
|
||||
%4 = OpTypeBool
|
||||
%7 = OpTypePointer Input %3
|
||||
%6 = OpVariable %7 Input
|
||||
%9 = OpVariable %7 Input
|
||||
%11 = OpVariable %7 Input
|
||||
%13 = OpVariable %7 Input
|
||||
%16 = OpTypeFunction %2
|
||||
%17 = OpConstant %3 1
|
||||
%18 = OpConstant %3 0
|
||||
%19 = OpConstant %3 4
|
||||
%21 = OpConstant %3 3
|
||||
%22 = OpConstant %3 2
|
||||
%23 = OpConstant %3 8
|
||||
%26 = OpTypeVector %3 4
|
||||
%28 = OpConstantTrue %4
|
||||
%15 = OpFunction %2 None %16
|
||||
%5 = OpLabel
|
||||
%8 = OpLoad %3 %6
|
||||
%10 = OpLoad %3 %9
|
||||
%12 = OpLoad %3 %11
|
||||
%14 = OpLoad %3 %13
|
||||
OpBranch %20
|
||||
%20 = OpLabel
|
||||
OpControlBarrier %21 %22 %23
|
||||
%24 = OpBitwiseAnd %3 %14 %17
|
||||
%25 = OpIEqual %4 %24 %17
|
||||
%27 = OpGroupNonUniformBallot %26 %21 %25
|
||||
%29 = OpGroupNonUniformBallot %26 %21 %28
|
||||
%30 = OpINotEqual %4 %14 %18
|
||||
%31 = OpGroupNonUniformAll %4 %21 %30
|
||||
%32 = OpIEqual %4 %14 %18
|
||||
%33 = OpGroupNonUniformAny %4 %21 %32
|
||||
%34 = OpGroupNonUniformIAdd %3 %21 Reduce %14
|
||||
%35 = OpGroupNonUniformIMul %3 %21 Reduce %14
|
||||
%36 = OpGroupNonUniformUMin %3 %21 Reduce %14
|
||||
%37 = OpGroupNonUniformUMax %3 %21 Reduce %14
|
||||
%38 = OpGroupNonUniformBitwiseAnd %3 %21 Reduce %14
|
||||
%39 = OpGroupNonUniformBitwiseOr %3 %21 Reduce %14
|
||||
%40 = OpGroupNonUniformBitwiseXor %3 %21 Reduce %14
|
||||
%41 = OpGroupNonUniformIAdd %3 %21 ExclusiveScan %14
|
||||
%42 = OpGroupNonUniformIMul %3 %21 ExclusiveScan %14
|
||||
%43 = OpGroupNonUniformIAdd %3 %21 InclusiveScan %14
|
||||
%44 = OpGroupNonUniformIMul %3 %21 InclusiveScan %14
|
||||
%45 = OpGroupNonUniformBroadcastFirst %3 %21 %14
|
||||
%46 = OpGroupNonUniformBroadcast %3 %21 %14 %19
|
||||
%47 = OpISub %3 %12 %17
|
||||
%48 = OpISub %3 %47 %14
|
||||
%49 = OpGroupNonUniformShuffle %3 %21 %14 %48
|
||||
%50 = OpGroupNonUniformShuffleDown %3 %21 %14 %17
|
||||
%51 = OpGroupNonUniformShuffleUp %3 %21 %14 %17
|
||||
%52 = OpISub %3 %12 %17
|
||||
%53 = OpGroupNonUniformShuffleXor %3 %21 %14 %52
|
||||
OpReturn
|
||||
OpFunctionEnd
|
27
naga/tests/in/subgroup-operations.param.ron
Normal file
27
naga/tests/in/subgroup-operations.param.ron
Normal file
@ -0,0 +1,27 @@
|
||||
(
|
||||
god_mode: true,
|
||||
spv: (
|
||||
version: (1, 3),
|
||||
),
|
||||
msl: (
|
||||
lang_version: (2, 4),
|
||||
per_entry_point_map: {},
|
||||
inline_samplers: [],
|
||||
spirv_cross_compatibility: false,
|
||||
fake_missing_bindings: false,
|
||||
zero_initialize_workgroup_memory: true,
|
||||
),
|
||||
glsl: (
|
||||
version: Desktop(430),
|
||||
writer_flags: (""),
|
||||
binding_map: { },
|
||||
zero_initialize_workgroup_memory: true,
|
||||
),
|
||||
hlsl: (
|
||||
shader_model: V6_0,
|
||||
binding_map: {},
|
||||
fake_missing_bindings: true,
|
||||
special_constants_binding: None,
|
||||
zero_initialize_workgroup_memory: true,
|
||||
),
|
||||
)
|
37
naga/tests/in/subgroup-operations.wgsl
Normal file
37
naga/tests/in/subgroup-operations.wgsl
Normal file
@ -0,0 +1,37 @@
|
||||
struct Structure {
|
||||
@builtin(num_subgroups) num_subgroups: u32,
|
||||
@builtin(subgroup_size) subgroup_size: u32,
|
||||
};
|
||||
|
||||
@compute @workgroup_size(1)
|
||||
fn main(
|
||||
sizes: Structure,
|
||||
@builtin(subgroup_id) subgroup_id: u32,
|
||||
@builtin(subgroup_invocation_id) subgroup_invocation_id: u32,
|
||||
) {
|
||||
subgroupBarrier();
|
||||
|
||||
subgroupBallot((subgroup_invocation_id & 1u) == 1u);
|
||||
subgroupBallot();
|
||||
|
||||
subgroupAll(subgroup_invocation_id != 0u);
|
||||
subgroupAny(subgroup_invocation_id == 0u);
|
||||
subgroupAdd(subgroup_invocation_id);
|
||||
subgroupMul(subgroup_invocation_id);
|
||||
subgroupMin(subgroup_invocation_id);
|
||||
subgroupMax(subgroup_invocation_id);
|
||||
subgroupAnd(subgroup_invocation_id);
|
||||
subgroupOr(subgroup_invocation_id);
|
||||
subgroupXor(subgroup_invocation_id);
|
||||
subgroupExclusiveAdd(subgroup_invocation_id);
|
||||
subgroupExclusiveMul(subgroup_invocation_id);
|
||||
subgroupInclusiveAdd(subgroup_invocation_id);
|
||||
subgroupInclusiveMul(subgroup_invocation_id);
|
||||
|
||||
subgroupBroadcastFirst(subgroup_invocation_id);
|
||||
subgroupBroadcast(subgroup_invocation_id, 4u);
|
||||
subgroupShuffle(subgroup_invocation_id, sizes.subgroup_size - 1u - subgroup_invocation_id);
|
||||
subgroupShuffleDown(subgroup_invocation_id, 1u);
|
||||
subgroupShuffleUp(subgroup_invocation_id, 1u);
|
||||
subgroupShuffleXor(subgroup_invocation_id, sizes.subgroup_size - 1u);
|
||||
}
|
58
naga/tests/out/glsl/subgroup-operations-s.main.Compute.glsl
Normal file
58
naga/tests/out/glsl/subgroup-operations-s.main.Compute.glsl
Normal file
@ -0,0 +1,58 @@
|
||||
#version 430 core
|
||||
#extension GL_ARB_compute_shader : require
|
||||
#extension GL_KHR_shader_subgroup_basic : require
|
||||
#extension GL_KHR_shader_subgroup_vote : require
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : require
|
||||
#extension GL_KHR_shader_subgroup_ballot : require
|
||||
#extension GL_KHR_shader_subgroup_shuffle : require
|
||||
#extension GL_KHR_shader_subgroup_shuffle_relative : require
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
uint num_subgroups_1 = 0u;
|
||||
|
||||
uint subgroup_id_1 = 0u;
|
||||
|
||||
uint subgroup_size_1 = 0u;
|
||||
|
||||
uint subgroup_invocation_id_1 = 0u;
|
||||
|
||||
|
||||
void main_1() {
|
||||
uint _e5 = subgroup_size_1;
|
||||
uint _e6 = subgroup_invocation_id_1;
|
||||
uvec4 _e9 = subgroupBallot(((_e6 & 1u) == 1u));
|
||||
uvec4 _e10 = subgroupBallot(true);
|
||||
bool _e12 = subgroupAll((_e6 != 0u));
|
||||
bool _e14 = subgroupAny((_e6 == 0u));
|
||||
uint _e15 = subgroupAdd(_e6);
|
||||
uint _e16 = subgroupMul(_e6);
|
||||
uint _e17 = subgroupMin(_e6);
|
||||
uint _e18 = subgroupMax(_e6);
|
||||
uint _e19 = subgroupAnd(_e6);
|
||||
uint _e20 = subgroupOr(_e6);
|
||||
uint _e21 = subgroupXor(_e6);
|
||||
uint _e22 = subgroupExclusiveAdd(_e6);
|
||||
uint _e23 = subgroupExclusiveMul(_e6);
|
||||
uint _e24 = subgroupInclusiveAdd(_e6);
|
||||
uint _e25 = subgroupInclusiveMul(_e6);
|
||||
uint _e26 = subgroupBroadcastFirst(_e6);
|
||||
uint _e27 = subgroupBroadcast(_e6, 4u);
|
||||
uint _e30 = subgroupShuffle(_e6, ((_e5 - 1u) - _e6));
|
||||
uint _e31 = subgroupShuffleDown(_e6, 1u);
|
||||
uint _e32 = subgroupShuffleUp(_e6, 1u);
|
||||
uint _e34 = subgroupShuffleXor(_e6, (_e5 - 1u));
|
||||
return;
|
||||
}
|
||||
|
||||
void main() {
|
||||
uint num_subgroups = gl_NumSubgroups;
|
||||
uint subgroup_id = gl_SubgroupID;
|
||||
uint subgroup_size = gl_SubgroupSize;
|
||||
uint subgroup_invocation_id = gl_SubgroupInvocationID;
|
||||
num_subgroups_1 = num_subgroups;
|
||||
subgroup_id_1 = subgroup_id;
|
||||
subgroup_size_1 = subgroup_size;
|
||||
subgroup_invocation_id_1 = subgroup_invocation_id;
|
||||
main_1();
|
||||
}
|
||||
|
45
naga/tests/out/glsl/subgroup-operations.main.Compute.glsl
Normal file
45
naga/tests/out/glsl/subgroup-operations.main.Compute.glsl
Normal file
@ -0,0 +1,45 @@
|
||||
#version 430 core
|
||||
#extension GL_ARB_compute_shader : require
|
||||
#extension GL_KHR_shader_subgroup_basic : require
|
||||
#extension GL_KHR_shader_subgroup_vote : require
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : require
|
||||
#extension GL_KHR_shader_subgroup_ballot : require
|
||||
#extension GL_KHR_shader_subgroup_shuffle : require
|
||||
#extension GL_KHR_shader_subgroup_shuffle_relative : require
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
struct Structure {
|
||||
uint num_subgroups;
|
||||
uint subgroup_size;
|
||||
};
|
||||
|
||||
void main() {
|
||||
Structure sizes = Structure(gl_NumSubgroups, gl_SubgroupSize);
|
||||
uint subgroup_id = gl_SubgroupID;
|
||||
uint subgroup_invocation_id = gl_SubgroupInvocationID;
|
||||
subgroupMemoryBarrier();
|
||||
barrier();
|
||||
uvec4 _e7 = subgroupBallot(((subgroup_invocation_id & 1u) == 1u));
|
||||
uvec4 _e8 = subgroupBallot(true);
|
||||
bool _e11 = subgroupAll((subgroup_invocation_id != 0u));
|
||||
bool _e14 = subgroupAny((subgroup_invocation_id == 0u));
|
||||
uint _e15 = subgroupAdd(subgroup_invocation_id);
|
||||
uint _e16 = subgroupMul(subgroup_invocation_id);
|
||||
uint _e17 = subgroupMin(subgroup_invocation_id);
|
||||
uint _e18 = subgroupMax(subgroup_invocation_id);
|
||||
uint _e19 = subgroupAnd(subgroup_invocation_id);
|
||||
uint _e20 = subgroupOr(subgroup_invocation_id);
|
||||
uint _e21 = subgroupXor(subgroup_invocation_id);
|
||||
uint _e22 = subgroupExclusiveAdd(subgroup_invocation_id);
|
||||
uint _e23 = subgroupExclusiveMul(subgroup_invocation_id);
|
||||
uint _e24 = subgroupInclusiveAdd(subgroup_invocation_id);
|
||||
uint _e25 = subgroupInclusiveMul(subgroup_invocation_id);
|
||||
uint _e26 = subgroupBroadcastFirst(subgroup_invocation_id);
|
||||
uint _e28 = subgroupBroadcast(subgroup_invocation_id, 4u);
|
||||
uint _e33 = subgroupShuffle(subgroup_invocation_id, ((sizes.subgroup_size - 1u) - subgroup_invocation_id));
|
||||
uint _e35 = subgroupShuffleDown(subgroup_invocation_id, 1u);
|
||||
uint _e37 = subgroupShuffleUp(subgroup_invocation_id, 1u);
|
||||
uint _e41 = subgroupShuffleXor(subgroup_invocation_id, (sizes.subgroup_size - 1u));
|
||||
return;
|
||||
}
|
||||
|
50
naga/tests/out/hlsl/subgroup-operations-s.hlsl
Normal file
50
naga/tests/out/hlsl/subgroup-operations-s.hlsl
Normal file
@ -0,0 +1,50 @@
|
||||
static uint num_subgroups_1 = (uint)0;
|
||||
static uint subgroup_id_1 = (uint)0;
|
||||
static uint subgroup_size_1 = (uint)0;
|
||||
static uint subgroup_invocation_id_1 = (uint)0;
|
||||
|
||||
struct ComputeInput_main {
|
||||
uint __local_invocation_index : SV_GroupIndex;
|
||||
};
|
||||
|
||||
void main_1()
|
||||
{
|
||||
uint _expr5 = subgroup_size_1;
|
||||
uint _expr6 = subgroup_invocation_id_1;
|
||||
const uint4 _e9 = WaveActiveBallot(((_expr6 & 1u) == 1u));
|
||||
const uint4 _e10 = WaveActiveBallot(true);
|
||||
const bool _e12 = WaveActiveAllTrue((_expr6 != 0u));
|
||||
const bool _e14 = WaveActiveAnyTrue((_expr6 == 0u));
|
||||
const uint _e15 = WaveActiveSum(_expr6);
|
||||
const uint _e16 = WaveActiveProduct(_expr6);
|
||||
const uint _e17 = WaveActiveMin(_expr6);
|
||||
const uint _e18 = WaveActiveMax(_expr6);
|
||||
const uint _e19 = WaveActiveBitAnd(_expr6);
|
||||
const uint _e20 = WaveActiveBitOr(_expr6);
|
||||
const uint _e21 = WaveActiveBitXor(_expr6);
|
||||
const uint _e22 = WavePrefixSum(_expr6);
|
||||
const uint _e23 = WavePrefixProduct(_expr6);
|
||||
const uint _e24 = _expr6 + WavePrefixSum(_expr6);
|
||||
const uint _e25 = _expr6 * WavePrefixProduct(_expr6);
|
||||
const uint _e26 = WaveReadLaneFirst(_expr6);
|
||||
const uint _e27 = WaveReadLaneAt(_expr6, 4u);
|
||||
const uint _e30 = WaveReadLaneAt(_expr6, ((_expr5 - 1u) - _expr6));
|
||||
const uint _e31 = WaveReadLaneAt(_expr6, WaveGetLaneIndex() + 1u);
|
||||
const uint _e32 = WaveReadLaneAt(_expr6, WaveGetLaneIndex() - 1u);
|
||||
const uint _e34 = WaveReadLaneAt(_expr6, WaveGetLaneIndex() ^ (_expr5 - 1u));
|
||||
return;
|
||||
}
|
||||
|
||||
[numthreads(1, 1, 1)]
|
||||
void main(ComputeInput_main computeinput_main)
|
||||
{
|
||||
uint num_subgroups = (1u + WaveGetLaneCount() - 1u) / WaveGetLaneCount();
|
||||
uint subgroup_id = computeinput_main.__local_invocation_index / WaveGetLaneCount();
|
||||
uint subgroup_size = WaveGetLaneCount();
|
||||
uint subgroup_invocation_id = WaveGetLaneIndex();
|
||||
num_subgroups_1 = num_subgroups;
|
||||
subgroup_id_1 = subgroup_id;
|
||||
subgroup_size_1 = subgroup_size;
|
||||
subgroup_invocation_id_1 = subgroup_invocation_id;
|
||||
main_1();
|
||||
}
|
12
naga/tests/out/hlsl/subgroup-operations-s.ron
Normal file
12
naga/tests/out/hlsl/subgroup-operations-s.ron
Normal file
@ -0,0 +1,12 @@
|
||||
(
|
||||
vertex:[
|
||||
],
|
||||
fragment:[
|
||||
],
|
||||
compute:[
|
||||
(
|
||||
entry_point:"main",
|
||||
target_profile:"cs_6_0",
|
||||
),
|
||||
],
|
||||
)
|
38
naga/tests/out/hlsl/subgroup-operations.hlsl
Normal file
38
naga/tests/out/hlsl/subgroup-operations.hlsl
Normal file
@ -0,0 +1,38 @@
|
||||
struct Structure {
|
||||
uint num_subgroups;
|
||||
uint subgroup_size;
|
||||
};
|
||||
|
||||
struct ComputeInput_main {
|
||||
uint __local_invocation_index : SV_GroupIndex;
|
||||
};
|
||||
|
||||
[numthreads(1, 1, 1)]
|
||||
void main(ComputeInput_main computeinput_main)
|
||||
{
|
||||
Structure sizes = { (1u + WaveGetLaneCount() - 1u) / WaveGetLaneCount(), WaveGetLaneCount() };
|
||||
uint subgroup_id = computeinput_main.__local_invocation_index / WaveGetLaneCount();
|
||||
uint subgroup_invocation_id = WaveGetLaneIndex();
|
||||
const uint4 _e7 = WaveActiveBallot(((subgroup_invocation_id & 1u) == 1u));
|
||||
const uint4 _e8 = WaveActiveBallot(true);
|
||||
const bool _e11 = WaveActiveAllTrue((subgroup_invocation_id != 0u));
|
||||
const bool _e14 = WaveActiveAnyTrue((subgroup_invocation_id == 0u));
|
||||
const uint _e15 = WaveActiveSum(subgroup_invocation_id);
|
||||
const uint _e16 = WaveActiveProduct(subgroup_invocation_id);
|
||||
const uint _e17 = WaveActiveMin(subgroup_invocation_id);
|
||||
const uint _e18 = WaveActiveMax(subgroup_invocation_id);
|
||||
const uint _e19 = WaveActiveBitAnd(subgroup_invocation_id);
|
||||
const uint _e20 = WaveActiveBitOr(subgroup_invocation_id);
|
||||
const uint _e21 = WaveActiveBitXor(subgroup_invocation_id);
|
||||
const uint _e22 = WavePrefixSum(subgroup_invocation_id);
|
||||
const uint _e23 = WavePrefixProduct(subgroup_invocation_id);
|
||||
const uint _e24 = subgroup_invocation_id + WavePrefixSum(subgroup_invocation_id);
|
||||
const uint _e25 = subgroup_invocation_id * WavePrefixProduct(subgroup_invocation_id);
|
||||
const uint _e26 = WaveReadLaneFirst(subgroup_invocation_id);
|
||||
const uint _e28 = WaveReadLaneAt(subgroup_invocation_id, 4u);
|
||||
const uint _e33 = WaveReadLaneAt(subgroup_invocation_id, ((sizes.subgroup_size - 1u) - subgroup_invocation_id));
|
||||
const uint _e35 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() + 1u);
|
||||
const uint _e37 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() - 1u);
|
||||
const uint _e41 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() ^ (sizes.subgroup_size - 1u));
|
||||
return;
|
||||
}
|
12
naga/tests/out/hlsl/subgroup-operations.ron
Normal file
12
naga/tests/out/hlsl/subgroup-operations.ron
Normal file
@ -0,0 +1,12 @@
|
||||
(
|
||||
vertex:[
|
||||
],
|
||||
fragment:[
|
||||
],
|
||||
compute:[
|
||||
(
|
||||
entry_point:"main",
|
||||
target_profile:"cs_6_0",
|
||||
),
|
||||
],
|
||||
)
|
55
naga/tests/out/msl/subgroup-operations-s.msl
Normal file
55
naga/tests/out/msl/subgroup-operations-s.msl
Normal file
@ -0,0 +1,55 @@
|
||||
// language: metal2.4
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
|
||||
using metal::uint;
|
||||
|
||||
|
||||
void main_1(
|
||||
thread uint& subgroup_size_1,
|
||||
thread uint& subgroup_invocation_id_1
|
||||
) {
|
||||
uint _e5 = subgroup_size_1;
|
||||
uint _e6 = subgroup_invocation_id_1;
|
||||
metal::uint4 unnamed = uint4((uint64_t)metal::simd_ballot((_e6 & 1u) == 1u), 0, 0, 0);
|
||||
metal::uint4 unnamed_1 = uint4((uint64_t)metal::simd_ballot(true), 0, 0, 0);
|
||||
bool unnamed_2 = metal::simd_all(_e6 != 0u);
|
||||
bool unnamed_3 = metal::simd_any(_e6 == 0u);
|
||||
uint unnamed_4 = metal::simd_sum(_e6);
|
||||
uint unnamed_5 = metal::simd_product(_e6);
|
||||
uint unnamed_6 = metal::simd_min(_e6);
|
||||
uint unnamed_7 = metal::simd_max(_e6);
|
||||
uint unnamed_8 = metal::simd_and(_e6);
|
||||
uint unnamed_9 = metal::simd_or(_e6);
|
||||
uint unnamed_10 = metal::simd_xor(_e6);
|
||||
uint unnamed_11 = metal::simd_prefix_exclusive_sum(_e6);
|
||||
uint unnamed_12 = metal::simd_prefix_exclusive_product(_e6);
|
||||
uint unnamed_13 = metal::simd_prefix_inclusive_sum(_e6);
|
||||
uint unnamed_14 = metal::simd_prefix_inclusive_product(_e6);
|
||||
uint unnamed_15 = metal::simd_broadcast_first(_e6);
|
||||
uint unnamed_16 = metal::simd_broadcast(_e6, 4u);
|
||||
uint unnamed_17 = metal::simd_shuffle(_e6, (_e5 - 1u) - _e6);
|
||||
uint unnamed_18 = metal::simd_shuffle_down(_e6, 1u);
|
||||
uint unnamed_19 = metal::simd_shuffle_up(_e6, 1u);
|
||||
uint unnamed_20 = metal::simd_shuffle_xor(_e6, _e5 - 1u);
|
||||
return;
|
||||
}
|
||||
|
||||
struct main_Input {
|
||||
};
|
||||
kernel void main_(
|
||||
uint num_subgroups [[simdgroups_per_threadgroup]]
|
||||
, uint subgroup_id [[simdgroup_index_in_threadgroup]]
|
||||
, uint subgroup_size [[threads_per_simdgroup]]
|
||||
, uint subgroup_invocation_id [[thread_index_in_simdgroup]]
|
||||
) {
|
||||
uint num_subgroups_1 = {};
|
||||
uint subgroup_id_1 = {};
|
||||
uint subgroup_size_1 = {};
|
||||
uint subgroup_invocation_id_1 = {};
|
||||
num_subgroups_1 = num_subgroups;
|
||||
subgroup_id_1 = subgroup_id;
|
||||
subgroup_size_1 = subgroup_size;
|
||||
subgroup_invocation_id_1 = subgroup_invocation_id;
|
||||
main_1(subgroup_size_1, subgroup_invocation_id_1);
|
||||
}
|
44
naga/tests/out/msl/subgroup-operations.msl
Normal file
44
naga/tests/out/msl/subgroup-operations.msl
Normal file
@ -0,0 +1,44 @@
|
||||
// language: metal2.4
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
|
||||
using metal::uint;
|
||||
|
||||
struct Structure {
|
||||
uint num_subgroups;
|
||||
uint subgroup_size;
|
||||
};
|
||||
|
||||
struct main_Input {
|
||||
};
|
||||
kernel void main_(
|
||||
uint num_subgroups [[simdgroups_per_threadgroup]]
|
||||
, uint subgroup_size [[threads_per_simdgroup]]
|
||||
, uint subgroup_id [[simdgroup_index_in_threadgroup]]
|
||||
, uint subgroup_invocation_id [[thread_index_in_simdgroup]]
|
||||
) {
|
||||
const Structure sizes = { num_subgroups, subgroup_size };
|
||||
metal::simdgroup_barrier(metal::mem_flags::mem_threadgroup);
|
||||
metal::uint4 unnamed = uint4((uint64_t)metal::simd_ballot((subgroup_invocation_id & 1u) == 1u), 0, 0, 0);
|
||||
metal::uint4 unnamed_1 = uint4((uint64_t)metal::simd_ballot(true), 0, 0, 0);
|
||||
bool unnamed_2 = metal::simd_all(subgroup_invocation_id != 0u);
|
||||
bool unnamed_3 = metal::simd_any(subgroup_invocation_id == 0u);
|
||||
uint unnamed_4 = metal::simd_sum(subgroup_invocation_id);
|
||||
uint unnamed_5 = metal::simd_product(subgroup_invocation_id);
|
||||
uint unnamed_6 = metal::simd_min(subgroup_invocation_id);
|
||||
uint unnamed_7 = metal::simd_max(subgroup_invocation_id);
|
||||
uint unnamed_8 = metal::simd_and(subgroup_invocation_id);
|
||||
uint unnamed_9 = metal::simd_or(subgroup_invocation_id);
|
||||
uint unnamed_10 = metal::simd_xor(subgroup_invocation_id);
|
||||
uint unnamed_11 = metal::simd_prefix_exclusive_sum(subgroup_invocation_id);
|
||||
uint unnamed_12 = metal::simd_prefix_exclusive_product(subgroup_invocation_id);
|
||||
uint unnamed_13 = metal::simd_prefix_inclusive_sum(subgroup_invocation_id);
|
||||
uint unnamed_14 = metal::simd_prefix_inclusive_product(subgroup_invocation_id);
|
||||
uint unnamed_15 = metal::simd_broadcast_first(subgroup_invocation_id);
|
||||
uint unnamed_16 = metal::simd_broadcast(subgroup_invocation_id, 4u);
|
||||
uint unnamed_17 = metal::simd_shuffle(subgroup_invocation_id, (sizes.subgroup_size - 1u) - subgroup_invocation_id);
|
||||
uint unnamed_18 = metal::simd_shuffle_down(subgroup_invocation_id, 1u);
|
||||
uint unnamed_19 = metal::simd_shuffle_up(subgroup_invocation_id, 1u);
|
||||
uint unnamed_20 = metal::simd_shuffle_xor(subgroup_invocation_id, sizes.subgroup_size - 1u);
|
||||
return;
|
||||
}
|
81
naga/tests/out/spv/subgroup-operations.spvasm
Normal file
81
naga/tests/out/spv/subgroup-operations.spvasm
Normal file
@ -0,0 +1,81 @@
|
||||
; SPIR-V
|
||||
; Version: 1.3
|
||||
; Generator: rspirv
|
||||
; Bound: 58
|
||||
OpCapability Shader
|
||||
OpCapability GroupNonUniform
|
||||
OpCapability GroupNonUniformBallot
|
||||
OpCapability GroupNonUniformVote
|
||||
OpCapability GroupNonUniformArithmetic
|
||||
OpCapability GroupNonUniformShuffle
|
||||
OpCapability GroupNonUniformShuffleRelative
|
||||
%1 = OpExtInstImport "GLSL.std.450"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %17 "main" %8 %11 %13 %15
|
||||
OpExecutionMode %17 LocalSize 1 1 1
|
||||
OpMemberDecorate %4 0 Offset 0
|
||||
OpMemberDecorate %4 1 Offset 4
|
||||
OpDecorate %8 BuiltIn NumSubgroups
|
||||
OpDecorate %11 BuiltIn SubgroupSize
|
||||
OpDecorate %13 BuiltIn SubgroupId
|
||||
OpDecorate %15 BuiltIn SubgroupLocalInvocationId
|
||||
%2 = OpTypeVoid
|
||||
%3 = OpTypeInt 32 0
|
||||
%4 = OpTypeStruct %3 %3
|
||||
%5 = OpTypeBool
|
||||
%9 = OpTypePointer Input %3
|
||||
%8 = OpVariable %9 Input
|
||||
%11 = OpVariable %9 Input
|
||||
%13 = OpVariable %9 Input
|
||||
%15 = OpVariable %9 Input
|
||||
%18 = OpTypeFunction %2
|
||||
%19 = OpConstant %3 1
|
||||
%20 = OpConstant %3 0
|
||||
%21 = OpConstant %3 4
|
||||
%23 = OpConstant %3 3
|
||||
%24 = OpConstant %3 2
|
||||
%25 = OpConstant %3 8
|
||||
%28 = OpTypeVector %3 4
|
||||
%30 = OpConstantTrue %5
|
||||
%17 = OpFunction %2 None %18
|
||||
%6 = OpLabel
|
||||
%10 = OpLoad %3 %8
|
||||
%12 = OpLoad %3 %11
|
||||
%7 = OpCompositeConstruct %4 %10 %12
|
||||
%14 = OpLoad %3 %13
|
||||
%16 = OpLoad %3 %15
|
||||
OpBranch %22
|
||||
%22 = OpLabel
|
||||
OpControlBarrier %23 %24 %25
|
||||
%26 = OpBitwiseAnd %3 %16 %19
|
||||
%27 = OpIEqual %5 %26 %19
|
||||
%29 = OpGroupNonUniformBallot %28 %23 %27
|
||||
%31 = OpGroupNonUniformBallot %28 %23 %30
|
||||
%32 = OpINotEqual %5 %16 %20
|
||||
%33 = OpGroupNonUniformAll %5 %23 %32
|
||||
%34 = OpIEqual %5 %16 %20
|
||||
%35 = OpGroupNonUniformAny %5 %23 %34
|
||||
%36 = OpGroupNonUniformIAdd %3 %23 Reduce %16
|
||||
%37 = OpGroupNonUniformIMul %3 %23 Reduce %16
|
||||
%38 = OpGroupNonUniformUMin %3 %23 Reduce %16
|
||||
%39 = OpGroupNonUniformUMax %3 %23 Reduce %16
|
||||
%40 = OpGroupNonUniformBitwiseAnd %3 %23 Reduce %16
|
||||
%41 = OpGroupNonUniformBitwiseOr %3 %23 Reduce %16
|
||||
%42 = OpGroupNonUniformBitwiseXor %3 %23 Reduce %16
|
||||
%43 = OpGroupNonUniformIAdd %3 %23 ExclusiveScan %16
|
||||
%44 = OpGroupNonUniformIMul %3 %23 ExclusiveScan %16
|
||||
%45 = OpGroupNonUniformIAdd %3 %23 InclusiveScan %16
|
||||
%46 = OpGroupNonUniformIMul %3 %23 InclusiveScan %16
|
||||
%47 = OpGroupNonUniformBroadcastFirst %3 %23 %16
|
||||
%48 = OpGroupNonUniformShuffle %3 %23 %16 %21
|
||||
%49 = OpCompositeExtract %3 %7 1
|
||||
%50 = OpISub %3 %49 %19
|
||||
%51 = OpISub %3 %50 %16
|
||||
%52 = OpGroupNonUniformShuffle %3 %23 %16 %51
|
||||
%53 = OpGroupNonUniformShuffleDown %3 %23 %16 %19
|
||||
%54 = OpGroupNonUniformShuffleUp %3 %23 %16 %19
|
||||
%55 = OpCompositeExtract %3 %7 1
|
||||
%56 = OpISub %3 %55 %19
|
||||
%57 = OpGroupNonUniformShuffleXor %3 %23 %16 %56
|
||||
OpReturn
|
||||
OpFunctionEnd
|
40
naga/tests/out/wgsl/subgroup-operations-s.wgsl
Normal file
40
naga/tests/out/wgsl/subgroup-operations-s.wgsl
Normal file
@ -0,0 +1,40 @@
|
||||
var<private> num_subgroups_1: u32;
|
||||
var<private> subgroup_id_1: u32;
|
||||
var<private> subgroup_size_1: u32;
|
||||
var<private> subgroup_invocation_id_1: u32;
|
||||
|
||||
fn main_1() {
|
||||
let _e5 = subgroup_size_1;
|
||||
let _e6 = subgroup_invocation_id_1;
|
||||
let _e9 = subgroupBallot(((_e6 & 1u) == 1u));
|
||||
let _e10 = subgroupBallot();
|
||||
let _e12 = subgroupAll((_e6 != 0u));
|
||||
let _e14 = subgroupAny((_e6 == 0u));
|
||||
let _e15 = subgroupAdd(_e6);
|
||||
let _e16 = subgroupMul(_e6);
|
||||
let _e17 = subgroupMin(_e6);
|
||||
let _e18 = subgroupMax(_e6);
|
||||
let _e19 = subgroupAnd(_e6);
|
||||
let _e20 = subgroupOr(_e6);
|
||||
let _e21 = subgroupXor(_e6);
|
||||
let _e22 = subgroupExclusiveAdd(_e6);
|
||||
let _e23 = subgroupExclusiveMul(_e6);
|
||||
let _e24 = subgroupInclusiveAdd(_e6);
|
||||
let _e25 = subgroupInclusiveMul(_e6);
|
||||
let _e26 = subgroupBroadcastFirst(_e6);
|
||||
let _e27 = subgroupBroadcast(_e6, 4u);
|
||||
let _e30 = subgroupShuffle(_e6, ((_e5 - 1u) - _e6));
|
||||
let _e31 = subgroupShuffleDown(_e6, 1u);
|
||||
let _e32 = subgroupShuffleUp(_e6, 1u);
|
||||
let _e34 = subgroupShuffleXor(_e6, (_e5 - 1u));
|
||||
return;
|
||||
}
|
||||
|
||||
@compute @workgroup_size(1, 1, 1)
|
||||
fn main(@builtin(num_subgroups) num_subgroups: u32, @builtin(subgroup_id) subgroup_id: u32, @builtin(subgroup_size) subgroup_size: u32, @builtin(subgroup_invocation_id) subgroup_invocation_id: u32) {
|
||||
num_subgroups_1 = num_subgroups;
|
||||
subgroup_id_1 = subgroup_id;
|
||||
subgroup_size_1 = subgroup_size;
|
||||
subgroup_invocation_id_1 = subgroup_invocation_id;
|
||||
main_1();
|
||||
}
|
31
naga/tests/out/wgsl/subgroup-operations.wgsl
Normal file
31
naga/tests/out/wgsl/subgroup-operations.wgsl
Normal file
@ -0,0 +1,31 @@
|
||||
struct Structure {
|
||||
@builtin(num_subgroups) num_subgroups: u32,
|
||||
@builtin(subgroup_size) subgroup_size: u32,
|
||||
}
|
||||
|
||||
@compute @workgroup_size(1, 1, 1)
|
||||
fn main(sizes: Structure, @builtin(subgroup_id) subgroup_id: u32, @builtin(subgroup_invocation_id) subgroup_invocation_id: u32) {
|
||||
subgroupBarrier();
|
||||
let _e7 = subgroupBallot(((subgroup_invocation_id & 1u) == 1u));
|
||||
let _e8 = subgroupBallot();
|
||||
let _e11 = subgroupAll((subgroup_invocation_id != 0u));
|
||||
let _e14 = subgroupAny((subgroup_invocation_id == 0u));
|
||||
let _e15 = subgroupAdd(subgroup_invocation_id);
|
||||
let _e16 = subgroupMul(subgroup_invocation_id);
|
||||
let _e17 = subgroupMin(subgroup_invocation_id);
|
||||
let _e18 = subgroupMax(subgroup_invocation_id);
|
||||
let _e19 = subgroupAnd(subgroup_invocation_id);
|
||||
let _e20 = subgroupOr(subgroup_invocation_id);
|
||||
let _e21 = subgroupXor(subgroup_invocation_id);
|
||||
let _e22 = subgroupExclusiveAdd(subgroup_invocation_id);
|
||||
let _e23 = subgroupExclusiveMul(subgroup_invocation_id);
|
||||
let _e24 = subgroupInclusiveAdd(subgroup_invocation_id);
|
||||
let _e25 = subgroupInclusiveMul(subgroup_invocation_id);
|
||||
let _e26 = subgroupBroadcastFirst(subgroup_invocation_id);
|
||||
let _e28 = subgroupBroadcast(subgroup_invocation_id, 4u);
|
||||
let _e33 = subgroupShuffle(subgroup_invocation_id, ((sizes.subgroup_size - 1u) - subgroup_invocation_id));
|
||||
let _e35 = subgroupShuffleDown(subgroup_invocation_id, 1u);
|
||||
let _e37 = subgroupShuffleUp(subgroup_invocation_id, 1u);
|
||||
let _e41 = subgroupShuffleXor(subgroup_invocation_id, (sizes.subgroup_size - 1u));
|
||||
return;
|
||||
}
|
@ -269,10 +269,18 @@ fn check_targets(
|
||||
let params = input.read_parameters();
|
||||
let name = &input.file_name;
|
||||
|
||||
let capabilities = if params.god_mode {
|
||||
naga::valid::Capabilities::all()
|
||||
let (capabilities, subgroup_stages, subgroup_operations) = if params.god_mode {
|
||||
(
|
||||
naga::valid::Capabilities::all(),
|
||||
naga::valid::ShaderStages::all(),
|
||||
naga::valid::SubgroupOperationSet::all(),
|
||||
)
|
||||
} else {
|
||||
naga::valid::Capabilities::default()
|
||||
(
|
||||
naga::valid::Capabilities::default(),
|
||||
naga::valid::ShaderStages::empty(),
|
||||
naga::valid::SubgroupOperationSet::empty(),
|
||||
)
|
||||
};
|
||||
|
||||
#[cfg(feature = "serialize")]
|
||||
@ -285,6 +293,8 @@ fn check_targets(
|
||||
}
|
||||
|
||||
let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities)
|
||||
.subgroup_stages(subgroup_stages)
|
||||
.subgroup_operations(subgroup_operations)
|
||||
.validate(module)
|
||||
.unwrap_or_else(|err| {
|
||||
panic!(
|
||||
@ -308,6 +318,8 @@ fn check_targets(
|
||||
}
|
||||
|
||||
naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities)
|
||||
.subgroup_stages(subgroup_stages)
|
||||
.subgroup_operations(subgroup_operations)
|
||||
.validate(module)
|
||||
.unwrap_or_else(|err| {
|
||||
panic!(
|
||||
@ -850,6 +862,10 @@ fn convert_wgsl() {
|
||||
"int64",
|
||||
Targets::SPIRV | Targets::HLSL | Targets::WGSL | Targets::METAL,
|
||||
),
|
||||
(
|
||||
"subgroup-operations",
|
||||
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
|
||||
),
|
||||
(
|
||||
"overrides",
|
||||
Targets::IR
|
||||
@ -957,6 +973,11 @@ fn convert_spv_all() {
|
||||
);
|
||||
convert_spv("builtin-accessed-outside-entrypoint", true, Targets::WGSL);
|
||||
convert_spv("spec-constants", true, Targets::IR);
|
||||
convert_spv(
|
||||
"subgroup-operations-s",
|
||||
false,
|
||||
Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "glsl-in")]
|
||||
|
@ -33,6 +33,7 @@ mod scissor_tests;
|
||||
mod shader;
|
||||
mod shader_primitive_index;
|
||||
mod shader_view_format;
|
||||
mod subgroup_operations;
|
||||
mod texture_bounds;
|
||||
mod texture_view_creation;
|
||||
mod transfer;
|
||||
|
126
tests/tests/subgroup_operations/mod.rs
Normal file
126
tests/tests/subgroup_operations/mod.rs
Normal file
@ -0,0 +1,126 @@
|
||||
use std::{borrow::Cow, collections::HashMap, num::NonZeroU64};
|
||||
|
||||
use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters};
|
||||
|
||||
const THREAD_COUNT: u64 = 128;
|
||||
const TEST_COUNT: u32 = 32;
|
||||
|
||||
#[gpu_test]
|
||||
static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new()
|
||||
.parameters(
|
||||
TestParameters::default()
|
||||
.features(wgpu::Features::SUBGROUP)
|
||||
.limits(wgpu::Limits::downlevel_defaults())
|
||||
.expect_fail(wgpu_test::FailureCase::molten_vk())
|
||||
.expect_fail(
|
||||
// Expect metal to fail on tests involving operations in divergent control flow
|
||||
wgpu_test::FailureCase::backend(wgpu::Backends::METAL)
|
||||
.panic("thread 0 failed tests: 27, 29,\nthread 1 failed tests: 27, 28, 29,\n"),
|
||||
),
|
||||
)
|
||||
.run_sync(|ctx| {
|
||||
let device = &ctx.device;
|
||||
|
||||
let storage_buffer = device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: None,
|
||||
size: THREAD_COUNT * std::mem::size_of::<u32>() as u64,
|
||||
usage: wgpu::BufferUsages::STORAGE
|
||||
| wgpu::BufferUsages::COPY_DST
|
||||
| wgpu::BufferUsages::COPY_SRC,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
|
||||
label: Some("bind group layout"),
|
||||
entries: &[wgpu::BindGroupLayoutEntry {
|
||||
binding: 0,
|
||||
visibility: wgpu::ShaderStages::COMPUTE,
|
||||
ty: wgpu::BindingType::Buffer {
|
||||
ty: wgpu::BufferBindingType::Storage { read_only: false },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: NonZeroU64::new(
|
||||
THREAD_COUNT * std::mem::size_of::<u32>() as u64,
|
||||
),
|
||||
},
|
||||
count: None,
|
||||
}],
|
||||
});
|
||||
|
||||
let cs_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
|
||||
label: None,
|
||||
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("shader.wgsl"))),
|
||||
});
|
||||
|
||||
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
|
||||
label: Some("main"),
|
||||
bind_group_layouts: &[&bind_group_layout],
|
||||
push_constant_ranges: &[],
|
||||
});
|
||||
|
||||
let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
|
||||
label: None,
|
||||
layout: Some(&pipeline_layout),
|
||||
module: &cs_module,
|
||||
entry_point: "main",
|
||||
constants: &HashMap::default(),
|
||||
});
|
||||
|
||||
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
|
||||
entries: &[wgpu::BindGroupEntry {
|
||||
binding: 0,
|
||||
resource: storage_buffer.as_entire_binding(),
|
||||
}],
|
||||
layout: &bind_group_layout,
|
||||
label: Some("bind group"),
|
||||
});
|
||||
|
||||
let mut encoder =
|
||||
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
|
||||
{
|
||||
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
|
||||
label: None,
|
||||
timestamp_writes: None,
|
||||
});
|
||||
cpass.set_pipeline(&compute_pipeline);
|
||||
cpass.set_bind_group(0, &bind_group, &[]);
|
||||
cpass.dispatch_workgroups(1, 1, 1);
|
||||
}
|
||||
ctx.queue.submit(Some(encoder.finish()));
|
||||
|
||||
wgpu::util::DownloadBuffer::read_buffer(
|
||||
device,
|
||||
&ctx.queue,
|
||||
&storage_buffer.slice(..),
|
||||
|mapping_buffer_view| {
|
||||
let mapping_buffer_view = mapping_buffer_view.unwrap();
|
||||
let result: &[u32; THREAD_COUNT as usize] =
|
||||
bytemuck::from_bytes(&mapping_buffer_view);
|
||||
let expected_mask = (1u64 << (TEST_COUNT)) - 1; // generate full mask
|
||||
let expected_array = [expected_mask as u32; THREAD_COUNT as usize];
|
||||
if result != &expected_array {
|
||||
use std::fmt::Write;
|
||||
let mut msg = String::new();
|
||||
writeln!(
|
||||
&mut msg,
|
||||
"Got from GPU:\n{:x?}\n expected:\n{:x?}",
|
||||
result, &expected_array,
|
||||
)
|
||||
.unwrap();
|
||||
for (thread, (result, expected)) in result
|
||||
.iter()
|
||||
.zip(expected_array)
|
||||
.enumerate()
|
||||
.filter(|(_, (r, e))| *r != e)
|
||||
{
|
||||
write!(&mut msg, "thread {} failed tests:", thread).unwrap();
|
||||
let difference = result ^ expected;
|
||||
for i in (0..u32::BITS).filter(|i| (difference & (1 << i)) != 0) {
|
||||
write!(&mut msg, " {},", i).unwrap();
|
||||
}
|
||||
writeln!(&mut msg).unwrap();
|
||||
}
|
||||
panic!("{}", msg);
|
||||
}
|
||||
},
|
||||
);
|
||||
});
|
158
tests/tests/subgroup_operations/shader.wgsl
Normal file
158
tests/tests/subgroup_operations/shader.wgsl
Normal file
@ -0,0 +1,158 @@
|
||||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read_write> storage_buffer: array<u32>;
|
||||
|
||||
var<workgroup> workgroup_buffer: u32;
|
||||
|
||||
fn add_result_to_mask(mask: ptr<function, u32>, index: u32, value: bool) {
|
||||
(*mask) |= u32(value) << index;
|
||||
}
|
||||
|
||||
@compute
|
||||
@workgroup_size(128)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_subgroups) num_subgroups: u32,
|
||||
@builtin(subgroup_id) subgroup_id: u32,
|
||||
@builtin(subgroup_size) subgroup_size: u32,
|
||||
@builtin(subgroup_invocation_id) subgroup_invocation_id: u32,
|
||||
) {
|
||||
var passed = 0u;
|
||||
var expected: u32;
|
||||
|
||||
add_result_to_mask(&passed, 0u, num_subgroups == 128u / subgroup_size);
|
||||
add_result_to_mask(&passed, 1u, subgroup_id == global_id.x / subgroup_size);
|
||||
add_result_to_mask(&passed, 2u, subgroup_invocation_id == global_id.x % subgroup_size);
|
||||
|
||||
var expected_ballot = vec4<u32>(0u);
|
||||
for(var i = 0u; i < subgroup_size; i += 1u) {
|
||||
expected_ballot[i / 32u] |= ((global_id.x - subgroup_invocation_id + i) & 1u) << (i % 32u);
|
||||
}
|
||||
add_result_to_mask(&passed, 3u, dot(vec4<u32>(1u), vec4<u32>(subgroupBallot((subgroup_invocation_id & 1u) == 1u) == expected_ballot)) == 4u);
|
||||
|
||||
add_result_to_mask(&passed, 4u, subgroupAll(true));
|
||||
add_result_to_mask(&passed, 5u, !subgroupAll(subgroup_invocation_id != 0u));
|
||||
|
||||
add_result_to_mask(&passed, 6u, subgroupAny(subgroup_invocation_id == 0u));
|
||||
add_result_to_mask(&passed, 7u, !subgroupAny(false));
|
||||
|
||||
expected = 0u;
|
||||
for(var i = 0u; i < subgroup_size; i += 1u) {
|
||||
expected += global_id.x - subgroup_invocation_id + i + 1u;
|
||||
}
|
||||
add_result_to_mask(&passed, 8u, subgroupAdd(global_id.x + 1u) == expected);
|
||||
|
||||
expected = 1u;
|
||||
for(var i = 0u; i < subgroup_size; i += 1u) {
|
||||
expected *= global_id.x - subgroup_invocation_id + i + 1u;
|
||||
}
|
||||
add_result_to_mask(&passed, 9u, subgroupMul(global_id.x + 1u) == expected);
|
||||
|
||||
expected = 0u;
|
||||
for(var i = 0u; i < subgroup_size; i += 1u) {
|
||||
expected = max(expected, global_id.x - subgroup_invocation_id + i + 1u);
|
||||
}
|
||||
add_result_to_mask(&passed, 10u, subgroupMax(global_id.x + 1u) == expected);
|
||||
|
||||
expected = 0xFFFFFFFFu;
|
||||
for(var i = 0u; i < subgroup_size; i += 1u) {
|
||||
expected = min(expected, global_id.x - subgroup_invocation_id + i + 1u);
|
||||
}
|
||||
add_result_to_mask(&passed, 11u, subgroupMin(global_id.x + 1u) == expected);
|
||||
|
||||
expected = 0xFFFFFFFFu;
|
||||
for(var i = 0u; i < subgroup_size; i += 1u) {
|
||||
expected &= global_id.x - subgroup_invocation_id + i + 1u;
|
||||
}
|
||||
add_result_to_mask(&passed, 12u, subgroupAnd(global_id.x + 1u) == expected);
|
||||
|
||||
expected = 0u;
|
||||
for(var i = 0u; i < subgroup_size; i += 1u) {
|
||||
expected |= global_id.x - subgroup_invocation_id + i + 1u;
|
||||
}
|
||||
add_result_to_mask(&passed, 13u, subgroupOr(global_id.x + 1u) == expected);
|
||||
|
||||
expected = 0u;
|
||||
for(var i = 0u; i < subgroup_size; i += 1u) {
|
||||
expected ^= global_id.x - subgroup_invocation_id + i + 1u;
|
||||
}
|
||||
add_result_to_mask(&passed, 14u, subgroupXor(global_id.x + 1u) == expected);
|
||||
|
||||
expected = 0u;
|
||||
for(var i = 0u; i < subgroup_invocation_id; i += 1u) {
|
||||
expected += global_id.x - subgroup_invocation_id + i + 1u;
|
||||
}
|
||||
add_result_to_mask(&passed, 15u, subgroupExclusiveAdd(global_id.x + 1u) == expected);
|
||||
|
||||
expected = 1u;
|
||||
for(var i = 0u; i < subgroup_invocation_id; i += 1u) {
|
||||
expected *= global_id.x - subgroup_invocation_id + i + 1u;
|
||||
}
|
||||
add_result_to_mask(&passed, 16u, subgroupExclusiveMul(global_id.x + 1u) == expected);
|
||||
|
||||
expected = 0u;
|
||||
for(var i = 0u; i <= subgroup_invocation_id; i += 1u) {
|
||||
expected += global_id.x - subgroup_invocation_id + i + 1u;
|
||||
}
|
||||
add_result_to_mask(&passed, 17u, subgroupInclusiveAdd(global_id.x + 1u) == expected);
|
||||
|
||||
expected = 1u;
|
||||
for(var i = 0u; i <= subgroup_invocation_id; i += 1u) {
|
||||
expected *= global_id.x - subgroup_invocation_id + i + 1u;
|
||||
}
|
||||
add_result_to_mask(&passed, 18u, subgroupInclusiveMul(global_id.x + 1u) == expected);
|
||||
|
||||
add_result_to_mask(&passed, 19u, subgroupBroadcastFirst(u32(subgroup_invocation_id != 0u)) == 0u);
|
||||
add_result_to_mask(&passed, 20u, subgroupBroadcastFirst(u32(subgroup_invocation_id == 0u)) == 1u);
|
||||
add_result_to_mask(&passed, 21u, subgroupBroadcast(subgroup_invocation_id, 1u) == 1u);
|
||||
add_result_to_mask(&passed, 22u, subgroupShuffle(subgroup_invocation_id, subgroup_invocation_id) == subgroup_invocation_id);
|
||||
add_result_to_mask(&passed, 23u, subgroupShuffle(subgroup_invocation_id, subgroup_size - 1u - subgroup_invocation_id) == subgroup_size - 1u - subgroup_invocation_id);
|
||||
add_result_to_mask(&passed, 24u, subgroup_invocation_id == subgroup_size - 1u || subgroupShuffleDown(subgroup_invocation_id, 1u) == subgroup_invocation_id + 1u);
|
||||
add_result_to_mask(&passed, 25u, subgroup_invocation_id == 0u || subgroupShuffleUp(subgroup_invocation_id, 1u) == subgroup_invocation_id - 1u);
|
||||
add_result_to_mask(&passed, 26u, subgroupShuffleXor(subgroup_invocation_id, subgroup_size - 1u) == (subgroup_invocation_id ^ (subgroup_size - 1u)));
|
||||
|
||||
var passed_27 = false;
|
||||
if subgroup_invocation_id % 2u == 0u {
|
||||
passed_27 |= subgroupAdd(1u) == (subgroup_size / 2u);
|
||||
} else {
|
||||
passed_27 |= subgroupAdd(1u) == (subgroup_size / 2u);
|
||||
}
|
||||
add_result_to_mask(&passed, 27u, passed_27);
|
||||
|
||||
var passed_28 = false;
|
||||
switch subgroup_invocation_id % 3u {
|
||||
case 0u: {
|
||||
passed_28 = subgroupBroadcastFirst(subgroup_invocation_id) == 0u;
|
||||
}
|
||||
case 1u: {
|
||||
passed_28 = subgroupBroadcastFirst(subgroup_invocation_id) == 1u;
|
||||
}
|
||||
case 2u: {
|
||||
passed_28 = subgroupBroadcastFirst(subgroup_invocation_id) == 2u;
|
||||
}
|
||||
default { }
|
||||
}
|
||||
add_result_to_mask(&passed, 28u, passed_28);
|
||||
|
||||
expected = 0u;
|
||||
for (var i = subgroup_size; i >= 0u; i -= 1u) {
|
||||
expected = subgroupAdd(1u);
|
||||
if i == subgroup_invocation_id {
|
||||
break;
|
||||
}
|
||||
}
|
||||
add_result_to_mask(&passed, 29u, expected == (subgroup_invocation_id + 1u));
|
||||
|
||||
if global_id.x == 0u {
|
||||
workgroup_buffer = subgroup_size;
|
||||
}
|
||||
workgroupBarrier();
|
||||
add_result_to_mask(&passed, 30u, workgroup_buffer == subgroup_size);
|
||||
|
||||
// Keep this test last, verify we are still convergent after running other tests
|
||||
add_result_to_mask(&passed, 31u, subgroupAdd(1u) == subgroup_size);
|
||||
|
||||
// Increment TEST_COUNT in subgroup_operations/mod.rs if adding more tests
|
||||
|
||||
storage_buffer[global_id.x] = passed;
|
||||
}
|
@ -1537,6 +1537,15 @@ impl<A: HalApi> Device<A> {
|
||||
.flags
|
||||
.contains(wgt::DownlevelFlags::CUBE_ARRAY_TEXTURES),
|
||||
);
|
||||
caps.set(
|
||||
Caps::SUBGROUP,
|
||||
self.features
|
||||
.intersects(wgt::Features::SUBGROUP | wgt::Features::SUBGROUP_VERTEX),
|
||||
);
|
||||
caps.set(
|
||||
Caps::SUBGROUP_BARRIER,
|
||||
self.features.intersects(wgt::Features::SUBGROUP_BARRIER),
|
||||
);
|
||||
|
||||
let debug_source =
|
||||
if self.instance_flags.contains(wgt::InstanceFlags::DEBUG) && !source.is_empty() {
|
||||
@ -1552,7 +1561,26 @@ impl<A: HalApi> Device<A> {
|
||||
None
|
||||
};
|
||||
|
||||
let mut subgroup_stages = naga::valid::ShaderStages::empty();
|
||||
subgroup_stages.set(
|
||||
naga::valid::ShaderStages::COMPUTE | naga::valid::ShaderStages::FRAGMENT,
|
||||
self.features.contains(wgt::Features::SUBGROUP),
|
||||
);
|
||||
subgroup_stages.set(
|
||||
naga::valid::ShaderStages::VERTEX,
|
||||
self.features.contains(wgt::Features::SUBGROUP_VERTEX),
|
||||
);
|
||||
|
||||
let subgroup_operations = if caps.contains(Caps::SUBGROUP) {
|
||||
use naga::valid::SubgroupOperationSet as S;
|
||||
S::BASIC | S::VOTE | S::ARITHMETIC | S::BALLOT | S::SHUFFLE | S::SHUFFLE_RELATIVE
|
||||
} else {
|
||||
naga::valid::SubgroupOperationSet::empty()
|
||||
};
|
||||
|
||||
let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), caps)
|
||||
.subgroup_stages(subgroup_stages)
|
||||
.subgroup_operations(subgroup_operations)
|
||||
.validate(&module)
|
||||
.map_err(|inner| {
|
||||
pipeline::CreateShaderModuleError::Validation(pipeline::ShaderError {
|
||||
|
@ -127,6 +127,11 @@ impl super::Adapter {
|
||||
)
|
||||
});
|
||||
|
||||
// If we don't have dxc, we reduce the max to 5.1
|
||||
if dxc_container.is_none() {
|
||||
shader_model_support.HighestShaderModel = d3d12_ty::D3D_SHADER_MODEL_5_1;
|
||||
}
|
||||
|
||||
let mut workarounds = super::Workarounds::default();
|
||||
|
||||
let info = wgt::AdapterInfo {
|
||||
@ -343,21 +348,28 @@ impl super::Adapter {
|
||||
bgra8unorm_storage_supported,
|
||||
);
|
||||
|
||||
let mut features1: d3d12_ty::D3D12_FEATURE_DATA_D3D12_OPTIONS1 = unsafe { mem::zeroed() };
|
||||
let hr = unsafe {
|
||||
device.CheckFeatureSupport(
|
||||
d3d12_ty::D3D12_FEATURE_D3D12_OPTIONS1,
|
||||
&mut features1 as *mut _ as *mut _,
|
||||
mem::size_of::<d3d12_ty::D3D12_FEATURE_DATA_D3D12_OPTIONS1>() as _,
|
||||
)
|
||||
};
|
||||
|
||||
// we must be using DXC because uint64_t was added with Shader Model 6
|
||||
// and FXC only supports up to 5.1
|
||||
let int64_shader_ops_supported = dxc_container.is_some() && {
|
||||
let mut features1: d3d12_ty::D3D12_FEATURE_DATA_D3D12_OPTIONS1 =
|
||||
unsafe { mem::zeroed() };
|
||||
let hr = unsafe {
|
||||
device.CheckFeatureSupport(
|
||||
d3d12_ty::D3D12_FEATURE_D3D12_OPTIONS1,
|
||||
&mut features1 as *mut _ as *mut _,
|
||||
mem::size_of::<d3d12_ty::D3D12_FEATURE_DATA_D3D12_OPTIONS1>() as _,
|
||||
)
|
||||
};
|
||||
hr == 0 && features1.Int64ShaderOps != 0
|
||||
};
|
||||
features.set(wgt::Features::SHADER_INT64, int64_shader_ops_supported);
|
||||
features.set(
|
||||
wgt::Features::SHADER_INT64,
|
||||
dxc_container.is_some() && hr == 0 && features1.Int64ShaderOps != 0,
|
||||
);
|
||||
|
||||
features.set(
|
||||
wgt::Features::SUBGROUP,
|
||||
shader_model_support.HighestShaderModel >= d3d12_ty::D3D_SHADER_MODEL_6_0
|
||||
&& hr == 0
|
||||
&& features1.WaveOps != 0,
|
||||
);
|
||||
|
||||
// float32-filterable should always be available on d3d12
|
||||
features.set(wgt::Features::FLOAT32_FILTERABLE, true);
|
||||
@ -425,6 +437,8 @@ impl super::Adapter {
|
||||
.min(crate::MAX_VERTEX_BUFFERS as u32),
|
||||
max_vertex_attributes: d3d12_ty::D3D12_IA_VERTEX_INPUT_RESOURCE_SLOT_COUNT,
|
||||
max_vertex_buffer_array_stride: d3d12_ty::D3D12_SO_BUFFER_MAX_STRIDE_IN_BYTES,
|
||||
min_subgroup_size: 4, // Not using `features1.WaveLaneCountMin` as it is unreliable
|
||||
max_subgroup_size: 128,
|
||||
// The push constants are part of the root signature which
|
||||
// has a limit of 64 DWORDS (256 bytes), but other resources
|
||||
// also share the root signature:
|
||||
|
@ -748,6 +748,8 @@ impl super::Adapter {
|
||||
} else {
|
||||
!0
|
||||
},
|
||||
min_subgroup_size: 0,
|
||||
max_subgroup_size: 0,
|
||||
max_push_constant_size: super::MAX_PUSH_CONSTANTS as u32 * 4,
|
||||
min_uniform_buffer_offset_alignment,
|
||||
min_storage_buffer_offset_alignment,
|
||||
|
@ -813,6 +813,10 @@ impl super::PrivateCapabilities {
|
||||
None
|
||||
},
|
||||
timestamp_query_support,
|
||||
supports_simd_scoped_operations: family_check
|
||||
&& (device.supports_family(MTLGPUFamily::Metal3)
|
||||
|| device.supports_family(MTLGPUFamily::Mac2)
|
||||
|| device.supports_family(MTLGPUFamily::Apple7)),
|
||||
}
|
||||
}
|
||||
|
||||
@ -898,6 +902,10 @@ impl super::PrivateCapabilities {
|
||||
features.set(F::RG11B10UFLOAT_RENDERABLE, self.format_rg11b10_all);
|
||||
features.set(F::SHADER_UNUSED_VERTEX_OUTPUT, true);
|
||||
|
||||
if self.supports_simd_scoped_operations {
|
||||
features.insert(F::SUBGROUP | F::SUBGROUP_BARRIER);
|
||||
}
|
||||
|
||||
features
|
||||
}
|
||||
|
||||
@ -952,6 +960,8 @@ impl super::PrivateCapabilities {
|
||||
max_vertex_buffers: self.max_vertex_buffers,
|
||||
max_vertex_attributes: 31,
|
||||
max_vertex_buffer_array_stride: base.max_vertex_buffer_array_stride,
|
||||
min_subgroup_size: 4,
|
||||
max_subgroup_size: 64,
|
||||
max_push_constant_size: 0x1000,
|
||||
min_uniform_buffer_offset_alignment: self.buffer_alignment as u32,
|
||||
min_storage_buffer_offset_alignment: self.buffer_alignment as u32,
|
||||
|
@ -269,6 +269,7 @@ struct PrivateCapabilities {
|
||||
supports_shader_primitive_index: bool,
|
||||
has_unified_memory: Option<bool>,
|
||||
timestamp_query_support: TimestampQuerySupport,
|
||||
supports_simd_scoped_operations: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
|
@ -101,6 +101,9 @@ pub struct PhysicalDeviceFeatures {
|
||||
/// to Vulkan 1.3.
|
||||
zero_initialize_workgroup_memory:
|
||||
Option<vk::PhysicalDeviceZeroInitializeWorkgroupMemoryFeatures>,
|
||||
|
||||
/// Features provided by `VK_EXT_subgroup_size_control`, promoted to Vulkan 1.3.
|
||||
subgroup_size_control: Option<vk::PhysicalDeviceSubgroupSizeControlFeatures>,
|
||||
}
|
||||
|
||||
// This is safe because the structs have `p_next: *mut c_void`, which we null out/never read.
|
||||
@ -148,6 +151,9 @@ impl PhysicalDeviceFeatures {
|
||||
if let Some(ref mut feature) = self.ray_query {
|
||||
info = info.push_next(feature);
|
||||
}
|
||||
if let Some(ref mut feature) = self.subgroup_size_control {
|
||||
info = info.push_next(feature);
|
||||
}
|
||||
info
|
||||
}
|
||||
|
||||
@ -434,6 +440,17 @@ impl PhysicalDeviceFeatures {
|
||||
} else {
|
||||
None
|
||||
},
|
||||
subgroup_size_control: if device_api_version >= vk::API_VERSION_1_3
|
||||
|| enabled_extensions.contains(&vk::ExtSubgroupSizeControlFn::name())
|
||||
{
|
||||
Some(
|
||||
vk::PhysicalDeviceSubgroupSizeControlFeatures::builder()
|
||||
.subgroup_size_control(true)
|
||||
.build(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@ -638,6 +655,34 @@ impl PhysicalDeviceFeatures {
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(ref subgroup) = caps.subgroup {
|
||||
if (caps.device_api_version >= vk::API_VERSION_1_3
|
||||
|| caps.supports_extension(vk::ExtSubgroupSizeControlFn::name()))
|
||||
&& subgroup.supported_operations.contains(
|
||||
vk::SubgroupFeatureFlags::BASIC
|
||||
| vk::SubgroupFeatureFlags::VOTE
|
||||
| vk::SubgroupFeatureFlags::ARITHMETIC
|
||||
| vk::SubgroupFeatureFlags::BALLOT
|
||||
| vk::SubgroupFeatureFlags::SHUFFLE
|
||||
| vk::SubgroupFeatureFlags::SHUFFLE_RELATIVE,
|
||||
)
|
||||
{
|
||||
features.set(
|
||||
F::SUBGROUP,
|
||||
subgroup
|
||||
.supported_stages
|
||||
.contains(vk::ShaderStageFlags::COMPUTE | vk::ShaderStageFlags::FRAGMENT),
|
||||
);
|
||||
features.set(
|
||||
F::SUBGROUP_VERTEX,
|
||||
subgroup
|
||||
.supported_stages
|
||||
.contains(vk::ShaderStageFlags::VERTEX),
|
||||
);
|
||||
features.insert(F::SUBGROUP_BARRIER);
|
||||
}
|
||||
}
|
||||
|
||||
let supports_depth_format = |format| {
|
||||
supports_format(
|
||||
instance,
|
||||
@ -773,6 +818,13 @@ pub struct PhysicalDeviceProperties {
|
||||
/// `VK_KHR_driver_properties` extension, promoted to Vulkan 1.2.
|
||||
driver: Option<vk::PhysicalDeviceDriverPropertiesKHR>,
|
||||
|
||||
/// Additional `vk::PhysicalDevice` properties from Vulkan 1.1.
|
||||
subgroup: Option<vk::PhysicalDeviceSubgroupProperties>,
|
||||
|
||||
/// Additional `vk::PhysicalDevice` properties from the
|
||||
/// `VK_EXT_subgroup_size_control` extension, promoted to Vulkan 1.3.
|
||||
subgroup_size_control: Option<vk::PhysicalDeviceSubgroupSizeControlProperties>,
|
||||
|
||||
/// The device API version.
|
||||
///
|
||||
/// Which is the version of Vulkan supported for device-level functionality.
|
||||
@ -888,6 +940,11 @@ impl PhysicalDeviceProperties {
|
||||
if self.supports_extension(vk::ExtImageRobustnessFn::name()) {
|
||||
extensions.push(vk::ExtImageRobustnessFn::name());
|
||||
}
|
||||
|
||||
// Require `VK_EXT_subgroup_size_control` if the associated feature was requested
|
||||
if requested_features.contains(wgt::Features::SUBGROUP) {
|
||||
extensions.push(vk::ExtSubgroupSizeControlFn::name());
|
||||
}
|
||||
}
|
||||
|
||||
// Optional `VK_KHR_swapchain_mutable_format`
|
||||
@ -987,6 +1044,14 @@ impl PhysicalDeviceProperties {
|
||||
.min(crate::MAX_VERTEX_BUFFERS as u32),
|
||||
max_vertex_attributes: limits.max_vertex_input_attributes,
|
||||
max_vertex_buffer_array_stride: limits.max_vertex_input_binding_stride,
|
||||
min_subgroup_size: self
|
||||
.subgroup_size_control
|
||||
.map(|subgroup_size| subgroup_size.min_subgroup_size)
|
||||
.unwrap_or(0),
|
||||
max_subgroup_size: self
|
||||
.subgroup_size_control
|
||||
.map(|subgroup_size| subgroup_size.max_subgroup_size)
|
||||
.unwrap_or(0),
|
||||
max_push_constant_size: limits.max_push_constants_size,
|
||||
min_uniform_buffer_offset_alignment: limits.min_uniform_buffer_offset_alignment as u32,
|
||||
min_storage_buffer_offset_alignment: limits.min_storage_buffer_offset_alignment as u32,
|
||||
@ -1042,6 +1107,9 @@ impl super::InstanceShared {
|
||||
let supports_driver_properties = capabilities.device_api_version
|
||||
>= vk::API_VERSION_1_2
|
||||
|| capabilities.supports_extension(vk::KhrDriverPropertiesFn::name());
|
||||
let supports_subgroup_size_control = capabilities.device_api_version
|
||||
>= vk::API_VERSION_1_3
|
||||
|| capabilities.supports_extension(vk::ExtSubgroupSizeControlFn::name());
|
||||
|
||||
let supports_acceleration_structure =
|
||||
capabilities.supports_extension(vk::KhrAccelerationStructureFn::name());
|
||||
@ -1075,6 +1143,20 @@ impl super::InstanceShared {
|
||||
builder = builder.push_next(next);
|
||||
}
|
||||
|
||||
if capabilities.device_api_version >= vk::API_VERSION_1_1 {
|
||||
let next = capabilities
|
||||
.subgroup
|
||||
.insert(vk::PhysicalDeviceSubgroupProperties::default());
|
||||
builder = builder.push_next(next);
|
||||
}
|
||||
|
||||
if supports_subgroup_size_control {
|
||||
let next = capabilities
|
||||
.subgroup_size_control
|
||||
.insert(vk::PhysicalDeviceSubgroupSizeControlProperties::default());
|
||||
builder = builder.push_next(next);
|
||||
}
|
||||
|
||||
let mut properties2 = builder.build();
|
||||
unsafe {
|
||||
get_device_properties.get_physical_device_properties2(phd, &mut properties2);
|
||||
@ -1190,6 +1272,16 @@ impl super::InstanceShared {
|
||||
builder = builder.push_next(next);
|
||||
}
|
||||
|
||||
// `VK_EXT_subgroup_size_control` is promoted to 1.3
|
||||
if capabilities.device_api_version >= vk::API_VERSION_1_3
|
||||
|| capabilities.supports_extension(vk::ExtSubgroupSizeControlFn::name())
|
||||
{
|
||||
let next = features
|
||||
.subgroup_size_control
|
||||
.insert(vk::PhysicalDeviceSubgroupSizeControlFeatures::default());
|
||||
builder = builder.push_next(next);
|
||||
}
|
||||
|
||||
let mut features2 = builder.build();
|
||||
unsafe {
|
||||
get_device_properties.get_physical_device_features2(phd, &mut features2);
|
||||
@ -1382,6 +1474,9 @@ impl super::Instance {
|
||||
}),
|
||||
image_format_list: phd_capabilities.device_api_version >= vk::API_VERSION_1_2
|
||||
|| phd_capabilities.supports_extension(vk::KhrImageFormatListFn::name()),
|
||||
subgroup_size_control: phd_features
|
||||
.subgroup_size_control
|
||||
.map_or(false, |ext| ext.subgroup_size_control == vk::TRUE),
|
||||
};
|
||||
let capabilities = crate::Capabilities {
|
||||
limits: phd_capabilities.to_wgpu_limits(),
|
||||
@ -1581,6 +1676,15 @@ impl super::Adapter {
|
||||
capabilities.push(spv::Capability::Geometry);
|
||||
}
|
||||
|
||||
if features.intersects(wgt::Features::SUBGROUP | wgt::Features::SUBGROUP_VERTEX) {
|
||||
capabilities.push(spv::Capability::GroupNonUniform);
|
||||
capabilities.push(spv::Capability::GroupNonUniformVote);
|
||||
capabilities.push(spv::Capability::GroupNonUniformArithmetic);
|
||||
capabilities.push(spv::Capability::GroupNonUniformBallot);
|
||||
capabilities.push(spv::Capability::GroupNonUniformShuffle);
|
||||
capabilities.push(spv::Capability::GroupNonUniformShuffleRelative);
|
||||
}
|
||||
|
||||
if features.intersects(
|
||||
wgt::Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING
|
||||
| wgt::Features::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING,
|
||||
@ -1616,7 +1720,13 @@ impl super::Adapter {
|
||||
true, // could check `super::Workarounds::SEPARATE_ENTRY_POINTS`
|
||||
);
|
||||
spv::Options {
|
||||
lang_version: (1, 0),
|
||||
lang_version: if features
|
||||
.intersects(wgt::Features::SUBGROUP | wgt::Features::SUBGROUP_VERTEX)
|
||||
{
|
||||
(1, 3)
|
||||
} else {
|
||||
(1, 0)
|
||||
},
|
||||
flags,
|
||||
capabilities: Some(capabilities.iter().cloned().collect()),
|
||||
bounds_check_policies: naga::proc::BoundsCheckPolicies {
|
||||
|
@ -782,8 +782,14 @@ impl super::Device {
|
||||
}
|
||||
};
|
||||
|
||||
let mut flags = vk::PipelineShaderStageCreateFlags::empty();
|
||||
if self.shared.private_caps.subgroup_size_control {
|
||||
flags |= vk::PipelineShaderStageCreateFlags::ALLOW_VARYING_SUBGROUP_SIZE
|
||||
}
|
||||
|
||||
let entry_point = CString::new(stage.entry_point).unwrap();
|
||||
let create_info = vk::PipelineShaderStageCreateInfo::builder()
|
||||
.flags(flags)
|
||||
.stage(conv::map_shader_stage(stage_flags))
|
||||
.module(vk_module)
|
||||
.name(&entry_point)
|
||||
|
@ -238,6 +238,7 @@ struct PrivateCapabilities {
|
||||
robust_image_access2: bool,
|
||||
zero_initialize_workgroup_memory: bool,
|
||||
image_format_list: bool,
|
||||
subgroup_size_control: bool,
|
||||
}
|
||||
|
||||
bitflags::bitflags!(
|
||||
|
@ -143,6 +143,8 @@ fn print_adapter(output: &mut impl io::Write, report: &AdapterReport, idx: usize
|
||||
max_vertex_buffers,
|
||||
max_vertex_attributes,
|
||||
max_vertex_buffer_array_stride,
|
||||
min_subgroup_size,
|
||||
max_subgroup_size,
|
||||
max_push_constant_size,
|
||||
min_uniform_buffer_offset_alignment,
|
||||
min_storage_buffer_offset_alignment,
|
||||
@ -176,6 +178,8 @@ fn print_adapter(output: &mut impl io::Write, report: &AdapterReport, idx: usize
|
||||
writeln!(output, "\t\t Max Vertex Buffers: {max_vertex_buffers}")?;
|
||||
writeln!(output, "\t\t Max Vertex Attributes: {max_vertex_attributes}")?;
|
||||
writeln!(output, "\t\t Max Vertex Buffer Array Stride: {max_vertex_buffer_array_stride}")?;
|
||||
writeln!(output, "\t\t Min Subgroup Size: {min_subgroup_size}")?;
|
||||
writeln!(output, "\t\t Max Subgroup Size: {max_subgroup_size}")?;
|
||||
writeln!(output, "\t\t Max Push Constant Size: {max_push_constant_size}")?;
|
||||
writeln!(output, "\t\t Min Uniform Buffer Offset Alignment: {min_uniform_buffer_offset_alignment}")?;
|
||||
writeln!(output, "\t\t Min Storage Buffer Offset Alignment: {min_storage_buffer_offset_alignment}")?;
|
||||
|
@ -890,6 +890,30 @@ bitflags::bitflags! {
|
||||
///
|
||||
/// This is a native only feature.
|
||||
const SHADER_INT64 = 1 << 55;
|
||||
/// Allows compute and fragment shaders to use the subgroup operation built-ins
|
||||
///
|
||||
/// Supported Platforms:
|
||||
/// - Vulkan
|
||||
/// - DX12
|
||||
/// - Metal
|
||||
///
|
||||
/// This is a native only feature.
|
||||
const SUBGROUP = 1 << 56;
|
||||
/// Allows vertex shaders to use the subgroup operation built-ins
|
||||
///
|
||||
/// Supported Platforms:
|
||||
/// - Vulkan
|
||||
///
|
||||
/// This is a native only feature.
|
||||
const SUBGROUP_VERTEX = 1 << 57;
|
||||
/// Allows shaders to use the subgroup barrier
|
||||
///
|
||||
/// Supported Platforms:
|
||||
/// - Vulkan
|
||||
/// - Metal
|
||||
///
|
||||
/// This is a native only feature.
|
||||
const SUBGROUP_BARRIER = 1 << 58;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1136,6 +1160,11 @@ pub struct Limits {
|
||||
/// The maximum value for each dimension of a `ComputePass::dispatch(x, y, z)` operation.
|
||||
/// Defaults to 65535. Higher is "better".
|
||||
pub max_compute_workgroups_per_dimension: u32,
|
||||
|
||||
/// Minimal number of invocations in a subgroup. Higher is "better".
|
||||
pub min_subgroup_size: u32,
|
||||
/// Maximal number of invocations in a subgroup. Lower is "better".
|
||||
pub max_subgroup_size: u32,
|
||||
/// Amount of storage available for push constants in bytes. Defaults to 0. Higher is "better".
|
||||
/// Requesting more than 0 during device creation requires [`Features::PUSH_CONSTANTS`] to be enabled.
|
||||
///
|
||||
@ -1146,7 +1175,6 @@ pub struct Limits {
|
||||
/// - OpenGL doesn't natively support push constants, and are emulated with uniforms,
|
||||
/// so this number is less useful but likely 256.
|
||||
pub max_push_constant_size: u32,
|
||||
|
||||
/// Maximum number of live non-sampler bindings.
|
||||
///
|
||||
/// This limit only affects the d3d12 backend. Using a large number will allow the device
|
||||
@ -1187,6 +1215,8 @@ impl Default for Limits {
|
||||
max_compute_workgroup_size_y: 256,
|
||||
max_compute_workgroup_size_z: 64,
|
||||
max_compute_workgroups_per_dimension: 65535,
|
||||
min_subgroup_size: 0,
|
||||
max_subgroup_size: 0,
|
||||
max_push_constant_size: 0,
|
||||
max_non_sampler_bindings: 1_000_000,
|
||||
}
|
||||
@ -1218,6 +1248,8 @@ impl Limits {
|
||||
/// max_vertex_buffers: 8,
|
||||
/// max_vertex_attributes: 16,
|
||||
/// max_vertex_buffer_array_stride: 2048,
|
||||
/// min_subgroup_size: 0,
|
||||
/// max_subgroup_size: 0,
|
||||
/// max_push_constant_size: 0,
|
||||
/// min_uniform_buffer_offset_alignment: 256,
|
||||
/// min_storage_buffer_offset_alignment: 256,
|
||||
@ -1254,6 +1286,8 @@ impl Limits {
|
||||
max_vertex_buffers: 8,
|
||||
max_vertex_attributes: 16,
|
||||
max_vertex_buffer_array_stride: 2048,
|
||||
min_subgroup_size: 0,
|
||||
max_subgroup_size: 0,
|
||||
max_push_constant_size: 0,
|
||||
min_uniform_buffer_offset_alignment: 256,
|
||||
min_storage_buffer_offset_alignment: 256,
|
||||
@ -1296,6 +1330,8 @@ impl Limits {
|
||||
/// max_vertex_buffers: 8,
|
||||
/// max_vertex_attributes: 16,
|
||||
/// max_vertex_buffer_array_stride: 255, // +
|
||||
/// min_subgroup_size: 0,
|
||||
/// max_subgroup_size: 0,
|
||||
/// max_push_constant_size: 0,
|
||||
/// min_uniform_buffer_offset_alignment: 256,
|
||||
/// min_storage_buffer_offset_alignment: 256,
|
||||
@ -1326,6 +1362,8 @@ impl Limits {
|
||||
max_compute_workgroup_size_y: 0,
|
||||
max_compute_workgroup_size_z: 0,
|
||||
max_compute_workgroups_per_dimension: 0,
|
||||
min_subgroup_size: 0,
|
||||
max_subgroup_size: 0,
|
||||
|
||||
// Value supported by Intel Celeron B830 on Windows (OpenGL 3.1)
|
||||
max_inter_stage_shader_components: 31,
|
||||
@ -1418,6 +1456,10 @@ impl Limits {
|
||||
compare!(max_vertex_buffers, Less);
|
||||
compare!(max_vertex_attributes, Less);
|
||||
compare!(max_vertex_buffer_array_stride, Less);
|
||||
if self.min_subgroup_size > 0 && self.max_subgroup_size > 0 {
|
||||
compare!(min_subgroup_size, Greater);
|
||||
compare!(max_subgroup_size, Less);
|
||||
}
|
||||
compare!(max_push_constant_size, Less);
|
||||
compare!(min_uniform_buffer_offset_alignment, Greater);
|
||||
compare!(min_storage_buffer_offset_alignment, Greater);
|
||||
|
@ -737,6 +737,8 @@ fn map_wgt_limits(limits: webgpu_sys::GpuSupportedLimits) -> wgt::Limits {
|
||||
max_compute_workgroup_size_z: limits.max_compute_workgroup_size_z(),
|
||||
max_compute_workgroups_per_dimension: limits.max_compute_workgroups_per_dimension(),
|
||||
// The following are not part of WebGPU
|
||||
min_subgroup_size: wgt::Limits::default().min_subgroup_size,
|
||||
max_subgroup_size: wgt::Limits::default().max_subgroup_size,
|
||||
max_push_constant_size: wgt::Limits::default().max_push_constant_size,
|
||||
max_non_sampler_bindings: wgt::Limits::default().max_non_sampler_bindings,
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user