diff --git a/rust/lance/src/dataset/builder.rs b/rust/lance/src/dataset/builder.rs index 3d463ce6ca4..61959d3599a 100644 --- a/rust/lance/src/dataset/builder.rs +++ b/rust/lance/src/dataset/builder.rs @@ -234,7 +234,7 @@ impl DatasetBuilder { /// Sets `version` for the builder using a tag pub fn with_tag(mut self, tag: &str) -> Self { - self.version = Some(Ref::from(tag)); + self.version = Some(Ref::Tag(tag.to_string())); self } diff --git a/rust/lance/src/dataset/refs.rs b/rust/lance/src/dataset/refs.rs index 4044da9f60e..863b4ff4fc1 100644 --- a/rust/lance/src/dataset/refs.rs +++ b/rust/lance/src/dataset/refs.rs @@ -25,7 +25,7 @@ use std::io::ErrorKind; pub const MAIN_BRANCH: &str = "main"; /// Lance Ref -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum Ref { // Version number points of the current branch VersionNumber(u64), @@ -45,10 +45,51 @@ impl From for Ref { impl From<&str> for Ref { fn from(reference: &str) -> Self { + let reference = reference.trim(); + if let Some((branch, version)) = reference.split_once(':') { + if branch.is_empty() { + return Tag(reference.to_string()); + } + + let branch = if branch == MAIN_BRANCH { + None + } else if check_valid_branch(branch).is_ok() { + Some(branch.to_string()) + } else { + return Tag(reference.to_string()); + }; + + let version = match version { + "" | "latest" => None, + v => match v.parse::() { + Ok(v) => Some(v), + Err(_) => return Tag(reference.to_string()), + }, + }; + + return Version(branch, version); + } + + if reference == MAIN_BRANCH { + return Version(None, None); + } + + if reference.chars().all(|c| c.is_ascii_digit()) { + if let Ok(v) = reference.parse::() { + return VersionNumber(v); + } + } + Tag(reference.to_string()) } } +impl From for Ref { + fn from(reference: String) -> Self { + Self::from(reference.as_str()) + } +} + impl From<(&str, u64)> for Ref { fn from(reference: (&str, u64)) -> Self { Version(standardize_branch(reference.0), Some(reference.1)) @@ -221,6 +262,13 @@ impl Tags<'_> { message: format!("tag {} already exists", tag), }); } + + let branch_file = branch_contents_path(&root_location.path, tag); + if self.object_store().exists(&branch_file).await? { + return Err(Error::RefConflict { + message: format!("tag {} conflicts with existing branch", tag), + }); + } let tag_contents = self.build_tag_content_by_ref(reference).await?; self.object_store() @@ -257,6 +305,13 @@ impl Tags<'_> { message: format!("tag {} does not exist", tag), }); } + + let branch_file = branch_contents_path(&root_location.path, tag); + if self.object_store().exists(&branch_file).await? { + return Err(Error::RefConflict { + message: format!("tag {} conflicts with existing branch", tag), + }); + } let tag_contents = self.build_tag_content_by_ref(reference).await?; self.object_store() @@ -393,6 +448,13 @@ impl Branches<'_> { let source_branch = source_branch.and_then(standardize_branch); let root_location = self.refs.root()?; + + let tag_file = tag_path(&root_location.path, branch_name); + if self.object_store().exists(&tag_file).await? { + return Err(Error::RefConflict { + message: format!("branch {} conflicts with existing tag", branch_name), + }); + } let branch_file = branch_contents_path(&root_location.path, branch_name); if self.object_store().exists(&branch_file).await? { return Err(Error::RefConflict { @@ -793,6 +855,41 @@ mod tests { use rstest::rstest; + #[rstest] + fn test_parse_ref_ok( + #[values( + ("0", Ref::VersionNumber(0)), + ("42", Ref::VersionNumber(42)), + ("main", Ref::Version(None, None)), + ("main:", Ref::Version(None, None)), + ("main:latest", Ref::Version(None, None)), + ("main:10", Ref::Version(None, Some(10))), + ("feature/a:", Ref::Version(Some("feature/a".to_string()), None)), + ("feature/a:latest", Ref::Version(Some("feature/a".to_string()), None)), + ("feature/a:10", Ref::Version(Some("feature/a".to_string()), Some(10))), + ("tag1", Ref::Tag("tag1".to_string())) + )] + (input, expected): (&str, Ref), + ) { + assert_eq!(Ref::from(input), expected); + } + + #[rstest] + fn test_ref_from_str_invalid_syntax_falls_back_to_tag( + #[values( + "", + ":", + ":10", + "main:bad", + "feature/a", + "/start-with-slash", + "feature//double-slash" + )] + input: &str, + ) { + assert_eq!(Ref::from(input), Ref::Tag(input.trim().to_string())); + } + #[rstest] fn test_ok_ref( #[values( diff --git a/rust/lance/src/dataset/tests/dataset_versioning.rs b/rust/lance/src/dataset/tests/dataset_versioning.rs index 2e2fcdf6601..0f793e32f10 100644 --- a/rust/lance/src/dataset/tests/dataset_versioning.rs +++ b/rust/lance/src/dataset/tests/dataset_versioning.rs @@ -366,6 +366,59 @@ async fn test_tag( assert_eq!(dataset.manifest.version, 1); } +#[tokio::test] +async fn test_tag_branch_name_conflict() { + let test_uri = TempStrDir::default(); + + let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( + "i", + DataType::UInt32, + false, + )])); + + let data = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(UInt32Array::from_iter_values(0..10))], + ) + .unwrap(); + let reader = RecordBatchIterator::new(vec![data].into_iter().map(Ok), schema); + let mut dataset = Dataset::write(reader, &test_uri, None).await.unwrap(); + + dataset + .create_branch("conflict", dataset.version().version, None) + .await + .unwrap(); + + let err = dataset + .tags() + .create("conflict", dataset.version().version) + .await + .err() + .unwrap() + .to_string(); + assert_eq!( + err, + "Ref conflict error: tag conflict conflicts with existing branch" + ); + + dataset + .tags() + .create("tag_conflict", dataset.version().version) + .await + .unwrap(); + + let err = dataset + .create_branch("tag_conflict", dataset.version().version, None) + .await + .err() + .unwrap() + .to_string(); + assert_eq!( + err, + "Ref conflict error: branch tag_conflict conflicts with existing tag" + ); +} + #[rstest] #[tokio::test] async fn test_fragment_id_zero_not_reused() {