render: implement fxaa (#8)
CI / build (push) Successful in 3m33s Details

This commit is contained in:
SeanOMik 2024-06-28 15:26:14 -04:00
parent 5ebbec8cf9
commit c4aebdb25d
Signed by: SeanOMik
GPG Key ID: FEC9E2FC15235964
4 changed files with 441 additions and 4 deletions

View File

@ -0,0 +1,171 @@
use std::{collections::HashMap, rc::Rc};
use lyra_game_derive::RenderGraphLabel;
use crate::render::{
graph::{Node, NodeDesc, NodeType},
resource::{FragmentState, PipelineDescriptor, RenderPipelineDescriptor, Shader, VertexState},
};
#[derive(Default, Debug, Clone, Copy, Hash, RenderGraphLabel)]
pub struct FxaaPassLabel;
#[derive(Debug, Default)]
pub struct FxaaPass {
target_sampler: Option<wgpu::Sampler>,
bgl: Option<Rc<wgpu::BindGroupLayout>>,
/// Store bind groups for the input textures.
/// The texture may change due to resizes, or changes to the view target chain
/// from other nodes.
bg_cache: HashMap<wgpu::Id, wgpu::BindGroup>,
}
impl FxaaPass {
pub fn new() -> Self {
Self::default()
}
}
impl Node for FxaaPass {
fn desc(
&mut self,
graph: &mut crate::render::graph::RenderGraph,
) -> crate::render::graph::NodeDesc {
let device = &graph.device;
let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("fxaa_bgl"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::FRAGMENT,
ty: wgpu::BindingType::Texture {
sample_type: wgpu::TextureSampleType::Float { filterable: true },
view_dimension: wgpu::TextureViewDimension::D2,
multisampled: false,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::FRAGMENT,
ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
count: None,
},
],
});
let bgl = Rc::new(bgl);
self.bgl = Some(bgl.clone());
self.target_sampler = Some(device.create_sampler(&wgpu::SamplerDescriptor {
label: Some("fxaa sampler"),
mag_filter: wgpu::FilterMode::Linear,
min_filter: wgpu::FilterMode::Linear,
mipmap_filter: wgpu::FilterMode::Linear,
..Default::default()
}));
let shader = Rc::new(Shader {
label: Some("fxaa_shader".into()),
source: include_str!("../../shaders/fxaa.wgsl").to_string(),
});
let vt = graph.view_target();
NodeDesc::new(
NodeType::Render,
Some(PipelineDescriptor::Render(RenderPipelineDescriptor {
label: Some("fxaa_pass".into()),
layouts: vec![bgl.clone()],
push_constant_ranges: vec![],
vertex: VertexState {
module: shader.clone(),
entry_point: "vs_main".into(),
buffers: vec![],
},
fragment: Some(FragmentState {
module: shader,
entry_point: "fs_main".into(),
targets: vec![Some(wgpu::ColorTargetState {
format: vt.format(),
blend: Some(wgpu::BlendState::REPLACE),
write_mask: wgpu::ColorWrites::ALL,
})],
}),
depth_stencil: None,
primitive: wgpu::PrimitiveState::default(),
multisample: wgpu::MultisampleState::default(),
multiview: None,
})),
vec![],
)
}
fn prepare(
&mut self,
_: &mut crate::render::graph::RenderGraph,
_: &mut lyra_ecs::World,
_: &mut crate::render::graph::RenderGraphContext,
) {
//todo!()
}
fn execute(
&mut self,
graph: &mut crate::render::graph::RenderGraph,
_: &crate::render::graph::NodeDesc,
context: &mut crate::render::graph::RenderGraphContext,
) {
let pipeline = graph
.pipeline(context.label.clone())
.expect("Failed to find pipeline for FxaaPass");
let mut vt = graph.view_target_mut();
let chain = vt.get_chain();
let source_view = chain.source.frame_view.as_ref().unwrap();
let dest_view = chain.dest.frame_view.as_ref().unwrap();
let bg = self
.bg_cache
.entry(source_view.global_id())
.or_insert_with(|| {
graph
.device()
.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("fxaa_bg"),
layout: self.bgl.as_ref().unwrap(),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: wgpu::BindingResource::TextureView(source_view),
},
wgpu::BindGroupEntry {
binding: 1,
resource: wgpu::BindingResource::Sampler(
self.target_sampler.as_ref().unwrap(),
),
},
],
})
});
{
let encoder = context.encoder.as_mut().unwrap();
let mut pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
label: Some("fxaa_pass"),
color_attachments: &[Some(wgpu::RenderPassColorAttachment {
view: dest_view,
resolve_target: None,
ops: wgpu::Operations {
load: wgpu::LoadOp::Load,
store: true,
},
})],
depth_stencil_attachment: None,
});
pass.set_pipeline(pipeline.as_render());
pass.set_bind_group(0, bg, &[]);
pass.draw(0..3, 0..1);
}
}
}

