add tests, fix removing comments

This commit is contained in:
SeanOMik 2024-08-08 21:29:52 -04:00
parent 7940cbdba9
commit 40b9581b46
6 changed files with 380 additions and 65 deletions

36
.vscode/launch.json vendored
View File

@ -21,6 +21,42 @@
}, },
"args": [], "args": [],
"cwd": "${workspaceFolder}" "cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug unit tests in executable 'shader_prepoc'",
"cargo": {
"args": [
"test",
"--no-run",
"--bin=shader_prepoc",
"--package=shader_prepoc"
],
"filter": {
"name": "shader_prepoc",
"kind": "bin"
} }
},
"args": [],
"cwd": "${workspaceFolder}"
},
/* {
"type": "lldb",
"request": "launch",
"name": "Test shader_prepoc",
"cargo": {
"args": [
"test",
"--bin=shader_prepoc",
],
"filter": {
"name": "shader_prepoc",
"kind": "bin"
}
},
"args": [],
"cwd": "${workspaceFolder}"
} */
] ]
} }

39
Cargo.lock generated
View File

@ -2,6 +2,15 @@
# It is not intended for manual editing. # It is not intended for manual editing.
version = 3 version = 3
[[package]]
name = "aho-corasick"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
dependencies = [
"memchr",
]
[[package]] [[package]]
name = "block-buffer" name = "block-buffer"
version = "0.10.4" version = "0.10.4"
@ -171,6 +180,35 @@ dependencies = [
"proc-macro2", "proc-macro2",
] ]
[[package]]
name = "regex"
version = "1.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619"
dependencies = [
"aho-corasick",
"memchr",
"regex-automata",
"regex-syntax",
]
[[package]]
name = "regex-automata"
version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
]
[[package]]
name = "regex-syntax"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
[[package]] [[package]]
name = "sha2" name = "sha2"
version = "0.10.8" version = "0.10.8"
@ -188,6 +226,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"pest", "pest",
"pest_derive", "pest_derive",
"regex",
"thiserror", "thiserror",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",

View File

@ -6,6 +6,7 @@ edition = "2021"
[dependencies] [dependencies]
pest = "2.7.11" pest = "2.7.11"
pest_derive = "2.7.11" pest_derive = "2.7.11"
regex = "1.10.6"
thiserror = "1.0.63" thiserror = "1.0.63"
tracing = "0.1.40" tracing = "0.1.40"
tracing-subscriber = "0.3.18" tracing-subscriber = "0.3.18"

View File

