Skip to content

Fix the bug where ConnectionPool cannot be used with multiprocessing #314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions couchdb/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from base64 import b64encode
from datetime import datetime
import os
import errno
import socket
import time
Expand Down Expand Up @@ -482,17 +483,16 @@ class ConnectionPool(object):
def __init__(self, timeout, disable_ssl_verification=False):
self.timeout = timeout
self.disable_ssl_verification = disable_ssl_verification
self.conns = {} # HTTP connections keyed by (scheme, host)
self.conns = {} # HTTP connections keyed by (os.getpid(), scheme, host)
self.lock = Lock()

def get(self, url):

scheme, host = util.urlsplit(url, 'http', False)[:2]

# Try to reuse an existing connection.
self.lock.acquire()
try:
conns = self.conns.setdefault((scheme, host), [])
conns = self.conns.setdefault((os.getpid(), scheme, host), [])
if conns:
conn = conns.pop(-1)
else:
Expand Down Expand Up @@ -520,7 +520,7 @@ def release(self, url, conn):
scheme, host = util.urlsplit(url, 'http', False)[:2]
self.lock.acquire()
try:
self.conns.setdefault((scheme, host), []).append(conn)
self.conns.setdefault((os.getpid(), scheme, host), []).append(conn)
finally:
self.lock.release()

Expand Down
35 changes: 33 additions & 2 deletions couchdb/tests/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
# you should have received as part of this distribution.

from datetime import datetime
import functools
import multiprocessing
import os
import os.path
import shutil
Expand Down Expand Up @@ -481,7 +483,9 @@ def test_changes_releases_conn(self):
# that the HTTP connection made it to the pool.
list(self.db.changes(feed='continuous', timeout=0))
scheme, netloc = util.urlsplit(client.DEFAULT_BASE_URL)[:2]
self.assertTrue(self.db.resource.session.connection_pool.conns[(scheme, netloc)])
current_pid = os.getpid()
key = (current_pid, scheme, netloc)
self.assertTrue(self.db.resource.session.connection_pool.conns[key])

def test_changes_releases_conn_when_lastseq(self):
# Consume a changes feed, stopping at the 'last_seq' item, i.e. don't
Expand All @@ -490,8 +494,10 @@ def test_changes_releases_conn_when_lastseq(self):
for obj in self.db.changes(feed='continuous', timeout=0):
if 'last_seq' in obj:
break
current_pid = os.getpid()
scheme, netloc = util.urlsplit(client.DEFAULT_BASE_URL)[:2]
self.assertTrue(self.db.resource.session.connection_pool.conns[(scheme, netloc)])
key = (current_pid, scheme, netloc)
self.assertTrue(self.db.resource.session.connection_pool.conns[key])

def test_changes_conn_usable(self):
# Consume a changes feed to get a used connection in the pool.
Expand Down Expand Up @@ -838,8 +844,33 @@ def test_startkey(self):
def test_nullkeys(self):
self.assertEqual(len(list(self.db.iterview('test/nulls', 10))), self.num_docs)


def _get_by_id(db, result, id):
result.append(db[id])


class TestConcurrent(testutil.TempDatabaseMixin, unittest.TestCase):
def test_concurrent_get(self):
self.db.save({'_id': 'foo', 'value': 'hello'})
self.db.save({'_id': 'bar', 'value': 'world'})
processes = []
result = multiprocessing.Manager().list()
for id in ('foo', 'bar'):
process = multiprocessing.Process(target=functools.partial(_get_by_id, self.db, result),
args=(id,))
processes.append(process)
process.start()

for process in processes:
process.join()

self.assertEqual(len(result), 2)
self.assertEqual(set(['hello', 'world']), set([r['value'] for r in result]))


def suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(TestConcurrent, 'test'))
suite.addTest(unittest.makeSuite(ServerTestCase, 'test'))
suite.addTest(unittest.makeSuite(DatabaseTestCase, 'test'))
suite.addTest(unittest.makeSuite(ViewTestCase, 'test'))
Expand Down