feat: implement F16 support in shaders

Co-Authored-By: Erich Gubler <erichdongubler@gmail.com>
This commit is contained in:
FL33TW00D 2024-10-04 10:12:04 -04:00 committed by Erich Gubler
parent a8c9356023
commit 68bb221d40
39 changed files with 1151 additions and 63 deletions

17
Cargo.lock generated
View File

@ -1444,11 +1444,14 @@ dependencies = [
[[package]]
name = "half"
version = "2.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888"
source = "git+https://github.com/FL33TW00D/half-rs.git?branch=feature/arbitrary#6bc4bea632269b53ccb6666a8508edc25fba9f3e"
dependencies = [
"arbitrary",
"bytemuck",
"cfg-if",
"crunchy",
"num-traits",
"serde",
]
[[package]]
@ -1703,6 +1706,12 @@ dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "libm"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
[[package]]
name = "libredox"
version = "0.1.3"
@ -1885,11 +1894,13 @@ dependencies = [
"codespan-reporting",
"diff",
"env_logger",
"half",
"hexf-parse",
"hlsl-snapshots",
"indexmap",
"itertools",
"log",
"num-traits",
"petgraph",
"pp-rs",
"ron",
@ -2026,6 +2037,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
"libm",
]
[[package]]
@ -3643,6 +3655,7 @@ dependencies = [
"flume",
"getrandom",
"glam",
"half",
"ktx2",
"log",
"nanorand",

View File

@ -193,6 +193,7 @@ ndk-sys = "0.5.0"
#gpu-alloc = { path = "../gpu-alloc/gpu-alloc" }
[patch.crates-io]
half = { git = "https://github.com/FL33TW00D/half-rs.git", branch = "feature/arbitrary" }
#glow = { path = "../glow" }
#web-sys = { path = "../wasm-bindgen/crates/web-sys" }
#js-sys = { path = "../wasm-bindgen/crates/js-sys" }

View File

@ -35,6 +35,7 @@ encase = { workspace = true, features = ["glam"] }
flume.workspace = true
getrandom.workspace = true
glam.workspace = true
half = { version = "2.1.0", features = ["bytemuck"] }
ktx2.workspace = true
log.workspace = true
nanorand.workspace = true

View File

@ -17,6 +17,7 @@ pub mod mipmap;
pub mod msaa_line;
pub mod render_to_texture;
pub mod repeated_compute;
pub mod shader_f16;
pub mod shadow;
pub mod skybox;
pub mod srgb_blend;

View File

@ -146,6 +146,12 @@ const EXAMPLES: &[ExampleDesc] = &[
webgl: false, // No RODS
webgpu: true,
},
ExampleDesc {
name: "shader-f16",
function: wgpu_examples::shader_f16::main,
webgl: false, // No RODS
webgpu: true,
},
];
fn get_example_name() -> Option<String> {

View File

@ -0,0 +1,9 @@
# shader-f16
Demonstrate the ability to perform compute in F16 using wgpu.
## To Run
```
RUST_LOG=shader_f16 cargo run --bin wgpu-examples shader_f16
```

View File

@ -0,0 +1,189 @@
use half::f16;
use std::{borrow::Cow, str::FromStr};
use wgpu::util::DeviceExt;
#[cfg_attr(test, allow(dead_code))]
async fn run() {
let numbers = if std::env::args().len() <= 2 {
let default = vec![
f16::from_f32(27.),
f16::from_f32(7.),
f16::from_f32(5.),
f16::from_f32(3.),
];
println!("No numbers were provided, defaulting to {default:?}");
default
} else {
std::env::args()
.skip(2)
.map(|s| f16::from_str(&s).expect("You must pass a list of positive integers!"))
.collect()
};
let steps = execute_gpu(&numbers).await.unwrap();
println!("Steps: [{:?}]", steps);
#[cfg(target_arch = "wasm32")]
log::info!("Steps: [{:?}]", steps);
}
#[cfg_attr(test, allow(dead_code))]
async fn execute_gpu(numbers: &[f16]) -> Option<Vec<f16>> {
// Instantiates instance of WebGPU
let instance = wgpu::Instance::default();
// `request_adapter` instantiates the general connection to the GPU
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions::default())
.await?;
// `request_device` instantiates the feature specific connection to the GPU, defining some parameters,
// `features` being the available features.
let (device, queue) = adapter
.request_device(
&wgpu::DeviceDescriptor {
label: None,
required_features: wgpu::Features::SHADER_F16,
required_limits: wgpu::Limits::downlevel_defaults(),
memory_hints: Default::default(),
},
None,
)
.await
.unwrap();
execute_gpu_inner(&device, &queue, numbers).await
}
async fn execute_gpu_inner(
device: &wgpu::Device,
queue: &wgpu::Queue,
numbers: &[f16],
) -> Option<Vec<f16>> {
// Loads the shader from WGSL
let cs_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("shader.wgsl"))),
});
// Gets the size in bytes of the buffer.
let size = std::mem::size_of_val(numbers) as wgpu::BufferAddress;
// Instantiates buffer without data.
// `usage` of buffer specifies how it can be used:
// `BufferUsages::MAP_READ` allows it to be read (outside the shader).
// `BufferUsages::COPY_DST` allows it to be the destination of the copy.
let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
// Instantiates buffer with data (`numbers`).
// Usage allowing the buffer to be:
// A storage buffer (can be bound within a bind group and thus available to a shader).
// The destination of a copy.
// The source of a copy.
let storage_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Storage Buffer"),
contents: bytemuck::cast_slice(numbers),
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
});
// A bind group defines how buffers are accessed by shaders.
// It is to WebGPU what a descriptor set is to Vulkan.
// `binding` here refers to the `binding` of a buffer in the shader (`layout(set = 0, binding = 0) buffer`).
// A pipeline specifies the operation of a shader
// Instantiates the pipeline.
let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: None,
module: &cs_module,
entry_point: None,
compilation_options: Default::default(),
cache: None,
});
// Instantiates the bind group, once again specifying the binding of buffers.
let bind_group_layout = compute_pipeline.get_bind_group_layout(0);
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &bind_group_layout,
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: storage_buffer.as_entire_binding(),
}],
});
// A command encoder executes one or many pipelines.
// It is to WebGPU what a command buffer is to Vulkan.
let mut encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: None,
timestamp_writes: None,
});
cpass.set_pipeline(&compute_pipeline);
cpass.set_bind_group(0, Some(&bind_group), &[]);
cpass.insert_debug_marker("compute collatz iterations");
cpass.dispatch_workgroups(numbers.len() as u32, 1, 1); // Number of cells to run, the (x,y,z) size of item being processed
}
// Sets adds copy operation to command encoder.
// Will copy data from storage buffer on GPU to staging buffer on CPU.
encoder.copy_buffer_to_buffer(&storage_buffer, 0, &staging_buffer, 0, size);
// Submits command encoder for processing
queue.submit(Some(encoder.finish()));
// Note that we're not calling `.await` here.
let buffer_slice = staging_buffer.slice(..);
// Sets the buffer up for mapping, sending over the result of the mapping back to us when it is finished.
let (sender, receiver) = flume::bounded(1);
buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
// Poll the device in a blocking manner so that our future resolves.
// In an actual application, `device.poll(...)` should
// be called in an event loop or on another thread.
device.poll(wgpu::Maintain::wait()).panic_on_timeout();
// Awaits until `buffer_future` can be read from
if let Ok(Ok(())) = receiver.recv_async().await {
// Gets contents of buffer
let data = buffer_slice.get_mapped_range();
// Since contents are got in bytes, this converts these bytes back to u32
let result = bytemuck::cast_slice(&data).to_vec();
// With the current interface, we have to make sure all mapped views are
// dropped before we unmap the buffer.
drop(data);
staging_buffer.unmap(); // Unmaps buffer from memory
// If you are familiar with C++ these 2 lines can be thought of similarly to:
// delete myPointer;
// myPointer = NULL;
// It effectively frees the memory
// Returns data from buffer
Some(result)
} else {
panic!("failed to run compute on gpu!")
}
}
pub fn main() {
#[cfg(not(target_arch = "wasm32"))]
{
env_logger::init();
pollster::block_on(run());
}
#[cfg(target_arch = "wasm32")]
{
std::panic::set_hook(Box::new(console_error_panic_hook::hook));
console_log::init().expect("could not initialize logger");
wasm_bindgen_futures::spawn_local(run());
}
}

