Skip to content

Commit ecca1bf

Browse files
ianmcookHyukjinKwon
andcommitted
[SPARK-47365][PYTHON] Add toArrow() DataFrame method to PySpark
### What changes were proposed in this pull request? - Add a PySpark DataFrame method `toArrow()` which returns the contents of the DataFrame as a [PyArrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html), for both local Spark and Spark Connect. - Add a new entry to the **Apache Arrow in PySpark** user guide page describing usage of the `toArrow()` method. - Add a new option to the method `_collect_as_arrow()` to provide more useful output when there are zero records returned. (This keeps the implementation of `toArrow()` simpler.) ### Why are the changes needed? In the Apache Arrow community, we hear from a lot of users who want to return the contents of a PySpark DataFrame as a PyArrow Table. Currently the only documented way to do this is to return the contents as a pandas DataFrame, then use PyArrow (`pa`) to convert that to a PyArrow Table. ```py pa.Table.from_pandas(df.toPandas()) ``` But going through pandas adds significant overhead which is easily avoided since internally `toPandas()` already converts the contents of Spark DataFrame to Arrow format as an intermediate step when `spark.sql.execution.arrow.pyspark.enabled` is `true`. Currently it is also possible to use the experimental `_collect_as_arrow()` method to return the contents of a PySpark DataFrame as a list of PyArrow RecordBatches. This PR adds a new non-experimental method `toArrow()` which returns the more user-friendly PyArrow Table object. This PR also adds a new argument `empty_list_if_zero_records` to the experimental method `_collect_as_arrow()` to control what the method returns in the case when the result data has zero rows. If set to `True` (the default), the existing behavior is preserved, and the method returns an empty Python list. If set to `False`, the method returns returns a length-one list containing an empty Arrow RecordBatch which includes the schema. This is used by `toArrow()` which requires the schema even if the data has zero rows. For Spark Connect, there is already a `SparkSession.client.to_table()` method that returns a PyArrow table. This PR uses that to expose `toArrow()` for Spark Connect. ### Does this PR introduce _any_ user-facing change? - It adds a DataFrame method `toArrow()` to the PySpark SQL DataFrame API. - It adds a new argument `empty_list_if_zero_records` to the experimental DataFrame method `_collect_as_arrow()` with a default value which preserves the method's existing behavior. - It exposes `toArrow()` for Spark Connect, via the existing `SparkSession.client.to_table()` method. - It does not introduce any other user-facing changes. ### How was this patch tested? This adds a new test and a new helper function for the test in `pyspark/sql/tests/test_arrow.py`. ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#45481 from ianmcook/SPARK-47365. Lead-authored-by: Ian Cook <ianmcook@gmail.com> Co-authored-by: Hyukjin Kwon <gurwls223@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 027327d commit ecca1bf

8 files changed

Lines changed: 169 additions & 20 deletions

File tree

examples/src/main/python/sql/arrow.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,22 @@
3333
require_minimum_pyarrow_version()
3434

3535

36+
def dataframe_to_arrow_table_example(spark: SparkSession) -> None:
37+
import pyarrow as pa # noqa: F401
38+
from pyspark.sql.functions import rand
39+
40+
# Create a Spark DataFrame
41+
df = spark.range(100).drop("id").withColumns({"0": rand(), "1": rand(), "2": rand()})
42+
43+
# Convert the Spark DataFrame to a PyArrow Table
44+
table = df.select("*").toArrow()
45+
46+
print(table.schema)
47+
# 0: double not null
48+
# 1: double not null
49+
# 2: double not null
50+
51+
3652
def dataframe_with_arrow_example(spark: SparkSession) -> None:
3753
import numpy as np
3854
import pandas as pd
@@ -302,6 +318,8 @@ def arrow_slen(s): # type: ignore[no-untyped-def]
302318
.appName("Python Arrow-in-Spark example") \
303319
.getOrCreate()
304320

321+
print("Running Arrow conversion example: DataFrame to Table")
322+
dataframe_to_arrow_table_example(spark)
305323
print("Running Pandas to/from conversion example")
306324
dataframe_with_arrow_example(spark)
307325
print("Running pandas_udf example: Series to Frame")

python/docs/source/reference/pyspark.sql/dataframe.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ DataFrame
109109
DataFrame.tail
110110
DataFrame.take
111111
DataFrame.to
112+
DataFrame.toArrow
112113
DataFrame.toDF
113114
DataFrame.toJSON
114115
DataFrame.toLocalIterator

