sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses 12from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26 from sqlglot.optimizer.annotate_types import TypeAnnotator 27 28 AnnotatorsType = t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]] 29 30logger = logging.getLogger("sqlglot") 31 32UNESCAPED_SEQUENCES = { 33 "\\a": "\a", 34 "\\b": "\b", 35 "\\f": "\f", 36 "\\n": "\n", 37 "\\r": "\r", 38 "\\t": "\t", 39 "\\v": "\v", 40 "\\\\": "\\", 41} 42 43 44def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 45 return lambda self, e: self._annotate_with_type(e, data_type) 46 47 48class Dialects(str, Enum): 49 """Dialects supported by SQLGLot.""" 50 51 DIALECT = "" 52 53 ATHENA = "athena" 54 BIGQUERY = "bigquery" 55 CLICKHOUSE = "clickhouse" 56 DATABRICKS = "databricks" 57 DORIS = "doris" 58 DRILL = "drill" 59 DUCKDB = "duckdb" 60 HIVE = "hive" 61 MATERIALIZE = "materialize" 62 MYSQL = "mysql" 63 ORACLE = "oracle" 64 POSTGRES = "postgres" 65 PRESTO = "presto" 66 PRQL = "prql" 67 REDSHIFT = "redshift" 68 RISINGWAVE = "risingwave" 69 SNOWFLAKE = "snowflake" 70 SPARK = "spark" 71 SPARK2 = "spark2" 72 SQLITE = "sqlite" 73 STARROCKS = "starrocks" 74 TABLEAU = "tableau" 75 TERADATA = "teradata" 76 TRINO = "trino" 77 TSQL = "tsql" 78 79 80class NormalizationStrategy(str, AutoName): 81 """Specifies the strategy according to which identifiers should be normalized.""" 82 83 LOWERCASE = auto() 84 """Unquoted identifiers are lowercased.""" 85 86 UPPERCASE = auto() 87 """Unquoted identifiers are uppercased.""" 88 89 CASE_SENSITIVE = auto() 90 """Always case-sensitive, regardless of quotes.""" 91 92 CASE_INSENSITIVE = auto() 93 """Always case-insensitive, regardless of quotes.""" 94 95 96class _Dialect(type): 97 classes: t.Dict[str, t.Type[Dialect]] = {} 98 99 def __eq__(cls, other: t.Any) -> bool: 100 if cls is other: 101 return True 102 if isinstance(other, str): 103 return cls is cls.get(other) 104 if isinstance(other, Dialect): 105 return cls is type(other) 106 107 return False 108 109 def __hash__(cls) -> int: 110 return hash(cls.__name__.lower()) 111 112 @classmethod 113 def __getitem__(cls, key: str) -> t.Type[Dialect]: 114 return cls.classes[key] 115 116 @classmethod 117 def get( 118 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 119 ) -> t.Optional[t.Type[Dialect]]: 120 return cls.classes.get(key, default) 121 122 def __new__(cls, clsname, bases, attrs): 123 klass = super().__new__(cls, clsname, bases, attrs) 124 enum = Dialects.__members__.get(clsname.upper()) 125 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 126 127 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 128 klass.FORMAT_TRIE = ( 129 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 130 ) 131 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 132 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 133 klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()} 134 klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING) 135 136 base = seq_get(bases, 0) 137 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 138 base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),) 139 base_parser = (getattr(base, "parser_class", Parser),) 140 base_generator = (getattr(base, "generator_class", Generator),) 141 142 klass.tokenizer_class = klass.__dict__.get( 143 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 144 ) 145 klass.jsonpath_tokenizer_class = klass.__dict__.get( 146 "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {}) 147 ) 148 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 149 klass.generator_class = klass.__dict__.get( 150 "Generator", type("Generator", base_generator, {}) 151 ) 152 153 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 154 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 155 klass.tokenizer_class._IDENTIFIERS.items() 156 )[0] 157 158 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 159 return next( 160 ( 161 (s, e) 162 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 163 if t == token_type 164 ), 165 (None, None), 166 ) 167 168 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 169 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 170 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 171 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 172 173 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 174 klass.UNESCAPED_SEQUENCES = { 175 **UNESCAPED_SEQUENCES, 176 **klass.UNESCAPED_SEQUENCES, 177 } 178 179 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 180 181 klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS 182 183 if enum not in ("", "bigquery"): 184 klass.generator_class.SELECT_KINDS = () 185 186 if enum not in ("", "athena", "presto", "trino"): 187 klass.generator_class.TRY_SUPPORTED = False 188 klass.generator_class.SUPPORTS_UESCAPE = False 189 190 if enum not in ("", "databricks", "hive", "spark", "spark2"): 191 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 192 for modifier in ("cluster", "distribute", "sort"): 193 modifier_transforms.pop(modifier, None) 194 195 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 196 197 if enum not in ("", "doris", "mysql"): 198 klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { 199 TokenType.STRAIGHT_JOIN, 200 } 201 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 202 TokenType.STRAIGHT_JOIN, 203 } 204 205 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 206 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 207 TokenType.ANTI, 208 TokenType.SEMI, 209 } 210 211 return klass 212 213 214class Dialect(metaclass=_Dialect): 215 INDEX_OFFSET = 0 216 """The base index offset for arrays.""" 217 218 WEEK_OFFSET = 0 219 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 220 221 UNNEST_COLUMN_ONLY = False 222 """Whether `UNNEST` table aliases are treated as column aliases.""" 223 224 ALIAS_POST_TABLESAMPLE = False 225 """Whether the table alias comes after tablesample.""" 226 227 TABLESAMPLE_SIZE_IS_PERCENT = False 228 """Whether a size in the table sample clause represents percentage.""" 229 230 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 231 """Specifies the strategy according to which identifiers should be normalized.""" 232 233 IDENTIFIERS_CAN_START_WITH_DIGIT = False 234 """Whether an unquoted identifier can start with a digit.""" 235 236 DPIPE_IS_STRING_CONCAT = True 237 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 238 239 STRICT_STRING_CONCAT = False 240 """Whether `CONCAT`'s arguments must be strings.""" 241 242 SUPPORTS_USER_DEFINED_TYPES = True 243 """Whether user-defined data types are supported.""" 244 245 SUPPORTS_SEMI_ANTI_JOIN = True 246 """Whether `SEMI` or `ANTI` joins are supported.""" 247 248 SUPPORTS_COLUMN_JOIN_MARKS = False 249 """Whether the old-style outer join (+) syntax is supported.""" 250 251 COPY_PARAMS_ARE_CSV = True 252 """Separator of COPY statement parameters.""" 253 254 NORMALIZE_FUNCTIONS: bool | str = "upper" 255 """ 256 Determines how function names are going to be normalized. 257 Possible values: 258 "upper" or True: Convert names to uppercase. 259 "lower": Convert names to lowercase. 260 False: Disables function name normalization. 261 """ 262 263 LOG_BASE_FIRST: t.Optional[bool] = True 264 """ 265 Whether the base comes first in the `LOG` function. 266 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 267 """ 268 269 NULL_ORDERING = "nulls_are_small" 270 """ 271 Default `NULL` ordering method to use if not explicitly set. 272 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 273 """ 274 275 TYPED_DIVISION = False 276 """ 277 Whether the behavior of `a / b` depends on the types of `a` and `b`. 278 False means `a / b` is always float division. 279 True means `a / b` is integer division if both `a` and `b` are integers. 280 """ 281 282 SAFE_DIVISION = False 283 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 284 285 CONCAT_COALESCE = False 286 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 287 288 HEX_LOWERCASE = False 289 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 290 291 DATE_FORMAT = "'%Y-%m-%d'" 292 DATEINT_FORMAT = "'%Y%m%d'" 293 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 294 295 TIME_MAPPING: t.Dict[str, str] = {} 296 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 297 298 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 299 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 300 FORMAT_MAPPING: t.Dict[str, str] = {} 301 """ 302 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 303 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 304 """ 305 306 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 307 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 308 309 PSEUDOCOLUMNS: t.Set[str] = set() 310 """ 311 Columns that are auto-generated by the engine corresponding to this dialect. 312 For example, such columns may be excluded from `SELECT *` queries. 313 """ 314 315 PREFER_CTE_ALIAS_COLUMN = False 316 """ 317 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 318 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 319 any projection aliases in the subquery. 320 321 For example, 322 WITH y(c) AS ( 323 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 324 ) SELECT c FROM y; 325 326 will be rewritten as 327 328 WITH y(c) AS ( 329 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 330 ) SELECT c FROM y; 331 """ 332 333 COPY_PARAMS_ARE_CSV = True 334 """ 335 Whether COPY statement parameters are separated by comma or whitespace 336 """ 337 338 FORCE_EARLY_ALIAS_REF_EXPANSION = False 339 """ 340 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 341 342 For example: 343 WITH data AS ( 344 SELECT 345 1 AS id, 346 2 AS my_id 347 ) 348 SELECT 349 id AS my_id 350 FROM 351 data 352 WHERE 353 my_id = 1 354 GROUP BY 355 my_id, 356 HAVING 357 my_id = 1 358 359 In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: 360 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 361 - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1" 362 """ 363 364 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 365 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 366 367 SUPPORTS_ORDER_BY_ALL = False 368 """ 369 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 370 """ 371 372 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 373 """ 374 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 375 as the former is of type INT[] vs the latter which is SUPER 376 """ 377 378 SUPPORTS_FIXED_SIZE_ARRAYS = False 379 """ 380 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In 381 dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator 382 """ 383 384 # --- Autofilled --- 385 386 tokenizer_class = Tokenizer 387 jsonpath_tokenizer_class = JSONPathTokenizer 388 parser_class = Parser 389 generator_class = Generator 390 391 # A trie of the time_mapping keys 392 TIME_TRIE: t.Dict = {} 393 FORMAT_TRIE: t.Dict = {} 394 395 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 396 INVERSE_TIME_TRIE: t.Dict = {} 397 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 398 INVERSE_FORMAT_TRIE: t.Dict = {} 399 400 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 401 402 # Delimiters for string literals and identifiers 403 QUOTE_START = "'" 404 QUOTE_END = "'" 405 IDENTIFIER_START = '"' 406 IDENTIFIER_END = '"' 407 408 # Delimiters for bit, hex, byte and unicode literals 409 BIT_START: t.Optional[str] = None 410 BIT_END: t.Optional[str] = None 411 HEX_START: t.Optional[str] = None 412 HEX_END: t.Optional[str] = None 413 BYTE_START: t.Optional[str] = None 414 BYTE_END: t.Optional[str] = None 415 UNICODE_START: t.Optional[str] = None 416 UNICODE_END: t.Optional[str] = None 417 418 DATE_PART_MAPPING = { 419 "Y": "YEAR", 420 "YY": "YEAR", 421 "YYY": "YEAR", 422 "YYYY": "YEAR", 423 "YR": "YEAR", 424 "YEARS": "YEAR", 425 "YRS": "YEAR", 426 "MM": "MONTH", 427 "MON": "MONTH", 428 "MONS": "MONTH", 429 "MONTHS": "MONTH", 430 "D": "DAY", 431 "DD": "DAY", 432 "DAYS": "DAY", 433 "DAYOFMONTH": "DAY", 434 "DAY OF WEEK": "DAYOFWEEK", 435 "WEEKDAY": "DAYOFWEEK", 436 "DOW": "DAYOFWEEK", 437 "DW": "DAYOFWEEK", 438 "WEEKDAY_ISO": "DAYOFWEEKISO", 439 "DOW_ISO": "DAYOFWEEKISO", 440 "DW_ISO": "DAYOFWEEKISO", 441 "DAY OF YEAR": "DAYOFYEAR", 442 "DOY": "DAYOFYEAR", 443 "DY": "DAYOFYEAR", 444 "W": "WEEK", 445 "WK": "WEEK", 446 "WEEKOFYEAR": "WEEK", 447 "WOY": "WEEK", 448 "WY": "WEEK", 449 "WEEK_ISO": "WEEKISO", 450 "WEEKOFYEARISO": "WEEKISO", 451 "WEEKOFYEAR_ISO": "WEEKISO", 452 "Q": "QUARTER", 453 "QTR": "QUARTER", 454 "QTRS": "QUARTER", 455 "QUARTERS": "QUARTER", 456 "H": "HOUR", 457 "HH": "HOUR", 458 "HR": "HOUR", 459 "HOURS": "HOUR", 460 "HRS": "HOUR", 461 "M": "MINUTE", 462 "MI": "MINUTE", 463 "MIN": "MINUTE", 464 "MINUTES": "MINUTE", 465 "MINS": "MINUTE", 466 "S": "SECOND", 467 "SEC": "SECOND", 468 "SECONDS": "SECOND", 469 "SECS": "SECOND", 470 "MS": "MILLISECOND", 471 "MSEC": "MILLISECOND", 472 "MSECS": "MILLISECOND", 473 "MSECOND": "MILLISECOND", 474 "MSECONDS": "MILLISECOND", 475 "MILLISEC": "MILLISECOND", 476 "MILLISECS": "MILLISECOND", 477 "MILLISECON": "MILLISECOND", 478 "MILLISECONDS": "MILLISECOND", 479 "US": "MICROSECOND", 480 "USEC": "MICROSECOND", 481 "USECS": "MICROSECOND", 482 "MICROSEC": "MICROSECOND", 483 "MICROSECS": "MICROSECOND", 484 "USECOND": "MICROSECOND", 485 "USECONDS": "MICROSECOND", 486 "MICROSECONDS": "MICROSECOND", 487 "NS": "NANOSECOND", 488 "NSEC": "NANOSECOND", 489 "NANOSEC": "NANOSECOND", 490 "NSECOND": "NANOSECOND", 491 "NSECONDS": "NANOSECOND", 492 "NANOSECS": "NANOSECOND", 493 "EPOCH_SECOND": "EPOCH", 494 "EPOCH_SECONDS": "EPOCH", 495 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 496 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 497 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 498 "TZH": "TIMEZONE_HOUR", 499 "TZM": "TIMEZONE_MINUTE", 500 "DEC": "DECADE", 501 "DECS": "DECADE", 502 "DECADES": "DECADE", 503 "MIL": "MILLENIUM", 504 "MILS": "MILLENIUM", 505 "MILLENIA": "MILLENIUM", 506 "C": "CENTURY", 507 "CENT": "CENTURY", 508 "CENTS": "CENTURY", 509 "CENTURIES": "CENTURY", 510 } 511 512 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 513 exp.DataType.Type.BIGINT: { 514 exp.ApproxDistinct, 515 exp.ArraySize, 516 exp.Count, 517 exp.Length, 518 }, 519 exp.DataType.Type.BOOLEAN: { 520 exp.Between, 521 exp.Boolean, 522 exp.In, 523 exp.RegexpLike, 524 }, 525 exp.DataType.Type.DATE: { 526 exp.CurrentDate, 527 exp.Date, 528 exp.DateFromParts, 529 exp.DateStrToDate, 530 exp.DiToDate, 531 exp.StrToDate, 532 exp.TimeStrToDate, 533 exp.TsOrDsToDate, 534 }, 535 exp.DataType.Type.DATETIME: { 536 exp.CurrentDatetime, 537 exp.Datetime, 538 exp.DatetimeAdd, 539 exp.DatetimeSub, 540 }, 541 exp.DataType.Type.DOUBLE: { 542 exp.ApproxQuantile, 543 exp.Avg, 544 exp.Div, 545 exp.Exp, 546 exp.Ln, 547 exp.Log, 548 exp.Pow, 549 exp.Quantile, 550 exp.Round, 551 exp.SafeDivide, 552 exp.Sqrt, 553 exp.Stddev, 554 exp.StddevPop, 555 exp.StddevSamp, 556 exp.Variance, 557 exp.VariancePop, 558 }, 559 exp.DataType.Type.INT: { 560 exp.Ceil, 561 exp.DatetimeDiff, 562 exp.DateDiff, 563 exp.TimestampDiff, 564 exp.TimeDiff, 565 exp.DateToDi, 566 exp.Levenshtein, 567 exp.Sign, 568 exp.StrPosition, 569 exp.TsOrDiToDi, 570 }, 571 exp.DataType.Type.JSON: { 572 exp.ParseJSON, 573 }, 574 exp.DataType.Type.TIME: { 575 exp.Time, 576 }, 577 exp.DataType.Type.TIMESTAMP: { 578 exp.CurrentTime, 579 exp.CurrentTimestamp, 580 exp.StrToTime, 581 exp.TimeAdd, 582 exp.TimeStrToTime, 583 exp.TimeSub, 584 exp.TimestampAdd, 585 exp.TimestampSub, 586 exp.UnixToTime, 587 }, 588 exp.DataType.Type.TINYINT: { 589 exp.Day, 590 exp.Month, 591 exp.Week, 592 exp.Year, 593 exp.Quarter, 594 }, 595 exp.DataType.Type.VARCHAR: { 596 exp.ArrayConcat, 597 exp.Concat, 598 exp.ConcatWs, 599 exp.DateToDateStr, 600 exp.GroupConcat, 601 exp.Initcap, 602 exp.Lower, 603 exp.Substring, 604 exp.TimeToStr, 605 exp.TimeToTimeStr, 606 exp.Trim, 607 exp.TsOrDsToDateStr, 608 exp.UnixToStr, 609 exp.UnixToTimeStr, 610 exp.Upper, 611 }, 612 } 613 614 ANNOTATORS: AnnotatorsType = { 615 **{ 616 expr_type: lambda self, e: self._annotate_unary(e) 617 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 618 }, 619 **{ 620 expr_type: lambda self, e: self._annotate_binary(e) 621 for expr_type in subclasses(exp.__name__, exp.Binary) 622 }, 623 **{ 624 expr_type: _annotate_with_type_lambda(data_type) 625 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 626 for expr_type in expressions 627 }, 628 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 629 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 630 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 631 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 632 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 633 exp.Bracket: lambda self, e: self._annotate_bracket(e), 634 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 635 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 636 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 637 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 638 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 639 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 640 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 641 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 642 exp.Div: lambda self, e: self._annotate_div(e), 643 exp.Dot: lambda self, e: self._annotate_dot(e), 644 exp.Explode: lambda self, e: self._annotate_explode(e), 645 exp.Extract: lambda self, e: self._annotate_extract(e), 646 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 647 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 648 e, exp.DataType.build("ARRAY<DATE>") 649 ), 650 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 651 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 652 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 653 exp.Literal: lambda self, e: self._annotate_literal(e), 654 exp.Map: lambda self, e: self._annotate_map(e), 655 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 656 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 657 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 658 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 659 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 660 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 661 exp.Struct: lambda self, e: self._annotate_struct(e), 662 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 663 exp.Timestamp: lambda self, e: self._annotate_with_type( 664 e, 665 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 666 ), 667 exp.ToMap: lambda self, e: self._annotate_to_map(e), 668 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 669 exp.Unnest: lambda self, e: self._annotate_unnest(e), 670 exp.VarMap: lambda self, e: self._annotate_map(e), 671 } 672 673 @classmethod 674 def get_or_raise(cls, dialect: DialectType) -> Dialect: 675 """ 676 Look up a dialect in the global dialect registry and return it if it exists. 677 678 Args: 679 dialect: The target dialect. If this is a string, it can be optionally followed by 680 additional key-value pairs that are separated by commas and are used to specify 681 dialect settings, such as whether the dialect's identifiers are case-sensitive. 682 683 Example: 684 >>> dialect = dialect_class = get_or_raise("duckdb") 685 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 686 687 Returns: 688 The corresponding Dialect instance. 689 """ 690 691 if not dialect: 692 return cls() 693 if isinstance(dialect, _Dialect): 694 return dialect() 695 if isinstance(dialect, Dialect): 696 return dialect 697 if isinstance(dialect, str): 698 try: 699 dialect_name, *kv_strings = dialect.split(",") 700 kv_pairs = (kv.split("=") for kv in kv_strings) 701 kwargs = {} 702 for pair in kv_pairs: 703 key = pair[0].strip() 704 value: t.Union[bool | str | None] = None 705 706 if len(pair) == 1: 707 # Default initialize standalone settings to True 708 value = True 709 elif len(pair) == 2: 710 value = pair[1].strip() 711 712 # Coerce the value to boolean if it matches to the truthy/falsy values below 713 value_lower = value.lower() 714 if value_lower in ("true", "1"): 715 value = True 716 elif value_lower in ("false", "0"): 717 value = False 718 719 kwargs[key] = value 720 721 except ValueError: 722 raise ValueError( 723 f"Invalid dialect format: '{dialect}'. " 724 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 725 ) 726 727 result = cls.get(dialect_name.strip()) 728 if not result: 729 from difflib import get_close_matches 730 731 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 732 if similar: 733 similar = f" Did you mean {similar}?" 734 735 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 736 737 return result(**kwargs) 738 739 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 740 741 @classmethod 742 def format_time( 743 cls, expression: t.Optional[str | exp.Expression] 744 ) -> t.Optional[exp.Expression]: 745 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 746 if isinstance(expression, str): 747 return exp.Literal.string( 748 # the time formats are quoted 749 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 750 ) 751 752 if expression and expression.is_string: 753 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 754 755 return expression 756 757 def __init__(self, **kwargs) -> None: 758 normalization_strategy = kwargs.pop("normalization_strategy", None) 759 760 if normalization_strategy is None: 761 self.normalization_strategy = self.NORMALIZATION_STRATEGY 762 else: 763 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 764 765 self.settings = kwargs 766 767 def __eq__(self, other: t.Any) -> bool: 768 # Does not currently take dialect state into account 769 return type(self) == other 770 771 def __hash__(self) -> int: 772 # Does not currently take dialect state into account 773 return hash(type(self)) 774 775 def normalize_identifier(self, expression: E) -> E: 776 """ 777 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 778 779 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 780 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 781 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 782 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 783 784 There are also dialects like Spark, which are case-insensitive even when quotes are 785 present, and dialects like MySQL, whose resolution rules match those employed by the 786 underlying operating system, for example they may always be case-sensitive in Linux. 787 788 Finally, the normalization behavior of some engines can even be controlled through flags, 789 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 790 791 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 792 that it can analyze queries in the optimizer and successfully capture their semantics. 793 """ 794 if ( 795 isinstance(expression, exp.Identifier) 796 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 797 and ( 798 not expression.quoted 799 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 800 ) 801 ): 802 expression.set( 803 "this", 804 ( 805 expression.this.upper() 806 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 807 else expression.this.lower() 808 ), 809 ) 810 811 return expression 812 813 def case_sensitive(self, text: str) -> bool: 814 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 815 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 816 return False 817 818 unsafe = ( 819 str.islower 820 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 821 else str.isupper 822 ) 823 return any(unsafe(char) for char in text) 824 825 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 826 """Checks if text can be identified given an identify option. 827 828 Args: 829 text: The text to check. 830 identify: 831 `"always"` or `True`: Always returns `True`. 832 `"safe"`: Only returns `True` if the identifier is case-insensitive. 833 834 Returns: 835 Whether the given text can be identified. 836 """ 837 if identify is True or identify == "always": 838 return True 839 840 if identify == "safe": 841 return not self.case_sensitive(text) 842 843 return False 844 845 def quote_identifier(self, expression: E, identify: bool = True) -> E: 846 """ 847 Adds quotes to a given identifier. 848 849 Args: 850 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 851 identify: If set to `False`, the quotes will only be added if the identifier is deemed 852 "unsafe", with respect to its characters and this dialect's normalization strategy. 853 """ 854 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 855 name = expression.this 856 expression.set( 857 "quoted", 858 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 859 ) 860 861 return expression 862 863 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 864 if isinstance(path, exp.Literal): 865 path_text = path.name 866 if path.is_number: 867 path_text = f"[{path_text}]" 868 try: 869 return parse_json_path(path_text, self) 870 except ParseError as e: 871 logger.warning(f"Invalid JSON path syntax. {str(e)}") 872 873 return path 874 875 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 876 return self.parser(**opts).parse(self.tokenize(sql), sql) 877 878 def parse_into( 879 self, expression_type: exp.IntoType, sql: str, **opts 880 ) -> t.List[t.Optional[exp.Expression]]: 881 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 882 883 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 884 return self.generator(**opts).generate(expression, copy=copy) 885 886 def transpile(self, sql: str, **opts) -> t.List[str]: 887 return [ 888 self.generate(expression, copy=False, **opts) if expression else "" 889 for expression in self.parse(sql) 890 ] 891 892 def tokenize(self, sql: str) -> t.List[Token]: 893 return self.tokenizer.tokenize(sql) 894 895 @property 896 def tokenizer(self) -> Tokenizer: 897 return self.tokenizer_class(dialect=self) 898 899 @property 900 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 901 return self.jsonpath_tokenizer_class(dialect=self) 902 903 def parser(self, **opts) -> Parser: 904 return self.parser_class(dialect=self, **opts) 905 906 def generator(self, **opts) -> Generator: 907 return self.generator_class(dialect=self, **opts) 908 909 910DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 911 912 913def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 914 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 915 916 917def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 918 if expression.args.get("accuracy"): 919 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 920 return self.func("APPROX_COUNT_DISTINCT", expression.this) 921 922 923def if_sql( 924 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 925) -> t.Callable[[Generator, exp.If], str]: 926 def _if_sql(self: Generator, expression: exp.If) -> str: 927 return self.func( 928 name, 929 expression.this, 930 expression.args.get("true"), 931 expression.args.get("false") or false_value, 932 ) 933 934 return _if_sql 935 936 937def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 938 this = expression.this 939 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 940 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 941 942 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 943 944 945def inline_array_sql(self: Generator, expression: exp.Array) -> str: 946 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 947 948 949def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 950 elem = seq_get(expression.expressions, 0) 951 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 952 return self.func("ARRAY", elem) 953 return inline_array_sql(self, expression) 954 955 956def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 957 return self.like_sql( 958 exp.Like( 959 this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) 960 ) 961 ) 962 963 964def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 965 zone = self.sql(expression, "this") 966 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 967 968 969def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 970 if expression.args.get("recursive"): 971 self.unsupported("Recursive CTEs are unsupported") 972 expression.args["recursive"] = False 973 return self.with_sql(expression) 974 975 976def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 977 n = self.sql(expression, "this") 978 d = self.sql(expression, "expression") 979 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 980 981 982def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 983 self.unsupported("TABLESAMPLE unsupported") 984 return self.sql(expression.this) 985 986 987def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 988 self.unsupported("PIVOT unsupported") 989 return "" 990 991 992def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 993 return self.cast_sql(expression) 994 995 996def no_comment_column_constraint_sql( 997 self: Generator, expression: exp.CommentColumnConstraint 998) -> str: 999 self.unsupported("CommentColumnConstraint unsupported") 1000 return "" 1001 1002 1003def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 1004 self.unsupported("MAP_FROM_ENTRIES unsupported") 1005 return "" 1006 1007 1008def str_position_sql( 1009 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 1010) -> str: 1011 this = self.sql(expression, "this") 1012 substr = self.sql(expression, "substr") 1013 position = self.sql(expression, "position") 1014 instance = expression.args.get("instance") if generate_instance else None 1015 position_offset = "" 1016 1017 if position: 1018 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1019 this = self.func("SUBSTR", this, position) 1020 position_offset = f" + {position} - 1" 1021 1022 return self.func("STRPOS", this, substr, instance) + position_offset 1023 1024 1025def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 1026 return ( 1027 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 1028 ) 1029 1030 1031def var_map_sql( 1032 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1033) -> str: 1034 keys = expression.args["keys"] 1035 values = expression.args["values"] 1036 1037 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1038 self.unsupported("Cannot convert array columns into map.") 1039 return self.func(map_func_name, keys, values) 1040 1041 args = [] 1042 for key, value in zip(keys.expressions, values.expressions): 1043 args.append(self.sql(key)) 1044 args.append(self.sql(value)) 1045 1046 return self.func(map_func_name, *args) 1047 1048 1049def build_formatted_time( 1050 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1051) -> t.Callable[[t.List], E]: 1052 """Helper used for time expressions. 1053 1054 Args: 1055 exp_class: the expression class to instantiate. 1056 dialect: target sql dialect. 1057 default: the default format, True being time. 1058 1059 Returns: 1060 A callable that can be used to return the appropriately formatted time expression. 1061 """ 1062 1063 def _builder(args: t.List): 1064 return exp_class( 1065 this=seq_get(args, 0), 1066 format=Dialect[dialect].format_time( 1067 seq_get(args, 1) 1068 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1069 ), 1070 ) 1071 1072 return _builder 1073 1074 1075def time_format( 1076 dialect: DialectType = None, 1077) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1078 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1079 """ 1080 Returns the time format for a given expression, unless it's equivalent 1081 to the default time format of the dialect of interest. 1082 """ 1083 time_format = self.format_time(expression) 1084 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1085 1086 return _time_format 1087 1088 1089def build_date_delta( 1090 exp_class: t.Type[E], 1091 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1092 default_unit: t.Optional[str] = "DAY", 1093) -> t.Callable[[t.List], E]: 1094 def _builder(args: t.List) -> E: 1095 unit_based = len(args) == 3 1096 this = args[2] if unit_based else seq_get(args, 0) 1097 unit = None 1098 if unit_based or default_unit: 1099 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1100 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1101 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1102 1103 return _builder 1104 1105 1106def build_date_delta_with_interval( 1107 expression_class: t.Type[E], 1108) -> t.Callable[[t.List], t.Optional[E]]: 1109 def _builder(args: t.List) -> t.Optional[E]: 1110 if len(args) < 2: 1111 return None 1112 1113 interval = args[1] 1114 1115 if not isinstance(interval, exp.Interval): 1116 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1117 1118 expression = interval.this 1119 if expression and expression.is_string: 1120 expression = exp.Literal.number(expression.this) 1121 1122 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 1123 1124 return _builder 1125 1126 1127def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1128 unit = seq_get(args, 0) 1129 this = seq_get(args, 1) 1130 1131 if isinstance(this, exp.Cast) and this.is_type("date"): 1132 return exp.DateTrunc(unit=unit, this=this) 1133 return exp.TimestampTrunc(this=this, unit=unit) 1134 1135 1136def date_add_interval_sql( 1137 data_type: str, kind: str 1138) -> t.Callable[[Generator, exp.Expression], str]: 1139 def func(self: Generator, expression: exp.Expression) -> str: 1140 this = self.sql(expression, "this") 1141 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1142 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1143 1144 return func 1145 1146 1147def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1148 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1149 args = [unit_to_str(expression), expression.this] 1150 if zone: 1151 args.append(expression.args.get("zone")) 1152 return self.func("DATE_TRUNC", *args) 1153 1154 return _timestamptrunc_sql 1155 1156 1157def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1158 zone = expression.args.get("zone") 1159 if not zone: 1160 from sqlglot.optimizer.annotate_types import annotate_types 1161 1162 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1163 return self.sql(exp.cast(expression.this, target_type)) 1164 if zone.name.lower() in TIMEZONES: 1165 return self.sql( 1166 exp.AtTimeZone( 1167 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1168 zone=zone, 1169 ) 1170 ) 1171 return self.func("TIMESTAMP", expression.this, zone) 1172 1173 1174def no_time_sql(self: Generator, expression: exp.Time) -> str: 1175 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1176 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1177 expr = exp.cast( 1178 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1179 ) 1180 return self.sql(expr) 1181 1182 1183def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1184 this = expression.this 1185 expr = expression.expression 1186 1187 if expr.name.lower() in TIMEZONES: 1188 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1189 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1190 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1191 return self.sql(this) 1192 1193 this = exp.cast(this, exp.DataType.Type.DATE) 1194 expr = exp.cast(expr, exp.DataType.Type.TIME) 1195 1196 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP)) 1197 1198 1199def locate_to_strposition(args: t.List) -> exp.Expression: 1200 return exp.StrPosition( 1201 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 1202 ) 1203 1204 1205def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 1206 return self.func( 1207 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 1208 ) 1209 1210 1211def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1212 return self.sql( 1213 exp.Substring( 1214 this=expression.this, start=exp.Literal.number(1), length=expression.expression 1215 ) 1216 ) 1217 1218 1219def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1220 return self.sql( 1221 exp.Substring( 1222 this=expression.this, 1223 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 1224 ) 1225 ) 1226 1227 1228def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: 1229 return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP)) 1230 1231 1232def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 1233 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 1234 1235 1236# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 1237def encode_decode_sql( 1238 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1239) -> str: 1240 charset = expression.args.get("charset") 1241 if charset and charset.name.lower() != "utf-8": 1242 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1243 1244 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 1245 1246 1247def min_or_least(self: Generator, expression: exp.Min) -> str: 1248 name = "LEAST" if expression.expressions else "MIN" 1249 return rename_func(name)(self, expression) 1250 1251 1252def max_or_greatest(self: Generator, expression: exp.Max) -> str: 1253 name = "GREATEST" if expression.expressions else "MAX" 1254 return rename_func(name)(self, expression) 1255 1256 1257def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1258 cond = expression.this 1259 1260 if isinstance(expression.this, exp.Distinct): 1261 cond = expression.this.expressions[0] 1262 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1263 1264 return self.func("sum", exp.func("if", cond, 1, 0)) 1265 1266 1267def trim_sql(self: Generator, expression: exp.Trim) -> str: 1268 target = self.sql(expression, "this") 1269 trim_type = self.sql(expression, "position") 1270 remove_chars = self.sql(expression, "expression") 1271 collation = self.sql(expression, "collation") 1272 1273 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1274 if not remove_chars and not collation: 1275 return self.trim_sql(expression) 1276 1277 trim_type = f"{trim_type} " if trim_type else "" 1278 remove_chars = f"{remove_chars} " if remove_chars else "" 1279 from_part = "FROM " if trim_type or remove_chars else "" 1280 collation = f" COLLATE {collation}" if collation else "" 1281 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 1282 1283 1284def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 1285 return self.func("STRPTIME", expression.this, self.format_time(expression)) 1286 1287 1288def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 1289 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 1290 1291 1292def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 1293 delim, *rest_args = expression.expressions 1294 return self.sql( 1295 reduce( 1296 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 1297 rest_args, 1298 ) 1299 ) 1300 1301 1302def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1303 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1304 if bad_args: 1305 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1306 1307 return self.func( 1308 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 1309 ) 1310 1311 1312def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1313 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1314 if bad_args: 1315 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1316 1317 return self.func( 1318 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1319 ) 1320 1321 1322def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1323 names = [] 1324 for agg in aggregations: 1325 if isinstance(agg, exp.Alias): 1326 names.append(agg.alias) 1327 else: 1328 """ 1329 This case corresponds to aggregations without aliases being used as suffixes 1330 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1331 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1332 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1333 """ 1334 agg_all_unquoted = agg.transform( 1335 lambda node: ( 1336 exp.Identifier(this=node.name, quoted=False) 1337 if isinstance(node, exp.Identifier) 1338 else node 1339 ) 1340 ) 1341 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1342 1343 return names 1344 1345 1346def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 1347 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 1348 1349 1350# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 1351def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 1352 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 1353 1354 1355def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 1356 return self.func("MAX", expression.this) 1357 1358 1359def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 1360 a = self.sql(expression.left) 1361 b = self.sql(expression.right) 1362 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 1363 1364 1365def is_parse_json(expression: exp.Expression) -> bool: 1366 return isinstance(expression, exp.ParseJSON) or ( 1367 isinstance(expression, exp.Cast) and expression.is_type("json") 1368 ) 1369 1370 1371def isnull_to_is_null(args: t.List) -> exp.Expression: 1372 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 1373 1374 1375def generatedasidentitycolumnconstraint_sql( 1376 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 1377) -> str: 1378 start = self.sql(expression, "start") or "1" 1379 increment = self.sql(expression, "increment") or "1" 1380 return f"IDENTITY({start}, {increment})" 1381 1382 1383def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1384 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1385 if expression.args.get("count"): 1386 self.unsupported(f"Only two arguments are supported in function {name}.") 1387 1388 return self.func(name, expression.this, expression.expression) 1389 1390 return _arg_max_or_min_sql 1391 1392 1393def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1394 this = expression.this.copy() 1395 1396 return_type = expression.return_type 1397 if return_type.is_type(exp.DataType.Type.DATE): 1398 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1399 # can truncate timestamp strings, because some dialects can't cast them to DATE 1400 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1401 1402 expression.this.replace(exp.cast(this, return_type)) 1403 return expression 1404 1405 1406def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1407 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1408 if cast and isinstance(expression, exp.TsOrDsAdd): 1409 expression = ts_or_ds_add_cast(expression) 1410 1411 return self.func( 1412 name, 1413 unit_to_var(expression), 1414 expression.expression, 1415 expression.this, 1416 ) 1417 1418 return _delta_sql 1419 1420 1421def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1422 unit = expression.args.get("unit") 1423 1424 if isinstance(unit, exp.Placeholder): 1425 return unit 1426 if unit: 1427 return exp.Literal.string(unit.name) 1428 return exp.Literal.string(default) if default else None 1429 1430 1431def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1432 unit = expression.args.get("unit") 1433 1434 if isinstance(unit, (exp.Var, exp.Placeholder)): 1435 return unit 1436 return exp.Var(this=default) if default else None 1437 1438 1439@t.overload 1440def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var: 1441 pass 1442 1443 1444@t.overload 1445def map_date_part( 1446 part: t.Optional[exp.Expression], dialect: DialectType = Dialect 1447) -> t.Optional[exp.Expression]: 1448 pass 1449 1450 1451def map_date_part(part, dialect: DialectType = Dialect): 1452 mapped = ( 1453 Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None 1454 ) 1455 return exp.var(mapped) if mapped else part 1456 1457 1458def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1459 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1460 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1461 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1462 1463 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1464 1465 1466def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1467 """Remove table refs from columns in when statements.""" 1468 alias = expression.this.args.get("alias") 1469 1470 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1471 return self.dialect.normalize_identifier(identifier).name if identifier else None 1472 1473 targets = {normalize(expression.this.this)} 1474 1475 if alias: 1476 targets.add(normalize(alias.this)) 1477 1478 for when in expression.expressions: 1479 when.transform( 1480 lambda node: ( 1481 exp.column(node.this) 1482 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1483 else node 1484 ), 1485 copy=False, 1486 ) 1487 1488 return self.merge_sql(expression) 1489 1490 1491def build_json_extract_path( 1492 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1493) -> t.Callable[[t.List], F]: 1494 def _builder(args: t.List) -> F: 1495 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1496 for arg in args[1:]: 1497 if not isinstance(arg, exp.Literal): 1498 # We use the fallback parser because we can't really transpile non-literals safely 1499 return expr_type.from_arg_list(args) 1500 1501 text = arg.name 1502 if is_int(text): 1503 index = int(text) 1504 segments.append( 1505 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1506 ) 1507 else: 1508 segments.append(exp.JSONPathKey(this=text)) 1509 1510 # This is done to avoid failing in the expression validator due to the arg count 1511 del args[2:] 1512 return expr_type( 1513 this=seq_get(args, 0), 1514 expression=exp.JSONPath(expressions=segments), 1515 only_json_types=arrow_req_json_type, 1516 ) 1517 1518 return _builder 1519 1520 1521def json_extract_segments( 1522 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1523) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1524 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1525 path = expression.expression 1526 if not isinstance(path, exp.JSONPath): 1527 return rename_func(name)(self, expression) 1528 1529 segments = [] 1530 for segment in path.expressions: 1531 path = self.sql(segment) 1532 if path: 1533 if isinstance(segment, exp.JSONPathPart) and ( 1534 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1535 ): 1536 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1537 1538 segments.append(path) 1539 1540 if op: 1541 return f" {op} ".join([self.sql(expression.this), *segments]) 1542 return self.func(name, expression.this, *segments) 1543 1544 return _json_extract_segments 1545 1546 1547def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1548 if isinstance(expression.this, exp.JSONPathWildcard): 1549 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1550 1551 return expression.name 1552 1553 1554def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1555 cond = expression.expression 1556 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1557 alias = cond.expressions[0] 1558 cond = cond.this 1559 elif isinstance(cond, exp.Predicate): 1560 alias = "_u" 1561 else: 1562 self.unsupported("Unsupported filter condition") 1563 return "" 1564 1565 unnest = exp.Unnest(expressions=[expression.this]) 1566 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1567 return self.sql(exp.Array(expressions=[filtered])) 1568 1569 1570def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1571 return self.func( 1572 "TO_NUMBER", 1573 expression.this, 1574 expression.args.get("format"), 1575 expression.args.get("nlsparam"), 1576 ) 1577 1578 1579def build_default_decimal_type( 1580 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1581) -> t.Callable[[exp.DataType], exp.DataType]: 1582 def _builder(dtype: exp.DataType) -> exp.DataType: 1583 if dtype.expressions or precision is None: 1584 return dtype 1585 1586 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1587 return exp.DataType.build(f"DECIMAL({params})") 1588 1589 return _builder 1590 1591 1592def build_timestamp_from_parts(args: t.List) -> exp.Func: 1593 if len(args) == 2: 1594 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1595 # so we parse this into Anonymous for now instead of introducing complexity 1596 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1597 1598 return exp.TimestampFromParts.from_arg_list(args) 1599 1600 1601def sha256_sql(self: Generator, expression: exp.SHA2) -> str: 1602 return self.func(f"SHA{expression.text('length') or '256'}", expression.this) 1603 1604 1605def sequence_sql(self: Generator, expression: exp.GenerateSeries): 1606 start = expression.args["start"] 1607 end = expression.args["end"] 1608 step = expression.args.get("step") 1609 1610 if isinstance(start, exp.Cast): 1611 target_type = start.to 1612 elif isinstance(end, exp.Cast): 1613 target_type = end.to 1614 else: 1615 target_type = None 1616 1617 if target_type and target_type.is_type("timestamp"): 1618 if target_type is start.to: 1619 end = exp.cast(end, target_type) 1620 else: 1621 start = exp.cast(start, target_type) 1622 1623 return self.func("SEQUENCE", start, end, step)
49class Dialects(str, Enum): 50 """Dialects supported by SQLGLot.""" 51 52 DIALECT = "" 53 54 ATHENA = "athena" 55 BIGQUERY = "bigquery" 56 CLICKHOUSE = "clickhouse" 57 DATABRICKS = "databricks" 58 DORIS = "doris" 59 DRILL = "drill" 60 DUCKDB = "duckdb" 61 HIVE = "hive" 62 MATERIALIZE = "materialize" 63 MYSQL = "mysql" 64 ORACLE = "oracle" 65 POSTGRES = "postgres" 66 PRESTO = "presto" 67 PRQL = "prql" 68 REDSHIFT = "redshift" 69 RISINGWAVE = "risingwave" 70 SNOWFLAKE = "snowflake" 71 SPARK = "spark" 72 SPARK2 = "spark2" 73 SQLITE = "sqlite" 74 STARROCKS = "starrocks" 75 TABLEAU = "tableau" 76 TERADATA = "teradata" 77 TRINO = "trino" 78 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
81class NormalizationStrategy(str, AutoName): 82 """Specifies the strategy according to which identifiers should be normalized.""" 83 84 LOWERCASE = auto() 85 """Unquoted identifiers are lowercased.""" 86 87 UPPERCASE = auto() 88 """Unquoted identifiers are uppercased.""" 89 90 CASE_SENSITIVE = auto() 91 """Always case-sensitive, regardless of quotes.""" 92 93 CASE_INSENSITIVE = auto() 94 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
215class Dialect(metaclass=_Dialect): 216 INDEX_OFFSET = 0 217 """The base index offset for arrays.""" 218 219 WEEK_OFFSET = 0 220 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 221 222 UNNEST_COLUMN_ONLY = False 223 """Whether `UNNEST` table aliases are treated as column aliases.""" 224 225 ALIAS_POST_TABLESAMPLE = False 226 """Whether the table alias comes after tablesample.""" 227 228 TABLESAMPLE_SIZE_IS_PERCENT = False 229 """Whether a size in the table sample clause represents percentage.""" 230 231 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 232 """Specifies the strategy according to which identifiers should be normalized.""" 233 234 IDENTIFIERS_CAN_START_WITH_DIGIT = False 235 """Whether an unquoted identifier can start with a digit.""" 236 237 DPIPE_IS_STRING_CONCAT = True 238 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 239 240 STRICT_STRING_CONCAT = False 241 """Whether `CONCAT`'s arguments must be strings.""" 242 243 SUPPORTS_USER_DEFINED_TYPES = True 244 """Whether user-defined data types are supported.""" 245 246 SUPPORTS_SEMI_ANTI_JOIN = True 247 """Whether `SEMI` or `ANTI` joins are supported.""" 248 249 SUPPORTS_COLUMN_JOIN_MARKS = False 250 """Whether the old-style outer join (+) syntax is supported.""" 251 252 COPY_PARAMS_ARE_CSV = True 253 """Separator of COPY statement parameters.""" 254 255 NORMALIZE_FUNCTIONS: bool | str = "upper" 256 """ 257 Determines how function names are going to be normalized. 258 Possible values: 259 "upper" or True: Convert names to uppercase. 260 "lower": Convert names to lowercase. 261 False: Disables function name normalization. 262 """ 263 264 LOG_BASE_FIRST: t.Optional[bool] = True 265 """ 266 Whether the base comes first in the `LOG` function. 267 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 268 """ 269 270 NULL_ORDERING = "nulls_are_small" 271 """ 272 Default `NULL` ordering method to use if not explicitly set. 273 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 274 """ 275 276 TYPED_DIVISION = False 277 """ 278 Whether the behavior of `a / b` depends on the types of `a` and `b`. 279 False means `a / b` is always float division. 280 True means `a / b` is integer division if both `a` and `b` are integers. 281 """ 282 283 SAFE_DIVISION = False 284 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 285 286 CONCAT_COALESCE = False 287 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 288 289 HEX_LOWERCASE = False 290 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 291 292 DATE_FORMAT = "'%Y-%m-%d'" 293 DATEINT_FORMAT = "'%Y%m%d'" 294 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 295 296 TIME_MAPPING: t.Dict[str, str] = {} 297 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 298 299 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 300 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 301 FORMAT_MAPPING: t.Dict[str, str] = {} 302 """ 303 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 304 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 305 """ 306 307 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 308 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 309 310 PSEUDOCOLUMNS: t.Set[str] = set() 311 """ 312 Columns that are auto-generated by the engine corresponding to this dialect. 313 For example, such columns may be excluded from `SELECT *` queries. 314 """ 315 316 PREFER_CTE_ALIAS_COLUMN = False 317 """ 318 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 319 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 320 any projection aliases in the subquery. 321 322 For example, 323 WITH y(c) AS ( 324 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 325 ) SELECT c FROM y; 326 327 will be rewritten as 328 329 WITH y(c) AS ( 330 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 331 ) SELECT c FROM y; 332 """ 333 334 COPY_PARAMS_ARE_CSV = True 335 """ 336 Whether COPY statement parameters are separated by comma or whitespace 337 """ 338 339 FORCE_EARLY_ALIAS_REF_EXPANSION = False 340 """ 341 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 342 343 For example: 344 WITH data AS ( 345 SELECT 346 1 AS id, 347 2 AS my_id 348 ) 349 SELECT 350 id AS my_id 351 FROM 352 data 353 WHERE 354 my_id = 1 355 GROUP BY 356 my_id, 357 HAVING 358 my_id = 1 359 360 In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: 361 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 362 - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1" 363 """ 364 365 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 366 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 367 368 SUPPORTS_ORDER_BY_ALL = False 369 """ 370 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 371 """ 372 373 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 374 """ 375 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 376 as the former is of type INT[] vs the latter which is SUPER 377 """ 378 379 SUPPORTS_FIXED_SIZE_ARRAYS = False 380 """ 381 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In 382 dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator 383 """ 384 385 # --- Autofilled --- 386 387 tokenizer_class = Tokenizer 388 jsonpath_tokenizer_class = JSONPathTokenizer 389 parser_class = Parser 390 generator_class = Generator 391 392 # A trie of the time_mapping keys 393 TIME_TRIE: t.Dict = {} 394 FORMAT_TRIE: t.Dict = {} 395 396 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 397 INVERSE_TIME_TRIE: t.Dict = {} 398 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 399 INVERSE_FORMAT_TRIE: t.Dict = {} 400 401 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 402 403 # Delimiters for string literals and identifiers 404 QUOTE_START = "'" 405 QUOTE_END = "'" 406 IDENTIFIER_START = '"' 407 IDENTIFIER_END = '"' 408 409 # Delimiters for bit, hex, byte and unicode literals 410 BIT_START: t.Optional[str] = None 411 BIT_END: t.Optional[str] = None 412 HEX_START: t.Optional[str] = None 413 HEX_END: t.Optional[str] = None 414 BYTE_START: t.Optional[str] = None 415 BYTE_END: t.Optional[str] = None 416 UNICODE_START: t.Optional[str] = None 417 UNICODE_END: t.Optional[str] = None 418 419 DATE_PART_MAPPING = { 420 "Y": "YEAR", 421 "YY": "YEAR", 422 "YYY": "YEAR", 423 "YYYY": "YEAR", 424 "YR": "YEAR", 425 "YEARS": "YEAR", 426 "YRS": "YEAR", 427 "MM": "MONTH", 428 "MON": "MONTH", 429 "MONS": "MONTH", 430 "MONTHS": "MONTH", 431 "D": "DAY", 432 "DD": "DAY", 433 "DAYS": "DAY", 434 "DAYOFMONTH": "DAY", 435 "DAY OF WEEK": "DAYOFWEEK", 436 "WEEKDAY": "DAYOFWEEK", 437 "DOW": "DAYOFWEEK", 438 "DW": "DAYOFWEEK", 439 "WEEKDAY_ISO": "DAYOFWEEKISO", 440 "DOW_ISO": "DAYOFWEEKISO", 441 "DW_ISO": "DAYOFWEEKISO", 442 "DAY OF YEAR": "DAYOFYEAR", 443 "DOY": "DAYOFYEAR", 444 "DY": "DAYOFYEAR", 445 "W": "WEEK", 446 "WK": "WEEK", 447 "WEEKOFYEAR": "WEEK", 448 "WOY": "WEEK", 449 "WY": "WEEK", 450 "WEEK_ISO": "WEEKISO", 451 "WEEKOFYEARISO": "WEEKISO", 452 "WEEKOFYEAR_ISO": "WEEKISO", 453 "Q": "QUARTER", 454 "QTR": "QUARTER", 455 "QTRS": "QUARTER", 456 "QUARTERS": "QUARTER", 457 "H": "HOUR", 458 "HH": "HOUR", 459 "HR": "HOUR", 460 "HOURS": "HOUR", 461 "HRS": "HOUR", 462 "M": "MINUTE", 463 "MI": "MINUTE", 464 "MIN": "MINUTE", 465 "MINUTES": "MINUTE", 466 "MINS": "MINUTE", 467 "S": "SECOND", 468 "SEC": "SECOND", 469 "SECONDS": "SECOND", 470 "SECS": "SECOND", 471 "MS": "MILLISECOND", 472 "MSEC": "MILLISECOND", 473 "MSECS": "MILLISECOND", 474 "MSECOND": "MILLISECOND", 475 "MSECONDS": "MILLISECOND", 476 "MILLISEC": "MILLISECOND", 477 "MILLISECS": "MILLISECOND", 478 "MILLISECON": "MILLISECOND", 479 "MILLISECONDS": "MILLISECOND", 480 "US": "MICROSECOND", 481 "USEC": "MICROSECOND", 482 "USECS": "MICROSECOND", 483 "MICROSEC": "MICROSECOND", 484 "MICROSECS": "MICROSECOND", 485 "USECOND": "MICROSECOND", 486 "USECONDS": "MICROSECOND", 487 "MICROSECONDS": "MICROSECOND", 488 "NS": "NANOSECOND", 489 "NSEC": "NANOSECOND", 490 "NANOSEC": "NANOSECOND", 491 "NSECOND": "NANOSECOND", 492 "NSECONDS": "NANOSECOND", 493 "NANOSECS": "NANOSECOND", 494 "EPOCH_SECOND": "EPOCH", 495 "EPOCH_SECONDS": "EPOCH", 496 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 497 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 498 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 499 "TZH": "TIMEZONE_HOUR", 500 "TZM": "TIMEZONE_MINUTE", 501 "DEC": "DECADE", 502 "DECS": "DECADE", 503 "DECADES": "DECADE", 504 "MIL": "MILLENIUM", 505 "MILS": "MILLENIUM", 506 "MILLENIA": "MILLENIUM", 507 "C": "CENTURY", 508 "CENT": "CENTURY", 509 "CENTS": "CENTURY", 510 "CENTURIES": "CENTURY", 511 } 512 513 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 514 exp.DataType.Type.BIGINT: { 515 exp.ApproxDistinct, 516 exp.ArraySize, 517 exp.Count, 518 exp.Length, 519 }, 520 exp.DataType.Type.BOOLEAN: { 521 exp.Between, 522 exp.Boolean, 523 exp.In, 524 exp.RegexpLike, 525 }, 526 exp.DataType.Type.DATE: { 527 exp.CurrentDate, 528 exp.Date, 529 exp.DateFromParts, 530 exp.DateStrToDate, 531 exp.DiToDate, 532 exp.StrToDate, 533 exp.TimeStrToDate, 534 exp.TsOrDsToDate, 535 }, 536 exp.DataType.Type.DATETIME: { 537 exp.CurrentDatetime, 538 exp.Datetime, 539 exp.DatetimeAdd, 540 exp.DatetimeSub, 541 }, 542 exp.DataType.Type.DOUBLE: { 543 exp.ApproxQuantile, 544 exp.Avg, 545 exp.Div, 546 exp.Exp, 547 exp.Ln, 548 exp.Log, 549 exp.Pow, 550 exp.Quantile, 551 exp.Round, 552 exp.SafeDivide, 553 exp.Sqrt, 554 exp.Stddev, 555 exp.StddevPop, 556 exp.StddevSamp, 557 exp.Variance, 558 exp.VariancePop, 559 }, 560 exp.DataType.Type.INT: { 561 exp.Ceil, 562 exp.DatetimeDiff, 563 exp.DateDiff, 564 exp.TimestampDiff, 565 exp.TimeDiff, 566 exp.DateToDi, 567 exp.Levenshtein, 568 exp.Sign, 569 exp.StrPosition, 570 exp.TsOrDiToDi, 571 }, 572 exp.DataType.Type.JSON: { 573 exp.ParseJSON, 574 }, 575 exp.DataType.Type.TIME: { 576 exp.Time, 577 }, 578 exp.DataType.Type.TIMESTAMP: { 579 exp.CurrentTime, 580 exp.CurrentTimestamp, 581 exp.StrToTime, 582 exp.TimeAdd, 583 exp.TimeStrToTime, 584 exp.TimeSub, 585 exp.TimestampAdd, 586 exp.TimestampSub, 587 exp.UnixToTime, 588 }, 589 exp.DataType.Type.TINYINT: { 590 exp.Day, 591 exp.Month, 592 exp.Week, 593 exp.Year, 594 exp.Quarter, 595 }, 596 exp.DataType.Type.VARCHAR: { 597 exp.ArrayConcat, 598 exp.Concat, 599 exp.ConcatWs, 600 exp.DateToDateStr, 601 exp.GroupConcat, 602 exp.Initcap, 603 exp.Lower, 604 exp.Substring, 605 exp.TimeToStr, 606 exp.TimeToTimeStr, 607 exp.Trim, 608 exp.TsOrDsToDateStr, 609 exp.UnixToStr, 610 exp.UnixToTimeStr, 611 exp.Upper, 612 }, 613 } 614 615 ANNOTATORS: AnnotatorsType = { 616 **{ 617 expr_type: lambda self, e: self._annotate_unary(e) 618 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 619 }, 620 **{ 621 expr_type: lambda self, e: self._annotate_binary(e) 622 for expr_type in subclasses(exp.__name__, exp.Binary) 623 }, 624 **{ 625 expr_type: _annotate_with_type_lambda(data_type) 626 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 627 for expr_type in expressions 628 }, 629 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 630 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 631 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 632 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 633 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 634 exp.Bracket: lambda self, e: self._annotate_bracket(e), 635 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 636 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 637 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 638 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 639 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 640 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 641 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 642 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 643 exp.Div: lambda self, e: self._annotate_div(e), 644 exp.Dot: lambda self, e: self._annotate_dot(e), 645 exp.Explode: lambda self, e: self._annotate_explode(e), 646 exp.Extract: lambda self, e: self._annotate_extract(e), 647 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 648 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 649 e, exp.DataType.build("ARRAY<DATE>") 650 ), 651 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 652 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 653 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 654 exp.Literal: lambda self, e: self._annotate_literal(e), 655 exp.Map: lambda self, e: self._annotate_map(e), 656 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 657 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 658 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 659 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 660 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 661 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 662 exp.Struct: lambda self, e: self._annotate_struct(e), 663 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 664 exp.Timestamp: lambda self, e: self._annotate_with_type( 665 e, 666 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 667 ), 668 exp.ToMap: lambda self, e: self._annotate_to_map(e), 669 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 670 exp.Unnest: lambda self, e: self._annotate_unnest(e), 671 exp.VarMap: lambda self, e: self._annotate_map(e), 672 } 673 674 @classmethod 675 def get_or_raise(cls, dialect: DialectType) -> Dialect: 676 """ 677 Look up a dialect in the global dialect registry and return it if it exists. 678 679 Args: 680 dialect: The target dialect. If this is a string, it can be optionally followed by 681 additional key-value pairs that are separated by commas and are used to specify 682 dialect settings, such as whether the dialect's identifiers are case-sensitive. 683 684 Example: 685 >>> dialect = dialect_class = get_or_raise("duckdb") 686 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 687 688 Returns: 689 The corresponding Dialect instance. 690 """ 691 692 if not dialect: 693 return cls() 694 if isinstance(dialect, _Dialect): 695 return dialect() 696 if isinstance(dialect, Dialect): 697 return dialect 698 if isinstance(dialect, str): 699 try: 700 dialect_name, *kv_strings = dialect.split(",") 701 kv_pairs = (kv.split("=") for kv in kv_strings) 702 kwargs = {} 703 for pair in kv_pairs: 704 key = pair[0].strip() 705 value: t.Union[bool | str | None] = None 706 707 if len(pair) == 1: 708 # Default initialize standalone settings to True 709 value = True 710 elif len(pair) == 2: 711 value = pair[1].strip() 712 713 # Coerce the value to boolean if it matches to the truthy/falsy values below 714 value_lower = value.lower() 715 if value_lower in ("true", "1"): 716 value = True 717 elif value_lower in ("false", "0"): 718 value = False 719 720 kwargs[key] = value 721 722 except ValueError: 723 raise ValueError( 724 f"Invalid dialect format: '{dialect}'. " 725 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 726 ) 727 728 result = cls.get(dialect_name.strip()) 729 if not result: 730 from difflib import get_close_matches 731 732 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 733 if similar: 734 similar = f" Did you mean {similar}?" 735 736 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 737 738 return result(**kwargs) 739 740 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 741 742 @classmethod 743 def format_time( 744 cls, expression: t.Optional[str | exp.Expression] 745 ) -> t.Optional[exp.Expression]: 746 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 747 if isinstance(expression, str): 748 return exp.Literal.string( 749 # the time formats are quoted 750 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 751 ) 752 753 if expression and expression.is_string: 754 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 755 756 return expression 757 758 def __init__(self, **kwargs) -> None: 759 normalization_strategy = kwargs.pop("normalization_strategy", None) 760 761 if normalization_strategy is None: 762 self.normalization_strategy = self.NORMALIZATION_STRATEGY 763 else: 764 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 765 766 self.settings = kwargs 767 768 def __eq__(self, other: t.Any) -> bool: 769 # Does not currently take dialect state into account 770 return type(self) == other 771 772 def __hash__(self) -> int: 773 # Does not currently take dialect state into account 774 return hash(type(self)) 775 776 def normalize_identifier(self, expression: E) -> E: 777 """ 778 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 779 780 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 781 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 782 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 783 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 784 785 There are also dialects like Spark, which are case-insensitive even when quotes are 786 present, and dialects like MySQL, whose resolution rules match those employed by the 787 underlying operating system, for example they may always be case-sensitive in Linux. 788 789 Finally, the normalization behavior of some engines can even be controlled through flags, 790 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 791 792 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 793 that it can analyze queries in the optimizer and successfully capture their semantics. 794 """ 795 if ( 796 isinstance(expression, exp.Identifier) 797 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 798 and ( 799 not expression.quoted 800 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 801 ) 802 ): 803 expression.set( 804 "this", 805 ( 806 expression.this.upper() 807 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 808 else expression.this.lower() 809 ), 810 ) 811 812 return expression 813 814 def case_sensitive(self, text: str) -> bool: 815 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 816 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 817 return False 818 819 unsafe = ( 820 str.islower 821 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 822 else str.isupper 823 ) 824 return any(unsafe(char) for char in text) 825 826 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 827 """Checks if text can be identified given an identify option. 828 829 Args: 830 text: The text to check. 831 identify: 832 `"always"` or `True`: Always returns `True`. 833 `"safe"`: Only returns `True` if the identifier is case-insensitive. 834 835 Returns: 836 Whether the given text can be identified. 837 """ 838 if identify is True or identify == "always": 839 return True 840 841 if identify == "safe": 842 return not self.case_sensitive(text) 843 844 return False 845 846 def quote_identifier(self, expression: E, identify: bool = True) -> E: 847 """ 848 Adds quotes to a given identifier. 849 850 Args: 851 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 852 identify: If set to `False`, the quotes will only be added if the identifier is deemed 853 "unsafe", with respect to its characters and this dialect's normalization strategy. 854 """ 855 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 856 name = expression.this 857 expression.set( 858 "quoted", 859 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 860 ) 861 862 return expression 863 864 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 865 if isinstance(path, exp.Literal): 866 path_text = path.name 867 if path.is_number: 868 path_text = f"[{path_text}]" 869 try: 870 return parse_json_path(path_text, self) 871 except ParseError as e: 872 logger.warning(f"Invalid JSON path syntax. {str(e)}") 873 874 return path 875 876 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 877 return self.parser(**opts).parse(self.tokenize(sql), sql) 878 879 def parse_into( 880 self, expression_type: exp.IntoType, sql: str, **opts 881 ) -> t.List[t.Optional[exp.Expression]]: 882 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 883 884 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 885 return self.generator(**opts).generate(expression, copy=copy) 886 887 def transpile(self, sql: str, **opts) -> t.List[str]: 888 return [ 889 self.generate(expression, copy=False, **opts) if expression else "" 890 for expression in self.parse(sql) 891 ] 892 893 def tokenize(self, sql: str) -> t.List[Token]: 894 return self.tokenizer.tokenize(sql) 895 896 @property 897 def tokenizer(self) -> Tokenizer: 898 return self.tokenizer_class(dialect=self) 899 900 @property 901 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 902 return self.jsonpath_tokenizer_class(dialect=self) 903 904 def parser(self, **opts) -> Parser: 905 return self.parser_class(dialect=self, **opts) 906 907 def generator(self, **opts) -> Generator: 908 return self.generator_class(dialect=self, **opts)
758 def __init__(self, **kwargs) -> None: 759 normalization_strategy = kwargs.pop("normalization_strategy", None) 760 761 if normalization_strategy is None: 762 self.normalization_strategy = self.NORMALIZATION_STRATEGY 763 else: 764 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 765 766 self.settings = kwargs
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
For example:
WITH data AS ( SELECT 1 AS id, 2 AS my_id ) SELECT id AS my_id FROM data WHERE my_id = 1 GROUP BY my_id, HAVING my_id = 1
In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
Whether alias reference expansion before qualification should only happen for the GROUP BY clause.
Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks
Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) as the former is of type INT[] vs the latter which is SUPER
Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator
674 @classmethod 675 def get_or_raise(cls, dialect: DialectType) -> Dialect: 676 """ 677 Look up a dialect in the global dialect registry and return it if it exists. 678 679 Args: 680 dialect: The target dialect. If this is a string, it can be optionally followed by 681 additional key-value pairs that are separated by commas and are used to specify 682 dialect settings, such as whether the dialect's identifiers are case-sensitive. 683 684 Example: 685 >>> dialect = dialect_class = get_or_raise("duckdb") 686 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 687 688 Returns: 689 The corresponding Dialect instance. 690 """ 691 692 if not dialect: 693 return cls() 694 if isinstance(dialect, _Dialect): 695 return dialect() 696 if isinstance(dialect, Dialect): 697 return dialect 698 if isinstance(dialect, str): 699 try: 700 dialect_name, *kv_strings = dialect.split(",") 701 kv_pairs = (kv.split("=") for kv in kv_strings) 702 kwargs = {} 703 for pair in kv_pairs: 704 key = pair[0].strip() 705 value: t.Union[bool | str | None] = None 706 707 if len(pair) == 1: 708 # Default initialize standalone settings to True 709 value = True 710 elif len(pair) == 2: 711 value = pair[1].strip() 712 713 # Coerce the value to boolean if it matches to the truthy/falsy values below 714 value_lower = value.lower() 715 if value_lower in ("true", "1"): 716 value = True 717 elif value_lower in ("false", "0"): 718 value = False 719 720 kwargs[key] = value 721 722 except ValueError: 723 raise ValueError( 724 f"Invalid dialect format: '{dialect}'. " 725 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 726 ) 727 728 result = cls.get(dialect_name.strip()) 729 if not result: 730 from difflib import get_close_matches 731 732 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 733 if similar: 734 similar = f" Did you mean {similar}?" 735 736 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 737 738 return result(**kwargs) 739 740 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
742 @classmethod 743 def format_time( 744 cls, expression: t.Optional[str | exp.Expression] 745 ) -> t.Optional[exp.Expression]: 746 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 747 if isinstance(expression, str): 748 return exp.Literal.string( 749 # the time formats are quoted 750 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 751 ) 752 753 if expression and expression.is_string: 754 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 755 756 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
776 def normalize_identifier(self, expression: E) -> E: 777 """ 778 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 779 780 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 781 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 782 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 783 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 784 785 There are also dialects like Spark, which are case-insensitive even when quotes are 786 present, and dialects like MySQL, whose resolution rules match those employed by the 787 underlying operating system, for example they may always be case-sensitive in Linux. 788 789 Finally, the normalization behavior of some engines can even be controlled through flags, 790 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 791 792 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 793 that it can analyze queries in the optimizer and successfully capture their semantics. 794 """ 795 if ( 796 isinstance(expression, exp.Identifier) 797 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 798 and ( 799 not expression.quoted 800 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 801 ) 802 ): 803 expression.set( 804 "this", 805 ( 806 expression.this.upper() 807 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 808 else expression.this.lower() 809 ), 810 ) 811 812 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
814 def case_sensitive(self, text: str) -> bool: 815 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 816 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 817 return False 818 819 unsafe = ( 820 str.islower 821 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 822 else str.isupper 823 ) 824 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
826 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 827 """Checks if text can be identified given an identify option. 828 829 Args: 830 text: The text to check. 831 identify: 832 `"always"` or `True`: Always returns `True`. 833 `"safe"`: Only returns `True` if the identifier is case-insensitive. 834 835 Returns: 836 Whether the given text can be identified. 837 """ 838 if identify is True or identify == "always": 839 return True 840 841 if identify == "safe": 842 return not self.case_sensitive(text) 843 844 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
846 def quote_identifier(self, expression: E, identify: bool = True) -> E: 847 """ 848 Adds quotes to a given identifier. 849 850 Args: 851 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 852 identify: If set to `False`, the quotes will only be added if the identifier is deemed 853 "unsafe", with respect to its characters and this dialect's normalization strategy. 854 """ 855 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 856 name = expression.this 857 expression.set( 858 "quoted", 859 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 860 ) 861 862 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
864 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 865 if isinstance(path, exp.Literal): 866 path_text = path.name 867 if path.is_number: 868 path_text = f"[{path_text}]" 869 try: 870 return parse_json_path(path_text, self) 871 except ParseError as e: 872 logger.warning(f"Invalid JSON path syntax. {str(e)}") 873 874 return path
924def if_sql( 925 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 926) -> t.Callable[[Generator, exp.If], str]: 927 def _if_sql(self: Generator, expression: exp.If) -> str: 928 return self.func( 929 name, 930 expression.this, 931 expression.args.get("true"), 932 expression.args.get("false") or false_value, 933 ) 934 935 return _if_sql
938def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 939 this = expression.this 940 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 941 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 942 943 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
1009def str_position_sql( 1010 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 1011) -> str: 1012 this = self.sql(expression, "this") 1013 substr = self.sql(expression, "substr") 1014 position = self.sql(expression, "position") 1015 instance = expression.args.get("instance") if generate_instance else None 1016 position_offset = "" 1017 1018 if position: 1019 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1020 this = self.func("SUBSTR", this, position) 1021 position_offset = f" + {position} - 1" 1022 1023 return self.func("STRPOS", this, substr, instance) + position_offset
1032def var_map_sql( 1033 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1034) -> str: 1035 keys = expression.args["keys"] 1036 values = expression.args["values"] 1037 1038 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1039 self.unsupported("Cannot convert array columns into map.") 1040 return self.func(map_func_name, keys, values) 1041 1042 args = [] 1043 for key, value in zip(keys.expressions, values.expressions): 1044 args.append(self.sql(key)) 1045 args.append(self.sql(value)) 1046 1047 return self.func(map_func_name, *args)
1050def build_formatted_time( 1051 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1052) -> t.Callable[[t.List], E]: 1053 """Helper used for time expressions. 1054 1055 Args: 1056 exp_class: the expression class to instantiate. 1057 dialect: target sql dialect. 1058 default: the default format, True being time. 1059 1060 Returns: 1061 A callable that can be used to return the appropriately formatted time expression. 1062 """ 1063 1064 def _builder(args: t.List): 1065 return exp_class( 1066 this=seq_get(args, 0), 1067 format=Dialect[dialect].format_time( 1068 seq_get(args, 1) 1069 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1070 ), 1071 ) 1072 1073 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
1076def time_format( 1077 dialect: DialectType = None, 1078) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1079 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1080 """ 1081 Returns the time format for a given expression, unless it's equivalent 1082 to the default time format of the dialect of interest. 1083 """ 1084 time_format = self.format_time(expression) 1085 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1086 1087 return _time_format
1090def build_date_delta( 1091 exp_class: t.Type[E], 1092 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1093 default_unit: t.Optional[str] = "DAY", 1094) -> t.Callable[[t.List], E]: 1095 def _builder(args: t.List) -> E: 1096 unit_based = len(args) == 3 1097 this = args[2] if unit_based else seq_get(args, 0) 1098 unit = None 1099 if unit_based or default_unit: 1100 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1101 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1102 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1103 1104 return _builder
1107def build_date_delta_with_interval( 1108 expression_class: t.Type[E], 1109) -> t.Callable[[t.List], t.Optional[E]]: 1110 def _builder(args: t.List) -> t.Optional[E]: 1111 if len(args) < 2: 1112 return None 1113 1114 interval = args[1] 1115 1116 if not isinstance(interval, exp.Interval): 1117 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1118 1119 expression = interval.this 1120 if expression and expression.is_string: 1121 expression = exp.Literal.number(expression.this) 1122 1123 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 1124 1125 return _builder
1128def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1129 unit = seq_get(args, 0) 1130 this = seq_get(args, 1) 1131 1132 if isinstance(this, exp.Cast) and this.is_type("date"): 1133 return exp.DateTrunc(unit=unit, this=this) 1134 return exp.TimestampTrunc(this=this, unit=unit)
1137def date_add_interval_sql( 1138 data_type: str, kind: str 1139) -> t.Callable[[Generator, exp.Expression], str]: 1140 def func(self: Generator, expression: exp.Expression) -> str: 1141 this = self.sql(expression, "this") 1142 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1143 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1144 1145 return func
1148def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1149 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1150 args = [unit_to_str(expression), expression.this] 1151 if zone: 1152 args.append(expression.args.get("zone")) 1153 return self.func("DATE_TRUNC", *args) 1154 1155 return _timestamptrunc_sql
1158def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1159 zone = expression.args.get("zone") 1160 if not zone: 1161 from sqlglot.optimizer.annotate_types import annotate_types 1162 1163 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1164 return self.sql(exp.cast(expression.this, target_type)) 1165 if zone.name.lower() in TIMEZONES: 1166 return self.sql( 1167 exp.AtTimeZone( 1168 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1169 zone=zone, 1170 ) 1171 ) 1172 return self.func("TIMESTAMP", expression.this, zone)
1175def no_time_sql(self: Generator, expression: exp.Time) -> str: 1176 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1177 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1178 expr = exp.cast( 1179 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1180 ) 1181 return self.sql(expr)
1184def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1185 this = expression.this 1186 expr = expression.expression 1187 1188 if expr.name.lower() in TIMEZONES: 1189 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1190 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1191 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1192 return self.sql(this) 1193 1194 this = exp.cast(this, exp.DataType.Type.DATE) 1195 expr = exp.cast(expr, exp.DataType.Type.TIME) 1196 1197 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
1238def encode_decode_sql( 1239 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1240) -> str: 1241 charset = expression.args.get("charset") 1242 if charset and charset.name.lower() != "utf-8": 1243 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1244 1245 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
1258def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1259 cond = expression.this 1260 1261 if isinstance(expression.this, exp.Distinct): 1262 cond = expression.this.expressions[0] 1263 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1264 1265 return self.func("sum", exp.func("if", cond, 1, 0))
1268def trim_sql(self: Generator, expression: exp.Trim) -> str: 1269 target = self.sql(expression, "this") 1270 trim_type = self.sql(expression, "position") 1271 remove_chars = self.sql(expression, "expression") 1272 collation = self.sql(expression, "collation") 1273 1274 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1275 if not remove_chars and not collation: 1276 return self.trim_sql(expression) 1277 1278 trim_type = f"{trim_type} " if trim_type else "" 1279 remove_chars = f"{remove_chars} " if remove_chars else "" 1280 from_part = "FROM " if trim_type or remove_chars else "" 1281 collation = f" COLLATE {collation}" if collation else "" 1282 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
1303def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1304 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1305 if bad_args: 1306 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1307 1308 return self.func( 1309 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 1310 )
1313def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1314 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1315 if bad_args: 1316 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1317 1318 return self.func( 1319 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1320 )
1323def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1324 names = [] 1325 for agg in aggregations: 1326 if isinstance(agg, exp.Alias): 1327 names.append(agg.alias) 1328 else: 1329 """ 1330 This case corresponds to aggregations without aliases being used as suffixes 1331 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1332 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1333 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1334 """ 1335 agg_all_unquoted = agg.transform( 1336 lambda node: ( 1337 exp.Identifier(this=node.name, quoted=False) 1338 if isinstance(node, exp.Identifier) 1339 else node 1340 ) 1341 ) 1342 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1343 1344 return names
1384def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1385 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1386 if expression.args.get("count"): 1387 self.unsupported(f"Only two arguments are supported in function {name}.") 1388 1389 return self.func(name, expression.this, expression.expression) 1390 1391 return _arg_max_or_min_sql
1394def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1395 this = expression.this.copy() 1396 1397 return_type = expression.return_type 1398 if return_type.is_type(exp.DataType.Type.DATE): 1399 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1400 # can truncate timestamp strings, because some dialects can't cast them to DATE 1401 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1402 1403 expression.this.replace(exp.cast(this, return_type)) 1404 return expression
1407def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1408 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1409 if cast and isinstance(expression, exp.TsOrDsAdd): 1410 expression = ts_or_ds_add_cast(expression) 1411 1412 return self.func( 1413 name, 1414 unit_to_var(expression), 1415 expression.expression, 1416 expression.this, 1417 ) 1418 1419 return _delta_sql
1422def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1423 unit = expression.args.get("unit") 1424 1425 if isinstance(unit, exp.Placeholder): 1426 return unit 1427 if unit: 1428 return exp.Literal.string(unit.name) 1429 return exp.Literal.string(default) if default else None
1459def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1460 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1461 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1462 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1463 1464 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1467def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1468 """Remove table refs from columns in when statements.""" 1469 alias = expression.this.args.get("alias") 1470 1471 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1472 return self.dialect.normalize_identifier(identifier).name if identifier else None 1473 1474 targets = {normalize(expression.this.this)} 1475 1476 if alias: 1477 targets.add(normalize(alias.this)) 1478 1479 for when in expression.expressions: 1480 when.transform( 1481 lambda node: ( 1482 exp.column(node.this) 1483 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1484 else node 1485 ), 1486 copy=False, 1487 ) 1488 1489 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1492def build_json_extract_path( 1493 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1494) -> t.Callable[[t.List], F]: 1495 def _builder(args: t.List) -> F: 1496 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1497 for arg in args[1:]: 1498 if not isinstance(arg, exp.Literal): 1499 # We use the fallback parser because we can't really transpile non-literals safely 1500 return expr_type.from_arg_list(args) 1501 1502 text = arg.name 1503 if is_int(text): 1504 index = int(text) 1505 segments.append( 1506 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1507 ) 1508 else: 1509 segments.append(exp.JSONPathKey(this=text)) 1510 1511 # This is done to avoid failing in the expression validator due to the arg count 1512 del args[2:] 1513 return expr_type( 1514 this=seq_get(args, 0), 1515 expression=exp.JSONPath(expressions=segments), 1516 only_json_types=arrow_req_json_type, 1517 ) 1518 1519 return _builder
1522def json_extract_segments( 1523 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1524) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1525 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1526 path = expression.expression 1527 if not isinstance(path, exp.JSONPath): 1528 return rename_func(name)(self, expression) 1529 1530 segments = [] 1531 for segment in path.expressions: 1532 path = self.sql(segment) 1533 if path: 1534 if isinstance(segment, exp.JSONPathPart) and ( 1535 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1536 ): 1537 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1538 1539 segments.append(path) 1540 1541 if op: 1542 return f" {op} ".join([self.sql(expression.this), *segments]) 1543 return self.func(name, expression.this, *segments) 1544 1545 return _json_extract_segments
1555def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1556 cond = expression.expression 1557 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1558 alias = cond.expressions[0] 1559 cond = cond.this 1560 elif isinstance(cond, exp.Predicate): 1561 alias = "_u" 1562 else: 1563 self.unsupported("Unsupported filter condition") 1564 return "" 1565 1566 unnest = exp.Unnest(expressions=[expression.this]) 1567 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1568 return self.sql(exp.Array(expressions=[filtered]))
1580def build_default_decimal_type( 1581 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1582) -> t.Callable[[exp.DataType], exp.DataType]: 1583 def _builder(dtype: exp.DataType) -> exp.DataType: 1584 if dtype.expressions or precision is None: 1585 return dtype 1586 1587 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1588 return exp.DataType.build(f"DECIMAL({params})") 1589 1590 return _builder
1593def build_timestamp_from_parts(args: t.List) -> exp.Func: 1594 if len(args) == 2: 1595 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1596 # so we parse this into Anonymous for now instead of introducing complexity 1597 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1598 1599 return exp.TimestampFromParts.from_arg_list(args)
1606def sequence_sql(self: Generator, expression: exp.GenerateSeries): 1607 start = expression.args["start"] 1608 end = expression.args["end"] 1609 step = expression.args.get("step") 1610 1611 if isinstance(start, exp.Cast): 1612 target_type = start.to 1613 elif isinstance(end, exp.Cast): 1614 target_type = end.to 1615 else: 1616 target_type = None 1617 1618 if target_type and target_type.is_type("timestamp"): 1619 if target_type is start.to: 1620 end = exp.cast(end, target_type) 1621 else: 1622 start = exp.cast(start, target_type) 1623 1624 return self.func("SEQUENCE", start, end, step)