-
Notifications
You must be signed in to change notification settings - Fork 4.3k
/
Copy pathdb.py
272 lines (226 loc) · 9.05 KB
/
db.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import uuid
import time
import json
import datetime
import threading
import logging
from awscli.compat import collections_abc
from botocore.history import BaseHistoryHandler
from awscli.compat import sqlite3
from awscli.compat import binary_type
LOG = logging.getLogger(__name__)
class DatabaseConnection(object):
_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS records (
id TEXT,
request_id TEXT,
source TEXT,
event_type TEXT,
timestamp INTEGER,
payload TEXT
)"""
_ENABLE_WAL = 'PRAGMA journal_mode=WAL'
def __init__(self, db_filename):
self._connection = sqlite3.connect(
db_filename, check_same_thread=False, isolation_level=None)
self._ensure_database_setup()
def close(self):
self._connection.close()
def execute(self, query, *parameters):
return self._connection.execute(query, *parameters)
def _ensure_database_setup(self):
self._create_record_table()
self._try_to_enable_wal()
def _create_record_table(self):
self.execute(self._CREATE_TABLE)
def _try_to_enable_wal(self):
try:
self.execute(self._ENABLE_WAL)
except sqlite3.Error:
# This is just a performance enhancement so it is optional. Not all
# systems will have a sqlite compiled with the WAL enabled.
LOG.debug('Failed to enable sqlite WAL.')
@property
def row_factory(self):
return self._connection.row_factory
@row_factory.setter
def row_factory(self, row_factory):
self._connection.row_factory = row_factory
class PayloadSerializer(json.JSONEncoder):
def _encode_mutable_mapping(self, obj):
return dict(obj)
def _encode_datetime(self, obj):
return obj.isoformat()
def _try_decode_bytes(self, obj):
try:
obj = obj.decode('utf-8')
except UnicodeDecodeError:
obj = '<Byte sequence>'
return obj
def _remove_non_unicode_stings(self, obj):
if isinstance(obj, str):
obj = self._try_decode_bytes(obj)
elif isinstance(obj, dict):
obj = dict((k, self._remove_non_unicode_stings(v)) for k, v
in obj.items())
elif isinstance(obj, (list, tuple)):
obj = [self._remove_non_unicode_stings(o) for o in obj]
return obj
def encode(self, obj):
try:
return super(PayloadSerializer, self).encode(obj)
except UnicodeDecodeError:
# This happens in PY2 in the case where a record payload has some
# binary data in it that is not utf-8 encodable. PY2 will not call
# the default method on the individual field with bytes in it since
# it thinks it can handle it with the normal string serialization
# method. Since it cannot tell the difference between a utf-8 str
# and a str with raw bytes in it we will get a UnicodeDecodeError
# here at the top level. There are no hooks into the serialization
# process in PY2 that allow us to fix this behavior, so instead
# when we encounter the unicode error we climb the structure
# ourselves and replace all strings that are not utf-8 decodable
# and try to encode again.
scrubbed_obj = self._remove_non_unicode_stings(obj)
return super(PayloadSerializer, self).encode(scrubbed_obj)
def default(self, obj):
if isinstance(obj, datetime.datetime):
return self._encode_datetime(obj)
elif isinstance(obj, collections_abc.MutableMapping):
return self._encode_mutable_mapping(obj)
elif isinstance(obj, binary_type):
# In PY3 the bytes type differs from the str type so the default
# method will be called when a bytes object is encountered.
# We call the same _try_decode_bytes method that either decodes it
# to a utf-8 string and continues serialization, or removes the
# value if it is not valid utf-8 string.
return self._try_decode_bytes(obj)
else:
return repr(obj)
class DatabaseRecordWriter(object):
_WRITE_RECORD = """
INSERT INTO records(
id, request_id, source, event_type, timestamp, payload)
VALUES (?,?,?,?,?,?) """
def __init__(self, connection):
self._connection = connection
self._lock = threading.Lock()
def close(self):
self._connection.close()
def write_record(self, record):
db_record = self._create_db_record(record)
with self._lock:
self._connection.execute(self._WRITE_RECORD, db_record)
def _create_db_record(self, record):
event_type = record['event_type']
json_serialized_payload = json.dumps(record['payload'],
cls=PayloadSerializer)
db_record = (
record['command_id'],
record.get('request_id'),
record['source'],
event_type,
record['timestamp'],
json_serialized_payload
)
return db_record
class DatabaseRecordReader(object):
_ORDERING = 'ORDER BY timestamp'
_GET_LAST_ID_RECORDS = """
SELECT * FROM records
WHERE id =
(SELECT id FROM records WHERE timestamp =
(SELECT max(timestamp) FROM records)) %s;""" % _ORDERING
_GET_RECORDS_BY_ID = 'SELECT * from records where id = ? %s' % _ORDERING
_GET_ALL_RECORDS = (
'SELECT a.id AS id_a, '
' b.id AS id_b, '
' a.timestamp as timestamp, '
' a.payload AS args, '
' b.payload AS rc '
'FROM records a, records b '
'where a.event_type == "CLI_ARGUMENTS" AND '
' b.event_type = "CLI_RC" AND '
' id_a == id_b '
'%s DESC' % _ORDERING
)
def __init__(self, connection):
self._connection = connection
self._connection.row_factory = self._row_factory
def close(self):
self._connection.close()
def _row_factory(self, cursor, row):
d = {}
for idx, col in enumerate(cursor.description):
val = row[idx]
if col[0] == 'payload':
val = json.loads(val)
d[col[0]] = val
return d
def iter_latest_records(self):
cursor = self._connection.execute(self._GET_LAST_ID_RECORDS)
for row in cursor:
yield row
def iter_records(self, record_id):
cursor = self._connection.execute(self._GET_RECORDS_BY_ID, [record_id])
for row in cursor:
yield row
def iter_all_records(self):
cursor = self._connection.execute(self._GET_ALL_RECORDS)
for row in cursor:
yield row
class RecordBuilder(object):
_REQUEST_LIFECYCLE_EVENTS = set(
['API_CALL', 'HTTP_REQUEST', 'HTTP_RESPONSE', 'PARSED_RESPONSE'])
_START_OF_REQUEST_LIFECYCLE_EVENT = 'API_CALL'
def __init__(self):
self._identifier = None
self._locals = threading.local()
def _get_current_thread_request_id(self):
request_id = getattr(self._locals, 'request_id', None)
return request_id
def _start_http_lifecycle(self):
setattr(self._locals, 'request_id', str(uuid.uuid4()))
def _get_request_id(self, event_type):
if event_type == self._START_OF_REQUEST_LIFECYCLE_EVENT:
self._start_http_lifecycle()
if event_type in self._REQUEST_LIFECYCLE_EVENTS:
request_id = self._get_current_thread_request_id()
return request_id
return None
def _get_identifier(self):
if self._identifier is None:
self._identifier = str(uuid.uuid4())
return self._identifier
def build_record(self, event_type, payload, source):
uid = self._get_identifier()
record = {
'command_id': uid,
'event_type': event_type,
'payload': payload,
'source': source,
'timestamp': int(time.time() * 1000)
}
request_id = self._get_request_id(event_type)
if request_id:
record['request_id'] = request_id
return record
class DatabaseHistoryHandler(BaseHistoryHandler):
def __init__(self, writer, record_builder):
self._writer = writer
self._record_builder = record_builder
def emit(self, event_type, payload, source):
record = self._record_builder.build_record(event_type, payload, source)
self._writer.write_record(record)