python/docs/source/user_guide/sql/arrow_pandas.rst

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,20 @@ is installed and available on all cluster nodes.
3939
You can install it using pip or conda from the conda-forge channel. See PyArrow
4040
`installation <https://arrow.apache.org/docs/python/install.html>`_ for details.
4141

42+
Conversion to Arrow Table
43+
-------------------------
44+
45+
You can call :meth:`DataFrame.toArrow` to convert a Spark DataFrame to a PyArrow Table.
46+
47+
.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
48+
:language: python
49+
:lines: 37-49
50+
:dedent: 4
51+
52+
Note that :meth:`DataFrame.toArrow` results in the collection of all records in the DataFrame to
53+
the driver program and should be done on a small subset of the data. Not all Spark data types are
54+
currently supported and an error can be raised if a column has an unsupported type.
55+
4256
Enabling for Conversion to/from Pandas
4357
--------------------------------------
4458

@@ -53,7 +67,7 @@ This can be controlled by ``spark.sql.execution.arrow.pyspark.fallback.enabled``
5367

5468
.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
5569
:language: python
56-
:lines: 37-52
70+
:lines: 53-68
5771
:dedent: 4
5872

5973
Using the above optimizations with Arrow will produce the same results as when Arrow is not
@@ -90,7 +104,7 @@ specify the type hints of ``pandas.Series`` and ``pandas.DataFrame`` as below:
90104

91105
.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
92106
:language: python
93-
:lines: 56-80
107+
:lines: 72-96
94108
:dedent: 4
95109

96110
In the following sections, it describes the combinations of the supported type hints. For simplicity,
@@ -113,7 +127,7 @@ The following example shows how to create this Pandas UDF that computes the prod
113127

114128
.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
115129
:language: python
116-
:lines: 84-114
130+
:lines: 100-130
117131
:dedent: 4
118132

119133
For detailed usage, please see :func:`pandas_udf`.
@@ -152,7 +166,7 @@ The following example shows how to create this Pandas UDF:
152166

153167
.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
154168
:language: python
155-
:lines: 118-140
169+
:lines: 134-156
156170
:dedent: 4
157171

158172
For detailed usage, please see :func:`pandas_udf`.
@@ -174,7 +188,7 @@ The following example shows how to create this Pandas UDF:
174188

175189
.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
176190
:language: python
177-
:lines: 144-167
191+
:lines: 160-183
178192
:dedent: 4
179193

180194
For detailed usage, please see :func:`pandas_udf`.
@@ -205,7 +219,7 @@ and window operations:
205219

206220
.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
207221
:language: python
208-
:lines: 171-212
222+
:lines: 187-228
209223
:dedent: 4
210224

211225
.. currentmodule:: pyspark.sql.functions
@@ -270,7 +284,7 @@ in the group.
270284

271285
.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
272286
:language: python
273-
:lines: 216-234
287+
:lines: 232-250
274288
:dedent: 4
275289

276290
For detailed usage, please see please see :meth:`GroupedData.applyInPandas`
@@ -288,7 +302,7 @@ The following example shows how to use :meth:`DataFrame.mapInPandas`:
288302

289303
.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
290304
:language: python
291-
:lines: 238-249
305+
:lines: 254-265
292306
:dedent: 4
293307

294308
For detailed usage, please see :meth:`DataFrame.mapInPandas`.
@@ -327,7 +341,7 @@ The following example shows how to use ``DataFrame.groupby().cogroup().applyInPa
327341

328342
.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
329343
:language: python
330-
:lines: 253-275
344+
:lines: 269-291
331345
:dedent: 4
332346

333347

@@ -349,7 +363,7 @@ Here's an example that demonstrates the usage of both a default, pickled Python
349363

350364
.. literalinclude:: ../../../../../examples/src/main/python/sql/arrow.py
351365
:language: python
352-
:lines: 279-297
366+
:lines: 295-313
353367
:dedent: 4
354368

355369
Compared to the default, pickled Python UDFs, Arrow Python UDFs provide a more coherent type coercion mechanism. UDF
@@ -421,9 +435,12 @@ be verified by the user.
421435
Setting Arrow ``self_destruct`` for memory savings
422436
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
423437

