sqlparser/ast/
visitor.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Recursive visitors for ast Nodes. See [`Visitor`] for more details.
19
20use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor, Value};
21use core::ops::ControlFlow;
22
23/// A type that can be visited by a [`Visitor`]. See [`Visitor`] for
24/// recursively visiting parsed SQL statements.
25///
26/// # Note
27///
28/// This trait should be automatically derived for sqlparser AST nodes
29/// using the [Visit](sqlparser_derive::Visit) proc macro.
30///
31/// ```text
32/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
33/// ```
34pub trait Visit {
35    fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break>;
36}
37
38/// A type that can be visited by a [`VisitorMut`]. See [`VisitorMut`] for
39/// recursively visiting parsed SQL statements.
40///
41/// # Note
42///
43/// This trait should be automatically derived for sqlparser AST nodes
44/// using the [VisitMut](sqlparser_derive::VisitMut) proc macro.
45///
46/// ```text
47/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
48/// ```
49pub trait VisitMut {
50    fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break>;
51}
52
53impl<T: Visit> Visit for Option<T> {
54    fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
55        if let Some(s) = self {
56            s.visit(visitor)?;
57        }
58        ControlFlow::Continue(())
59    }
60}
61
62impl<T: Visit> Visit for Vec<T> {
63    fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
64        for v in self {
65            v.visit(visitor)?;
66        }
67        ControlFlow::Continue(())
68    }
69}
70
71impl<T: Visit> Visit for Box<T> {
72    fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
73        T::visit(self, visitor)
74    }
75}
76
77impl<T: VisitMut> VisitMut for Option<T> {
78    fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
79        if let Some(s) = self {
80            s.visit(visitor)?;
81        }
82        ControlFlow::Continue(())
83    }
84}
85
86impl<T: VisitMut> VisitMut for Vec<T> {
87    fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
88        for v in self {
89            v.visit(visitor)?;
90        }
91        ControlFlow::Continue(())
92    }
93}
94
95impl<T: VisitMut> VisitMut for Box<T> {
96    fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
97        T::visit(self, visitor)
98    }
99}
100
101macro_rules! visit_noop {
102    ($($t:ty),+) => {
103        $(impl Visit for $t {
104            fn visit<V: Visitor>(&self, _visitor: &mut V) -> ControlFlow<V::Break> {
105               ControlFlow::Continue(())
106            }
107        })+
108        $(impl VisitMut for $t {
109            fn visit<V: VisitorMut>(&mut self, _visitor: &mut V) -> ControlFlow<V::Break> {
110               ControlFlow::Continue(())
111            }
112        })+
113    };
114}
115
116visit_noop!(u8, u16, u32, u64, i8, i16, i32, i64, char, bool, String);
117
118#[cfg(feature = "bigdecimal")]
119visit_noop!(bigdecimal::BigDecimal);
120
121/// A visitor that can be used to walk an AST tree.
122///
123/// `pre_visit_` methods are invoked before visiting all children of the
124/// node and `post_visit_` methods are invoked after visiting all
125/// children of the node.
126///
127/// # See also
128///
129/// These methods provide a more concise way of visiting nodes of a certain type:
130/// * [visit_relations]
131/// * [visit_expressions]
132/// * [visit_statements]
133///
134/// # Example
135/// ```
136/// # use sqlparser::parser::Parser;
137/// # use sqlparser::dialect::GenericDialect;
138/// # use sqlparser::ast::{Visit, Visitor, ObjectName, Expr};
139/// # use core::ops::ControlFlow;
140/// // A structure that records statements and relations
141/// #[derive(Default)]
142/// struct V {
143///    visited: Vec<String>,
144/// }
145///
146/// // Visit relations and exprs before children are visited (depth first walk)
147/// // Note you can also visit statements and visit exprs after children have been visited
148/// impl Visitor for V {
149///   type Break = ();
150///
151///   fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
152///     self.visited.push(format!("PRE: RELATION: {}", relation));
153///     ControlFlow::Continue(())
154///   }
155///
156///   fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
157///     self.visited.push(format!("PRE: EXPR: {}", expr));
158///     ControlFlow::Continue(())
159///   }
160/// }
161///
162/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
163/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
164///    .unwrap();
165///
166/// // Drive the visitor through the AST
167/// let mut visitor = V::default();
168/// statements.visit(&mut visitor);
169///
170/// // The visitor has visited statements and expressions in pre-traversal order
171/// let expected : Vec<_> = [
172///   "PRE: EXPR: a",
173///   "PRE: RELATION: foo",
174///   "PRE: EXPR: x IN (SELECT y FROM bar)",
175///   "PRE: EXPR: x",
176///   "PRE: EXPR: y",
177///   "PRE: RELATION: bar",
178/// ]
179///   .into_iter().map(|s| s.to_string()).collect();
180///
181/// assert_eq!(visitor.visited, expected);
182/// ```
183pub trait Visitor {
184    /// Type returned when the recursion returns early.
185    type Break;
186
187    /// Invoked for any queries that appear in the AST before visiting children
188    fn pre_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
189        ControlFlow::Continue(())
190    }
191
192    /// Invoked for any queries that appear in the AST after visiting children
193    fn post_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
194        ControlFlow::Continue(())
195    }
196
197    /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
198    fn pre_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
199        ControlFlow::Continue(())
200    }
201
202    /// Invoked for any relations (e.g. tables) that appear in the AST after visiting children
203    fn post_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
204        ControlFlow::Continue(())
205    }
206
207    /// Invoked for any table factors that appear in the AST before visiting children
208    fn pre_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
209        ControlFlow::Continue(())
210    }
211
212    /// Invoked for any table factors that appear in the AST after visiting children
213    fn post_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
214        ControlFlow::Continue(())
215    }
216
217    /// Invoked for any expressions that appear in the AST before visiting children
218    fn pre_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
219        ControlFlow::Continue(())
220    }
221
222    /// Invoked for any expressions that appear in the AST
223    fn post_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
224        ControlFlow::Continue(())
225    }
226
227    /// Invoked for any statements that appear in the AST before visiting children
228    fn pre_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
229        ControlFlow::Continue(())
230    }
231
232    /// Invoked for any statements that appear in the AST after visiting children
233    fn post_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
234        ControlFlow::Continue(())
235    }
236
237    /// Invoked for any Value that appear in the AST before visiting children
238    fn pre_visit_value(&mut self, _value: &Value) -> ControlFlow<Self::Break> {
239        ControlFlow::Continue(())
240    }
241
242    /// Invoked for any Value that appear in the AST after visiting children
243    fn post_visit_value(&mut self, _value: &Value) -> ControlFlow<Self::Break> {
244        ControlFlow::Continue(())
245    }
246}
247
248/// A visitor that can be used to mutate an AST tree.
249///
250/// `pre_visit_` methods are invoked before visiting all children of the
251/// node and `post_visit_` methods are invoked after visiting all
252/// children of the node.
253///
254/// # See also
255///
256/// These methods provide a more concise way of visiting nodes of a certain type:
257/// * [visit_relations_mut]
258/// * [visit_expressions_mut]
259/// * [visit_statements_mut]
260///
261/// # Example
262/// ```
263/// # use sqlparser::parser::Parser;
264/// # use sqlparser::dialect::GenericDialect;
265/// # use sqlparser::ast::{VisitMut, VisitorMut, ObjectName, Expr, Ident};
266/// # use core::ops::ControlFlow;
267///
268/// // A visitor that replaces "to_replace" with "replaced" in all expressions
269/// struct Replacer;
270///
271/// // Visit each expression after its children have been visited
272/// impl VisitorMut for Replacer {
273///   type Break = ();
274///
275///   fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
276///     if let Expr::Identifier(Ident{ value, ..}) = expr {
277///         *value = value.replace("to_replace", "replaced")
278///     }
279///     ControlFlow::Continue(())
280///   }
281/// }
282///
283/// let sql = "SELECT to_replace FROM foo where to_replace IN (SELECT to_replace FROM bar)";
284/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
285///
286/// // Drive the visitor through the AST
287/// statements.visit(&mut Replacer);
288///
289/// assert_eq!(statements[0].to_string(), "SELECT replaced FROM foo WHERE replaced IN (SELECT replaced FROM bar)");
290/// ```
291pub trait VisitorMut {
292    /// Type returned when the recursion returns early.
293    type Break;
294
295    /// Invoked for any queries that appear in the AST before visiting children
296    fn pre_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
297        ControlFlow::Continue(())
298    }
299
300    /// Invoked for any queries that appear in the AST after visiting children
301    fn post_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
302        ControlFlow::Continue(())
303    }
304
305    /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
306    fn pre_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
307        ControlFlow::Continue(())
308    }
309
310    /// Invoked for any relations (e.g. tables) that appear in the AST after visiting children
311    fn post_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
312        ControlFlow::Continue(())
313    }
314
315    /// Invoked for any table factors that appear in the AST before visiting children
316    fn pre_visit_table_factor(
317        &mut self,
318        _table_factor: &mut TableFactor,
319    ) -> ControlFlow<Self::Break> {
320        ControlFlow::Continue(())
321    }
322
323    /// Invoked for any table factors that appear in the AST after visiting children
324    fn post_visit_table_factor(
325        &mut self,
326        _table_factor: &mut TableFactor,
327    ) -> ControlFlow<Self::Break> {
328        ControlFlow::Continue(())
329    }
330
331    /// Invoked for any expressions that appear in the AST before visiting children
332    fn pre_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
333        ControlFlow::Continue(())
334    }
335
336    /// Invoked for any expressions that appear in the AST
337    fn post_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
338        ControlFlow::Continue(())
339    }
340
341    /// Invoked for any statements that appear in the AST before visiting children
342    fn pre_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
343        ControlFlow::Continue(())
344    }
345
346    /// Invoked for any statements that appear in the AST after visiting children
347    fn post_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
348        ControlFlow::Continue(())
349    }
350
351    /// Invoked for any value that appear in the AST before visiting children
352    fn pre_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
353        ControlFlow::Continue(())
354    }
355
356    /// Invoked for any statements that appear in the AST after visiting children
357    fn post_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
358        ControlFlow::Continue(())
359    }
360}
361
362struct RelationVisitor<F>(F);
363
364impl<E, F: FnMut(&ObjectName) -> ControlFlow<E>> Visitor for RelationVisitor<F> {
365    type Break = E;
366
367    fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
368        self.0(relation)
369    }
370}
371
372impl<E, F: FnMut(&mut ObjectName) -> ControlFlow<E>> VisitorMut for RelationVisitor<F> {
373    type Break = E;
374
375    fn post_visit_relation(&mut self, relation: &mut ObjectName) -> ControlFlow<Self::Break> {
376        self.0(relation)
377    }
378}
379
380/// Invokes the provided closure on all relations (e.g. table names) present in `v`
381///
382/// # Example
383/// ```
384/// # use sqlparser::parser::Parser;
385/// # use sqlparser::dialect::GenericDialect;
386/// # use sqlparser::ast::{visit_relations};
387/// # use core::ops::ControlFlow;
388/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
389/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
390///    .unwrap();
391///
392/// // visit statements, capturing relations (table names)
393/// let mut visited = vec![];
394/// visit_relations(&statements, |relation| {
395///   visited.push(format!("RELATION: {}", relation));
396///   ControlFlow::<()>::Continue(())
397/// });
398///
399/// let expected : Vec<_> = [
400///   "RELATION: foo",
401///   "RELATION: bar",
402/// ]
403///   .into_iter().map(|s| s.to_string()).collect();
404///
405/// assert_eq!(visited, expected);
406/// ```
407pub fn visit_relations<V, E, F>(v: &V, f: F) -> ControlFlow<E>
408where
409    V: Visit,
410    F: FnMut(&ObjectName) -> ControlFlow<E>,
411{
412    let mut visitor = RelationVisitor(f);
413    v.visit(&mut visitor)?;
414    ControlFlow::Continue(())
415}
416
417/// Invokes the provided closure with a mutable reference to all relations (e.g. table names)
418/// present in `v`.
419///
420/// When the closure mutates its argument, the new mutated relation will not be visited again.
421///
422/// # Example
423/// ```
424/// # use sqlparser::parser::Parser;
425/// # use sqlparser::dialect::GenericDialect;
426/// # use sqlparser::ast::{ObjectName, ObjectNamePart, Ident, visit_relations_mut};
427/// # use core::ops::ControlFlow;
428/// let sql = "SELECT a FROM foo";
429/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql)
430///    .unwrap();
431///
432/// // visit statements, renaming table foo to bar
433/// visit_relations_mut(&mut statements, |table| {
434///   table.0[0] = ObjectNamePart::Identifier(Ident::new("bar"));
435///   ControlFlow::<()>::Continue(())
436/// });
437///
438/// assert_eq!(statements[0].to_string(), "SELECT a FROM bar");
439/// ```
440pub fn visit_relations_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
441where
442    V: VisitMut,
443    F: FnMut(&mut ObjectName) -> ControlFlow<E>,
444{
445    let mut visitor = RelationVisitor(f);
446    v.visit(&mut visitor)?;
447    ControlFlow::Continue(())
448}
449
450struct ExprVisitor<F>(F);
451
452impl<E, F: FnMut(&Expr) -> ControlFlow<E>> Visitor for ExprVisitor<F> {
453    type Break = E;
454
455    fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
456        self.0(expr)
457    }
458}
459
460impl<E, F: FnMut(&mut Expr) -> ControlFlow<E>> VisitorMut for ExprVisitor<F> {
461    type Break = E;
462
463    fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
464        self.0(expr)
465    }
466}
467
468/// Invokes the provided closure on all expressions (e.g. `1 + 2`) present in `v`
469///
470/// # Example
471/// ```
472/// # use sqlparser::parser::Parser;
473/// # use sqlparser::dialect::GenericDialect;
474/// # use sqlparser::ast::{visit_expressions};
475/// # use core::ops::ControlFlow;
476/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
477/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
478///    .unwrap();
479///
480/// // visit all expressions
481/// let mut visited = vec![];
482/// visit_expressions(&statements, |expr| {
483///   visited.push(format!("EXPR: {}", expr));
484///   ControlFlow::<()>::Continue(())
485/// });
486///
487/// let expected : Vec<_> = [
488///   "EXPR: a",
489///   "EXPR: x IN (SELECT y FROM bar)",
490///   "EXPR: x",
491///   "EXPR: y",
492/// ]
493///   .into_iter().map(|s| s.to_string()).collect();
494///
495/// assert_eq!(visited, expected);
496/// ```
497pub fn visit_expressions<V, E, F>(v: &V, f: F) -> ControlFlow<E>
498where
499    V: Visit,
500    F: FnMut(&Expr) -> ControlFlow<E>,
501{
502    let mut visitor = ExprVisitor(f);
503    v.visit(&mut visitor)?;
504    ControlFlow::Continue(())
505}
506
507/// Invokes the provided closure iteratively with a mutable reference to all expressions
508/// present in `v`.
509///
510/// This performs a depth-first search, so if the closure mutates the expression
511///
512/// # Example
513///
514/// ## Remove all select limits in sub-queries
515/// ```
516/// # use sqlparser::parser::Parser;
517/// # use sqlparser::dialect::GenericDialect;
518/// # use sqlparser::ast::{Expr, visit_expressions_mut, visit_statements_mut};
519/// # use core::ops::ControlFlow;
520/// let sql = "SELECT (SELECT y FROM z LIMIT 9) FROM t LIMIT 3";
521/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
522///
523/// // Remove all select limits in sub-queries
524/// visit_expressions_mut(&mut statements, |expr| {
525///   if let Expr::Subquery(q) = expr {
526///      q.limit_clause = None;
527///   }
528///   ControlFlow::<()>::Continue(())
529/// });
530///
531/// assert_eq!(statements[0].to_string(), "SELECT (SELECT y FROM z) FROM t LIMIT 3");
532/// ```
533///
534/// ## Wrap column name in function call
535///
536/// This demonstrates how to effectively replace an expression with another more complicated one
537/// that references the original. This example avoids unnecessary allocations by using the
538/// [`std::mem`] family of functions.
539///
540/// ```
541/// # use sqlparser::parser::Parser;
542/// # use sqlparser::dialect::GenericDialect;
543/// # use sqlparser::ast::*;
544/// # use core::ops::ControlFlow;
545/// let sql = "SELECT x, y FROM t";
546/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
547///
548/// visit_expressions_mut(&mut statements, |expr| {
549///   if matches!(expr, Expr::Identifier(col_name) if col_name.value == "x") {
550///     let old_expr = std::mem::replace(expr, Expr::value(Value::Null));
551///     *expr = Expr::Function(Function {
552///           name: ObjectName::from(vec![Ident::new("f")]),
553///           uses_odbc_syntax: false,
554///           args: FunctionArguments::List(FunctionArgumentList {
555///               duplicate_treatment: None,
556///               args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(old_expr))],
557///               clauses: vec![],
558///           }),
559///           null_treatment: None,
560///           filter: None,
561///           over: None,
562///           parameters: FunctionArguments::None,
563///           within_group: vec![],
564///      });
565///   }
566///   ControlFlow::<()>::Continue(())
567/// });
568///
569/// assert_eq!(statements[0].to_string(), "SELECT f(x), y FROM t");
570/// ```
571pub fn visit_expressions_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
572where
573    V: VisitMut,
574    F: FnMut(&mut Expr) -> ControlFlow<E>,
575{
576    v.visit(&mut ExprVisitor(f))?;
577    ControlFlow::Continue(())
578}
579
580struct StatementVisitor<F>(F);
581
582impl<E, F: FnMut(&Statement) -> ControlFlow<E>> Visitor for StatementVisitor<F> {
583    type Break = E;
584
585    fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
586        self.0(statement)
587    }
588}
589
590impl<E, F: FnMut(&mut Statement) -> ControlFlow<E>> VisitorMut for StatementVisitor<F> {
591    type Break = E;
592
593    fn post_visit_statement(&mut self, statement: &mut Statement) -> ControlFlow<Self::Break> {
594        self.0(statement)
595    }
596}
597
598/// Invokes the provided closure iteratively with a mutable reference to all statements
599/// present in `v` (e.g. `SELECT`, `CREATE TABLE`, etc).
600///
601/// # Example
602/// ```
603/// # use sqlparser::parser::Parser;
604/// # use sqlparser::dialect::GenericDialect;
605/// # use sqlparser::ast::{visit_statements};
606/// # use core::ops::ControlFlow;
607/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar); CREATE TABLE baz(q int)";
608/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
609///    .unwrap();
610///
611/// // visit all statements
612/// let mut visited = vec![];
613/// visit_statements(&statements, |stmt| {
614///   visited.push(format!("STATEMENT: {}", stmt));
615///   ControlFlow::<()>::Continue(())
616/// });
617///
618/// let expected : Vec<_> = [
619///   "STATEMENT: SELECT a FROM foo WHERE x IN (SELECT y FROM bar)",
620///   "STATEMENT: CREATE TABLE baz (q INT)"
621/// ]
622///   .into_iter().map(|s| s.to_string()).collect();
623///
624/// assert_eq!(visited, expected);
625/// ```
626pub fn visit_statements<V, E, F>(v: &V, f: F) -> ControlFlow<E>
627where
628    V: Visit,
629    F: FnMut(&Statement) -> ControlFlow<E>,
630{
631    let mut visitor = StatementVisitor(f);
632    v.visit(&mut visitor)?;
633    ControlFlow::Continue(())
634}
635
636/// Invokes the provided closure on all statements (e.g. `SELECT`, `CREATE TABLE`, etc) present in `v`
637///
638/// # Example
639/// ```
640/// # use sqlparser::parser::Parser;
641/// # use sqlparser::dialect::GenericDialect;
642/// # use sqlparser::ast::{Statement, visit_statements_mut};
643/// # use core::ops::ControlFlow;
644/// let sql = "SELECT x FROM foo LIMIT 9+$limit; SELECT * FROM t LIMIT f()";
645/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
646///
647/// // Remove all select limits in outer statements (not in sub-queries)
648/// visit_statements_mut(&mut statements, |stmt| {
649///   if let Statement::Query(q) = stmt {
650///      q.limit_clause = None;
651///   }
652///   ControlFlow::<()>::Continue(())
653/// });
654///
655/// assert_eq!(statements[0].to_string(), "SELECT x FROM foo");
656/// assert_eq!(statements[1].to_string(), "SELECT * FROM t");
657/// ```
658pub fn visit_statements_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
659where
660    V: VisitMut,
661    F: FnMut(&mut Statement) -> ControlFlow<E>,
662{
663    v.visit(&mut StatementVisitor(f))?;
664    ControlFlow::Continue(())
665}
666
667#[cfg(test)]
668mod tests {
669    use super::*;
670    use crate::ast::Statement;
671    use crate::dialect::GenericDialect;
672    use crate::parser::Parser;
673    use crate::tokenizer::Tokenizer;
674
675    #[derive(Default)]
676    struct TestVisitor {
677        visited: Vec<String>,
678    }
679
680    impl Visitor for TestVisitor {
681        type Break = ();
682
683        /// Invoked for any queries that appear in the AST before visiting children
684        fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
685            self.visited.push(format!("PRE: QUERY: {query}"));
686            ControlFlow::Continue(())
687        }
688
689        /// Invoked for any queries that appear in the AST after visiting children
690        fn post_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
691            self.visited.push(format!("POST: QUERY: {query}"));
692            ControlFlow::Continue(())
693        }
694
695        fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
696            self.visited.push(format!("PRE: RELATION: {relation}"));
697            ControlFlow::Continue(())
698        }
699
700        fn post_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
701            self.visited.push(format!("POST: RELATION: {relation}"));
702            ControlFlow::Continue(())
703        }
704
705        fn pre_visit_table_factor(
706            &mut self,
707            table_factor: &TableFactor,
708        ) -> ControlFlow<Self::Break> {
709            self.visited
710                .push(format!("PRE: TABLE FACTOR: {table_factor}"));
711            ControlFlow::Continue(())
712        }
713
714        fn post_visit_table_factor(
715            &mut self,
716            table_factor: &TableFactor,
717        ) -> ControlFlow<Self::Break> {
718            self.visited
719                .push(format!("POST: TABLE FACTOR: {table_factor}"));
720            ControlFlow::Continue(())
721        }
722
723        fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
724            self.visited.push(format!("PRE: EXPR: {expr}"));
725            ControlFlow::Continue(())
726        }
727
728        fn post_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
729            self.visited.push(format!("POST: EXPR: {expr}"));
730            ControlFlow::Continue(())
731        }
732
733        fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
734            self.visited.push(format!("PRE: STATEMENT: {statement}"));
735            ControlFlow::Continue(())
736        }
737
738        fn post_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
739            self.visited.push(format!("POST: STATEMENT: {statement}"));
740            ControlFlow::Continue(())
741        }
742    }
743
744    fn do_visit<V: Visitor>(sql: &str, visitor: &mut V) -> Statement {
745        let dialect = GenericDialect {};
746        let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
747        let s = Parser::new(&dialect)
748            .with_tokens(tokens)
749            .parse_statement()
750            .unwrap();
751
752        s.visit(visitor);
753        s
754    }
755
756    #[test]
757    fn test_sql() {
758        let tests = vec![
759            (
760                "SELECT * from table_name as my_table",
761                vec![
762                    "PRE: STATEMENT: SELECT * FROM table_name AS my_table",
763                    "PRE: QUERY: SELECT * FROM table_name AS my_table",
764                    "PRE: TABLE FACTOR: table_name AS my_table",
765                    "PRE: RELATION: table_name",
766                    "POST: RELATION: table_name",
767                    "POST: TABLE FACTOR: table_name AS my_table",
768                    "POST: QUERY: SELECT * FROM table_name AS my_table",
769                    "POST: STATEMENT: SELECT * FROM table_name AS my_table",
770                ],
771            ),
772            (
773                "SELECT * from t1 join t2 on t1.id = t2.t1_id",
774                vec![
775                    "PRE: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
776                    "PRE: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
777                    "PRE: TABLE FACTOR: t1",
778                    "PRE: RELATION: t1",
779                    "POST: RELATION: t1",
780                    "POST: TABLE FACTOR: t1",
781                    "PRE: TABLE FACTOR: t2",
782                    "PRE: RELATION: t2",
783                    "POST: RELATION: t2",
784                    "POST: TABLE FACTOR: t2",
785                    "PRE: EXPR: t1.id = t2.t1_id",
786                    "PRE: EXPR: t1.id",
787                    "POST: EXPR: t1.id",
788                    "PRE: EXPR: t2.t1_id",
789                    "POST: EXPR: t2.t1_id",
790                    "POST: EXPR: t1.id = t2.t1_id",
791                    "POST: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
792                    "POST: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
793                ],
794            ),
795            (
796                "SELECT * from t1 where EXISTS(SELECT column from t2)",
797                vec![
798                    "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
799                    "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
800                    "PRE: TABLE FACTOR: t1",
801                    "PRE: RELATION: t1",
802                    "POST: RELATION: t1",
803                    "POST: TABLE FACTOR: t1",
804                    "PRE: EXPR: EXISTS (SELECT column FROM t2)",
805                    "PRE: QUERY: SELECT column FROM t2",
806                    "PRE: EXPR: column",
807                    "POST: EXPR: column",
808                    "PRE: TABLE FACTOR: t2",
809                    "PRE: RELATION: t2",
810                    "POST: RELATION: t2",
811                    "POST: TABLE FACTOR: t2",
812                    "POST: QUERY: SELECT column FROM t2",
813                    "POST: EXPR: EXISTS (SELECT column FROM t2)",
814                    "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
815                    "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
816                ],
817            ),
818            (
819                "SELECT * from t1 where EXISTS(SELECT column from t2)",
820                vec![
821                    "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
822                    "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
823                    "PRE: TABLE FACTOR: t1",
824                    "PRE: RELATION: t1",
825                    "POST: RELATION: t1",
826                    "POST: TABLE FACTOR: t1",
827                    "PRE: EXPR: EXISTS (SELECT column FROM t2)",
828                    "PRE: QUERY: SELECT column FROM t2",
829                    "PRE: EXPR: column",
830                    "POST: EXPR: column",
831                    "PRE: TABLE FACTOR: t2",
832                    "PRE: RELATION: t2",
833                    "POST: RELATION: t2",
834                    "POST: TABLE FACTOR: t2",
835                    "POST: QUERY: SELECT column FROM t2",
836                    "POST: EXPR: EXISTS (SELECT column FROM t2)",
837                    "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
838                    "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
839                ],
840            ),
841            (
842                "SELECT * from t1 where EXISTS(SELECT column from t2) UNION SELECT * from t3",
843                vec![
844                    "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
845                    "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
846                    "PRE: TABLE FACTOR: t1",
847                    "PRE: RELATION: t1",
848                    "POST: RELATION: t1",
849                    "POST: TABLE FACTOR: t1",
850                    "PRE: EXPR: EXISTS (SELECT column FROM t2)",
851                    "PRE: QUERY: SELECT column FROM t2",
852                    "PRE: EXPR: column",
853                    "POST: EXPR: column",
854                    "PRE: TABLE FACTOR: t2",
855                    "PRE: RELATION: t2",
856                    "POST: RELATION: t2",
857                    "POST: TABLE FACTOR: t2",
858                    "POST: QUERY: SELECT column FROM t2",
859                    "POST: EXPR: EXISTS (SELECT column FROM t2)",
860                    "PRE: TABLE FACTOR: t3",
861                    "PRE: RELATION: t3",
862                    "POST: RELATION: t3",
863                    "POST: TABLE FACTOR: t3",
864                    "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
865                    "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
866                ],
867            ),
868            (
869                concat!(
870                    "SELECT * FROM monthly_sales ",
871                    "PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
872                    "ORDER BY EMPID"
873                ),
874                vec![
875                    "PRE: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
876                    "PRE: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
877                    "PRE: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
878                    "PRE: TABLE FACTOR: monthly_sales",
879                    "PRE: RELATION: monthly_sales",
880                    "POST: RELATION: monthly_sales",
881                    "POST: TABLE FACTOR: monthly_sales",
882                    "PRE: EXPR: SUM(a.amount)",
883                    "PRE: EXPR: a.amount",
884                    "POST: EXPR: a.amount",
885                    "POST: EXPR: SUM(a.amount)",
886                    "PRE: EXPR: 'JAN'",
887                    "POST: EXPR: 'JAN'",
888                    "PRE: EXPR: 'FEB'",
889                    "POST: EXPR: 'FEB'",
890                    "PRE: EXPR: 'MAR'",
891                    "POST: EXPR: 'MAR'",
892                    "PRE: EXPR: 'APR'",
893                    "POST: EXPR: 'APR'",
894                    "POST: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
895                    "PRE: EXPR: EMPID",
896                    "POST: EXPR: EMPID",
897                    "POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
898                    "POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
899                ]
900            ),
901            (
902                "SHOW COLUMNS FROM t1",
903                vec![
904                    "PRE: STATEMENT: SHOW COLUMNS FROM t1",
905                    "PRE: RELATION: t1",
906                    "POST: RELATION: t1",
907                    "POST: STATEMENT: SHOW COLUMNS FROM t1",
908                ],
909            ),
910        ];
911        for (sql, expected) in tests {
912            let mut visitor = TestVisitor::default();
913            let _ = do_visit(sql, &mut visitor);
914            let actual: Vec<_> = visitor.visited.iter().map(|x| x.as_str()).collect();
915            assert_eq!(actual, expected)
916        }
917    }
918
919    struct QuickVisitor; // [`TestVisitor`] is too slow to iterate over thousands of nodes
920
921    impl Visitor for QuickVisitor {
922        type Break = ();
923    }
924
925    #[test]
926    fn overflow() {
927        let cond = (0..1000)
928            .map(|n| format!("X = {}", n))
929            .collect::<Vec<_>>()
930            .join(" OR ");
931        let sql = format!("SELECT x where {0}", cond);
932
933        let dialect = GenericDialect {};
934        let tokens = Tokenizer::new(&dialect, sql.as_str()).tokenize().unwrap();
935        let s = Parser::new(&dialect)
936            .with_tokens(tokens)
937            .parse_statement()
938            .unwrap();
939
940        let mut visitor = QuickVisitor {};
941        s.visit(&mut visitor);
942    }
943}
944
945#[cfg(test)]
946mod visit_mut_tests {
947    use crate::ast::{Statement, Value, VisitMut, VisitorMut};
948    use crate::dialect::GenericDialect;
949    use crate::parser::Parser;
950    use crate::tokenizer::Tokenizer;
951    use core::ops::ControlFlow;
952
953    #[derive(Default)]
954    struct MutatorVisitor {
955        index: u64,
956    }
957
958    impl VisitorMut for MutatorVisitor {
959        type Break = ();
960
961        fn pre_visit_value(&mut self, value: &mut Value) -> ControlFlow<Self::Break> {
962            self.index += 1;
963            *value = Value::SingleQuotedString(format!("REDACTED_{}", self.index));
964            ControlFlow::Continue(())
965        }
966
967        fn post_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
968            ControlFlow::Continue(())
969        }
970    }
971
972    fn do_visit_mut<V: VisitorMut>(sql: &str, visitor: &mut V) -> Statement {
973        let dialect = GenericDialect {};
974        let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
975        let mut s = Parser::new(&dialect)
976            .with_tokens(tokens)
977            .parse_statement()
978            .unwrap();
979
980        s.visit(visitor);
981        s
982    }
983
984    #[test]
985    fn test_value_redact() {
986        let tests = vec![
987            (
988                concat!(
989                    "SELECT * FROM monthly_sales ",
990                    "PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
991                    "ORDER BY EMPID"
992                ),
993                concat!(
994                    "SELECT * FROM monthly_sales ",
995                    "PIVOT(SUM(a.amount) FOR a.MONTH IN ('REDACTED_1', 'REDACTED_2', 'REDACTED_3', 'REDACTED_4')) AS p (c, d) ",
996                    "ORDER BY EMPID"
997                ),
998            ),
999        ];
1000
1001        for (sql, expected) in tests {
1002            let mut visitor = MutatorVisitor::default();
1003            let mutated = do_visit_mut(sql, &mut visitor);
1004            assert_eq!(mutated.to_string(), expected)
1005        }
1006    }
1007}