render: create a depth map for the directional light

This commit is contained in:
SeanOMik 2024-06-30 19:33:51 -04:00
parent 3a80c069c9
commit e8974bbd44
Signed by: SeanOMik
GPG Key ID: FEC9E2FC15235964
5 changed files with 361 additions and 3 deletions

View File

@ -22,8 +22,8 @@ pub use tint::*;
mod fxaa; mod fxaa;
pub use fxaa::*; pub use fxaa::*;
/* mod shadow_maps; mod shadows;
pub use shadow_maps::*; */ pub use shadows::*;
mod mesh_prepare; mod mesh_prepare;
pub use mesh_prepare::*; pub use mesh_prepare::*;

View File

@ -0,0 +1,312 @@
use std::{mem, num::NonZeroU64, rc::Rc, sync::Arc};
use lyra_ecs::{query::Entities, AtomicRef, Entity, ResourceData};
use lyra_game_derive::RenderGraphLabel;
use lyra_math::{Transform, OPENGL_TO_WGPU_MATRIX};
use rustc_hash::FxHashMap;
use tracing::{debug, warn};
use wgpu::util::DeviceExt;
use crate::render::{
graph::{Node, NodeDesc, NodeType},
light::directional::DirectionalLight,
resource::{FragmentState, PipelineDescriptor, RenderPipeline, RenderPipelineDescriptor, Shader, VertexState},
transform_buffer_storage::TransformBuffers,
vertex::Vertex,
};
use super::{MeshBufferStorage, RenderAssets, RenderMeshes};
const SHADOW_SIZE: glam::UVec2 = glam::UVec2::new(1024, 1024);
#[derive(Debug, Clone, Hash, PartialEq, RenderGraphLabel)]
pub struct ShadowMapsPassLabel;
struct LightDepthMap {
light_projection_buffer: wgpu::Buffer,
texture: wgpu::Texture,
view: wgpu::TextureView,
sampler: wgpu::Sampler,
bindgroup: wgpu::BindGroup,
}
pub struct ShadowMapsPass {
bgl: Arc<wgpu::BindGroupLayout>,
/// depth maps for a light owned by an entity.
depth_maps: FxHashMap<Entity, LightDepthMap>,
// TODO: find a better way to extract these resources from the main world to be used in the
// render stage.
transform_buffers: Option<ResourceData>,
render_meshes: Option<ResourceData>,
mesh_buffers: Option<ResourceData>,
pipeline: Option<RenderPipeline>,
}
impl ShadowMapsPass {
pub fn new(device: &wgpu::Device) -> Self {
let bgl = Arc::new(device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("shadows_bgl"),
entries: &[wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::VERTEX,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: Some(
NonZeroU64::new(mem::size_of::<glam::Mat4>() as _).unwrap(),
),
},
count: None,
}],
}));
Self {
bgl,
depth_maps: Default::default(),
transform_buffers: None,
render_meshes: None,
mesh_buffers: None,
pipeline: None,
}
}
fn create_depth_map(&mut self, device: &wgpu::Device, entity: Entity, light_pos: Transform) {
let tex = device.create_texture(&wgpu::TextureDescriptor {
label: Some("texture_shadow_map_directional_light"),
size: wgpu::Extent3d {
width: SHADOW_SIZE.x,
height: SHADOW_SIZE.y,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
format: wgpu::TextureFormat::Depth32Float,
usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
view_formats: &[],
});
let view = tex.create_view(&wgpu::TextureViewDescriptor {
label: Some("shadows_map_view"),
..Default::default()
});
let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
label: Some("sampler_light_depth_map"),
address_mode_u: wgpu::AddressMode::ClampToEdge,
address_mode_v: wgpu::AddressMode::ClampToEdge,
address_mode_w: wgpu::AddressMode::ClampToEdge,
mag_filter: wgpu::FilterMode::Linear,
min_filter: wgpu::FilterMode::Linear,
mipmap_filter: wgpu::FilterMode::Linear,
border_color: Some(wgpu::SamplerBorderColor::OpaqueWhite),
..Default::default()
});
const NEAR_PLANE: f32 = 0.1;
const FAR_PLANE: f32 = 80.0;
let ortho_proj =
glam::Mat4::orthographic_rh_gl(-20.0, 20.0, -20.0, 20.0, NEAR_PLANE, FAR_PLANE);
let look_view = glam::Mat4::look_to_rh(
light_pos.translation,
light_pos.forward(),
light_pos.up()
);
let light_proj = OPENGL_TO_WGPU_MATRIX * (ortho_proj * look_view);
let light_projection_buffer =
device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("shadows_light_view_mat_buffer"),
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
contents: bytemuck::bytes_of(&light_proj),
});
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("shadows_bind_group"),
layout: &self.bgl,
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
buffer: &light_projection_buffer,
offset: 0,
size: None,
}),
}],
});
self.depth_maps.insert(
entity,
LightDepthMap {
light_projection_buffer,
texture: tex,
view,
sampler,
bindgroup: bg,
},
);
}
fn transform_buffers(&self) -> AtomicRef<TransformBuffers> {
self.transform_buffers.as_ref().unwrap().get()
}
fn render_meshes(&self) -> AtomicRef<RenderMeshes> {
self.render_meshes.as_ref().unwrap().get()
}
fn mesh_buffers(&self) -> AtomicRef<RenderAssets<MeshBufferStorage>> {
self.mesh_buffers.as_ref().unwrap().get()
}
}
impl Node for ShadowMapsPass {
fn desc(
&mut self,
graph: &mut crate::render::graph::RenderGraph,
) -> crate::render::graph::NodeDesc {
NodeDesc::new(NodeType::Render, None, vec![])
}
fn prepare(
&mut self,
graph: &mut crate::render::graph::RenderGraph,
world: &mut lyra_ecs::World,
context: &mut crate::render::graph::RenderGraphContext,
) {
self.render_meshes = world.try_get_resource_data::<RenderMeshes>();
self.transform_buffers = world.try_get_resource_data::<TransformBuffers>();
self.mesh_buffers = world.try_get_resource_data::<RenderAssets<MeshBufferStorage>>();
for (entity, pos, light) in world.view_iter::<(Entities, &Transform, &DirectionalLight)>() {
if !self.depth_maps.contains_key(&entity) {
self.create_depth_map(graph.device(), entity, *pos);
debug!("Created depth map for {:?} light entity", entity);
}
}
if self.pipeline.is_none() {
let shader = Rc::new(Shader {
label: Some("shader_shadows".into()),
source: include_str!("../../shaders/shadows.wgsl").to_string(),
});
let bgl = self.bgl.clone();
let transforms = self.transform_buffers().bindgroup_layout.clone();
self.pipeline = Some(RenderPipeline::create(
&graph.device,
&RenderPipelineDescriptor {
label: Some("pipeline_shadows".into()),
layouts: vec![
bgl,
transforms,
],
push_constant_ranges: vec![],
vertex: VertexState {
module: shader.clone(),
entry_point: "vs_main".into(),
buffers: vec![Vertex::position_desc().into()],
},
fragment: None, /* Some(FragmentState {
module: shader,
entry_point: "fs_main".into(),
targets: vec![],
}), */
depth_stencil: Some(wgpu::DepthStencilState {
format: wgpu::TextureFormat::Depth32Float,
depth_write_enabled: true,
depth_compare: wgpu::CompareFunction::Less,
stencil: wgpu::StencilState::default(),
bias: wgpu::DepthBiasState::default(),
}),
primitive: wgpu::PrimitiveState::default(),
multisample: wgpu::MultisampleState::default(),
multiview: None,
}
));
/* */
}
}
fn execute(
&mut self,
graph: &mut crate::render::graph::RenderGraph,
desc: &crate::render::graph::NodeDesc,
context: &mut crate::render::graph::RenderGraphContext,
) {
let encoder = context.encoder.as_mut().unwrap();
let pipeline = self.pipeline.as_ref().unwrap();
let render_meshes = self.render_meshes();
let mesh_buffers = self.mesh_buffers();
let transforms = self.transform_buffers();
debug_assert_eq!(
self.depth_maps.len(),
1,
"shadows map pass only supports 1 light"
);
let (_, dir_depth_map) = self
.depth_maps
.iter()
.next()
.expect("missing directional light in scene");
{
let mut pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
label: Some("pass_shadow_map"),
color_attachments: &[],
depth_stencil_attachment: Some(wgpu::RenderPassDepthStencilAttachment {
view: &dir_depth_map.view,
depth_ops: Some(wgpu::Operations {
load: wgpu::LoadOp::Clear(1.0),
store: true,
}),
stencil_ops: None,
}),
});
pass.set_pipeline(&pipeline);
for job in render_meshes.iter() {
// get the mesh (containing vertices) and the buffers from storage
let buffers = mesh_buffers.get(&job.mesh_uuid);
if buffers.is_none() {
warn!("Skipping job since its mesh is missing {:?}", job.mesh_uuid);
continue;
}
let buffers = buffers.unwrap();
pass.set_bind_group(0, &dir_depth_map.bindgroup, &[]);
// Get the bindgroup for job's transform and bind to it using an offset.
let bindgroup = transforms.bind_group(job.transform_id);
let offset = transforms.buffer_offset(job.transform_id);
pass.set_bind_group(1, bindgroup, &[offset]);
// if this mesh uses indices, use them to draw the mesh
if let Some((idx_type, indices)) = buffers.buffer_indices.as_ref() {
let indices_len = indices.count() as u32;
pass.set_vertex_buffer(
buffers.buffer_vertex.slot(),
buffers.buffer_vertex.buffer().slice(..),
);
pass.set_index_buffer(indices.buffer().slice(..), *idx_type);
pass.draw_indexed(0..indices_len, 0, 0..1);
} else {
let vertex_count = buffers.buffer_vertex.count();
pass.set_vertex_buffer(
buffers.buffer_vertex.slot(),
buffers.buffer_vertex.buffer().slice(..),
);
pass.draw(0..vertex_count as u32, 0..1);
}
}
}
}
}