View File

@ -18,3 +18,6 @@ pub use init::*;
mod tint;
pub use tint::*;
mod fxaa;
pub use fxaa::*;

View File

@ -9,7 +9,7 @@ use lyra_game_derive::RenderGraphLabel;
use tracing::{debug, instrument, warn};
use winit::window::Window;
use crate::render::graph::{BasePass, BasePassLabel, BasePassSlots, LightBasePass, LightBasePassLabel, LightCullComputePass, LightCullComputePassLabel, MeshPass, MeshesPassLabel, PresentPass, PresentPassLabel, RenderGraphLabelValue, RenderTarget, SubGraphNode, TintPass, TintPassLabel, ViewTarget};
use crate::render::graph::{BasePass, BasePassLabel, BasePassSlots, FxaaPass, FxaaPassLabel, LightBasePass, LightBasePassLabel, LightCullComputePass, LightCullComputePassLabel, MeshPass, MeshesPassLabel, PresentPass, PresentPassLabel, RenderGraphLabelValue, RenderTarget, SubGraphNode, ViewTarget};
use super::graph::RenderGraph;
use super::{resource::RenderPipeline, render_job::RenderJob};
@ -164,8 +164,8 @@ impl BasicRenderer {
));
}
main_graph.add_node(TintPassLabel, TintPass::default());
main_graph.add_edge(TestSubGraphLabel, TintPassLabel);
main_graph.add_node(FxaaPassLabel, FxaaPass::default());
main_graph.add_edge(TestSubGraphLabel, FxaaPassLabel);
//let present_pass_label = PresentPassLabel::new(BasePassSlots::Frame);//TintPassSlots::Frame);
let p = PresentPass;

View File

