Add push constants example (#1072)

This commit is contained in:
Lucas Kent 2018-10-11 11:42:56 +11:00 committed by GitHub
parent 37de51eeef
commit 4633c72bab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 112 additions and 1 deletions

View File

@ -77,7 +77,7 @@ fn main() {
// We need to create the compute pipeline that describes our operation.
//
// If you are familiar with graphics pipeline, the principle is the same except that compute
// pipelines are much more simple to create.
// pipelines are much simpler to create.
let pipeline = Arc::new({
// TODO: explain
#[allow(dead_code)]

View File

@ -0,0 +1,111 @@
// Copyright (c) 2017 The vulkano developers
// Licensed under the Apache License, Version 2.0
// <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT
// license <LICENSE-MIT or http://opensource.org/licenses/MIT>,
// at your option. All files in the project carrying such
// notice may not be copied, modified, or distributed except
// according to those terms.
// TODO: Give a paragraph about what push constants are and what problems they solve
extern crate vulkano;
#[macro_use]
extern crate vulkano_shader_derive;
use vulkano::buffer::BufferUsage;
use vulkano::buffer::CpuAccessibleBuffer;
use vulkano::command_buffer::AutoCommandBufferBuilder;
use vulkano::descriptor::descriptor_set::PersistentDescriptorSet;
use vulkano::device::Device;
use vulkano::device::DeviceExtensions;
use vulkano::instance::Instance;
use vulkano::instance::InstanceExtensions;
use vulkano::pipeline::ComputePipeline;
use vulkano::sync::now;
use vulkano::sync::GpuFuture;
use std::sync::Arc;
fn main() {
let instance = Instance::new(None, &InstanceExtensions::none(), None).unwrap();
let physical = vulkano::instance::PhysicalDevice::enumerate(&instance).next().unwrap();
let queue_family = physical.queue_families().find(|&q| q.supports_compute()).unwrap();
let (device, mut queues) = {
Device::new(physical, physical.supported_features(), &DeviceExtensions::none(),
[(queue_family, 0.5)].iter().cloned()).expect("failed to create device")
};
let queue = queues.next().unwrap();
mod cs {
#[derive(VulkanoShader)]
#[ty = "compute"]
#[src = "
#version 450
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout(push_constant) uniform PushConstantData {
int multiple;
float addend;
bool enable;
} pc;
layout(set = 0, binding = 0) buffer Data {
uint data[];
} data;
void main() {
uint idx = gl_GlobalInvocationID.x;
if (pc.enable) {
data.data[idx] *= pc.multiple;
data.data[idx] += uint(pc.addend);
}
}"
]
#[allow(dead_code)]
struct Dummy;
}
let shader = cs::Shader::load(device.clone())
.expect("failed to create shader module");
let pipeline = Arc::new(ComputePipeline::new(device.clone(), &shader.main_entry_point(), &()).unwrap());
let data_buffer = {
let data_iter = (0 .. 65536u32).map(|n| n);
CpuAccessibleBuffer::from_iter(device.clone(), BufferUsage::all(),
data_iter).expect("failed to create buffer")
};
let set = Arc::new(PersistentDescriptorSet::start(pipeline.clone(), 0)
.add_buffer(data_buffer.clone()).unwrap()
.build().unwrap()
);
// The `vulkano_shaders!` macro generates a struct with the correct representation of the push constants struct specified in the shader.
// Here we create an instance of the generated struct.
let push_constants = cs::ty::PushConstantData {
multiple: 1,
addend: 1.0,
enable: 1,
};
// For a compute pipeline, push constants are passed to the `dispatch` method.
// For a graphics pipeline, push constants are passed to the `draw` and `draw_indexed` methods.
// Note that there is no type safety for the push constants argument.
// So be careful to only pass an instance of the struct generated by the `vulkano_shaders!` macro.
let command_buffer = AutoCommandBufferBuilder::primary_one_time_submit(device.clone(), queue.family()).unwrap()
.dispatch([1024, 1, 1], pipeline.clone(), set.clone(), push_constants).unwrap()
.build().unwrap();
let future = now(device.clone())
.then_execute(queue.clone(), command_buffer).unwrap()
.then_signal_fence_and_flush().unwrap();
future.wait(None).unwrap();
let data_buffer_content = data_buffer.read().expect("failed to lock buffer for reading");
for n in 0 .. 65536u32 {
assert_eq!(data_buffer_content[n as usize], n * 1 + 1);
}
println!("Success");
}