2
\$\begingroup\$

We do constant traffic with our Redshift tables, and so I created a wrapper class that will allow for custom sql to be ran (or a default generic stmt), and can run a safe_load where it first copies the info to a dev db to make sure it will work before it truncates the prod db. Let me know what I can improve!

class RedshiftBase(object):
 def __init__(self,
 s3_credentials,
 redshift_db_credentials,
 table_name=None,
 schema_name=None,
 manifest_url=None,
 unload_url=None,
 dev_db_credentials=None,
 sql_stmt=None,
 safe_load=False,
 truncate=False):
 """
 This class automates the copy of data from an S3 file to a Redshift
 database. Most of the methods are static, and can be accessed outside
 the class. Run the 'execute' method to run the process.
 :param table_name: The Redshift table name. Must include the schema if that
 is required for database access. Ex: 'schema.table'.
 :param s3_credentials: A dictionary containing the access and
 secret access keys. Keys must match the example:
 S3_INFO = {
 'aws_access_key_id': S3_ACCESS,
 'aws_secret_access_key': S3_SECRET,
 'region_name': 'us-west-2'
 }
 :param redshift_db_credentials: A dictionary containing the host, port,
 database name, username, and password. Keys must match example:
 REDSHIFT_POSTGRES_INFO = {
 'host': REDSHIFT_HOST,
 'port': REDSHIFT_PORT,
 'database': REDSHIFT_DATABASE_DEV,
 'user': REDSHIFT_USER,
 'password': REDSHIFT_PASS
 }
 :param schema_name: The schema name associated with the desired table.
 :param unload_url: In the case of an unload operation, this specifies
 the location on S3 where the files will be unloaded to.
 :param manifest_url: The location of the file on S3.
 :param sql_stmt: A SQL statement given as a single string. This is the
 statement that will be used instead of the default.
 Ex:
 '''
 SELECT *
 FROM table
 WHERE conditions < parameters
 ORDER BY field DESC
 '''
 :param safe_load: If True will trigger a test load to a specified
 development database during the 'execute' method. Useful for making
 sure the data will correctly load before truncating the production
 database.
 :param truncate: If 'True', the production table will be truncated
 before the copy step.
 :return: None
 """
 if schema_name:
 self.table_name = schema_name + '.' + table_name
 else:
 self.table_name = table_name
 self.manifest_url = manifest_url
 self.unload_url = unload_url
 self.s3_credentials = s3_credentials
 self.prod_db_credentials = redshift_db_credentials
 self.dev_db_credentials = dev_db_credentials
 self.sql_stmt = sql_stmt
 self.safe_load = safe_load
 self.truncate = truncate
 def __repr__(self):
 return ('Table: {}\nManifest URL: {}\nUnload URL: {}\nS3 Credentials: '
 '{}\nDev DB Credentials: {}\nProd DB Credentials: {}\nSafe '
 'Load: {}\nTruncate: {}'.format(
 self.table_name,
 self.manifest_url,
 self.unload_url,
 self.s3_credentials,
 self.dev_db_credentials,
 self.prod_db_credentials,
 self.safe_load,
 self.truncate
 ))
