From 70daf320827f64b325a77718df07177d74d7ea58 Mon Sep 17 00:00:00 2001 From: SeanOMik Date: Sat, 14 Sep 2024 20:03:37 -0400 Subject: [PATCH] Fix many many problems that led to invalid shader code output for lyra engine --- src/lib.rs | 570 +++++++++++++++++++++++++++++++-- src/preprocessor.rs | 751 ++++++++++++++++++++++++++++++++------------ src/wgsl.pest | 66 ++-- 3 files changed, 1139 insertions(+), 248 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index a2dc40c..1a4d71d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,6 +25,13 @@ pub struct Import { imports: Vec, } +/// An import of a single item from a module. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct ImportItemFrom { + module: String, + import: String, +} + #[derive(Clone, PartialEq, Eq, Hash)] pub struct DefRequirement { /// None if the requirement is local @@ -138,7 +145,7 @@ fn main() -> vec4 { let simple_path = p.parse_module(&SIMPLE_MODULE).unwrap() .expect("failed to find module"); - let out = p.process_file(&simple_path, &SIMPLE_MODULE).unwrap(); + let out = p.preprocess_module(&simple_path).unwrap(); assert!(out.contains("const scalar"), "definition of imported `const scalar` is missing!"); assert!(out.contains("fn mult_some_nums"), "definition of imported `fn mult_some_nums` is missing!"); @@ -150,11 +157,11 @@ fn main() -> vec4 { /// i.e., `simple::some_const`. #[test] fn double_layer_import_indirect_imports() { - /* tracing_subscriber::fmt() + tracing_subscriber::fmt() // enable everything .with_max_level(tracing::Level::TRACE) // sets this to be the default, global collector for this application. - .init(); */ + .init(); let mut p = Processor::new(); p.parse_module(&INNER_MODULE).unwrap() @@ -164,18 +171,18 @@ fn main() -> vec4 { let main_path = p.parse_module(&MAIN_MODULE).unwrap() .expect("failed to find module"); - let out = p.process_file(&main_path, &MAIN_MODULE).unwrap(); + let out = p.preprocess_module(&main_path).unwrap(); - assert!(out.contains("const simple_scalar"), "definition of imported `const simple_scalar` is missing!"); - assert!(out.contains("fn do_something_cool("), "definition of imported `fn do_something_cool` is missing!"); + assert!(out.contains("const simple_scalar"), "definition of imported `const simple_scalar` is missing!\n{}", out); + assert!(out.contains("fn do_something_cool("), "definition of imported `fn do_something_cool` is missing!\n{}", out); - assert!(out.contains("const scalar"), "definition of imported dependency, `const scalar` is missing! {}", out); - assert!(out.contains("fn mult_some_nums("), "definition of imported dependency, `fn mult_some_nums` is missing!"); + assert!(out.contains("const scalar"), "definition of imported dependency, `const scalar` is missing!\n{}", out); + assert!(out.contains("fn mult_some_nums("), "definition of imported dependency, `fn mult_some_nums` is missing!\n{}", out); - assert!(out.contains("fn do_something_cool("), "definition of imported `fn do_something_cool` is missing!"); + assert!(out.contains("fn do_something_cool("), "definition of imported `fn do_something_cool` is missing!\n{}", out); - assert!(out.contains("simple_scalar * do_something_cool"), "indirect imports of `simple_scalar * do_something_cool` is missing!"); - assert!(out.contains("fn main("), "definition of `fn main` is missing!"); + assert!(out.contains("simple_scalar * do_something_cool"), "indirect imports of `simple_scalar * do_something_cool` is missing!\n{}", out); + assert!(out.contains("fn main("), "definition of `fn main` is missing!\n{}", out); } #[test] @@ -211,7 +218,7 @@ fn main() -> vec4 { let importing_mod = p.parse_module(IMPORTING_MODULE).unwrap() .expect("failed to find main module def"); - let out = p.process_file(&importing_mod, &IMPORTING_MODULE).unwrap(); + let out = p.preprocess_module(&importing_mod).unwrap(); assert!(out.contains("var shadows_atlas"), "missing shadows_atlas binding: {}", out); assert!(out.contains("var lights"), "missing lights binding: {}", out); @@ -252,7 +259,7 @@ fn main() -> vec4 { let importing_mod = p.parse_module(IMPORTING_MODULE).unwrap() .expect("failed to find main module def"); - let out = p.process_file(&importing_mod, &IMPORTING_MODULE).unwrap(); + let out = p.preprocess_module(&importing_mod).unwrap(); assert!(out.contains("var shadows_atlas"), "missing shadows_atlas binding: {}", out); assert!(out.contains("var lights"), "missing lights binding: {}", out); @@ -302,7 +309,7 @@ fn main() -> vec4 { let importing_mod = p.parse_module(IMPORTING_MODULE).unwrap() .expect("failed to find main module def"); - let out = p.process_file(&importing_mod, &IMPORTING_MODULE).unwrap(); + let out = p.preprocess_module(&importing_mod).unwrap(); assert!(out.contains("var use_anyway"), "missing use_anyway binding: {}", out); assert!(out.contains("var shadows_atlas"), "missing shadows_atlas binding: {}", out); @@ -348,7 +355,7 @@ fn main() -> vec4 { let lights_mod = p.modules.get("engine::lights").unwrap(); assert!(lights_mod.structs.contains_key("Light"), "`Light` struct definition was not parsed"); - let out = p.process_file(&importing_mod, &IMPORTING_MODULE).unwrap(); + let out = p.preprocess_module(&importing_mod).unwrap(); assert!(out.contains("struct Light"), "missing `Light` struct definition: {}", out); } @@ -398,7 +405,7 @@ fn main() -> vec4 { let lights_mod = p.modules.get("engine::lights").unwrap(); assert!(lights_mod.structs.contains_key("Light"), "`Light` struct definition was not parsed"); - let out = p.process_file(&importing_mod, &IMPORTING_MODULE).unwrap(); + let out = p.preprocess_module(&importing_mod).unwrap(); assert!(out.contains("struct Light"), "missing `Light` struct definition: {}", out); assert!(out.contains("struct Something"), "missing `Something` struct definition: {}", out); } @@ -433,7 +440,7 @@ fn main() -> vec4 { let importing_mod = p.parse_module(IMPORTING_MODULE).unwrap() .expect("failed to find main module def"); - let out = p.process_file(&importing_mod, &IMPORTING_MODULE).unwrap(); + let out = p.preprocess_module(&importing_mod).unwrap(); assert!(out.contains("alias LightIntensity"), "missing `LightIntensity` type alias: {}", out); } @@ -475,8 +482,535 @@ fn main() -> vec4 { let importing_mod = p.parse_module(IMPORTING_MODULE).unwrap() .expect("failed to find main module def"); - let out = p.process_file(&importing_mod, &IMPORTING_MODULE).unwrap(); + let out = p.preprocess_module(&importing_mod).unwrap(); assert!(out.contains("alias LightIntensity"), "missing `LightIntensity` type alias: {}", out); assert!(out.contains("alias Number"), "missing `Number` type alias: {}", out); } + + /// Tests comments and gaps in structs + #[test] + fn comments_and_gaps_in_structs() { + const TYPE_MODULE: &'static str = +r#"#define_module engine::types + +struct TextureAtlasFrame { + /*offset: vec2, + size: vec2,*/ + x: u32, + y: u32, + + width: u32, + height: u32, +}"#; + + /* tracing_subscriber::fmt() + // enable everything + .with_max_level(tracing::Level::TRACE) + // sets this to be the default, global collector for this application. + .init(); */ + + let mut p = Processor::new(); + let type_mod = p.parse_module(TYPE_MODULE).unwrap() + .expect("failed to find bindings module def"); + + let out = p.preprocess_module(&type_mod).unwrap(); + assert!(out.contains("struct TextureAtlasFrame"), "missing `TextureAtlasFrame` struct definition: {}", out); + assert!(!out.contains("offset: vec2,"), "`offset` struct field was included in output: {}", out); + assert!(!out.contains("size: vec2,"), "`size` struct field was included in output: {}", out); + assert!(out.contains("x: u32,"), "missing `x` struct field: {}", out); + assert!(out.contains("y: u32,"), "missing `y` struct field: {}", out); + assert!(out.contains("width: u32,"), "missing `width` struct field: {}", out); + assert!(out.contains("height: u32,"), "missing `height` struct field: {}", out); + } + + #[test] + fn built_in_generics() { + const TYPE_MODULE: &'static str = +r#"#define_module engine::types + +@group(5) @binding(5) +var point_array: array>;"#; + + /* tracing_subscriber::fmt() + // enable everything + .with_max_level(tracing::Level::TRACE) + // sets this to be the default, global collector for this application. + .init(); */ + + let mut p = Processor::new(); + let type_mod = p.parse_module(TYPE_MODULE).unwrap() + .expect("failed to find bindings module def"); + + let out = p.preprocess_module(&type_mod).unwrap(); + assert!(out.contains("var point_array"), "missing `point_array` storage variable: {}", out); + } + + #[test] + fn complex_code_blocks() { + const MAIN_MODULE: &'static str = +r#"#define_module engine::main + +fn main() -> vec4 { + let abs_x = 5.0; + let abs_y = 2.0; + let abs_z = 1.0; + + if (abs_x >= abs_y && abs_x >= abs_z) { + let some_thing = 0.5; + + return vec4(some_thing); + } + + return vec4(1.0); +}"#; + + /* tracing_subscriber::fmt() + // enable everything + .with_max_level(tracing::Level::TRACE) + // sets this to be the default, global collector for this application. + .init(); */ + + let mut p = Processor::new(); + let module = p.parse_module(MAIN_MODULE).unwrap() + .expect("failed to find bindings module def"); + + let out = p.preprocess_module(&module).unwrap(); + assert!(out.contains("abs_x >= abs_y && abs_x >= abs_z"), "missing if statement condition: {}", out); + assert!(out.contains("let some_thing = 0.5;"), "missing definition of `some_thing`: {}", out); + assert!(out.contains("return vec4(some_thing);"), "missing return of `some_thing` as vec4: {}", out); + } + + #[test] + fn code_blocks_using_arrays_nested() { + const MAIN_MODULE: &'static str = +r#"#define_module engine::main + +fn main() -> vec4 { + calc_shadow_dir_light(vec3(1.0), vec3(1.0), vec3(1.0), light_example_var); + + return vec4(1.0); +} + +fn calc_shadow_dir_light(world_pos: vec3, world_normal: vec3, light_dir: vec3, light: Light) -> f32 { + let map_data: LightShadowMapUniform = u_light_shadow[light.light_shadow_uniform_index[0]]; + let frag_pos_light_space = map_data.light_space_matrix * vec4(world_pos, 1.0); + + return 1.0; +} + +fn pcf_point_light(tex_coords: vec3, test_depth: f32, shadow_us: array, samples_num: u32, uv_radius: f32) -> f32 { + return 1.0; +}"#; + + /* tracing_subscriber::fmt() + // enable everything + .with_max_level(tracing::Level::TRACE) + // sets this to be the default, global collector for this application. + .init(); */ + + let mut p = Processor::new(); + let module = p.parse_module(MAIN_MODULE).unwrap() + .expect("failed to find bindings module def"); + + let out = p.preprocess_module(&module).unwrap(); + assert!(out.contains("map_data: LightShadowMapUniform"), "missing `map_data` var: {}", out); + //println!("{out}"); + } + + #[test] + fn struct_with_attributes() { + const MAIN_MODULE: &'static str = +r#"#define_module engine::main + +struct VertexOutput { + @builtin(position) clip_position: vec4, + @location(0) tex_coords: vec2, + @location(1) world_position: vec3, + @location(2) world_normal: vec3, + @location(3) frag_pos_light_space: vec4, +} + +struct VertexOutput { + @builtin(position) + clip_position: vec4, + @location(0) + tex_coords: vec2, +}"#; + + /* tracing_subscriber::fmt() + // enable everything + .with_max_level(tracing::Level::TRACE) + // sets this to be the default, global collector for this application. + .init(); */ + + let mut p = Processor::new(); + let module = p.parse_module(MAIN_MODULE).unwrap() + .expect("failed to find bindings module def"); + + let out = p.preprocess_module(&module).unwrap(); + assert!(out.contains("struct VertexOutput"), "missing `VertexOutput` struct: {}", out); + assert!(out.contains("@location(0) tex_coords: vec2"), "missing `tex_coords` struct field: {}", out); + //println!("{out}"); + } + + #[test] + fn function_annotations() { + const MAIN_MODULE: &'static str = +r#"#define_module engine::main + +@vertex +fn vs_main( + @builtin(vertex_index) in_vertex_index: u32, +) -> VertexOutput { + var out: VertexOutput; + let x = f32(1 - i32(in_vertex_index)) * 0.5; + let y = f32(i32(in_vertex_index & 1u) * 2 - 1) * 0.5; + out.clip_position = vec4(x, y, 0.0, 1.0); + return out; +}"#; + + /* tracing_subscriber::fmt() + // enable everything + .with_max_level(tracing::Level::TRACE) + // sets this to be the default, global collector for this application. + .init(); */ + + let mut p = Processor::new(); + let module = p.parse_module(MAIN_MODULE).unwrap() + .expect("failed to find bindings module def"); + + let out = p.preprocess_module(&module).unwrap(); + assert!(out.contains("@vertex"), "missing `@vertex` annotation: {}", out); + assert!(out.contains("fn vs_main("), "missing `vs_main` entry point: {}", out); + assert!(out.contains("@builtin(vertex_index) in_vertex_index: u32"), "missing `in_vertex_index` function argument: {}", out); + //println!("{out}"); + } + + #[test] + fn import_binding_type() { + const STRUCTS_MODULE: &'static str = +r#"#define_module engine::shadows::structs + +struct TextureAtlasFrame { + x: u32, + y: u32, + width: u32, + height: u32, +} + +struct LightShadowMapUniform { + light_space_matrix: mat4x4, + atlas_frame: TextureAtlasFrame, + near_plane: f32, + far_plane: f32, + light_size_uv: f32, + light_pos: vec3, + /// boolean casted as u32 + has_shadow_settings: u32, + pcf_samples_num: u32, + pcss_blocker_search_samples: u32, + constant_depth_bias: f32, +}"#; + + const MAIN_MODULE: &'static str = +r#"#define_module engine::shadows::depth_pass +#import engine::shadows::structs::{LightShadowMapUniform} + +struct TransformData { + transform: mat4x4, + normal_matrix: mat4x4, +} + +@group(0) @binding(0) +var u_light_shadow: array; +@group(1) @binding(0) +var u_model_transform_data: TransformData; + +struct VertexOutput { + @builtin(position) + clip_position: vec4, + @location(0) world_pos: vec3, + @location(1) instance_index: u32, +} + +@vertex +fn vs_main( + @location(0) position: vec3, + @builtin(instance_index) instance_index: u32, +) -> VertexOutput { + let world_pos = u_model_transform_data.transform * vec4(position, 1.0); + let pos = u_light_shadow[instance_index].light_space_matrix * world_pos; + return VertexOutput(pos, world_pos.xyz, instance_index); +}"#; + + /* tracing_subscriber::fmt() + // enable everything + .with_max_level(tracing::Level::TRACE) + // sets this to be the default, global collector for this application. + .init(); */ + + let mut p = Processor::new(); + p.parse_module(STRUCTS_MODULE).unwrap() + .expect("failed to find bindings module def"); + let module = p.parse_module(MAIN_MODULE).unwrap() + .expect("failed to find bindings module def"); + + let out = p.preprocess_module(&module).unwrap(); + assert!(out.contains("struct LightShadowMapUniform"), "missing `LightShadowMapUniform` \ + struct definition: {}", out); + assert!(out.contains("struct TextureAtlasFrame"), "missing `TextureAtlasFrame` struct \ + definition, requirement of `LightShadowMapUniform` struct: {}", out); + assert!(out.contains("u_light_shadow: array"), "missing \ + `u_light_shadow` binding: {}", out); + println!("{out}"); + } + + #[test] + fn array_construct() { + const STRUCTS_MODULE: &'static str = +r#"#define_module engine::shadows::structs + +struct TextureAtlasFrame { + x: u32, + y: u32, + width: u32, + height: u32, +} + +struct LightShadowMapUniform { + light_space_matrix: mat4x4, + atlas_frame: TextureAtlasFrame, + near_plane: f32, + far_plane: f32, + light_size_uv: f32, + light_pos: vec3, + /// boolean casted as u32 + has_shadow_settings: u32, + pcf_samples_num: u32, + pcss_blocker_search_samples: u32, + constant_depth_bias: f32, +}"#; + + const MAIN_MODULE: &'static str = +r#"#define_module engine::shadows::depth_pass +#import engine::shadows::structs::{LightShadowMapUniform} + +fn main() -> f32 { + let uniforms = array( + u_light_shadow[indices[0]], + u_light_shadow[indices[1]], + u_light_shadow[indices[2]], + u_light_shadow[indices[3]], + u_light_shadow[indices[4]], + u_light_shadow[indices[5]] + ); + + return 1.0; +}"#; + + /* tracing_subscriber::fmt() + // enable everything + .with_max_level(tracing::Level::TRACE) + // sets this to be the default, global collector for this application. + .init(); */ + + let mut p = Processor::new(); + p.parse_module(STRUCTS_MODULE).unwrap() + .expect("failed to find bindings module def"); + let module = p.parse_module(MAIN_MODULE).unwrap() + .expect("failed to find bindings module def"); + + let out = p.preprocess_module(&module).unwrap(); + assert!(out.contains("struct LightShadowMapUniform"), "missing `LightShadowMapUniform` \ + struct definition: {}", out); + assert!(out.contains("struct TextureAtlasFrame"), "missing `TextureAtlasFrame` struct \ + definition, requirement of `LightShadowMapUniform` struct: {}", out); + assert!(out.contains("array("), "missing creation of uniform \ + array: {}", out); + println!("{out}"); + } + + #[test] + fn generics_with_two_args() { + const MAIN_MODULE: &'static str = +r#"#define_module engine::shadows::depth_pass + +@group(0) @binding(0) +var t_light_grid: texture_storage_2d; + +fn main() -> f32 { + return 1.0; +}"#; + + /* tracing_subscriber::fmt() + // enable everything + .with_max_level(tracing::Level::TRACE) + // sets this to be the default, global collector for this application. + .init(); */ + + let mut p = Processor::new(); + let module = p.parse_module(MAIN_MODULE).unwrap() + .expect("failed to find bindings module def"); + + let out = p.preprocess_module(&module).unwrap(); + assert!(out.contains("t_light_grid: texture_storage_2d"), "missing `t_light_grid` \ + binding: {}", out); + println!("{out}"); + } + + #[test] + fn function_return_attributes() { + const MAIN_MODULE: &'static str = +r#"#define_module engine::shadows::depth_pass + +fn main() -> @location(0) vec4 { + return vec4(1.0); +}"#; + + /* tracing_subscriber::fmt() + // enable everything + .with_max_level(tracing::Level::TRACE) + // sets this to be the default, global collector for this application. + .init(); */ + + let mut p = Processor::new(); + let module = p.parse_module(MAIN_MODULE).unwrap() + .expect("failed to find bindings module def"); + + let out = p.preprocess_module(&module).unwrap(); + assert!(out.contains("@location(0) vec4"), + "missing return value with attribute: {}", out); + println!("{out}"); + } + + #[test] + fn import_binding_use() { + const STRUCTS_MODULE: &'static str = +r#"#define_module engine::shadows::structs + +struct TextureAtlasFrame { + x: u32, + y: u32, + width: u32, + height: u32, +} + +struct LightShadowMapUniform { + light_space_matrix: mat4x4, + atlas_frame: TextureAtlasFrame, + near_plane: f32, + far_plane: f32, + light_size_uv: f32, + light_pos: vec3, + /// boolean casted as u32 + has_shadow_settings: u32, + pcf_samples_num: u32, + pcss_blocker_search_samples: u32, + constant_depth_bias: f32, +} + +@group(0) @binding(0) +var u_light_shadow: array;"#; + + const MAIN_MODULE: &'static str = +r#"#define_module engine::shadows::depth_pass +#import engine::shadows::structs::{u_light_shadow} + +fn main() -> @location(0) vec4 { + let shadow_u: LightShadowMapUniform = u_light_shadow[light.light_shadow_uniform_index[0]]; + + return vec4(1.0); +}"#; + + /* tracing_subscriber::fmt() + // enable everything + .with_max_level(tracing::Level::TRACE) + // sets this to be the default, global collector for this application. + .init(); */ + + let mut p = Processor::new(); + p.parse_module(STRUCTS_MODULE).unwrap() + .expect("failed to find bindings module def"); + let module = p.parse_module(MAIN_MODULE).unwrap() + .expect("failed to find bindings module def"); + + let out = p.preprocess_module(&module).unwrap(); + assert!(out.contains("u_light_shadow: array"), + "missing `u_light_shadow` binding:\n{}", out); + println!("{out}"); + } + + /// Tests using the same import multiple times + #[test] + fn multi_import_use() { + const STRUCTS_MODULE: &'static str = +r#"#define_module engine::shadows::structs + +struct TextureAtlasFrame { + x: u32, + y: u32, + width: u32, + height: u32, +} + +struct LightShadowMapUniform { + light_space_matrix: mat4x4, + atlas_frame: TextureAtlasFrame, + near_plane: f32, + far_plane: f32, + light_size_uv: f32, + light_pos: vec3, + /// boolean casted as u32 + has_shadow_settings: u32, + pcf_samples_num: u32, + pcss_blocker_search_samples: u32, + constant_depth_bias: f32, +} + +@group(0) @binding(0) +var u_light_shadow: array;"#; + + const FN_MODULE: &'static str = +r#"#define_module engine::shadows::funcs +#import engine::shadows::structs::{LightShadowMapUniform, u_light_shadow} + +fn do_light_thing(u: LightShadowMapUniform) -> f32 { + let s = u_light_shadow[light.light_shadow_uniform_index[0]]; + return 1.0; +}"#; + + const MAIN_MODULE: &'static str = +r#"#define_module engine::shadows::depth_pass +#import engine::shadows::structs::{u_light_shadow} +#import engine::shadows::funcs::{do_light_thing} + +fn main() -> @location(0) vec4 { + let shadow_u: LightShadowMapUniform = u_light_shadow[light.light_shadow_uniform_index[0]]; + let f = do_light_thing(shadow_u); + + return vec4(f); +}"#; + + /* tracing_subscriber::fmt() + // enable everything + .with_max_level(tracing::Level::TRACE) + // sets this to be the default, global collector for this application. + .init(); */ + + let mut p = Processor::new(); + p.parse_module(STRUCTS_MODULE).unwrap() + .expect("failed to find bindings module def"); + p.parse_module(FN_MODULE).unwrap() + .expect("failed to find bindings module def"); + let module = p.parse_module(MAIN_MODULE).unwrap() + .expect("failed to find bindings module def"); + + let out = p.preprocess_module(&module).unwrap(); + assert!(out.contains("u_light_shadow: array"), + "missing `u_light_shadow` binding:\n{}", out); + let v: Vec<_> = out.match_indices("struct LightShadowMapUniform").collect(); + assert_eq!(v.len(), 1, + "expected one `LightShadowMapUniform` struct definition:\n{}", out); + println!("{out}"); + } } diff --git a/src/preprocessor.rs b/src/preprocessor.rs index 657cae4..65ecbb1 100644 --- a/src/preprocessor.rs +++ b/src/preprocessor.rs @@ -1,10 +1,10 @@ -use std::{collections::{HashMap, HashSet, VecDeque}, fmt::Write, fs, path::Path}; +use std::{collections::{HashMap, HashSet}, fmt::Write, fs, path::Path}; use pest::{iterators::Pair, Parser}; use regex::Regex; use tracing::{debug, debug_span, instrument, trace}; -use crate::{recurse_files, DefRequirement, Definition, Import, Module, Rule, WgslParser}; +use crate::{recurse_files, DefRequirement, Definition, Import, ImportItemFrom, Module, Rule, WgslParser}; const RESERVED_WORDS: [&str; 201] = [ "NULL", "Self", "abstract", "active", "alignas", "alignof", "as", "asm", "asm_fragment", @@ -48,59 +48,81 @@ const RESERVED_WORDS: [&str; 201] = [ ]; #[derive(Debug, thiserror::Error)] -pub enum PreprocessorError { +pub enum Error { #[error("{0}")] IoError(#[from] std::io::Error), #[error("error parsing {0}")] ParserError(#[from] pest::error::Error), #[error("failure formatting preprocessor output to string ({0})")] FormatError(#[from] std::fmt::Error), + #[error("could not find module from path '{0}'")] + UnknownModulePath(String), #[error("unknown module import '{module}', in {from_module}")] - UnknownModule { from_module: String, module: String }, + UnknownModuleImport { from_module: String, module: String }, #[error("in {from_module}: unknown import from '{module}': `{item}`")] - UnknownImport { from_module: String, module: String, item: String }, + UnknownItemImport { from_module: String, module: String, item: String }, #[error("import usage from `{from_module}` conflicts with local variable/function: `{name}`")] ConflictingImport { from_module: String, name: String }, } -fn get_external_var_requirement(module: &Module, pair: Pair) -> Option { +fn get_external_var_requirement(module: &mut Module, pair: Pair) -> Option { let var_name = pair.as_str(); - if pair.as_rule() == Rule::shader_external_variable { - let (module_name, name) = var_name.rsplit_once("::").unwrap(); + let r = pair.as_rule(); + match r { + Rule::module_path => { + let (module_name, name) = var_name.rsplit_once("::").unwrap(); - if RESERVED_WORDS.contains(&name) { + if RESERVED_WORDS.contains(&name) { + return None; + } + + // Find the module that this variable is from. + // Starts by checking if the full module path was used. + // If it was not, then find the full path of the module. + let module_full_name = add_external_as_import(module, module_name, name); + + debug!("Binding module is {:?}", module_full_name); + + Some(DefRequirement { + module: module_full_name, + name: name.to_string(), + }) + }, + Rule::type_generic => { + // generics are always builtins return None; - } + }, + Rule::type_array => { + let mut inner = pair.into_inner(); + let inner = inner.next().unwrap(); + assert_eq!(inner.as_rule(), Rule::r#type); + // inner is `Rule::type` + get_external_var_requirement(module, inner) + }, + Rule::ident => { + if RESERVED_WORDS.contains(&var_name) { + return None; + } - // Find the module that this variable is from. - // Starts by checking if the full module path was used. - // If it was not, then find the full path of the module. - let module_full_name = if module.module_imports.contains(module_name) { - Some(module_name.to_string()) - } else { - module.module_imports.iter() - .find(|m| { - m.ends_with(module_name) - }).cloned() - }; - - debug!("Binding module is {:?}", module_full_name); - - Some(DefRequirement { - module: module_full_name, - name: name.to_string(), - }) - } else { // The only possibility is `Rule::shader_type` - if RESERVED_WORDS.contains(&var_name) { - return None; - } - - Some(DefRequirement { - module: None, - name: var_name.to_string(), - }) + Some(DefRequirement { + module: None, + name: var_name.to_string(), + }) + }, + Rule::r#type => { + let mut inner = pair.into_inner(); + let inner = inner.next().unwrap(); + get_external_var_requirement(module, inner) + }, + _ => unimplemented!("ran into unhandled rule: {:?}, {:?}", r, pair.as_str()) } + + /* if pair.as_rule() == Rule::module_path { + + } else { // The only possibility is `Rule::shader_type` + + } */ } #[derive(Default)] @@ -157,15 +179,16 @@ impl Processor { let (fn_module, fn_name) = fn_path.rsplit_once("::").unwrap(); - debug!("Found call to `{}::{}`", fn_module, fn_name); + debug!("Found call to external `{}::{}`", fn_module, fn_name); + let full_module = add_external_as_import(module, fn_module, fn_name); let req = DefRequirement { - module: Some(fn_module.into()), + module: full_module, name: fn_name.to_string(), }; requirements.push(req); }, - Rule::shader_external_variable => { + Rule::module_path => { let pairs = code.into_inner(); // shader_external_variable is the only pair for this rule let ident_path = pairs.as_str(); @@ -177,23 +200,28 @@ impl Processor { let (ident_module, ident_name) = ident_path.rsplit_once("::").unwrap(); - debug!("Found call to `{}`", ident_name); - // ignore reserved words if RESERVED_WORDS.contains(&ident_name) { continue; } - + + let module_full = add_external_as_import(module, ident_module, ident_name); + let req = DefRequirement { - module: Some(ident_module.into()), + module: module_full, name: ident_name.to_string(), }; requirements.push(req); }, + Rule::type_array | Rule::r#type => { + let req = self.get_imports_in_block(module, code, found_requirements); + requirements.extend(req.into_iter()); + }, Rule::newline => (), Rule::cws => (), + Rule::number => (), Rule::shader_code_char => (), - Rule::shader_ident => { + Rule::ident => { let ident = code.as_str(); if found_requirements.contains(ident) { @@ -213,6 +241,8 @@ impl Processor { }; requirements.push(req); }, + // generic types are only built-in types + Rule::type_generic => (), Rule::shader_value => (), _ => unimplemented!("ran into unhandled rule: {:?}, {:?}", code.as_rule(), code.as_span()) } @@ -221,6 +251,45 @@ impl Processor { requirements } + /// Get the requirement of a type defined as a pair. + /// + /// Returns `None` when the type is built-in + fn get_type_requirement(module: &mut Module, mut type_pair: Pair<'_, Rule>) -> Option { + let mut r = type_pair.as_rule(); + + // the inner rule of `shader_var_type_array` is + // `shader_external_variable` or `shader_type` + while r == Rule::type_array { + let mut inner = type_pair.into_inner(); + type_pair = inner.next().unwrap(); + r = type_pair.as_rule(); + } + + if r == Rule::module_path { + let external_str = type_pair.as_str(); + let (module_name, name) = external_str.rsplit_once("::").unwrap(); + + if RESERVED_WORDS.contains(&name) { + return None; + } + + let module_full_name = add_external_as_import(module, module_name, name); + + Some(DefRequirement { + module: module_full_name, + name: name.to_string(), + }) + } else { // The only possibility is `Rule::shader_type` + let name = type_pair.as_str(); + + Some(DefRequirement { + module: None, + name: name.to_string(), + }) + } + + } + /// Parse a module file to attempt to find the include identifier. /// /// Returns `None` if the module does not define an include identifier. @@ -228,7 +297,7 @@ impl Processor { pub fn parse_module( &mut self, module_src: &str, - ) -> Result, PreprocessorError> { + ) -> Result, Error> { let e = debug_span!("parse_module", module = tracing::field::Empty).entered(); let module_src = remove_comments(module_src); @@ -293,10 +362,20 @@ impl Processor { Rule::shader_fn_def => { let mut pairs = line.clone().into_inner(); // shader_ident is the only pair for this rule - let fn_name = pairs.next().unwrap().as_str().to_string(); + let fn_name = { + let next = pairs.next().unwrap(); + let r = next.as_rule(); + + if r == Rule::shader_fn_attribute { + pairs.next().unwrap() + } else { + next + } + }; + let fn_name = fn_name.as_str().to_string(); debug!("Found function def: {fn_name}"); - let fn_body = pairs.skip(2).next().unwrap(); + let fn_body = pairs.last().unwrap();//pairs.skip(3).next().unwrap(); let mut found_reqs = HashSet::default(); let requirements = self.get_imports_in_block(&mut module, fn_body, &mut found_reqs); @@ -324,54 +403,20 @@ impl Processor { if name.as_rule() == Rule::shader_binding_var_constraint { trace!("Skipping binding constraint"); name = pairs.next().unwrap(); - assert_eq!(name.as_rule(), Rule::shader_ident); + assert_eq!(name.as_rule(), Rule::ident); } let name = name.as_str(); let var_type = pairs.next().unwrap(); + let mut var_type_inner = var_type.into_inner(); + let inner_type = var_type_inner.next().unwrap(); - let mut requirements = vec![]; - { - let mut vt_inner = var_type.into_inner(); - let inner_type = vt_inner.next().unwrap(); + let requirement_vec = Self::get_type_requirement(&mut module, inner_type) + .map(|r| vec![r]) + .unwrap_or_default(); - if inner_type.as_rule() == Rule::shader_external_variable { - let external_str = inner_type.as_str(); - let (module_name, name) = external_str.rsplit_once("::").unwrap(); - - if RESERVED_WORDS.contains(&name) { - continue; - } - - // Find the module that this variable is from. - // Starts by checking if the full module path was used. - // If it was not, then find the full path of the module. - let module_full_name = if module.module_imports.contains(module_name) { - Some(module_name.to_string()) - } else { - module.module_imports.iter() - .find(|m| { - m.ends_with(module_name) - }).cloned() - }; - - debug!("Binding module is {:?}", module_full_name); - - requirements.push(DefRequirement { - module: module_full_name, - name: name.to_string(), - }); - } else { // The only possibility is `Rule::shader_type` - let name = inner_type.as_str(); - - requirements.push(DefRequirement { - module: None, - name: name.to_string(), - }); - } - } - - debug!("Found binding def: `{name}`"); + let req_names = requirement_vec.first().map(|r| r.name.clone()); + debug!("Found binding def: `{}` with requirement: {:?}", name, req_names); let line_span = line.as_span(); let start_pos = line_span.start(); @@ -383,7 +428,7 @@ impl Processor { name: name.to_string(), start_pos, end_pos, - requirements: vec![], + requirements: requirement_vec, }, ); }, @@ -393,6 +438,19 @@ impl Processor { let const_name = pairs.next().unwrap().as_str().to_string(); debug!("Found const def: `{const_name}`"); + let mut requirements = vec![]; + + // if the type of the const was specified, find the requirement + let next = pairs.next().unwrap(); + if next.as_rule() == Rule::r#type { + let mut next_inner = next.into_inner(); + let var_type = next_inner.next().unwrap(); + + if let Some(req) = Self::get_type_requirement(&mut module, var_type) { + requirements.push(req); + } + } + let line_span = line.as_span(); let start_pos = line_span.start(); let end_pos = line_span.end(); @@ -403,7 +461,7 @@ impl Processor { name: const_name, start_pos, end_pos, - requirements: vec![], + requirements, }, ); }, @@ -431,19 +489,46 @@ impl Processor { // iterate through fields in struct let mut requirements = vec![]; - while let (Some(_), Some(field_ty)) = (struct_inner.next(), struct_inner.next()) { - //panic!("got {} of type {}", ident.as_str(), field_ty.as_str()); + loop { + let (_field_attr, _field_name) = { + let mut temp = struct_inner.next(); + + // ignore whitespace + while let Some(t) = &temp { + if t.as_rule() == Rule::cws { + temp = struct_inner.next(); + } else { + break; + } + } + + if let Some(next) = temp { + let r = next.as_rule(); + + if r == Rule::shader_struct_field_attribute { + (Some(next), struct_inner.next().unwrap()) + } else { + (None, next) + } + } else { + break; + } + }; + //todo!("field name is: {:?}", _field_name.as_str()); + + let field_ty = struct_inner.next().unwrap(); + let field_ty_rule = field_ty.as_rule(); let mut ty_inner = field_ty.into_inner(); let ty_inner = ty_inner.next().unwrap(); - - match ty_inner.as_rule() { - Rule::shader_external_variable | Rule::shader_type => { - //requirements.push(value) - if let Some(req) = get_external_var_requirement(&module, ty_inner) { + //todo!("ty inner is: {:?}", ty_inner.as_str()); + + match field_ty_rule { + Rule::r#type => { + if let Some(req) = get_external_var_requirement(&mut module, ty_inner) { requirements.push(req); } }, - _ => unimplemented!("ran into unhandled rule: ({:?}) {:?}", line.as_rule(), line.as_span()) + _ => unimplemented!("ran into unhandled rule: ({:?}) {:?}", field_ty_rule, ty_inner.as_str()) } } @@ -467,7 +552,7 @@ impl Processor { let alias_to = inner.next().unwrap(); // get the requirements of the alias if there are any - let requirements = if let Some(req) = get_external_var_requirement(&module, alias_to) { + let requirements = if let Some(req) = get_external_var_requirement(&mut module, alias_to) { debug!("Type req: {}", req.name); vec![req] } else { vec![] }; @@ -521,7 +606,7 @@ impl Processor { &mut self, path: P, extensions: [&str; N], - ) -> Result { + ) -> Result { let files = recurse_files(path)?; let mut parsed = 0; @@ -541,18 +626,93 @@ impl Processor { } #[instrument(skip(self))] - fn generate_header(&mut self, module_path: &str) -> Result { + fn generate_header(&mut self, module_path: &str) -> Result { let module = self.modules.get(module_path).unwrap(); let mut output = String::new(); let mut compiled = HashSet::new(); - compile_definitions(&self.modules, module, &mut compiled, &mut output)?; + let mut compiled_imports = HashSet::new(); + + /* for (module_name, import) in &module.item_imports { + for item in &import.imports { + let item_mod = self.modules.get(&import.module) + .ok_or(Error::UnknownItemImport { from_module: module_path.into(), module: import.module.clone(), item: item.clone() })?; + let item_def = item_mod.get_definition(item) + .ok_or(Error::UnknownItemImport { from_module: module_path.into(), module: import.module.clone(), item: item.clone() })?; + + compile_requirements(self.modules, item_mod, &mut compiled, item_def, &mut output)?; + + let func_src = &item_mod.src[item_def.start_pos..item_def.end_pos]; + output.write_fmt(format_args!("// SOURCE {}::{} (import)\n", module_name, item_def.name))?; + output.push_str(func_src); + output.push_str("\n"); + } + } */ + compile_imports(&self.modules, module, &mut compiled_imports, &mut compiled, &mut output)?; + + //compile_definitions(&self.modules, module, &mut compiled, &mut output)?; Ok(output) } + #[instrument(skip(self, type_, output))] + fn output_type(&self, type_: Pair, output: &mut String) -> Result<(), std::fmt::Error> { + let r = type_.as_rule(); + match r { + Rule::type_array => { + let mut inner = type_.into_inner(); + let array_ty = inner.next().unwrap(); + let array_size = inner.next() + .map(|p| format!(", {}", p.as_str())) + .unwrap_or_default(); + + output.write_str("array<")?; + self.output_type(array_ty, output)?; + output.write_str(&array_size)?; + output.write_str(">")?; + }, + Rule::module_path => { + let pairs = type_.clone().into_inner(); + // the last pair is the name of the variable + let ident_name = pairs.last().unwrap().as_str(); + + output.write_str(ident_name)?; + }, + Rule::type_generic => { + let mut pairs = type_.clone().into_inner(); + // the last pair is the name of the variable + let ident_name = pairs.next().unwrap().as_str(); + + output.write_str(ident_name)?; + output.write_str("<")?; + self.output_type(pairs.next().unwrap(), output)?; + + // if there's a second generic field, output it as well + if let Some(second_generic) = pairs.next() { + output.write_str(", ")?; + self.output_type(second_generic, output)?; + } + + output.write_str(">")?; + }, + Rule::ident => { + output.write_str(type_.as_str())?; + }, + Rule::r#type => { + let mut pairs = type_.into_inner(); + let ty = pairs.next().unwrap(); + self.output_type(ty, output)?; + }, + _ => unimplemented!("ran into unhandled rule: {:?}, {:?}", r, type_.as_str()), + } + + Ok(()) + } + #[instrument(skip(self, shader_code_rule, output))] fn output_shader_code_line(&self, shader_code_rule: Pair, output: &mut String) -> Result<(), std::fmt::Error> { for line in shader_code_rule.into_inner() { + //debug!("output rule: {:?}, '{}'", line.as_rule(), line.as_str()); + match line.as_rule() { Rule::shader_external_fn => { let mut pairs = line.clone().into_inner(); @@ -569,33 +729,139 @@ impl Processor { panic!("Unknown error, rule input: {}", line.as_str()); } }, - Rule::shader_external_variable => { - let pairs = line.clone().into_inner(); - // the last pair is the name of the variable - let ident_name = pairs.last().unwrap().as_str(); - - output.write_str(ident_name)?; + Rule::type_array | Rule::module_path => { + self.output_type(line, output)?; }, - //shader_external_variable Rule::shader_fn_def => { let mut rule_inner = line.into_inner(); - let fn_name = rule_inner.next().unwrap().as_str(); - let args = rule_inner.next().unwrap().as_str(); - let fn_ret = rule_inner.next().unwrap().as_str(); + // attempt to get the optional attribute, right after it is the name + let mut attribute = None; + let fn_name = { + // get the fn name, s + let next = rule_inner.next().unwrap(); + let r = next.as_rule(); + + if r == Rule::shader_fn_attribute { + attribute = Some(next); + rule_inner.next().unwrap() + } else { + next + } + }; + let attribute = attribute + .map(|a| a.as_str().to_string()) + .unwrap_or_default(); + let fn_name = fn_name.as_str(); + + let args = rule_inner.next().unwrap().as_str(); // TODO: external types in args + let mut fn_ret_attr = ""; + let fn_ret = { + let next = rule_inner.next().unwrap(); + let r = next.as_rule(); + + if r == Rule::shader_fn_attribute { + fn_ret_attr = next.as_str(); + rule_inner.next().unwrap().as_str() + } else { + next.as_str() + } + }; let fn_body = rule_inner.next().unwrap(); let mut body_output = String::new(); self.output_shader_code_line(fn_body, &mut body_output)?; // to escape {, must use two - output.write_fmt(format_args!("fn {}{} -> {} {{{}}}", - fn_name, args, fn_ret, body_output))?; + output.write_fmt(format_args!("{}fn {}{} -> {}{} {{{}}}", + attribute, fn_name, args, fn_ret_attr, fn_ret, body_output))?; + }, + Rule::shader_code_block => { + output.write_str("{")?; + // this call will run line.into_inner() + self.output_shader_code_line(line, output)?; + output.write_str("}")?; + }, + Rule::shader_struct_def => { + let mut rule_inner = line.into_inner(); + + let ident = rule_inner.next().unwrap().as_str(); + output.write_fmt(format_args!("struct {} {{\n", ident))?; + + // process fields + loop { + let mut space_count = 0; + let (field_attr, field_name) = { + let mut temp = rule_inner.next(); + + // find the white space amount for the field + while let Some(t) = &temp { + if t.as_rule() == Rule::cws { + space_count += 1; + temp = rule_inner.next(); + } else { + // stop when a non whitespace rule is found + break; + } + } + + if let Some(next) = temp { + let r = next.as_rule(); + + if r == Rule::shader_struct_field_attribute { + (Some(next), rule_inner.next().unwrap()) + } else { + (None, next) + } + } else { + break; + } + }; + + // write indentation + output.write_str(&" ".repeat(space_count))?; + + // write field attribute with space to separate from field name + if let Some(field_attr) = field_attr { + output.write_str(field_attr.as_str())?; + output.write_str(" ")?; + } + + output.write_str(field_name.as_str())?; + output.write_str(": ")?; + + // get field type inner + let field_ty = rule_inner.next().unwrap(); + self.output_type(field_ty, output)?; + output.write_str(",\n")?; + } + + output.write_str("}")?; + }, + Rule::shader_binding_def => { + let rule_inner = line.clone().into_inner(); + + // the last pair is the type of the binding + let var_type = rule_inner.last().unwrap(); + let var_type_span = var_type.as_span(); + + // this captures everything up to the type of the var + let pairs_start = line.as_span().start(); + let var_type_start = var_type_span.start(); + let in_end = var_type_start - pairs_start; + let mut var_def = line.as_str()[..in_end].to_string(); + + // this recursive call will output the code for the binding type + self.output_type(var_type, &mut var_def)?; + output.write_str(&var_def)?; + output.write_str(";")?; }, Rule::shader_code => { self.output_shader_code_line(line, output)?; }, - Rule::shader_code_fn_usage | Rule::shader_value | Rule::shader_const_def | Rule::shader_ident | Rule::shader_code_char | Rule::cws => { + Rule::shader_code_fn_usage | Rule::shader_value | Rule::r#type | + Rule::shader_const_def | Rule::ident | Rule::shader_code_char | + Rule::cws | Rule::type_generic => { let input = line.as_str(); output.write_str(&input)?; }, @@ -609,24 +875,26 @@ impl Processor { Ok(()) } - #[instrument(skip(self, module_src))] - pub fn process_file(&mut self, module_path: &str, module_src: &str) -> Result { - let module_src = remove_comments(module_src); + #[instrument(skip(self))] + pub fn preprocess_module(&mut self, module_path: &str) -> Result { + let module = self.modules.get(module_path) + .ok_or(Error::UnknownModulePath(module_path.to_string()))?; + let module_src = module.src.clone(); + + let module_pair = WgslParser::parse(Rule::file, &module_src)? + .next() + .unwrap(); // get and unwrap the `file` rule; never fails let mut out_string = String::new(); // the output will be at least the length of module_src out_string.reserve(module_src.len()); - let file = WgslParser::parse(Rule::file, &module_src)? - .next() - .unwrap(); // get and unwrap the `file` rule; never fails - let header = self.generate_header(module_path)?; out_string.write_str("// START OF IMPORT HEADER\n")?; out_string.write_str(&header)?; out_string.write_str("// END OF IMPORT HEADER\n")?; - for record in file.into_inner() { + for record in module_pair.into_inner() { match record.as_rule() { Rule::command_line => { // the parser has found a preprocessor command, figure out what it is @@ -658,14 +926,10 @@ impl Processor { } fn try_find_requirement_module(module: &Module, req_name: &str) -> Option { - debug!("looking for '{}' in module '{}' that imports: {:?}", req_name, module.name, module.item_imports.values()); - //debug!("Module: {}, req: {}", module.name, req_name); + trace!("looking for '{}' in module '{}' that imports: {:?}", req_name, module.name, module.item_imports.values()); for import in module.item_imports.values() { - //debug!("Import: {:?}", import); - if import.imports.contains(&req_name.to_string()) { - //debug!("has it"); return Some(import.module.clone()); } } @@ -673,90 +937,135 @@ fn try_find_requirement_module(module: &Module, req_name: &str) -> Option, module: &Module, compiled_definitions: &mut HashSet, output: &mut String) -> Result<(), PreprocessorError> { - let e = debug_span!("compile_definitions", module = module.name, require = tracing::field::Empty, sub_require = tracing::field::Empty).entered(); +#[inline(always)] +fn is_import_compiled(compiled_imports: &mut HashSet, import_mod: &str, import_name: &str) -> bool { + let compiled_import = ImportItemFrom { + module: import_mod.to_string(), + import: import_name.to_string(), + }; - let defs = module.functions.iter().chain(module.vars.iter()).chain(module.structs.iter()).chain(module.aliases.iter()); - for (_, funcs) in defs { //&module.functions { - let mut requirements = VecDeque::from(funcs.requirements.clone()); + !compiled_imports.insert(compiled_import) +} + +#[inline(always)] +fn is_def_compiled(compiled_definitions: &mut HashSet, def_module: &str, def_name: &str) -> bool { + let item_def_req = DefRequirement { + module: Some(def_module.to_string()), + name: def_name.to_string(), + }; + + !compiled_definitions.insert(item_def_req) +} + +fn compile_imports(modules: &HashMap, module: &Module, compiled_imports: &mut HashSet, compiled_definitions: &mut HashSet, output: &mut String) -> Result<(), Error> { + let module_path = &module.name; + + for (module_name, import) in &module.item_imports { + let item_mod = modules.get(module_name) + .ok_or(Error::UnknownModulePath(module_name.clone()))?; + compile_imports(modules, item_mod, compiled_imports, compiled_definitions, output)?; + + for item in &import.imports { + if is_import_compiled(compiled_imports, &import.module, item) { + continue; + } else if is_def_compiled(compiled_definitions, &module_name, item) { + continue; + } + + let item_def = item_mod.get_definition(item) + .ok_or(Error::UnknownItemImport { from_module: module_path.into(), module: import.module.clone(), item: item.clone() })?; + + compile_requirements(modules, item_mod, compiled_definitions, item_def, output)?; + + let func_src = &item_mod.src[item_def.start_pos..item_def.end_pos]; + output.write_fmt(format_args!("// SOURCE {}::{} (import)\n", module_name, item_def.name))?; + output.push_str(func_src); + output.push_str("\n"); + } + } + + Ok(()) +} + +fn compile_requirements(modules: &HashMap, module: &Module, compiled_definitions: &mut HashSet, def: &Definition, output: &mut String) -> Result<(), Error> { + let e = debug_span!("compile_requirements", module = module.name, + require = def.name, sub_require = tracing::field::Empty).entered(); + + trace!("Compiling {} requirements of {}", def.requirements.len(), def.name); + + for req in &def.requirements { + e.record("sub_require", req.name.clone()); + + let req_module = if module.get_definition(&req.name).is_some() { + // if the requirement is local, use the current module + Some(Ok(module)) + } else { + req.module.clone() + // try to find the module name the requirement is in using the imported module + // if it was not supplied + .or_else(|| try_find_requirement_module(&module, &req.name)) + // get the module, panic if the name was incorrect. + .map(|m| modules.get(&m) + .ok_or(Error::UnknownModuleImport { from_module: module.name.clone(), module: m }) + //.unwrap_or_else(|| panic!("invalid module import: {}", m)) + ) + }; - while let Some(mut req) = requirements.pop_front() { - e.record("require", req.name.clone()); + if let Some(req_module) = req_module { + let req_module = req_module?; + let module_name = &req_module.name; - if req.module.is_none() { - let mod_name = try_find_requirement_module(&module, &req.name); - req.module = mod_name; - debug!("req module: {:?}", req.module); + let mut r = req.clone(); + r.module = Some(module_name.clone()); + if compiled_definitions.contains(&r) { + debug!("Definition ({}) already compiled, skipping...", req.name); + continue; } + compiled_definitions.insert(r); + + let req_def = req_module.get_definition(&req.name) + .unwrap_or_else(|| panic!("invalid requirement of {}::{}: '{}'", module_name, def.name, req.name)); + + if !req_def.requirements.is_empty() { + output.write_fmt(format_args!("\n// REQUIREMENTS OF {}::{}\n", module_name, req_def.name))?; - if let Some(module_name) = &req.module { - let req_module = modules.get(module_name) - .unwrap_or_else(|| panic!("invalid module import: {}", module_name)); - - // get the definition from the module that defines the import - let req_def = req_module.get_definition(&req.name) - .unwrap_or_else(|| panic!("invalid import: {} from {}", req.name, module_name)); - - if !req_def.requirements.is_empty() { - for mut sub in req_def.requirements.clone() { - e.record("sub_require", sub.name.clone()); - - println!("searching for sub requirement module in module: '{}'", req_module.name); - - if sub.module.is_none() { - sub.module = try_find_requirement_module(&req_module, &sub.name); - } - - if let Some(sub_mod) = &sub.module { - if !compiled_definitions.contains(&sub) { - compiled_definitions.insert(sub.clone()); - - let sub_mod = modules.get(sub_mod).expect("unknown module from import"); - debug!("found sub requirement module: {}", sub_mod.name); - let module_name = &sub_mod.name; - - let mut requirements_output = String::new(); - compile_definitions(modules, sub_mod, compiled_definitions, &mut requirements_output)?; - - output.write_fmt(format_args!("\n// REQUIREMENTS OF {}::{}\n", module_name, req.name))?; - output.push_str(&requirements_output); - output.push_str("\n"); - - // find the definition of the sub requirement - let sub_def = sub_mod.get_definition(&sub.name) - .unwrap_or_else(|| panic!("invalid import: {} from {}", req.name, module_name)); - - // write the source of the sub requirement to the output - let req_src = &sub_mod.src[sub_def.start_pos..sub_def.end_pos]; - output.write_fmt(format_args!("// SOURCE {}::{}\n", module_name, sub.name))?; - output.push_str(req_src); - output.push_str("\n"); - } else { - debug!("Requirement module is already compiled, skipping..."); - } - } else { - debug!("Requirement has no module, assuming its local..."); - } + for mut sub in req_def.requirements.clone() { + if sub.module.is_none() { + sub.module = try_find_requirement_module(&req_module, &sub.name); } - /* let sub_req_names: Vec = req_def.requirements.iter().map(|r| r.name.clone()).collect(); - debug!("Found requirement: {}, with the following sub-requirements: {:?} from {}", req_def.name, sub_req_names, req_module.name); - let mut requirements_output = String::new(); - compile_definitions(modules, req_module, &mut requirements_output)?; + let sub_module = sub.module.as_ref() + .map(|m| modules.get(m) + .unwrap_or_else(|| panic!("invalid module import: {}", m))) + .unwrap_or(req_module); + + let mut s = sub.clone(); + s.module = Some(sub_module.name.clone()); + if compiled_definitions.contains(&s) { + debug!("Sub-requirement ({}) already compiled, skipping...", s.name); + continue; + } + compiled_definitions.insert(s); - output.write_fmt(format_args!("\n// REQUIREMENTS OF {}::{}\n", module_name, req.name))?; - output.push_str(&requirements_output); - output.push_str("\n"); */ - } + if let Some(local) = sub_module.get_definition(&sub.name) { + if !local.requirements.is_empty() { + output.write_fmt(format_args!("\n// REQUIREMENTS OF {}::{}\n", module_name, local.name))?; + compile_requirements(modules, sub_module, compiled_definitions, local, output)?; + output.push_str("\n"); + } - let func_src = &req_module.src[req_def.start_pos..req_def.end_pos]; - output.write_fmt(format_args!("// SOURCE {}::{}\n", module_name, req.name))?; - output.push_str(func_src); - output.push_str("\n"); - } else { - debug!("Could not find module for `{}`, assuming its local", req.name); + let req_src = &sub_module.src[local.start_pos..local.end_pos]; + output.write_fmt(format_args!("// SOURCE {}::{} (subsubrequirement)\n", module_name, sub.name))?; + output.push_str(req_src); + output.push_str("\n\n"); + } + } } + + let req_src = &req_module.src[req_def.start_pos..req_def.end_pos]; + output.write_fmt(format_args!("// SOURCE {}::{} (subrequirement)\n", module_name, req.name))?; + output.push_str(req_src); + output.push_str("\n"); } } @@ -820,6 +1129,40 @@ fn remove_comments(text: &str) -> String { } } +/// Adds an external function or variable usage as an import. +/// +/// Returns the full module name, if it was found. +#[inline(always)] +fn add_external_as_import(module: &mut Module, ident_module: &str, ident_name: &str) -> Option { + // Find the module that this variable is from. + // Starts by checking if the full module path was used. + // If it was not, then find the full path of the module. + let module_full_name = if module.module_imports.contains(ident_module) { + Some(ident_module.to_string()) + } else { + module.module_imports.iter() + .find(|m| { + m.ends_with(ident_module) + }).cloned() + }; + + // Add the ident as an import from the module that was specified. + if let Some(module_from) = &module_full_name { + debug!("adding item to imports from {}", module_from); + let imports = module.item_imports.entry(module_from.clone()) + .or_insert_with(|| Import { module: module_from.clone(), imports: vec![] }); + + let name = ident_name.to_string(); + if !imports.imports.contains(&name) { + imports.imports.push(name); + } + } else { + debug!("Could not found full module name"); + } + + module_full_name +} + #[cfg(test)] mod tests { use crate::preprocessor::remove_comments; diff --git a/src/wgsl.pest b/src/wgsl.pest index beeb1d4..a41830f 100644 --- a/src/wgsl.pest +++ b/src/wgsl.pest @@ -1,11 +1,13 @@ -shader_ident = { (ASCII_ALPHANUMERIC | "_")+ } +ident = { (ASCII_ALPHANUMERIC | "_")+ } +type_generic = { ident ~ ("<" ~ type ~ ("," ~ ws* ~ type)? ~ ">") } +module_path = { ident ~ ( "::" ~ ident)+ } +type_array = { "array<" ~ type ~ ("," ~ ws ~ number)? ~ ">" } +type = { type_array | module_path | type_generic | ident } -shader_type = { shader_ident ~ ("<" ~ shader_ident ~ ">")? } - -shader_module = { shader_ident ~ ( "::" ~ shader_ident)* } +shader_module = { ident ~ ( "::" ~ ident)* } import_module_command = { "import" ~ ws ~ shader_module } -import_list = _{ "{" ~ shader_ident ~ (ws* ~ "," ~ NEWLINE? ~ ws* ~ shader_ident)* ~ "}" } +import_list = _{ "{" ~ ident ~ (ws* ~ "," ~ NEWLINE? ~ ws* ~ ident)* ~ "}" } import_types_command = { import_module_command ~ "::" ~ import_list } import_command = _{ import_types_command | import_module_command } @@ -16,33 +18,34 @@ preproc_prefix = _{ "#" } command_line = { preproc_prefix ~ (define_module_command | import_command) ~ newline } // all characters used by wgsl -shader_code_char = { "@" | "-" | "+" | "*" | "/" | "=" | "(" | ")" | ">" | "<" | ";" | ":" | "." | "_" | "," } +shader_code_char = { "@" | "-" | "+" | "*" | "/" | "=" | "(" | ")" | ">" | "<" | ";" | ":" | "." | "_" | "," | "&" | "|" | "[" | "]" | ASCII_ALPHA } shader_code_block = { "{" ~ newline* ~ (cws* ~ (shader_actual_code_line ~ cws*)* ~ newline)+ ~ cws* ~ "}" } // an fn argument can be another function use -shader_code_fn_arg = _{ shader_code_fn_usage | shader_value | shader_ident } -shader_code_fn_usage = { shader_ident ~ "(" ~ shader_code_fn_arg ~ ("," ~ ws* ~ shader_code_fn_arg)* ~ ")" } -//shader_code_fn_usage = { shader_ident ~ "(in, 2.0)" } -shader_code = { shader_code_block | shader_code_fn_usage | shader_value | shader_ident | shader_code_char } +shader_code_fn_arg = _{ shader_code_fn_usage | shader_value | ident } +shader_code_fn_usage = { ident ~ "(" ~ shader_code_fn_arg ~ ("," ~ ws* ~ shader_code_fn_arg)* ~ ")" } +shader_code = { shader_code_block | shader_code_fn_usage | shader_value | type_array | type_generic | ident | shader_code_char } // usages of code from another module -shader_external_variable = { shader_ident ~ ( "::" ~ shader_ident)+ } -shader_external_fn = { shader_external_variable ~ "(" ~ shader_code_fn_arg ~ ("," ~ ws* ~ shader_code_fn_arg)* ~ ")" } -shader_external_code = _{ shader_external_fn | shader_external_variable } +shader_external_fn = { module_path ~ "(" ~ shader_code_fn_arg ~ ("," ~ ws* ~ shader_code_fn_arg)* ~ ")" } +shader_external_code = _{ shader_external_fn | module_path } shader_actual_code_line = _{ shader_external_code | shader_code } -shader_value_num = { ASCII_DIGIT+ ~ ( "." ~ ASCII_DIGIT+ )? } +shader_value_num = { ASCII_DIGIT+ ~ ( "." ~ ASCII_DIGIT+ )? ~ ("u" | "i")? } shader_value_bool = { "true" | "false" } shader_value = { shader_value_bool | shader_value_num } -shader_var_type = { ":" ~ ws* ~ (shader_external_variable | shader_type) } -shader_const_def = { "const" ~ ws ~ shader_ident ~ (ws* ~ shader_var_type)? ~ ws* ~ "=" ~ ws* ~ shader_value ~ ";" } -shader_type_alias_def = { ("alias" | "type") ~ ws ~ shader_ident ~ ws* ~ "=" ~ ws* ~ (shader_external_variable | shader_type) ~ ";" } +shader_const_def = { "const" ~ ws ~ ident ~ (ws* ~ ":" ~ ws* ~ type)? ~ ws* ~ "=" ~ ws* ~ shader_value ~ ";"? } +shader_type_alias_def = { ("alias" | "type") ~ ws ~ ident ~ ws* ~ "=" ~ ws* ~ type ~ ";" } +//shader_struct_field_attribute = { "@" ~ ("builtin(" ~ ASCII_ALPHA ~ ")") | ("location(" ~ NUMBER ~ ")") } +shader_struct_field_attribute = { "@" ~ ( ("builtin(" ~ ident ~ ")") | ("location(" ~ NUMBER+ ~ ")") ) } +shader_struct_field = _{ + ( shader_struct_field_attribute ~ (NEWLINE? ~ ws*) )? ~ ident ~ ":" ~ ws* ~ type +} shader_struct_def = { - "struct" ~ ws ~ shader_ident ~ ws* ~ "{" ~ NEWLINE ~ - ws* ~ shader_ident ~ shader_var_type ~ - ("," ~ NEWLINE ~ ws* ~ shader_ident ~ shader_var_type)* ~ - ","? ~ NEWLINE? ~ + "struct" ~ ws ~ ident ~ ws* ~ "{" ~ NEWLINE ~ + (ws* ~ NEWLINE)* ~ + (cws* ~ (NEWLINE | shader_struct_field) ~ ("," ~ NEWLINE)?)* ~ "}" } @@ -52,15 +55,24 @@ shader_binding_var_constraint = { ( "uniform" | "private" | "workgroup" | ("storage" ~ ("," ~ ws? ~ "read" ~ "_write"? )?) ) ~ ">" } -shader_binding_def = { shader_group_binding ~ (NEWLINE | ws)? ~ "var" ~ shader_binding_var_constraint? ~ ws ~ shader_ident ~ ws* ~ shader_var_type ~ ";" } +shader_binding_def = { shader_group_binding ~ (NEWLINE | ws)? ~ "var" ~ shader_binding_var_constraint? ~ ws ~ ident ~ ws* ~ ":" ~ ws* ~ type ~ ";" } -shader_var_name_type = { shader_ident ~ shader_var_type } -shader_fn_args = { "(" ~ shader_var_name_type? ~ (ws* ~ "," ~ ws* ~ shader_var_name_type)* ~ ")" } +shader_var_name_type = { ident ~ ":" ~ ws* ~ type } +shader_fn_args = { "(" ~ + newline_or_ws? ~ + // arguments are optional, so attempt to get one here + //shader_fn_attribute? ~ shader_var_name_type? ~ newline_or_ws? ~ + //(newline_or_ws? ~ "," ~ newline_or_ws? ~ shader_var_name_type)* + ( shader_fn_attribute? ~ shader_var_name_type ~ newline_or_ws? ~ ","? ~ newline_or_ws? )* + ~ newline_or_ws? ~ ")" +} // the body of a function, including the opening and closing brackets shader_fn_body = { "{" ~ newline* ~ (cws* ~ (shader_actual_code_line ~ cws*)* ~ newline)+ ~ "}" } +shader_fn_attribute = { "@" ~ ASCII_ALPHA+ ~ ("(" ~ ident ~ ")")? ~ (NEWLINE? ~ ws*) } shader_fn_def = { - "fn" ~ ws ~ shader_ident ~ shader_fn_args ~ ws ~ "->" ~ ws ~ shader_type ~ ws ~ shader_fn_body + shader_fn_attribute? ~ + "fn" ~ ws ~ ident ~ shader_fn_args ~ ws* ~ "->" ~ ws* ~ shader_fn_attribute? ~ type ~ ws* ~ shader_fn_body } // a line of shader code, including white space @@ -81,4 +93,6 @@ ws = _{ " " | "\t" } // capturing white space cws = { " " | "\t" } -newline = { "\n" | "\r\n" | "\r" } \ No newline at end of file +newline = { "\n" | "\r\n" | "\r" } +newline_or_ws = _{ (NEWLINE ~ ws*) | ws* } +number = { NUMBER } \ No newline at end of file