@ -0,0 +1,263 @@
// Largely based off of https://blog.simonrodriguez.fr/articles/2016/07/implementing_fxaa.html
const EDGE_THRESHOLD_MIN: f32 = 0.0312;
const EDGE_THRESHOLD_MAX: f32 = 0.125;
const ITERATIONS: i32 = 12;
const SUBPIXEL_QUALITY: f32 = 0.75;
@group(0) @binding(0)
var t_screen: texture_2d<f32>;
@group(0) @binding(1)
var s_screen: sampler;
struct VertexOutput {
@builtin(position)
clip_position: vec4<f32>,
@location(0)
tex_coords: vec2<f32>,
}
fn QUALITY(q: i32) -> f32 {
switch (q) {
default: { return 1.0; }
case 5: { return 1.5; }
case 6, 7, 8, 9: { return 2.0; }
case 10: { return 4.0; }
case 11: { return 8.0; }
}
}
fn rgb2luma(rgb: vec3<f32>) -> f32 {
return sqrt(dot(rgb, vec3<f32>(0.299, 0.587, 0.114)));
}
@vertex
fn vs_main(
@builtin(vertex_index) vertex_index: u32,
) -> VertexOutput {
let tex_coords = vec2<f32>(f32(vertex_index >> 1u), f32(vertex_index & 1u)) * 2.0;
let clip_position = vec4<f32>(tex_coords * vec2<f32>(2.0, -2.0) + vec2<f32>(-1.0, 1.0), 0.0, 1.0);
return VertexOutput(clip_position, tex_coords);
}
fn texture_offset(tex: texture_2d<f32>, samp: sampler, point: vec2<f32>, offset: vec2<i32>) -> vec3<f32> {
var tex_coords = point + vec2<f32>(offset);
return textureSample(tex, samp, tex_coords).xyz;
}
@fragment
fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
let resolution = vec2<f32>(textureDimensions(t_screen));
let inverse_screen_size = 1.0 / resolution.xy;
let tex_coords = in.clip_position.xy * inverse_screen_size;
var color_center: vec3<f32> = textureSampleLevel(t_screen, s_screen, tex_coords, 0.0).xyz;
// Luma at the current fragment
let luma_center = rgb2luma(color_center);
// Luma at the four direct neighbours of the current fragment.
let luma_down = rgb2luma(textureSampleLevel(t_screen, s_screen, tex_coords, 0.0, vec2<i32>(0, -1)).xyz);
let luma_up = rgb2luma(textureSampleLevel(t_screen, s_screen, tex_coords, 0.0, vec2<i32>(0, 1)).xyz);
let luma_left = rgb2luma(textureSampleLevel(t_screen, s_screen, tex_coords, 0.0, vec2<i32>(-1, 0)).xyz);
let luma_right = rgb2luma(textureSampleLevel(t_screen, s_screen, tex_coords, 0.0, vec2<i32>(1, 0)).xyz);
// Find the maximum and minimum luma around the current fragment.
let luma_min = min(luma_center, min(min(luma_down, luma_up), min(luma_left, luma_right)));
let luma_max = max(luma_center, max(max(luma_down, luma_up), max(luma_left, luma_right)));
// Compute the delta
let luma_range = luma_max - luma_min;
// If the luma variation is lower that a threshold (or if we are in a really dark area),
// we are not on an edge, don't perform any AA.
if (luma_range < max(EDGE_THRESHOLD_MIN, luma_max * EDGE_THRESHOLD_MAX)) {
return vec4<f32>(color_center, 1.0);
}
// Query the 4 remaining corners lumas
let luma_down_left = rgb2luma(textureSampleLevel(t_screen, s_screen, tex_coords, 0.0, vec2<i32>(-1, -1)).xyz);
let luma_up_right = rgb2luma(textureSampleLevel(t_screen, s_screen, tex_coords, 0.0, vec2<i32>(1, 1)).xyz);
let luma_up_left = rgb2luma(textureSampleLevel(t_screen, s_screen, tex_coords, 0.0, vec2<i32>(-1, 1)).xyz);
let luma_down_right = rgb2luma(textureSampleLevel(t_screen, s_screen, tex_coords, 0.0, vec2<i32>(1, -1)).xyz);
// Combine the four edges lumas (using intermediary variables for future computations with the same values).
let luma_down_up = luma_down + luma_up;
let luma_left_right = luma_left + luma_right;
// Same for corners
let luma_left_corners = luma_down_left + luma_up_left;
let luma_down_corners = luma_down_left + luma_down_right;
let luma_right_corners = luma_down_right + luma_up_right;
let luma_up_corners = luma_up_right + luma_up_left;
// Compute an estimation of the gradient along the horizontal and verical axis.
let edge_horizontal = abs(-2.0 * luma_left + luma_left_corners)
+ abs(-2.0 * luma_center + luma_down_up) * 2.0
+ abs(-2.0 * luma_right + luma_right_corners);
let edge_vertical = abs(-2.0 * luma_up + luma_up_corners)
+ abs(-2.0 * luma_center + luma_left_right) * 2.0
+ abs(-2.0 * luma_down + luma_down_corners);
// Is the local edge horizontal or vertical?
let is_horizontal = edge_horizontal >= edge_vertical;
// Select the two neighboring texels lumas in the opposite direction to the local edge.
let luma1 = select(luma_left, luma_down, is_horizontal);
let luma2 = select(luma_right, luma_up, is_horizontal);
// Compute gradients in this direction
let gradient1 = luma1 - luma_center;
let gradient2 = luma2 - luma_center;
// Which direction is the steepest?
let is_1_steepest = abs(gradient1) >= abs(gradient2);
// Gradient in the corresponding direction, normalized
let gradient_scaled = 0.25 * max(abs(gradient1), abs(gradient2));
// Choose the step size (one pixel) according to the edge direction.
var step_length: f32;
if (is_horizontal) {
step_length = inverse_screen_size.y;
} else {
step_length = inverse_screen_size.x;
}
// Average luma in the correct direction.
var luma_local_average = 0.0;
if (is_1_steepest) {
// Switch the direction
step_length = -step_length;
luma_local_average = 0.5 * (luma1 + luma_center);
} else {
luma_local_average = 0.5 * (luma2 + luma_center);
}
// Shift UV in the correct direction by half a pixel.
var current_uv = tex_coords;
if (is_horizontal) {
current_uv.y += step_length * 0.5;
} else {
current_uv.x += step_length * 0.5;
}
// Compute offset (for each iteration step) in the right direction.
var offset: vec2<f32>;
if (is_horizontal) {
offset = vec2<f32>(inverse_screen_size.x, 0.0);
} else {
offset = vec2<f32>(0.0, inverse_screen_size.y);
}
// Compute UVs to explore on each side of the edge, orthogonally. The QUALITY allows us to
// step faster.
var uv1 = current_uv - offset;
var uv2 = current_uv + offset;
// Read the lumas at both current extremities of the exploration segment, and compute the
// delta wrt to the local average luma.
var luma_end1 = rgb2luma(textureSampleLevel(t_screen, s_screen, uv1, 0.0).xyz);
var luma_end2 = rgb2luma(textureSampleLevel(t_screen, s_screen, uv2, 0.0).xyz);
luma_end1 -= luma_local_average;
luma_end2 -= luma_local_average;
// If the luma deltas at the current extremities are larger than the local gradient, we have
// reached the side of the edge.
var reached1 = abs(luma_end1) >= gradient_scaled;
var reached2 = abs(luma_end2) >= gradient_scaled;
var reached_both = reached1 && reached2;
// If the side is not reached, we continue to explore in this direction.
if (!reached1) {
uv1 -= offset;
}
if (!reached2) {
uv2 += offset;
}
if (!reached_both) {
for (var i = 2; i < ITERATIONS; i++) {
// If needed, read luma in 1st direction, compute delta.
if (!reached1) {
luma_end1 = rgb2luma(textureSampleLevel(t_screen, s_screen, uv1, 0.0).xyz);
luma_end1 = luma_end1 - luma_local_average;
}
// If needed, read luma in opposite direction, compute delta.
if (!reached2) {
luma_end2 = rgb2luma(textureSampleLevel(t_screen, s_screen, uv2, 0.0).xyz);
luma_end2 = luma_end2 - luma_local_average;
}
// If the luma deltas at the current extremities is larger than the local gradient, we have reached the side of the edge.
reached1 = abs(luma_end1) >= gradient_scaled;
reached2 = abs(luma_end2) >= gradient_scaled;
reached_both = reached1 && reached2;
// If the side is not reached, we continue to explore in this direction, with a variable quality.
if (!reached1) {
uv1 -= offset * QUALITY(i);
}
if (!reached2) {
uv2 += offset * QUALITY(i);
}
// If both sides have been reached, stop the exploration
if (reached_both) {
break;
}
}
}
// Compute the distances to each extremity of the edge.
var distance1 = select(tex_coords.y - uv1.y, tex_coords.x - uv1.x, is_horizontal);
var distance2 = select(uv2.y - tex_coords.y, uv2.x - tex_coords.x, is_horizontal);
// In which direction is the extremity of the edge closer?
let is_direction1 = distance1 < distance2;
let distance_final = min(distance1, distance2);
// Length of the edge.
let edge_thickness = (distance1 + distance2);
// UV offset: read in the direction of the closest side of the edge.
let pixel_offset = -distance_final / edge_thickness + 0.5;
// Is the luma at center smaller than the local average?
let is_luma_center_smaller = luma_center < luma_local_average;
// If the luma at center is smaller than at its neighbour, the delta luma at each end should
// be positive (same variation). (in the direction of the closer side of the edge.)
var direction_luma_end: f32;
if (is_direction1) {
direction_luma_end = luma_end1;
} else {
direction_luma_end = luma_end2;
}
let correct_variation = (direction_luma_end < 0.0) != is_luma_center_smaller;
// If the luma variation is incorrect, do not offset.
var final_offset = select(0.0, pixel_offset, correct_variation);
// Sub-pixel shifting
// Full weighted average of the luma over the 3x3 neighborhood.
let luma_average = (1.0 / 12.0) * (2.0 * (luma_down_up + luma_left_right) + luma_left_corners + luma_right_corners);
// Ratio of the delta between the global average and the center luma, over the luma range
// in the 3x3 neighborhood.
let sub_pixel_offset1 = clamp(abs(luma_average - luma_center) / luma_range, 0.0, 1.0);
let sub_pixel_offset2 = (-2.0 * sub_pixel_offset1 + 3.0) * sub_pixel_offset1 * sub_pixel_offset1;
// Compute a sub-pixel offset based on this delta.
let sub_pixel_offset_final = sub_pixel_offset2 * sub_pixel_offset2 * SUBPIXEL_QUALITY;
// Pick the biggest of the two offsets.
final_offset = max(final_offset, sub_pixel_offset_final);
var final_uv = tex_coords;
if (is_horizontal) {
final_uv.y += final_offset * step_length;
} else {
final_uv.x += final_offset * step_length;
}
let color = textureSampleLevel(t_screen, s_screen, final_uv, 0.0).xyz;
return vec4<f32>(color, 1.0);
}