Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 126 additions & 32 deletions dt-common/src/config/config_token_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,84 @@ use crate::{error::Error, utils::sql_util::SqlUtil};

use super::config_enums::DbType;

#[derive(Debug, Clone)]
pub enum TokenEscapePair {
Char((char, char)),
String((String, String)),
}

impl From<(char, char)> for TokenEscapePair {
fn from(value: (char, char)) -> Self {
Self::Char(value)
}
}

impl From<(String, String)> for TokenEscapePair {
fn from(value: (String, String)) -> Self {
Self::String(value)
}
}

impl TokenEscapePair {
pub fn from_char_pairs(char_pairs: Vec<(char, char)>) -> Vec<Self> {
char_pairs.into_iter().map(|e| Self::from(e)).collect()
}

pub fn match_escape_left(&self, chars: &[char], start_index: usize) -> bool {
self.match_escape_side(chars, start_index, true)
}

pub fn match_escape_right(&self, chars: &[char], start_index: usize) -> bool {
self.match_escape_side(chars, start_index, false)
}

fn match_escape_side(&self, chars: &[char], start_index: usize, is_left: bool) -> bool {
match self {
TokenEscapePair::Char((escape_left, escape_right)) => {
let escape = if is_left { escape_left } else { escape_right };
if start_index >= chars.len() {
return false;
}
if chars[start_index] != *escape {
return false;
}
true
}
TokenEscapePair::String((escape_left, escape_right)) => {
let escape = if is_left { escape_left } else { escape_right };
if start_index + escape.len() > chars.len() {
return false;
}
for (i, char_to_match) in escape.chars().enumerate() {
if chars[start_index + i] != char_to_match {
return false;
}
}
true
}
}
}
}

pub struct ConfigTokenParser {}

