Source code for sparkly.testing

# Copyright 2017 Tubular Labs, Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import contextlib
import copy
import difflib
from functools import partial, total_ordering
import importlib
import json
import logging
import math
import operator
import os
import pprint
import shutil
import sys
import tempfile
from unittest import TestCase
from unittest.util import safe_repr
import warnings

from pyspark.sql import types as T
import six

from sparkly import SparklySession
from sparkly.exceptions import FixtureError
from sparkly.utils import kafka_get_topics_offsets

if sys.version_info.major == 3:
    from http.client import HTTPConnection
    from httplib import HTTPConnection

    from cassandra.cluster import Cluster
except ImportError:

    import pymysql as connector
except ImportError:
        import mysql.connector as connector
    except ImportError:

    from kafka import KafkaProducer, SimpleClient
except ImportError:

logger = logging.getLogger()

_test_session_cache = None

[docs]class SparklyTest(TestCase): """Base test for spark scrip tests. Initialize and shut down Session specified in `session` attribute. Example: >>> from pyspark.sql import types as T >>> class MyTestCase(SparklyTest): ... def test(self): ... self.assertRowsEqual( ... self.spark.sql('SELECT 1 as one').collect(), ... [T.Row(one=1)], ... ) """ session = SparklySession class_fixtures = [] fixtures = [] maxDiff = None # (str|None) import the function/class to be tested programmatically # by specifying the path to it here, e.g. 'module_a.submodule_b.my_func' test_target = None @classmethod def setup_session(cls): return cls.session({ # Use in-memory hive metastore (faster tests). 'spark.hadoop.javax.jdo.option.ConnectionURL': 'jdbc:derby:memory:databaseName=metastore_db;create=true', 'spark.hadoop.javax.jdo.option.ConnectionDriverName': 'org.apache.derby.jdbc.EmbeddedDriver', # Isolate the warehouse inside of a random temporary directory (no side effects). 'spark.sql.warehouse.dir': tempfile.mkdtemp(suffix='sparkly'), # Reduce number of shuffle partitions (faster tests). 'spark.sql.shuffle.partitions': '4', }) @classmethod def _init_session(cls): # In case if project has a mix of SparklyTest and SparklyGlobalContextTest-based tests global _test_session_cache if _test_session_cache:'Found a global session, stopping it %r', _test_session_cache) _test_session_cache.stop() _test_session_cache = None cls.spark = cls.setup_session() @classmethod def setUpClass(cls): super(SparklyTest, cls).setUpClass() cls._init_session() for fixture in cls.class_fixtures: fixture.setup_data() # HACK: When pyspark.sql.functions.udf is used as a decorator # it is evaluated on import time; this has the side effect of # creating a spark context if one doesn't exist, messing up # with this class creating its own for test purposes. As a # result, any transformations to be tested are imported here # programmatically after the test class initialization if # the user wishes. if not cls.test_target: return test_module_path, test_target = cls.test_target.rsplit('.', 1) test_module = importlib.import_module(test_module_path) setattr(sys.modules[cls.__module__], test_target, getattr(test_module, test_target)) @classmethod def tearDownClass(cls): cls.spark.stop() super(SparklyTest, cls).tearDownClass() for fixture in cls.class_fixtures: fixture.teardown_data() def setUp(self): for fixture in self.fixtures: fixture.setup_data() def tearDown(self): for fixture in self.fixtures: fixture.teardown_data()
[docs] def assertDataFrameEqual(self, actual_df, expected_data, fields=None, ordered=False): """Ensure that DataFrame has the right data inside. ``assertDataFrameEqual`` is being deprecated. Please use ``assertRowsEqual`` instead. Args: actual_df (pyspark.sql.DataFrame|list[pyspark.sql.Row]): Dataframe to test data in. expected_data (list[dict]): Expected dataframe rows defined as dicts. fields (list[str]): Compare only certain fields. ordered (bool): Does order of rows matter? """ warnings.warn( 'assertDataFrameEqual is being deprecated. Please use assertRowsEqual instead.', DeprecationWarning, ) if fields: actual_df =*fields) actual_rows = actual_df.collect() if hasattr(actual_df, 'collect') else actual_df actual_data = [row.asDict(recursive=True) for row in actual_rows] return self.assertRowsEqual( actual_data, expected_data, ignore_order=not ordered, ignore_order_depth=1, atol=0, rtol=0, equal_nan=False, ignore_nullability=False, )
[docs] def assertRowsEqual(self, first, second, msg=None, # ordering parameters ignore_order=True, ignore_order_depth=None, # float comparison parameters atol=0, rtol=1e-07, equal_nan=True, # DataType comparison parameters ignore_nullability=True): """Assert equal on steroids. Extend this classic function signature to work better with comparisons involving rows, datatypes, dictionaries, lists and floats by: - ignoring the order of lists and datatypes recursively, - comparing floats within a given tolerance, - assuming NaNs are equal, - ignoring the nullability requirements of datatypes (since Spark can be inaccurate when inferring it), - providing better diffs for rows and datatypes. Float comparisons are inspired by NumPy's ``assert_allclose``. The main formula used is ``| float1 - float2 | <= atol + rtol * float2``. Args: first: see ``unittest.TestCase.assertEqual``. second: see ``unittest.TestCase.assertEqual``. msg: see ``unittest.TestCase.assertEqual``. ignore_order (bool|True): ignore the order in lists and datatypes (rows, dicts are inherently orderless). ignore_order_depth (int|None): if ignore_order is true, do ignore order up to this level of nested lists or datatypes (exclusive). Setting this to 0 or None means ignore order infinitely, 1 means ignore order only at the top level, 2 will ignore order within lists of lists and so on. Default is ignore order arbitrarily deep. atol (int, float|0): Absolute tolerance in float comparisons. rtol (int, float|1e-07): Relative tolerance in float comparisons. equal_nan (bool|True): If set, NaNs will compare equal. ignore_nullability (bool|True): If set, ignore all nullability fields in dataTypes. This includes ``containsNull`` in arrays, ``valueContainsNull`` in maps and ``nullable`` in struct fields. Returns: None iff the two objects are equal. Raises AssertionError: iff the two objects are not equal. See ``unittest.TestCase.assertEqual`` for details. """ # Our approach here is to redefine the 5 container objects that # this function expects to work with - floats, dataTypes, rows, # dicts and lists - to introduce generic ordering and extend # the meaning of equality where applicable. We can then change # all such objects to our custom containers, provide new asserters # for some, and then feed them to the vanilla assertEqual. # Define our custom containers def cast_to_test_friendly_container(value, ignore_order_depth): if isinstance(value, float): return Float(value) if isinstance(value, T.DataType): return DataType(value, ignore_order_depth) if isinstance(value, T.Row): return Row(value, ignore_order_depth) if isinstance(value, dict): return Dict(value, ignore_order_depth) if isinstance(value, list): return List(value, ignore_order_depth) return value @total_ordering class Float(object): def __init__(self, f): self._f = f def __eq__(self, other): return other is not None and ( (equal_nan and math.isnan(self._f) and math.isnan(other._f)) or abs(self._f - other._f) <= atol + rtol * abs(other._f) ) def __lt__(self, other): return self._f != other._f and self._f < other._f def __repr__(self): return repr(self._f) @total_ordering class DataType(object): def __init__(self, dt, ignore_order_depth=0): # update recursively all T.StructTypes to define their # fields in sorted order def _sort_structs(dt, ignore_order_depth): if ignore_order_depth == 0: return dt if dt.typeName() == 'array': return T.ArrayType( elementType=_sort_structs(dt.elementType, ignore_order_depth), containsNull=ignore_nullability or dt.containsNull, ) if dt.typeName() == 'map': return T.MapType( keyType=_sort_structs(dt.keyType, ignore_order_depth), valueType=_sort_structs(dt.valueType, ignore_order_depth), valueContainsNull=ignore_nullability or dt.valueContainsNull, ) if dt.typeName() == 'struct': return T.StructType([ _sort_structs(f, ignore_order_depth - 1) for f in sorted(dt.fields, key=lambda f: ]) if dt.typeName() == 'structf': return T.StructField(, _sort_structs(dt.dataType, ignore_order_depth), nullable=ignore_nullability or dt.nullable, metadata=dt.metadata, ) return dt self._dt = _sort_structs(dt, ignore_order_depth) def __eq__(self, other): return self._dt == other._dt def __ne__(self, other): # Only needed for Py27... return self._dt != other._dt def __lt__(self, other): return other is not None and repr(self._dt) < repr(other._dt) def __repr__(self): return repr(self._dt) def pretty_repr(self): # useful to get a nice diff later return pprint.pformat(self._dt.jsonValue()).splitlines() @total_ordering class Row(collections.OrderedDict): def __init__(self, row, ignore_order_depth=0): super(Row, self).__init__( (field, cast_to_test_friendly_container(row[field], ignore_order_depth)) # Rows currently store their fields in order either # way but we ensure this is the case here too for field in sorted(row.__fields__) ) def __lt__(self, other): return other is not None and ( List(zip(self.keys(), self.values())) < List(zip(other.keys(), other.values())) ) def __repr__(self): return 'Row({})'.format(', '.join(['{!r}={!r}'.format(*i) for i in self.items()])) @total_ordering class Dict(collections.OrderedDict): def __init__(self, dictionary, ignore_order_depth=0): super(Dict, self).__init__( sorted([ ( cast_to_test_friendly_container(k, ignore_order_depth), cast_to_test_friendly_container(v, ignore_order_depth), ) for k, v in dictionary.items() ]) ) def __lt__(self, other): return other is not None and ( List(zip(self.keys(), self.values())) < List(zip(other.keys(), other.values())) ) def __repr__(self): return '{{{}}}'.format(', '.join(['{!r}: {!r}'.format(*i) for i in self.items()])) @total_ordering class List(list): def __init__(self, sequence, ignore_order_depth=0): if ignore_order_depth == 0: _sort_or_pass = lambda l: l else: ignore_order_depth = ignore_order_depth - 1 _sort_or_pass = sorted super(List, self).__init__( _sort_or_pass([ cast_to_test_friendly_container(v, ignore_order_depth) for v in sequence ]) ) def __lt__(self, other): # None is not a nice value to compare to when trying to # order things, as TypeErrors are raised. Instead, we # transform it to a tuple - if the first entry doesn't # match (is it None?) we don't need to compare further def _neutralize_none(entry): if isinstance(entry, tuple): return entry is None, tuple(_neutralize_none(e) for e in entry) return entry is None, entry return [_neutralize_none(e) for e in self] < [_neutralize_none(e) for e in other] # Define new equality asserters for floats, rows and datatypes def assert_float_equal(self, float1, float2, msg=None): if not isinstance(float1, Float): float1 = Float(float1) if not isinstance(float2, Float): float2 = Float(float2) if float1 == float2: return if not atol and not rtol: standard_msg = '{} != {}'.format(float1, float2) else: standard_msg = ( '{} != {} within absolute tolerance {} and relative tolerance {}' .format(float1, float2, atol, rtol) ), standard_msg)) def assert_row_equal(self, row1, row2, msg=None): self.assertEqual(Row(row1), Row(row2), msg) def assert_datatype_equal(self, dt1, dt2, msg=None): if not isinstance(dt1, DataType): dt1 = DataType(dt1) if not isinstance(dt2, DataType): dt2 = DataType(dt2) if dt1 != dt2: standard_msg = '{} != {}'.format(safe_repr(dt1, True), safe_repr(dt2, True)) diff = '\n' + '\n'.join(difflib.ndiff(dt1.pretty_repr(), dt2.pretty_repr())) standard_msg = self._truncateMessage(standard_msg, diff), standard_msg)) # Create a context manager to temporarily register our asserters, # then restore them to defaults after this function is finished # since they might depend on specific parameters provided here # (e.g., atol/rtol for floats) @contextlib.contextmanager def temp_add_type_equality_func(self, typeobj, function): old_asserter = self._type_equality_funcs.get(typeobj) self.addTypeEqualityFunc(typeobj, function) yield self.addTypeEqualityFunc(typeobj, old_asserter) temp_add_type_equality_func = partial(temp_add_type_equality_func, self) # Register equality asserters with temp_add_type_equality_func(float, partial(assert_float_equal, self)), \ temp_add_type_equality_func(Float, partial(assert_float_equal, self)), \ temp_add_type_equality_func(DataType, partial(assert_datatype_equal, self)), \ temp_add_type_equality_func(T.DataType, partial(assert_datatype_equal, self)), \ temp_add_type_equality_func(Row, self.assertDictEqual), \ temp_add_type_equality_func(T.Row, partial(assert_row_equal, self)), \ temp_add_type_equality_func(Dict, self.assertDictEqual), \ temp_add_type_equality_func(List, self.assertListEqual): # And finally (phew!) run the actual comparisons ignore_order_depth = ignore_order_depth or -1 if ignore_order else 0 first = cast_to_test_friendly_container(first, ignore_order_depth) second = cast_to_test_friendly_container(second, ignore_order_depth) self.assertEqual(first, second, msg)
[docs]class SparklyGlobalSessionTest(SparklyTest): """Base test case that keeps a single instance for the given session class across all tests. Integration tests are slow, especially when you have to start/stop Spark context for each test case. This class allows you to reuse Spark session across multiple test cases. """ @classmethod def _init_session(cls): global _test_session_cache if _test_session_cache and cls.session == type(_test_session_cache):'Reusing the global session for %r', cls.session) spark = _test_session_cache else: if _test_session_cache:'Stopping the previous global session %r', _test_session_cache) _test_session_cache.stop()'Starting the new global session for %r', cls.session) spark = _test_session_cache = cls.setup_session() cls.spark = spark @classmethod def tearDownClass(cls): cls.spark.catalog.clearCache() for fixture in cls.class_fixtures: fixture.teardown_data()
[docs]class Fixture(object): """Base class for fixtures. Fixture is a term borrowed from Django tests, it's data loaded into database for integration testing. """
[docs] def setup_data(self): """Method called to load data into database.""" raise NotImplementedError()
[docs] def teardown_data(self): """Method called to remove data from database which was loaded by `setup_data`.""" raise NotImplementedError()
def __enter__(self): self.setup_data() def __exit__(self, exc_type, exc_val, exc_tb): self.teardown_data() @classmethod def read_file(cls, path): with open(path) as f: data = return data
[docs]class CassandraFixture(Fixture): """Fixture to load data into cassandra. Notes: * Depends on cassandra-driver. Examples: >>> class MyTestCase(SparklyTest): ... fixtures = [ ... CassandraFixture( ... '', ... absolute_path(__file__, 'resources', 'setup.cql'), ... absolute_path(__file__, 'resources', 'teardown.cql'), ... ) ... ] ... >>> class MyTestCase(SparklyTest): ... data = CassandraFixture( ... '', ... absolute_path(__file__, 'resources', 'setup.cql'), ... absolute_path(__file__, 'resources', 'teardown.cql'), ... ) ... def setUp(self): ... data.setup_data() ... def tearDown(self): ... data.teardown_data() ... >>> def test(): ... fixture = CassandraFixture(...) ... with fixture: ... test_stuff() ... """ def __init__(self, host, setup_file, teardown_file): if not CASSANDRA_FIXTURES_SUPPORT: raise NotImplementedError('cassandra-driver package isn\'t available. ' 'Use pip install sparkly[test] to fix it.') = host self.setup_file = setup_file self.teardown_file = teardown_file def _execute(self, statements): cluster = Cluster([]) session = cluster.connect() for statement in statements.split(';'): if bool(statement.strip()): session.execute(statement.strip()) def setup_data(self): self._execute(self.read_file(self.setup_file)) def teardown_data(self): self._execute(self.read_file(self.teardown_file))
[docs]class ElasticFixture(Fixture): """Fixture for elastic integration tests. Examples: >>> class MyTestCase(SparklyTest): ... fixtures = [ ... ElasticFixture( ... '', ... 'es_index', ... 'es_type', ... '/path/to/mapping.json', ... '/path/to/data.json', ... ) ... ] ... """ def __init__(self, host, es_index, es_type, mapping=None, data=None, port=None): = host self.port = port or 9200 self.es_index = es_index self.es_type = es_type self.mapping = mapping = data def setup_data(self): if self.mapping: self._request( 'PUT', '/{}'.format(self.es_index), json.dumps({ 'settings': { 'index': { 'number_of_shards': 1, 'number_of_replicas': 1, } } }), ) self._request( 'PUT', '/{}/_mapping/{}'.format(self.es_index, self.es_type), self.read_file(self.mapping), ) if self._request( 'POST', '/_bulk', self.read_file(, ) self._request( 'POST', '/_refresh', ) def teardown_data(self): self._request( 'DELETE', '/{}'.format(self.es_index), ) def _request(self, method, url, body=None): connection = HTTPConnection(, port=self.port) connection.request(method, url, body) response = connection.getresponse() if sys.version_info.major == 3: code = response.code else: code = response.status if code != 200: raise FixtureError('{}: {}'.format(code,
[docs]class MysqlFixture(Fixture): """Fixture for mysql integration tests. Notes: * depends on PyMySql lib. Examples: >>> class MyTestCase(SparklyTest): ... fixtures = [ ... MysqlFixture('', 'user', 'password', '/path/to/data.sql') ... ] ... def test(self): ... pass ... """ def __init__(self, host, user, password=None, data=None, teardown=None): if not MYSQL_FIXTURES_SUPPORT: raise NotImplementedError('PyMySQL package isn\'t available. ' 'Use pip install sparkly[test] to fix it.') = host self.user = user self.password = password = data self.teardown = teardown def _execute(self, statements): ctx = connector.connect( user=self.user, password=self.password,, ) cursor = ctx.cursor() cursor.execute(statements) ctx.commit() cursor.close() ctx.close() def setup_data(self): self._execute(self.read_file( def teardown_data(self): self._execute(self.read_file(self.teardown))
[docs]class KafkaFixture(Fixture): """Fixture for kafka integration tests. Notes: * depends on kafka-python lib. * json file should contain array of dicts: [{'key': ..., 'value': ...}] Examples: >>> class MyTestCase(SparklySession): ... fixtures = [ ... KafkaFixture( ... '', 'topic', ... key_serializer=..., value_serializer=..., ... data='/path/to/data.json', ... ) ... ] """ def __init__(self, host, port=9092, topic=None, key_serializer=None, value_serializer=None, data=None): """Constructor. Args: host (str): Kafka host. port (int): Kafka port. topic (str): Kafka topic. key_serializer (function): Converts python data structure to bytes, applied to message key. value_serializer (function): Converts python data structure to bytes, applied to message value. data (str): Path to json file with data. """ if not KAFKA_FIXTURES_SUPPORT: raise NotImplementedError('kafka-python package isn\'t available. ' 'Use pip install sparkly[test] to fix it.') = host self.port = port self.topic = topic self.key_serializer = key_serializer self.value_serializer = value_serializer = data def _publish_data(self, data): producer = KafkaProducer(bootstrap_servers='kafka.docker', key_serializer=self.key_serializer, value_serializer=self.value_serializer) for item in data: producer.send(self.topic, key=item['key'], value=item['value']) producer.flush() producer.close() def setup_data(self): data = [json.loads(item) for item in self.read_file('\n')] self._publish_data(data) def teardown_data(self): pass
[docs]class KafkaWatcher: """Context manager that tracks Kafka data published to a topic Provides access to the new items that were written to a kafka topic by code running within this context. NOTE: This is mainly useful in integration test cases and may produce unexpected results in production environments, since there are no guarantees about who else may be publishing to a kafka topic. Usage: my_deserializer = lambda item: json.loads(item.decode('utf-8')) kafka_watcher = KafkaWatcher( my_sparkly_session, expected_output_dataframe_schema, my_deserializer, my_deserializer, '', 'my_kafka_topic', ) with kafka_watcher: # do stuff that publishes messages to 'my_kafka_topic' self.assertEqual(kafka_watcher.count, expected_number_of_new_messages) self.assertDataFrameEqual(kafka_watcher.df, expected_df) """ def __init__( self, spark, df_schema, key_deserializer, value_deserializer, host, topic, port=9092, ): """Initialize context manager Parameters `key_deserializer` and `value_deserializer` are callables which get bytes as input and should return python structures as output. Args: spark (SparklySession): currently active SparklySession df_schema (pyspark.sql.types.StructType): schema of dataframe to be generated key_deserializer (function): function used to deserialize the key value_deserializer (function): function used to deserialize the value host (basestring): host or ip address of the kafka server to connect to topic (basestring): Kafka topic to monitor port (int): port number of the Kafka server to connect to """ self.spark = spark self.topic = topic self.df_schema = df_schema self.key_deser, self.val_deser = key_deserializer, value_deserializer, self.port = host, port self._df = None self.count = 0 kafka_client = SimpleClient(host) kafka_client.ensure_topic_exists(topic) def __enter__(self): self._df = None self.count = 0 self.pre_offsets = kafka_get_topics_offsets( topic=self.topic,, port=self.port, ) def __exit__(self, e_type, e_value, e_trace): self.post_offsets = kafka_get_topics_offsets( topic=self.topic,, port=self.port, ) self.count = sum([ post[2] - pre[2] for pre, post in zip(self.pre_offsets, self.post_offsets) ]) @property def df(self): if not self.count: return None if not self._df: offset_ranges = [ [pre[0], pre[2], post[2]] for pre, post in zip(self.pre_offsets, self.post_offsets) ] self._df = self.spark.read_ext.kafka( topic=self.topic, offset_ranges=offset_ranges, schema=self.df_schema, key_deserializer=self.key_deser, value_deserializer=self.val_deser,, port=self.port, ) return self._df