Refactor snapshot testing to include some IR

This commit is contained in:
Dzmitry Malyshau 2021-02-17 21:46:32 -05:00 committed by Dzmitry Malyshau
parent 4c5a1ba054
commit 5f21cf360f
5 changed files with 1835 additions and 71 deletions

View File

@ -11,6 +11,8 @@ use std::ops;
bitflags::bitflags! {
#[derive(Default)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct ControlFlags: u8 {
/// The result (of an expression) is not dynamically uniform.
///
@ -57,18 +59,24 @@ bitflags::bitflags! {
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct SamplingKey {
pub image: Handle<crate::GlobalVariable>,
pub sampler: Handle<crate::GlobalVariable>,
}
#[derive(Clone, Default)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct ExpressionInfo {
pub control_flags: ControlFlags,
pub ref_count: usize,
assignable_global: Option<Handle<crate::GlobalVariable>>,
}
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct FunctionInfo {
/// Accumulated control flags of this function.
pub control_flags: ControlFlags,
@ -459,6 +467,8 @@ impl FunctionInfo {
}
#[derive(Default)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct Analysis {
functions: Vec<FunctionInfo>,
entry_points: crate::FastHashMap<(crate::ShaderStage, String), FunctionInfo>,

View File

@ -0,0 +1,245 @@
---
source: tests/snapshots.rs
expression: output
---
(
functions: [
(
control_flags: (
bits: 5,
),
sampling_set: [],
global_uses: [
(
bits: 0,
),
(
bits: 0,
),
],
expressions: [
(
control_flags: (
bits: 1,
),
ref_count: 0,
assignable_global: Some(1),
),
(
control_flags: (
bits: 1,
),
ref_count: 0,
assignable_global: Some(2),
),
(
control_flags: (
bits: 1,
),
ref_count: 1,
assignable_global: None,
),
(
control_flags: (
bits: 1,
),
ref_count: 7,
assignable_global: None,
),
(
control_flags: (
bits: 0,
),
ref_count: 0,
assignable_global: None,
),
(
control_flags: (
bits: 1,
),
ref_count: 3,
assignable_global: None,
),
(
control_flags: (
bits: 0,
),
ref_count: 1,
assignable_global: None,
),
(
control_flags: (
bits: 1,
),
ref_count: 1,
assignable_global: None,
),
(
control_flags: (
bits: 0,
),
ref_count: 1,
assignable_global: None,
),
(
control_flags: (
bits: 1,
),
ref_count: 1,
assignable_global: None,
),
(
control_flags: (
bits: 0,
),
ref_count: 1,
assignable_global: None,
),
(
control_flags: (
bits: 1,
),
ref_count: 1,
assignable_global: None,
),
(
control_flags: (
bits: 0,
),
ref_count: 1,
assignable_global: None,
),
(
control_flags: (
bits: 1,
),
ref_count: 1,
assignable_global: None,
),
(
control_flags: (
bits: 0,
),
ref_count: 1,
assignable_global: None,
),
(
control_flags: (
bits: 1,
),
ref_count: 1,
assignable_global: None,
),
(
control_flags: (
bits: 0,
),
ref_count: 1,
assignable_global: None,
),
(
control_flags: (
bits: 1,
),
ref_count: 1,
assignable_global: None,
),
(
control_flags: (
bits: 0,
),
ref_count: 1,
assignable_global: None,
),
(
control_flags: (
bits: 1,
),
ref_count: 1,
assignable_global: None,
),
],
),
],
entry_points: {
(Compute, "main"): (
control_flags: (
bits: 5,
),
sampling_set: [],
global_uses: [
(
bits: 1,
),
(
bits: 3,
),
],
expressions: [
(
control_flags: (
bits: 1,
),
ref_count: 2,
assignable_global: Some(1),
),
(
control_flags: (
bits: 1,
),
ref_count: 2,
assignable_global: Some(2),
),
(
control_flags: (
bits: 1,
),
ref_count: 1,
assignable_global: Some(2),
),
(
control_flags: (
bits: 1,
),
ref_count: 1,
assignable_global: Some(1),
),
(
control_flags: (
bits: 1,
),
ref_count: 1,
assignable_global: Some(2),
),
(
control_flags: (
bits: 1,
),
ref_count: 1,
assignable_global: Some(2),
),
(
control_flags: (
bits: 1,
),
ref_count: 1,
assignable_global: Some(1),
),
(
control_flags: (
bits: 1,
),
ref_count: 1,
assignable_global: Some(2),
),
(
control_flags: (
bits: 5,
),
ref_count: 1,
assignable_global: None,
),
],
),
},
)

274
tests/out/collatz.ron.snap Normal file
View File

@ -0,0 +1,274 @@
---
source: tests/snapshots.rs
expression: output
---
(
types: [
(
name: None,
inner: Vector(
size: Tri,
kind: Uint,
width: 4,
),
),
(
name: None,
inner: Scalar(
kind: Uint,
width: 4,
),
),
(
name: None,
inner: Array(
base: 2,
size: Dynamic,
stride: Some(4),
),
),
(
name: Some("PrimeIndices"),
inner: Struct(
block: true,
members: [
(
name: Some("data"),
span: None,
ty: 3,
),
],
),
),
],
constants: [
(
name: None,
specialization: None,
inner: Scalar(
width: 4,
value: Uint(0),
),
),
(
name: None,
specialization: None,
inner: Scalar(
width: 4,
value: Uint(1),
),
),
(
name: None,
specialization: None,
inner: Scalar(
width: 4,
value: Uint(2),
),
),
(
name: None,
specialization: None,
inner: Scalar(
width: 4,
value: Uint(3),
),
),
],
global_variables: [
(
name: Some("global_id"),
class: Input,
binding: Some(BuiltIn(GlobalInvocationId)),
ty: 1,
init: None,
interpolation: None,
storage_access: (
bits: 0,
),
),
(
name: Some("v_indices"),
class: Storage,
binding: Some(Resource(
group: 0,
binding: 0,
)),
ty: 4,
init: None,
interpolation: None,
storage_access: (
bits: 3,
),
),
],
functions: [
(
name: Some("collatz_iterations"),
arguments: [
(
name: Some("n_base"),
ty: 2,
),
],
return_type: Some(2),
local_variables: [
(
name: Some("n"),
ty: 2,
init: None,
),
(
name: Some("i"),
ty: 2,
init: Some(1),
),
],
expressions: [
GlobalVariable(1),
GlobalVariable(2),
FunctionArgument(0),
LocalVariable(1),
Constant(1),
LocalVariable(2),
Constant(2),
Binary(
op: LessEqual,
left: 4,
right: 7,
),
Constant(3),
Binary(
op: Modulo,
left: 4,
right: 9,
),
Constant(1),
Binary(
op: Equal,
left: 10,
right: 11,
),
Constant(3),
Binary(
op: Divide,
left: 4,
right: 13,
),
Constant(4),
Binary(
op: Multiply,
left: 15,
right: 4,
),
Constant(2),
Binary(
op: Add,
left: 16,
right: 17,
),
Constant(2),
Binary(
op: Add,
left: 6,
right: 19,
),
],
body: [
Store(
pointer: 4,
value: 3,
),
Loop(
body: [
If(
condition: 8,
accept: [
Break,
],
reject: [],
),
If(
condition: 12,
accept: [
Store(
pointer: 4,
value: 14,
),
],
reject: [
Store(
pointer: 4,
value: 18,
),
],
),
Store(
pointer: 6,
value: 20,
),
],
continuing: [],
),
Return(
value: Some(6),
),
],
),
],
entry_points: {
(Compute, "main"): (
early_depth_test: None,
workgroup_size: (1, 1, 1),
function: (
name: Some("main"),
arguments: [],
return_type: None,
local_variables: [],
expressions: [
GlobalVariable(1),
GlobalVariable(2),
AccessIndex(
base: 2,
index: 0,
),
AccessIndex(
base: 1,
index: 0,
),
Access(
base: 3,
index: 4,
),
AccessIndex(
base: 2,
index: 0,
),
AccessIndex(
base: 1,
index: 0,
),
Access(
base: 6,
index: 7,
),
Call(
function: 1,
arguments: [
8,
],
),
],
body: [
Store(
pointer: 5,
value: 9,
),
Return(
value: None,
),
],
),
),
},
)

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +1,10 @@
bitflags::bitflags! {
struct Language: u32 {
const SPIRV = 0x1;
const METAL = 0x2;
const GLSL = 0x4;
struct Targets: u32 {
const IR = 0x1;
const ANALYSIS = 0x2;
const SPIRV = 0x4;
const METAL = 0x8;
const GLSL = 0x10;
}
}
@ -57,6 +59,54 @@ fn with_snapshot_settings<F: FnOnce() -> ()>(snapshot_assertion: F) {
settings.bind(|| snapshot_assertion());
}
#[allow(unused_variables)]
fn check_targets(module: &naga::Module, name: &str, targets: Targets) {
let params = match std::fs::read_to_string(format!("tests/in/{}{}", name, ".param.ron")) {
Ok(string) => ron::de::from_str(&string).expect("Couldn't find param file"),
Err(_) => Parameters::default(),
};
let analysis = naga::proc::Validator::new().validate(module).unwrap();
#[cfg(feature = "serialize")]
{
if targets.contains(Targets::IR) {
let config = ron::ser::PrettyConfig::default().with_new_line("\n".to_string());
let output = ron::ser::to_string_pretty(module, config).unwrap();
with_snapshot_settings(|| {
insta::assert_snapshot!(format!("{}.ron", name), output);
});
}
if targets.contains(Targets::ANALYSIS) {
let config = ron::ser::PrettyConfig::default().with_new_line("\n".to_string());
let output = ron::ser::to_string_pretty(&analysis, config).unwrap();
with_snapshot_settings(|| {
insta::assert_snapshot!(format!("{}.info.ron", name), output);
});
}
}
#[cfg(feature = "spv-out")]
{
if targets.contains(Targets::SPIRV) {
check_output_spv(module, &analysis, name, &params);
}
}
#[cfg(feature = "msl-out")]
{
if targets.contains(Targets::METAL) {
check_output_msl(module, &analysis, name, &params);
}
}
#[cfg(feature = "glsl-out")]
{
if targets.contains(Targets::GLSL) {
for &(stage, ref ep_name) in module.entry_points.keys() {
check_output_glsl(module, &analysis, name, stage, ep_name);
}
}
}
}
#[cfg(feature = "spv-out")]
fn check_output_spv(
module: &naga::Module,
@ -152,115 +202,83 @@ fn check_output_glsl(
}
#[cfg(feature = "wgsl-in")]
fn convert_wgsl(name: &str, language: Language) {
let params = match std::fs::read_to_string(format!("tests/in/{}{}", name, ".param.ron")) {
Ok(string) => ron::de::from_str(&string).expect("Couldn't find param file"),
Err(_) => Parameters::default(),
};
fn convert_wgsl(name: &str, targets: Targets) {
let module = naga::front::wgsl::parse_str(
&std::fs::read_to_string(format!("tests/in/{}{}", name, ".wgsl"))
.expect("Couldn't find wgsl file"),
)
.unwrap();
let analysis = naga::proc::Validator::new().validate(&module).unwrap();
#[cfg(feature = "spv-out")]
{
if language.contains(Language::SPIRV) {
check_output_spv(&module, &analysis, name, &params);
}
}
#[cfg(feature = "msl-out")]
{
if language.contains(Language::METAL) {
check_output_msl(&module, &analysis, name, &params);
}
}
#[cfg(feature = "glsl-out")]
{
if language.contains(Language::GLSL) {
for &(stage, ref ep_name) in module.entry_points.keys() {
check_output_glsl(&module, &analysis, name, stage, ep_name);
}
}
}
check_targets(&module, name, targets);
}
#[cfg(feature = "wgsl-in")]
#[test]
fn convert_wgsl_quad() {
convert_wgsl("quad", Language::all());
convert_wgsl("quad", Targets::SPIRV | Targets::METAL | Targets::GLSL);
}
#[cfg(feature = "wgsl-in")]
#[test]
fn convert_wgsl_empty() {
convert_wgsl("empty", Language::all());
convert_wgsl("empty", Targets::SPIRV | Targets::METAL | Targets::GLSL);
}
#[cfg(feature = "wgsl-in")]
#[test]
fn convert_wgsl_boids() {
convert_wgsl("boids", Language::METAL | Language::SPIRV);
convert_wgsl("boids", Targets::SPIRV | Targets::METAL);
}
#[cfg(feature = "wgsl-in")]
#[test]
fn convert_wgsl_skybox() {
convert_wgsl("skybox", Language::all());
convert_wgsl("skybox", Targets::SPIRV | Targets::METAL | Targets::GLSL);
}
#[cfg(feature = "wgsl-in")]
#[test]
fn convert_wgsl_collatz() {
convert_wgsl("collatz", Language::METAL | Language::SPIRV);
convert_wgsl(
"collatz",
Targets::SPIRV | Targets::METAL | Targets::IR | Targets::ANALYSIS,
);
}
#[cfg(feature = "wgsl-in")]
#[test]
fn convert_wgsl_shadow() {
convert_wgsl("shadow", Language::METAL | Language::SPIRV);
convert_wgsl("shadow", Targets::SPIRV | Targets::METAL);
}
#[cfg(feature = "wgsl-in")]
#[test]
fn convert_wgsl_texture_array() {
convert_wgsl("texture-array", Language::SPIRV);
convert_wgsl("texture-array", Targets::SPIRV);
}
#[cfg(feature = "spv-in")]
fn convert_spv(name: &str) {
fn convert_spv(name: &str, targets: Targets) {
let module = naga::front::spv::parse_u8_slice(
&std::fs::read(format!("tests/in/{}{}", name, ".spv")).expect("Couldn't find spv file"),
&Default::default(),
)
.unwrap();
check_targets(&module, name, targets);
naga::proc::Validator::new().validate(&module).unwrap();
#[cfg(feature = "serialize")]
{
let config = ron::ser::PrettyConfig::default().with_new_line("\n".to_string());
let output = ron::ser::to_string_pretty(&module, config).unwrap();
with_snapshot_settings(|| {
insta::assert_snapshot!(format!("{}.ron", name), output);
});
}
}
#[cfg(feature = "spv-in")]
#[test]
fn convert_spv_shadow() {
convert_spv("shadow");
convert_spv("shadow", Targets::IR | Targets::ANALYSIS);
}
#[cfg(feature = "glsl-in")]
fn convert_glsl(name: &str, entry_points: naga::FastHashMap<String, naga::ShaderStage>) {
let params = match std::fs::read_to_string(format!("tests/in/{}{}", name, ".param.ron")) {
Ok(string) => ron::de::from_str(&string).expect("Couldn't find param file"),
Err(_) => Parameters::default(),
};
fn convert_glsl(
name: &str,
entry_points: naga::FastHashMap<String, naga::ShaderStage>,
targets: Targets,
) {
let module = naga::front::glsl::parse_str(
&std::fs::read_to_string(format!("tests/in/{}{}", name, ".glsl"))
.expect("Couldn't find glsl file"),
@ -270,20 +288,7 @@ fn convert_glsl(name: &str, entry_points: naga::FastHashMap<String, naga::Shader
},
)
.unwrap();
let analysis = naga::proc::Validator::new().validate(&module).unwrap();
#[cfg(feature = "spv-out")]
{
check_output_spv(&module, &analysis, name, &params);
}
#[cfg(feature = "serialize")]
{
let config = ron::ser::PrettyConfig::default().with_new_line("\n".to_string());
let output = ron::ser::to_string_pretty(&module, config).unwrap();
with_snapshot_settings(|| {
insta::assert_snapshot!(format!("{}.ron", name), output);
});
}
check_targets(&module, name, targets);
}
#[cfg(feature = "glsl-in")]
@ -292,5 +297,5 @@ fn convert_glsl_quad() {
let mut entry_points = naga::FastHashMap::default();
entry_points.insert("vert_main".to_string(), naga::ShaderStage::Vertex);
entry_points.insert("frag_main".to_string(), naga::ShaderStage::Fragment);
convert_glsl("quad-glsl", entry_points);
convert_glsl("quad-glsl", entry_points, Targets::SPIRV | Targets::IR);
}