sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.helper import find_new_name, name_sequence 7 8if t.TYPE_CHECKING: 9 from sqlglot.generator import Generator 10 11 12def preprocess( 13 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 14) -> t.Callable[[Generator, exp.Expression], str]: 15 """ 16 Creates a new transform by chaining a sequence of transformations and converts the resulting 17 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 18 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 19 20 Args: 21 transforms: sequence of transform functions. These will be called in order. 22 23 Returns: 24 Function that can be used as a generator transform. 25 """ 26 27 def _to_sql(self, expression: exp.Expression) -> str: 28 expression_type = type(expression) 29 30 expression = transforms[0](expression) 31 for transform in transforms[1:]: 32 expression = transform(expression) 33 34 _sql_handler = getattr(self, expression.key + "_sql", None) 35 if _sql_handler: 36 return _sql_handler(expression) 37 38 transforms_handler = self.TRANSFORMS.get(type(expression)) 39 if transforms_handler: 40 if expression_type is type(expression): 41 if isinstance(expression, exp.Func): 42 return self.function_fallback_sql(expression) 43 44 # Ensures we don't enter an infinite loop. This can happen when the original expression 45 # has the same type as the final expression and there's no _sql method available for it, 46 # because then it'd re-enter _to_sql. 47 raise ValueError( 48 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 49 ) 50 51 return transforms_handler(self, expression) 52 53 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 54 55 return _to_sql 56 57 58def unnest_generate_series(expression: exp.Expression) -> exp.Expression: 59 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 60 this = expression.this 61 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 62 unnest = exp.Unnest(expressions=[this]) 63 if expression.alias: 64 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 65 66 return unnest 67 68 return expression 69 70 71def unalias_group(expression: exp.Expression) -> exp.Expression: 72 """ 73 Replace references to select aliases in GROUP BY clauses. 74 75 Example: 76 >>> import sqlglot 77 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 78 'SELECT a AS b FROM x GROUP BY 1' 79 80 Args: 81 expression: the expression that will be transformed. 82 83 Returns: 84 The transformed expression. 85 """ 86 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 87 aliased_selects = { 88 e.alias: i 89 for i, e in enumerate(expression.parent.expressions, start=1) 90 if isinstance(e, exp.Alias) 91 } 92 93 for group_by in expression.expressions: 94 if ( 95 isinstance(group_by, exp.Column) 96 and not group_by.table 97 and group_by.name in aliased_selects 98 ): 99 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 100 101 return expression 102 103 104def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 105 """ 106 Convert SELECT DISTINCT ON statements to a subquery with a window function. 107 108 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 109 110 Args: 111 expression: the expression that will be transformed. 112 113 Returns: 114 The transformed expression. 115 """ 116 if ( 117 isinstance(expression, exp.Select) 118 and expression.args.get("distinct") 119 and expression.args["distinct"].args.get("on") 120 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 121 ): 122 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 123 outer_selects = expression.selects 124 row_number = find_new_name(expression.named_selects, "_row_number") 125 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 126 order = expression.args.get("order") 127 128 if order: 129 window.set("order", order.pop()) 130 else: 131 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 132 133 window = exp.alias_(window, row_number) 134 expression.select(window, copy=False) 135 136 return ( 137 exp.select(*outer_selects, copy=False) 138 .from_(expression.subquery("_t", copy=False), copy=False) 139 .where(exp.column(row_number).eq(1), copy=False) 140 ) 141 142 return expression 143 144 145def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 146 """ 147 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 148 149 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 150 https://docs.snowflake.com/en/sql-reference/constructs/qualify 151 152 Some dialects don't support window functions in the WHERE clause, so we need to include them as 153 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 154 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 155 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 156 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 157 corresponding expression to avoid creating invalid column references. 158 """ 159 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 160 taken = set(expression.named_selects) 161 for select in expression.selects: 162 if not select.alias_or_name: 163 alias = find_new_name(taken, "_c") 164 select.replace(exp.alias_(select, alias)) 165 taken.add(alias) 166 167 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 168 alias_or_name = select.alias_or_name 169 identifier = select.args.get("alias") or select.this 170 if isinstance(identifier, exp.Identifier): 171 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 172 return alias_or_name 173 174 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 175 qualify_filters = expression.args["qualify"].pop().this 176 expression_by_alias = { 177 select.alias: select.this 178 for select in expression.selects 179 if isinstance(select, exp.Alias) 180 } 181 182 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 183 for select_candidate in qualify_filters.find_all(select_candidates): 184 if isinstance(select_candidate, exp.Window): 185 if expression_by_alias: 186 for column in select_candidate.find_all(exp.Column): 187 expr = expression_by_alias.get(column.name) 188 if expr: 189 column.replace(expr) 190 191 alias = find_new_name(expression.named_selects, "_w") 192 expression.select(exp.alias_(select_candidate, alias), copy=False) 193 column = exp.column(alias) 194 195 if isinstance(select_candidate.parent, exp.Qualify): 196 qualify_filters = column 197 else: 198 select_candidate.replace(column) 199 elif select_candidate.name not in expression.named_selects: 200 expression.select(select_candidate.copy(), copy=False) 201 202 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 203 qualify_filters, copy=False 204 ) 205 206 return expression 207 208 209def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 210 """ 211 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 212 other expressions. This transforms removes the precision from parameterized types in expressions. 213 """ 214 for node in expression.find_all(exp.DataType): 215 node.set( 216 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 217 ) 218 219 return expression 220 221 222def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 223 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 224 from sqlglot.optimizer.scope import find_all_in_scope 225 226 if isinstance(expression, exp.Select): 227 unnest_aliases = { 228 unnest.alias 229 for unnest in find_all_in_scope(expression, exp.Unnest) 230 if isinstance(unnest.parent, (exp.From, exp.Join)) 231 } 232 if unnest_aliases: 233 for column in expression.find_all(exp.Column): 234 if column.table in unnest_aliases: 235 column.set("table", None) 236 elif column.db in unnest_aliases: 237 column.set("db", None) 238 239 return expression 240 241 242def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 243 """Convert cross join unnest into lateral view explode.""" 244 if isinstance(expression, exp.Select): 245 from_ = expression.args.get("from") 246 247 if from_ and isinstance(from_.this, exp.Unnest): 248 unnest = from_.this 249 alias = unnest.args.get("alias") 250 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 251 this, *expressions = unnest.expressions 252 unnest.replace( 253 exp.Table( 254 this=udtf( 255 this=this, 256 expressions=expressions, 257 ), 258 alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None, 259 ) 260 ) 261 262 for join in expression.args.get("joins") or []: 263 unnest = join.this 264 265 if isinstance(unnest, exp.Unnest): 266 alias = unnest.args.get("alias") 267 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 268 269 expression.args["joins"].remove(join) 270 271 for e, column in zip(unnest.expressions, alias.columns if alias else []): 272 expression.append( 273 "laterals", 274 exp.Lateral( 275 this=udtf(this=e), 276 view=True, 277 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 278 ), 279 ) 280 281 return expression 282 283 284def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 285 """Convert explode/posexplode into unnest.""" 286 287 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 288 if isinstance(expression, exp.Select): 289 from sqlglot.optimizer.scope import Scope 290 291 taken_select_names = set(expression.named_selects) 292 taken_source_names = {name for name, _ in Scope(expression).references} 293 294 def new_name(names: t.Set[str], name: str) -> str: 295 name = find_new_name(names, name) 296 names.add(name) 297 return name 298 299 arrays: t.List[exp.Condition] = [] 300 series_alias = new_name(taken_select_names, "pos") 301 series = exp.alias_( 302 exp.Unnest( 303 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 304 ), 305 new_name(taken_source_names, "_u"), 306 table=[series_alias], 307 ) 308 309 # we use list here because expression.selects is mutated inside the loop 310 for select in list(expression.selects): 311 explode = select.find(exp.Explode) 312 313 if explode: 314 pos_alias = "" 315 explode_alias = "" 316 317 if isinstance(select, exp.Alias): 318 explode_alias = select.args["alias"] 319 alias = select 320 elif isinstance(select, exp.Aliases): 321 pos_alias = select.aliases[0] 322 explode_alias = select.aliases[1] 323 alias = select.replace(exp.alias_(select.this, "", copy=False)) 324 else: 325 alias = select.replace(exp.alias_(select, "")) 326 explode = alias.find(exp.Explode) 327 assert explode 328 329 is_posexplode = isinstance(explode, exp.Posexplode) 330 explode_arg = explode.this 331 332 if isinstance(explode, exp.ExplodeOuter): 333 bracket = explode_arg[0] 334 bracket.set("safe", True) 335 bracket.set("offset", True) 336 explode_arg = exp.func( 337 "IF", 338 exp.func( 339 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 340 ).eq(0), 341 exp.array(bracket, copy=False), 342 explode_arg, 343 ) 344 345 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 346 if isinstance(explode_arg, exp.Column): 347 taken_select_names.add(explode_arg.output_name) 348 349 unnest_source_alias = new_name(taken_source_names, "_u") 350 351 if not explode_alias: 352 explode_alias = new_name(taken_select_names, "col") 353 354 if is_posexplode: 355 pos_alias = new_name(taken_select_names, "pos") 356 357 if not pos_alias: 358 pos_alias = new_name(taken_select_names, "pos") 359 360 alias.set("alias", exp.to_identifier(explode_alias)) 361 362 series_table_alias = series.args["alias"].this 363 column = exp.If( 364 this=exp.column(series_alias, table=series_table_alias).eq( 365 exp.column(pos_alias, table=unnest_source_alias) 366 ), 367 true=exp.column(explode_alias, table=unnest_source_alias), 368 ) 369 370 explode.replace(column) 371 372 if is_posexplode: 373 expressions = expression.expressions 374 expressions.insert( 375 expressions.index(alias) + 1, 376 exp.If( 377 this=exp.column(series_alias, table=series_table_alias).eq( 378 exp.column(pos_alias, table=unnest_source_alias) 379 ), 380 true=exp.column(pos_alias, table=unnest_source_alias), 381 ).as_(pos_alias), 382 ) 383 expression.set("expressions", expressions) 384 385 if not arrays: 386 if expression.args.get("from"): 387 expression.join(series, copy=False, join_type="CROSS") 388 else: 389 expression.from_(series, copy=False) 390 391 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 392 arrays.append(size) 393 394 # trino doesn't support left join unnest with on conditions 395 # if it did, this would be much simpler 396 expression.join( 397 exp.alias_( 398 exp.Unnest( 399 expressions=[explode_arg.copy()], 400 offset=exp.to_identifier(pos_alias), 401 ), 402 unnest_source_alias, 403 table=[explode_alias], 404 ), 405 join_type="CROSS", 406 copy=False, 407 ) 408 409 if index_offset != 1: 410 size = size - 1 411 412 expression.where( 413 exp.column(series_alias, table=series_table_alias) 414 .eq(exp.column(pos_alias, table=unnest_source_alias)) 415 .or_( 416 (exp.column(series_alias, table=series_table_alias) > size).and_( 417 exp.column(pos_alias, table=unnest_source_alias).eq(size) 418 ) 419 ), 420 copy=False, 421 ) 422 423 if arrays: 424 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 425 426 if index_offset != 1: 427 end = end - (1 - index_offset) 428 series.expressions[0].set("end", end) 429 430 return expression 431 432 return _explode_to_unnest 433 434 435def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 436 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 437 if ( 438 isinstance(expression, exp.PERCENTILES) 439 and not isinstance(expression.parent, exp.WithinGroup) 440 and expression.expression 441 ): 442 column = expression.this.pop() 443 expression.set("this", expression.expression.pop()) 444 order = exp.Order(expressions=[exp.Ordered(this=column)]) 445 expression = exp.WithinGroup(this=expression, expression=order) 446 447 return expression 448 449 450def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 451 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 452 if ( 453 isinstance(expression, exp.WithinGroup) 454 and isinstance(expression.this, exp.PERCENTILES) 455 and isinstance(expression.expression, exp.Order) 456 ): 457 quantile = expression.this.this 458 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 459 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 460 461 return expression 462 463 464def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 465 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 466 if isinstance(expression, exp.With) and expression.recursive: 467 next_name = name_sequence("_c_") 468 469 for cte in expression.expressions: 470 if not cte.args["alias"].columns: 471 query = cte.this 472 if isinstance(query, exp.SetOperation): 473 query = query.this 474 475 cte.args["alias"].set( 476 "columns", 477 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 478 ) 479 480 return expression 481 482 483def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 484 """Replace 'epoch' in casts by the equivalent date literal.""" 485 if ( 486 isinstance(expression, (exp.Cast, exp.TryCast)) 487 and expression.name.lower() == "epoch" 488 and expression.to.this in exp.DataType.TEMPORAL_TYPES 489 ): 490 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 491 492 return expression 493 494 495def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 496 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 497 if isinstance(expression, exp.Select): 498 for join in expression.args.get("joins") or []: 499 on = join.args.get("on") 500 if on and join.kind in ("SEMI", "ANTI"): 501 subquery = exp.select("1").from_(join.this).where(on) 502 exists = exp.Exists(this=subquery) 503 if join.kind == "ANTI": 504 exists = exists.not_(copy=False) 505 506 join.pop() 507 expression.where(exists, copy=False) 508 509 return expression 510 511 512def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 513 """ 514 Converts a query with a FULL OUTER join to a union of identical queries that 515 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 516 for queries that have a single FULL OUTER join. 517 """ 518 if isinstance(expression, exp.Select): 519 full_outer_joins = [ 520 (index, join) 521 for index, join in enumerate(expression.args.get("joins") or []) 522 if join.side == "FULL" 523 ] 524 525 if len(full_outer_joins) == 1: 526 expression_copy = expression.copy() 527 expression.set("limit", None) 528 index, full_outer_join = full_outer_joins[0] 529 full_outer_join.set("side", "left") 530 expression_copy.args["joins"][index].set("side", "right") 531 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 532 533 return exp.union(expression, expression_copy, copy=False) 534 535 return expression 536 537 538def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 539 """ 540 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 541 defined at the top-level, so for example queries like: 542 543 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 544 545 are invalid in those dialects. This transformation can be used to ensure all CTEs are 546 moved to the top level so that the final SQL code is valid from a syntax standpoint. 547 548 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 549 """ 550 top_level_with = expression.args.get("with") 551 for inner_with in expression.find_all(exp.With): 552 if inner_with.parent is expression: 553 continue 554 555 if not top_level_with: 556 top_level_with = inner_with.pop() 557 expression.set("with", top_level_with) 558 else: 559 if inner_with.recursive: 560 top_level_with.set("recursive", True) 561 562 parent_cte = inner_with.find_ancestor(exp.CTE) 563 inner_with.pop() 564 565 if parent_cte: 566 i = top_level_with.expressions.index(parent_cte) 567 top_level_with.expressions[i:i] = inner_with.expressions 568 top_level_with.set("expressions", top_level_with.expressions) 569 else: 570 top_level_with.set( 571 "expressions", top_level_with.expressions + inner_with.expressions 572 ) 573 574 return expression 575 576 577def ensure_bools(expression: exp.Expression) -> exp.Expression: 578 """Converts numeric values used in conditions into explicit boolean expressions.""" 579 from sqlglot.optimizer.canonicalize import ensure_bools 580 581 def _ensure_bool(node: exp.Expression) -> None: 582 if ( 583 node.is_number 584 or ( 585 not isinstance(node, exp.SubqueryPredicate) 586 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 587 ) 588 or (isinstance(node, exp.Column) and not node.type) 589 ): 590 node.replace(node.neq(0)) 591 592 for node in expression.walk(): 593 ensure_bools(node, _ensure_bool) 594 595 return expression 596 597 598def unqualify_columns(expression: exp.Expression) -> exp.Expression: 599 for column in expression.find_all(exp.Column): 600 # We only wanna pop off the table, db, catalog args 601 for part in column.parts[:-1]: 602 part.pop() 603 604 return expression 605 606 607def remove_unique_constraints(expression: exp.Expression) -> exp.Expression: 608 assert isinstance(expression, exp.Create) 609 for constraint in expression.find_all(exp.UniqueColumnConstraint): 610 if constraint.parent: 611 constraint.parent.pop() 612 613 return expression 614 615 616def ctas_with_tmp_tables_to_create_tmp_view( 617 expression: exp.Expression, 618 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 619) -> exp.Expression: 620 assert isinstance(expression, exp.Create) 621 properties = expression.args.get("properties") 622 temporary = any( 623 isinstance(prop, exp.TemporaryProperty) 624 for prop in (properties.expressions if properties else []) 625 ) 626 627 # CTAS with temp tables map to CREATE TEMPORARY VIEW 628 if expression.kind == "TABLE" and temporary: 629 if expression.expression: 630 return exp.Create( 631 kind="TEMPORARY VIEW", 632 this=expression.this, 633 expression=expression.expression, 634 ) 635 return tmp_storage_provider(expression) 636 637 return expression 638 639 640def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 641 """ 642 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 643 PARTITIONED BY value is an array of column names, they are transformed into a schema. 644 The corresponding columns are removed from the create statement. 645 """ 646 assert isinstance(expression, exp.Create) 647 has_schema = isinstance(expression.this, exp.Schema) 648 is_partitionable = expression.kind in {"TABLE", "VIEW"} 649 650 if has_schema and is_partitionable: 651 prop = expression.find(exp.PartitionedByProperty) 652 if prop and prop.this and not isinstance(prop.this, exp.Schema): 653 schema = expression.this 654 columns = {v.name.upper() for v in prop.this.expressions} 655 partitions = [col for col in schema.expressions if col.name.upper() in columns] 656 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 657 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 658 expression.set("this", schema) 659 660 return expression 661 662 663def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 664 """ 665 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 666 667 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 668 """ 669 assert isinstance(expression, exp.Create) 670 prop = expression.find(exp.PartitionedByProperty) 671 if ( 672 prop 673 and prop.this 674 and isinstance(prop.this, exp.Schema) 675 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 676 ): 677 prop_this = exp.Tuple( 678 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 679 ) 680 schema = expression.this 681 for e in prop.this.expressions: 682 schema.append("expressions", e) 683 prop.set("this", prop_this) 684 685 return expression 686 687 688def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 689 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 690 if isinstance(expression, exp.Struct): 691 expression.set( 692 "expressions", 693 [ 694 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 695 for e in expression.expressions 696 ], 697 ) 698 699 return expression 700 701 702def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 703 """ 704 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 705 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 706 707 For example, 708 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 709 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 710 711 Args: 712 expression: The AST to remove join marks from. 713 714 Returns: 715 The AST with join marks removed. 716 """ 717 from sqlglot.optimizer.scope import traverse_scope 718 719 for scope in traverse_scope(expression): 720 query = scope.expression 721 722 where = query.args.get("where") 723 joins = query.args.get("joins") 724 725 if not where or not joins: 726 continue 727 728 query_from = query.args["from"] 729 730 # These keep track of the joins to be replaced 731 new_joins: t.Dict[str, exp.Join] = {} 732 old_joins = {join.alias_or_name: join for join in joins} 733 734 for column in scope.columns: 735 if not column.args.get("join_mark"): 736 continue 737 738 predicate = column.find_ancestor(exp.Predicate, exp.Select) 739 assert isinstance( 740 predicate, exp.Binary 741 ), "Columns can only be marked with (+) when involved in a binary operation" 742 743 predicate_parent = predicate.parent 744 join_predicate = predicate.pop() 745 746 left_columns = [ 747 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 748 ] 749 right_columns = [ 750 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 751 ] 752 753 assert not ( 754 left_columns and right_columns 755 ), "The (+) marker cannot appear in both sides of a binary predicate" 756 757 marked_column_tables = set() 758 for col in left_columns or right_columns: 759 table = col.table 760 assert table, f"Column {col} needs to be qualified with a table" 761 762 col.set("join_mark", False) 763 marked_column_tables.add(table) 764 765 assert ( 766 len(marked_column_tables) == 1 767 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 768 769 join_this = old_joins.get(col.table, query_from).this 770 new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT") 771 772 # Upsert new_join into new_joins dictionary 773 new_join_alias_or_name = new_join.alias_or_name 774 existing_join = new_joins.get(new_join_alias_or_name) 775 if existing_join: 776 existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"])) 777 else: 778 new_joins[new_join_alias_or_name] = new_join 779 780 # If the parent of the target predicate is a binary node, then it now has only one child 781 if isinstance(predicate_parent, exp.Binary): 782 if predicate_parent.left is None: 783 predicate_parent.replace(predicate_parent.right) 784 else: 785 predicate_parent.replace(predicate_parent.left) 786 787 if query_from.alias_or_name in new_joins: 788 only_old_joins = old_joins.keys() - new_joins.keys() 789 assert ( 790 len(only_old_joins) >= 1 791 ), "Cannot determine which table to use in the new FROM clause" 792 793 new_from_name = list(only_old_joins)[0] 794 query.set("from", exp.From(this=old_joins[new_from_name].this)) 795 796 query.set("joins", list(new_joins.values())) 797 798 if not where.this: 799 where.pop() 800 801 return expression
13def preprocess( 14 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 15) -> t.Callable[[Generator, exp.Expression], str]: 16 """ 17 Creates a new transform by chaining a sequence of transformations and converts the resulting 18 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 19 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 20 21 Args: 22 transforms: sequence of transform functions. These will be called in order. 23 24 Returns: 25 Function that can be used as a generator transform. 26 """ 27 28 def _to_sql(self, expression: exp.Expression) -> str: 29 expression_type = type(expression) 30 31 expression = transforms[0](expression) 32 for transform in transforms[1:]: 33 expression = transform(expression) 34 35 _sql_handler = getattr(self, expression.key + "_sql", None) 36 if _sql_handler: 37 return _sql_handler(expression) 38 39 transforms_handler = self.TRANSFORMS.get(type(expression)) 40 if transforms_handler: 41 if expression_type is type(expression): 42 if isinstance(expression, exp.Func): 43 return self.function_fallback_sql(expression) 44 45 # Ensures we don't enter an infinite loop. This can happen when the original expression 46 # has the same type as the final expression and there's no _sql method available for it, 47 # because then it'd re-enter _to_sql. 48 raise ValueError( 49 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 50 ) 51 52 return transforms_handler(self, expression) 53 54 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 55 56 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.
59def unnest_generate_series(expression: exp.Expression) -> exp.Expression: 60 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 61 this = expression.this 62 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 63 unnest = exp.Unnest(expressions=[this]) 64 if expression.alias: 65 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 66 67 return unnest 68 69 return expression
Unnests GENERATE_SERIES or SEQUENCE table references.
72def unalias_group(expression: exp.Expression) -> exp.Expression: 73 """ 74 Replace references to select aliases in GROUP BY clauses. 75 76 Example: 77 >>> import sqlglot 78 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 79 'SELECT a AS b FROM x GROUP BY 1' 80 81 Args: 82 expression: the expression that will be transformed. 83 84 Returns: 85 The transformed expression. 86 """ 87 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 88 aliased_selects = { 89 e.alias: i 90 for i, e in enumerate(expression.parent.expressions, start=1) 91 if isinstance(e, exp.Alias) 92 } 93 94 for group_by in expression.expressions: 95 if ( 96 isinstance(group_by, exp.Column) 97 and not group_by.table 98 and group_by.name in aliased_selects 99 ): 100 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 101 102 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
105def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 106 """ 107 Convert SELECT DISTINCT ON statements to a subquery with a window function. 108 109 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 110 111 Args: 112 expression: the expression that will be transformed. 113 114 Returns: 115 The transformed expression. 116 """ 117 if ( 118 isinstance(expression, exp.Select) 119 and expression.args.get("distinct") 120 and expression.args["distinct"].args.get("on") 121 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 122 ): 123 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 124 outer_selects = expression.selects 125 row_number = find_new_name(expression.named_selects, "_row_number") 126 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 127 order = expression.args.get("order") 128 129 if order: 130 window.set("order", order.pop()) 131 else: 132 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 133 134 window = exp.alias_(window, row_number) 135 expression.select(window, copy=False) 136 137 return ( 138 exp.select(*outer_selects, copy=False) 139 .from_(expression.subquery("_t", copy=False), copy=False) 140 .where(exp.column(row_number).eq(1), copy=False) 141 ) 142 143 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
146def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 147 """ 148 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 149 150 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 151 https://docs.snowflake.com/en/sql-reference/constructs/qualify 152 153 Some dialects don't support window functions in the WHERE clause, so we need to include them as 154 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 155 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 156 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 157 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 158 corresponding expression to avoid creating invalid column references. 159 """ 160 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 161 taken = set(expression.named_selects) 162 for select in expression.selects: 163 if not select.alias_or_name: 164 alias = find_new_name(taken, "_c") 165 select.replace(exp.alias_(select, alias)) 166 taken.add(alias) 167 168 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 169 alias_or_name = select.alias_or_name 170 identifier = select.args.get("alias") or select.this 171 if isinstance(identifier, exp.Identifier): 172 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 173 return alias_or_name 174 175 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 176 qualify_filters = expression.args["qualify"].pop().this 177 expression_by_alias = { 178 select.alias: select.this 179 for select in expression.selects 180 if isinstance(select, exp.Alias) 181 } 182 183 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 184 for select_candidate in qualify_filters.find_all(select_candidates): 185 if isinstance(select_candidate, exp.Window): 186 if expression_by_alias: 187 for column in select_candidate.find_all(exp.Column): 188 expr = expression_by_alias.get(column.name) 189 if expr: 190 column.replace(expr) 191 192 alias = find_new_name(expression.named_selects, "_w") 193 expression.select(exp.alias_(select_candidate, alias), copy=False) 194 column = exp.column(alias) 195 196 if isinstance(select_candidate.parent, exp.Qualify): 197 qualify_filters = column 198 else: 199 select_candidate.replace(column) 200 elif select_candidate.name not in expression.named_selects: 201 expression.select(select_candidate.copy(), copy=False) 202 203 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 204 qualify_filters, copy=False 205 ) 206 207 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.
210def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 211 """ 212 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 213 other expressions. This transforms removes the precision from parameterized types in expressions. 214 """ 215 for node in expression.find_all(exp.DataType): 216 node.set( 217 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 218 ) 219 220 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
223def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 224 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 225 from sqlglot.optimizer.scope import find_all_in_scope 226 227 if isinstance(expression, exp.Select): 228 unnest_aliases = { 229 unnest.alias 230 for unnest in find_all_in_scope(expression, exp.Unnest) 231 if isinstance(unnest.parent, (exp.From, exp.Join)) 232 } 233 if unnest_aliases: 234 for column in expression.find_all(exp.Column): 235 if column.table in unnest_aliases: 236 column.set("table", None) 237 elif column.db in unnest_aliases: 238 column.set("db", None) 239 240 return expression
Remove references to unnest table aliases, added by the optimizer's qualify_columns step.
243def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 244 """Convert cross join unnest into lateral view explode.""" 245 if isinstance(expression, exp.Select): 246 from_ = expression.args.get("from") 247 248 if from_ and isinstance(from_.this, exp.Unnest): 249 unnest = from_.this 250 alias = unnest.args.get("alias") 251 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 252 this, *expressions = unnest.expressions 253 unnest.replace( 254 exp.Table( 255 this=udtf( 256 this=this, 257 expressions=expressions, 258 ), 259 alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None, 260 ) 261 ) 262 263 for join in expression.args.get("joins") or []: 264 unnest = join.this 265 266 if isinstance(unnest, exp.Unnest): 267 alias = unnest.args.get("alias") 268 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 269 270 expression.args["joins"].remove(join) 271 272 for e, column in zip(unnest.expressions, alias.columns if alias else []): 273 expression.append( 274 "laterals", 275 exp.Lateral( 276 this=udtf(this=e), 277 view=True, 278 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 279 ), 280 ) 281 282 return expression
Convert cross join unnest into lateral view explode.
285def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 286 """Convert explode/posexplode into unnest.""" 287 288 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 289 if isinstance(expression, exp.Select): 290 from sqlglot.optimizer.scope import Scope 291 292 taken_select_names = set(expression.named_selects) 293 taken_source_names = {name for name, _ in Scope(expression).references} 294 295 def new_name(names: t.Set[str], name: str) -> str: 296 name = find_new_name(names, name) 297 names.add(name) 298 return name 299 300 arrays: t.List[exp.Condition] = [] 301 series_alias = new_name(taken_select_names, "pos") 302 series = exp.alias_( 303 exp.Unnest( 304 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 305 ), 306 new_name(taken_source_names, "_u"), 307 table=[series_alias], 308 ) 309 310 # we use list here because expression.selects is mutated inside the loop 311 for select in list(expression.selects): 312 explode = select.find(exp.Explode) 313 314 if explode: 315 pos_alias = "" 316 explode_alias = "" 317 318 if isinstance(select, exp.Alias): 319 explode_alias = select.args["alias"] 320 alias = select 321 elif isinstance(select, exp.Aliases): 322 pos_alias = select.aliases[0] 323 explode_alias = select.aliases[1] 324 alias = select.replace(exp.alias_(select.this, "", copy=False)) 325 else: 326 alias = select.replace(exp.alias_(select, "")) 327 explode = alias.find(exp.Explode) 328 assert explode 329 330 is_posexplode = isinstance(explode, exp.Posexplode) 331 explode_arg = explode.this 332 333 if isinstance(explode, exp.ExplodeOuter): 334 bracket = explode_arg[0] 335 bracket.set("safe", True) 336 bracket.set("offset", True) 337 explode_arg = exp.func( 338 "IF", 339 exp.func( 340 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 341 ).eq(0), 342 exp.array(bracket, copy=False), 343 explode_arg, 344 ) 345 346 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 347 if isinstance(explode_arg, exp.Column): 348 taken_select_names.add(explode_arg.output_name) 349 350 unnest_source_alias = new_name(taken_source_names, "_u") 351 352 if not explode_alias: 353 explode_alias = new_name(taken_select_names, "col") 354 355 if is_posexplode: 356 pos_alias = new_name(taken_select_names, "pos") 357 358 if not pos_alias: 359 pos_alias = new_name(taken_select_names, "pos") 360 361 alias.set("alias", exp.to_identifier(explode_alias)) 362 363 series_table_alias = series.args["alias"].this 364 column = exp.If( 365 this=exp.column(series_alias, table=series_table_alias).eq( 366 exp.column(pos_alias, table=unnest_source_alias) 367 ), 368 true=exp.column(explode_alias, table=unnest_source_alias), 369 ) 370 371 explode.replace(column) 372 373 if is_posexplode: 374 expressions = expression.expressions 375 expressions.insert( 376 expressions.index(alias) + 1, 377 exp.If( 378 this=exp.column(series_alias, table=series_table_alias).eq( 379 exp.column(pos_alias, table=unnest_source_alias) 380 ), 381 true=exp.column(pos_alias, table=unnest_source_alias), 382 ).as_(pos_alias), 383 ) 384 expression.set("expressions", expressions) 385 386 if not arrays: 387 if expression.args.get("from"): 388 expression.join(series, copy=False, join_type="CROSS") 389 else: 390 expression.from_(series, copy=False) 391 392 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 393 arrays.append(size) 394 395 # trino doesn't support left join unnest with on conditions 396 # if it did, this would be much simpler 397 expression.join( 398 exp.alias_( 399 exp.Unnest( 400 expressions=[explode_arg.copy()], 401 offset=exp.to_identifier(pos_alias), 402 ), 403 unnest_source_alias, 404 table=[explode_alias], 405 ), 406 join_type="CROSS", 407 copy=False, 408 ) 409 410 if index_offset != 1: 411 size = size - 1 412 413 expression.where( 414 exp.column(series_alias, table=series_table_alias) 415 .eq(exp.column(pos_alias, table=unnest_source_alias)) 416 .or_( 417 (exp.column(series_alias, table=series_table_alias) > size).and_( 418 exp.column(pos_alias, table=unnest_source_alias).eq(size) 419 ) 420 ), 421 copy=False, 422 ) 423 424 if arrays: 425 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 426 427 if index_offset != 1: 428 end = end - (1 - index_offset) 429 series.expressions[0].set("end", end) 430 431 return expression 432 433 return _explode_to_unnest
Convert explode/posexplode into unnest.
436def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 437 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 438 if ( 439 isinstance(expression, exp.PERCENTILES) 440 and not isinstance(expression.parent, exp.WithinGroup) 441 and expression.expression 442 ): 443 column = expression.this.pop() 444 expression.set("this", expression.expression.pop()) 445 order = exp.Order(expressions=[exp.Ordered(this=column)]) 446 expression = exp.WithinGroup(this=expression, expression=order) 447 448 return expression
Transforms percentiles by adding a WITHIN GROUP clause to them.
451def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 452 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 453 if ( 454 isinstance(expression, exp.WithinGroup) 455 and isinstance(expression.this, exp.PERCENTILES) 456 and isinstance(expression.expression, exp.Order) 457 ): 458 quantile = expression.this.this 459 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 460 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 461 462 return expression
Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.
465def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 466 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 467 if isinstance(expression, exp.With) and expression.recursive: 468 next_name = name_sequence("_c_") 469 470 for cte in expression.expressions: 471 if not cte.args["alias"].columns: 472 query = cte.this 473 if isinstance(query, exp.SetOperation): 474 query = query.this 475 476 cte.args["alias"].set( 477 "columns", 478 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 479 ) 480 481 return expression
Uses projection output names in recursive CTE definitions to define the CTEs' columns.
484def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 485 """Replace 'epoch' in casts by the equivalent date literal.""" 486 if ( 487 isinstance(expression, (exp.Cast, exp.TryCast)) 488 and expression.name.lower() == "epoch" 489 and expression.to.this in exp.DataType.TEMPORAL_TYPES 490 ): 491 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 492 493 return expression
Replace 'epoch' in casts by the equivalent date literal.
496def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 497 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 498 if isinstance(expression, exp.Select): 499 for join in expression.args.get("joins") or []: 500 on = join.args.get("on") 501 if on and join.kind in ("SEMI", "ANTI"): 502 subquery = exp.select("1").from_(join.this).where(on) 503 exists = exp.Exists(this=subquery) 504 if join.kind == "ANTI": 505 exists = exists.not_(copy=False) 506 507 join.pop() 508 expression.where(exists, copy=False) 509 510 return expression
Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.
513def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 514 """ 515 Converts a query with a FULL OUTER join to a union of identical queries that 516 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 517 for queries that have a single FULL OUTER join. 518 """ 519 if isinstance(expression, exp.Select): 520 full_outer_joins = [ 521 (index, join) 522 for index, join in enumerate(expression.args.get("joins") or []) 523 if join.side == "FULL" 524 ] 525 526 if len(full_outer_joins) == 1: 527 expression_copy = expression.copy() 528 expression.set("limit", None) 529 index, full_outer_join = full_outer_joins[0] 530 full_outer_join.set("side", "left") 531 expression_copy.args["joins"][index].set("side", "right") 532 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 533 534 return exp.union(expression, expression_copy, copy=False) 535 536 return expression
Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.
539def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 540 """ 541 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 542 defined at the top-level, so for example queries like: 543 544 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 545 546 are invalid in those dialects. This transformation can be used to ensure all CTEs are 547 moved to the top level so that the final SQL code is valid from a syntax standpoint. 548 549 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 550 """ 551 top_level_with = expression.args.get("with") 552 for inner_with in expression.find_all(exp.With): 553 if inner_with.parent is expression: 554 continue 555 556 if not top_level_with: 557 top_level_with = inner_with.pop() 558 expression.set("with", top_level_with) 559 else: 560 if inner_with.recursive: 561 top_level_with.set("recursive", True) 562 563 parent_cte = inner_with.find_ancestor(exp.CTE) 564 inner_with.pop() 565 566 if parent_cte: 567 i = top_level_with.expressions.index(parent_cte) 568 top_level_with.expressions[i:i] = inner_with.expressions 569 top_level_with.set("expressions", top_level_with.expressions) 570 else: 571 top_level_with.set( 572 "expressions", top_level_with.expressions + inner_with.expressions 573 ) 574 575 return expression
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:
SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
578def ensure_bools(expression: exp.Expression) -> exp.Expression: 579 """Converts numeric values used in conditions into explicit boolean expressions.""" 580 from sqlglot.optimizer.canonicalize import ensure_bools 581 582 def _ensure_bool(node: exp.Expression) -> None: 583 if ( 584 node.is_number 585 or ( 586 not isinstance(node, exp.SubqueryPredicate) 587 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 588 ) 589 or (isinstance(node, exp.Column) and not node.type) 590 ): 591 node.replace(node.neq(0)) 592 593 for node in expression.walk(): 594 ensure_bools(node, _ensure_bool) 595 596 return expression
Converts numeric values used in conditions into explicit boolean expressions.
617def ctas_with_tmp_tables_to_create_tmp_view( 618 expression: exp.Expression, 619 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 620) -> exp.Expression: 621 assert isinstance(expression, exp.Create) 622 properties = expression.args.get("properties") 623 temporary = any( 624 isinstance(prop, exp.TemporaryProperty) 625 for prop in (properties.expressions if properties else []) 626 ) 627 628 # CTAS with temp tables map to CREATE TEMPORARY VIEW 629 if expression.kind == "TABLE" and temporary: 630 if expression.expression: 631 return exp.Create( 632 kind="TEMPORARY VIEW", 633 this=expression.this, 634 expression=expression.expression, 635 ) 636 return tmp_storage_provider(expression) 637 638 return expression
641def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 642 """ 643 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 644 PARTITIONED BY value is an array of column names, they are transformed into a schema. 645 The corresponding columns are removed from the create statement. 646 """ 647 assert isinstance(expression, exp.Create) 648 has_schema = isinstance(expression.this, exp.Schema) 649 is_partitionable = expression.kind in {"TABLE", "VIEW"} 650 651 if has_schema and is_partitionable: 652 prop = expression.find(exp.PartitionedByProperty) 653 if prop and prop.this and not isinstance(prop.this, exp.Schema): 654 schema = expression.this 655 columns = {v.name.upper() for v in prop.this.expressions} 656 partitions = [col for col in schema.expressions if col.name.upper() in columns] 657 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 658 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 659 expression.set("this", schema) 660 661 return expression
In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
664def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 665 """ 666 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 667 668 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 669 """ 670 assert isinstance(expression, exp.Create) 671 prop = expression.find(exp.PartitionedByProperty) 672 if ( 673 prop 674 and prop.this 675 and isinstance(prop.this, exp.Schema) 676 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 677 ): 678 prop_this = exp.Tuple( 679 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 680 ) 681 schema = expression.this 682 for e in prop.this.expressions: 683 schema.append("expressions", e) 684 prop.set("this", prop_this) 685 686 return expression
Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
Currently, SQLGlot uses the DATASOURCE format for Spark 3.
689def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 690 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 691 if isinstance(expression, exp.Struct): 692 expression.set( 693 "expressions", 694 [ 695 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 696 for e in expression.expressions 697 ], 698 ) 699 700 return expression
Converts struct arguments to aliases, e.g. STRUCT(1 AS y).
703def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 704 """ 705 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 706 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 707 708 For example, 709 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 710 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 711 712 Args: 713 expression: The AST to remove join marks from. 714 715 Returns: 716 The AST with join marks removed. 717 """ 718 from sqlglot.optimizer.scope import traverse_scope 719 720 for scope in traverse_scope(expression): 721 query = scope.expression 722 723 where = query.args.get("where") 724 joins = query.args.get("joins") 725 726 if not where or not joins: 727 continue 728 729 query_from = query.args["from"] 730 731 # These keep track of the joins to be replaced 732 new_joins: t.Dict[str, exp.Join] = {} 733 old_joins = {join.alias_or_name: join for join in joins} 734 735 for column in scope.columns: 736 if not column.args.get("join_mark"): 737 continue 738 739 predicate = column.find_ancestor(exp.Predicate, exp.Select) 740 assert isinstance( 741 predicate, exp.Binary 742 ), "Columns can only be marked with (+) when involved in a binary operation" 743 744 predicate_parent = predicate.parent 745 join_predicate = predicate.pop() 746 747 left_columns = [ 748 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 749 ] 750 right_columns = [ 751 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 752 ] 753 754 assert not ( 755 left_columns and right_columns 756 ), "The (+) marker cannot appear in both sides of a binary predicate" 757 758 marked_column_tables = set() 759 for col in left_columns or right_columns: 760 table = col.table 761 assert table, f"Column {col} needs to be qualified with a table" 762 763 col.set("join_mark", False) 764 marked_column_tables.add(table) 765 766 assert ( 767 len(marked_column_tables) == 1 768 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 769 770 join_this = old_joins.get(col.table, query_from).this 771 new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT") 772 773 # Upsert new_join into new_joins dictionary 774 new_join_alias_or_name = new_join.alias_or_name 775 existing_join = new_joins.get(new_join_alias_or_name) 776 if existing_join: 777 existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"])) 778 else: 779 new_joins[new_join_alias_or_name] = new_join 780 781 # If the parent of the target predicate is a binary node, then it now has only one child 782 if isinstance(predicate_parent, exp.Binary): 783 if predicate_parent.left is None: 784 predicate_parent.replace(predicate_parent.right) 785 else: 786 predicate_parent.replace(predicate_parent.left) 787 788 if query_from.alias_or_name in new_joins: 789 only_old_joins = old_joins.keys() - new_joins.keys() 790 assert ( 791 len(only_old_joins) >= 1 792 ), "Cannot determine which table to use in the new FROM clause" 793 794 new_from_name = list(only_old_joins)[0] 795 query.set("from", exp.From(this=old_joins[new_from_name].this)) 796 797 query.set("joins", list(new_joins.values())) 798 799 if not where.this: 800 where.pop() 801 802 return expression
Remove join marks from an AST. This rule assumes that all marked columns are qualified.
If this does not hold for a query, consider running sqlglot.optimizer.qualify
first.
For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
Arguments:
- expression: The AST to remove join marks from.
Returns:
The AST with join marks removed.