Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.helper import find_new_name, name_sequence
  7
  8if t.TYPE_CHECKING:
  9    from sqlglot.generator import Generator
 10
 11
 12def preprocess(
 13    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
 14) -> t.Callable[[Generator, exp.Expression], str]:
 15    """
 16    Creates a new transform by chaining a sequence of transformations and converts the resulting
 17    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
 18    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
 19
 20    Args:
 21        transforms: sequence of transform functions. These will be called in order.
 22
 23    Returns:
 24        Function that can be used as a generator transform.
 25    """
 26
 27    def _to_sql(self, expression: exp.Expression) -> str:
 28        expression_type = type(expression)
 29
 30        expression = transforms[0](expression)
 31        for transform in transforms[1:]:
 32            expression = transform(expression)
 33
 34        _sql_handler = getattr(self, expression.key + "_sql", None)
 35        if _sql_handler:
 36            return _sql_handler(expression)
 37
 38        transforms_handler = self.TRANSFORMS.get(type(expression))
 39        if transforms_handler:
 40            if expression_type is type(expression):
 41                if isinstance(expression, exp.Func):
 42                    return self.function_fallback_sql(expression)
 43
 44                # Ensures we don't enter an infinite loop. This can happen when the original expression
 45                # has the same type as the final expression and there's no _sql method available for it,
 46                # because then it'd re-enter _to_sql.
 47                raise ValueError(
 48                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
 49                )
 50
 51            return transforms_handler(self, expression)
 52
 53        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
 54
 55    return _to_sql
 56
 57
 58def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
 59    """Unnests GENERATE_SERIES or SEQUENCE table references."""
 60    this = expression.this
 61    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
 62        unnest = exp.Unnest(expressions=[this])
 63        if expression.alias:
 64            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
 65
 66        return unnest
 67
 68    return expression
 69
 70
 71def unalias_group(expression: exp.Expression) -> exp.Expression:
 72    """
 73    Replace references to select aliases in GROUP BY clauses.
 74
 75    Example:
 76        >>> import sqlglot
 77        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
 78        'SELECT a AS b FROM x GROUP BY 1'
 79
 80    Args:
 81        expression: the expression that will be transformed.
 82
 83    Returns:
 84        The transformed expression.
 85    """
 86    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
 87        aliased_selects = {
 88            e.alias: i
 89            for i, e in enumerate(expression.parent.expressions, start=1)
 90            if isinstance(e, exp.Alias)
 91        }
 92
 93        for group_by in expression.expressions:
 94            if (
 95                isinstance(group_by, exp.Column)
 96                and not group_by.table
 97                and group_by.name in aliased_selects
 98            ):
 99                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
