5
\$\begingroup\$

I would like to get some code review for my recursive implementation of python flatten array method.

Write a piece of functioning code that will flatten an array of arbitrarily nested arrays of integers into a flat array of integers. e.g. [[1,2,[3]],4] -> [1,2,3,4].

I'm particularly looking for some feedback on the following:

Is usage of TypeError exception justified? Did I miss some valid edge cases in my tests? Here is my solution including unit tests:

def flatten(input_array):
 result_array = []
 for element in input_array:
 if isinstance(element, int):
 result_array.append(element)
 elif isinstance(element, list):
 result_array += flatten(element)
 return result_array

it has passed all of the following tests

from io import StringIO
import sys
# custom assert function to handle tests
# input: count {List} - keeps track out how many tests pass and how many total
# in the form of a two item array i.e., [0, 0]
# input: name {String} - describes the test
# input: test {Function} - performs a set of operations and returns a boolean
# indicating if test passed
# output: {None}
def expect(count, name, test):
 if (count is None or not isinstance(count, list) or len(count) != 2):
 count = [0, 0]
 else:
 count[1] += 1
 result = 'false'
 error_msg = None
 try:
 if test():
 result = ' true'
 count[0] += 1
 except Exception as err:
 error_msg = str(err)
 print(' ' + (str(count[1]) + ') ') + result + ' : ' + name)
 if error_msg is not None:
 print(' ' + error_msg + '\n')
# code for capturing print output
#
# directions: capture_print function returns a list of all elements that were
# printed using print with the function that it is given. Note that
# the function given to capture_print must be fed using lambda.
class Capturing(list):
 def __enter__(self):
 self._stdout = sys.stdout
 sys.stdout = self._stringio = StringIO()
 return self
 def __exit__(self, *args):
 self.extend(self._stringio.getvalue().splitlines())
 sys.stdout = self._stdout
def capture_print(to_run):
 with Capturing() as output:
 pass
 with Capturing(output) as output: # note the constructor argument
 to_run()
 return output
def test():
 results = flatten([1, [2, 3, [4]], 5, [[6]]])
 return (len(results) == 6 and
 results[0] == 1 and
 results[1] == 2 and
 results[2] == 3 and
 results[3] == 4 and
 results[4] == 5 and
 results[5] == 6)
expect(test_count, 'should return [1,2,3,4,5,6] output for [1, [2, 3, [4]], 5, [[6]]] input', test)
def test():
 results = flatten([])
 return len(results) == 0
expect(test_count, 'should return [] output for [] input', test)
def test():
 results = flatten([1, [2, 3, [4], []], [], 5, [[], [6]]])
 return (len(results) == 6 and
 results[0] == 1 and
 results[1] == 2 and
 results[2] == 3 and
 results[3] == 4 and
 results[4] == 5 and
 results[5] == 6)
expect(test_count, 'should return [1,2,3,4,5,6] output for [1, [2, 3, [4], []], [], 5, [[], [6]]] input (note the empty arrays)', test)
print('PASSED: ' + str(test_count[0]) + ' / ' + str(test_count[1]) + '\n\n')
200_success
146k22 gold badges190 silver badges478 bronze badges
asked Aug 9, 2018 at 0:36
\$\endgroup\$
2
  • \$\begingroup\$ In terms of your flatten function itself - you could simplify by only checking if the element is a list (and if so, extending recursively), and appending if it's not. Unnecessary for this case, but it would make the function more general and not limit it to integers only. \$\endgroup\$ Commented Aug 9, 2018 at 7:09
  • 1
    \$\begingroup\$ You asked "Is usage of TypeError exception justified?" But I don't see any use of TypeError in your code, neither in your tests... Can you explain what you had in mind? \$\endgroup\$ Commented Aug 9, 2018 at 7:34

3 Answers 3

9
\$\begingroup\$

Your code looks fine, however to improve it, you should use a proper test system like pytest or unittest. To demonstrate, here is your code when using pytest, and making the test proper (you don't need to test every specific item:

def flatten(input_array):
 result_array = []
 for element in input_array:
 if isinstance(element, int):
 result_array.append(element)
 elif isinstance(element, list):
 result_array += flatten(element)
 return result_array
def test01():
 results = flatten([1, [2, 3, [4]], 5, [[6]]])
 assert results == [1, 2, 3, 4, 5, 6]
def test02():
 results = flatten([1, [2, 3, [4], []], [], 5, [[], [6]]])
 assert results == [1, 2, 3, 4, 5, 6]

And here are the results:

C:\PycharmProjects\codereview\tests>pytest scratch_14.py 
======================== test session starts ========================
platform win32 -- Python 3.7.0, pytest-3.6.2, py-1.5.4, pluggy-0.6.0
rootdir: C:\PycharmProjects\codereview\tests, inifile:
plugins: cov-2.5.1, celery-4.2.0
collected 2 items
scratch_14.py .. [100%]
===================== 2 passed in 0.09 seconds ======================

This is much easier to set up, and less code to write to validate if the solution is correct. You asked: Is usage of TypeError exception justified?
I don't actually see any code referencing a type error. Did you forget to put it in? Or are you referring to the use of isinstance? If so, that code is fine.
Hope this helps!

answered Aug 9, 2018 at 4:46
\$\endgroup\$
7
\$\begingroup\$

Your function only deals with ints and lists. While it may be fine in the context of the question, this doesn't feel Pythonic at all as it disregard any other kind of iterable and any other type of data:

>>> flatten([1, (2, 3), [4.5], 6])
[1, 6]

Instead, you could make use of the iterator protocol to have a generic flatten function:

def flatten(iterable):
 try:
 iterator = iter(iterable)
 except TypeError:
 yield iterable
 else:
 for element in iterator:
 yield from flatten(element)

Usage being:

>>> list(flatten([1, (2, 3), [4.5], 6]))
[1, 2, 3, 4.5, 6]

However, there are two potential issues with this approach:

  • you may not like that flatten is now a generator: change it to an helper function and wrap it with a call to list:

    def _flatten_generator(iterable):
     # previous code
    def flatten(iterable):
     return list(_flatten_generator(iterable))
    
  • you won't be able to handle strings at all as individual characters are still a string and you will run into a:

    RecursionError: maximum recursion depth exceeded while calling a Python object
    

    So you may want to add an explicit check for str at the beginning of the function.

answered Aug 9, 2018 at 8:32
\$\endgroup\$
0
5
\$\begingroup\$

appending to @C. Harley's answer,

just a note on string concatenation. it works fine but if you are putting in several variables, it's better to use string formatting from

print('PASSED: ' + str(test_count[0]) + ' / ' + str(test_count[1]) + '\n\n')

to

print('PASSED: {} / {}\n\n'.format(test_count[0], test_count[1]))

it also saves from using str() each time

answered Aug 9, 2018 at 5:48
\$\endgroup\$

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.