Add a collatz snapshot test

This commit is contained in:
Dzmitry Malyshau 2021-01-21 23:56:18 -05:00 committed by Dzmitry Malyshau
parent 04b0f2443e
commit c758399354
6 changed files with 92 additions and 2 deletions

View File

@ -176,3 +176,9 @@ fn converts_wgsl_boids() {
fn converts_wgsl_skybox() {
convert_wgsl("skybox", Language::all());
}
#[cfg(feature = "wgsl-in")]
#[test]
fn converts_wgsl_collatz() {
convert_wgsl("collatz", Language::METAL);
}

View File

@ -84,7 +84,7 @@ fn main() {
var pos : vec2<f32>;
var vel : vec2<f32>;
var i : u32 = 0;
var i : u32 = u32(0);
loop {
if (i >= u32(5)) {
break;

View File

@ -0,0 +1,7 @@
(
spv_flow_dump_prefix: "",
spv_capabilities: [ Shader ],
mtl_bindings: {
(stage: Compute, group: 0, binding: 0): (buffer: Some(0), mutable: true),
}
)

View File

@ -0,0 +1,37 @@
[[builtin(global_invocation_id)]]
var global_id: vec3<u32>;
struct PrimeIndices {
data: array<u32>;
}; // this is used as both input and output for convenience
[[group(0), binding(0)]]
var<storage> v_indices: [[access(read_write)]] PrimeIndices;
// The Collatz Conjecture states that for any integer n:
// If n is even, n = n/2
// If n is odd, n = 3n+1
// And repeat this process for each new n, you will always eventually reach 1.
// Though the conjecture has not been proven, no counterexample has ever been found.
// This function returns how many times this recurrence needs to be applied to reach 1.
fn collatz_iterations(n: u32) -> u32{
var i: u32 = u32(0);
loop {
if (n <= u32(1)) {
break;
}
if (n % u32(2) == u32(0)) {
n = n / u32(2);
}
else {
n = u32(3) * n + i32(1);
}
i = i + u32(1);
}
return i;
}
[[stage(compute), workgroup_size(1)]]
fn main() {
v_indices.data[global_id.x] = collatz_iterations(v_indices.data[global_id.x]);
}

View File

@ -71,7 +71,7 @@ kernel void main3(
type6 cVelCount = 0;
type pos1;
type vel1;
type5 i = 0;
type5 i;
if (gl_GlobalInvocationID.x >= static_cast<uint>(5)) {
}
vPos = particlesA.particles[gl_GlobalInvocationID.x].pos;
@ -79,6 +79,7 @@ kernel void main3(
cMass = metal::float2(0.0, 0.0);
cVel = metal::float2(0.0, 0.0);
colVel = metal::float2(0.0, 0.0);
i = static_cast<uint>(0);
while(true) {
if (i >= static_cast<uint>(5)) {
break;

View File

@ -0,0 +1,39 @@
---
source: tests/snapshots.rs
expression: msl
---
#include <metal_stdlib>
#include <simd/simd.h>
typedef metal::uint3 type;
typedef uint type1;
typedef type1 type2[1];
struct PrimeIndices {
type2 data;
};
type1 collatz_iterations(
type1 n
) {
type1 i;
i = static_cast<uint>(0);
while(true) {
if (n <= static_cast<uint>(1)) {
break;
}
if (n % static_cast<uint>(2) == static_cast<uint>(0)) {
n = n / static_cast<uint>(2);
} else {
n = static_cast<uint>(3) * n + static_cast<int>(1);
}
i = i + static_cast<uint>(1);
}
return i;
}
kernel void main1(
type global_id [[thread_position_in_grid]],
device PrimeIndices& v_indices [[buffer(0)]]
) {
v_indices.data[global_id.x] = collatz_iterations(v_indices.data[global_id.x]);
}