Skip to content

Commit 8d32704

Browse files
committed
Some incomplete instance tag support.
(Not sure what's missing, I wrote that code a few weeks ago)
1 parent 4bf97e0 commit 8d32704

3 files changed

Lines changed: 183 additions & 115 deletions

File tree

src/potr/context.py

Lines changed: 131 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -61,23 +61,36 @@ def callable(x):
6161
OFFER_REJECTED = 2
6262
OFFER_ACCEPTED = 3
6363

64+
INSTAG_MASTER = 0
65+
INSTAG_BEST = 1
66+
INSTAG_RECENT = 2
67+
INSTAG_RECENT_RECEIVED = 3
68+
INSTAG_RECENT_SENT = 4
69+
MIN_VALID_INSTAG = 0x100
70+
71+
SENT = False
72+
RECEIVED = True
73+
74+
6475
class Context(object):
65-
__slots__ = ['user', 'policy', 'crypto', 'tagOffer', 'lastSend',
66-
'lastMessage', 'mayRetransmit', 'fragment', 'fragmentInfo', 'state',
67-
'inject', 'trust', 'peer', 'trustName']
76+
__slots__ = ['user', 'policy', 'crypto', 'tagOffer', 'lastSent',
77+
'lastMessage', 'mayRetransmit', 'fragment', 'state',
78+
'inject', 'peer', 'trustName', 'master', 'lastRecv',
79+
'recentChild', 'recentRcvdChild', 'recentSentChild']
6880

69-
def __init__(self, account, peername):
81+
def __init__(self, account, peername, instag=INSTAG_MASTER):
7082
self.user = account
7183
self.peer = peername
7284
self.policy = {}
7385
self.crypto = crypt.CryptEngine(self)
74-
self.discardFragment()
7586
self.tagOffer = OFFER_NOTSENT
7687
self.mayRetransmit = 0
77-
self.lastSend = 0
88+
self.lastSent = 0
89+
self.lastRecv = 0
7890
self.lastMessage = None
7991
self.state = STATE_PLAINTEXT
8092
self.trustName = self.peer
93+
self.fragment = FragmentAccumulator()
8194

8295
def getPolicy(self, key):
8396
raise NotImplementedError
@@ -88,51 +101,6 @@ def inject(self, msg, appdata=None):
88101
def policyOtrEnabled(self):
89102
return self.getPolicy('ALLOW_V2') or self.getPolicy('ALLOW_V1')
90103

91-
def discardFragment(self):
92-
self.fragmentInfo = (0, 0)
93-
self.fragment = []
94-
95-
def fragmentAccumulate(self, message):
96-
'''Accumulate a fragmented message. Returns None if the fragment is
97-
to be ignored, returns a string if the message is ready for further
98-
processing'''
99-
100-
params = message.split(b',')
101-
if len(params) < 5 or not params[1].isdigit() or not params[2].isdigit():
102-
logger.warning('invalid formed fragmented message: %r', params)
103-
return None
104-
105-
106-
K, N = self.fragmentInfo
107-
108-
k = int(params[1])
109-
n = int(params[2])
110-
fragData = params[3]
111-
112-
logger.debug(params)
113-
114-
if n >= k == 1:
115-
# first fragment
116-
self.discardFragment()
117-
self.fragmentInfo = (k,n)
118-
self.fragment.append(fragData)
119-
elif N == n >= k > 1 and k == K+1:
120-
# accumulate
121-
self.fragmentInfo = (k,n)
122-
self.fragment.append(fragData)
123-
else:
124-
# bad, discard
125-
self.discardFragment()
126-
logger.warning('invalid fragmented message: %r', params)
127-
return None
128-
129-
if n == k > 0:
130-
assembled = b''.join(self.fragment)
131-
self.discardFragment()
132-
return assembled
133-
134-
return None
135-
136104
def removeFingerprint(self, fingerprint):
137105
self.user.removeFingerprint(self.trustName, fingerprint)
138106

@@ -163,6 +131,15 @@ def getCurrentTrust(self):
163131
return None
164132
return self.getTrust(self.crypto.theirPubkey.cfingerprint(), None)
165133

134+
def updateRecent(self, direction):
135+
self.master.recentChild = self
136+
if direction == SENT:
137+
self.lastSent = time()
138+
self.master.recentSentChild = self
139+
else:
140+
self.lastRecv = time()
141+
self.master.recentRcvdChild = self
142+
166143
def receiveMessage(self, messageData, appdata=None):
167144
IGN = None, []
168145

@@ -176,6 +153,8 @@ def receiveMessage(self, messageData, appdata=None):
176153
return IGN
177154

178155
logger.debug(repr(message))
156+
157+
self.updateRecent(RECEIVED)
179158

180159
if self.getPolicy('SEND_TAG'):
181160
if isinstance(message, basestring):
@@ -218,7 +197,7 @@ def receiveMessage(self, messageData, appdata=None):
218197
try:
219198
plaintext, tlvs = self.crypto.handleDataMessage(message)
220199
self.processTLVs(tlvs, appdata=appdata)
221-
if plaintext and self.lastSend < time() - HEARTBEAT_INTERVAL:
200+
if plaintext and self.lastSent < time() - HEARTBEAT_INTERVAL:
222201
self.sendInternal(b'', appdata=appdata)
223202
return plaintext or None, tlvs
224203
except crypt.InvalidParameterError:
@@ -242,7 +221,7 @@ def sendInternal(self, msg, tlvs=[], appdata=None):
242221

243222
def sendMessage(self, sendPolicy, msg, flags=0, tlvs=[], appdata=None):
244223
if self.policyOtrEnabled():
245-
self.lastSend = time()
224+
self.updateRecent(SENT)
246225

247226
if isinstance(msg, proto.OTRMessage):
248227
# we want to send a protocol message (probably internal)
@@ -270,7 +249,7 @@ def processOutgoingMessage(self, msg, flags, tlvs=[]):
270249
if self.getPolicy('REQUIRE_ENCRYPTION'):
271250
if not isinstance(self.parse(msg), proto.Query):
272251
self.lastMessage = msg
273-
self.lastSend = time()
252+
self.updateRecent(SENT)
274253
self.mayRetransmit = 2
275254
# TODO notify
276255
msg = self.user.getDefaultQueryMessage(self.getPolicy)
@@ -286,7 +265,7 @@ def processOutgoingMessage(self, msg, flags, tlvs=[]):
286265
return msg
287266
if self.state == STATE_ENCRYPTED:
288267
msg = self.crypto.createDataMessage(msg, flags, tlvs)
289-
self.lastSend = time()
268+
self.updateRecent(SENT)
290269
return msg
291270
if self.state == STATE_FINISHED:
292271
raise NotEncryptedError(EXC_FINISHED)
@@ -391,51 +370,7 @@ def authStartV2(self, appdata=None):
391370
self.crypto.startAKE(appdata=appdata)
392371

393372
def parse(self, message):
394-
otrTagPos = message.find(proto.OTRTAG)
395-
if otrTagPos == -1:
396-
if proto.MESSAGE_TAG_BASE in message:
397-
return proto.TaggedPlaintext.parse(message)
398-
else:
399-
return message
400-
401-
indexBase = otrTagPos + len(proto.OTRTAG)
402-
compare = message[indexBase]
403-
404-
if compare == b','[0]:
405-
message = self.fragmentAccumulate(message[indexBase:])
406-
if message is None:
407-
return None
408-
else:
409-
return self.parse(message)
410-
else:
411-
self.discardFragment()
412-
413-
hasq = compare == b'?'[0]
414-
hasv = compare == b'v'[0]
415-
if hasq or hasv:
416-
hasv |= len(message) > indexBase+1 and \
417-
message[indexBase+1] == b'v'[0]
418-
if hasv:
419-
end = message.find(b'?', indexBase+1)
420-
else:
421-
end = indexBase+1
422-
payload = message[indexBase:end]
423-
return proto.Query.parse(payload)
424-
425-
if compare == b':'[0] and len(message) > indexBase + 4:
426-
infoTag = base64.b64decode(message[indexBase+1:indexBase+5])
427-
classInfo = struct.unpack(b'!HB', infoTag)
428-
cls = proto.messageClasses.get(classInfo, None)
429-
if cls is None:
430-
return message
431-
logger.debug('{user} got msg {typ!r}' \
432-
.format(user=self.user.name, typ=cls))
433-
return cls.parsePayload(message[indexBase+5:])
434-
435-
if message[indexBase:indexBase+7] == b' Error:':
436-
return proto.Error(message[indexBase+7:])
437-
438-
return message
373+
return proto.OTRMessage.parse(message, self)
439374

440375
def maxMessageSize(self, appdata=None):
441376
"""Return the max message size for this context."""
@@ -496,12 +431,49 @@ def savePrivkey(self):
496431
def saveTrusts(self):
497432
raise NotImplementedError
498433

499-
def getContext(self, uid, newCtxCb=None):
434+
def getContext(self, uid, instag=INSTAG_MASTER, newCtxCb=None):
500435
if uid not in self.ctxs:
501-
self.ctxs[uid] = self.contextclass(self, uid)
436+
# no master context found, create on first
437+
newctx = self.contextclass(self, uid, instag=INSTAG_MASTER)
438+
439+
newctx.master = newctx
440+
newctx.recentChild = newctx
441+
newctx.recentRcvdChild = newctx
442+
newctx.recentSentChild = newctx
443+
444+
self.ctxs[uid] = { INSTAG_MASTER:newctx }
502445
if callable(newCtxCb):
503-
newCtxCb(self.ctxs[uid])
504-
return self.ctxs[uid]
446+
newCtxCb(newctx)
447+
448+
master = self.ctxs[uid][INSTAG_MASTER]
449+
450+
if instag == INSTAG_MASTER:
451+
return master
452+
453+
elif instag >= MIN_VALID_INSTAG:
454+
if instag not in self.ctxs[uid]:
455+
# no instance context found, create
456+
ctx = self.contextclass(self, uid, instag=instag)
457+
ctx.master = self.ctxs[uid][INSTAG_MASTER]
458+
self.ctxs[uid][instag] = ctx
459+
if callable(newCtxCb):
460+
newCtxCb(ctx)
461+
else:
462+
ctx = self.ctxs[uid][instag]
463+
else:
464+
if instag == INSTAG_RECENT:
465+
ctx = master.recentChild
466+
elif instag == INSTAG_RECENT_RECEIVED:
467+
ctx = master.recentRcvdChild
468+
elif instag == INSTAG_RECENT_SENT:
469+
ctx = master.recentSentChild
470+
elif instag == INSTAG_BEST:
471+
ctx = max(self.ctxs[uid].values(), key=contextMetric)
472+
else:
473+
raise ValueError(
474+
'unknown meta instance tag {tag!r}'.format(tag=instag))
475+
476+
return ctx
505477

506478
def getDefaultQueryMessage(self, policy):
507479
v = '2' if policy('ALLOW_V2') else ''
@@ -523,6 +495,61 @@ def removeFingerprint(self, key, fingerprint):
523495
if key in self.trusts and fingerprint in self.trusts[key]:
524496
del self.trusts[key][fingerprint]
525497

498+
def contextMetric(ctx):
499+
return ctx.state << 65 | int(bool(ctx.getCurrentTrust())) << 64 | ctx.lastRecv
500+
501+
class FragmentAccumulator(object):
502+
def __init__(self):
503+
self.discard()
504+
505+
def discard(self):
506+
self.n = 0
507+
self.k = 0
508+
self.fragments = []
509+
510+
def process(self, message):
511+
'''Accumulate a fragmented message. Returns None if the fragment is
512+
to be ignored, returns a string if the message is ready for further
513+
processing'''
514+
515+
params = message.split(b',', 4)
516+
if len(params) == 1:
517+
# not fragmented
518+
return message
519+
520+
if len(params) != 5 or not params[1].isdigit() or not params[2].isdigit():
521+
logger.warning('invalid formed fragmented message: %r', params)
522+
return None
523+
524+
525+
K, N = self.k, self.n
526+
527+
k = int(params[1])
528+
n = int(params[2])
529+
fragData = params[3]
530+
531+
if n >= k == 1:
532+
# first fragment
533+
self.n = n
534+
self.k = k
535+
self.fragments = [fragData]
536+
elif N == n >= k > 1 and k == K+1:
537+
# accumulate
538+
self.k = k
539+
self.fragments.append(fragData)
540+
else:
541+
# bad, discard
542+
self.discard()
543+
logger.warning('invalid fragmented message: %r', params)
544+
return None
545+
546+
if n == k > 0:
547+
assembled = b''.join(self.fragments)
548+
self.discard()
549+
return assembled
550+
551+
return None
552+
526553
class NotEncryptedError(RuntimeError):
527554
pass
528555
class UnencryptedMessage(RuntimeError):

src/potr/proto.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,48 @@ def __eq__(self, other):
107107
return False
108108
return True
109109

110+
@staticmethod
111+
def parse(data, ctx):
112+
otrTagPos = data.find(OTRTAG)
113+
if otrTagPos == -1:
114+
return TaggedPlaintext.parse(data)
115+
116+
indexBase = otrTagPos + len(OTRTAG)
117+
compare = data[indexBase]
118+
119+
if compare == b','[0]:
120+
data = ctx.fragment.process(data[indexBase:])
121+
if data is None:
122+
return None
123+
return OTRMessage.parse(data, ctx)
124+
else:
125+
ctx.fragment.discard()
126+
127+
hasq = compare == b'?'[0]
128+
hasv = compare == b'v'[0]
129+
if hasq or hasv:
130+
hasv |= len(data) > indexBase+1 and \
131+
data[indexBase+1] == b'v'[0]
132+
if hasv:
133+
end = data.find(b'?', indexBase+1)
134+
else:
135+
end = indexBase+1
136+
payload = data[indexBase:end]
137+
return Query.parse(payload)
138+
139+
if compare == b':'[0] and len(data) > indexBase + 4:
140+
infoTag = base64.b64decode(data[indexBase+1:indexBase+5])
141+
classInfo = struct.unpack(b'!HB', infoTag)
142+
cls = messageClasses.get(classInfo, None)
143+
if cls is None:
144+
return data
145+
return cls.parsePayload(data[indexBase+5:])
146+
147+
if data[indexBase:indexBase+7] == b' Error:':
148+
return Error(data[indexBase+7:])
149+
150+
return data
151+
110152
def __neq__(self, other):
111153
return not self.__eq__(other)
112154

@@ -178,8 +220,7 @@ def __repr__(self):
178220
def parse(cls, data):
179221
tagPos = data.find(MESSAGE_TAG_BASE)
180222
if tagPos < 0:
181-
raise TypeError(
182-
'this is not a tagged plaintext ({0!r:.20})'.format(data))
223+
return data
183224

184225
tags = [ data[i:i+8] for i in range(tagPos, len(data), 8) ]
185226
versions = set([ version for version, tag in MESSAGE_TAGS.items() if tag

0 commit comments

Comments
 (0)