[Python-checkins] bpo-32604: Implement force-closing channels. (gh-6937)

Eric Snow webhook-mailer at python.org
Thu May 17 10:27:19 EDT 2018


https://github.com/python/cpython/commit/3ab0136ac5d6059ce96d4debca89c5f5ab0356f5
commit: 3ab0136ac5d6059ce96d4debca89c5f5ab0356f5
branch: master
author: Eric Snow <ericsnowcurrently at gmail.com>
committer: GitHub <noreply at github.com>
date: 2018年05月17日T10:27:09-04:00
summary:
bpo-32604: Implement force-closing channels. (gh-6937)
This will make it easier to clean up channels (e.g. when used in tests).
files:
M Lib/test/test__xxsubinterpreters.py
M Modules/_xxsubinterpretersmodule.c
diff --git a/Lib/test/test__xxsubinterpreters.py b/Lib/test/test__xxsubinterpreters.py
index 118f2e4895fe..f66cc9516926 100644
--- a/Lib/test/test__xxsubinterpreters.py
+++ b/Lib/test/test__xxsubinterpreters.py
@@ -1379,12 +1379,104 @@ def test_close_multiple_times(self):
 with self.assertRaises(interpreters.ChannelClosedError):
 interpreters.channel_close(cid)
 
- def test_close_with_unused_items(self):
+ def test_close_empty(self):
+ tests = [
+ (False, False),
+ (True, False),
+ (False, True),
+ (True, True),
+ ]
+ for send, recv in tests:
+ with self.subTest((send, recv)):
+ cid = interpreters.channel_create()
+ interpreters.channel_send(cid, b'spam')
+ interpreters.channel_recv(cid)
+ interpreters.channel_close(cid, send=send, recv=recv)
+
+ with self.assertRaises(interpreters.ChannelClosedError):
+ interpreters.channel_send(cid, b'eggs')
+ with self.assertRaises(interpreters.ChannelClosedError):
+ interpreters.channel_recv(cid)
+
+ def test_close_defaults_with_unused_items(self):
+ cid = interpreters.channel_create()
+ interpreters.channel_send(cid, b'spam')
+ interpreters.channel_send(cid, b'ham')
+
+ with self.assertRaises(interpreters.ChannelNotEmptyError):
+ interpreters.channel_close(cid)
+ interpreters.channel_recv(cid)
+ interpreters.channel_send(cid, b'eggs')
+
+ def test_close_recv_with_unused_items_unforced(self):
 cid = interpreters.channel_create()
 interpreters.channel_send(cid, b'spam')
 interpreters.channel_send(cid, b'ham')
- interpreters.channel_close(cid)
+
+ with self.assertRaises(interpreters.ChannelNotEmptyError):
+ interpreters.channel_close(cid, recv=True)
+ interpreters.channel_recv(cid)
+ interpreters.channel_send(cid, b'eggs')
+ interpreters.channel_recv(cid)
+ interpreters.channel_recv(cid)
+ interpreters.channel_close(cid, recv=True)
+
+ def test_close_send_with_unused_items_unforced(self):
+ cid = interpreters.channel_create()
+ interpreters.channel_send(cid, b'spam')
+ interpreters.channel_send(cid, b'ham')
+ interpreters.channel_close(cid, send=True)
 
