Compare commits

...

10 Commits

15 changed files with 1353 additions and 107 deletions

1
.gitignore vendored
View File

@ -1 +1,2 @@
/target
out.wgsl

62
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,62 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"type": "lldb",
"request": "launch",
"name": "Debug shader_prepoc",
"cargo": {
"args": [
"build",
//"--manifest-path", "${workspaceFolder}/examples/testbed/Cargo.toml"
"--bin=shader_prepoc",
],
"filter": {
"name": "shader_prepoc",
"kind": "bin"
}
},
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug unit tests in executable 'shader_prepoc'",
"cargo": {
"args": [
"test",
"--no-run",
"--bin=shader_prepoc",
"--package=shader_prepoc"
],
"filter": {
"name": "shader_prepoc",
"kind": "bin"
}
},
"args": [],
"cwd": "${workspaceFolder}"
},
/* {
"type": "lldb",
"request": "launch",
"name": "Test shader_prepoc",
"cargo": {
"args": [
"test",
"--bin=shader_prepoc",
],
"filter": {
"name": "shader_prepoc",
"kind": "bin"
}
},
"args": [],
"cwd": "${workspaceFolder}"
} */
]
}

194
Cargo.lock generated
View File

@ -2,6 +2,15 @@
# It is not intended for manual editing.
version = 3
[[package]]
name = "aho-corasick"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
dependencies = [
"memchr",
]
[[package]]
name = "block-buffer"
version = "0.10.4"
@ -56,24 +65,52 @@ dependencies = [
"version_check",
]
[[package]]
name = "lazy_static"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "libc"
version = "0.2.155"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
[[package]]
name = "log"
version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
[[package]]
name = "memchr"
version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
dependencies = [
"overload",
"winapi",
]
[[package]]
name = "once_cell"
version = "1.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
[[package]]
name = "overload"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]]
name = "pest"
version = "2.7.11"
@ -119,6 +156,12 @@ dependencies = [
"sha2",
]
[[package]]
name = "pin-project-lite"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02"
[[package]]
name = "proc-macro2"
version = "1.0.86"
@ -137,6 +180,35 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "regex"
version = "1.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619"
dependencies = [
"aho-corasick",
"memchr",
"regex-automata",
"regex-syntax",
]
[[package]]
name = "regex-automata"
version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
]
[[package]]
name = "regex-syntax"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
[[package]]
name = "sha2"
version = "0.10.8"
@ -149,13 +221,20 @@ dependencies = [
]
[[package]]
name = "shader_prepoc"
version = "0.1.0"
name = "sharded-slab"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6"
dependencies = [
"pest",
"pest_derive",
"lazy_static",
]
[[package]]
name = "smallvec"
version = "1.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
[[package]]
name = "syn"
version = "2.0.72"
@ -187,6 +266,73 @@ dependencies = [
"syn",
]
[[package]]
name = "thread_local"
version = "1.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c"
dependencies = [
"cfg-if",
"once_cell",
]
[[package]]
name = "tracing"
version = "0.1.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
dependencies = [
"pin-project-lite",
"tracing-attributes",
"tracing-core",
]
[[package]]
name = "tracing-attributes"
version = "0.1.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "tracing-core"
version = "0.1.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54"
dependencies = [
"once_cell",
"valuable",
]
[[package]]
name = "tracing-log"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3"
dependencies = [
"log",
"once_cell",
"tracing-core",
]
[[package]]
name = "tracing-subscriber"
version = "0.3.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b"
dependencies = [
"nu-ansi-term",
"sharded-slab",
"smallvec",
"thread_local",
"tracing-core",
"tracing-log",
]
[[package]]
name = "typenum"
version = "1.17.0"
@ -205,8 +351,48 @@ version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
[[package]]
name = "valuable"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d"
[[package]]
name = "version_check"
version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
[[package]]
name = "wgsl_preprocessor"
version = "0.1.0"
dependencies = [
"pest",
"pest_derive",
"regex",
"thiserror",
"tracing",
"tracing-subscriber",
]
[[package]]
name = "winapi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
dependencies = [
"winapi-i686-pc-windows-gnu",
"winapi-x86_64-pc-windows-gnu",
]
[[package]]
name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"

View File

@ -1,8 +1,12 @@
[package]
name = "shader_prepoc"
name = "wgsl_preprocessor"
version = "0.1.0"
edition = "2021"
[dependencies]
pest = "2.7.11"
pest_derive = "2.7.11"
regex = "1.10.6"
thiserror = "1.0.63"
tracing = "0.1.40"
tracing-subscriber = "0.3.18"

83
README.md Normal file
View File

