Implement a Render Graph #16
|
@ -1,89 +0,0 @@
|
||||||
use std::collections::{HashMap, VecDeque};
|
|
||||||
|
|
||||||
use rustc_hash::{FxHashMap, FxHashSet};
|
|
||||||
|
|
||||||
use super::RenderGraphPassDesc;
|
|
||||||
|
|
||||||
pub struct GraphExecutionPath {
|
|
||||||
/// Queue of the path, top is the first to be executed.
|
|
||||||
/// Each element is the handle of a pass.
|
|
||||||
pub queue: VecDeque<u64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl GraphExecutionPath {
|
|
||||||
pub fn new(pass_descriptions: Vec<&RenderGraphPassDesc>) -> Self {
|
|
||||||
// collect all the output slots
|
|
||||||
let mut total_outputs = HashMap::new();
|
|
||||||
total_outputs.reserve(pass_descriptions.len());
|
|
||||||
|
|
||||||
for desc in pass_descriptions.iter() {
|
|
||||||
for slot in desc.output_slots() {
|
|
||||||
total_outputs.insert(slot.name.clone(), SlotOwnerPair {
|
|
||||||
pass: desc.id,
|
|
||||||
slot: slot.id,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut nodes = FxHashMap::<u64, Node>::default();
|
|
||||||
for desc in pass_descriptions.iter() {
|
|
||||||
// find the node inputs
|
|
||||||
let mut inputs = vec![];
|
|
||||||
for slot in desc.input_slots() {
|
|
||||||
let inp = total_outputs.get(&slot.name)
|
|
||||||
.expect(&format!("failed to find slot: '{}', ensure that there is a pass outputting it", slot.name));
|
|
||||||
inputs.push(*inp);
|
|
||||||
}
|
|
||||||
|
|
||||||
let node = Node {
|
|
||||||
id: desc.id,
|
|
||||||
desc: (*desc),
|
|
||||||
slot_inputs: inputs
|
|
||||||
};
|
|
||||||
nodes.insert(node.id, node);
|
|
||||||
}
|
|
||||||
|
|
||||||
// sort the graph
|
|
||||||
let mut stack = VecDeque::new();
|
|
||||||
let mut visited = FxHashSet::default();
|
|
||||||
for (_, no) in nodes.iter() {
|
|
||||||
Self::topological_sort(&nodes, &mut stack, &mut visited, no);
|
|
||||||
}
|
|
||||||
|
|
||||||
Self {
|
|
||||||
queue: stack,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn topological_sort(graph: &FxHashMap<u64, Node>, stack: &mut VecDeque<u64>, visited: &mut FxHashSet<u64>, node: &Node) {
|
|
||||||
if !visited.contains(&node.id) {
|
|
||||||
visited.insert(node.id);
|
|
||||||
|
|
||||||
for depend in &node.slot_inputs {
|
|
||||||
let depend_node = graph.get(&depend.pass)
|
|
||||||
.expect("could not find dependent node");
|
|
||||||
|
|
||||||
if !visited.contains(&depend.pass) {
|
|
||||||
Self::topological_sort(graph, stack, visited, depend_node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
stack.push_back(node.id);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
struct SlotOwnerPair {
|
|
||||||
pass: u64,
|
|
||||||
slot: u64,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
struct Node<'a> {
|
|
||||||
id: u64,
|
|
||||||
desc: &'a RenderGraphPassDesc,
|
|
||||||
slot_inputs: Vec<SlotOwnerPair>,
|
|
||||||
}
|
|
|
@ -16,17 +16,12 @@ pub use passes::*;
|
||||||
mod slot_desc;
|
mod slot_desc;
|
||||||
pub use slot_desc::*;
|
pub use slot_desc::*;
|
||||||
|
|
||||||
mod execution_path;
|
|
||||||
|
|
||||||
use rustc_hash::FxHashMap;
|
use rustc_hash::FxHashMap;
|
||||||
use tracing::{debug_span, instrument, trace, warn};
|
use tracing::{debug_span, instrument, trace, warn};
|
||||||
use wgpu::ComputePass;
|
use wgpu::ComputePass;
|
||||||
|
|
||||||
use self::execution_path::GraphExecutionPath;
|
|
||||||
|
|
||||||
use super::resource::{ComputePipeline, Pipeline, RenderPipeline};
|
use super::resource::{ComputePipeline, Pipeline, RenderPipeline};
|
||||||
|
|
||||||
//#[derive(Clone)]
|
|
||||||
struct PassEntry {
|
struct PassEntry {
|
||||||
inner: Arc<RefCell<dyn RenderGraphPass>>,
|
inner: Arc<RefCell<dyn RenderGraphPass>>,
|
||||||
desc: Arc<RenderGraphPassDesc>,
|
desc: Arc<RenderGraphPassDesc>,
|
||||||
|
@ -45,7 +40,6 @@ pub struct BindGroupEntry {
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
struct ResourcedSlot {
|
struct ResourcedSlot {
|
||||||
name: String,
|
name: String,
|
||||||
//slot: RenderPassSlot,
|
|
||||||
ty: SlotType,
|
ty: SlotType,
|
||||||
value: SlotValue,
|
value: SlotValue,
|
||||||
}
|
}
|
||||||
|
@ -74,13 +68,13 @@ pub struct RenderGraph {
|
||||||
passes: FxHashMap<u64, PassEntry>,
|
passes: FxHashMap<u64, PassEntry>,
|
||||||
// TODO: Use a SlotMap
|
// TODO: Use a SlotMap
|
||||||
bind_groups: FxHashMap<u64, BindGroupEntry>,
|
bind_groups: FxHashMap<u64, BindGroupEntry>,
|
||||||
bind_group_names: FxHashMap<String, u64>,
|
bind_group_names: HashMap<String, u64>,
|
||||||
// TODO: make pipelines a `type` parameter in RenderPasses,
|
// TODO: make pipelines a `type` parameter in RenderPasses,
|
||||||
// then the pipelines can be retrieved via TypeId to the pass.
|
// then the pipelines can be retrieved via TypeId to the pass.
|
||||||
pipelines: FxHashMap<u64, PipelineResource>,
|
pipelines: FxHashMap<u64, PipelineResource>,
|
||||||
current_id: u64,
|
current_id: u64,
|
||||||
exec_path: Option<GraphExecutionPath>,
|
/// A directed graph describing the execution path of the RenderGraph
|
||||||
new_path: petgraph::matrix_graph::DiMatrix<u64, (), Option<()>, usize>,
|
execution_graph: petgraph::matrix_graph::DiMatrix<u64, (), Option<()>, usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RenderGraph {
|
impl RenderGraph {
|
||||||
|
@ -95,8 +89,7 @@ impl RenderGraph {
|
||||||
bind_group_names: Default::default(),
|
bind_group_names: Default::default(),
|
||||||
pipelines: Default::default(),
|
pipelines: Default::default(),
|
||||||
current_id: 1,
|
current_id: 1,
|
||||||
exec_path: None,
|
execution_graph: Default::default(),
|
||||||
new_path: Default::default(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -132,7 +125,8 @@ impl RenderGraph {
|
||||||
|
|
||||||
trace!(
|
trace!(
|
||||||
"Found existing slot for {}, changing id to {}",
|
"Found existing slot for {}, changing id to {}",
|
||||||
slot.name, id
|
slot.name,
|
||||||
|
id
|
||||||
);
|
);
|
||||||
|
|
||||||
// if there is a slot of the same name
|
// if there is a slot of the same name
|
||||||
|
@ -163,7 +157,7 @@ impl RenderGraph {
|
||||||
self.bind_group_names.insert(name.clone(), bg_id);
|
self.bind_group_names.insert(name.clone(), bg_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
let index = self.new_path.add_node(desc.id);
|
let index = self.execution_graph.add_node(desc.id);
|
||||||
|
|
||||||
self.passes.insert(
|
self.passes.insert(
|
||||||
desc.id,
|
desc.id,
|
||||||
|
@ -182,14 +176,18 @@ impl RenderGraph {
|
||||||
for pass in self.passes.values() {
|
for pass in self.passes.values() {
|
||||||
if let Some(pipeline_desc) = &pass.desc.pipeline_desc {
|
if let Some(pipeline_desc) = &pass.desc.pipeline_desc {
|
||||||
let pipeline = match pass.desc.pass_type {
|
let pipeline = match pass.desc.pass_type {
|
||||||
RenderPassType::Render => {
|
RenderPassType::Render => Pipeline::Render(RenderPipeline::create(
|
||||||
Pipeline::Render(RenderPipeline::create(device, pipeline_desc.as_render_pipeline_descriptor()
|
device,
|
||||||
.expect("got compute pipeline descriptor in a render pass")))
|
pipeline_desc
|
||||||
},
|
.as_render_pipeline_descriptor()
|
||||||
RenderPassType::Compute => {
|
.expect("got compute pipeline descriptor in a render pass"),
|
||||||
Pipeline::Compute(ComputePipeline::create(device, pipeline_desc.as_compute_pipeline_descriptor()
|
)),
|
||||||
.expect("got render pipeline descriptor in a compute pass")))
|
RenderPassType::Compute => Pipeline::Compute(ComputePipeline::create(
|
||||||
},
|
device,
|
||||||
|
pipeline_desc
|
||||||
|
.as_compute_pipeline_descriptor()
|
||||||
|
.expect("got render pipeline descriptor in a compute pass"),
|
||||||
|
)),
|
||||||
RenderPassType::Presenter | RenderPassType::Node => {
|
RenderPassType::Presenter | RenderPassType::Node => {
|
||||||
panic!("Present or Node RenderGraph passes should not have a pipeline descriptor!");
|
panic!("Present or Node RenderGraph passes should not have a pipeline descriptor!");
|
||||||
}
|
}
|
||||||
|
@ -234,25 +232,19 @@ impl RenderGraph {
|
||||||
self.queue.write_buffer(buf, bufwr.offset, &bufwr.bytes);
|
self.queue.write_buffer(buf, bufwr.offset, &bufwr.bytes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// create the execution path for the graph. This will be executed in `RenderGraph::render`
|
|
||||||
let descs = self.passes.values().map(|p| &*p.desc).collect();
|
|
||||||
let path = GraphExecutionPath::new(descs);
|
|
||||||
trace!(
|
|
||||||
"Found {} steps in the rendergraph to execute",
|
|
||||||
path.queue.len()
|
|
||||||
);
|
|
||||||
|
|
||||||
self.exec_path = Some(path);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub fn render(&mut self) {
|
pub fn render(&mut self) {
|
||||||
let mut sorted: VecDeque<u64> = petgraph::algo::toposort(&self.new_path, None)
|
let mut sorted: VecDeque<u64> = petgraph::algo::toposort(&self.execution_graph, None)
|
||||||
.expect("RenderGraph had cycled!")
|
.expect("RenderGraph had cycled!")
|
||||||
.iter().map(|i| self.new_path[i.clone()])
|
.iter()
|
||||||
|
.map(|i| self.execution_graph[i.clone()])
|
||||||
.collect();
|
.collect();
|
||||||
let path_names = sorted.iter().map(|i| self.pass(*i).unwrap().name.clone()).collect_vec();
|
let path_names = sorted
|
||||||
|
.iter()
|
||||||
|
.map(|i| self.pass(*i).unwrap().name.clone())
|
||||||
|
.collect_vec();
|
||||||
trace!("Render graph execution order: {:?}", path_names);
|
trace!("Render graph execution order: {:?}", path_names);
|
||||||
|
|
||||||
let mut encoders = Vec::with_capacity(self.passes.len() / 2);
|
let mut encoders = Vec::with_capacity(self.passes.len() / 2);
|
||||||
|
@ -293,9 +285,12 @@ impl RenderGraph {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !encoders.is_empty() {
|
if !encoders.is_empty() {
|
||||||
warn!("{} encoders were not submitted in the same render cycle they were created. \
|
warn!(
|
||||||
|
"{} encoders were not submitted in the same render cycle they were created. \
|
||||||
Make sure there is a presenting pass at the end. You may still see something, \
|
Make sure there is a presenting pass at the end. You may still see something, \
|
||||||
however it will be delayed a render cycle.", encoders.len());
|
however it will be delayed a render cycle.",
|
||||||
|
encoders.len()
|
||||||
|
);
|
||||||
self.queue.submit(encoders.into_iter());
|
self.queue.submit(encoders.into_iter());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -344,13 +339,59 @@ impl RenderGraph {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_edge(&mut self, from: &str, to: &str) {
|
pub fn add_edge(&mut self, from: &str, to: &str) {
|
||||||
let from_idx = self.passes.iter().find(|p| p.1.desc.name == from).map(|p| p.1.graph_index)
|
let from_idx = self
|
||||||
|
.passes
|
||||||
|
.iter()
|
||||||
|
.find(|p| p.1.desc.name == from)
|
||||||
|
.map(|p| p.1.graph_index)
|
||||||
.expect("Failed to find from pass");
|
.expect("Failed to find from pass");
|
||||||
let to_idx = self.passes.iter().find(|p| p.1.desc.name == to).map(|p| p.1.graph_index)
|
let to_idx = self
|
||||||
|
.passes
|
||||||
|
.iter()
|
||||||
|
.find(|p| p.1.desc.name == to)
|
||||||
|
.map(|p| p.1.graph_index)
|
||||||
.expect("Failed to find to pass");
|
.expect("Failed to find to pass");
|
||||||
|
|
||||||
self.new_path.add_edge(from_idx, to_idx, ());
|
self.execution_graph.add_edge(from_idx, to_idx, ());
|
||||||
//self.new_path.add_edge(NodeIndex::new(from_id as usize), NodeIndex::new(to_id as usize), ());
|
}
|
||||||
|
|
||||||
|
/// Utility method for setting the bind groups for a pass.
|
||||||
|
///
|
||||||
|
/// The parameter `bind_groups` can be used to specify the labels of a bind group, and the
|
||||||
|
/// index of the bind group in the pipeline for the pass. If a bind group of the provided
|
||||||
|
/// name is not found in the graph, a panic will occur.
|
||||||
|
///
|
||||||
|
/// # Example:
|
||||||
|
/// ```rust,nobuild
|
||||||
|
/// graph.set_bind_groups(
|
||||||
|
/// &mut pass,
|
||||||
|
/// &[
|
||||||
|
/// // retrieves the "depth_texture" bind group and sets the index 0 in the
|
||||||
|
/// // pass to it.
|
||||||
|
/// ("depth_texture", 0),
|
||||||
|
/// ("camera", 1),
|
||||||
|
/// ("light_buffers", 2),
|
||||||
|
/// ("light_indices_grid", 3),
|
||||||
|
/// ("screen_size", 4),
|
||||||
|
/// ],
|
||||||
|
/// );
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
/// Panics if a bind group of a provided name is not found.
|
||||||
|
pub fn set_bind_groups<'a>(
|
||||||
|
&'a self,
|
||||||
|
pass: &mut ComputePass<'a>,
|
||||||
|
bind_groups: &[(&str, u32)],
|
||||||
|
) {
|
||||||
|
for (name, index) in bind_groups {
|
||||||
|
let bg = self
|
||||||
|
.bind_group_id(name)
|
||||||
|
.map(|bgi| self.bind_group(bgi))
|
||||||
|
.expect(&format!("Could not find bind group '{}'", name));
|
||||||
|
|
||||||
|
pass.set_bind_group(*index, bg, &[]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -428,21 +469,4 @@ impl<'a> RenderGraphContext<'a> {
|
||||||
) {
|
) {
|
||||||
self.queue_buffer_write(target_slot, offset, bytemuck::bytes_of(&bytes));
|
self.queue_buffer_write(target_slot, offset, bytemuck::bytes_of(&bytes));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_bind_groups<'b>(&self, graph: &'b RenderGraph, bind_group_names: &[&str]) -> Vec<Option<&'b Rc<wgpu::BindGroup>>> {
|
|
||||||
bind_group_names
|
|
||||||
.iter()
|
|
||||||
.map(|name| graph.bind_group_id(name))
|
|
||||||
.map(|bgi| bgi.map(|bgi| graph.bind_group(bgi)))
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn set_bind_groups<'b, 'c>(graph: &'b RenderGraph, pass: &'c mut ComputePass<'b>, bind_groups: &[(&str, u32)]) {
|
|
||||||
for (name, index) in bind_groups {
|
|
||||||
let bg = graph.bind_group_id(name).map(|bgi| graph.bind_group(bgi))
|
|
||||||
.expect(&format!("Could not find bind group '{}'", name));
|
|
||||||
|
|
||||||
pass.set_bind_group(*index, bg, &[]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -236,8 +236,7 @@ impl RenderGraphPass for LightCullComputePass {
|
||||||
pass.set_bind_group(3, grid_bg, &[]);
|
pass.set_bind_group(3, grid_bg, &[]);
|
||||||
pass.set_bind_group(4, screen_size_bg, &[]); */
|
pass.set_bind_group(4, screen_size_bg, &[]); */
|
||||||
|
|
||||||
RenderGraphContext::set_bind_groups(
|
graph.set_bind_groups(
|
||||||
graph,
|
|
||||||
&mut pass,
|
&mut pass,
|
||||||
&[
|
&[
|
||||||
("depth_texture", 0),
|
("depth_texture", 0),
|
||||||
|
|
Loading…
Reference in New Issue