+ with self.assertRaises(interpreters.ChannelClosedError):
+ interpreters.channel_send(cid, b'eggs')
+ interpreters.channel_recv(cid)
+ interpreters.channel_recv(cid)
+ with self.assertRaises(interpreters.ChannelClosedError):
+ interpreters.channel_recv(cid)
+
+ def test_close_both_with_unused_items_unforced(self):
+ cid = interpreters.channel_create()
+ interpreters.channel_send(cid, b'spam')
+ interpreters.channel_send(cid, b'ham')
+
+ with self.assertRaises(interpreters.ChannelNotEmptyError):
+ interpreters.channel_close(cid, recv=True, send=True)
+ interpreters.channel_recv(cid)
+ interpreters.channel_send(cid, b'eggs')
+ interpreters.channel_recv(cid)
+ interpreters.channel_recv(cid)
+ interpreters.channel_close(cid, recv=True)
+
+ def test_close_recv_with_unused_items_forced(self):
+ cid = interpreters.channel_create()
+ interpreters.channel_send(cid, b'spam')
+ interpreters.channel_send(cid, b'ham')
+ interpreters.channel_close(cid, recv=True, force=True)
+
+ with self.assertRaises(interpreters.ChannelClosedError):
+ interpreters.channel_send(cid, b'eggs')
+ with self.assertRaises(interpreters.ChannelClosedError):
+ interpreters.channel_recv(cid)
+
+ def test_close_send_with_unused_items_forced(self):
+ cid = interpreters.channel_create()
+ interpreters.channel_send(cid, b'spam')
+ interpreters.channel_send(cid, b'ham')
+ interpreters.channel_close(cid, send=True, force=True)
+
+ with self.assertRaises(interpreters.ChannelClosedError):
+ interpreters.channel_send(cid, b'eggs')
+ with self.assertRaises(interpreters.ChannelClosedError):
+ interpreters.channel_recv(cid)
+
+ def test_close_both_with_unused_items_forced(self):
+ cid = interpreters.channel_create()
+ interpreters.channel_send(cid, b'spam')
+ interpreters.channel_send(cid, b'ham')
+ interpreters.channel_close(cid, send=True, recv=True, force=True)
+
+ with self.assertRaises(interpreters.ChannelClosedError):
+ interpreters.channel_send(cid, b'eggs')
 with self.assertRaises(interpreters.ChannelClosedError):
 interpreters.channel_recv(cid)
 
@@ -1403,7 +1495,7 @@ def test_close_by_unassociated_interp(self):
 interp = interpreters.create()
 interpreters.run_string(interp, dedent(f"""
 import _xxsubinterpreters as _interpreters
- _interpreters.channel_close({cid})
+ _interpreters.channel_close({cid}, force=True)
 """))
 with self.assertRaises(interpreters.ChannelClosedError):
 interpreters.channel_recv(cid)
@@ -1416,7 +1508,7 @@ def test_close_used_multiple_times_by_single_user(self):
 interpreters.channel_send(cid, b'spam')
 interpreters.channel_send(cid, b'spam')
 interpreters.channel_recv(cid)
- interpreters.channel_close(cid)
+ interpreters.channel_close(cid, force=True)
 
 with self.assertRaises(interpreters.ChannelClosedError):
 interpreters.channel_send(cid, b'eggs')
diff --git a/Modules/_xxsubinterpretersmodule.c b/Modules/_xxsubinterpretersmodule.c
index 5184f6593db1..72387d8da56b 100644
--- a/Modules/_xxsubinterpretersmodule.c
+++ b/Modules/_xxsubinterpretersmodule.c
@@ -306,10 +306,15 @@ _sharedexception_apply(_sharedexception *exc, PyObject *wrapperclass)
 
 /* channel-specific code ****************************************************/
 
+#define CHANNEL_SEND 1
+#define CHANNEL_BOTH 0
+#define CHANNEL_RECV -1
+
 static PyObject *ChannelError;
 static PyObject *ChannelNotFoundError;
 static PyObject *ChannelClosedError;
 static PyObject *ChannelEmptyError;
+static PyObject *ChannelNotEmptyError;
 
 static int
 channel_exceptions_init(PyObject *ns)
@@ -356,6 +361,16 @@ channel_exceptions_init(PyObject *ns)
 return -1;
 }
 
+ // An operation tried to close a non-empty channel.
+ ChannelNotEmptyError = PyErr_NewException(
+ "_xxsubinterpreters.ChannelNotEmptyError", ChannelError, NULL);
+ if (ChannelNotEmptyError == NULL) {
+ return -1;
+ }
+ if (PyDict_SetItemString(ns, "ChannelNotEmptyError", ChannelNotEmptyError) != 0) {
+ return -1;
+ }
+
 return 0;
 }
 
@@ -696,8 +711,11 @@ _channelends_close_interpreter(_channelends *ends, int64_t interp, int which)
 }
 
 static void
