diff --git a/Cargo.lock b/Cargo.lock index 49143908..94b591f3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3189,6 +3189,7 @@ dependencies = [ "pgt_suppressions", "pgt_test_utils", "pgt_text_size", + "pgt_tokenizer", "pgt_typecheck", "pgt_workspace_macros", "rustc-hash 2.1.0", diff --git a/crates/pgt_lexer/src/lexer.rs b/crates/pgt_lexer/src/lexer.rs index ad6db297..3e691229 100644 --- a/crates/pgt_lexer/src/lexer.rs +++ b/crates/pgt_lexer/src/lexer.rs @@ -111,6 +111,7 @@ impl<'a> Lexer<'a> { pgt_tokenizer::TokenKind::Tilde => SyntaxKind::TILDE, pgt_tokenizer::TokenKind::Question => SyntaxKind::QUESTION, pgt_tokenizer::TokenKind::Colon => SyntaxKind::COLON, + pgt_tokenizer::TokenKind::DoubleColon => SyntaxKind::DOUBLE_COLON, pgt_tokenizer::TokenKind::Eq => SyntaxKind::EQ, pgt_tokenizer::TokenKind::Bang => SyntaxKind::BANG, pgt_tokenizer::TokenKind::Lt => SyntaxKind::L_ANGLE, diff --git a/crates/pgt_lexer_codegen/src/syntax_kind.rs b/crates/pgt_lexer_codegen/src/syntax_kind.rs index c671e451..3a005437 100644 --- a/crates/pgt_lexer_codegen/src/syntax_kind.rs +++ b/crates/pgt_lexer_codegen/src/syntax_kind.rs @@ -37,6 +37,7 @@ const PUNCT: &[(&str, &str)] = &[ ("_", "UNDERSCORE"), (".", "DOT"), (":", "COLON"), + ("::", "DOUBLE_COLON"), ("=", "EQ"), ("!", "BANG"), ("-", "MINUS"), diff --git a/crates/pgt_tokenizer/src/lib.rs b/crates/pgt_tokenizer/src/lib.rs index 83b9ba44..16093db8 100644 --- a/crates/pgt_tokenizer/src/lib.rs +++ b/crates/pgt_tokenizer/src/lib.rs @@ -144,32 +144,37 @@ impl Cursor<'_> { } } ':' => { - // Named parameters in psql with different substitution styles. - // - // https://www.postgresql.org/docs/current/app-psql.html#APP-PSQL-INTERPOLATION - match self.first() { - '\'' => { - // Named parameter with colon prefix and single quotes. - self.bump(); - let terminated = self.single_quoted_string(); - let kind = NamedParamKind::ColonString { terminated }; - TokenKind::NamedParam { kind } - } - '"' => { - // Named parameter with colon prefix and double quotes. - self.bump(); - let terminated = self.double_quoted_string(); - let kind = NamedParamKind::ColonIdentifier { terminated }; - TokenKind::NamedParam { kind } - } - c if is_ident_start(c) => { - // Named parameter with colon prefix. - self.eat_while(is_ident_cont); - TokenKind::NamedParam { - kind: NamedParamKind::ColonRaw, + if self.first() == ':' { + self.bump(); + TokenKind::DoubleColon + } else { + // Named parameters in psql with different substitution styles. + // + // https://www.postgresql.org/docs/current/app-psql.html#APP-PSQL-INTERPOLATION + match self.first() { + '\'' => { + // Named parameter with colon prefix and single quotes. + self.bump(); + let terminated = self.single_quoted_string(); + let kind = NamedParamKind::ColonString { terminated }; + TokenKind::NamedParam { kind } + } + '"' => { + // Named parameter with colon prefix and double quotes. + self.bump(); + let terminated = self.double_quoted_string(); + let kind = NamedParamKind::ColonIdentifier { terminated }; + TokenKind::NamedParam { kind } + } + c if is_ident_start(c) => { + // Named parameter with colon prefix. + self.eat_while(is_ident_cont); + TokenKind::NamedParam { + kind: NamedParamKind::ColonRaw, + } } + _ => TokenKind::Colon, } - _ => TokenKind::Colon, } } // One-symbol tokens. @@ -675,6 +680,23 @@ mod tests { assert_debug_snapshot!(result); } + #[test] + fn debug_simple_cast() { + let result = lex("::test"); + assert_debug_snapshot!(result, @r###" + [ + "::" @ DoubleColon, + "test" @ Ident, + ] + "###); + } + + #[test] + fn named_param_colon_raw_vs_cast() { + let result = lex("select 1 from c where id::test = :id;"); + assert_debug_snapshot!(result); + } + #[test] fn named_param_colon_string() { let result = lex("select 1 from c where id = :'id';"); diff --git a/crates/pgt_tokenizer/src/snapshots/pgt_tokenizer__tests__named_param_colon_raw_vs_cast.snap b/crates/pgt_tokenizer/src/snapshots/pgt_tokenizer__tests__named_param_colon_raw_vs_cast.snap new file mode 100644 index 00000000..ecfd4821 --- /dev/null +++ b/crates/pgt_tokenizer/src/snapshots/pgt_tokenizer__tests__named_param_colon_raw_vs_cast.snap @@ -0,0 +1,25 @@ +--- +source: crates/pgt_tokenizer/src/lib.rs +expression: result +snapshot_kind: text +--- +[ + "select" @ Ident, + " " @ Space, + "1" @ Literal { kind: Int { base: Decimal, empty_int: false } }, + " " @ Space, + "from" @ Ident, + " " @ Space, + "c" @ Ident, + " " @ Space, + "where" @ Ident, + " " @ Space, + "id" @ Ident, + "::" @ DoubleColon, + "test" @ Ident, + " " @ Space, + "=" @ Eq, + " " @ Space, + ":id" @ NamedParam { kind: ColonRaw }, + ";" @ Semi, +] diff --git a/crates/pgt_tokenizer/src/token.rs b/crates/pgt_tokenizer/src/token.rs index da98a229..1312773d 100644 --- a/crates/pgt_tokenizer/src/token.rs +++ b/crates/pgt_tokenizer/src/token.rs @@ -46,6 +46,8 @@ pub enum TokenKind { Minus, /// `:` Colon, + /// `::` + DoubleColon, /// `.` Dot, /// `=` diff --git a/crates/pgt_workspace/Cargo.toml b/crates/pgt_workspace/Cargo.toml index efded47c..860b5133 100644 --- a/crates/pgt_workspace/Cargo.toml +++ b/crates/pgt_workspace/Cargo.toml @@ -33,6 +33,7 @@ pgt_schema_cache = { workspace = true } pgt_statement_splitter = { workspace = true } pgt_suppressions = { workspace = true } pgt_text_size.workspace = true +pgt_tokenizer = { workspace = true } pgt_typecheck = { workspace = true } pgt_workspace_macros = { workspace = true } rustc-hash = { workspace = true } diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index f0a39dbf..49c306f2 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -14,6 +14,7 @@ use document::{ TypecheckDiagnosticsMapper, }; use futures::{StreamExt, stream}; +use pg_query::convert_to_positional_params; use pgt_analyse::{AnalyserOptions, AnalysisFilter}; use pgt_analyser::{Analyser, AnalyserConfig, AnalyserParams}; use pgt_diagnostics::{ @@ -468,7 +469,7 @@ impl Workspace for WorkspaceServer { // Type checking let typecheck_result = pgt_typecheck::check_sql(TypecheckParams { conn: &pool, - sql: id.content(), + sql: convert_to_positional_params(id.content()).as_str(), ast: &ast, tree: &cst, schema_cache: schema_cache.as_ref(), diff --git a/crates/pgt_workspace/src/workspace/server.tests.rs b/crates/pgt_workspace/src/workspace/server.tests.rs index ef5ba267..894d1042 100644 --- a/crates/pgt_workspace/src/workspace/server.tests.rs +++ b/crates/pgt_workspace/src/workspace/server.tests.rs @@ -277,3 +277,57 @@ async fn test_dedupe_diagnostics(test_db: PgPool) { Some(TextRange::new(115.into(), 210.into())) ); } + +#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] +async fn test_positional_params(test_db: PgPool) { + let mut conf = PartialConfiguration::init(); + conf.merge_with(PartialConfiguration { + db: Some(PartialDatabaseConfiguration { + database: Some( + test_db + .connect_options() + .get_database() + .unwrap() + .to_string(), + ), + ..Default::default() + }), + ..Default::default() + }); + + let workspace = get_test_workspace(Some(conf)).expect("Unable to create test workspace"); + + let path = PgTPath::new("test.sql"); + + let setup_sql = r" + create table users ( + id serial primary key, + name text not null, + email text not null + ); + "; + test_db.execute(setup_sql).await.expect("setup sql failed"); + + let content = r#"select * from users where id = @one and name = :two and email = :'three';"#; + + workspace + .open_file(OpenFileParams { + path: path.clone(), + content: content.into(), + version: 1, + }) + .expect("Unable to open test file"); + + let diagnostics = workspace + .pull_diagnostics(crate::workspace::PullDiagnosticsParams { + path: path.clone(), + categories: RuleCategories::all(), + max_diagnostics: 100, + only: vec![], + skip: vec![], + }) + .expect("Unable to pull diagnostics") + .diagnostics; + + assert_eq!(diagnostics.len(), 0, "Expected no diagnostic"); +} diff --git a/crates/pgt_workspace/src/workspace/server/pg_query.rs b/crates/pgt_workspace/src/workspace/server/pg_query.rs index 05f1425d..bd9ffdfc 100644 --- a/crates/pgt_workspace/src/workspace/server/pg_query.rs +++ b/crates/pgt_workspace/src/workspace/server/pg_query.rs @@ -1,9 +1,11 @@ +use std::collections::HashMap; use std::num::NonZeroUsize; use std::sync::{Arc, Mutex}; use lru::LruCache; use pgt_query_ext::diagnostics::*; use pgt_text_size::TextRange; +use pgt_tokenizer::tokenize; use super::statement_identifier::StatementId; @@ -37,7 +39,7 @@ impl PgQueryStore { } let r = Arc::new( - pgt_query::parse(statement.content()) + pgt_query::parse(&convert_to_positional_params(statement.content())) .map_err(SyntaxDiagnostic::from) .and_then(|ast| { ast.into_root().ok_or_else(|| { @@ -87,10 +89,79 @@ impl PgQueryStore { } } +/// Converts named parameters in a SQL query string to positional parameters. +/// +/// This function scans the input SQL string for named parameters (e.g., `@param`, `:param`, `:'param'`) +/// and replaces them with positional parameters (e.g., `$1`, `$2`, etc.). +/// +/// It maintains the original spacing of the named parameters in the output string. +/// +/// Useful for preparing SQL queries for parsing or execution where named paramters are not supported. +pub fn convert_to_positional_params(text: &str) -> String { + let mut result = String::with_capacity(text.len()); + let mut param_mapping: HashMap<&str, usize> = HashMap::new(); + let mut param_index = 1; + let mut position = 0; + + for token in tokenize(text) { + let token_len = token.len as usize; + let token_text = &text[position..position + token_len]; + + if matches!(token.kind, pgt_tokenizer::TokenKind::NamedParam { .. }) { + let idx = match param_mapping.get(token_text) { + Some(&index) => index, + None => { + let index = param_index; + param_mapping.insert(token_text, index); + param_index += 1; + index + } + }; + + let replacement = format!("${}", idx); + let original_len = token_text.len(); + let replacement_len = replacement.len(); + + result.push_str(&replacement); + + // maintain original spacing + if replacement_len < original_len { + result.push_str(&" ".repeat(original_len - replacement_len)); + } + } else { + result.push_str(token_text); + } + + position += token_len; + } + + result +} + #[cfg(test)] mod tests { use super::*; + #[test] + fn test_convert_to_positional_params() { + let input = "select * from users where id = @one and name = :two and email = :'three';"; + let result = convert_to_positional_params(input); + assert_eq!( + result, + "select * from users where id = $1 and name = $2 and email = $3 ;" + ); + } + + #[test] + fn test_convert_to_positional_params_with_duplicates() { + let input = "select * from users where first_name = @one and starts_with(email, @one) and created_at > @two;"; + let result = convert_to_positional_params(input); + assert_eq!( + result, + "select * from users where first_name = $1 and starts_with(email, $1 ) and created_at > $2 ;" + ); + } + #[test] fn test_plpgsql_syntax_error() { let input = "