mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-21 22:34:34 +00:00
WIP SampleParams
This commit is contained in:
parent
fd73e1b462
commit
061727bba0
@ -14,4 +14,4 @@ proc-macro = true
|
||||
spirv-std-types.workspace = true
|
||||
proc-macro2 = "1.0.24"
|
||||
quote = "1.0.8"
|
||||
syn = { version = "1.0.58", features = ["full"] }
|
||||
syn = { version = "1.0.58", features = ["full", "visit-mut"] }
|
||||
|
@ -76,7 +76,7 @@ mod image;
|
||||
use proc_macro::TokenStream;
|
||||
use proc_macro2::{Delimiter, Group, Ident, Span, TokenTree};
|
||||
|
||||
use syn::{punctuated::Punctuated, spanned::Spanned, ItemFn, Token};
|
||||
use syn::{punctuated::Punctuated, spanned::Spanned, visit_mut::VisitMut, ItemFn, Token};
|
||||
|
||||
use quote::{quote, ToTokens};
|
||||
use std::fmt::Write;
|
||||
@ -625,3 +625,132 @@ fn debug_printf_inner(input: DebugPrintfInput) -> TokenStream {
|
||||
|
||||
output.into()
|
||||
}
|
||||
|
||||
const SAMPLE_PARAM_COUNT: usize = 3;
|
||||
const SAMPLE_PARAM_TYPES: [&str; SAMPLE_PARAM_COUNT] = ["B", "L", "S"];
|
||||
const SAMPLE_PARAM_OPERANDS: [&str; SAMPLE_PARAM_COUNT] = ["Bias", "Lod", "Sample"];
|
||||
const SAMPLE_PARAM_NAMES: [&str; SAMPLE_PARAM_COUNT] = ["bias", "lod", "sample_index"];
|
||||
|
||||
struct SampleFnRewriter(usize);
|
||||
|
||||
impl SampleFnRewriter {
|
||||
pub fn rewrite(mask: usize, f: &syn::ItemFn) -> syn::ItemFn {
|
||||
let mut new_f = f.clone();
|
||||
let mut ty = String::from("SampleParams<");
|
||||
|
||||
for i in 0..SAMPLE_PARAM_COUNT {
|
||||
if mask & (1 << i) != 0 {
|
||||
new_f.sig.generics.params.push(syn::GenericParam::Type(
|
||||
syn::Ident::new(SAMPLE_PARAM_TYPES[i], Span::call_site()).into(),
|
||||
));
|
||||
ty.push_str(SAMPLE_PARAM_TYPES[i]);
|
||||
} else {
|
||||
ty.push_str("()");
|
||||
}
|
||||
ty.push(',');
|
||||
}
|
||||
ty.push('>');
|
||||
if let Some(syn::FnArg::Typed(p)) = new_f.sig.inputs.last_mut() {
|
||||
*p.ty.as_mut() = syn::parse(ty.parse().unwrap()).unwrap();
|
||||
}
|
||||
SampleFnRewriter(mask).visit_item_fn_mut(&mut new_f);
|
||||
new_f
|
||||
}
|
||||
|
||||
fn get_operands(&self) -> String {
|
||||
let mut op = String::new();
|
||||
for i in 0..SAMPLE_PARAM_COUNT {
|
||||
if self.0 & (1 << i) != 0 {
|
||||
op.push_str(SAMPLE_PARAM_OPERANDS[i]);
|
||||
op.push_str(" %");
|
||||
op.push_str(SAMPLE_PARAM_NAMES[i]);
|
||||
op.push(' ');
|
||||
}
|
||||
}
|
||||
op
|
||||
}
|
||||
|
||||
fn add_loads(&self, t: &mut Vec<TokenTree>) {
|
||||
for i in 0..SAMPLE_PARAM_COUNT {
|
||||
if self.0 & (1 << i) != 0 {
|
||||
let s = format!("%{0} = OpLoad _ {{{0}}}", SAMPLE_PARAM_NAMES[i]);
|
||||
t.push(TokenTree::Literal(proc_macro2::Literal::string(s.as_str())));
|
||||
t.push(TokenTree::Punct(proc_macro2::Punct::new(
|
||||
',',
|
||||
proc_macro2::Spacing::Alone,
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn add_regs(&self, t: &mut Vec<TokenTree>) {
|
||||
for i in 0..SAMPLE_PARAM_COUNT {
|
||||
if self.0 & (1 << i) != 0 {
|
||||
let s = format!("{0} = in(reg) ¶m.{0}", SAMPLE_PARAM_NAMES[i]);
|
||||
t.push(TokenTree::Literal(proc_macro2::Literal::string(s.as_str())));
|
||||
t.push(TokenTree::Punct(proc_macro2::Punct::new(
|
||||
',',
|
||||
proc_macro2::Spacing::Alone,
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl syn::visit_mut::VisitMut for SampleFnRewriter {
|
||||
fn visit_macro_mut(&mut self, m: &mut syn::Macro) {
|
||||
if m.path.is_ident("asm") {
|
||||
let t = m.tokens.clone();
|
||||
let mut new_t = Vec::new();
|
||||
let mut altered = false;
|
||||
|
||||
for tt in t {
|
||||
match tt {
|
||||
TokenTree::Literal(l) => {
|
||||
if let Ok(l) = syn::parse::<syn::LitStr>(l.to_token_stream().into()) {
|
||||
let s = l.value();
|
||||
if s.contains("$PARAMS") {
|
||||
altered = true;
|
||||
self.add_loads(&mut new_t);
|
||||
let s = s.replace("$PARAMS", &self.get_operands());
|
||||
new_t.push(TokenTree::Literal(proc_macro2::Literal::string(
|
||||
s.as_str(),
|
||||
)));
|
||||
} else {
|
||||
new_t.push(TokenTree::Literal(l.token()));
|
||||
}
|
||||
} else {
|
||||
new_t.push(TokenTree::Literal(l));
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
new_t.push(tt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if altered {
|
||||
self.add_regs(&mut new_t);
|
||||
}
|
||||
|
||||
m.tokens = new_t.into_iter().collect();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates permutations of a sampling function containing inline assembly with an image
|
||||
/// instruction ending with a placeholder `$PARAMS` operand. The last parameter must be named
|
||||
/// `params`, its type will be rewritten. Relevant generic arguments are added to the function
|
||||
/// signature. See `SAMPLE_PARAM_TYPES` for a list of names you cannot use as generic arguments.
|
||||
#[proc_macro_attribute]
|
||||
#[doc(hidden)]
|
||||
pub fn gen_sample_param_permutations(_attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||
let item_fn = syn::parse_macro_input!(item as syn::ItemFn);
|
||||
let mut fns = Vec::new();
|
||||
|
||||
for m in 1..((1 << SAMPLE_PARAM_COUNT) - 1) {
|
||||
fns.push(SampleFnRewriter::rewrite(m, &item_fn));
|
||||
}
|
||||
|
||||
quote! { #(#fns)* }.into()
|
||||
}
|
||||
|
@ -10,7 +10,7 @@ use core::arch::asm;
|
||||
#[rustfmt::skip]
|
||||
mod params;
|
||||
|
||||
pub use self::params::{ImageCoordinate, ImageCoordinateSubpassData, SampleType};
|
||||
pub use self::params::{ImageCoordinate, ImageCoordinateSubpassData, SampleParams, SampleType};
|
||||
pub use crate::macros::Image;
|
||||
pub use spirv_std_types::image_params::{
|
||||
AccessQualifier, Arrayed, Dimensionality, ImageDepth, ImageFormat, Multisampled, Sampled,
|
||||
@ -158,6 +158,33 @@ impl<
|
||||
}
|
||||
result.truncate_into()
|
||||
}
|
||||
|
||||
#[crate::macros::gen_sample_param_permutations]
|
||||
#[crate::macros::gpu_only]
|
||||
#[doc(alias = "OpImageFetch")]
|
||||
pub fn fetch_with<I>(
|
||||
&self,
|
||||
coordinate: impl ImageCoordinate<I, DIM, ARRAYED>,
|
||||
params: SampleParams,
|
||||
) -> SampledType::SampleResult
|
||||
where
|
||||
I: Integer,
|
||||
B: Integer,
|
||||
{
|
||||
let mut result = SampledType::Vec4::default();
|
||||
unsafe {
|
||||
asm! {
|
||||
"%image = OpLoad _ {this}",
|
||||
"%coordinate = OpLoad _ {coordinate}",
|
||||
"%result = OpImageFetch typeof*{result} %image %coordinate $PARAMS",
|
||||
"OpStore {result} %result",
|
||||
result = in(reg) &mut result,
|
||||
this = in(reg) self,
|
||||
coordinate = in(reg) &coordinate,
|
||||
}
|
||||
}
|
||||
result.truncate_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
|
@ -194,3 +194,66 @@ impl<V: Vector<S, 4>, S: Scalar>
|
||||
pub trait ImageCoordinateSubpassData<T, const ARRAYED: u32> {}
|
||||
impl<V: Vector<I, 2>, I: Integer> ImageCoordinateSubpassData<I, { Arrayed::False as u32 }> for V {}
|
||||
impl<V: Vector<I, 3>, I: Integer> ImageCoordinateSubpassData<I, { Arrayed::True as u32 }> for V {}
|
||||
|
||||
/// Helper struct that allows building image operands. Start with a global function that returns this
|
||||
/// struct, and then chain additional calls.
|
||||
/// Example: `image.sample_with(coords, params::bias(3.0).sample_index(1))`
|
||||
pub struct SampleParams<B, L, S> {
|
||||
bias: B,
|
||||
lod: L,
|
||||
sample_index: S,
|
||||
}
|
||||
|
||||
pub fn bias<B>(bias: B) -> SampleParams<B, (), ()> {
|
||||
SampleParams {
|
||||
bias,
|
||||
lod: (),
|
||||
sample_index: (),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn lod<L>(lod: L) -> SampleParams<(), L, ()> {
|
||||
SampleParams {
|
||||
bias: (),
|
||||
lod,
|
||||
sample_index: (),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sample_index<S>(sample_index: S) -> SampleParams<(), (), S> {
|
||||
SampleParams {
|
||||
bias: (),
|
||||
lod: (),
|
||||
sample_index,
|
||||
}
|
||||
}
|
||||
|
||||
impl<L, S> SampleParams<(), L, S> {
|
||||
pub fn bias<B>(self, bias: B) -> SampleParams<B, L, S> {
|
||||
SampleParams {
|
||||
bias,
|
||||
lod: self.lod,
|
||||
sample_index: self.sample_index,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, S> SampleParams<B, (), S> {
|
||||
pub fn lod<L>(self, lod: L) -> SampleParams<B, L, S> {
|
||||
SampleParams {
|
||||
bias: self.bias,
|
||||
lod: lod,
|
||||
sample_index: self.sample_index,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, L> SampleParams<B, L, ()> {
|
||||
pub fn sample_index<S>(self, sample_index: S) -> SampleParams<B, L, S> {
|
||||
SampleParams {
|
||||
bias: self.bias,
|
||||
lod: self.lod,
|
||||
sample_index: sample_index,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user