Many a times, I've had the need to use a permutations with replacement function.
So, I've written a function to do just that:
from sys import setrecursionlimit
setrecursionlimit(10 ** 9)
def permutations_with_replacement(n: int, m: int, cur=None):
if cur is None:
cur = []
if n == 0:
yield cur
return
for i in range(1, m + 1):
yield from permutations_with_replacement(n - 1, m, cur + [i])
if __name__ == '__main__':
n = int(input("Please enter 'N': "))
m = int(input("Please enter 'M': "))
for i in permutations_with_replacement(n, m):
print(*i)
There's a better way to do this if we used itertools.product
, but there's no fun in that!
from itertools import product
def permutations_with_replacement(n, m):
for i in product(list(range(1, m + 1)), repeat=n):
yield i
if __name__ == '__main__':
n = int(input("Please enter 'N': "))
m = int(input("Please enter 'M': "))
for i in permutations_with_replacement(n, m):
print(*i)
I don't like the way I've implemented cur
in my code. Is there any better way to do that?
2 Answers 2
First of all, this is a very nicely written code.
Before writing about the cur
implementation, a few notes:
- I do not think
n
andm
should be in main. It creates shadowing in the function. - I added type annotation to the return value of the function. Since it's a generator it should be:
Iterator[List[int]]
.
So you said you do not like the way cur
is implemented in your code, I thought of a way to take it out of the function declaration.
I created a sub function that is recursive taken out the cur
variable. This is how it turned out:
def permutations_with_replacement(n: int, m: int) -> Iterator[List[int]]:
cur = []
def permutations_with_replacement_rec(n_rec: int, m_rec: int) -> Iterator[List[int]]:
nonlocal cur
if n_rec == 0:
yield cur
return
for i in range(1, m_rec + 1):
cur = cur + [i]
yield from permutations_with_replacement_rec(n_rec - 1, m_rec)
cur.pop()
yield from permutations_with_replacement_rec(n, m)
Now notice that I needed to add a pop
to the variable since now it keeps the elements after the call to the function.
Also I needed to use the nonlocal
term so the function will know the cur
variable.
The full code (with tests):
from itertools import product
from typing import Iterator, List
from sys import setrecursionlimit
setrecursionlimit(10 ** 9)
def permutations_with_replacement(n: int, m: int) -> Iterator[List[int]]:
cur = []
def permutations_with_replacement_rec(n_rec: int, m_rec: int) -> Iterator[List[int]]:
nonlocal cur
if n_rec == 0:
yield cur
return
for i in range(1, m_rec + 1):
cur = cur + [i]
yield from permutations_with_replacement_rec(n_rec - 1, m_rec)
cur.pop()
yield from permutations_with_replacement_rec(n, m)
def test_permutations_with_replacement():
n = 3
m = 4
assert set(product(list(range(1, m + 1)), repeat=n)) == set(tuple(i) for i in permutations_with_replacement(n, m))
def main():
n = int(input("Please enter 'N': "))
m = int(input("Please enter 'M': "))
for i in permutations_with_replacement(n, m):
print(*i)
if __name__ == '__main__':
main()
I know its a bit late but I have had this idea to solve this problem iteratively. I wish to leave this here for someone who is searching for an iterative answer like me.
It is difficult to explain in English (for me), but the idea is to start updating the current permutation from the end, one index at a time. Keep updating until the initial permutation is again encountered, in which case we stop updating any more by setting hasNext=False
.
from typing import Iterator, List
def permutations_with_replacement(n: int, m: int) -> Iterator[List[int]]:
cur = [1]*n
hasNext = True
while hasNext:
yield cur
i = n-1
while hasNext:
cur[i] += 1
if cur[i] > m:
cur[i] = 1
i -= 1
if i < 0:
hasNext = False
else:
break
if __name__ == '__main__':
n = int(input("Please enter 'N': "))
m = int(input("Please enter 'M': "))
for i in permutations_with_replacement(n, m):
print(*i)
Explore related questions
See similar questions with these tags.