100
101    return expression
102
103
104def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
105    """
106    Convert SELECT DISTINCT ON statements to a subquery with a window function.
107
108    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
109
110    Args:
111        expression: the expression that will be transformed.
112
113    Returns:
114        The transformed expression.
115    """
116    if (
117        isinstance(expression, exp.Select)
118        and expression.args.get("distinct")
119        and expression.args["distinct"].args.get("on")
120        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
121    ):
122        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
123        outer_selects = expression.selects
124        row_number = find_new_name(expression.named_selects, "_row_number")
125        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
126        order = expression.args.get("order")
127
128        if order:
129            window.set("order", order.pop())
130        else:
131            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
132
133        window = exp.alias_(window, row_number)
134        expression.select(window, copy=False)
135
136        return (
137            exp.select(*outer_selects, copy=False)
138            .from_(expression.subquery("_t", copy=False), copy=False)
139            .where(exp.column(row_number).eq(1), copy=False)
140        )
141
142    return expression
143
144
145def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
146    """
147    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
148
149    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
150    https://docs.snowflake.com/en/sql-reference/constructs/qualify
151
152    Some dialects don't support window functions in the WHERE clause, so we need to include them as
153    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
154    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
155    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
156    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
157    corresponding expression to avoid creating invalid column references.
158    """
159    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
160        taken = set(expression.named_selects)
161        for select in expression.selects:
162            if not select.alias_or_name:
163                alias = find_new_name(taken, "_c")
164                select.replace(exp.alias_(select, alias))
165                taken.add(alias)
166
167        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
168            alias_or_name = select.alias_or_name
169            identifier = select.args.get("alias") or select.this
170            if isinstance(identifier, exp.Identifier):
171                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
172            return alias_or_name
173
174        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
175        qualify_filters = expression.args["qualify"].pop().this
176        expression_by_alias = {
177            select.alias: select.this
178            for select in expression.selects
179            if isinstance(select, exp.Alias)
180        }
181
182        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
183        for select_candidate in qualify_filters.find_all(select_candidates):
184            if isinstance(select_candidate, exp.Window):
185                if expression_by_alias:
186                    for column in select_candidate.find_all(exp.Column):
187                        expr = expression_by_alias.get(column.name)
188                        if expr:
189                            column.replace(expr)
190
191                alias = find_new_name(expression.named_selects, "_w")
192                expression.select(exp.alias_(select_candidate, alias), copy=False)
193                column = exp.column(alias)
194
195                if isinstance(select_candidate.parent, exp.Qualify):
196                    qualify_filters = column
197                else:
198                    select_candidate.replace(column)
199            elif select_candidate.name not in expression.named_selects:
200                expression.select(select_candidate.copy(), copy=False)
201
202        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
203            qualify_filters, copy=False
204        )
205
206    return expression
207
208
209def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
210    """
211    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
212    other expressions. This transforms removes the precision from parameterized types in expressions.
213    """
214    for node in expression.find_all(exp.DataType):
215        node.set(
216            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
217        )
218
219    return expression
220
221
222def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
223    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
224    from sqlglot.optimizer.scope import find_all_in_scope
225
226    if isinstance(expression, exp.Select):
227        unnest_aliases = {
228            unnest.alias
229            for unnest in find_all_in_scope(expression, exp.Unnest)
230            if isinstance(unnest.parent, (exp.From, exp.Join))
231        }
232        if unnest_aliases:
233            for column in expression.find_all(exp.Column):
234                if column.table in unnest_aliases:
235                    column.set("table", None)
236                elif column.db in unnest_aliases:
237                    column.set("db", None)
238
239    return expression
240
241
242def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
243    """Convert cross join unnest into lateral view explode."""
244    if isinstance(expression, exp.Select):
245        from_ = expression.args.get("from")
246
247        if from_ and isinstance(from_.this, exp.Unnest):
248            unnest = from_.this
249            alias = unnest.args.get("alias")
250            udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
251            this, *expressions = unnest.expressions
252            unnest.replace(
253                exp.Table(
254                    this=udtf(
255                        this=this,
256                        expressions=expressions,
257                    ),
258                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
259                )
260            )
261
262        for join in expression.args.get("joins") or []:
263            unnest = join.this
264
265            if isinstance(unnest, exp.Unnest):
266                alias = unnest.args.get("alias")
267                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
268
269                expression.args["joins"].remove(join)
270
271                for e, column in zip(unnest.expressions, alias.columns if alias else []):
272                    expression.append(
273                        "laterals",
274                        exp.Lateral(
275                            this=udtf(this=e),
276                            view=True,
277                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
278                        ),
279                    )
280
281    return expression
282
283
284def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
285    """Convert explode/posexplode into unnest."""
286
287    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
288        if isinstance(expression, exp.Select):
289            from sqlglot.optimizer.scope import Scope
290
291            taken_select_names = set(expression.named_selects)
292            taken_source_names = {name for name, _ in Scope(expression).references}
293
294            def new_name(names: t.Set[str], name: str) -> str:
295                name = find_new_name(names, name)
296                names.add(name)
297                return name
298
299            arrays: t.List[exp.Condition] = []
300            series_alias = new_name(taken_select_names, "pos")
301            series = exp.alias_(
302                exp.Unnest(
303                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
304                ),
305                new_name(taken_source_names, "_u"),
306                table=[series_alias],
307            )
308
309            # we use list here because expression.selects is mutated inside the loop
310            for select in list(expression.selects):
311                explode = select.find(exp.Explode)
312
313                if explode:
314                    pos_alias = ""
315                    explode_alias = ""
316
317                    if isinstance(select, exp.Alias):
318                        explode_alias = select.args["alias"]
319                        alias = select
320                    elif isinstance(select, exp.Aliases):
321                        pos_alias = select.aliases[0]
322                        explode_alias = select.aliases[1]
323                        alias = select.replace(exp.alias_(select.this, "", copy=False))
324                    else:
325                        alias = select.replace(exp.alias_(select, ""))
326                        explode = alias.find(exp.Explode)
327                        assert explode
328
329                    is_posexplode = isinstance(explode, exp.Posexplode)
330                    explode_arg = explode.this
331
332                    if isinstance(explode, exp.ExplodeOuter):
333                        bracket = explode_arg[0]
334                        bracket.set("safe", True)
335                        bracket.set("offset", True)
336                        explode_arg = exp.func(
337                            "IF",
338                            exp.func(
339                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
340                            ).eq(0),
341                            exp.array(bracket, copy=False),
342                            explode_arg,
343                        )
344
345                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
346                    if isinstance(explode_arg, exp.Column):
347                        taken_select_names.add(explode_arg.output_name)
348
349                    unnest_source_alias = new_name(taken_source_names, "_u")
350
351                    if not explode_alias:
352                        explode_alias = new_name(taken_select_names, "col")
353
354                        if is_posexplode:
355                            pos_alias = new_name(taken_select_names, "pos")
356
357                    if not pos_alias:
358                        pos_alias = new_name(taken_select_names, "pos")
359
360                    alias.set("alias", exp.to_identifier(explode_alias))
361
362                    series_table_alias = series.args["alias"].this
363                    column = exp.If(
364                        this=exp.column(series_alias, table=series_table_alias).eq(
365                            exp.column(pos_alias, table=unnest_source_alias)
366                        ),
367                        true=exp.column(explode_alias, table=unnest_source_alias),
368                    )
369
370                    explode.replace(column)
371
372                    if is_posexplode:
373                        expressions = expression.expressions
374                        expressions.insert(
375                            expressions.index(alias) + 1,
376                            exp.If(
377                                this=exp.column(series_alias, table=series_table_alias).eq(
378                                    exp.column(pos_alias, table=unnest_source_alias)
379                                ),
380                                true=exp.column(pos_alias, table=unnest_source_alias),
381                            ).as_(pos_alias),
382                        )
383                        expression.set("expressions", expressions)
384
385                    if not arrays:
386                        if expression.args.get("from"):
387                            expression.join(series, copy=False, join_type="CROSS")
388                        else:
389                            expression.from_(series, copy=False)
390
391                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
392                    arrays.append(size)
393
394                    # trino doesn't support left join unnest with on conditions
395                    # if it did, this would be much simpler
396                    expression.join(
397                        exp.alias_(
398                            exp.Unnest(
399                                expressions=[explode_arg.copy()],
400                                offset=exp.to_identifier(pos_alias),
401                            ),
402                            unnest_source_alias,
403                            table=[explode_alias],
404                        ),
405                        join_type="CROSS",
406                        copy=False,
407                    )
408
409                    if index_offset != 1:
410                        size = size - 1
411
412                    expression.where(
413                        exp.column(series_alias, table=series_table_alias)
414                        .eq(exp.column(pos_alias, table=unnest_source_alias))
415                        .or_(
416                            (exp.column(series_alias, table=series_table_alias) > size).and_(
417                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
418                            )
419                        ),
420                        copy=False,
421                    )
422
423            if arrays:
424                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
425
426                if index_offset != 1:
427                    end = end - (1 - index_offset)
428                series.expressions[0].set("end", end)
429
430        return expression
431
432    return _explode_to_unnest
433
434
435def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
436    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
437    if (
438        isinstance(expression, exp.PERCENTILES)
439        and not isinstance(expression.parent, exp.WithinGroup)
440        and expression.expression
441    ):
442        column = expression.this.pop()
443        expression.set("this", expression.expression.pop())
444        order = exp.Order(expressions=[exp.Ordered(this=column)])
445        expression = exp.WithinGroup(this=expression, expression=order)
446
447    return expression
448
449
450def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
451    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
452    if (
453        isinstance(expression, exp.WithinGroup)
454        and isinstance(expression.this, exp.PERCENTILES)
455        and isinstance(expression.expression, exp.Order)
456    ):
457        quantile = expression.this.this
458        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
459        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
460
461    return expression
462
463
464def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
465    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
466    if isinstance(expression, exp.With) and expression.recursive:
467        next_name = name_sequence("_c_")
468
469        for cte in expression.expressions:
470            if not cte.args["alias"].columns:
471                query = cte.this
472                if isinstance(query, exp.SetOperation):
473                    query = query.this
474
475                cte.args["alias"].set(
476                    "columns",
477                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
478                )
479
480    return expression
481
482
483def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
484    """Replace 'epoch' in casts by the equivalent date literal."""
485    if (
486        isinstance(expression, (exp.Cast, exp.TryCast))
487        and expression.name.lower() == "epoch"
488        and expression.to.this in exp.DataType.TEMPORAL_TYPES
489    ):
490        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
491
492    return expression
493
494
495def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
496    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
497    if isinstance(expression, exp.Select):
498        for join in expression.args.get("joins") or []:
499            on = join.args.get("on")
500            if on and join.kind in ("SEMI", "ANTI"):
501                subquery = exp.select("1").from_(join.this).where(on)
502                exists = exp.Exists(this=subquery)
503                if join.kind == "ANTI":
504                    exists = exists.not_(copy=False)
505
506                join.pop()
507                expression.where(exists, copy=False)
508
509    return expression
510
511
512def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
513    """
514    Converts a query with a FULL OUTER join to a union of identical queries that
515    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
516    for queries that have a single FULL OUTER join.
517    """
518    if isinstance(expression, exp.Select):
519        full_outer_joins = [
520            (index, join)
521            for index, join in enumerate(expression.args.get("joins") or [])
522            if join.side == "FULL"
523        ]
524
525        if len(full_outer_joins) == 1:
526            expression_copy = expression.copy()
527            expression.set("limit", None)
528            index, full_outer_join = full_outer_joins[0]
529            full_outer_join.set("side", "left")
530            expression_copy.args["joins"][index].set("side", "right")
531            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
532
533            return exp.union(expression, expression_copy, copy=False)
534
535    return expression
536
537
538def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
539    """
540    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
541    defined at the top-level, so for example queries like:
542
543        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
544
545    are invalid in those dialects. This transformation can be used to ensure all CTEs are
546    moved to the top level so that the final SQL code is valid from a syntax standpoint.
547
548    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
549    """
550    top_level_with = expression.args.get("with")
551    for inner_with in expression.find_all(exp.With):
552        if inner_with.parent is expression:
553            continue
554
555        if not top_level_with:
556            top_level_with = inner_with.pop()
557            expression.set("with", top_level_with)
558        else:
559            if inner_with.recursive:
560                top_level_with.set("recursive", True)
561
562            parent_cte = inner_with.find_ancestor(exp.CTE)
563            inner_with.pop()
564
565            if parent_cte:
566                i = top_level_with.expressions.index(parent_cte)
567                top_level_with.expressions[i:i] = inner_with.expressions
568                top_level_with.set("expressions", top_level_with.expressions)
569            else:
570                top_level_with.set(
571                    "expressions", top_level_with.expressions + inner_with.expressions
572                )
573
574    return expression
575
576
577def ensure_bools(expression: exp.Expression) -> exp.Expression:
578    """Converts numeric values used in conditions into explicit boolean expressions."""
579    from sqlglot.optimizer.canonicalize import ensure_bools
580
581    def _ensure_bool(node: exp.Expression) -> None:
582        if (
583            node.is_number
584            or (
585                not isinstance(node, exp.SubqueryPredicate)
586                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
587            )
588            or (isinstance(node, exp.Column) and not node.type)
589        ):
590            node.replace(node.neq(0))
591
592    for node in expression.walk():
593        ensure_bools(node, _ensure_bool)
594
595    return expression
596
597
598def unqualify_columns(expression: exp.Expression) -> exp.Expression:
599    for column in expression.find_all(exp.Column):
600        # We only wanna pop off the table, db, catalog args
601        for part in column.parts[:-1]:
602            part.pop()
603
604    return expression
605
606
607def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
608    assert isinstance(expression, exp.Create)
609    for constraint in expression.find_all(exp.UniqueColumnConstraint):
610        if constraint.parent:
611            constraint.parent.pop()
612
613    return expression
614
615
616def ctas_with_tmp_tables_to_create_tmp_view(
617    expression: exp.Expression,
618    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
619) -> exp.Expression:
620    assert isinstance(expression, exp.Create)
621    properties = expression.args.get("properties")
622    temporary = any(
623        isinstance(prop, exp.TemporaryProperty)
624        for prop in (properties.expressions if properties else [])
625    )
626
627    # CTAS with temp tables map to CREATE TEMPORARY VIEW
628    if expression.kind == "TABLE" and temporary:
629        if expression.expression:
630            return exp.Create(
631                kind="TEMPORARY VIEW",
632                this=expression.this,
633                expression=expression.expression,
634            )
635        return tmp_storage_provider(expression)
636
637    return expression
638
639
640def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
641    """
642    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
643    PARTITIONED BY value is an array of column names, they are transformed into a schema.
644    The corresponding columns are removed from the create statement.
645    """
646    assert isinstance(expression, exp.Create)
647    has_schema = isinstance(expression.this, exp.Schema)
648    is_partitionable = expression.kind in {"TABLE", "VIEW"}
649
650    if has_schema and is_partitionable:
651        prop = expression.find(exp.PartitionedByProperty)
652        if prop and prop.this and not isinstance(prop.this, exp.Schema):
653            schema = expression.this
654            columns = {v.name.upper() for v in prop.this.expressions}
655            partitions = [col for col in schema.expressions if col.name.upper() in columns]
656            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
657            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
658            expression.set("this", schema)
659
660    return expression
661
662
663def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
664    """
665    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
666
667    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
668    """
669    assert isinstance(expression, exp.Create)
670    prop = expression.find(exp.PartitionedByProperty)
671    if (
672        prop
673        and prop.this
674        and isinstance(prop.this, exp.Schema)
675        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
676    ):
677        prop_this = exp.Tuple(
678            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
679        )
680        schema = expression.this
681        for e in prop.this.expressions:
682            schema.append("expressions", e)
683        prop.set("this", prop_this)
684
685    return expression
686
687
688def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
689    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
690    if isinstance(expression, exp.Struct):
691        expression.set(
692            "expressions",
693            [
694                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
695                for e in expression.expressions
696            ],
697        )
698
699    return expression
700
701
702def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
703    """
704    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
705    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
706
707    For example,
708        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
709        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
710
711    Args:
712        expression: The AST to remove join marks from.
713
714    Returns:
715       The AST with join marks removed.
716    """
717    from sqlglot.optimizer.scope import traverse_scope
718
719    for scope in traverse_scope(expression):
720        query = scope.expression
721
722        where = query.args.get("where")
723        joins = query.args.get("joins")
724
725        if not where or not joins:
726            continue
727
728        query_from = query.args["from"]
729
730        # These keep track of the joins to be replaced
731        new_joins: t.Dict[str, exp.Join] = {}
732        old_joins = {join.alias_or_name: join for join in joins}
733
734        for column in scope.columns:
735            if not column.args.get("join_mark"):
736                continue
737
738            predicate = column.find_ancestor(exp.Predicate, exp.Select)
739            assert isinstance(
740                predicate, exp.Binary
741            ), "Columns can only be marked with (+) when involved in a binary operation"
742
743            predicate_parent = predicate.parent
744            join_predicate = predicate.pop()
745
746            left_columns = [
747                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
748            ]
749            right_columns = [
750                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
751            ]
752
753            assert not (
754                left_columns and right_columns
755            ), "The (+) marker cannot appear in both sides of a binary predicate"
756
757            marked_column_tables = set()
758            for col in left_columns or right_columns:
759                table = col.table
760                assert table, f"Column {col} needs to be qualified with a table"
761
762                col.set("join_mark", False)
763                marked_column_tables.add(table)
764
765            assert (
766                len(marked_column_tables) == 1
767            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
768
769            join_this = old_joins.get(col.table, query_from).this
770            new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
771
772            # Upsert new_join into new_joins dictionary
773            new_join_alias_or_name = new_join.alias_or_name
774            existing_join = new_joins.get(new_join_alias_or_name)
775            if existing_join:
776                existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
777            else:
778                new_joins[new_join_alias_or_name] = new_join
779
780            # If the parent of the target predicate is a binary node, then it now has only one child
781            if isinstance(predicate_parent, exp.Binary):
782                if predicate_parent.left is None:
783                    predicate_parent.replace(predicate_parent.right)
784                else:
785                    predicate_parent.replace(predicate_parent.left)
786
787        if query_from.alias_or_name in new_joins:
788            only_old_joins = old_joins.keys() - new_joins.keys()
789            assert (
790                len(only_old_joins) >= 1
791            ), "Cannot determine which table to use in the new FROM clause"
792
793            new_from_name = list(only_old_joins)[0]
794            query.set("from", exp.From(this=old_joins[new_from_name].this))
795
796        query.set("joins", list(new_joins.values()))
797
798        if not where.this:
799            where.pop()
800
801    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
13def preprocess(
14    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
15) -> t.Callable[[Generator, exp.Expression], str]:
16    """
17    Creates a new transform by chaining a sequence of transformations and converts the resulting
18    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
19    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
20
21    Args:
22        transforms: sequence of transform functions. These will be called in order.
23
24    Returns:
25        Function that can be used as a generator transform.
26    """
27
28    def _to_sql(self, expression: exp.Expression) -> str:
29        expression_type = type(expression)
30
31        expression = transforms[0](expression)
32        for transform in transforms[1:]:
33            expression = transform(expression)
34
35        _sql_handler = getattr(self, expression.key + "_sql", None)
36        if _sql_handler:
37            return _sql_handler(expression)
38
39        transforms_handler = self.TRANSFORMS.get(type(expression))
40        if transforms_handler:
41            if expression_type is type(expression):
42                if isinstance(expression, exp.Func):
43                    return self.function_fallback_sql(expression)
44
45                # Ensures we don't enter an infinite loop. This can happen when the original expression
46                # has the same type as the final expression and there's no _sql method available for it,
47                # because then it'd re-enter _to_sql.
48                raise ValueError(
49                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
50                )
51
52            return transforms_handler(self, expression)
53
54        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
55
56    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.

def unnest_generate_series( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
59def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
60    """Unnests GENERATE_SERIES or SEQUENCE table references."""
61    this = expression.this
62    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
63        unnest = exp.Unnest(expressions=[this])
64        if expression.alias:
65            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
66
67        return unnest
68
69    return expression

Unnests GENERATE_SERIES or SEQUENCE table references.

def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 72def unalias_group(expression: exp.Expression) -> exp.Expression:
 73    """
 74    Replace references to select aliases in GROUP BY clauses.
 75
 76    Example:
 77        >>> import sqlglot
 78        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
 79        'SELECT a AS b FROM x GROUP BY 1'
 80
 81    Args:
 82        expression: the expression that will be transformed.
 83
 84    Returns:
 85        The transformed expression.
 86    """
 87    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
 88        aliased_selects = {
 89            e.alias: i
 90            for i, e in enumerate(expression.parent.expressions, start=1)
 91            if isinstance(e, exp.Alias)
 92        }
 93
 94        for group_by in expression.expressions:
 95            if (
 96                isinstance(group_by, exp.Column)
 97                and not group_by.table
 98                and group_by.name in aliased_selects
 99            ):
100                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
101
102    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
105def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
106    """
107    Convert SELECT DISTINCT ON statements to a subquery with a window function.
108
109    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
110
111    Args:
112        expression: the expression that will be transformed.
113
114    Returns:
115        The transformed expression.
116    """
117    if (
118        isinstance(expression, exp.Select)
119        and expression.args.get("distinct")
120        and expression.args["distinct"].args.get("on")
121        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
122    ):
123        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
124        outer_selects = expression.selects
125        row_number = find_new_name(expression.named_selects, "_row_number")
126        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
127        order = expression.args.get("order")
128
129        if order:
130            window.set("order", order.pop())
131        else:
132            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
133
134        window = exp.alias_(window, row_number)
135        expression.select(window, copy=False)
136
137        return (
138            exp.select(*outer_selects, copy=False)
139            .from_(expression.subquery("_t", copy=False), copy=False)
140            .where(exp.column(row_number).eq(1), copy=False)
141        )
142
143    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
146def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
147    """
148    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
149
150    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
151    https://docs.snowflake.com/en/sql-reference/constructs/qualify
152
153    Some dialects don't support window functions in the WHERE clause, so we need to include them as
154    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
155    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
156    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
157    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
158    corresponding expression to avoid creating invalid column references.
159    """
160    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
161        taken = set(expression.named_selects)
162        for select in expression.selects:
163            if not select.alias_or_name:
164                alias = find_new_name(taken, "_c")
165                select.replace(exp.alias_(select, alias))
166                taken.add(alias)
167
168        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
169            alias_or_name = select.alias_or_name
170            identifier = select.args.get("alias") or select.this
171            if isinstance(identifier, exp.Identifier):
172                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
173            return alias_or_name
174
175        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
176        qualify_filters = expression.args["qualify"].pop().this
177        expression_by_alias = {
178            select.alias: select.this
179            for select in expression.selects
180            if isinstance(select, exp.Alias)
181        }
182
183        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
184        for select_candidate in qualify_filters.find_all(select_candidates):
185            if isinstance(select_candidate, exp.Window):
186                if expression_by_alias:
187                    for column in select_candidate.find_all(exp.Column):
188                        expr = expression_by_alias.get(column.name)
189                        if expr:
190                            column.replace(expr)
191
192                alias = find_new_name(expression.named_selects, "_w")
193                expression.select(exp.alias_(select_candidate, alias), copy=False)
194                column = exp.column(alias)
195
196                if isinstance(select_candidate.parent, exp.Qualify):
197                    qualify_filters = column
198                else:
199                    select_candidate.replace(column)
200            elif select_candidate.name not in expression.named_selects:
201                expression.select(select_candidate.copy(), copy=False)
202
203        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
204            qualify_filters, copy=False
205        )
206
207    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
210def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
211    """
212    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
213    other expressions. This transforms removes the precision from parameterized types in expressions.
214    """
215    for node in expression.find_all(exp.DataType):
216        node.set(
217            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
218        )
219
220    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unqualify_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
223def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
224    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
225    from sqlglot.optimizer.scope import find_all_in_scope
226
227    if isinstance(expression, exp.Select):
228        unnest_aliases = {
229            unnest.alias
230            for unnest in find_all_in_scope(expression, exp.Unnest)
231            if isinstance(unnest.parent, (exp.From, exp.Join))
232        }
233        if unnest_aliases:
234            for column in expression.find_all(exp.Column):
235                if column.table in unnest_aliases:
236                    column.set("table", None)
237                elif column.db in unnest_aliases:
238                    column.set("db", None)
239
240    return expression

Remove references to unnest table aliases, added by the optimizer's qualify_columns step.

def unnest_to_explode( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
243def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
244    """Convert cross join unnest into lateral view explode."""
245    if isinstance(expression, exp.Select):
246        from_ = expression.args.get("from")
247
248        if from_ and isinstance(from_.this, exp.Unnest):
249            unnest = from_.this
250            alias = unnest.args.get("alias")
251            udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
252            this, *expressions = unnest.expressions
253            unnest.replace(
254                exp.Table(
255                    this=udtf(
256                        this=this,
257                        expressions=expressions,
258                    ),
259                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
260                )
261            )
262
263        for join in expression.args.get("joins") or []:
264            unnest = join.this
265
266            if isinstance(unnest, exp.Unnest):
267                alias = unnest.args.get("alias")
268                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
269
270                expression.args["joins"].remove(join)
271
272                for e, column in zip(unnest.expressions, alias.columns if alias else []):
273                    expression.append(
274                        "laterals",
275                        exp.Lateral(
276                            this=udtf(this=e),
277                            view=True,
278                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
279                        ),
280                    )
281
282    return expression

Convert cross join unnest into lateral view explode.

def explode_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
285def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
286    """Convert explode/posexplode into unnest."""
287
288    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
289        if isinstance(expression, exp.Select):
290            from sqlglot.optimizer.scope import Scope
291
292            taken_select_names = set(expression.named_selects)
293            taken_source_names = {name for name, _ in Scope(expression).references}
294
295            def new_name(names: t.Set[str], name: str) -> str:
296                name = find_new_name(names, name)
297                names.add(name)
298                return name
299
300            arrays: t.List[exp.Condition] = []
301            series_alias = new_name(taken_select_names, "pos")
302            series = exp.alias_(
303                exp.Unnest(
304                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
305                ),
306                new_name(taken_source_names, "_u"),
307                table=[series_alias],
308            )
309
310            # we use list here because expression.selects is mutated inside the loop
311            for select in list(expression.selects):
312                explode = select.find(exp.Explode)
313
314                if explode:
315                    pos_alias = ""
316                    explode_alias = ""
317
318                    if isinstance(select, exp.Alias):
319                        explode_alias = select.args["alias"]
320                        alias = select
321                    elif isinstance(select, exp.Aliases):
322                        pos_alias = select.aliases[0]
323                        explode_alias = select.aliases[1]
324                        alias = select.replace(exp.alias_(select.this, "", copy=False))
325                    else:
326                        alias = select.replace(exp.alias_(select, ""))
327                        explode = alias.find(exp.Explode)
328                        assert explode
329
330                    is_posexplode = isinstance(explode, exp.Posexplode)
331                    explode_arg = explode.this
332
333                    if isinstance(explode, exp.ExplodeOuter):
334                        bracket = explode_arg[0]
335                        bracket.set("safe", True)
336                        bracket.set("offset", True)
337                        explode_arg = exp.func(
338                            "IF",
339                            exp.func(
340                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
341                            ).eq(0),
342                            exp.array(bracket, copy=False),
343                            explode_arg,
344                        )
345
346                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
347                    if isinstance(explode_arg, exp.Column):
348                        taken_select_names.add(explode_arg.output_name)
349
350                    unnest_source_alias = new_name(taken_source_names, "_u")
351
352                    if not explode_alias:
353                        explode_alias = new_name(taken_select_names, "col")
354
355                        if is_posexplode:
356                            pos_alias = new_name(taken_select_names, "pos")
357
358                    if not pos_alias:
359                        pos_alias = new_name(taken_select_names, "pos")
360
361                    alias.set("alias", exp.to_identifier(explode_alias))
362
363                    series_table_alias = series.args["alias"].this
364                    column = exp.If(
365                        this=exp.column(series_alias, table=series_table_alias).eq(
366                            exp.column(pos_alias, table=unnest_source_alias)
367                        ),
368                        true=exp.column(explode_alias, table=unnest_source_alias),
369                    )
370
371                    explode.replace(column)
372
373                    if is_posexplode:
374                        expressions = expression.expressions
375                        expressions.insert(
376                            expressions.index(alias) + 1,
377                            exp.If(
378                                this=exp.column(series_alias, table=series_table_alias).eq(
379                                    exp.column(pos_alias, table=unnest_source_alias)
380                                ),
381                                true=exp.column(pos_alias, table=unnest_source_alias),
382                            ).as_(pos_alias),
383                        )
384                        expression.set("expressions", expressions)
385
386                    if not arrays:
387                        if expression.args.get("from"):
388                            expression.join(series, copy=False, join_type="CROSS")
389                        else:
390                            expression.from_(series, copy=False)
391
392                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
393                    arrays.append(size)
394
395                    # trino doesn't support left join unnest with on conditions
396                    # if it did, this would be much simpler
397                    expression.join(
398                        exp.alias_(
399                            exp.Unnest(
400                                expressions=[explode_arg.copy()],
401                                offset=exp.to_identifier(pos_alias),
402                            ),
403                            unnest_source_alias,
404                            table=[explode_alias],
405                        ),
406                        join_type="CROSS",
407                        copy=False,
408                    )
409
410                    if index_offset != 1:
411                        size = size - 1
412
413                    expression.where(
414                        exp.column(series_alias, table=series_table_alias)
415                        .eq(exp.column(pos_alias, table=unnest_source_alias))
416                        .or_(
417                            (exp.column(series_alias, table=series_table_alias) > size).and_(
418                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
419                            )
420                        ),
421                        copy=False,
422                    )
423
424            if arrays:
425                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
426
427                if index_offset != 1:
428                    end = end - (1 - index_offset)
429                series.expressions[0].set("end", end)
430
431        return expression
432
433    return _explode_to_unnest

Convert explode/posexplode into unnest.

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
436def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
437    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
438    if (
439        isinstance(expression, exp.PERCENTILES)
440        and not isinstance(expression.parent, exp.WithinGroup)
441        and expression.expression
442    ):
443        column = expression.this.pop()
444        expression.set("this", expression.expression.pop())
445        order = exp.Order(expressions=[exp.Ordered(this=column)])
446        expression = exp.WithinGroup(this=expression, expression=order)
447
448    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
451def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
452    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
453    if (
454        isinstance(expression, exp.WithinGroup)
455        and isinstance(expression.this, exp.PERCENTILES)
456        and isinstance(expression.expression, exp.Order)
457    ):
458        quantile = expression.this.this
459        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
460        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
461
462    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
465def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
466    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
467    if isinstance(expression, exp.With) and expression.recursive:
468        next_name = name_sequence("_c_")
469
470        for cte in expression.expressions:
471            if not cte.args["alias"].columns:
472                query = cte.this
473                if isinstance(query, exp.SetOperation):
474                    query = query.this
475
476                cte.args["alias"].set(
477                    "columns",
478                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
479                )
480
481    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
484def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
485    """Replace 'epoch' in casts by the equivalent date literal."""
486    if (
487        isinstance(expression, (exp.Cast, exp.TryCast))
488        and expression.name.lower() == "epoch"
489        and expression.to.this in exp.DataType.TEMPORAL_TYPES
490    ):
491        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
492
493    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
496def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
497    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
498    if isinstance(expression, exp.Select):
499        for join in expression.args.get("joins") or []:
500            on = join.args.get("on")
501            if on and join.kind in ("SEMI", "ANTI"):
502                subquery = exp.select("1").from_(join.this).where(on)
503                exists = exp.Exists(this=subquery)
504                if join.kind == "ANTI":
505                    exists = exists.not_(copy=False)
506
507                join.pop()
508                expression.where(exists, copy=False)
509
510    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
513def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
514    """
515    Converts a query with a FULL OUTER join to a union of identical queries that
516    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
517    for queries that have a single FULL OUTER join.
518    """
519    if isinstance(expression, exp.Select):
520        full_outer_joins = [
521            (index, join)
522            for index, join in enumerate(expression.args.get("joins") or [])
523            if join.side == "FULL"
524        ]
525
526        if len(full_outer_joins) == 1:
527            expression_copy = expression.copy()
528            expression.set("limit", None)
529            index, full_outer_join = full_outer_joins[0]
530            full_outer_join.set("side", "left")
531            expression_copy.args["joins"][index].set("side", "right")
532            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
533
534            return exp.union(expression, expression_copy, copy=False)
535
536    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
539def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
540    """
541    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
542    defined at the top-level, so for example queries like:
543
544        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
545
546    are invalid in those dialects. This transformation can be used to ensure all CTEs are
547    moved to the top level so that the final SQL code is valid from a syntax standpoint.
548
549    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
550    """
551    top_level_with = expression.args.get("with")
552    for inner_with in expression.find_all(exp.With):
553        if inner_with.parent is expression:
554            continue
555
556        if not top_level_with:
557            top_level_with = inner_with.pop()
558            expression.set("with", top_level_with)
559        else:
560            if inner_with.recursive:
561                top_level_with.set("recursive", True)
562
563            parent_cte = inner_with.find_ancestor(exp.CTE)
564            inner_with.pop()
565
566            if parent_cte:
567                i = top_level_with.expressions.index(parent_cte)
568                top_level_with.expressions[i:i] = inner_with.expressions
569                top_level_with.set("expressions", top_level_with.expressions)
570            else:
571                top_level_with.set(
572                    "expressions", top_level_with.expressions + inner_with.expressions
573                )
574
575    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
578def ensure_bools(expression: exp.Expression) -> exp.Expression:
579    """Converts numeric values used in conditions into explicit boolean expressions."""
580    from sqlglot.optimizer.canonicalize import ensure_bools
581
582    def _ensure_bool(node: exp.Expression) -> None:
583        if (
584            node.is_number
585            or (
586                not isinstance(node, exp.SubqueryPredicate)
587                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
588            )
589            or (isinstance(node, exp.Column) and not node.type)
590        ):
591            node.replace(node.neq(0))
592
593    for node in expression.walk():
594        ensure_bools(node, _ensure_bool)
595
596    return expression

Converts numeric values used in conditions into explicit boolean expressions.

def unqualify_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
599def unqualify_columns(expression: exp.Expression) -> exp.Expression:
600    for column in expression.find_all(exp.Column):
601        # We only wanna pop off the table, db, catalog args
602        for part in column.parts[:-1]:
603            part.pop()
604
605    return expression
def remove_unique_constraints( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
608def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
609    assert isinstance(expression, exp.Create)
610    for constraint in expression.find_all(exp.UniqueColumnConstraint):
611        if constraint.parent:
612            constraint.parent.pop()
613
614    return expression
def ctas_with_tmp_tables_to_create_tmp_view( expression: sqlglot.expressions.Expression, tmp_storage_provider: Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression] = <function <lambda>>) -> sqlglot.expressions.Expression:
617def ctas_with_tmp_tables_to_create_tmp_view(
618    expression: exp.Expression,
619    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
620) -> exp.Expression:
621    assert isinstance(expression, exp.Create)
622    properties = expression.args.get("properties")
623    temporary = any(
624        isinstance(prop, exp.TemporaryProperty)
625        for prop in (properties.expressions if properties else [])
626    )
627
628    # CTAS with temp tables map to CREATE TEMPORARY VIEW
629    if expression.kind == "TABLE" and temporary:
630        if expression.expression:
631            return exp.Create(
632                kind="TEMPORARY VIEW",
633                this=expression.this,
634                expression=expression.expression,
635            )
636        return tmp_storage_provider(expression)
637
638    return expression
def move_schema_columns_to_partitioned_by( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
641def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
642    """
643    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
644    PARTITIONED BY value is an array of column names, they are transformed into a schema.
645    The corresponding columns are removed from the create statement.
646    """
647    assert isinstance(expression, exp.Create)
648    has_schema = isinstance(expression.this, exp.Schema)
649    is_partitionable = expression.kind in {"TABLE", "VIEW"}
650
651    if has_schema and is_partitionable:
652        prop = expression.find(exp.PartitionedByProperty)
653        if prop and prop.this and not isinstance(prop.this, exp.Schema):
654            schema = expression.this
655            columns = {v.name.upper() for v in prop.this.expressions}
656            partitions = [col for col in schema.expressions if col.name.upper() in columns]
657            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
658            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
659            expression.set("this", schema)
660
661    return expression

In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def move_partitioned_by_to_schema_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
664def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
665    """
666    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
667
668    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
669    """
670    assert isinstance(expression, exp.Create)
671    prop = expression.find(exp.PartitionedByProperty)
672    if (
673        prop
674        and prop.this
675        and isinstance(prop.this, exp.Schema)
676        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
677    ):
678        prop_this = exp.Tuple(
679            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
680        )
681        schema = expression.this
682        for e in prop.this.expressions:
683            schema.append("expressions", e)
684        prop.set("this", prop_this)
685
686    return expression

Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.

Currently, SQLGlot uses the DATASOURCE format for Spark 3.

def struct_kv_to_alias( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
689def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
690    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
691    if isinstance(expression, exp.Struct):
692        expression.set(
693            "expressions",
694            [
695                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
696                for e in expression.expressions
697            ],
698        )
699
700    return expression

Converts struct arguments to aliases, e.g. STRUCT(1 AS y).

def eliminate_join_marks( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
703def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
704    """
705    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
706    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
707
708    For example,
709        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
710        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
711
712    Args:
713        expression: The AST to remove join marks from.
714
715    Returns:
716       The AST with join marks removed.
717    """
718    from sqlglot.optimizer.scope import traverse_scope
719
720    for scope in traverse_scope(expression):
721        query = scope.expression
722
723        where = query.args.get("where")
724        joins = query.args.get("joins")
725
726        if not where or not joins:
727            continue
728
729        query_from = query.args["from"]
730
731        # These keep track of the joins to be replaced
732        new_joins: t.Dict[str, exp.Join] = {}
733        old_joins = {join.alias_or_name: join for join in joins}
734
735        for column in scope.columns:
736            if not column.args.get("join_mark"):
737                continue
738
739            predicate = column.find_ancestor(exp.Predicate, exp.Select)
740            assert isinstance(
741                predicate, exp.Binary
742            ), "Columns can only be marked with (+) when involved in a binary operation"
743
744            predicate_parent = predicate.parent
745            join_predicate = predicate.pop()
746
747            left_columns = [
748                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
749            ]
750            right_columns = [
751                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
752            ]
753
754            assert not (
755                left_columns and right_columns
756            ), "The (+) marker cannot appear in both sides of a binary predicate"
757
758            marked_column_tables = set()
759            for col in left_columns or right_columns:
760                table = col.table
761                assert table, f"Column {col} needs to be qualified with a table"
762
763                col.set("join_mark", False)
764                marked_column_tables.add(table)
765
766            assert (
767                len(marked_column_tables) == 1
768            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
769
770            join_this = old_joins.get(col.table, query_from).this
771            new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
772
773            # Upsert new_join into new_joins dictionary
774            new_join_alias_or_name = new_join.alias_or_name
775            existing_join = new_joins.get(new_join_alias_or_name)
776            if existing_join:
777                existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
778            else:
779                new_joins[new_join_alias_or_name] = new_join
780
781            # If the parent of the target predicate is a binary node, then it now has only one child
782            if isinstance(predicate_parent, exp.Binary):
783                if predicate_parent.left is None:
784                    predicate_parent.replace(predicate_parent.right)
785                else:
786                    predicate_parent.replace(predicate_parent.left)
787
788        if query_from.alias_or_name in new_joins:
789            only_old_joins = old_joins.keys() - new_joins.keys()
790            assert (
791                len(only_old_joins) >= 1
792            ), "Cannot determine which table to use in the new FROM clause"
793
794            new_from_name = list(only_old_joins)[0]
795            query.set("from", exp.From(this=old_joins[new_from_name].this))
796
797        query.set("joins", list(new_joins.values()))
798
799        if not where.this:
800            where.pop()
801
802    return expression

Remove join marks from an AST. This rule assumes that all marked columns are qualified. If this does not hold for a query, consider running sqlglot.optimizer.qualify first.

For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this

Arguments:
  • expression: The AST to remove join marks from.
Returns:

The AST with join marks removed.