WIP SampleParams

This commit is contained in:
Sylvester Hesp 2023-04-20 12:57:54 +02:00
parent fd73e1b462
commit 061727bba0
4 changed files with 222 additions and 3 deletions

View File

@ -14,4 +14,4 @@ proc-macro = true
spirv-std-types.workspace = true
proc-macro2 = "1.0.24"
quote = "1.0.8"
syn = { version = "1.0.58", features = ["full"] }
syn = { version = "1.0.58", features = ["full", "visit-mut"] }

View File

@ -76,7 +76,7 @@ mod image;
use proc_macro::TokenStream;
use proc_macro2::{Delimiter, Group, Ident, Span, TokenTree};
use syn::{punctuated::Punctuated, spanned::Spanned, ItemFn, Token};
use syn::{punctuated::Punctuated, spanned::Spanned, visit_mut::VisitMut, ItemFn, Token};
use quote::{quote, ToTokens};
use std::fmt::Write;
@ -625,3 +625,132 @@ fn debug_printf_inner(input: DebugPrintfInput) -> TokenStream {
output.into()
}
const SAMPLE_PARAM_COUNT: usize = 3;
const SAMPLE_PARAM_TYPES: [&str; SAMPLE_PARAM_COUNT] = ["B", "L", "S"];
const SAMPLE_PARAM_OPERANDS: [&str; SAMPLE_PARAM_COUNT] = ["Bias", "Lod", "Sample"];
const SAMPLE_PARAM_NAMES: [&str; SAMPLE_PARAM_COUNT] = ["bias", "lod", "sample_index"];
struct SampleFnRewriter(usize);
impl SampleFnRewriter {
pub fn rewrite(mask: usize, f: &syn::ItemFn) -> syn::ItemFn {
let mut new_f = f.clone();
let mut ty = String::from("SampleParams<");
for i in 0..SAMPLE_PARAM_COUNT {
if mask & (1 << i) != 0 {
new_f.sig.generics.params.push(syn::GenericParam::Type(
syn::Ident::new(SAMPLE_PARAM_TYPES[i], Span::call_site()).into(),
));
ty.push_str(SAMPLE_PARAM_TYPES[i]);
} else {
ty.push_str("()");
}
ty.push(',');
}
ty.push('>');
if let Some(syn::FnArg::Typed(p)) = new_f.sig.inputs.last_mut() {
*p.ty.as_mut() = syn::parse(ty.parse().unwrap()).unwrap();
}
SampleFnRewriter(mask).visit_item_fn_mut(&mut new_f);
new_f
}
fn get_operands(&self) -> String {
let mut op = String::new();
for i in 0..SAMPLE_PARAM_COUNT {
if self.0 & (1 << i) != 0 {
op.push_str(SAMPLE_PARAM_OPERANDS[i]);
op.push_str(" %");
op.push_str(SAMPLE_PARAM_NAMES[i]);
op.push(' ');
}
}
op
}
fn add_loads(&self, t: &mut Vec<TokenTree>) {
for i in 0..SAMPLE_PARAM_COUNT {
if self.0 & (1 << i) != 0 {
let s = format!("%{0} = OpLoad _ {{{0}}}", SAMPLE_PARAM_NAMES[i]);
t.push(TokenTree::Literal(proc_macro2::Literal::string(s.as_str())));
t.push(TokenTree::Punct(proc_macro2::Punct::new(
',',
proc_macro2::Spacing::Alone,
)))
}
}
}
fn add_regs(&self, t: &mut Vec<TokenTree>) {
for i in 0..SAMPLE_PARAM_COUNT {
if self.0 & (1 << i) != 0 {
let s = format!("{0} = in(reg) &param.{0}", SAMPLE_PARAM_NAMES[i]);
t.push(TokenTree::Literal(proc_macro2::Literal::string(s.as_str())));
t.push(TokenTree::Punct(proc_macro2::Punct::new(
',',
proc_macro2::Spacing::Alone,
)))
}
}
}
}
impl syn::visit_mut::VisitMut for SampleFnRewriter {
fn visit_macro_mut(&mut self, m: &mut syn::Macro) {
if m.path.is_ident("asm") {
let t = m.tokens.clone();
let mut new_t = Vec::new();
let mut altered = false;
for tt in t {
match tt {
TokenTree::Literal(l) => {
if let Ok(l) = syn::parse::<syn::LitStr>(l.to_token_stream().into()) {
let s = l.value();
if s.contains("$PARAMS") {
altered = true;
self.add_loads(&mut new_t);
let s = s.replace("$PARAMS", &self.get_operands());
new_t.push(TokenTree::Literal(proc_macro2::Literal::string(
s.as_str(),
)));
} else {
new_t.push(TokenTree::Literal(l.token()));
}
} else {
new_t.push(TokenTree::Literal(l));
}
}
_ => {
new_t.push(tt);
}
}
}
if altered {
self.add_regs(&mut new_t);
}
m.tokens = new_t.into_iter().collect();
}
}
}
/// Generates permutations of a sampling function containing inline assembly with an image
/// instruction ending with a placeholder `$PARAMS` operand. The last parameter must be named
/// `params`, its type will be rewritten. Relevant generic arguments are added to the function
/// signature. See `SAMPLE_PARAM_TYPES` for a list of names you cannot use as generic arguments.
#[proc_macro_attribute]
#[doc(hidden)]
pub fn gen_sample_param_permutations(_attr: TokenStream, item: TokenStream) -> TokenStream {
let item_fn = syn::parse_macro_input!(item as syn::ItemFn);
let mut fns = Vec::new();
for m in 1..((1 << SAMPLE_PARAM_COUNT) - 1) {
fns.push(SampleFnRewriter::rewrite(m, &item_fn));
}
quote! { #(#fns)* }.into()
}

View File

@ -10,7 +10,7 @@ use core::arch::asm;
#[rustfmt::skip]
mod params;
pub use self::params::{ImageCoordinate, ImageCoordinateSubpassData, SampleType};
pub use self::params::{ImageCoordinate, ImageCoordinateSubpassData, SampleParams, SampleType};
pub use crate::macros::Image;
pub use spirv_std_types::image_params::{
AccessQualifier, Arrayed, Dimensionality, ImageDepth, ImageFormat, Multisampled, Sampled,
@ -158,6 +158,33 @@ impl<
}
result.truncate_into()
}
#[crate::macros::gen_sample_param_permutations]
#[crate::macros::gpu_only]
#[doc(alias = "OpImageFetch")]
pub fn fetch_with<I>(
&self,
coordinate: impl ImageCoordinate<I, DIM, ARRAYED>,
params: SampleParams,
) -> SampledType::SampleResult
where
I: Integer,
B: Integer,
{
let mut result = SampledType::Vec4::default();
unsafe {
asm! {
"%image = OpLoad _ {this}",
"%coordinate = OpLoad _ {coordinate}",
"%result = OpImageFetch typeof*{result} %image %coordinate $PARAMS",
"OpStore {result} %result",
result = in(reg) &mut result,
this = in(reg) self,
coordinate = in(reg) &coordinate,
}
}
result.truncate_into()
}
}
impl<