View File

@ -0,0 +1,9 @@
enable f16;
@group(0) @binding(0)
var<storage, read_write> values: array<vec4<f16>>; // this is used as both values and output for convenience
@compute @workgroup_size(1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
values[global_id.x] = fma(values[0], values[0], values[0]);
}

View File

@ -41,8 +41,8 @@ msl-out = []
## If you want to enable MSL output it regardless of the target platform, use `naga/msl-out`.
msl-out-if-target-apple = []
serialize = ["dep:serde", "bitflags/serde", "indexmap/serde"]
deserialize = ["dep:serde", "bitflags/serde", "indexmap/serde"]
serialize = ["dep:serde", "bitflags/serde", "indexmap/serde", "half/serde"]
deserialize = ["dep:serde", "bitflags/serde", "indexmap/serde", "half/serde"]
arbitrary = ["dep:arbitrary", "bitflags/arbitrary", "indexmap/arbitrary"]
spv-in = ["dep:petgraph", "dep:spirv"]
spv-out = ["dep:spirv"]
@ -82,6 +82,9 @@ petgraph = { version = "0.6", optional = true }
pp-rs = { version = "0.2.1", optional = true }
hexf-parse = { version = "0.2.1", optional = true }
unicode-xid = { version = "0.2.6", optional = true }
# TODO: remove `[patch]` entry in workspace `Cargo.toml` for `half` after we upstream `arbitrary` support
half = { version = "2.4.1", features = ["arbitrary", "num-traits"] }
num-traits = "0.2"
[build-dependencies]
cfg_aliases.workspace = true

View File