@ -0,0 +1,83 @@
# WGSL Preprocessor
This crate was created for my 3d game engine, Lyra Engine, which uses this as a preprocessor.
## Features
* Modules import other modules via defined module paths
More to come, check out issues in this repo.
## How-To
### Preprocessing modules
To process modules, they first must be parsed to find the module paths of each file:
```rust
let mut p = Processor::new();
let inner_include_src = fs::read_to_string("shaders/inner_include.wgsl").unwrap();
let inner_module_path = p.parse_module(&inner_include_src)
.unwrap().expect("failed to find module");
let simple_include_src = fs::read_to_string("shaders/simple.wgsl").unwrap();
let simple_module_path = p.parse_module(&simple_include_src)
.unwrap().expect("failed to find module");
let base_include_src = fs::read_to_string("shaders/base.wgsl").unwrap();
let base_module_path = p.parse_module(&base_include_src)
.unwrap().expect("failed to find module");
```
Then, after all the modules are parsed, you can preprocess the module. This is where all the imported modules and items are compiled into a single output shader:
```rust
let out = p.process_file(&base_module_path, &base_include_src).unwrap();
fs::write("out.wgsl", out).unwrap();
```
### Importing modules
To import modules, that module must have its path defined. Do to so, at the top of the file, use the preprocessor command `#define_module <module_path>` with `module_path` being the identifier of the module and how its imported:
```wgsl
/// ==== shadows.wgsl
#define_module engine::shadows
const PCF_SAMPLES_NUM: u32 = 32;
/// ==== pbr.wgsl
#define_module engine::pbr
#import engine::shadows::{PCF_SAMPLES_NUM}
fn fs_main(in: VertexOutput) -> vec4<f32> {
// ...
let samples_num = PCF_SAMPLES_NUM;
// ...
}
```
The above example imports an item from the `engine::shadows` module. Keep in mind that you cannot have conflicted names of variables, functions, types, and other imported things. Instead of importing an item from the module, you can import the module and use items from it (still no conflicting names):
```wgsl
/// ==== shadows.wgsl
#define_module engine::shadows
const PCF_SAMPLES_NUM: u32 = 32;
/// ==== pbr.wgsl
#define_module engine::pbr
// not importing an item, but only a module
#import engine::shadows
fn fs_main(in: VertexOutput) -> vec4<f32> {
// ...
// use an item from the module here
let samples_num = shadows::PCF_SAMPLES_NUM;
// ...
}
```

7
examples/base.wgsl Normal file
View File

@ -0,0 +1,7 @@
#define_module base
#import simple
fn main() -> vec4<f32> {
let a = simple::simple_scalar * simple::do_something_cool(10.0);
return vec4<f32>(vec3<f32>(a), 1.0);
}

View File

@ -0,0 +1,8 @@
#define_module inner::some_include
const scalar: f32 = 5.0;
fn mult_some_nums(a: f32, b: f32) -> f32 {
let c = a * b;
return c;
}

8
examples/simple.wgsl Normal file
View File

@ -0,0 +1,8 @@
#define_module simple
#import inner::some_include::{scalar, mult_some_nums}
const simple_scalar: f32 = 50.0;
fn do_something_cool(in: f32) -> f32 {
return scalar * mult_some_nums(in, 2.0);
}

View File

@ -1,6 +0,0 @@
#import "simple.wgsl"
fn main() -> vec4<f32> {
let a = do_something_cool(10.0);
return vec4<f32>(vec3<f32>(a), 1.0);
}

View File

@ -1,10 +0,0 @@
// ==== START OF INCLUDE OF 'shaders/simple.wgsl' ====
fn do_something_cool(in: f32) -> f32 {
return in * 2.0;
}
// ==== END OF INCLUDE OF 'shaders/simple.wgsl' ====
fn main() -> vec4<f32> {
let a = do_something_cool(10.0);
return vec4<f32>(vec3<f32>(a), 1.0);
}

View File

@ -1,3 +0,0 @@
fn do_something_cool(in: f32) -> f32 {
return in * 2.0;
}

161
src/lib.rs Normal file
View File

