sqlglot.optimizer.simplify
1from __future__ import annotations 2 3import datetime 4import functools 5import itertools 6import typing as t 7from collections import deque 8from decimal import Decimal 9 10import sqlglot 11from sqlglot import Dialect, exp 12from sqlglot.helper import first, merge_ranges, while_changing 13from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope 14 15if t.TYPE_CHECKING: 16 from sqlglot.dialects.dialect import DialectType 17 18 DateTruncBinaryTransform = t.Callable[ 19 [exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression] 20 ] 21 22# Final means that an expression should not be simplified 23FINAL = "final" 24 25 26class UnsupportedUnit(Exception): 27 pass 28 29 30def simplify( 31 expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None 32): 33 """ 34 Rewrite sqlglot AST to simplify expressions. 35 36 Example: 37 >>> import sqlglot 38 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 39 >>> simplify(expression).sql() 40 'TRUE' 41 42 Args: 43 expression (sqlglot.Expression): expression to simplify 44 constant_propagation: whether the constant propagation rule should be used 45 46 Returns: 47 sqlglot.Expression: simplified expression 48 """ 49 50 dialect = Dialect.get_or_raise(dialect) 51 52 def _simplify(expression, root=True): 53 if expression.meta.get(FINAL): 54 return expression 55 56 # group by expressions cannot be simplified, for example 57 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 58 # the projection must exactly match the group by key 59 group = expression.args.get("group") 60 61 if group and hasattr(expression, "selects"): 62 groups = set(group.expressions) 63 group.meta[FINAL] = True 64 65 for e in expression.selects: 66 for node in e.walk(): 67 if node in groups: 68 e.meta[FINAL] = True 69 break 70 71 having = expression.args.get("having") 72 if having: 73 for node in having.walk(): 74 if node in groups: 75 having.meta[FINAL] = True 76 break 77 78 # Pre-order transformations 79 node = expression 80 node = rewrite_between(node) 81 node = uniq_sort(node, root) 82 node = absorb_and_eliminate(node, root) 83 node = simplify_concat(node) 84 node = simplify_conditionals(node) 85 86 if constant_propagation: 87 node = propagate_constants(node, root) 88 89 exp.replace_children(node, lambda e: _simplify(e, False)) 90 91 # Post-order transformations 92 node = simplify_not(node) 93 node = flatten(node) 94 node = simplify_connectors(node, root) 95 node = remove_complements(node, root) 96 node = simplify_coalesce(node) 97 node.parent = expression.parent 98 node = simplify_literals(node, root) 99 node = simplify_equality(node) 100 node = simplify_parens(node) 101 node = simplify_datetrunc(node, dialect) 102 node = sort_comparison(node) 103 node = simplify_startswith(node) 104 105 if root: 106 expression.replace(node) 107 return node 108 109 expression = while_changing(expression, _simplify) 110 remove_where_true(expression) 111 return expression 112 113 114def catch(*exceptions): 115 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 116 117 def decorator(func): 118 def wrapped(expression, *args, **kwargs): 119 try: 120 return func(expression, *args, **kwargs) 121 except exceptions: 122 return expression 123 124 return wrapped 125 126 return decorator 127 128 129def rewrite_between(expression: exp.Expression) -> exp.Expression: 130 """Rewrite x between y and z to x >= y AND x <= z. 131 132 This is done because comparison simplification is only done on lt/lte/gt/gte. 133 """ 134 if isinstance(expression, exp.Between): 135 negate = isinstance(expression.parent, exp.Not) 136 137 expression = exp.and_( 138 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 139 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 140 copy=False, 141 ) 142 143 if negate: 144 expression = exp.paren(expression, copy=False) 145 146 return expression 147 148 149COMPLEMENT_COMPARISONS = { 150 exp.LT: exp.GTE, 151 exp.GT: exp.LTE, 152 exp.LTE: exp.GT, 153 exp.GTE: exp.LT, 154 exp.EQ: exp.NEQ, 155 exp.NEQ: exp.EQ, 156} 157 158 159def simplify_not(expression): 160 """ 161 Demorgan's Law 162 NOT (x OR y) -> NOT x AND NOT y 163 NOT (x AND y) -> NOT x OR NOT y 164 """ 165 if isinstance(expression, exp.Not): 166 this = expression.this 167 if is_null(this): 168 return exp.null() 169 if this.__class__ in COMPLEMENT_COMPARISONS: 170 return COMPLEMENT_COMPARISONS[this.__class__]( 171 this=this.this, expression=this.expression 172 ) 173 if isinstance(this, exp.Paren): 174 condition = this.unnest() 175 if isinstance(condition, exp.And): 176 return exp.paren( 177 exp.or_( 178 exp.not_(condition.left, copy=False), 179 exp.not_(condition.right, copy=False), 180 copy=False, 181 ) 182 ) 183 if isinstance(condition, exp.Or): 184 return exp.paren( 185 exp.and_( 186 exp.not_(condition.left, copy=False), 187 exp.not_(condition.right, copy=False), 188 copy=False, 189 ) 190 ) 191 if is_null(condition): 192 return exp.null() 193 if always_true(this): 194 return exp.false() 195 if is_false(this): 196 return exp.true() 197 if isinstance(this, exp.Not): 198 # double negation 199 # NOT NOT x -> x 200 return this.this 201 return expression 202 203 204def flatten(expression): 205 """ 206 A AND (B AND C) -> A AND B AND C 207 A OR (B OR C) -> A OR B OR C 208 """ 209 if isinstance(expression, exp.Connector): 210 for node in expression.args.values(): 211 child = node.unnest() 212 if isinstance(child, expression.__class__): 213 node.replace(child) 214 return expression 215 216 217def simplify_connectors(expression, root=True): 218 def _simplify_connectors(expression, left, right): 219 if left == right: 220 return left 221 if isinstance(expression, exp.And): 222 if is_false(left) or is_false(right): 223 return exp.false() 224 if is_null(left) or is_null(right): 225 return exp.null() 226 if always_true(left) and always_true(right): 227 return exp.true() 228 if always_true(left): 229 return right 230 if always_true(right): 231 return left 232 return _simplify_comparison(expression, left, right) 233 elif isinstance(expression, exp.Or): 234 if always_true(left) or always_true(right): 235 return exp.true() 236 if is_false(left) and is_false(right): 237 return exp.false() 238 if ( 239 (is_null(left) and is_null(right)) 240 or (is_null(left) and is_false(right)) 241 or (is_false(left) and is_null(right)) 242 ): 243 return exp.null() 244 if is_false(left): 245 return right 246 if is_false(right): 247 return left 248 return _simplify_comparison(expression, left, right, or_=True) 249 250 if isinstance(expression, exp.Connector): 251 return _flat_simplify(expression, _simplify_connectors, root) 252 return expression 253 254 255LT_LTE = (exp.LT, exp.LTE) 256GT_GTE = (exp.GT, exp.GTE) 257 258COMPARISONS = ( 259 *LT_LTE, 260 *GT_GTE, 261 exp.EQ, 262 exp.NEQ, 263 exp.Is, 264) 265 266INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 267 exp.LT: exp.GT, 268 exp.GT: exp.LT, 269 exp.LTE: exp.GTE, 270 exp.GTE: exp.LTE, 271} 272 273NONDETERMINISTIC = (exp.Rand, exp.Randn) 274 275 276def _simplify_comparison(expression, left, right, or_=False): 277 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 278 ll, lr = left.args.values() 279 rl, rr = right.args.values() 280 281 largs = {ll, lr} 282 rargs = {rl, rr} 283 284 matching = largs & rargs 285 columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)} 286 287 if matching and columns: 288 try: 289 l = first(largs - columns) 290 r = first(rargs - columns) 291 except StopIteration: 292 return expression 293 294 if l.is_number and r.is_number: 295 l = float(l.name) 296 r = float(r.name) 297 elif l.is_string and r.is_string: 298 l = l.name 299 r = r.name 300 else: 301 l = extract_date(l) 302 if not l: 303 return None 304 r = extract_date(r) 305 if not r: 306 return None 307 308 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 309 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 310 return left if (av > bv if or_ else av <= bv) else right 311 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 312 return left if (av < bv if or_ else av >= bv) else right 313 314 # we can't ever shortcut to true because the column could be null 315 if not or_: 316 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 317 if av <= bv: 318 return exp.false() 319 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 320 if av >= bv: 321 return exp.false() 322 elif isinstance(a, exp.EQ): 323 if isinstance(b, exp.LT): 324 return exp.false() if av >= bv else a 325 if isinstance(b, exp.LTE): 326 return exp.false() if av > bv else a 327 if isinstance(b, exp.GT): 328 return exp.false() if av <= bv else a 329 if isinstance(b, exp.GTE): 330 return exp.false() if av < bv else a 331 if isinstance(b, exp.NEQ): 332 return exp.false() if av == bv else a 333 return None 334 335 336def remove_complements(expression, root=True): 337 """ 338 Removing complements. 339 340 A AND NOT A -> FALSE 341 A OR NOT A -> TRUE 342 """ 343 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 344 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 345 346 for a, b in itertools.permutations(expression.flatten(), 2): 347 if is_complement(a, b): 348 return complement 349 return expression 350 351 352def uniq_sort(expression, root=True): 353 """ 354 Uniq and sort a connector. 355 356 C AND A AND B AND B -> A AND B AND C 357 """ 358 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 359 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 360 flattened = tuple(expression.flatten()) 361 deduped = {gen(e): e for e in flattened} 362 arr = tuple(deduped.items()) 363 364 # check if the operands are already sorted, if not sort them 365 # A AND C AND B -> A AND B AND C 366 for i, (sql, e) in enumerate(arr[1:]): 367 if sql < arr[i][0]: 368 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 369 break 370 else: 371 # we didn't have to sort but maybe we need to dedup 372 if len(deduped) < len(flattened): 373 expression = result_func(*deduped.values(), copy=False) 374 375 return expression 376 377 378def absorb_and_eliminate(expression, root=True): 379 """ 380 absorption: 381 A AND (A OR B) -> A 382 A OR (A AND B) -> A 383 A AND (NOT A OR B) -> A AND B 384 A OR (NOT A AND B) -> A OR B 385 elimination: 386 (A AND B) OR (A AND NOT B) -> A 387 (A OR B) AND (A OR NOT B) -> A 388 """ 389 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 390 kind = exp.Or if isinstance(expression, exp.And) else exp.And 391 392 for a, b in itertools.permutations(expression.flatten(), 2): 393 if isinstance(a, kind): 394 aa, ab = a.unnest_operands() 395 396 # absorb 397 if is_complement(b, aa): 398 aa.replace(exp.true() if kind == exp.And else exp.false()) 399 elif is_complement(b, ab): 400 ab.replace(exp.true() if kind == exp.And else exp.false()) 401 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 402 a.replace(exp.false() if kind == exp.And else exp.true()) 403 elif isinstance(b, kind): 404 # eliminate 405 rhs = b.unnest_operands() 406 ba, bb = rhs 407 408 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 409 a.replace(aa) 410 b.replace(aa) 411 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 412 a.replace(ab) 413 b.replace(ab) 414 415 return expression 416 417 418def propagate_constants(expression, root=True): 419 """ 420 Propagate constants for conjunctions in DNF: 421 422 SELECT * FROM t WHERE a = b AND b = 5 becomes 423 SELECT * FROM t WHERE a = 5 AND b = 5 424 425 Reference: https://www.sqlite.org/optoverview.html 426 """ 427 428 if ( 429 isinstance(expression, exp.And) 430 and (root or not expression.same_parent) 431 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 432 ): 433 constant_mapping = {} 434 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 435 if isinstance(expr, exp.EQ): 436 l, r = expr.left, expr.right 437 438 # TODO: create a helper that can be used to detect nested literal expressions such 439 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 440 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 441 constant_mapping[l] = (id(l), r) 442 443 if constant_mapping: 444 for column in find_all_in_scope(expression, exp.Column): 445 parent = column.parent 446 column_id, constant = constant_mapping.get(column) or (None, None) 447 if ( 448 column_id is not None 449 and id(column) != column_id 450 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 451 ): 452 column.replace(constant.copy()) 453 454 return expression 455 456 457INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 458 exp.DateAdd: exp.Sub, 459 exp.DateSub: exp.Add, 460 exp.DatetimeAdd: exp.Sub, 461 exp.DatetimeSub: exp.Add, 462} 463 464INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 465 **INVERSE_DATE_OPS, 466 exp.Add: exp.Sub, 467 exp.Sub: exp.Add, 468} 469 470 471def _is_number(expression: exp.Expression) -> bool: 472 return expression.is_number 473 474 475def _is_interval(expression: exp.Expression) -> bool: 476 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 477 478 479@catch(ModuleNotFoundError, UnsupportedUnit) 480def simplify_equality(expression: exp.Expression) -> exp.Expression: 481 """ 482 Use the subtraction and addition properties of equality to simplify expressions: 483 484 x + 1 = 3 becomes x = 2 485 486 There are two binary operations in the above expression: + and = 487 Here's how we reference all the operands in the code below: 488 489 l r 490 x + 1 = 3 491 a b 492 """ 493 if isinstance(expression, COMPARISONS): 494 l, r = expression.left, expression.right 495 496 if l.__class__ not in INVERSE_OPS: 497 return expression 498 499 if r.is_number: 500 a_predicate = _is_number 501 b_predicate = _is_number 502 elif _is_date_literal(r): 503 a_predicate = _is_date_literal 504 b_predicate = _is_interval 505 else: 506 return expression 507 508 if l.__class__ in INVERSE_DATE_OPS: 509 l = t.cast(exp.IntervalOp, l) 510 a = l.this 511 b = l.interval() 512 else: 513 l = t.cast(exp.Binary, l) 514 a, b = l.left, l.right 515 516 if not a_predicate(a) and b_predicate(b): 517 pass 518 elif not a_predicate(b) and b_predicate(a): 519 a, b = b, a 520 else: 521 return expression 522 523 return expression.__class__( 524 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 525 ) 526 return expression 527 528 529def simplify_literals(expression, root=True): 530 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 531 return _flat_simplify(expression, _simplify_binary, root) 532 533 if isinstance(expression, exp.Neg): 534 this = expression.this 535 if this.is_number: 536 value = this.name 537 if value[0] == "-": 538 return exp.Literal.number(value[1:]) 539 return exp.Literal.number(f"-{value}") 540 541 if type(expression) in INVERSE_DATE_OPS: 542 return _simplify_binary(expression, expression.this, expression.interval()) or expression 543 544 return expression 545 546 547NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ) 548 549 550def _simplify_binary(expression, a, b): 551 if isinstance(expression, exp.Is): 552 if isinstance(b, exp.Not): 553 c = b.this 554 not_ = True 555 else: 556 c = b 557 not_ = False 558 559 if is_null(c): 560 if isinstance(a, exp.Literal): 561 return exp.true() if not_ else exp.false() 562 if is_null(a): 563 return exp.false() if not_ else exp.true() 564 elif isinstance(expression, NULL_OK): 565 return None 566 elif is_null(a) or is_null(b): 567 return exp.null() 568 569 if a.is_number and b.is_number: 570 num_a = int(a.name) if a.is_int else Decimal(a.name) 571 num_b = int(b.name) if b.is_int else Decimal(b.name) 572 573 if isinstance(expression, exp.Add): 574 return exp.Literal.number(num_a + num_b) 575 if isinstance(expression, exp.Mul): 576 return exp.Literal.number(num_a * num_b) 577 578 # We only simplify Sub, Div if a and b have the same parent because they're not associative 579 if isinstance(expression, exp.Sub): 580 return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None 581 if isinstance(expression, exp.Div): 582 # engines have differing int div behavior so intdiv is not safe 583 if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: 584 return None 585 return exp.Literal.number(num_a / num_b) 586 587 boolean = eval_boolean(expression, num_a, num_b) 588 589 if boolean: 590 return boolean 591 elif a.is_string and b.is_string: 592 boolean = eval_boolean(expression, a.this, b.this) 593 594 if boolean: 595 return boolean 596 elif _is_date_literal(a) and isinstance(b, exp.Interval): 597 a, b = extract_date(a), extract_interval(b) 598 if a and b: 599 if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): 600 return date_literal(a + b) 601 if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): 602 return date_literal(a - b) 603 elif isinstance(a, exp.Interval) and _is_date_literal(b): 604 a, b = extract_interval(a), extract_date(b) 605 # you cannot subtract a date from an interval 606 if a and b and isinstance(expression, exp.Add): 607 return date_literal(a + b) 608 elif _is_date_literal(a) and _is_date_literal(b): 609 if isinstance(expression, exp.Predicate): 610 a, b = extract_date(a), extract_date(b) 611 boolean = eval_boolean(expression, a, b) 612 if boolean: 613 return boolean 614 615 return None 616 617 618def simplify_parens(expression): 619 if not isinstance(expression, exp.Paren): 620 return expression 621 622 this = expression.this 623 parent = expression.parent 624 parent_is_predicate = isinstance(parent, exp.Predicate) 625 626 if not isinstance(this, exp.Select) and ( 627 not isinstance(parent, (exp.Condition, exp.Binary)) 628 or isinstance(parent, exp.Paren) 629 or ( 630 not isinstance(this, exp.Binary) 631 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 632 ) 633 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 634 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 635 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 636 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 637 ): 638 return this 639 return expression 640 641 642def _is_nonnull_constant(expression: exp.Expression) -> bool: 643 return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression) 644 645 646def _is_constant(expression: exp.Expression) -> bool: 647 return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression) 648 649 650def simplify_coalesce(expression): 651 # COALESCE(x) -> x 652 if ( 653 isinstance(expression, exp.Coalesce) 654 and (not expression.expressions or _is_nonnull_constant(expression.this)) 655 # COALESCE is also used as a Spark partitioning hint 656 and not isinstance(expression.parent, exp.Hint) 657 ): 658 return expression.this 659 660 if not isinstance(expression, COMPARISONS): 661 return expression 662 663 if isinstance(expression.left, exp.Coalesce): 664 coalesce = expression.left 665 other = expression.right 666 elif isinstance(expression.right, exp.Coalesce): 667 coalesce = expression.right 668 other = expression.left 669 else: 670 return expression 671 672 # This transformation is valid for non-constants, 673 # but it really only does anything if they are both constants. 674 if not _is_constant(other): 675 return expression 676 677 # Find the first constant arg 678 for arg_index, arg in enumerate(coalesce.expressions): 679 if _is_constant(arg): 680 break 681 else: 682 return expression 683 684 coalesce.set("expressions", coalesce.expressions[:arg_index]) 685 686 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 687 # since we already remove COALESCE at the top of this function. 688 coalesce = coalesce if coalesce.expressions else coalesce.this 689 690 # This expression is more complex than when we started, but it will get simplified further 691 return exp.paren( 692 exp.or_( 693 exp.and_( 694 coalesce.is_(exp.null()).not_(copy=False), 695 expression.copy(), 696 copy=False, 697 ), 698 exp.and_( 699 coalesce.is_(exp.null()), 700 type(expression)(this=arg.copy(), expression=other.copy()), 701 copy=False, 702 ), 703 copy=False, 704 ) 705 ) 706 707 708CONCATS = (exp.Concat, exp.DPipe) 709 710 711def simplify_concat(expression): 712 """Reduces all groups that contain string literals by concatenating them.""" 713 if not isinstance(expression, CONCATS) or ( 714 # We can't reduce a CONCAT_WS call if we don't statically know the separator 715 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 716 ): 717 return expression 718 719 if isinstance(expression, exp.ConcatWs): 720 sep_expr, *expressions = expression.expressions 721 sep = sep_expr.name 722 concat_type = exp.ConcatWs 723 args = {} 724 else: 725 expressions = expression.expressions 726 sep = "" 727 concat_type = exp.Concat 728 args = { 729 "safe": expression.args.get("safe"), 730 "coalesce": expression.args.get("coalesce"), 731 } 732 733 new_args = [] 734 for is_string_group, group in itertools.groupby( 735 expressions or expression.flatten(), lambda e: e.is_string 736 ): 737 if is_string_group: 738 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 739 else: 740 new_args.extend(group) 741 742 if len(new_args) == 1 and new_args[0].is_string: 743 return new_args[0] 744 745 if concat_type is exp.ConcatWs: 746 new_args = [sep_expr] + new_args 747 748 return concat_type(expressions=new_args, **args) 749 750 751def simplify_conditionals(expression): 752 """Simplifies expressions like IF, CASE if their condition is statically known.""" 753 if isinstance(expression, exp.Case): 754 this = expression.this 755 for case in expression.args["ifs"]: 756 cond = case.this 757 if this: 758 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 759 cond = cond.replace(this.pop().eq(cond)) 760 761 if always_true(cond): 762 return case.args["true"] 763 764 if always_false(cond): 765 case.pop() 766 if not expression.args["ifs"]: 767 return expression.args.get("default") or exp.null() 768 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 769 if always_true(expression.this): 770 return expression.args["true"] 771 if always_false(expression.this): 772 return expression.args.get("false") or exp.null() 773 774 return expression 775 776 777def simplify_startswith(expression: exp.Expression) -> exp.Expression: 778 """ 779 Reduces a prefix check to either TRUE or FALSE if both the string and the 780 prefix are statically known. 781 782 Example: 783 >>> from sqlglot import parse_one 784 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 785 'TRUE' 786 """ 787 if ( 788 isinstance(expression, exp.StartsWith) 789 and expression.this.is_string 790 and expression.expression.is_string 791 ): 792 return exp.convert(expression.name.startswith(expression.expression.name)) 793 794 return expression 795 796 797DateRange = t.Tuple[datetime.date, datetime.date] 798 799 800def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: 801 """ 802 Get the date range for a DATE_TRUNC equality comparison: 803 804 Example: 805 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 806 Returns: 807 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 808 """ 809 floor = date_floor(date, unit, dialect) 810 811 if date != floor: 812 # This will always be False, except for NULL values. 813 return None 814 815 return floor, floor + interval(unit) 816 817 818def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression: 819 """Get the logical expression for a date range""" 820 return exp.and_( 821 left >= date_literal(drange[0]), 822 left < date_literal(drange[1]), 823 copy=False, 824 ) 825 826 827def _datetrunc_eq( 828 left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect 829) -> t.Optional[exp.Expression]: 830 drange = _datetrunc_range(date, unit, dialect) 831 if not drange: 832 return None 833 834 return _datetrunc_eq_expression(left, drange) 835 836 837def _datetrunc_neq( 838 left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect 839) -> t.Optional[exp.Expression]: 840 drange = _datetrunc_range(date, unit, dialect) 841 if not drange: 842 return None 843 844 return exp.and_( 845 left < date_literal(drange[0]), 846 left >= date_literal(drange[1]), 847 copy=False, 848 ) 849 850 851DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 852 exp.LT: lambda l, dt, u, d: l 853 < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)), 854 exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)), 855 exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)), 856 exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)), 857 exp.EQ: _datetrunc_eq, 858 exp.NEQ: _datetrunc_neq, 859} 860DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 861DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) 862 863 864def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 865 return isinstance(left, DATETRUNCS) and _is_date_literal(right) 866 867 868@catch(ModuleNotFoundError, UnsupportedUnit) 869def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression: 870 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 871 comparison = expression.__class__ 872 873 if isinstance(expression, DATETRUNCS): 874 date = extract_date(expression.this) 875 if date and expression.unit: 876 return date_literal(date_floor(date, expression.unit.name.lower(), dialect)) 877 elif comparison not in DATETRUNC_COMPARISONS: 878 return expression 879 880 if isinstance(expression, exp.Binary): 881 l, r = expression.left, expression.right 882 883 if not _is_datetrunc_predicate(l, r): 884 return expression 885 886 l = t.cast(exp.DateTrunc, l) 887 unit = l.unit.name.lower() 888 date = extract_date(r) 889 890 if not date: 891 return expression 892 893 return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression 894 elif isinstance(expression, exp.In): 895 l = expression.this 896 rs = expression.expressions 897 898 if rs and all(_is_datetrunc_predicate(l, r) for r in rs): 899 l = t.cast(exp.DateTrunc, l) 900 unit = l.unit.name.lower() 901 902 ranges = [] 903 for r in rs: 904 date = extract_date(r) 905 if not date: 906 return expression 907 drange = _datetrunc_range(date, unit, dialect) 908 if drange: 909 ranges.append(drange) 910 911 if not ranges: 912 return expression 913 914 ranges = merge_ranges(ranges) 915 916 return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False) 917 918 return expression 919 920 921def sort_comparison(expression: exp.Expression) -> exp.Expression: 922 if expression.__class__ in COMPLEMENT_COMPARISONS: 923 l, r = expression.this, expression.expression 924 l_column = isinstance(l, exp.Column) 925 r_column = isinstance(r, exp.Column) 926 l_const = _is_constant(l) 927 r_const = _is_constant(r) 928 929 if (l_column and not r_column) or (r_const and not l_const): 930 return expression 931 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 932 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 933 this=r, expression=l 934 ) 935 return expression 936 937 938# CROSS joins result in an empty table if the right table is empty. 939# So we can only simplify certain types of joins to CROSS. 940# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 941JOINS = { 942 ("", ""), 943 ("", "INNER"), 944 ("RIGHT", ""), 945 ("RIGHT", "OUTER"), 946} 947 948 949def remove_where_true(expression): 950 for where in expression.find_all(exp.Where): 951 if always_true(where.this): 952 where.pop() 953 for join in expression.find_all(exp.Join): 954 if ( 955 always_true(join.args.get("on")) 956 and not join.args.get("using") 957 and not join.args.get("method") 958 and (join.side, join.kind) in JOINS 959 ): 960 join.args["on"].pop() 961 join.set("side", None) 962 join.set("kind", "CROSS") 963 964 965def always_true(expression): 966 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 967 expression, exp.Literal 968 ) 969 970 971def always_false(expression): 972 return is_false(expression) or is_null(expression) 973 974 975def is_complement(a, b): 976 return isinstance(b, exp.Not) and b.this == a 977 978 979def is_false(a: exp.Expression) -> bool: 980 return type(a) is exp.Boolean and not a.this 981 982 983def is_null(a: exp.Expression) -> bool: 984 return type(a) is exp.Null 985 986 987def eval_boolean(expression, a, b): 988 if isinstance(expression, (exp.EQ, exp.Is)): 989 return boolean_literal(a == b) 990 if isinstance(expression, exp.NEQ): 991 return boolean_literal(a != b) 992 if isinstance(expression, exp.GT): 993 return boolean_literal(a > b) 994 if isinstance(expression, exp.GTE): 995 return boolean_literal(a >= b) 996 if isinstance(expression, exp.LT): 997 return boolean_literal(a < b) 998 if isinstance(expression, exp.LTE): 999 return boolean_literal(a <= b) 1000 return None 1001 1002 1003def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1004 if isinstance(value, datetime.datetime): 1005 return value.date() 1006 if isinstance(value, datetime.date): 1007 return value 1008 try: 1009 return datetime.datetime.fromisoformat(value).date() 1010 except ValueError: 1011 return None 1012 1013 1014def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1015 if isinstance(value, datetime.datetime): 1016 return value 1017 if isinstance(value, datetime.date): 1018 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1019 try: 1020 return datetime.datetime.fromisoformat(value) 1021 except ValueError: 1022 return None 1023 1024 1025def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1026 if not value: 1027 return None 1028 if to.is_type(exp.DataType.Type.DATE): 1029 return cast_as_date(value) 1030 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1031 return cast_as_datetime(value) 1032 return None 1033 1034 1035def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1036 if isinstance(cast, exp.Cast): 1037 to = cast.to 1038 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1039 to = exp.DataType.build(exp.DataType.Type.DATE) 1040 else: 1041 return None 1042 1043 if isinstance(cast.this, exp.Literal): 1044 value: t.Any = cast.this.name 1045 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1046 value = extract_date(cast.this) 1047 else: 1048 return None 1049 return cast_value(value, to) 1050 1051 1052def _is_date_literal(expression: exp.Expression) -> bool: 1053 return extract_date(expression) is not None 1054 1055 1056def extract_interval(expression): 1057 try: 1058 n = int(expression.name) 1059 unit = expression.text("unit").lower() 1060 return interval(unit, n) 1061 except (UnsupportedUnit, ModuleNotFoundError, ValueError): 1062 return None 1063 1064 1065def date_literal(date): 1066 return exp.cast( 1067 exp.Literal.string(date), 1068 ( 1069 exp.DataType.Type.DATETIME 1070 if isinstance(date, datetime.datetime) 1071 else exp.DataType.Type.DATE 1072 ), 1073 ) 1074 1075 1076def interval(unit: str, n: int = 1): 1077 from dateutil.relativedelta import relativedelta 1078 1079 if unit == "year": 1080 return relativedelta(years=1 * n) 1081 if unit == "quarter": 1082 return relativedelta(months=3 * n) 1083 if unit == "month": 1084 return relativedelta(months=1 * n) 1085 if unit == "week": 1086 return relativedelta(weeks=1 * n) 1087 if unit == "day": 1088 return relativedelta(days=1 * n) 1089 if unit == "hour": 1090 return relativedelta(hours=1 * n) 1091 if unit == "minute": 1092 return relativedelta(minutes=1 * n) 1093 if unit == "second": 1094 return relativedelta(seconds=1 * n) 1095 1096 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1097 1098 1099def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1100 if unit == "year": 1101 return d.replace(month=1, day=1) 1102 if unit == "quarter": 1103 if d.month <= 3: 1104 return d.replace(month=1, day=1) 1105 elif d.month <= 6: 1106 return d.replace(month=4, day=1) 1107 elif d.month <= 9: 1108 return d.replace(month=7, day=1) 1109 else: 1110 return d.replace(month=10, day=1) 1111 if unit == "month": 1112 return d.replace(month=d.month, day=1) 1113 if unit == "week": 1114 # Assuming week starts on Monday (0) and ends on Sunday (6) 1115 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1116 if unit == "day": 1117 return d 1118 1119 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1120 1121 1122def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1123 floor = date_floor(d, unit, dialect) 1124 1125 if floor == d: 1126 return d 1127 1128 return floor + interval(unit) 1129 1130 1131def boolean_literal(condition): 1132 return exp.true() if condition else exp.false() 1133 1134 1135def _flat_simplify(expression, simplifier, root=True): 1136 if root or not expression.same_parent: 1137 operands = [] 1138 queue = deque(expression.flatten(unnest=False)) 1139 size = len(queue) 1140 1141 while queue: 1142 a = queue.popleft() 1143 1144 for b in queue: 1145 result = simplifier(expression, a, b) 1146 1147 if result and result is not expression: 1148 queue.remove(b) 1149 queue.appendleft(result) 1150 break 1151 else: 1152 operands.append(a) 1153 1154 if len(operands) < size: 1155 return functools.reduce( 1156 lambda a, b: expression.__class__(this=a, expression=b), operands 1157 ) 1158 return expression 1159 1160 1161def gen(expression: t.Any) -> str: 1162 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1163 1164 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1165 generator is expensive so we have a bare minimum sql generator here. 1166 """ 1167 return Gen().gen(expression) 1168 1169 1170class Gen: 1171 def __init__(self): 1172 self.stack = [] 1173 self.sqls = [] 1174 1175 def gen(self, expression: exp.Expression) -> str: 1176 self.stack = [expression] 1177 self.sqls.clear() 1178 1179 while self.stack: 1180 node = self.stack.pop() 1181 1182 if isinstance(node, exp.Expression): 1183 exp_handler_name = f"{node.key}_sql" 1184 1185 if hasattr(self, exp_handler_name): 1186 getattr(self, exp_handler_name)(node) 1187 elif isinstance(node, exp.Func): 1188 self._function(node) 1189 else: 1190 key = node.key.upper() 1191 self.stack.append(f"{key} " if self._args(node) else key) 1192 elif type(node) is list: 1193 for n in reversed(node): 1194 if n is not None: 1195 self.stack.extend((n, ",")) 1196 if node: 1197 self.stack.pop() 1198 else: 1199 if node is not None: 1200 self.sqls.append(str(node)) 1201 1202 return "".join(self.sqls) 1203 1204 def add_sql(self, e: exp.Add) -> None: 1205 self._binary(e, " + ") 1206 1207 def alias_sql(self, e: exp.Alias) -> None: 1208 self.stack.extend( 1209 ( 1210 e.args.get("alias"), 1211 " AS ", 1212 e.args.get("this"), 1213 ) 1214 ) 1215 1216 def and_sql(self, e: exp.And) -> None: 1217 self._binary(e, " AND ") 1218 1219 def anonymous_sql(self, e: exp.Anonymous) -> None: 1220 this = e.this 1221 if isinstance(this, str): 1222 name = this.upper() 1223 elif isinstance(this, exp.Identifier): 1224 name = this.this 1225 name = f'"{name}"' if this.quoted else name.upper() 1226 else: 1227 raise ValueError( 1228 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1229 ) 1230 1231 self.stack.extend( 1232 ( 1233 ")", 1234 e.expressions, 1235 "(", 1236 name, 1237 ) 1238 ) 1239 1240 def between_sql(self, e: exp.Between) -> None: 1241 self.stack.extend( 1242 ( 1243 e.args.get("high"), 1244 " AND ", 1245 e.args.get("low"), 1246 " BETWEEN ", 1247 e.this, 1248 ) 1249 ) 1250 1251 def boolean_sql(self, e: exp.Boolean) -> None: 1252 self.stack.append("TRUE" if e.this else "FALSE") 1253 1254 def bracket_sql(self, e: exp.Bracket) -> None: 1255 self.stack.extend( 1256 ( 1257 "]", 1258 e.expressions, 1259 "[", 1260 e.this, 1261 ) 1262 ) 1263 1264 def column_sql(self, e: exp.Column) -> None: 1265 for p in reversed(e.parts): 1266 self.stack.extend((p, ".")) 1267 self.stack.pop() 1268 1269 def datatype_sql(self, e: exp.DataType) -> None: 1270 self._args(e, 1) 1271 self.stack.append(f"{e.this.name} ") 1272 1273 def div_sql(self, e: exp.Div) -> None: 1274 self._binary(e, " / ") 1275 1276 def dot_sql(self, e: exp.Dot) -> None: 1277 self._binary(e, ".") 1278 1279 def eq_sql(self, e: exp.EQ) -> None: 1280 self._binary(e, " = ") 1281 1282 def from_sql(self, e: exp.From) -> None: 1283 self.stack.extend((e.this, "FROM ")) 1284 1285 def gt_sql(self, e: exp.GT) -> None: 1286 self._binary(e, " > ") 1287 1288 def gte_sql(self, e: exp.GTE) -> None: 1289 self._binary(e, " >= ") 1290 1291 def identifier_sql(self, e: exp.Identifier) -> None: 1292 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1293 1294 def ilike_sql(self, e: exp.ILike) -> None: 1295 self._binary(e, " ILIKE ") 1296 1297 def in_sql(self, e: exp.In) -> None: 1298 self.stack.append(")") 1299 self._args(e, 1) 1300 self.stack.extend( 1301 ( 1302 "(", 1303 " IN ", 1304 e.this, 1305 ) 1306 ) 1307 1308 def intdiv_sql(self, e: exp.IntDiv) -> None: 1309 self._binary(e, " DIV ") 1310 1311 def is_sql(self, e: exp.Is) -> None: 1312 self._binary(e, " IS ") 1313 1314 def like_sql(self, e: exp.Like) -> None: 1315 self._binary(e, " Like ") 1316 1317 def literal_sql(self, e: exp.Literal) -> None: 1318 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1319 1320 def lt_sql(self, e: exp.LT) -> None: 1321 self._binary(e, " < ") 1322 1323 def lte_sql(self, e: exp.LTE) -> None: 1324 self._binary(e, " <= ") 1325 1326 def mod_sql(self, e: exp.Mod) -> None: 1327 self._binary(e, " % ") 1328 1329 def mul_sql(self, e: exp.Mul) -> None: 1330 self._binary(e, " * ") 1331 1332 def neg_sql(self, e: exp.Neg) -> None: 1333 self._unary(e, "-") 1334 1335 def neq_sql(self, e: exp.NEQ) -> None: 1336 self._binary(e, " <> ") 1337 1338 def not_sql(self, e: exp.Not) -> None: 1339 self._unary(e, "NOT ") 1340 1341 def null_sql(self, e: exp.Null) -> None: 1342 self.stack.append("NULL") 1343 1344 def or_sql(self, e: exp.Or) -> None: 1345 self._binary(e, " OR ") 1346 1347 def paren_sql(self, e: exp.Paren) -> None: 1348 self.stack.extend( 1349 ( 1350 ")", 1351 e.this, 1352 "(", 1353 ) 1354 ) 1355 1356 def sub_sql(self, e: exp.Sub) -> None: 1357 self._binary(e, " - ") 1358 1359 def subquery_sql(self, e: exp.Subquery) -> None: 1360 self._args(e, 2) 1361 alias = e.args.get("alias") 1362 if alias: 1363 self.stack.append(alias) 1364 self.stack.extend((")", e.this, "(")) 1365 1366 def table_sql(self, e: exp.Table) -> None: 1367 self._args(e, 4) 1368 alias = e.args.get("alias") 1369 if alias: 1370 self.stack.append(alias) 1371 for p in reversed(e.parts): 1372 self.stack.extend((p, ".")) 1373 self.stack.pop() 1374 1375 def tablealias_sql(self, e: exp.TableAlias) -> None: 1376 columns = e.columns 1377 1378 if columns: 1379 self.stack.extend((")", columns, "(")) 1380 1381 self.stack.extend((e.this, " AS ")) 1382 1383 def var_sql(self, e: exp.Var) -> None: 1384 self.stack.append(e.this) 1385 1386 def _binary(self, e: exp.Binary, op: str) -> None: 1387 self.stack.extend((e.expression, op, e.this)) 1388 1389 def _unary(self, e: exp.Unary, op: str) -> None: 1390 self.stack.extend((e.this, op)) 1391 1392 def _function(self, e: exp.Func) -> None: 1393 self.stack.extend( 1394 ( 1395 ")", 1396 list(e.args.values()), 1397 "(", 1398 e.sql_name(), 1399 ) 1400 ) 1401 1402 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1403 kvs = [] 1404 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1405 1406 for k in arg_types or arg_types: 1407 v = node.args.get(k) 1408 1409 if v is not None: 1410 kvs.append([f":{k}", v]) 1411 if kvs: 1412 self.stack.append(kvs) 1413 return True 1414 return False
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- args
31def simplify( 32 expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None 33): 34 """ 35 Rewrite sqlglot AST to simplify expressions. 36 37 Example: 38 >>> import sqlglot 39 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 40 >>> simplify(expression).sql() 41 'TRUE' 42 43 Args: 44 expression (sqlglot.Expression): expression to simplify 45 constant_propagation: whether the constant propagation rule should be used 46 47 Returns: 48 sqlglot.Expression: simplified expression 49 """ 50 51 dialect = Dialect.get_or_raise(dialect) 52 53 def _simplify(expression, root=True): 54 if expression.meta.get(FINAL): 55 return expression 56 57 # group by expressions cannot be simplified, for example 58 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 59 # the projection must exactly match the group by key 60 group = expression.args.get("group") 61 62 if group and hasattr(expression, "selects"): 63 groups = set(group.expressions) 64 group.meta[FINAL] = True 65 66 for e in expression.selects: 67 for node in e.walk(): 68 if node in groups: 69 e.meta[FINAL] = True 70 break 71 72 having = expression.args.get("having") 73 if having: 74 for node in having.walk(): 75 if node in groups: 76 having.meta[FINAL] = True 77 break 78 79 # Pre-order transformations 80 node = expression 81 node = rewrite_between(node) 82 node = uniq_sort(node, root) 83 node = absorb_and_eliminate(node, root) 84 node = simplify_concat(node) 85 node = simplify_conditionals(node) 86 87 if constant_propagation: 88 node = propagate_constants(node, root) 89 90 exp.replace_children(node, lambda e: _simplify(e, False)) 91 92 # Post-order transformations 93 node = simplify_not(node) 94 node = flatten(node) 95 node = simplify_connectors(node, root) 96 node = remove_complements(node, root) 97 node = simplify_coalesce(node) 98 node.parent = expression.parent 99 node = simplify_literals(node, root) 100 node = simplify_equality(node) 101 node = simplify_parens(node) 102 node = simplify_datetrunc(node, dialect) 103 node = sort_comparison(node) 104 node = simplify_startswith(node) 105 106 if root: 107 expression.replace(node) 108 return node 109 110 expression = while_changing(expression, _simplify) 111 remove_where_true(expression) 112 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression (sqlglot.Expression): expression to simplify
- constant_propagation: whether the constant propagation rule should be used
Returns:
sqlglot.Expression: simplified expression
115def catch(*exceptions): 116 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 117 118 def decorator(func): 119 def wrapped(expression, *args, **kwargs): 120 try: 121 return func(expression, *args, **kwargs) 122 except exceptions: 123 return expression 124 125 return wrapped 126 127 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
130def rewrite_between(expression: exp.Expression) -> exp.Expression: 131 """Rewrite x between y and z to x >= y AND x <= z. 132 133 This is done because comparison simplification is only done on lt/lte/gt/gte. 134 """ 135 if isinstance(expression, exp.Between): 136 negate = isinstance(expression.parent, exp.Not) 137 138 expression = exp.and_( 139 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 140 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 141 copy=False, 142 ) 143 144 if negate: 145 expression = exp.paren(expression, copy=False) 146 147 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
160def simplify_not(expression): 161 """ 162 Demorgan's Law 163 NOT (x OR y) -> NOT x AND NOT y 164 NOT (x AND y) -> NOT x OR NOT y 165 """ 166 if isinstance(expression, exp.Not): 167 this = expression.this 168 if is_null(this): 169 return exp.null() 170 if this.__class__ in COMPLEMENT_COMPARISONS: 171 return COMPLEMENT_COMPARISONS[this.__class__]( 172 this=this.this, expression=this.expression 173 ) 174 if isinstance(this, exp.Paren): 175 condition = this.unnest() 176 if isinstance(condition, exp.And): 177 return exp.paren( 178 exp.or_( 179 exp.not_(condition.left, copy=False), 180 exp.not_(condition.right, copy=False), 181 copy=False, 182 ) 183 ) 184 if isinstance(condition, exp.Or): 185 return exp.paren( 186 exp.and_( 187 exp.not_(condition.left, copy=False), 188 exp.not_(condition.right, copy=False), 189 copy=False, 190 ) 191 ) 192 if is_null(condition): 193 return exp.null() 194 if always_true(this): 195 return exp.false() 196 if is_false(this): 197 return exp.true() 198 if isinstance(this, exp.Not): 199 # double negation 200 # NOT NOT x -> x 201 return this.this 202 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
205def flatten(expression): 206 """ 207 A AND (B AND C) -> A AND B AND C 208 A OR (B OR C) -> A OR B OR C 209 """ 210 if isinstance(expression, exp.Connector): 211 for node in expression.args.values(): 212 child = node.unnest() 213 if isinstance(child, expression.__class__): 214 node.replace(child) 215 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
218def simplify_connectors(expression, root=True): 219 def _simplify_connectors(expression, left, right): 220 if left == right: 221 return left 222 if isinstance(expression, exp.And): 223 if is_false(left) or is_false(right): 224 return exp.false() 225 if is_null(left) or is_null(right): 226 return exp.null() 227 if always_true(left) and always_true(right): 228 return exp.true() 229 if always_true(left): 230 return right 231 if always_true(right): 232 return left 233 return _simplify_comparison(expression, left, right) 234 elif isinstance(expression, exp.Or): 235 if always_true(left) or always_true(right): 236 return exp.true() 237 if is_false(left) and is_false(right): 238 return exp.false() 239 if ( 240 (is_null(left) and is_null(right)) 241 or (is_null(left) and is_false(right)) 242 or (is_false(left) and is_null(right)) 243 ): 244 return exp.null() 245 if is_false(left): 246 return right 247 if is_false(right): 248 return left 249 return _simplify_comparison(expression, left, right, or_=True) 250 251 if isinstance(expression, exp.Connector): 252 return _flat_simplify(expression, _simplify_connectors, root) 253 return expression
337def remove_complements(expression, root=True): 338 """ 339 Removing complements. 340 341 A AND NOT A -> FALSE 342 A OR NOT A -> TRUE 343 """ 344 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 345 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 346 347 for a, b in itertools.permutations(expression.flatten(), 2): 348 if is_complement(a, b): 349 return complement 350 return expression
Removing complements.
A AND NOT A -> FALSE A OR NOT A -> TRUE
353def uniq_sort(expression, root=True): 354 """ 355 Uniq and sort a connector. 356 357 C AND A AND B AND B -> A AND B AND C 358 """ 359 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 360 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 361 flattened = tuple(expression.flatten()) 362 deduped = {gen(e): e for e in flattened} 363 arr = tuple(deduped.items()) 364 365 # check if the operands are already sorted, if not sort them 366 # A AND C AND B -> A AND B AND C 367 for i, (sql, e) in enumerate(arr[1:]): 368 if sql < arr[i][0]: 369 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 370 break 371 else: 372 # we didn't have to sort but maybe we need to dedup 373 if len(deduped) < len(flattened): 374 expression = result_func(*deduped.values(), copy=False) 375 376 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
379def absorb_and_eliminate(expression, root=True): 380 """ 381 absorption: 382 A AND (A OR B) -> A 383 A OR (A AND B) -> A 384 A AND (NOT A OR B) -> A AND B 385 A OR (NOT A AND B) -> A OR B 386 elimination: 387 (A AND B) OR (A AND NOT B) -> A 388 (A OR B) AND (A OR NOT B) -> A 389 """ 390 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 391 kind = exp.Or if isinstance(expression, exp.And) else exp.And 392 393 for a, b in itertools.permutations(expression.flatten(), 2): 394 if isinstance(a, kind): 395 aa, ab = a.unnest_operands() 396 397 # absorb 398 if is_complement(b, aa): 399 aa.replace(exp.true() if kind == exp.And else exp.false()) 400 elif is_complement(b, ab): 401 ab.replace(exp.true() if kind == exp.And else exp.false()) 402 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 403 a.replace(exp.false() if kind == exp.And else exp.true()) 404 elif isinstance(b, kind): 405 # eliminate 406 rhs = b.unnest_operands() 407 ba, bb = rhs 408 409 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 410 a.replace(aa) 411 b.replace(aa) 412 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 413 a.replace(ab) 414 b.replace(ab) 415 416 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
419def propagate_constants(expression, root=True): 420 """ 421 Propagate constants for conjunctions in DNF: 422 423 SELECT * FROM t WHERE a = b AND b = 5 becomes 424 SELECT * FROM t WHERE a = 5 AND b = 5 425 426 Reference: https://www.sqlite.org/optoverview.html 427 """ 428 429 if ( 430 isinstance(expression, exp.And) 431 and (root or not expression.same_parent) 432 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 433 ): 434 constant_mapping = {} 435 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 436 if isinstance(expr, exp.EQ): 437 l, r = expr.left, expr.right 438 439 # TODO: create a helper that can be used to detect nested literal expressions such 440 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 441 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 442 constant_mapping[l] = (id(l), r) 443 444 if constant_mapping: 445 for column in find_all_in_scope(expression, exp.Column): 446 parent = column.parent 447 column_id, constant = constant_mapping.get(column) or (None, None) 448 if ( 449 column_id is not None 450 and id(column) != column_id 451 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 452 ): 453 column.replace(constant.copy()) 454 455 return expression
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
119 def wrapped(expression, *args, **kwargs): 120 try: 121 return func(expression, *args, **kwargs) 122 except exceptions: 123 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
530def simplify_literals(expression, root=True): 531 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 532 return _flat_simplify(expression, _simplify_binary, root) 533 534 if isinstance(expression, exp.Neg): 535 this = expression.this 536 if this.is_number: 537 value = this.name 538 if value[0] == "-": 539 return exp.Literal.number(value[1:]) 540 return exp.Literal.number(f"-{value}") 541 542 if type(expression) in INVERSE_DATE_OPS: 543 return _simplify_binary(expression, expression.this, expression.interval()) or expression 544 545 return expression
619def simplify_parens(expression): 620 if not isinstance(expression, exp.Paren): 621 return expression 622 623 this = expression.this 624 parent = expression.parent 625 parent_is_predicate = isinstance(parent, exp.Predicate) 626 627 if not isinstance(this, exp.Select) and ( 628 not isinstance(parent, (exp.Condition, exp.Binary)) 629 or isinstance(parent, exp.Paren) 630 or ( 631 not isinstance(this, exp.Binary) 632 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 633 ) 634 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 635 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 636 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 637 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 638 ): 639 return this 640 return expression
651def simplify_coalesce(expression): 652 # COALESCE(x) -> x 653 if ( 654 isinstance(expression, exp.Coalesce) 655 and (not expression.expressions or _is_nonnull_constant(expression.this)) 656 # COALESCE is also used as a Spark partitioning hint 657 and not isinstance(expression.parent, exp.Hint) 658 ): 659 return expression.this 660 661 if not isinstance(expression, COMPARISONS): 662 return expression 663 664 if isinstance(expression.left, exp.Coalesce): 665 coalesce = expression.left 666 other = expression.right 667 elif isinstance(expression.right, exp.Coalesce): 668 coalesce = expression.right 669 other = expression.left 670 else: 671 return expression 672 673 # This transformation is valid for non-constants, 674 # but it really only does anything if they are both constants. 675 if not _is_constant(other): 676 return expression 677 678 # Find the first constant arg 679 for arg_index, arg in enumerate(coalesce.expressions): 680 if _is_constant(arg): 681 break 682 else: 683 return expression 684 685 coalesce.set("expressions", coalesce.expressions[:arg_index]) 686 687 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 688 # since we already remove COALESCE at the top of this function. 689 coalesce = coalesce if coalesce.expressions else coalesce.this 690 691 # This expression is more complex than when we started, but it will get simplified further 692 return exp.paren( 693 exp.or_( 694 exp.and_( 695 coalesce.is_(exp.null()).not_(copy=False), 696 expression.copy(), 697 copy=False, 698 ), 699 exp.and_( 700 coalesce.is_(exp.null()), 701 type(expression)(this=arg.copy(), expression=other.copy()), 702 copy=False, 703 ), 704 copy=False, 705 ) 706 )
712def simplify_concat(expression): 713 """Reduces all groups that contain string literals by concatenating them.""" 714 if not isinstance(expression, CONCATS) or ( 715 # We can't reduce a CONCAT_WS call if we don't statically know the separator 716 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 717 ): 718 return expression 719 720 if isinstance(expression, exp.ConcatWs): 721 sep_expr, *expressions = expression.expressions 722 sep = sep_expr.name 723 concat_type = exp.ConcatWs 724 args = {} 725 else: 726 expressions = expression.expressions 727 sep = "" 728 concat_type = exp.Concat 729 args = { 730 "safe": expression.args.get("safe"), 731 "coalesce": expression.args.get("coalesce"), 732 } 733 734 new_args = [] 735 for is_string_group, group in itertools.groupby( 736 expressions or expression.flatten(), lambda e: e.is_string 737 ): 738 if is_string_group: 739 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 740 else: 741 new_args.extend(group) 742 743 if len(new_args) == 1 and new_args[0].is_string: 744 return new_args[0] 745 746 if concat_type is exp.ConcatWs: 747 new_args = [sep_expr] + new_args 748 749 return concat_type(expressions=new_args, **args)
Reduces all groups that contain string literals by concatenating them.
752def simplify_conditionals(expression): 753 """Simplifies expressions like IF, CASE if their condition is statically known.""" 754 if isinstance(expression, exp.Case): 755 this = expression.this 756 for case in expression.args["ifs"]: 757 cond = case.this 758 if this: 759 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 760 cond = cond.replace(this.pop().eq(cond)) 761 762 if always_true(cond): 763 return case.args["true"] 764 765 if always_false(cond): 766 case.pop() 767 if not expression.args["ifs"]: 768 return expression.args.get("default") or exp.null() 769 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 770 if always_true(expression.this): 771 return expression.args["true"] 772 if always_false(expression.this): 773 return expression.args.get("false") or exp.null() 774 775 return expression
Simplifies expressions like IF, CASE if their condition is statically known.
778def simplify_startswith(expression: exp.Expression) -> exp.Expression: 779 """ 780 Reduces a prefix check to either TRUE or FALSE if both the string and the 781 prefix are statically known. 782 783 Example: 784 >>> from sqlglot import parse_one 785 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 786 'TRUE' 787 """ 788 if ( 789 isinstance(expression, exp.StartsWith) 790 and expression.this.is_string 791 and expression.expression.is_string 792 ): 793 return exp.convert(expression.name.startswith(expression.expression.name)) 794 795 return expression
Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.
Example:
>>> from sqlglot import parse_one >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 'TRUE'
119 def wrapped(expression, *args, **kwargs): 120 try: 121 return func(expression, *args, **kwargs) 122 except exceptions: 123 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
922def sort_comparison(expression: exp.Expression) -> exp.Expression: 923 if expression.__class__ in COMPLEMENT_COMPARISONS: 924 l, r = expression.this, expression.expression 925 l_column = isinstance(l, exp.Column) 926 r_column = isinstance(r, exp.Column) 927 l_const = _is_constant(l) 928 r_const = _is_constant(r) 929 930 if (l_column and not r_column) or (r_const and not l_const): 931 return expression 932 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 933 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 934 this=r, expression=l 935 ) 936 return expression
950def remove_where_true(expression): 951 for where in expression.find_all(exp.Where): 952 if always_true(where.this): 953 where.pop() 954 for join in expression.find_all(exp.Join): 955 if ( 956 always_true(join.args.get("on")) 957 and not join.args.get("using") 958 and not join.args.get("method") 959 and (join.side, join.kind) in JOINS 960 ): 961 join.args["on"].pop() 962 join.set("side", None) 963 join.set("kind", "CROSS")
988def eval_boolean(expression, a, b): 989 if isinstance(expression, (exp.EQ, exp.Is)): 990 return boolean_literal(a == b) 991 if isinstance(expression, exp.NEQ): 992 return boolean_literal(a != b) 993 if isinstance(expression, exp.GT): 994 return boolean_literal(a > b) 995 if isinstance(expression, exp.GTE): 996 return boolean_literal(a >= b) 997 if isinstance(expression, exp.LT): 998 return boolean_literal(a < b) 999 if isinstance(expression, exp.LTE): 1000 return boolean_literal(a <= b) 1001 return None
1004def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1005 if isinstance(value, datetime.datetime): 1006 return value.date() 1007 if isinstance(value, datetime.date): 1008 return value 1009 try: 1010 return datetime.datetime.fromisoformat(value).date() 1011 except ValueError: 1012 return None
1015def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1016 if isinstance(value, datetime.datetime): 1017 return value 1018 if isinstance(value, datetime.date): 1019 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1020 try: 1021 return datetime.datetime.fromisoformat(value) 1022 except ValueError: 1023 return None
1026def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1027 if not value: 1028 return None 1029 if to.is_type(exp.DataType.Type.DATE): 1030 return cast_as_date(value) 1031 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1032 return cast_as_datetime(value) 1033 return None
1036def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1037 if isinstance(cast, exp.Cast): 1038 to = cast.to 1039 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1040 to = exp.DataType.build(exp.DataType.Type.DATE) 1041 else: 1042 return None 1043 1044 if isinstance(cast.this, exp.Literal): 1045 value: t.Any = cast.this.name 1046 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1047 value = extract_date(cast.this) 1048 else: 1049 return None 1050 return cast_value(value, to)
1077def interval(unit: str, n: int = 1): 1078 from dateutil.relativedelta import relativedelta 1079 1080 if unit == "year": 1081 return relativedelta(years=1 * n) 1082 if unit == "quarter": 1083 return relativedelta(months=3 * n) 1084 if unit == "month": 1085 return relativedelta(months=1 * n) 1086 if unit == "week": 1087 return relativedelta(weeks=1 * n) 1088 if unit == "day": 1089 return relativedelta(days=1 * n) 1090 if unit == "hour": 1091 return relativedelta(hours=1 * n) 1092 if unit == "minute": 1093 return relativedelta(minutes=1 * n) 1094 if unit == "second": 1095 return relativedelta(seconds=1 * n) 1096 1097 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1100def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1101 if unit == "year": 1102 return d.replace(month=1, day=1) 1103 if unit == "quarter": 1104 if d.month <= 3: 1105 return d.replace(month=1, day=1) 1106 elif d.month <= 6: 1107 return d.replace(month=4, day=1) 1108 elif d.month <= 9: 1109 return d.replace(month=7, day=1) 1110 else: 1111 return d.replace(month=10, day=1) 1112 if unit == "month": 1113 return d.replace(month=d.month, day=1) 1114 if unit == "week": 1115 # Assuming week starts on Monday (0) and ends on Sunday (6) 1116 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1117 if unit == "day": 1118 return d 1119 1120 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1162def gen(expression: t.Any) -> str: 1163 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1164 1165 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1166 generator is expensive so we have a bare minimum sql generator here. 1167 """ 1168 return Gen().gen(expression)
Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.
1171class Gen: 1172 def __init__(self): 1173 self.stack = [] 1174 self.sqls = [] 1175 1176 def gen(self, expression: exp.Expression) -> str: 1177 self.stack = [expression] 1178 self.sqls.clear() 1179 1180 while self.stack: 1181 node = self.stack.pop() 1182 1183 if isinstance(node, exp.Expression): 1184 exp_handler_name = f"{node.key}_sql" 1185 1186 if hasattr(self, exp_handler_name): 1187 getattr(self, exp_handler_name)(node) 1188 elif isinstance(node, exp.Func): 1189 self._function(node) 1190 else: 1191 key = node.key.upper() 1192 self.stack.append(f"{key} " if self._args(node) else key) 1193 elif type(node) is list: 1194 for n in reversed(node): 1195 if n is not None: 1196 self.stack.extend((n, ",")) 1197 if node: 1198 self.stack.pop() 1199 else: 1200 if node is not None: 1201 self.sqls.append(str(node)) 1202 1203 return "".join(self.sqls) 1204 1205 def add_sql(self, e: exp.Add) -> None: 1206 self._binary(e, " + ") 1207 1208 def alias_sql(self, e: exp.Alias) -> None: 1209 self.stack.extend( 1210 ( 1211 e.args.get("alias"), 1212 " AS ", 1213 e.args.get("this"), 1214 ) 1215 ) 1216 1217 def and_sql(self, e: exp.And) -> None: 1218 self._binary(e, " AND ") 1219 1220 def anonymous_sql(self, e: exp.Anonymous) -> None: 1221 this = e.this 1222 if isinstance(this, str): 1223 name = this.upper() 1224 elif isinstance(this, exp.Identifier): 1225 name = this.this 1226 name = f'"{name}"' if this.quoted else name.upper() 1227 else: 1228 raise ValueError( 1229 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1230 ) 1231 1232 self.stack.extend( 1233 ( 1234 ")", 1235 e.expressions, 1236 "(", 1237 name, 1238 ) 1239 ) 1240 1241 def between_sql(self, e: exp.Between) -> None: 1242 self.stack.extend( 1243 ( 1244 e.args.get("high"), 1245 " AND ", 1246 e.args.get("low"), 1247 " BETWEEN ", 1248 e.this, 1249 ) 1250 ) 1251 1252 def boolean_sql(self, e: exp.Boolean) -> None: 1253 self.stack.append("TRUE" if e.this else "FALSE") 1254 1255 def bracket_sql(self, e: exp.Bracket) -> None: 1256 self.stack.extend( 1257 ( 1258 "]", 1259 e.expressions, 1260 "[", 1261 e.this, 1262 ) 1263 ) 1264 1265 def column_sql(self, e: exp.Column) -> None: 1266 for p in reversed(e.parts): 1267 self.stack.extend((p, ".")) 1268 self.stack.pop() 1269 1270 def datatype_sql(self, e: exp.DataType) -> None: 1271 self._args(e, 1) 1272 self.stack.append(f"{e.this.name} ") 1273 1274 def div_sql(self, e: exp.Div) -> None: 1275 self._binary(e, " / ") 1276 1277 def dot_sql(self, e: exp.Dot) -> None: 1278 self._binary(e, ".") 1279 1280 def eq_sql(self, e: exp.EQ) -> None: 1281 self._binary(e, " = ") 1282 1283 def from_sql(self, e: exp.From) -> None: 1284 self.stack.extend((e.this, "FROM ")) 1285 1286 def gt_sql(self, e: exp.GT) -> None: 1287 self._binary(e, " > ") 1288 1289 def gte_sql(self, e: exp.GTE) -> None: 1290 self._binary(e, " >= ") 1291 1292 def identifier_sql(self, e: exp.Identifier) -> None: 1293 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1294 1295 def ilike_sql(self, e: exp.ILike) -> None: 1296 self._binary(e, " ILIKE ") 1297 1298 def in_sql(self, e: exp.In) -> None: 1299 self.stack.append(")") 1300 self._args(e, 1) 1301 self.stack.extend( 1302 ( 1303 "(", 1304 " IN ", 1305 e.this, 1306 ) 1307 ) 1308 1309 def intdiv_sql(self, e: exp.IntDiv) -> None: 1310 self._binary(e, " DIV ") 1311 1312 def is_sql(self, e: exp.Is) -> None: 1313 self._binary(e, " IS ") 1314 1315 def like_sql(self, e: exp.Like) -> None: 1316 self._binary(e, " Like ") 1317 1318 def literal_sql(self, e: exp.Literal) -> None: 1319 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1320 1321 def lt_sql(self, e: exp.LT) -> None: 1322 self._binary(e, " < ") 1323 1324 def lte_sql(self, e: exp.LTE) -> None: 1325 self._binary(e, " <= ") 1326 1327 def mod_sql(self, e: exp.Mod) -> None: 1328 self._binary(e, " % ") 1329 1330 def mul_sql(self, e: exp.Mul) -> None: 1331 self._binary(e, " * ") 1332 1333 def neg_sql(self, e: exp.Neg) -> None: 1334 self._unary(e, "-") 1335 1336 def neq_sql(self, e: exp.NEQ) -> None: 1337 self._binary(e, " <> ") 1338 1339 def not_sql(self, e: exp.Not) -> None: 1340 self._unary(e, "NOT ") 1341 1342 def null_sql(self, e: exp.Null) -> None: 1343 self.stack.append("NULL") 1344 1345 def or_sql(self, e: exp.Or) -> None: 1346 self._binary(e, " OR ") 1347 1348 def paren_sql(self, e: exp.Paren) -> None: 1349 self.stack.extend( 1350 ( 1351 ")", 1352 e.this, 1353 "(", 1354 ) 1355 ) 1356 1357 def sub_sql(self, e: exp.Sub) -> None: 1358 self._binary(e, " - ") 1359 1360 def subquery_sql(self, e: exp.Subquery) -> None: 1361 self._args(e, 2) 1362 alias = e.args.get("alias") 1363 if alias: 1364 self.stack.append(alias) 1365 self.stack.extend((")", e.this, "(")) 1366 1367 def table_sql(self, e: exp.Table) -> None: 1368 self._args(e, 4) 1369 alias = e.args.get("alias") 1370 if alias: 1371 self.stack.append(alias) 1372 for p in reversed(e.parts): 1373 self.stack.extend((p, ".")) 1374 self.stack.pop() 1375 1376 def tablealias_sql(self, e: exp.TableAlias) -> None: 1377 columns = e.columns 1378 1379 if columns: 1380 self.stack.extend((")", columns, "(")) 1381 1382 self.stack.extend((e.this, " AS ")) 1383 1384 def var_sql(self, e: exp.Var) -> None: 1385 self.stack.append(e.this) 1386 1387 def _binary(self, e: exp.Binary, op: str) -> None: 1388 self.stack.extend((e.expression, op, e.this)) 1389 1390 def _unary(self, e: exp.Unary, op: str) -> None: 1391 self.stack.extend((e.this, op)) 1392 1393 def _function(self, e: exp.Func) -> None: 1394 self.stack.extend( 1395 ( 1396 ")", 1397 list(e.args.values()), 1398 "(", 1399 e.sql_name(), 1400 ) 1401 ) 1402 1403 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1404 kvs = [] 1405 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1406 1407 for k in arg_types or arg_types: 1408 v = node.args.get(k) 1409 1410 if v is not None: 1411 kvs.append([f":{k}", v]) 1412 if kvs: 1413 self.stack.append(kvs) 1414 return True 1415 return False
1176 def gen(self, expression: exp.Expression) -> str: 1177 self.stack = [expression] 1178 self.sqls.clear() 1179 1180 while self.stack: 1181 node = self.stack.pop() 1182 1183 if isinstance(node, exp.Expression): 1184 exp_handler_name = f"{node.key}_sql" 1185 1186 if hasattr(self, exp_handler_name): 1187 getattr(self, exp_handler_name)(node) 1188 elif isinstance(node, exp.Func): 1189 self._function(node) 1190 else: 1191 key = node.key.upper() 1192 self.stack.append(f"{key} " if self._args(node) else key) 1193 elif type(node) is list: 1194 for n in reversed(node): 1195 if n is not None: 1196 self.stack.extend((n, ",")) 1197 if node: 1198 self.stack.pop() 1199 else: 1200 if node is not None: 1201 self.sqls.append(str(node)) 1202 1203 return "".join(self.sqls)
1220 def anonymous_sql(self, e: exp.Anonymous) -> None: 1221 this = e.this 1222 if isinstance(this, str): 1223 name = this.upper() 1224 elif isinstance(this, exp.Identifier): 1225 name = this.this 1226 name = f'"{name}"' if this.quoted else name.upper() 1227 else: 1228 raise ValueError( 1229 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1230 ) 1231 1232 self.stack.extend( 1233 ( 1234 ")", 1235 e.expressions, 1236 "(", 1237 name, 1238 ) 1239 )