View File

@ -9,7 +9,7 @@ use lyra_game_derive::RenderGraphLabel;
use tracing::{debug, instrument, warn}; use tracing::{debug, instrument, warn};
use winit::window::Window; use winit::window::Window;
use crate::render::graph::{BasePass, BasePassLabel, BasePassSlots, FxaaPass, FxaaPassLabel, LightBasePass, LightBasePassLabel, LightCullComputePass, LightCullComputePassLabel, MeshPass, MeshPrepNode, MeshPrepNodeLabel, MeshesPassLabel, PresentPass, PresentPassLabel, RenderGraphLabelValue, RenderTarget, SubGraphNode, ViewTarget}; use crate::render::graph::{BasePass, BasePassLabel, BasePassSlots, FxaaPass, FxaaPassLabel, LightBasePass, LightBasePassLabel, LightCullComputePass, LightCullComputePassLabel, MeshPass, MeshPrepNode, MeshPrepNodeLabel, MeshesPassLabel, PresentPass, PresentPassLabel, RenderGraphLabelValue, RenderTarget, ShadowMapsPass, ShadowMapsPassLabel, SubGraphNode, ViewTarget};
use super::graph::RenderGraph; use super::graph::RenderGraph;
use super::{resource::RenderPipeline, render_job::RenderJob}; use super::{resource::RenderPipeline, render_job::RenderJob};
@ -152,9 +152,15 @@ impl BasicRenderer {
forward_plus_graph.add_node(MeshPrepNodeLabel, mesh_prep); forward_plus_graph.add_node(MeshPrepNodeLabel, mesh_prep);
forward_plus_graph.add_node(MeshesPassLabel, MeshPass::new(material_bgl)); forward_plus_graph.add_node(MeshesPassLabel, MeshPass::new(material_bgl));
forward_plus_graph.add_node(ShadowMapsPassLabel, ShadowMapsPass::new(&device));
forward_plus_graph.add_edge(LightBasePassLabel, LightCullComputePassLabel); forward_plus_graph.add_edge(LightBasePassLabel, LightCullComputePassLabel);
forward_plus_graph.add_edge(MeshPrepNodeLabel, MeshesPassLabel); forward_plus_graph.add_edge(MeshPrepNodeLabel, MeshesPassLabel);
// run ShadowMapsPass after MeshPrep and before MeshesPass
forward_plus_graph.add_edge(MeshPrepNodeLabel, ShadowMapsPassLabel);
forward_plus_graph.add_edge(ShadowMapsPassLabel, MeshesPassLabel);
main_graph.add_sub_graph(TestSubGraphLabel, forward_plus_graph); main_graph.add_sub_graph(TestSubGraphLabel, forward_plus_graph);
main_graph.add_node(TestSubGraphLabel, SubGraphNode::new(TestSubGraphLabel, main_graph.add_node(TestSubGraphLabel, SubGraphNode::new(TestSubGraphLabel,
vec![ vec![

View File

@ -0,0 +1,23 @@
struct TransformData {
transform: mat4x4<f32>,
normal_matrix: mat4x4<f32>,
}
@group(0) @binding(0)
var<uniform> u_light_space_matrix: mat4x4<f32>;
@group(1) @binding(0)
var<uniform> u_model_transform_data: TransformData;
struct VertexOutput {
@builtin(position)
clip_position: vec4<f32>,
}
@vertex
fn vs_main(
@location(0) position: vec3<f32>
) -> VertexOutput {
let pos = u_light_space_matrix * u_model_transform_data.transform * vec4<f32>(position, 1.0);
return VertexOutput(pos);
}

View File

@ -15,6 +15,23 @@ impl Vertex {
position, tex_coords, normals position, tex_coords, normals
} }
} }
/// Returns a [`wgpu::VertexBufferLayout`] with only the position as a vertex attribute.
///
/// The stride is still `std::mem::size_of::<Vertex>()`, but only position is included.
pub fn position_desc<'a>() -> wgpu::VertexBufferLayout<'a> {
wgpu::VertexBufferLayout {
array_stride: std::mem::size_of::<Vertex>() as wgpu::BufferAddress,
step_mode: wgpu::VertexStepMode::Vertex,
attributes: &[
wgpu::VertexAttribute {
offset: 0,
shader_location: 0,
format: wgpu::VertexFormat::Float32x3, // Vec3
},
]
}
}
} }
impl DescVertexBufferLayout for Vertex { impl DescVertexBufferLayout for Vertex {