#
# Copyright 2017 Tubular Labs, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import sys
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession
from sparkly.catalog import SparklyCatalog
from sparkly.reader import SparklyReader
from sparkly.writer import attach_writer_to_dataframe
[docs]class SparklySession(SparkSession):
"""Wrapper around HiveContext to simplify definition of options, packages, JARs and UDFs.
Example::
from pyspark.sql.types import IntegerType
import sparkly
class MySession(sparkly.SparklySession):
options = {'spark.sql.shuffle.partitions': '2000'}
repositories = ['http://packages.confluent.io/maven/']
packages = ['com.databricks:spark-csv_2.10:1.4.0']
jars = ['../path/to/brickhouse-0.7.1.jar']
udfs = {
'collect_max': 'brickhouse.udf.collect.CollectMaxUDAF',
'my_python_udf': (lambda x: len(x), IntegerType()),
}
spark = MySession()
spark.read_ext.cassandra(...)
Attributes:
options (dict[str,str]): Configuration options that are passed to SparkConf.
See `the list of possible options
<https://spark.apache.org/docs/2.1.0/configuration.html#available-properties>`_.
repositories (list[str]): List of additional maven repositories for package lookup.
packages (list[str]): Spark packages that should be installed.
See https://spark-packages.org/
jars (list[str]): Full paths to jar files that we want to include to the session.
E.g. a JDBC connector or a library with UDF functions.
udfs (dict[str,str|typing.Callable]): Register UDF functions within the session.
Key - a name of the function,
Value - either a class name imported from a JAR file
or a tuple with python function and its return type.
"""
options = {}
packages = []
jars = []
udfs = {}
repositories = []
def __init__(self, additional_options=None):
os.environ['PYSPARK_PYTHON'] = sys.executable
submit_args = [
self._setup_repositories(),
self._setup_packages(),
self._setup_jars(),
'pyspark-shell',
]
os.environ['PYSPARK_SUBMIT_ARGS'] = ' '.join(filter(None, submit_args))
# Init SparkContext
spark_conf = SparkConf()
spark_conf.set('spark.sql.catalogImplementation', 'hive')
spark_conf.setAll(self._setup_options(additional_options))
spark_context = SparkContext(conf=spark_conf)
# Init HiveContext
super(SparklySession, self).__init__(spark_context)
self._setup_udfs()
self.read_ext = SparklyReader(self)
self.catalog_ext = SparklyCatalog(self)
attach_writer_to_dataframe()
@property
def builder(self):
raise NotImplementedError(
'You do not need a builder for SparklySession. '
'Just use a regular python constructor. '
'Please, follow the documentation for more details.'
)
[docs] def has_package(self, package_prefix):
"""Check if the package is available in the session.
Args:
package_prefix (str): E.g. "org.elasticsearch:elasticsearch-spark".
Returns:
bool
"""
return any(package for package in self.packages if package.startswith(package_prefix))
[docs] def has_jar(self, jar_name):
"""Check if the jar is available in the session.
Args:
jar_name (str): E.g. "mysql-connector-java".
Returns:
bool
"""
return any(jar for jar in self.jars if jar_name in jar)
def _setup_repositories(self):
if self.repositories:
return '--repositories {}'.format(','.join(self.repositories))
else:
return ''
def _setup_packages(self):
if self.packages:
return '--packages {}'.format(','.join(self.packages))
else:
return ''
def _setup_jars(self):
if self.jars:
return '--jars {}'.format(','.join(self.jars))
else:
return ''
def _setup_options(self, additional_options):
options = list(self.options.items())
if additional_options:
options += list(additional_options.items())
return sorted(options)
def _setup_udfs(self):
for name, defn in self.udfs.items():
if isinstance(defn, str):
self.sql('create temporary function {} as "{}"'.format(name, defn))
elif isinstance(defn, tuple):
self.catalog.registerFunction(name, *defn)
else:
raise NotImplementedError('Incorrect UDF definition: {}: {}'.format(name, defn))