From e85adce6e3b08c344ef305b34bc421f9f8d0fdf0 Mon Sep 17 00:00:00 2001 From: SeanOMik Date: Sat, 3 Aug 2024 16:48:02 -0400 Subject: [PATCH] output includes recursively, find dependencies of imported functions and output those as well --- shaders/base.wgsl | 2 +- src/compiler.rs | 35 +++ src/main.rs | 464 +++++---------------------------------- src/preprocessor.rs | 512 ++++++++++++++++++++++++++++++++++++++++++++ src/wgsl.pest | 39 ++-- 5 files changed, 625 insertions(+), 427 deletions(-) create mode 100644 src/compiler.rs create mode 100644 src/preprocessor.rs diff --git a/shaders/base.wgsl b/shaders/base.wgsl index e11535e..746eb84 100644 --- a/shaders/base.wgsl +++ b/shaders/base.wgsl @@ -1,5 +1,5 @@ #define_module base -#import simple +#import simple::{do_something_cool} fn main() -> vec4 { let a = do_something_cool(10.0); diff --git a/src/compiler.rs b/src/compiler.rs new file mode 100644 index 0000000..7595ab9 --- /dev/null +++ b/src/compiler.rs @@ -0,0 +1,35 @@ +use std::collections::HashMap; + +use crate::{Definition, Module, PreprocessorError, Processor}; + +/// Compile a module including its imports into a single module. +#[derive(Default)] +pub struct Compiler { + preprocessor: Processor, +} + +impl Compiler { + /// Add a module to the compiler + /// + /// Returns `None` if the module does not define an include identifier. + pub fn add_module(&mut self, module_src: &str) -> Result, PreprocessorError> { + self.preprocessor.parse_module(module_src) + } + + pub fn compile_module(self, module_src: &str) -> Result { + todo!() + } +} + +/* pub struct Source { + +} */ + +pub struct ExpandableModule { + /// The name of the module. + name: String, + /// Constants that this module defines + pub constants: HashMap, + /// Functions that this module defines + pub functions: HashMap, +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 9b00ff0..b15c15e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,28 +9,15 @@ use itertools::Itertools; use pest::Parser; use pest_derive::Parser; +mod preprocessor; +pub use preprocessor::*; + +mod compiler; +pub use compiler::*; + #[derive(Parser)] #[grammar = "wgsl.pest"] -pub struct WgslParser; - -#[derive(Debug, thiserror::Error)] -pub enum PreprocessorError { - #[error("{0}")] - IoError(#[from] std::io::Error), - #[error("error parsing {path}: {err}")] - ParserError { - path: PathBuf, - err: pest::error::Error, - }, - #[error("failure formatting preprocessor output to string ({0})")] - FormatError(#[from] std::fmt::Error), - #[error("unknown module import '{module}', in {from_path}")] - UnknownModule { from_path: PathBuf, module: String }, - #[error("in {from_path}: unknown import from '{module}': `{item}`")] - UnknownImport { from_path: PathBuf, module: String, item: String }, - #[error("import usage from `{from_module}` conflicts with local variable/function: `{name}`")] - ConflictingImport { from_module: String, name: String }, -} +pub(crate) struct WgslParser; fn main() { /* let mut successful_parse = WgslParser::parse(Rule::command_line, "#define_module inner::some_include").unwrap(); @@ -40,11 +27,18 @@ fn main() { let mut p = Processor::new(); //let f = p.parse_modules("shaders", ["wgsl"]).unwrap(); //println!("Parsed {} modules:", f); - p.parse_module("shaders/inner_include.wgsl") + let inner_include_src = fs::read_to_string("shaders/inner_include.wgsl").unwrap(); + p.parse_module(&inner_include_src) .unwrap() .expect("failed to find module"); - p.parse_module("shaders/simple.wgsl") + let simple_include_src = fs::read_to_string("shaders/simple.wgsl").unwrap(); + p.parse_module(&simple_include_src) + .unwrap() + .expect("failed to find module"); + + let base_include_src = fs::read_to_string("shaders/base.wgsl").unwrap(); + let base_module_path = p.parse_module(&base_include_src) .unwrap() .expect("failed to find module"); @@ -60,16 +54,20 @@ fn main() { } for (name, def) in &module.functions { - println!(" fn {name}, {}-{}", def.start_pos, def.end_pos); + let requires: Vec = def.requirements.iter().map(|r| { + let pre = r.module.as_ref().map(|m| format!("{m}::")).unwrap_or_default(); + format!("{}{}", pre, r.name) + }).collect(); + println!(" fn {name}, {}-{}. requires: {:?}", def.start_pos, def.end_pos, requires); } println!(" imported modules: {:?}", module.module_imports); - if !module.type_imports.is_empty() { + if !module.item_imports.is_empty() { println!(" type imports:"); } - for (module, usages) in &module.type_imports { + for (module, usages) in &module.item_imports { println!(" {}: {:?}", module, usages.imports); } @@ -89,7 +87,8 @@ fn main() { } } - let out = p.process_file("shaders/simple.wgsl").unwrap(); + let base_include_src = fs::read_to_string("shaders/base.wgsl").unwrap(); + let out = p.process_file(&base_module_path, &base_include_src).unwrap(); fs::write("out.wgsl", out).unwrap(); } @@ -99,6 +98,7 @@ pub enum ExternalUsageType { Function, } +#[derive(Clone)] pub struct ExternalUsage { name: String, ty: ExternalUsageType, @@ -106,412 +106,60 @@ pub struct ExternalUsage { start_pos: usize, } +#[derive(Clone)] pub struct ImportUsage { module: String, imports: Vec, } +#[derive(Clone)] pub struct Import { module: String, imports: Vec, } +#[derive(Clone)] +pub struct DefRequirement { + /// None if the requirement is local + module: Option, + name: String, + ty: ExternalUsageType, +} + +#[derive(Clone)] pub struct Definition { name: String, + requirements: Vec, /// The start byte position as a `usize`. start_pos: usize, /// The end byte position as a `usize`. end_pos: usize, } -#[derive(Default)] +#[derive(Default, Clone)] pub struct Module { + /// The name of the module. name: String, - path: String, - constants: HashMap, - functions: HashMap, - module_imports: HashSet, - type_imports: HashMap, + /// The source code of the module, non-processed. + src: String, + /// Constants that this module defines + pub constants: HashMap, + /// Functions that this module defines + pub functions: HashMap, + /// Imports of things per module + /// ie `other_module::{scalar, do_math_func}` + item_imports: HashMap, + /// usages of imported things + /// ie `other_module::scalar` import_usages: HashMap, -} - -#[derive(Default)] -pub struct Processor { - modules: HashMap, -} - -impl Processor { - pub fn new() -> Self { - Self::default() - } - - /// Parse a module file to attempt to find the include identifier. - /// - /// Returns `None` if the module does not define an include identifier. - pub fn parse_module>( - &mut self, - path: P, - ) -> Result, PreprocessorError> { - let unparsed_file = fs::read_to_string(path.as_ref())?; - - // add a new line to the end of the input to make the grammar happy - //let unparsed_file = format!("{unparsed_file}\n"); - - let file = WgslParser::parse(Rule::file, &unparsed_file) - .map_err(|e| PreprocessorError::ParserError { - path: path.as_ref().to_path_buf(), - err: e, - })? - .next() - .unwrap(); // get and unwrap the `file` rule; never fails - - let mut module = Module::default(); - module.path = path.as_ref().to_str().unwrap().into(); - - for record in file.into_inner() { - match record.as_rule() { - Rule::command_line => { - // the parser has found a preprocessor command, figure out what it is - let mut pairs = record.into_inner(); - let command_line = pairs.next().unwrap(); - - match command_line.as_rule() { - Rule::import_types_command => { - let mut inner = command_line.into_inner(); - let import_module_command = inner.next().unwrap(); - - let mut import_module_command = import_module_command.into_inner(); - let module_name = import_module_command.next().unwrap().as_str(); - - let types: Vec = inner.map(|t| t.as_str().to_string()).collect(); - - println!("found import of types from `{}`: `{:?}`", module_name, types); - - // add these type imports to imports of the module - module.type_imports.entry(module_name.into()) - .or_insert_with(|| Import { - module: module_name.into(), - imports: vec![], - }) - .imports.extend(types.into_iter()); - }, - Rule::import_module_command => { - let mut inner = command_line.into_inner(); - let module_name = inner.next().unwrap().as_str(); - println!("found import of module: {}", module_name); - - module.module_imports.insert(module_name.into()); - }, - Rule::define_module_command => { - let mut shader_file_pairs = command_line.into_inner(); - let shader_file = shader_file_pairs.next().unwrap(); - let shader_file = shader_file.as_str().to_string(); - - module.name = shader_file; - } - _ => unreachable!(), - } - } - Rule::shader_code_line => { - for line in record.into_inner() { - let (pos_line, pos_col) = line.line_col(); - - match line.as_rule() { - Rule::shader_fn_def => { - let mut pairs = line.clone().into_inner(); - // shader_ident is the only pair for this rule - let fn_name = pairs.next().unwrap().as_str().to_string(); - println!("fn: {fn_name:?}"); - - /* let fn_args = pairs.next().unwrap().as_str(); - let ret_type = pairs.next().unwrap().as_str(); - let fn_body_pair = pairs.next().unwrap(); */ - - let line_span = line.as_span(); - let start_pos = line_span.start(); - let end_pos = line_span.end(); - - module.functions.insert( - fn_name.clone(), - Definition { - name: fn_name, - start_pos, - end_pos, - }, - ); - } - Rule::shader_const_def => { - let mut pairs = line.clone().into_inner(); - // shader_ident is the only pair for this rule - let const_name = pairs.next().unwrap().as_str().to_string(); - println!("const: {const_name:?}"); - - let line_span = line.as_span(); - let start_pos = line_span.start(); - let end_pos = line_span.end(); - - module.constants.insert( - const_name.clone(), - Definition { - name: const_name, - start_pos, - end_pos, - }, - ); - } - Rule::shader_external_fn => { - let mut pairs = line.into_inner(); - // shader_external_variable is the only pair for this rule - let ident_name = pairs.next().unwrap().as_str().to_string(); - - println!("external fn: {ident_name}"); - } - Rule::shader_external_variable => { - let pairs = line.into_inner(); - // shader_external_variable is the only pair for this rule - let ident_name = pairs.as_str(); - - println!("external var: {ident_name}"); - } - Rule::shader_code => { - println!("code: {}", line.as_str()); - } - Rule::cws => (), - Rule::newline => (), - _ => unimplemented!("ran into unhandled rule: {:?}", line.as_span()) - } - } - } - Rule::newline => (), - Rule::EOI => (), - _ => unimplemented!("ran into unhandled rule: {:?}", record.as_span()) - } - } - - if module.name.is_empty() { - Ok(None) - } else { - let name = module.name.clone(); - self.modules.insert(name.clone(), module); - - Ok(Some(name)) - } - } - - /// Find files recursively in `path` with an extension in `extensions`, and parse them. - /// - /// For each file that's found, [`Processor::parse_module`] is used to parse them. - /// - /// Parameters: - /// * `path` - The path to search for files in. - /// * `extensions` - The extensions that the discovered files must have. Make sure they have - /// no leading '.' - pub fn parse_modules, const N: usize>( - &mut self, - path: P, - extensions: [&str; N], - ) -> Result { - //debug_assert!(!extension.starts_with("."), "remove leading '.' from extension"); - - let files = recurse_files(path)?; - - let mut parsed = 0; - - for file in files { - if let Some(ext) = file.extension().and_then(|p| p.to_str()) { - if extensions.contains(&ext) { - self.parse_module(file)?; - parsed += 1; - } - } - } - - Ok(parsed) - } - - pub fn process_file>(&mut self, path: P) -> Result { - let unparsed_file = fs::read_to_string(path.as_ref())?; - // add a new line to the end of the input to make the grammar happy - let unparsed_file = format!("{unparsed_file}\n"); - - let mut out_string = String::new(); - - let file = WgslParser::parse(Rule::file, &unparsed_file) - .map_err(|e| PreprocessorError::ParserError { - path: path.as_ref().to_path_buf(), - err: e, - })? - .next() - .unwrap(); // get and unwrap the `file` rule; never fails - - for record in file.into_inner() { - match record.as_rule() { - Rule::command_line => { - // the parser has found a preprocessor command, figure out what it is - let mut pairs = record.into_inner(); - let command_line = pairs.next().unwrap(); - - match command_line.as_rule() { - Rule::import_module_command => { - let mut shader_file_pairs = command_line.into_inner(); - let shader_file = shader_file_pairs.next().unwrap(); - let shader_file = shader_file.as_str(); - - println!("found module import: {}", shader_file); - let imported_mod = self.modules.get(shader_file).ok_or_else(|| { - PreprocessorError::UnknownModule { - from_path: path.as_ref().to_path_buf(), - module: shader_file.into(), - } - })?; - - let included_file = self.process_file(imported_mod.path.clone())?; - - let start_header = - format!("// ==== START OF INCLUDE OF '{}' ====", shader_file); - let end_header = - format!("\n// ==== END OF INCLUDE OF '{}' ====\n", shader_file); - - out_string.write_str(&start_header)?; - out_string.write_str(&included_file)?; - out_string.write_str(&end_header)?; - } - Rule::import_types_command => { - let mut import_command = command_line.into_inner(); - let import_module_command = import_command.next().unwrap(); - let module_path = import_module_command.into_inner().next().unwrap(); - let module_path = module_path.as_str(); - - let importing_from_mod = self.modules.get(module_path).ok_or_else(|| { - PreprocessorError::UnknownModule { - from_path: path.as_ref().to_path_buf(), - module: module_path.into(), - } - })?; - - let module_raw_src = fs::read_to_string(&importing_from_mod.path)?; - - for import in import_command { - let import_ident = import.as_str(); - - let def = importing_from_mod.functions.get(import_ident) - .or_else(|| importing_from_mod.constants.get(import_ident)) - .ok_or_else(|| { - PreprocessorError::UnknownImport { - from_path: path.as_ref().to_path_buf(), - module: module_path.into(), - item: import_ident.to_string(), - } - })?; - - let import_text = &module_raw_src[def.start_pos..def.end_pos]; - //println!("must add:\n{import_text}"); - - out_string.write_fmt(format_args!("// START OF IMPORT ITEM {} FROM {}\n", import_ident, module_path))?; - out_string.write_str(import_text)?; - out_string.write_fmt(format_args!("\n// END OF IMPORT ITEM {} FROM {}\n\n", import_ident, module_path))?; - } - - //todo!(); - - /* importing_from_mod.functions.get() - - let imports: Vec<&str> = - shader_file_pairs.map(|i| i.as_str()).collect(); - - println!("found module import: {}", module_path); - println!("imports: {imports:?}"); - todo!(); */ - /* let imported_mod = self.modules.get(shader_file).ok_or_else(|| { - PreprocessorError::UnknownModule { - from_path: path.as_ref().to_path_buf(), - module: shader_file.into(), - } - })?; */ - } - Rule::define_module_command => (), - _ => unimplemented!("ran into unhandled rule: {:?}", command_line.as_span()), - } - } - Rule::cws => (), - Rule::shader_code_line => { - for line in record.into_inner() { - let (pos_line, pos_col) = line.line_col(); - - match line.as_rule() { - Rule::shader_external_fn => { - let mut pairs = line.into_inner(); - // shader_external_variable is the only pair for this rule - let ident_name = pairs.next().unwrap().as_str().to_string(); - - if let Some((module_name, ident)) = ident_name.rsplit_once("::") { - /* let usage = ExternalUsage { - name: ident.into(), - ty: ExternalUsageType::Function, - line: pos_line, - col: pos_col - }; */ - - out_string.write_str(ident)?; - } else { - // TODO: not really sure how this would get triggered - unimplemented!( - "this function is actually not external, i think" - ); - } - } - Rule::shader_external_variable => { - let pairs = line.into_inner(); - // shader_external_variable is the only pair for this rule - let ident_name = pairs.as_str(); - - if let Some((module_name, ident)) = ident_name.rsplit_once("::") { - /* let usage = ExternalUsage { - name: ident.into(), - ty: ExternalUsageType::Variable, - line: pos_line, - col: pos_col - }; */ - - out_string.write_str(ident)?; - } else { - // TODO: not really sure how this would get triggered - unimplemented!( - "this function is actually not external, i think" - ); - } - } - /* Rule::shader_fn_def => (), - Rule::shader_const_def => (), - Rule::shader_code => { */ - Rule::shader_code | Rule::shader_const_def | Rule::shader_fn_def => { - let input = line.as_str(); - out_string.write_str(&input)?; - } - Rule::cws => { - let input = line.as_str(); - out_string.write_str(&input)?; - } - Rule::newline => (), - _ => unimplemented!("ran into unhandled rule: {:?}", line.as_span()), - } - } - } - Rule::newline => { - let input = record.as_str(); - out_string.write_str(&input)?; - } - Rule::EOI => (), - _ => unimplemented!("ran into unhandled rule: {:?}", record.as_span()), - } - } - - Ok(out_string) - } + /// Imports of modules + /// + /// These modules are used along side `import_usages` + module_imports: HashSet, } /// Recursively find files in `path`. -fn recurse_files(path: impl AsRef) -> std::io::Result> { +pub(crate) fn recurse_files(path: impl AsRef) -> std::io::Result> { let mut buf = vec![]; let entries = fs::read_dir(path)?; diff --git a/src/preprocessor.rs b/src/preprocessor.rs new file mode 100644 index 0000000..c5e3a3e --- /dev/null +++ b/src/preprocessor.rs @@ -0,0 +1,512 @@ +use std::{collections::{HashMap, HashSet, VecDeque}, fmt::Write, fs, path::{Path, PathBuf}}; + +use pest::{iterators::Pair, Parser}; + +use crate::{recurse_files, DefRequirement, Definition, ExternalUsageType, Import, Module, Rule, WgslParser}; + +const RESERVED_WORDS: [&str; 171] = [ + "NULL", "Self", "abstract", "active", "alignas", "alignof", "as", "asm", "asm_fragment", + "async", "attribute", "auto", "await", "become", "binding_array", "cast", "catch", "class", + "co_await", "co_return", "co_yield", "coherent", "column_major", "common", "compile", + "compile_fragment", "concept", "const_cast", "consteval", "constexpr", "constinit", "crate", + "debugger", "decltype", "delete", "demote", "demote_to_helper", "do", "dynamic_cast", "enum", + "explicit", "export", "extends", "extern", "external", "fallthrough", "filter", "final", + "finally", "friend", "from", "fxgroup", "get", "goto", "groupshared", "highp", "impl", + "implements", "import", "inline", "instanceof", "interface", "layout", "lowp", "macro", + "macro_rules", "match", "mediump", "meta", "mod", "module", "move", "mut", "mutable", + "namespace", "new", "nil", "noexcept", "noinline", "nointerpolation", "noperspective", "null", + "nullptr", "of", "operator", "package", "packoffset", "partition", "pass", "patch", + "pixelfragment", "precise", "precision", "premerge", "priv", "protected", "pub", "public", + "readonly", "ref", "regardless", "register", "reinterpret_cast", "require", "resource", + "restrict", "self", "set", "shared", "sizeof", "smooth", "snorm", "static", "static_assert", + "static_cast", "std", "subroutine", "super", "target", "template", "this", "thread_local", + "throw", "trait", "try", "type", "typedef", "typeid", "typename", "typeof", "union", "unless", + "unorm", "unsafe", "unsized", "use", "using", "varying", "virtual", "volatile", "wgsl", + "where", "with", "writeonly", "yield", "alias", "break", "case", "const", "const_assert", + "continue", "continuing", "default", "diagnostic", "discard", "else", "enable", "false", "fn", + "for", "if", "let", "loop", "override", "requires", "return", "struct", "switch", "true", + "var", "while" +]; + +#[derive(Debug, thiserror::Error)] +pub enum PreprocessorError { + #[error("{0}")] + IoError(#[from] std::io::Error), + #[error("error parsing {0}")] + ParserError(#[from] pest::error::Error), + #[error("failure formatting preprocessor output to string ({0})")] + FormatError(#[from] std::fmt::Error), + #[error("unknown module import '{module}', in {from_module}")] + UnknownModule { from_module: String, module: String }, + #[error("in {from_module}: unknown import from '{module}': `{item}`")] + UnknownImport { from_module: String, module: String, item: String }, + #[error("import usage from `{from_module}` conflicts with local variable/function: `{name}`")] + ConflictingImport { from_module: String, name: String }, +} + +#[derive(Default)] +pub struct Processor { + pub modules: HashMap, +} + +impl Processor { + pub fn new() -> Self { + Self::default() + } + + fn get_imports_in_block(&mut self, module: &mut Module, block: Pair, found_requirements: &mut HashSet) -> Vec { + let mut requirements = vec![]; + + for code in block.into_inner() { + match code.as_rule() { + Rule::shader_code_block | Rule::shader_code => { + let reqs = self.get_imports_in_block(module, code, found_requirements); + requirements.extend(reqs.into_iter()); + }, + Rule::shader_code_fn_usage => { + let mut usage_inner = code.into_inner(); + let fn_name = usage_inner.next().unwrap().as_str(); + let fn_args: Vec<&str> = usage_inner.map(|a| a.as_str()).collect(); + + if found_requirements.contains(fn_name) { + continue; + } + found_requirements.insert(fn_name.to_string()); + + println!("Found call to {} with args: {:?}", fn_name, fn_args); + + // ignore reserved words + if RESERVED_WORDS.contains(&fn_name) { + continue; + } + + let req = DefRequirement { + // module is discovered later + module: None, + name: fn_name.to_string(), + ty: ExternalUsageType::Function, + }; + requirements.push(req); + }, + Rule::shader_external_fn => { + let mut pairs = code.into_inner(); + // shader_external_variable is the only pair for this rule + let ident_name = pairs.next().unwrap().as_str().to_string(); + + println!("external fn: {ident_name}"); + }, + Rule::shader_external_variable => { + let pairs = code.into_inner(); + // shader_external_variable is the only pair for this rule + let ident_name = pairs.as_str(); + + println!("external var: {ident_name}"); + }, + Rule::newline => (), + Rule::cws => (), + Rule::shader_code_char => (), + Rule::shader_ident => { + let ident = code.as_str(); + println!("Found usage of ident: {}", ident); + + if found_requirements.contains(ident) { + continue; + } + found_requirements.insert(ident.to_string()); + + // ignore reserved words + if RESERVED_WORDS.contains(&ident) { + continue; + } + + let req = DefRequirement { + // module is discovered later + module: None, + name: ident.to_string(), + ty: ExternalUsageType::Variable, + }; + requirements.push(req); + }, + Rule::shader_value => (), + _ => unimplemented!("ran into unhandled rule: {:?}, {:?}", code.as_rule(), code.as_span()) + } + } + + requirements + } + + /// Parse a module file to attempt to find the include identifier. + /// + /// Returns `None` if the module does not define an include identifier. + pub fn parse_module( + &mut self, + module_src: &str, + ) -> Result, PreprocessorError> { + //let unparsed_file = fs::read_to_string(path.as_ref())?; + + // add a new line to the end of the input to make the grammar happy + //let unparsed_file = format!("{unparsed_file}\n"); + + let file = WgslParser::parse(Rule::file, &module_src)? + .next() + .unwrap(); // get and unwrap the `file` rule; never fails + + let mut module = Module::default(); + module.src = module_src.to_string(); + + for record in file.into_inner() { + match record.as_rule() { + Rule::command_line => { + // the parser has found a preprocessor command, figure out what it is + let mut pairs = record.into_inner(); + let command_line = pairs.next().unwrap(); + + match command_line.as_rule() { + Rule::import_types_command => { + let mut inner = command_line.into_inner(); + let import_module_command = inner.next().unwrap(); + + let mut import_module_command = import_module_command.into_inner(); + let module_name = import_module_command.next().unwrap().as_str(); + + let types: Vec = inner.map(|t| t.as_str().to_string()).collect(); + + println!("found import of types from `{}`: `{:?}`", module_name, types); + + // add these type imports to imports of the module + module.item_imports.entry(module_name.into()) + .or_insert_with(|| Import { + module: module_name.into(), + imports: vec![], + }) + .imports.extend(types.into_iter()); + }, + Rule::import_module_command => { + let mut inner = command_line.into_inner(); + let module_name = inner.next().unwrap().as_str(); + println!("found import of module: {}", module_name); + + module.module_imports.insert(module_name.into()); + }, + Rule::define_module_command => { + let mut shader_file_pairs = command_line.into_inner(); + let shader_file = shader_file_pairs.next().unwrap(); + let shader_file = shader_file.as_str().to_string(); + + module.name = shader_file; + } + _ => unreachable!(), + } + } + Rule::shader_code_line => { + for line in record.into_inner() { + match line.as_rule() { + Rule::shader_fn_def => { + let mut pairs = line.clone().into_inner(); + // shader_ident is the only pair for this rule + let fn_name = pairs.next().unwrap().as_str().to_string(); + println!("fn def: {fn_name:?}"); + + let fn_body = pairs.skip(2).next().unwrap(); + let mut found_reqs = HashSet::default(); + let requirements = self.get_imports_in_block(&mut module, fn_body, &mut found_reqs); + + let line_span = line.as_span(); + let start_pos = line_span.start(); + let end_pos = line_span.end(); + + module.functions.insert( + fn_name.clone(), + Definition { + name: fn_name, + start_pos, + end_pos, + requirements + }, + ); + } + Rule::shader_const_def => { + let mut pairs = line.clone().into_inner(); + // shader_ident is the only pair for this rule + let const_name = pairs.next().unwrap().as_str().to_string(); + println!("const def: {const_name:?}"); + + let line_span = line.as_span(); + let start_pos = line_span.start(); + let end_pos = line_span.end(); + + module.constants.insert( + const_name.clone(), + Definition { + name: const_name, + start_pos, + end_pos, + requirements: vec![], + }, + ); + } + + Rule::shader_code => { + let mut shader_inner = line.clone().into_inner(); + let code_type = shader_inner.next().unwrap(); + + match code_type.as_rule() { + Rule::shader_code_block => { + let mut found_reqs = HashSet::default(); + self.get_imports_in_block(&mut module, code_type, &mut found_reqs); + }, + Rule::shader_code_fn_usage => { + todo!("cannot handle usage at this level"); + }, + Rule::shader_code_char => { + todo!("I think this can be ignored"); + }, + _ => unimplemented!("ran into unhandled rule: {:?}", line.as_span()) + } + + + println!("code: {}", line.as_str()); + } + Rule::cws => (), + Rule::newline => (), + _ => unimplemented!("ran into unhandled rule: {:?}", line.as_span()) + } + } + } + Rule::newline => (), + Rule::EOI => (), + _ => unimplemented!("ran into unhandled rule: {:?}", record.as_span()) + } + } + + if module.name.is_empty() { + Ok(None) + } else { + let name = module.name.clone(); + self.modules.insert(name.clone(), module); + + Ok(Some(name)) + } + } + + /// Find files recursively in `path` with an extension in `extensions`, and parse them. + /// + /// For each file that's found, [`Processor::parse_module`] is used to parse them. + /// + /// Parameters: + /// * `path` - The path to search for files in. + /// * `extensions` - The extensions that the discovered files must have. Make sure they have + /// no leading '.' + pub fn parse_modules, const N: usize>( + &mut self, + path: P, + extensions: [&str; N], + ) -> Result { + let files = recurse_files(path)?; + + let mut parsed = 0; + + for file in files { + if let Some(ext) = file.extension().and_then(|p| p.to_str()) { + if extensions.contains(&ext) { + let module_src = fs::read_to_string(file)?; + + self.parse_module(&module_src)?; + parsed += 1; + } + } + } + + Ok(parsed) + } + + fn generate_header(&mut self, module_path: &str) -> String { + let module = self.modules.get(module_path).unwrap(); + + let mut output = String::new(); + compile_definitions(&self.modules, module, &mut output); + + output + } + + fn output_shader_code_line(&self, shader_code_rule: Pair, output: &mut String) -> Result<(), std::fmt::Error> { + for line in shader_code_rule.into_inner() { + let (pos_line, pos_col) = line.line_col(); + + match line.as_rule() { + Rule::shader_external_fn | Rule::shader_external_variable => { + let mut pairs = line.into_inner(); + // shader_external_variable is the only pair for this rule + let ident_name = pairs.next().unwrap().as_str().to_string(); + + // remove the module from the identifier and write it to the output + if let Some((module_name, ident)) = ident_name.rsplit_once("::") { + output.write_str(ident)?; + } else { + // TODO: not really sure how this would get triggered + unimplemented!( + "this function is actually not external, i think" + ); + } + }, + Rule::shader_fn_def => { + let mut rule_inner = line.into_inner(); + + let fn_name = rule_inner.next().unwrap().as_str(); + let args = rule_inner.next().unwrap().as_str(); + let fn_ret = rule_inner.next().unwrap().as_str(); + let fn_body = rule_inner.next().unwrap(); + + let mut body_output = String::new(); + self.output_shader_code_line(fn_body, &mut body_output)?; + + // to escape {, must use two + output.write_fmt(format_args!("fn {}{} -> {} {{{}}}", + fn_name, args, fn_ret, body_output))?; + }, + Rule::shader_code | Rule::shader_const_def => { + self.output_shader_code_line(line, output)?; + }, + Rule::shader_code_fn_usage | Rule::shader_value | Rule::shader_ident | Rule::shader_code_char | Rule::cws => { + let input = line.as_str(); + output.write_str(&input)?; + }, + Rule::newline => { + output.write_str("\n")?; + }, + _ => unimplemented!("ran into unhandled rule: {:?}", line.as_rule()), + } + } + + Ok(()) + } + + pub fn process_file(&mut self, module_path: &str, module_src: &str) -> Result { + let mut out_string = String::new(); + // the output will be at least the length of module_src + out_string.reserve(module_src.len()); + + let file = WgslParser::parse(Rule::file, &module_src)? + .next() + .unwrap(); // get and unwrap the `file` rule; never fails + + let header = self.generate_header(module_path); + out_string.write_str("// START OF IMPORT HEADER\n")?; + out_string.write_str(&header)?; + out_string.write_str("// END OF IMPORT HEADER\n")?; + + for record in file.into_inner() { + match record.as_rule() { + Rule::command_line => { + // the parser has found a preprocessor command, figure out what it is + let mut pairs = record.into_inner(); + let command_line = pairs.next().unwrap(); + + match command_line.as_rule() { + Rule::import_module_command => (), + Rule::import_types_command => (), + Rule::define_module_command => (), + _ => unimplemented!("ran into unhandled rule: {:?}", command_line.as_span()), + } + }, + Rule::cws => (), + Rule::shader_code_line => { + self.output_shader_code_line(record, &mut out_string)?; + /* for line in record.into_inner() { + let (pos_line, pos_col) = line.line_col(); + + match line.as_rule() { + Rule::shader_external_fn | Rule::shader_external_variable => { + let mut pairs = line.into_inner(); + // shader_external_variable is the only pair for this rule + let ident_name = pairs.next().unwrap().as_str().to_string(); + + // remove the module from the identifier and write it to the output + if let Some((module_name, ident)) = ident_name.rsplit_once("::") { + out_string.write_str(ident)?; + } else { + // TODO: not really sure how this would get triggered + unimplemented!( + "this function is actually not external, i think" + ); + } + } + Rule::shader_code | Rule::shader_const_def => { + let input = line.as_str(); + out_string.write_str(&input)?; + }, + /* Rule::shader_fn_def => { + + }, */ + Rule::cws => { + let input = line.as_str(); + out_string.write_str(&input)?; + }, + Rule::newline => (), + _ => unimplemented!("ran into unhandled rule: {:?}", line.as_rule()), + } + } */ + }, + Rule::newline => { + let input = record.as_str(); + out_string.write_str(&input)?; + }, + Rule::EOI => (), + _ => unimplemented!("ran into unhandled rule: {:?}", record.as_rule()), + } + } + + Ok(out_string) + } +} + +fn try_find_requirement_module(module: &Module, req_name: &str) -> Option { + for import in module.item_imports.values() { + if import.imports.contains(&req_name.to_string()) { + return Some(import.module.clone()); + } + } + + None +} + +fn compile_definitions(modules: &HashMap, module: &Module, output: &mut String) -> Result<(), PreprocessorError> { + for (_, funcs) in &module.functions { + let mut requirements = VecDeque::from(funcs.requirements.clone()); + + while let Some(mut req) = requirements.pop_front() { + if req.module.is_none() { + let mod_name = try_find_requirement_module(&module, &req.name); + req.module = mod_name; + } + + if let Some(module_name) = &req.module { + let req_module = modules.get(module_name) + .unwrap_or_else(|| panic!("invalid module import: {}", module_name)); + + let req_def = req_module.functions.get(&req.name) + .or_else(|| req_module.constants.get(&req.name)) + .unwrap_or_else(|| panic!("invalid import: {} from {}", req.name, module_name)); + + let sub_req_names: Vec = req_def.requirements.iter().map(|r| r.name.clone()).collect(); + println!("got req: {}, subreqs: {:?}", req_def.name, sub_req_names); + + if !req_def.requirements.is_empty() { + let mut requirements_output = String::new(); + compile_definitions(modules, req_module, &mut requirements_output)?; + + output.write_fmt(format_args!("\n// REQUIREMENTS OF {}::{}\n", module_name, req.name))?; + output.push_str(&requirements_output); + output.push_str("\n"); + } + + let func_src = &req_module.src[req_def.start_pos..req_def.end_pos]; + output.write_fmt(format_args!("// SOURCE {}::{}\n", module_name, req.name))?; + output.push_str(func_src); + output.push_str("\n"); + } + } + } + + Ok(()) +} \ No newline at end of file diff --git a/src/wgsl.pest b/src/wgsl.pest index 33ddbf2..f75ca38 100644 --- a/src/wgsl.pest +++ b/src/wgsl.pest @@ -22,23 +22,11 @@ command_line = { preproc_prefix ~ (define_module_command | import_command) ~ new // all characters used by wgsl shader_code_char = { "@" | "-" | "+" | "*" | "/" | "=" | "(" | ")" | ">" | "<" | ";" | ":" | "." | "_" | "," } shader_code_block = { "{" ~ newline* ~ (cws* ~ (shader_actual_code_line ~ cws*)* ~ newline)+ ~ cws* ~ "}" } -shader_code = { shader_code_block | shader_code_char | ASCII_ALPHANUMERIC+ } - -shader_value_num = { ASCII_DIGIT* ~ ( "." ~ ASCII_DIGIT* )? } -shader_value_bool = { "true" | "false" } -shader_value = { shader_value_bool | shader_value_num } - -// defines type of something i.e., `: f32`, `: u32`, etc. -shader_var_type = { ":" ~ ws* ~ shader_type } -shader_const_def = { "const" ~ ws ~ shader_ident ~ (ws* ~ shader_var_type)? ~ ws* ~ "=" ~ ws* ~ shader_value ~ ";" } -shader_var_name_type = { shader_ident ~ shader_var_type } -shader_fn_args = { "(" ~ shader_var_name_type ~ (ws* ~ "," ~ ws* ~ shader_var_name_type)* ~ ")" } -// the body of a function, including the opening and closing brackets -shader_fn_body = { "{" ~ newline* ~ (cws* ~ (shader_actual_code_line ~ cws*)* ~ newline)+ ~ "}" } - -shader_fn_def = { - "fn" ~ ws ~ shader_ident ~ shader_fn_args ~ ws ~ "->" ~ ws ~ shader_type ~ ws ~ shader_fn_body -} +// an fn argument can be another function use +shader_code_fn_arg = _{ shader_code_fn_usage | shader_value | shader_ident } +shader_code_fn_usage = { shader_ident ~ "(" ~ shader_code_fn_arg ~ ("," ~ ws* ~ shader_code_fn_arg)* ~ ")" } +//shader_code_fn_usage = { shader_ident ~ "(in, 2.0)" } +shader_code = { shader_code_block | shader_code_fn_usage | shader_value | shader_ident | shader_code_char } // usages of code from another module shader_external_variable = { shader_ident ~ ( "::" ~ shader_ident)+ } @@ -46,7 +34,22 @@ shader_external_fn = { shader_external_variable ~ "(" ~ ANY* ~ ")" } shader_external_code = _{ shader_external_fn | shader_external_variable } shader_actual_code_line = _{ shader_external_code | shader_code } -//shader_actual_code_line = _{ cws* ~ ( (shader_external_code | shader_code) ~ cws*)* } + +shader_value_num = { ASCII_DIGIT+ ~ ( "." ~ ASCII_DIGIT+ )? } +shader_value_bool = { "true" | "false" } +shader_value = { shader_value_bool | shader_value_num } + +// defines type of something i.e., `: f32`, `: u32`, etc. +shader_var_type = { ":" ~ ws* ~ shader_type } +shader_const_def = { "const" ~ ws ~ shader_ident ~ (ws* ~ shader_var_type)? ~ ws* ~ "=" ~ ws* ~ shader_value ~ ";" } +shader_var_name_type = { shader_ident ~ shader_var_type } +shader_fn_args = { "(" ~ shader_var_name_type? ~ (ws* ~ "," ~ ws* ~ shader_var_name_type)* ~ ")" } +// the body of a function, including the opening and closing brackets +shader_fn_body = { "{" ~ newline* ~ (cws* ~ (shader_actual_code_line ~ cws*)* ~ newline)+ ~ "}" } + +shader_fn_def = { + "fn" ~ ws ~ shader_ident ~ shader_fn_args ~ ws ~ "->" ~ ws ~ shader_type ~ ws ~ shader_fn_body +} // a line of shader code, including white space shader_code_line = {