@ -0,0 +1,161 @@
use std::{
collections::{HashMap, HashSet},
fs,
path::{Path, PathBuf},
};
use pest_derive::Parser;
mod preprocessor;
pub use preprocessor::*;
#[derive(Parser)]
#[grammar = "wgsl.pest"]
pub(crate) struct WgslParser;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum ExternalUsageType {
Variable,
Function,
}
#[derive(Clone)]
pub struct Import {
module: String,
imports: Vec<String>,
}
#[derive(Clone)]
pub struct DefRequirement {
/// None if the requirement is local
module: Option<String>,
name: String,
}
#[derive(Clone)]
pub struct Definition {
name: String,
requirements: Vec<DefRequirement>,
/// The start byte position as a `usize`.
start_pos: usize,
/// The end byte position as a `usize`.
end_pos: usize,
}
#[derive(Default, Clone)]
pub struct Module {
/// The name of the module.
name: String,
/// The source code of the module, non-processed.
src: String,
/// Constants that this module defines
pub constants: HashMap<String, Definition>,
/// Functions that this module defines
pub functions: HashMap<String, Definition>,
/// Imports of things per module
/// ie `other_module::{scalar, do_math_func}`
item_imports: HashMap<String, Import>,
/// Imports of modules
///
/// These modules are used along side `import_usages`
module_imports: HashSet<String>,
}
/// Recursively find files in `path`.
pub(crate) fn recurse_files(path: impl AsRef<Path>) -> std::io::Result<Vec<PathBuf>> {
let mut buf = vec![];
let entries = fs::read_dir(path)?;
for entry in entries {
let entry = entry?;
let meta = entry.metadata()?;
if meta.is_dir() {
let mut subdir = recurse_files(entry.path())?;
buf.append(&mut subdir);
}
if meta.is_file() {
buf.push(entry.path());
}
}
Ok(buf)
}
#[cfg(test)]
mod tests {
use crate::Processor;
const INNER_MODULE: &'static str =
r#"#define_module engine::inner
const scalar: f32 = 5.0;
fn mult_some_nums(a: f32, b: f32) -> f32 {
let c = a * b;
return c;
}"#;
const SIMPLE_MODULE: &'static str =
r#"#define_module simple
#import engine::inner::{scalar, mult_some_nums}
const simple_scalar: f32 = 50.0;
fn do_something_cool(in: f32) -> f32 {
return scalar * mult_some_nums(in, 2.0);
}"#;
const MAIN_MODULE: &'static str =
r#"
#define_module base
#import simple
fn main() -> vec4<f32> {
let a = simple::simple_scalar * simple::do_something_cool(10.0);
return vec4<f32>(vec3<f32>(a), 1.0);
}"#;
#[test]
fn single_layer_import() {
let mut p = Processor::new();
p.parse_module(&INNER_MODULE).unwrap()
.expect("failed to find module");
let simple_path = p.parse_module(&SIMPLE_MODULE).unwrap()
.expect("failed to find module");
let out = p.process_file(&simple_path, &SIMPLE_MODULE).unwrap();
assert!(out.contains("const scalar"), "definition of imported `const scalar` is missing!");
assert!(out.contains("fn mult_some_nums"), "definition of imported `fn mult_some_nums` is missing!");
assert!(out.contains("fn do_something_cool("), "definition of `fn do_something_cool` is missing!");
}
/// Tests importing things that imports depend on, and indirect usage of things,
/// i.e., `simple::some_const`.
#[test]
fn double_layer_import_indirect_imports() {
let mut p = Processor::new();
p.parse_module(&INNER_MODULE).unwrap()
.expect("failed to find module");
p.parse_module(&SIMPLE_MODULE).unwrap()
.expect("failed to find module");
let main_path = p.parse_module(&MAIN_MODULE).unwrap()
.expect("failed to find module");
let out = p.process_file(&main_path, &MAIN_MODULE).unwrap();
assert!(out.contains("const simple_scalar"), "definition of imported `const simple_scalar` is missing!");
assert!(out.contains("fn do_something_cool("), "definition of imported `fn do_something_cool` is missing!");
assert!(out.contains("const scalar"), "definition of imported dependency, `const scalar` is missing!");
assert!(out.contains("fn mult_some_nums("), "definition of imported dependency, `fn mult_some_nums` is missing!");
assert!(out.contains("fn do_something_cool("), "definition of imported `fn do_something_cool` is missing!");
assert!(out.contains("simple_scalar * do_something_cool"), "indirect imports of `simple_scalar * do_something_cool` is missing!");
assert!(out.contains("fn main("), "definition of `fn main` is missing!");
}
}

View File

@ -1,74 +0,0 @@
use std::{fs, io::Write};
use pest::Parser;
use pest_derive::Parser;
#[derive(Parser)]
#[grammar = "wgsl.pest"]
pub struct WgslParser;
fn main() {
/* let unparsed_file = fs::read_to_string("shaders/base.wgsl")
.expect("cannot read file");
let mut successful_parse = WgslParser::parse(Rule::file, &unparsed_file).unwrap();
let a = successful_parse.next().unwrap();
println!("got {}", a.as_str()); */
let unparsed_file = fs::read_to_string("shaders/base.wgsl")
.expect("cannot read file");
// add a new line to the end of the input to make the grammar happy
let unparsed_file = format!("{unparsed_file}\n");
let mut out_file = fs::File::create("shaders/out.wgsl").unwrap();
let file = WgslParser::parse(Rule::file, &unparsed_file)
.expect("unsuccessful parse") // unwrap the parse result
.next().unwrap(); // get and unwrap the `file` rule; never fails
for record in file.into_inner() {
match record.as_rule() {
Rule::command_line => {
for command_line in record.into_inner() {
match command_line.as_rule() {
Rule::preproc_command => {},
Rule::import_command => {
let mut shader_file_pairs = command_line.into_inner();
let shader_file = shader_file_pairs.next().unwrap();
let shader_file = shader_file.as_str();
// remove surrounding quotes
let shader_file = &shader_file[1..shader_file.len() - 1];
println!("found include for file: {}", shader_file);
let path = format!("shaders/{}", shader_file);
let included_file = fs::read(&path)
.expect("cannot read file");
let start_header = format!("// ==== START OF INCLUDE OF '{}' ====\n", path);
let end_header = format!("\n// ==== END OF INCLUDE OF '{}' ====\n", path);
out_file.write(start_header.as_bytes()).unwrap();
out_file.write(&included_file).unwrap();
out_file.write(end_header.as_bytes()).unwrap();
},
_ => unreachable!()
}
}
println!("found line");
},
Rule::shader_code_line => {
let input = record.as_str();
println!("in: {}", input);
let input = format!("{input}\n");
out_file.write(input.as_bytes()).unwrap();
},
Rule::EOI => (),
Rule::eol => (),
_ => unreachable!(),
}
}
}

765
src/preprocessor.rs Normal file
View File

