Skip to content
4 changes: 2 additions & 2 deletions src/databricks/sqlalchemy/_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ def post_create_table(self, table):
return " USING DELTA"

def visit_unique_constraint(self, constraint, **kw):
logger.warn("Databricks does not support unique constraints")
logger.warning("Databricks does not support unique constraints")
pass

def visit_check_constraint(self, constraint, **kw):
logger.warn("Databricks does not support check constraints")
logger.warning("This dialect does not support check constraints")
pass

def visit_identity_column(self, identity, **kw):
Expand Down
121 changes: 106 additions & 15 deletions src/databricks/sqlalchemy/_parse.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from typing import List, Optional, Dict
import re

import sqlalchemy
from sqlalchemy.engine import CursorResult
from sqlalchemy.engine.interfaces import ReflectedColumn

"""
This module contains helper functions that can parse the contents
of metadata and exceptions received from DBR. These are mostly just
wrappers around regexes.
"""


def _match_table_not_found_string(message: str) -> bool:
"""Return True if the message contains a substring indicating that a table was not found"""

Expand All @@ -22,9 +25,10 @@ def _match_table_not_found_string(message: str) -> bool:
)


def _describe_table_extended_result_to_dict_list(result: CursorResult) -> List[Dict[str, str]]:
"""Transform the CursorResult of DESCRIBE TABLE EXTENDED into a list of Dictionaries
"""
def _describe_table_extended_result_to_dict_list(
result: CursorResult,
) -> List[Dict[str, str]]:
"""Transform the CursorResult of DESCRIBE TABLE EXTENDED into a list of Dictionaries"""

rows_to_return = []
for row in result:
Expand Down Expand Up @@ -68,22 +72,23 @@ def extract_three_level_identifier_from_constraint_string(input_str: str) -> dic
"""
pat = re.compile(r"REFERENCES\s+(.*?)\s*\(")
matches = pat.findall(input_str)

if not matches:
return None

first_match = matches[0]
parts = first_match.split(".")

def strip_backticks(input:str):
def strip_backticks(input: str):
return input.replace("`", "")

return {
"catalog": strip_backticks(parts[0]),
"catalog": strip_backticks(parts[0]),
"schema": strip_backticks(parts[1]),
"table": strip_backticks(parts[2])
"table": strip_backticks(parts[2]),
}


def _parse_fk_from_constraint_string(constraint_str: str) -> dict:
"""Build a dictionary of foreign key constraint information from a constraint string.

Expand Down Expand Up @@ -133,6 +138,7 @@ def _parse_fk_from_constraint_string(constraint_str: str) -> dict:
"referred_schema": referred_schema,
}


