I would like to get some code review for my recursive implementation of python flatten array method.
Write a piece of functioning code that will flatten an array of arbitrarily nested arrays of integers into a flat array of integers. e.g. [[1,2,[3]],4] -> [1,2,3,4].
I'm particularly looking for some feedback on the following:
Is usage of TypeError exception justified? Did I miss some valid edge cases in my tests? Here is my solution including unit tests:
def flatten(input_array):
result_array = []
for element in input_array:
if isinstance(element, int):
result_array.append(element)
elif isinstance(element, list):
result_array += flatten(element)
return result_array
it has passed all of the following tests
from io import StringIO
import sys
# custom assert function to handle tests
# input: count {List} - keeps track out how many tests pass and how many total
# in the form of a two item array i.e., [0, 0]
# input: name {String} - describes the test
# input: test {Function} - performs a set of operations and returns a boolean
# indicating if test passed
# output: {None}
def expect(count, name, test):
if (count is None or not isinstance(count, list) or len(count) != 2):
count = [0, 0]
else:
count[1] += 1
result = 'false'
error_msg = None
try:
if test():
result = ' true'
count[0] += 1
except Exception as err:
error_msg = str(err)
print(' ' + (str(count[1]) + ') ') + result + ' : ' + name)
if error_msg is not None:
print(' ' + error_msg + '\n')
# code for capturing print output
#
# directions: capture_print function returns a list of all elements that were
# printed using print with the function that it is given. Note that
# the function given to capture_print must be fed using lambda.
class Capturing(list):
def __enter__(self):
self._stdout = sys.stdout
sys.stdout = self._stringio = StringIO()
return self
def __exit__(self, *args):
self.extend(self._stringio.getvalue().splitlines())
sys.stdout = self._stdout
def capture_print(to_run):
with Capturing() as output:
pass
with Capturing(output) as output: # note the constructor argument
to_run()
return output
def test():
results = flatten([1, [2, 3, [4]], 5, [[6]]])
return (len(results) == 6 and
results[0] == 1 and
results[1] == 2 and
results[2] == 3 and
results[3] == 4 and
results[4] == 5 and
results[5] == 6)
expect(test_count, 'should return [1,2,3,4,5,6] output for [1, [2, 3, [4]], 5, [[6]]] input', test)
def test():
results = flatten([])
return len(results) == 0
expect(test_count, 'should return [] output for [] input', test)
def test():
results = flatten([1, [2, 3, [4], []], [], 5, [[], [6]]])
return (len(results) == 6 and
results[0] == 1 and
results[1] == 2 and
results[2] == 3 and
results[3] == 4 and
results[4] == 5 and
results[5] == 6)
expect(test_count, 'should return [1,2,3,4,5,6] output for [1, [2, 3, [4], []], [], 5, [[], [6]]] input (note the empty arrays)', test)
print('PASSED: ' + str(test_count[0]) + ' / ' + str(test_count[1]) + '\n\n')
3 Answers 3
Your code looks fine, however to improve it, you should use a proper test system like pytest or unittest. To demonstrate, here is your code when using pytest, and making the test proper (you don't need to test every specific item:
def flatten(input_array):
result_array = []
for element in input_array:
if isinstance(element, int):
result_array.append(element)
elif isinstance(element, list):
result_array += flatten(element)
return result_array
def test01():
results = flatten([1, [2, 3, [4]], 5, [[6]]])
assert results == [1, 2, 3, 4, 5, 6]
def test02():
results = flatten([1, [2, 3, [4], []], [], 5, [[], [6]]])
assert results == [1, 2, 3, 4, 5, 6]
And here are the results:
C:\PycharmProjects\codereview\tests>pytest scratch_14.py
======================== test session starts ========================
platform win32 -- Python 3.7.0, pytest-3.6.2, py-1.5.4, pluggy-0.6.0
rootdir: C:\PycharmProjects\codereview\tests, inifile:
plugins: cov-2.5.1, celery-4.2.0
collected 2 items
scratch_14.py .. [100%]
===================== 2 passed in 0.09 seconds ======================
This is much easier to set up, and less code to write to validate if the solution is correct.
You asked: Is usage of TypeError exception justified?
I don't actually see any code referencing a type error. Did you forget to put it in? Or are you referring to the use of isinstance? If so, that code is fine.
Hope this helps!
Your function only deals with int
s and list
s. While it may be fine in the context of the question, this doesn't feel Pythonic at all as it disregard any other kind of iterable and any other type of data:
>>> flatten([1, (2, 3), [4.5], 6])
[1, 6]
Instead, you could make use of the iterator protocol to have a generic flatten
function:
def flatten(iterable):
try:
iterator = iter(iterable)
except TypeError:
yield iterable
else:
for element in iterator:
yield from flatten(element)
Usage being:
>>> list(flatten([1, (2, 3), [4.5], 6]))
[1, 2, 3, 4.5, 6]
However, there are two potential issues with this approach:
you may not like that
flatten
is now a generator: change it to an helper function and wrap it with a call tolist
:def _flatten_generator(iterable): # previous code def flatten(iterable): return list(_flatten_generator(iterable))
you won't be able to handle strings at all as individual characters are still a string and you will run into a:
RecursionError: maximum recursion depth exceeded while calling a Python object
So you may want to add an explicit check for
str
at the beginning of the function.
appending to @C. Harley's answer,
just a note on string concatenation. it works fine but if you are putting in several variables, it's better to use string formatting from
print('PASSED: ' + str(test_count[0]) + ' / ' + str(test_count[1]) + '\n\n')
to
print('PASSED: {} / {}\n\n'.format(test_count[0], test_count[1]))
it also saves from using str() each time
TypeError
in your code, neither in your tests... Can you explain what you had in mind? \$\endgroup\$