#
# Copyright 2017 Tubular Labs, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import sys
import os
import shutil
from unittest import TestCase
from sparkly import SparklySession
from sparkly.exceptions import FixtureError
from sparkly.utils import kafka_get_topics_offsets
if sys.version_info.major == 3:
from http.client import HTTPConnection
else:
from httplib import HTTPConnection
try:
from cassandra.cluster import Cluster
CASSANDRA_FIXTURES_SUPPORT = True
except ImportError:
CASSANDRA_FIXTURES_SUPPORT = False
try:
import pymysql as connector
MYSQL_FIXTURES_SUPPORT = True
except ImportError:
try:
import mysql.connector as connector
MYSQL_FIXTURES_SUPPORT = True
except ImportError:
MYSQL_FIXTURES_SUPPORT = False
try:
from kafka import KafkaProducer, SimpleClient
KAFKA_FIXTURES_SUPPORT = True
except ImportError:
KAFKA_FIXTURES_SUPPORT = False
logger = logging.getLogger()
_test_session_cache = None
[docs]class SparklyTest(TestCase):
"""Base test for spark scrip tests.
Initialize and shut down Session specified in `session` attribute.
Example:
>>> class MyTestCase(SparklyTest):
... def test(self):
... self.assertDataFrameEqual(
... self.spark.sql('SELECT 1 as one').collect(),
... [{'one': 1}],
... )
"""
session = SparklySession
class_fixtures = []
fixtures = []
maxDiff = None
@classmethod
def setUpClass(cls):
super(SparklyTest, cls).setUpClass()
# In case if project has a mix of SparklyTest and SparklyGlobalContextTest-based tests
global _test_session_cache
if _test_session_cache:
logger.info('Found a global session, stopping it %r', _test_session_cache)
_test_session_cache.stop()
_test_session_cache = None
cls.spark = cls.session()
for fixture in cls.class_fixtures:
fixture.setup_data()
@classmethod
def tearDownClass(cls):
cls.spark.stop()
super(SparklyTest, cls).tearDownClass()
try:
shutil.rmtree('metastore_db')
except OSError:
pass
try:
os.unlink('derby.log')
except OSError:
pass
for fixture in cls.class_fixtures:
fixture.teardown_data()
def setUp(self):
for fixture in self.fixtures:
fixture.setup_data()
def tearDown(self):
for fixture in self.fixtures:
fixture.teardown_data()
[docs] def assertDataFrameEqual(self, actual_df, expected_data, fields=None, ordered=False):
"""Ensure that DataFrame has the right data inside.
Args:
actual_df (pyspark.sql.DataFrame|list[pyspark.sql.Row]): Dataframe to test data in.
expected_data (list[dict]): Expected dataframe rows defined as dicts.
fields (list[str]): Compare only certain fields.
ordered (bool): Does order of rows matter?
"""
if fields:
actual_df = actual_df.select(*fields)
actual_rows = actual_df.collect() if hasattr(actual_df, 'collect') else actual_df
actual_data = [row.asDict(recursive=True) for row in actual_rows]
if ordered:
self.assertEqual(actual_data, expected_data)
else:
try:
self.assertCountEqual(actual_data, expected_data)
except AttributeError:
self.assertItemsEqual(actual_data, expected_data)
[docs]class SparklyGlobalSessionTest(SparklyTest):
"""Base test case that keeps a single instance for the given session class across all tests.
Integration tests are slow, especially when you have to start/stop Spark context
for each test case. This class allows you to reuse Spark session across multiple test cases.
"""
@classmethod
def setUpClass(cls):
global _test_session_cache
if _test_session_cache and cls.session == type(_test_session_cache):
logger.info('Reusing the global session for %r', cls.session)
spark = _test_session_cache
else:
if _test_session_cache:
logger.info('Stopping the previous global session %r', _test_session_cache)
_test_session_cache.stop()
logger.info('Starting the new global session for %r', cls.session)
spark = _test_session_cache = cls.session()
cls.spark = spark
for fixture in cls.class_fixtures:
fixture.setup_data()
@classmethod
def tearDownClass(cls):
cls.spark.catalog.clearCache()
for fixture in cls.class_fixtures:
fixture.teardown_data()
[docs]class Fixture(object):
"""Base class for fixtures.
Fixture is a term borrowed from Django tests,
it's data loaded into database for integration testing.
"""
[docs] def setup_data(self):
"""Method called to load data into database."""
raise NotImplementedError()
[docs] def teardown_data(self):
"""Method called to remove data from database which was loaded by `setup_data`."""
raise NotImplementedError()
def __enter__(self):
self.setup_data()
def __exit__(self, exc_type, exc_val, exc_tb):
self.teardown_data()
@classmethod
def read_file(cls, path):
with open(path) as f:
data = f.read()
return data
[docs]class CassandraFixture(Fixture):
"""Fixture to load data into cassandra.
Notes:
* Depends on cassandra-driver.
Examples:
>>> class MyTestCase(SparklyTest):
... fixtures = [
... CassandraFixture(
... 'cassandra.host',
... absolute_path(__file__, 'resources', 'setup.cql'),
... absolute_path(__file__, 'resources', 'teardown.cql'),
... )
... ]
...
>>> class MyTestCase(SparklyTest):
... data = CassandraFixture(
... 'cassandra.host',
... absolute_path(__file__, 'resources', 'setup.cql'),
... absolute_path(__file__, 'resources', 'teardown.cql'),
... )
... def setUp(self):
... data.setup_data()
... def tearDown(self):
... data.teardown_data()
...
>>> def test():
... fixture = CassandraFixture(...)
... with fixture:
... test_stuff()
...
"""
def __init__(self, host, setup_file, teardown_file):
if not CASSANDRA_FIXTURES_SUPPORT:
raise NotImplementedError('cassandra-driver package isn\'t available. '
'Use pip install sparkly[test] to fix it.')
self.host = host
self.setup_file = setup_file
self.teardown_file = teardown_file
def _execute(self, statements):
cluster = Cluster([self.host])
session = cluster.connect()
for statement in statements.split(';'):
if bool(statement.strip()):
session.execute(statement.strip())
def setup_data(self):
self._execute(self.read_file(self.setup_file))
def teardown_data(self):
self._execute(self.read_file(self.teardown_file))
[docs]class ElasticFixture(Fixture):
"""Fixture for elastic integration tests.
Examples:
>>> class MyTestCase(SparklyTest):
... fixtures = [
... ElasticFixture(
... 'elastic.host',
... 'es_index',
... 'es_type',
... '/path/to/mapping.json',
... '/path/to/data.json',
... )
... ]
...
"""
def __init__(self, host, es_index, es_type, mapping=None, data=None, port=None):
self.host = host
self.port = port or 9200
self.es_index = es_index
self.es_type = es_type
self.mapping = mapping
self.data = data
def setup_data(self):
if self.mapping:
self._request(
'PUT',
'/{}'.format(self.es_index),
json.dumps({
'settings': {
'index': {
'number_of_shards': 1,
'number_of_replicas': 1,
}
}
}),
)
self._request(
'PUT',
'/{}/_mapping/{}'.format(self.es_index, self.es_type),
self.read_file(self.mapping),
)
if self.data:
self._request(
'POST',
'/_bulk',
self.read_file(self.data),
)
self._request(
'POST',
'/_refresh',
)
def teardown_data(self):
self._request(
'DELETE',
'/{}'.format(self.es_index),
)
def _request(self, method, url, body=None):
connection = HTTPConnection(self.host, port=self.port)
connection.request(method, url, body)
response = connection.getresponse()
if sys.version_info.major == 3:
code = response.code
else:
code = response.status
if code != 200:
raise FixtureError('{}: {}'.format(code, response.read()))
[docs]class MysqlFixture(Fixture):
"""Fixture for mysql integration tests.
Notes:
* depends on PyMySql lib.
Examples:
>>> class MyTestCase(SparklyTest):
... fixtures = [
... MysqlFixture('mysql.host', 'user', 'password', '/path/to/data.sql')
... ]
... def test(self):
... pass
...
"""
def __init__(self, host, user, password=None, data=None, teardown=None):
if not MYSQL_FIXTURES_SUPPORT:
raise NotImplementedError('PyMySQL package isn\'t available. '
'Use pip install sparkly[test] to fix it.')
self.host = host
self.user = user
self.password = password
self.data = data
self.teardown = teardown
def _execute(self, statements):
ctx = connector.connect(
user=self.user,
password=self.password,
host=self.host,
)
cursor = ctx.cursor()
cursor.execute(statements)
ctx.commit()
cursor.close()
ctx.close()
def setup_data(self):
self._execute(self.read_file(self.data))
def teardown_data(self):
self._execute(self.read_file(self.teardown))
[docs]class KafkaFixture(Fixture):
"""Fixture for kafka integration tests.
Notes:
* depends on kafka-python lib.
* json file should contain array of dicts: [{'key': ..., 'value': ...}]
Examples:
>>> class MyTestCase(SparklySession):
... fixtures = [
... KafkaFixture(
... 'kafka.host', 'topic',
... key_serializer=..., value_serializer=...,
... data='/path/to/data.json',
... )
... ]
"""
def __init__(self, host, port=9092, topic=None,
key_serializer=None, value_serializer=None,
data=None):
"""Constructor.
Args:
host (str): Kafka host.
port (int): Kafka port.
topic (str): Kafka topic.
key_serializer (function): Converts python data structure to bytes,
applied to message key.
value_serializer (function): Converts python data structure to bytes,
applied to message value.
data (str): Path to json file with data.
"""
if not KAFKA_FIXTURES_SUPPORT:
raise NotImplementedError('kafka-python package isn\'t available. '
'Use pip install sparkly[test] to fix it.')
self.host = host
self.port = port
self.topic = topic
self.key_serializer = key_serializer
self.value_serializer = value_serializer
self.data = data
def _publish_data(self, data):
producer = KafkaProducer(bootstrap_servers='kafka.docker',
key_serializer=self.key_serializer,
value_serializer=self.value_serializer)
for item in data:
producer.send(self.topic, key=item['key'], value=item['value'])
producer.flush()
producer.close()
def setup_data(self):
data = [json.loads(item) for item in self.read_file(self.data).strip().split('\n')]
self._publish_data(data)
def teardown_data(self):
pass
[docs]class KafkaWatcher:
"""Context manager that tracks Kafka data published to a topic
Provides access to the new items that were written to a kafka topic by code running
within this context.
NOTE: This is mainly useful in integration test cases and may produce unexpected results in
production environments, since there are no guarantees about who else may be publishing to
a kafka topic.
Usage:
my_deserializer = lambda item: json.loads(item.decode('utf-8'))
kafka_watcher = KafkaWatcher(
my_sparkly_session,
expected_output_dataframe_schema,
my_deserializer,
my_deserializer,
'my.kafkaserver.net',
'my_kafka_topic',
)
with kafka_watcher:
# do stuff that publishes messages to 'my_kafka_topic'
self.assertEqual(kafka_watcher.count, expected_number_of_new_messages)
self.assertDataFrameEqual(kafka_watcher.df, expected_df)
"""
def __init__(
self,
spark,
df_schema,
key_deserializer,
value_deserializer,
host,
topic,
port=9092,
):
"""Initialize context manager
Parameters `key_deserializer` and `value_deserializer` are callables
which get bytes as input and should return python structures as output.
Args:
spark (SparklySession): currently active SparklySession
df_schema (pyspark.sql.types.StructType): schema of dataframe to be generated
key_deserializer (function): function used to deserialize the key
value_deserializer (function): function used to deserialize the value
host (basestring): host or ip address of the kafka server to connect to
topic (basestring): Kafka topic to monitor
port (int): port number of the Kafka server to connect to
"""
self.spark = spark
self.topic = topic
self.df_schema = df_schema
self.key_deser, self.val_deser = key_deserializer, value_deserializer
self.host, self.port = host, port
self._df = None
self.count = 0
kafka_client = SimpleClient(host)
kafka_client.ensure_topic_exists(topic)
def __enter__(self):
self._df = None
self.count = 0
self.pre_offsets = kafka_get_topics_offsets(
topic=self.topic,
host=self.host,
port=self.port,
)
def __exit__(self, e_type, e_value, e_trace):
self.post_offsets = kafka_get_topics_offsets(
topic=self.topic,
host=self.host,
port=self.port,
)
self.count = sum([
post[2] - pre[2]
for pre, post in zip(self.pre_offsets, self.post_offsets)
])
@property
def df(self):
if not self.count:
return None
if not self._df:
offset_ranges = [
[pre[0], pre[2], post[2]]
for pre, post in zip(self.pre_offsets, self.post_offsets)
]
self._df = self.spark.read_ext.kafka(
topic=self.topic,
offset_ranges=offset_ranges,
schema=self.df_schema,
key_deserializer=self.key_deser,
value_deserializer=self.val_deser,
host=self.host,
port=self.port,
)
return self._df