output includes recursively, find dependencies of imported functions and output those as well

This commit is contained in:
SeanOMik 2024-08-03 16:48:02 -04:00
parent 5ee537763f
commit e85adce6e3
5 changed files with 625 additions and 427 deletions

View File

@ -1,5 +1,5 @@
#define_module base #define_module base
#import simple #import simple::{do_something_cool}
fn main() -> vec4<f32> { fn main() -> vec4<f32> {
let a = do_something_cool(10.0); let a = do_something_cool(10.0);

35
src/compiler.rs Normal file
View File

@ -0,0 +1,35 @@
use std::collections::HashMap;
use crate::{Definition, Module, PreprocessorError, Processor};
/// Compile a module including its imports into a single module.
#[derive(Default)]
pub struct Compiler {
preprocessor: Processor,
}
impl Compiler {
/// Add a module to the compiler
///
/// Returns `None` if the module does not define an include identifier.
pub fn add_module(&mut self, module_src: &str) -> Result<Option<String>, PreprocessorError> {
self.preprocessor.parse_module(module_src)
}
pub fn compile_module(self, module_src: &str) -> Result<String, PreprocessorError> {
todo!()
}
}
/* pub struct Source {
} */
pub struct ExpandableModule {
/// The name of the module.
name: String,
/// Constants that this module defines
pub constants: HashMap<String, String>,
/// Functions that this module defines
pub functions: HashMap<String, String>,
}

View File

@ -9,28 +9,15 @@ use itertools::Itertools;
use pest::Parser; use pest::Parser;
use pest_derive::Parser; use pest_derive::Parser;
mod preprocessor;
pub use preprocessor::*;
mod compiler;
pub use compiler::*;
#[derive(Parser)] #[derive(Parser)]
#[grammar = "wgsl.pest"] #[grammar = "wgsl.pest"]
pub struct WgslParser; pub(crate) struct WgslParser;
#[derive(Debug, thiserror::Error)]
pub enum PreprocessorError {
#[error("{0}")]
IoError(#[from] std::io::Error),
#[error("error parsing {path}: {err}")]
ParserError {
path: PathBuf,
err: pest::error::Error<Rule>,
},
#[error("failure formatting preprocessor output to string ({0})")]
FormatError(#[from] std::fmt::Error),
#[error("unknown module import '{module}', in {from_path}")]
UnknownModule { from_path: PathBuf, module: String },
#[error("in {from_path}: unknown import from '{module}': `{item}`")]
UnknownImport { from_path: PathBuf, module: String, item: String },
#[error("import usage from `{from_module}` conflicts with local variable/function: `{name}`")]
ConflictingImport { from_module: String, name: String },
}
fn main() { fn main() {
/* let mut successful_parse = WgslParser::parse(Rule::command_line, "#define_module inner::some_include").unwrap(); /* let mut successful_parse = WgslParser::parse(Rule::command_line, "#define_module inner::some_include").unwrap();
@ -40,11 +27,18 @@ fn main() {
let mut p = Processor::new(); let mut p = Processor::new();
//let f = p.parse_modules("shaders", ["wgsl"]).unwrap(); //let f = p.parse_modules("shaders", ["wgsl"]).unwrap();
//println!("Parsed {} modules:", f); //println!("Parsed {} modules:", f);
p.parse_module("shaders/inner_include.wgsl") let inner_include_src = fs::read_to_string("shaders/inner_include.wgsl").unwrap();
p.parse_module(&inner_include_src)
.unwrap() .unwrap()
.expect("failed to find module"); .expect("failed to find module");
p.parse_module("shaders/simple.wgsl") let simple_include_src = fs::read_to_string("shaders/simple.wgsl").unwrap();
p.parse_module(&simple_include_src)
.unwrap()
.expect("failed to find module");
let base_include_src = fs::read_to_string("shaders/base.wgsl").unwrap();
let base_module_path = p.parse_module(&base_include_src)
.unwrap() .unwrap()
.expect("failed to find module"); .expect("failed to find module");
@ -60,16 +54,20 @@ fn main() {
} }
for (name, def) in &module.functions { for (name, def) in &module.functions {
println!(" fn {name}, {}-{}", def.start_pos, def.end_pos); let requires: Vec<String> = def.requirements.iter().map(|r| {
let pre = r.module.as_ref().map(|m| format!("{m}::")).unwrap_or_default();
format!("{}{}", pre, r.name)
}).collect();
println!(" fn {name}, {}-{}. requires: {:?}", def.start_pos, def.end_pos, requires);
} }
println!(" imported modules: {:?}", module.module_imports); println!(" imported modules: {:?}", module.module_imports);
if !module.type_imports.is_empty() { if !module.item_imports.is_empty() {
println!(" type imports:"); println!(" type imports:");
} }
for (module, usages) in &module.type_imports { for (module, usages) in &module.item_imports {
println!(" {}: {:?}", module, usages.imports); println!(" {}: {:?}", module, usages.imports);
} }
@ -89,7 +87,8 @@ fn main() {
} }
} }
let out = p.process_file("shaders/simple.wgsl").unwrap(); let base_include_src = fs::read_to_string("shaders/base.wgsl").unwrap();
let out = p.process_file(&base_module_path, &base_include_src).unwrap();
fs::write("out.wgsl", out).unwrap(); fs::write("out.wgsl", out).unwrap();
} }
@ -99,6 +98,7 @@ pub enum ExternalUsageType {
Function, Function,
} }
#[derive(Clone)]
pub struct ExternalUsage { pub struct ExternalUsage {
name: String, name: String,
ty: ExternalUsageType, ty: ExternalUsageType,
@ -106,412 +106,60 @@ pub struct ExternalUsage {
start_pos: usize, start_pos: usize,
} }
#[derive(Clone)]
pub struct ImportUsage { pub struct ImportUsage {
module: String, module: String,
imports: Vec<ExternalUsage>, imports: Vec<ExternalUsage>,
} }
#[derive(Clone)]
pub struct Import { pub struct Import {
module: String, module: String,
imports: Vec<String>, imports: Vec<String>,
} }
#[derive(Clone)]
pub struct DefRequirement {
/// None if the requirement is local
module: Option<String>,
name: String,
ty: ExternalUsageType,
}
#[derive(Clone)]
pub struct Definition { pub struct Definition {
name: String, name: String,
requirements: Vec<DefRequirement>,
/// The start byte position as a `usize`. /// The start byte position as a `usize`.
start_pos: usize, start_pos: usize,
/// The end byte position as a `usize`. /// The end byte position as a `usize`.
end_pos: usize, end_pos: usize,
} }
#[derive(Default)] #[derive(Default, Clone)]
pub struct Module { pub struct Module {
/// The name of the module.
name: String, name: String,
path: String, /// The source code of the module, non-processed.
constants: HashMap<String, Definition>, src: String,
functions: HashMap<String, Definition>, /// Constants that this module defines
module_imports: HashSet<String>, pub constants: HashMap<String, Definition>,
type_imports: HashMap<String, Import>, /// Functions that this module defines
pub functions: HashMap<String, Definition>,
/// Imports of things per module
/// ie `other_module::{scalar, do_math_func}`
item_imports: HashMap<String, Import>,
/// usages of imported things
/// ie `other_module::scalar`
import_usages: HashMap<String, ImportUsage>, import_usages: HashMap<String, ImportUsage>,
} /// Imports of modules
#[derive(Default)]
pub struct Processor {
modules: HashMap<String, Module>,
}
impl Processor {
pub fn new() -> Self {
Self::default()
}
/// Parse a module file to attempt to find the include identifier.
/// ///
/// Returns `None` if the module does not define an include identifier. /// These modules are used along side `import_usages`
pub fn parse_module<P: AsRef<Path>>( module_imports: HashSet<String>,
&mut self,
path: P,
) -> Result<Option<String>, PreprocessorError> {
let unparsed_file = fs::read_to_string(path.as_ref())?;
// add a new line to the end of the input to make the grammar happy
//let unparsed_file = format!("{unparsed_file}\n");
let file = WgslParser::parse(Rule::file, &unparsed_file)
.map_err(|e| PreprocessorError::ParserError {
path: path.as_ref().to_path_buf(),
err: e,
})?
.next()
.unwrap(); // get and unwrap the `file` rule; never fails
let mut module = Module::default();
module.path = path.as_ref().to_str().unwrap().into();
for record in file.into_inner() {
match record.as_rule() {
Rule::command_line => {
// the parser has found a preprocessor command, figure out what it is
let mut pairs = record.into_inner();
let command_line = pairs.next().unwrap();
match command_line.as_rule() {
Rule::import_types_command => {
let mut inner = command_line.into_inner();
let import_module_command = inner.next().unwrap();
let mut import_module_command = import_module_command.into_inner();
let module_name = import_module_command.next().unwrap().as_str();
let types: Vec<String> = inner.map(|t| t.as_str().to_string()).collect();
println!("found import of types from `{}`: `{:?}`", module_name, types);
// add these type imports to imports of the module
module.type_imports.entry(module_name.into())
.or_insert_with(|| Import {
module: module_name.into(),
imports: vec![],
})
.imports.extend(types.into_iter());
},
Rule::import_module_command => {
let mut inner = command_line.into_inner();
let module_name = inner.next().unwrap().as_str();
println!("found import of module: {}", module_name);
module.module_imports.insert(module_name.into());
},
Rule::define_module_command => {
let mut shader_file_pairs = command_line.into_inner();
let shader_file = shader_file_pairs.next().unwrap();
let shader_file = shader_file.as_str().to_string();
module.name = shader_file;
}
_ => unreachable!(),
}
}
Rule::shader_code_line => {
for line in record.into_inner() {
let (pos_line, pos_col) = line.line_col();
match line.as_rule() {
Rule::shader_fn_def => {
let mut pairs = line.clone().into_inner();
// shader_ident is the only pair for this rule
let fn_name = pairs.next().unwrap().as_str().to_string();
println!("fn: {fn_name:?}");
/* let fn_args = pairs.next().unwrap().as_str();
let ret_type = pairs.next().unwrap().as_str();
let fn_body_pair = pairs.next().unwrap(); */
let line_span = line.as_span();
let start_pos = line_span.start();
let end_pos = line_span.end();
module.functions.insert(
fn_name.clone(),
Definition {
name: fn_name,
start_pos,
end_pos,
},
);
}
Rule::shader_const_def => {
let mut pairs = line.clone().into_inner();
// shader_ident is the only pair for this rule
let const_name = pairs.next().unwrap().as_str().to_string();
println!("const: {const_name:?}");
let line_span = line.as_span();
let start_pos = line_span.start();
let end_pos = line_span.end();
module.constants.insert(
const_name.clone(),
Definition {
name: const_name,
start_pos,
end_pos,
},
);
}
Rule::shader_external_fn => {
let mut pairs = line.into_inner();
// shader_external_variable is the only pair for this rule
let ident_name = pairs.next().unwrap().as_str().to_string();
println!("external fn: {ident_name}");
}
Rule::shader_external_variable => {
let pairs = line.into_inner();
// shader_external_variable is the only pair for this rule
let ident_name = pairs.as_str();
println!("external var: {ident_name}");
}
Rule::shader_code => {
println!("code: {}", line.as_str());
}
Rule::cws => (),
Rule::newline => (),
_ => unimplemented!("ran into unhandled rule: {:?}", line.as_span())
}
}
}
Rule::newline => (),
Rule::EOI => (),
_ => unimplemented!("ran into unhandled rule: {:?}", record.as_span())
}
}
if module.name.is_empty() {
Ok(None)
} else {
let name = module.name.clone();
self.modules.insert(name.clone(), module);
Ok(Some(name))
}
}
/// Find files recursively in `path` with an extension in `extensions`, and parse them.
///
/// For each file that's found, [`Processor::parse_module`] is used to parse them.
///
/// Parameters:
/// * `path` - The path to search for files in.
/// * `extensions` - The extensions that the discovered files must have. Make sure they have
/// no leading '.'
pub fn parse_modules<P: AsRef<Path>, const N: usize>(
&mut self,
path: P,
extensions: [&str; N],
) -> Result<usize, PreprocessorError> {
//debug_assert!(!extension.starts_with("."), "remove leading '.' from extension");
let files = recurse_files(path)?;
let mut parsed = 0;
for file in files {
if let Some(ext) = file.extension().and_then(|p| p.to_str()) {
if extensions.contains(&ext) {
self.parse_module(file)?;
parsed += 1;
}
}
}
Ok(parsed)
}
pub fn process_file<P: AsRef<Path>>(&mut self, path: P) -> Result<String, PreprocessorError> {
let unparsed_file = fs::read_to_string(path.as_ref())?;
// add a new line to the end of the input to make the grammar happy
let unparsed_file = format!("{unparsed_file}\n");
let mut out_string = String::new();
let file = WgslParser::parse(Rule::file, &unparsed_file)
.map_err(|e| PreprocessorError::ParserError {
path: path.as_ref().to_path_buf(),
err: e,
})?
.next()
.unwrap(); // get and unwrap the `file` rule; never fails
for record in file.into_inner() {
match record.as_rule() {
Rule::command_line => {
// the parser has found a preprocessor command, figure out what it is
let mut pairs = record.into_inner();
let command_line = pairs.next().unwrap();
match command_line.as_rule() {
Rule::import_module_command => {
let mut shader_file_pairs = command_line.into_inner();
let shader_file = shader_file_pairs.next().unwrap();
let shader_file = shader_file.as_str();
println!("found module import: {}", shader_file);
let imported_mod = self.modules.get(shader_file).ok_or_else(|| {
PreprocessorError::UnknownModule {
from_path: path.as_ref().to_path_buf(),
module: shader_file.into(),
}
})?;
let included_file = self.process_file(imported_mod.path.clone())?;
let start_header =
format!("// ==== START OF INCLUDE OF '{}' ====", shader_file);
let end_header =
format!("\n// ==== END OF INCLUDE OF '{}' ====\n", shader_file);
out_string.write_str(&start_header)?;
out_string.write_str(&included_file)?;
out_string.write_str(&end_header)?;
}
Rule::import_types_command => {
let mut import_command = command_line.into_inner();
let import_module_command = import_command.next().unwrap();
let module_path = import_module_command.into_inner().next().unwrap();
let module_path = module_path.as_str();
let importing_from_mod = self.modules.get(module_path).ok_or_else(|| {
PreprocessorError::UnknownModule {
from_path: path.as_ref().to_path_buf(),
module: module_path.into(),
}
})?;
let module_raw_src = fs::read_to_string(&importing_from_mod.path)?;
for import in import_command {
let import_ident = import.as_str();
let def = importing_from_mod.functions.get(import_ident)
.or_else(|| importing_from_mod.constants.get(import_ident))
.ok_or_else(|| {
PreprocessorError::UnknownImport {
from_path: path.as_ref().to_path_buf(),
module: module_path.into(),
item: import_ident.to_string(),
}
})?;
let import_text = &module_raw_src[def.start_pos..def.end_pos];
//println!("must add:\n{import_text}");
out_string.write_fmt(format_args!("// START OF IMPORT ITEM {} FROM {}\n", import_ident, module_path))?;
out_string.write_str(import_text)?;
out_string.write_fmt(format_args!("\n// END OF IMPORT ITEM {} FROM {}\n\n", import_ident, module_path))?;
}
//todo!();
/* importing_from_mod.functions.get()
let imports: Vec<&str> =
shader_file_pairs.map(|i| i.as_str()).collect();
println!("found module import: {}", module_path);
println!("imports: {imports:?}");
todo!(); */
/* let imported_mod = self.modules.get(shader_file).ok_or_else(|| {
PreprocessorError::UnknownModule {
from_path: path.as_ref().to_path_buf(),
module: shader_file.into(),
}
})?; */
}
Rule::define_module_command => (),
_ => unimplemented!("ran into unhandled rule: {:?}", command_line.as_span()),
}
}
Rule::cws => (),
Rule::shader_code_line => {
for line in record.into_inner() {
let (pos_line, pos_col) = line.line_col();
match line.as_rule() {
Rule::shader_external_fn => {
let mut pairs = line.into_inner();
// shader_external_variable is the only pair for this rule
let ident_name = pairs.next().unwrap().as_str().to_string();
if let Some((module_name, ident)) = ident_name.rsplit_once("::") {
/* let usage = ExternalUsage {
name: ident.into(),
ty: ExternalUsageType::Function,
line: pos_line,
col: pos_col
}; */
out_string.write_str(ident)?;
} else {
// TODO: not really sure how this would get triggered
unimplemented!(
"this function is actually not external, i think"
);
}
}
Rule::shader_external_variable => {
let pairs = line.into_inner();
// shader_external_variable is the only pair for this rule
let ident_name = pairs.as_str();
if let Some((module_name, ident)) = ident_name.rsplit_once("::") {
/* let usage = ExternalUsage {
name: ident.into(),
ty: ExternalUsageType::Variable,
line: pos_line,
col: pos_col
}; */
out_string.write_str(ident)?;
} else {
// TODO: not really sure how this would get triggered
unimplemented!(
"this function is actually not external, i think"
);
}
}
/* Rule::shader_fn_def => (),
Rule::shader_const_def => (),
Rule::shader_code => { */
Rule::shader_code | Rule::shader_const_def | Rule::shader_fn_def => {
let input = line.as_str();
out_string.write_str(&input)?;
}
Rule::cws => {
let input = line.as_str();
out_string.write_str(&input)?;
}
Rule::newline => (),
_ => unimplemented!("ran into unhandled rule: {:?}", line.as_span()),
}
}
}
Rule::newline => {
let input = record.as_str();
out_string.write_str(&input)?;
}
Rule::EOI => (),
_ => unimplemented!("ran into unhandled rule: {:?}", record.as_span()),
}
}
Ok(out_string)
}
} }
/// Recursively find files in `path`. /// Recursively find files in `path`.
fn recurse_files(path: impl AsRef<Path>) -> std::io::Result<Vec<PathBuf>> { pub(crate) fn recurse_files(path: impl AsRef<Path>) -> std::io::Result<Vec<PathBuf>> {
let mut buf = vec![]; let mut buf = vec![];
let entries = fs::read_dir(path)?; let entries = fs::read_dir(path)?;

512
src/preprocessor.rs Normal file
View File

@ -0,0 +1,512 @@
use std::{collections::{HashMap, HashSet, VecDeque}, fmt::Write, fs, path::{Path, PathBuf}};
use pest::{iterators::Pair, Parser};
use crate::{recurse_files, DefRequirement, Definition, ExternalUsageType, Import, Module, Rule, WgslParser};
const RESERVED_WORDS: [&str; 171] = [
"NULL", "Self", "abstract", "active", "alignas", "alignof", "as", "asm", "asm_fragment",
"async", "attribute", "auto", "await", "become", "binding_array", "cast", "catch", "class",
"co_await", "co_return", "co_yield", "coherent", "column_major", "common", "compile",
"compile_fragment", "concept", "const_cast", "consteval", "constexpr", "constinit", "crate",
"debugger", "decltype", "delete", "demote", "demote_to_helper", "do", "dynamic_cast", "enum",
"explicit", "export", "extends", "extern", "external", "fallthrough", "filter", "final",
"finally", "friend", "from", "fxgroup", "get", "goto", "groupshared", "highp", "impl",
"implements", "import", "inline", "instanceof", "interface", "layout", "lowp", "macro",
"macro_rules", "match", "mediump", "meta", "mod", "module", "move", "mut", "mutable",
"namespace", "new", "nil", "noexcept", "noinline", "nointerpolation", "noperspective", "null",
"nullptr", "of", "operator", "package", "packoffset", "partition", "pass", "patch",
"pixelfragment", "precise", "precision", "premerge", "priv", "protected", "pub", "public",
"readonly", "ref", "regardless", "register", "reinterpret_cast", "require", "resource",
"restrict", "self", "set", "shared", "sizeof", "smooth", "snorm", "static", "static_assert",
"static_cast", "std", "subroutine", "super", "target", "template", "this", "thread_local",
"throw", "trait", "try", "type", "typedef", "typeid", "typename", "typeof", "union", "unless",
"unorm", "unsafe", "unsized", "use", "using", "varying", "virtual", "volatile", "wgsl",
"where", "with", "writeonly", "yield", "alias", "break", "case", "const", "const_assert",
"continue", "continuing", "default", "diagnostic", "discard", "else", "enable", "false", "fn",
"for", "if", "let", "loop", "override", "requires", "return", "struct", "switch", "true",
"var", "while"
];
#[derive(Debug, thiserror::Error)]
pub enum PreprocessorError {
#[error("{0}")]
IoError(#[from] std::io::Error),
#[error("error parsing {0}")]
ParserError(#[from] pest::error::Error<Rule>),
#[error("failure formatting preprocessor output to string ({0})")]
FormatError(#[from] std::fmt::Error),
#[error("unknown module import '{module}', in {from_module}")]
UnknownModule { from_module: String, module: String },
#[error("in {from_module}: unknown import from '{module}': `{item}`")]
UnknownImport { from_module: String, module: String, item: String },
#[error("import usage from `{from_module}` conflicts with local variable/function: `{name}`")]
ConflictingImport { from_module: String, name: String },
}
#[derive(Default)]
pub struct Processor {
pub modules: HashMap<String, Module>,
}
impl Processor {
pub fn new() -> Self {
Self::default()
}
fn get_imports_in_block(&mut self, module: &mut Module, block: Pair<Rule>, found_requirements: &mut HashSet<String>) -> Vec<DefRequirement> {
let mut requirements = vec![];
for code in block.into_inner() {
match code.as_rule() {
Rule::shader_code_block | Rule::shader_code => {
let reqs = self.get_imports_in_block(module, code, found_requirements);
requirements.extend(reqs.into_iter());
},
Rule::shader_code_fn_usage => {
let mut usage_inner = code.into_inner();
let fn_name = usage_inner.next().unwrap().as_str();
let fn_args: Vec<&str> = usage_inner.map(|a| a.as_str()).collect();
if found_requirements.contains(fn_name) {
continue;
}
found_requirements.insert(fn_name.to_string());
println!("Found call to {} with args: {:?}", fn_name, fn_args);
// ignore reserved words
if RESERVED_WORDS.contains(&fn_name) {
continue;
}
let req = DefRequirement {
// module is discovered later
module: None,
name: fn_name.to_string(),
ty: ExternalUsageType::Function,
};
requirements.push(req);
},
Rule::shader_external_fn => {
let mut pairs = code.into_inner();
// shader_external_variable is the only pair for this rule
let ident_name = pairs.next().unwrap().as_str().to_string();
println!("external fn: {ident_name}");
},
Rule::shader_external_variable => {
let pairs = code.into_inner();
// shader_external_variable is the only pair for this rule
let ident_name = pairs.as_str();
println!("external var: {ident_name}");
},
Rule::newline => (),
Rule::cws => (),
Rule::shader_code_char => (),
Rule::shader_ident => {
let ident = code.as_str();
println!("Found usage of ident: {}", ident);
if found_requirements.contains(ident) {
continue;
}
found_requirements.insert(ident.to_string());
// ignore reserved words
if RESERVED_WORDS.contains(&ident) {
continue;
}
let req = DefRequirement {
// module is discovered later
module: None,
name: ident.to_string(),
ty: ExternalUsageType::Variable,
};
requirements.push(req);
},
Rule::shader_value => (),
_ => unimplemented!("ran into unhandled rule: {:?}, {:?}", code.as_rule(), code.as_span())
}
}
requirements
}
/// Parse a module file to attempt to find the include identifier.
///
/// Returns `None` if the module does not define an include identifier.
pub fn parse_module(
&mut self,
module_src: &str,
) -> Result<Option<String>, PreprocessorError> {
//let unparsed_file = fs::read_to_string(path.as_ref())?;
// add a new line to the end of the input to make the grammar happy
//let unparsed_file = format!("{unparsed_file}\n");
let file = WgslParser::parse(Rule::file, &module_src)?
.next()
.unwrap(); // get and unwrap the `file` rule; never fails
let mut module = Module::default();
module.src = module_src.to_string();
for record in file.into_inner() {
match record.as_rule() {
Rule::command_line => {
// the parser has found a preprocessor command, figure out what it is
let mut pairs = record.into_inner();
let command_line = pairs.next().unwrap();
match command_line.as_rule() {
Rule::import_types_command => {
let mut inner = command_line.into_inner();
let import_module_command = inner.next().unwrap();
let mut import_module_command = import_module_command.into_inner();
let module_name = import_module_command.next().unwrap().as_str();
let types: Vec<String> = inner.map(|t| t.as_str().to_string()).collect();
println!("found import of types from `{}`: `{:?}`", module_name, types);
// add these type imports to imports of the module
module.item_imports.entry(module_name.into())
.or_insert_with(|| Import {
module: module_name.into(),
imports: vec![],
})
.imports.extend(types.into_iter());
},
Rule::import_module_command => {
let mut inner = command_line.into_inner();
let module_name = inner.next().unwrap().as_str();
println!("found import of module: {}", module_name);
module.module_imports.insert(module_name.into());
},
Rule::define_module_command => {
let mut shader_file_pairs = command_line.into_inner();
let shader_file = shader_file_pairs.next().unwrap();
let shader_file = shader_file.as_str().to_string();
module.name = shader_file;
}
_ => unreachable!(),
}
}
Rule::shader_code_line => {
for line in record.into_inner() {
match line.as_rule() {
Rule::shader_fn_def => {
let mut pairs = line.clone().into_inner();
// shader_ident is the only pair for this rule
let fn_name = pairs.next().unwrap().as_str().to_string();
println!("fn def: {fn_name:?}");
let fn_body = pairs.skip(2).next().unwrap();
let mut found_reqs = HashSet::default();
let requirements = self.get_imports_in_block(&mut module, fn_body, &mut found_reqs);
let line_span = line.as_span();
let start_pos = line_span.start();
let end_pos = line_span.end();
module.functions.insert(
fn_name.clone(),
Definition {
name: fn_name,
start_pos,
end_pos,
requirements
},
);
}
Rule::shader_const_def => {
let mut pairs = line.clone().into_inner();
// shader_ident is the only pair for this rule
let const_name = pairs.next().unwrap().as_str().to_string();
println!("const def: {const_name:?}");
let line_span = line.as_span();
let start_pos = line_span.start();
let end_pos = line_span.end();
module.constants.insert(
const_name.clone(),
Definition {
name: const_name,
start_pos,
end_pos,
requirements: vec![],
},
);
}
Rule::shader_code => {
let mut shader_inner = line.clone().into_inner();
let code_type = shader_inner.next().unwrap();
match code_type.as_rule() {
Rule::shader_code_block => {
let mut found_reqs = HashSet::default();
self.get_imports_in_block(&mut module, code_type, &mut found_reqs);
},
Rule::shader_code_fn_usage => {
todo!("cannot handle usage at this level");
},
Rule::shader_code_char => {
todo!("I think this can be ignored");
},
_ => unimplemented!("ran into unhandled rule: {:?}", line.as_span())
}
println!("code: {}", line.as_str());
}
Rule::cws => (),
Rule::newline => (),
_ => unimplemented!("ran into unhandled rule: {:?}", line.as_span())
}
}
}
Rule::newline => (),
Rule::EOI => (),
_ => unimplemented!("ran into unhandled rule: {:?}", record.as_span())
}
}
if module.name.is_empty() {
Ok(None)
} else {
let name = module.name.clone();
self.modules.insert(name.clone(), module);
Ok(Some(name))
}
}
/// Find files recursively in `path` with an extension in `extensions`, and parse them.
///
/// For each file that's found, [`Processor::parse_module`] is used to parse them.
///
/// Parameters:
/// * `path` - The path to search for files in.
/// * `extensions` - The extensions that the discovered files must have. Make sure they have
/// no leading '.'
pub fn parse_modules<P: AsRef<Path>, const N: usize>(
&mut self,
path: P,
extensions: [&str; N],
) -> Result<usize, PreprocessorError> {
let files = recurse_files(path)?;
let mut parsed = 0;
for file in files {
if let Some(ext) = file.extension().and_then(|p| p.to_str()) {
if extensions.contains(&ext) {
let module_src = fs::read_to_string(file)?;
self.parse_module(&module_src)?;
parsed += 1;
}
}
}
Ok(parsed)
}
fn generate_header(&mut self, module_path: &str) -> String {
let module = self.modules.get(module_path).unwrap();
let mut output = String::new();
compile_definitions(&self.modules, module, &mut output);
output
}
fn output_shader_code_line(&self, shader_code_rule: Pair<Rule>, output: &mut String) -> Result<(), std::fmt::Error> {
for line in shader_code_rule.into_inner() {
let (pos_line, pos_col) = line.line_col();
match line.as_rule() {
Rule::shader_external_fn | Rule::shader_external_variable => {
let mut pairs = line.into_inner();
// shader_external_variable is the only pair for this rule
let ident_name = pairs.next().unwrap().as_str().to_string();
// remove the module from the identifier and write it to the output
if let Some((module_name, ident)) = ident_name.rsplit_once("::") {
output.write_str(ident)?;
} else {
// TODO: not really sure how this would get triggered
unimplemented!(
"this function is actually not external, i think"
);
}
},
Rule::shader_fn_def => {
let mut rule_inner = line.into_inner();
let fn_name = rule_inner.next().unwrap().as_str();
let args = rule_inner.next().unwrap().as_str();
let fn_ret = rule_inner.next().unwrap().as_str();
let fn_body = rule_inner.next().unwrap();
let mut body_output = String::new();
self.output_shader_code_line(fn_body, &mut body_output)?;
// to escape {, must use two
output.write_fmt(format_args!("fn {}{} -> {} {{{}}}",
fn_name, args, fn_ret, body_output))?;
},
Rule::shader_code | Rule::shader_const_def => {
self.output_shader_code_line(line, output)?;
},
Rule::shader_code_fn_usage | Rule::shader_value | Rule::shader_ident | Rule::shader_code_char | Rule::cws => {
let input = line.as_str();
output.write_str(&input)?;
},
Rule::newline => {
output.write_str("\n")?;
},
_ => unimplemented!("ran into unhandled rule: {:?}", line.as_rule()),
}
}
Ok(())
}
pub fn process_file(&mut self, module_path: &str, module_src: &str) -> Result<String, PreprocessorError> {
let mut out_string = String::new();
// the output will be at least the length of module_src
out_string.reserve(module_src.len());
let file = WgslParser::parse(Rule::file, &module_src)?
.next()
.unwrap(); // get and unwrap the `file` rule; never fails
let header = self.generate_header(module_path);
out_string.write_str("// START OF IMPORT HEADER\n")?;
out_string.write_str(&header)?;
out_string.write_str("// END OF IMPORT HEADER\n")?;
for record in file.into_inner() {
match record.as_rule() {
Rule::command_line => {
// the parser has found a preprocessor command, figure out what it is
let mut pairs = record.into_inner();
let command_line = pairs.next().unwrap();
match command_line.as_rule() {
Rule::import_module_command => (),
Rule::import_types_command => (),
Rule::define_module_command => (),
_ => unimplemented!("ran into unhandled rule: {:?}", command_line.as_span()),
}
},
Rule::cws => (),
Rule::shader_code_line => {
self.output_shader_code_line(record, &mut out_string)?;
/* for line in record.into_inner() {
let (pos_line, pos_col) = line.line_col();
match line.as_rule() {
Rule::shader_external_fn | Rule::shader_external_variable => {
let mut pairs = line.into_inner();
// shader_external_variable is the only pair for this rule
let ident_name = pairs.next().unwrap().as_str().to_string();
// remove the module from the identifier and write it to the output
if let Some((module_name, ident)) = ident_name.rsplit_once("::") {
out_string.write_str(ident)?;
} else {
// TODO: not really sure how this would get triggered
unimplemented!(
"this function is actually not external, i think"
);
}
}
Rule::shader_code | Rule::shader_const_def => {
let input = line.as_str();
out_string.write_str(&input)?;
},
/* Rule::shader_fn_def => {
}, */
Rule::cws => {
let input = line.as_str();
out_string.write_str(&input)?;
},
Rule::newline => (),
_ => unimplemented!("ran into unhandled rule: {:?}", line.as_rule()),
}
} */
},
Rule::newline => {
let input = record.as_str();
out_string.write_str(&input)?;
},
Rule::EOI => (),
_ => unimplemented!("ran into unhandled rule: {:?}", record.as_rule()),
}
}
Ok(out_string)
}
}
fn try_find_requirement_module(module: &Module, req_name: &str) -> Option<String> {
for import in module.item_imports.values() {
if import.imports.contains(&req_name.to_string()) {
return Some(import.module.clone());
}
}
None
}
fn compile_definitions(modules: &HashMap<String, Module>, module: &Module, output: &mut String) -> Result<(), PreprocessorError> {
for (_, funcs) in &module.functions {
let mut requirements = VecDeque::from(funcs.requirements.clone());
while let Some(mut req) = requirements.pop_front() {
if req.module.is_none() {
let mod_name = try_find_requirement_module(&module, &req.name);
req.module = mod_name;
}
if let Some(module_name) = &req.module {
let req_module = modules.get(module_name)
.unwrap_or_else(|| panic!("invalid module import: {}", module_name));
let req_def = req_module.functions.get(&req.name)
.or_else(|| req_module.constants.get(&req.name))
.unwrap_or_else(|| panic!("invalid import: {} from {}", req.name, module_name));
let sub_req_names: Vec<String> = req_def.requirements.iter().map(|r| r.name.clone()).collect();
println!("got req: {}, subreqs: {:?}", req_def.name, sub_req_names);
if !req_def.requirements.is_empty() {
let mut requirements_output = String::new();
compile_definitions(modules, req_module, &mut requirements_output)?;
output.write_fmt(format_args!("\n// REQUIREMENTS OF {}::{}\n", module_name, req.name))?;
output.push_str(&requirements_output);
output.push_str("\n");
}
let func_src = &req_module.src[req_def.start_pos..req_def.end_pos];
output.write_fmt(format_args!("// SOURCE {}::{}\n", module_name, req.name))?;
output.push_str(func_src);
output.push_str("\n");
}
}
}
Ok(())
}

View File

@ -22,23 +22,11 @@ command_line = { preproc_prefix ~ (define_module_command | import_command) ~ new
// all characters used by wgsl // all characters used by wgsl
shader_code_char = { "@" | "-" | "+" | "*" | "/" | "=" | "(" | ")" | ">" | "<" | ";" | ":" | "." | "_" | "," } shader_code_char = { "@" | "-" | "+" | "*" | "/" | "=" | "(" | ")" | ">" | "<" | ";" | ":" | "." | "_" | "," }
shader_code_block = { "{" ~ newline* ~ (cws* ~ (shader_actual_code_line ~ cws*)* ~ newline)+ ~ cws* ~ "}" } shader_code_block = { "{" ~ newline* ~ (cws* ~ (shader_actual_code_line ~ cws*)* ~ newline)+ ~ cws* ~ "}" }
shader_code = { shader_code_block | shader_code_char | ASCII_ALPHANUMERIC+ } // an fn argument can be another function use
shader_code_fn_arg = _{ shader_code_fn_usage | shader_value | shader_ident }
shader_value_num = { ASCII_DIGIT* ~ ( "." ~ ASCII_DIGIT* )? } shader_code_fn_usage = { shader_ident ~ "(" ~ shader_code_fn_arg ~ ("," ~ ws* ~ shader_code_fn_arg)* ~ ")" }
shader_value_bool = { "true" | "false" } //shader_code_fn_usage = { shader_ident ~ "(in, 2.0)" }
shader_value = { shader_value_bool | shader_value_num } shader_code = { shader_code_block | shader_code_fn_usage | shader_value | shader_ident | shader_code_char }
// defines type of something i.e., `: f32`, `: u32`, etc.
shader_var_type = { ":" ~ ws* ~ shader_type }
shader_const_def = { "const" ~ ws ~ shader_ident ~ (ws* ~ shader_var_type)? ~ ws* ~ "=" ~ ws* ~ shader_value ~ ";" }
shader_var_name_type = { shader_ident ~ shader_var_type }
shader_fn_args = { "(" ~ shader_var_name_type ~ (ws* ~ "," ~ ws* ~ shader_var_name_type)* ~ ")" }
// the body of a function, including the opening and closing brackets
shader_fn_body = { "{" ~ newline* ~ (cws* ~ (shader_actual_code_line ~ cws*)* ~ newline)+ ~ "}" }
shader_fn_def = {
"fn" ~ ws ~ shader_ident ~ shader_fn_args ~ ws ~ "->" ~ ws ~ shader_type ~ ws ~ shader_fn_body
}
// usages of code from another module // usages of code from another module
shader_external_variable = { shader_ident ~ ( "::" ~ shader_ident)+ } shader_external_variable = { shader_ident ~ ( "::" ~ shader_ident)+ }
@ -46,7 +34,22 @@ shader_external_fn = { shader_external_variable ~ "(" ~ ANY* ~ ")" }
shader_external_code = _{ shader_external_fn | shader_external_variable } shader_external_code = _{ shader_external_fn | shader_external_variable }
shader_actual_code_line = _{ shader_external_code | shader_code } shader_actual_code_line = _{ shader_external_code | shader_code }
//shader_actual_code_line = _{ cws* ~ ( (shader_external_code | shader_code) ~ cws*)* }
shader_value_num = { ASCII_DIGIT+ ~ ( "." ~ ASCII_DIGIT+ )? }
shader_value_bool = { "true" | "false" }
shader_value = { shader_value_bool | shader_value_num }
// defines type of something i.e., `: f32`, `: u32`, etc.
shader_var_type = { ":" ~ ws* ~ shader_type }
shader_const_def = { "const" ~ ws ~ shader_ident ~ (ws* ~ shader_var_type)? ~ ws* ~ "=" ~ ws* ~ shader_value ~ ";" }
shader_var_name_type = { shader_ident ~ shader_var_type }
shader_fn_args = { "(" ~ shader_var_name_type? ~ (ws* ~ "," ~ ws* ~ shader_var_name_type)* ~ ")" }
// the body of a function, including the opening and closing brackets
shader_fn_body = { "{" ~ newline* ~ (cws* ~ (shader_actual_code_line ~ cws*)* ~ newline)+ ~ "}" }
shader_fn_def = {
"fn" ~ ws ~ shader_ident ~ shader_fn_args ~ ws ~ "->" ~ ws ~ shader_type ~ ws ~ shader_fn_body
}
// a line of shader code, including white space // a line of shader code, including white space
shader_code_line = { shader_code_line = {