diff --git a/couchdb/http.py b/couchdb/http.py index 4a6750e9..9e21fbd1 100644 --- a/couchdb/http.py +++ b/couchdb/http.py @@ -13,6 +13,7 @@ from base64 import b64encode from datetime import datetime +import os import errno import socket import time @@ -482,17 +483,16 @@ class ConnectionPool(object): def __init__(self, timeout, disable_ssl_verification=False): self.timeout = timeout self.disable_ssl_verification = disable_ssl_verification - self.conns = {} # HTTP connections keyed by (scheme, host) + self.conns = {} # HTTP connections keyed by (os.getpid(), scheme, host) self.lock = Lock() def get(self, url): - scheme, host = util.urlsplit(url, 'http', False)[:2] # Try to reuse an existing connection. self.lock.acquire() try: - conns = self.conns.setdefault((scheme, host), []) + conns = self.conns.setdefault((os.getpid(), scheme, host), []) if conns: conn = conns.pop(-1) else: @@ -520,7 +520,7 @@ def release(self, url, conn): scheme, host = util.urlsplit(url, 'http', False)[:2] self.lock.acquire() try: - self.conns.setdefault((scheme, host), []).append(conn) + self.conns.setdefault((os.getpid(), scheme, host), []).append(conn) finally: self.lock.release() diff --git a/couchdb/tests/client.py b/couchdb/tests/client.py index 5aec3939..dc354d50 100644 --- a/couchdb/tests/client.py +++ b/couchdb/tests/client.py @@ -7,6 +7,8 @@ # you should have received as part of this distribution. from datetime import datetime +import functools +import multiprocessing import os import os.path import shutil @@ -481,7 +483,9 @@ def test_changes_releases_conn(self): # that the HTTP connection made it to the pool. list(self.db.changes(feed='continuous', timeout=0)) scheme, netloc = util.urlsplit(client.DEFAULT_BASE_URL)[:2] - self.assertTrue(self.db.resource.session.connection_pool.conns[(scheme, netloc)]) + current_pid = os.getpid() + key = (current_pid, scheme, netloc) + self.assertTrue(self.db.resource.session.connection_pool.conns[key]) def test_changes_releases_conn_when_lastseq(self): # Consume a changes feed, stopping at the 'last_seq' item, i.e. don't @@ -490,8 +494,10 @@ def test_changes_releases_conn_when_lastseq(self): for obj in self.db.changes(feed='continuous', timeout=0): if 'last_seq' in obj: break + current_pid = os.getpid() scheme, netloc = util.urlsplit(client.DEFAULT_BASE_URL)[:2] - self.assertTrue(self.db.resource.session.connection_pool.conns[(scheme, netloc)]) + key = (current_pid, scheme, netloc) + self.assertTrue(self.db.resource.session.connection_pool.conns[key]) def test_changes_conn_usable(self): # Consume a changes feed to get a used connection in the pool. @@ -838,8 +844,33 @@ def test_startkey(self): def test_nullkeys(self): self.assertEqual(len(list(self.db.iterview('test/nulls', 10))), self.num_docs) + +def _get_by_id(db, result, id): + result.append(db[id]) + + +class TestConcurrent(testutil.TempDatabaseMixin, unittest.TestCase): + def test_concurrent_get(self): + self.db.save({'_id': 'foo', 'value': 'hello'}) + self.db.save({'_id': 'bar', 'value': 'world'}) + processes = [] + result = multiprocessing.Manager().list() + for id in ('foo', 'bar'): + process = multiprocessing.Process(target=functools.partial(_get_by_id, self.db, result), + args=(id,)) + processes.append(process) + process.start() + + for process in processes: + process.join() + + self.assertEqual(len(result), 2) + self.assertEqual(set(['hello', 'world']), set([r['value'] for r in result])) + + def suite(): suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(TestConcurrent, 'test')) suite.addTest(unittest.makeSuite(ServerTestCase, 'test')) suite.addTest(unittest.makeSuite(DatabaseTestCase, 'test')) suite.addTest(unittest.makeSuite(ViewTestCase, 'test'))