mirror of
https://github.com/gfx-rs/wgpu.git
synced 2024-11-21 22:33:49 +00:00
feat: implement F16 support in shaders
Co-Authored-By: Erich Gubler <erichdongubler@gmail.com>
This commit is contained in:
parent
a8c9356023
commit
68bb221d40
17
Cargo.lock
generated
17
Cargo.lock
generated
@ -1444,11 +1444,14 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "half"
|
||||
version = "2.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888"
|
||||
source = "git+https://github.com/FL33TW00D/half-rs.git?branch=feature/arbitrary#6bc4bea632269b53ccb6666a8508edc25fba9f3e"
|
||||
dependencies = [
|
||||
"arbitrary",
|
||||
"bytemuck",
|
||||
"cfg-if",
|
||||
"crunchy",
|
||||
"num-traits",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1703,6 +1706,12 @@ dependencies = [
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libm"
|
||||
version = "0.2.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
|
||||
|
||||
[[package]]
|
||||
name = "libredox"
|
||||
version = "0.1.3"
|
||||
@ -1885,11 +1894,13 @@ dependencies = [
|
||||
"codespan-reporting",
|
||||
"diff",
|
||||
"env_logger",
|
||||
"half",
|
||||
"hexf-parse",
|
||||
"hlsl-snapshots",
|
||||
"indexmap",
|
||||
"itertools",
|
||||
"log",
|
||||
"num-traits",
|
||||
"petgraph",
|
||||
"pp-rs",
|
||||
"ron",
|
||||
@ -2026,6 +2037,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"libm",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -3643,6 +3655,7 @@ dependencies = [
|
||||
"flume",
|
||||
"getrandom",
|
||||
"glam",
|
||||
"half",
|
||||
"ktx2",
|
||||
"log",
|
||||
"nanorand",
|
||||
|
@ -193,6 +193,7 @@ ndk-sys = "0.5.0"
|
||||
#gpu-alloc = { path = "../gpu-alloc/gpu-alloc" }
|
||||
|
||||
[patch.crates-io]
|
||||
half = { git = "https://github.com/FL33TW00D/half-rs.git", branch = "feature/arbitrary" }
|
||||
#glow = { path = "../glow" }
|
||||
#web-sys = { path = "../wasm-bindgen/crates/web-sys" }
|
||||
#js-sys = { path = "../wasm-bindgen/crates/js-sys" }
|
||||
|
@ -35,6 +35,7 @@ encase = { workspace = true, features = ["glam"] }
|
||||
flume.workspace = true
|
||||
getrandom.workspace = true
|
||||
glam.workspace = true
|
||||
half = { version = "2.1.0", features = ["bytemuck"] }
|
||||
ktx2.workspace = true
|
||||
log.workspace = true
|
||||
nanorand.workspace = true
|
||||
|
@ -17,6 +17,7 @@ pub mod mipmap;
|
||||
pub mod msaa_line;
|
||||
pub mod render_to_texture;
|
||||
pub mod repeated_compute;
|
||||
pub mod shader_f16;
|
||||
pub mod shadow;
|
||||
pub mod skybox;
|
||||
pub mod srgb_blend;
|
||||
|
@ -146,6 +146,12 @@ const EXAMPLES: &[ExampleDesc] = &[
|
||||
webgl: false, // No RODS
|
||||
webgpu: true,
|
||||
},
|
||||
ExampleDesc {
|
||||
name: "shader-f16",
|
||||
function: wgpu_examples::shader_f16::main,
|
||||
webgl: false, // No RODS
|
||||
webgpu: true,
|
||||
},
|
||||
];
|
||||
|
||||
fn get_example_name() -> Option<String> {
|
||||
|
9
examples/src/shader_f16/README.md
Normal file
9
examples/src/shader_f16/README.md
Normal file
@ -0,0 +1,9 @@
|
||||
# shader-f16
|
||||
|
||||
Demonstrate the ability to perform compute in F16 using wgpu.
|
||||
|
||||
## To Run
|
||||
|
||||
```
|
||||
RUST_LOG=shader_f16 cargo run --bin wgpu-examples shader_f16
|
||||
```
|
189
examples/src/shader_f16/mod.rs
Normal file
189
examples/src/shader_f16/mod.rs
Normal file
@ -0,0 +1,189 @@
|
||||
use half::f16;
|
||||
use std::{borrow::Cow, str::FromStr};
|
||||
use wgpu::util::DeviceExt;
|
||||
|
||||
#[cfg_attr(test, allow(dead_code))]
|
||||
async fn run() {
|
||||
let numbers = if std::env::args().len() <= 2 {
|
||||
let default = vec![
|
||||
f16::from_f32(27.),
|
||||
f16::from_f32(7.),
|
||||
f16::from_f32(5.),
|
||||
f16::from_f32(3.),
|
||||
];
|
||||
println!("No numbers were provided, defaulting to {default:?}");
|
||||
default
|
||||
} else {
|
||||
std::env::args()
|
||||
.skip(2)
|
||||
.map(|s| f16::from_str(&s).expect("You must pass a list of positive integers!"))
|
||||
.collect()
|
||||
};
|
||||
|
||||
let steps = execute_gpu(&numbers).await.unwrap();
|
||||
println!("Steps: [{:?}]", steps);
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
log::info!("Steps: [{:?}]", steps);
|
||||
}
|
||||
|
||||
#[cfg_attr(test, allow(dead_code))]
|
||||
async fn execute_gpu(numbers: &[f16]) -> Option<Vec<f16>> {
|
||||
// Instantiates instance of WebGPU
|
||||
let instance = wgpu::Instance::default();
|
||||
|
||||
// `request_adapter` instantiates the general connection to the GPU
|
||||
let adapter = instance
|
||||
.request_adapter(&wgpu::RequestAdapterOptions::default())
|
||||
.await?;
|
||||
|
||||
// `request_device` instantiates the feature specific connection to the GPU, defining some parameters,
|
||||
// `features` being the available features.
|
||||
let (device, queue) = adapter
|
||||
.request_device(
|
||||
&wgpu::DeviceDescriptor {
|
||||
label: None,
|
||||
required_features: wgpu::Features::SHADER_F16,
|
||||
required_limits: wgpu::Limits::downlevel_defaults(),
|
||||
memory_hints: Default::default(),
|
||||
},
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
execute_gpu_inner(&device, &queue, numbers).await
|
||||
}
|
||||
|
||||
async fn execute_gpu_inner(
|
||||
device: &wgpu::Device,
|
||||
queue: &wgpu::Queue,
|
||||
numbers: &[f16],
|
||||
) -> Option<Vec<f16>> {
|
||||
// Loads the shader from WGSL
|
||||
let cs_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
|
||||
label: None,
|
||||
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("shader.wgsl"))),
|
||||
});
|
||||
|
||||
// Gets the size in bytes of the buffer.
|
||||
let size = std::mem::size_of_val(numbers) as wgpu::BufferAddress;
|
||||
|
||||
// Instantiates buffer without data.
|
||||
// `usage` of buffer specifies how it can be used:
|
||||
// `BufferUsages::MAP_READ` allows it to be read (outside the shader).
|
||||
// `BufferUsages::COPY_DST` allows it to be the destination of the copy.
|
||||
let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: None,
|
||||
size,
|
||||
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
// Instantiates buffer with data (`numbers`).
|
||||
// Usage allowing the buffer to be:
|
||||
// A storage buffer (can be bound within a bind group and thus available to a shader).
|
||||
// The destination of a copy.
|
||||
// The source of a copy.
|
||||
let storage_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
|
||||
label: Some("Storage Buffer"),
|
||||
contents: bytemuck::cast_slice(numbers),
|
||||
usage: wgpu::BufferUsages::STORAGE
|
||||
| wgpu::BufferUsages::COPY_DST
|
||||
| wgpu::BufferUsages::COPY_SRC,
|
||||
});
|
||||
|
||||
// A bind group defines how buffers are accessed by shaders.
|
||||
// It is to WebGPU what a descriptor set is to Vulkan.
|
||||
// `binding` here refers to the `binding` of a buffer in the shader (`layout(set = 0, binding = 0) buffer`).
|
||||
|
||||
// A pipeline specifies the operation of a shader
|
||||
|
||||
// Instantiates the pipeline.
|
||||
let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
|
||||
label: None,
|
||||
layout: None,
|
||||
module: &cs_module,
|
||||
entry_point: None,
|
||||
compilation_options: Default::default(),
|
||||
cache: None,
|
||||
});
|
||||
|
||||
// Instantiates the bind group, once again specifying the binding of buffers.
|
||||
let bind_group_layout = compute_pipeline.get_bind_group_layout(0);
|
||||
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
|
||||
label: None,
|
||||
layout: &bind_group_layout,
|
||||
entries: &[wgpu::BindGroupEntry {
|
||||
binding: 0,
|
||||
resource: storage_buffer.as_entire_binding(),
|
||||
}],
|
||||
});
|
||||
|
||||
// A command encoder executes one or many pipelines.
|
||||
// It is to WebGPU what a command buffer is to Vulkan.
|
||||
let mut encoder =
|
||||
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
|
||||
{
|
||||
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
|
||||
label: None,
|
||||
timestamp_writes: None,
|
||||
});
|
||||
cpass.set_pipeline(&compute_pipeline);
|
||||
cpass.set_bind_group(0, Some(&bind_group), &[]);
|
||||
cpass.insert_debug_marker("compute collatz iterations");
|
||||
cpass.dispatch_workgroups(numbers.len() as u32, 1, 1); // Number of cells to run, the (x,y,z) size of item being processed
|
||||
}
|
||||
// Sets adds copy operation to command encoder.
|
||||
// Will copy data from storage buffer on GPU to staging buffer on CPU.
|
||||
encoder.copy_buffer_to_buffer(&storage_buffer, 0, &staging_buffer, 0, size);
|
||||
|
||||
// Submits command encoder for processing
|
||||
queue.submit(Some(encoder.finish()));
|
||||
|
||||
// Note that we're not calling `.await` here.
|
||||
let buffer_slice = staging_buffer.slice(..);
|
||||
// Sets the buffer up for mapping, sending over the result of the mapping back to us when it is finished.
|
||||
let (sender, receiver) = flume::bounded(1);
|
||||
buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
|
||||
|
||||
// Poll the device in a blocking manner so that our future resolves.
|
||||
// In an actual application, `device.poll(...)` should
|
||||
// be called in an event loop or on another thread.
|
||||
device.poll(wgpu::Maintain::wait()).panic_on_timeout();
|
||||
|
||||
// Awaits until `buffer_future` can be read from
|
||||
if let Ok(Ok(())) = receiver.recv_async().await {
|
||||
// Gets contents of buffer
|
||||
let data = buffer_slice.get_mapped_range();
|
||||
// Since contents are got in bytes, this converts these bytes back to u32
|
||||
let result = bytemuck::cast_slice(&data).to_vec();
|
||||
|
||||
// With the current interface, we have to make sure all mapped views are
|
||||
// dropped before we unmap the buffer.
|
||||
drop(data);
|
||||
staging_buffer.unmap(); // Unmaps buffer from memory
|
||||
// If you are familiar with C++ these 2 lines can be thought of similarly to:
|
||||
// delete myPointer;
|
||||
// myPointer = NULL;
|
||||
// It effectively frees the memory
|
||||
|
||||
// Returns data from buffer
|
||||
Some(result)
|
||||
} else {
|
||||
panic!("failed to run compute on gpu!")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn main() {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
{
|
||||
env_logger::init();
|
||||
pollster::block_on(run());
|
||||
}
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
{
|
||||
std::panic::set_hook(Box::new(console_error_panic_hook::hook));
|
||||
console_log::init().expect("could not initialize logger");
|
||||
wasm_bindgen_futures::spawn_local(run());
|
||||
}
|
||||
}
|
9
examples/src/shader_f16/shader.wgsl
Normal file
9
examples/src/shader_f16/shader.wgsl
Normal file
@ -0,0 +1,9 @@
|
||||
enable f16;
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> values: array<vec4<f16>>; // this is used as both values and output for convenience
|
||||
|
||||
@compute @workgroup_size(1)
|
||||
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||
values[global_id.x] = fma(values[0], values[0], values[0]);
|
||||
}
|
@ -41,8 +41,8 @@ msl-out = []
|
||||
## If you want to enable MSL output it regardless of the target platform, use `naga/msl-out`.
|
||||
msl-out-if-target-apple = []
|
||||
|
||||
serialize = ["dep:serde", "bitflags/serde", "indexmap/serde"]
|
||||
deserialize = ["dep:serde", "bitflags/serde", "indexmap/serde"]
|
||||
serialize = ["dep:serde", "bitflags/serde", "indexmap/serde", "half/serde"]
|
||||
deserialize = ["dep:serde", "bitflags/serde", "indexmap/serde", "half/serde"]
|
||||
arbitrary = ["dep:arbitrary", "bitflags/arbitrary", "indexmap/arbitrary"]
|
||||
spv-in = ["dep:petgraph", "dep:spirv"]
|
||||
spv-out = ["dep:spirv"]
|
||||
@ -82,6 +82,9 @@ petgraph = { version = "0.6", optional = true }
|
||||
pp-rs = { version = "0.2.1", optional = true }
|
||||
hexf-parse = { version = "0.2.1", optional = true }
|
||||
unicode-xid = { version = "0.2.6", optional = true }
|
||||
# TODO: remove `[patch]` entry in workspace `Cargo.toml` for `half` after we upstream `arbitrary` support
|
||||
half = { version = "2.4.1", features = ["arbitrary", "num-traits"] }
|
||||
num-traits = "0.2"
|
||||
|
||||
[build-dependencies]
|
||||
cfg_aliases.workspace = true
|
||||
|
@ -2647,6 +2647,9 @@ impl<'a, W: Write> Writer<'a, W> {
|
||||
// decimal part even it's zero which is needed for a valid glsl float constant
|
||||
crate::Literal::F64(value) => write!(self.out, "{value:?}LF")?,
|
||||
crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
|
||||
crate::Literal::F16(_) => {
|
||||
return Err(Error::Custom("GLSL has no 16-bit float type".into()));
|
||||
}
|
||||
// Unsigned integers need a `u` at the end
|
||||
//
|
||||
// While `core` doesn't necessarily need it, it's allowed and since `es` needs it we
|
||||
|
@ -2383,6 +2383,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
// decimal part even it's zero
|
||||
crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
|
||||
crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
|
||||
crate::Literal::F16(value) => write!(self.out, "{value:?}h")?,
|
||||
crate::Literal::U32(value) => write!(self.out, "{value}u")?,
|
||||
crate::Literal::I32(value) => write!(self.out, "{value}")?,
|
||||
crate::Literal::U64(value) => write!(self.out, "{value}uL")?,
|
||||
|
@ -6,6 +6,8 @@ use crate::{
|
||||
proc::{self, NameKey, TypeResolution},
|
||||
valid, FastHashMap, FastHashSet,
|
||||
};
|
||||
use half::f16;
|
||||
use num_traits::real::Real;
|
||||
#[cfg(test)]
|
||||
use std::ptr;
|
||||
use std::{
|
||||
@ -390,8 +392,12 @@ impl crate::Scalar {
|
||||
match self {
|
||||
Self {
|
||||
kind: Sk::Float,
|
||||
width: _,
|
||||
width: 4,
|
||||
} => "float",
|
||||
Self {
|
||||
kind: Sk::Float,
|
||||
width: 2,
|
||||
} => "half",
|
||||
Self {
|
||||
kind: Sk::Sint,
|
||||
width: 4,
|
||||
@ -1385,6 +1391,21 @@ impl<W: Write> Writer<W> {
|
||||
crate::Literal::F64(_) => {
|
||||
return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
|
||||
}
|
||||
crate::Literal::F16(value) => {
|
||||
if value.is_infinite() {
|
||||
let sign = if value.is_sign_negative() { "-" } else { "" };
|
||||
write!(self.out, "{sign}INFINITY")?;
|
||||
} else if value.is_nan() {
|
||||
write!(self.out, "NAN")?;
|
||||
} else {
|
||||
let suffix = if value.fract() == f16::from_f32(0.0) {
|
||||
".0h"
|
||||
} else {
|
||||
"h"
|
||||
};
|
||||
write!(self.out, "{value}{suffix}")?;
|
||||
}
|
||||
}
|
||||
crate::Literal::F32(value) => {
|
||||
if value.is_infinite() {
|
||||
let sign = if value.is_sign_negative() { "-" } else { "" };
|
||||
|
@ -406,6 +406,10 @@ impl super::Instruction {
|
||||
instruction
|
||||
}
|
||||
|
||||
pub(super) fn constant_16bit(result_type_id: Word, id: Word, low: Word) -> Self {
|
||||
Self::constant(result_type_id, id, &[low])
|
||||
}
|
||||
|
||||
pub(super) fn constant_32bit(result_type_id: Word, id: Word, value: Word) -> Self {
|
||||
Self::constant(result_type_id, id, &[value])
|
||||
}
|
||||
|
@ -799,6 +799,15 @@ impl Writer {
|
||||
if bits == 64 {
|
||||
self.capabilities_used.insert(spirv::Capability::Float64);
|
||||
}
|
||||
if bits == 16 {
|
||||
self.capabilities_used.insert(spirv::Capability::Float16);
|
||||
self.capabilities_used
|
||||
.insert(spirv::Capability::StorageBuffer16BitAccess);
|
||||
self.capabilities_used
|
||||
.insert(spirv::Capability::UniformAndStorageBuffer16BitAccess);
|
||||
self.capabilities_used
|
||||
.insert(spirv::Capability::StorageInputOutput16);
|
||||
}
|
||||
Instruction::type_float(id, bits)
|
||||
}
|
||||
Sk::Bool => Instruction::type_bool(id),
|
||||
@ -1148,6 +1157,10 @@ impl Writer {
|
||||
Instruction::constant_64bit(type_id, id, bits as u32, (bits >> 32) as u32)
|
||||
}
|
||||
crate::Literal::F32(value) => Instruction::constant_32bit(type_id, id, value.to_bits()),
|
||||
crate::Literal::F16(value) => {
|
||||
let low = value.to_bits();
|
||||
Instruction::constant_16bit(type_id, id, low as u32)
|
||||
}
|
||||
crate::Literal::U32(value) => Instruction::constant_32bit(type_id, id, value),
|
||||
crate::Literal::I32(value) => Instruction::constant_32bit(type_id, id, value as u32),
|
||||
crate::Literal::U64(value) => {
|
||||
|
@ -1232,6 +1232,7 @@ impl<W: Write> Writer<W> {
|
||||
|
||||
match expressions[expr] {
|
||||
Expression::Literal(literal) => match literal {
|
||||
crate::Literal::F16(value) => write!(self.out, "{value}h")?,
|
||||
crate::Literal::F32(value) => write!(self.out, "{value}f")?,
|
||||
crate::Literal::U32(value) => write!(self.out, "{value}u")?,
|
||||
crate::Literal::I32(value) => {
|
||||
@ -1995,6 +1996,10 @@ const fn scalar_kind_str(scalar: crate::Scalar) -> &'static str {
|
||||
kind: Sk::Float,
|
||||
width: 4,
|
||||
} => "f32",
|
||||
Scalar {
|
||||
kind: Sk::Float,
|
||||
width: 2,
|
||||
} => "f16",
|
||||
Scalar {
|
||||
kind: Sk::Sint,
|
||||
width: 4,
|
||||
|
@ -36,6 +36,7 @@ mod null;
|
||||
use convert::*;
|
||||
pub use error::Error;
|
||||
use function::*;
|
||||
use half::f16;
|
||||
use indexmap::IndexSet;
|
||||
|
||||
use crate::{
|
||||
@ -5484,6 +5485,9 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
|
||||
}) => {
|
||||
let low = self.next()?;
|
||||
match width {
|
||||
// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#Literal
|
||||
// If a numeric type’s bit width is less than 32-bits, the value appears in the low-order bits of the word.
|
||||
2 => crate::Literal::F16(f16::from_bits(low as u16)),
|
||||
4 => crate::Literal::F32(f32::from_bits(low)),
|
||||
8 => {
|
||||
inst.expect(5)?;
|
||||
|
@ -144,8 +144,6 @@ pub enum NumberError {
|
||||
Invalid,
|
||||
#[error("numeric literal not representable by target type")]
|
||||
NotRepresentable,
|
||||
#[error("unimplemented f16 type")]
|
||||
UnimplementedF16,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq)]
|
||||
|
@ -1831,6 +1831,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
|
||||
let expr: Typed<crate::Expression> = match *expr {
|
||||
ast::Expression::Literal(literal) => {
|
||||
let literal = match literal {
|
||||
ast::Literal::Number(Number::F16(f)) => crate::Literal::F16(f),
|
||||
ast::Literal::Number(Number::F32(f)) => crate::Literal::F32(f),
|
||||
ast::Literal::Number(Number::I32(i)) => crate::Literal::I32(i),
|
||||
ast::Literal::Number(Number::U32(u)) => crate::Literal::U32(u),
|
||||
|
@ -114,7 +114,10 @@ pub fn map_storage_format(word: &str, span: Span) -> Result<crate::StorageFormat
|
||||
pub fn get_scalar_type(word: &str) -> Option<Scalar> {
|
||||
use crate::ScalarKind as Sk;
|
||||
match word {
|
||||
// "f16" => Some(Scalar { kind: Sk::Float, width: 2 }),
|
||||
"f16" => Some(Scalar {
|
||||
kind: Sk::Float,
|
||||
width: 2,
|
||||
}),
|
||||
"f32" => Some(Scalar {
|
||||
kind: Sk::Float,
|
||||
width: 4,
|
||||
|
@ -5,24 +5,29 @@ use crate::{front::wgsl::error::Error, Span};
|
||||
|
||||
/// Tracks the status of every enable-extension known to Naga.
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct EnableExtensions {}
|
||||
pub struct EnableExtensions {
|
||||
/// Whether `enable f16;` was written earlier in the shader module.
|
||||
f16: bool,
|
||||
}
|
||||
|
||||
impl EnableExtensions {
|
||||
pub(crate) const fn empty() -> Self {
|
||||
Self {}
|
||||
Self { f16: false }
|
||||
}
|
||||
|
||||
/// Add an enable-extension to the set requested by a module.
|
||||
#[allow(unreachable_code)]
|
||||
pub(crate) fn add(&mut self, ext: ImplementedEnableExtension) {
|
||||
let _field: &mut bool = match ext {};
|
||||
*_field = true;
|
||||
let field = match ext {
|
||||
ImplementedEnableExtension::F16 => &mut self.f16,
|
||||
};
|
||||
*field = true;
|
||||
}
|
||||
|
||||
/// Query whether an enable-extension tracked here has been requested.
|
||||
#[allow(unused)]
|
||||
pub(crate) const fn contains(&self, ext: ImplementedEnableExtension) -> bool {
|
||||
match ext {}
|
||||
match ext {
|
||||
ImplementedEnableExtension::F16 => self.f16,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -37,7 +42,6 @@ impl Default for EnableExtensions {
|
||||
/// WGSL spec.: <https://www.w3.org/TR/WGSL/#enable-extensions-sec>
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
|
||||
pub enum EnableExtension {
|
||||
#[allow(unused)]
|
||||
Implemented(ImplementedEnableExtension),
|
||||
Unimplemented(UnimplementedEnableExtension),
|
||||
}
|
||||
@ -50,7 +54,7 @@ impl EnableExtension {
|
||||
/// Convert from a sentinel word in WGSL into its associated [`EnableExtension`], if possible.
|
||||
pub(crate) fn from_ident(word: &str, span: Span) -> Result<Self, Error<'_>> {
|
||||
Ok(match word {
|
||||
Self::F16 => Self::Unimplemented(UnimplementedEnableExtension::F16),
|
||||
Self::F16 => Self::Implemented(ImplementedEnableExtension::F16),
|
||||
Self::CLIP_DISTANCES => {
|
||||
Self::Unimplemented(UnimplementedEnableExtension::ClipDistances)
|
||||
}
|
||||
@ -64,9 +68,10 @@ impl EnableExtension {
|
||||
/// Maps this [`EnableExtension`] into the sentinel word associated with it in WGSL.
|
||||
pub const fn to_ident(self) -> &'static str {
|
||||
match self {
|
||||
Self::Implemented(kind) => match kind {},
|
||||
Self::Implemented(kind) => match kind {
|
||||
ImplementedEnableExtension::F16 => Self::F16,
|
||||
},
|
||||
Self::Unimplemented(kind) => match kind {
|
||||
UnimplementedEnableExtension::F16 => Self::F16,
|
||||
UnimplementedEnableExtension::ClipDistances => Self::CLIP_DISTANCES,
|
||||
UnimplementedEnableExtension::DualSourceBlending => Self::DUAL_SOURCE_BLENDING,
|
||||
},
|
||||
@ -76,17 +81,24 @@ impl EnableExtension {
|
||||
|
||||
/// A variant of [`EnableExtension::Implemented`].
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
|
||||
pub enum ImplementedEnableExtension {}
|
||||
|
||||
/// A variant of [`EnableExtension::Unimplemented`].
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
|
||||
pub enum UnimplementedEnableExtension {
|
||||
pub enum ImplementedEnableExtension {
|
||||
/// Enables `f16`/`half` primitive support in all shader languages.
|
||||
///
|
||||
/// In the WGSL standard, this corresponds to [`enable f16;`].
|
||||
///
|
||||
/// [`enable f16;`]: https://www.w3.org/TR/WGSL/#extension-f16
|
||||
F16,
|
||||
}
|
||||
|
||||
impl From<ImplementedEnableExtension> for EnableExtension {
|
||||
fn from(value: ImplementedEnableExtension) -> Self {
|
||||
Self::Implemented(value)
|
||||
}
|
||||
}
|
||||
|
||||
/// A variant of [`EnableExtension::Unimplemented`].
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
|
||||
pub enum UnimplementedEnableExtension {
|
||||
/// Enables the `clip_distances` variable in WGSL.
|
||||
///
|
||||
/// In the WGSL standard, this corresponds to [`enable clip_distances;`].
|
||||
@ -104,7 +116,6 @@ pub enum UnimplementedEnableExtension {
|
||||
impl UnimplementedEnableExtension {
|
||||
pub(crate) const fn tracking_issue_num(self) -> u16 {
|
||||
match self {
|
||||
Self::F16 => 4384,
|
||||
Self::ClipDistances => 6236,
|
||||
Self::DualSourceBlending => 6402,
|
||||
}
|
||||
|
@ -1,6 +1,8 @@
|
||||
use super::{number::consume_number, Error, ExpectedToken};
|
||||
use crate::front::wgsl::error::NumberError;
|
||||
use crate::front::wgsl::parse::directive::enable_extension::EnableExtensions;
|
||||
use crate::front::wgsl::parse::directive::enable_extension::{
|
||||
EnableExtensions, ImplementedEnableExtension,
|
||||
};
|
||||
use crate::front::wgsl::parse::{conv, Number};
|
||||
use crate::front::wgsl::Scalar;
|
||||
use crate::Span;
|
||||
@ -395,14 +397,26 @@ impl<'a> Lexer<'a> {
|
||||
/// Parses a generic scalar type, for example `<f32>`.
|
||||
pub(in crate::front::wgsl) fn next_scalar_generic(&mut self) -> Result<Scalar, Error<'a>> {
|
||||
self.expect_generic_paren('<')?;
|
||||
let pair = match self.next() {
|
||||
(Token::Word(word), span) => {
|
||||
conv::get_scalar_type(word).ok_or(Error::UnknownScalarType(span))
|
||||
}
|
||||
let (scalar, span) = match self.next() {
|
||||
(Token::Word(word), span) => conv::get_scalar_type(word)
|
||||
.map(|scalar| (scalar, span))
|
||||
.ok_or(Error::UnknownScalarType(span)),
|
||||
(_, span) => Err(Error::UnknownScalarType(span)),
|
||||
}?;
|
||||
|
||||
if matches!(scalar, Scalar::F16)
|
||||
&& !self
|
||||
.enable_extensions
|
||||
.contains(ImplementedEnableExtension::F16)
|
||||
{
|
||||
return Err(Error::EnableExtensionNotEnabled {
|
||||
span,
|
||||
kind: ImplementedEnableExtension::F16.into(),
|
||||
});
|
||||
}
|
||||
|
||||
self.expect_generic_paren('>')?;
|
||||
Ok(pair)
|
||||
Ok(scalar)
|
||||
}
|
||||
|
||||
/// Parses a generic scalar type, for example `<f32>`.
|
||||
@ -412,14 +426,27 @@ impl<'a> Lexer<'a> {
|
||||
&mut self,
|
||||
) -> Result<(Scalar, Span), Error<'a>> {
|
||||
self.expect_generic_paren('<')?;
|
||||
let pair = match self.next() {
|
||||
|
||||
let (scalar, span) = match self.next() {
|
||||
(Token::Word(word), span) => conv::get_scalar_type(word)
|
||||
.map(|scalar| (scalar, span))
|
||||
.ok_or(Error::UnknownScalarType(span)),
|
||||
(_, span) => Err(Error::UnknownScalarType(span)),
|
||||
}?;
|
||||
|
||||
if matches!(scalar, Scalar::F16)
|
||||
&& !self
|
||||
.enable_extensions
|
||||
.contains(ImplementedEnableExtension::F16)
|
||||
{
|
||||
return Err(Error::EnableExtensionNotEnabled {
|
||||
span,
|
||||
kind: ImplementedEnableExtension::F16.into(),
|
||||
});
|
||||
}
|
||||
|
||||
self.expect_generic_paren('>')?;
|
||||
Ok(pair)
|
||||
Ok((scalar, span))
|
||||
}
|
||||
|
||||
pub(in crate::front::wgsl) fn next_storage_access(
|
||||
@ -477,6 +504,7 @@ fn sub_test(source: &str, expected_tokens: &[Token]) {
|
||||
|
||||
#[test]
|
||||
fn test_numbers() {
|
||||
use half::f16;
|
||||
// WGSL spec examples //
|
||||
|
||||
// decimal integer
|
||||
@ -501,14 +529,14 @@ fn test_numbers() {
|
||||
Token::Number(Ok(Number::AbstractFloat(0.01))),
|
||||
Token::Number(Ok(Number::AbstractFloat(12.34))),
|
||||
Token::Number(Ok(Number::F32(0.))),
|
||||
Token::Number(Err(NumberError::UnimplementedF16)),
|
||||
Token::Number(Ok(Number::F16(f16::from_f32(0.)))),
|
||||
Token::Number(Ok(Number::AbstractFloat(0.001))),
|
||||
Token::Number(Ok(Number::AbstractFloat(43.75))),
|
||||
Token::Number(Ok(Number::F32(16.))),
|
||||
Token::Number(Ok(Number::AbstractFloat(0.1875))),
|
||||
Token::Number(Err(NumberError::UnimplementedF16)),
|
||||
Token::Number(Ok(Number::F16(f16::from_f32(0.75)))),
|
||||
Token::Number(Ok(Number::AbstractFloat(0.12109375))),
|
||||
Token::Number(Err(NumberError::UnimplementedF16)),
|
||||
Token::Number(Ok(Number::F16(f16::from_f32(12.5)))),
|
||||
],
|
||||
);
|
||||
|
||||
|
@ -1,7 +1,5 @@
|
||||
use crate::front::wgsl::error::{Error, ExpectedToken};
|
||||
use crate::front::wgsl::parse::directive::enable_extension::{
|
||||
EnableExtension, EnableExtensions, UnimplementedEnableExtension,
|
||||
};
|
||||
use crate::front::wgsl::parse::directive::enable_extension::{EnableExtension, EnableExtensions};
|
||||
use crate::front::wgsl::parse::directive::language_extension::LanguageExtension;
|
||||
use crate::front::wgsl::parse::directive::DirectiveKind;
|
||||
use crate::front::wgsl::parse::lexer::{Lexer, Token};
|
||||
@ -670,15 +668,7 @@ impl Parser {
|
||||
}
|
||||
(Token::Number(res), span) => {
|
||||
let _ = lexer.next();
|
||||
let num = res.map_err(|err| match err {
|
||||
super::error::NumberError::UnimplementedF16 => {
|
||||
Error::EnableExtensionNotEnabled {
|
||||
kind: EnableExtension::Unimplemented(UnimplementedEnableExtension::F16),
|
||||
span,
|
||||
}
|
||||
}
|
||||
err => Error::BadNumber(span, err),
|
||||
})?;
|
||||
let num = res.map_err(|err| Error::BadNumber(span, err))?;
|
||||
ast::Expression::Literal(ast::Literal::Number(num))
|
||||
}
|
||||
(Token::Word("RAY_FLAG_NONE"), _) => {
|
||||
|
@ -1,5 +1,6 @@
|
||||
use crate::front::wgsl::error::NumberError;
|
||||
use crate::front::wgsl::parse::lexer::Token;
|
||||
use half::f16;
|
||||
|
||||
/// When using this type assume no Abstract Int/Float for now
|
||||
#[derive(Copy, Clone, Debug, PartialEq)]
|
||||
@ -16,6 +17,8 @@ pub enum Number {
|
||||
I64(i64),
|
||||
/// Concrete u64
|
||||
U64(u64),
|
||||
/// Concrete f16
|
||||
F16(f16),
|
||||
/// Concrete f32
|
||||
F32(f32),
|
||||
/// Concrete f64
|
||||
@ -362,7 +365,8 @@ fn parse_hex_float(input: &str, kind: Option<FloatKind>) -> Result<Number, Numbe
|
||||
// can only be ParseHexfErrorKind::Inexact but we can't check since it's private
|
||||
_ => Err(NumberError::NotRepresentable),
|
||||
},
|
||||
Some(FloatKind::F16) => Err(NumberError::UnimplementedF16),
|
||||
// TODO: f16 is not supported by hexf_parse
|
||||
Some(FloatKind::F16) => Err(NumberError::NotRepresentable),
|
||||
Some(FloatKind::F32) => match hexf_parse::parse_hexf32(input, false) {
|
||||
Ok(num) => Ok(Number::F32(num)),
|
||||
// can only be ParseHexfErrorKind::Inexact but we can't check since it's private
|
||||
@ -398,7 +402,12 @@ fn parse_dec_float(input: &str, kind: Option<FloatKind>) -> Result<Number, Numbe
|
||||
.then_some(Number::F64(num))
|
||||
.ok_or(NumberError::NotRepresentable)
|
||||
}
|
||||
Some(FloatKind::F16) => Err(NumberError::UnimplementedF16),
|
||||
Some(FloatKind::F16) => {
|
||||
let num = input.parse::<f16>().unwrap(); // will never fail
|
||||
num.is_finite()
|
||||
.then_some(Number::F16(num))
|
||||
.ok_or(NumberError::NotRepresentable)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -268,6 +268,7 @@ pub use crate::arena::{Arena, Handle, Range, UniqueArena};
|
||||
pub use crate::span::{SourceLocation, Span, SpanContext, WithSpan};
|
||||
#[cfg(feature = "arbitrary")]
|
||||
use arbitrary::Arbitrary;
|
||||
use half::f16;
|
||||
#[cfg(feature = "deserialize")]
|
||||
use serde::Deserialize;
|
||||
#[cfg(feature = "serialize")]
|
||||
@ -871,6 +872,7 @@ pub enum Literal {
|
||||
F64(f64),
|
||||
/// May not be NaN or infinity.
|
||||
F32(f32),
|
||||
F16(f16),
|
||||
U32(u32),
|
||||
I32(i32),
|
||||
U64(u64),
|
||||
|
@ -1,6 +1,8 @@
|
||||
use std::iter;
|
||||
|
||||
use arrayvec::ArrayVec;
|
||||
use half::f16;
|
||||
use num_traits::{real::Real, FromPrimitive, One, ToPrimitive, Zero};
|
||||
|
||||
use crate::{
|
||||
arena::{Arena, Handle, HandleVec, UniqueArena},
|
||||
@ -199,6 +201,7 @@ gen_component_wise_extractor! {
|
||||
literals: [
|
||||
AbstractFloat => AbstractFloat: f64,
|
||||
F32 => F32: f32,
|
||||
F16 => F16: f16,
|
||||
AbstractInt => AbstractInt: i64,
|
||||
U32 => U32: u32,
|
||||
I32 => I32: i32,
|
||||
@ -219,6 +222,7 @@ gen_component_wise_extractor! {
|
||||
literals: [
|
||||
AbstractFloat => Abstract: f64,
|
||||
F32 => F32: f32,
|
||||
F16 => F16: f16,
|
||||
],
|
||||
scalar_kinds: [
|
||||
Float,
|
||||
@ -244,6 +248,7 @@ gen_component_wise_extractor! {
|
||||
AbstractFloat => AbstractFloat: f64,
|
||||
AbstractInt => AbstractInt: i64,
|
||||
F32 => F32: f32,
|
||||
F16 => F16: f16,
|
||||
I32 => I32: i32,
|
||||
],
|
||||
scalar_kinds: [
|
||||
@ -1088,6 +1093,7 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
component_wise_scalar(self, span, [arg], |args| match args {
|
||||
Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
|
||||
Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
|
||||
Scalar::F16([e]) => Ok(Scalar::F16([e.abs()])),
|
||||
Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.abs()])),
|
||||
Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
|
||||
Scalar::U32([e]) => Ok(Scalar::U32([e])), // TODO: just re-use the expression, ezpz
|
||||
@ -1119,9 +1125,13 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
}
|
||||
)
|
||||
}
|
||||
crate::MathFunction::Saturate => {
|
||||
component_wise_float!(self, span, [arg], |e| { Ok([e.clamp(0., 1.)]) })
|
||||
}
|
||||
crate::MathFunction::Saturate => component_wise_float(self, span, [arg], |e| match e {
|
||||
Float::F16([e]) => Ok(Float::F16(
|
||||
[e.clamp(f16::from_f32(0.0), f16::from_f32(1.0))],
|
||||
)),
|
||||
Float::F32([e]) => Ok(Float::F32([e.clamp(0., 1.)])),
|
||||
Float::Abstract([e]) => Ok(Float::Abstract([e.clamp(0., 1.)])),
|
||||
}),
|
||||
|
||||
// trigonometry
|
||||
crate::MathFunction::Cos => {
|
||||
@ -1198,6 +1208,9 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
component_wise_float(self, span, [arg], |e| match e {
|
||||
Float::Abstract([e]) => Ok(Float::Abstract([round_ties_even(e)])),
|
||||
Float::F32([e]) => Ok(Float::F32([(round_ties_even(e as f64) as f32)])),
|
||||
Float::F16([e]) => {
|
||||
Ok(Float::F16([(f16::from_f64(round_ties_even(f64::from(e))))]))
|
||||
}
|
||||
})
|
||||
}
|
||||
crate::MathFunction::Fract => {
|
||||
@ -1243,15 +1256,27 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
)
|
||||
}
|
||||
crate::MathFunction::Step => {
|
||||
component_wise_float!(self, span, [arg, arg1.unwrap()], |edge, x| {
|
||||
Ok([if edge <= x { 1.0 } else { 0.0 }])
|
||||
component_wise_float(self, span, [arg, arg1.unwrap()], |x| match x {
|
||||
Float::Abstract([edge, x]) => {
|
||||
Ok(Float::Abstract([if edge <= x { 1.0 } else { 0.0 }]))
|
||||
}
|
||||
Float::F32([edge, x]) => Ok(Float::F32([if edge <= x { 1.0 } else { 0.0 }])),
|
||||
Float::F16([edge, x]) => Ok(Float::F16([if edge <= x {
|
||||
f16::one()
|
||||
} else {
|
||||
f16::zero()
|
||||
}])),
|
||||
})
|
||||
}
|
||||
crate::MathFunction::Sqrt => {
|
||||
component_wise_float!(self, span, [arg], |e| { Ok([e.sqrt()]) })
|
||||
}
|
||||
crate::MathFunction::InverseSqrt => {
|
||||
component_wise_float!(self, span, [arg], |e| { Ok([1. / e.sqrt()]) })
|
||||
component_wise_float(self, span, [arg], |e| match e {
|
||||
Float::Abstract([e]) => Ok(Float::Abstract([1. / e.sqrt()])),
|
||||
Float::F32([e]) => Ok(Float::F32([1. / e.sqrt()])),
|
||||
Float::F16([e]) => Ok(Float::F16([f16::from_f32(1. / f32::from(e).sqrt())])),
|
||||
})
|
||||
}
|
||||
|
||||
// bits
|
||||
@ -1529,6 +1554,7 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
Literal::I32(v) => v,
|
||||
Literal::U32(v) => v as i32,
|
||||
Literal::F32(v) => v as i32,
|
||||
Literal::F16(v) => f16::to_i32(&v).unwrap(), //Only None on NaN or Inf
|
||||
Literal::Bool(v) => v as i32,
|
||||
Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
|
||||
return make_error();
|
||||
@ -1540,6 +1566,7 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
Literal::I32(v) => v as u32,
|
||||
Literal::U32(v) => v,
|
||||
Literal::F32(v) => v as u32,
|
||||
Literal::F16(v) => f16::to_u32(&v).unwrap(), //Only None on NaN or Inf
|
||||
Literal::Bool(v) => v as u32,
|
||||
Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
|
||||
return make_error();
|
||||
@ -1555,6 +1582,7 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
Literal::F64(v) => v as i64,
|
||||
Literal::I64(v) => v,
|
||||
Literal::U64(v) => v as i64,
|
||||
Literal::F16(v) => f16::to_i64(&v).unwrap(), //Only None on NaN or Inf
|
||||
Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
|
||||
Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
|
||||
}),
|
||||
@ -1566,9 +1594,22 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
Literal::F64(v) => v as u64,
|
||||
Literal::I64(v) => v as u64,
|
||||
Literal::U64(v) => v,
|
||||
Literal::F16(v) => f16::to_u64(&v).unwrap(), //Only None on NaN or Inf
|
||||
Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
|
||||
Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
|
||||
}),
|
||||
Sc::F16 => Literal::F16(match literal {
|
||||
Literal::F16(v) => v,
|
||||
Literal::F32(v) => f16::from_f32(v),
|
||||
Literal::F64(v) => f16::from_f64(v),
|
||||
Literal::Bool(v) => f16::from_u32(v as u32).unwrap(),
|
||||
Literal::I64(v) => f16::from_i64(v).unwrap(),
|
||||
Literal::U64(v) => f16::from_u64(v).unwrap(),
|
||||
Literal::I32(v) => f16::from_i32(v).unwrap(),
|
||||
Literal::U32(v) => f16::from_u32(v).unwrap(),
|
||||
Literal::AbstractFloat(v) => f16::try_from_abstract(v)?,
|
||||
Literal::AbstractInt(v) => f16::try_from_abstract(v)?,
|
||||
}),
|
||||
Sc::F32 => Literal::F32(match literal {
|
||||
Literal::I32(v) => v as f32,
|
||||
Literal::U32(v) => v as f32,
|
||||
@ -1577,12 +1618,14 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
|
||||
return make_error();
|
||||
}
|
||||
Literal::F16(v) => f16::to_f32(v),
|
||||
Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
|
||||
Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
|
||||
}),
|
||||
Sc::F64 => Literal::F64(match literal {
|
||||
Literal::I32(v) => v as f64,
|
||||
Literal::U32(v) => v as f64,
|
||||
Literal::F16(v) => f16::to_f64(v),
|
||||
Literal::F32(v) => v as f64,
|
||||
Literal::F64(v) => v,
|
||||
Literal::Bool(v) => v as u32 as f64,
|
||||
@ -1594,6 +1637,7 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
Literal::I32(v) => v != 0,
|
||||
Literal::U32(v) => v != 0,
|
||||
Literal::F32(v) => v != 0.0,
|
||||
Literal::F16(v) => v != f16::zero(),
|
||||
Literal::Bool(v) => v,
|
||||
Literal::F64(_)
|
||||
| Literal::I64(_)
|
||||
@ -1743,6 +1787,7 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
UnaryOperator::Negate => match value {
|
||||
Literal::I32(v) => Literal::I32(v.wrapping_neg()),
|
||||
Literal::F32(v) => Literal::F32(-v),
|
||||
Literal::F16(v) => Literal::F16(-v),
|
||||
Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
|
||||
Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
|
||||
_ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
|
||||
@ -1881,6 +1926,14 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
BinaryOperator::Modulo => a % b,
|
||||
_ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
|
||||
}),
|
||||
(Literal::F16(a), Literal::F16(b)) => Literal::F16(match op {
|
||||
BinaryOperator::Add => a + b,
|
||||
BinaryOperator::Subtract => a - b,
|
||||
BinaryOperator::Multiply => a * b,
|
||||
BinaryOperator::Divide => a / b,
|
||||
BinaryOperator::Modulo => a % b,
|
||||
_ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
|
||||
}),
|
||||
(Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
|
||||
Literal::AbstractInt(match op {
|
||||
BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
|
||||
@ -2450,6 +2503,32 @@ impl TryFromAbstract<f64> for u64 {
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFromAbstract<f64> for f16 {
|
||||
fn try_from_abstract(value: f64) -> Result<f16, ConstantEvaluatorError> {
|
||||
let f = f16::from_f64(value);
|
||||
if f.is_infinite() {
|
||||
return Err(ConstantEvaluatorError::AutomaticConversionLossy {
|
||||
value: format!("{value:?}"),
|
||||
to_type: "f16",
|
||||
});
|
||||
}
|
||||
Ok(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFromAbstract<i64> for f16 {
|
||||
fn try_from_abstract(value: i64) -> Result<f16, ConstantEvaluatorError> {
|
||||
let f = f16::from_i64(value);
|
||||
if f.is_none() {
|
||||
return Err(ConstantEvaluatorError::AutomaticConversionLossy {
|
||||
value: format!("{value:?}"),
|
||||
to_type: "f16",
|
||||
});
|
||||
}
|
||||
Ok(f.unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::vec;
|
||||
|
@ -90,6 +90,10 @@ impl super::Scalar {
|
||||
kind: crate::ScalarKind::Uint,
|
||||
width: 4,
|
||||
};
|
||||
pub const F16: Self = Self {
|
||||
kind: crate::ScalarKind::Float,
|
||||
width: 2,
|
||||
};
|
||||
pub const F32: Self = Self {
|
||||
kind: crate::ScalarKind::Float,
|
||||
width: 4,
|
||||
@ -157,6 +161,7 @@ impl super::Scalar {
|
||||
pub enum HashableLiteral {
|
||||
F64(u64),
|
||||
F32(u32),
|
||||
F16(u16),
|
||||
U32(u32),
|
||||
I32(i32),
|
||||
U64(u64),
|
||||
@ -171,6 +176,7 @@ impl From<crate::Literal> for HashableLiteral {
|
||||
match l {
|
||||
crate::Literal::F64(v) => Self::F64(v.to_bits()),
|
||||
crate::Literal::F32(v) => Self::F32(v.to_bits()),
|
||||
crate::Literal::F16(v) => Self::F16(v.to_bits()),
|
||||
crate::Literal::U32(v) => Self::U32(v),
|
||||
crate::Literal::I32(v) => Self::I32(v),
|
||||
crate::Literal::U64(v) => Self::U64(v),
|
||||
@ -209,6 +215,7 @@ impl crate::Literal {
|
||||
match *self {
|
||||
Self::F64(_) | Self::I64(_) | Self::U64(_) => 8,
|
||||
Self::F32(_) | Self::U32(_) | Self::I32(_) => 4,
|
||||
Self::F16(_) => 2,
|
||||
Self::Bool(_) => crate::BOOL_WIDTH,
|
||||
Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH,
|
||||
}
|
||||
@ -217,6 +224,7 @@ impl crate::Literal {
|
||||
match *self {
|
||||
Self::F64(_) => crate::Scalar::F64,
|
||||
Self::F32(_) => crate::Scalar::F32,
|
||||
Self::F16(_) => crate::Scalar::F16,
|
||||
Self::U32(_) => crate::Scalar::U32,
|
||||
Self::I32(_) => crate::Scalar::I32,
|
||||
Self::U64(_) => crate::Scalar::U64,
|
||||
|
@ -143,6 +143,8 @@ bitflags::bitflags! {
|
||||
const SHADER_INT64_ATOMIC_MIN_MAX = 0x80000;
|
||||
/// Support for all atomic operations on 64-bit integers.
|
||||
const SHADER_INT64_ATOMIC_ALL_OPS = 0x100000;
|
||||
/// Support for 16-bit floating-point types.
|
||||
const SHADER_FLOAT16 = 0x200000;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -243,8 +243,8 @@ impl super::Validator {
|
||||
pub(super) const fn check_width(&self, scalar: crate::Scalar) -> Result<(), WidthError> {
|
||||
let good = match scalar.kind {
|
||||
crate::ScalarKind::Bool => scalar.width == crate::BOOL_WIDTH,
|
||||
crate::ScalarKind::Float => {
|
||||
if scalar.width == 8 {
|
||||
crate::ScalarKind::Float => match scalar.width {
|
||||
8 => {
|
||||
if !self.capabilities.contains(Capabilities::FLOAT64) {
|
||||
return Err(WidthError::MissingCapability {
|
||||
name: "f64",
|
||||
@ -252,10 +252,18 @@ impl super::Validator {
|
||||
});
|
||||
}
|
||||
true
|
||||
} else {
|
||||
scalar.width == 4
|
||||
}
|
||||
}
|
||||
2 => {
|
||||
if !self.capabilities.contains(Capabilities::SHADER_FLOAT16) {
|
||||
return Err(WidthError::MissingCapability {
|
||||
name: "f16",
|
||||
flag: "FLOAT16",
|
||||
});
|
||||
}
|
||||
true
|
||||
}
|
||||
_ => scalar.width == 4,
|
||||
},
|
||||
crate::ScalarKind::Sint => {
|
||||
if scalar.width == 8 {
|
||||
if !self.capabilities.contains(Capabilities::SHADER_INT64) {
|
||||
|
22
naga/tests/in/float16.param.ron
Normal file
22
naga/tests/in/float16.param.ron
Normal file
@ -0,0 +1,22 @@
|
||||
(
|
||||
god_mode: true,
|
||||
spv: (
|
||||
version: (1, 0),
|
||||
),
|
||||
hlsl: (
|
||||
shader_model: V6_2,
|
||||
binding_map: {},
|
||||
fake_missing_bindings: true,
|
||||
special_constants_binding: Some((space: 1, register: 0)),
|
||||
push_constants_target: Some((space: 0, register: 0)),
|
||||
zero_initialize_workgroup_memory: true,
|
||||
),
|
||||
msl: (
|
||||
lang_version: (1, 0),
|
||||
per_entry_point_map: {},
|
||||
inline_samplers: [],
|
||||
spirv_cross_compatibility: false,
|
||||
fake_missing_bindings: true,
|
||||
zero_initialize_workgroup_memory: true,
|
||||
),
|
||||
)
|
87
naga/tests/in/float16.wgsl
Normal file
87
naga/tests/in/float16.wgsl
Normal file
@ -0,0 +1,87 @@
|
||||
enable f16;
|
||||
enable f16; //redundant directives are OK
|
||||
|
||||
var<private> private_variable: f16 = 1h;
|
||||
const constant_variable: f16 = f16(15.2);
|
||||
|
||||
struct UniformCompatible {
|
||||
// Other types
|
||||
val_u32: u32,
|
||||
val_i32: i32,
|
||||
val_f32: f32,
|
||||
|
||||
// f16
|
||||
val_f16: f16,
|
||||
val_f16_2: vec2<f16>,
|
||||
val_f16_3: vec3<f16>,
|
||||
val_f16_4: vec4<f16>,
|
||||
final_value: f16,
|
||||
}
|
||||
|
||||
struct StorageCompatible {
|
||||
val_f16_array_2: array<f16, 2>,
|
||||
val_f16_array_2: array<f16, 2>,
|
||||
}
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<uniform> input_uniform: UniformCompatible;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage> input_storage: UniformCompatible;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage> input_arrays: StorageCompatible;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<storage, read_write> output: UniformCompatible;
|
||||
|
||||
@group(0) @binding(4)
|
||||
var<storage, read_write> output_arrays: StorageCompatible;
|
||||
|
||||
fn f16_function(x: f16) -> f16 {
|
||||
var val: f16 = f16(constant_variable);
|
||||
// A number too big for f16
|
||||
val += 1h - 33333h;
|
||||
// Constructing an f16 from an AbstractInt
|
||||
val += val + f16(5.);
|
||||
// Constructing a f16 from other types and other types from f16.
|
||||
val += f16(input_uniform.val_f32 + f32(val));
|
||||
// Constructing a vec3<i64> from a i64
|
||||
val += vec3<f16>(input_uniform.val_f16).z;
|
||||
|
||||
// Reading/writing to a uniform/storage buffer
|
||||
output.val_f16 = input_uniform.val_f16 + input_storage.val_f16;
|
||||
output.val_f16_2 = input_uniform.val_f16_2 + input_storage.val_f16_2;
|
||||
output.val_f16_3 = input_uniform.val_f16_3 + input_storage.val_f16_3;
|
||||
output.val_f16_4 = input_uniform.val_f16_4 + input_storage.val_f16_4;
|
||||
|
||||
output_arrays.val_f16_array_2 = input_arrays.val_f16_array_2;
|
||||
|
||||
// We make sure not to use 32 in these arguments, so it's clear in the results which are builtin
|
||||
// constants based on the size of the type, and which are arguments.
|
||||
|
||||
// Numeric functions
|
||||
val += abs(val);
|
||||
val += clamp(val, val, val);
|
||||
//val += countLeadingZeros(val);
|
||||
//val += countOneBits(val);
|
||||
//val += countTrailingZeros(val);
|
||||
val += dot(vec2(val), vec2(val));
|
||||
//val += extractBits(val, 15u, 28u);
|
||||
//val += firstLeadingBit(val);
|
||||
//val += firstTrailingBit(val);
|
||||
//val += insertBits(val, 12li, 15u, 28u);
|
||||
val += max(val, val);
|
||||
val += min(val, val);
|
||||
//val += reverseBits(val);
|
||||
val += sign(val); // only for i64
|
||||
|
||||
// Make sure all the variables are used.
|
||||
return f16(1.0);
|
||||
}
|
||||
|
||||
@compute @workgroup_size(1)
|
||||
fn main() {
|
||||
output.final_value = f16_function(2h);
|
||||
}
|
||||
|
107
naga/tests/out/hlsl/float16.hlsl
Normal file
107
naga/tests/out/hlsl/float16.hlsl
Normal file
@ -0,0 +1,107 @@
|
||||
struct NagaConstants {
|
||||
int first_vertex;
|
||||
int first_instance;
|
||||
uint other;
|
||||
};
|
||||
ConstantBuffer<NagaConstants> _NagaConstants: register(b0, space1);
|
||||
|
||||
struct UniformCompatible {
|
||||
uint val_u32_;
|
||||
int val_i32_;
|
||||
float val_f32_;
|
||||
half val_f16_;
|
||||
half2 val_f16_2_;
|
||||
int _pad5_0;
|
||||
half3 val_f16_3_;
|
||||
half4 val_f16_4_;
|
||||
half final_value;
|
||||
int _end_pad_0;
|
||||
};
|
||||
|
||||
struct StorageCompatible {
|
||||
half val_f16_array_2_[2];
|
||||
half val_f16_array_2_1[2];
|
||||
};
|
||||
|
||||
static const half constant_variable = 15.203125h;
|
||||
|
||||
static half private_variable = 1.0h;
|
||||
cbuffer input_uniform : register(b0) { UniformCompatible input_uniform; }
|
||||
ByteAddressBuffer input_storage : register(t1);
|
||||
ByteAddressBuffer input_arrays : register(t2);
|
||||
RWByteAddressBuffer output : register(u3);
|
||||
RWByteAddressBuffer output_arrays : register(u4);
|
||||
|
||||
typedef half ret_Constructarray2_half_[2];
|
||||
ret_Constructarray2_half_ Constructarray2_half_(half arg0, half arg1) {
|
||||
half ret[2] = { arg0, arg1 };
|
||||
return ret;
|
||||
}
|
||||
|
||||
half f16_function(half x)
|
||||
{
|
||||
half val = 15.203125h;
|
||||
|
||||
half _expr6 = val;
|
||||
val = (_expr6 + (1.0h - 33344.0h));
|
||||
half _expr8 = val;
|
||||
half _expr11 = val;
|
||||
val = (_expr11 + (_expr8 + 5.0h));
|
||||
float _expr15 = input_uniform.val_f32_;
|
||||
half _expr16 = val;
|
||||
half _expr20 = val;
|
||||
val = (_expr20 + half((_expr15 + float(_expr16))));
|
||||
half _expr24 = input_uniform.val_f16_;
|
||||
half _expr27 = val;
|
||||
val = (_expr27 + (_expr24).xxx.z);
|
||||
half _expr33 = input_uniform.val_f16_;
|
||||
half _expr36 = input_storage.Load<half>(12);
|
||||
output.Store(12, (_expr33 + _expr36));
|
||||
half2 _expr42 = input_uniform.val_f16_2_;
|
||||
half2 _expr45 = input_storage.Load<half2>(16);
|
||||
output.Store(16, (_expr42 + _expr45));
|
||||
half3 _expr51 = input_uniform.val_f16_3_;
|
||||
half3 _expr54 = input_storage.Load<half3>(24);
|
||||
output.Store(24, (_expr51 + _expr54));
|
||||
half4 _expr60 = input_uniform.val_f16_4_;
|
||||
half4 _expr63 = input_storage.Load<half4>(32);
|
||||
output.Store(32, (_expr60 + _expr63));
|
||||
half _expr69[2] = Constructarray2_half_(input_arrays.Load<half>(0+0), input_arrays.Load<half>(0+2));
|
||||
{
|
||||
half _value2[2] = _expr69;
|
||||
output_arrays.Store(0+0, _value2[0]);
|
||||
output_arrays.Store(0+2, _value2[1]);
|
||||
}
|
||||
half _expr70 = val;
|
||||
half _expr72 = val;
|
||||
val = (_expr72 + abs(_expr70));
|
||||
half _expr74 = val;
|
||||
half _expr75 = val;
|
||||
half _expr76 = val;
|
||||
half _expr78 = val;
|
||||
val = (_expr78 + clamp(_expr74, _expr75, _expr76));
|
||||
half _expr80 = val;
|
||||
half _expr82 = val;
|
||||
half _expr85 = val;
|
||||
val = (_expr85 + dot((_expr80).xx, (_expr82).xx));
|
||||
half _expr87 = val;
|
||||
half _expr88 = val;
|
||||
half _expr90 = val;
|
||||
val = (_expr90 + max(_expr87, _expr88));
|
||||
half _expr92 = val;
|
||||
half _expr93 = val;
|
||||
half _expr95 = val;
|
||||
val = (_expr95 + min(_expr92, _expr93));
|
||||
half _expr97 = val;
|
||||
half _expr99 = val;
|
||||
val = (_expr99 + sign(_expr97));
|
||||
return 1.0h;
|
||||
}
|
||||
|
||||
[numthreads(1, 1, 1)]
|
||||
void main()
|
||||
{
|
||||
const half _e3 = f16_function(2.0h);
|
||||
output.Store(40, _e3);
|
||||
return;
|
||||
}
|
12
naga/tests/out/hlsl/float16.ron
Normal file
12
naga/tests/out/hlsl/float16.ron
Normal file
@ -0,0 +1,12 @@
|
||||
(
|
||||
vertex:[
|
||||
],
|
||||
fragment:[
|
||||
],
|
||||
compute:[
|
||||
(
|
||||
entry_point:"main",
|
||||
target_profile:"cs_6_2",
|
||||
),
|
||||
],
|
||||
)
|
99
naga/tests/out/msl/float16.msl
Normal file
99
naga/tests/out/msl/float16.msl
Normal file
@ -0,0 +1,99 @@
|
||||
// language: metal1.0
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
|
||||
using metal::uint;
|
||||
|
||||
struct UniformCompatible {
|
||||
uint val_u32_;
|
||||
int val_i32_;
|
||||
float val_f32_;
|
||||
half val_f16_;
|
||||
char _pad4[2];
|
||||
metal::half2 val_f16_2_;
|
||||
char _pad5[4];
|
||||
metal::half3 val_f16_3_;
|
||||
metal::half4 val_f16_4_;
|
||||
half final_value;
|
||||
};
|
||||
struct type_7 {
|
||||
half inner[2];
|
||||
};
|
||||
struct StorageCompatible {
|
||||
type_7 val_f16_array_2_;
|
||||
type_7 val_f16_array_2_1;
|
||||
};
|
||||
constant half constant_variable = 15.203125;
|
||||
|
||||
half f16_function(
|
||||
half x,
|
||||
constant UniformCompatible& input_uniform,
|
||||
device UniformCompatible const& input_storage,
|
||||
device StorageCompatible const& input_arrays,
|
||||
device UniformCompatible& output,
|
||||
device StorageCompatible& output_arrays
|
||||
) {
|
||||
half val = 15.203125;
|
||||
half _e6 = val;
|
||||
val = _e6 + (1.0 - 33344.0);
|
||||
half _e8 = val;
|
||||
half _e11 = val;
|
||||
val = _e11 + (_e8 + 5.0);
|
||||
float _e15 = input_uniform.val_f32_;
|
||||
half _e16 = val;
|
||||
half _e20 = val;
|
||||
val = _e20 + static_cast<half>(_e15 + static_cast<float>(_e16));
|
||||
half _e24 = input_uniform.val_f16_;
|
||||
half _e27 = val;
|
||||
val = _e27 + metal::half3(_e24).z;
|
||||
half _e33 = input_uniform.val_f16_;
|
||||
half _e36 = input_storage.val_f16_;
|
||||
output.val_f16_ = _e33 + _e36;
|
||||
metal::half2 _e42 = input_uniform.val_f16_2_;
|
||||
metal::half2 _e45 = input_storage.val_f16_2_;
|
||||
output.val_f16_2_ = _e42 + _e45;
|
||||
metal::half3 _e51 = input_uniform.val_f16_3_;
|
||||
metal::half3 _e54 = input_storage.val_f16_3_;
|
||||
output.val_f16_3_ = _e51 + _e54;
|
||||
metal::half4 _e60 = input_uniform.val_f16_4_;
|
||||
metal::half4 _e63 = input_storage.val_f16_4_;
|
||||
output.val_f16_4_ = _e60 + _e63;
|
||||
type_7 _e69 = input_arrays.val_f16_array_2_;
|
||||
output_arrays.val_f16_array_2_ = _e69;
|
||||
half _e70 = val;
|
||||
half _e72 = val;
|
||||
val = _e72 + metal::abs(_e70);
|
||||
half _e74 = val;
|
||||
half _e75 = val;
|
||||
half _e76 = val;
|
||||
half _e78 = val;
|
||||
val = _e78 + metal::clamp(_e74, _e75, _e76);
|
||||
half _e80 = val;
|
||||
half _e82 = val;
|
||||
half _e85 = val;
|
||||
val = _e85 + metal::dot(metal::half2(_e80), metal::half2(_e82));
|
||||
half _e87 = val;
|
||||
half _e88 = val;
|
||||
half _e90 = val;
|
||||
val = _e90 + metal::max(_e87, _e88);
|
||||
half _e92 = val;
|
||||
half _e93 = val;
|
||||
half _e95 = val;
|
||||
val = _e95 + metal::min(_e92, _e93);
|
||||
half _e97 = val;
|
||||
half _e99 = val;
|
||||
val = _e99 + metal::sign(_e97);
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
kernel void main_(
|
||||
constant UniformCompatible& input_uniform [[user(fake0)]]
|
||||
, device UniformCompatible const& input_storage [[user(fake0)]]
|
||||
, device StorageCompatible const& input_arrays [[user(fake0)]]
|
||||
, device UniformCompatible& output [[user(fake0)]]
|
||||
, device StorageCompatible& output_arrays [[user(fake0)]]
|
||||
) {
|
||||
half _e3 = f16_function(2.0, input_uniform, input_storage, input_arrays, output, output_arrays);
|
||||
output.final_value = _e3;
|
||||
return;
|
||||
}
|
220
naga/tests/out/spv/float16.spvasm
Normal file
220
naga/tests/out/spv/float16.spvasm
Normal file
@ -0,0 +1,220 @@
|
||||
; SPIR-V
|
||||
; Version: 1.0
|
||||
; Generator: rspirv
|
||||
; Bound: 157
|
||||
OpCapability Shader
|
||||
OpExtension "SPV_KHR_storage_buffer_storage_class"
|
||||
%1 = OpExtInstImport "GLSL.std.450"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %145 "main"
|
||||
OpExecutionMode %145 LocalSize 1 1 1
|
||||
OpMemberDecorate %10 0 Offset 0
|
||||
OpMemberDecorate %10 1 Offset 4
|
||||
OpMemberDecorate %10 2 Offset 8
|
||||
OpMemberDecorate %10 3 Offset 12
|
||||
OpMemberDecorate %10 4 Offset 16
|
||||
OpMemberDecorate %10 5 Offset 24
|
||||
OpMemberDecorate %10 6 Offset 32
|
||||
OpMemberDecorate %10 7 Offset 40
|
||||
OpDecorate %11 ArrayStride 2
|
||||
OpMemberDecorate %13 0 Offset 0
|
||||
OpMemberDecorate %13 1 Offset 4
|
||||
OpDecorate %18 DescriptorSet 0
|
||||
OpDecorate %18 Binding 0
|
||||
OpDecorate %19 Block
|
||||
OpMemberDecorate %19 0 Offset 0
|
||||
OpDecorate %21 NonWritable
|
||||
OpDecorate %21 DescriptorSet 0
|
||||
OpDecorate %21 Binding 1
|
||||
OpDecorate %22 Block
|
||||
OpMemberDecorate %22 0 Offset 0
|
||||
OpDecorate %24 NonWritable
|
||||
OpDecorate %24 DescriptorSet 0
|
||||
OpDecorate %24 Binding 2
|
||||
OpDecorate %25 Block
|
||||
OpMemberDecorate %25 0 Offset 0
|
||||
OpDecorate %27 DescriptorSet 0
|
||||
OpDecorate %27 Binding 3
|
||||
OpDecorate %28 Block
|
||||
OpMemberDecorate %28 0 Offset 0
|
||||
OpDecorate %30 DescriptorSet 0
|
||||
OpDecorate %30 Binding 4
|
||||
OpDecorate %31 Block
|
||||
OpMemberDecorate %31 0 Offset 0
|
||||
%2 = OpTypeVoid
|
||||
%3 = OpTypeFloat 16
|
||||
%4 = OpTypeInt 32 0
|
||||
%5 = OpTypeInt 32 1
|
||||
%6 = OpTypeFloat 32
|
||||
%7 = OpTypeVector %3 2
|
||||
%8 = OpTypeVector %3 3
|
||||
%9 = OpTypeVector %3 4
|
||||
%10 = OpTypeStruct %4 %5 %6 %3 %7 %8 %9 %3
|
||||
%12 = OpConstant %4 2
|
||||
%11 = OpTypeArray %3 %12
|
||||
%13 = OpTypeStruct %11 %11
|
||||
%14 = OpConstant %3 2.1524e-41
|
||||
%15 = OpConstant %3 2.7121e-41
|
||||
%17 = OpTypePointer Private %3
|
||||
%16 = OpVariable %17 Private %14
|
||||
%19 = OpTypeStruct %10
|
||||
%20 = OpTypePointer Uniform %19
|
||||
%18 = OpVariable %20 Uniform
|
||||
%22 = OpTypeStruct %10
|
||||
%23 = OpTypePointer StorageBuffer %22
|
||||
%21 = OpVariable %23 StorageBuffer
|
||||
%25 = OpTypeStruct %13
|
||||
%26 = OpTypePointer StorageBuffer %25
|
||||
%24 = OpVariable %26 StorageBuffer
|
||||
%28 = OpTypeStruct %10
|
||||
%29 = OpTypePointer StorageBuffer %28
|
||||
%27 = OpVariable %29 StorageBuffer
|
||||
%31 = OpTypeStruct %13
|
||||
%32 = OpTypePointer StorageBuffer %31
|
||||
%30 = OpVariable %32 StorageBuffer
|
||||
%36 = OpTypeFunction %3 %3
|
||||
%37 = OpTypePointer Uniform %10
|
||||
%38 = OpConstant %4 0
|
||||
%40 = OpTypePointer StorageBuffer %10
|
||||
%42 = OpTypePointer StorageBuffer %13
|
||||
%46 = OpConstant %3 4.3073e-41
|
||||
%47 = OpConstant %3 2.4753e-41
|
||||
%49 = OpTypePointer Function %3
|
||||
%58 = OpTypePointer Uniform %6
|
||||
%67 = OpTypePointer Uniform %3
|
||||
%68 = OpConstant %4 3
|
||||
%75 = OpTypePointer StorageBuffer %3
|
||||
%82 = OpTypePointer StorageBuffer %7
|
||||
%83 = OpTypePointer Uniform %7
|
||||
%84 = OpConstant %4 4
|
||||
%91 = OpTypePointer StorageBuffer %8
|
||||
%92 = OpTypePointer Uniform %8
|
||||
%93 = OpConstant %4 5
|
||||
%100 = OpTypePointer StorageBuffer %9
|
||||
%101 = OpTypePointer Uniform %9
|
||||
%102 = OpConstant %4 6
|
||||
%109 = OpTypePointer StorageBuffer %11
|
||||
%146 = OpTypeFunction %2
|
||||
%152 = OpConstant %3 2.2959e-41
|
||||
%155 = OpConstant %4 7
|
||||
%35 = OpFunction %3 None %36
|
||||
%34 = OpFunctionParameter %3
|
||||
%33 = OpLabel
|
||||
%48 = OpVariable %49 Function %15
|
||||
%39 = OpAccessChain %37 %18 %38
|
||||
%41 = OpAccessChain %40 %21 %38
|
||||
%43 = OpAccessChain %42 %24 %38
|
||||
%44 = OpAccessChain %40 %27 %38
|
||||
%45 = OpAccessChain %42 %30 %38
|
||||
OpBranch %50
|
||||
%50 = OpLabel
|
||||
%51 = OpFSub %3 %14 %46
|
||||
%52 = OpLoad %3 %48
|
||||
%53 = OpFAdd %3 %52 %51
|
||||
OpStore %48 %53
|
||||
%54 = OpLoad %3 %48
|
||||
%55 = OpFAdd %3 %54 %47
|
||||
%56 = OpLoad %3 %48
|
||||
%57 = OpFAdd %3 %56 %55
|
||||
OpStore %48 %57
|
||||
%59 = OpAccessChain %58 %39 %12
|
||||
%60 = OpLoad %6 %59
|
||||
%61 = OpLoad %3 %48
|
||||
%62 = OpFConvert %6 %61
|
||||
%63 = OpFAdd %6 %60 %62
|
||||
%64 = OpFConvert %3 %63
|
||||
%65 = OpLoad %3 %48
|
||||
%66 = OpFAdd %3 %65 %64
|
||||
OpStore %48 %66
|
||||
%69 = OpAccessChain %67 %39 %68
|
||||
%70 = OpLoad %3 %69
|
||||
%71 = OpCompositeConstruct %8 %70 %70 %70
|
||||
%72 = OpCompositeExtract %3 %71 2
|
||||
%73 = OpLoad %3 %48
|
||||
%74 = OpFAdd %3 %73 %72
|
||||
OpStore %48 %74
|
||||
%76 = OpAccessChain %67 %39 %68
|
||||
%77 = OpLoad %3 %76
|
||||
%78 = OpAccessChain %75 %41 %68
|
||||
%79 = OpLoad %3 %78
|
||||
%80 = OpFAdd %3 %77 %79
|
||||
%81 = OpAccessChain %75 %44 %68
|
||||
OpStore %81 %80
|
||||
%85 = OpAccessChain %83 %39 %84
|
||||
%86 = OpLoad %7 %85
|
||||
%87 = OpAccessChain %82 %41 %84
|
||||
%88 = OpLoad %7 %87
|
||||
%89 = OpFAdd %7 %86 %88
|
||||
%90 = OpAccessChain %82 %44 %84
|
||||
OpStore %90 %89
|
||||
%94 = OpAccessChain %92 %39 %93
|
||||
%95 = OpLoad %8 %94
|
||||
%96 = OpAccessChain %91 %41 %93
|
||||
%97 = OpLoad %8 %96
|
||||
%98 = OpFAdd %8 %95 %97
|
||||
%99 = OpAccessChain %91 %44 %93
|
||||
OpStore %99 %98
|
||||
%103 = OpAccessChain %101 %39 %102
|
||||
%104 = OpLoad %9 %103
|
||||
%105 = OpAccessChain %100 %41 %102
|
||||
%106 = OpLoad %9 %105
|
||||
%107 = OpFAdd %9 %104 %106
|
||||
%108 = OpAccessChain %100 %44 %102
|
||||
OpStore %108 %107
|
||||
%110 = OpAccessChain %109 %43 %38
|
||||
%111 = OpLoad %11 %110
|
||||
%112 = OpAccessChain %109 %45 %38
|
||||
OpStore %112 %111
|
||||
%113 = OpLoad %3 %48
|
||||
%114 = OpExtInst %3 %1 FAbs %113
|
||||
%115 = OpLoad %3 %48
|
||||
%116 = OpFAdd %3 %115 %114
|
||||
OpStore %48 %116
|
||||
%117 = OpLoad %3 %48
|
||||
%118 = OpLoad %3 %48
|
||||
%119 = OpLoad %3 %48
|
||||
%120 = OpExtInst %3 %1 FClamp %117 %118 %119
|
||||
%121 = OpLoad %3 %48
|
||||
%122 = OpFAdd %3 %121 %120
|
||||
OpStore %48 %122
|
||||
%123 = OpLoad %3 %48
|
||||
%124 = OpCompositeConstruct %7 %123 %123
|
||||
%125 = OpLoad %3 %48
|
||||
%126 = OpCompositeConstruct %7 %125 %125
|
||||
%127 = OpDot %3 %124 %126
|
||||
%128 = OpLoad %3 %48
|
||||
%129 = OpFAdd %3 %128 %127
|
||||
OpStore %48 %129
|
||||
%130 = OpLoad %3 %48
|
||||
%131 = OpLoad %3 %48
|
||||
%132 = OpExtInst %3 %1 FMax %130 %131
|
||||
%133 = OpLoad %3 %48
|
||||
%134 = OpFAdd %3 %133 %132
|
||||
OpStore %48 %134
|
||||
%135 = OpLoad %3 %48
|
||||
%136 = OpLoad %3 %48
|
||||
%137 = OpExtInst %3 %1 FMin %135 %136
|
||||
%138 = OpLoad %3 %48
|
||||
%139 = OpFAdd %3 %138 %137
|
||||
OpStore %48 %139
|
||||
%140 = OpLoad %3 %48
|
||||
%141 = OpExtInst %3 %1 FSign %140
|
||||
%142 = OpLoad %3 %48
|
||||
%143 = OpFAdd %3 %142 %141
|
||||
OpStore %48 %143
|
||||
OpReturnValue %14
|
||||
OpFunctionEnd
|
||||
%145 = OpFunction %2 None %146
|
||||
%144 = OpLabel
|
||||
%147 = OpAccessChain %37 %18 %38
|
||||
%148 = OpAccessChain %40 %21 %38
|
||||
%149 = OpAccessChain %42 %24 %38
|
||||
%150 = OpAccessChain %40 %27 %38
|
||||
%151 = OpAccessChain %42 %30 %38
|
||||
OpBranch %153
|
||||
%153 = OpLabel
|
||||
%154 = OpFunctionCall %3 %35 %152
|
||||
%156 = OpAccessChain %75 %150 %155
|
||||
OpStore %156 %154
|
||||
OpReturn
|
||||
OpFunctionEnd
|
91
naga/tests/out/wgsl/float16.wgsl
Normal file
91
naga/tests/out/wgsl/float16.wgsl
Normal file
@ -0,0 +1,91 @@
|
||||
struct UniformCompatible {
|
||||
val_u32_: u32,
|
||||
val_i32_: i32,
|
||||
val_f32_: f32,
|
||||
val_f16_: f16,
|
||||
val_f16_2_: vec2<f16>,
|
||||
val_f16_3_: vec3<f16>,
|
||||
val_f16_4_: vec4<f16>,
|
||||
final_value: f16,
|
||||
}
|
||||
|
||||
struct StorageCompatible {
|
||||
val_f16_array_2_: array<f16, 2>,
|
||||
val_f16_array_2_1: array<f16, 2>,
|
||||
}
|
||||
|
||||
const constant_variable: f16 = 15.203125h;
|
||||
|
||||
var<private> private_variable: f16 = 1h;
|
||||
@group(0) @binding(0)
|
||||
var<uniform> input_uniform: UniformCompatible;
|
||||
@group(0) @binding(1)
|
||||
var<storage> input_storage: UniformCompatible;
|
||||
@group(0) @binding(2)
|
||||
var<storage> input_arrays: StorageCompatible;
|
||||
@group(0) @binding(3)
|
||||
var<storage, read_write> output: UniformCompatible;
|
||||
@group(0) @binding(4)
|
||||
var<storage, read_write> output_arrays: StorageCompatible;
|
||||
|
||||
fn f16_function(x: f16) -> f16 {
|
||||
var val: f16 = 15.203125h;
|
||||
|
||||
let _e6 = val;
|
||||
val = (_e6 + (1h - 33344h));
|
||||
let _e8 = val;
|
||||
let _e11 = val;
|
||||
val = (_e11 + (_e8 + 5h));
|
||||
let _e15 = input_uniform.val_f32_;
|
||||
let _e16 = val;
|
||||
let _e20 = val;
|
||||
val = (_e20 + f16((_e15 + f32(_e16))));
|
||||
let _e24 = input_uniform.val_f16_;
|
||||
let _e27 = val;
|
||||
val = (_e27 + vec3(_e24).z);
|
||||
let _e33 = input_uniform.val_f16_;
|
||||
let _e36 = input_storage.val_f16_;
|
||||
output.val_f16_ = (_e33 + _e36);
|
||||
let _e42 = input_uniform.val_f16_2_;
|
||||
let _e45 = input_storage.val_f16_2_;
|
||||
output.val_f16_2_ = (_e42 + _e45);
|
||||
let _e51 = input_uniform.val_f16_3_;
|
||||
let _e54 = input_storage.val_f16_3_;
|
||||
output.val_f16_3_ = (_e51 + _e54);
|
||||
let _e60 = input_uniform.val_f16_4_;
|
||||
let _e63 = input_storage.val_f16_4_;
|
||||
output.val_f16_4_ = (_e60 + _e63);
|
||||
let _e69 = input_arrays.val_f16_array_2_;
|
||||
output_arrays.val_f16_array_2_ = _e69;
|
||||
let _e70 = val;
|
||||
let _e72 = val;
|
||||
val = (_e72 + abs(_e70));
|
||||
let _e74 = val;
|
||||
let _e75 = val;
|
||||
let _e76 = val;
|
||||
let _e78 = val;
|
||||
val = (_e78 + clamp(_e74, _e75, _e76));
|
||||
let _e80 = val;
|
||||
let _e82 = val;
|
||||
let _e85 = val;
|
||||
val = (_e85 + dot(vec2(_e80), vec2(_e82)));
|
||||
let _e87 = val;
|
||||
let _e88 = val;
|
||||
let _e90 = val;
|
||||
val = (_e90 + max(_e87, _e88));
|
||||
let _e92 = val;
|
||||
let _e93 = val;
|
||||
let _e95 = val;
|
||||
val = (_e95 + min(_e92, _e93));
|
||||
let _e97 = val;
|
||||
let _e99 = val;
|
||||
val = (_e99 + sign(_e97));
|
||||
return 1h;
|
||||
}
|
||||
|
||||
@compute @workgroup_size(1, 1, 1)
|
||||
fn main() {
|
||||
let _e3 = f16_function(2h);
|
||||
output.final_value = _e3;
|
||||
return;
|
||||
}
|
@ -905,6 +905,10 @@ fn convert_wgsl() {
|
||||
"int64",
|
||||
Targets::SPIRV | Targets::HLSL | Targets::WGSL | Targets::METAL,
|
||||
),
|
||||
(
|
||||
"float16",
|
||||
Targets::SPIRV | Targets::HLSL | Targets::WGSL | Targets::METAL,
|
||||
),
|
||||
(
|
||||
"subgroup-operations",
|
||||
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
|
||||
|
@ -481,6 +481,10 @@ pub fn create_validator(
|
||||
features.contains(wgt::Features::PUSH_CONSTANTS),
|
||||
);
|
||||
caps.set(Caps::FLOAT64, features.contains(wgt::Features::SHADER_F64));
|
||||
caps.set(
|
||||
Caps::SHADER_FLOAT16,
|
||||
features.contains(wgt::Features::SHADER_F16),
|
||||
);
|
||||
caps.set(
|
||||
Caps::PRIMITIVE_INDEX,
|
||||
features.contains(wgt::Features::SHADER_PRIMITIVE_INDEX),
|
||||
|
@ -1,5 +1,5 @@
|
||||
use std::{
|
||||
mem::{size_of, size_of_val},
|
||||
mem::{self, size_of, size_of_val},
|
||||
ptr,
|
||||
sync::Arc,
|
||||
thread,
|
||||
@ -378,6 +378,24 @@ impl super::Adapter {
|
||||
&& features1.Int64ShaderOps.as_bool(),
|
||||
);
|
||||
|
||||
let float16_supported = {
|
||||
let mut features4: Direct3D12::D3D12_FEATURE_DATA_D3D12_OPTIONS4 =
|
||||
unsafe { mem::zeroed() };
|
||||
let hr = unsafe {
|
||||
device.CheckFeatureSupport(
|
||||
Direct3D12::D3D12_FEATURE_D3D12_OPTIONS4, // https://learn.microsoft.com/en-us/windows/win32/api/d3d12/ne-d3d12-d3d12_feature#syntax
|
||||
ptr::from_mut(&mut features4).cast(),
|
||||
size_of::<Direct3D12::D3D12_FEATURE_DATA_D3D12_OPTIONS4>() as _,
|
||||
)
|
||||
};
|
||||
hr.is_ok() && features4.Native16BitShaderOpsSupported.as_bool()
|
||||
};
|
||||
|
||||
features.set(
|
||||
wgt::Features::SHADER_F16,
|
||||
shader_model >= naga::back::hlsl::ShaderModel::V6_2 && float16_supported,
|
||||
);
|
||||
|
||||
features.set(
|
||||
wgt::Features::SUBGROUP,
|
||||
shader_model >= naga::back::hlsl::ShaderModel::V6_0
|
||||
|
@ -382,6 +382,7 @@ impl PhysicalDeviceFeatures {
|
||||
vk::PhysicalDeviceShaderFloat16Int8Features::default().shader_float16(true),
|
||||
vk::PhysicalDevice16BitStorageFeatures::default()
|
||||
.storage_buffer16_bit_access(true)
|
||||
.storage_input_output16(true)
|
||||
.uniform_and_storage_buffer16_bit_access(true),
|
||||
))
|
||||
} else {
|
||||
@ -664,7 +665,8 @@ impl PhysicalDeviceFeatures {
|
||||
F::SHADER_F16,
|
||||
f16_i8.shader_float16 != 0
|
||||
&& bit16.storage_buffer16_bit_access != 0
|
||||
&& bit16.uniform_and_storage_buffer16_bit_access != 0,
|
||||
&& bit16.uniform_and_storage_buffer16_bit_access != 0
|
||||
&& bit16.storage_input_output16 != 0,
|
||||
);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user