-_channelends_close_all(_channelends *ends)
+_channelends_close_all(_channelends *ends, int which, int force)
 {
+ // XXX Handle the ends.
+ // XXX Handle force is True.
+
 // Ensure all the "send"-associated interpreters are closed.
 _channelend *end;
 for (end = ends->send; end != NULL; end = end->next) {
@@ -713,12 +731,16 @@ _channelends_close_all(_channelends *ends)
 /* channels */
 
 struct _channel;
+struct _channel_closing;
+static void _channel_clear_closing(struct _channel *);
+static void _channel_finish_closing(struct _channel *);
 
 typedef struct _channel {
 PyThread_type_lock mutex;
 _channelqueue *queue;
 _channelends *ends;
 int open;
+ struct _channel_closing *closing;
 } _PyChannelState;
 
 static _PyChannelState *
@@ -747,12 +769,14 @@ _channel_new(void)
 return NULL;
 }
 chan->open = 1;
+ chan->closing = NULL;
 return chan;
 }
 
 static void
 _channel_free(_PyChannelState *chan)
 {
+ _channel_clear_closing(chan);
 PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
 _channelqueue_free(chan->queue);
 _channelends_free(chan->ends);
@@ -802,13 +826,20 @@ _channel_next(_PyChannelState *chan, int64_t interp)
 }
 
 data = _channelqueue_get(chan->queue);
+ if (data == NULL && !PyErr_Occurred() && chan->closing != NULL) {
+ chan->open = 0;
+ }
+
 done:
 PyThread_release_lock(chan->mutex);
+ if (chan->queue->count == 0) {
+ _channel_finish_closing(chan);
+ }
 return data;
 }
 
 static int
-_channel_close_interpreter(_PyChannelState *chan, int64_t interp, int which)
+_channel_close_interpreter(_PyChannelState *chan, int64_t interp, int end)
 {
 PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
 
@@ -818,7 +849,7 @@ _channel_close_interpreter(_PyChannelState *chan, int64_t interp, int which)
 goto done;
 }
 
- if (_channelends_close_interpreter(chan->ends, interp, which) != 0) {
+ if (_channelends_close_interpreter(chan->ends, interp, end) != 0) {
 goto done;
 }
 chan->open = _channelends_is_open(chan->ends);
@@ -830,7 +861,7 @@ _channel_close_interpreter(_PyChannelState *chan, int64_t interp, int which)
 }
 
 static int
-_channel_close_all(_PyChannelState *chan)
+_channel_close_all(_PyChannelState *chan, int end, int force)
 {
 int res = -1;
 PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
@@ -840,11 +871,17 @@ _channel_close_all(_PyChannelState *chan)
 goto done;
 }
 
+ if (!force && chan->queue->count > 0) {
+ PyErr_SetString(ChannelNotEmptyError,
+ "may not be closed if not empty (try force=True)");
+ goto done;
+ }
+
 chan->open = 0;
 
 // We *could* also just leave these in place, since we've marked
 // the channel as closed already.
- _channelends_close_all(chan->ends);
+ _channelends_close_all(chan->ends, end, force);
 
 res = 0;
 done:
@@ -889,6 +926,9 @@ _channelref_new(int64_t id, _PyChannelState *chan)
 static void
 _channelref_free(_channelref *ref)
 {
+ if (ref->chan != NULL) {
+ _channel_clear_closing(ref->chan);
+ }
 //_channelref_clear(ref);
 PyMem_Free(ref);
 }
@@ -1009,8 +1049,12 @@ _channels_add(_channels *channels, _PyChannelState *chan)
 return cid;
 }
 
+/* forward */
+static int _channel_set_closing(struct _channelref *, PyThread_type_lock);
+
 static int
-_channels_close(_channels *channels, int64_t cid, _PyChannelState **pchan)
+_channels_close(_channels *channels, int64_t cid, _PyChannelState **pchan,
+ int end, int force)
 {
 int res = -1;
 PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
@@ -1028,14 +1072,35 @@ _channels_close(_channels *channels, int64_t cid, _PyChannelState **pchan)
 PyErr_Format(ChannelClosedError, "channel %d closed", cid);
 goto done;
 }
