diff --git a/influxdb/client.py b/influxdb/client.py index 6c180ce4..b84ccffb 100644 --- a/influxdb/client.py +++ b/influxdb/client.py @@ -763,13 +763,36 @@ def __init__(self, use_udp=use_udp, udp_port=udp_port)) for method in dir(client_base_class): - if method.startswith('_'): + if method.startswith('_') or method in ('switch_database', + 'switch_user'): continue + orig_func = getattr(client_base_class, method) if not callable(orig_func): continue + setattr(self, method, self._make_func(orig_func)) + def switch_database(self, database): + """Change the database of all clients in the cluster. + + :param database: the name of the database to switch to + :type database: str + """ + for client in self.clients + self.bad_clients: + client.switch_database(database) + + def switch_user(self, username, password): + """Change the username of all clients in the cluster. + + :param username: the username to switch to + :type username: str + :param password: the password for the username + :type password: str + """ + for client in self.clients + self.bad_clients: + client.switch_user(username, password) + @staticmethod def from_DSN(dsn, client_base_class=InfluxDBClient, shuffle=True, **kwargs): diff --git a/influxdb/tests/client_test.py b/influxdb/tests/client_test.py index 4f09970d..500320f9 100644 --- a/influxdb/tests/client_test.py +++ b/influxdb/tests/client_test.py @@ -337,12 +337,13 @@ def test_write_points_with_precision_fails(self): cli.write_points_with_precision([]) def test_query(self): - example_response = \ - '{"results": [{"series": [{"measurement": "sdfsdfsdf", ' \ - '"columns": ["time", "value"], "values": ' \ - '[["2009-11-10T23:00:00Z", 0.64]]}]}, {"series": ' \ - '[{"measurement": "cpu_load_short", "columns": ["time", "value"], ' \ + example_response = ( + '{"results": [{"series": [{"measurement": "sdfsdfsdf", ' + '"columns": ["time", "value"], "values": ' + '[["2009-11-10T23:00:00Z", 0.64]]}]}, {"series": ' + '[{"measurement": "cpu_load_short", "columns": ["time", "value"], ' '"values": [["2009-11-10T23:00:00Z", 0.64]]}]}]}' + ) with requests_mock.Mocker() as m: m.register_uri( @@ -804,6 +805,28 @@ def test_recovery(self): self.assertEqual(1, len(cluster.clients)) self.assertEqual(2, len(cluster.bad_clients)) + def test_switch_database(self): + c = InfluxDBClusterClient(hosts=self.hosts, + shuffle=True, + client_base_class=InfluxDBClient) + self.assertEqual(3, len(c.clients)) + self.assertEqual(0, len(c.bad_clients)) + map(lambda x: self.assertEqual(None, x._database), c.clients) + c.switch_database('database') + map(lambda x: self.assertEqual('database', x._database), c.clients) + + def test_switch_user(self): + c = InfluxDBClusterClient(hosts=self.hosts, + shuffle=True, + client_base_class=InfluxDBClient) + self.assertEqual(3, len(c.clients)) + self.assertEqual(0, len(c.bad_clients)) + map(lambda x: self.assertEqual('root', x._username), c.clients) + map(lambda x: self.assertEqual('root', x._password), c.clients) + c.switch_user('username', 'password') + map(lambda x: self.assertEqual('username', x._username), c.clients) + map(lambda x: self.assertEqual('password', x._password), c.clients) + def test_dsn(self): cli = InfluxDBClusterClient.from_DSN(self.dsn_string) self.assertEqual(2, len(cli.clients))