55import time
66import uuid
77import threading
8- import lz4 .frame
98from ssl import CERT_NONE , CERT_REQUIRED , create_default_context
109from typing import List , Union
1110
3130 _bound ,
3231 RequestErrorInfo ,
3332 NoRetryReason ,
33+ ResultSetQueueFactory ,
34+ convert_arrow_based_set_to_arrow_table ,
35+ convert_decimals_in_arrow_table ,
36+ convert_column_based_set_to_arrow_table ,
3437)
3538
3639logger = logging .getLogger (__name__ )
6770class ThriftBackend :
6871 CLOSED_OP_STATE = ttypes .TOperationState .CLOSED_STATE
6972 ERROR_OP_STATE = ttypes .TOperationState .ERROR_STATE
70- BIT_MASKS = [1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 ]
7173
7274 def __init__ (
7375 self ,
@@ -558,108 +560,19 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
558560 (
559561 arrow_table ,
560562 num_rows ,
561- ) = ThriftBackend . _convert_column_based_set_to_arrow_table (
563+ ) = convert_column_based_set_to_arrow_table (
562564 t_row_set .columns , description
563565 )
564566 elif t_row_set .arrowBatches is not None :
565567 (
566568 arrow_table ,
567569 num_rows ,
568- ) = ThriftBackend . _convert_arrow_based_set_to_arrow_table (
570+ ) = convert_arrow_based_set_to_arrow_table (
569571 t_row_set .arrowBatches , lz4_compressed , schema_bytes
570572 )
571573 else :
572574 raise OperationalError ("Unsupported TRowSet instance {}" .format (t_row_set ))
573- return self ._convert_decimals_in_arrow_table (arrow_table , description ), num_rows
574-
575- @staticmethod
576- def _convert_decimals_in_arrow_table (table , description ):
577- for (i , col ) in enumerate (table .itercolumns ()):
578- if description [i ][1 ] == "decimal" :
579- decimal_col = col .to_pandas ().apply (
580- lambda v : v if v is None else Decimal (v )
581- )
582- precision , scale = description [i ][4 ], description [i ][5 ]
583- assert scale is not None
584- assert precision is not None
585- # Spark limits decimal to a maximum scale of 38,
586- # so 128 is guaranteed to be big enough
587- dtype = pyarrow .decimal128 (precision , scale )
588- col_data = pyarrow .array (decimal_col , type = dtype )
589- field = table .field (i ).with_type (dtype )
590- table = table .set_column (i , field , col_data )
591- return table
592-
593- @staticmethod
594- def _convert_arrow_based_set_to_arrow_table (
595- arrow_batches , lz4_compressed , schema_bytes
596- ):
597- ba = bytearray ()
598- ba += schema_bytes
599- n_rows = 0
600- if lz4_compressed :
601- for arrow_batch in arrow_batches :
602- n_rows += arrow_batch .rowCount
603- ba += lz4 .frame .decompress (arrow_batch .batch )
604- else :
605- for arrow_batch in arrow_batches :
606- n_rows += arrow_batch .rowCount
607- ba += arrow_batch .batch
608- arrow_table = pyarrow .ipc .open_stream (ba ).read_all ()
609- return arrow_table , n_rows
610-
611- @staticmethod
612- def _convert_column_based_set_to_arrow_table (columns , description ):
613- arrow_table = pyarrow .Table .from_arrays (
614- [ThriftBackend ._convert_column_to_arrow_array (c ) for c in columns ],
615- # Only use the column names from the schema, the types are determined by the
616- # physical types used in column based set, as they can differ from the
617- # mapping used in _hive_schema_to_arrow_schema.
618- names = [c [0 ] for c in description ],
619- )
620- return arrow_table , arrow_table .num_rows
621-
622- @staticmethod
623- def _convert_column_to_arrow_array (t_col ):
624- """
625- Return a pyarrow array from the values in a TColumn instance.
626- Note that ColumnBasedSet has no native support for complex types, so they will be converted
627- to strings server-side.
628- """
629- field_name_to_arrow_type = {
630- "boolVal" : pyarrow .bool_ (),
631- "byteVal" : pyarrow .int8 (),
632- "i16Val" : pyarrow .int16 (),
633- "i32Val" : pyarrow .int32 (),
634- "i64Val" : pyarrow .int64 (),
635- "doubleVal" : pyarrow .float64 (),
636- "stringVal" : pyarrow .string (),
637- "binaryVal" : pyarrow .binary (),
638- }
639- for field in field_name_to_arrow_type .keys ():
640- wrapper = getattr (t_col , field )
641- if wrapper :
642- return ThriftBackend ._create_arrow_array (
643- wrapper , field_name_to_arrow_type [field ]
644- )
645-
646- raise OperationalError ("Empty TColumn instance {}" .format (t_col ))
647-
648- @staticmethod
649- def _create_arrow_array (t_col_value_wrapper , arrow_type ):
650- result = t_col_value_wrapper .values
651- nulls = t_col_value_wrapper .nulls # bitfield describing which values are null
652- assert isinstance (nulls , bytes )
653-
654- # The number of bits in nulls can be both larger or smaller than the number of
655- # elements in result, so take the minimum of both to iterate over.
656- length = min (len (result ), len (nulls ) * 8 )
657-
658- for i in range (length ):
659- if nulls [i >> 3 ] & ThriftBackend .BIT_MASKS [i & 0x7 ]:
660- result [i ] = None
661-
662- return pyarrow .array (result , type = arrow_type )
575+ return convert_decimals_in_arrow_table (arrow_table , description ), num_rows
663576
664577 def _get_metadata_resp (self , op_handle ):
665578 req = ttypes .TGetResultSetMetadataReq (operationHandle = op_handle )
@@ -752,6 +665,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
752665 if t_result_set_metadata_resp .resultFormat not in [
753666 ttypes .TSparkRowSetType .ARROW_BASED_SET ,
754667 ttypes .TSparkRowSetType .COLUMN_BASED_SET ,
668+ ttypes .TSparkRowSetType .URL_BASED_SET ,
755669 ]:
756670 raise OperationalError (
757671 "Expected results to be in Arrow or column based format, "
@@ -783,13 +697,16 @@ def _results_message_to_execute_response(self, resp, operation_state):
783697 assert direct_results .resultSet .results .startRowOffset == 0
784698 assert direct_results .resultSetMetadata
785699
786- arrow_results , n_rows = self ._create_arrow_table (
787- direct_results .resultSet .results ,
788- lz4_compressed ,
789- schema_bytes ,
790- description ,
791- )
792- arrow_queue_opt = ArrowQueue (arrow_results , n_rows , 0 )
700+ if direct_results .resultSet .results .resultLinks is None :
701+ arrow_results , n_rows = self ._create_arrow_table (
702+ direct_results .resultSet .results ,
703+ lz4_compressed ,
704+ schema_bytes ,
705+ description ,
706+ )
707+ arrow_queue_opt = ArrowQueue (arrow_results , n_rows , 0 )
708+ else :
709+ arrow_queue_opt = None
793710 else :
794711 arrow_queue_opt = None
795712 return ExecuteResponse (
@@ -843,7 +760,7 @@ def _check_direct_results_for_error(t_spark_direct_results):
843760 )
844761
845762 def execute_command (
846- self , operation , session_handle , max_rows , max_bytes , lz4_compression , cursor
763+ self , operation , session_handle , max_rows , max_bytes , lz4_compression , cursor , use_cloud_fetch = False
847764 ):
848765 assert session_handle is not None
849766
@@ -864,7 +781,7 @@ def execute_command(
864781 ),
865782 canReadArrowResult = True ,
866783 canDecompressLZ4Result = lz4_compression ,
867- canDownloadResult = False ,
784+ canDownloadResult = use_cloud_fetch ,
868785 confOverlay = {
869786 # We want to receive proper Timestamp arrow types.
870787 "spark.thriftserver.arrowBasedRowSet.timestampAsString" : "false"
@@ -993,6 +910,7 @@ def fetch_results(
993910 maxRows = max_rows ,
994911 maxBytes = max_bytes ,
995912 orientation = ttypes .TFetchOrientation .FETCH_NEXT ,
913+ includeResultSetMetadata = True ,
996914 )
997915
998916 resp = self .make_request (self ._client .FetchResults , req )
@@ -1002,12 +920,16 @@ def fetch_results(
1002920 expected_row_start_offset , resp .results .startRowOffset
1003921 )
1004922 )
1005- arrow_results , n_rows = self ._create_arrow_table (
1006- resp .results , lz4_compressed , arrow_schema_bytes , description
923+
924+ queue = ResultSetQueueFactory .build_queue (
925+ row_set_type = resp .resultSetMetadata .resultFormat ,
926+ t_row_set = resp .results ,
927+ arrow_schema_bytes = arrow_schema_bytes ,
928+ lz4_compressed = lz4_compressed ,
929+ description = description ,
1007930 )
1008- arrow_queue = ArrowQueue (arrow_results , n_rows )
1009931
1010- return arrow_queue , resp .hasMoreRows
932+ return queue , resp .hasMoreRows
1011933
1012934 def close_command (self , op_handle ):
1013935 req = ttypes .TCloseOperationReq (operationHandle = op_handle )
0 commit comments