@@ -61,23 +61,36 @@ def callable(x):
6161OFFER_REJECTED = 2
6262OFFER_ACCEPTED = 3
6363
64+ INSTAG_MASTER = 0
65+ INSTAG_BEST = 1
66+ INSTAG_RECENT = 2
67+ INSTAG_RECENT_RECEIVED = 3
68+ INSTAG_RECENT_SENT = 4
69+ MIN_VALID_INSTAG = 0x100
70+
71+ SENT = False
72+ RECEIVED = True
73+
74+
6475class Context (object ):
65- __slots__ = ['user' , 'policy' , 'crypto' , 'tagOffer' , 'lastSend' ,
66- 'lastMessage' , 'mayRetransmit' , 'fragment' , 'fragmentInfo' , 'state' ,
67- 'inject' , 'trust' , 'peer' , 'trustName' ]
76+ __slots__ = ['user' , 'policy' , 'crypto' , 'tagOffer' , 'lastSent' ,
77+ 'lastMessage' , 'mayRetransmit' , 'fragment' , 'state' ,
78+ 'inject' , 'peer' , 'trustName' , 'master' , 'lastRecv' ,
79+ 'recentChild' , 'recentRcvdChild' , 'recentSentChild' ]
6880
69- def __init__ (self , account , peername ):
81+ def __init__ (self , account , peername , instag = INSTAG_MASTER ):
7082 self .user = account
7183 self .peer = peername
7284 self .policy = {}
7385 self .crypto = crypt .CryptEngine (self )
74- self .discardFragment ()
7586 self .tagOffer = OFFER_NOTSENT
7687 self .mayRetransmit = 0
77- self .lastSend = 0
88+ self .lastSent = 0
89+ self .lastRecv = 0
7890 self .lastMessage = None
7991 self .state = STATE_PLAINTEXT
8092 self .trustName = self .peer
93+ self .fragment = FragmentAccumulator ()
8194
8295 def getPolicy (self , key ):
8396 raise NotImplementedError
@@ -88,51 +101,6 @@ def inject(self, msg, appdata=None):
88101 def policyOtrEnabled (self ):
89102 return self .getPolicy ('ALLOW_V2' ) or self .getPolicy ('ALLOW_V1' )
90103
91- def discardFragment (self ):
92- self .fragmentInfo = (0 , 0 )
93- self .fragment = []
94-
95- def fragmentAccumulate (self , message ):
96- '''Accumulate a fragmented message. Returns None if the fragment is
97- to be ignored, returns a string if the message is ready for further
98- processing'''
99-
100- params = message .split (b',' )
101- if len (params ) < 5 or not params [1 ].isdigit () or not params [2 ].isdigit ():
102- logger .warning ('invalid formed fragmented message: %r' , params )
103- return None
104-
105-
106- K , N = self .fragmentInfo
107-
108- k = int (params [1 ])
109- n = int (params [2 ])
110- fragData = params [3 ]
111-
112- logger .debug (params )
113-
114- if n >= k == 1 :
115- # first fragment
116- self .discardFragment ()
117- self .fragmentInfo = (k ,n )
118- self .fragment .append (fragData )
119- elif N == n >= k > 1 and k == K + 1 :
120- # accumulate
121- self .fragmentInfo = (k ,n )
122- self .fragment .append (fragData )
123- else :
124- # bad, discard
125- self .discardFragment ()
126- logger .warning ('invalid fragmented message: %r' , params )
127- return None
128-
129- if n == k > 0 :
130- assembled = b'' .join (self .fragment )
131- self .discardFragment ()
132- return assembled
133-
134- return None
135-
136104 def removeFingerprint (self , fingerprint ):
137105 self .user .removeFingerprint (self .trustName , fingerprint )
138106
@@ -163,6 +131,15 @@ def getCurrentTrust(self):
163131 return None
164132 return self .getTrust (self .crypto .theirPubkey .cfingerprint (), None )
165133
134+ def updateRecent (self , direction ):
135+ self .master .recentChild = self
136+ if direction == SENT :
137+ self .lastSent = time ()
138+ self .master .recentSentChild = self
139+ else :
140+ self .lastRecv = time ()
141+ self .master .recentRcvdChild = self
142+
166143 def receiveMessage (self , messageData , appdata = None ):
167144 IGN = None , []
168145
@@ -176,6 +153,8 @@ def receiveMessage(self, messageData, appdata=None):
176153 return IGN
177154
178155 logger .debug (repr (message ))
156+
157+ self .updateRecent (RECEIVED )
179158
180159 if self .getPolicy ('SEND_TAG' ):
181160 if isinstance (message , basestring ):
@@ -218,7 +197,7 @@ def receiveMessage(self, messageData, appdata=None):
218197 try :
219198 plaintext , tlvs = self .crypto .handleDataMessage (message )
220199 self .processTLVs (tlvs , appdata = appdata )
221- if plaintext and self .lastSend < time () - HEARTBEAT_INTERVAL :
200+ if plaintext and self .lastSent < time () - HEARTBEAT_INTERVAL :
222201 self .sendInternal (b'' , appdata = appdata )
223202 return plaintext or None , tlvs
224203 except crypt .InvalidParameterError :
@@ -242,7 +221,7 @@ def sendInternal(self, msg, tlvs=[], appdata=None):
242221
243222 def sendMessage (self , sendPolicy , msg , flags = 0 , tlvs = [], appdata = None ):
244223 if self .policyOtrEnabled ():
245- self .lastSend = time ( )
224+ self .updateRecent ( SENT )
246225
247226 if isinstance (msg , proto .OTRMessage ):
248227 # we want to send a protocol message (probably internal)
@@ -270,7 +249,7 @@ def processOutgoingMessage(self, msg, flags, tlvs=[]):
270249 if self .getPolicy ('REQUIRE_ENCRYPTION' ):
271250 if not isinstance (self .parse (msg ), proto .Query ):
272251 self .lastMessage = msg
273- self .lastSend = time ( )
252+ self .updateRecent ( SENT )
274253 self .mayRetransmit = 2
275254 # TODO notify
276255 msg = self .user .getDefaultQueryMessage (self .getPolicy )
@@ -286,7 +265,7 @@ def processOutgoingMessage(self, msg, flags, tlvs=[]):
286265 return msg
287266 if self .state == STATE_ENCRYPTED :
288267 msg = self .crypto .createDataMessage (msg , flags , tlvs )
289- self .lastSend = time ( )
268+ self .updateRecent ( SENT )
290269 return msg
291270 if self .state == STATE_FINISHED :
292271 raise NotEncryptedError (EXC_FINISHED )
@@ -391,51 +370,7 @@ def authStartV2(self, appdata=None):
391370 self .crypto .startAKE (appdata = appdata )
392371
393372 def parse (self , message ):
394- otrTagPos = message .find (proto .OTRTAG )
395- if otrTagPos == - 1 :
396- if proto .MESSAGE_TAG_BASE in message :
397- return proto .TaggedPlaintext .parse (message )
398- else :
399- return message
400-
401- indexBase = otrTagPos + len (proto .OTRTAG )
402- compare = message [indexBase ]
403-
404- if compare == b',' [0 ]:
405- message = self .fragmentAccumulate (message [indexBase :])
406- if message is None :
407- return None
408- else :
409- return self .parse (message )
410- else :
411- self .discardFragment ()
412-
413- hasq = compare == b'?' [0 ]
414- hasv = compare == b'v' [0 ]
415- if hasq or hasv :
416- hasv |= len (message ) > indexBase + 1 and \
417- message [indexBase + 1 ] == b'v' [0 ]
418- if hasv :
419- end = message .find (b'?' , indexBase + 1 )
420- else :
421- end = indexBase + 1
422- payload = message [indexBase :end ]
423- return proto .Query .parse (payload )
424-
425- if compare == b':' [0 ] and len (message ) > indexBase + 4 :
426- infoTag = base64 .b64decode (message [indexBase + 1 :indexBase + 5 ])
427- classInfo = struct .unpack (b'!HB' , infoTag )
428- cls = proto .messageClasses .get (classInfo , None )
429- if cls is None :
430- return message
431- logger .debug ('{user} got msg {typ!r}' \
432- .format (user = self .user .name , typ = cls ))
433- return cls .parsePayload (message [indexBase + 5 :])
434-
435- if message [indexBase :indexBase + 7 ] == b' Error:' :
436- return proto .Error (message [indexBase + 7 :])
437-
438- return message
373+ return proto .OTRMessage .parse (message , self )
439374
440375 def maxMessageSize (self , appdata = None ):
441376 """Return the max message size for this context."""
@@ -496,12 +431,49 @@ def savePrivkey(self):
496431 def saveTrusts (self ):
497432 raise NotImplementedError
498433
499- def getContext (self , uid , newCtxCb = None ):
434+ def getContext (self , uid , instag = INSTAG_MASTER , newCtxCb = None ):
500435 if uid not in self .ctxs :
501- self .ctxs [uid ] = self .contextclass (self , uid )
436+ # no master context found, create on first
437+ newctx = self .contextclass (self , uid , instag = INSTAG_MASTER )
438+
439+ newctx .master = newctx
440+ newctx .recentChild = newctx
441+ newctx .recentRcvdChild = newctx
442+ newctx .recentSentChild = newctx
443+
444+ self .ctxs [uid ] = { INSTAG_MASTER :newctx }
502445 if callable (newCtxCb ):
503- newCtxCb (self .ctxs [uid ])
504- return self .ctxs [uid ]
446+ newCtxCb (newctx )
447+
448+ master = self .ctxs [uid ][INSTAG_MASTER ]
449+
450+ if instag == INSTAG_MASTER :
451+ return master
452+
453+ elif instag >= MIN_VALID_INSTAG :
454+ if instag not in self .ctxs [uid ]:
455+ # no instance context found, create
456+ ctx = self .contextclass (self , uid , instag = instag )
457+ ctx .master = self .ctxs [uid ][INSTAG_MASTER ]
458+ self .ctxs [uid ][instag ] = ctx
459+ if callable (newCtxCb ):
460+ newCtxCb (ctx )
461+ else :
462+ ctx = self .ctxs [uid ][instag ]
463+ else :
464+ if instag == INSTAG_RECENT :
465+ ctx = master .recentChild
466+ elif instag == INSTAG_RECENT_RECEIVED :
467+ ctx = master .recentRcvdChild
468+ elif instag == INSTAG_RECENT_SENT :
469+ ctx = master .recentSentChild
470+ elif instag == INSTAG_BEST :
471+ ctx = max (self .ctxs [uid ].values (), key = contextMetric )
472+ else :
473+ raise ValueError (
474+ 'unknown meta instance tag {tag!r}' .format (tag = instag ))
475+
476+ return ctx
505477
506478 def getDefaultQueryMessage (self , policy ):
507479 v = '2' if policy ('ALLOW_V2' ) else ''
@@ -523,6 +495,61 @@ def removeFingerprint(self, key, fingerprint):
523495 if key in self .trusts and fingerprint in self .trusts [key ]:
524496 del self .trusts [key ][fingerprint ]
525497
498+ def contextMetric (ctx ):
499+ return ctx .state << 65 | int (bool (ctx .getCurrentTrust ())) << 64 | ctx .lastRecv
500+
501+ class FragmentAccumulator (object ):
502+ def __init__ (self ):
503+ self .discard ()
504+
505+ def discard (self ):
506+ self .n = 0
507+ self .k = 0
508+ self .fragments = []
509+
510+ def process (self , message ):
511+ '''Accumulate a fragmented message. Returns None if the fragment is
512+ to be ignored, returns a string if the message is ready for further
513+ processing'''
514+
515+ params = message .split (b',' , 4 )
516+ if len (params ) == 1 :
517+ # not fragmented
518+ return message
519+
520+ if len (params ) != 5 or not params [1 ].isdigit () or not params [2 ].isdigit ():
521+ logger .warning ('invalid formed fragmented message: %r' , params )
522+ return None
523+
524+
525+ K , N = self .k , self .n
526+
527+ k = int (params [1 ])
528+ n = int (params [2 ])
529+ fragData = params [3 ]
530+
531+ if n >= k == 1 :
532+ # first fragment
533+ self .n = n
534+ self .k = k
535+ self .fragments = [fragData ]
536+ elif N == n >= k > 1 and k == K + 1 :
537+ # accumulate
538+ self .k = k
539+ self .fragments .append (fragData )
540+ else :
541+ # bad, discard
542+ self .discard ()
543+ logger .warning ('invalid fragmented message: %r' , params )
544+ return None
545+
546+ if n == k > 0 :
547+ assembled = b'' .join (self .fragments )
548+ self .discard ()
549+ return assembled
550+
551+ return None
552+
526553class NotEncryptedError (RuntimeError ):
527554 pass
528555class UnencryptedMessage (RuntimeError ):
0 commit comments