@ -2647,6 +2647,9 @@ impl<'a, W: Write> Writer<'a, W> {
// decimal part even it's zero which is needed for a valid glsl float constant
crate::Literal::F64(value) => write!(self.out, "{value:?}LF")?,
crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
crate::Literal::F16(_) => {
return Err(Error::Custom("GLSL has no 16-bit float type".into()));
}
// Unsigned integers need a `u` at the end
//
// While `core` doesn't necessarily need it, it's allowed and since `es` needs it we

View File

@ -2383,6 +2383,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// decimal part even it's zero
crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
crate::Literal::U32(value) => write!(self.out, "{value}u")?,
crate::Literal::I32(value) => write!(self.out, "{value}")?,
crate::Literal::U64(value) => write!(self.out, "{value}uL")?,

View File

@ -6,6 +6,8 @@ use crate::{
proc::{self, NameKey, TypeResolution},
valid, FastHashMap, FastHashSet,
};
use half::f16;
use num_traits::real::Real;
#[cfg(test)]
use std::ptr;
use std::{
@ -390,8 +392,12 @@ impl crate::Scalar {
match self {
Self {
kind: Sk::Float,
width: _,
width: 4,
} => "float",
Self {
kind: Sk::Float,
width: 2,
} => "half",
Self {
kind: Sk::Sint,
width: 4,
@ -1385,6 +1391,21 @@ impl<W: Write> Writer<W> {
crate::Literal::F64(_) => {
return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
}
crate::Literal::F16(value) => {
if value.is_infinite() {
let sign = if value.is_sign_negative() { "-" } else { "" };
write!(self.out, "{sign}INFINITY")?;
} else if value.is_nan() {
write!(self.out, "NAN")?;
} else {
let suffix = if value.fract() == f16::from_f32(0.0) {
".0h"
} else {
"h"
};
write!(self.out, "{value}{suffix}")?;
}
}
crate::Literal::F32(value) => {
if value.is_infinite() {
let sign = if value.is_sign_negative() { "-" } else { "" };

View File

@ -406,6 +406,10 @@ impl super::Instruction {
instruction
}
pub(super) fn constant_16bit(result_type_id: Word, id: Word, low: Word) -> Self {
Self::constant(result_type_id, id, &[low])
}
pub(super) fn constant_32bit(result_type_id: Word, id: Word, value: Word) -> Self {
Self::constant(result_type_id, id, &[value])
}

View File

@ -799,6 +799,15 @@ impl Writer {
if bits == 64 {
self.capabilities_used.insert(spirv::Capability::Float64);
}
if bits == 16 {
self.capabilities_used.insert(spirv::Capability::Float16);
self.capabilities_used
.insert(spirv::Capability::StorageBuffer16BitAccess);
self.capabilities_used
.insert(spirv::Capability::UniformAndStorageBuffer16BitAccess);
self.capabilities_used
.insert(spirv::Capability::StorageInputOutput16);
}
Instruction::type_float(id, bits)
}
Sk::Bool => Instruction::type_bool(id),
@ -1148,6 +1157,10 @@ impl Writer {
Instruction::constant_64bit(type_id, id, bits as u32, (bits >> 32) as u32)
}
crate::Literal::F32(value) => Instruction::constant_32bit(type_id, id, value.to_bits()),
crate::Literal::F16(value) => {
let low = value.to_bits();
Instruction::constant_16bit(type_id, id, low as u32)
}
crate::Literal::U32(value) => Instruction::constant_32bit(type_id, id, value),
crate::Literal::I32(value) => Instruction::constant_32bit(type_id, id, value as u32),
crate::Literal::U64(value) => {

View File

@ -1232,6 +1232,7 @@ impl<W: Write> Writer<W> {
match expressions[expr] {
Expression::Literal(literal) => match literal {
crate::Literal::F16(value) => write!(self.out, "{value}h")?,
crate::Literal::F32(value) => write!(self.out, "{value}f")?,
crate::Literal::U32(value) => write!(self.out, "{value}u")?,
crate::Literal::I32(value) => {
@ -1995,6 +1996,10 @@ const fn scalar_kind_str(scalar: crate::Scalar) -> &'static str {
kind: Sk::Float,
width: 4,
} => "f32",
Scalar {
kind: Sk::Float,
width: 2,
} => "f16",
Scalar {
kind: Sk::Sint,
width: 4,

View File

@ -36,6 +36,7 @@ mod null;
use convert::*;
pub use error::Error;
use function::*;
use half::f16;
use indexmap::IndexSet;
use crate::{
@ -5484,6 +5485,9 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
}) => {
let low = self.next()?;
match width {
// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#Literal
// If a numeric types bit width is less than 32-bits, the value appears in the low-order bits of the word.
2 => crate::Literal::F16(f16::from_bits(low as u16)),
4 => crate::Literal::F32(f32::from_bits(low)),
8 => {
inst.expect(5)?;

View File

@ -144,8 +144,6 @@ pub enum NumberError {
Invalid,
#[error("numeric literal not representable by target type")]
NotRepresentable,
#[error("unimplemented f16 type")]
UnimplementedF16,
}
#[derive(Copy, Clone, Debug, PartialEq)]

View File

@ -1831,6 +1831,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let expr: Typed<crate::Expression> = match *expr {
ast::Expression::Literal(literal) => {
let literal = match literal {
ast::Literal::Number(Number::F16(f)) => crate::Literal::F16(f),
ast::Literal::Number(Number::F32(f)) => crate::Literal::F32(f),
ast::Literal::Number(Number::I32(i)) => crate::Literal::I32(i),
ast::Literal::Number(Number::U32(u)) => crate::Literal::U32(u),

View File

@ -114,7 +114,10 @@ pub fn map_storage_format(word: &str, span: Span) -> Result<crate::StorageFormat
pub fn get_scalar_type(word: &str) -> Option<Scalar> {
use crate::ScalarKind as Sk;
match word {
// "f16" => Some(Scalar { kind: Sk::Float, width: 2 }),
"f16" => Some(Scalar {
kind: Sk::Float,
width: 2,
}),
"f32" => Some(Scalar {
kind: Sk::Float,
width: 4,

View File

@ -5,24 +5,29 @@ use crate::{front::wgsl::error::Error, Span};
/// Tracks the status of every enable-extension known to Naga.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct EnableExtensions {}
pub struct EnableExtensions {
/// Whether `enable f16;` was written earlier in the shader module.
f16: bool,
}
impl EnableExtensions {
pub(crate) const fn empty() -> Self {
Self {}
Self { f16: false }
}
/// Add an enable-extension to the set requested by a module.
#[allow(unreachable_code)]
pub(crate) fn add(&mut self, ext: ImplementedEnableExtension) {
let _field: &mut bool = match ext {};
*_field = true;
let field = match ext {
ImplementedEnableExtension::F16 => &mut self.f16,
};
*field = true;
}
/// Query whether an enable-extension tracked here has been requested.
#[allow(unused)]
pub(crate) const fn contains(&self, ext: ImplementedEnableExtension) -> bool {
match ext {}
match ext {
ImplementedEnableExtension::F16 => self.f16,
}
}
}
@ -37,7 +42,6 @@ impl Default for EnableExtensions {
/// WGSL spec.: <https://www.w3.org/TR/WGSL/#enable-extensions-sec>
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
pub enum EnableExtension {
#[allow(unused)]
Implemented(ImplementedEnableExtension),
Unimplemented(UnimplementedEnableExtension),
}
@ -50,7 +54,7 @@ impl EnableExtension {
/// Convert from a sentinel word in WGSL into its associated [`EnableExtension`], if possible.
pub(crate) fn from_ident(word: &str, span: Span) -> Result<Self, Error<'_>> {
Ok(match word {
Self::F16 => Self::Unimplemented(UnimplementedEnableExtension::F16),
Self::F16 => Self::Implemented(ImplementedEnableExtension::F16),
Self::CLIP_DISTANCES => {
Self::Unimplemented(UnimplementedEnableExtension::ClipDistances)
}
@ -64,9 +68,10 @@ impl EnableExtension {
/// Maps this [`EnableExtension`] into the sentinel word associated with it in WGSL.
pub const fn to_ident(self) -> &'static str {
match self {
Self::Implemented(kind) => match kind {},
Self::Implemented(kind) => match kind {
ImplementedEnableExtension::F16 => Self::F16,
},
Self::Unimplemented(kind) => match kind {
UnimplementedEnableExtension::F16 => Self::F16,
UnimplementedEnableExtension::ClipDistances => Self::CLIP_DISTANCES,
UnimplementedEnableExtension::DualSourceBlending => Self::DUAL_SOURCE_BLENDING,
},
@ -76,17 +81,24 @@ impl EnableExtension {
/// A variant of [`EnableExtension::Implemented`].
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
pub enum ImplementedEnableExtension {}
/// A variant of [`EnableExtension::Unimplemented`].
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
pub enum UnimplementedEnableExtension {
pub enum ImplementedEnableExtension {
/// Enables `f16`/`half` primitive support in all shader languages.
///
/// In the WGSL standard, this corresponds to [`enable f16;`].
///
/// [`enable f16;`]: https://www.w3.org/TR/WGSL/#extension-f16
F16,
}
impl From<ImplementedEnableExtension> for EnableExtension {
fn from(value: ImplementedEnableExtension) -> Self {
Self::Implemented(value)
}
}
/// A variant of [`EnableExtension::Unimplemented`].
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
pub enum UnimplementedEnableExtension {
/// Enables the `clip_distances` variable in WGSL.
///
/// In the WGSL standard, this corresponds to [`enable clip_distances;`].
@ -104,7 +116,6 @@ pub enum UnimplementedEnableExtension {
impl UnimplementedEnableExtension {
pub(crate) const fn tracking_issue_num(self) -> u16 {
match self {
Self::F16 => 4384,
Self::ClipDistances => 6236,
Self::DualSourceBlending => 6402,
}

View File

@ -1,6 +1,8 @@
use super::{number::consume_number, Error, ExpectedToken};
use crate::front::wgsl::error::NumberError;
use crate::front::wgsl::parse::directive::enable_extension::EnableExtensions;
use crate::front::wgsl::parse::directive::enable_extension::{
EnableExtensions, ImplementedEnableExtension,
};
use crate::front::wgsl::parse::{conv, Number};
use crate::front::wgsl::Scalar;
use crate::Span;
@ -395,14 +397,26 @@ impl<'a> Lexer<'a> {
/// Parses a generic scalar type, for example `<f32>`.
pub(in crate::front::wgsl) fn next_scalar_generic(&mut self) -> Result<Scalar, Error<'a>> {
self.expect_generic_paren('<')?;
let pair = match self.next() {
(Token::Word(word), span) => {
conv::get_scalar_type(word).ok_or(Error::UnknownScalarType(span))
}
let (scalar, span) = match self.next() {
(Token::Word(word), span) => conv::get_scalar_type(word)
.map(|scalar| (scalar, span))
.ok_or(Error::UnknownScalarType(span)),
(_, span) => Err(Error::UnknownScalarType(span)),
}?;
if matches!(scalar, Scalar::F16)
&& !self
.enable_extensions
.contains(ImplementedEnableExtension::F16)
{
return Err(Error::EnableExtensionNotEnabled {
span,
kind: ImplementedEnableExtension::F16.into(),
});
}
self.expect_generic_paren('>')?;
Ok(pair)
Ok(scalar)
}
/// Parses a generic scalar type, for example `<f32>`.
@ -412,14 +426,27 @@ impl<'a> Lexer<'a> {
&mut self,
) -> Result<(Scalar, Span), Error<'a>> {
self.expect_generic_paren('<')?;
let pair = match self.next() {
let (scalar, span) = match self.next() {
(Token::Word(word), span) => conv::get_scalar_type(word)
.map(|scalar| (scalar, span))
.ok_or(Error::UnknownScalarType(span)),
(_, span) => Err(Error::UnknownScalarType(span)),
}?;
if matches!(scalar, Scalar::F16)
&& !self
.enable_extensions
.contains(ImplementedEnableExtension::F16)
{
return Err(Error::EnableExtensionNotEnabled {
span,
kind: ImplementedEnableExtension::F16.into(),
});
}
self.expect_generic_paren('>')?;
Ok(pair)
Ok((scalar, span))
}
pub(in crate::front::wgsl) fn next_storage_access(
@ -477,6 +504,7 @@ fn sub_test(source: &str, expected_tokens: &[Token]) {
#[test]
fn test_numbers() {
use half::f16;
// WGSL spec examples //
// decimal integer
@ -501,14 +529,14 @@ fn test_numbers() {
Token::Number(Ok(Number::AbstractFloat(0.01))),
Token::Number(Ok(Number::AbstractFloat(12.34))),
Token::Number(Ok(Number::F32(0.))),
Token::Number(Err(NumberError::UnimplementedF16)),
Token::Number(Ok(Number::F16(f16::from_f32(0.)))),
Token::Number(Ok(Number::AbstractFloat(0.001))),
Token::Number(Ok(Number::AbstractFloat(43.75))),
Token::Number(Ok(Number::F32(16.))),
Token::Number(Ok(Number::AbstractFloat(0.1875))),
Token::Number(Err(NumberError::UnimplementedF16)),
Token::Number(Ok(Number::F16(f16::from_f32(0.75)))),
Token::Number(Ok(Number::AbstractFloat(0.12109375))),
Token::Number(Err(NumberError::UnimplementedF16)),
Token::Number(Ok(Number::F16(f16::from_f32(12.5)))),
],
);

View File

@ -1,7 +1,5 @@
use crate::front::wgsl::error::{Error, ExpectedToken};
use crate::front::wgsl::parse::directive::enable_extension::{
EnableExtension, EnableExtensions, UnimplementedEnableExtension,
};
use crate::front::wgsl::parse::directive::enable_extension::{EnableExtension, EnableExtensions};
use crate::front::wgsl::parse::directive::language_extension::LanguageExtension;
use crate::front::wgsl::parse::directive::DirectiveKind;
use crate::front::wgsl::parse::lexer::{Lexer, Token};
@ -670,15 +668,7 @@ impl Parser {
}
(Token::Number(res), span) => {
let _ = lexer.next();
let num = res.map_err(|err| match err {
super::error::NumberError::UnimplementedF16 => {
Error::EnableExtensionNotEnabled {
kind: EnableExtension::Unimplemented(UnimplementedEnableExtension::F16),
span,
}
}
err => Error::BadNumber(span, err),
})?;
let num = res.map_err(|err| Error::BadNumber(span, err))?;
ast::Expression::Literal(ast::Literal::Number(num))
}
(Token::Word("RAY_FLAG_NONE"), _) => {

View File

@ -1,5 +1,6 @@
use crate::front::wgsl::error::NumberError;
use crate::front::wgsl::parse::lexer::Token;
use half::f16;
/// When using this type assume no Abstract Int/Float for now
#[derive(Copy, Clone, Debug, PartialEq)]
@ -16,6 +17,8 @@ pub enum Number {
I64(i64),
/// Concrete u64
U64(u64),
/// Concrete f16
F16(f16),
/// Concrete f32
F32(f32),
/// Concrete f64
@ -362,7 +365,8 @@ fn parse_hex_float(input: &str, kind: Option<FloatKind>) -> Result<Number, Numbe
// can only be ParseHexfErrorKind::Inexact but we can't check since it's private
_ => Err(NumberError::NotRepresentable),
},
Some(FloatKind::F16) => Err(NumberError::UnimplementedF16),
// TODO: f16 is not supported by hexf_parse
Some(FloatKind::F16) => Err(NumberError::NotRepresentable),
Some(FloatKind::F32) => match hexf_parse::parse_hexf32(input, false) {
Ok(num) => Ok(Number::F32(num)),
// can only be ParseHexfErrorKind::Inexact but we can't check since it's private
@ -398,7 +402,12 @@ fn parse_dec_float(input: &str, kind: Option<FloatKind>) -> Result<Number, Numbe
.then_some(Number::F64(num))
.ok_or(NumberError::NotRepresentable)
}
Some(FloatKind::F16) => Err(NumberError::UnimplementedF16),
Some(FloatKind::F16) => {
let num = input.parse::<f16>().unwrap(); // will never fail
num.is_finite()
.then_some(Number::F16(num))
.ok_or(NumberError::NotRepresentable)
}
}
}

View File

@ -268,6 +268,7 @@ pub use crate::arena::{Arena, Handle, Range, UniqueArena};
pub use crate::span::{SourceLocation, Span, SpanContext, WithSpan};
#[cfg(feature = "arbitrary")]
use arbitrary::Arbitrary;
use half::f16;
#[cfg(feature = "deserialize")]
use serde::Deserialize;
#[cfg(feature = "serialize")]
@ -871,6 +872,7 @@ pub enum Literal {
F64(f64),
/// May not be NaN or infinity.
F32(f32),
F16(f16),
U32(u32),
I32(i32),
U64(u64),

View File

@ -1,6 +1,8 @@
use std::iter;
use arrayvec::ArrayVec;
use half::f16;
use num_traits::{real::Real, FromPrimitive, One, ToPrimitive, Zero};
use crate::{
arena::{Arena, Handle, HandleVec, UniqueArena},
@ -199,6 +201,7 @@ gen_component_wise_extractor! {
literals: [
AbstractFloat => AbstractFloat: f64,
F32 => F32: f32,
F16 => F16: f16,
AbstractInt => AbstractInt: i64,
U32 => U32: u32,
I32 => I32: i32,
@ -219,6 +222,7 @@ gen_component_wise_extractor! {
literals: [
AbstractFloat => Abstract: f64,
F32 => F32: f32,
F16 => F16: f16,
],
scalar_kinds: [
Float,
@ -244,6 +248,7 @@ gen_component_wise_extractor! {
AbstractFloat => AbstractFloat: f64,
AbstractInt => AbstractInt: i64,
F32 => F32: f32,
F16 => F16: f16,
I32 => I32: i32,
],
scalar_kinds: [
@ -1088,6 +1093,7 @@ impl<'a> ConstantEvaluator<'a> {
component_wise_scalar(self, span, [arg], |args| match args {
Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
Scalar::F16([e]) => Ok(Scalar::F16([e.abs()])),
Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.abs()])),
Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
Scalar::U32([e]) => Ok(Scalar::U32([e])), // TODO: just re-use the expression, ezpz
@ -1119,9 +1125,13 @@ impl<'a> ConstantEvaluator<'a> {
}
)
}
crate::MathFunction::Saturate => {
component_wise_float!(self, span, [arg], |e| { Ok([e.clamp(0., 1.)]) })
}
crate::MathFunction::Saturate => component_wise_float(self, span, [arg], |e| match e {
Float::F16([e]) => Ok(Float::F16(
[e.clamp(f16::from_f32(0.0), f16::from_f32(1.0))],
)),
Float::F32([e]) => Ok(Float::F32([e.clamp(0., 1.)])),
Float::Abstract([e]) => Ok(Float::Abstract([e.clamp(0., 1.)])),
}),
// trigonometry
crate::MathFunction::Cos => {
@ -1198,6 +1208,9 @@ impl<'a> ConstantEvaluator<'a> {
component_wise_float(self, span, [arg], |e| match e {
Float::Abstract([e]) => Ok(Float::Abstract([round_ties_even(e)])),
Float::F32([e]) => Ok(Float::F32([(round_ties_even(e as f64) as f32)])),
Float::F16([e]) => {
Ok(Float::F16([(f16::from_f64(round_ties_even(f64::from(e))))]))
}
})
}
crate::MathFunction::Fract => {
@ -1243,15 +1256,27 @@ impl<'a> ConstantEvaluator<'a> {
)
}
crate::MathFunction::Step => {
component_wise_float!(self, span, [arg, arg1.unwrap()], |edge, x| {
Ok([if edge <= x { 1.0 } else { 0.0 }])
component_wise_float(self, span, [arg, arg1.unwrap()], |x| match x {
Float::Abstract([edge, x]) => {
Ok(Float::Abstract([if edge <= x { 1.0 } else { 0.0 }]))
}
Float::F32([edge, x]) => Ok(Float::F32([if edge <= x { 1.0 } else { 0.0 }])),
Float::F16([edge, x]) => Ok(Float::F16([if edge <= x {
f16::one()
} else {
f16::zero()
}])),
})
}
crate::MathFunction::Sqrt => {
component_wise_float!(self, span, [arg], |e| { Ok([e.sqrt()]) })
}
crate::MathFunction::InverseSqrt => {
component_wise_float!(self, span, [arg], |e| { Ok([1. / e.sqrt()]) })
component_wise_float(self, span, [arg], |e| match e {
Float::Abstract([e]) => Ok(Float::Abstract([1. / e.sqrt()])),
Float::F32([e]) => Ok(Float::F32([1. / e.sqrt()])),
Float::F16([e]) => Ok(Float::F16([f16::from_f32(1. / f32::from(e).sqrt())])),
})
}
// bits
@ -1529,6 +1554,7 @@ impl<'a> ConstantEvaluator<'a> {
Literal::I32(v) => v,
Literal::U32(v) => v as i32,
Literal::F32(v) => v as i32,
Literal::F16(v) => f16::to_i32(&v).unwrap(), //Only None on NaN or Inf
Literal::Bool(v) => v as i32,
Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
return make_error();
@ -1540,6 +1566,7 @@ impl<'a> ConstantEvaluator<'a> {
Literal::I32(v) => v as u32,
Literal::U32(v) => v,
Literal::F32(v) => v as u32,
Literal::F16(v) => f16::to_u32(&v).unwrap(), //Only None on NaN or Inf
Literal::Bool(v) => v as u32,
Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
return make_error();
@ -1555,6 +1582,7 @@ impl<'a> ConstantEvaluator<'a> {
Literal::F64(v) => v as i64,
Literal::I64(v) => v,
Literal::U64(v) => v as i64,
Literal::F16(v) => f16::to_i64(&v).unwrap(), //Only None on NaN or Inf
Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
}),
@ -1566,9 +1594,22 @@ impl<'a> ConstantEvaluator<'a> {
Literal::F64(v) => v as u64,
Literal::I64(v) => v as u64,
Literal::U64(v) => v,
Literal::F16(v) => f16::to_u64(&v).unwrap(), //Only None on NaN or Inf
Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
}),
Sc::F16 => Literal::F16(match literal {
Literal::F16(v) => v,
Literal::F32(v) => f16::from_f32(v),
Literal::F64(v) => f16::from_f64(v),
Literal::Bool(v) => f16::from_u32(v as u32).unwrap(),
Literal::I64(v) => f16::from_i64(v).unwrap(),
Literal::U64(v) => f16::from_u64(v).unwrap(),
Literal::I32(v) => f16::from_i32(v).unwrap(),
Literal::U32(v) => f16::from_u32(v).unwrap(),
Literal::AbstractFloat(v) => f16::try_from_abstract(v)?,
Literal::AbstractInt(v) => f16::try_from_abstract(v)?,
}),
Sc::F32 => Literal::F32(match literal {
Literal::I32(v) => v as f32,
Literal::U32(v) => v as f32,
@ -1577,12 +1618,14 @@ impl<'a> ConstantEvaluator<'a> {
Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
return make_error();
}
Literal::F16(v) => f16::to_f32(v),
Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
}),
Sc::F64 => Literal::F64(match literal {
Literal::I32(v) => v as f64,
Literal::U32(v) => v as f64,
Literal::F16(v) => f16::to_f64(v),
Literal::F32(v) => v as f64,
Literal::F64(v) => v,
Literal::Bool(v) => v as u32 as f64,
@ -1594,6 +1637,7 @@ impl<'a> ConstantEvaluator<'a> {
Literal::I32(v) => v != 0,
Literal::U32(v) => v != 0,
Literal::F32(v) => v != 0.0,
Literal::F16(v) => v != f16::zero(),
Literal::Bool(v) => v,
Literal::F64(_)
| Literal::I64(_)
@ -1743,6 +1787,7 @@ impl<'a> ConstantEvaluator<'a> {
UnaryOperator::Negate => match value {
Literal::I32(v) => Literal::I32(v.wrapping_neg()),
Literal::F32(v) => Literal::F32(-v),
Literal::F16(v) => Literal::F16(-v),
Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
_ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
@ -1881,6 +1926,14 @@ impl<'a> ConstantEvaluator<'a> {
BinaryOperator::Modulo => a % b,
_ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
}),
(Literal::F16(a), Literal::F16(b)) => Literal::F16(match op {
BinaryOperator::Add => a + b,
BinaryOperator::Subtract => a - b,
BinaryOperator::Multiply => a * b,
BinaryOperator::Divide => a / b,
BinaryOperator::Modulo => a % b,
_ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
}),
(Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
Literal::AbstractInt(match op {
BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
@ -2450,6 +2503,32 @@ impl TryFromAbstract<f64> for u64 {
}
}
impl TryFromAbstract<f64> for f16 {
fn try_from_abstract(value: f64) -> Result<f16, ConstantEvaluatorError> {
let f = f16::from_f64(value);
if f.is_infinite() {
return Err(ConstantEvaluatorError::AutomaticConversionLossy {
value: format!("{value:?}"),
to_type: "f16",
});
}
Ok(f)
}
}
impl TryFromAbstract<i64> for f16 {
fn try_from_abstract(value: i64) -> Result<f16, ConstantEvaluatorError> {
let f = f16::from_i64(value);
if f.is_none() {
return Err(ConstantEvaluatorError::AutomaticConversionLossy {
value: format!("{value:?}"),
to_type: "f16",
});
}
Ok(f.unwrap())
}
}
#[cfg(test)]
mod tests {
use std::vec;

View File

@ -90,6 +90,10 @@ impl super::Scalar {
kind: crate::ScalarKind::Uint,
width: 4,
};
pub const F16: Self = Self {
kind: crate::ScalarKind::Float,
width: 2,
};
pub const F32: Self = Self {
kind: crate::ScalarKind::Float,
width: 4,
@ -157,6 +161,7 @@ impl super::Scalar {
pub enum HashableLiteral {
F64(u64),
F32(u32),
F16(u16),
U32(u32),
I32(i32),
U64(u64),
@ -171,6 +176,7 @@ impl From<crate::Literal> for HashableLiteral {
match l {
crate::Literal::F64(v) => Self::F64(v.to_bits()),
crate::Literal::F32(v) => Self::F32(v.to_bits()),
crate::Literal::F16(v) => Self::F16(v.to_bits()),
crate::Literal::U32(v) => Self::U32(v),
crate::Literal::I32(v) => Self::I32(v),
crate::Literal::U64(v) => Self::U64(v),
@ -209,6 +215,7 @@ impl crate::Literal {
match *self {
Self::F64(_) | Self::I64(_) | Self::U64(_) => 8,
Self::F32(_) | Self::U32(_) | Self::I32(_) => 4,
Self::F16(_) => 2,
Self::Bool(_) => crate::BOOL_WIDTH,
Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH,
}
@ -217,6 +224,7 @@ impl crate::Literal {
match *self {
Self::F64(_) => crate::Scalar::F64,
Self::F32(_) => crate::Scalar::F32,
Self::F16(_) => crate::Scalar::F16,
Self::U32(_) => crate::Scalar::U32,
Self::I32(_) => crate::Scalar::I32,
Self::U64(_) => crate::Scalar::U64,

View File

@ -143,6 +143,8 @@ bitflags::bitflags! {
const SHADER_INT64_ATOMIC_MIN_MAX = 0x80000;
/// Support for all atomic operations on 64-bit integers.
const SHADER_INT64_ATOMIC_ALL_OPS = 0x100000;
/// Support for 16-bit floating-point types.
const SHADER_FLOAT16 = 0x200000;
}
}

View File

@ -243,8 +243,8 @@ impl super::Validator {
pub(super) const fn check_width(&self, scalar: crate::Scalar) -> Result<(), WidthError> {
let good = match scalar.kind {
crate::ScalarKind::Bool => scalar.width == crate::BOOL_WIDTH,
crate::ScalarKind::Float => {
if scalar.width == 8 {
crate::ScalarKind::Float => match scalar.width {
8 => {
if !self.capabilities.contains(Capabilities::FLOAT64) {
return Err(WidthError::MissingCapability {
name: "f64",
@ -252,10 +252,18 @@ impl super::Validator {
});
}
true
} else {
scalar.width == 4
}
}
2 => {
if !self.capabilities.contains(Capabilities::SHADER_FLOAT16) {
return Err(WidthError::MissingCapability {
name: "f16",
flag: "FLOAT16",
});
}
true
}
_ => scalar.width == 4,
},
crate::ScalarKind::Sint => {
if scalar.width == 8 {
if !self.capabilities.contains(Capabilities::SHADER_INT64) {

View File

@ -0,0 +1,22 @@
(
god_mode: true,
spv: (
version: (1, 0),
),
hlsl: (
shader_model: V6_2,
binding_map: {},
fake_missing_bindings: true,
special_constants_binding: Some((space: 1, register: 0)),
push_constants_target: Some((space: 0, register: 0)),
zero_initialize_workgroup_memory: true,
),
msl: (
lang_version: (1, 0),
per_entry_point_map: {},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: true,
zero_initialize_workgroup_memory: true,
),
)

View File

@ -0,0 +1,87 @@
enable f16;
enable f16; //redundant directives are OK
var<private> private_variable: f16 = 1h;
const constant_variable: f16 = f16(15.2);
struct UniformCompatible {
// Other types
val_u32: u32,
val_i32: i32,
val_f32: f32,
// f16
val_f16: f16,
val_f16_2: vec2<f16>,
val_f16_3: vec3<f16>,
val_f16_4: vec4<f16>,
final_value: f16,
}
struct StorageCompatible {
val_f16_array_2: array<f16, 2>,
val_f16_array_2: array<f16, 2>,
}
@group(0) @binding(0)
var<uniform> input_uniform: UniformCompatible;
@group(0) @binding(1)
var<storage> input_storage: UniformCompatible;
@group(0) @binding(2)
var<storage> input_arrays: StorageCompatible;
@group(0) @binding(3)
var<storage, read_write> output: UniformCompatible;
@group(0) @binding(4)
var<storage, read_write> output_arrays: StorageCompatible;
fn f16_function(x: f16) -> f16 {
var val: f16 = f16(constant_variable);
// A number too big for f16
val += 1h - 33333h;
// Constructing an f16 from an AbstractInt
val += val + f16(5.);
// Constructing a f16 from other types and other types from f16.
val += f16(input_uniform.val_f32 + f32(val));
// Constructing a vec3<i64> from a i64
val += vec3<f16>(input_uniform.val_f16).z;
// Reading/writing to a uniform/storage buffer
output.val_f16 = input_uniform.val_f16 + input_storage.val_f16;
output.val_f16_2 = input_uniform.val_f16_2 + input_storage.val_f16_2;
output.val_f16_3 = input_uniform.val_f16_3 + input_storage.val_f16_3;
output.val_f16_4 = input_uniform.val_f16_4 + input_storage.val_f16_4;
output_arrays.val_f16_array_2 = input_arrays.val_f16_array_2;
// We make sure not to use 32 in these arguments, so it's clear in the results which are builtin
// constants based on the size of the type, and which are arguments.
// Numeric functions
val += abs(val);
val += clamp(val, val, val);
//val += countLeadingZeros(val);
//val += countOneBits(val);
//val += countTrailingZeros(val);
val += dot(vec2(val), vec2(val));
//val += extractBits(val, 15u, 28u);
//val += firstLeadingBit(val);
//val += firstTrailingBit(val);
//val += insertBits(val, 12li, 15u, 28u);
val += max(val, val);
val += min(val, val);
//val += reverseBits(val);
val += sign(val); // only for i64
// Make sure all the variables are used.
return f16(1.0);
}
@compute @workgroup_size(1)
fn main() {
output.final_value = f16_function(2h);
}

View File

@ -0,0 +1,107 @@
struct NagaConstants {
int first_vertex;
int first_instance;
uint other;
};
ConstantBuffer<NagaConstants> _NagaConstants: register(b0, space1);
struct UniformCompatible {
uint val_u32_;
int val_i32_;
float val_f32_;
half val_f16_;
half2 val_f16_2_;
int _pad5_0;
half3 val_f16_3_;
half4 val_f16_4_;
half final_value;
int _end_pad_0;
};
struct StorageCompatible {
half val_f16_array_2_[2];
half val_f16_array_2_1[2];
};
static const half constant_variable = 15.203125h;
static half private_variable = 1.0h;
cbuffer input_uniform : register(b0) { UniformCompatible input_uniform; }
ByteAddressBuffer input_storage : register(t1);
ByteAddressBuffer input_arrays : register(t2);
RWByteAddressBuffer output : register(u3);
RWByteAddressBuffer output_arrays : register(u4);
typedef half ret_Constructarray2_half_[2];
ret_Constructarray2_half_ Constructarray2_half_(half arg0, half arg1) {
half ret[2] = { arg0, arg1 };
return ret;
}
half f16_function(half x)
{
half val = 15.203125h;
half _expr6 = val;
val = (_expr6 + (1.0h - 33344.0h));
half _expr8 = val;
half _expr11 = val;
val = (_expr11 + (_expr8 + 5.0h));
float _expr15 = input_uniform.val_f32_;
half _expr16 = val;
half _expr20 = val;
val = (_expr20 + half((_expr15 + float(_expr16))));
half _expr24 = input_uniform.val_f16_;
half _expr27 = val;
val = (_expr27 + (_expr24).xxx.z);
half _expr33 = input_uniform.val_f16_;
half _expr36 = input_storage.Load<half>(12);
output.Store(12, (_expr33 + _expr36));
half2 _expr42 = input_uniform.val_f16_2_;
half2 _expr45 = input_storage.Load<half2>(16);
output.Store(16, (_expr42 + _expr45));
half3 _expr51 = input_uniform.val_f16_3_;
half3 _expr54 = input_storage.Load<half3>(24);
output.Store(24, (_expr51 + _expr54));
half4 _expr60 = input_uniform.val_f16_4_;
half4 _expr63 = input_storage.Load<half4>(32);
output.Store(32, (_expr60 + _expr63));
half _expr69[2] = Constructarray2_half_(input_arrays.Load<half>(0+0), input_arrays.Load<half>(0+2));
{
half _value2[2] = _expr69;
output_arrays.Store(0+0, _value2[0]);
output_arrays.Store(0+2, _value2[1]);
}
half _expr70 = val;
half _expr72 = val;
val = (_expr72 + abs(_expr70));
half _expr74 = val;
half _expr75 = val;
half _expr76 = val;
half _expr78 = val;
val = (_expr78 + clamp(_expr74, _expr75, _expr76));
half _expr80 = val;
half _expr82 = val;
half _expr85 = val;
val = (_expr85 + dot((_expr80).xx, (_expr82).xx));
half _expr87 = val;
half _expr88 = val;
half _expr90 = val;
val = (_expr90 + max(_expr87, _expr88));
half _expr92 = val;
half _expr93 = val;
half _expr95 = val;
val = (_expr95 + min(_expr92, _expr93));
half _expr97 = val;
half _expr99 = val;
val = (_expr99 + sign(_expr97));
return 1.0h;
}
[numthreads(1, 1, 1)]
void main()
{
const half _e3 = f16_function(2.0h);
output.Store(40, _e3);
return;
}

View File

@ -0,0 +1,12 @@
(
vertex:[
],
fragment:[
],
compute:[
(
entry_point:"main",
target_profile:"cs_6_2",
),
],
)

View File

@ -0,0 +1,99 @@
// language: metal1.0
#include <metal_stdlib>
#include <simd/simd.h>
using metal::uint;
struct UniformCompatible {
uint val_u32_;
int val_i32_;
float val_f32_;
half val_f16_;
char _pad4[2];
metal::half2 val_f16_2_;
char _pad5[4];
metal::half3 val_f16_3_;
metal::half4 val_f16_4_;
half final_value;
};
struct type_7 {
half inner[2];
};
struct StorageCompatible {
type_7 val_f16_array_2_;
type_7 val_f16_array_2_1;
};
constant half constant_variable = 15.203125;
half f16_function(
half x,
constant UniformCompatible& input_uniform,
device UniformCompatible const& input_storage,
device StorageCompatible const& input_arrays,
device UniformCompatible& output,
device StorageCompatible& output_arrays
) {
half val = 15.203125;
half _e6 = val;
val = _e6 + (1.0 - 33344.0);
half _e8 = val;
half _e11 = val;
val = _e11 + (_e8 + 5.0);
float _e15 = input_uniform.val_f32_;
half _e16 = val;
half _e20 = val;
val = _e20 + static_cast<half>(_e15 + static_cast<float>(_e16));
half _e24 = input_uniform.val_f16_;
half _e27 = val;
val = _e27 + metal::half3(_e24).z;
half _e33 = input_uniform.val_f16_;
half _e36 = input_storage.val_f16_;
output.val_f16_ = _e33 + _e36;
metal::half2 _e42 = input_uniform.val_f16_2_;
metal::half2 _e45 = input_storage.val_f16_2_;
output.val_f16_2_ = _e42 + _e45;
metal::half3 _e51 = input_uniform.val_f16_3_;
metal::half3 _e54 = input_storage.val_f16_3_;
output.val_f16_3_ = _e51 + _e54;
metal::half4 _e60 = input_uniform.val_f16_4_;
metal::half4 _e63 = input_storage.val_f16_4_;
output.val_f16_4_ = _e60 + _e63;
type_7 _e69 = input_arrays.val_f16_array_2_;
output_arrays.val_f16_array_2_ = _e69;
half _e70 = val;
half _e72 = val;
val = _e72 + metal::abs(_e70);
half _e74 = val;
half _e75 = val;
half _e76 = val;
half _e78 = val;
val = _e78 + metal::clamp(_e74, _e75, _e76);
half _e80 = val;
half _e82 = val;
half _e85 = val;
val = _e85 + metal::dot(metal::half2(_e80), metal::half2(_e82));
half _e87 = val;
half _e88 = val;
half _e90 = val;
val = _e90 + metal::max(_e87, _e88);
half _e92 = val;
half _e93 = val;
half _e95 = val;
val = _e95 + metal::min(_e92, _e93);
half _e97 = val;
half _e99 = val;
val = _e99 + metal::sign(_e97);
return 1.0;
}
kernel void main_(
constant UniformCompatible& input_uniform [[user(fake0)]]
, device UniformCompatible const& input_storage [[user(fake0)]]
, device StorageCompatible const& input_arrays [[user(fake0)]]
, device UniformCompatible& output [[user(fake0)]]
, device StorageCompatible& output_arrays [[user(fake0)]]
) {
half _e3 = f16_function(2.0, input_uniform, input_storage, input_arrays, output, output_arrays);
output.final_value = _e3;
return;
}

View File

@ -0,0 +1,220 @@
; SPIR-V
; Version: 1.0
; Generator: rspirv
; Bound: 157
OpCapability Shader
OpExtension "SPV_KHR_storage_buffer_storage_class"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %145 "main"
OpExecutionMode %145 LocalSize 1 1 1
OpMemberDecorate %10 0 Offset 0
OpMemberDecorate %10 1 Offset 4
OpMemberDecorate %10 2 Offset 8
OpMemberDecorate %10 3 Offset 12
OpMemberDecorate %10 4 Offset 16
OpMemberDecorate %10 5 Offset 24
OpMemberDecorate %10 6 Offset 32
OpMemberDecorate %10 7 Offset 40
OpDecorate %11 ArrayStride 2
OpMemberDecorate %13 0 Offset 0
OpMemberDecorate %13 1 Offset 4
OpDecorate %18 DescriptorSet 0
OpDecorate %18 Binding 0
OpDecorate %19 Block
OpMemberDecorate %19 0 Offset 0
OpDecorate %21 NonWritable
OpDecorate %21 DescriptorSet 0
OpDecorate %21 Binding 1
OpDecorate %22 Block
OpMemberDecorate %22 0 Offset 0
OpDecorate %24 NonWritable
OpDecorate %24 DescriptorSet 0
OpDecorate %24 Binding 2
OpDecorate %25 Block
OpMemberDecorate %25 0 Offset 0
OpDecorate %27 DescriptorSet 0
OpDecorate %27 Binding 3
OpDecorate %28 Block
OpMemberDecorate %28 0 Offset 0
OpDecorate %30 DescriptorSet 0
OpDecorate %30 Binding 4
OpDecorate %31 Block
OpMemberDecorate %31 0 Offset 0
%2 = OpTypeVoid
%3 = OpTypeFloat 16
%4 = OpTypeInt 32 0
%5 = OpTypeInt 32 1
%6 = OpTypeFloat 32
%7 = OpTypeVector %3 2
%8 = OpTypeVector %3 3
%9 = OpTypeVector %3 4
%10 = OpTypeStruct %4 %5 %6 %3 %7 %8 %9 %3
%12 = OpConstant %4 2
%11 = OpTypeArray %3 %12
%13 = OpTypeStruct %11 %11
%14 = OpConstant %3 2.1524e-41
%15 = OpConstant %3 2.7121e-41
%17 = OpTypePointer Private %3
%16 = OpVariable %17 Private %14
%19 = OpTypeStruct %10
%20 = OpTypePointer Uniform %19
%18 = OpVariable %20 Uniform
%22 = OpTypeStruct %10
%23 = OpTypePointer StorageBuffer %22
%21 = OpVariable %23 StorageBuffer
%25 = OpTypeStruct %13
%26 = OpTypePointer StorageBuffer %25
%24 = OpVariable %26 StorageBuffer
%28 = OpTypeStruct %10
%29 = OpTypePointer StorageBuffer %28
%27 = OpVariable %29 StorageBuffer
%31 = OpTypeStruct %13
%32 = OpTypePointer StorageBuffer %31
%30 = OpVariable %32 StorageBuffer
%36 = OpTypeFunction %3 %3
%37 = OpTypePointer Uniform %10
%38 = OpConstant %4 0
%40 = OpTypePointer StorageBuffer %10
%42 = OpTypePointer StorageBuffer %13
%46 = OpConstant %3 4.3073e-41
%47 = OpConstant %3 2.4753e-41
%49 = OpTypePointer Function %3
%58 = OpTypePointer Uniform %6
%67 = OpTypePointer Uniform %3
%68 = OpConstant %4 3
%75 = OpTypePointer StorageBuffer %3
%82 = OpTypePointer StorageBuffer %7
%83 = OpTypePointer Uniform %7
%84 = OpConstant %4 4
%91 = OpTypePointer StorageBuffer %8
%92 = OpTypePointer Uniform %8
%93 = OpConstant %4 5
%100 = OpTypePointer StorageBuffer %9
%101 = OpTypePointer Uniform %9
%102 = OpConstant %4 6
%109 = OpTypePointer StorageBuffer %11
%146 = OpTypeFunction %2
%152 = OpConstant %3 2.2959e-41
%155 = OpConstant %4 7
%35 = OpFunction %3 None %36
%34 = OpFunctionParameter %3
%33 = OpLabel
%48 = OpVariable %49 Function %15
%39 = OpAccessChain %37 %18 %38
%41 = OpAccessChain %40 %21 %38
%43 = OpAccessChain %42 %24 %38
%44 = OpAccessChain %40 %27 %38
%45 = OpAccessChain %42 %30 %38
OpBranch %50
%50 = OpLabel
%51 = OpFSub %3 %14 %46
%52 = OpLoad %3 %48
%53 = OpFAdd %3 %52 %51
OpStore %48 %53
%54 = OpLoad %3 %48
%55 = OpFAdd %3 %54 %47
%56 = OpLoad %3 %48
%57 = OpFAdd %3 %56 %55
OpStore %48 %57
%59 = OpAccessChain %58 %39 %12
%60 = OpLoad %6 %59
%61 = OpLoad %3 %48
%62 = OpFConvert %6 %61
%63 = OpFAdd %6 %60 %62
%64 = OpFConvert %3 %63
%65 = OpLoad %3 %48
%66 = OpFAdd %3 %65 %64
OpStore %48 %66
%69 = OpAccessChain %67 %39 %68
%70 = OpLoad %3 %69
%71 = OpCompositeConstruct %8 %70 %70 %70
%72 = OpCompositeExtract %3 %71 2
%73 = OpLoad %3 %48
%74 = OpFAdd %3 %73 %72
OpStore %48 %74
%76 = OpAccessChain %67 %39 %68
%77 = OpLoad %3 %76
%78 = OpAccessChain %75 %41 %68
%79 = OpLoad %3 %78
%80 = OpFAdd %3 %77 %79
%81 = OpAccessChain %75 %44 %68
OpStore %81 %80
%85 = OpAccessChain %83 %39 %84
%86 = OpLoad %7 %85
%87 = OpAccessChain %82 %41 %84
%88 = OpLoad %7 %87
%89 = OpFAdd %7 %86 %88
%90 = OpAccessChain %82 %44 %84
OpStore %90 %89
%94 = OpAccessChain %92 %39 %93
%95 = OpLoad %8 %94
%96 = OpAccessChain %91 %41 %93
%97 = OpLoad %8 %96
%98 = OpFAdd %8 %95 %97
%99 = OpAccessChain %91 %44 %93
OpStore %99 %98
%103 = OpAccessChain %101 %39 %102
%104 = OpLoad %9 %103
%105 = OpAccessChain %100 %41 %102
%106 = OpLoad %9 %105
%107 = OpFAdd %9 %104 %106
%108 = OpAccessChain %100 %44 %102
OpStore %108 %107
%110 = OpAccessChain %109 %43 %38
%111 = OpLoad %11 %110
%112 = OpAccessChain %109 %45 %38
OpStore %112 %111
%113 = OpLoad %3 %48
%114 = OpExtInst %3 %1 FAbs %113
%115 = OpLoad %3 %48
%116 = OpFAdd %3 %115 %114
OpStore %48 %116
%117 = OpLoad %3 %48
%118 = OpLoad %3 %48
%119 = OpLoad %3 %48
%120 = OpExtInst %3 %1 FClamp %117 %118 %119
%121 = OpLoad %3 %48
%122 = OpFAdd %3 %121 %120
OpStore %48 %122
%123 = OpLoad %3 %48
%124 = OpCompositeConstruct %7 %123 %123
%125 = OpLoad %3 %48
%126 = OpCompositeConstruct %7 %125 %125
%127 = OpDot %3 %124 %126
%128 = OpLoad %3 %48
%129 = OpFAdd %3 %128 %127
OpStore %48 %129
%130 = OpLoad %3 %48
%131 = OpLoad %3 %48
%132 = OpExtInst %3 %1 FMax %130 %131
%133 = OpLoad %3 %48
%134 = OpFAdd %3 %133 %132
OpStore %48 %134
%135 = OpLoad %3 %48
%136 = OpLoad %3 %48
%137 = OpExtInst %3 %1 FMin %135 %136
%138 = OpLoad %3 %48
%139 = OpFAdd %3 %138 %137
OpStore %48 %139
%140 = OpLoad %3 %48
%141 = OpExtInst %3 %1 FSign %140
%142 = OpLoad %3 %48
%143 = OpFAdd %3 %142 %141
OpStore %48 %143
OpReturnValue %14
OpFunctionEnd
%145 = OpFunction %2 None %146
%144 = OpLabel
%147 = OpAccessChain %37 %18 %38
%148 = OpAccessChain %40 %21 %38
%149 = OpAccessChain %42 %24 %38
%150 = OpAccessChain %40 %27 %38
%151 = OpAccessChain %42 %30 %38
OpBranch %153
%153 = OpLabel
%154 = OpFunctionCall %3 %35 %152
%156 = OpAccessChain %75 %150 %155
OpStore %156 %154
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,91 @@
struct UniformCompatible {
val_u32_: u32,
val_i32_: i32,
val_f32_: f32,
val_f16_: f16,
val_f16_2_: vec2<f16>,
val_f16_3_: vec3<f16>,
val_f16_4_: vec4<f16>,
final_value: f16,
}
struct StorageCompatible {
val_f16_array_2_: array<f16, 2>,
val_f16_array_2_1: array<f16, 2>,
}
const constant_variable: f16 = 15.203125h;
var<private> private_variable: f16 = 1h;
@group(0) @binding(0)
var<uniform> input_uniform: UniformCompatible;
@group(0) @binding(1)
var<storage> input_storage: UniformCompatible;
@group(0) @binding(2)
var<storage> input_arrays: StorageCompatible;
@group(0) @binding(3)
var<storage, read_write> output: UniformCompatible;
@group(0) @binding(4)
var<storage, read_write> output_arrays: StorageCompatible;
fn f16_function(x: f16) -> f16 {
var val: f16 = 15.203125h;
let _e6 = val;
val = (_e6 + (1h - 33344h));
let _e8 = val;
let _e11 = val;
val = (_e11 + (_e8 + 5h));
let _e15 = input_uniform.val_f32_;
let _e16 = val;
let _e20 = val;
val = (_e20 + f16((_e15 + f32(_e16))));
let _e24 = input_uniform.val_f16_;
let _e27 = val;
val = (_e27 + vec3(_e24).z);
let _e33 = input_uniform.val_f16_;
let _e36 = input_storage.val_f16_;
output.val_f16_ = (_e33 + _e36);
let _e42 = input_uniform.val_f16_2_;
let _e45 = input_storage.val_f16_2_;
output.val_f16_2_ = (_e42 + _e45);
let _e51 = input_uniform.val_f16_3_;
let _e54 = input_storage.val_f16_3_;
output.val_f16_3_ = (_e51 + _e54);
let _e60 = input_uniform.val_f16_4_;
let _e63 = input_storage.val_f16_4_;
output.val_f16_4_ = (_e60 + _e63);
let _e69 = input_arrays.val_f16_array_2_;
output_arrays.val_f16_array_2_ = _e69;
let _e70 = val;
let _e72 = val;
val = (_e72 + abs(_e70));
let _e74 = val;
let _e75 = val;
let _e76 = val;
let _e78 = val;
val = (_e78 + clamp(_e74, _e75, _e76));
let _e80 = val;
let _e82 = val;
let _e85 = val;
val = (_e85 + dot(vec2(_e80), vec2(_e82)));
let _e87 = val;
let _e88 = val;
let _e90 = val;
val = (_e90 + max(_e87, _e88));
let _e92 = val;
let _e93 = val;
let _e95 = val;
val = (_e95 + min(_e92, _e93));
let _e97 = val;
let _e99 = val;
val = (_e99 + sign(_e97));
return 1h;
}
@compute @workgroup_size(1, 1, 1)
fn main() {
let _e3 = f16_function(2h);
output.final_value = _e3;
return;
}

View File

@ -905,6 +905,10 @@ fn convert_wgsl() {
"int64",
Targets::SPIRV | Targets::HLSL | Targets::WGSL | Targets::METAL,
),
(
"float16",
Targets::SPIRV | Targets::HLSL | Targets::WGSL | Targets::METAL,
),
(
"subgroup-operations",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,

View File

@ -481,6 +481,10 @@ pub fn create_validator(
features.contains(wgt::Features::PUSH_CONSTANTS),
);
caps.set(Caps::FLOAT64, features.contains(wgt::Features::SHADER_F64));
caps.set(
Caps::SHADER_FLOAT16,
features.contains(wgt::Features::SHADER_F16),
);
caps.set(
Caps::PRIMITIVE_INDEX,
features.contains(wgt::Features::SHADER_PRIMITIVE_INDEX),

View File

@ -1,5 +1,5 @@
use std::{
mem::{size_of, size_of_val},
mem::{self, size_of, size_of_val},
ptr,
sync::Arc,
thread,
@ -378,6 +378,24 @@ impl super::Adapter {
&& features1.Int64ShaderOps.as_bool(),
);
let float16_supported = {
let mut features4: Direct3D12::D3D12_FEATURE_DATA_D3D12_OPTIONS4 =
unsafe { mem::zeroed() };
let hr = unsafe {
device.CheckFeatureSupport(
Direct3D12::D3D12_FEATURE_D3D12_OPTIONS4, // https://learn.microsoft.com/en-us/windows/win32/api/d3d12/ne-d3d12-d3d12_feature#syntax
ptr::from_mut(&mut features4).cast(),
size_of::<Direct3D12::D3D12_FEATURE_DATA_D3D12_OPTIONS4>() as _,
)
};
hr.is_ok() && features4.Native16BitShaderOpsSupported.as_bool()
};
features.set(
wgt::Features::SHADER_F16,
shader_model >= naga::back::hlsl::ShaderModel::V6_2 && float16_supported,
);
features.set(
wgt::Features::SUBGROUP,
shader_model >= naga::back::hlsl::ShaderModel::V6_0

View File

@ -382,6 +382,7 @@ impl PhysicalDeviceFeatures {
vk::PhysicalDeviceShaderFloat16Int8Features::default().shader_float16(true),
vk::PhysicalDevice16BitStorageFeatures::default()
.storage_buffer16_bit_access(true)
.storage_input_output16(true)
.uniform_and_storage_buffer16_bit_access(true),
))
} else {
@ -664,7 +665,8 @@ impl PhysicalDeviceFeatures {
F::SHADER_F16,
f16_i8.shader_float16 != 0
&& bit16.storage_buffer16_bit_access != 0
&& bit16.uniform_and_storage_buffer16_bit_access != 0,
&& bit16.uniform_and_storage_buffer16_bit_access != 0
&& bit16.storage_input_output16 != 0,
);
}