2626
2727#include <stdio.h>
2828#include <stdint.h>
29+ #include <string.h>
30+ #include <errno.h>
2931
3032#include "py/nlr.h"
3133#include "py/obj.h"
@@ -38,11 +40,13 @@ enum { FRAME_HEADER, FRAME_OPT, PAYLOAD };
3840typedef struct _mp_obj_websocket_t {
3941 mp_obj_base_t base ;
4042 mp_obj_t sock ;
41- uint32_t mask ;
43+ uint32_t msg_sz ;
44+ byte mask [4 ];
4245 byte state ;
4346 byte to_recv ;
4447 byte mask_pos ;
45- byte buf [4 ];
48+ byte buf_pos ;
49+ byte buf [6 ];
4650} mp_obj_websocket_t ;
4751
4852STATIC mp_obj_t websocket_make_new (const mp_obj_type_t * type , size_t n_args , size_t n_kw , const mp_obj_t * args ) {
@@ -53,9 +57,95 @@ STATIC mp_obj_t websocket_make_new(const mp_obj_type_t *type, size_t n_args, siz
5357 o -> state = FRAME_HEADER ;
5458 o -> to_recv = 2 ;
5559 o -> mask_pos = 0 ;
60+ o -> buf_pos = 0 ;
5661 return o ;
5762}
5863
64+ STATIC mp_uint_t websocket_read (mp_obj_t self_in , void * buf , mp_uint_t size , int * errcode ) {
65+ mp_obj_websocket_t * self = self_in ;
66+ const mp_stream_p_t * stream_p = mp_get_stream_raise (self -> sock , MP_STREAM_OP_READ );
67+ while (1 ) {
68+ if (self -> to_recv != 0 ) {
69+ mp_uint_t out_sz = stream_p -> read (self -> sock , self -> buf + self -> buf_pos , self -> to_recv , errcode );
70+ if (out_sz == MP_STREAM_ERROR ) {
71+ return out_sz ;
72+ }
73+ self -> buf_pos += out_sz ;
74+ self -> to_recv -= out_sz ;
75+ if (self -> to_recv != 0 ) {
76+ * errcode = EAGAIN ;
77+ return MP_STREAM_ERROR ;
78+ }
79+ }
80+
81+ switch (self -> state ) {
82+ case FRAME_HEADER : {
83+ assert (self -> buf [0 ] & 0x80 );
84+ int to_recv = 0 ;
85+ size_t sz = self -> buf [1 ] & 0x7f ;
86+ if (sz == 126 ) {
87+ // Msg size is next 2 bytes
88+ to_recv += 2 ;
89+ } else if (sz == 127 ) {
90+ // Msg size is next 2 bytes
91+ assert (0 );
92+ }
93+ if (self -> buf [1 ] & 0x80 ) {
94+ // Next 4 bytes is mask
95+ to_recv += 4 ;
96+ }
97+
98+ self -> buf_pos = 0 ;
99+ self -> to_recv = to_recv ;
100+ self -> msg_sz = sz ; // May be overriden by FRAME_OPT
101+ if (to_recv != 0 ) {
102+ self -> state = FRAME_OPT ;
103+ } else {
104+ self -> state = PAYLOAD ;
105+ }
106+ continue ;
107+ }
108+
109+ case FRAME_OPT : {
110+ if ((self -> buf_pos & 3 ) == 2 ) {
111+ // First two bytes are message length
112+ self -> msg_sz = (self -> buf [0 ] << 8 ) | self -> buf [1 ];
113+ }
114+ if (self -> buf_pos >= 4 ) {
115+ // Last 4 bytes is mask
116+ memcpy (self -> mask , self -> buf + self -> buf_pos - 4 , 4 );
117+ }
118+ self -> buf_pos = 0 ;
119+ self -> state = PAYLOAD ;
120+ continue ;
121+ }
122+
123+ case PAYLOAD : {
124+ size_t sz = MIN (size , self -> msg_sz );
125+ mp_uint_t out_sz = stream_p -> read (self -> sock , buf , sz , errcode );
126+ if (out_sz == MP_STREAM_ERROR ) {
127+ return out_sz ;
128+ }
129+
130+ sz = out_sz ;
131+ for (byte * p = buf ; sz -- ; p ++ ) {
132+ * p ^= self -> mask [self -> mask_pos ++ & 3 ];
133+ }
134+
135+ self -> msg_sz -= out_sz ;
136+ if (self -> msg_sz == 0 ) {
137+ self -> state = FRAME_HEADER ;
138+ self -> to_recv = 2 ;
139+ self -> mask_pos = 0 ;
140+ self -> buf_pos = 0 ;
141+ }
142+ return out_sz ;
143+ }
144+
145+ }
146+ }
147+ }
148+
59149STATIC mp_uint_t websocket_write (mp_obj_t self_in , const void * buf , mp_uint_t size , int * errcode ) {
60150 mp_obj_websocket_t * self = self_in ;
61151 assert (size < 126 );
@@ -69,12 +159,13 @@ STATIC mp_uint_t websocket_write(mp_obj_t self_in, const void *buf, mp_uint_t si
69159}
70160
71161STATIC const mp_map_elem_t websocket_locals_dict_table [] = {
162+ { MP_OBJ_NEW_QSTR (MP_QSTR_read ), (mp_obj_t )& mp_stream_read_obj },
72163 { MP_OBJ_NEW_QSTR (MP_QSTR_write ), (mp_obj_t )& mp_stream_write_obj },
73164};
74165STATIC MP_DEFINE_CONST_DICT (websocket_locals_dict , websocket_locals_dict_table );
75166
76167STATIC const mp_stream_p_t websocket_stream_p = {
77- // .read = websocket_read,
168+ .read = websocket_read ,
78169 .write = websocket_write ,
79170};
80171
0 commit comments