class RedshiftLoad(RedshiftBase):
 @staticmethod
 def copy_to_db(database_credentials,
 table_name,
 manifest_url,
 s3_credentials,
 sql_stmt,
 truncate=False):
 """
 Copies data from a file on S3 to a Redshift table. Data must be
 properly formatted and in the right order, etc...
 :param database_credentials: A dictionary containing the host, port,
 database name, username, and password. Keys must match example:
 REDSHIFT_POSTGRES_INFO = {
 'host': REDSHIFT_HOST,
 'port': REDSHIFT_PORT,
 'database': REDSHIFT_DATABASE_DEV,
 'user': REDSHIFT_USER,
 'password': REDSHIFT_PASS
 }
 :param table_name: The Redshift table name. Must include the schema if that
 is required for database access. Ex: 'schema.table'.
 :param manifest_url: The location of the file on the S3 server.
 :param s3_credentials: A dictionary containing the access and
 secret access keys. Keys must match the example:
 S3_INFO = {
 'aws_access_key_id': S3_ACCESS,
 'aws_secret_access_key': S3_SECRET,
 'region_name': 'us-west-2'
 }
 :param truncate: If 'True', will cause the table to be truncated before
 the load.
 :return: None
 """
 s3_access = s3_credentials['aws_access_key_id']
 s3_secret = s3_credentials['aws_secret_access_key']
 logging.info('Accessing {}'.format(table_name))
 try:
 with ppg2.connect(**database_credentials) as conn:
 cur = conn.cursor()
 if truncate:
 RedshiftLoad.truncate_table(table_name, cur)
 load = '''
 COPY {}
 from '{}'
 credentials 'aws_access_key_id={};aws_secret_access_key={}'
 delimiter '|'
 gzip
 trimblanks
 truncatecolumns
 acceptinvchars
 timeformat 'auto'
 dateformat 'auto'
 manifest
 '''.format(
 table_name,
 manifest_url,
 s3_access,
 s3_secret)
 if sql_stmt:
 logging.info('Executing custom SQL unload statement.')
 cur.execute(sql_stmt)
 else:
 logging.info('Executing default SQL unload statement.')
 logging.info('Unloading from {}'.format(table_name))
 cur.execute(load)
 conn.commit()
 except ppg2.Error as e:
 logging.critical('Error occurred during db load: {}'.format(e))
 sys.exit(1)
 @staticmethod
 def truncate_table(table, cursor):
 """
 Truncates a table given the schema and table names."""
 trunc_stmt = '''
 truncate table {}
 '''.format(table)
 cursor.execute(trunc_stmt)
 def execute(self):
 if self.safe_load:
 logging.info('Test load triggered, connecting now to {}.'.format(
 self.table_name
 ))
 self.copy_to_db(self.dev_db_credentials,
 self.table_name,
 self.manifest_url,
 self.s3_credentials,
 self.sql_stmt,
 self.truncate)
 logging.info('Load to the development database was a success.')
 logging.info('Commencing load operation to {}.'.format(
 self.table_name))
 self.copy_to_db(self.prod_db_credentials,
 self.table_name,
 self.manifest_url,
 self.s3_credentials,
 self.sql_stmt,
 self.truncate)
 logging.info('Load to the production database was a success.')
class RedshiftUnload(RedshiftBase):
 @staticmethod
 def unload_to_s3(database_credentials,
 table_name,
 unload_url,
 s3_credentials,
 sql_stmt):
 """
 Unloads the data from a Redshift table into a specified location in S3.
 The default UNLOAD statement (i.e. if you don't pass in a sql statement
 ) defaults to allowing files to be overwritten if the unload_url is not
 and empty directory.
 :param database_credentials: A dictionary containing the host, port,
 database name, username, and password. Keys must match example:
 REDSHIFT_POSTGRES_INFO = {
 'host': REDSHIFT_HOST,
 'port': REDSHIFT_PORT,
 'database': REDSHIFT_DATABASE_DEV,
 'user': REDSHIFT_USER,
 'password': REDSHIFT_PASS
 }
 :param table_name: The Redshift table name. Must include the schema if that
 is required for database access. Ex: 'schema.table'
 :param s3_credentials: A dictionary containing the access and
 secret access keys. Keys must match the example:
 S3_INFO = {
 'aws_access_key_id': S3_ACCESS,
 'aws_secret_access_key': S3_SECRET,
 'region_name': 'us-west-2'
 }
 :return: None
 """
 s3_access = s3_credentials['aws_access_key_id']
 s3_secret = s3_credentials['aws_secret_access_key']
 try:
 with ppg2.connect(**database_credentials) as conn:
 cur = conn.cursor()
 unload = '''
 UNLOAD ('SELECT * FROM {}')
 TO '{}'
 CREDENTIALS 'aws_access_key_id={};aws_secret_access_key={}'
 MANIFEST
 DELIMITER '|'
 GZIP
 ALLOWOVERWRITE
 '''.format(
 table_name,
 unload_url,
 s3_access,
 s3_secret)
 if sql_stmt:
 logging.info('Executing custom SQL unload statement.')
 cur.execute(sql_stmt)
 else:
 logging.info('Executing default SQL unload statement.')
 logging.info('Unloading from {} (Will be None if custom '
 'SQL was used'.format(table_name))
 cur.execute(unload)
 conn.commit()
 except ppg2.Error as e:
 logging.critical('Error occurred during unload: {}'.format(e))
 sys.exit(1)
 def execute(self):
 self.unload_to_s3(self.prod_db_credentials,
 self.table_name,
 self.unload_url,
 self.s3_credentials,
 self.sql_stmt)
 logging.info('Unload was a success.')
200_success
146k22 gold badges190 silver badges478 bronze badges
asked Mar 12, 2016 at 0:10
\$\endgroup\$

1 Answer 1

1
\$\begingroup\$

so I created a wrapper class that will allow for custom sql to be ran

