pyspark.sql.GroupedData.applyInPandas#

GroupedData.applyInPandas(func, schema)#

Maps each group of the current DataFrame using a pandas udf and returns the result as a DataFrame.

The function can take one of two forms: It can take a pandas.DataFrame and return a pandas.DataFrame, or it can take an iterator of pandas.DataFrame and yield pandas.DataFrame. Alternatively each form can take a tuple of grouping keys as the first argument in addition to the input type above. For each group, all columns are passed together as a pandas.DataFrame or iterator of pandas.DataFrame, and the returned pandas.DataFrame or iterator of pandas.DataFrame are combined as a DataFrame.

The schema should be a StructType describing the schema of the returned pandas.DataFrame. The column labels of the returned pandas.DataFrame must either match the field names in the defined schema if specified as strings, or match the field data types by position if not strings, e.g. integer indices. The length of the returned pandas.DataFrame can be arbitrary.

New in version 3.0.0.

Changed in version 3.4.0: Support Spark Connect.

Changed in version 4.1.0: Added support for an iterator of pandas.DataFrame API.

Parameters
funcfunction

a Python native function that either takes a pandas.DataFrame and outputs a pandas.DataFrame or takes an iterator of pandas.DataFrame and yields pandas.DataFrame. Additionally, each form can take a tuple of grouping keys as the first argument, with the pandas.DataFrame or iterator of pandas.DataFrame as the second argument.

schemapyspark.sql.types.DataType or str

the return type of the func in PySpark. The value can be either a pyspark.sql.types.DataType object or a DDL-formatted type string.

Notes

This function requires a full shuffle. If using the pandas.DataFrame API, all data of a group will be loaded into memory, so the user should be aware of the potential OOM risk if data is skewed and certain groups are too large to fit in memory, and can use the iterator of pandas.DataFrame API to mitigate this.

Examples

>>> import pandas as pd  
>>> from pyspark.sql.functions import ceil
>>> df = spark.createDataFrame(
...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
...     ("id", "v"))  
>>> def normalize(pdf):
...     v = pdf.v
...     return pdf.assign(v=(v - v.mean()) / v.std())
...
>>> df.groupby("id").applyInPandas(
...     normalize, schema="id long, v double").show()  
+---+-------------------+
| id|                  v|
+---+-------------------+
|  1|-0.7071067811865475|
|  1| 0.7071067811865475|
|  2|-0.8320502943378437|
|  2|-0.2773500981126146|
|  2| 1.1094003924504583|
+---+-------------------+

Alternatively, the user can pass a function that takes two arguments. In this case, the grouping key(s) will be passed as the first argument and the data will be passed as the second argument. The grouping key(s) will be passed as a tuple of numpy data types, e.g., numpy.int32 and numpy.float64. The data will still be passed in as a pandas.DataFrame containing all columns from the original Spark DataFrame. This is useful when the user does not want to hardcode grouping key(s) in the function.

>>> df = spark.createDataFrame(
...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
...     ("id", "v"))  
>>> def mean_func(key, pdf):
...     # key is a tuple of one numpy.int64, which is the value
...     # of 'id' for the current group
...     return pd.DataFrame([key + (pdf.v.mean(),)])
...
>>> df.groupby('id').applyInPandas(
...     mean_func, schema="id long, v double").show()  
+---+---+
| id|  v|
+---+---+
|  1|1.5|
|  2|6.0|
+---+---+
>>> def sum_func(key, pdf):
...     # key is a tuple of two numpy.int64s, which is the values
...     # of 'id' and 'ceil(df.v / 2)' for the current group
...     return pd.DataFrame([key + (pdf.v.sum(),)])
...
>>> df.groupby(df.id, ceil(df.v / 2)).applyInPandas(
...     sum_func, schema="id long, `ceil(v / 2)` long, v double").show()  
+---+-----------+----+
| id|ceil(v / 2)|   v|
+---+-----------+----+
|  2|          5|10.0|
|  1|          1| 3.0|
|  2|          3| 5.0|
|  2|          2| 3.0|
+---+-----------+----+

The function can also take and return an iterator of pandas.DataFrame using type hints.

>>> from typing import Iterator  
>>> df = spark.createDataFrame(
...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
...     ("id", "v"))  
>>> def filter_func(
...     batches: Iterator[pd.DataFrame]
... ) -> Iterator[pd.DataFrame]:  
...     for batch in batches:
...         # Process and yield each batch independently
...         filtered = batch[batch['v'] > 2.0]
...         if not filtered.empty:
...             yield filtered[['v']]
>>> df.groupby("id").applyInPandas(
...     filter_func, schema="v double").show()  
+----+
|   v|
+----+
| 3.0|
| 5.0|
|10.0|
+----+

Alternatively, the user can pass a function that takes two arguments. In this case, the grouping key(s) will be passed as the first argument and the data will be passed as the second argument. The grouping key(s) will be passed as a tuple of numpy data types. The data will still be passed in as an iterator of pandas.DataFrame.

>>> from typing import Iterator, Tuple, Any  
>>> def transform_func(
...     key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]
... ) -> Iterator[pd.DataFrame]:  
...     for batch in batches:
...         # Yield transformed results for each batch
...         result = batch.assign(id=key[0], v_doubled=batch['v'] * 2)
...         yield result[['id', 'v_doubled']]
>>> df.groupby("id").applyInPandas(
...     transform_func, schema="id long, v_doubled double").show()  
+---+----------+
| id|v_doubled |
+---+----------+
|  1|       2.0|
|  1|       4.0|
|  2|       6.0|
|  2|      10.0|
|  2|      20.0|
+---+----------+