424-
Since Spark 3.2, the Spark configuration ``spark.sql.execution.arrow.pyspark.selfDestruct.enabled`` can be used to enable PyArrow's ``self_destruct`` feature, which can save memory when creating a Pandas DataFrame via ``toPandas`` by freeing Arrow-allocated memory while building the Pandas DataFrame.
425-
This option is experimental, and some operations may fail on the resulting Pandas DataFrame due to immutable backing arrays.
426-
Typically, you would see the error ``ValueError: buffer source array is read-only``.
427-
Newer versions of Pandas may fix these errors by improving support for such cases.
428-
You can work around this error by copying the column(s) beforehand.
429-
Additionally, this conversion may be slower because it is single-threaded.
438+
Since Spark 3.2, the Spark configuration ``spark.sql.execution.arrow.pyspark.selfDestruct.enabled``
439+
can be used to enable PyArrow's ``self_destruct`` feature, which can save memory when creating a
440+
Pandas DataFrame via ``toPandas`` by freeing Arrow-allocated memory while building the Pandas
441+
DataFrame. This option can also save memory when creating a PyArrow Table via ``toArrow``.
442+
This option is experimental. When used with ``toPandas``, some operations may fail on the resulting
443+
Pandas DataFrame due to immutable backing arrays. Typically, you would see the error
444+
``ValueError: buffer source array is read-only``. Newer versions of Pandas may fix these errors by
445+
improving support for such cases. You can work around this error by copying the column(s)
446+
beforehand. Additionally, this conversion may be slower because it is single-threaded.

python/pyspark/sql/classic/dataframe.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474

7575
if TYPE_CHECKING:
7676
from py4j.java_gateway import JavaObject
77+
import pyarrow as pa
7778
from pyspark.core.rdd import RDD
7879
from pyspark.core.context import SparkContext
7980
from pyspark._typing import PrimitiveType
@@ -1825,6 +1826,9 @@ def mapInArrow(
18251826
) -> ParentDataFrame:
18261827
return PandasMapOpsMixin.mapInArrow(self, func, schema, barrier, profile)
18271828

1829+
def toArrow(self) -> "pa.Table":
1830+
return PandasConversionMixin.toArrow(self)
1831+
18281832
def toPandas(self) -> "PandasDataFrameLike":
18291833
return PandasConversionMixin.toPandas(self)
18301834

python/pyspark/sql/connect/dataframe.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1768,6 +1768,10 @@ def _to_table(self) -> Tuple["pa.Table", Optional[StructType]]:
17681768
assert table is not None
17691769
return (table, schema)
17701770

1771+
def toArrow(self) -> "pa.Table":
1772+
table, _ = self._to_table()
1773+
return table
1774+
17711775
def toPandas(self) -> "PandasDataFrameLike":
17721776
query = self._plan.to_proto(self._session.client)
17731777
return self._session.client.to_pandas(query, self._plan.observations)

