Implement a Render Graph #16
|
@ -1,89 +0,0 @@
|
|||
use std::collections::{HashMap, VecDeque};
|
||||
|
||||
use rustc_hash::{FxHashMap, FxHashSet};
|
||||
|
||||
use super::RenderGraphPassDesc;
|
||||
|
||||
pub struct GraphExecutionPath {
|
||||
/// Queue of the path, top is the first to be executed.
|
||||
/// Each element is the handle of a pass.
|
||||
pub queue: VecDeque<u64>,
|
||||
}
|
||||
|
||||
impl GraphExecutionPath {
|
||||
pub fn new(pass_descriptions: Vec<&RenderGraphPassDesc>) -> Self {
|
||||
// collect all the output slots
|
||||
let mut total_outputs = HashMap::new();
|
||||
total_outputs.reserve(pass_descriptions.len());
|
||||
|
||||
for desc in pass_descriptions.iter() {
|
||||
for slot in desc.output_slots() {
|
||||
total_outputs.insert(slot.name.clone(), SlotOwnerPair {
|
||||
pass: desc.id,
|
||||
slot: slot.id,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let mut nodes = FxHashMap::<u64, Node>::default();
|
||||
for desc in pass_descriptions.iter() {
|
||||
// find the node inputs
|
||||
let mut inputs = vec![];
|
||||
for slot in desc.input_slots() {
|
||||
let inp = total_outputs.get(&slot.name)
|
||||
.expect(&format!("failed to find slot: '{}', ensure that there is a pass outputting it", slot.name));
|
||||
inputs.push(*inp);
|
||||
}
|
||||
|
||||
let node = Node {
|
||||
id: desc.id,
|
||||
desc: (*desc),
|
||||
slot_inputs: inputs
|
||||
};
|
||||
nodes.insert(node.id, node);
|
||||
}
|
||||
|
||||
// sort the graph
|
||||
let mut stack = VecDeque::new();
|
||||
let mut visited = FxHashSet::default();
|
||||
for (_, no) in nodes.iter() {
|
||||
Self::topological_sort(&nodes, &mut stack, &mut visited, no);
|
||||
}
|
||||
|
||||
Self {
|
||||
queue: stack,
|
||||
}
|
||||
}
|
||||
|
||||
fn topological_sort(graph: &FxHashMap<u64, Node>, stack: &mut VecDeque<u64>, visited: &mut FxHashSet<u64>, node: &Node) {
|
||||
if !visited.contains(&node.id) {
|
||||
visited.insert(node.id);
|
||||
|
||||
for depend in &node.slot_inputs {
|
||||
let depend_node = graph.get(&depend.pass)
|
||||
.expect("could not find dependent node");
|
||||
|
||||
if !visited.contains(&depend.pass) {
|
||||
Self::topological_sort(graph, stack, visited, depend_node);
|
||||
}
|
||||
}
|
||||
|
||||
stack.push_back(node.id);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct SlotOwnerPair {
|
||||
pass: u64,
|
||||
slot: u64,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
struct Node<'a> {
|
||||
id: u64,
|
||||
desc: &'a RenderGraphPassDesc,
|
||||
slot_inputs: Vec<SlotOwnerPair>,
|
||||
}
|
|
@ -16,17 +16,12 @@ pub use passes::*;
|
|||
mod slot_desc;
|
||||
pub use slot_desc::*;
|
||||
|
||||
mod execution_path;
|
||||
|
||||
use rustc_hash::FxHashMap;
|
||||
use tracing::{debug_span, instrument, trace, warn};
|
||||
use wgpu::ComputePass;
|
||||
|
||||
use self::execution_path::GraphExecutionPath;
|
||||
|
||||
use super::resource::{ComputePipeline, Pipeline, RenderPipeline};
|
||||
|
||||
//#[derive(Clone)]
|
||||
struct PassEntry {
|
||||
inner: Arc<RefCell<dyn RenderGraphPass>>,
|
||||
desc: Arc<RenderGraphPassDesc>,
|
||||
|
@ -45,7 +40,6 @@ pub struct BindGroupEntry {
|
|||
#[allow(dead_code)]
|
||||
struct ResourcedSlot {
|
||||
name: String,
|
||||
//slot: RenderPassSlot,
|
||||
ty: SlotType,
|
||||
value: SlotValue,
|
||||
}
|
||||
|
@ -74,13 +68,13 @@ pub struct RenderGraph {
|
|||
passes: FxHashMap<u64, PassEntry>,
|
||||
// TODO: Use a SlotMap
|
||||
bind_groups: FxHashMap<u64, BindGroupEntry>,
|
||||
bind_group_names: FxHashMap<String, u64>,
|
||||
bind_group_names: HashMap<String, u64>,
|
||||
// TODO: make pipelines a `type` parameter in RenderPasses,
|
||||
// then the pipelines can be retrieved via TypeId to the pass.
|
||||
pipelines: FxHashMap<u64, PipelineResource>,
|
||||
current_id: u64,
|
||||
exec_path: Option<GraphExecutionPath>,
|
||||
new_path: petgraph::matrix_graph::DiMatrix<u64, (), Option<()>, usize>,
|
||||
/// A directed graph describing the execution path of the RenderGraph
|
||||
execution_graph: petgraph::matrix_graph::DiMatrix<u64, (), Option<()>, usize>,
|
||||
}
|
||||
|
||||
impl RenderGraph {
|
||||
|
@ -95,8 +89,7 @@ impl RenderGraph {
|
|||
bind_group_names: Default::default(),
|
||||
pipelines: Default::default(),
|
||||
current_id: 1,
|
||||
exec_path: None,
|
||||
new_path: Default::default(),
|
||||
execution_graph: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -132,7 +125,8 @@ impl RenderGraph {
|
|||
|
||||
trace!(
|
||||
"Found existing slot for {}, changing id to {}",
|
||||
slot.name, id
|
||||
slot.name,
|
||||
id
|
||||
);
|
||||
|
||||
// if there is a slot of the same name
|
||||
|
@ -163,7 +157,7 @@ impl RenderGraph {
|
|||
self.bind_group_names.insert(name.clone(), bg_id);
|
||||
}
|
||||
|
||||
let index = self.new_path.add_node(desc.id);
|
||||
let index = self.execution_graph.add_node(desc.id);
|
||||
|
||||
self.passes.insert(
|
||||
desc.id,
|
||||
|
@ -182,14 +176,18 @@ impl RenderGraph {
|
|||
for pass in self.passes.values() {
|
||||
if let Some(pipeline_desc) = &pass.desc.pipeline_desc {
|
||||
let pipeline = match pass.desc.pass_type {
|
||||
RenderPassType::Render => {
|
||||
Pipeline::Render(RenderPipeline::create(device, pipeline_desc.as_render_pipeline_descriptor()
|
||||
.expect("got compute pipeline descriptor in a render pass")))
|
||||
},
|
||||
RenderPassType::Compute => {
|
||||
Pipeline::Compute(ComputePipeline::create(device, pipeline_desc.as_compute_pipeline_descriptor()
|
||||
.expect("got render pipeline descriptor in a compute pass")))
|
||||
},
|
||||
RenderPassType::Render => Pipeline::Render(RenderPipeline::create(
|
||||
device,
|
||||
pipeline_desc
|
||||
.as_render_pipeline_descriptor()
|
||||
.expect("got compute pipeline descriptor in a render pass"),
|
||||
)),
|
||||
RenderPassType::Compute => Pipeline::Compute(ComputePipeline::create(
|
||||
device,
|
||||
pipeline_desc
|
||||
.as_compute_pipeline_descriptor()
|
||||
.expect("got render pipeline descriptor in a compute pass"),
|
||||
)),
|
||||
RenderPassType::Presenter | RenderPassType::Node => {
|
||||
panic!("Present or Node RenderGraph passes should not have a pipeline descriptor!");
|
||||
}
|
||||
|
@ -234,25 +232,19 @@ impl RenderGraph {
|
|||
self.queue.write_buffer(buf, bufwr.offset, &bufwr.bytes);
|
||||
}
|
||||
}
|
||||
|
||||
// create the execution path for the graph. This will be executed in `RenderGraph::render`
|
||||
let descs = self.passes.values().map(|p| &*p.desc).collect();
|
||||
let path = GraphExecutionPath::new(descs);
|
||||
trace!(
|
||||
"Found {} steps in the rendergraph to execute",
|
||||
path.queue.len()
|
||||
);
|
||||
|
||||
self.exec_path = Some(path);
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
pub fn render(&mut self) {
|
||||
let mut sorted: VecDeque<u64> = petgraph::algo::toposort(&self.new_path, None)
|
||||
let mut sorted: VecDeque<u64> = petgraph::algo::toposort(&self.execution_graph, None)
|
||||
.expect("RenderGraph had cycled!")
|
||||
.iter().map(|i| self.new_path[i.clone()])
|
||||
.iter()
|
||||
.map(|i| self.execution_graph[i.clone()])
|
||||
.collect();
|
||||
let path_names = sorted.iter().map(|i| self.pass(*i).unwrap().name.clone()).collect_vec();
|
||||
let path_names = sorted
|
||||
.iter()
|
||||
.map(|i| self.pass(*i).unwrap().name.clone())
|
||||
.collect_vec();
|
||||
trace!("Render graph execution order: {:?}", path_names);
|
||||
|
||||
let mut encoders = Vec::with_capacity(self.passes.len() / 2);
|
||||
|
@ -293,9 +285,12 @@ impl RenderGraph {
|
|||
}
|
||||
|
||||
if !encoders.is_empty() {
|
||||
warn!("{} encoders were not submitted in the same render cycle they were created. \
|
||||
warn!(
|
||||
"{} encoders were not submitted in the same render cycle they were created. \
|
||||
Make sure there is a presenting pass at the end. You may still see something, \
|
||||
however it will be delayed a render cycle.", encoders.len());
|
||||
however it will be delayed a render cycle.",
|
||||
encoders.len()
|
||||
);
|
||||
self.queue.submit(encoders.into_iter());
|
||||
}
|
||||
}
|
||||
|
@ -344,13 +339,59 @@ impl RenderGraph {
|
|||
}
|
||||
|
||||
pub fn add_edge(&mut self, from: &str, to: &str) {
|
||||
let from_idx = self.passes.iter().find(|p| p.1.desc.name == from).map(|p| p.1.graph_index)
|
||||
let from_idx = self
|
||||
.passes
|
||||
.iter()
|
||||
.find(|p| p.1.desc.name == from)
|
||||
.map(|p| p.1.graph_index)
|
||||
.expect("Failed to find from pass");
|
||||
let to_idx = self.passes.iter().find(|p| p.1.desc.name == to).map(|p| p.1.graph_index)
|
||||
let to_idx = self
|
||||
.passes
|
||||
.iter()
|
||||
.find(|p| p.1.desc.name == to)
|
||||
.map(|p| p.1.graph_index)
|
||||
.expect("Failed to find to pass");
|
||||
|
||||
self.new_path.add_edge(from_idx, to_idx, ());
|
||||
//self.new_path.add_edge(NodeIndex::new(from_id as usize), NodeIndex::new(to_id as usize), ());
|
||||
self.execution_graph.add_edge(from_idx, to_idx, ());
|
||||
}
|
||||
|
||||
/// Utility method for setting the bind groups for a pass.
|
||||
///
|
||||
/// The parameter `bind_groups` can be used to specify the labels of a bind group, and the
|
||||
/// index of the bind group in the pipeline for the pass. If a bind group of the provided
|
||||
/// name is not found in the graph, a panic will occur.
|
||||
///
|
||||
/// # Example:
|
||||
/// ```rust,nobuild
|
||||
/// graph.set_bind_groups(
|
||||
/// &mut pass,
|
||||
/// &[
|
||||
/// // retrieves the "depth_texture" bind group and sets the index 0 in the
|
||||
/// // pass to it.
|
||||
/// ("depth_texture", 0),
|
||||
/// ("camera", 1),
|
||||
/// ("light_buffers", 2),
|
||||
/// ("light_indices_grid", 3),
|
||||
/// ("screen_size", 4),
|
||||
/// ],
|
||||
/// );
|
||||
/// ```
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if a bind group of a provided name is not found.
|
||||
pub fn set_bind_groups<'a>(
|
||||
&'a self,
|
||||
pass: &mut ComputePass<'a>,
|
||||
bind_groups: &[(&str, u32)],
|
||||
) {
|
||||
for (name, index) in bind_groups {
|
||||
let bg = self
|
||||
.bind_group_id(name)
|
||||
.map(|bgi| self.bind_group(bgi))
|
||||
.expect(&format!("Could not find bind group '{}'", name));
|
||||
|
||||
pass.set_bind_group(*index, bg, &[]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -428,21 +469,4 @@ impl<'a> RenderGraphContext<'a> {
|
|||
) {
|
||||
self.queue_buffer_write(target_slot, offset, bytemuck::bytes_of(&bytes));
|
||||
}
|
||||
|
||||
pub fn get_bind_groups<'b>(&self, graph: &'b RenderGraph, bind_group_names: &[&str]) -> Vec<Option<&'b Rc<wgpu::BindGroup>>> {
|
||||
bind_group_names
|
||||
.iter()
|
||||
.map(|name| graph.bind_group_id(name))
|
||||
.map(|bgi| bgi.map(|bgi| graph.bind_group(bgi)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn set_bind_groups<'b, 'c>(graph: &'b RenderGraph, pass: &'c mut ComputePass<'b>, bind_groups: &[(&str, u32)]) {
|
||||
for (name, index) in bind_groups {
|
||||
let bg = graph.bind_group_id(name).map(|bgi| graph.bind_group(bgi))
|
||||
.expect(&format!("Could not find bind group '{}'", name));
|
||||
|
||||
pass.set_bind_group(*index, bg, &[]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -236,8 +236,7 @@ impl RenderGraphPass for LightCullComputePass {
|
|||
pass.set_bind_group(3, grid_bg, &[]);
|
||||
pass.set_bind_group(4, screen_size_bg, &[]); */
|
||||
|
||||
RenderGraphContext::set_bind_groups(
|
||||
graph,
|
||||
graph.set_bind_groups(
|
||||
&mut pass,
|
||||
&[
|
||||
("depth_texture", 0),
|
||||
|
|
Loading…
Reference in New Issue