So... I don't know if there is a way to parameterize your queries to redshift, or how this class is called, but hopefully not from user input - right now it is completely vulnerable to SQL injection as you are simply entering your table_name variable (and others) directly into the query.

I should note the documentation for executing queries says:

Warning Never, never, NEVER use Python string concatenation (+) or string parameters interpolation (%) to pass variables to a SQL query string. Not even at gunpoint.

I would recommend reading their documentation on parameterized queries:

load = '''
 COPY %(table_name)s
 from '%(manifest_url)s'
 credentials 'aws_access_key_id=$(s3_access)s;aws_secret_access_key=%(s3_secret)s'
 delimiter '|'
 gzip
 trimblanks
 truncatecolumns
 acceptinvchars
 timeformat 'auto'
 dateformat 'auto'
 manifest
 '''
params = ['table_name': table_name,
 'manifest_url': manifest_url,
 's3_access': s3_credentials['aws_access_key_id'],
 's3_secret': s3_credentials['aws_secret_access_key']]
cur.execute(load, params)

You can remove the temporary variable assignment for s3 in this case too.

Also, I am not a fan of how you have structured multiple methods to allow execution of precisely created queries.... or whatever string SQL statement comes in (which I guess would negate the SQL injection parameterization worry if you are literally allowing that... ?).

If you must, I would create a custom method for "execute arbitrary SQL" method instead of basically splitting all your load/unload methods two separate methods. Right now, each with this argument does totally different things depending on the single optional parameter -- when provided most code in the methods it is allowed is suddenly irrelevant.

Your doc strings could use some cleanup:

The default UNLOAD statement (i.e. if you don't pass in a sql statement ) defaults to allowing files to be overwritten if the unload_url is not and empty directory.

I don't really understand what this is saying.

Truncates a table given the schema and table names."""

This method seems to truncate a table given just a table name. I think what you meant to say was "Truncates a table given the fully qualified table name." You also have no logging in this method, nor any error handling. I don't know if this matters but given your meticulous logging elsewhere it seems useful.

Your docstring for sql_stmt is missing in the copy_to_db method. This suggests to me it was hacked in (see above recommendations for a normal "execute arbitrary SQL method") since it did not make it into the doc string, nor does it fit into the program flow naturally.

For multiline strings using format I really like specifying what the variable name is:

logging.info('Unloading from {table} (Will be None if custom '
 'SQL was used'.format(table=table_name))

is much easier to read than

logging.info('Unloading from {} (Will be None if custom '
 'SQL was used'.format(table_name))

This is especially true if you have many {} within the string. It helps because in the future if you change the order of your parameters, adding another parameter, etc, are all now coupled to the arguments. When your string is 10+ lines long this matters even more.

Last, I'm not really seeing what you are gaining from subtyping here. There may be more context but I am not really sure why you need to do this - it seems to just generate a lot more docstrings and allow you to get the __repr__ method. Do you need three classes? Couldn't an 'execute_loadandexecute_unload` be used on the same base without any of the other stuff?

Related to that, normal convention in Python is prepending private variables with _. It looks like you only are setting variables through your __init__ method. Which might make them private - though you may be using/setting them elsewhere.

answered Mar 12, 2016 at 3:09
\$\endgroup\$
5
  • \$\begingroup\$ Excellent feedback, thank you! I'm only a year into my career, and just got a job at a start-up doing data engineering, so I'm trying to learn good habits and practices. All of the code is ran internally, but the SQL injection stuff is very good to know. Looks like I have some tidying up to do! \$\endgroup\$ Commented Mar 12, 2016 at 3:14
  • 1
    \$\begingroup\$ @flybonzai make sure to read through the documentation for that - I'm not 100% sure you are using that exact library to do your connections, it sure looks like it, but make sure to verify that works the way it seems to work on your system ;) \$\endgroup\$ Commented Mar 12, 2016 at 3:15
  • 1
    \$\begingroup\$ @flybonzai also if this is primarily for you using yourself it's probably not as big of a problem to have SQL injection. \$\endgroup\$ Commented Mar 12, 2016 at 3:29
  • \$\begingroup\$ It's all for internal etl stuff, no Web access:) \$\endgroup\$ Commented Mar 12, 2016 at 3:32
  • \$\begingroup\$ I'm working through your suggestions now, and had a question about private variables. So when you are only setting them in the __init__ method, you should consider them private variables? Python is my first language, so private variables are a bit of a foreign construct to me. When should you do the protected variables using double underscore? \$\endgroup\$ Commented Mar 14, 2016 at 15:56

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.