Source code for sparkly.session

#
# Copyright 2017 Tubular Labs, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import signal
import sys

from pyspark import SparkContext
from pyspark.sql import SparkSession

from sparkly.catalog import SparklyCatalog
from sparkly.instant_testing import InstantTesting
from sparkly.reader import SparklyReader
from sparkly.writer import attach_writer_to_dataframe


[docs]class SparklySession(SparkSession): """Wrapper around SparkSession to simplify definition of options, packages, JARs and UDFs. Example:: from pyspark.sql.types import IntegerType import sparkly class MySession(sparkly.SparklySession): options = {'spark.sql.shuffle.partitions': '2000'} repositories = ['http://packages.confluent.io/maven/'] packages = ['com.databricks:spark-csv_2.10:1.4.0'] jars = ['../path/to/brickhouse-0.7.1.jar'] udfs = { 'collect_max': 'brickhouse.udf.collect.CollectMaxUDAF', 'my_python_udf': (lambda x: len(x), IntegerType()), } spark = MySession() spark.read_ext.cassandra(...) # Alternatively spark = MySession.get_or_create() spark.read_ext.cassandra(...) Attributes: options (dict[str,str]): Configuration options that are passed to spark-submit. See `the list of possible options <https://spark.apache.org/docs/2.1.0/configuration.html#available-properties>`_. Note that any options set already through PYSPARK_SUBMIT_ARGS will override these. repositories (list[str]): List of additional maven repositories for package lookup. packages (list[str]): Spark packages that should be installed. See https://spark-packages.org/ jars (list[str]): Full paths to jar files that we want to include to the session. E.g. a JDBC connector or a library with UDF functions. udfs (dict[str,str|typing.Callable]): Register UDF functions within the session. Key - a name of the function, Value - either a class name imported from a JAR file or a tuple with python function and its return type. """ options = {} packages = [] jars = [] udfs = {} repositories = [] _instantiated_session = None def __init__(self, additional_options=None): os.environ['PYSPARK_PYTHON'] = sys.executable submit_args = [ # options that were already defined through PYSPARK_SUBMIT_ARGS # take precedence over SparklySession's os.environ.get('PYSPARK_SUBMIT_ARGS', '').replace('pyspark-shell', ''), self._setup_repositories(), self._setup_packages(), self._setup_jars(), self._setup_options(additional_options), 'pyspark-shell', ] os.environ['PYSPARK_SUBMIT_ARGS'] = ' '.join(filter(None, submit_args)) # If we are in instant testing mode if InstantTesting.is_activated(): spark_context = InstantTesting.get_context() # It's the first run, so we have to create context and demonise the process. if spark_context is None: spark_context = SparkContext() if os.fork() == 0: # Detached process. signal.pause() else: InstantTesting.set_context(spark_context) else: spark_context = SparkContext() # Init HiveContext super(SparklySession, self).__init__(spark_context) self._setup_udfs() self.read_ext = SparklyReader(self) self.catalog_ext = SparklyCatalog(self) attach_writer_to_dataframe() SparklySession._instantiated_session = self @classmethod
[docs] def get_or_create(cls): """Access instantiated sparkly session. If sparkly session has already been instantiated, return that instance; if not, then instantiate one and return it. Useful for lazy access to the session. Not thread-safe. Returns: SparklySession (or subclass). """ if SparklySession._instantiated_session is None: cls() return SparklySession._instantiated_session
@classmethod
[docs] def stop(cls): """Stop instantiated sparkly session.""" if SparklySession._instantiated_session is not None: SparkSession.stop(SparklySession._instantiated_session) SparklySession._instantiated_session = None
@property def builder(self): raise NotImplementedError( 'You do not need a builder for SparklySession. ' 'Just use a regular python constructor. ' 'Please, follow the documentation for more details.' )
[docs] def has_package(self, package_prefix): """Check if the package is available in the session. Args: package_prefix (str): E.g. "org.elasticsearch:elasticsearch-spark". Returns: bool """ return any(package for package in self.packages if package.startswith(package_prefix))
[docs] def has_jar(self, jar_name): """Check if the jar is available in the session. Args: jar_name (str): E.g. "mysql-connector-java". Returns: bool """ return any(jar for jar in self.jars if jar_name in jar)
def _setup_repositories(self): if self.repositories: return '--repositories {}'.format(','.join(self.repositories)) else: return '' def _setup_packages(self): if self.packages: return '--packages {}'.format(','.join(self.packages)) else: return '' def _setup_jars(self): if self.jars: return '--jars {}'.format(','.join(self.jars)) else: return '' def _setup_options(self, additional_options): options = {} options.update(self.options) if additional_options: options.update(additional_options) if 'spark.sql.catalogImplementation' not in options: options['spark.sql.catalogImplementation'] = 'hive' # Here we massage conf properties with the intent to pass them to # spark-submit; this is convenient as it is unified with the approach # we take for repos, packages and jars, and it also handles precedence # of conf properties already defined by the user in a very # straightforward way (since we always append to PYSPARK_SUBMIT_ARGS) return ' '.join('--conf "{}={}"'.format(*o) for o in sorted(options.items())) def _setup_udfs(self): for name, defn in self.udfs.items(): if isinstance(defn, str): self.sql('create temporary function {} as "{}"'.format(name, defn)) elif isinstance(defn, tuple): self.catalog.registerFunction(name, *defn) else: raise NotImplementedError('Incorrect UDF definition: {}: {}'.format(name, defn))