@ -2,16 +2,6 @@
const scalar: f32 = 5.0; const scalar: f32 = 5.0;
// test to ignore comments
/*
some test ig
/* inner comment */
*/
/* c-style comment */
fn mult_some_nums(a: f32, b: f32) -> f32 { fn mult_some_nums(a: f32, b: f32) -> f32 {
let c = a * b; let c = a * b;
return c; return c;

View File

@ -21,8 +21,6 @@ fn main() {
.init(); .init();
let mut p = Processor::new(); let mut p = Processor::new();
//let f = p.parse_modules("shaders", ["wgsl"]).unwrap();
//println!("Parsed {} modules:", f);
let inner_include_src = fs::read_to_string("shaders/inner_include.wgsl").unwrap(); let inner_include_src = fs::read_to_string("shaders/inner_include.wgsl").unwrap();
p.parse_module(&inner_include_src) p.parse_module(&inner_include_src)
.unwrap() .unwrap()
@ -168,3 +166,80 @@ pub(crate) fn recurse_files(path: impl AsRef<Path>) -> std::io::Result<Vec<PathB
Ok(buf) Ok(buf)
} }
#[cfg(test)]
mod tests {
use crate::Processor;
const INNER_MODULE: &'static str =
r#"#define_module engine::inner
const scalar: f32 = 5.0;
fn mult_some_nums(a: f32, b: f32) -> f32 {
let c = a * b;
return c;
}"#;
const SIMPLE_MODULE: &'static str =
r#"#define_module simple
#import engine::inner::{scalar, mult_some_nums}
const simple_scalar: f32 = 50.0;
fn do_something_cool(in: f32) -> f32 {
return scalar * mult_some_nums(in, 2.0);
}"#;
const MAIN_MODULE: &'static str =
r#"
#define_module base
#import simple
fn main() -> vec4<f32> {
let a = simple::simple_scalar * simple::do_something_cool(10.0);
return vec4<f32>(vec3<f32>(a), 1.0);
}"#;
#[test]
fn single_layer_import() {
let mut p = Processor::new();
p.parse_module(&INNER_MODULE).unwrap()
.expect("failed to find module");
let simple_path = p.parse_module(&SIMPLE_MODULE).unwrap()
.expect("failed to find module");
let out = p.process_file(&simple_path, &SIMPLE_MODULE).unwrap();
assert!(out.contains("const scalar"), "definition of imported `const scalar` is missing!");
assert!(out.contains("fn mult_some_nums"), "definition of imported `fn mult_some_nums` is missing!");
assert!(out.contains("fn do_something_cool("), "definition of `fn do_something_cool` is missing!");
}
/// Tests importing things that imports depend on, and indirect usage of things,
/// i.e., `simple::some_const`.
#[test]
fn double_layer_import_indirect_imports() {
let mut p = Processor::new();
p.parse_module(&INNER_MODULE).unwrap()
.expect("failed to find module");
p.parse_module(&SIMPLE_MODULE).unwrap()
.expect("failed to find module");
let main_path = p.parse_module(&MAIN_MODULE).unwrap()
.expect("failed to find module");
let out = p.process_file(&main_path, &MAIN_MODULE).unwrap();
assert!(out.contains("const simple_scalar"), "definition of imported `const simple_scalar` is missing!");
assert!(out.contains("fn do_something_cool("), "definition of imported `fn do_something_cool` is missing!");
assert!(out.contains("const scalar"), "definition of imported dependency, `const scalar` is missing!");
assert!(out.contains("fn mult_some_nums("), "definition of imported dependency, `fn mult_some_nums` is missing!");
assert!(out.contains("fn do_something_cool("), "definition of imported `fn do_something_cool` is missing!");
assert!(out.contains("simple_scalar * do_something_cool"), "indirect imports of `simple_scalar * do_something_cool` is missing!");
assert!(out.contains("fn main("), "definition of `fn main` is missing!");
}
}

View File

@ -1,6 +1,7 @@
use std::{collections::{HashMap, HashSet, VecDeque}, fmt::Write, fs, path::Path}; use std::{collections::{HashMap, HashSet, VecDeque}, fmt::Write, fs, path::Path};
use pest::{iterators::Pair, Parser}; use pest::{iterators::Pair, Parser};
use regex::Regex;
use tracing::{debug, debug_span, instrument}; use tracing::{debug, debug_span, instrument};
use crate::{recurse_files, DefRequirement, Definition, Import, Module, Rule, WgslParser}; use crate::{recurse_files, DefRequirement, Definition, Import, Module, Rule, WgslParser};
@ -72,7 +73,8 @@ impl Processor {
Self::default() Self::default()
} }
fn add_func_requirements(&self, found_requirements: &mut HashSet<String>, requirements: &mut Vec<DefRequirement>, fn_name: &str) { //
/* fn add_func_requirements(&self, found_requirements: &mut HashSet<String>, requirements: &mut Vec<DefRequirement>, fn_name: &str) {
if found_requirements.contains(fn_name) { if found_requirements.contains(fn_name) {
return; return;
} }
@ -91,7 +93,7 @@ impl Processor {
name: fn_name.to_string(), name: fn_name.to_string(),
}; };
requirements.push(req); requirements.push(req);
} } */
#[instrument(fields(module = module.name), skip_all)] #[instrument(fields(module = module.name), skip_all)]
fn get_imports_in_block(&mut self, module: &mut Module, block: Pair<Rule>, found_requirements: &mut HashSet<String>) -> Vec<DefRequirement> { fn get_imports_in_block(&mut self, module: &mut Module, block: Pair<Rule>, found_requirements: &mut HashSet<String>) -> Vec<DefRequirement> {
@ -107,7 +109,24 @@ impl Processor {
let mut usage_inner = code.into_inner(); let mut usage_inner = code.into_inner();
let fn_name = usage_inner.next().unwrap().as_str(); let fn_name = usage_inner.next().unwrap().as_str();
self.add_func_requirements(found_requirements, &mut requirements, fn_name); if found_requirements.contains(fn_name) {
continue;
}
found_requirements.insert(fn_name.to_string());
debug!("Found call to `{}`", fn_name);
// ignore reserved words
if RESERVED_WORDS.contains(&fn_name) {
continue;
}
let req = DefRequirement {
// module is discovered later
module: None,
name: fn_name.to_string(),
};
requirements.push(req);
}, },
Rule::shader_external_fn => { Rule::shader_external_fn => {
let mut usage_inner = code.into_inner(); let mut usage_inner = code.into_inner();
@ -192,11 +211,9 @@ impl Processor {
&mut self, &mut self,
module_src: &str, module_src: &str,
) -> Result<Option<String>, PreprocessorError> { ) -> Result<Option<String>, PreprocessorError> {
//let current_span = Span::current();
//let e = current_span.entered();
let e = debug_span!("parse_module", module = tracing::field::Empty).entered(); let e = debug_span!("parse_module", module = tracing::field::Empty).entered();
let module_src = remove_comments(module_src)?; let module_src = remove_comments(module_src);
let file = WgslParser::parse(Rule::file, &module_src)? let file = WgslParser::parse(Rule::file, &module_src)?
.next() .next()
@ -423,17 +440,17 @@ impl Processor {
output.write_fmt(format_args!("fn {}{} -> {} {{{}}}", output.write_fmt(format_args!("fn {}{} -> {} {{{}}}",
fn_name, args, fn_ret, body_output))?; fn_name, args, fn_ret, body_output))?;
}, },
Rule::shader_code | Rule::shader_const_def => { Rule::shader_code => {
self.output_shader_code_line(line, output)?; self.output_shader_code_line(line, output)?;
}, },
Rule::shader_code_fn_usage | Rule::shader_value | Rule::shader_ident | Rule::shader_code_char | Rule::cws => { Rule::shader_code_fn_usage | Rule::shader_value | Rule::shader_const_def | Rule::shader_ident | Rule::shader_code_char | Rule::cws => {
let input = line.as_str(); let input = line.as_str();
output.write_str(&input)?; output.write_str(&input)?;
}, },
Rule::newline => { Rule::newline => {
output.write_str("\n")?; output.write_str("\n")?;
}, },
_ => unimplemented!("ran into unhandled rule: {:?}", line.as_rule()), _ => unimplemented!("ran into unhandled rule: {:?}, {:?}", line.as_rule(), line.as_str()),
} }
} }
@ -442,7 +459,7 @@ impl Processor {
#[instrument(skip(self, module_src))] #[instrument(skip(self, module_src))]
pub fn process_file(&mut self, module_path: &str, module_src: &str) -> Result<String, PreprocessorError> { pub fn process_file(&mut self, module_path: &str, module_src: &str) -> Result<String, PreprocessorError> {
let module_src = remove_comments(module_src)?; let module_src = remove_comments(module_src);
let mut out_string = String::new(); let mut out_string = String::new();
// the output will be at least the length of module_src // the output will be at least the length of module_src
@ -543,49 +560,206 @@ fn compile_definitions(modules: &HashMap<String, Module>, module: &Module, outpu
} }
#[instrument(skip(text))] #[instrument(skip(text))]
fn remove_comments(text: &str) -> Result<String, std::fmt::Error> { fn remove_comments(text: &str) -> String {
let mut output = String::new(); let mut output = String::new();
output.reserve(text.len());
let mut comment_layers = 0; let comments_regex = Regex::new(r"(//|/\*|\*/)").unwrap();
for line in text.lines() { let mut comments = comments_regex.captures_iter(&text)
.map(|c| c.get(0).unwrap())
.peekable();
if let Some(line_comment_start) = line.find("//") { if comments.peek().is_none() {
let uncommented = &line[..line_comment_start]; return text.to_string();
output.write_str(uncommented)?;
continue;
} }
let (block_start, block_end) = (line.find("/*"), line.find("*/")); let mut block_start = 0;
if let (Some(start), Some(end)) = (block_start, block_end) { let mut scope_depth: u32 = 0;
if comment_layers == 0 {
let before = &line[..start];
output.write_str(before)?;
let after = &line[end + 2..]; loop {
output.write_str(after)?; let mut next = comments.next();
} let mut end = next.map(|m| m.start()).unwrap_or(text.len());
} else if let Some(block_comment_start) = block_start {
if comment_layers == 0 { while next.is_some() && block_start > end {
let uncommented = &line[..block_comment_start]; next = comments.next();
output.write_str(uncommented)?; end = next.map(|m| m.start()).unwrap_or(text.len());
} }
comment_layers += 1; if scope_depth == 0 {
} else if let Some(block_uncomment_start) = block_end { output.push_str(&text[block_start..end]);
if comment_layers == 1 {
let uncommented = &line[block_uncomment_start + 2..];
output.write_str(uncommented)?;
} }
comment_layers -= 1; match next {
} else if comment_layers == 0 { None => return output,
output.write_str(line)?; Some(com) => {
output.write_str("\n")?; match com.as_str() {
"//" => {
if scope_depth == 0 {
let line = &text[com.end()..];
let line_end = line.find("\n")
.unwrap_or(line.len());
// find the end of the line and continue
block_start = com.end() + line_end;
}
}
"/*" => {
scope_depth += 1;
block_start = com.end();
},
"*/" => {
scope_depth = scope_depth.saturating_sub(1);
block_start = com.end();
},
_ => unreachable!()
}
}
}
}
} }
#[cfg(test)]
mod tests {
use crate::preprocessor::remove_comments;
#[test]
fn test_remove_line_comments() {
const INPUT: &'static str =
r#"
// PLEASE REMOVE ME
DONT TOUCH ME
"#;
let out = remove_comments(&INPUT);
assert!(!out.contains("PLEASE REMOVE ME"));
assert!(out.contains("DONT TOUCH ME"));
} }
Ok(output) #[test]
fn test_remove_block_comments() {
const INPUT: &'static str =
r#"
/*
PLEASE REMOVE ME
*/
DONT TOUCH ME
"#;
let out = remove_comments(&INPUT);
assert!(!out.contains("PLEASE REMOVE ME"));
assert!(out.contains("DONT TOUCH ME"));
}
#[test]
fn test_remove_block_comments_single_line() {
const INPUT: &'static str =
r#"
/* PLEASE REMOVE ME */
/* PLEASE REMOVE ME
*/
DONT TOUCH ME
"#;
let out = remove_comments(&INPUT);
assert!(!out.contains("PLEASE REMOVE ME"));
assert!(out.contains("DONT TOUCH ME"));
}
#[test]
fn test_remove_block_comments_nested() {
const INPUT: &'static str =
r#"
/*
PLEASE REMOVE ME
/*
PLEASE REMOVE ME
*/
*/
DONT TOUCH ME
"#;
let out = remove_comments(&INPUT);
assert!(!out.contains("PLEASE REMOVE ME"));
assert!(out.contains("DONT TOUCH ME"));
}
#[test]
fn test_remove_block_comments_nested_same_line() {
const INPUT: &'static str =
r#"
/* PLEASE REMOVE ME
/* PLEASE REMOVE ME */
/* PLEASE REMOVE ME
*/
*/
DONT TOUCH ME
"#;
let out = remove_comments(&INPUT);
assert!(!out.contains("PLEASE REMOVE ME"), "{}", out);
assert!(out.contains("DONT TOUCH ME"));
}
#[test]
fn test_remove_block_comments_nested_same_line_cursed() {
const INPUT: &'static str =
r#"
/* PLEASE REMOVE ME
/* PLEASE REMOVE ME */
/* PLEASE REMOVE ME
*/
*/
/* PLEASE REMOVE ME /* PLEASE REMOVE ME */ /* PLEASE REMOVE ME */ */
//PLEASE REMOVE ME
DONT TOUCH ME
"#;
let out = remove_comments(&INPUT);
assert!(!out.contains("PLEASE REMOVE ME"), "{}", out);
assert!(out.contains("DONT TOUCH ME"), "{}", out);
}
#[test]
fn test_remove_several_comments() {
const INPUT: &'static str =
r#"// PLEASE REMOVE ME
// PLEASE REMOVE ME
const simple_scalar: f32 = 50.0;
// PLEASE REMOVE ME
// PLEASE REMOVE ME
const scalar: f32 = 5.0;
// PLEASE REMOVE ME
// PLEASE REMOVE ME
fn mult_some_nums(a: f32, b: f32) -> f32 {
let c = a * b;
return c;
}
// PLEASE REMOVE ME
fn do_something_cool(in: f32) -> f32 {
return scalar * mult_some_nums(in, 2.0);
}
// PLEASE REMOVE ME
/* PLEASE REMOVE ME /* PLEASE REMOVE ME */ /* PLEASE REMOVE ME */ */
"#;
let out = remove_comments(&INPUT);
assert!(!out.contains("PLEASE REMOVE ME"), "{}", out);
assert!(out.contains("const simple_scalar: f32 = 50.0;"), "{}", out);
assert!(out.contains("const scalar: f32 = 5.0;"), "{}", out);
assert!(out.contains("fn mult_some_nums(a: f32, b: f32) -> f32 {"), "{}", out);
assert!(out.contains("fn do_something_cool(in: f32) -> f32 {"), "{}", out);
}
} }