Fix many many problems that led to invalid shader code output for lyra engine

This commit is contained in:
SeanOMik 2024-09-14 20:03:37 -04:00
parent 2899b1c3d3
commit 70daf32082
Signed by: SeanOMik
GPG Key ID: FEC9E2FC15235964
3 changed files with 1139 additions and 248 deletions

View File

@ -25,6 +25,13 @@ pub struct Import {
imports: Vec<String>, imports: Vec<String>,
} }
/// An import of a single item from a module.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct ImportItemFrom {
module: String,
import: String,
}
#[derive(Clone, PartialEq, Eq, Hash)] #[derive(Clone, PartialEq, Eq, Hash)]
pub struct DefRequirement { pub struct DefRequirement {
/// None if the requirement is local /// None if the requirement is local
@ -138,7 +145,7 @@ fn main() -> vec4<f32> {
let simple_path = p.parse_module(&SIMPLE_MODULE).unwrap() let simple_path = p.parse_module(&SIMPLE_MODULE).unwrap()
.expect("failed to find module"); .expect("failed to find module");
let out = p.process_file(&simple_path, &SIMPLE_MODULE).unwrap(); let out = p.preprocess_module(&simple_path).unwrap();
assert!(out.contains("const scalar"), "definition of imported `const scalar` is missing!"); assert!(out.contains("const scalar"), "definition of imported `const scalar` is missing!");
assert!(out.contains("fn mult_some_nums"), "definition of imported `fn mult_some_nums` is missing!"); assert!(out.contains("fn mult_some_nums"), "definition of imported `fn mult_some_nums` is missing!");
@ -150,11 +157,11 @@ fn main() -> vec4<f32> {
/// i.e., `simple::some_const`. /// i.e., `simple::some_const`.
#[test] #[test]
fn double_layer_import_indirect_imports() { fn double_layer_import_indirect_imports() {
/* tracing_subscriber::fmt() tracing_subscriber::fmt()
// enable everything // enable everything
.with_max_level(tracing::Level::TRACE) .with_max_level(tracing::Level::TRACE)
// sets this to be the default, global collector for this application. // sets this to be the default, global collector for this application.
.init(); */ .init();
let mut p = Processor::new(); let mut p = Processor::new();
p.parse_module(&INNER_MODULE).unwrap() p.parse_module(&INNER_MODULE).unwrap()
@ -164,18 +171,18 @@ fn main() -> vec4<f32> {
let main_path = p.parse_module(&MAIN_MODULE).unwrap() let main_path = p.parse_module(&MAIN_MODULE).unwrap()
.expect("failed to find module"); .expect("failed to find module");
let out = p.process_file(&main_path, &MAIN_MODULE).unwrap(); let out = p.preprocess_module(&main_path).unwrap();
assert!(out.contains("const simple_scalar"), "definition of imported `const simple_scalar` is missing!"); assert!(out.contains("const simple_scalar"), "definition of imported `const simple_scalar` is missing!\n{}", out);
assert!(out.contains("fn do_something_cool("), "definition of imported `fn do_something_cool` is missing!"); assert!(out.contains("fn do_something_cool("), "definition of imported `fn do_something_cool` is missing!\n{}", out);
assert!(out.contains("const scalar"), "definition of imported dependency, `const scalar` is missing! {}", out); assert!(out.contains("const scalar"), "definition of imported dependency, `const scalar` is missing!\n{}", out);
assert!(out.contains("fn mult_some_nums("), "definition of imported dependency, `fn mult_some_nums` is missing!"); assert!(out.contains("fn mult_some_nums("), "definition of imported dependency, `fn mult_some_nums` is missing!\n{}", out);
assert!(out.contains("fn do_something_cool("), "definition of imported `fn do_something_cool` is missing!"); assert!(out.contains("fn do_something_cool("), "definition of imported `fn do_something_cool` is missing!\n{}", out);
assert!(out.contains("simple_scalar * do_something_cool"), "indirect imports of `simple_scalar * do_something_cool` is missing!"); assert!(out.contains("simple_scalar * do_something_cool"), "indirect imports of `simple_scalar * do_something_cool` is missing!\n{}", out);
assert!(out.contains("fn main("), "definition of `fn main` is missing!"); assert!(out.contains("fn main("), "definition of `fn main` is missing!\n{}", out);
} }
#[test] #[test]
@ -211,7 +218,7 @@ fn main() -> vec4<f32> {
let importing_mod = p.parse_module(IMPORTING_MODULE).unwrap() let importing_mod = p.parse_module(IMPORTING_MODULE).unwrap()
.expect("failed to find main module def"); .expect("failed to find main module def");
let out = p.process_file(&importing_mod, &IMPORTING_MODULE).unwrap(); let out = p.preprocess_module(&importing_mod).unwrap();
assert!(out.contains("var<uniform> shadows_atlas"), "missing shadows_atlas binding: {}", out); assert!(out.contains("var<uniform> shadows_atlas"), "missing shadows_atlas binding: {}", out);
assert!(out.contains("var<uniform> lights"), "missing lights binding: {}", out); assert!(out.contains("var<uniform> lights"), "missing lights binding: {}", out);
@ -252,7 +259,7 @@ fn main() -> vec4<f32> {
let importing_mod = p.parse_module(IMPORTING_MODULE).unwrap() let importing_mod = p.parse_module(IMPORTING_MODULE).unwrap()
.expect("failed to find main module def"); .expect("failed to find main module def");
let out = p.process_file(&importing_mod, &IMPORTING_MODULE).unwrap(); let out = p.preprocess_module(&importing_mod).unwrap();
assert!(out.contains("var<uniform> shadows_atlas"), "missing shadows_atlas binding: {}", out); assert!(out.contains("var<uniform> shadows_atlas"), "missing shadows_atlas binding: {}", out);
assert!(out.contains("var<uniform> lights"), "missing lights binding: {}", out); assert!(out.contains("var<uniform> lights"), "missing lights binding: {}", out);
@ -302,7 +309,7 @@ fn main() -> vec4<f32> {
let importing_mod = p.parse_module(IMPORTING_MODULE).unwrap() let importing_mod = p.parse_module(IMPORTING_MODULE).unwrap()
.expect("failed to find main module def"); .expect("failed to find main module def");
let out = p.process_file(&importing_mod, &IMPORTING_MODULE).unwrap(); let out = p.preprocess_module(&importing_mod).unwrap();
assert!(out.contains("var use_anyway"), "missing use_anyway binding: {}", out); assert!(out.contains("var use_anyway"), "missing use_anyway binding: {}", out);
assert!(out.contains("var<uniform> shadows_atlas"), "missing shadows_atlas binding: {}", out); assert!(out.contains("var<uniform> shadows_atlas"), "missing shadows_atlas binding: {}", out);
@ -348,7 +355,7 @@ fn main() -> vec4<f32> {
let lights_mod = p.modules.get("engine::lights").unwrap(); let lights_mod = p.modules.get("engine::lights").unwrap();
assert!(lights_mod.structs.contains_key("Light"), "`Light` struct definition was not parsed"); assert!(lights_mod.structs.contains_key("Light"), "`Light` struct definition was not parsed");
let out = p.process_file(&importing_mod, &IMPORTING_MODULE).unwrap(); let out = p.preprocess_module(&importing_mod).unwrap();
assert!(out.contains("struct Light"), "missing `Light` struct definition: {}", out); assert!(out.contains("struct Light"), "missing `Light` struct definition: {}", out);
} }
@ -398,7 +405,7 @@ fn main() -> vec4<f32> {
let lights_mod = p.modules.get("engine::lights").unwrap(); let lights_mod = p.modules.get("engine::lights").unwrap();
assert!(lights_mod.structs.contains_key("Light"), "`Light` struct definition was not parsed"); assert!(lights_mod.structs.contains_key("Light"), "`Light` struct definition was not parsed");
let out = p.process_file(&importing_mod, &IMPORTING_MODULE).unwrap(); let out = p.preprocess_module(&importing_mod).unwrap();
assert!(out.contains("struct Light"), "missing `Light` struct definition: {}", out); assert!(out.contains("struct Light"), "missing `Light` struct definition: {}", out);
assert!(out.contains("struct Something"), "missing `Something` struct definition: {}", out); assert!(out.contains("struct Something"), "missing `Something` struct definition: {}", out);
} }
@ -433,7 +440,7 @@ fn main() -> vec4<f32> {
let importing_mod = p.parse_module(IMPORTING_MODULE).unwrap() let importing_mod = p.parse_module(IMPORTING_MODULE).unwrap()
.expect("failed to find main module def"); .expect("failed to find main module def");
let out = p.process_file(&importing_mod, &IMPORTING_MODULE).unwrap(); let out = p.preprocess_module(&importing_mod).unwrap();
assert!(out.contains("alias LightIntensity"), "missing `LightIntensity` type alias: {}", out); assert!(out.contains("alias LightIntensity"), "missing `LightIntensity` type alias: {}", out);
} }
@ -475,8 +482,535 @@ fn main() -> vec4<f32> {
let importing_mod = p.parse_module(IMPORTING_MODULE).unwrap() let importing_mod = p.parse_module(IMPORTING_MODULE).unwrap()
.expect("failed to find main module def"); .expect("failed to find main module def");
let out = p.process_file(&importing_mod, &IMPORTING_MODULE).unwrap(); let out = p.preprocess_module(&importing_mod).unwrap();
assert!(out.contains("alias LightIntensity"), "missing `LightIntensity` type alias: {}", out); assert!(out.contains("alias LightIntensity"), "missing `LightIntensity` type alias: {}", out);
assert!(out.contains("alias Number"), "missing `Number` type alias: {}", out); assert!(out.contains("alias Number"), "missing `Number` type alias: {}", out);
} }
/// Tests comments and gaps in structs
#[test]
fn comments_and_gaps_in_structs() {
const TYPE_MODULE: &'static str =
r#"#define_module engine::types
struct TextureAtlasFrame {
/*offset: vec2<u32>,
size: vec2<u32>,*/
x: u32,
y: u32,
width: u32,
height: u32,
}"#;
/* tracing_subscriber::fmt()
// enable everything
.with_max_level(tracing::Level::TRACE)
// sets this to be the default, global collector for this application.
.init(); */
let mut p = Processor::new();
let type_mod = p.parse_module(TYPE_MODULE).unwrap()
.expect("failed to find bindings module def");
let out = p.preprocess_module(&type_mod).unwrap();
assert!(out.contains("struct TextureAtlasFrame"), "missing `TextureAtlasFrame` struct definition: {}", out);
assert!(!out.contains("offset: vec2<u32>,"), "`offset` struct field was included in output: {}", out);
assert!(!out.contains("size: vec2<u32>,"), "`size` struct field was included in output: {}", out);
assert!(out.contains("x: u32,"), "missing `x` struct field: {}", out);
assert!(out.contains("y: u32,"), "missing `y` struct field: {}", out);
assert!(out.contains("width: u32,"), "missing `width` struct field: {}", out);
assert!(out.contains("height: u32,"), "missing `height` struct field: {}", out);
}
#[test]
fn built_in_generics() {
const TYPE_MODULE: &'static str =
r#"#define_module engine::types
@group(5) @binding(5)
var<storage, read> point_array: array<vec2<f32>>;"#;
/* tracing_subscriber::fmt()
// enable everything
.with_max_level(tracing::Level::TRACE)
// sets this to be the default, global collector for this application.
.init(); */
let mut p = Processor::new();
let type_mod = p.parse_module(TYPE_MODULE).unwrap()
.expect("failed to find bindings module def");
let out = p.preprocess_module(&type_mod).unwrap();
assert!(out.contains("var<storage, read> point_array"), "missing `point_array` storage variable: {}", out);
}
#[test]
fn complex_code_blocks() {
const MAIN_MODULE: &'static str =
r#"#define_module engine::main
fn main() -> vec4<f32> {
let abs_x = 5.0;
let abs_y = 2.0;
let abs_z = 1.0;
if (abs_x >= abs_y && abs_x >= abs_z) {
let some_thing = 0.5;
return vec4<f32>(some_thing);
}
return vec4<f32>(1.0);
}"#;
/* tracing_subscriber::fmt()
// enable everything
.with_max_level(tracing::Level::TRACE)
// sets this to be the default, global collector for this application.
.init(); */
let mut p = Processor::new();
let module = p.parse_module(MAIN_MODULE).unwrap()
.expect("failed to find bindings module def");
let out = p.preprocess_module(&module).unwrap();
assert!(out.contains("abs_x >= abs_y && abs_x >= abs_z"), "missing if statement condition: {}", out);
assert!(out.contains("let some_thing = 0.5;"), "missing definition of `some_thing`: {}", out);
assert!(out.contains("return vec4<f32>(some_thing);"), "missing return of `some_thing` as vec4: {}", out);
}
#[test]
fn code_blocks_using_arrays_nested() {
const MAIN_MODULE: &'static str =
r#"#define_module engine::main
fn main() -> vec4<f32> {
calc_shadow_dir_light(vec3<f32>(1.0), vec3<f32>(1.0), vec3<f32>(1.0), light_example_var);
return vec4<f32>(1.0);
}
fn calc_shadow_dir_light(world_pos: vec3<f32>, world_normal: vec3<f32>, light_dir: vec3<f32>, light: Light) -> f32 {
let map_data: LightShadowMapUniform = u_light_shadow[light.light_shadow_uniform_index[0]];
let frag_pos_light_space = map_data.light_space_matrix * vec4<f32>(world_pos, 1.0);
return 1.0;
}
fn pcf_point_light(tex_coords: vec3<f32>, test_depth: f32, shadow_us: array<LightShadowMapUniform, 6>, samples_num: u32, uv_radius: f32) -> f32 {
return 1.0;
}"#;
/* tracing_subscriber::fmt()
// enable everything
.with_max_level(tracing::Level::TRACE)
// sets this to be the default, global collector for this application.
.init(); */
let mut p = Processor::new();
let module = p.parse_module(MAIN_MODULE).unwrap()
.expect("failed to find bindings module def");
let out = p.preprocess_module(&module).unwrap();
assert!(out.contains("map_data: LightShadowMapUniform"), "missing `map_data` var: {}", out);
//println!("{out}");
}
#[test]
fn struct_with_attributes() {
const MAIN_MODULE: &'static str =
r#"#define_module engine::main
struct VertexOutput {
@builtin(position) clip_position: vec4<f32>,
@location(0) tex_coords: vec2<f32>,
@location(1) world_position: vec3<f32>,
@location(2) world_normal: vec3<f32>,
@location(3) frag_pos_light_space: vec4<f32>,
}
struct VertexOutput {
@builtin(position)
clip_position: vec4<f32>,
@location(0)
tex_coords: vec2<f32>,
}"#;
/* tracing_subscriber::fmt()
// enable everything
.with_max_level(tracing::Level::TRACE)
// sets this to be the default, global collector for this application.
.init(); */
let mut p = Processor::new();
let module = p.parse_module(MAIN_MODULE).unwrap()
.expect("failed to find bindings module def");
let out = p.preprocess_module(&module).unwrap();
assert!(out.contains("struct VertexOutput"), "missing `VertexOutput` struct: {}", out);
assert!(out.contains("@location(0) tex_coords: vec2<f32>"), "missing `tex_coords` struct field: {}", out);
//println!("{out}");
}
#[test]
fn function_annotations() {
const MAIN_MODULE: &'static str =
r#"#define_module engine::main
@vertex
fn vs_main(
@builtin(vertex_index) in_vertex_index: u32,
) -> VertexOutput {
var out: VertexOutput;
let x = f32(1 - i32(in_vertex_index)) * 0.5;
let y = f32(i32(in_vertex_index & 1u) * 2 - 1) * 0.5;
out.clip_position = vec4<f32>(x, y, 0.0, 1.0);
return out;
}"#;
/* tracing_subscriber::fmt()
// enable everything
.with_max_level(tracing::Level::TRACE)
// sets this to be the default, global collector for this application.
.init(); */
let mut p = Processor::new();
let module = p.parse_module(MAIN_MODULE).unwrap()
.expect("failed to find bindings module def");
let out = p.preprocess_module(&module).unwrap();
assert!(out.contains("@vertex"), "missing `@vertex` annotation: {}", out);
assert!(out.contains("fn vs_main("), "missing `vs_main` entry point: {}", out);
assert!(out.contains("@builtin(vertex_index) in_vertex_index: u32"), "missing `in_vertex_index` function argument: {}", out);
//println!("{out}");
}
#[test]
fn import_binding_type() {
const STRUCTS_MODULE: &'static str =
r#"#define_module engine::shadows::structs
struct TextureAtlasFrame {
x: u32,
y: u32,
width: u32,
height: u32,
}
struct LightShadowMapUniform {
light_space_matrix: mat4x4<f32>,
atlas_frame: TextureAtlasFrame,
near_plane: f32,
far_plane: f32,
light_size_uv: f32,
light_pos: vec3<f32>,
/// boolean casted as u32
has_shadow_settings: u32,
pcf_samples_num: u32,
pcss_blocker_search_samples: u32,
constant_depth_bias: f32,
}"#;
const MAIN_MODULE: &'static str =
r#"#define_module engine::shadows::depth_pass
#import engine::shadows::structs::{LightShadowMapUniform}
struct TransformData {
transform: mat4x4<f32>,
normal_matrix: mat4x4<f32>,
}
@group(0) @binding(0)
var<storage, read> u_light_shadow: array<LightShadowMapUniform>;
@group(1) @binding(0)
var<uniform> u_model_transform_data: TransformData;
struct VertexOutput {
@builtin(position)
clip_position: vec4<f32>,
@location(0) world_pos: vec3<f32>,
@location(1) instance_index: u32,
}
@vertex
fn vs_main(
@location(0) position: vec3<f32>,
@builtin(instance_index) instance_index: u32,
) -> VertexOutput {
let world_pos = u_model_transform_data.transform * vec4<f32>(position, 1.0);
let pos = u_light_shadow[instance_index].light_space_matrix * world_pos;
return VertexOutput(pos, world_pos.xyz, instance_index);
}"#;
/* tracing_subscriber::fmt()
// enable everything
.with_max_level(tracing::Level::TRACE)
// sets this to be the default, global collector for this application.
.init(); */
let mut p = Processor::new();
p.parse_module(STRUCTS_MODULE).unwrap()
.expect("failed to find bindings module def");
let module = p.parse_module(MAIN_MODULE).unwrap()
.expect("failed to find bindings module def");
let out = p.preprocess_module(&module).unwrap();
assert!(out.contains("struct LightShadowMapUniform"), "missing `LightShadowMapUniform` \
struct definition: {}", out);
assert!(out.contains("struct TextureAtlasFrame"), "missing `TextureAtlasFrame` struct \
definition, requirement of `LightShadowMapUniform` struct: {}", out);
assert!(out.contains("u_light_shadow: array<LightShadowMapUniform>"), "missing \
`u_light_shadow` binding: {}", out);
println!("{out}");
}
#[test]
fn array_construct() {
const STRUCTS_MODULE: &'static str =
r#"#define_module engine::shadows::structs
struct TextureAtlasFrame {
x: u32,
y: u32,
width: u32,
height: u32,
}
struct LightShadowMapUniform {
light_space_matrix: mat4x4<f32>,
atlas_frame: TextureAtlasFrame,
near_plane: f32,
far_plane: f32,
light_size_uv: f32,
light_pos: vec3<f32>,
/// boolean casted as u32
has_shadow_settings: u32,
pcf_samples_num: u32,
pcss_blocker_search_samples: u32,
constant_depth_bias: f32,
}"#;
const MAIN_MODULE: &'static str =
r#"#define_module engine::shadows::depth_pass
#import engine::shadows::structs::{LightShadowMapUniform}
fn main() -> f32 {
let uniforms = array<LightShadowMapUniform, 6>(
u_light_shadow[indices[0]],
u_light_shadow[indices[1]],
u_light_shadow[indices[2]],
u_light_shadow[indices[3]],
u_light_shadow[indices[4]],
u_light_shadow[indices[5]]
);
return 1.0;
}"#;
/* tracing_subscriber::fmt()
// enable everything
.with_max_level(tracing::Level::TRACE)
// sets this to be the default, global collector for this application.
.init(); */
let mut p = Processor::new();
p.parse_module(STRUCTS_MODULE).unwrap()
.expect("failed to find bindings module def");
let module = p.parse_module(MAIN_MODULE).unwrap()
.expect("failed to find bindings module def");
let out = p.preprocess_module(&module).unwrap();
assert!(out.contains("struct LightShadowMapUniform"), "missing `LightShadowMapUniform` \
struct definition: {}", out);
assert!(out.contains("struct TextureAtlasFrame"), "missing `TextureAtlasFrame` struct \
definition, requirement of `LightShadowMapUniform` struct: {}", out);
assert!(out.contains("array<LightShadowMapUniform, 6>("), "missing creation of uniform \
array: {}", out);
println!("{out}");
}
#[test]
fn generics_with_two_args() {
const MAIN_MODULE: &'static str =
r#"#define_module engine::shadows::depth_pass
@group(0) @binding(0)
var t_light_grid: texture_storage_2d<rg32uint, read_write>;
fn main() -> f32 {
return 1.0;
}"#;
/* tracing_subscriber::fmt()
// enable everything
.with_max_level(tracing::Level::TRACE)
// sets this to be the default, global collector for this application.
.init(); */
let mut p = Processor::new();
let module = p.parse_module(MAIN_MODULE).unwrap()
.expect("failed to find bindings module def");
let out = p.preprocess_module(&module).unwrap();
assert!(out.contains("t_light_grid: texture_storage_2d<rg32uint, read_write>"), "missing `t_light_grid` \
binding: {}", out);
println!("{out}");
}
#[test]
fn function_return_attributes() {
const MAIN_MODULE: &'static str =
r#"#define_module engine::shadows::depth_pass
fn main() -> @location(0) vec4<f32> {
return vec4<f32>(1.0);
}"#;
/* tracing_subscriber::fmt()
// enable everything
.with_max_level(tracing::Level::TRACE)
// sets this to be the default, global collector for this application.
.init(); */
let mut p = Processor::new();
let module = p.parse_module(MAIN_MODULE).unwrap()
.expect("failed to find bindings module def");
let out = p.preprocess_module(&module).unwrap();
assert!(out.contains("@location(0) vec4<f32>"),
"missing return value with attribute: {}", out);
println!("{out}");
}
#[test]
fn import_binding_use() {
const STRUCTS_MODULE: &'static str =
r#"#define_module engine::shadows::structs
struct TextureAtlasFrame {
x: u32,
y: u32,
width: u32,
height: u32,
}
struct LightShadowMapUniform {
light_space_matrix: mat4x4<f32>,
atlas_frame: TextureAtlasFrame,
near_plane: f32,
far_plane: f32,
light_size_uv: f32,
light_pos: vec3<f32>,
/// boolean casted as u32
has_shadow_settings: u32,
pcf_samples_num: u32,
pcss_blocker_search_samples: u32,
constant_depth_bias: f32,
}
@group(0) @binding(0)
var<storage, read> u_light_shadow: array<LightShadowMapUniform>;"#;
const MAIN_MODULE: &'static str =
r#"#define_module engine::shadows::depth_pass
#import engine::shadows::structs::{u_light_shadow}
fn main() -> @location(0) vec4<f32> {
let shadow_u: LightShadowMapUniform = u_light_shadow[light.light_shadow_uniform_index[0]];
return vec4<f32>(1.0);
}"#;
/* tracing_subscriber::fmt()
// enable everything
.with_max_level(tracing::Level::TRACE)
// sets this to be the default, global collector for this application.
.init(); */
let mut p = Processor::new();
p.parse_module(STRUCTS_MODULE).unwrap()
.expect("failed to find bindings module def");
let module = p.parse_module(MAIN_MODULE).unwrap()
.expect("failed to find bindings module def");
let out = p.preprocess_module(&module).unwrap();
assert!(out.contains("u_light_shadow: array<LightShadowMapUniform>"),
"missing `u_light_shadow` binding:\n{}", out);
println!("{out}");
}
/// Tests using the same import multiple times
#[test]
fn multi_import_use() {
const STRUCTS_MODULE: &'static str =
r#"#define_module engine::shadows::structs
struct TextureAtlasFrame {
x: u32,
y: u32,
width: u32,
height: u32,
}
struct LightShadowMapUniform {
light_space_matrix: mat4x4<f32>,
atlas_frame: TextureAtlasFrame,
near_plane: f32,
far_plane: f32,
light_size_uv: f32,
light_pos: vec3<f32>,
/// boolean casted as u32
has_shadow_settings: u32,
pcf_samples_num: u32,
pcss_blocker_search_samples: u32,
constant_depth_bias: f32,
}
@group(0) @binding(0)
var<storage, read> u_light_shadow: array<LightShadowMapUniform>;"#;
const FN_MODULE: &'static str =
r#"#define_module engine::shadows::funcs
#import engine::shadows::structs::{LightShadowMapUniform, u_light_shadow}
fn do_light_thing(u: LightShadowMapUniform) -> f32 {
let s = u_light_shadow[light.light_shadow_uniform_index[0]];
return 1.0;
}"#;
const MAIN_MODULE: &'static str =
r#"#define_module engine::shadows::depth_pass
#import engine::shadows::structs::{u_light_shadow}
#import engine::shadows::funcs::{do_light_thing}
fn main() -> @location(0) vec4<f32> {
let shadow_u: LightShadowMapUniform = u_light_shadow[light.light_shadow_uniform_index[0]];
let f = do_light_thing(shadow_u);
return vec4<f32>(f);
}"#;
/* tracing_subscriber::fmt()
// enable everything
.with_max_level(tracing::Level::TRACE)
// sets this to be the default, global collector for this application.
.init(); */
let mut p = Processor::new();
p.parse_module(STRUCTS_MODULE).unwrap()
.expect("failed to find bindings module def");
p.parse_module(FN_MODULE).unwrap()
.expect("failed to find bindings module def");
let module = p.parse_module(MAIN_MODULE).unwrap()
.expect("failed to find bindings module def");
let out = p.preprocess_module(&module).unwrap();
assert!(out.contains("u_light_shadow: array<LightShadowMapUniform>"),
"missing `u_light_shadow` binding:\n{}", out);
let v: Vec<_> = out.match_indices("struct LightShadowMapUniform").collect();
assert_eq!(v.len(), 1,
"expected one `LightShadowMapUniform` struct definition:\n{}", out);
println!("{out}");
}
} }

View File

@ -1,10 +1,10 @@
use std::{collections::{HashMap, HashSet, VecDeque}, fmt::Write, fs, path::Path}; use std::{collections::{HashMap, HashSet}, fmt::Write, fs, path::Path};
use pest::{iterators::Pair, Parser}; use pest::{iterators::Pair, Parser};
use regex::Regex; use regex::Regex;
use tracing::{debug, debug_span, instrument, trace}; use tracing::{debug, debug_span, instrument, trace};
use crate::{recurse_files, DefRequirement, Definition, Import, Module, Rule, WgslParser}; use crate::{recurse_files, DefRequirement, Definition, Import, ImportItemFrom, Module, Rule, WgslParser};
const RESERVED_WORDS: [&str; 201] = [ const RESERVED_WORDS: [&str; 201] = [
"NULL", "Self", "abstract", "active", "alignas", "alignof", "as", "asm", "asm_fragment", "NULL", "Self", "abstract", "active", "alignas", "alignof", "as", "asm", "asm_fragment",
@ -48,59 +48,81 @@ const RESERVED_WORDS: [&str; 201] = [
]; ];
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum PreprocessorError { pub enum Error {
#[error("{0}")] #[error("{0}")]
IoError(#[from] std::io::Error), IoError(#[from] std::io::Error),
#[error("error parsing {0}")] #[error("error parsing {0}")]
ParserError(#[from] pest::error::Error<Rule>), ParserError(#[from] pest::error::Error<Rule>),
#[error("failure formatting preprocessor output to string ({0})")] #[error("failure formatting preprocessor output to string ({0})")]
FormatError(#[from] std::fmt::Error), FormatError(#[from] std::fmt::Error),
#[error("could not find module from path '{0}'")]
UnknownModulePath(String),
#[error("unknown module import '{module}', in {from_module}")] #[error("unknown module import '{module}', in {from_module}")]
UnknownModule { from_module: String, module: String }, UnknownModuleImport { from_module: String, module: String },
#[error("in {from_module}: unknown import from '{module}': `{item}`")] #[error("in {from_module}: unknown import from '{module}': `{item}`")]
UnknownImport { from_module: String, module: String, item: String }, UnknownItemImport { from_module: String, module: String, item: String },
#[error("import usage from `{from_module}` conflicts with local variable/function: `{name}`")] #[error("import usage from `{from_module}` conflicts with local variable/function: `{name}`")]
ConflictingImport { from_module: String, name: String }, ConflictingImport { from_module: String, name: String },
} }
fn get_external_var_requirement(module: &Module, pair: Pair<Rule>) -> Option<DefRequirement> { fn get_external_var_requirement(module: &mut Module, pair: Pair<Rule>) -> Option<DefRequirement> {
let var_name = pair.as_str(); let var_name = pair.as_str();
if pair.as_rule() == Rule::shader_external_variable { let r = pair.as_rule();
let (module_name, name) = var_name.rsplit_once("::").unwrap(); match r {
Rule::module_path => {
let (module_name, name) = var_name.rsplit_once("::").unwrap();
if RESERVED_WORDS.contains(&name) { if RESERVED_WORDS.contains(&name) {
return None;
}
// Find the module that this variable is from.
// Starts by checking if the full module path was used.
// If it was not, then find the full path of the module.
let module_full_name = add_external_as_import(module, module_name, name);
debug!("Binding module is {:?}", module_full_name);
Some(DefRequirement {
module: module_full_name,
name: name.to_string(),
})
},
Rule::type_generic => {
// generics are always builtins
return None; return None;
} },
Rule::type_array => {
let mut inner = pair.into_inner();
let inner = inner.next().unwrap();
assert_eq!(inner.as_rule(), Rule::r#type);
// inner is `Rule::type`
get_external_var_requirement(module, inner)
},
Rule::ident => {
if RESERVED_WORDS.contains(&var_name) {
return None;
}
// Find the module that this variable is from. Some(DefRequirement {
// Starts by checking if the full module path was used. module: None,
// If it was not, then find the full path of the module. name: var_name.to_string(),
let module_full_name = if module.module_imports.contains(module_name) { })
Some(module_name.to_string()) },
} else { Rule::r#type => {
module.module_imports.iter() let mut inner = pair.into_inner();
.find(|m| { let inner = inner.next().unwrap();
m.ends_with(module_name) get_external_var_requirement(module, inner)
}).cloned() },
}; _ => unimplemented!("ran into unhandled rule: {:?}, {:?}", r, pair.as_str())
debug!("Binding module is {:?}", module_full_name);
Some(DefRequirement {
module: module_full_name,
name: name.to_string(),
})
} else { // The only possibility is `Rule::shader_type`
if RESERVED_WORDS.contains(&var_name) {
return None;
}
Some(DefRequirement {
module: None,
name: var_name.to_string(),
})
} }
/* if pair.as_rule() == Rule::module_path {
} else { // The only possibility is `Rule::shader_type`
} */
} }
#[derive(Default)] #[derive(Default)]
@ -157,15 +179,16 @@ impl Processor {
let (fn_module, fn_name) = fn_path.rsplit_once("::").unwrap(); let (fn_module, fn_name) = fn_path.rsplit_once("::").unwrap();
debug!("Found call to `{}::{}`", fn_module, fn_name); debug!("Found call to external `{}::{}`", fn_module, fn_name);
let full_module = add_external_as_import(module, fn_module, fn_name);
let req = DefRequirement { let req = DefRequirement {
module: Some(fn_module.into()), module: full_module,
name: fn_name.to_string(), name: fn_name.to_string(),
}; };
requirements.push(req); requirements.push(req);
}, },
Rule::shader_external_variable => { Rule::module_path => {
let pairs = code.into_inner(); let pairs = code.into_inner();
// shader_external_variable is the only pair for this rule // shader_external_variable is the only pair for this rule
let ident_path = pairs.as_str(); let ident_path = pairs.as_str();
@ -177,23 +200,28 @@ impl Processor {
let (ident_module, ident_name) = ident_path.rsplit_once("::").unwrap(); let (ident_module, ident_name) = ident_path.rsplit_once("::").unwrap();
debug!("Found call to `{}`", ident_name);
// ignore reserved words // ignore reserved words
if RESERVED_WORDS.contains(&ident_name) { if RESERVED_WORDS.contains(&ident_name) {
continue; continue;
} }
let module_full = add_external_as_import(module, ident_module, ident_name);
let req = DefRequirement { let req = DefRequirement {
module: Some(ident_module.into()), module: module_full,
name: ident_name.to_string(), name: ident_name.to_string(),
}; };
requirements.push(req); requirements.push(req);
}, },
Rule::type_array | Rule::r#type => {
let req = self.get_imports_in_block(module, code, found_requirements);
requirements.extend(req.into_iter());
},
Rule::newline => (), Rule::newline => (),
Rule::cws => (), Rule::cws => (),
Rule::number => (),
Rule::shader_code_char => (), Rule::shader_code_char => (),
Rule::shader_ident => { Rule::ident => {
let ident = code.as_str(); let ident = code.as_str();
if found_requirements.contains(ident) { if found_requirements.contains(ident) {
@ -213,6 +241,8 @@ impl Processor {
}; };
requirements.push(req); requirements.push(req);
}, },
// generic types are only built-in types
Rule::type_generic => (),
Rule::shader_value => (), Rule::shader_value => (),
_ => unimplemented!("ran into unhandled rule: {:?}, {:?}", code.as_rule(), code.as_span()) _ => unimplemented!("ran into unhandled rule: {:?}, {:?}", code.as_rule(), code.as_span())
} }
@ -221,6 +251,45 @@ impl Processor {
requirements requirements
} }
/// Get the requirement of a type defined as a pair.
///
/// Returns `None` when the type is built-in
fn get_type_requirement(module: &mut Module, mut type_pair: Pair<'_, Rule>) -> Option<DefRequirement> {
let mut r = type_pair.as_rule();
// the inner rule of `shader_var_type_array` is
// `shader_external_variable` or `shader_type`
while r == Rule::type_array {
let mut inner = type_pair.into_inner();
type_pair = inner.next().unwrap();
r = type_pair.as_rule();
}
if r == Rule::module_path {
let external_str = type_pair.as_str();
let (module_name, name) = external_str.rsplit_once("::").unwrap();
if RESERVED_WORDS.contains(&name) {
return None;
}
let module_full_name = add_external_as_import(module, module_name, name);
Some(DefRequirement {
module: module_full_name,
name: name.to_string(),
})
} else { // The only possibility is `Rule::shader_type`
let name = type_pair.as_str();
Some(DefRequirement {
module: None,
name: name.to_string(),
})
}
}
/// Parse a module file to attempt to find the include identifier. /// Parse a module file to attempt to find the include identifier.
/// ///
/// Returns `None` if the module does not define an include identifier. /// Returns `None` if the module does not define an include identifier.
@ -228,7 +297,7 @@ impl Processor {
pub fn parse_module( pub fn parse_module(
&mut self, &mut self,
module_src: &str, module_src: &str,
) -> Result<Option<String>, PreprocessorError> { ) -> Result<Option<String>, Error> {
let e = debug_span!("parse_module", module = tracing::field::Empty).entered(); let e = debug_span!("parse_module", module = tracing::field::Empty).entered();
let module_src = remove_comments(module_src); let module_src = remove_comments(module_src);
@ -293,10 +362,20 @@ impl Processor {
Rule::shader_fn_def => { Rule::shader_fn_def => {
let mut pairs = line.clone().into_inner(); let mut pairs = line.clone().into_inner();
// shader_ident is the only pair for this rule // shader_ident is the only pair for this rule
let fn_name = pairs.next().unwrap().as_str().to_string(); let fn_name = {
let next = pairs.next().unwrap();
let r = next.as_rule();
if r == Rule::shader_fn_attribute {
pairs.next().unwrap()
} else {
next
}
};
let fn_name = fn_name.as_str().to_string();
debug!("Found function def: {fn_name}"); debug!("Found function def: {fn_name}");
let fn_body = pairs.skip(2).next().unwrap(); let fn_body = pairs.last().unwrap();//pairs.skip(3).next().unwrap();
let mut found_reqs = HashSet::default(); let mut found_reqs = HashSet::default();
let requirements = self.get_imports_in_block(&mut module, fn_body, &mut found_reqs); let requirements = self.get_imports_in_block(&mut module, fn_body, &mut found_reqs);
@ -324,54 +403,20 @@ impl Processor {
if name.as_rule() == Rule::shader_binding_var_constraint { if name.as_rule() == Rule::shader_binding_var_constraint {
trace!("Skipping binding constraint"); trace!("Skipping binding constraint");
name = pairs.next().unwrap(); name = pairs.next().unwrap();
assert_eq!(name.as_rule(), Rule::shader_ident); assert_eq!(name.as_rule(), Rule::ident);
} }
let name = name.as_str(); let name = name.as_str();
let var_type = pairs.next().unwrap(); let var_type = pairs.next().unwrap();
let mut var_type_inner = var_type.into_inner();
let inner_type = var_type_inner.next().unwrap();
let mut requirements = vec![]; let requirement_vec = Self::get_type_requirement(&mut module, inner_type)
{ .map(|r| vec![r])
let mut vt_inner = var_type.into_inner(); .unwrap_or_default();
let inner_type = vt_inner.next().unwrap();
if inner_type.as_rule() == Rule::shader_external_variable { let req_names = requirement_vec.first().map(|r| r.name.clone());
let external_str = inner_type.as_str(); debug!("Found binding def: `{}` with requirement: {:?}", name, req_names);
let (module_name, name) = external_str.rsplit_once("::").unwrap();
if RESERVED_WORDS.contains(&name) {
continue;
}
// Find the module that this variable is from.
// Starts by checking if the full module path was used.
// If it was not, then find the full path of the module.
let module_full_name = if module.module_imports.contains(module_name) {
Some(module_name.to_string())
} else {
module.module_imports.iter()
.find(|m| {
m.ends_with(module_name)
}).cloned()
};
debug!("Binding module is {:?}", module_full_name);
requirements.push(DefRequirement {
module: module_full_name,
name: name.to_string(),
});
} else { // The only possibility is `Rule::shader_type`
let name = inner_type.as_str();
requirements.push(DefRequirement {
module: None,
name: name.to_string(),
});
}
}
debug!("Found binding def: `{name}`");
let line_span = line.as_span(); let line_span = line.as_span();
let start_pos = line_span.start(); let start_pos = line_span.start();
@ -383,7 +428,7 @@ impl Processor {
name: name.to_string(), name: name.to_string(),
start_pos, start_pos,
end_pos, end_pos,
requirements: vec![], requirements: requirement_vec,
}, },
); );
}, },
@ -393,6 +438,19 @@ impl Processor {
let const_name = pairs.next().unwrap().as_str().to_string(); let const_name = pairs.next().unwrap().as_str().to_string();
debug!("Found const def: `{const_name}`"); debug!("Found const def: `{const_name}`");
let mut requirements = vec![];
// if the type of the const was specified, find the requirement
let next = pairs.next().unwrap();
if next.as_rule() == Rule::r#type {
let mut next_inner = next.into_inner();
let var_type = next_inner.next().unwrap();
if let Some(req) = Self::get_type_requirement(&mut module, var_type) {
requirements.push(req);
}
}
let line_span = line.as_span(); let line_span = line.as_span();
let start_pos = line_span.start(); let start_pos = line_span.start();
let end_pos = line_span.end(); let end_pos = line_span.end();
@ -403,7 +461,7 @@ impl Processor {
name: const_name, name: const_name,
start_pos, start_pos,
end_pos, end_pos,
requirements: vec![], requirements,
}, },
); );
}, },
@ -431,19 +489,46 @@ impl Processor {
// iterate through fields in struct // iterate through fields in struct
let mut requirements = vec![]; let mut requirements = vec![];
while let (Some(_), Some(field_ty)) = (struct_inner.next(), struct_inner.next()) { loop {
//panic!("got {} of type {}", ident.as_str(), field_ty.as_str()); let (_field_attr, _field_name) = {
let mut temp = struct_inner.next();
// ignore whitespace
while let Some(t) = &temp {
if t.as_rule() == Rule::cws {
temp = struct_inner.next();
} else {
break;
}
}
if let Some(next) = temp {
let r = next.as_rule();
if r == Rule::shader_struct_field_attribute {
(Some(next), struct_inner.next().unwrap())
} else {
(None, next)
}
} else {
break;
}
};
//todo!("field name is: {:?}", _field_name.as_str());
let field_ty = struct_inner.next().unwrap();
let field_ty_rule = field_ty.as_rule();
let mut ty_inner = field_ty.into_inner(); let mut ty_inner = field_ty.into_inner();
let ty_inner = ty_inner.next().unwrap(); let ty_inner = ty_inner.next().unwrap();
//todo!("ty inner is: {:?}", ty_inner.as_str());
match ty_inner.as_rule() {
Rule::shader_external_variable | Rule::shader_type => { match field_ty_rule {
//requirements.push(value) Rule::r#type => {
if let Some(req) = get_external_var_requirement(&module, ty_inner) { if let Some(req) = get_external_var_requirement(&mut module, ty_inner) {
requirements.push(req); requirements.push(req);
} }
}, },
_ => unimplemented!("ran into unhandled rule: ({:?}) {:?}", line.as_rule(), line.as_span()) _ => unimplemented!("ran into unhandled rule: ({:?}) {:?}", field_ty_rule, ty_inner.as_str())
} }
} }
@ -467,7 +552,7 @@ impl Processor {
let alias_to = inner.next().unwrap(); let alias_to = inner.next().unwrap();
// get the requirements of the alias if there are any // get the requirements of the alias if there are any
let requirements = if let Some(req) = get_external_var_requirement(&module, alias_to) { let requirements = if let Some(req) = get_external_var_requirement(&mut module, alias_to) {
debug!("Type req: {}", req.name); debug!("Type req: {}", req.name);
vec![req] vec![req]
} else { vec![] }; } else { vec![] };
@ -521,7 +606,7 @@ impl Processor {
&mut self, &mut self,
path: P, path: P,
extensions: [&str; N], extensions: [&str; N],
) -> Result<usize, PreprocessorError> { ) -> Result<usize, Error> {
let files = recurse_files(path)?; let files = recurse_files(path)?;
let mut parsed = 0; let mut parsed = 0;
@ -541,18 +626,93 @@ impl Processor {
} }
#[instrument(skip(self))] #[instrument(skip(self))]
fn generate_header(&mut self, module_path: &str) -> Result<String, PreprocessorError> { fn generate_header(&mut self, module_path: &str) -> Result<String, Error> {
let module = self.modules.get(module_path).unwrap(); let module = self.modules.get(module_path).unwrap();
let mut output = String::new(); let mut output = String::new();
let mut compiled = HashSet::new(); let mut compiled = HashSet::new();
compile_definitions(&self.modules, module, &mut compiled, &mut output)?; let mut compiled_imports = HashSet::new();
/* for (module_name, import) in &module.item_imports {
for item in &import.imports {
let item_mod = self.modules.get(&import.module)
.ok_or(Error::UnknownItemImport { from_module: module_path.into(), module: import.module.clone(), item: item.clone() })?;
let item_def = item_mod.get_definition(item)
.ok_or(Error::UnknownItemImport { from_module: module_path.into(), module: import.module.clone(), item: item.clone() })?;
compile_requirements(self.modules, item_mod, &mut compiled, item_def, &mut output)?;
let func_src = &item_mod.src[item_def.start_pos..item_def.end_pos];
output.write_fmt(format_args!("// SOURCE {}::{} (import)\n", module_name, item_def.name))?;
output.push_str(func_src);
output.push_str("\n");
}
} */
compile_imports(&self.modules, module, &mut compiled_imports, &mut compiled, &mut output)?;
//compile_definitions(&self.modules, module, &mut compiled, &mut output)?;
Ok(output) Ok(output)
} }
#[instrument(skip(self, type_, output))]
fn output_type(&self, type_: Pair<Rule>, output: &mut String) -> Result<(), std::fmt::Error> {
let r = type_.as_rule();
match r {
Rule::type_array => {
let mut inner = type_.into_inner();
let array_ty = inner.next().unwrap();
let array_size = inner.next()
.map(|p| format!(", {}", p.as_str()))
.unwrap_or_default();
output.write_str("array<")?;
self.output_type(array_ty, output)?;
output.write_str(&array_size)?;
output.write_str(">")?;
},
Rule::module_path => {
let pairs = type_.clone().into_inner();
// the last pair is the name of the variable
let ident_name = pairs.last().unwrap().as_str();
output.write_str(ident_name)?;
},
Rule::type_generic => {
let mut pairs = type_.clone().into_inner();
// the last pair is the name of the variable
let ident_name = pairs.next().unwrap().as_str();
output.write_str(ident_name)?;
output.write_str("<")?;
self.output_type(pairs.next().unwrap(), output)?;
// if there's a second generic field, output it as well
if let Some(second_generic) = pairs.next() {
output.write_str(", ")?;
self.output_type(second_generic, output)?;
}
output.write_str(">")?;
},
Rule::ident => {
output.write_str(type_.as_str())?;
},
Rule::r#type => {
let mut pairs = type_.into_inner();
let ty = pairs.next().unwrap();
self.output_type(ty, output)?;
},
_ => unimplemented!("ran into unhandled rule: {:?}, {:?}", r, type_.as_str()),
}
Ok(())
}
#[instrument(skip(self, shader_code_rule, output))] #[instrument(skip(self, shader_code_rule, output))]
fn output_shader_code_line(&self, shader_code_rule: Pair<Rule>, output: &mut String) -> Result<(), std::fmt::Error> { fn output_shader_code_line(&self, shader_code_rule: Pair<Rule>, output: &mut String) -> Result<(), std::fmt::Error> {
for line in shader_code_rule.into_inner() { for line in shader_code_rule.into_inner() {
//debug!("output rule: {:?}, '{}'", line.as_rule(), line.as_str());
match line.as_rule() { match line.as_rule() {
Rule::shader_external_fn => { Rule::shader_external_fn => {
let mut pairs = line.clone().into_inner(); let mut pairs = line.clone().into_inner();
@ -569,33 +729,139 @@ impl Processor {
panic!("Unknown error, rule input: {}", line.as_str()); panic!("Unknown error, rule input: {}", line.as_str());
} }
}, },
Rule::shader_external_variable => { Rule::type_array | Rule::module_path => {
let pairs = line.clone().into_inner(); self.output_type(line, output)?;
// the last pair is the name of the variable
let ident_name = pairs.last().unwrap().as_str();
output.write_str(ident_name)?;
}, },
//shader_external_variable
Rule::shader_fn_def => { Rule::shader_fn_def => {
let mut rule_inner = line.into_inner(); let mut rule_inner = line.into_inner();
let fn_name = rule_inner.next().unwrap().as_str(); // attempt to get the optional attribute, right after it is the name
let args = rule_inner.next().unwrap().as_str(); let mut attribute = None;
let fn_ret = rule_inner.next().unwrap().as_str(); let fn_name = {
// get the fn name, s
let next = rule_inner.next().unwrap();
let r = next.as_rule();
if r == Rule::shader_fn_attribute {
attribute = Some(next);
rule_inner.next().unwrap()
} else {
next
}
};
let attribute = attribute
.map(|a| a.as_str().to_string())
.unwrap_or_default();
let fn_name = fn_name.as_str();
let args = rule_inner.next().unwrap().as_str(); // TODO: external types in args
let mut fn_ret_attr = "";
let fn_ret = {
let next = rule_inner.next().unwrap();
let r = next.as_rule();
if r == Rule::shader_fn_attribute {
fn_ret_attr = next.as_str();
rule_inner.next().unwrap().as_str()
} else {
next.as_str()
}
};
let fn_body = rule_inner.next().unwrap(); let fn_body = rule_inner.next().unwrap();
let mut body_output = String::new(); let mut body_output = String::new();
self.output_shader_code_line(fn_body, &mut body_output)?; self.output_shader_code_line(fn_body, &mut body_output)?;
// to escape {, must use two // to escape {, must use two
output.write_fmt(format_args!("fn {}{} -> {} {{{}}}", output.write_fmt(format_args!("{}fn {}{} -> {}{} {{{}}}",
fn_name, args, fn_ret, body_output))?; attribute, fn_name, args, fn_ret_attr, fn_ret, body_output))?;
},
Rule::shader_code_block => {
output.write_str("{")?;
// this call will run line.into_inner()
self.output_shader_code_line(line, output)?;
output.write_str("}")?;
},
Rule::shader_struct_def => {
let mut rule_inner = line.into_inner();
let ident = rule_inner.next().unwrap().as_str();
output.write_fmt(format_args!("struct {} {{\n", ident))?;
// process fields
loop {
let mut space_count = 0;
let (field_attr, field_name) = {
let mut temp = rule_inner.next();
// find the white space amount for the field
while let Some(t) = &temp {
if t.as_rule() == Rule::cws {
space_count += 1;
temp = rule_inner.next();
} else {
// stop when a non whitespace rule is found
break;
}
}
if let Some(next) = temp {
let r = next.as_rule();
if r == Rule::shader_struct_field_attribute {
(Some(next), rule_inner.next().unwrap())
} else {
(None, next)
}
} else {
break;
}
};
// write indentation
output.write_str(&" ".repeat(space_count))?;
// write field attribute with space to separate from field name
if let Some(field_attr) = field_attr {
output.write_str(field_attr.as_str())?;
output.write_str(" ")?;
}
output.write_str(field_name.as_str())?;
output.write_str(": ")?;
// get field type inner
let field_ty = rule_inner.next().unwrap();
self.output_type(field_ty, output)?;
output.write_str(",\n")?;
}
output.write_str("}")?;
},
Rule::shader_binding_def => {
let rule_inner = line.clone().into_inner();
// the last pair is the type of the binding
let var_type = rule_inner.last().unwrap();
let var_type_span = var_type.as_span();
// this captures everything up to the type of the var
let pairs_start = line.as_span().start();
let var_type_start = var_type_span.start();
let in_end = var_type_start - pairs_start;
let mut var_def = line.as_str()[..in_end].to_string();
// this recursive call will output the code for the binding type
self.output_type(var_type, &mut var_def)?;
output.write_str(&var_def)?;
output.write_str(";")?;
}, },
Rule::shader_code => { Rule::shader_code => {
self.output_shader_code_line(line, output)?; self.output_shader_code_line(line, output)?;
}, },
Rule::shader_code_fn_usage | Rule::shader_value | Rule::shader_const_def | Rule::shader_ident | Rule::shader_code_char | Rule::cws => { Rule::shader_code_fn_usage | Rule::shader_value | Rule::r#type |
Rule::shader_const_def | Rule::ident | Rule::shader_code_char |
Rule::cws | Rule::type_generic => {
let input = line.as_str(); let input = line.as_str();
output.write_str(&input)?; output.write_str(&input)?;
}, },
@ -609,24 +875,26 @@ impl Processor {
Ok(()) Ok(())
} }
#[instrument(skip(self, module_src))] #[instrument(skip(self))]
pub fn process_file(&mut self, module_path: &str, module_src: &str) -> Result<String, PreprocessorError> { pub fn preprocess_module(&mut self, module_path: &str) -> Result<String, Error> {
let module_src = remove_comments(module_src); let module = self.modules.get(module_path)
.ok_or(Error::UnknownModulePath(module_path.to_string()))?;
let module_src = module.src.clone();
let module_pair = WgslParser::parse(Rule::file, &module_src)?
.next()
.unwrap(); // get and unwrap the `file` rule; never fails
let mut out_string = String::new(); let mut out_string = String::new();
// the output will be at least the length of module_src // the output will be at least the length of module_src
out_string.reserve(module_src.len()); out_string.reserve(module_src.len());
let file = WgslParser::parse(Rule::file, &module_src)?
.next()
.unwrap(); // get and unwrap the `file` rule; never fails
let header = self.generate_header(module_path)?; let header = self.generate_header(module_path)?;
out_string.write_str("// START OF IMPORT HEADER\n")?; out_string.write_str("// START OF IMPORT HEADER\n")?;
out_string.write_str(&header)?; out_string.write_str(&header)?;
out_string.write_str("// END OF IMPORT HEADER\n")?; out_string.write_str("// END OF IMPORT HEADER\n")?;
for record in file.into_inner() { for record in module_pair.into_inner() {
match record.as_rule() { match record.as_rule() {
Rule::command_line => { Rule::command_line => {
// the parser has found a preprocessor command, figure out what it is // the parser has found a preprocessor command, figure out what it is
@ -658,14 +926,10 @@ impl Processor {
} }
fn try_find_requirement_module(module: &Module, req_name: &str) -> Option<String> { fn try_find_requirement_module(module: &Module, req_name: &str) -> Option<String> {
debug!("looking for '{}' in module '{}' that imports: {:?}", req_name, module.name, module.item_imports.values()); trace!("looking for '{}' in module '{}' that imports: {:?}", req_name, module.name, module.item_imports.values());
//debug!("Module: {}, req: {}", module.name, req_name);
for import in module.item_imports.values() { for import in module.item_imports.values() {
//debug!("Import: {:?}", import);
if import.imports.contains(&req_name.to_string()) { if import.imports.contains(&req_name.to_string()) {
//debug!("has it");
return Some(import.module.clone()); return Some(import.module.clone());
} }
} }
@ -673,90 +937,135 @@ fn try_find_requirement_module(module: &Module, req_name: &str) -> Option<String
None None
} }
//#[instrument(fields(module = module.name, require=tracing::field::Empty), skip_all)] #[inline(always)]
fn compile_definitions(modules: &HashMap<String, Module>, module: &Module, compiled_definitions: &mut HashSet<DefRequirement>, output: &mut String) -> Result<(), PreprocessorError> { fn is_import_compiled(compiled_imports: &mut HashSet<ImportItemFrom>, import_mod: &str, import_name: &str) -> bool {
let e = debug_span!("compile_definitions", module = module.name, require = tracing::field::Empty, sub_require = tracing::field::Empty).entered(); let compiled_import = ImportItemFrom {
module: import_mod.to_string(),
import: import_name.to_string(),
};
let defs = module.functions.iter().chain(module.vars.iter()).chain(module.structs.iter()).chain(module.aliases.iter()); !compiled_imports.insert(compiled_import)
for (_, funcs) in defs { //&module.functions { }
let mut requirements = VecDeque::from(funcs.requirements.clone());
#[inline(always)]
fn is_def_compiled(compiled_definitions: &mut HashSet<DefRequirement>, def_module: &str, def_name: &str) -> bool {
let item_def_req = DefRequirement {
module: Some(def_module.to_string()),
name: def_name.to_string(),
};
!compiled_definitions.insert(item_def_req)
}
fn compile_imports(modules: &HashMap<String, Module>, module: &Module, compiled_imports: &mut HashSet<ImportItemFrom>, compiled_definitions: &mut HashSet<DefRequirement>, output: &mut String) -> Result<(), Error> {
let module_path = &module.name;
for (module_name, import) in &module.item_imports {
let item_mod = modules.get(module_name)
.ok_or(Error::UnknownModulePath(module_name.clone()))?;
compile_imports(modules, item_mod, compiled_imports, compiled_definitions, output)?;
for item in &import.imports {
if is_import_compiled(compiled_imports, &import.module, item) {
continue;
} else if is_def_compiled(compiled_definitions, &module_name, item) {
continue;
}
let item_def = item_mod.get_definition(item)
.ok_or(Error::UnknownItemImport { from_module: module_path.into(), module: import.module.clone(), item: item.clone() })?;
compile_requirements(modules, item_mod, compiled_definitions, item_def, output)?;
let func_src = &item_mod.src[item_def.start_pos..item_def.end_pos];
output.write_fmt(format_args!("// SOURCE {}::{} (import)\n", module_name, item_def.name))?;
output.push_str(func_src);
output.push_str("\n");
}
}
Ok(())
}
fn compile_requirements(modules: &HashMap<String, Module>, module: &Module, compiled_definitions: &mut HashSet<DefRequirement>, def: &Definition, output: &mut String) -> Result<(), Error> {
let e = debug_span!("compile_requirements", module = module.name,
require = def.name, sub_require = tracing::field::Empty).entered();
trace!("Compiling {} requirements of {}", def.requirements.len(), def.name);
for req in &def.requirements {
e.record("sub_require", req.name.clone());
let req_module = if module.get_definition(&req.name).is_some() {
// if the requirement is local, use the current module
Some(Ok(module))
} else {
req.module.clone()
// try to find the module name the requirement is in using the imported module
// if it was not supplied
.or_else(|| try_find_requirement_module(&module, &req.name))
// get the module, panic if the name was incorrect.
.map(|m| modules.get(&m)
.ok_or(Error::UnknownModuleImport { from_module: module.name.clone(), module: m })
//.unwrap_or_else(|| panic!("invalid module import: {}", m))
)
};
while let Some(mut req) = requirements.pop_front() { if let Some(req_module) = req_module {
e.record("require", req.name.clone()); let req_module = req_module?;
let module_name = &req_module.name;
if req.module.is_none() { let mut r = req.clone();
let mod_name = try_find_requirement_module(&module, &req.name); r.module = Some(module_name.clone());
req.module = mod_name; if compiled_definitions.contains(&r) {
debug!("req module: {:?}", req.module); debug!("Definition ({}) already compiled, skipping...", req.name);
continue;
} }
compiled_definitions.insert(r);
let req_def = req_module.get_definition(&req.name)
.unwrap_or_else(|| panic!("invalid requirement of {}::{}: '{}'", module_name, def.name, req.name));
if !req_def.requirements.is_empty() {
output.write_fmt(format_args!("\n// REQUIREMENTS OF {}::{}\n", module_name, req_def.name))?;
if let Some(module_name) = &req.module { for mut sub in req_def.requirements.clone() {
let req_module = modules.get(module_name) if sub.module.is_none() {
.unwrap_or_else(|| panic!("invalid module import: {}", module_name)); sub.module = try_find_requirement_module(&req_module, &sub.name);
// get the definition from the module that defines the import
let req_def = req_module.get_definition(&req.name)
.unwrap_or_else(|| panic!("invalid import: {} from {}", req.name, module_name));
if !req_def.requirements.is_empty() {
for mut sub in req_def.requirements.clone() {
e.record("sub_require", sub.name.clone());
println!("searching for sub requirement module in module: '{}'", req_module.name);
if sub.module.is_none() {
sub.module = try_find_requirement_module(&req_module, &sub.name);
}
if let Some(sub_mod) = &sub.module {
if !compiled_definitions.contains(&sub) {
compiled_definitions.insert(sub.clone());
let sub_mod = modules.get(sub_mod).expect("unknown module from import");
debug!("found sub requirement module: {}", sub_mod.name);
let module_name = &sub_mod.name;
let mut requirements_output = String::new();
compile_definitions(modules, sub_mod, compiled_definitions, &mut requirements_output)?;
output.write_fmt(format_args!("\n// REQUIREMENTS OF {}::{}\n", module_name, req.name))?;
output.push_str(&requirements_output);
output.push_str("\n");
// find the definition of the sub requirement
let sub_def = sub_mod.get_definition(&sub.name)
.unwrap_or_else(|| panic!("invalid import: {} from {}", req.name, module_name));
// write the source of the sub requirement to the output
let req_src = &sub_mod.src[sub_def.start_pos..sub_def.end_pos];
output.write_fmt(format_args!("// SOURCE {}::{}\n", module_name, sub.name))?;
output.push_str(req_src);
output.push_str("\n");
} else {
debug!("Requirement module is already compiled, skipping...");
}
} else {
debug!("Requirement has no module, assuming its local...");
}
} }
/* let sub_req_names: Vec<String> = req_def.requirements.iter().map(|r| r.name.clone()).collect();
debug!("Found requirement: {}, with the following sub-requirements: {:?} from {}", req_def.name, sub_req_names, req_module.name);
let mut requirements_output = String::new(); let sub_module = sub.module.as_ref()
compile_definitions(modules, req_module, &mut requirements_output)?; .map(|m| modules.get(m)
.unwrap_or_else(|| panic!("invalid module import: {}", m)))
.unwrap_or(req_module);
let mut s = sub.clone();
s.module = Some(sub_module.name.clone());
if compiled_definitions.contains(&s) {
debug!("Sub-requirement ({}) already compiled, skipping...", s.name);
continue;
}
compiled_definitions.insert(s);
output.write_fmt(format_args!("\n// REQUIREMENTS OF {}::{}\n", module_name, req.name))?; if let Some(local) = sub_module.get_definition(&sub.name) {
output.push_str(&requirements_output); if !local.requirements.is_empty() {
output.push_str("\n"); */ output.write_fmt(format_args!("\n// REQUIREMENTS OF {}::{}\n", module_name, local.name))?;
} compile_requirements(modules, sub_module, compiled_definitions, local, output)?;
output.push_str("\n");
}
let func_src = &req_module.src[req_def.start_pos..req_def.end_pos]; let req_src = &sub_module.src[local.start_pos..local.end_pos];
output.write_fmt(format_args!("// SOURCE {}::{}\n", module_name, req.name))?; output.write_fmt(format_args!("// SOURCE {}::{} (subsubrequirement)\n", module_name, sub.name))?;
output.push_str(func_src); output.push_str(req_src);
output.push_str("\n"); output.push_str("\n\n");
} else { }
debug!("Could not find module for `{}`, assuming its local", req.name); }
} }
let req_src = &req_module.src[req_def.start_pos..req_def.end_pos];
output.write_fmt(format_args!("// SOURCE {}::{} (subrequirement)\n", module_name, req.name))?;
output.push_str(req_src);
output.push_str("\n");
} }
} }
@ -820,6 +1129,40 @@ fn remove_comments(text: &str) -> String {
} }
} }
/// Adds an external function or variable usage as an import.
///
/// Returns the full module name, if it was found.
#[inline(always)]
fn add_external_as_import(module: &mut Module, ident_module: &str, ident_name: &str) -> Option<String> {
// Find the module that this variable is from.
// Starts by checking if the full module path was used.
// If it was not, then find the full path of the module.
let module_full_name = if module.module_imports.contains(ident_module) {
Some(ident_module.to_string())
} else {
module.module_imports.iter()
.find(|m| {
m.ends_with(ident_module)
}).cloned()
};
// Add the ident as an import from the module that was specified.
if let Some(module_from) = &module_full_name {
debug!("adding item to imports from {}", module_from);
let imports = module.item_imports.entry(module_from.clone())
.or_insert_with(|| Import { module: module_from.clone(), imports: vec![] });
let name = ident_name.to_string();
if !imports.imports.contains(&name) {
imports.imports.push(name);
}
} else {
debug!("Could not found full module name");
}
module_full_name
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::preprocessor::remove_comments; use crate::preprocessor::remove_comments;

View File

@ -1,11 +1,13 @@
shader_ident = { (ASCII_ALPHANUMERIC | "_")+ } ident = { (ASCII_ALPHANUMERIC | "_")+ }
type_generic = { ident ~ ("<" ~ type ~ ("," ~ ws* ~ type)? ~ ">") }
module_path = { ident ~ ( "::" ~ ident)+ }
type_array = { "array<" ~ type ~ ("," ~ ws ~ number)? ~ ">" }
type = { type_array | module_path | type_generic | ident }
shader_type = { shader_ident ~ ("<" ~ shader_ident ~ ">")? } shader_module = { ident ~ ( "::" ~ ident)* }
shader_module = { shader_ident ~ ( "::" ~ shader_ident)* }
import_module_command = { "import" ~ ws ~ shader_module } import_module_command = { "import" ~ ws ~ shader_module }
import_list = _{ "{" ~ shader_ident ~ (ws* ~ "," ~ NEWLINE? ~ ws* ~ shader_ident)* ~ "}" } import_list = _{ "{" ~ ident ~ (ws* ~ "," ~ NEWLINE? ~ ws* ~ ident)* ~ "}" }
import_types_command = { import_module_command ~ "::" ~ import_list } import_types_command = { import_module_command ~ "::" ~ import_list }
import_command = _{ import_types_command | import_module_command } import_command = _{ import_types_command | import_module_command }
@ -16,33 +18,34 @@ preproc_prefix = _{ "#" }
command_line = { preproc_prefix ~ (define_module_command | import_command) ~ newline } command_line = { preproc_prefix ~ (define_module_command | import_command) ~ newline }
// all characters used by wgsl // all characters used by wgsl
shader_code_char = { "@" | "-" | "+" | "*" | "/" | "=" | "(" | ")" | ">" | "<" | ";" | ":" | "." | "_" | "," } shader_code_char = { "@" | "-" | "+" | "*" | "/" | "=" | "(" | ")" | ">" | "<" | ";" | ":" | "." | "_" | "," | "&" | "|" | "[" | "]" | ASCII_ALPHA }
shader_code_block = { "{" ~ newline* ~ (cws* ~ (shader_actual_code_line ~ cws*)* ~ newline)+ ~ cws* ~ "}" } shader_code_block = { "{" ~ newline* ~ (cws* ~ (shader_actual_code_line ~ cws*)* ~ newline)+ ~ cws* ~ "}" }
// an fn argument can be another function use // an fn argument can be another function use
shader_code_fn_arg = _{ shader_code_fn_usage | shader_value | shader_ident } shader_code_fn_arg = _{ shader_code_fn_usage | shader_value | ident }
shader_code_fn_usage = { shader_ident ~ "(" ~ shader_code_fn_arg ~ ("," ~ ws* ~ shader_code_fn_arg)* ~ ")" } shader_code_fn_usage = { ident ~ "(" ~ shader_code_fn_arg ~ ("," ~ ws* ~ shader_code_fn_arg)* ~ ")" }
//shader_code_fn_usage = { shader_ident ~ "(in, 2.0)" } shader_code = { shader_code_block | shader_code_fn_usage | shader_value | type_array | type_generic | ident | shader_code_char }
shader_code = { shader_code_block | shader_code_fn_usage | shader_value | shader_ident | shader_code_char }
// usages of code from another module // usages of code from another module
shader_external_variable = { shader_ident ~ ( "::" ~ shader_ident)+ } shader_external_fn = { module_path ~ "(" ~ shader_code_fn_arg ~ ("," ~ ws* ~ shader_code_fn_arg)* ~ ")" }
shader_external_fn = { shader_external_variable ~ "(" ~ shader_code_fn_arg ~ ("," ~ ws* ~ shader_code_fn_arg)* ~ ")" } shader_external_code = _{ shader_external_fn | module_path }
shader_external_code = _{ shader_external_fn | shader_external_variable }
shader_actual_code_line = _{ shader_external_code | shader_code } shader_actual_code_line = _{ shader_external_code | shader_code }
shader_value_num = { ASCII_DIGIT+ ~ ( "." ~ ASCII_DIGIT+ )? } shader_value_num = { ASCII_DIGIT+ ~ ( "." ~ ASCII_DIGIT+ )? ~ ("u" | "i")? }
shader_value_bool = { "true" | "false" } shader_value_bool = { "true" | "false" }
shader_value = { shader_value_bool | shader_value_num } shader_value = { shader_value_bool | shader_value_num }
shader_var_type = { ":" ~ ws* ~ (shader_external_variable | shader_type) } shader_const_def = { "const" ~ ws ~ ident ~ (ws* ~ ":" ~ ws* ~ type)? ~ ws* ~ "=" ~ ws* ~ shader_value ~ ";"? }
shader_const_def = { "const" ~ ws ~ shader_ident ~ (ws* ~ shader_var_type)? ~ ws* ~ "=" ~ ws* ~ shader_value ~ ";" } shader_type_alias_def = { ("alias" | "type") ~ ws ~ ident ~ ws* ~ "=" ~ ws* ~ type ~ ";" }
shader_type_alias_def = { ("alias" | "type") ~ ws ~ shader_ident ~ ws* ~ "=" ~ ws* ~ (shader_external_variable | shader_type) ~ ";" } //shader_struct_field_attribute = { "@" ~ ("builtin(" ~ ASCII_ALPHA ~ ")") | ("location(" ~ NUMBER ~ ")") }
shader_struct_field_attribute = { "@" ~ ( ("builtin(" ~ ident ~ ")") | ("location(" ~ NUMBER+ ~ ")") ) }
shader_struct_field = _{
( shader_struct_field_attribute ~ (NEWLINE? ~ ws*) )? ~ ident ~ ":" ~ ws* ~ type
}
shader_struct_def = { shader_struct_def = {
"struct" ~ ws ~ shader_ident ~ ws* ~ "{" ~ NEWLINE ~ "struct" ~ ws ~ ident ~ ws* ~ "{" ~ NEWLINE ~
ws* ~ shader_ident ~ shader_var_type ~ (ws* ~ NEWLINE)* ~
("," ~ NEWLINE ~ ws* ~ shader_ident ~ shader_var_type)* ~ (cws* ~ (NEWLINE | shader_struct_field) ~ ("," ~ NEWLINE)?)* ~
","? ~ NEWLINE? ~
"}" "}"
} }
@ -52,15 +55,24 @@ shader_binding_var_constraint = {
( "uniform" | "private" | "workgroup" | ("storage" ~ ("," ~ ws? ~ "read" ~ "_write"? )?) ) ~ ( "uniform" | "private" | "workgroup" | ("storage" ~ ("," ~ ws? ~ "read" ~ "_write"? )?) ) ~
">" ">"
} }
shader_binding_def = { shader_group_binding ~ (NEWLINE | ws)? ~ "var" ~ shader_binding_var_constraint? ~ ws ~ shader_ident ~ ws* ~ shader_var_type ~ ";" } shader_binding_def = { shader_group_binding ~ (NEWLINE | ws)? ~ "var" ~ shader_binding_var_constraint? ~ ws ~ ident ~ ws* ~ ":" ~ ws* ~ type ~ ";" }
shader_var_name_type = { shader_ident ~ shader_var_type } shader_var_name_type = { ident ~ ":" ~ ws* ~ type }
shader_fn_args = { "(" ~ shader_var_name_type? ~ (ws* ~ "," ~ ws* ~ shader_var_name_type)* ~ ")" } shader_fn_args = { "(" ~
newline_or_ws? ~
// arguments are optional, so attempt to get one here
//shader_fn_attribute? ~ shader_var_name_type? ~ newline_or_ws? ~
//(newline_or_ws? ~ "," ~ newline_or_ws? ~ shader_var_name_type)*
( shader_fn_attribute? ~ shader_var_name_type ~ newline_or_ws? ~ ","? ~ newline_or_ws? )*
~ newline_or_ws? ~ ")"
}
// the body of a function, including the opening and closing brackets // the body of a function, including the opening and closing brackets
shader_fn_body = { "{" ~ newline* ~ (cws* ~ (shader_actual_code_line ~ cws*)* ~ newline)+ ~ "}" } shader_fn_body = { "{" ~ newline* ~ (cws* ~ (shader_actual_code_line ~ cws*)* ~ newline)+ ~ "}" }
shader_fn_attribute = { "@" ~ ASCII_ALPHA+ ~ ("(" ~ ident ~ ")")? ~ (NEWLINE? ~ ws*) }
shader_fn_def = { shader_fn_def = {
"fn" ~ ws ~ shader_ident ~ shader_fn_args ~ ws ~ "->" ~ ws ~ shader_type ~ ws ~ shader_fn_body shader_fn_attribute? ~
"fn" ~ ws ~ ident ~ shader_fn_args ~ ws* ~ "->" ~ ws* ~ shader_fn_attribute? ~ type ~ ws* ~ shader_fn_body
} }
// a line of shader code, including white space // a line of shader code, including white space
@ -81,4 +93,6 @@ ws = _{ " " | "\t" }
// capturing white space // capturing white space
cws = { " " | "\t" } cws = { " " | "\t" }
newline = { "\n" | "\r\n" | "\r" } newline = { "\n" | "\r\n" | "\r" }
newline_or_ws = _{ (NEWLINE ~ ws*) | ws* }
number = { NUMBER }