Share via


Python user-defined table functions (UDTFs)

Important

This feature is in Public Preview in Databricks Runtime 14.3 LTS and above.

A user-defined table function (UDTF) allows you to register functions that return tables instead of scalar values. Unlike scalar functions that return a single result value from each call, each UDTF is invoked in a SQL statement's FROM clause and returns an entire table as output.

Each UDTF call can accept zero or more arguments. These arguments can be scalar expressions or table arguments representing entire input tables.

UDTFs can be registered in two ways:

Tip

Databricks recommends registering UDTFs in Unity Catalog to take advantage of centralized governance that makes it easier to securely share and reuse functions across users and teams.

Basic UDTF syntax

Apache Spark implements Python UDTFs as Python classes with a mandatory eval method that uses yield to emit output rows.

To use your class as a UDTF, you must import the PySpark udtf function. Databricks recommends using this function as a decorator and explicitly specifying field names and types using the returnType option (unless the class defines an analyze method as described in a later section).

The following UDTF creates a table using a fixed list of two integer arguments:

from pyspark.sql.functions import lit, udtf

@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
    def eval(self, x: int, y: int):
        yield x + y, x - y

GetSumDiff(lit(1), lit(2)).show()
+----+-----+
| sum| diff|
+----+-----+
|   3|   -1|
+----+-----+

Register a UDTF

To register a session-scoped UDTF for use in SQL queries, use spark.udtf.register(). Provide a name for the SQL function and the Python UDTF class.

spark.udtf.register("get_sum_diff", GetSumDiff)

Call a registered UDTF

Once registered, you can use the UDTF in SQL using either the %sql magic command or spark.sql() function:

spark.udtf.register("get_sum_diff", GetSumDiff)
spark.sql("SELECT * FROM get_sum_diff(1,2);").show()
%sql
SELECT * FROM get_sum_diff(1,2);

Upgrade a session-scoped UDTF to Unity Catalog

Important

Registering Python UDTFs in Unity Catalog is in Public Preview. Unity Catalog UDTFs require Databricks Runtime version 17.1 and above. See Requirements.

You can upgrade a session-scoped UDTF to Unity Catalog to take advantage of centralized governance and make it easier to securely share and reuse functions across users and teams.

To upgrade a session-scoped UDTF to Unity Catalog, use SQL DDL with the CREATE OR REPLACE FUNCTION statement. The following example shows how to convert the GetSumDiff UDTF from a session-scoped function to a Unity Catalog function:

CREATE OR REPLACE FUNCTION get_sum_diff(x INT, y INT)
RETURNS TABLE (sum INT, diff INT)
LANGUAGE PYTHON
HANDLER 'GetSumDiff'
AS $$
class GetSumDiff:
    def eval(self, x: int, y: int):
        yield x + y, x - y
$$;

SELECT * FROM get_sum_diff(10, 3);
+-----+------+
| sum | diff |
+-----+------+
| 13  | 7    |
+-----+------+

For more information about Unity Catalog UDTFs, see Python user-defined table functions (UDTFs) in Unity Catalog.

Use Apache Arrow

If your UDTF receives a small amount of data as input but outputs a large table, Databricks recommends using Apache Arrow. You can enable it by specifying the useArrow parameter when declaring the UDTF:

@udtf(returnType="c1: int, c2: int", useArrow=True)

Variable argument lists - *args and **kwargs

You can use Python *args or **kwargs syntax and implement logic to handle an unspecified number of input values.

The following example returns the same result while explicitly checking the input length and types for the arguments:

@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
    def eval(self, *args):
        assert(len(args) == 2)
        assert(isinstance(arg, int) for arg in args)
        x = args[0]
        y = args[1]
        yield x + y, x - y

GetSumDiff(lit(1), lit(2)).show()

Here is the same example, but using keyword arguments:

@udtf(returnType="sum: int, diff: int")
class GetSumDiff:
    def eval(self, **kwargs):
        x = kwargs["x"]
        y = kwargs["y"]
        yield x + y, x - y

