[Python-checkins] cpython: Add optional *func* argument to itertools.accumulate().

raymond.hettinger python-checkins at python.org
Mon Mar 28 03:52:18 CEST 2011


http://hg.python.org/cpython/rev/79ccefd30a37
changeset: 69021:79ccefd30a37
user: Raymond Hettinger <python at rcn.com>
date: Sun Mar 27 18:52:10 2011 -0700
summary:
 Add optional *func* argument to itertools.accumulate().
files:
 Doc/library/itertools.rst | 33 ++++++++++++++++++++++---
 Lib/test/test_itertools.py | 12 ++++++++-
 Misc/NEWS | 3 ++
 Modules/itertoolsmodule.c | 18 +++++++++++---
 4 files changed, 56 insertions(+), 10 deletions(-)
diff --git a/Doc/library/itertools.rst b/Doc/library/itertools.rst
--- a/Doc/library/itertools.rst
+++ b/Doc/library/itertools.rst
@@ -46,7 +46,7 @@
 ==================== ============================ ================================================= =============================================================
 Iterator Arguments Results Example
 ==================== ============================ ================================================= =============================================================
-:func:`accumulate` p p0, p0+p1, p0+p1+p2, ... ``accumulate([1,2,3,4,5]) --> 1 3 6 10 15``
+:func:`accumulate` p [,func] p0, p0+p1, p0+p1+p2, ... ``accumulate([1,2,3,4,5]) --> 1 3 6 10 15``
 :func:`chain` p, q, ... p0, p1, ... plast, q0, q1, ... ``chain('ABC', 'DEF') --> A B C D E F``
 :func:`compress` data, selectors (d[0] if s[0]), (d[1] if s[1]), ... ``compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F``
 :func:`dropwhile` pred, seq seq[n], seq[n+1], starting when pred fails ``dropwhile(lambda x: x<5, [1,4,6,4,1]) --> 6 4 1``
@@ -84,23 +84,46 @@
 streams of infinite length, so they should only be accessed by functions or
 loops that truncate the stream.
 
-.. function:: accumulate(iterable)
+.. function:: accumulate(iterable[, func])
 
 Make an iterator that returns accumulated sums. Elements may be any addable
- type including :class:`Decimal` or :class:`Fraction`. Equivalent to::
+ type including :class:`Decimal` or :class:`Fraction`. If the optional
+ *func* argument is supplied, it should be a function of two arguments
+ and it will be used instead of addition.
 
- def accumulate(iterable):
+ Equivalent to::
+
+ def accumulate(iterable, func=operator.add):
 'Return running totals'
 # accumulate([1,2,3,4,5]) --> 1 3 6 10 15
+ # accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
 it = iter(iterable)
 total = next(it)
 yield total
 for element in it:
- total = total + element
+ total = func(total, element)
 yield total
 
+ Uses for the *func* argument include :func:`min` for a running minimum,
+ :func:`max` for a running maximum, and :func:`operator.mul` for a running
+ product::
+
+ >>> data = [3, 4, 6, 2, 1, 9, 0, 7, 5, 8]
+ >>> list(accumulate(data, operator.mul)) # running product
+ [3, 12, 72, 144, 144, 1296, 0, 0, 0, 0]
+ >>> list(accumulate(data, max)) # running maximum
+ [3, 4, 6, 6, 6, 9, 9, 9, 9, 9]
+
+ # Amortize a 5% loan of 1000 with 4 annual payments of 90
+ >>> cashflows = [1000, -90, -90, -90, -90]
+ >>> list(accumulate(cashflows, lambda bal, pmt: bal*1.05 + pmt))
+ [1000, 960.0, 918.0, 873.9000000000001, 827.5950000000001]
+
 .. versionadded:: 3.2
 
+ .. versionchanged:: 3.3
+ Added the optional *func* parameter.
+
 .. function:: chain(*iterables)
 
 Make an iterator that returns elements from the first iterable until it is
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py
--- a/Lib/test/test_itertools.py
+++ b/Lib/test/test_itertools.py
@@ -69,11 +69,21 @@
 self.assertEqual(list(accumulate('abc')), ['a', 'ab', 'abc']) # works with non-numeric
 self.assertEqual(list(accumulate([])), []) # empty iterable
 self.assertEqual(list(accumulate([7])), [7]) # iterable of length one
