Implement a Render Graph #16

Merged
SeanOMik merged 20 commits from feature/render-graph into main 2024-06-15 22:54:47 +00:00
3 changed files with 82 additions and 148 deletions
Showing only changes of commit 7f5a1cd953 - Show all commits

View File

@ -1,89 +0,0 @@
use std::collections::{HashMap, VecDeque};
use rustc_hash::{FxHashMap, FxHashSet};
use super::RenderGraphPassDesc;
pub struct GraphExecutionPath {
/// Queue of the path, top is the first to be executed.
/// Each element is the handle of a pass.
pub queue: VecDeque<u64>,
}
impl GraphExecutionPath {
pub fn new(pass_descriptions: Vec<&RenderGraphPassDesc>) -> Self {
// collect all the output slots
let mut total_outputs = HashMap::new();
total_outputs.reserve(pass_descriptions.len());
for desc in pass_descriptions.iter() {
for slot in desc.output_slots() {
total_outputs.insert(slot.name.clone(), SlotOwnerPair {
pass: desc.id,
slot: slot.id,
});
}
}
let mut nodes = FxHashMap::<u64, Node>::default();
for desc in pass_descriptions.iter() {
// find the node inputs
let mut inputs = vec![];
for slot in desc.input_slots() {
let inp = total_outputs.get(&slot.name)
.expect(&format!("failed to find slot: '{}', ensure that there is a pass outputting it", slot.name));
inputs.push(*inp);
}
let node = Node {
id: desc.id,
desc: (*desc),
slot_inputs: inputs
};
nodes.insert(node.id, node);
}
// sort the graph
let mut stack = VecDeque::new();
let mut visited = FxHashSet::default();
for (_, no) in nodes.iter() {
Self::topological_sort(&nodes, &mut stack, &mut visited, no);
}
Self {
queue: stack,
}
}
fn topological_sort(graph: &FxHashMap<u64, Node>, stack: &mut VecDeque<u64>, visited: &mut FxHashSet<u64>, node: &Node) {
if !visited.contains(&node.id) {
visited.insert(node.id);
for depend in &node.slot_inputs {
let depend_node = graph.get(&depend.pass)
.expect("could not find dependent node");
if !visited.contains(&depend.pass) {
Self::topological_sort(graph, stack, visited, depend_node);
}
}
stack.push_back(node.id);
}
}
}
#[allow(dead_code)]
#[derive(Debug, Clone, Copy)]
struct SlotOwnerPair {
pass: u64,
slot: u64,
}
#[allow(dead_code)]
struct Node<'a> {
id: u64,
desc: &'a RenderGraphPassDesc,
slot_inputs: Vec<SlotOwnerPair>,
}

View File

