#
# Copyright 2017 Tubular Labs, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
try:
from urllib.parse import urlparse, parse_qsl
except ImportError:
from urlparse import urlparse, parse_qsl
try:
from kafka import KafkaProducer
except ImportError:
KAFKA_WRITER_SUPPORT = False
else:
KAFKA_WRITER_SUPPORT = True
try:
import redis
import ujson
except ImportError:
REDIS_WRITER_SUPPORT = False
else:
import bz2
import gzip
import uuid
import zlib
REDIS_WRITER_SUPPORT = True
from functools import partial
from itertools import islice
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
from pyspark.sql import types as T
from sparkly.exceptions import WriteError
[docs]class SparklyWriter(object):
"""A set of tools to write DataFrames to external storages.
Note:
We don't expect you to be using the class directly.
The instance of the class is available under `DataFrame` via `write_ext` attribute.
"""
def __init__(self, df):
self._df = df
self._spark = df.sql_ctx.sparkSession
[docs] def by_url(self, url):
"""Write a dataframe to a destination specified by `url`.
The main idea behind the method is to unify data export interface for different
formats and locations. A generic schema looks like::
format:[protocol:]//host[:port][/location][?configuration]
Supported formats:
- CSV ``csv://``
- Cassandra ``cassandra://``
- Elastic ``elastic://``
- MySQL ``mysql://``
- Parquet ``parquet://``
- Redis ``redis://`` or ``rediss://``
Query string arguments are passed as parameters to the relevant writer.\n
For instance, the next data export URL::
elastic://localhost:9200/my_index/my_type?¶llelism=3&mode=overwrite
&es.write.operation=upsert
Is an equivalent for::
hc.read_ext.elastic(
host='localhost',
port=9200,
es_index='my_index',
es_type='my_type',
parallelism=3,
mode='overwrite',
options={'es.write.operation': 'upsert'},
)
More examples::
csv:s3://some-s3-bucket/some-s3-key?partitionBy=date,platform
cassandra://cassandra.host/keyspace/table?consistency=ONE&mode=append
parquet:///var/log/?partitionBy=date
elastic://elastic.host/es_index/es_type
mysql://mysql.host/database/table
redis://redis.host/db?keyBy=id
Args:
url (str): Destination URL.
"""
parsed_url = urlparse(url)
parsed_qs = dict(parse_qsl(parsed_url.query))
# Used across all readers
if 'parallelism' in parsed_qs:
parsed_qs['parallelism'] = int(parsed_qs['parallelism'])
try:
resolver = getattr(self, '_resolve_{}'.format(parsed_url.scheme))
except AttributeError:
raise NotImplementedError('Data source is not supported: {}'.format(url))
else:
return resolver(parsed_url, parsed_qs)
[docs] def cassandra(self, host, keyspace, table, consistency=None, port=None, mode=None,
parallelism=None, options=None):
"""Write a dataframe to a Cassandra table.
Args:
host (str): Cassandra server host.
keyspace (str): Cassandra keyspace to write to.
table (str): Cassandra table to write to.
consistency (str|None): Write consistency level: ``ONE``, ``QUORUM``, ``ALL``, etc.
port (int|None): Cassandra server port.
mode (str|None): Spark save mode,
http://spark.apache.org/docs/latest/sql-programming-guide.html#save-modes
parallelism (int|None): The max number of parallel tasks that could be executed
during the write stage (see :ref:`controlling-the-load`).
options (dict[str, str]): Additional options to `org.apache.spark.sql.cassandra`
format (see configuration for :ref:`cassandra`).
"""
assert self._spark.has_package('datastax:spark-cassandra-connector')
writer_options = {
'format': 'org.apache.spark.sql.cassandra',
'spark.cassandra.connection.host': host,
'keyspace': keyspace,
'table': table,
}
if consistency:
writer_options['spark.cassandra.input.consistency.level'] = consistency
if port:
writer_options['spark.cassandra.connection.port'] = str(port)
return self._basic_write(writer_options, options, parallelism, mode)
[docs] def elastic(self, host, es_index, es_type, port=None, mode=None,
parallelism=None, options=None):
"""Write a dataframe into an ElasticSearch index.
Args:
host (str): Elastic server host.
es_index (str): Elastic index.
es_type (str): Elastic type.
port (int|None) Elastic server port.
mode (str|None): Spark save mode,
http://spark.apache.org/docs/latest/sql-programming-guide.html#save-modes
parallelism (int|None): The max number of parallel tasks that could be executed
during the write stage (see :ref:`controlling-the-load`).
options (dict[str, str]): Additional options to `org.elasticsearch.spark.sql` format
(see configuration for :ref:`elastic`).
"""
assert self._spark.has_package('org.elasticsearch:elasticsearch-spark')
writer_options = {
'path': '{}/{}'.format(es_index, es_type),
'format': 'org.elasticsearch.spark.sql',
'es.nodes': host,
}
if port:
writer_options['es.port'] = str(port)
return self._basic_write(writer_options, options, parallelism, mode)
[docs] def mysql(self, host, database, table, port=None, mode=None, parallelism=None, options=None):
"""Write a dataframe to a MySQL table.
Options should include user and password.
Args:
host (str): MySQL server address.
database (str): Database to connect to.
table (str): Table to read rows from.
mode (str|None): Spark save mode,
http://spark.apache.org/docs/latest/sql-programming-guide.html#save-modes
parallelism (int|None): The max number of parallel tasks that could be executed
during the write stage (see :ref:`controlling-the-load`).
options (dict): Additional options for JDBC writer
(see configuration for :ref:`mysql`).
"""
assert (self._spark.has_jar('mysql-connector-java') or
self._spark.has_package('mysql:mysql-connector-java'))
writer_options = {
'format': 'jdbc',
'driver': 'com.mysql.jdbc.Driver',
'url': 'jdbc:mysql://{host}{port}/{database}'.format(
host=host,
port=':{}'.format(port) if port else '',
database=database,
),
'table': table,
}
return self._basic_write(writer_options, options, parallelism, mode)
[docs] def kafka(self,
host,
topic,
key_serializer,
value_serializer,
port=9092,
parallelism=None,
options=None):
"""Writes dataframe to kafka topic.
The schema of the dataframe should conform the pattern:
>>> StructType([
... StructField('key', ...),
... StructField('value', ...),
... ])
Parameters `key_serializer` and `value_serializer` are callables
which get's python structure as input and should return bytes of encoded data as output.
Args:
host (str): Kafka host.
topic (str): Topic to write to.
key_serializer (function): Function to serialize key.
value_serializer (function): Function to serialize value.
port (int): Kafka port.
parallelism (int|None): The max number of parallel tasks that could be executed
during the write stage (see :ref:`controlling-the-load`).
options (dict|None): Additional options.
"""
assert self._spark.has_package('org.apache.spark:spark-streaming-kafka')
if not KAFKA_WRITER_SUPPORT:
raise NotImplementedError('kafka-python package isn\'t available. '
'Use pip install sparkly[kafka] to fix it.')
def write_partition_to_kafka(messages):
producer = KafkaProducer(
bootstrap_servers=['{}:{}'.format(host, port)],
key_serializer=key_serializer,
value_serializer=value_serializer,
)
for message in messages:
as_dict = message.asDict(recursive=True)
result = producer.send(topic, key=as_dict['key'], value=as_dict['value'])
if result.failed():
raise WriteError('Error publishing to kafka: {}'.format(result.exception))
producer.flush()
producer.close()
return messages
rdd = self._df.rdd
if parallelism:
rdd = rdd.coalesce(parallelism)
rdd.mapPartitions(write_partition_to_kafka).count()
[docs] def redis(self,
key_by,
key_prefix=None,
key_delimiter='.',
group_by_key=False,
exclude_key_columns=False,
exclude_null_fields=False,
expire=None,
compression=None,
max_pipeline_size=100,
parallelism=None,
mode='overwrite',
host=None,
port=6379,
db=0,
redis_client_init=None):
"""Write a dataframe to Redis as JSON.
Args:
key_by (list[str]): Column names that form the redis key for
each row. The columns are concatenated in the order they
appear.
key_prefix (str|None): Common prefix to add to all keys from
this DataFrame. Useful to namespace DataFrame exports.
key_delimiter (str|.): Characters to delimit different columns
while forming the key.
group_by_key (bool|False): If set, group rows that share the
same redis key together in an array before exporting. By
default if multiple rows share the same redis key, one
will overwrite the other.
exclude_key_columns (bool|False): If set, exclude all columns
that comprise the key from the value being exported to
redis.
exclude_null_fields (bool|False): If set, exclude all fields
of a row that are null from the value being exported to
redis.
expire (int|None): Expire the keys after this number of seconds.
compression (str|None): Compress each Redis entry using this
protocol. Currently bzip2, gzip and zlib are supported.
max_pipeline_size (int|100): Number of writes to pipeline.
parallelism (int|None): The max number of parallel tasks that
could be executed during the write stage
(see :ref:`controlling-the-load`).
mode (str|overwrite):
- ``'append'``: Append to existing data on a key by key
basis.
- ``'ignore'``: Silently ignore if data already exists.
- ``'overwrite'``: Flush all existing data before writing.
host (str|None): Redis host. Either this or redis_client_init
must be provided. See below.
port (int|6379): Port redis is listening to.
db (int|0): Redis db to write to.
redis_client_init (callable|None): Bypass internal redis
client initialization by passing a function that does it,
no arguments required. For example this could be
``redis.StrictRedis.from_url`` with the appropriate url and
``kwargs`` already set through ``functools.partial``.
This option overrides other conflicting arguments.
Raises:
NotImplementedError: if `redis-py` is not installed.
AssertionError: if nor host neither ``redis_client_init`` are
provided.
ValueError: if any of the ``expire``, ``compression``,
``max_pipeline_size`` or ``mode`` options assume an
invalid value.
"""
if not REDIS_WRITER_SUPPORT:
raise NotImplementedError(
'redis package is not available. Use pip install sparkly[redis] to fix it.'
)
assert host or redis_client_init, \
'redis: At least one of host or redis_client_init must be provided.'
if expire is not None and expire < 1:
raise ValueError('redis: expire must be positive')
if compression not in {None, 'bzip2', 'gzip', 'zlib'}:
raise ValueError(
'redis: bzip2, gzip and zlib are the only supported compression codecs.'
)
if max_pipeline_size < 1:
raise ValueError('redis: max pipeline size must be positive')
if mode not in {'append', 'ignore', 'overwrite'}:
raise ValueError(
'redis: only append (default), ignore and overwrite modes are supported.'
)
key_name = '_sparkly_redis_key_col_{}'.format(uuid.uuid4().hex)
value_name = '_sparkly_redis_value_col_{}'.format(uuid.uuid4().hex)
# Compute the key
df = self._df.withColumn(
key_name,
F.concat_ws(
key_delimiter,
*(
([F.lit(key_prefix).astype(T.StringType())] if key_prefix else []) +
[F.col(col_name).astype(T.StringType()) for col_name in key_by]
)
)
)
if exclude_key_columns:
for col_name in key_by:
df = df.drop(col_name)
if group_by_key:
df = (
df
.withColumn(
value_name,
F.struct(*[col for col in df.columns if col != key_name])
)
.groupBy(key_name)
.agg(F.collect_list(value_name).alias(value_name))
)
# Repartition if needed to achieve specified parallelism
df = df.coalesce(parallelism or df.rdd.getNumPartitions())
def compress(data, protocol=None):
if protocol is None:
return data
elif protocol == 'bzip2':
return bz2.compress(data)
elif protocol == 'zlib':
return zlib.compress(data)
elif protocol == 'gzip':
try:
return gzip.compress(data)
except AttributeError: # py27
compressor = zlib.compressobj(
zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED,
zlib.MAX_WBITS | 16,
)
return compressor.compress(data) + compressor.flush()
else:
raise ValueError('unknown compression protocol {}'.format(protocol))
# Define the function that will write each partition to redis
def _write_redis(partition, redis, expire, compress, max_pipeline_size, mode):
pipeline = redis().pipeline(transaction=False)
add_to_pipeline = partial(pipeline.set, ex=expire, nx=(mode == 'ignore'))
partition = iter(partition)
while True:
rows = islice(partition, max_pipeline_size)
for row in rows:
row = row.asDict(recursive=True)
key = row.pop(key_name)
if group_by_key:
row = row[value_name]
if exclude_null_fields:
for null_field in list(f for f in row.keys() if row[f] is None):
del row[null_field]
data = bytes(ujson.dumps(row).encode('ascii'))
if compress:
data = compress(data)
add_to_pipeline(key, data)
if not len(pipeline):
break
pipeline.execute()
redis_client_init = redis_client_init or partial(redis.StrictRedis, host, port, db=db)
if mode == 'overwrite':
redis_client_init().flushdb()
df.foreachPartition(
partial(
_write_redis,
redis=redis_client_init,
expire=expire,
compress=partial(compress, protocol=compression),
max_pipeline_size=max_pipeline_size,
mode=mode,
)
)
def _basic_write(self, writer_options, additional_options, parallelism, mode):
if mode:
writer_options['mode'] = mode
writer_options.update(additional_options or {})
df = self._df
if parallelism:
df = df.coalesce(parallelism)
# For some reason the native `df.write.jdbc` calls `_jwrite` directly
# so we can't use `df.write.save` for it.
if writer_options['format'] == 'jdbc':
return df.write.jdbc(
url=writer_options.pop('url'),
table=writer_options.pop('table'),
mode=writer_options.pop('mode', None),
properties=writer_options,
)
else:
return df.write.save(**writer_options)
def _resolve_cassandra(self, parsed_url, parsed_qs):
return self.cassandra(
host=parsed_url.netloc,
keyspace=parsed_url.path.split('/')[1],
table=parsed_url.path.split('/')[2],
consistency=parsed_qs.pop('consistency', None),
port=parsed_url.port,
mode=parsed_qs.pop('mode', None),
parallelism=parsed_qs.pop('parallelism', None),
options=parsed_qs,
)
def _resolve_csv(self, parsed_url, parsed_qs):
parallelism = parsed_qs.pop('parallelism', None)
if parallelism:
df = self._df.coalesce(int(parallelism))
else:
df = self._df
if 'partitionBy' in parsed_qs:
parsed_qs['partitionBy'] = parsed_qs.pop('partitionBy').split(',')
df.write.save(
path=parsed_url.path,
format=parsed_url.scheme,
**parsed_qs
)
def _resolve_elastic(self, parsed_url, parsed_qs):
return self.elastic(
host=parsed_url.netloc,
es_index=parsed_url.path.split('/')[1],
es_type=parsed_url.path.split('/')[2],
port=parsed_url.port,
mode=parsed_qs.pop('mode', None),
parallelism=parsed_qs.pop('parallelism', None),
options=parsed_qs,
)
def _resolve_mysql(self, parsed_url, parsed_qs):
return self.mysql(
host=parsed_url.netloc,
database=parsed_url.path.split('/')[1],
table=parsed_url.path.split('/')[2],
port=parsed_url.port,
mode=parsed_qs.pop('mode', None),
parallelism=parsed_qs.pop('parallelism', None),
options=parsed_qs,
)
def _resolve_parquet(self, parsed_url, parsed_qs):
parallelism = parsed_qs.pop('parallelism', None)
if parallelism:
df = self._df.coalesce(int(parallelism))
else:
df = self._df
if 'partitionBy' in parsed_qs:
parsed_qs['partitionBy'] = parsed_qs.pop('partitionBy').split(',')
df.write.save(
path=parsed_url.path,
format='parquet',
**parsed_qs
)
def _resolve_redis(self, parsed_url, parsed_qs):
# Extract all the custom options
try:
key_by = parsed_qs.pop('keyBy')
except KeyError:
raise AssertionError('redis: url must define keyBy columns to construct redis key')
key_by = key_by.split(',')
key_prefix = parsed_qs.pop('keyPrefix', None)
key_delimiter = parsed_qs.pop('keyDelimiter', '.')
def _parse_boolean(parameter, default='false'):
value = parsed_qs.pop(parameter, default).lower()
if value not in {'true', 'false'}:
raise ValueError(
'redis: true and false (default) are the only supported {} values'
.format(parameter)
)
return value == 'true'
group_by_key = _parse_boolean('groupByKey')
exclude_key_columns = _parse_boolean('excludeKeyColumns')
exclude_null_fields = _parse_boolean('excludeNullFields')
try:
expire = int(parsed_qs.pop('expire'))
except KeyError:
expire = None
except (TypeError, ValueError):
raise ValueError('redis: expire must be a base 10, positive integer')
compression = parsed_qs.pop('compression', None)
try:
max_pipeline_size = int(parsed_qs.pop('maxPipelineSize', 100))
except (TypeError, ValueError):
raise ValueError('redis: maxPipelineSize must be a base 10, positive integer')
parallelism = parsed_qs.pop('parallelism', None)
mode = parsed_qs.pop('mode', 'append')
# Reconstruct whatever remains of the original URL
url = parsed_url._replace(query='&'.join('='.join(o) for o in parsed_qs)).geturl()
return self.redis(
key_by=key_by,
key_prefix=key_prefix,
key_delimiter=key_delimiter,
group_by_key=group_by_key,
exclude_key_columns=exclude_key_columns,
exclude_null_fields=exclude_null_fields,
expire=expire,
compression=compression,
max_pipeline_size=max_pipeline_size,
parallelism=parallelism,
mode=mode,
# and then let redis-py decode it
redis_client_init=partial(redis.StrictRedis.from_url, url)
)
def _resolve_rediss(self, parsed_url, parsed_qs):
return self._resolve_redis(parsed_url, parsed_qs)
[docs]def attach_writer_to_dataframe():
"""A tiny amount of magic to attach write extensions."""
def write_ext(self):
return SparklyWriter(self)
DataFrame.write_ext = property(write_ext)