Basic support for WGSL

This commit is contained in:
Dzmitry Malyshau 2020-06-17 13:27:36 -04:00
parent fc2dd481b2
commit 35a1dc3076
7 changed files with 81 additions and 56 deletions

5
Cargo.lock generated
View File

@ -484,8 +484,7 @@ dependencies = [
[[package]]
name = "gfx-memory"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2eed6cda674d9cd4d92229102dbd544292124533d236904f987e9afab456137"
source = "git+https://github.com/gfx-rs/gfx-extras?rev=438353c3f75368c12024ad2fc03cbeb15f351fd9#438353c3f75368c12024ad2fc03cbeb15f351fd9"
dependencies = [
"fxhash",
"gfx-hal",
@ -737,7 +736,7 @@ dependencies = [
[[package]]
name = "naga"
version = "0.1.0"
source = "git+https://github.com/gfx-rs/naga?rev=bce6358eb1026c13d2f1c6d365af37afe8869a86#bce6358eb1026c13d2f1c6d365af37afe8869a86"
source = "git+https://github.com/gfx-rs/naga?rev=e3aea9619865b16a24164d46ab29cca36ad7daf2#e3aea9619865b16a24164d46ab29cca36ad7daf2"
dependencies = [
"bitflags",
"fxhash",

View File

@ -314,12 +314,7 @@ impl GlobalExt for wgc::hub::Global<IdentityPassThroughFactory> {
.collect::<Vec<_>>();
self.device_create_shader_module::<B>(
device,
&wgc::pipeline::ShaderModuleDescriptor {
code: wgc::U32Array {
bytes: spv.as_ptr(),
length: spv.len(),
},
},
wgc::pipeline::ShaderModuleSource::SpirV(&spv),
id,
);
}

View File

@ -41,7 +41,7 @@ vec_map = "0.8.1"
[dependencies.naga]
git = "https://github.com/gfx-rs/naga"
rev = "bce6358eb1026c13d2f1c6d365af37afe8869a86"
rev = "e3aea9619865b16a24164d46ab29cca36ad7daf2"
[dependencies.gfx-memory]
git = "https://github.com/gfx-rs/gfx-extras"

View File

@ -1749,22 +1749,26 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
pub fn device_create_shader_module<B: GfxBackend>(
&self,
device_id: id::DeviceId,
desc: &pipeline::ShaderModuleDescriptor,
source: pipeline::ShaderModuleSource,
id_in: Input<G, id::ShaderModuleId>,
) -> id::ShaderModuleId {
let hub = B::hub(self);
let mut token = Token::root();
let (device_guard, mut token) = hub.devices.read(&mut token);
let device = &device_guard[device_id];
let spv_owned;
let spv_flags = if cfg!(debug_assertions) {
naga::back::spv::WriterFlags::DEBUG
} else {
naga::back::spv::WriterFlags::empty()
};
let spv = unsafe { slice::from_raw_parts(desc.code.bytes, desc.code.length) };
let raw = unsafe { device.raw.create_shader_module(spv).unwrap() };
let (spv, naga) = match source {
pipeline::ShaderModuleSource::SpirV(spv) => {
let module = if device.private_features.shader_validation {
// Parse the given shader code and store its representation.
let spv_iter = spv.into_iter().cloned();
let mut parser = naga::front::spirv::Parser::new(spv_iter);
parser
naga::front::spv::Parser::new(spv_iter)
.parse()
.map_err(|err| {
log::warn!("Failed to parse shader SPIR-V code: {:?}", err);
@ -1774,13 +1778,40 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
} else {
None
};
(spv, module)
}
pipeline::ShaderModuleSource::Wgsl(code) => {
let module = naga::front::wgsl::parse_str(code).unwrap();
spv_owned = naga::back::spv::Writer::new(&module.header, spv_flags).write(&module);
(
spv_owned.as_slice(),
if device.private_features.shader_validation {
Some(module)
} else {
None
},
)
}
pipeline::ShaderModuleSource::Naga(module) => {
spv_owned = naga::back::spv::Writer::new(&module.header, spv_flags).write(&module);
(
spv_owned.as_slice(),
if device.private_features.shader_validation {
Some(module)
} else {
None
},
)
}
};
let shader = pipeline::ShaderModule {
raw,
raw: unsafe { device.raw.create_shader_module(spv).unwrap() },
device_id: Stored {
value: device_id,
ref_count: device.life_guard.add_ref(),
},
module,
module: naga,
};
let id = hub
@ -1791,7 +1822,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
Some(ref trace) => {
let mut trace = trace.lock();
let data = trace.make_binary("spv", unsafe {
slice::from_raw_parts(desc.code.bytes as *const u8, desc.code.length * 4)
slice::from_raw_parts(spv.as_ptr() as *const u8, spv.len() * 4)
});
trace.add(trace::Action::CreateShaderModule { id, data });
}

View File

@ -163,13 +163,6 @@ struct Stored<T> {
ref_count: RefCount,
}
#[repr(C)]
#[derive(Debug)]
pub struct U32Array {
pub bytes: *const u32,
pub length: usize,
}
#[derive(Clone, Copy, Debug)]
struct PrivateFeatures {
shader_validation: bool,

View File

@ -6,7 +6,7 @@ use crate::{
device::RenderPassContext,
id::{DeviceId, PipelineLayoutId, ShaderModuleId},
validation::StageError,
LifeGuard, RawString, RefCount, Stored, U32Array,
LifeGuard, RawString, RefCount, Stored,
};
use std::borrow::Borrow;
use wgt::{
@ -33,8 +33,10 @@ pub struct VertexStateDescriptor {
#[repr(C)]
#[derive(Debug)]
pub struct ShaderModuleDescriptor {
pub code: U32Array,
pub enum ShaderModuleSource<'a> {
SpirV(&'a [u32]),
Wgsl(&'a str),
Naga(naga::Module),
}
#[derive(Debug)]

View File

@ -23,10 +23,12 @@ pub enum BindingError {
WrongTextureViewDimension { dim: spirv::Dim, is_array: bool },
/// The component type of a sampled texture doesn't match the shader.
WrongTextureComponentType(Option<naga::ScalarKind>),
/// Texture sampling capability doesn't match with the shader.
/// Texture sampling capability doesn't match the shader.
WrongTextureSampled,
/// The multisampled flag doesn't match.
/// The multisampled flag doesn't match the shader.
WrongTextureMultisampled,
/// The comparison flag doesn't match the shader.
WrongSamplerComparison,
}
#[derive(Clone, Debug)]
@ -57,12 +59,12 @@ pub enum StageError {
fn get_aligned_type_size(
module: &naga::Module,
inner: &naga::TypeInner,
handle: naga::Handle<naga::Type>,
allow_unbound: bool,
) -> wgt::BufferAddress {
use naga::TypeInner as Ti;
//TODO: take alignment into account!
match *inner {
match module.types[handle].inner {
Ti::Scalar { kind: _, width } => width as wgt::BufferAddress / 8,
Ti::Vector {
size,
@ -82,14 +84,15 @@ fn get_aligned_type_size(
Ti::Array {
base,
size: naga::ArraySize::Static(size),
} => {
size as wgt::BufferAddress
* get_aligned_type_size(module, &module.types[base].inner, false)
}
} => size as wgt::BufferAddress * get_aligned_type_size(module, base, false),
Ti::Array {
base,
size: naga::ArraySize::Dynamic,
} if allow_unbound => get_aligned_type_size(module, &module.types[base].inner, false),
} if allow_unbound => get_aligned_type_size(module, base, false),
Ti::Struct { ref members } => members
.iter()
.map(|member| get_aligned_type_size(module, member.ty, false))
.sum(),
_ => panic!("Unexpected struct field"),
}
}
@ -128,11 +131,7 @@ fn check_binding(
};
let mut actual_size = 0;
for (i, member) in members.iter().enumerate() {
actual_size += get_aligned_type_size(
module,
&module.types[member.ty].inner,
i + 1 == members.len(),
);
actual_size += get_aligned_type_size(module, member.ty, i + 1 == members.len());
}
match min_size {
Some(non_zero) if non_zero.get() < actual_size => {
@ -142,8 +141,14 @@ fn check_binding(
}
allowed_usage
}
naga::TypeInner::Sampler => match entry.ty {
BindingType::Sampler { .. } => naga::GlobalUse::empty(),
naga::TypeInner::Sampler { comparison } => match entry.ty {
BindingType::Sampler { comparison: cmp } => {
if cmp == comparison {
naga::GlobalUse::empty()
} else {
return Err(BindingError::WrongSamplerComparison);
}
}
_ => return Err(BindingError::WrongType),
},
naga::TypeInner::Image { base, dim, flags } => {