Source code for sparkly.functions

#
# Copyright 2017 Tubular Labs, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from collections import defaultdict
from functools import reduce

from pyspark.sql import Column
from pyspark.sql import functions as F


[docs]def multijoin(dfs, on=None, how=None, coalesce=None): """Join multiple dataframes. Args: dfs (list[pyspark.sql.DataFrame]). on: same as ``pyspark.sql.DataFrame.join``. how: same as ``pyspark.sql.DataFrame.join``. coalesce (list[str]): column names to disambiguate by coalescing across the input dataframes. A column must be of the same type across all dataframes that define it; if different types appear coalesce will do a best-effort attempt in merging them. The selected value is the first non-null one in order of appearance of the dataframes in the input list. Default is None - don't coalesce any ambiguous columns. Returns: pyspark.sql.DataFrame or None if provided dataframe list is empty. Example: Assume we have two DataFrames, the first is ``first = [{'id': 1, 'value': None}, {'id': 2, 'value': 2}]`` and the second is ``second = [{'id': 1, 'value': 1}, {'id': 2, 'value': 22}]`` Then collecting the ``DataFrame`` produced by ``multijoin([first, second], on='id', how='inner', coalesce=['value'])`` yields ``[{'id': 1, 'value': 1}, {'id': 2, 'value': 2}]``. """ if not dfs: return None # Go over the input dataframes and rename each to-be-resolved # column to ensure name uniqueness coalesce = set(coalesce or []) renamed_columns = defaultdict(list) for idx, df in enumerate(dfs): for col in df.columns: if col in coalesce: disambiguation = '__{}_{}'.format(idx, col) df = df.withColumnRenamed(col, disambiguation) renamed_columns[col].append(disambiguation) dfs[idx] = df # Join the dataframes joined_df = reduce(lambda x, y: x.join(y, on=on, how=how), dfs) # And coalesce the would-have-been-ambiguities for col, disambiguations in renamed_columns.items(): joined_df = joined_df.withColumn(col, F.coalesce(*disambiguations)) for disambiguation in disambiguations: joined_df = joined_df.drop(disambiguation) return joined_df
[docs]def switch_case(switch, case=None, default=None, **additional_cases): """Switch/case style column generation. Args: switch (str, pyspark.sql.Column): column to "switch" on; its values are going to be compared against defined cases. case (dict): case statements. When a key matches the value of the column in a specific row, the respective value will be assigned to the new column for that row. This is useful when your case condition constants are not strings. default: default value to be used when the value of the switch column doesn't match any keys. additional_cases: additional "case" statements, kwargs style. Same semantics with cases above. If both are provided, cases takes precedence. Returns: pyspark.sql.Column Example: ``switch_case('state', CA='California', NY='New York', default='Other')`` is equivalent to >>> F.when( ... F.col('state') == 'CA', 'California' ).when( ... F.col('state') == 'NY', 'New York' ).otherwise('Other') """ if not isinstance(switch, Column): switch = F.col(switch) def _column_or_lit(x): return F.lit(x) if not isinstance(x, Column) else x def _execute_case(accumulator, case): # transform the case to a pyspark.sql.functions.when statement, # then chain it to existing when statements condition_constant, assigned_value = case when_args = (switch == F.lit(condition_constant), _column_or_lit(assigned_value)) return accumulator.when(*when_args) cases = case or {} for conflict in set(cases.keys()) & set(additional_cases.keys()): del additional_cases[conflict] cases = list(cases.items()) + list(additional_cases.items()) default = _column_or_lit(default) if not cases: return default result = reduce(_execute_case, cases, F).otherwise(default) return result