7
7
# you should have received as part of this distribution.
8
8
9
9
from datetime import datetime
10
+ import functools
11
+ import multiprocessing
10
12
import os
11
13
import os .path
12
14
import shutil
19
21
from couchdb .tests import testutil
20
22
21
23
24
+ def _current_pid ():
25
+ return os .getpid ()
26
+
27
+
22
28
class ServerTestCase (testutil .TempDatabaseMixin , unittest .TestCase ):
23
29
24
30
def test_init_with_resource (self ):
@@ -481,7 +487,9 @@ def test_changes_releases_conn(self):
481
487
# that the HTTP connection made it to the pool.
482
488
list (self .db .changes (feed = 'continuous' , timeout = 0 ))
483
489
scheme , netloc = util .urlsplit (client .DEFAULT_BASE_URL )[:2 ]
484
- self .assertTrue (self .db .resource .session .connection_pool .conns [(scheme , netloc )])
490
+ current_pid = _current_pid ()
491
+ key = (current_pid , scheme , netloc )
492
+ self .assertTrue (self .db .resource .session .connection_pool .conns [key ])
485
493
486
494
def test_changes_releases_conn_when_lastseq (self ):
487
495
# Consume a changes feed, stopping at the 'last_seq' item, i.e. don't
@@ -490,8 +498,10 @@ def test_changes_releases_conn_when_lastseq(self):
490
498
for obj in self .db .changes (feed = 'continuous' , timeout = 0 ):
491
499
if 'last_seq' in obj :
492
500
break
501
+ current_pid = _current_pid ()
493
502
scheme , netloc = util .urlsplit (client .DEFAULT_BASE_URL )[:2 ]
494
- self .assertTrue (self .db .resource .session .connection_pool .conns [(scheme , netloc )])
503
+ key = (current_pid , scheme , netloc )
504
+ self .assertTrue (self .db .resource .session .connection_pool .conns [key ])
495
505
496
506
def test_changes_conn_usable (self ):
497
507
# Consume a changes feed to get a used connection in the pool.
@@ -838,8 +848,33 @@ def test_startkey(self):
838
848
def test_nullkeys (self ):
839
849
self .assertEqual (len (list (self .db .iterview ('test/nulls' , 10 ))), self .num_docs )
840
850
851
+
852
+ def _get_by_id (db , result , id ):
853
+ result .append (db [id ])
854
+
855
+
856
+ class TestConcurrent (testutil .TempDatabaseMixin , unittest .TestCase ):
857
+ def test_concurrent_get (self ):
858
+ self .db .save ({'_id' : 'foo' , 'value' : 'hello' })
859
+ self .db .save ({'_id' : 'bar' , 'value' : 'world' })
860
+ processes = []
861
+ result = multiprocessing .Manager ().list ()
862
+ for id in ('foo' , 'bar' ):
863
+ process = multiprocessing .Process (target = functools .partial (_get_by_id , self .db , result ),
864
+ args = (id ,))
865
+ processes .append (process )
866
+ process .start ()
867
+
868
+ for process in processes :
869
+ process .join ()
870
+
871
+ self .assertEqual (len (result ), 2 )
872
+ self .assertEqual (set (['hello' , 'world' ]), set ([r ['value' ] for r in result ]))
873
+
874
+
841
875
def suite ():
842
876
suite = unittest .TestSuite ()
877
+ suite .addTest (unittest .makeSuite (TestConcurrent , 'test' ))
843
878
suite .addTest (unittest .makeSuite (ServerTestCase , 'test' ))
844
879
suite .addTest (unittest .makeSuite (DatabaseTestCase , 'test' ))
845
880
suite .addTest (unittest .makeSuite (ViewTestCase , 'test' ))
0 commit comments