View File

@ -194,3 +194,66 @@ impl<V: Vector<S, 4>, S: Scalar>
pub trait ImageCoordinateSubpassData<T, const ARRAYED: u32> {}
impl<V: Vector<I, 2>, I: Integer> ImageCoordinateSubpassData<I, { Arrayed::False as u32 }> for V {}
impl<V: Vector<I, 3>, I: Integer> ImageCoordinateSubpassData<I, { Arrayed::True as u32 }> for V {}
/// Helper struct that allows building image operands. Start with a global function that returns this
/// struct, and then chain additional calls.
/// Example: `image.sample_with(coords, params::bias(3.0).sample_index(1))`
pub struct SampleParams<B, L, S> {
bias: B,
lod: L,
sample_index: S,
}
pub fn bias<B>(bias: B) -> SampleParams<B, (), ()> {
SampleParams {
bias,
lod: (),
sample_index: (),
}
}
pub fn lod<L>(lod: L) -> SampleParams<(), L, ()> {
SampleParams {
bias: (),
lod,
sample_index: (),
}
}
pub fn sample_index<S>(sample_index: S) -> SampleParams<(), (), S> {
SampleParams {
bias: (),
lod: (),
sample_index,
}
}
impl<L, S> SampleParams<(), L, S> {
pub fn bias<B>(self, bias: B) -> SampleParams<B, L, S> {
SampleParams {
bias,
lod: self.lod,
sample_index: self.sample_index,
}
}
}
impl<B, S> SampleParams<B, (), S> {
pub fn lod<L>(self, lod: L) -> SampleParams<B, L, S> {
SampleParams {
bias: self.bias,
lod: lod,
sample_index: self.sample_index,
}
}
}
impl<B, L> SampleParams<B, L, ()> {
pub fn sample_index<S>(self, sample_index: S) -> SampleParams<B, L, S> {
SampleParams {
bias: self.bias,
lod: self.lod,
sample_index: sample_index,
}
}
}