impl ConfigTokenParser {
pub fn parse_config(
config_str: &str,
db_type: &DbType,
delimiters: &[char],
custom_escape_pairs: Option<&[TokenEscapePair]>,
) -> anyhow::Result<Vec<String>> {
if config_str.is_empty() {
return Ok(Vec::new());
}

let escape_pairs = SqlUtil::get_escape_pairs(db_type);
let tokens = Self::parse(config_str, delimiters, &escape_pairs);
let mut token_escape_pairs = TokenEscapePair::from_char_pairs(escape_pairs.clone());
if let Some(pairs) = custom_escape_pairs {
token_escape_pairs.extend_from_slice(pairs);
}
let tokens = Self::parse(config_str, delimiters, &token_escape_pairs);
for token in tokens.iter() {
if !SqlUtil::is_valid_token(token, db_type, &escape_pairs) {
bail! {Error::ConfigError(format!(
Expand All @@ -29,7 +93,11 @@ impl ConfigTokenParser {
Ok(tokens)
}

pub fn parse(config: &str, delimiters: &[char], escape_pairs: &[(char, char)]) -> Vec<String> {
pub fn parse(
config: &str,
delimiters: &[char],
escape_pairs: &[TokenEscapePair],
) -> Vec<String> {
let chars: Vec<char> = config.chars().collect();
let mut start_index = 0;
let mut tokens = Vec::new();
Expand Down Expand Up @@ -58,16 +126,12 @@ impl ConfigTokenParser {
chars: &[char],
start_index: usize,
delimiters: &[char],
escape_pairs: &[(char, char)],
escape_pairs: &[TokenEscapePair],
) -> (String, usize) {
// read token surrounded by escapes: `db.2`
for (escape_left, escape_right) in escape_pairs.iter() {
if chars[start_index] == *escape_left {
return Self::read_token_with_escape(
chars,
start_index,
(*escape_left, *escape_right),
);
for e in escape_pairs.iter() {
if e.match_escape_left(chars, start_index) {
return Self::read_token_with_escape(chars, start_index, e);
}
}
Self::read_token_to_delimiter(chars, start_index, delimiters)
Expand Down Expand Up @@ -96,23 +160,39 @@ impl ConfigTokenParser {
fn read_token_with_escape(
chars: &[char],
start_index: usize,
escape_pair: (char, char),
escape_pair: &TokenEscapePair,
) -> (String, usize) {
let mut start = false;
let mut token = String::new();
let mut read_count = 0;
for c in chars.iter().skip(start_index) {
if start && *c == escape_pair.1 {
token.push(*c);
read_count += 1;
break;
match escape_pair {
TokenEscapePair::Char((escape_left, escape_right)) => {
let mut start = false;
for c in chars.iter().skip(start_index) {
if start && *c == *escape_right {
token.push(*c);
read_count += 1;
break;
}
if *c == *escape_left {
start = true;
}
if start {
token.push(*c);
read_count += 1;
}
}
}
if *c == escape_pair.0 {
start = true;
}
if start {
token.push(*c);
read_count += 1;
TokenEscapePair::String((escape_left, _)) => {
let prefix_len = escape_left.len();
for c in chars.iter().skip(start_index) {
token.push(*c);
read_count += 1;
if read_count > prefix_len
&& escape_pair.match_escape_right(&chars, start_index + read_count - 1)
{
break;
}
}
}
}

Expand All @@ -131,12 +211,15 @@ mod tests {

#[test]
fn test_parse_mysql_filter_config_tokens() {
let config = r#"db_1.tb_1,`db.2`.`tb.2`,`db"3`.tb_3,db_4.`tb"4`,db_5.*,`db.6`.*,db_7*.*,`db.8*`.*,*.*,`*`.`*`"#;
let config = r#"db_1.tb_1,`db.2`.`tb.2`,`db"3`.tb_3,db_4.`tb"4`,db_5.*,`db.6`.*,db_7*.*,`db.8*`.*,*.*,`*`.`*`,r#.*#.r#.?#,`r#.*#`.`r#.?#`"#;
let delimiters = vec!['.', ','];
let escape_pairs = vec![('`', '`')];
let escape_pairs = vec![
TokenEscapePair::Char(('`', '`')),
TokenEscapePair::String(("r#".to_string(), '#'.to_string())),
];

let tokens = ConfigTokenParser::parse(config, &delimiters, &escape_pairs);
assert_eq!(tokens.len(), 20);
assert_eq!(tokens.len(), 24);
assert_eq!(tokens[0], "db_1");
assert_eq!(tokens[1], "tb_1");
assert_eq!(tokens[2], "`db.2`");
Expand All @@ -157,13 +240,17 @@ mod tests {
assert_eq!(tokens[17], "*");
assert_eq!(tokens[18], "`*`");
assert_eq!(tokens[19], "`*`");
assert_eq!(tokens[20], "r#.*#");
assert_eq!(tokens[21], "r#.?#");
assert_eq!(tokens[22], "`r#.*#`");
assert_eq!(tokens[23], "`r#.?#`");
}

#[test]
fn test_parse_mysql_router_config_tokens() {
let config = r#"db_1.tb_1:`db.2`.`tb.2`,`db"3`.tb_3:db_4.`tb"4`"#;
let delimiters = vec!['.', ',', ':'];
let escape_pairs = vec![('`', '`')];
let escape_pairs = vec![TokenEscapePair::Char(('`', '`'))];

let tokens = ConfigTokenParser::parse(config, &delimiters, &escape_pairs);
assert_eq!(tokens.len(), 8);
Expand All @@ -179,12 +266,15 @@ mod tests {

#[test]
fn test_parse_pg_filter_config_tokens() {
let config = r#"db_1.tb_1,"db.2"."tb.2","db`3".tb_3,db_4."tb`4",db_5.*,"db.6".*,db_7*.*,"db.8*".*,*.*,"*"."*""#;
let config = r#"db_1.tb_1,"db.2"."tb.2","db`3".tb_3,db_4."tb`4",db_5.*,"db.6".*,db_7*.*,"db.8*".*,*.*,"*"."*",r#.*#.r#.?#,"r#.*#"."r#.?#""#;
let delimiters = vec!['.', ','];
let escape_pairs = vec![('"', '"')];
let escape_pairs = vec![
TokenEscapePair::Char(('"', '"')),
TokenEscapePair::String(("r#".to_string(), '#'.to_string())),
];

let tokens = ConfigTokenParser::parse(config, &delimiters, &escape_pairs);
assert_eq!(tokens.len(), 20);
assert_eq!(tokens.len(), 24);
assert_eq!(tokens[0], "db_1");
assert_eq!(tokens[1], "tb_1");
assert_eq!(tokens[2], r#""db.2""#);
Expand All @@ -205,13 +295,17 @@ mod tests {
assert_eq!(tokens[17], "*");
assert_eq!(tokens[18], r#""*""#);
assert_eq!(tokens[19], r#""*""#);
assert_eq!(tokens[20], "r#.*#");
assert_eq!(tokens[21], "r#.?#");
assert_eq!(tokens[22], r#""r#.*#""#);
assert_eq!(tokens[23], r#""r#.?#""#);
}

#[test]
fn test_parse_pg_router_config_tokens() {
let config = r#"db_1.tb_1:"db.2"."tb.2","db`3".tb_3:db_4."tb`4""#;
let delimiters = vec!['.', ',', ':'];
let escape_pairs = vec![('"', '"')];
let escape_pairs = vec![TokenEscapePair::Char(('"', '"'))];

let tokens = ConfigTokenParser::parse(config, &delimiters, &escape_pairs);
assert_eq!(tokens.len(), 8);
Expand All @@ -229,7 +323,7 @@ mod tests {
fn test_parse_emoj_config_tokens() {
let config = r#"SET "set_key_3_ 😀" "val_2_ 😀""#;
let delimiters = vec![' '];
let escape_pairs = vec![('"', '"')];
let escape_pairs = vec![TokenEscapePair::Char(('"', '"'))];
let tokens = ConfigTokenParser::parse(config, &delimiters, &escape_pairs);
assert_eq!(tokens.len(), 3);
assert_eq!(tokens[0], "SET");
Expand Down
Loading
Loading