python/pyspark/sql/dataframe.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545
if TYPE_CHECKING:
4646
from py4j.java_gateway import JavaObject
47+
import pyarrow as pa
4748
from pyspark.core.context import SparkContext
4849
from pyspark.core.rdd import RDD
4950
from pyspark._typing import PrimitiveType
@@ -1200,6 +1201,7 @@ def collect(self) -> List[Row]:
12001201
DataFrame.take : Returns the first `n` rows.
12011202
DataFrame.head : Returns the first `n` rows.
12021203
DataFrame.toPandas : Returns the data as a pandas DataFrame.
1204+
DataFrame.toArrow : Returns the data as a PyArrow Table.
12031205
12041206
Notes
12051207
-----
@@ -6213,6 +6215,34 @@ def mapInArrow(
62136215
"""
62146216
...
62156217

6218+
@dispatch_df_method
6219+
def toArrow(self) -> "pa.Table":
6220+
"""
6221+
Returns the contents of this :class:`DataFrame` as PyArrow ``pyarrow.Table``.
6222+
6223+
This is only available if PyArrow is installed and available.
6224+
6225+
.. versionadded:: 4.0.0
6226+
6227+
Notes
6228+
-----
6229+
This method should only be used if the resulting PyArrow ``pyarrow.Table`` is
6230+
expected to be small, as all the data is loaded into the driver's memory.
6231+
6232+
This API is a developer API.
6233+
6234+
Examples
6235+
--------
6236+
>>> df.toArrow() # doctest: +SKIP
6237+
pyarrow.Table
6238+
age: int64
6239+
name: string
6240+
----
6241+
age: [[2,5]]
6242+
name: [["Alice","Bob"]]
6243+
"""
6244+
...
6245+
62166246
def toPandas(self) -> "PandasDataFrameLike":
62176247
"""
62186248
Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``.

python/pyspark/sql/pandas/conversion.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,15 +225,48 @@ def toPandas(self) -> "PandasDataFrameLike":
225225
else:
226226
return pdf
227227

228-
def _collect_as_arrow(self, split_batches: bool = False) -> List["pa.RecordBatch"]:
228+
def toArrow(self) -> "pa.Table":
229+
from pyspark.sql.dataframe import DataFrame
230+
231+
assert isinstance(self, DataFrame)
232+
233+
jconf = self.sparkSession._jconf
234+
235+
from pyspark.sql.pandas.types import to_arrow_schema
236+
from pyspark.sql.pandas.utils import require_minimum_pyarrow_version
237+
238+
require_minimum_pyarrow_version()
239+
to_arrow_schema(self.schema)
240+
241+
import pyarrow as pa
242+
243+
self_destruct = jconf.arrowPySparkSelfDestructEnabled()
244+
batches = self._collect_as_arrow(
245+
split_batches=self_destruct, empty_list_if_zero_records=False
246+
)
247+
table = pa.Table.from_batches(batches)
248+
# Ensure only the table has a reference to the batches, so that
249+
# self_destruct (if enabled) is effective
250+
del batches
251+
return table
252+
253+
def _collect_as_arrow(
254+
self,
255+
split_batches: bool = False,
256+
empty_list_if_zero_records: bool = True,
257+
) -> List["pa.RecordBatch"]:
229258
"""
230-
Returns all records as a list of ArrowRecordBatches, pyarrow must be installed
259+
Returns all records as a list of Arrow RecordBatches. PyArrow must be installed
231260
and available on driver and worker Python environments.
232261
This is an experimental feature.
233262
234263
:param split_batches: split batches such that each column is in its own allocation, so
235264
that the selfDestruct optimization is effective; default False.
236265
266+
:param empty_list_if_zero_records: If True (the default), returns an empty list if the
267+
result has 0 records. Otherwise, returns a list of length 1 containing an empty
268+
Arrow RecordBatch which includes the schema.
269+
237270
.. note:: Experimental.
238271
"""
239272
from pyspark.sql.dataframe import DataFrame
@@ -282,8 +315,15 @@ def _collect_as_arrow(self, split_batches: bool = False) -> List["pa.RecordBatch
282315
batches = results[:-1]
283316
batch_order = results[-1]
284317

285-
# Re-order the batch list using the correct order
286-
return [batches[i] for i in batch_order]
318+
if len(batches) or empty_list_if_zero_records:
319+
# Re-order the batch list using the correct order
320+
return [batches[i] for i in batch_order]
321+
else:
322+
from pyspark.sql.pandas.types import to_arrow_schema
323+
324+
schema = to_arrow_schema(self.schema)
325+
empty_arrays = [pa.array([], type=field.type) for field in schema]
326+
return [pa.RecordBatch.from_arrays(empty_arrays, schema=schema)]
287327

288328

289329
class SparkConversionMixin:

python/pyspark/sql/tests/test_arrow.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,35 @@ def create_pandas_data_frame(self):
179179
data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
180180
return pd.DataFrame(data=data_dict)
181181

182+
def create_arrow_table(self):
183+
import pyarrow as pa
184+
import pyarrow.compute as pc
185+
186+
data_dict = {}
187+
for j, name in enumerate(self.schema.names):
188+
data_dict[name] = [self.data[i][j] for i in range(len(self.data))]
189+
t = pa.Table.from_pydict(data_dict)
190+
# convert these to Arrow types
191+
new_schema = t.schema.set(
192+
t.schema.get_field_index("2_int_t"), pa.field("2_int_t", pa.int32())
193+
)
194+
new_schema = new_schema.set(
195+
new_schema.get_field_index("4_float_t"), pa.field("4_float_t", pa.float32())
196+
)
197+
new_schema = new_schema.set(
198+
new_schema.get_field_index("6_decimal_t"),
199+
pa.field("6_decimal_t", pa.decimal128(38, 18)),
200+
)
201+
t = t.cast(new_schema)
202+
# convert timestamp to local timezone
203+
timezone = self.spark.conf.get("spark.sql.session.timeZone")
204+
t = t.set_column(
205+
t.schema.get_field_index("8_timestamp_t"),
206+
"8_timestamp_t",
207+
pc.assume_timezone(t["8_timestamp_t"], timezone),
208+
)
209+
return t
210+
182211
@property
183212
def create_np_arrs(self):
184213
import numpy as np
@@ -339,6 +368,12 @@ def test_pandas_round_trip(self):
339368
pdf_arrow = df.toPandas()
340369
assert_frame_equal(pdf_arrow, pdf)
341370

371+
def test_arrow_round_trip(self):
372+
t_in = self.create_arrow_table()
373+
df = self.spark.createDataFrame(self.data, schema=self.schema)
374+
t_out = df.toArrow()
375+
self.assertTrue(t_out.equals(t_in))
376+
342377
def test_pandas_self_destruct(self):
343378
import pyarrow as pa
344379

0 commit comments

Comments
 (0)