GetSumDiff(x=lit(1), y=lit(2)).show()

Define a static schema at registration time

The UDTF returns rows with an output schema comprising an ordered sequence of column names and types. If the UDTF schema should always remain the same for all queries, you can specify a static, fixed schema after the @udtf decorator. It must either be a StructType:

StructType().add("c1", StringType())

Or a DDL string representing a struct type:

c1: string

Compute a dynamic schema at function call time

UDTFs can also compute the output schema programmatically for each call depending on the values of the input arguments. To do this, define a static method called analyze that accepts zero or more parameters that correspond to the arguments provided to the specific UDTF call.

Each argument of the analyze method is an instance of the AnalyzeArgument class which contains the following fields:

AnalyzeArgument class field Description
dataType The type of the input argument as a DataType. For input table arguments, this is a StructType representing the table's columns.
value The value of the input argument as an Optional[Any]. This is None for table arguments or literal scalar arguments that are not constant.
isTable Whether the input argument is a table as a BooleanType.
isConstantExpression Whether the input argument is a constant-foldable expression as a BooleanType.

The analyze method returns an instance of the AnalyzeResult class, which includes the result table's schema as a StructType plus some optional fields. If the UDTF accepts an input table argument, then the AnalyzeResult can also include a requested way to partition and order the rows of the input table across several UDTF calls, as described later.

AnalyzeResult class field Description
schema The schema of the result table as a StructType.
withSinglePartition Whether to send all input rows to the same UDTF class instance as a BooleanType.
partitionBy If set to non-empty, all rows with each unique combination of values of the partitioning expressions are consumed by a separate instance of the UDTF class.
orderBy If set to non-empty, this specifies an ordering of rows within each partition.
select If set to non-empty, this is a sequence of expressions that the UDTF is specifying for Catalyst to evaluate against the columns in the input TABLE argument. The UDTF receives one input attribute for each name in the list in the order they are listed.

This analyze example returns one output column for each word in the input string argument.

from pyspark.sql.functions import lit, udtf
from pyspark.sql.types import StructType, IntegerType
from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult


@udtf
class MyUDTF:
  @staticmethod
  def analyze(text: AnalyzeArgument) -> AnalyzeResult:
    schema = StructType()
    for index, word in enumerate(sorted(list(set(text.value.split(" "))))):
      schema = schema.add(f"word_{index}", IntegerType())
    return AnalyzeResult(schema=schema)

  def eval(self, text: str):
    counts = {}
    for word in text.split(" "):
      if word not in counts:
            counts[word] = 0
      counts[word] += 1
    result = []
    for word in sorted(list(set(text.split(" ")))):
      result.append(counts[word])
    yield result

MyUDTF(lit("hello world")).columns
['word_0', 'word_1']

Forward state to future eval calls

The analyze method can serve as a convenient place to perform initialization and then forward the results to future eval method invocations for the same UDTF call.

To do so, create a subclass of AnalyzeResult and return an instance of the subclass from the analyze method. Then, add an additional argument to the __init__ method to accept that instance.

This analyze example returns a constant output schema, but adds custom information in the result metadata to be consumed by future __init__ method calls:

from pyspark.sql.functions import lit, udtf
from pyspark.sql.types import StructType, IntegerType
from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult

@dataclass
class AnalyzeResultWithBuffer(AnalyzeResult):
    buffer: str = ""

@udtf
class TestUDTF:
  def __init__(self, analyze_result=None):
    self._total = 0
    if analyze_result is not None:
      self._buffer = analyze_result.buffer
    else:
      self._buffer = ""

  @staticmethod
  def analyze(argument, _) -> AnalyzeResult:
    if (
      argument.value is None
      or argument.isTable
      or not isinstance(argument.value, str)
      or len(argument.value) == 0
    ):
      raise Exception("The first argument must be a non-empty string")
    assert argument.dataType == StringType()
    assert not argument.isTable
    return AnalyzeResultWithBuffer(
      schema=StructType()
        .add("total", IntegerType())
        .add("buffer", StringType()),
      withSinglePartition=True,
      buffer=argument.value,
    )

  def eval(self, argument, row: Row):
    self._total += 1

  def terminate(self):
    yield self._total, self._buffer

