diff --git a/src/preprocessor.rs b/src/preprocessor.rs index 65ecbb1..3863edc 100644 --- a/src/preprocessor.rs +++ b/src/preprocessor.rs @@ -393,6 +393,35 @@ impl Processor { }, ); }, + Rule::shader_workgroup_var => { + let mut pairs = line.clone().into_inner(); + let name = pairs.next().unwrap().as_str(); + + let var_type = pairs.next().unwrap(); + let mut var_type_inner = var_type.into_inner(); + let inner_type = var_type_inner.next().unwrap(); + + let requirement_vec = Self::get_type_requirement(&mut module, inner_type) + .map(|r| vec![r]) + .unwrap_or_default(); + + let req_names = requirement_vec.first().map(|r| r.name.clone()); + debug!("Found binding def: `{}` with requirement: {:?}", name, req_names); + + let line_span = line.as_span(); + let start_pos = line_span.start(); + let end_pos = line_span.end(); + + module.vars.insert( + name.into(), + Definition { + name: name.into(), + start_pos, + end_pos, + requirements: requirement_vec, + }, + ); + }, Rule::shader_binding_def => { let mut pairs = line.clone().into_inner(); let _group_binding = pairs.next(); diff --git a/src/wgsl.pest b/src/wgsl.pest index a41830f..a16e11f 100644 --- a/src/wgsl.pest +++ b/src/wgsl.pest @@ -18,8 +18,9 @@ preproc_prefix = _{ "#" } command_line = { preproc_prefix ~ (define_module_command | import_command) ~ newline } // all characters used by wgsl -shader_code_char = { "@" | "-" | "+" | "*" | "/" | "=" | "(" | ")" | ">" | "<" | ";" | ":" | "." | "_" | "," | "&" | "|" | "[" | "]" | ASCII_ALPHA } -shader_code_block = { "{" ~ newline* ~ (cws* ~ (shader_actual_code_line ~ cws*)* ~ newline)+ ~ cws* ~ "}" } +shader_code_char = { "@" | "-" | "+" | "*" | "/" | "=" | "(" | ")" | ">" | "<" | ";" | ":" | "." | "_" | "," | "&" | "|" | "[" | "]" | "!" | ASCII_ALPHA } +//shader_code_block = { "{" ~ newline* ~ (cws* ~ (shader_actual_code_line ~ cws*)* ~ newline)+ ~ cws* ~ "}" } +shader_code_block = { "{" ~ newline* ~ (cws* ~ (shader_actual_code_line | newline)+ ~ cws* ~ newline*)* ~ cws* ~ "}" } // an fn argument can be another function use shader_code_fn_arg = _{ shader_code_fn_usage | shader_value | ident } shader_code_fn_usage = { ident ~ "(" ~ shader_code_fn_arg ~ ("," ~ ws* ~ shader_code_fn_arg)* ~ ")" } @@ -46,16 +47,19 @@ shader_struct_def = { "struct" ~ ws ~ ident ~ ws* ~ "{" ~ NEWLINE ~ (ws* ~ NEWLINE)* ~ (cws* ~ (NEWLINE | shader_struct_field) ~ ("," ~ NEWLINE)?)* ~ - "}" + "}" ~ ";"? } shader_group_binding = { "@group(" ~ NUMBER ~ ") @binding(" ~ NUMBER ~ ")" } shader_binding_var_constraint = { "<" ~ - ( "uniform" | "private" | "workgroup" | ("storage" ~ ("," ~ ws? ~ "read" ~ "_write"? )?) ) ~ + ( "uniform" | "private" | ("storage" ~ ("," ~ ws? ~ "read" ~ "_write"? )?) ) ~ ">" } -shader_binding_def = { shader_group_binding ~ (NEWLINE | ws)? ~ "var" ~ shader_binding_var_constraint? ~ ws ~ ident ~ ws* ~ ":" ~ ws* ~ type ~ ";" } +shader_workgroup_var = { "var" ~ ws ~ ident ~ ws* ~ ":" ~ ws* ~ type ~ ";" } +shader_binding_def = { + (shader_group_binding ~ (NEWLINE | ws)? ~ "var" ~ shader_binding_var_constraint? ~ ws ~ ident ~ ws* ~ ":" ~ ws* ~ type ~ ";") +} shader_var_name_type = { ident ~ ":" ~ ws* ~ type } shader_fn_args = { "(" ~ @@ -64,15 +68,19 @@ shader_fn_args = { "(" ~ //shader_fn_attribute? ~ shader_var_name_type? ~ newline_or_ws? ~ //(newline_or_ws? ~ "," ~ newline_or_ws? ~ shader_var_name_type)* ( shader_fn_attribute? ~ shader_var_name_type ~ newline_or_ws? ~ ","? ~ newline_or_ws? )* - ~ newline_or_ws? ~ ")" + ~ ")" } // the body of a function, including the opening and closing brackets shader_fn_body = { "{" ~ newline* ~ (cws* ~ (shader_actual_code_line ~ cws*)* ~ newline)+ ~ "}" } -shader_fn_attribute = { "@" ~ ASCII_ALPHA+ ~ ("(" ~ ident ~ ")")? ~ (NEWLINE? ~ ws*) } +shader_fn_workgroup_size_attribute = { "@workgroup_size" ~ "(" ~ ( ASCII_DIGIT+ ~ ","? ~ cws*){3} ~ ")" ~ (NEWLINE? ~ ws*) } +shader_fn_attribute = { + shader_fn_workgroup_size_attribute | + (cws* ~ "@" ~ ASCII_ALPHA+ ~ ("(" ~ ident ~ ")")? ~ (NEWLINE? ~ ws*)) +} shader_fn_def = { - shader_fn_attribute? ~ - "fn" ~ ws ~ ident ~ shader_fn_args ~ ws* ~ "->" ~ ws* ~ shader_fn_attribute? ~ type ~ ws* ~ shader_fn_body + shader_fn_attribute{0,2} ~ + "fn" ~ ws ~ ident ~ shader_fn_args ~ ws* ~ ("->" ~ ws* ~ shader_fn_attribute? ~ type ~ ws*)? ~ shader_fn_body } // a line of shader code, including white space @@ -80,6 +88,7 @@ shader_code_line = { shader_fn_def | (shader_struct_def ~ newline?) | (shader_const_def ~ newline?) | + (shader_workgroup_var ~ newline?) | (shader_binding_def ~ newline?) | (shader_type_alias_def ~ newline?) | ws* ~ newline