@ -0,0 +1,765 @@
use std::{collections::{HashMap, HashSet, VecDeque}, fmt::Write, fs, path::Path};
use pest::{iterators::Pair, Parser};
use regex::Regex;
use tracing::{debug, debug_span, instrument};
use crate::{recurse_files, DefRequirement, Definition, Import, Module, Rule, WgslParser};
const RESERVED_WORDS: [&str; 201] = [
"NULL", "Self", "abstract", "active", "alignas", "alignof", "as", "asm", "asm_fragment",
"async", "attribute", "auto", "await", "become", "binding_array", "cast", "catch", "class",
"co_await", "co_return", "co_yield", "coherent", "column_major", "common", "compile",
"compile_fragment", "concept", "const_cast", "consteval", "constexpr", "constinit", "crate",
"debugger", "decltype", "delete", "demote", "demote_to_helper", "do", "dynamic_cast", "enum",
"explicit", "export", "extends", "extern", "external", "fallthrough", "filter", "final",
"finally", "friend", "from", "fxgroup", "get", "goto", "groupshared", "highp", "impl",
"implements", "import", "inline", "instanceof", "interface", "layout", "lowp", "macro",
"macro_rules", "match", "mediump", "meta", "mod", "module", "move", "mut", "mutable",
"namespace", "new", "nil", "noexcept", "noinline", "nointerpolation", "noperspective", "null",
"nullptr", "of", "operator", "package", "packoffset", "partition", "pass", "patch",
"pixelfragment", "precise", "precision", "premerge", "priv", "protected", "pub", "public",
"readonly", "ref", "regardless", "register", "reinterpret_cast", "require", "resource",
"restrict", "self", "set", "shared", "sizeof", "smooth", "snorm", "static", "static_assert",
"static_cast", "std", "subroutine", "super", "target", "template", "this", "thread_local",
"throw", "trait", "try", "type", "typedef", "typeid", "typename", "typeof", "union", "unless",
"unorm", "unsafe", "unsized", "use", "using", "varying", "virtual", "volatile", "wgsl",
"where", "with", "writeonly", "yield", "alias", "break", "case", "const", "const_assert",
"continue", "continuing", "default", "diagnostic", "discard", "else", "enable", "false", "fn",
"for", "if", "let", "loop", "override", "requires", "return", "struct", "switch", "true",
"var", "while",
// types
"vec4", "vec3", "vec2",
"mat2x2", "mat2x3", "mat2x4",
"mat3x2", "mat3x3", "mat3x4",
"mat4x2", "mat4x3", "mat4x4",
"f16", "f32", "i32", "u32",
"bool", "array", "atomic",
// texture types
"texture_1d",
"texture_2d", "texture_2d_array",
"texture_depth_2d", "texture_depth_2d_array",
"texture_depth_cube", "texture_depth_cube_array",
"texture_3d",
"texture_cube_array",
"sampler", "sampler_comparison",
];
#[derive(Debug, thiserror::Error)]
pub enum PreprocessorError {
#[error("{0}")]
IoError(#[from] std::io::Error),
#[error("error parsing {0}")]
ParserError(#[from] pest::error::Error<Rule>),
#[error("failure formatting preprocessor output to string ({0})")]
FormatError(#[from] std::fmt::Error),
#[error("unknown module import '{module}', in {from_module}")]
UnknownModule { from_module: String, module: String },
#[error("in {from_module}: unknown import from '{module}': `{item}`")]
UnknownImport { from_module: String, module: String, item: String },
#[error("import usage from `{from_module}` conflicts with local variable/function: `{name}`")]
ConflictingImport { from_module: String, name: String },
}
#[derive(Default)]
pub struct Processor {
pub modules: HashMap<String, Module>,
}
impl Processor {
pub fn new() -> Self {
Self::default()
}
//
/* fn add_func_requirements(&self, found_requirements: &mut HashSet<String>, requirements: &mut Vec<DefRequirement>, fn_name: &str) {
if found_requirements.contains(fn_name) {
return;
}
found_requirements.insert(fn_name.to_string());
debug!("Found call to `{}`", fn_name);
// ignore reserved words
if RESERVED_WORDS.contains(&fn_name) {
return;
}
let req = DefRequirement {
// module is discovered later
module: None,
name: fn_name.to_string(),
};
requirements.push(req);
} */
#[instrument(fields(module = module.name), skip_all)]
fn get_imports_in_block(&mut self, module: &mut Module, block: Pair<Rule>, found_requirements: &mut HashSet<String>) -> Vec<DefRequirement> {
let mut requirements = vec![];
for code in block.into_inner() {
match code.as_rule() {
Rule::shader_code_block | Rule::shader_code => {
let reqs = self.get_imports_in_block(module, code, found_requirements);
requirements.extend(reqs.into_iter());
},
Rule::shader_code_fn_usage => {
let mut usage_inner = code.into_inner();
let fn_name = usage_inner.next().unwrap().as_str();
if found_requirements.contains(fn_name) {
continue;
}
found_requirements.insert(fn_name.to_string());
debug!("Found call to `{}`", fn_name);
// ignore reserved words
if RESERVED_WORDS.contains(&fn_name) {
continue;
}
let req = DefRequirement {
// module is discovered later
module: None,
name: fn_name.to_string(),
};
requirements.push(req);
},
Rule::shader_external_fn => {
let mut usage_inner = code.into_inner();
let fn_path = usage_inner.next().unwrap().as_str();
if found_requirements.contains(fn_path) {
continue;
}
found_requirements.insert(fn_path.to_string());
let (fn_module, fn_name) = fn_path.rsplit_once("::").unwrap();
debug!("Found call to `{}::{}`", fn_module, fn_name);
let req = DefRequirement {
module: Some(fn_module.into()),
name: fn_name.to_string(),
};
requirements.push(req);
},
Rule::shader_external_variable => {
let pairs = code.into_inner();
// shader_external_variable is the only pair for this rule
let ident_path = pairs.as_str();
if found_requirements.contains(ident_path) {
continue;
}
found_requirements.insert(ident_path.to_string());
let (ident_module, ident_name) = ident_path.rsplit_once("::").unwrap();
debug!("Found call to `{}`", ident_name);
// ignore reserved words
if RESERVED_WORDS.contains(&ident_name) {
continue;
}
let req = DefRequirement {
module: Some(ident_module.into()),
name: ident_name.to_string(),
};
requirements.push(req);
},
Rule::newline => (),
Rule::cws => (),
Rule::shader_code_char => (),
Rule::shader_ident => {
let ident = code.as_str();
if found_requirements.contains(ident) {
continue;
}
found_requirements.insert(ident.to_string());
// ignore reserved words
if RESERVED_WORDS.contains(&ident) {
continue;
}
let req = DefRequirement {
// module is discovered later
module: None,
name: ident.to_string(),
};
requirements.push(req);
},
Rule::shader_value => (),
_ => unimplemented!("ran into unhandled rule: {:?}, {:?}", code.as_rule(), code.as_span())
}
}
requirements
}
/// Parse a module file to attempt to find the include identifier.
///
/// Returns `None` if the module does not define an include identifier.
//#[instrument(skip(self, module_src))]
pub fn parse_module(
&mut self,
module_src: &str,
) -> Result<Option<String>, PreprocessorError> {
let e = debug_span!("parse_module", module = tracing::field::Empty).entered();
let module_src = remove_comments(module_src);
let file = WgslParser::parse(Rule::file, &module_src)?
.next()
.unwrap(); // get and unwrap the `file` rule; never fails
let mut module = Module::default();
module.src = module_src.to_string();
for record in file.into_inner() {
match record.as_rule() {
Rule::command_line => {
// the parser has found a preprocessor command, figure out what it is
let mut pairs = record.into_inner();
let command_line = pairs.next().unwrap();
match command_line.as_rule() {
Rule::import_types_command => {
let mut inner = command_line.into_inner();
let import_module_command = inner.next().unwrap();
let mut import_module_command = import_module_command.into_inner();
let module_name = import_module_command.next().unwrap().as_str();
let types: Vec<String> = inner.map(|t| t.as_str().to_string()).collect();
debug!("Found imports from `{}`: `{:?}`", module_name, types);
// add these type imports to imports of the module
module.item_imports.entry(module_name.into())
.or_insert_with(|| Import {
module: module_name.into(),
imports: vec![],
})
.imports.extend(types.into_iter());
},
Rule::import_module_command => {
let mut inner = command_line.into_inner();
let module_name = inner.next().unwrap().as_str();
debug!("Found module import: `{}`", module_name);
module.module_imports.insert(module_name.into());
},
Rule::define_module_command => {
let mut shader_file_pairs = command_line.into_inner();
let module_name = shader_file_pairs.next().unwrap();
let module_name = module_name.as_str();
e.record("module", module_name);
debug!("Defined module as `{}`", module_name);
module.name = module_name.into();
}
_ => unreachable!(),
}
}
Rule::shader_code_line => {
for line in record.into_inner() {
match line.as_rule() {
Rule::shader_fn_def => {
let mut pairs = line.clone().into_inner();
// shader_ident is the only pair for this rule
let fn_name = pairs.next().unwrap().as_str().to_string();
debug!("Found function def: {fn_name}");
let fn_body = pairs.skip(2).next().unwrap();
let mut found_reqs = HashSet::default();
let requirements = self.get_imports_in_block(&mut module, fn_body, &mut found_reqs);
let line_span = line.as_span();
let start_pos = line_span.start();
let end_pos = line_span.end();
module.functions.insert(
fn_name.clone(),
Definition {
name: fn_name,
start_pos,
end_pos,
requirements
},
);
}
Rule::shader_const_def => {
let mut pairs = line.clone().into_inner();
// shader_ident is the only pair for this rule
let const_name = pairs.next().unwrap().as_str().to_string();
debug!("Found const def: `{const_name}`");
let line_span = line.as_span();
let start_pos = line_span.start();
let end_pos = line_span.end();
module.constants.insert(
const_name.clone(),
Definition {
name: const_name,
start_pos,
end_pos,
requirements: vec![],
},
);
}
Rule::shader_code => {
let mut shader_inner = line.clone().into_inner();
let code_type = shader_inner.next().unwrap();
match code_type.as_rule() {
Rule::shader_code_block => {
let mut found_reqs = HashSet::default();
self.get_imports_in_block(&mut module, code_type, &mut found_reqs);
},
Rule::shader_code_fn_usage => {
todo!("cannot handle usage at this level");
},
Rule::shader_code_char => {
todo!("I think this can be ignored");
},
_ => unimplemented!("ran into unhandled rule: {:?}", line.as_span())
}
}
Rule::cws => (),
Rule::newline => (),
_ => unimplemented!("ran into unhandled rule: ({:?}) {:?}", line.as_rule(), line.as_span())
}
}
}
Rule::newline => (),
Rule::EOI => (),
_ => unimplemented!("ran into unhandled rule: {:?}", record.as_span())
}
}
if module.name.is_empty() {
Ok(None)
} else {
let name = module.name.clone();
self.modules.insert(name.clone(), module);
Ok(Some(name))
}
}
/// Find files recursively in `path` with an extension in `extensions`, and parse them.
///
/// For each file that's found, [`Processor::parse_module`] is used to parse them.
///
/// Parameters:
/// * `path` - The path to search for files in.
/// * `extensions` - The extensions that the discovered files must have. Make sure they have
/// no leading '.'
#[instrument(skip(self, path, extensions))]
pub fn parse_modules<P: AsRef<Path>, const N: usize>(
&mut self,
path: P,
extensions: [&str; N],
) -> Result<usize, PreprocessorError> {
let files = recurse_files(path)?;
let mut parsed = 0;
for file in files {
if let Some(ext) = file.extension().and_then(|p| p.to_str()) {
if extensions.contains(&ext) {
let module_src = fs::read_to_string(file)?;
self.parse_module(&module_src)?;
parsed += 1;
}
}
}
Ok(parsed)
}
#[instrument(skip(self))]
fn generate_header(&mut self, module_path: &str) -> Result<String, PreprocessorError> {
let module = self.modules.get(module_path).unwrap();
let mut output = String::new();
compile_definitions(&self.modules, module, &mut output)?;
Ok(output)
}
#[instrument(skip(self, shader_code_rule, output))]
fn output_shader_code_line(&self, shader_code_rule: Pair<Rule>, output: &mut String) -> Result<(), std::fmt::Error> {
for line in shader_code_rule.into_inner() {
match line.as_rule() {
Rule::shader_external_fn => {
let mut pairs = line.clone().into_inner();
let fn_path = pairs.next().unwrap().as_str().to_string();
// the rest of the pairs are the arguments for the fn
let fn_args = pairs.as_str();
// remove the module from the identifier and write it to the output
if let Some((_, fn_name)) = fn_path.rsplit_once("::") {
output.write_str(fn_name)?;
output.write_fmt(format_args!("({})", fn_args))?;
} else {
// TODO: not really sure how this would get triggered
panic!("Unknown error, rule input: {}", line.as_str());
}
},
Rule::shader_external_variable => {
let pairs = line.clone().into_inner();
// the last pair is the name of the variable
let ident_name = pairs.last().unwrap().as_str();
output.write_str(ident_name)?;
},
//shader_external_variable
Rule::shader_fn_def => {
let mut rule_inner = line.into_inner();
let fn_name = rule_inner.next().unwrap().as_str();
let args = rule_inner.next().unwrap().as_str();
let fn_ret = rule_inner.next().unwrap().as_str();
let fn_body = rule_inner.next().unwrap();
let mut body_output = String::new();
self.output_shader_code_line(fn_body, &mut body_output)?;
// to escape {, must use two
output.write_fmt(format_args!("fn {}{} -> {} {{{}}}",
fn_name, args, fn_ret, body_output))?;
},
Rule::shader_code => {
self.output_shader_code_line(line, output)?;
},
Rule::shader_code_fn_usage | Rule::shader_value | Rule::shader_const_def | Rule::shader_ident | Rule::shader_code_char | Rule::cws => {
let input = line.as_str();
output.write_str(&input)?;
},
Rule::newline => {
output.write_str("\n")?;
},
_ => unimplemented!("ran into unhandled rule: {:?}, {:?}", line.as_rule(), line.as_str()),
}
}
Ok(())
}
#[instrument(skip(self, module_src))]
pub fn process_file(&mut self, module_path: &str, module_src: &str) -> Result<String, PreprocessorError> {
let module_src = remove_comments(module_src);
let mut out_string = String::new();
// the output will be at least the length of module_src
out_string.reserve(module_src.len());
let file = WgslParser::parse(Rule::file, &module_src)?
.next()
.unwrap(); // get and unwrap the `file` rule; never fails
let header = self.generate_header(module_path)?;
out_string.write_str("// START OF IMPORT HEADER\n")?;
out_string.write_str(&header)?;
out_string.write_str("// END OF IMPORT HEADER\n")?;
for record in file.into_inner() {
match record.as_rule() {
Rule::command_line => {
// the parser has found a preprocessor command, figure out what it is
let mut pairs = record.into_inner();
let command_line = pairs.next().unwrap();
match command_line.as_rule() {
Rule::import_module_command => (),
Rule::import_types_command => (),
Rule::define_module_command => (),
_ => unimplemented!("ran into unhandled rule: {:?}", command_line.as_span()),
}
},
Rule::cws => (),
Rule::shader_code_line => {
self.output_shader_code_line(record, &mut out_string)?;
},
Rule::newline => {
let input = record.as_str();
out_string.write_str(&input)?;
},
Rule::EOI => (),
_ => unimplemented!("ran into unhandled rule: {:?}", record.as_rule()),
}
}
Ok(out_string)
}
}
fn try_find_requirement_module(module: &Module, req_name: &str) -> Option<String> {
for import in module.item_imports.values() {
if import.imports.contains(&req_name.to_string()) {
return Some(import.module.clone());
}
}
None
}
#[instrument(fields(module = module.name), skip_all)]
fn compile_definitions(modules: &HashMap<String, Module>, module: &Module, output: &mut String) -> Result<(), PreprocessorError> {
for (_, funcs) in &module.functions {
let mut requirements = VecDeque::from(funcs.requirements.clone());
while let Some(mut req) = requirements.pop_front() {
if req.module.is_none() {
let mod_name = try_find_requirement_module(&module, &req.name);
req.module = mod_name;
}
if let Some(module_name) = &req.module {
let req_module = modules.get(module_name)
.unwrap_or_else(|| panic!("invalid module import: {}", module_name));
let req_def = req_module.functions.get(&req.name)
.or_else(|| req_module.constants.get(&req.name))
.unwrap_or_else(|| panic!("invalid import: {} from {}", req.name, module_name));
if !req_def.requirements.is_empty() {
let sub_req_names: Vec<String> = req_def.requirements.iter().map(|r| r.name.clone()).collect();
debug!("Found requirement: {}, with the following sub-requirements: {:?}", req_def.name, sub_req_names);
let mut requirements_output = String::new();
compile_definitions(modules, req_module, &mut requirements_output)?;
output.write_fmt(format_args!("\n// REQUIREMENTS OF {}::{}\n", module_name, req.name))?;
output.push_str(&requirements_output);
output.push_str("\n");
}
let func_src = &req_module.src[req_def.start_pos..req_def.end_pos];
output.write_fmt(format_args!("// SOURCE {}::{}\n", module_name, req.name))?;
output.push_str(func_src);
output.push_str("\n");
} else {
debug!("Could not find module for `{}`, assuming its local", req.name);
}
}
}
Ok(())
}
#[instrument(skip(text))]
fn remove_comments(text: &str) -> String {
let mut output = String::new();
let comments_regex = Regex::new(r"(//|/\*|\*/)").unwrap();
let mut comments = comments_regex.captures_iter(&text)
.map(|c| c.get(0).unwrap())
.peekable();
if comments.peek().is_none() {
return text.to_string();
}
let mut block_start = 0;
let mut scope_depth: u32 = 0;
loop {
let mut next = comments.next();
let mut end = next.map(|m| m.start()).unwrap_or(text.len());
while next.is_some() && block_start > end {
next = comments.next();
end = next.map(|m| m.start()).unwrap_or(text.len());
}
if scope_depth == 0 {
output.push_str(&text[block_start..end]);
}
match next {
None => return output,
Some(com) => {
match com.as_str() {
"//" => {
if scope_depth == 0 {
let line = &text[com.end()..];
let line_end = line.find("\n")
.unwrap_or(line.len());
// find the end of the line and continue
block_start = com.end() + line_end;
}
}
"/*" => {
scope_depth += 1;
block_start = com.end();
},
"*/" => {
scope_depth = scope_depth.saturating_sub(1);
block_start = com.end();
},
_ => unreachable!()
}
}
}
}
}
#[cfg(test)]
mod tests {
use crate::preprocessor::remove_comments;
#[test]
fn test_remove_line_comments() {
const INPUT: &'static str =
r#"
// PLEASE REMOVE ME
DONT TOUCH ME
"#;
let out = remove_comments(&INPUT);
assert!(!out.contains("PLEASE REMOVE ME"));
assert!(out.contains("DONT TOUCH ME"));
}
#[test]
fn test_remove_block_comments() {
const INPUT: &'static str =
r#"
/*
PLEASE REMOVE ME
*/
DONT TOUCH ME
"#;
let out = remove_comments(&INPUT);
assert!(!out.contains("PLEASE REMOVE ME"));
assert!(out.contains("DONT TOUCH ME"));
}
#[test]
fn test_remove_block_comments_single_line() {
const INPUT: &'static str =
r#"
/* PLEASE REMOVE ME */
/* PLEASE REMOVE ME
*/
DONT TOUCH ME
"#;
let out = remove_comments(&INPUT);
assert!(!out.contains("PLEASE REMOVE ME"));
assert!(out.contains("DONT TOUCH ME"));
}
#[test]
fn test_remove_block_comments_nested() {
const INPUT: &'static str =
r#"
/*
PLEASE REMOVE ME
/*
PLEASE REMOVE ME
*/
*/
DONT TOUCH ME
"#;
let out = remove_comments(&INPUT);
assert!(!out.contains("PLEASE REMOVE ME"));
assert!(out.contains("DONT TOUCH ME"));
}
#[test]
fn test_remove_block_comments_nested_same_line() {
const INPUT: &'static str =
r#"
/* PLEASE REMOVE ME
/* PLEASE REMOVE ME */
/* PLEASE REMOVE ME
*/
*/
DONT TOUCH ME
"#;
let out = remove_comments(&INPUT);
assert!(!out.contains("PLEASE REMOVE ME"), "{}", out);
assert!(out.contains("DONT TOUCH ME"));
}
#[test]
fn test_remove_block_comments_nested_same_line_cursed() {
const INPUT: &'static str =
r#"
/* PLEASE REMOVE ME
/* PLEASE REMOVE ME */
/* PLEASE REMOVE ME
*/
*/
/* PLEASE REMOVE ME /* PLEASE REMOVE ME */ /* PLEASE REMOVE ME */ */
//PLEASE REMOVE ME
DONT TOUCH ME
"#;
let out = remove_comments(&INPUT);
assert!(!out.contains("PLEASE REMOVE ME"), "{}", out);
assert!(out.contains("DONT TOUCH ME"), "{}", out);
}
#[test]
fn test_remove_several_comments() {
const INPUT: &'static str =
r#"// PLEASE REMOVE ME
// PLEASE REMOVE ME
const simple_scalar: f32 = 50.0;
// PLEASE REMOVE ME
// PLEASE REMOVE ME
const scalar: f32 = 5.0;
// PLEASE REMOVE ME
// PLEASE REMOVE ME
fn mult_some_nums(a: f32, b: f32) -> f32 {
let c = a * b;
return c;
}
// PLEASE REMOVE ME
fn do_something_cool(in: f32) -> f32 {
return scalar * mult_some_nums(in, 2.0);
}
// PLEASE REMOVE ME
/* PLEASE REMOVE ME /* PLEASE REMOVE ME */ /* PLEASE REMOVE ME */ */
"#;
let out = remove_comments(&INPUT);
assert!(!out.contains("PLEASE REMOVE ME"), "{}", out);
assert!(out.contains("const simple_scalar: f32 = 50.0;"), "{}", out);
assert!(out.contains("const scalar: f32 = 5.0;"), "{}", out);
assert!(out.contains("fn mult_some_nums(a: f32, b: f32) -> f32 {"), "{}", out);
assert!(out.contains("fn do_something_cool(in: f32) -> f32 {"), "{}", out);
}
}

View File

@ -1,16 +1,70 @@
shader_file = { "\"" ~ ASCII_ALPHA* ~ "." ~ ASCII_ALPHA* ~ "\"" }
import_command = { "import" ~ ws ~ shader_file }
preproc_prefix = { "#" }
shader_ident = { (ASCII_ALPHANUMERIC | "_")+ }
// a shader generic could have multiple generics, i.e., vec4<f32>
//shader_generic_type = { shader_ident ~ "<" ~ shader_ident ~ ">"+ }
//shader_type = { shader_generic_type | shader_ident }
shader_type = { shader_ident ~ ("<" ~ shader_ident ~ ">")? }
shader_module = { shader_ident ~ ( "::" ~ shader_ident)* }
import_module_command = { "import" ~ ws ~ shader_module }
import_list = _{ "{" ~ shader_ident ~ (ws* ~ "," ~ ws* ~ shader_ident)* ~ "}" }
import_types_command = { import_module_command ~ "::" ~ import_list }
//import_types_command = { "import" ~ shader_ident ~ ( "::" ~ shader_ident)* ~ "::" ~ import_list }
import_command = _{ import_types_command | import_module_command }
define_module_command = { "define_module" ~ ws ~ shader_module }
preproc_prefix = _{ "#" }
// a line of preprocessor commands
command_line = { preproc_prefix ~ import_command }
command_line = { preproc_prefix ~ (define_module_command | import_command) ~ newline }
// all characters used by wgsl
shader_code = { ASCII_ALPHANUMERIC | "{" | "}" | "@" | "-" | "+" | "=" | "(" | ")" | ">" | "<" | ";" | "." | "_" | "," }
// a line of shader code, including white space
shader_code_line = { (ws* ~ shader_code ~ ws*)* }
shader_code_char = { "@" | "-" | "+" | "*" | "/" | "=" | "(" | ")" | ">" | "<" | ";" | ":" | "." | "_" | "," }
shader_code_block = { "{" ~ newline* ~ (cws* ~ (shader_actual_code_line ~ cws*)* ~ newline)+ ~ cws* ~ "}" }
// an fn argument can be another function use
shader_code_fn_arg = _{ shader_code_fn_usage | shader_value | shader_ident }
shader_code_fn_usage = { shader_ident ~ "(" ~ shader_code_fn_arg ~ ("," ~ ws* ~ shader_code_fn_arg)* ~ ")" }
//shader_code_fn_usage = { shader_ident ~ "(in, 2.0)" }
shader_code = { shader_code_block | shader_code_fn_usage | shader_value | shader_ident | shader_code_char }
file = { SOI ~ ( (command_line | shader_code_line) ~ NEWLINE)* ~ EOI }
// usages of code from another module
shader_external_variable = { shader_ident ~ ( "::" ~ shader_ident)+ }
//shader_fn_args2 = { shader_var_name_type? ~ (ws* ~ "," ~ ws* ~ shader_var_name_type)* }
shader_external_fn = { shader_external_variable ~ "(" ~ shader_code_fn_arg ~ ("," ~ ws* ~ shader_code_fn_arg)* ~ ")" }
shader_external_code = _{ shader_external_fn | shader_external_variable }
shader_actual_code_line = _{ shader_external_code | shader_code }
shader_value_num = { ASCII_DIGIT+ ~ ( "." ~ ASCII_DIGIT+ )? }
shader_value_bool = { "true" | "false" }
shader_value = { shader_value_bool | shader_value_num }
// defines type of something i.e., `: f32`, `: u32`, etc.
shader_var_type = { ":" ~ ws* ~ shader_type }
shader_const_def = { "const" ~ ws ~ shader_ident ~ (ws* ~ shader_var_type)? ~ ws* ~ "=" ~ ws* ~ shader_value ~ ";" }
shader_var_name_type = { shader_ident ~ shader_var_type }
shader_fn_args = { "(" ~ shader_var_name_type? ~ (ws* ~ "," ~ ws* ~ shader_var_name_type)* ~ ")" }
// the body of a function, including the opening and closing brackets
shader_fn_body = { "{" ~ newline* ~ (cws* ~ (shader_actual_code_line ~ cws*)* ~ newline)+ ~ "}" }
shader_fn_def = {
"fn" ~ ws ~ shader_ident ~ shader_fn_args ~ ws ~ "->" ~ ws ~ shader_type ~ ws ~ shader_fn_body
}
// a line of shader code, including white space
shader_code_line = {
shader_fn_def |
(shader_const_def ~ newline) |
ws* ~ newline
}
//shader_code_line = { shader_fn_def | shader_const_def | (ws* ~ (shader_external_code | shader_code)* ~ ws*) }
file = { SOI ~ ( (command_line | shader_code_line) )* ~ EOI }
// whitespace
ws = _{ " " | "\t" }
// capturing white space
cws = { " " | "\t" }
newline = { "\n" | "\r\n" | "\r" }