spark.udtf.register("test_udtf", TestUDTF)

spark.sql(
  """
  WITH t AS (
    SELECT id FROM range(1, 21)
  )
  SELECT total, buffer
  FROM test_udtf("abc", TABLE(t))
  """
).show()
+-------+-------+
| count | buffer|
+-------+-------+
|    20 |  "abc"|
+-------+-------+

Yield output rows

The eval method runs once for each row of the input table argument (or just once if no table argument is provided), followed by one invocation of the terminate method at the end. Either method outputs zero or more rows that conform to the result schema by yielding tuples, lists, or pyspark.sql.Row objects.

This example returns a row by providing a tuple of three elements:

def eval(self, x, y, z):
  yield (x, y, z)

You can also omit the parentheses:

def eval(self, x, y, z):
  yield x, y, z

Add a trailing comma to return a row with only one column:

def eval(self, x, y, z):
  yield x,

You can also yield a pyspark.sql.Row object.

def eval(self, x, y, z):
  from pyspark.sql.types import Row
  yield Row(x, y, z)

This example yields output rows from the terminate method using a Python list. You can store state inside the class from earlier steps in the UDTF evaluation for this purpose.

def terminate(self):
  yield [self.x, self.y, self.z]

Pass scalar arguments to a UDTF

You can pass scalar arguments to a UDTF as constant expressions comprising literal values or functions based on them. For example:

SELECT * FROM get_sum_diff(1, y => 2)

Pass table arguments to a UDTF

Python UDTFs can accept an input table as an argument in addition to scalar input arguments. A single UDTF can also accept a table argument and multiple scalar arguments.

Then any SQL query can provide an input table using the TABLE keyword followed by parentheses surrounding an appropriate table identifier, like TABLE(t). Alternatively, you can pass a table subquery, like TABLE(SELECT a, b, c FROM t) or TABLE(SELECT t1.a, t2.b FROM t1 INNER JOIN t2 USING (key)).

The input table argument is then represented as a pyspark.sql.Row argument to the eval method, with one call to the eval method for each row in the input table. You can use standard PySpark column field annotations to interact with columns in each row. The following example demonstrates explicitly importing the PySpark Row type and then filtering the passed table on the id field:

from pyspark.sql.functions import udtf
from pyspark.sql.types import Row

@udtf(returnType="id: int")
class FilterUDTF:
    def eval(self, row: Row):
        if row["id"] > 5:
            yield row["id"],

spark.udtf.register("filter_udtf", FilterUDTF)

To query the function, use the TABLE SQL keyword:

SELECT * FROM filter_udtf(TABLE(SELECT * FROM range(10)));
+---+
| id|
+---+
|  6|
|  7|
|  8|
|  9|
+---+

Specify a partitioning of the input rows from function calls

When calling a UDTF with a table argument, any SQL query can partition the input table across several UDTF calls based on the values of one or more input table columns.

To specify a partition, use the PARTITION BY clause in the function call after the TABLE argument. This guarantees that all input rows with each unique combination of values of the partitioning columns will get consumed by exactly one instance of the UDTF class.

Note that in addition to simple column references, the PARTITION BY clause also accepts arbitrary expressions based on input table columns. For example, you can specify the LENGTH of a string, extract a month from a date, or concatenate two values.

It is also possible to specify WITH SINGLE PARTITION instead of PARTITION BY to request only one partition wherein all input rows must be consumed by exactly one instance of the UDTF class.

Within each partition, you can optionally specify a required ordering of the input rows as the UDTF's eval method consumes them. To do so, provide an ORDER BY clause after the PARTITION BY or WITH SINGLE PARTITION clause described above.

For example, consider the following UDTF:

from pyspark.sql.functions import udtf
from pyspark.sql.types import Row

