From 32c2ea58bc61d436b8899ebae57c04cd7d50b4ec Mon Sep 17 00:00:00 2001 From: charles-r-earp Date: Tue, 10 Nov 2020 00:10:21 -0800 Subject: [PATCH] Compute Shaders + ExecutionMode's (#195) * Created examples/ wgpu-example-compute-runner + wgpu-example-compute-shader. * Working compute example, can compile and run, set local_size. Validated changes do not break rendering example. * Added complete list of ExecutionMode's to be specified underneath ExecutionModel. Replaced SpirvAttribute::Entry ExecutionModel with an Entry struct, which includes a Vec of ExecutionMode, ExecutionModeExtra to be submitted in entry.rs. Compute example runs. Passes all tests. * Changed Cargo license info for compute examples. Simplified compute runner to be more similar to other wgpu example. Split of entry logic in symbol.rs to separate function. Fixed issue in builder/mod.rs. * Pulled in reorganization changes to crates + examples. In symbols.rs moved really_unsafe_ignore_bitcasts to its own Symbol match. In entry.rs, entry stubs now return the fn_id, so that entry_stub() can add the execution modes in one place. Passed all tests. * cargo fmt * Removed duplicate examples. Fixed cargo fmt bug in compute runner. --- Cargo.lock | 32 +- Cargo.toml | 2 + .../src/codegen_cx/declare.rs | 10 +- .../src/codegen_cx/entry.rs | 34 +- crates/rustc_codegen_spirv/src/symbols.rs | 446 +++++++++++++++--- examples/runners/wgpu-compute/Cargo.toml | 23 + examples/runners/wgpu-compute/build.rs | 8 + examples/runners/wgpu-compute/src/main.rs | 77 +++ examples/shaders/compute-shader/Cargo.toml | 12 + examples/shaders/compute-shader/src/lib.rs | 16 + 10 files changed, 575 insertions(+), 85 deletions(-) create mode 100644 examples/runners/wgpu-compute/Cargo.toml create mode 100644 examples/runners/wgpu-compute/build.rs create mode 100644 examples/runners/wgpu-compute/src/main.rs create mode 100644 examples/shaders/compute-shader/Cargo.toml create mode 100644 examples/shaders/compute-shader/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index ec4b5905a2..228cd656fb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -271,6 +271,13 @@ dependencies = [ "objc", ] +[[package]] +name = "compute-shader" +version = "0.1.0" +dependencies = [ + "spirv-std", +] + [[package]] name = "console_error_panic_hook" version = "0.1.6" @@ -584,6 +591,17 @@ dependencies = [ "winit", ] +[[package]] +name = "example-runner-wgpu-compute" +version = "0.1.0" +dependencies = [ + "futures", + "rspirv 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", + "spirv-builder", + "wasm-bindgen-futures", + "wgpu", +] + [[package]] name = "filetime" version = "0.2.12" @@ -1853,6 +1871,18 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "rspirv" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31b11de7481ced3d22182a2be7670323028b38a7cc3457e470f91d144947fba4" +dependencies = [ + "derive_more", + "fxhash", + "num-traits", + "spirv_headers 1.5.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "rspirv" version = "0.7.0" @@ -1870,7 +1900,7 @@ version = "0.1.0" dependencies = [ "bimap", "pretty_assertions", - "rspirv", + "rspirv 0.7.0 (git+https://github.com/gfx-rs/rspirv.git?rev=f11f8797bd4df2d1d22cf10767b39a5119c57551)", "spirv-tools", "tar", "tempfile", diff --git a/Cargo.toml b/Cargo.toml index ba6c57d754..b5310cd7e1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,8 +3,10 @@ members = [ "examples/runners/cpu", "examples/runners/ash", "examples/runners/wgpu", + "examples/runners/wgpu-compute", "examples/shaders/sky-shader", "examples/shaders/simplest-shader", + "examples/shaders/compute-shader", "crates/rustc_codegen_spirv", "crates/spirv-builder", "crates/spirv-std", diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs b/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs index 1bf8f976c4..9e24a029eb 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs @@ -211,13 +211,9 @@ impl<'tcx> PreDefineMethods<'tcx> for CodegenCx<'tcx> { for attr in parse_attrs(self, self.tcx.get_attrs(instance.def_id())) { match attr { - SpirvAttribute::Entry(execution_model) => self.entry_stub( - &instance, - &fn_abi, - declared, - human_name.clone(), - execution_model, - ), + SpirvAttribute::Entry(entry) => { + self.entry_stub(&instance, &fn_abi, declared, human_name.clone(), entry) + } SpirvAttribute::ReallyUnsafeIgnoreBitcasts => { self.really_unsafe_ignore_bitcasts .borrow_mut() diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs index 39910dae3f..4968d0c8d4 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs @@ -1,11 +1,9 @@ use super::CodegenCx; use crate::builder_spirv::SpirvValue; use crate::spirv_type::SpirvType; -use crate::symbols::{parse_attrs, SpirvAttribute}; +use crate::symbols::{parse_attrs, Entry, SpirvAttribute}; use rspirv::dr::Operand; -use rspirv::spirv::{ - Decoration, ExecutionMode, ExecutionModel, FunctionControl, StorageClass, Word, -}; +use rspirv::spirv::{Decoration, ExecutionModel, FunctionControl, StorageClass, Word}; use rustc_hir::Param; use rustc_middle::ty::{Instance, Ty}; use rustc_target::abi::call::{FnAbi, PassMode}; @@ -22,7 +20,7 @@ impl<'tcx> CodegenCx<'tcx> { fn_abi: &FnAbi<'_, Ty<'_>>, entry_func: SpirvValue, name: String, - execution_model: ExecutionModel, + entry: Entry, ) { let local_id = match instance.def_id().as_local() { Some(id) => id, @@ -54,11 +52,19 @@ impl<'tcx> CodegenCx<'tcx> { ), ) } - if execution_model == ExecutionModel::Kernel { - self.kernel_entry_stub(entry_func, name, execution_model); + let execution_model = entry.execution_model; + let fn_id = if execution_model == ExecutionModel::Kernel { + self.kernel_entry_stub(entry_func, name, execution_model) } else { - self.shader_entry_stub(entry_func, body.params, name, execution_model); - } + self.shader_entry_stub(entry_func, body.params, name, execution_model) + }; + let mut emit = self.emit_global(); + entry + .execution_modes + .iter() + .for_each(|(execution_mode, execution_mode_extra)| { + emit.execution_mode(fn_id, *execution_mode, execution_mode_extra); + }); } fn shader_entry_stub( @@ -67,7 +73,7 @@ impl<'tcx> CodegenCx<'tcx> { hir_params: &[Param<'tcx>], name: String, execution_model: ExecutionModel, - ) { + ) -> Word { let void = SpirvType::Void.def(self); let fn_void_void = SpirvType::Function { return_type: void, @@ -120,10 +126,7 @@ impl<'tcx> CodegenCx<'tcx> { .collect() }; emit.entry_point(execution_model, fn_id, name, interface); - if execution_model == ExecutionModel::Fragment { - // TODO: Make this configurable. - emit.execution_mode(fn_id, ExecutionMode::OriginUpperLeft, &[]); - } + fn_id } fn declare_parameter( @@ -205,7 +208,7 @@ impl<'tcx> CodegenCx<'tcx> { entry_func: SpirvValue, name: String, execution_model: ExecutionModel, - ) { + ) -> Word { let (entry_func_return, entry_func_args) = match self.lookup_type(entry_func.ty) { SpirvType::Function { return_type, @@ -241,5 +244,6 @@ impl<'tcx> CodegenCx<'tcx> { emit.end_function().unwrap(); emit.entry_point(execution_model, fn_id, name, &[]); + fn_id } } diff --git a/crates/rustc_codegen_spirv/src/symbols.rs b/crates/rustc_codegen_spirv/src/symbols.rs index 1f2f00bbdc..cd7b7a05a6 100644 --- a/crates/rustc_codegen_spirv/src/symbols.rs +++ b/crates/rustc_codegen_spirv/src/symbols.rs @@ -1,7 +1,7 @@ use crate::codegen_cx::CodegenCx; -use rspirv::spirv::{BuiltIn, ExecutionModel, StorageClass}; +use rspirv::spirv::{BuiltIn, ExecutionMode, ExecutionModel, StorageClass}; use rustc_ast::ast::{AttrKind, Attribute, Lit, LitIntType, LitKind, NestedMetaItem}; -use rustc_span::symbol::Symbol; +use rustc_span::symbol::{Ident, Symbol}; use std::collections::HashMap; /// Various places in the codebase (mostly attribute parsing) need to compare rustc Symbols to particular keywords. @@ -24,7 +24,9 @@ pub struct Symbols { pub spirv15: Symbol, descriptor_set: Symbol, binding: Symbol, + really_unsafe_ignore_bitcasts: Symbol, attributes: HashMap, + execution_modes: HashMap, } const BUILTINS: &[(&str, BuiltIn)] = { @@ -189,6 +191,118 @@ const EXECUTION_MODELS: &[(&str, ExecutionModel)] = { ] }; +#[derive(Copy, Clone, Debug)] +enum ExecutionModeExtraDim { + None, + Value, + X, + Y, + Z, +} + +const EXECUTION_MODES: &[(&str, ExecutionMode, ExecutionModeExtraDim)] = { + use ExecutionMode::*; + use ExecutionModeExtraDim::*; + &[ + ("invocations", Invocations, Value), + ("spacing_equal", SpacingEqual, None), + ("spacing_fraction_even", SpacingFractionalEven, None), + ("spacing_fraction_odd", SpacingFractionalOdd, None), + ("vertex_order_cw", VertexOrderCw, None), + ("vertex_order_ccw", VertexOrderCcw, None), + ("pixel_center_integer", PixelCenterInteger, None), + ("orgin_upper_left", OriginUpperLeft, None), + ("origin_lower_left", OriginLowerLeft, None), + ("early_fragment_tests", EarlyFragmentTests, None), + ("point_mode", PointMode, None), + ("xfb", Xfb, None), + ("depth_replacing", DepthReplacing, None), + ("depth_greater", DepthGreater, None), + ("depth_less", DepthLess, None), + ("depth_unchanged", DepthUnchanged, None), + ("local_size_x", LocalSize, X), + ("local_size_y", LocalSize, Y), + ("local_size_z", LocalSize, Z), + ("local_size_hint_x", LocalSizeHint, X), + ("local_size_hint_y", LocalSizeHint, Y), + ("local_size_hint_z", LocalSizeHint, Z), + ("input_points", InputPoints, None), + ("input_lines", InputLines, None), + ("input_lines_adjacency", InputLinesAdjacency, None), + ("triangles", Triangles, None), + ("input_triangles_adjacency", InputTrianglesAdjacency, None), + ("quads", Quads, None), + ("isolines", Isolines, None), + ("output_vertices", OutputVertices, Value), + ("output_points", OutputPoints, None), + ("output_line_strip", OutputLineStrip, None), + ("output_triangle_strip", OutputTriangleStrip, None), + ("vec_type_hint", VecTypeHint, Value), + ("contraction_off", ContractionOff, None), + ("initializer", Initializer, None), + ("finalizer", Finalizer, None), + ("subgroup_size", SubgroupSize, Value), + ("subgroups_per_workgroup", SubgroupsPerWorkgroup, Value), + ("subgroups_per_workgroup_id", SubgroupsPerWorkgroupId, Value), + ("local_size_id_x", LocalSizeId, X), + ("local_size_id_y", LocalSizeId, Y), + ("local_size_id_z", LocalSizeId, Z), + ("local_size_hint_id", LocalSizeHintId, Value), + ("post_depth_coverage", PostDepthCoverage, None), + ("denorm_preserve", DenormPreserve, None), + ("denorm_flush_to_zero", DenormFlushToZero, Value), + ( + "signed_zero_inf_nan_preserve", + SignedZeroInfNanPreserve, + Value, + ), + ("rounding_mode_rte", RoundingModeRTE, Value), + ("rounding_mode_rtz", RoundingModeRTZ, Value), + ("stencil_ref_replacing_ext", StencilRefReplacingEXT, None), + ("output_lines_nv", OutputLinesNV, None), + ("output_primitives_nv", OutputPrimitivesNV, Value), + ("derivative_group_quads_nv", DerivativeGroupQuadsNV, None), + ("output_triangles_nv", OutputTrianglesNV, None), + ( + "pixel_interlock_ordered_ext", + PixelInterlockOrderedEXT, + None, + ), + ( + "pixel_interlock_unordered_ext", + PixelInterlockUnorderedEXT, + None, + ), + ( + "sample_interlock_ordered_ext", + SampleInterlockOrderedEXT, + None, + ), + ( + "sample_interlock_unordered_ext", + SampleInterlockUnorderedEXT, + None, + ), + ( + "shading_rate_interlock_ordered_ext", + ShadingRateInterlockOrderedEXT, + None, + ), + ( + "shading_rate_interlock_unordered_ext", + ShadingRateInterlockUnorderedEXT, + None, + ), + // Reserved + /*("max_workgroup_size_intel_x", MaxWorkgroupSizeINTEL, X), + ("max_workgroup_size_intel_y", MaxWorkgroupSizeINTEL, Y), + ("max_workgroup_size_intel_z", MaxWorkgroupSizeINTEL, Z), + ("max_work_dim_intel", MaxWorkDimINTEL, Value), + ("no_global_offset_intel", NoGlobalOffsetINTEL, None), + ("num_simd_workitems_intel", NumSIMDWorkitemsINTEL, Value),*/ + ] +}; + impl Symbols { pub fn new() -> Self { let builtins = BUILTINS @@ -199,23 +313,23 @@ impl Symbols { .map(|&(a, b)| (a, SpirvAttribute::StorageClass(b))); let execution_models = EXECUTION_MODELS .iter() - .map(|&(a, b)| (a, SpirvAttribute::Entry(b))); - let custom = std::iter::once(( - "really_unsafe_ignore_bitcasts", - SpirvAttribute::ReallyUnsafeIgnoreBitcasts, - )); + .map(|&(a, b)| (a, SpirvAttribute::Entry(b.into()))); let attributes_iter = builtins .chain(storage_classes) .chain(execution_models) - .chain(custom) .map(|(a, b)| (Symbol::intern(a), b)); let mut attributes = HashMap::new(); - for attr in attributes_iter { - let old = attributes.insert(attr.0, attr.1); + attributes_iter.for_each(|(a, b)| { + let old = attributes.insert(a, b); // `.collect()` into a HashMap does not error on duplicates, so manually write out the // loop here to error on duplicates. assert!(old.is_none()); - } + }); + let mut execution_modes = HashMap::new(); + EXECUTION_MODES.iter().for_each(|(key, mode, dim)| { + let old = execution_modes.insert(Symbol::intern(key), (*mode, *dim)); + assert!(old.is_none()); + }); Self { spirv: Symbol::intern("spirv"), spirv_std: Symbol::intern("spirv_std"), @@ -231,7 +345,46 @@ impl Symbols { spirv15: Symbol::intern("spirv1.5"), descriptor_set: Symbol::intern("descriptor_set"), binding: Symbol::intern("binding"), + really_unsafe_ignore_bitcasts: Symbol::intern("really_unsafe_ignore_bitcasts"), attributes, + execution_modes, + } + } +} + +#[derive(Copy, Clone, Debug)] +pub struct ExecutionModeExtra { + args: [u32; 3], + len: u8, +} + +impl ExecutionModeExtra { + fn new(args: impl AsRef<[u32]>) -> Self { + let _args = args.as_ref(); + let mut args = [0; 3]; + args[.._args.len()].copy_from_slice(_args); + let len = _args.len() as u8; + Self { args, len } + } +} + +impl AsRef<[u32]> for ExecutionModeExtra { + fn as_ref(&self) -> &[u32] { + &self.args[..self.len as _] + } +} + +#[derive(Clone, Debug)] +pub struct Entry { + pub execution_model: ExecutionModel, + pub execution_modes: Vec<(ExecutionMode, ExecutionModeExtra)>, +} + +impl From for Entry { + fn from(execution_model: ExecutionModel) -> Self { + Self { + execution_model, + execution_modes: Vec::new(), } } } @@ -240,7 +393,7 @@ impl Symbols { pub enum SpirvAttribute { Builtin(BuiltIn), StorageClass(StorageClass), - Entry(ExecutionModel), + Entry(Entry), DescriptorSet(u32), Binding(u32), ReallyUnsafeIgnoreBitcasts, @@ -254,61 +407,70 @@ pub fn parse_attrs( cx: &CodegenCx<'_>, attrs: &[Attribute], ) -> impl Iterator { - let result = attrs.iter().flat_map(|attr| { - let is_spirv = match attr.kind { - AttrKind::Normal(ref item) => { - // TODO: We ignore the rest of the path. Is this right? - let last = item.path.segments.last(); - last.map_or(false, |seg| seg.ident.name == cx.sym.spirv) - } - AttrKind::DocComment(..) => false, - }; - let args = if !is_spirv { - // Use an empty vec here to return empty - Vec::new() - } else if let Some(args) = attr.meta_item_list() { - args - } else { - cx.tcx.sess.span_err( - attr.span, - "#[spirv(..)] attribute must have at least one argument", - ); - Vec::new() - }; - args.into_iter().filter_map(move |ref arg| { - if arg.has_name(cx.sym.descriptor_set) { - match parse_attr_int_value(cx, arg) { - Some(x) => Some(SpirvAttribute::DescriptorSet(x)), - None => None, - } - } else if arg.has_name(cx.sym.binding) { - match parse_attr_int_value(cx, arg) { - Some(x) => Some(SpirvAttribute::Binding(x)), - None => None, + let result = + attrs.iter().flat_map(|attr| { + let is_spirv = match attr.kind { + AttrKind::Normal(ref item) => { + // TODO: We ignore the rest of the path. Is this right? + let last = item.path.segments.last(); + last.map_or(false, |seg| seg.ident.name == cx.sym.spirv) } + AttrKind::DocComment(..) => false, + }; + let args = if !is_spirv { + // Use an empty vec here to return empty + Vec::new() + } else if let Some(args) = attr.meta_item_list() { + args } else { - let name = match arg.ident() { - Some(i) => i, - None => { - cx.tcx.sess.span_err( - arg.span(), - "#[spirv(..)] attribute argument must be single identifier", - ); - return None; + cx.tcx.sess.span_err( + attr.span, + "#[spirv(..)] attribute must have at least one argument", + ); + Vec::new() + }; + args.into_iter().filter_map(move |ref arg| { + if arg.has_name(cx.sym.really_unsafe_ignore_bitcasts) { + Some(SpirvAttribute::ReallyUnsafeIgnoreBitcasts) + } else if arg.has_name(cx.sym.descriptor_set) { + match parse_attr_int_value(cx, arg) { + Some(x) => Some(SpirvAttribute::DescriptorSet(x)), + None => None, } - }; - match cx.sym.attributes.get(&name.name) { - Some(a) => Some(a.clone()), - None => { - cx.tcx - .sess - .span_err(name.span, "unknown argument to spirv attribute"); - None + } else if arg.has_name(cx.sym.binding) { + match parse_attr_int_value(cx, arg) { + Some(x) => Some(SpirvAttribute::Binding(x)), + None => None, } + } else { + let name = match arg.ident() { + Some(i) => i, + None => { + cx.tcx.sess.span_err( + arg.span(), + "#[spirv(..)] attribute argument must be single identifier", + ); + return None; + } + }; + cx.sym + .attributes + .get(&name.name) + .map(|a| match a { + SpirvAttribute::Entry(entry) => SpirvAttribute::Entry( + parse_entry_attrs(cx, arg, &name, entry.execution_model), + ), + _ => a.clone(), + }) + .or_else(|| { + cx.tcx + .sess + .span_err(name.span, "unknown argument to spirv attribute"); + None + }) } - } - }) - }); + }) + }); // lifetimes are hard :( result.collect::>().into_iter() } @@ -336,3 +498,163 @@ fn parse_attr_int_value(cx: &CodegenCx<'_>, arg: &NestedMetaItem) -> Option } } } + +// for a given entry, gather up the additional attributes +// in this case ExecutionMode's, some have extra arguments +// others are specified with x, y, or z components +// ie #[spirv(fragment(origin_lower_left))] or #[spirv(gl_compute(local_size_x=64, local_size_y=8))] +fn parse_entry_attrs( + cx: &CodegenCx<'_>, + arg: &NestedMetaItem, + name: &Ident, + execution_model: ExecutionModel, +) -> Entry { + use ExecutionMode::*; + use ExecutionModel::*; + let mut entry = Entry::from(execution_model); + let mut origin_mode: Option = None; + let mut local_size: Option<[u32; 3]> = None; + let mut local_size_hint: Option<[u32; 3]> = None; + // Reserved + //let mut max_workgroup_size_intel: Option<[u32; 3]> = None; + if let Some(attrs) = arg.meta_item_list() { + attrs.iter().for_each(|attr| { + if let Some(attr_name) = attr.ident() { + if let Some((execution_mode, extra_dim)) = + cx.sym.execution_modes.get(&attr_name.name) + { + use ExecutionModeExtraDim::*; + let val = match extra_dim { + None => Option::None, + _ => parse_attr_int_value(cx, attr), + }; + match execution_mode { + OriginUpperLeft | OriginLowerLeft => { + origin_mode.replace(*execution_mode); + } + LocalSize => { + let val = val.unwrap(); + if local_size.is_none() { + local_size.replace([1, 1, 1]); + } + let local_size = local_size.as_mut().unwrap(); + match extra_dim { + X => { + local_size[0] = val; + } + Y => { + local_size[1] = val; + } + Z => { + local_size[2] = val; + } + _ => unreachable!(), + } + } + LocalSizeHint => { + let val = val.unwrap(); + if local_size_hint.is_none() { + local_size_hint.replace([1, 1, 1]); + } + let local_size_hint = local_size_hint.as_mut().unwrap(); + match extra_dim { + X => { + local_size_hint[0] = val; + } + Y => { + local_size_hint[1] = val; + } + Z => { + local_size_hint[2] = val; + } + _ => unreachable!(), + } + } + // Reserved + /*MaxWorkgroupSizeINTEL => { + let val = val.unwrap(); + if max_workgroup_size_intel.is_none() { + max_workgroup_size_intel.replace([1, 1, 1]); + } + let max_workgroup_size_intel = max_workgroup_size_intel.as_mut() + .unwrap(); + match extra_dim { + X => { + max_workgroup_size_intel[0] = val; + }, + Y => { + max_workgroup_size_intel[1] = val; + }, + Z => { + max_workgroup_size_intel[2] = val; + }, + _ => unreachable!(), + } + },*/ + _ => { + if let Some(val) = val { + entry + .execution_modes + .push((*execution_mode, ExecutionModeExtra::new([val]))); + } else { + entry + .execution_modes + .push((*execution_mode, ExecutionModeExtra::new([]))); + } + } + } + } else { + cx.tcx.sess.span_err( + attr_name.span, + &format!( + "#[spirv({}(..))] unknown attribute argument {}", + name.name.to_ident_string(), + attr_name.name.to_ident_string() + ), + ); + } + } else { + cx.tcx.sess.span_err( + arg.span(), + &format!( + "#[spirv({}(..))] attribute argument must be single identifier", + name.name.to_ident_string() + ), + ); + } + }); + } + match entry.execution_model { + Fragment => { + let origin_mode = origin_mode.unwrap_or(OriginUpperLeft); + entry + .execution_modes + .push((origin_mode, ExecutionModeExtra::new([]))); + } + GLCompute => { + let local_size = local_size.unwrap_or([1, 1, 1]); + entry + .execution_modes + .push((LocalSize, ExecutionModeExtra::new(local_size))); + } + Kernel => { + if let Some(local_size) = local_size { + entry + .execution_modes + .push((LocalSize, ExecutionModeExtra::new(local_size))); + } + if let Some(local_size_hint) = local_size_hint { + entry + .execution_modes + .push((LocalSizeHint, ExecutionModeExtra::new(local_size_hint))); + } + // Reserved + /*if let Some(max_workgroup_size_intel) = max_workgroup_size_intel { + entry.execution_modes.push((MaxWorkgroupSizeINTEL, ExecutionModeExtra::new(max_workgroup_size_intel))); + }*/ + } + //TODO: Cover more defaults + _ => {} + } + entry +} diff --git a/examples/runners/wgpu-compute/Cargo.toml b/examples/runners/wgpu-compute/Cargo.toml new file mode 100644 index 0000000000..a4b03017d3 --- /dev/null +++ b/examples/runners/wgpu-compute/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "example-runner-wgpu-compute" +version = "0.1.0" +authors = ["Embark "] +edition = "2018" +license = "MIT OR Apache-2.0" + +# See rustc_codegen_spirv/Cargo.toml for details on these features +[features] +default = ["use-compiled-tools"] +use-installed-tools = ["spirv-builder/use-installed-tools"] +use-compiled-tools = ["spirv-builder/use-compiled-tools"] + +[dependencies] +wgpu = "0.6.0" +futures = { version = "0.3", default-features = false, features = ["std", "executor"] } +rspirv = "0.7.0" + +[build-dependencies] +spirv-builder = { path = "../../../crates/spirv-builder" } + +[target.'cfg(target_arch = "wasm32")'.dependencies] +wasm-bindgen-futures = "0.4.18" diff --git a/examples/runners/wgpu-compute/build.rs b/examples/runners/wgpu-compute/build.rs new file mode 100644 index 0000000000..dd10208fde --- /dev/null +++ b/examples/runners/wgpu-compute/build.rs @@ -0,0 +1,8 @@ +use spirv_builder::SpirvBuilder; +use std::error::Error; + +fn main() -> Result<(), Box> { + // This will set the env var `compute_shader.spv` to a spir-v file that can be include!()'d + SpirvBuilder::new("../../shaders/compute-shader").build()?; + Ok(()) +} diff --git a/examples/runners/wgpu-compute/src/main.rs b/examples/runners/wgpu-compute/src/main.rs new file mode 100644 index 0000000000..c62f0189ff --- /dev/null +++ b/examples/runners/wgpu-compute/src/main.rs @@ -0,0 +1,77 @@ +fn create_device_queue() -> (wgpu::Device, wgpu::Queue) { + async fn create_device_queue_async() -> (wgpu::Device, wgpu::Queue) { + let instance = wgpu::Instance::new(wgpu::BackendBit::PRIMARY); + let adapter = instance + .request_adapter(&wgpu::RequestAdapterOptions { + power_preference: wgpu::PowerPreference::default(), + compatible_surface: None, + }) + .await + .expect("Failed to find an appropriate adapter"); + + adapter + .request_device( + &wgpu::DeviceDescriptor { + features: wgpu::Features::empty(), + limits: wgpu::Limits::default(), + shader_validation: true, + }, + None, + ) + .await + .expect("Failed to create device") + } + #[cfg(not(target_arch = "wasm32"))] + { + return futures::executor::block_on(create_device_queue_async()); + }; + #[cfg(target_arch = "wasm32")] + { + return wasm_bindgen_futures::spawn_local(create_device_queue_async()); + }; +} + +fn main() { + let (device, queue) = create_device_queue(); + + // Load the shaders from disk + let module = device.create_shader_module(wgpu::include_spirv!(env!("compute_shader.spv"))); + + let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: None, + entries: &[], + }); + + let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: None, + bind_group_layouts: &[&bind_group_layout], + push_constant_ranges: &[], + }); + + let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: Some(&pipeline_layout), + compute_stage: wgpu::ProgrammableStageDescriptor { + module: &module, + entry_point: "main_cs", + }, + }); + + let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &bind_group_layout, + entries: &[], + }); + + let mut encoder = + device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + + { + let mut cpass = encoder.begin_compute_pass(); + cpass.set_bind_group(0, &bind_group, &[]); + cpass.set_pipeline(&compute_pipeline); + cpass.dispatch(1, 1, 1); + } + + queue.submit(Some(encoder.finish())); +} diff --git a/examples/shaders/compute-shader/Cargo.toml b/examples/shaders/compute-shader/Cargo.toml new file mode 100644 index 0000000000..746178e8a7 --- /dev/null +++ b/examples/shaders/compute-shader/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "compute-shader" +version = "0.1.0" +authors = ["Embark "] +edition = "2018" +license = "MIT OR Apache-2.0" + +[lib] +crate-type = ["dylib"] + +[dependencies] +spirv-std = { path = "../../../crates/spirv-std" } diff --git a/examples/shaders/compute-shader/src/lib.rs b/examples/shaders/compute-shader/src/lib.rs new file mode 100644 index 0000000000..7d4cd860a7 --- /dev/null +++ b/examples/shaders/compute-shader/src/lib.rs @@ -0,0 +1,16 @@ +#![cfg_attr(target_arch = "spirv", no_std)] +#![feature(lang_items)] +#![feature(register_attr)] +#![register_attr(spirv)] + +extern crate spirv_std; + +#[cfg(all(not(test), target_arch = "spirv"))] +#[panic_handler] +fn panic(_: &core::panic::PanicInfo) -> ! { + loop {} +} + +#[allow(unused_attributes)] +#[spirv(gl_compute)] +pub fn main_cs() {}