diff --git a/Cargo.toml b/Cargo.toml index ae57867..ce9554f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,4 +9,6 @@ pest_derive = "2.7.11" regex = "1.10.6" thiserror = "1.0.63" tracing = "0.1.40" + +[dev-dependencies] tracing-subscriber = "0.3.18" diff --git a/README.md b/README.md index 7b645c9..83798d0 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ This crate was created for my 3d game engine, Lyra Engine, which uses this as a ## Features -* Modules import other modules via defined module paths +* Import constants, functions, and bindings from other modules More to come, check out issues in this repo. ## How-To diff --git a/src/lib.rs b/src/lib.rs index 3784983..3afc903 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,7 +49,7 @@ pub struct Module { /// The source code of the module, non-processed. src: String, /// Constants that this module defines - pub constants: HashMap, + pub vars: HashMap, /// Functions that this module defines pub functions: HashMap, /// Imports of things per module @@ -158,4 +158,139 @@ fn main() -> vec4 { assert!(out.contains("simple_scalar * do_something_cool"), "indirect imports of `simple_scalar * do_something_cool` is missing!"); assert!(out.contains("fn main("), "definition of `fn main` is missing!"); } + + #[test] + fn binding_imports() { + const BINDINGS_MODULE: &'static str = +r#"#define_module engine::bindings + +@group(0) @binding(0) var shadows_atlas: texture_depth_2d; +// this type is undefined, but that is fine for a test +@group(0) @binding(1) var lights: Lights;"#; + + const IMPORTING_MODULE: &'static str = +r#"#define_module engine::main +#import engine::bindings::{shadows_atlas, lights} + +fn main() -> vec4 { + // imports must be used to be generated + let atlas = shadows_atlas; + let num = lights.num; + + return vec4(1.0); +}"#; + + tracing_subscriber::fmt() + // enable everything + .with_max_level(tracing::Level::TRACE) + // sets this to be the default, global collector for this application. + .init(); + + let mut p = Processor::new(); + p.parse_module(BINDINGS_MODULE).unwrap() + .expect("failed to find bindings module def"); + let importing_mod = p.parse_module(IMPORTING_MODULE).unwrap() + .expect("failed to find main module def"); + + let out = p.process_file(&importing_mod, &IMPORTING_MODULE).unwrap(); + + assert!(out.contains("var shadows_atlas"), "missing shadows_atlas binding: {}", out); + assert!(out.contains("var lights"), "missing lights binding: {}", out); + } + + #[test] + fn binding_imports_multi_line() { + const BINDINGS_MODULE: &'static str = +r#"#define_module engine::bindings + +@group(0) @binding(0) +var shadows_atlas: texture_depth_2d; +// this type is undefined, but that is fine for a test +@group(0) @binding(1) +var lights: Lights;"#; + + const IMPORTING_MODULE: &'static str = +r#"#define_module engine::main +#import engine::bindings::{shadows_atlas, lights} + +fn main() -> vec4 { + // imports must be used to be generated + let atlas = shadows_atlas; + let num = lights.num; + + return vec4(1.0); +}"#; + + tracing_subscriber::fmt() + // enable everything + .with_max_level(tracing::Level::TRACE) + // sets this to be the default, global collector for this application. + .init(); + + let mut p = Processor::new(); + p.parse_module(BINDINGS_MODULE).unwrap() + .expect("failed to find bindings module def"); + let importing_mod = p.parse_module(IMPORTING_MODULE).unwrap() + .expect("failed to find main module def"); + + let out = p.process_file(&importing_mod, &IMPORTING_MODULE).unwrap(); + + assert!(out.contains("var shadows_atlas"), "missing shadows_atlas binding: {}", out); + assert!(out.contains("var lights"), "missing lights binding: {}", out); + } + + // ensure bindings of all constraints are parsed + #[test] + fn binding_imports_all_constraints() { + const BINDINGS_MODULE: &'static str = +r#"#define_module engine::bindings + +@group(0) @binding(0) var use_anyway: texture_depth_2d; +@group(0) @binding(1) var shadows_atlas: texture_depth_2d; +// this type is undefined, but that is fine for a test +@group(0) @binding(2) var lights: Lights; +@group(0) @binding(3) var light_grid: LightGrid; +@group(0) @binding(4) var some_storage: array; +@group(0) @binding(5) var light_indices: array; +@group(0) @binding(6) var some_buffer: array;"#; + + const IMPORTING_MODULE: &'static str = +r#"#define_module engine::main +#import engine::bindings::{use_anyway, shadows_atlas, lights, light_grid, some_storage, light_indices, some_buffer} + +fn main() -> vec4 { + // imports must be used to be generated + let anyw = use_anyway; + let atlas = shadows_atlas; + let num = lights.num; + let grid = light_grid; + let storage = some_storage; + let indices = light_indices; + let buf = some_buffer; + + return vec4(1.0); +}"#; + + tracing_subscriber::fmt() + // enable everything + .with_max_level(tracing::Level::TRACE) + // sets this to be the default, global collector for this application. + .init(); + + let mut p = Processor::new(); + p.parse_module(BINDINGS_MODULE).unwrap() + .expect("failed to find bindings module def"); + let importing_mod = p.parse_module(IMPORTING_MODULE).unwrap() + .expect("failed to find main module def"); + + let out = p.process_file(&importing_mod, &IMPORTING_MODULE).unwrap(); + + assert!(out.contains("var use_anyway"), "missing use_anyway binding: {}", out); + assert!(out.contains("var shadows_atlas"), "missing shadows_atlas binding: {}", out); + assert!(out.contains("var lights"), "missing lights binding: {}", out); + assert!(out.contains("var light_grid"), "missing light_grid binding: {}", out); + assert!(out.contains("var some_storage"), "missing some_storage binding: {}", out); + assert!(out.contains("var light_indices"), "missing light_indices binding: {}", out); + assert!(out.contains("var some_buffer"), "missing some_buffer binding: {}", out); + } } \ No newline at end of file diff --git a/src/preprocessor.rs b/src/preprocessor.rs index 3fd0bda..38ec8ac 100644 --- a/src/preprocessor.rs +++ b/src/preprocessor.rs @@ -2,7 +2,7 @@ use std::{collections::{HashMap, HashSet, VecDeque}, fmt::Write, fs, path::Path} use pest::{iterators::Pair, Parser}; use regex::Regex; -use tracing::{debug, debug_span, instrument}; +use tracing::{debug, debug_span, instrument, trace}; use crate::{recurse_files, DefRequirement, Definition, Import, Module, Rule, WgslParser}; @@ -295,7 +295,80 @@ impl Processor { requirements }, ); - } + }, + Rule::shader_binding_def => { + let mut pairs = line.clone().into_inner(); + let _group_binding = pairs.next(); + + // the next pair could be the constraint of the binding, or + // the name. We don't need the constraint, so skip it if its there. + let mut name = pairs.next().unwrap(); + if name.as_rule() == Rule::shader_binding_var_constraint { + trace!("Skipping binding constraint"); + name = pairs.next().unwrap(); + assert_eq!(name.as_rule(), Rule::shader_ident); + } + let name = name.as_str(); + + let var_type = pairs.next().unwrap(); + + let mut requirements = vec![]; + { + let mut vt_inner = var_type.into_inner(); + let inner_type = vt_inner.next().unwrap(); + + if inner_type.as_rule() == Rule::shader_external_variable { + let external_str = inner_type.as_str(); + let (module_name, name) = external_str.rsplit_once("::").unwrap(); + + if RESERVED_WORDS.contains(&name) { + continue; + } + + // Find the module that this variable is from. + // Starts by checking if the full module path was used. + // If it was not, then find the full path of the module. + let module_full_name = if module.module_imports.contains(module_name) { + Some(module_name.to_string()) + } else { + module.module_imports.iter() + .find(|m| { + m.ends_with(module_name) + }).cloned() + }; + + debug!("Binding module is {:?}", module_full_name); + + requirements.push(DefRequirement { + module: module_full_name, + name: name.to_string(), + }); + } else { // The only possibility is `Rule::shader_type` + let name = inner_type.as_str(); + + requirements.push(DefRequirement { + module: None, + name: name.to_string(), + }); + } + } + + debug!("Found binding def: `{name}`"); + + let line_span = line.as_span(); + let start_pos = line_span.start(); + let end_pos = line_span.end(); + + module.vars.insert( + name.to_string(), + Definition { + name: name.to_string(), + start_pos, + end_pos, + requirements: vec![], + }, + ); + }, Rule::shader_const_def => { let mut pairs = line.clone().into_inner(); // shader_ident is the only pair for this rule @@ -306,7 +379,7 @@ impl Processor { let start_pos = line_span.start(); let end_pos = line_span.end(); - module.constants.insert( + module.vars.insert( const_name.clone(), Definition { name: const_name, @@ -315,8 +388,7 @@ impl Processor { requirements: vec![], }, ); - } - + }, Rule::shader_code => { let mut shader_inner = line.clone().into_inner(); let code_type = shader_inner.next().unwrap(); @@ -334,7 +406,7 @@ impl Processor { }, _ => unimplemented!("ran into unhandled rule: {:?}", line.as_span()) } - } + }, Rule::cws => (), Rule::newline => (), _ => unimplemented!("ran into unhandled rule: ({:?}) {:?}", line.as_rule(), line.as_span()) @@ -531,7 +603,7 @@ fn compile_definitions(modules: &HashMap, module: &Module, outpu .unwrap_or_else(|| panic!("invalid module import: {}", module_name)); let req_def = req_module.functions.get(&req.name) - .or_else(|| req_module.constants.get(&req.name)) + .or_else(|| req_module.vars.get(&req.name)) .unwrap_or_else(|| panic!("invalid import: {} from {}", req.name, module_name)); if !req_def.requirements.is_empty() { diff --git a/src/wgsl.pest b/src/wgsl.pest index d7150b2..392e697 100644 --- a/src/wgsl.pest +++ b/src/wgsl.pest @@ -1,16 +1,12 @@ shader_ident = { (ASCII_ALPHANUMERIC | "_")+ } -// a shader generic could have multiple generics, i.e., vec4 -//shader_generic_type = { shader_ident ~ "<" ~ shader_ident ~ ">"+ } -//shader_type = { shader_generic_type | shader_ident } shader_type = { shader_ident ~ ("<" ~ shader_ident ~ ">")? } shader_module = { shader_ident ~ ( "::" ~ shader_ident)* } import_module_command = { "import" ~ ws ~ shader_module } -import_list = _{ "{" ~ shader_ident ~ (ws* ~ "," ~ ws* ~ shader_ident)* ~ "}" } +import_list = _{ "{" ~ shader_ident ~ (ws* ~ "," ~ NEWLINE? ~ ws* ~ shader_ident)* ~ "}" } import_types_command = { import_module_command ~ "::" ~ import_list } -//import_types_command = { "import" ~ shader_ident ~ ( "::" ~ shader_ident)* ~ "::" ~ import_list } import_command = _{ import_types_command | import_module_command } define_module_command = { "define_module" ~ ws ~ shader_module } @@ -30,7 +26,6 @@ shader_code = { shader_code_block | shader_code_fn_usage | shader_value | shader // usages of code from another module shader_external_variable = { shader_ident ~ ( "::" ~ shader_ident)+ } -//shader_fn_args2 = { shader_var_name_type? ~ (ws* ~ "," ~ ws* ~ shader_var_name_type)* } shader_external_fn = { shader_external_variable ~ "(" ~ shader_code_fn_arg ~ ("," ~ ws* ~ shader_code_fn_arg)* ~ ")" } shader_external_code = _{ shader_external_fn | shader_external_variable } @@ -40,9 +35,17 @@ shader_value_num = { ASCII_DIGIT+ ~ ( "." ~ ASCII_DIGIT+ )? } shader_value_bool = { "true" | "false" } shader_value = { shader_value_bool | shader_value_num } -// defines type of something i.e., `: f32`, `: u32`, etc. -shader_var_type = { ":" ~ ws* ~ shader_type } +shader_var_type = { ":" ~ ws* ~ (shader_external_variable | shader_type) } shader_const_def = { "const" ~ ws ~ shader_ident ~ (ws* ~ shader_var_type)? ~ ws* ~ "=" ~ ws* ~ shader_value ~ ";" } + +shader_group_binding = { "@group(" ~ NUMBER ~ ") @binding(" ~ NUMBER ~ ")" } +shader_binding_var_constraint = { + "<" ~ + ( "uniform" | "private" | "workgroup" | ("storage" ~ ("," ~ ws? ~ "read" ~ "_write"? )?) ) ~ + ">" +} +shader_binding_def = { shader_group_binding ~ (NEWLINE | ws)? ~ "var" ~ shader_binding_var_constraint? ~ ws ~ shader_ident ~ ws* ~ shader_var_type ~ ";" } + shader_var_name_type = { shader_ident ~ shader_var_type } shader_fn_args = { "(" ~ shader_var_name_type? ~ (ws* ~ "," ~ ws* ~ shader_var_name_type)* ~ ")" } // the body of a function, including the opening and closing brackets @@ -55,7 +58,8 @@ shader_fn_def = { // a line of shader code, including white space shader_code_line = { shader_fn_def | - (shader_const_def ~ newline) | + (shader_const_def ~ newline?) | + (shader_binding_def ~ newline?) | ws* ~ newline } //shader_code_line = { shader_fn_def | shader_const_def | (ws* ~ (shader_external_code | shader_code)* ~ ws*) }