diff --git a/.vscode/launch.json b/.vscode/launch.json index 565c275..2196853 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -7,15 +7,15 @@ { "type": "lldb", "request": "launch", - "name": "Debug shader_prepoc", + "name": "Debug wgsl_preprocessor", "cargo": { "args": [ "build", //"--manifest-path", "${workspaceFolder}/examples/testbed/Cargo.toml" - "--bin=shader_prepoc", + "--bin=wgsl_preprocessor", ], "filter": { - "name": "shader_prepoc", + "name": "wgsl_preprocessor", "kind": "bin" } }, @@ -25,33 +25,55 @@ { "type": "lldb", "request": "launch", - "name": "Debug unit tests in executable 'shader_prepoc'", + "name": "Debug unit tests in executable 'wgsl_preprocessor'", "cargo": { "args": [ "test", "--no-run", - "--bin=shader_prepoc", - "--package=shader_prepoc" + "--bin=wgsl_preprocessor", + "--package=wgsl_preprocessor" ], "filter": { - "name": "shader_prepoc", + "name": "wgsl_preprocessor", "kind": "bin" } }, "args": [], "cwd": "${workspaceFolder}" }, + { + "type": "lldb", + "request": "launch", + "name": "Debug specific test in executable 'wgsl_preprocessor'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=wgsl_preprocessor", + "tests::import_struct_with_imported_fields", + "--", + "--exact --nocapture" + ], + "filter": { + "name": "wgsl_preprocessor", + "kind": "lib" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, /* { "type": "lldb", "request": "launch", - "name": "Test shader_prepoc", + "name": "Test wgsl_preprocessor", "cargo": { "args": [ "test", - "--bin=shader_prepoc", + "--bin=wgsl_preprocessor", ], "filter": { - "name": "shader_prepoc", + "name": "wgsl_preprocessor", "kind": "bin" } }, diff --git a/src/lib.rs b/src/lib.rs index 3afc903..a2dc40c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,13 +19,13 @@ pub enum ExternalUsageType { Function, } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Import { module: String, imports: Vec, } -#[derive(Clone)] +#[derive(Clone, PartialEq, Eq, Hash)] pub struct DefRequirement { /// None if the requirement is local module: Option, @@ -34,6 +34,7 @@ pub struct DefRequirement { #[derive(Clone)] pub struct Definition { + #[allow(dead_code)] name: String, requirements: Vec, /// The start byte position as a `usize`. @@ -50,6 +51,8 @@ pub struct Module { src: String, /// Constants that this module defines pub vars: HashMap, + pub structs: HashMap, + pub aliases: HashMap, /// Functions that this module defines pub functions: HashMap, /// Imports of things per module @@ -61,6 +64,16 @@ pub struct Module { module_imports: HashSet, } +impl Module { + /// Get a definition of any type from this module. + pub fn get_definition(&self, name: &str) -> Option<&Definition> { + self.functions.get(name) + .or_else(|| self.vars.get(name)) + .or_else(|| self.structs.get(name)) + .or_else(|| self.aliases.get(name)) + } +} + /// Recursively find files in `path`. pub(crate) fn recurse_files(path: impl AsRef) -> std::io::Result> { let mut buf = vec![]; @@ -137,6 +150,12 @@ fn main() -> vec4 { /// i.e., `simple::some_const`. #[test] fn double_layer_import_indirect_imports() { + /* tracing_subscriber::fmt() + // enable everything + .with_max_level(tracing::Level::TRACE) + // sets this to be the default, global collector for this application. + .init(); */ + let mut p = Processor::new(); p.parse_module(&INNER_MODULE).unwrap() .expect("failed to find module"); @@ -150,7 +169,7 @@ fn main() -> vec4 { assert!(out.contains("const simple_scalar"), "definition of imported `const simple_scalar` is missing!"); assert!(out.contains("fn do_something_cool("), "definition of imported `fn do_something_cool` is missing!"); - assert!(out.contains("const scalar"), "definition of imported dependency, `const scalar` is missing!"); + assert!(out.contains("const scalar"), "definition of imported dependency, `const scalar` is missing! {}", out); assert!(out.contains("fn mult_some_nums("), "definition of imported dependency, `fn mult_some_nums` is missing!"); assert!(out.contains("fn do_something_cool("), "definition of imported `fn do_something_cool` is missing!"); @@ -180,11 +199,11 @@ fn main() -> vec4 { return vec4(1.0); }"#; - tracing_subscriber::fmt() + /* tracing_subscriber::fmt() // enable everything .with_max_level(tracing::Level::TRACE) // sets this to be the default, global collector for this application. - .init(); + .init(); */ let mut p = Processor::new(); p.parse_module(BINDINGS_MODULE).unwrap() @@ -221,11 +240,11 @@ fn main() -> vec4 { return vec4(1.0); }"#; - tracing_subscriber::fmt() + /* tracing_subscriber::fmt() // enable everything .with_max_level(tracing::Level::TRACE) // sets this to be the default, global collector for this application. - .init(); + .init(); */ let mut p = Processor::new(); p.parse_module(BINDINGS_MODULE).unwrap() @@ -271,11 +290,11 @@ fn main() -> vec4 { return vec4(1.0); }"#; - tracing_subscriber::fmt() + /* tracing_subscriber::fmt() // enable everything .with_max_level(tracing::Level::TRACE) // sets this to be the default, global collector for this application. - .init(); + .init(); */ let mut p = Processor::new(); p.parse_module(BINDINGS_MODULE).unwrap() @@ -293,4 +312,171 @@ fn main() -> vec4 { assert!(out.contains("var light_indices"), "missing light_indices binding: {}", out); assert!(out.contains("var some_buffer"), "missing some_buffer binding: {}", out); } -} \ No newline at end of file + + #[test] + fn import_struct() { + const BINDINGS_MODULE: &'static str = +r#"#define_module engine::lights + +struct Light { + intensity: f32 +}"#; + + const IMPORTING_MODULE: &'static str = +r#"#define_module engine::main +#import engine::lights::{Light} + +fn main() -> vec4 { + // imports must be used to be generated + let base_light = Light(1.0); + + return vec4(1.0); +}"#; + + /* tracing_subscriber::fmt() + // enable everything + .with_max_level(tracing::Level::TRACE) + // sets this to be the default, global collector for this application. + .init(); */ + + let mut p = Processor::new(); + p.parse_module(BINDINGS_MODULE).unwrap() + .expect("failed to find bindings module def"); + let importing_mod = p.parse_module(IMPORTING_MODULE).unwrap() + .expect("failed to find main module def"); + + let lights_mod = p.modules.get("engine::lights").unwrap(); + assert!(lights_mod.structs.contains_key("Light"), "`Light` struct definition was not parsed"); + + let out = p.process_file(&importing_mod, &IMPORTING_MODULE).unwrap(); + assert!(out.contains("struct Light"), "missing `Light` struct definition: {}", out); + } + + #[test] + fn import_struct_with_imported_fields() { + const TYPE_MODULE: &'static str = +r#"#define_module engine::types + +struct Something { + val: f32 +}"#; + + const BINDINGS_MODULE: &'static str = +r#"#define_module engine::lights +#import engine::types::{Something} + +struct Light { + some: Something, + intensity: f32 +}"#; + + const IMPORTING_MODULE: &'static str = +r#"#define_module engine::main +#import engine::lights::{Light} + +fn main() -> vec4 { + // imports must be used to be generated + let base_light = Light(1.0); + + return vec4(1.0); +}"#; + + /* tracing_subscriber::fmt() + // enable everything + .with_max_level(tracing::Level::TRACE) + // sets this to be the default, global collector for this application. + .init(); */ + + let mut p = Processor::new(); + p.parse_module(TYPE_MODULE).unwrap() + .expect("failed to find bindings module def"); + p.parse_module(BINDINGS_MODULE).unwrap() + .expect("failed to find bindings module def"); + let importing_mod = p.parse_module(IMPORTING_MODULE).unwrap() + .expect("failed to find main module def"); + + let lights_mod = p.modules.get("engine::lights").unwrap(); + assert!(lights_mod.structs.contains_key("Light"), "`Light` struct definition was not parsed"); + + let out = p.process_file(&importing_mod, &IMPORTING_MODULE).unwrap(); + assert!(out.contains("struct Light"), "missing `Light` struct definition: {}", out); + assert!(out.contains("struct Something"), "missing `Something` struct definition: {}", out); + } + + #[test] + fn import_type_alias() { + const TYPE_MODULE: &'static str = +r#"#define_module engine::types + +alias LightIntensity = f32;"#; + + const IMPORTING_MODULE: &'static str = +r#"#define_module engine::main +#import engine::types::{LightIntensity} + +fn main() -> vec4 { + // imports must be used to be generated + let intensity: LightIntensity = 1.0; + + return vec4(1.0); +}"#; + + /* tracing_subscriber::fmt() + // enable everything + .with_max_level(tracing::Level::TRACE) + // sets this to be the default, global collector for this application. + .init(); */ + + let mut p = Processor::new(); + p.parse_module(TYPE_MODULE).unwrap() + .expect("failed to find bindings module def"); + let importing_mod = p.parse_module(IMPORTING_MODULE).unwrap() + .expect("failed to find main module def"); + + let out = p.process_file(&importing_mod, &IMPORTING_MODULE).unwrap(); + assert!(out.contains("alias LightIntensity"), "missing `LightIntensity` type alias: {}", out); + } + + #[test] + fn import_type_alias_with_imports() { + const TYPE_MODULE: &'static str = +r#"#define_module engine::types + +alias Number = f32;"#; + + const LIGHT_MODULE: &'static str = +r#"#define_module engine::light +#import engine::types::{Number} + +alias LightIntensity = Number;"#; + + const IMPORTING_MODULE: &'static str = +r#"#define_module engine::main +#import engine::light::{LightIntensity} + +fn main() -> vec4 { + // imports must be used to be generated + let intensity: LightIntensity = 1.0; + + return vec4(1.0); +}"#; + + /* tracing_subscriber::fmt() + // enable everything + .with_max_level(tracing::Level::TRACE) + // sets this to be the default, global collector for this application. + .init(); */ + + let mut p = Processor::new(); + p.parse_module(TYPE_MODULE).unwrap() + .expect("failed to find bindings module def"); + p.parse_module(LIGHT_MODULE).unwrap() + .expect("failed to find bindings module def"); + let importing_mod = p.parse_module(IMPORTING_MODULE).unwrap() + .expect("failed to find main module def"); + + let out = p.process_file(&importing_mod, &IMPORTING_MODULE).unwrap(); + assert!(out.contains("alias LightIntensity"), "missing `LightIntensity` type alias: {}", out); + assert!(out.contains("alias Number"), "missing `Number` type alias: {}", out); + } +} diff --git a/src/preprocessor.rs b/src/preprocessor.rs index 38ec8ac..657cae4 100644 --- a/src/preprocessor.rs +++ b/src/preprocessor.rs @@ -63,6 +63,46 @@ pub enum PreprocessorError { ConflictingImport { from_module: String, name: String }, } +fn get_external_var_requirement(module: &Module, pair: Pair) -> Option { + let var_name = pair.as_str(); + + if pair.as_rule() == Rule::shader_external_variable { + let (module_name, name) = var_name.rsplit_once("::").unwrap(); + + if RESERVED_WORDS.contains(&name) { + return None; + } + + // Find the module that this variable is from. + // Starts by checking if the full module path was used. + // If it was not, then find the full path of the module. + let module_full_name = if module.module_imports.contains(module_name) { + Some(module_name.to_string()) + } else { + module.module_imports.iter() + .find(|m| { + m.ends_with(module_name) + }).cloned() + }; + + debug!("Binding module is {:?}", module_full_name); + + Some(DefRequirement { + module: module_full_name, + name: name.to_string(), + }) + } else { // The only possibility is `Rule::shader_type` + if RESERVED_WORDS.contains(&var_name) { + return None; + } + + Some(DefRequirement { + module: None, + name: var_name.to_string(), + }) + } +} + #[derive(Default)] pub struct Processor { pub modules: HashMap, @@ -73,28 +113,6 @@ impl Processor { Self::default() } - // - /* fn add_func_requirements(&self, found_requirements: &mut HashSet, requirements: &mut Vec, fn_name: &str) { - if found_requirements.contains(fn_name) { - return; - } - found_requirements.insert(fn_name.to_string()); - - debug!("Found call to `{}`", fn_name); - - // ignore reserved words - if RESERVED_WORDS.contains(&fn_name) { - return; - } - - let req = DefRequirement { - // module is discovered later - module: None, - name: fn_name.to_string(), - }; - requirements.push(req); - } */ - #[instrument(fields(module = module.name), skip_all)] fn get_imports_in_block(&mut self, module: &mut Module, block: Pair, found_requirements: &mut HashSet) -> Vec { let mut requirements = vec![]; @@ -407,6 +425,67 @@ impl Processor { _ => unimplemented!("ran into unhandled rule: {:?}", line.as_span()) } }, + Rule::shader_struct_def => { + let mut struct_inner = line.clone().into_inner(); + let struct_name = struct_inner.next().unwrap(); + + // iterate through fields in struct + let mut requirements = vec![]; + while let (Some(_), Some(field_ty)) = (struct_inner.next(), struct_inner.next()) { + //panic!("got {} of type {}", ident.as_str(), field_ty.as_str()); + let mut ty_inner = field_ty.into_inner(); + let ty_inner = ty_inner.next().unwrap(); + + match ty_inner.as_rule() { + Rule::shader_external_variable | Rule::shader_type => { + //requirements.push(value) + if let Some(req) = get_external_var_requirement(&module, ty_inner) { + requirements.push(req); + } + }, + _ => unimplemented!("ran into unhandled rule: ({:?}) {:?}", line.as_rule(), line.as_span()) + } + } + + let sname = struct_name.as_str().to_string(); + let line_span = line.as_span(); + let start_pos = line_span.start(); + let end_pos = line_span.end(); + + module.structs.insert(sname.clone(), Definition { + name: sname, + requirements, + start_pos, + end_pos, + }); + }, + Rule::shader_type_alias_def => { + let mut inner = line.clone().into_inner(); + + // get rules + let alias_name = inner.next().unwrap(); + let alias_to = inner.next().unwrap(); + + // get the requirements of the alias if there are any + let requirements = if let Some(req) = get_external_var_requirement(&module, alias_to) { + debug!("Type req: {}", req.name); + vec![req] + } else { vec![] }; + + let aname = alias_name.as_str().to_string(); + debug!("Type alias `{}`", aname); + + let line_span = line.as_span(); + let start_pos = line_span.start(); + let end_pos = line_span.end(); + + module.aliases.insert(aname.clone(), Definition { + name: aname, + requirements, + start_pos, + end_pos, + }); + }, Rule::cws => (), Rule::newline => (), _ => unimplemented!("ran into unhandled rule: ({:?}) {:?}", line.as_rule(), line.as_span()) @@ -466,7 +545,8 @@ impl Processor { let module = self.modules.get(module_path).unwrap(); let mut output = String::new(); - compile_definitions(&self.modules, module, &mut output)?; + let mut compiled = HashSet::new(); + compile_definitions(&self.modules, module, &mut compiled, &mut output)?; Ok(output) } @@ -578,8 +658,14 @@ impl Processor { } fn try_find_requirement_module(module: &Module, req_name: &str) -> Option { + debug!("looking for '{}' in module '{}' that imports: {:?}", req_name, module.name, module.item_imports.values()); + //debug!("Module: {}, req: {}", module.name, req_name); + for import in module.item_imports.values() { + //debug!("Import: {:?}", import); + if import.imports.contains(&req_name.to_string()) { + //debug!("has it"); return Some(import.module.clone()); } } @@ -587,35 +673,81 @@ fn try_find_requirement_module(module: &Module, req_name: &str) -> Option, module: &Module, output: &mut String) -> Result<(), PreprocessorError> { - for (_, funcs) in &module.functions { +//#[instrument(fields(module = module.name, require=tracing::field::Empty), skip_all)] +fn compile_definitions(modules: &HashMap, module: &Module, compiled_definitions: &mut HashSet, output: &mut String) -> Result<(), PreprocessorError> { + let e = debug_span!("compile_definitions", module = module.name, require = tracing::field::Empty, sub_require = tracing::field::Empty).entered(); + + let defs = module.functions.iter().chain(module.vars.iter()).chain(module.structs.iter()).chain(module.aliases.iter()); + for (_, funcs) in defs { //&module.functions { let mut requirements = VecDeque::from(funcs.requirements.clone()); while let Some(mut req) = requirements.pop_front() { + e.record("require", req.name.clone()); + if req.module.is_none() { let mod_name = try_find_requirement_module(&module, &req.name); req.module = mod_name; + debug!("req module: {:?}", req.module); } if let Some(module_name) = &req.module { let req_module = modules.get(module_name) .unwrap_or_else(|| panic!("invalid module import: {}", module_name)); - let req_def = req_module.functions.get(&req.name) - .or_else(|| req_module.vars.get(&req.name)) + // get the definition from the module that defines the import + let req_def = req_module.get_definition(&req.name) .unwrap_or_else(|| panic!("invalid import: {} from {}", req.name, module_name)); if !req_def.requirements.is_empty() { - let sub_req_names: Vec = req_def.requirements.iter().map(|r| r.name.clone()).collect(); - debug!("Found requirement: {}, with the following sub-requirements: {:?}", req_def.name, sub_req_names); + for mut sub in req_def.requirements.clone() { + e.record("sub_require", sub.name.clone()); + + println!("searching for sub requirement module in module: '{}'", req_module.name); + + if sub.module.is_none() { + sub.module = try_find_requirement_module(&req_module, &sub.name); + } + + if let Some(sub_mod) = &sub.module { + if !compiled_definitions.contains(&sub) { + compiled_definitions.insert(sub.clone()); + + let sub_mod = modules.get(sub_mod).expect("unknown module from import"); + debug!("found sub requirement module: {}", sub_mod.name); + let module_name = &sub_mod.name; + + let mut requirements_output = String::new(); + compile_definitions(modules, sub_mod, compiled_definitions, &mut requirements_output)?; + + output.write_fmt(format_args!("\n// REQUIREMENTS OF {}::{}\n", module_name, req.name))?; + output.push_str(&requirements_output); + output.push_str("\n"); + + // find the definition of the sub requirement + let sub_def = sub_mod.get_definition(&sub.name) + .unwrap_or_else(|| panic!("invalid import: {} from {}", req.name, module_name)); + + // write the source of the sub requirement to the output + let req_src = &sub_mod.src[sub_def.start_pos..sub_def.end_pos]; + output.write_fmt(format_args!("// SOURCE {}::{}\n", module_name, sub.name))?; + output.push_str(req_src); + output.push_str("\n"); + } else { + debug!("Requirement module is already compiled, skipping..."); + } + } else { + debug!("Requirement has no module, assuming its local..."); + } + } + /* let sub_req_names: Vec = req_def.requirements.iter().map(|r| r.name.clone()).collect(); + debug!("Found requirement: {}, with the following sub-requirements: {:?} from {}", req_def.name, sub_req_names, req_module.name); let mut requirements_output = String::new(); compile_definitions(modules, req_module, &mut requirements_output)?; output.write_fmt(format_args!("\n// REQUIREMENTS OF {}::{}\n", module_name, req.name))?; output.push_str(&requirements_output); - output.push_str("\n"); + output.push_str("\n"); */ } let func_src = &req_module.src[req_def.start_pos..req_def.end_pos]; diff --git a/src/wgsl.pest b/src/wgsl.pest index 392e697..beeb1d4 100644 --- a/src/wgsl.pest +++ b/src/wgsl.pest @@ -37,6 +37,14 @@ shader_value = { shader_value_bool | shader_value_num } shader_var_type = { ":" ~ ws* ~ (shader_external_variable | shader_type) } shader_const_def = { "const" ~ ws ~ shader_ident ~ (ws* ~ shader_var_type)? ~ ws* ~ "=" ~ ws* ~ shader_value ~ ";" } +shader_type_alias_def = { ("alias" | "type") ~ ws ~ shader_ident ~ ws* ~ "=" ~ ws* ~ (shader_external_variable | shader_type) ~ ";" } +shader_struct_def = { + "struct" ~ ws ~ shader_ident ~ ws* ~ "{" ~ NEWLINE ~ + ws* ~ shader_ident ~ shader_var_type ~ + ("," ~ NEWLINE ~ ws* ~ shader_ident ~ shader_var_type)* ~ + ","? ~ NEWLINE? ~ + "}" +} shader_group_binding = { "@group(" ~ NUMBER ~ ") @binding(" ~ NUMBER ~ ")" } shader_binding_var_constraint = { @@ -58,8 +66,10 @@ shader_fn_def = { // a line of shader code, including white space shader_code_line = { shader_fn_def | + (shader_struct_def ~ newline?) | (shader_const_def ~ newline?) | (shader_binding_def ~ newline?) | + (shader_type_alias_def ~ newline?) | ws* ~ newline } //shader_code_line = { shader_fn_def | shader_const_def | (ws* ~ (shader_external_code | shader_code)* ~ ws*) }