+ else if (!force && end == CHANNEL_SEND && ref->chan->closing != NULL) {
+ PyErr_Format(ChannelClosedError, "channel %d closed", cid);
+ goto done;
+ }
 else {
- if (_channel_close_all(ref->chan) != 0) {
+ if (_channel_close_all(ref->chan, end, force) != 0) {
+ if (end == CHANNEL_SEND &&
+ PyErr_ExceptionMatches(ChannelNotEmptyError)) {
+ if (ref->chan->closing != NULL) {
+ PyErr_Format(ChannelClosedError, "channel %d closed", cid);
+ goto done;
+ }
+ // Mark the channel as closing and return. The channel
+ // will be cleaned up in _channel_next().
+ PyErr_Clear();
+ if (_channel_set_closing(ref, channels->mutex) != 0) {
+ goto done;
+ }
+ if (pchan != NULL) {
+ *pchan = ref->chan;
+ }
+ res = 0;
+ }
 goto done;
 }
 if (pchan != NULL) {
 *pchan = ref->chan;
 }
- else {
+ else {
 _channel_free(ref->chan);
 }
 ref->chan = NULL;
@@ -1161,6 +1226,60 @@ _channels_list_all(_channels *channels, int64_t *count)
 return cids;
 }
 
+/* support for closing non-empty channels */
+
+struct _channel_closing {
+ struct _channelref *ref;
+};
+
+static int
+_channel_set_closing(struct _channelref *ref, PyThread_type_lock mutex) {
+ struct _channel *chan = ref->chan;
+ if (chan == NULL) {
+ // already closed
+ return 0;
+ }
+ int res = -1;
+ PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
+ if (chan->closing != NULL) {
+ PyErr_SetString(ChannelClosedError, "channel closed");
+ goto done;
+ }
+ chan->closing = PyMem_NEW(struct _channel_closing, 1);
+ if (chan->closing == NULL) {
+ goto done;
+ }
+ chan->closing->ref = ref;
+
+ res = 0;
+done:
+ PyThread_release_lock(chan->mutex);
+ return res;
+}
+
+static void
+_channel_clear_closing(struct _channel *chan) {
+ PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
+ if (chan->closing != NULL) {
+ PyMem_Free(chan->closing);
+ chan->closing = NULL;
+ }
+ PyThread_release_lock(chan->mutex);
+}
+
+static void
+_channel_finish_closing(struct _channel *chan) {
+ struct _channel_closing *closing = chan->closing;
+ if (closing == NULL) {
+ return;
+ }
+ _channelref *ref = closing->ref;
+ _channel_clear_closing(chan);
+ // Do the things that would have been done in _channels_close().
+ ref->chan = NULL;
+ _channel_free(chan);
+};
+
 /* "high"-level channel-related functions */
 
 static int64_t
@@ -1207,6 +1326,12 @@ _channel_send(_channels *channels, int64_t id, PyObject *obj)
 }
 // Past this point we are responsible for releasing the mutex.
 
+ if (chan->closing != NULL) {
+ PyErr_Format(ChannelClosedError, "channel %d closed", id);
+ PyThread_release_lock(mutex);
+ return -1;
+ }
+
 // Convert the object to cross-interpreter data.
 _PyCrossInterpreterData *data = PyMem_NEW(_PyCrossInterpreterData, 1);
 if (data == NULL) {
@@ -1290,16 +1415,13 @@ _channel_drop(_channels *channels, int64_t id, int send, int recv)
 }
 
 static int
-_channel_close(_channels *channels, int64_t id)
+_channel_close(_channels *channels, int64_t id, int end, int force)
 {
- return _channels_close(channels, id, NULL);
+ return _channels_close(channels, id, NULL, end, force);
 }
 
 /* ChannelID class */
 
-#define CHANNEL_SEND 1
-#define CHANNEL_RECV -1
-
 static PyTypeObject ChannelIDtype;
 
 typedef struct channelid {
@@ -2555,15 +2677,8 @@ channel_close(PyObject *self, PyObject *args, PyObject *kwds)
 if (cid < 0) {
 return NULL;
 }
- if (send == 0 && recv == 0) {
- send = 1;
- recv = 1;
- }
-
- // XXX Handle the ends.
- // XXX Handle force is True.
 
- if (_channel_close(&_globals.channels, cid) != 0) {
+ if (_channel_close(&_globals.channels, cid, send-recv, force) != 0) {
 return NULL;
 }
 Py_RETURN_NONE;


More information about the Python-checkins mailing list

AltStyle によって変換されたページ (->オリジナル) /