def build_fk_dict(
fk_name: str, fk_constraint_string: str, schema_name: Optional[str]
) -> dict:
Expand Down Expand Up @@ -172,6 +178,7 @@ def build_fk_dict(

return complete_foreign_key_dict


def _parse_pk_columns_from_constraint_string(constraint_str: str) -> List[str]:
"""Build a list of constrained columns from a constraint string returned by DESCRIBE TABLE EXTENDED

Expand All @@ -188,21 +195,23 @@ def _parse_pk_columns_from_constraint_string(constraint_str: str) -> List[str]:

return _extracted


def build_pk_dict(pk_name: str, pk_constraint_string: str) -> dict:
"""Given a primary key name and a primary key constraint string, return a dictionary
with the following keys:

constrained_columns
A list of string column names that make up the primary key

name
The name of the primary key constraint
"""

constrained_columns = _parse_pk_columns_from_constraint_string(pk_constraint_string)

return {"constrained_columns": constrained_columns, "name": pk_name}



def match_dte_rows_by_value(dte_output: List[Dict[str, str]], match: str) -> List[dict]:
"""Return a list of dictionaries containing only the col_name:data_type pairs where the `data_type`
value contains the match argument.
Expand All @@ -221,9 +230,10 @@ def match_dte_rows_by_value(dte_output: List[Dict[str, str]], match: str) -> Lis
for row_dict in dte_output:
if match in row_dict["data_type"]:
output_rows.append(row_dict)

return output_rows


def get_fk_strings_from_dte_output(dte_output: List[List]) -> List[dict]:
"""If the DESCRIBE TABLE EXTENDED output contains foreign key constraints, return a list of dictionaries,
one dictionary per defined constraint
Expand All @@ -233,8 +243,10 @@ def get_fk_strings_from_dte_output(dte_output: List[List]) -> List[dict]:

return output


def get_pk_strings_from_dte_output(dte_output: List[Dict[str, str]]) -> Optional[List[dict]]:

def get_pk_strings_from_dte_output(
dte_output: List[Dict[str, str]]
) -> Optional[List[dict]]:
"""If the DESCRIBE TABLE EXTENDED output contains primary key constraints, return a list of dictionaries,
one dictionary per defined constraint.

Expand All @@ -244,3 +256,82 @@ def get_pk_strings_from_dte_output(dte_output: List[Dict[str, str]]) -> Optional
output = match_dte_rows_by_value(dte_output, "PRIMARY KEY")

return output


# The keys of this dictionary are the values we expect to see in a
# TGetColumnsRequest's .TYPE_NAME attribute.
# These are enumerated in ttypes.py as class TTypeId.
# TODO: confirm that all types in TTypeId are included here.
GET_COLUMNS_TYPE_MAP = {
"boolean": sqlalchemy.types.Boolean,
"smallint": sqlalchemy.types.SmallInteger,
"int": sqlalchemy.types.Integer,
"bigint": sqlalchemy.types.BigInteger,
"float": sqlalchemy.types.Float,
"double": sqlalchemy.types.Float,
"string": sqlalchemy.types.String,
"varchar": sqlalchemy.types.String,
"char": sqlalchemy.types.String,
"binary": sqlalchemy.types.String,
"array": sqlalchemy.types.String,
"map": sqlalchemy.types.String,
"struct": sqlalchemy.types.String,
"uniontype": sqlalchemy.types.String,
"decimal": sqlalchemy.types.Numeric,
"timestamp": sqlalchemy.types.DateTime,
"date": sqlalchemy.types.Date,
}


def parse_numeric_type_precision_and_scale(type_name_str):
"""Return an intantiated sqlalchemy Numeric() type that preserves the precision and scale indicated
in the output from TGetColumnsRequest.

type_name_str
The value of TGetColumnsReq.TYPE_NAME.

If type_name_str is "DECIMAL(18,5) returns sqlalchemy.types.Numeric(18,5)
"""

pattern = re.compile(r"DECIMAL\((\d+,\d+)\)")
match = re.search(pattern, type_name_str)
precision_and_scale = match.group(1)
precision, scale = tuple(precision_and_scale.split(","))

return sqlalchemy.types.Numeric(int(precision), int(scale))


def parse_column_info_from_tgetcolumnsresponse(thrift_resp_row) -> ReflectedColumn:
"""Returns a dictionary of the ReflectedColumn schema parsed from
a single of the result of a TGetColumnsRequest thrift RPC
"""

pat = re.compile(r"^\w+")
_raw_col_type = re.search(pat, thrift_resp_row.TYPE_NAME).group(0).lower()
_col_type = GET_COLUMNS_TYPE_MAP[_raw_col_type]

if _raw_col_type == "decimal":
final_col_type = parse_numeric_type_precision_and_scale(
thrift_resp_row.TYPE_NAME
)
else:
final_col_type = _col_type

# See comments about autoincrement in test_suite.py
# Since Databricks SQL doesn't currently support inline AUTOINCREMENT declarations
# the autoincrement must be manually declared with an Identity() construct in SQLAlchemy
# Other dialects can perform this extra Identity() step automatically. But that is not
# implemented in the Databricks dialect right now. So autoincrement is currently always False.
# It's not clear what IS_AUTO_INCREMENT in the thrift response actually reflects or whether
# it ever returns a `YES`.

# Per the guidance in SQLAlchemy's docstrings, we prefer to not even include an autoincrement
# key in this dictionary.
this_column = {
"name": thrift_resp_row.COLUMN_NAME,
"type": final_col_type,
"nullable": bool(thrift_resp_row.NULLABLE),
"default": thrift_resp_row.COLUMN_DEF,
}

return this_column
Loading