diff --git a/pymysql/tests/base.py b/pymysql/tests/base.py index 740157b1..e54afee5 100644 --- a/pymysql/tests/base.py +++ b/pymysql/tests/base.py @@ -40,15 +40,33 @@ def mysql_server_is(self, conn, version_tuple): ) return server_version_tuple >= version_tuple - def setUp(self): - self.connections = [] - for params in self.databases: - self.connections.append(pymysql.connect(**params)) - self.addCleanup(self._teardown_connections) + _connections = None + + @property + def connections(self): + if self._connections is None: + self._connections = [] + for params in self.databases: + self._connections.append(pymysql.connect(**params)) + self.addCleanup(self._teardown_connections) + return self._connections + + def connect(self, **params): + p = self.databases[0].copy() + p.update(params) + conn = pymysql.connect(**p) + @self.addCleanup + def teardown(): + if conn.open: + conn.close() + return conn def _teardown_connections(self): - for connection in self.connections: - connection.close() + if self._connections: + for connection in self._connections: + if connection.open: + connection.close() + self._connections = None def safe_create_table(self, connection, tablename, ddl, cleanup=True): """create a table. diff --git a/pymysql/tests/test_SSCursor.py b/pymysql/tests/test_SSCursor.py index e6d6cf53..77eeefa6 100644 --- a/pymysql/tests/test_SSCursor.py +++ b/pymysql/tests/test_SSCursor.py @@ -3,17 +3,19 @@ try: from pymysql.tests import base import pymysql.cursors + from pymysql.constants import CLIENT except Exception: # For local testing from top-level directory, without installing sys.path.append('../pymysql') from pymysql.tests import base import pymysql.cursors + from pymysql.constants import CLIENT class TestSSCursor(base.PyMySQLTestCase): def test_SSCursor(self): affected_rows = 18446744073709551615 - conn = self.connections[0] + conn = self.connect(client_flag=CLIENT.MULTI_STATEMENTS) data = [ ('America', '', 'America/Jamaica'), ('America', '', 'America/Los_Angeles'), @@ -30,10 +32,10 @@ def test_SSCursor(self): cursor = conn.cursor(pymysql.cursors.SSCursor) # Create table - cursor.execute(('CREATE TABLE tz_data (' + cursor.execute('CREATE TABLE tz_data (' 'region VARCHAR(64),' 'zone VARCHAR(64),' - 'name VARCHAR(64))')) + 'name VARCHAR(64))') conn.begin() # Test INSERT @@ -100,7 +102,7 @@ def test_SSCursor(self): self.assertFalse(cursor.nextset()) finally: - cursor.execute('DROP TABLE tz_data') + cursor.execute('DROP TABLE IF EXISTS tz_data') cursor.close() __all__ = ["TestSSCursor"] diff --git a/pymysql/tests/test_connection.py b/pymysql/tests/test_connection.py index 518b6fe7..1fe908ce 100644 --- a/pymysql/tests/test_connection.py +++ b/pymysql/tests/test_connection.py @@ -5,6 +5,7 @@ import pymysql from pymysql.tests import base from pymysql._compat import text_type +from pymysql.constants import CLIENT class TempUser: @@ -411,7 +412,7 @@ def test_connection_gone_away(self): http://dev.mysql.com/doc/refman/5.0/en/gone-away.html http://dev.mysql.com/doc/refman/5.0/en/error-messages-client.html#error_cr_server_gone_error """ - con = self.connections[0] + con = self.connect() cur = con.cursor() cur.execute("SET wait_timeout=1") time.sleep(2) @@ -422,10 +423,9 @@ def test_connection_gone_away(self): self.assertIn(cm.exception.args[0], (2006, 2013)) def test_init_command(self): - conn = pymysql.connect( + conn = self.connect( init_command='SELECT "bar"; SELECT "baz"', - **self.databases[0] - ) + client_flag=CLIENT.MULTI_STATEMENTS) c = conn.cursor() c.execute('select "foobar";') self.assertEqual(('foobar',), c.fetchone()) @@ -434,22 +434,21 @@ def test_init_command(self): conn.ping(reconnect=False) def test_read_default_group(self): - conn = pymysql.connect( + conn = self.connect( read_default_group='client', - **self.databases[0] ) self.assertTrue(conn.open) def test_context(self): with self.assertRaises(ValueError): - c = pymysql.connect(**self.databases[0]) + c = self.connect() with c as cur: cur.execute('create table test ( a int )') c.begin() cur.execute('insert into test values ((1))') raise ValueError('pseudo abort') c.commit() - c = pymysql.connect(**self.databases[0]) + c = self.connect() with c as cur: cur.execute('select count(*) from test') self.assertEqual(0, cur.fetchone()[0]) @@ -460,31 +459,31 @@ def test_context(self): cur.execute('drop table test') def test_set_charset(self): - c = pymysql.connect(**self.databases[0]) + c = self.connect() c.set_charset('utf8') # TODO validate setting here def test_defer_connect(self): import socket - for db in self.databases: - d = db.copy() + + d = self.databases[0].copy() + try: + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.connect(d['unix_socket']) + except KeyError: + sock = socket.create_connection( + (d.get('host', 'localhost'), d.get('port', 3306))) + for k in ['unix_socket', 'host', 'port']: try: - sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - sock.connect(d['unix_socket']) + del d[k] except KeyError: - sock = socket.create_connection( - (d.get('host', 'localhost'), d.get('port', 3306))) - for k in ['unix_socket', 'host', 'port']: - try: - del d[k] - except KeyError: - pass - - c = pymysql.connect(defer_connect=True, **d) - self.assertFalse(c.open) - c.connect(sock) - c.close() - sock.close() + pass + + c = pymysql.connect(defer_connect=True, **d) + self.assertFalse(c.open) + c.connect(sock) + c.close() + sock.close() @unittest2.skipUnless(sys.version_info[0:2] >= (3,2), "required py-3.2") def test_no_delay_warning(self): @@ -560,7 +559,9 @@ def test_escape_list_item(self): self.assertEqual(con.escape([Foo()], mapping), "(bar)") def test_previous_cursor_not_closed(self): - con = self.connections[0] + con = self.connect( + init_command='SELECT "bar"; SELECT "baz"', + client_flag=CLIENT.MULTI_STATEMENTS) cur1 = con.cursor() cur1.execute("SELECT 1; SELECT 2") cur2 = con.cursor() @@ -568,7 +569,7 @@ def test_previous_cursor_not_closed(self): self.assertEqual(cur2.fetchone()[0], 3) def test_commit_during_multi_result(self): - con = self.connections[0] + con = self.connect(client_flag=CLIENT.MULTI_STATEMENTS) cur = con.cursor() cur.execute("SELECT 1; SELECT 2") con.commit() diff --git a/pymysql/tests/test_nextset.py b/pymysql/tests/test_nextset.py index cdb6754f..593243e4 100644 --- a/pymysql/tests/test_nextset.py +++ b/pymysql/tests/test_nextset.py @@ -2,16 +2,16 @@ from pymysql.tests import base from pymysql import util +from pymysql.constants import CLIENT class TestNextset(base.PyMySQLTestCase): - def setUp(self): - super(TestNextset, self).setUp() - self.con = self.connections[0] - def test_nextset(self): - cur = self.con.cursor() + con = self.connect( + init_command='SELECT "bar"; SELECT "baz"', + client_flag=CLIENT.MULTI_STATEMENTS) + cur = con.cursor() cur.execute("SELECT 1; SELECT 2;") self.assertEqual([(1,)], list(cur)) @@ -22,7 +22,7 @@ def test_nextset(self): self.assertIsNone(cur.nextset()) def test_skip_nextset(self): - cur = self.con.cursor() + cur = self.connect(client_flag=CLIENT.MULTI_STATEMENTS).cursor() cur.execute("SELECT 1; SELECT 2;") self.assertEqual([(1,)], list(cur)) @@ -30,7 +30,7 @@ def test_skip_nextset(self): self.assertEqual([(42,)], list(cur)) def test_ok_and_next(self): - cur = self.con.cursor() + cur = self.connect(client_flag=CLIENT.MULTI_STATEMENTS).cursor() cur.execute("SELECT 1; commit; SELECT 2;") self.assertEqual([(1,)], list(cur)) self.assertTrue(cur.nextset()) @@ -40,8 +40,9 @@ def test_ok_and_next(self): @unittest2.expectedFailure def test_multi_cursor(self): - cur1 = self.con.cursor() - cur2 = self.con.cursor() + con = self.connect(client_flag=CLIENT.MULTI_STATEMENTS) + cur1 = con.cursor() + cur2 = con.cursor() cur1.execute("SELECT 1; SELECT 2;") cur2.execute("SELECT 42") @@ -56,7 +57,10 @@ def test_multi_cursor(self): self.assertIsNone(cur1.nextset()) def test_multi_statement_warnings(self): - cursor = self.con.cursor() + con = self.connect( + init_command='SELECT "bar"; SELECT "baz"', + client_flag=CLIENT.MULTI_STATEMENTS) + cursor = con.cursor() try: cursor.execute('DROP TABLE IF EXISTS a; '