diff --git a/src/headers_gen.rs b/src/headers_gen.rs index c8a2335..eb98992 100644 --- a/src/headers_gen.rs +++ b/src/headers_gen.rs @@ -28,16 +28,16 @@ pub(super) fn merge_defines<'a>(dst: &mut Vec<&'a [lexer::Token<'a>]>, src: &[&' /// Returns an error if there are any duplicate definitions /// Otherwise, adds all definitions in `src` to `dst` -pub(super) fn merge_structs<'a>(dst: &mut Vec<&'a [lexer::Token<'a>]>, src: &[&'a [lexer::Token<'a>]]) -> Result<()> { +pub(super) fn merge_udts<'a>(dst: &mut Vec<&'a [lexer::Token<'a>]>, src: &[&'a [lexer::Token<'a>]]) -> Result<()> { let mut dst_set = HashSet::new(); for &tokens in dst.iter() { - let s = lexer::get_struct_name(tokens); + let s = lexer::get_udt_name(tokens); dst_set.insert(s); } for &tokens in src.iter() { - let s = lexer::get_struct_name(tokens); + let s = lexer::get_udt_name(tokens); if dst_set.contains(&s) { return Err(anyhow!("Duplicate struct definitions for {}", s)); } @@ -46,4 +46,39 @@ pub(super) fn merge_structs<'a>(dst: &mut Vec<&'a [lexer::Token<'a>]>, src: &[&' dst.extend_from_slice(src); Ok(()) +} + +/// Expects raw source code and an include path (in the form `"../include/filename.h"`) +/// This will do nothing and return `code` if the include statement already exists, otherwise +/// it will insert it at the end of all the include statements +pub(super) fn insert_self_include(code: String, include: &str) -> String { + let mut code_lines: Vec<&str> = code.lines().collect(); + + let contains_include = code_lines.iter().any(|&line| { + line.trim().starts_with("#") && + line.contains("include") && + line.contains(include) + }); + + if contains_include { + return code; + } + + let mut line_idx: usize = 0; + + for (i, &line) in code_lines.iter().enumerate() { + let is_include_statement = line.trim().starts_with("#") && + line.contains("include") && + (line.contains("<") || line.contains("\"")); + + if is_include_statement { + line_idx = i; + } + } + + let include_line = format!("#include {}", include); + + code_lines.insert(line_idx + 1, &include_line); + + code_lines.join("\n") } \ No newline at end of file diff --git a/src/lexer.rs b/src/lexer.rs index 3b4fd45..c581d33 100644 --- a/src/lexer.rs +++ b/src/lexer.rs @@ -1,5 +1,7 @@ use anyhow::{anyhow, Result}; +const UDT_KWARGS: &[&str] = &["struct", "enum", "union"]; + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Token<'a> { Object(&'a str), @@ -36,10 +38,22 @@ pub enum Token<'a> { NewLine, } +#[inline] +fn is_udt_kwargs(token: &Token) -> bool { + match token { + Token::Object("struct") + | Token::Object("enum") + | Token::Object("union") => { + true + } + _ => false, + } +} + impl<'a> Token<'a> { pub fn tokens_to_string(tokens: &[Token]) -> String { - if tokens.len() >= 2 && (tokens[0] == Self::Object("struct") || tokens[1] == Self::Object("struct")) { - return Self::struct_tokens_to_string(tokens); + if tokens.len() >= 2 && (is_udt_kwargs(&tokens[0]) || is_udt_kwargs(&tokens[1])) { + return Self::udt_tokens_to_string(tokens); } let space_after = [Token::Comma, Token::Asterisk]; @@ -73,7 +87,7 @@ impl<'a> Token<'a> { string } - fn struct_tokens_to_string(tokens: &[Token]) -> String { + fn udt_tokens_to_string(tokens: &[Token]) -> String { let mut output = String::new(); let mut indent_level = 0; let mut start_of_line = true; @@ -161,7 +175,6 @@ impl<'a> Token<'a> { pub fn clean_source_code(code: String) -> String { // TDOD: skip over any `//` of `/*` that are in string literals - let mut cleaned = String::with_capacity(code.len()); let mut in_block_comment = false; // whether we're inside /* ... */ @@ -244,6 +257,51 @@ pub fn tokenize(code: &str) -> Result> { Ok(tokens) } +pub fn tokenize_unclean(code: &str) -> Result<(Vec, Vec)> { + let code_bytes = code.as_bytes(); + let mut tokens = Vec::with_capacity(4096); + let mut byte_idx = Vec::with_capacity(4096); + + let mut idx: usize = 0; + while idx < code.len() { + if code[idx..].starts_with("//") { + idx += code[idx..].find('\n').unwrap_or(code.len()) + 1; + continue; + } else if code[idx..].starts_with("/*") { + idx += code[idx..].find("*/").unwrap_or(code.len()) + 2; + continue; + } + + if code_bytes[idx] == ' ' as u8 || code_bytes[idx] == '\t' as u8 { + idx += 1; + continue; + } + if let Some(sym) = is_symbol(&code[idx..]) { + tokens.push(sym); + byte_idx.push(idx); + idx += 1; + continue; + } + if code_bytes[idx] == '"' as u8 { + let len = find_len_stringliteral(&code_bytes[idx..])?; + let val = &code[idx..(idx + len)]; + let tok = Token::Literal(val); + tokens.push(tok); + byte_idx.push(idx); + idx += len; + continue; + } + let new_idx = find_len_object(code_bytes, idx); + let val = &code[idx..new_idx]; + let tok = Token::Object(val); + tokens.push(tok); + byte_idx.push(idx); + idx = new_idx; + } + + Ok((tokens, byte_idx)) +} + #[inline] fn is_symbol(code: &str) -> Option { let char = code.chars().next(); @@ -288,6 +346,69 @@ fn find_len_stringliteral(code_bytes: &[u8]) -> Result { Err(anyhow!("String listeral not closed")) } +/// Gets the ranges (as [start, end] byte offsets) from the `byte_idx` vector to keep, +/// excluding the token slices in `exlucde_tokens`. +/// +/// Both `tokens` and `byte_idx` are assumed to be parallel; i.e. the ith element of `byte_idx` +/// gives the starting offset of the ith token in `tokens`. In an ideal setup, `byte_idx` would have +/// one extra element (the file length) to mark the end of the last token. +pub(super) fn get_inclusion_ranges( + tokens: &Vec, + byte_idx: &Vec, + exlucde_tokens: &[&[Token]] +) -> Vec<[usize; 2]> { + let mut inclusion_ranges = Vec::new(); + // current_inclusion_start marks the index (in tokens) where the current “keep” region began. + let mut current_inclusion_start: usize = 0; + let mut i = 0; + + while i < tokens.len() { + let mut matched_exclusion = None; + // See if any of the exclusion slices match starting at token index i. + for &excl in exlucde_tokens { + if excl.is_empty() { + continue; + } + // If there are enough tokens left and the slice matches... + if i + excl.len() <= tokens.len() && &tokens[i..(i + excl.len())] == excl { + matched_exclusion = Some(excl.len()); + break; + } + } + + if let Some(skip_len) = matched_exclusion { + // End the current inclusion region (if nonempty) at the beginning of the exclusion. + if current_inclusion_start < i { + inclusion_ranges.push([ + byte_idx[current_inclusion_start], + byte_idx[i] + ]); + } + // Skip over the excluded tokens. + i += skip_len; + current_inclusion_start = i; + } else { + // No exclusion match here; move on. + i += 1; + } + } + + // If there is any trailing inclusion region after the last exclusion, add it. + if current_inclusion_start < tokens.len() { + // For the end offset, we try to use the next byte offset if available. + // (Ideally, byte_idx has length tokens.len() + 1.) + let end = if tokens.len() < byte_idx.len() { + byte_idx[tokens.len()] + } else { + // Fallback: use the last token's start offset. + *byte_idx.last().unwrap() + }; + inclusion_ranges.push([byte_idx[current_inclusion_start], end]); + } + + inclusion_ranges +} + // Maps character's ascii codes to their token const TOKEN_MAPPING: [Option; 128] = [ None, @@ -501,16 +622,19 @@ pub fn get_includes<'a>(tokens: &'a Vec) -> Vec<&'a [Token<'a>]> { includes } -pub fn get_structs<'a>(tokens: &'a Vec) -> Vec<&'a [Token<'a>]> { - let mut structs = vec![]; +/// Extracts the user defined types (UDTs) +pub fn get_udts<'a>(tokens: &'a Vec) -> Vec<&'a [Token<'a>]> { + let mut udts = vec![]; if tokens.len() < 3 { - return structs; + return udts; } + let udt_kwargs = ["struct", "enum", "union"]; + let mut idx: usize = 0; while idx < tokens.len() - 2 { if let Token::Object(obj) = tokens[idx] { - if !["typedef", "struct"].contains(&obj) { + if !["typedef", "struct", "union", "enum"].contains(&obj) { idx += 1; continue; } else if "typedef" == obj { @@ -519,12 +643,12 @@ pub fn get_structs<'a>(tokens: &'a Vec) -> Vec<&'a [Token<'a>]> { } else { "-" }; - if obj_2 != "struct" { + if !udt_kwargs.contains(&obj_2) { idx += 1; continue; } } - let length = match struct_len(&tokens[idx..]) { + let length = match udt_len(&tokens[idx..]) { Some(l) => l, None => { idx += 1; @@ -545,14 +669,14 @@ pub fn get_structs<'a>(tokens: &'a Vec) -> Vec<&'a [Token<'a>]> { continue; } - structs.push(&tokens[idx..end]); + udts.push(&tokens[idx..end]); idx = end - 1; } else { idx += 1; } } - structs + udts } pub fn get_defines<'a>(tokens: &'a Vec) -> Vec<&'a [Token<'a>]> { @@ -586,9 +710,9 @@ pub fn get_defines<'a>(tokens: &'a Vec) -> Vec<&'a [Token<'a>]> { /// Gets the name of the struct /// Ex) for `struct Point {...}`, this would return "Point" -pub(super) fn get_struct_name<'a>(tokens: &'a [Token]) -> &'a str { +pub(super) fn get_udt_name<'a>(tokens: &'a [Token]) -> &'a str { if tokens.len() < 3 { - unreachable!("Token string is not a valid struct definition"); + unreachable!("Token string is not a valid user defined type definition"); } match &tokens[0] { @@ -603,7 +727,7 @@ pub(super) fn get_struct_name<'a>(tokens: &'a [Token]) -> &'a str { let semicolon_index = tokens .iter() .rposition(|t| *t == Token::Semicolon) - .expect("Missing semicolon in typedef struct definition"); + .expect("Missing semicolon in user defined type definition"); // Iterate backwards from the token just before the semicolon to find the typedef alias. for token in tokens[..semicolon_index].iter().rev() { @@ -611,18 +735,20 @@ pub(super) fn get_struct_name<'a>(tokens: &'a [Token]) -> &'a str { return name; } } - unreachable!("No valid struct name found in typedef struct definition"); + unreachable!("No valid struct name found in user defined type definition"); } // Handle regular struct definitions - Token::Object("struct") => { + Token::Object("struct") + | Token::Object("enum") + | Token::Object("union") => { // Expect the struct name to immediately follow the "struct" keyword. if let Token::Object(name) = tokens[1] { return name; } else { - unreachable!("Expected struct name after 'struct' keyword"); + unreachable!("Expected name after 'struct/enum/union' keyword"); } } - _ => unreachable!("Token string is not a valid struct definition"), + _ => unreachable!("Token string is not a valid user defined type definition"), } } @@ -641,7 +767,7 @@ pub(super) fn get_define_name<'a>(tokens: &'a[Token]) -> &'a str { unreachable!("Token string is not a valid define macro"); } -fn struct_len(tokens: &[Token]) -> Option { +fn udt_len(tokens: &[Token]) -> Option { let mut num_brackets = 0; let mut contains_brackets = false; @@ -710,8 +836,7 @@ mod lexer_tests { #[test] fn test_get_defines() { let s = fs::read_to_string("tests/lexer-define.c").unwrap(); - let s = clean_source_code(s); - let tokens = tokenize(&s).unwrap(); + let (tokens, _) = tokenize_unclean(&s).unwrap(); let defines = get_defines(&tokens); @@ -727,12 +852,11 @@ mod lexer_tests { } #[test] - fn test_get_structs() { - let s = fs::read_to_string("tests/lexer-struct.c").unwrap(); - let s = clean_source_code(s); - let tokens = tokenize(&s).unwrap(); + fn test_get_udts() { + let s = fs::read_to_string("tests/lexer-UDT.c").unwrap(); + let (tokens, _) = tokenize_unclean(&s).unwrap(); - let defines = get_structs(&tokens); + let defines = get_udts(&tokens); let mut log_dump = "".to_string(); for &def in &defines { @@ -740,14 +864,13 @@ mod lexer_tests { log_dump.push_str(&x); } - fs::write("tests/lexer.test_get_structs.log", format!("{}", log_dump)).unwrap(); + fs::write("tests/lexer.test_get_udts.log", format!("{}", log_dump)).unwrap(); } #[test] fn test_get_define_name() { let s = fs::read_to_string("tests/lexer-define.c").unwrap(); - let s = clean_source_code(s); - let tokens = tokenize(&s).unwrap(); + let (tokens, _) = tokenize_unclean(&s).unwrap(); let defines = get_defines(&tokens); @@ -764,33 +887,31 @@ mod lexer_tests { #[test] - fn test_get_struct_name() { - let s = fs::read_to_string("tests/lexer-struct.c").unwrap(); - let s = clean_source_code(s); - let tokens = tokenize(&s).unwrap(); + fn test_get_udt_name() { + let s = fs::read_to_string("tests/lexer-UDT.c").unwrap(); + let (tokens, _) = tokenize_unclean(&s).unwrap(); - let structs = get_structs(&tokens); + let structs = get_udts(&tokens); let mut names = vec![]; for &d in &structs { - names.push(get_struct_name(d)); + names.push(get_udt_name(d)); } - fs::write("tests/lexer.test_get_struct_name.log", format!("{:#?}", names)) + fs::write("tests/lexer.test_get_udt_name.log", format!("{:#?}", names)) .unwrap(); } #[test] - fn test_struct_tokens_to_string() { - let s = fs::read_to_string("tests/lexer-struct.c").unwrap(); - let s = clean_source_code(s); - let tokens = tokenize(&s).unwrap(); + fn test_udt_tokens_to_string() { + let s = fs::read_to_string("tests/lexer-UDT.c").unwrap(); + let (tokens, _) = tokenize_unclean(&s).unwrap(); - let structs = get_structs(&tokens); + let structs = get_udts(&tokens); let mut log_dump = "".to_string(); for &d in &structs { - let s = Token::struct_tokens_to_string(d); + let s = Token::udt_tokens_to_string(d); let s_exact = format!("{:?}", &s); log_dump.push_str(&s); log_dump.push_str("\n"); @@ -800,7 +921,7 @@ mod lexer_tests { } - fs::write("tests/lexer.test_struct_tokens_to_string.log", &log_dump) + fs::write("tests/lexer.test_udt_tokens_to_string.log", &log_dump) .unwrap(); } diff --git a/src/main.rs b/src/main.rs index c3ef2af..a3a33e4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -341,46 +341,49 @@ fn handle_gen_headers(config: &Config) -> Result<()> { let src_dir = cwd.join(src_dir); let inc_dir = cwd.join(inc_dir); - for file in fs::read_dir(src_dir).unwrap() { + + + for file in fs::read_dir(&src_dir).unwrap() { if let Ok(file) = file { let raw_name = file.file_name(); - let raw_name = raw_name.to_str().unwrap().rsplit_once(".").unwrap().0; - if raw_name == "main" { + let (raw_name, file_ext) = raw_name.to_str().unwrap().rsplit_once(".").unwrap(); + if raw_name == "main" || file_ext != "c" { continue; } let header_name = format!("{}.h", raw_name); - let mut code = fs::read_to_string(file.path())?; - code = lexer::clean_source_code(code); - let tokens = lexer::tokenize(&code)?; + let code = fs::read_to_string(file.path())?; + let (tokens, byte_idx) = lexer::tokenize_unclean(&code)?; - let mut code_h = + let code_h = fs::read_to_string(inc_dir.join(&header_name)).unwrap_or("".to_string()); - code_h = lexer::clean_source_code(code_h); - let tokens_h = lexer::tokenize(&code_h)?; + let (tokens_h, _) = lexer::tokenize_unclean(&code_h)?; let mut defines_h = lexer::get_defines(&tokens_h); - let mut sturcts_h = lexer::get_structs(&tokens_h); + let mut udts_h = lexer::get_udts(&tokens_h); let fn_defs = lexer::get_fn_def(&tokens); let includes = lexer::get_includes(&tokens); let defines = lexer::get_defines(&tokens); - let structs = lexer::get_structs(&tokens); + let udts = lexer::get_udts(&tokens); + + // Skip the first definition to skip the #ifndef NAME_H #define NAME_H + if defines_h.len() > 0 { + defines_h.remove(0); + } - let res = headers_gen::merge_defines(&mut defines_h, &defines[1..]); // Skip the first definition to skip the #ifndef NAME_H #define NAME_H + let res = headers_gen::merge_defines(&mut defines_h, &defines); if let Err(e) = res { eprintln!("Error: {}", e); process::exit(1); } - let res = headers_gen::merge_structs(&mut sturcts_h, &structs); + let res = headers_gen::merge_udts(&mut udts_h, &udts); if let Err(e) = res { eprintln!("Error: {}", e); process::exit(1); } - sturcts_h.extend_from_slice(&structs); - let mut headers = String::new(); headers.push_str(&format!("#ifndef {}_H\n", raw_name.to_uppercase())); @@ -400,10 +403,12 @@ fn handle_gen_headers(config: &Config) -> Result<()> { } headers.push('\n'); - for &struc in &sturcts_h { + for &struc in &udts_h { headers.push_str(&lexer::Token::tokens_to_string(struc).trim()); headers.push_str("\n\n"); } + headers.push('\n'); + for &func in &fn_defs { let s = lexer::Token::tokens_to_string(func); headers.push_str(&s); @@ -412,7 +417,30 @@ fn handle_gen_headers(config: &Config) -> Result<()> { headers.push('\n'); headers.push_str(&format!("#endif // {}_H", raw_name.to_uppercase())); - fs::write(inc_dir.join(header_name), headers)?; + fs::write(inc_dir.join(&header_name), headers)?; + + // Remove definitions from original C file to avoid duplicates + let mut new_code = "".to_string(); + + let mut exlude_tokens = udts; + exlude_tokens.extend_from_slice(&defines); + + let inclusion_ranges = lexer::get_inclusion_ranges(&tokens, &byte_idx, &exlude_tokens); + + + for range in &inclusion_ranges { + new_code.push_str(&code[range[0]..range[1]]); + } + + let header_inc_path = format!("\"../include/{}\"", &header_name); + + new_code = headers_gen::insert_self_include(new_code, &header_inc_path); + + // let new_file = format!("{}.c.tmp", raw_name); + let new_file = format!("{}.c", raw_name); + let new_filepath = src_dir.join(&new_file); + + fs::write(new_filepath, new_code).unwrap(); } } Ok(()) diff --git a/tests/lexer-struct.c b/tests/lexer-UDT.c similarity index 53% rename from tests/lexer-struct.c rename to tests/lexer-UDT.c index 7b24c6f..9ab5ecf 100644 --- a/tests/lexer-struct.c +++ b/tests/lexer-UDT.c @@ -70,3 +70,75 @@ typedef struct Car { int wheels; float engine_power; } Vehicle; + +// Test case 12: Simple enum definition with a tag +enum Color { + RED, + GREEN, + BLUE +}; + +// Test case 13: Typedef enum without a tag +typedef enum { + CIRCLE, + SQUARE, + TRIANGLE +} Shape; + +// Test case 14: Enum with explicit values +enum Direction { + NORTH = 0, + EAST = 90, + SOUTH = 180, + WEST = 270 +}; + +// Test case 15: Typedef enum with a tag +typedef enum Day { + MONDAY, + TUESDAY, + WEDNESDAY, + THURSDAY, + FRIDAY, + SATURDAY, + SUNDAY +} Day; + +// Test case 16: Simple union definition with a tag +union Data { + int i; + float f; + char *s; +}; + +// Test case 17: Typedef union without a tag +typedef union { + int i; + float f; + char *s; +} Data; + +// Test case 18: Union with a nested struct +union Mixed { + struct { + int a; + int b; + } pair; + float f; +}; + +// Test case 19: Anonymous union within a struct (C11 or as a compiler extension) +struct Container { + int id; + union { + int i; + float f; + }; // Note: This union is anonymous and its members become members of Container. +}; + +// Test Case 20: Random bullshit that doesn't work for some reason +typedef enum BlockType { + BlockTypeCreate, + BlockTypeUpdate, + BlockTypeDelete, +} BlockType; \ No newline at end of file