Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 102 additions & 10 deletions src/connector/src/source/filesystem/s3/enumerator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,56 @@

use std::collections::HashMap;

use anyhow::Context;
use async_trait::async_trait;
use aws_sdk_s3::client::Client;
use globset::{Glob, GlobMatcher};
use itertools::Itertools;

use crate::aws_utils::{default_conn_config, s3_client, AwsConfigV2};
use crate::source::filesystem::file_common::FsSplit;
use crate::source::filesystem::s3::S3Properties;
use crate::source::SplitEnumerator;

/// Get the prefix from a glob
fn get_prefix(glob: &str) -> String {
let mut escaped = false;
let mut escaped_filter = false;
glob.chars()
.take_while(|c| match (c, &escaped) {
('*', false) => false,
('[', false) => false,
('{', false) => false,
('\\', false) => {
escaped = true;
true
}
(_, false) => true,
(_, true) => {
escaped = false;
true
}
})
.filter(|c| match (c, &escaped_filter) {
(_, true) => {
escaped_filter = false;
true
}
('\\', false) => {
escaped_filter = true;
false
}
(_, _) => true,
})
.collect()
}

#[derive(Debug, Clone)]
pub struct S3SplitEnumerator {
bucket_name: String,
// prefix is used to reduce the number of objects to be listed
prefix: Option<String>,
matcher: Option<GlobMatcher>,
client: Client,
}

Expand All @@ -38,8 +76,19 @@ impl SplitEnumerator for S3SplitEnumerator {
let config = AwsConfigV2::from(HashMap::from(properties.clone()));
let sdk_config = config.load_config(None).await;
let s3_client = s3_client(&sdk_config, Some(default_conn_config()));
let matcher = if let Some(pattern) = properties.match_pattern.as_ref() {
let glob = Glob::new(pattern)
.with_context(|| format!("Invalid match_pattern: {}", pattern))?;
Some(glob.compile_matcher())
} else {
None
};
let prefix = matcher.as_ref().map(|m| get_prefix(m.glob().glob()));

Ok(S3SplitEnumerator {
bucket_name: properties.bucket_name,
matcher,
prefix,
client: s3_client,
})
}
Expand All @@ -49,20 +98,63 @@ impl SplitEnumerator for S3SplitEnumerator {
.client
.list_objects_v2()
.bucket(&self.bucket_name)
.set_prefix(self.prefix.clone())
.send()
.await?;

let objects = list_obj_out.contents();
let splits = objects
.map(|objs| {
objs.iter()
.map(|obj| {
let obj_name = obj.key().unwrap().to_string();
FsSplit::new(obj_name, 0, obj.size() as usize)
})
.collect_vec()
})
.unwrap_or_else(Vec::default);
let splits = if let Some(objs) = objects {
let matched_objs = objs
.iter()
.filter(|obj| obj.key().is_some())
.filter(|obj| {
self.matcher
.as_ref()
.map(|m| m.is_match(obj.key().unwrap()))
.unwrap_or(true)
})
.collect_vec();

matched_objs
.into_iter()
.map(|obj| FsSplit::new(obj.key().unwrap().to_owned(), 0, obj.size() as usize))
.collect_vec()
} else {
Vec::new()
};
Ok(splits)
}
}

#[cfg(test)]
mod tests {

#[test]
fn test_get_prefix() {
assert_eq!(&get_prefix("a/"), "a/");
assert_eq!(&get_prefix("a/**"), "a/");
assert_eq!(&get_prefix("[ab]*"), "");
assert_eq!(&get_prefix("a/{a,b}*"), "a/");
assert_eq!(&get_prefix(r"a/\{a,b}"), "a/{a,b}");
assert_eq!(&get_prefix(r"a/\[ab]"), "a/[ab]");
}

use super::*;
#[tokio::test]
#[ignore]
async fn test_s3_split_enumerator() {
let props = S3Properties {
region_name: "ap-southeast-1".to_owned(),
bucket_name: "mingchao-s3-source".to_owned(),
match_pattern: Some("happy[0-9].csv".to_owned()),
access: None,
secret: None,
};
let mut enumerator = S3SplitEnumerator::new(props.clone()).await.unwrap();
let splits = enumerator.list_splits().await.unwrap();
let names = splits.into_iter().map(|split| split.name).collect_vec();
assert_eq!(names.len(), 2);
assert!(names.contains(&"happy1.csv".to_owned()));
assert!(names.contains(&"happy2.csv".to_owned()));
}
}