- self.assertRaises(TypeError, accumulate, range(10), 5) # too many args
+ self.assertRaises(TypeError, accumulate, range(10), 5, 6) # too many args
 self.assertRaises(TypeError, accumulate) # too few args
 self.assertRaises(TypeError, accumulate, x=range(10)) # unexpected kwd arg
 self.assertRaises(TypeError, list, accumulate([1, []])) # args that don't add
 
+ s = [2, 8, 9, 5, 7, 0, 3, 4, 1, 6]
+ self.assertEqual(list(accumulate(s, min)),
+ [2, 2, 2, 2, 2, 0, 0, 0, 0, 0])
+ self.assertEqual(list(accumulate(s, max)),
+ [2, 8, 9, 9, 9, 9, 9, 9, 9, 9])
+ self.assertEqual(list(accumulate(s, operator.mul)),
+ [2, 16, 144, 720, 5040, 0, 0, 0, 0, 0])
+ with self.assertRaises(TypeError):
+ list(accumulate(s, chr)) # unary-operation
+
 def test_chain(self):
 
 def chain2(*iterables):
diff --git a/Misc/NEWS b/Misc/NEWS
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -89,6 +89,9 @@
 
 - Issue #11696: Fix ID generation in msilib.
 
+- itertools.accumulate now supports an optional *func* argument for
+ a user-supplied binary function.
+
 - Issue #11692: Remove unnecessary demo functions in subprocess module.
 
 - Issue #9696: Fix exception incorrectly raised by xdrlib.Packer.pack_int when
diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c
--- a/Modules/itertoolsmodule.c
+++ b/Modules/itertoolsmodule.c
@@ -2590,6 +2590,7 @@
 PyObject_HEAD
 PyObject *total;
 PyObject *it;
+ PyObject *binop;
 } accumulateobject;
 
 static PyTypeObject accumulate_type;
@@ -2597,12 +2598,14 @@
 static PyObject *
 accumulate_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
 {
- static char *kwargs[] = {"iterable", NULL};
+ static char *kwargs[] = {"iterable", "func", NULL};
 PyObject *iterable;
 PyObject *it;
+ PyObject *binop = NULL;
 accumulateobject *lz;
 
- if (!PyArg_ParseTupleAndKeywords(args, kwds, "O:accumulate", kwargs, &iterable))
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:accumulate",
+ kwargs, &iterable, &binop))
 return NULL;
 
 /* Get iterator. */
@@ -2617,6 +2620,8 @@
 return NULL;
 }
 
+ Py_XINCREF(binop);
+ lz->binop = binop;
 lz->total = NULL;
 lz->it = it;
 return (PyObject *)lz;
@@ -2626,6 +2631,7 @@
 accumulate_dealloc(accumulateobject *lz)
 {
 PyObject_GC_UnTrack(lz);
+ Py_XDECREF(lz->binop);
 Py_XDECREF(lz->total);
 Py_XDECREF(lz->it);
 Py_TYPE(lz)->tp_free(lz);
@@ -2634,6 +2640,7 @@
 static int
 accumulate_traverse(accumulateobject *lz, visitproc visit, void *arg)
 {
+ Py_VISIT(lz->binop);
 Py_VISIT(lz->it);
 Py_VISIT(lz->total);
 return 0;
@@ -2653,8 +2660,11 @@
 lz->total = val;
 return lz->total;
 }
- 
- newtotal = PyNumber_Add(lz->total, val);
+
+ if (lz->binop == NULL) 
+ newtotal = PyNumber_Add(lz->total, val);
+ else
+ newtotal = PyObject_CallFunctionObjArgs(lz->binop, lz->total, val, NULL);
 Py_DECREF(val);
 if (newtotal == NULL)
 return NULL;
-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list

AltStyle によって変換されたページ (->オリジナル) /