diff --git a/Cargo.toml b/Cargo.toml index e25ac18b09..6fb2f936fb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,7 +53,6 @@ rustdoc-args = ["--cfg", "docsrs"] [features] default = ["any", "macros", "migrate", "json"] - derive = ["sqlx-macros/derive"] macros = ["derive", "sqlx-macros/macros"] migrate = ["sqlx-core/migrate", "sqlx-macros?/migrate", "sqlx-mysql?/migrate", "sqlx-postgres?/migrate", "sqlx-sqlite?/migrate"] @@ -69,7 +68,7 @@ _unstable-all-types = [ "ipnetwork", "mac_address", "uuid", - "bit-vec", + "bit-vec" ] # Base runtime features without TLS diff --git a/sqlx-postgres/src/type_checking.rs b/sqlx-postgres/src/type_checking.rs index f30a737399..e41e59e5cc 100644 --- a/sqlx-postgres/src/type_checking.rs +++ b/sqlx-postgres/src/type_checking.rs @@ -207,6 +207,10 @@ impl_type_checking!( #[cfg(feature = "time")] Vec> | &[sqlx::postgres::types::PgRange], + + // Full text search + sqlx::postgres::types::TsVector, + Vec | &[sqlx::postgres::types::TsVector], }, ParamChecking::Strong, feature-types: info => info.__type_feature_gate(), diff --git a/sqlx-postgres/src/type_info.rs b/sqlx-postgres/src/type_info.rs index 5952291e6e..56387f4861 100644 --- a/sqlx-postgres/src/type_info.rs +++ b/sqlx-postgres/src/type_info.rs @@ -100,6 +100,8 @@ pub enum PgType { RecordArray, Uuid, UuidArray, + TsVector, + TsVectorArray, Jsonb, JsonbArray, Int4Range, @@ -333,6 +335,8 @@ impl PgType { 2287 => PgType::RecordArray, 2950 => PgType::Uuid, 2951 => PgType::UuidArray, + 3614 => PgType::TsVector, + 3643 => PgType::TsVectorArray, 3802 => PgType::Jsonb, 3807 => PgType::JsonbArray, 3904 => PgType::Int4Range, @@ -441,6 +445,8 @@ impl PgType { PgType::RecordArray => Oid(2287), PgType::Uuid => Oid(2950), PgType::UuidArray => Oid(2951), + PgType::TsVector => Oid(3614), + PgType::TsVectorArray => Oid(3643), PgType::Jsonb => Oid(3802), PgType::JsonbArray => Oid(3807), PgType::Int4Range => Oid(3904), @@ -542,6 +548,8 @@ impl PgType { PgType::RecordArray => "RECORD[]", PgType::Uuid => "UUID", PgType::UuidArray => "UUID[]", + PgType::TsVector => "TSVECTOR", + PgType::TsVectorArray => "TSVECTOR[]", PgType::Jsonb => "JSONB", PgType::JsonbArray => "JSONB[]", PgType::Int4Range => "INT4RANGE", @@ -642,6 +650,8 @@ impl PgType { PgType::RecordArray => "_record", PgType::Uuid => "uuid", PgType::UuidArray => "_uuid", + PgType::TsVector => "tsvector", + PgType::TsVectorArray => "_tsvector", PgType::Jsonb => "jsonb", PgType::JsonbArray => "_jsonb", PgType::Int4Range => "int4range", @@ -742,6 +752,8 @@ impl PgType { PgType::RecordArray => &PgTypeKind::Array(PgTypeInfo(PgType::Record)), PgType::Uuid => &PgTypeKind::Simple, PgType::UuidArray => &PgTypeKind::Array(PgTypeInfo(PgType::Uuid)), + PgType::TsVector => &PgTypeKind::Simple, + PgType::TsVectorArray => &PgTypeKind::Array(PgTypeInfo(PgType::TsVector)), PgType::Jsonb => &PgTypeKind::Simple, PgType::JsonbArray => &PgTypeKind::Array(PgTypeInfo(PgType::Jsonb)), PgType::Int4Range => &PgTypeKind::Range(PgTypeInfo::INT4), @@ -855,6 +867,8 @@ impl PgType { PgType::RecordArray => Some(Cow::Owned(PgTypeInfo(PgType::Record))), PgType::Uuid => None, PgType::UuidArray => Some(Cow::Owned(PgTypeInfo(PgType::Uuid))), + PgType::TsVector => None, + PgType::TsVectorArray => Some(Cow::Owned(PgTypeInfo(PgType::TsVector))), PgType::Jsonb => None, PgType::JsonbArray => Some(Cow::Owned(PgTypeInfo(PgType::Jsonb))), PgType::Int4Range => None, @@ -928,6 +942,10 @@ impl PgTypeInfo { pub(crate) const UUID: Self = Self(PgType::Uuid); pub(crate) const UUID_ARRAY: Self = Self(PgType::UuidArray); + // tsvector + pub(crate) const TS_VECTOR: Self = Self(PgType::TsVector); + pub(crate) const TS_VECTOR_ARRAY: Self = Self(PgType::TsVectorArray); + // record pub(crate) const RECORD: Self = Self(PgType::Record); pub(crate) const RECORD_ARRAY: Self = Self(PgType::RecordArray); diff --git a/sqlx-postgres/src/types/mod.rs b/sqlx-postgres/src/types/mod.rs index d68d9b9178..4476c4eb05 100644 --- a/sqlx-postgres/src/types/mod.rs +++ b/sqlx-postgres/src/types/mod.rs @@ -234,6 +234,8 @@ mod mac_address; #[cfg(feature = "bit-vec")] mod bit_vec; +mod ts_vector; + pub use array::PgHasArrayType; pub use citext::PgCiText; pub use interval::PgInterval; @@ -251,6 +253,8 @@ pub use range::PgRange; #[cfg(any(feature = "chrono", feature = "time"))] pub use time_tz::PgTimeTz; +pub use ts_vector::TsVector; + // used in derive(Type) for `struct` // but the interface is not considered part of the public API #[doc(hidden)] diff --git a/sqlx-postgres/src/types/ts_vector.rs b/sqlx-postgres/src/types/ts_vector.rs new file mode 100644 index 0000000000..97e14c0257 --- /dev/null +++ b/sqlx-postgres/src/types/ts_vector.rs @@ -0,0 +1,342 @@ +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::types::Type; +use crate::{ + error::BoxDynError, PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, + Postgres, +}; +use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; +use core::fmt; +use std::error::Error; +use std::fmt::{Display, Formatter}; +use std::io::{BufRead, Cursor, Write}; +use std::num::{IntErrorKind, ParseIntError}; +use std::str; +use std::str::FromStr; + +#[derive(Debug, Copy, Clone)] +pub struct LexemeMeta { + position: u16, + weight: u16, +} + +impl From for LexemeMeta { + fn from(value: u16) -> Self { + let weight = (value>> 14) & 0b11; + let position = value & 0x3fff; + + Self { weight, position } + } +} + +impl From<&lexememeta> for u16 { + fn from(LexemeMeta { weight, position }: &LexemeMeta) -> Self { + let mut lexeme_meta = 0u16; + lexeme_meta = (weight << 14) | (position & 0x3fff); + lexeme_meta = (position & 0xc00) | (weight & 0x3fff); + + lexeme_meta + } +} + +#[derive(Debug)] +pub struct ParseLexemeMetaError { + kind: IntErrorKind, +} + +impl From for ParseLexemeMetaError { + fn from(value: ParseIntError) -> Self { + Self { + kind: value.kind().clone(), + } + } +} + +#[allow(deprecated)] +impl Display for ParseLexemeMetaError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + self.description().fmt(f) + } +} + +impl Error for ParseLexemeMetaError { + fn description(&self) -> &str { + match self.kind { + IntErrorKind::Empty => "cannot parse integer from empty string", + IntErrorKind::InvalidDigit => "invalid digit found in string", + IntErrorKind::PosOverflow => "number too large to fit in target type", + IntErrorKind::NegOverflow => "number too small to fit in target type", + IntErrorKind::Zero => "number would be zero for non-zero type", + _ => "unknown", + } + } +} + +impl FromStr for LexemeMeta { + type Err = ParseLexemeMetaError; + + fn from_str(s: &str) -> Result { + if s.ends_with(&['A', 'B', 'C', 'D']) { + let weight_char = s.chars().last().ok_or(ParseLexemeMetaError { + kind: IntErrorKind::Empty, + })?; + let weight = match weight_char { + 'A' => 3, + 'B' => 2, + 'C' => 1, + 'D' => 0, + _ => { + return Err(ParseLexemeMetaError { + kind: IntErrorKind::InvalidDigit, + }) + } + }; + + let position = s.strip_suffix(weight_char).unwrap_or(s).parse::()?; + + Ok(Self { weight, position }) + } else { + Ok(Self { + weight: 0, + position: s.parse()?, + }) + } + } +} + +#[derive(Debug)] +pub struct Lexeme { + word: String, + positions: Vec, +} + +impl Lexeme { + pub fn word(&self) -> &str { + self.word.as_str() + } +} + +#[derive(Debug)] +pub struct TsVector { + words: Vec, +} + +impl TsVector { + pub fn words(&self) -> &Vec { + &self.words + } +} + +impl Display for TsVector { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use std::fmt::Write; + + let mut words = self.words.iter().peekable(); + + while let Some(Lexeme { positions, word }) = words.next() { + // Add escaping for any single quotes within the word. + let word = word.replace("'", "''"); + + if positions.is_empty() { + f.write_str(&format!("'{}'", word))?; + } else { + let position = positions + .into_iter() + .map(|LexemeMeta { position, weight }| { + match weight { + 3 => format!("{position}A"), + 2 => format!("{position}B"), + 1 => format!("{position}C"), + // 'D' is the default value and does not need to be displayed + _ => format!("{position}"), + } + }) + .collect::>() + .join(","); + + f.write_str(&format!("'{}':{}", word, position))?; + } + + if words.peek().is_some() { + f.write_char(' ')?; + } + } + + Ok(()) + } +} + +impl TryFrom<&[u8]> for TsVector { + type Error = BoxDynError; + + /// Decode binary data into [`TsVector`] based on the binary data format defined in + /// https://github.com/postgres/postgres/blob/252dcb32397f64a5e1ceac05b29a271ab19aa960/src/backend/utils/adt/tsvector.c#L399 + fn try_from(bytes: &[u8]) -> Result { + let mut reader = Cursor::new(bytes); + let mut words = vec![]; + + let num_lexemes = reader.read_u32::()?; + + for _ in 0..num_lexemes { + let mut lexeme = vec![]; + + reader.read_until(b'0円', &mut lexeme)?; + + let num_positions = reader.read_u16::()?; + let mut positions = Vec::::with_capacity(num_positions as usize); + + if num_positions> 0 { + for _ in 0..num_positions { + let position = reader.read_u16::()?; + positions.push(LexemeMeta::from(position)); + } + } + + words.push(Lexeme { + word: str::from_utf8(&lexeme)?.trim_end_matches('0円').to_string(), + positions, + }); + } + + Ok(Self { words }) + } +} + +impl TryInto> for &TsVector { + type Error = BoxDynError; + + fn try_into(self) -> Result, Self::Error> { + let buf: &mut Vec = &mut vec![]; + + buf.write_u32::(u32::try_from(self.words.len())?)?; + + for lexeme in &self.words { + buf.write(lexeme.word.as_bytes())?; + buf.write(&[b'0円'])?; + + buf.write_u16::(u16::try_from(lexeme.positions.len())?)?; + + if !lexeme.positions.is_empty() { + for lexeme_meta in &lexeme.positions { + buf.write_u16::(lexeme_meta.into())?; + } + } + } + + buf.flush()?; + + Ok(buf.to_vec()) + } +} + +fn split_into_ts_vector_words(input: &str) -> Vec { + let mut wrapped = false; + let mut words = vec![]; + let mut current_word = String::new(); + let mut escaped = false; + + let mut chars = input.chars().peekable(); + + while let Some(token) = chars.next() { + match token { + '\'' => { + if !escaped { + if chars.peek().is_some_and(|item| *item == '\'') { + escaped = true; + current_word += "'"; + } else { + wrapped = !wrapped; + } + } else { + escaped = false; + } + } + char => { + if char.is_whitespace() && !wrapped { + words.push(current_word); + current_word = String::new(); + } else { + current_word += &char.to_string(); + } + } + } + } + + if !current_word.is_empty() { + words.push(current_word); + current_word = String::new(); + } + + words +} + +impl FromStr for TsVector { + type Err = ParseLexemeMetaError; + + fn from_str(s: &str) -> Result { + let mut words: Vec = vec![]; + + for word in split_into_ts_vector_words(s) { + if let Some((word, positions)) = word.rsplit_once(':') { + words.push(Lexeme { + word: word + .trim_start_matches('\'') + .trim_end_matches('\'') + .to_string(), + positions: positions + .split(',') + .map(|value| Ok::(value.parse()?)) + .collect::, _>>()?, + }); + } else { + words.push(Lexeme { + word: word + .trim_start_matches('\'') + .trim_end_matches('\'') + .to_string(), + positions: vec![], + }) + } + } + + Ok(TsVector { words }) + } +} + +impl Type for TsVector { + fn type_info() -> PgTypeInfo { + PgTypeInfo::TS_VECTOR + } +} + +impl PgHasArrayType for TsVector { + fn array_type_info() -> PgTypeInfo { + PgTypeInfo::TS_VECTOR_ARRAY + } +} + +impl Encode<'_, Postgres> for TsVector { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + if let Ok(encoded_ts_vector) = <&tsvector as TryInto>>::try_into(self) { + buf.extend_from_slice(encoded_ts_vector.as_slice()); + + IsNull::No + } else { + IsNull::Yes + } + } +} + +impl Decode<'_, Postgres> for TsVector { + fn decode(value: PgValueRef<'_>) -> Result { + match value.format() { + PgValueFormat::Binary => { + let bytes = value.as_bytes()?; + let ts_vector = bytes.try_into()?; + + Ok(ts_vector) + } + PgValueFormat::Text => Ok(value.as_str()?.parse::()?), + } + } +} diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index 184007ce4b..d5c10cae56 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -442,6 +442,75 @@ mod json { } } +mod full_text_search { + use super::*; + use sqlx::postgres::types::TsVector; + use sqlx::postgres::PgRow; + use sqlx::{Executor, Row}; + use sqlx_core::statement::Statement; + use sqlx_test::new; + + #[sqlx_macros::test] + async fn test_ts_vector() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // unprepared, text API + let row: PgRow = conn + .fetch_one("SELECT to_tsvector('english', 'A quick brown fox')") + .await?; + let value: TsVector = row.try_get(0)?; + + assert_eq!(value.to_string(), "'brown':3 'fox':4 'quick':2"); + + // prepared, binary API + let row: PgRow = conn + .fetch_one(sqlx::query("SELECT to_tsvector('A quick brown fox')")) + .await?; + + let value: TsVector = row.try_get(0)?; + + assert_eq!(value.to_string(), "'brown':3 'fox':4 'quick':2"); + + // with weights + let row = conn + .fetch_one("SELECT 'text:1A,2B,3C,4,5D'::tsvector") + .await?; + + let value: TsVector = row.try_get(0)?; + + assert_eq!(value.to_string(), "'text':1A,2B,3C,4,5"); + + // with no positions + let row = conn.fetch_one("SELECT 'text'::tsvector").await?; + let value: TsVector = row.try_get(0)?; + + assert_eq!(value.to_string(), "'text'"); + + let row = conn.fetch_one(r#"SELECT $$' A'$$::tsvector;"#).await?; + let value = row.get::(0).to_string(); + assert_eq!(value, "' A'"); + + let row = conn + .fetch_one(r#"SELECT $$'Joe''s' cat$$::tsvector;"#) + .await?; + let value = row.get::(0).to_string(); + assert_eq!(value, "'Joe''s' 'cat'"); + + let sql = r#"SELECT $$'Joe''s' cat$$::tsvector;"#; + let row = conn.fetch_one(sql).await.unwrap(); + let cell = row.get::(0); + assert_eq!(cell.words()[0].word(), "Joe's"); + + let sql = r#"SELECT $$'Joe''s' cat$$::tsvector;"#; + let statement = conn.prepare(sql).await.unwrap(); + let row = statement.query().fetch_one(&mut conn).await.unwrap(); + let cell = row.get::(0); + assert_eq!(cell.to_string(), "'Joe''s' 'cat'"); + + Ok(()) + } +} + #[cfg(feature = "bigdecimal")] test_type!(bigdecimal(Postgres,

AltStyle によって変換されたページ (->オリジナル) /