sqlparser/ast/visitor.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Recursive visitors for ast Nodes. See [`Visitor`] for more details.
19
20use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor, Value};
21use core::ops::ControlFlow;
22
23/// A type that can be visited by a [`Visitor`]. See [`Visitor`] for
24/// recursively visiting parsed SQL statements.
25///
26/// # Note
27///
28/// This trait should be automatically derived for sqlparser AST nodes
29/// using the [Visit](sqlparser_derive::Visit) proc macro.
30///
31/// ```text
32/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
33/// ```
34pub trait Visit {
35 fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break>;
36}
37
38/// A type that can be visited by a [`VisitorMut`]. See [`VisitorMut`] for
39/// recursively visiting parsed SQL statements.
40///
41/// # Note
42///
43/// This trait should be automatically derived for sqlparser AST nodes
44/// using the [VisitMut](sqlparser_derive::VisitMut) proc macro.
45///
46/// ```text
47/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
48/// ```
49pub trait VisitMut {
50 fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break>;
51}
52
53impl<T: Visit> Visit for Option<T> {
54 fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
55 if let Some(s) = self {
56 s.visit(visitor)?;
57 }
58 ControlFlow::Continue(())
59 }
60}
61
62impl<T: Visit> Visit for Vec<T> {
63 fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
64 for v in self {
65 v.visit(visitor)?;
66 }
67 ControlFlow::Continue(())
68 }
69}
70
71impl<T: Visit> Visit for Box<T> {
72 fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
73 T::visit(self, visitor)
74 }
75}
76
77impl<T: VisitMut> VisitMut for Option<T> {
78 fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
79 if let Some(s) = self {
80 s.visit(visitor)?;
81 }
82 ControlFlow::Continue(())
83 }
84}
85
86impl<T: VisitMut> VisitMut for Vec<T> {
87 fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
88 for v in self {
89 v.visit(visitor)?;
90 }
91 ControlFlow::Continue(())
92 }
93}
94
95impl<T: VisitMut> VisitMut for Box<T> {
96 fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
97 T::visit(self, visitor)
98 }
99}
100
101macro_rules! visit_noop {
102 ($($t:ty),+) => {
103 $(impl Visit for $t {
104 fn visit<V: Visitor>(&self, _visitor: &mut V) -> ControlFlow<V::Break> {
105 ControlFlow::Continue(())
106 }
107 })+
108 $(impl VisitMut for $t {
109 fn visit<V: VisitorMut>(&mut self, _visitor: &mut V) -> ControlFlow<V::Break> {
110 ControlFlow::Continue(())
111 }
112 })+
113 };
114}
115
116visit_noop!(u8, u16, u32, u64, i8, i16, i32, i64, char, bool, String);
117
118#[cfg(feature = "bigdecimal")]
119visit_noop!(bigdecimal::BigDecimal);
120
121/// A visitor that can be used to walk an AST tree.
122///
123/// `pre_visit_` methods are invoked before visiting all children of the
124/// node and `post_visit_` methods are invoked after visiting all
125/// children of the node.
126///
127/// # See also
128///
129/// These methods provide a more concise way of visiting nodes of a certain type:
130/// * [visit_relations]
131/// * [visit_expressions]
132/// * [visit_statements]
133///
134/// # Example
135/// ```
136/// # use sqlparser::parser::Parser;
137/// # use sqlparser::dialect::GenericDialect;
138/// # use sqlparser::ast::{Visit, Visitor, ObjectName, Expr};
139/// # use core::ops::ControlFlow;
140/// // A structure that records statements and relations
141/// #[derive(Default)]
142/// struct V {
143/// visited: Vec<String>,
144/// }
145///
146/// // Visit relations and exprs before children are visited (depth first walk)
147/// // Note you can also visit statements and visit exprs after children have been visited
148/// impl Visitor for V {
149/// type Break = ();
150///
151/// fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
152/// self.visited.push(format!("PRE: RELATION: {}", relation));
153/// ControlFlow::Continue(())
154/// }
155///
156/// fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
157/// self.visited.push(format!("PRE: EXPR: {}", expr));
158/// ControlFlow::Continue(())
159/// }
160/// }
161///
162/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
163/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
164/// .unwrap();
165///
166/// // Drive the visitor through the AST
167/// let mut visitor = V::default();
168/// statements.visit(&mut visitor);
169///
170/// // The visitor has visited statements and expressions in pre-traversal order
171/// let expected : Vec<_> = [
172/// "PRE: EXPR: a",
173/// "PRE: RELATION: foo",
174/// "PRE: EXPR: x IN (SELECT y FROM bar)",
175/// "PRE: EXPR: x",
176/// "PRE: EXPR: y",
177/// "PRE: RELATION: bar",
178/// ]
179/// .into_iter().map(|s| s.to_string()).collect();
180///
181/// assert_eq!(visitor.visited, expected);
182/// ```
183pub trait Visitor {
184 /// Type returned when the recursion returns early.
185 type Break;
186
187 /// Invoked for any queries that appear in the AST before visiting children
188 fn pre_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
189 ControlFlow::Continue(())
190 }
191
192 /// Invoked for any queries that appear in the AST after visiting children
193 fn post_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
194 ControlFlow::Continue(())
195 }
196
197 /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
198 fn pre_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
199 ControlFlow::Continue(())
200 }
201
202 /// Invoked for any relations (e.g. tables) that appear in the AST after visiting children
203 fn post_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
204 ControlFlow::Continue(())
205 }
206
207 /// Invoked for any table factors that appear in the AST before visiting children
208 fn pre_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
209 ControlFlow::Continue(())
210 }
211
212 /// Invoked for any table factors that appear in the AST after visiting children
213 fn post_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
214 ControlFlow::Continue(())
215 }
216
217 /// Invoked for any expressions that appear in the AST before visiting children
218 fn pre_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
219 ControlFlow::Continue(())
220 }
221
222 /// Invoked for any expressions that appear in the AST
223 fn post_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
224 ControlFlow::Continue(())
225 }
226
227 /// Invoked for any statements that appear in the AST before visiting children
228 fn pre_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
229 ControlFlow::Continue(())
230 }
231
232 /// Invoked for any statements that appear in the AST after visiting children
233 fn post_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
234 ControlFlow::Continue(())
235 }
236
237 /// Invoked for any Value that appear in the AST before visiting children
238 fn pre_visit_value(&mut self, _value: &Value) -> ControlFlow<Self::Break> {
239 ControlFlow::Continue(())
240 }
241
242 /// Invoked for any Value that appear in the AST after visiting children
243 fn post_visit_value(&mut self, _value: &Value) -> ControlFlow<Self::Break> {
244 ControlFlow::Continue(())
245 }
246}
247
248/// A visitor that can be used to mutate an AST tree.
249///
250/// `pre_visit_` methods are invoked before visiting all children of the
251/// node and `post_visit_` methods are invoked after visiting all
252/// children of the node.
253///
254/// # See also
255///
256/// These methods provide a more concise way of visiting nodes of a certain type:
257/// * [visit_relations_mut]
258/// * [visit_expressions_mut]
259/// * [visit_statements_mut]
260///
261/// # Example
262/// ```
263/// # use sqlparser::parser::Parser;
264/// # use sqlparser::dialect::GenericDialect;
265/// # use sqlparser::ast::{VisitMut, VisitorMut, ObjectName, Expr, Ident};
266/// # use core::ops::ControlFlow;
267///
268/// // A visitor that replaces "to_replace" with "replaced" in all expressions
269/// struct Replacer;
270///
271/// // Visit each expression after its children have been visited
272/// impl VisitorMut for Replacer {
273/// type Break = ();
274///
275/// fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
276/// if let Expr::Identifier(Ident{ value, ..}) = expr {
277/// *value = value.replace("to_replace", "replaced")
278/// }
279/// ControlFlow::Continue(())
280/// }
281/// }
282///
283/// let sql = "SELECT to_replace FROM foo where to_replace IN (SELECT to_replace FROM bar)";
284/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
285///
286/// // Drive the visitor through the AST
287/// statements.visit(&mut Replacer);
288///
289/// assert_eq!(statements[0].to_string(), "SELECT replaced FROM foo WHERE replaced IN (SELECT replaced FROM bar)");
290/// ```
291pub trait VisitorMut {
292 /// Type returned when the recursion returns early.
293 type Break;
294
295 /// Invoked for any queries that appear in the AST before visiting children
296 fn pre_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
297 ControlFlow::Continue(())
298 }
299
300 /// Invoked for any queries that appear in the AST after visiting children
301 fn post_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
302 ControlFlow::Continue(())
303 }
304
305 /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
306 fn pre_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
307 ControlFlow::Continue(())
308 }
309
310 /// Invoked for any relations (e.g. tables) that appear in the AST after visiting children
311 fn post_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
312 ControlFlow::Continue(())
313 }
314
315 /// Invoked for any table factors that appear in the AST before visiting children
316 fn pre_visit_table_factor(
317 &mut self,
318 _table_factor: &mut TableFactor,
319 ) -> ControlFlow<Self::Break> {
320 ControlFlow::Continue(())
321 }
322
323 /// Invoked for any table factors that appear in the AST after visiting children
324 fn post_visit_table_factor(
325 &mut self,
326 _table_factor: &mut TableFactor,
327 ) -> ControlFlow<Self::Break> {
328 ControlFlow::Continue(())
329 }
330
331 /// Invoked for any expressions that appear in the AST before visiting children
332 fn pre_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
333 ControlFlow::Continue(())
334 }
335
336 /// Invoked for any expressions that appear in the AST
337 fn post_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
338 ControlFlow::Continue(())
339 }
340
341 /// Invoked for any statements that appear in the AST before visiting children
342 fn pre_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
343 ControlFlow::Continue(())
344 }
345
346 /// Invoked for any statements that appear in the AST after visiting children
347 fn post_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
348 ControlFlow::Continue(())
349 }
350
351 /// Invoked for any value that appear in the AST before visiting children
352 fn pre_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
353 ControlFlow::Continue(())
354 }
355
356 /// Invoked for any statements that appear in the AST after visiting children
357 fn post_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
358 ControlFlow::Continue(())
359 }
360}
361
362struct RelationVisitor<F>(F);
363
364impl<E, F: FnMut(&ObjectName) -> ControlFlow<E>> Visitor for RelationVisitor<F> {
365 type Break = E;
366
367 fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
368 self.0(relation)
369 }
370}
371
372impl<E, F: FnMut(&mut ObjectName) -> ControlFlow<E>> VisitorMut for RelationVisitor<F> {
373 type Break = E;
374
375 fn post_visit_relation(&mut self, relation: &mut ObjectName) -> ControlFlow<Self::Break> {
376 self.0(relation)
377 }
378}
379
380/// Invokes the provided closure on all relations (e.g. table names) present in `v`
381///
382/// # Example
383/// ```
384/// # use sqlparser::parser::Parser;
385/// # use sqlparser::dialect::GenericDialect;
386/// # use sqlparser::ast::{visit_relations};
387/// # use core::ops::ControlFlow;
388/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
389/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
390/// .unwrap();
391///
392/// // visit statements, capturing relations (table names)
393/// let mut visited = vec![];
394/// visit_relations(&statements, |relation| {
395/// visited.push(format!("RELATION: {}", relation));
396/// ControlFlow::<()>::Continue(())
397/// });
398///
399/// let expected : Vec<_> = [
400/// "RELATION: foo",
401/// "RELATION: bar",
402/// ]
403/// .into_iter().map(|s| s.to_string()).collect();
404///
405/// assert_eq!(visited, expected);
406/// ```
407pub fn visit_relations<V, E, F>(v: &V, f: F) -> ControlFlow<E>
408where
409 V: Visit,
410 F: FnMut(&ObjectName) -> ControlFlow<E>,
411{
412 let mut visitor = RelationVisitor(f);
413 v.visit(&mut visitor)?;
414 ControlFlow::Continue(())
415}
416
417/// Invokes the provided closure with a mutable reference to all relations (e.g. table names)
418/// present in `v`.
419///
420/// When the closure mutates its argument, the new mutated relation will not be visited again.
421///
422/// # Example
423/// ```
424/// # use sqlparser::parser::Parser;
425/// # use sqlparser::dialect::GenericDialect;
426/// # use sqlparser::ast::{ObjectName, ObjectNamePart, Ident, visit_relations_mut};
427/// # use core::ops::ControlFlow;
428/// let sql = "SELECT a FROM foo";
429/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql)
430/// .unwrap();
431///
432/// // visit statements, renaming table foo to bar
433/// visit_relations_mut(&mut statements, |table| {
434/// table.0[0] = ObjectNamePart::Identifier(Ident::new("bar"));
435/// ControlFlow::<()>::Continue(())
436/// });
437///
438/// assert_eq!(statements[0].to_string(), "SELECT a FROM bar");
439/// ```
440pub fn visit_relations_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
441where
442 V: VisitMut,
443 F: FnMut(&mut ObjectName) -> ControlFlow<E>,
444{
445 let mut visitor = RelationVisitor(f);
446 v.visit(&mut visitor)?;
447 ControlFlow::Continue(())
448}
449
450struct ExprVisitor<F>(F);
451
452impl<E, F: FnMut(&Expr) -> ControlFlow<E>> Visitor for ExprVisitor<F> {
453 type Break = E;
454
455 fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
456 self.0(expr)
457 }
458}
459
460impl<E, F: FnMut(&mut Expr) -> ControlFlow<E>> VisitorMut for ExprVisitor<F> {
461 type Break = E;
462
463 fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
464 self.0(expr)
465 }
466}
467
468/// Invokes the provided closure on all expressions (e.g. `1 + 2`) present in `v`
469///
470/// # Example
471/// ```
472/// # use sqlparser::parser::Parser;
473/// # use sqlparser::dialect::GenericDialect;
474/// # use sqlparser::ast::{visit_expressions};
475/// # use core::ops::ControlFlow;
476/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
477/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
478/// .unwrap();
479///
480/// // visit all expressions
481/// let mut visited = vec![];
482/// visit_expressions(&statements, |expr| {
483/// visited.push(format!("EXPR: {}", expr));
484/// ControlFlow::<()>::Continue(())
485/// });
486///
487/// let expected : Vec<_> = [
488/// "EXPR: a",
489/// "EXPR: x IN (SELECT y FROM bar)",
490/// "EXPR: x",
491/// "EXPR: y",
492/// ]
493/// .into_iter().map(|s| s.to_string()).collect();
494///
495/// assert_eq!(visited, expected);
496/// ```
497pub fn visit_expressions<V, E, F>(v: &V, f: F) -> ControlFlow<E>
498where
499 V: Visit,
500 F: FnMut(&Expr) -> ControlFlow<E>,
501{
502 let mut visitor = ExprVisitor(f);
503 v.visit(&mut visitor)?;
504 ControlFlow::Continue(())
505}
506
507/// Invokes the provided closure iteratively with a mutable reference to all expressions
508/// present in `v`.
509///
510/// This performs a depth-first search, so if the closure mutates the expression
511///
512/// # Example
513///
514/// ## Remove all select limits in sub-queries
515/// ```
516/// # use sqlparser::parser::Parser;
517/// # use sqlparser::dialect::GenericDialect;
518/// # use sqlparser::ast::{Expr, visit_expressions_mut, visit_statements_mut};
519/// # use core::ops::ControlFlow;
520/// let sql = "SELECT (SELECT y FROM z LIMIT 9) FROM t LIMIT 3";
521/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
522///
523/// // Remove all select limits in sub-queries
524/// visit_expressions_mut(&mut statements, |expr| {
525/// if let Expr::Subquery(q) = expr {
526/// q.limit_clause = None;
527/// }
528/// ControlFlow::<()>::Continue(())
529/// });
530///
531/// assert_eq!(statements[0].to_string(), "SELECT (SELECT y FROM z) FROM t LIMIT 3");
532/// ```
533///
534/// ## Wrap column name in function call
535///
536/// This demonstrates how to effectively replace an expression with another more complicated one
537/// that references the original. This example avoids unnecessary allocations by using the
538/// [`std::mem`] family of functions.
539///
540/// ```
541/// # use sqlparser::parser::Parser;
542/// # use sqlparser::dialect::GenericDialect;
543/// # use sqlparser::ast::*;
544/// # use core::ops::ControlFlow;
545/// let sql = "SELECT x, y FROM t";
546/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
547///
548/// visit_expressions_mut(&mut statements, |expr| {
549/// if matches!(expr, Expr::Identifier(col_name) if col_name.value == "x") {
550/// let old_expr = std::mem::replace(expr, Expr::value(Value::Null));
551/// *expr = Expr::Function(Function {
552/// name: ObjectName::from(vec![Ident::new("f")]),
553/// uses_odbc_syntax: false,
554/// args: FunctionArguments::List(FunctionArgumentList {
555/// duplicate_treatment: None,
556/// args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(old_expr))],
557/// clauses: vec![],
558/// }),
559/// null_treatment: None,
560/// filter: None,
561/// over: None,
562/// parameters: FunctionArguments::None,
563/// within_group: vec![],
564/// });
565/// }
566/// ControlFlow::<()>::Continue(())
567/// });
568///
569/// assert_eq!(statements[0].to_string(), "SELECT f(x), y FROM t");
570/// ```
571pub fn visit_expressions_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
572where
573 V: VisitMut,
574 F: FnMut(&mut Expr) -> ControlFlow<E>,
575{
576 v.visit(&mut ExprVisitor(f))?;
577 ControlFlow::Continue(())
578}
579
580struct StatementVisitor<F>(F);
581
582impl<E, F: FnMut(&Statement) -> ControlFlow<E>> Visitor for StatementVisitor<F> {
583 type Break = E;
584
585 fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
586 self.0(statement)
587 }
588}
589
590impl<E, F: FnMut(&mut Statement) -> ControlFlow<E>> VisitorMut for StatementVisitor<F> {
591 type Break = E;
592
593 fn post_visit_statement(&mut self, statement: &mut Statement) -> ControlFlow<Self::Break> {
594 self.0(statement)
595 }
596}
597
598/// Invokes the provided closure iteratively with a mutable reference to all statements
599/// present in `v` (e.g. `SELECT`, `CREATE TABLE`, etc).
600///
601/// # Example
602/// ```
603/// # use sqlparser::parser::Parser;
604/// # use sqlparser::dialect::GenericDialect;
605/// # use sqlparser::ast::{visit_statements};
606/// # use core::ops::ControlFlow;
607/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar); CREATE TABLE baz(q int)";
608/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
609/// .unwrap();
610///
611/// // visit all statements
612/// let mut visited = vec![];
613/// visit_statements(&statements, |stmt| {
614/// visited.push(format!("STATEMENT: {}", stmt));
615/// ControlFlow::<()>::Continue(())
616/// });
617///
618/// let expected : Vec<_> = [
619/// "STATEMENT: SELECT a FROM foo WHERE x IN (SELECT y FROM bar)",
620/// "STATEMENT: CREATE TABLE baz (q INT)"
621/// ]
622/// .into_iter().map(|s| s.to_string()).collect();
623///
624/// assert_eq!(visited, expected);
625/// ```
626pub fn visit_statements<V, E, F>(v: &V, f: F) -> ControlFlow<E>
627where
628 V: Visit,
629 F: FnMut(&Statement) -> ControlFlow<E>,
630{
631 let mut visitor = StatementVisitor(f);
632 v.visit(&mut visitor)?;
633 ControlFlow::Continue(())
634}
635
636/// Invokes the provided closure on all statements (e.g. `SELECT`, `CREATE TABLE`, etc) present in `v`
637///
638/// # Example
639/// ```
640/// # use sqlparser::parser::Parser;
641/// # use sqlparser::dialect::GenericDialect;
642/// # use sqlparser::ast::{Statement, visit_statements_mut};
643/// # use core::ops::ControlFlow;
644/// let sql = "SELECT x FROM foo LIMIT 9+$limit; SELECT * FROM t LIMIT f()";
645/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
646///
647/// // Remove all select limits in outer statements (not in sub-queries)
648/// visit_statements_mut(&mut statements, |stmt| {
649/// if let Statement::Query(q) = stmt {
650/// q.limit_clause = None;
651/// }
652/// ControlFlow::<()>::Continue(())
653/// });
654///
655/// assert_eq!(statements[0].to_string(), "SELECT x FROM foo");
656/// assert_eq!(statements[1].to_string(), "SELECT * FROM t");
657/// ```
658pub fn visit_statements_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
659where
660 V: VisitMut,
661 F: FnMut(&mut Statement) -> ControlFlow<E>,
662{
663 v.visit(&mut StatementVisitor(f))?;
664 ControlFlow::Continue(())
665}
666
667#[cfg(test)]
668mod tests {
669 use super::*;
670 use crate::ast::Statement;
671 use crate::dialect::GenericDialect;
672 use crate::parser::Parser;
673 use crate::tokenizer::Tokenizer;
674
675 #[derive(Default)]
676 struct TestVisitor {
677 visited: Vec<String>,
678 }
679
680 impl Visitor for TestVisitor {
681 type Break = ();
682
683 /// Invoked for any queries that appear in the AST before visiting children
684 fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
685 self.visited.push(format!("PRE: QUERY: {query}"));
686 ControlFlow::Continue(())
687 }
688
689 /// Invoked for any queries that appear in the AST after visiting children
690 fn post_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
691 self.visited.push(format!("POST: QUERY: {query}"));
692 ControlFlow::Continue(())
693 }
694
695 fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
696 self.visited.push(format!("PRE: RELATION: {relation}"));
697 ControlFlow::Continue(())
698 }
699
700 fn post_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
701 self.visited.push(format!("POST: RELATION: {relation}"));
702 ControlFlow::Continue(())
703 }
704
705 fn pre_visit_table_factor(
706 &mut self,
707 table_factor: &TableFactor,
708 ) -> ControlFlow<Self::Break> {
709 self.visited
710 .push(format!("PRE: TABLE FACTOR: {table_factor}"));
711 ControlFlow::Continue(())
712 }
713
714 fn post_visit_table_factor(
715 &mut self,
716 table_factor: &TableFactor,
717 ) -> ControlFlow<Self::Break> {
718 self.visited
719 .push(format!("POST: TABLE FACTOR: {table_factor}"));
720 ControlFlow::Continue(())
721 }
722
723 fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
724 self.visited.push(format!("PRE: EXPR: {expr}"));
725 ControlFlow::Continue(())
726 }
727
728 fn post_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
729 self.visited.push(format!("POST: EXPR: {expr}"));
730 ControlFlow::Continue(())
731 }
732
733 fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
734 self.visited.push(format!("PRE: STATEMENT: {statement}"));
735 ControlFlow::Continue(())
736 }
737
738 fn post_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
739 self.visited.push(format!("POST: STATEMENT: {statement}"));
740 ControlFlow::Continue(())
741 }
742 }
743
744 fn do_visit<V: Visitor>(sql: &str, visitor: &mut V) -> Statement {
745 let dialect = GenericDialect {};
746 let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
747 let s = Parser::new(&dialect)
748 .with_tokens(tokens)
749 .parse_statement()
750 .unwrap();
751
752 s.visit(visitor);
753 s
754 }
755
756 #[test]
757 fn test_sql() {
758 let tests = vec![
759 (
760 "SELECT * from table_name as my_table",
761 vec![
762 "PRE: STATEMENT: SELECT * FROM table_name AS my_table",
763 "PRE: QUERY: SELECT * FROM table_name AS my_table",
764 "PRE: TABLE FACTOR: table_name AS my_table",
765 "PRE: RELATION: table_name",
766 "POST: RELATION: table_name",
767 "POST: TABLE FACTOR: table_name AS my_table",
768 "POST: QUERY: SELECT * FROM table_name AS my_table",
769 "POST: STATEMENT: SELECT * FROM table_name AS my_table",
770 ],
771 ),
772 (
773 "SELECT * from t1 join t2 on t1.id = t2.t1_id",
774 vec![
775 "PRE: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
776 "PRE: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
777 "PRE: TABLE FACTOR: t1",
778 "PRE: RELATION: t1",
779 "POST: RELATION: t1",
780 "POST: TABLE FACTOR: t1",
781 "PRE: TABLE FACTOR: t2",
782 "PRE: RELATION: t2",
783 "POST: RELATION: t2",
784 "POST: TABLE FACTOR: t2",
785 "PRE: EXPR: t1.id = t2.t1_id",
786 "PRE: EXPR: t1.id",
787 "POST: EXPR: t1.id",
788 "PRE: EXPR: t2.t1_id",
789 "POST: EXPR: t2.t1_id",
790 "POST: EXPR: t1.id = t2.t1_id",
791 "POST: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
792 "POST: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
793 ],
794 ),
795 (
796 "SELECT * from t1 where EXISTS(SELECT column from t2)",
797 vec![
798 "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
799 "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
800 "PRE: TABLE FACTOR: t1",
801 "PRE: RELATION: t1",
802 "POST: RELATION: t1",
803 "POST: TABLE FACTOR: t1",
804 "PRE: EXPR: EXISTS (SELECT column FROM t2)",
805 "PRE: QUERY: SELECT column FROM t2",
806 "PRE: EXPR: column",
807 "POST: EXPR: column",
808 "PRE: TABLE FACTOR: t2",
809 "PRE: RELATION: t2",
810 "POST: RELATION: t2",
811 "POST: TABLE FACTOR: t2",
812 "POST: QUERY: SELECT column FROM t2",
813 "POST: EXPR: EXISTS (SELECT column FROM t2)",
814 "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
815 "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
816 ],
817 ),
818 (
819 "SELECT * from t1 where EXISTS(SELECT column from t2)",
820 vec![
821 "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
822 "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
823 "PRE: TABLE FACTOR: t1",
824 "PRE: RELATION: t1",
825 "POST: RELATION: t1",
826 "POST: TABLE FACTOR: t1",
827 "PRE: EXPR: EXISTS (SELECT column FROM t2)",
828 "PRE: QUERY: SELECT column FROM t2",
829 "PRE: EXPR: column",
830 "POST: EXPR: column",
831 "PRE: TABLE FACTOR: t2",
832 "PRE: RELATION: t2",
833 "POST: RELATION: t2",
834 "POST: TABLE FACTOR: t2",
835 "POST: QUERY: SELECT column FROM t2",
836 "POST: EXPR: EXISTS (SELECT column FROM t2)",
837 "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
838 "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
839 ],
840 ),
841 (
842 "SELECT * from t1 where EXISTS(SELECT column from t2) UNION SELECT * from t3",
843 vec![
844 "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
845 "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
846 "PRE: TABLE FACTOR: t1",
847 "PRE: RELATION: t1",
848 "POST: RELATION: t1",
849 "POST: TABLE FACTOR: t1",
850 "PRE: EXPR: EXISTS (SELECT column FROM t2)",
851 "PRE: QUERY: SELECT column FROM t2",
852 "PRE: EXPR: column",
853 "POST: EXPR: column",
854 "PRE: TABLE FACTOR: t2",
855 "PRE: RELATION: t2",
856 "POST: RELATION: t2",
857 "POST: TABLE FACTOR: t2",
858 "POST: QUERY: SELECT column FROM t2",
859 "POST: EXPR: EXISTS (SELECT column FROM t2)",
860 "PRE: TABLE FACTOR: t3",
861 "PRE: RELATION: t3",
862 "POST: RELATION: t3",
863 "POST: TABLE FACTOR: t3",
864 "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
865 "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
866 ],
867 ),
868 (
869 concat!(
870 "SELECT * FROM monthly_sales ",
871 "PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
872 "ORDER BY EMPID"
873 ),
874 vec![
875 "PRE: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
876 "PRE: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
877 "PRE: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
878 "PRE: TABLE FACTOR: monthly_sales",
879 "PRE: RELATION: monthly_sales",
880 "POST: RELATION: monthly_sales",
881 "POST: TABLE FACTOR: monthly_sales",
882 "PRE: EXPR: SUM(a.amount)",
883 "PRE: EXPR: a.amount",
884 "POST: EXPR: a.amount",
885 "POST: EXPR: SUM(a.amount)",
886 "PRE: EXPR: 'JAN'",
887 "POST: EXPR: 'JAN'",
888 "PRE: EXPR: 'FEB'",
889 "POST: EXPR: 'FEB'",
890 "PRE: EXPR: 'MAR'",
891 "POST: EXPR: 'MAR'",
892 "PRE: EXPR: 'APR'",
893 "POST: EXPR: 'APR'",
894 "POST: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
895 "PRE: EXPR: EMPID",
896 "POST: EXPR: EMPID",
897 "POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
898 "POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
899 ]
900 ),
901 (
902 "SHOW COLUMNS FROM t1",
903 vec![
904 "PRE: STATEMENT: SHOW COLUMNS FROM t1",
905 "PRE: RELATION: t1",
906 "POST: RELATION: t1",
907 "POST: STATEMENT: SHOW COLUMNS FROM t1",
908 ],
909 ),
910 ];
911 for (sql, expected) in tests {
912 let mut visitor = TestVisitor::default();
913 let _ = do_visit(sql, &mut visitor);
914 let actual: Vec<_> = visitor.visited.iter().map(|x| x.as_str()).collect();
915 assert_eq!(actual, expected)
916 }
917 }
918
919 struct QuickVisitor; // [`TestVisitor`] is too slow to iterate over thousands of nodes
920
921 impl Visitor for QuickVisitor {
922 type Break = ();
923 }
924
925 #[test]
926 fn overflow() {
927 let cond = (0..1000)
928 .map(|n| format!("X = {}", n))
929 .collect::<Vec<_>>()
930 .join(" OR ");
931 let sql = format!("SELECT x where {0}", cond);
932
933 let dialect = GenericDialect {};
934 let tokens = Tokenizer::new(&dialect, sql.as_str()).tokenize().unwrap();
935 let s = Parser::new(&dialect)
936 .with_tokens(tokens)
937 .parse_statement()
938 .unwrap();
939
940 let mut visitor = QuickVisitor {};
941 s.visit(&mut visitor);
942 }
943}
944
945#[cfg(test)]
946mod visit_mut_tests {
947 use crate::ast::{Statement, Value, VisitMut, VisitorMut};
948 use crate::dialect::GenericDialect;
949 use crate::parser::Parser;
950 use crate::tokenizer::Tokenizer;
951 use core::ops::ControlFlow;
952
953 #[derive(Default)]
954 struct MutatorVisitor {
955 index: u64,
956 }
957
958 impl VisitorMut for MutatorVisitor {
959 type Break = ();
960
961 fn pre_visit_value(&mut self, value: &mut Value) -> ControlFlow<Self::Break> {
962 self.index += 1;
963 *value = Value::SingleQuotedString(format!("REDACTED_{}", self.index));
964 ControlFlow::Continue(())
965 }
966
967 fn post_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
968 ControlFlow::Continue(())
969 }
970 }
971
972 fn do_visit_mut<V: VisitorMut>(sql: &str, visitor: &mut V) -> Statement {
973 let dialect = GenericDialect {};
974 let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
975 let mut s = Parser::new(&dialect)
976 .with_tokens(tokens)
977 .parse_statement()
978 .unwrap();
979
980 s.visit(visitor);
981 s
982 }
983
984 #[test]
985 fn test_value_redact() {
986 let tests = vec![
987 (
988 concat!(
989 "SELECT * FROM monthly_sales ",
990 "PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
991 "ORDER BY EMPID"
992 ),
993 concat!(
994 "SELECT * FROM monthly_sales ",
995 "PIVOT(SUM(a.amount) FOR a.MONTH IN ('REDACTED_1', 'REDACTED_2', 'REDACTED_3', 'REDACTED_4')) AS p (c, d) ",
996 "ORDER BY EMPID"
997 ),
998 ),
999 ];
1000
1001 for (sql, expected) in tests {
1002 let mut visitor = MutatorVisitor::default();
1003 let mutated = do_visit_mut(sql, &mut visitor);
1004 assert_eq!(mutated.to_string(), expected)
1005 }
1006 }
1007}