@ -16,17 +16,12 @@ pub use passes::*;
mod slot_desc; mod slot_desc;
pub use slot_desc::*; pub use slot_desc::*;
mod execution_path;
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use tracing::{debug_span, instrument, trace, warn}; use tracing::{debug_span, instrument, trace, warn};
use wgpu::ComputePass; use wgpu::ComputePass;
use self::execution_path::GraphExecutionPath;
use super::resource::{ComputePipeline, Pipeline, RenderPipeline}; use super::resource::{ComputePipeline, Pipeline, RenderPipeline};
//#[derive(Clone)]
struct PassEntry { struct PassEntry {
inner: Arc<RefCell<dyn RenderGraphPass>>, inner: Arc<RefCell<dyn RenderGraphPass>>,
desc: Arc<RenderGraphPassDesc>, desc: Arc<RenderGraphPassDesc>,
@ -45,7 +40,6 @@ pub struct BindGroupEntry {
#[allow(dead_code)] #[allow(dead_code)]
struct ResourcedSlot { struct ResourcedSlot {
name: String, name: String,
//slot: RenderPassSlot,
ty: SlotType, ty: SlotType,
value: SlotValue, value: SlotValue,
} }
@ -74,13 +68,13 @@ pub struct RenderGraph {
passes: FxHashMap<u64, PassEntry>, passes: FxHashMap<u64, PassEntry>,
// TODO: Use a SlotMap // TODO: Use a SlotMap
bind_groups: FxHashMap<u64, BindGroupEntry>, bind_groups: FxHashMap<u64, BindGroupEntry>,
bind_group_names: FxHashMap<String, u64>, bind_group_names: HashMap<String, u64>,
// TODO: make pipelines a `type` parameter in RenderPasses, // TODO: make pipelines a `type` parameter in RenderPasses,
// then the pipelines can be retrieved via TypeId to the pass. // then the pipelines can be retrieved via TypeId to the pass.
pipelines: FxHashMap<u64, PipelineResource>, pipelines: FxHashMap<u64, PipelineResource>,
current_id: u64, current_id: u64,
exec_path: Option<GraphExecutionPath>, /// A directed graph describing the execution path of the RenderGraph
new_path: petgraph::matrix_graph::DiMatrix<u64, (), Option<()>, usize>, execution_graph: petgraph::matrix_graph::DiMatrix<u64, (), Option<()>, usize>,
} }
impl RenderGraph { impl RenderGraph {
@ -95,8 +89,7 @@ impl RenderGraph {
bind_group_names: Default::default(), bind_group_names: Default::default(),
pipelines: Default::default(), pipelines: Default::default(),
current_id: 1, current_id: 1,
exec_path: None, execution_graph: Default::default(),
new_path: Default::default(),
} }
} }
@ -132,7 +125,8 @@ impl RenderGraph {
trace!( trace!(
"Found existing slot for {}, changing id to {}", "Found existing slot for {}, changing id to {}",
slot.name, id slot.name,
id
); );
// if there is a slot of the same name // if there is a slot of the same name
@ -163,7 +157,7 @@ impl RenderGraph {
self.bind_group_names.insert(name.clone(), bg_id); self.bind_group_names.insert(name.clone(), bg_id);
} }
let index = self.new_path.add_node(desc.id); let index = self.execution_graph.add_node(desc.id);
self.passes.insert( self.passes.insert(
desc.id, desc.id,
@ -182,14 +176,18 @@ impl RenderGraph {
for pass in self.passes.values() { for pass in self.passes.values() {
if let Some(pipeline_desc) = &pass.desc.pipeline_desc { if let Some(pipeline_desc) = &pass.desc.pipeline_desc {
let pipeline = match pass.desc.pass_type { let pipeline = match pass.desc.pass_type {
RenderPassType::Render => { RenderPassType::Render => Pipeline::Render(RenderPipeline::create(
Pipeline::Render(RenderPipeline::create(device, pipeline_desc.as_render_pipeline_descriptor() device,
.expect("got compute pipeline descriptor in a render pass"))) pipeline_desc
}, .as_render_pipeline_descriptor()
RenderPassType::Compute => { .expect("got compute pipeline descriptor in a render pass"),
Pipeline::Compute(ComputePipeline::create(device, pipeline_desc.as_compute_pipeline_descriptor() )),
.expect("got render pipeline descriptor in a compute pass"))) RenderPassType::Compute => Pipeline::Compute(ComputePipeline::create(
}, device,
pipeline_desc
.as_compute_pipeline_descriptor()
.expect("got render pipeline descriptor in a compute pass"),
)),
RenderPassType::Presenter | RenderPassType::Node => { RenderPassType::Presenter | RenderPassType::Node => {
panic!("Present or Node RenderGraph passes should not have a pipeline descriptor!"); panic!("Present or Node RenderGraph passes should not have a pipeline descriptor!");
} }
@ -234,25 +232,19 @@ impl RenderGraph {
self.queue.write_buffer(buf, bufwr.offset, &bufwr.bytes); self.queue.write_buffer(buf, bufwr.offset, &bufwr.bytes);
} }
} }
// create the execution path for the graph. This will be executed in `RenderGraph::render`
let descs = self.passes.values().map(|p| &*p.desc).collect();
let path = GraphExecutionPath::new(descs);
trace!(
"Found {} steps in the rendergraph to execute",
path.queue.len()
);
self.exec_path = Some(path);
} }
#[instrument(skip(self))] #[instrument(skip(self))]
pub fn render(&mut self) { pub fn render(&mut self) {
let mut sorted: VecDeque<u64> = petgraph::algo::toposort(&self.new_path, None) let mut sorted: VecDeque<u64> = petgraph::algo::toposort(&self.execution_graph, None)
.expect("RenderGraph had cycled!") .expect("RenderGraph had cycled!")
.iter().map(|i| self.new_path[i.clone()]) .iter()
.map(|i| self.execution_graph[i.clone()])
.collect(); .collect();
let path_names = sorted.iter().map(|i| self.pass(*i).unwrap().name.clone()).collect_vec(); let path_names = sorted
.iter()
.map(|i| self.pass(*i).unwrap().name.clone())
.collect_vec();
trace!("Render graph execution order: {:?}", path_names); trace!("Render graph execution order: {:?}", path_names);
let mut encoders = Vec::with_capacity(self.passes.len() / 2); let mut encoders = Vec::with_capacity(self.passes.len() / 2);
@ -293,9 +285,12 @@ impl RenderGraph {
} }
if !encoders.is_empty() { if !encoders.is_empty() {
warn!("{} encoders were not submitted in the same render cycle they were created. \ warn!(
"{} encoders were not submitted in the same render cycle they were created. \
Make sure there is a presenting pass at the end. You may still see something, \ Make sure there is a presenting pass at the end. You may still see something, \
however it will be delayed a render cycle.", encoders.len()); however it will be delayed a render cycle.",
encoders.len()
);
self.queue.submit(encoders.into_iter()); self.queue.submit(encoders.into_iter());
} }
} }
@ -344,13 +339,59 @@ impl RenderGraph {
} }
pub fn add_edge(&mut self, from: &str, to: &str) { pub fn add_edge(&mut self, from: &str, to: &str) {
let from_idx = self.passes.iter().find(|p| p.1.desc.name == from).map(|p| p.1.graph_index) let from_idx = self
.passes
.iter()
.find(|p| p.1.desc.name == from)
.map(|p| p.1.graph_index)
.expect("Failed to find from pass"); .expect("Failed to find from pass");
let to_idx = self.passes.iter().find(|p| p.1.desc.name == to).map(|p| p.1.graph_index) let to_idx = self
.passes
.iter()
.find(|p| p.1.desc.name == to)
.map(|p| p.1.graph_index)
.expect("Failed to find to pass"); .expect("Failed to find to pass");
self.new_path.add_edge(from_idx, to_idx, ()); self.execution_graph.add_edge(from_idx, to_idx, ());
//self.new_path.add_edge(NodeIndex::new(from_id as usize), NodeIndex::new(to_id as usize), ()); }
/// Utility method for setting the bind groups for a pass.
///
/// The parameter `bind_groups` can be used to specify the labels of a bind group, and the
/// index of the bind group in the pipeline for the pass. If a bind group of the provided
/// name is not found in the graph, a panic will occur.
///
/// # Example:
/// ```rust,nobuild
/// graph.set_bind_groups(
/// &mut pass,
/// &[
/// // retrieves the "depth_texture" bind group and sets the index 0 in the
/// // pass to it.
/// ("depth_texture", 0),
/// ("camera", 1),
/// ("light_buffers", 2),
/// ("light_indices_grid", 3),
/// ("screen_size", 4),
/// ],
/// );
/// ```
///
/// # Panics
/// Panics if a bind group of a provided name is not found.
pub fn set_bind_groups<'a>(
&'a self,
pass: &mut ComputePass<'a>,
bind_groups: &[(&str, u32)],
) {
for (name, index) in bind_groups {
let bg = self
.bind_group_id(name)
.map(|bgi| self.bind_group(bgi))
.expect(&format!("Could not find bind group '{}'", name));
pass.set_bind_group(*index, bg, &[]);
}
} }
} }
@ -428,21 +469,4 @@ impl<'a> RenderGraphContext<'a> {
) { ) {
self.queue_buffer_write(target_slot, offset, bytemuck::bytes_of(&bytes)); self.queue_buffer_write(target_slot, offset, bytemuck::bytes_of(&bytes));
} }
pub fn get_bind_groups<'b>(&self, graph: &'b RenderGraph, bind_group_names: &[&str]) -> Vec<Option<&'b Rc<wgpu::BindGroup>>> {
bind_group_names
.iter()
.map(|name| graph.bind_group_id(name))
.map(|bgi| bgi.map(|bgi| graph.bind_group(bgi)))
.collect()
}
pub fn set_bind_groups<'b, 'c>(graph: &'b RenderGraph, pass: &'c mut ComputePass<'b>, bind_groups: &[(&str, u32)]) {
for (name, index) in bind_groups {
let bg = graph.bind_group_id(name).map(|bgi| graph.bind_group(bgi))
.expect(&format!("Could not find bind group '{}'", name));
pass.set_bind_group(*index, bg, &[]);
}
}
} }

View File

@ -236,8 +236,7 @@ impl RenderGraphPass for LightCullComputePass {
pass.set_bind_group(3, grid_bg, &[]); pass.set_bind_group(3, grid_bg, &[]);
pass.set_bind_group(4, screen_size_bg, &[]); */ pass.set_bind_group(4, screen_size_bg, &[]); */
RenderGraphContext::set_bind_groups( graph.set_bind_groups(
graph,
&mut pass, &mut pass,
&[ &[
("depth_texture", 0), ("depth_texture", 0),