We do constant traffic with our Redshift tables, and so I created a wrapper class that will allow for custom sql to be ran (or a default generic stmt), and can run a safe_load
where it first copies the info to a dev db to make sure it will work before it truncates the prod db. Let me know what I can improve!
class RedshiftBase(object):
def __init__(self,
s3_credentials,
redshift_db_credentials,
table_name=None,
schema_name=None,
manifest_url=None,
unload_url=None,
dev_db_credentials=None,
sql_stmt=None,
safe_load=False,
truncate=False):
"""
This class automates the copy of data from an S3 file to a Redshift
database. Most of the methods are static, and can be accessed outside
the class. Run the 'execute' method to run the process.
:param table_name: The Redshift table name. Must include the schema if that
is required for database access. Ex: 'schema.table'.
:param s3_credentials: A dictionary containing the access and
secret access keys. Keys must match the example:
S3_INFO = {
'aws_access_key_id': S3_ACCESS,
'aws_secret_access_key': S3_SECRET,
'region_name': 'us-west-2'
}
:param redshift_db_credentials: A dictionary containing the host, port,
database name, username, and password. Keys must match example:
REDSHIFT_POSTGRES_INFO = {
'host': REDSHIFT_HOST,
'port': REDSHIFT_PORT,
'database': REDSHIFT_DATABASE_DEV,
'user': REDSHIFT_USER,
'password': REDSHIFT_PASS
}
:param schema_name: The schema name associated with the desired table.
:param unload_url: In the case of an unload operation, this specifies
the location on S3 where the files will be unloaded to.
:param manifest_url: The location of the file on S3.
:param sql_stmt: A SQL statement given as a single string. This is the
statement that will be used instead of the default.
Ex:
'''
SELECT *
FROM table
WHERE conditions < parameters
ORDER BY field DESC
'''
:param safe_load: If True will trigger a test load to a specified
development database during the 'execute' method. Useful for making
sure the data will correctly load before truncating the production
database.
:param truncate: If 'True', the production table will be truncated
before the copy step.
:return: None
"""
if schema_name:
self.table_name = schema_name + '.' + table_name
else:
self.table_name = table_name
self.manifest_url = manifest_url
self.unload_url = unload_url
self.s3_credentials = s3_credentials
self.prod_db_credentials = redshift_db_credentials
self.dev_db_credentials = dev_db_credentials
self.sql_stmt = sql_stmt
self.safe_load = safe_load
self.truncate = truncate
def __repr__(self):
return ('Table: {}\nManifest URL: {}\nUnload URL: {}\nS3 Credentials: '
'{}\nDev DB Credentials: {}\nProd DB Credentials: {}\nSafe '
'Load: {}\nTruncate: {}'.format(
self.table_name,
self.manifest_url,
self.unload_url,
self.s3_credentials,
self.dev_db_credentials,
self.prod_db_credentials,
self.safe_load,
self.truncate
))
class RedshiftLoad(RedshiftBase):
@staticmethod
def copy_to_db(database_credentials,
table_name,
manifest_url,
s3_credentials,
sql_stmt,
truncate=False):
"""
Copies data from a file on S3 to a Redshift table. Data must be
properly formatted and in the right order, etc...
:param database_credentials: A dictionary containing the host, port,
database name, username, and password. Keys must match example:
REDSHIFT_POSTGRES_INFO = {
'host': REDSHIFT_HOST,
'port': REDSHIFT_PORT,
'database': REDSHIFT_DATABASE_DEV,
'user': REDSHIFT_USER,
'password': REDSHIFT_PASS
}
:param table_name: The Redshift table name. Must include the schema if that
is required for database access. Ex: 'schema.table'.
:param manifest_url: The location of the file on the S3 server.
:param s3_credentials: A dictionary containing the access and
secret access keys. Keys must match the example:
S3_INFO = {
'aws_access_key_id': S3_ACCESS,
'aws_secret_access_key': S3_SECRET,
'region_name': 'us-west-2'
}
:param truncate: If 'True', will cause the table to be truncated before
the load.
:return: None
"""
s3_access = s3_credentials['aws_access_key_id']
s3_secret = s3_credentials['aws_secret_access_key']
logging.info('Accessing {}'.format(table_name))
try:
with ppg2.connect(**database_credentials) as conn:
cur = conn.cursor()
if truncate:
RedshiftLoad.truncate_table(table_name, cur)
load = '''
COPY {}
from '{}'
credentials 'aws_access_key_id={};aws_secret_access_key={}'
delimiter '|'
gzip
trimblanks
truncatecolumns
acceptinvchars
timeformat 'auto'
dateformat 'auto'
manifest
'''.format(
table_name,
manifest_url,
s3_access,
s3_secret)
if sql_stmt:
logging.info('Executing custom SQL unload statement.')
cur.execute(sql_stmt)
else:
logging.info('Executing default SQL unload statement.')
logging.info('Unloading from {}'.format(table_name))
cur.execute(load)
conn.commit()
except ppg2.Error as e:
logging.critical('Error occurred during db load: {}'.format(e))
sys.exit(1)
@staticmethod
def truncate_table(table, cursor):
"""
Truncates a table given the schema and table names."""
trunc_stmt = '''
truncate table {}
'''.format(table)
cursor.execute(trunc_stmt)
def execute(self):
if self.safe_load:
logging.info('Test load triggered, connecting now to {}.'.format(
self.table_name
))
self.copy_to_db(self.dev_db_credentials,
self.table_name,
self.manifest_url,
self.s3_credentials,
self.sql_stmt,
self.truncate)
logging.info('Load to the development database was a success.')
logging.info('Commencing load operation to {}.'.format(
self.table_name))
self.copy_to_db(self.prod_db_credentials,
self.table_name,
self.manifest_url,
self.s3_credentials,
self.sql_stmt,
self.truncate)
logging.info('Load to the production database was a success.')
class RedshiftUnload(RedshiftBase):
@staticmethod
def unload_to_s3(database_credentials,
table_name,
unload_url,
s3_credentials,
sql_stmt):
"""
Unloads the data from a Redshift table into a specified location in S3.
The default UNLOAD statement (i.e. if you don't pass in a sql statement
) defaults to allowing files to be overwritten if the unload_url is not
and empty directory.
:param database_credentials: A dictionary containing the host, port,
database name, username, and password. Keys must match example:
REDSHIFT_POSTGRES_INFO = {
'host': REDSHIFT_HOST,
'port': REDSHIFT_PORT,
'database': REDSHIFT_DATABASE_DEV,
'user': REDSHIFT_USER,
'password': REDSHIFT_PASS
}
:param table_name: The Redshift table name. Must include the schema if that
is required for database access. Ex: 'schema.table'
:param s3_credentials: A dictionary containing the access and
secret access keys. Keys must match the example:
S3_INFO = {
'aws_access_key_id': S3_ACCESS,
'aws_secret_access_key': S3_SECRET,
'region_name': 'us-west-2'
}
:return: None
"""
s3_access = s3_credentials['aws_access_key_id']
s3_secret = s3_credentials['aws_secret_access_key']
try:
with ppg2.connect(**database_credentials) as conn:
cur = conn.cursor()
unload = '''
UNLOAD ('SELECT * FROM {}')
TO '{}'
CREDENTIALS 'aws_access_key_id={};aws_secret_access_key={}'
MANIFEST
DELIMITER '|'
GZIP
ALLOWOVERWRITE
'''.format(
table_name,
unload_url,
s3_access,
s3_secret)
if sql_stmt:
logging.info('Executing custom SQL unload statement.')
cur.execute(sql_stmt)
else:
logging.info('Executing default SQL unload statement.')
logging.info('Unloading from {} (Will be None if custom '
'SQL was used'.format(table_name))
cur.execute(unload)
conn.commit()
except ppg2.Error as e:
logging.critical('Error occurred during unload: {}'.format(e))
sys.exit(1)
def execute(self):
self.unload_to_s3(self.prod_db_credentials,
self.table_name,
self.unload_url,
self.s3_credentials,
self.sql_stmt)
logging.info('Unload was a success.')
1 Answer 1
so I created a wrapper class that will allow for custom sql to be ran
So... I don't know if there is a way to parameterize your queries to redshift, or how this class is called, but hopefully not from user input - right now it is completely vulnerable to SQL injection as you are simply entering your table_name
variable (and others) directly into the query.
I should note the documentation for executing queries says:
Warning Never, never, NEVER use Python string concatenation (+) or string parameters interpolation (%) to pass variables to a SQL query string. Not even at gunpoint.
I would recommend reading their documentation on parameterized queries:
load = '''
COPY %(table_name)s
from '%(manifest_url)s'
credentials 'aws_access_key_id=$(s3_access)s;aws_secret_access_key=%(s3_secret)s'
delimiter '|'
gzip
trimblanks
truncatecolumns
acceptinvchars
timeformat 'auto'
dateformat 'auto'
manifest
'''
params = ['table_name': table_name,
'manifest_url': manifest_url,
's3_access': s3_credentials['aws_access_key_id'],
's3_secret': s3_credentials['aws_secret_access_key']]
cur.execute(load, params)
You can remove the temporary variable assignment for s3 in this case too.
Also, I am not a fan of how you have structured multiple methods to allow execution of precisely created queries.... or whatever string SQL statement comes in (which I guess would negate the SQL injection parameterization worry if you are literally allowing that... ?).
If you must, I would create a custom method for "execute arbitrary SQL" method instead of basically splitting all your load/unload methods two separate methods. Right now, each with this argument does totally different things depending on the single optional parameter -- when provided most code in the methods it is allowed is suddenly irrelevant.
Your doc strings could use some cleanup:
The default UNLOAD statement (i.e. if you don't pass in a sql statement ) defaults to allowing files to be overwritten if the unload_url is not and empty directory.
I don't really understand what this is saying.
Truncates a table given the schema and table names."""
This method seems to truncate a table given just a table name. I think what you meant to say was "Truncates a table given the fully qualified table name." You also have no logging in this method, nor any error handling. I don't know if this matters but given your meticulous logging elsewhere it seems useful.
Your docstring for sql_stmt
is missing in the copy_to_db
method. This suggests to me it was hacked in (see above recommendations for a normal "execute arbitrary SQL method") since it did not make it into the doc string, nor does it fit into the program flow naturally.
For multiline strings using format
I really like specifying what the variable name is:
logging.info('Unloading from {table} (Will be None if custom '
'SQL was used'.format(table=table_name))
is much easier to read than
logging.info('Unloading from {} (Will be None if custom '
'SQL was used'.format(table_name))
This is especially true if you have many {}
within the string. It helps because in the future if you change the order of your parameters, adding another parameter, etc, are all now coupled to the arguments. When your string is 10+ lines long this matters even more.
Last, I'm not really seeing what you are gaining from subtyping here. There may be more context but I am not really sure why you need to do this - it seems to just generate a lot more docstrings and allow you to get the __repr__
method. Do you need three classes? Couldn't an 'execute_loadand
execute_unload` be used on the same base without any of the other stuff?
Related to that, normal convention in Python is prepending private variables with _
. It looks like you only are setting variables through your __init__
method. Which might make them private - though you may be using/setting them elsewhere.
-
\$\begingroup\$ Excellent feedback, thank you! I'm only a year into my career, and just got a job at a start-up doing data engineering, so I'm trying to learn good habits and practices. All of the code is ran internally, but the SQL injection stuff is very good to know. Looks like I have some tidying up to do! \$\endgroup\$flybonzai– flybonzai2016年03月12日 03:14:52 +00:00Commented Mar 12, 2016 at 3:14
-
1\$\begingroup\$ @flybonzai make sure to read through the documentation for that - I'm not 100% sure you are using that exact library to do your connections, it sure looks like it, but make sure to verify that works the way it seems to work on your system ;) \$\endgroup\$enderland– enderland2016年03月12日 03:15:44 +00:00Commented Mar 12, 2016 at 3:15
-
1\$\begingroup\$ @flybonzai also if this is primarily for you using yourself it's probably not as big of a problem to have SQL injection. \$\endgroup\$enderland– enderland2016年03月12日 03:29:33 +00:00Commented Mar 12, 2016 at 3:29
-
\$\begingroup\$ It's all for internal etl stuff, no Web access:) \$\endgroup\$flybonzai– flybonzai2016年03月12日 03:32:29 +00:00Commented Mar 12, 2016 at 3:32
-
\$\begingroup\$ I'm working through your suggestions now, and had a question about private variables. So when you are only setting them in the
__init__
method, you should consider them private variables?Python
is my first language, so private variables are a bit of a foreign construct to me. When should you do the protected variables using double underscore? \$\endgroup\$flybonzai– flybonzai2016年03月14日 15:56:56 +00:00Commented Mar 14, 2016 at 15:56
Explore related questions
See similar questions with these tags.