diff --git a/src/ast/visitor.rs b/src/ast/visitor.rs index 90e414aae..85f9175ef 100644 --- a/src/ast/visitor.rs +++ b/src/ast/visitor.rs @@ -10,10 +10,22 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Recursive visitors for ast Nodes. See [`Visitor`] for more details. + use crate::ast::{Expr, ObjectName, Statement}; use core::ops::ControlFlow; -/// A type that can be visited by a `visitor` +/// A type that can be visited by a [`Visitor`]. See [`Visitor`] for +/// recursively visiting parsed SQL statements. +/// +/// # Note +/// +/// This trait should be automatically derived for sqlparser AST nodes +/// using the [Visit](sqlparser_derive::Visit) proc macro. +/// +/// ```text +/// #[cfg_attr(feature = "visitor", derive(Visit))] +/// ``` pub trait Visit { fn visit(&self, visitor: &mut V) -> ControlFlow; } @@ -57,8 +69,70 @@ visit_noop!(u8, u16, u32, u64, i8, i16, i32, i64, char, bool, String); #[cfg(feature = "bigdecimal")] visit_noop!(bigdecimal::BigDecimal); -/// A visitor that can be used to walk an AST tree +/// A visitor that can be used to walk an AST tree. +/// +/// `previst_` methods are invoked before visiting all children of the +/// node and `postvisit_` methods are invoked after visiting all +/// children of the node. +/// +/// # See also +/// +/// These methods provide a more concise way of visiting nodes of a certain type: +/// * [visit_relations] +/// * [visit_expressions] +/// * [visit_statements] +/// +/// # Example +/// ``` +/// # use sqlparser::parser::Parser; +/// # use sqlparser::dialect::GenericDialect; +/// # use sqlparser::ast::{Visit, Visitor, ObjectName, Expr}; +/// # use core::ops::ControlFlow; +/// // A structure that records statements and relations +/// #[derive(Default)] +/// struct V { +/// visited: Vec, +/// } +/// +/// // Visit relations and exprs before children are visited (depth first walk) +/// // Note you can also visit statements and visit exprs after children have been visitoed +/// impl Visitor for V { +/// type Break = (); +/// +/// fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow { +/// self.visited.push(format!("PRE: RELATION: {}", relation)); +/// ControlFlow::Continue(()) +/// } +/// +/// fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow { +/// self.visited.push(format!("PRE: EXPR: {}", expr)); +/// ControlFlow::Continue(()) +/// } +/// } +/// +/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)"; +/// let statements = Parser::parse_sql(&GenericDialect{}, sql) +/// .unwrap(); +/// +/// // Drive the visitor through the AST +/// let mut visitor = V::default(); +/// statements.visit(&mut visitor); +/// +/// // The visitor has visited statements and expressions in pre-traversal order +/// let expected : Vec<_> = [ +/// "PRE: EXPR: a", +/// "PRE: RELATION: foo", +/// "PRE: EXPR: x IN (SELECT y FROM bar)", +/// "PRE: EXPR: x", +/// "PRE: EXPR: y", +/// "PRE: RELATION: bar", +/// ] +/// .into_iter().map(|s| s.to_string()).collect(); +/// +/// assert_eq!(visitor.visited, expected); +/// ``` pub trait Visitor { + /// Type returned when the recursion returns early. type Break; /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children @@ -102,7 +176,33 @@ impl ControlFlow> Visitor for RelationVisitor } } -/// Invokes the provided closure on all relations present in v +/// Invokes the provided closure on all relations (e.g. table names) present in `v` +/// +/// # Example +/// ``` +/// # use sqlparser::parser::Parser; +/// # use sqlparser::dialect::GenericDialect; +/// # use sqlparser::ast::{visit_relations}; +/// # use core::ops::ControlFlow; +/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)"; +/// let statements = Parser::parse_sql(&GenericDialect{}, sql) +/// .unwrap(); +/// +/// // visit statements, capturing relations (table names) +/// let mut visited = vec![]; +/// visit_relations(&statements, |relation| { +/// visited.push(format!("RELATION: {}", relation)); +/// ControlFlow::<()>::Continue(()) +/// }); +/// +/// let expected : Vec<_> = [ +/// "RELATION: foo", +/// "RELATION: bar", +/// ] +/// .into_iter().map(|s| s.to_string()).collect(); +/// +/// assert_eq!(visited, expected); +/// ``` pub fn visit_relations(v: &V, f: F) -> ControlFlow where V: Visit, @@ -123,7 +223,35 @@ impl ControlFlow> Visitor for ExprVisitor { } } -/// Invokes the provided closure on all expressions present in v +/// Invokes the provided closure on all expressions (e.g. `1 + 2`) present in `v` +/// +/// # Example +/// ``` +/// # use sqlparser::parser::Parser; +/// # use sqlparser::dialect::GenericDialect; +/// # use sqlparser::ast::{visit_expressions}; +/// # use core::ops::ControlFlow; +/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)"; +/// let statements = Parser::parse_sql(&GenericDialect{}, sql) +/// .unwrap(); +/// +/// // visit all expressions +/// let mut visited = vec![]; +/// visit_expressions(&statements, |expr| { +/// visited.push(format!("EXPR: {}", expr)); +/// ControlFlow::<()>::Continue(()) +/// }); +/// +/// let expected : Vec<_> = [ +/// "EXPR: a", +/// "EXPR: x IN (SELECT y FROM bar)", +/// "EXPR: x", +/// "EXPR: y", +/// ] +/// .into_iter().map(|s| s.to_string()).collect(); +/// +/// assert_eq!(visited, expected); +/// ``` pub fn visit_expressions(v: &V, f: F) -> ControlFlow where V: Visit, @@ -144,7 +272,33 @@ impl ControlFlow> Visitor for StatementVisitor } } -/// Invokes the provided closure on all statements present in v +/// Invokes the provided closure on all statements (e.g. `SELECT`, `CREATE TABLE`, etc) present in `v` +/// +/// # Example +/// ``` +/// # use sqlparser::parser::Parser; +/// # use sqlparser::dialect::GenericDialect; +/// # use sqlparser::ast::{visit_statements}; +/// # use core::ops::ControlFlow; +/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar); CREATE TABLE baz(q int)"; +/// let statements = Parser::parse_sql(&GenericDialect{}, sql) +/// .unwrap(); +/// +/// // visit all statements +/// let mut visited = vec![]; +/// visit_statements(&statements, |stmt| { +/// visited.push(format!("STATEMENT: {}", stmt)); +/// ControlFlow::<()>::Continue(()) +/// }); +/// +/// let expected : Vec<_> = [ +/// "STATEMENT: SELECT a FROM foo WHERE x IN (SELECT y FROM bar)", +/// "STATEMENT: CREATE TABLE baz (q INT)" +/// ] +/// .into_iter().map(|s| s.to_string()).collect(); +/// +/// assert_eq!(visited, expected); +/// ``` pub fn visit_statements(v: &V, f: F) -> ControlFlow where V: Visit,