@udtf(returnType="a: string, b: int")
class FilterUDTF:
  def __init__(self):
    self.key = ""
    self.max = 0

  def eval(self, row: Row):
    self.key = row["a"]
    self.max = max(self.max, row["b"])

  def terminate(self):
    yield self.key, self.max

spark.udtf.register("filter_udtf", FilterUDTF)

You can specify partitioning options when calling the UDTF over the input table in multiple ways:

-- Create an input table with some example values.
DROP TABLE IF EXISTS values_table;
CREATE TABLE values_table (a STRING, b INT);
INSERT INTO values_table VALUES ('abc', 2), ('abc', 4), ('def', 6), ('def', 8);
SELECT * FROM values_table;
+-------+----+
|     a |  b |
+-------+----+
| "abc" | 2  |
| "abc" | 4  |
| "def" | 6  |
| "def" | 8  |
+-------+----+
-- Query the UDTF with the input table as an argument and a directive to partition the input
-- rows such that all rows with each unique value in the `a` column are processed by the same
-- instance of the UDTF class. Within each partition, the rows are ordered by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) PARTITION BY a ORDER BY b) ORDER BY 1;
+-------+----+
|     a |  b |
+-------+----+
| "abc" | 4  |
| "def" | 8  |
+-------+----+

-- Query the UDTF with the input table as an argument and a directive to partition the input
-- rows such that all rows with each unique result of evaluating the "LENGTH(a)" expression are
-- processed by the same instance of the UDTF class. Within each partition, the rows are ordered
-- by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) PARTITION BY LENGTH(a) ORDER BY b) ORDER BY 1;
+-------+---+
|     a | b |
+-------+---+
| "def" | 8 |
+-------+---+
-- Query the UDTF with the input table as an argument and a directive to consider all the input
-- rows in one single partition such that exactly one instance of the UDTF class consumes all of
-- the input rows. Within each partition, the rows are ordered by the `b` column.
SELECT * FROM filter_udtf(TABLE(values_table) WITH SINGLE PARTITION ORDER BY b) ORDER BY 1;
+-------+----+
|     a |  b |
+-------+----+
| "def" | 8 |
+-------+----+

Specify a partitioning of the input rows from the analyze method

Note that for each of the above ways of partitioning the input table when calling UDTFs in SQL queries, there is a corresponding way for the UDTF's analyze method to specify the same partitioning method automatically instead.

  • Instead of calling a UDTF as SELECT * FROM udtf(TABLE(t) PARTITION BY a), you can update the analyze method to set the field partitionBy=[PartitioningColumn("a")] and simply call the function using SELECT * FROM udtf(TABLE(t)).
  • By the same token, instead of specifying TABLE(t) WITH SINGLE PARTITION ORDER BY b in the SQL query, you can make analyze set the fields withSinglePartition=true and orderBy=[OrderingColumn("b")] and then just pass TABLE(t).
  • Instead of passing TABLE(SELECT a FROM t) in the SQL query, you can make analyze set select=[SelectedColumn("a")] and then just pass TABLE(t).

In the following example, analyze returns a constant output schema, selects a subset of columns from the input table, and specifies that the input table is partitioned across several UDTF calls based on the values of the date column:

@staticmethod
def analyze(*args) -> AnalyzeResult:
  """
  The input table will be partitioned across several UDTF calls based on the monthly
  values of each `date` column. The rows within each partition will arrive ordered by the `date`
  column. The UDTF will only receive the `date` and `word` columns from the input table.
  """
  from pyspark.sql.functions import (
    AnalyzeResult,
    OrderingColumn,
    PartitioningColumn,
  )

  assert len(args) == 1, "This function accepts one argument only"
  assert args[0].isTable, "Only table arguments are supported"
  return AnalyzeResult(
    schema=StructType()
      .add("month", DateType())
      .add("longest_word", IntegerType()),
    partitionBy=[
      PartitioningColumn("extract(month from date)")],
    orderBy=[
      OrderingColumn("date")],
    select=[
      SelectedColumn("date"),
      SelectedColumn(
        name="length(word)",
        alias="length_word")])