Thanks in advance if you are reading this code.
I recently submitted this code as part of an interview (took about 4-5 hours). Unfortunately, they did not like the code and I received a form rejection email without any sort of feedback. However, I am committed to improving my code and I would like to learn from my mistakes. The code below works. You should be able to run it yourself. It takes about 2 minutes to run. Access to the database is there. It is a test database, but I do not maintain it. It is perfectly fine to have the username and password there.
What the code does: The code accesses an API and a database. It then looks for people with the same first and last name and matches them up and extracts if they were active within 30 days: on the database and on the API, which each represent a different user platform. There was a space constraint for this assignment, which is why I used generators. There is some stuff I didn't mention, but this is the meat of the assignment. Please let me know if any additional clarification is required.
I thought that I had done a pretty good job, but apparently not. Please let me know if you have any feedback (positive and critical) on this code and how it could be improved (assuming it does what it's supposed to do). I would really like to be able to take my rejection and turn it into a learning opportunity. Thanks again.
If you feel that you need to contact me, let me know and we can work it out.
import time
import requests
import pymysql
from datetime import datetime, date
import json
#
# HELPER FUNCTIONS
#
def database_endpoint_iterator(database_config, database_query, size):
"""Generator function that connects to a database and iterates over the data.
Parameters:
database_config (dict): Configuration details for database.
database_query (str): Query specifying what information to extract from the database.
size (int): Number of rows to fetch each time. Controls how much data is loaded at one time into memory.
"""
connection = pymysql.connect(**database_config)
cursor = connection.cursor(pymysql.cursors.DictCursor)
cursor.execute(database_query)
while True:
rows = cursor.fetchmany(size)
if not rows:
break
for row in rows:
yield row
connection.close()
def api_endpoint_iterator(endpoint_url, page_size):
"""Generator function that queries a REST API and iterates over paginated data.
Parameters:
endpoint_url (str): REST API url.
page_size (int): Number of pages to fetch each time. Controls how much data is loaded at one time into memory.
"""
page = 1
total_pages = 1
users_left_over = []
while True:
users = users_left_over
# fetches correct amount of pages at one time
for _ in range(page_size):
payload = {
'page': page
}
r = requests.get(endpoint_url, params=payload)
r_json = r.json()
total_pages = r_json['total_pages']
users += r_json['users']
if page > total_pages:
break
page += 1
# users are only sorted by last name, this ensures that users are sorted by last name and first name
users.sort(key=lambda user: (user['lastname'], user['firstname']))
# handles situations where users with the same last name span multiple pages
for index, user in enumerate(users):
if user['lastname'] == users[-1]['lastname']:
users_left_over = users[index:]
break
yield user
if page > total_pages:
break
# gets any users that were left over due to same last names spanning multiple pages
for user in users_left_over:
yield user
def compare(user1, user2):
"""Compares two users using their first name and last name.
Returns:
0 if users have the same first name and last name
1 if user1 comes alphabetically after user2
-1 if user1 comes alphabetically before user2
"""
user1_str = user1['lastname'] + ' ' + user1['firstname']
user2_str = user2['lastname'] + ' ' + user2['firstname']
if user1_str < user2_str:
return -1
elif user1_str > user2_str:
return 1
else:
return 0
def is_active(user):
"""Determines if a user is active.
Returns:
True if the user was active within the last 30 days, otherwise False.
"""
today = "2017-02-02"
today = datetime.strptime(today, "%Y-%m-%d")
last_active = datetime.strptime(str(user['last_active_date']), "%Y-%m-%d")
return (today - last_active).days <= 30
def create_user_dict(user_internal, user_external):
"""Creates a combined data set from an internal user and external user.
Returns:
A dictionary of relevant data for the users.
"""
user = {'firstname': user_internal['firstname'],
'lastname': user_internal['lastname'],
'specialty': user_internal['specialty'].lower(),
'practice_location': user_external['practice_location'],
'platform_registered_on': user_internal['platform_registered_on'].lower(),
'internal_classification': user_internal['classification'].lower(),
'external_classification': user_external['user_type_classification'].lower(),
'is_active_internal_platform': is_active(user_internal),
'is_active_external_platform': is_active(user_external)}
return user
#
# CONFIGURATION
#
start_time = time.time()
row_size = 5000 # configuration variable for how many rows from the database are loaded into memory
page_size = 1 # configuration variable for how many pages from the api are loaded into memory
warehouse_sample_user_count = 10
warehouse_sample = {'users': []}
total_matches = 0
# rest api url
endpoint_url = 'http://de-tech-challenge-api.herokuapp.com/api/v1/users'
# database configuration
database_config = {'host': 'candidate-coding-challenge.dox.pub',
'user': 'de_candidate',
'password': 'P8MWmPPBLhhLX79n',
'port': 3316,
'database': 'data_engineer'}
database_query = "SELECT * FROM user ORDER BY lastname, firstname;"
#
# MAIN PROGRAM
#
# set up the data iterators using the function generators
users_internal_source = database_endpoint_iterator(database_config, database_query, row_size)
users_external_source = api_endpoint_iterator(endpoint_url, page_size)
# get a user from each data source
user_internal = next(users_internal_source)
user_external = next(users_external_source)
# compare each user in one data source to the other, stop when there is no more data
while True:
try:
if compare(user_internal, user_external) == 0:
total_matches += 1
if warehouse_sample_user_count > 0:
warehouse_sample['users'].append(create_user_dict(user_internal, user_external))
warehouse_sample_user_count -= 1
user_internal = next(users_internal_source)
user_external = next(users_external_source)
elif compare(user_internal, user_external) < 0:
user_internal = next(users_internal_source)
else:
user_external = next(users_external_source)
except StopIteration:
break
# sample user data in json for the warehouse
warehouse_sample = json.dumps(warehouse_sample, indent = 4)
# sql for the design of a table that would house the results, this is just for printing to the output.txt file
sql_ddl = '''CREATE TABLE user_active_status (
id INT NOT NULL AUTO_INCREMENT,
first_name VARCHAR(50),
last_name VARCHAR(50),
specialty VARCHAR(50),
practice_location VARCHAR(50),
platform_registered_on VARCHAR(25),
internal_classification VARCHAR(50),
external_classification VARCHAR(50),
is_active_internal_platform TINYINT(1),
is_active_external_platform TINYINT(1)
PRIMARY KEY (id)
);'''
end_time = time.time()
elapsed_time = round(end_time - start_time)
#
# OUTPUT
#
# generate the output.txt file
with open("output.txt", "w") as f:
f.write("Elapsed Time: " + str(int(elapsed_time / 60)) + ' minutes, ' + str(elapsed_time % 60) + ' seconds\n\n')
f.write("Total Matches: " + str(total_matches) + "\n\n")
f.write("Sample Output:\n" + warehouse_sample + "\n\n")
f.write("SQL DDL:\n")
f.write(sql_ddl)
2 Answers 2
First impression is the code is well-documented and is easy to read, especially given the context of it being an interview assignment. But there are definitely places where it can be improved, so let's start with the low-hanging fruit: execution time performance and memory consumption.
requests.Session
All API calls are to the same host, so we can take advantage of this and make all calls via the same requests.Session
object for better performance. From the requests
documentation on Session Objects:
The Session object allows you to persist certain parameters across requests. It also persists cookies across all requests made from the Session instance, and will use
urllib3
’s connection pooling. So if you’re making several requests to the same host, the underlying TCP connection will be reused, which can result in a significant performance increase (see HTTP persistent connection).
Example:
with requests.Session() as session:
for page_number in range(1, num_pages + 1):
# ...
json_response = session.get(url, params=params).json()
I tested this on a refactored version of your code, and this change alone almost halved the total execution time.
Memory footprint
Your code uses generators which is great for memory efficiency, but can we do better? Let's look at a memory trace of your code using the "Pretty top" recipe from tracemalloc
:
Top 10 lines
#1: json/decoder.py:353: 494.7 KiB
obj, end = self.scan_once(s, idx)
#2: pymysql/connections.py:1211: 202.8 KiB
return tuple(row)
#3: requests/models.py:828: 168.7 KiB
self._content = b''.join(self.iter_content(CONTENT_CHUNK_SIZE)) or b''
#4: ./old_db.py:100: 67.5 KiB
users.sort(key=lambda user: (user['lastname'], user['firstname']))
#5: <frozen importlib._bootstrap_external>:580: 57.7 KiB
#6: python3.8/abc.py:102: 13.5 KiB
return _abc_subclasscheck(cls, subclass)
#7: urllib3/poolmanager.py:297: 6.4 KiB
base_pool_kwargs = self.connection_pool_kw.copy()
#8: ./old_db.py:92: 6.0 KiB
users += r_json['users']
#9: urllib3/poolmanager.py:167: 5.1 KiB
self.key_fn_by_scheme = key_fn_by_scheme.copy()
#10: python3.8/re.py:310: 5.0 KiB
_cache[type(pattern), pattern, flags] = p
686 other: 290.4 KiB
Total allocated size: 1317.8 KiB
Shown above are the 10 lines allocating the most memory. It might not be immediately obvious, but the fairly high memory usages in #1, #2, and #4 can all be attributed to using a Python dictionary as a storage container for each database/API record. Basically, using a dictionary in this way is expensive and unnecessary since we're never really adding/removing/changing fields in one of these dictionaries once we've read it into memory.
The memory hotspots:
- Using
pymysql.cursors.DictCursor
to return each row in the query results as a dictionary, combined with the fact that we're making batched fetches ofsize=5000
rows at a time -- that's not a small number of dictionaries to hold in memory at one time. Plus, through testing I determined that there is virtually no difference in speed (execution time) between fetching in batches from the database versus retrieving rows one at a time using the unbufferedpymysql.cursors.SSCursor
, soSSCursor
is probably the better choice here - Reading in, accumulating, and sorting dictionaries in
api_endpoint_iterator
Side note: #3 above can actually be eliminated by merging the following two lines into one, since we never use
r
again after callingjson()
on it:# Before r = requests.get(endpoint_url, params=payload) r_json = r.json() # After r_json = requests.get(endpoint_url, params=payload).json()
A better alternative in this case is to use a NamedTuple
to represent each record. NamedTuple
s are immutable, have a smaller memory footprint than dictionaries, are sortable like regular tuples, and are the preferred option when you know all of your fields and their types in advance.
Having something like the following gives us a nice, expressive, compact type that also makes the code easier to read:
from typing import NamedTuple
class ExternalUser(NamedTuple):
last_name: str
first_name: str
user_id: int
last_active_date: str
practice_location: str
specialty: str
user_type_classification: str
At the end of this review is a refactored version of the code which uses NamedTuple
s. Here's a preview of what its memory trace looks like:
Top 10 lines
#1: <frozen importlib._bootstrap_external>:580: 54.0 KiB
#2: python3.8/abc.py:102: 12.8 KiB
return _abc_subclasscheck(cls, subclass)
#3: urllib3/poolmanager.py:297: 12.5 KiB
base_pool_kwargs = self.connection_pool_kw.copy()
#4: json/decoder.py:353: 5.0 KiB
obj, end = self.scan_once(s, idx)
#5: pymysql/converters.py:299: 4.5 KiB
return datetime.date(*[ int(x) for x in obj.split('-', 2) ])
#6: json/encoder.py:202: 4.2 KiB
return ''.join(chunks)
#7: ./new_db.py:201: 3.5 KiB
return {
#8: pymysql/connections.py:1206: 3.1 KiB
data = data.decode(encoding)
#9: python3.8/_strptime.py:170: 2.8 KiB
class TimeRE(dict):
#10: python3.8/_strptime.py:30: 2.7 KiB
class LocaleTime(object):
641 other: 276.6 KiB
Total allocated size: 381.5 KiB
Context managers
It's not provided out of the box by the pymysql
module, but you should use a context manager for the database connection to ensure that the connection is always closed, even after an unexpected program halt due to an exception.
Right now if your program were to encounter an exception anywhere in between connection = pymysql.connect(...)
and connection.close()
, the connection might not be closed safely.
Here's an example of how you could make your own context manager for the connection:
import pymysql
from typing import Dict, Any, Iterator
from contextlib import contextmanager
@contextmanager
def database_connection(
config: Dict[str, Any]
) -> Iterator[pymysql.connections.Connection]:
connection = pymysql.connect(**config)
try:
yield connection
finally:
connection.close()
# Example usage
with database_connection(config) as connection:
# Note: context managers for cursors __are__ provided by pymysql
with connection.cursor(pymysql.cursors.SSCursor) as cursor:
cursor.execute(query)
# ...
Type hints
Consider using type hints to:
- improve code readability
- increase confidence in code correctness with the help of a static type checker like
mypy
For example, the method that provides a stream of external users from the API has some fairly dense logic in it, but with type hints we can just look at the method signature to guess what it's doing or what to expect from it:
def api_records(api_url: str) -> Iterator[ExternalUser]:
# ...
Generator of matching pairs
At the top level of code execution, there's some logic where we iterate over both internal and external users in order to find all matching pairs, where a matching pair is an internal user record and an external user record with the same first and last name.
It would be cleaner to go one step further with generators and extract this logic into its own method that returns a generator. In other words, we could have two input streams (internal and external user records) and our output would then be a stream of matching pairs of internal and external user records:
def matching_users(
internal_users: Iterator[InternalUser],
external_users: Iterator[ExternalUser],
) -> Iterator[Tuple[InternalUser, ExternalUser]]:
# ...
This is a nicer abstraction to work with; the client gets direct access to all the matching pairs, and can iterate over them to get the total number of matches and/or save a subset of the matches to a report.
Refactored version
Below is the refactored version with the above suggestions incorporated:
#!/usr/bin/env python3
from __future__ import annotations
import time
import requests
import datetime
import json
import pymysql
from typing import (
NamedTuple,
TypeVar,
Dict,
List,
Iterator,
Callable,
Any,
Tuple,
)
from collections import OrderedDict
from functools import partial
from contextlib import contextmanager
from textwrap import dedent
T = TypeVar("T")
class Config(NamedTuple):
host: str
user: str
password: str
port: int
database: str
class InternalUser(NamedTuple):
last_name: str
first_name: str
user_id: int
last_active_date: datetime.date
platform_registered_on: str
practice_id: int
specialty: str
classification: str
class ExternalUser(NamedTuple):
last_name: str
first_name: str
user_id: int
last_active_date: str
practice_location: str
specialty: str
user_type_classification: str
@contextmanager
def database_connection(
config: Config,
) -> Iterator[pymysql.connections.Connection]:
connection = pymysql.connect(
host=config.host,
user=config.user,
password=config.password,
port=config.port,
database=config.database,
)
try:
yield connection
finally:
connection.close()
def database_records(
config: Config, query: str, record_type: Callable[..., T]
) -> Iterator[T]:
with database_connection(config) as connection:
with connection.cursor(pymysql.cursors.SSCursor) as cursor:
cursor.execute(query)
for row in cursor:
yield record_type(*row)
def api_records(api_url: str) -> Iterator[ExternalUser]:
def load_users(
storage: OrderedDict[str, List[ExternalUser]],
users: List[Dict[str, Any]],
) -> None:
for user in users:
ext_user = ExternalUser(
last_name=user["lastname"],
first_name=user["firstname"],
user_id=user["id"],
last_active_date=user["last_active_date"],
practice_location=user["practice_location"],
specialty=user["specialty"],
user_type_classification=user["user_type_classification"],
)
storage.setdefault(ext_user.last_name, []).append(ext_user)
def available_sorted_users(
storage: OrderedDict[str, List[ExternalUser]], remaining: bool = False
) -> Iterator[ExternalUser]:
threshold = 0 if remaining else 1
while len(storage) > threshold:
_, user_list = storage.popitem(last=False)
user_list.sort()
yield from user_list
user_dict: OrderedDict[str, List[ExternalUser]] = OrderedDict()
with requests.Session() as session:
params = {"page": 1}
json_response = session.get(api_url, params=params).json()
total_pages = json_response["total_pages"]
load_users(user_dict, json_response["users"])
yield from available_sorted_users(user_dict)
for current_page in range(2, total_pages + 1):
params = {"page": current_page}
json_response = session.get(api_url, params=params).json()
load_users(user_dict, json_response["users"])
yield from available_sorted_users(user_dict)
yield from available_sorted_users(user_dict, remaining=True)
def matching_users(
internal_users: Iterator[InternalUser],
external_users: Iterator[ExternalUser],
) -> Iterator[Tuple[InternalUser, ExternalUser]]:
internal_user = next(internal_users, None)
external_user = next(external_users, None)
while internal_user and external_user:
internal_name = (internal_user.last_name, internal_user.first_name)
external_name = (external_user.last_name, external_user.first_name)
if internal_name == external_name:
yield (internal_user, external_user)
internal_user = next(internal_users, None)
external_user = next(external_users, None)
elif internal_name < external_name:
internal_user = next(internal_users, None)
else:
external_user = next(external_users, None)
def active_recently(
current_date: datetime.date, num_days: int, last_active_date: datetime.date
) -> bool:
return (current_date - last_active_date).days <= num_days
def create_user_dict(
internal_user: InternalUser,
external_user: ExternalUser,
is_active: Callable[[datetime.date], bool],
) -> Dict[str, Any]:
internal_user_is_active = is_active(internal_user.last_active_date)
external_user_last_active_date = datetime.datetime.strptime(
external_user.last_active_date, "%Y-%m-%d"
).date()
external_user_is_active = is_active(external_user_last_active_date)
return {
"firstname": internal_user.first_name,
"lastname": internal_user.last_name,
"specialty": internal_user.specialty,
"practice_location": external_user.practice_location,
"platform_registered_on": internal_user.platform_registered_on,
"internal_classification": internal_user.classification,
"external_classification": external_user.user_type_classification,
"is_active_internal_platform": internal_user_is_active,
"is_active_external_platform": external_user_is_active,
}
if __name__ == "__main__":
start_time = time.time()
CURRENT_DATE = datetime.date(2017, 2, 2)
is_active = partial(active_recently, CURRENT_DATE, 30)
WAREHOUSE_SAMPLE_USER_COUNT = 10
warehouse_samples = []
API_URL = "http://de-tech-challenge-api.herokuapp.com/api/v1/users"
DB_CONFIG = Config(
host="candidate-coding-challenge.dox.pub",
user="de_candidate",
password="P8MWmPPBLhhLX79n",
port=3316,
database="data_engineer",
)
DB_QUERY = """
SELECT lastname
,firstname
,id
,last_active_date
,platform_registered_on
,practice_id
,specialty
,classification
FROM user
ORDER BY lastname, firstname
"""
internal_users = database_records(DB_CONFIG, DB_QUERY, InternalUser)
external_users = api_records(API_URL)
users_in_both_systems = matching_users(internal_users, external_users)
for i, (internal_user, external_user) in enumerate(users_in_both_systems):
if i < WAREHOUSE_SAMPLE_USER_COUNT:
warehouse_samples.append(
create_user_dict(internal_user, external_user, is_active)
)
# At the end of the for loop, `i` is the "index number"
# of the last match => `i + 1` is the total number of matches
total_matches = i + 1
warehouse_sample = json.dumps({"users": warehouse_samples}, indent=4)
SQL_DDL = dedent(
"""
CREATE TABLE user_active_status (
id INT NOT NULL AUTO_INCREMENT,
first_name VARCHAR(50),
last_name VARCHAR(50),
specialty VARCHAR(50),
practice_location VARCHAR(50),
platform_registered_on VARCHAR(25),
internal_classification VARCHAR(50),
external_classification VARCHAR(50),
is_active_internal_platform TINYINT(1),
is_active_external_platform TINYINT(1)
PRIMARY KEY (id)
);
"""
).strip()
end_time = time.time()
elapsed_time = round(end_time - start_time)
minutes = int(elapsed_time / 60)
seconds = elapsed_time % 60
with open("output.txt", "w") as f:
f.write(f"Elapsed Time: {minutes} minutes, {seconds} seconds\n\n")
f.write(f"Total Matches: {total_matches}\n\n")
f.write(f"Sample Matches:\n{warehouse_sample}\n\n")
f.write(f"SQL DDL:\n{SQL_DDL}\n")
I would keep the configuration in a config file. This also prevents stuff like:
# database configuration
database_config = {'host': 'candidate-coding-challenge.dox.pub',
'user': 'de_candidate',
'password': 'P8MWmPPBLhhLX79n',
'port': 3316,
'database': 'data_engineer'}
Where you could accidentily upload your password. The way I do this is by adding:
folder/
.gitignore
main.py
config/
config.yaml
config.yaml-template
Here the config.yaml
would be added to the .gitignore
and all non-sensitive info could be already filled out in the config.yaml-template
.
I would also not have your file run on import. You can do this with a simple structure like:
def main():
# do stuff
if __name__ == '__main__':
main()
Furthermore api_endpoint_iterator
is a very long function, I would try to split it into smaller functions which are easier to test.
Lastly, you explain what what is using:
#
# Section description
#
This might work for shorter assignments, but I preffer to split it into files so you can easier find everything:
folder/
.gitignore
main.py
config/
config.yaml
config.yaml-template
utils/
helper_functions.py
core/
main_functions.py
-
\$\begingroup\$ I up-voted, but it won't display since my reputation isn't high enough. Thank you very much for the feedback, it's incredibly helpful. If there are any other ways to improve the code I would be very appreciative. Do you have any good examples of how to organize code... like using config, utils, and core directories? Thanks again. \$\endgroup\$Daniel R– Daniel R2020年02月06日 19:36:03 +00:00Commented Feb 6, 2020 at 19:36
-
\$\begingroup\$ @DanielR I made this as an example for my friends: gitlab.com/nathan_vanthof/space_invaders \$\endgroup\$Nathan– Nathan2020年02月07日 13:22:26 +00:00Commented Feb 7, 2020 at 13:22
Explore related questions
See similar questions with these tags.
CREATE TABLE
in a SQL statement but don't see anything in your problem description that says it is necessary. \$\endgroup\$xxxxxxx
or something similar. \$\endgroup\$