From 8d32704ed1b0c11fd55fd69b81f05c5576b1682a Mon Sep 17 00:00:00 2001 From: Kjell Braden Date: Thu, 21 Mar 2013 19:46:25 +0100 Subject: [PATCH] Some incomplete instance tag support. (Not sure what's missing, I wrote that code a few weeks ago) --- src/potr/context.py | 235 ++++++++++++++++++++++++-------------------- src/potr/proto.py | 45 ++++++++- tests/testBasic.py | 18 ++-- 3 files changed, 183 insertions(+), 115 deletions(-) diff --git a/src/potr/context.py b/src/potr/context.py index 6bc191f..56c1b86 100644 --- a/src/potr/context.py +++ b/src/potr/context.py @@ -61,23 +61,36 @@ def callable(x): OFFER_REJECTED = 2 OFFER_ACCEPTED = 3 +INSTAG_MASTER = 0 +INSTAG_BEST = 1 +INSTAG_RECENT = 2 +INSTAG_RECENT_RECEIVED = 3 +INSTAG_RECENT_SENT = 4 +MIN_VALID_INSTAG = 0x100 + +SENT = False +RECEIVED = True + + class Context(object): - __slots__ = ['user', 'policy', 'crypto', 'tagOffer', 'lastSend', - 'lastMessage', 'mayRetransmit', 'fragment', 'fragmentInfo', 'state', - 'inject', 'trust', 'peer', 'trustName'] + __slots__ = ['user', 'policy', 'crypto', 'tagOffer', 'lastSent', + 'lastMessage', 'mayRetransmit', 'fragment', 'state', + 'inject', 'peer', 'trustName', 'master', 'lastRecv', + 'recentChild', 'recentRcvdChild', 'recentSentChild'] - def __init__(self, account, peername): + def __init__(self, account, peername, instag=INSTAG_MASTER): self.user = account self.peer = peername self.policy = {} self.crypto = crypt.CryptEngine(self) - self.discardFragment() self.tagOffer = OFFER_NOTSENT self.mayRetransmit = 0 - self.lastSend = 0 + self.lastSent = 0 + self.lastRecv = 0 self.lastMessage = None self.state = STATE_PLAINTEXT self.trustName = self.peer + self.fragment = FragmentAccumulator() def getPolicy(self, key): raise NotImplementedError @@ -88,51 +101,6 @@ def inject(self, msg, appdata=None): def policyOtrEnabled(self): return self.getPolicy('ALLOW_V2') or self.getPolicy('ALLOW_V1') - def discardFragment(self): - self.fragmentInfo = (0, 0) - self.fragment = [] - - def fragmentAccumulate(self, message): - '''Accumulate a fragmented message. Returns None if the fragment is - to be ignored, returns a string if the message is ready for further - processing''' - - params = message.split(b',') - if len(params) < 5 or not params[1].isdigit() or not params[2].isdigit(): - logger.warning('invalid formed fragmented message: %r', params) - return None - - - K, N = self.fragmentInfo - - k = int(params[1]) - n = int(params[2]) - fragData = params[3] - - logger.debug(params) - - if n >= k == 1: - # first fragment - self.discardFragment() - self.fragmentInfo = (k,n) - self.fragment.append(fragData) - elif N == n >= k > 1 and k == K+1: - # accumulate - self.fragmentInfo = (k,n) - self.fragment.append(fragData) - else: - # bad, discard - self.discardFragment() - logger.warning('invalid fragmented message: %r', params) - return None - - if n == k > 0: - assembled = b''.join(self.fragment) - self.discardFragment() - return assembled - - return None - def removeFingerprint(self, fingerprint): self.user.removeFingerprint(self.trustName, fingerprint) @@ -163,6 +131,15 @@ def getCurrentTrust(self): return None return self.getTrust(self.crypto.theirPubkey.cfingerprint(), None) + def updateRecent(self, direction): + self.master.recentChild = self + if direction == SENT: + self.lastSent = time() + self.master.recentSentChild = self + else: + self.lastRecv = time() + self.master.recentRcvdChild = self + def receiveMessage(self, messageData, appdata=None): IGN = None, [] @@ -176,6 +153,8 @@ def receiveMessage(self, messageData, appdata=None): return IGN logger.debug(repr(message)) + + self.updateRecent(RECEIVED) if self.getPolicy('SEND_TAG'): if isinstance(message, basestring): @@ -218,7 +197,7 @@ def receiveMessage(self, messageData, appdata=None): try: plaintext, tlvs = self.crypto.handleDataMessage(message) self.processTLVs(tlvs, appdata=appdata) - if plaintext and self.lastSend < time() - HEARTBEAT_INTERVAL: + if plaintext and self.lastSent < time() - HEARTBEAT_INTERVAL: self.sendInternal(b'', appdata=appdata) return plaintext or None, tlvs except crypt.InvalidParameterError: @@ -242,7 +221,7 @@ def sendInternal(self, msg, tlvs=[], appdata=None): def sendMessage(self, sendPolicy, msg, flags=0, tlvs=[], appdata=None): if self.policyOtrEnabled(): - self.lastSend = time() + self.updateRecent(SENT) if isinstance(msg, proto.OTRMessage): # we want to send a protocol message (probably internal) @@ -270,7 +249,7 @@ def processOutgoingMessage(self, msg, flags, tlvs=[]): if self.getPolicy('REQUIRE_ENCRYPTION'): if not isinstance(self.parse(msg), proto.Query): self.lastMessage = msg - self.lastSend = time() + self.updateRecent(SENT) self.mayRetransmit = 2 # TODO notify msg = self.user.getDefaultQueryMessage(self.getPolicy) @@ -286,7 +265,7 @@ def processOutgoingMessage(self, msg, flags, tlvs=[]): return msg if self.state == STATE_ENCRYPTED: msg = self.crypto.createDataMessage(msg, flags, tlvs) - self.lastSend = time() + self.updateRecent(SENT) return msg if self.state == STATE_FINISHED: raise NotEncryptedError(EXC_FINISHED) @@ -391,51 +370,7 @@ def authStartV2(self, appdata=None): self.crypto.startAKE(appdata=appdata) def parse(self, message): - otrTagPos = message.find(proto.OTRTAG) - if otrTagPos == -1: - if proto.MESSAGE_TAG_BASE in message: - return proto.TaggedPlaintext.parse(message) - else: - return message - - indexBase = otrTagPos + len(proto.OTRTAG) - compare = message[indexBase] - - if compare == b','[0]: - message = self.fragmentAccumulate(message[indexBase:]) - if message is None: - return None - else: - return self.parse(message) - else: - self.discardFragment() - - hasq = compare == b'?'[0] - hasv = compare == b'v'[0] - if hasq or hasv: - hasv |= len(message) > indexBase+1 and \ - message[indexBase+1] == b'v'[0] - if hasv: - end = message.find(b'?', indexBase+1) - else: - end = indexBase+1 - payload = message[indexBase:end] - return proto.Query.parse(payload) - - if compare == b':'[0] and len(message) > indexBase + 4: - infoTag = base64.b64decode(message[indexBase+1:indexBase+5]) - classInfo = struct.unpack(b'!HB', infoTag) - cls = proto.messageClasses.get(classInfo, None) - if cls is None: - return message - logger.debug('{user} got msg {typ!r}' \ - .format(user=self.user.name, typ=cls)) - return cls.parsePayload(message[indexBase+5:]) - - if message[indexBase:indexBase+7] == b' Error:': - return proto.Error(message[indexBase+7:]) - - return message + return proto.OTRMessage.parse(message, self) def maxMessageSize(self, appdata=None): """Return the max message size for this context.""" @@ -496,12 +431,49 @@ def savePrivkey(self): def saveTrusts(self): raise NotImplementedError - def getContext(self, uid, newCtxCb=None): + def getContext(self, uid, instag=INSTAG_MASTER, newCtxCb=None): if uid not in self.ctxs: - self.ctxs[uid] = self.contextclass(self, uid) + # no master context found, create on first + newctx = self.contextclass(self, uid, instag=INSTAG_MASTER) + + newctx.master = newctx + newctx.recentChild = newctx + newctx.recentRcvdChild = newctx + newctx.recentSentChild = newctx + + self.ctxs[uid] = { INSTAG_MASTER:newctx } if callable(newCtxCb): - newCtxCb(self.ctxs[uid]) - return self.ctxs[uid] + newCtxCb(newctx) + + master = self.ctxs[uid][INSTAG_MASTER] + + if instag == INSTAG_MASTER: + return master + + elif instag >= MIN_VALID_INSTAG: + if instag not in self.ctxs[uid]: + # no instance context found, create + ctx = self.contextclass(self, uid, instag=instag) + ctx.master = self.ctxs[uid][INSTAG_MASTER] + self.ctxs[uid][instag] = ctx + if callable(newCtxCb): + newCtxCb(ctx) + else: + ctx = self.ctxs[uid][instag] + else: + if instag == INSTAG_RECENT: + ctx = master.recentChild + elif instag == INSTAG_RECENT_RECEIVED: + ctx = master.recentRcvdChild + elif instag == INSTAG_RECENT_SENT: + ctx = master.recentSentChild + elif instag == INSTAG_BEST: + ctx = max(self.ctxs[uid].values(), key=contextMetric) + else: + raise ValueError( + 'unknown meta instance tag {tag!r}'.format(tag=instag)) + + return ctx def getDefaultQueryMessage(self, policy): v = '2' if policy('ALLOW_V2') else '' @@ -523,6 +495,61 @@ def removeFingerprint(self, key, fingerprint): if key in self.trusts and fingerprint in self.trusts[key]: del self.trusts[key][fingerprint] +def contextMetric(ctx): + return ctx.state << 65 | int(bool(ctx.getCurrentTrust())) << 64 | ctx.lastRecv + +class FragmentAccumulator(object): + def __init__(self): + self.discard() + + def discard(self): + self.n = 0 + self.k = 0 + self.fragments = [] + + def process(self, message): + '''Accumulate a fragmented message. Returns None if the fragment is + to be ignored, returns a string if the message is ready for further + processing''' + + params = message.split(b',', 4) + if len(params) == 1: + # not fragmented + return message + + if len(params) != 5 or not params[1].isdigit() or not params[2].isdigit(): + logger.warning('invalid formed fragmented message: %r', params) + return None + + + K, N = self.k, self.n + + k = int(params[1]) + n = int(params[2]) + fragData = params[3] + + if n >= k == 1: + # first fragment + self.n = n + self.k = k + self.fragments = [fragData] + elif N == n >= k > 1 and k == K+1: + # accumulate + self.k = k + self.fragments.append(fragData) + else: + # bad, discard + self.discard() + logger.warning('invalid fragmented message: %r', params) + return None + + if n == k > 0: + assembled = b''.join(self.fragments) + self.discard() + return assembled + + return None + class NotEncryptedError(RuntimeError): pass class UnencryptedMessage(RuntimeError): diff --git a/src/potr/proto.py b/src/potr/proto.py index 9716f2c..036d3eb 100644 --- a/src/potr/proto.py +++ b/src/potr/proto.py @@ -107,6 +107,48 @@ def __eq__(self, other): return False return True + @staticmethod + def parse(data, ctx): + otrTagPos = data.find(OTRTAG) + if otrTagPos == -1: + return TaggedPlaintext.parse(data) + + indexBase = otrTagPos + len(OTRTAG) + compare = data[indexBase] + + if compare == b','[0]: + data = ctx.fragment.process(data[indexBase:]) + if data is None: + return None + return OTRMessage.parse(data, ctx) + else: + ctx.fragment.discard() + + hasq = compare == b'?'[0] + hasv = compare == b'v'[0] + if hasq or hasv: + hasv |= len(data) > indexBase+1 and \ + data[indexBase+1] == b'v'[0] + if hasv: + end = data.find(b'?', indexBase+1) + else: + end = indexBase+1 + payload = data[indexBase:end] + return Query.parse(payload) + + if compare == b':'[0] and len(data) > indexBase + 4: + infoTag = base64.b64decode(data[indexBase+1:indexBase+5]) + classInfo = struct.unpack(b'!HB', infoTag) + cls = messageClasses.get(classInfo, None) + if cls is None: + return data + return cls.parsePayload(data[indexBase+5:]) + + if data[indexBase:indexBase+7] == b' Error:': + return Error(data[indexBase+7:]) + + return data + def __neq__(self, other): return not self.__eq__(other) @@ -178,8 +220,7 @@ def __repr__(self): def parse(cls, data): tagPos = data.find(MESSAGE_TAG_BASE) if tagPos < 0: - raise TypeError( - 'this is not a tagged plaintext ({0!r:.20})'.format(data)) + return data tags = [ data[i:i+8] for i in range(tagPos, len(data), 8) ] versions = set([ version for version, tag in MESSAGE_TAGS.items() if tag diff --git a/tests/testBasic.py b/tests/testBasic.py index e2261ac..2fd9c80 100644 --- a/tests/testBasic.py +++ b/tests/testBasic.py @@ -5,9 +5,18 @@ import unittest import base64 from potr import proto +from potr import utils class ProtoTest(unittest.TestCase): + def testLongToBytes(self): + self.assertEqual(b'\xde\xad\xbe\xef', + utils.long_to_bytes(0xdeadbeef)) + self.assertEqual(b'\0\0\0\0\0\0\xde\xad\xbe\xef', + utils.long_to_bytes(0xdeadbeef, 10)) + self.assertEqual(b'', utils.long_to_bytes(0x00)) + self.assertEqual(b'\0\0\0\0\0\0\0\0\0\0', utils.long_to_bytes(0x00, 10)) + def testPackData(self): self.assertEqual(b'\0\0\0\0', proto.pack_data(b'')) self.assertEqual(b'\0\0\0\x0afoobarbazx', proto.pack_data(b'foobarbazx')) @@ -85,15 +94,6 @@ def testQuery(self): b'\x20\x09\x20\x20\x09\x09\x09\x09\x20\x09\x20\x09\x20\x09\x20\x20', set([])) - # untagged - self.assertRaises(TypeError, - lambda: proto.TaggedPlaintext.parse(b'Foobarbaz?')) - - # only the version tag without base - self.assertRaises(TypeError, - lambda: proto.TaggedPlaintext.parse(b'Foobarbaz!' - + b'\x20\x09\x20\x09\x20\x20\x09\x20')) - def testGenericMsg(self): msg = base64.b64encode(proto.pack_data(b'foo')) self.assertEqual(b'foo', proto.DHKey.parsePayload(msg).gy)