[Python-checkins] r86948 - in python/branches/py3k: Doc/library/itertools.rst Lib/test/test_itertools.py Modules/itertoolsmodule.c

raymond.hettinger python-checkins at python.org
Fri Dec 3 03:09:34 CET 2010


Author: raymond.hettinger
Date: Fri Dec 3 03:09:34 2010
New Revision: 86948
Log:
Simplify the signature for itertools.accumulate() to match numpy. Handle one item iterable the same way as min()/max().
Modified:
 python/branches/py3k/Doc/library/itertools.rst
 python/branches/py3k/Lib/test/test_itertools.py
 python/branches/py3k/Modules/itertoolsmodule.c
Modified: python/branches/py3k/Doc/library/itertools.rst
==============================================================================
--- python/branches/py3k/Doc/library/itertools.rst	(original)
+++ python/branches/py3k/Doc/library/itertools.rst	Fri Dec 3 03:09:34 2010
@@ -90,13 +90,15 @@
 parameter (which defaults to :const:`0`). Elements may be any addable type
 including :class:`Decimal` or :class:`Fraction`. Equivalent to::
 
- def accumulate(iterable, start=0):
+ def accumulate(iterable):
 'Return running totals'
- # accumulate([1,2,3,4,5]) --> 1 3 6 10 15
- total = start
- for element in iterable:
- total += element
- yield total
+ # accumulate([1,2,3,4,5]) --> 1 3 6 10 15
+ it = iter(iterable)
+ total = next(it)
+ yield total
+ for element in it:
+ total += element
+ yield total
 
 .. versionadded:: 3.2
 
Modified: python/branches/py3k/Lib/test/test_itertools.py
==============================================================================
--- python/branches/py3k/Lib/test/test_itertools.py	(original)
+++ python/branches/py3k/Lib/test/test_itertools.py	Fri Dec 3 03:09:34 2010
@@ -59,18 +59,18 @@
 
 def test_accumulate(self):
 self.assertEqual(list(accumulate(range(10))), # one positional arg
- [0, 1, 3, 6, 10, 15, 21, 28, 36, 45])
- self.assertEqual(list(accumulate(range(10), 100)), # two positional args
- [100, 101, 103, 106, 110, 115, 121, 128, 136, 145])
- self.assertEqual(list(accumulate(iterable=range(10), start=100)), # kw args
- [100, 101, 103, 106, 110, 115, 121, 128, 136, 145])
+ [0, 1, 3, 6, 10, 15, 21, 28, 36, 45])
+ self.assertEqual(list(accumulate(iterable=range(10))), # kw arg
+ [0, 1, 3, 6, 10, 15, 21, 28, 36, 45])
 for typ in int, complex, Decimal, Fraction: # multiple types
- self.assertEqual(list(accumulate(range(10), typ(0))),
+ self.assertEqual(
+ list(accumulate(map(typ, range(10)))),
 list(map(typ, [0, 1, 3, 6, 10, 15, 21, 28, 36, 45])))
 self.assertEqual(list(accumulate([])), []) # empty iterable
- self.assertRaises(TypeError, accumulate, range(10), 0, 5) # too many args
+ self.assertEqual(list(accumulate([7])), [7]) # iterable of length one
+ self.assertRaises(TypeError, accumulate, range(10), 5) # too many args
 self.assertRaises(TypeError, accumulate) # too few args
- self.assertRaises(TypeError, accumulate, range(10), x=7) # unexpected kwd args
+ self.assertRaises(TypeError, accumulate, x=range(10)) # unexpected kwd arg
 self.assertRaises(TypeError, list, accumulate([1, []])) # args that don't add
 
 def test_chain(self):
Modified: python/branches/py3k/Modules/itertoolsmodule.c
==============================================================================
--- python/branches/py3k/Modules/itertoolsmodule.c	(original)
+++ python/branches/py3k/Modules/itertoolsmodule.c	Fri Dec 3 03:09:34 2010
@@ -2597,41 +2597,27 @@
 static PyObject *
 accumulate_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
 {
- static char *kwargs[] = {"iterable", "start", NULL};
+ static char *kwargs[] = {"iterable", NULL};
 PyObject *iterable;
 PyObject *it;
- PyObject *start = NULL;
 accumulateobject *lz;
 
- if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:accumulate",
-				 kwargs, &iterable, &start))
-	return NULL;
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "O:accumulate", kwargs, &iterable))
+ return NULL;
 
 /* Get iterator. */
 it = PyObject_GetIter(iterable);
 if (it == NULL)
 return NULL;
 
- /* Default start value */
- if (start == NULL) {
-	 start = PyLong_FromLong(0);
-	 if (start == NULL) {
-	 Py_DECREF(it);
-	 return NULL;
-	 }
- } else {
- Py_INCREF(start);
- }
-
 /* create accumulateobject structure */
 lz = (accumulateobject *)type->tp_alloc(type, 0);
 if (lz == NULL) {
 Py_DECREF(it);
-	 Py_DECREF(start);
- return NULL;
+ return NULL;
 }
 
- lz->total = start;
+ lz->total = NULL;
 lz->it = it;
 return (PyObject *)lz;
 }
@@ -2661,11 +2647,17 @@
 val = PyIter_Next(lz->it);
 if (val == NULL)
 return NULL;
- 
+ 
+ if (lz->total == NULL) {
+ Py_INCREF(val);
+ lz->total = val;
+ return lz->total;
+ }
+ 
 newtotal = PyNumber_Add(lz->total, val);
-	Py_DECREF(val);
+ Py_DECREF(val);
 if (newtotal == NULL)
-	 return NULL;
+ return NULL;
 
 oldtotal = lz->total;
 lz->total = newtotal;
@@ -2676,7 +2668,7 @@
 }
 
 PyDoc_STRVAR(accumulate_doc,
-"accumulate(iterable, start=0) --> accumulate object\n\
+"accumulate(iterable) --> accumulate object\n\
 \n\
 Return series of accumulated sums.");
 


More information about the Python-checkins mailing list

AltStyle によって変換されたページ (->オリジナル) /