Skip to content

Commit 78489c6

Browse files
committed
Cloud fetch queue and integration
Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>
1 parent 01b7a8d commit 78489c6

5 files changed

Lines changed: 518 additions & 125 deletions

File tree

src/databricks/sql/client.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ def read(self) -> Optional[OAuthToken]:
153153
# _use_arrow_native_timestamps
154154
# Databricks runtime will return native Arrow types for timestamps instead of Arrow strings
155155
# (True by default)
156+
# use_cloud_fetch
157+
# Enable use of cloud fetch to extract large query results in parallel via cloud storage
158+
# max_download_threads
159+
# Number of threads for handling cloud fetch downloads. Defaults to 10
156160

157161
if access_token:
158162
access_token_kv = {"access_token": access_token}
@@ -189,6 +193,8 @@ def read(self) -> Optional[OAuthToken]:
189193
self._session_handle = self.thrift_backend.open_session(
190194
session_configuration, catalog, schema
191195
)
196+
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", False)
197+
self.max_download_threads = kwargs.get("max_download_threads", 10)
192198
self.open = True
193199
logger.info("Successfully opened session " + str(self.get_session_id_hex()))
194200
self._cursors = [] # type: List[Cursor]
@@ -497,6 +503,7 @@ def execute(
497503
max_bytes=self.buffer_size_bytes,
498504
lz4_compression=self.connection.lz4_compression,
499505
cursor=self,
506+
use_cloud_fetch=self.connection.use_cloud_fetch,
500507
)
501508
self.active_result_set = ResultSet(
502509
self.connection,
@@ -804,6 +811,7 @@ def __init__(
804811
self.description = execute_response.description
805812
self._arrow_schema_bytes = execute_response.arrow_schema_bytes
806813
self._next_row_index = 0
814+
self.results = None
807815

808816
if execute_response.arrow_queue:
809817
# In this case the server has taken the fast path and returned an initial batch of
@@ -822,6 +830,7 @@ def __iter__(self):
822830
break
823831

824832
def _fill_results_buffer(self):
833+
# At initialization or if the server does not have cloud fetch result links available
825834
results, has_more_rows = self.thrift_backend.fetch_results(
826835
op_handle=self.command_id,
827836
max_rows=self.arraysize,

src/databricks/sql/thrift_backend.py

Lines changed: 29 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import time
66
import uuid
77
import threading
8-
import lz4.frame
98
from ssl import CERT_NONE, CERT_REQUIRED, create_default_context
109
from typing import List, Union
1110

@@ -31,6 +30,10 @@
3130
_bound,
3231
RequestErrorInfo,
3332
NoRetryReason,
33+
ResultSetQueueFactory,
34+
convert_arrow_based_set_to_arrow_table,
35+
convert_decimals_in_arrow_table,
36+
convert_column_based_set_to_arrow_table,
3437
)
3538

3639
logger = logging.getLogger(__name__)
@@ -67,7 +70,6 @@
6770
class ThriftBackend:
6871
CLOSED_OP_STATE = ttypes.TOperationState.CLOSED_STATE
6972
ERROR_OP_STATE = ttypes.TOperationState.ERROR_STATE
70-
BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128]
7173

7274
def __init__(
7375
self,
@@ -558,108 +560,19 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
558560
(
559561
arrow_table,
560562
num_rows,
561-
) = ThriftBackend._convert_column_based_set_to_arrow_table(
563+
) = convert_column_based_set_to_arrow_table(
562564
t_row_set.columns, description
563565
)
564566
elif t_row_set.arrowBatches is not None:
565567
(
566568
arrow_table,
567569
num_rows,
568-
) = ThriftBackend._convert_arrow_based_set_to_arrow_table(
570+
) = convert_arrow_based_set_to_arrow_table(
569571
t_row_set.arrowBatches, lz4_compressed, schema_bytes
570572
)
571573
else:
572574
raise OperationalError("Unsupported TRowSet instance {}".format(t_row_set))
573-
return self._convert_decimals_in_arrow_table(arrow_table, description), num_rows
574-
575-
@staticmethod
576-
def _convert_decimals_in_arrow_table(table, description):
577-
for (i, col) in enumerate(table.itercolumns()):
578-
if description[i][1] == "decimal":
579-
decimal_col = col.to_pandas().apply(
580-
lambda v: v if v is None else Decimal(v)
581-
)
582-
precision, scale = description[i][4], description[i][5]
583-
assert scale is not None
584-
assert precision is not None
585-
# Spark limits decimal to a maximum scale of 38,
586-
# so 128 is guaranteed to be big enough
587-
dtype = pyarrow.decimal128(precision, scale)
588-
col_data = pyarrow.array(decimal_col, type=dtype)
589-
field = table.field(i).with_type(dtype)
590-
table = table.set_column(i, field, col_data)
591-
return table
592-
593-
@staticmethod
594-
def _convert_arrow_based_set_to_arrow_table(
595-
arrow_batches, lz4_compressed, schema_bytes
596-
):
597-
ba = bytearray()
598-
ba += schema_bytes
599-
n_rows = 0
600-
if lz4_compressed:
601-
for arrow_batch in arrow_batches:
602-
n_rows += arrow_batch.rowCount
603-
ba += lz4.frame.decompress(arrow_batch.batch)
604-
else:
605-
for arrow_batch in arrow_batches:
606-
n_rows += arrow_batch.rowCount
607-
ba += arrow_batch.batch
608-
arrow_table = pyarrow.ipc.open_stream(ba).read_all()
609-
return arrow_table, n_rows
610-
611-
@staticmethod
612-
def _convert_column_based_set_to_arrow_table(columns, description):
613-
arrow_table = pyarrow.Table.from_arrays(
614-
[ThriftBackend._convert_column_to_arrow_array(c) for c in columns],
615-
# Only use the column names from the schema, the types are determined by the
616-
# physical types used in column based set, as they can differ from the
617-
# mapping used in _hive_schema_to_arrow_schema.
618-
names=[c[0] for c in description],
619-
)
620-
return arrow_table, arrow_table.num_rows
621-
622-
@staticmethod
623-
def _convert_column_to_arrow_array(t_col):
624-
"""
625-
Return a pyarrow array from the values in a TColumn instance.
626-
Note that ColumnBasedSet has no native support for complex types, so they will be converted
627-
to strings server-side.
628-
"""
629-
field_name_to_arrow_type = {
630-
"boolVal": pyarrow.bool_(),
631-
"byteVal": pyarrow.int8(),
632-
"i16Val": pyarrow.int16(),
633-
"i32Val": pyarrow.int32(),
634-
"i64Val": pyarrow.int64(),
635-
"doubleVal": pyarrow.float64(),
636-
"stringVal": pyarrow.string(),
637-
"binaryVal": pyarrow.binary(),
638-
}
639-
for field in field_name_to_arrow_type.keys():
640-
wrapper = getattr(t_col, field)
641-
if wrapper:
642-
return ThriftBackend._create_arrow_array(
643-
wrapper, field_name_to_arrow_type[field]
644-
)
645-
646-
raise OperationalError("Empty TColumn instance {}".format(t_col))
647-
648-
@staticmethod
649-
def _create_arrow_array(t_col_value_wrapper, arrow_type):
650-
result = t_col_value_wrapper.values
651-
nulls = t_col_value_wrapper.nulls # bitfield describing which values are null
652-
assert isinstance(nulls, bytes)
653-
654-
# The number of bits in nulls can be both larger or smaller than the number of
655-
# elements in result, so take the minimum of both to iterate over.
656-
length = min(len(result), len(nulls) * 8)
657-
658-
for i in range(length):
659-
if nulls[i >> 3] & ThriftBackend.BIT_MASKS[i & 0x7]:
660-
result[i] = None
661-
662-
return pyarrow.array(result, type=arrow_type)
575+
return convert_decimals_in_arrow_table(arrow_table, description), num_rows
663576

664577
def _get_metadata_resp(self, op_handle):
665578
req = ttypes.TGetResultSetMetadataReq(operationHandle=op_handle)
@@ -752,6 +665,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
752665
if t_result_set_metadata_resp.resultFormat not in [
753666
ttypes.TSparkRowSetType.ARROW_BASED_SET,
754667
ttypes.TSparkRowSetType.COLUMN_BASED_SET,
668+
ttypes.TSparkRowSetType.URL_BASED_SET,
755669
]:
756670
raise OperationalError(
757671
"Expected results to be in Arrow or column based format, "
@@ -783,13 +697,16 @@ def _results_message_to_execute_response(self, resp, operation_state):
783697
assert direct_results.resultSet.results.startRowOffset == 0
784698
assert direct_results.resultSetMetadata
785699

786-
arrow_results, n_rows = self._create_arrow_table(
787-
direct_results.resultSet.results,
788-
lz4_compressed,
789-
schema_bytes,
790-
description,
791-
)
792-
arrow_queue_opt = ArrowQueue(arrow_results, n_rows, 0)
700+
if direct_results.resultSet.results.resultLinks is None:
701+
arrow_results, n_rows = self._create_arrow_table(
702+
direct_results.resultSet.results,
703+
lz4_compressed,
704+
schema_bytes,
705+
description,
706+
)
707+
arrow_queue_opt = ArrowQueue(arrow_results, n_rows, 0)
708+
else:
709+
arrow_queue_opt = None
793710
else:
794711
arrow_queue_opt = None
795712
return ExecuteResponse(
@@ -843,7 +760,7 @@ def _check_direct_results_for_error(t_spark_direct_results):
843760
)
844761

845762
def execute_command(
846-
self, operation, session_handle, max_rows, max_bytes, lz4_compression, cursor
763+
self, operation, session_handle, max_rows, max_bytes, lz4_compression, cursor, use_cloud_fetch=False
847764
):
848765
assert session_handle is not None
849766

@@ -864,7 +781,7 @@ def execute_command(
864781
),
865782
canReadArrowResult=True,
866783
canDecompressLZ4Result=lz4_compression,
867-
canDownloadResult=False,
784+
canDownloadResult=use_cloud_fetch,
868785
confOverlay={
869786
# We want to receive proper Timestamp arrow types.
870787
"spark.thriftserver.arrowBasedRowSet.timestampAsString": "false"
@@ -993,6 +910,7 @@ def fetch_results(
993910
maxRows=max_rows,
994911
maxBytes=max_bytes,
995912
orientation=ttypes.TFetchOrientation.FETCH_NEXT,
913+
includeResultSetMetadata=True,
996914
)
997915

998916
resp = self.make_request(self._client.FetchResults, req)
@@ -1002,12 +920,16 @@ def fetch_results(
1002920
expected_row_start_offset, resp.results.startRowOffset
1003921
)
1004922
)
1005-
arrow_results, n_rows = self._create_arrow_table(
1006-
resp.results, lz4_compressed, arrow_schema_bytes, description
923+
924+
queue = ResultSetQueueFactory.build_queue(
925+
row_set_type=resp.resultSetMetadata.resultFormat,
926+
t_row_set=resp.results,
927+
arrow_schema_bytes=arrow_schema_bytes,
928+
lz4_compressed=lz4_compressed,
929+
description=description,
1007930
)
1008-
arrow_queue = ArrowQueue(arrow_results, n_rows)
1009931

1010-
return arrow_queue, resp.hasMoreRows
932+
return queue, resp.hasMoreRows
1011933

1012934
def close_command(self, op_handle):
1013935
req = ttypes.TCloseOperationReq(operationHandle=op_handle)

0 commit comments

Comments
 (0)