Skip to content

Commit

Permalink
Merge pull request #9318 from jepler/sslsocket-stream-protocol
Browse files Browse the repository at this point in the history
SSLSocket: Add stream protocol
  • Loading branch information
jepler authored Jun 14, 2024
2 parents 03e42a8 + 7c85f6a commit ed5591c
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 8 deletions.
6 changes: 6 additions & 0 deletions shared-bindings/audiomp3/MP3Decoder.c
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@
//| decoder.file = stream
//|
//| If the stream is played with ``loop = True``, the loop will start at the beginning.
//|
//| It is possible to stream an mp3 from a socket, including a secure socket.
//| The MP3Decoder may change the timeout and non-blocking status of the socket.
//| Using a larger decode buffer with a stream can be helpful to avoid data underruns.
//| An ``adafruit_requests`` request must be made with ``headers={"Connection": "close"}`` so
//| that the socket closes when the stream ends.
//| """
//| ...

Expand Down
67 changes: 64 additions & 3 deletions shared-bindings/ssl/SSLSocket.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
#include <string.h>

#include "shared/runtime/context_manager_helpers.h"
#include "py/objtuple.h"
#include "py/mperrno.h"
#include "py/objlist.h"
#include "py/objtuple.h"
#include "py/runtime.h"
#include "py/mperrno.h"
#include "py/stream.h"

#include "shared/netutils/netutils.h"

Expand Down Expand Up @@ -247,9 +248,69 @@ static const mp_rom_map_elem_t ssl_sslsocket_locals_dict_table[] = {

static MP_DEFINE_CONST_DICT(ssl_sslsocket_locals_dict, ssl_sslsocket_locals_dict_table);

typedef mp_uint_t (*readwrite_func)(ssl_sslsocket_obj_t *, const uint8_t *, mp_uint_t);

static mp_int_t readwrite_common(mp_obj_t self_in, readwrite_func fn, const uint8_t *buf, size_t size, int *errorcode) {
ssl_sslsocket_obj_t *self = MP_OBJ_TO_PTR(self_in);
mp_int_t ret = -EIO;
nlr_buf_t nlr;
if (nlr_push(&nlr) == 0) {
ret = fn(self, buf, size);
nlr_pop();
} else {
mp_obj_t exc = MP_OBJ_FROM_PTR(nlr.ret_val);
if (nlr_push(&nlr) == 0) {
ret = -mp_obj_get_int(mp_load_attr(exc, MP_QSTR_errno));
nlr_pop();
}
}
if (ret < 0) {
*errorcode = -ret;
return MP_STREAM_ERROR;
}
return ret;
}

static mp_uint_t sslsocket_read(mp_obj_t self_in, void *buf, mp_uint_t size, int *errorcode) {
return readwrite_common(self_in, (readwrite_func)common_hal_ssl_sslsocket_recv_into, buf, size, errorcode);
}

static mp_uint_t sslsocket_write(mp_obj_t self_in, const void *buf, mp_uint_t size, int *errorcode) {
return readwrite_common(self_in, common_hal_ssl_sslsocket_send, buf, size, errorcode);
}

static mp_uint_t sslsocket_ioctl(mp_obj_t self_in, mp_uint_t request, mp_uint_t arg, int *errcode) {
ssl_sslsocket_obj_t *self = MP_OBJ_TO_PTR(self_in);
mp_uint_t ret;
if (request == MP_STREAM_POLL) {
mp_uint_t flags = arg;
ret = 0;
if ((flags & MP_STREAM_POLL_RD) && common_hal_ssl_sslsocket_readable(self) > 0) {
ret |= MP_STREAM_POLL_RD;
}
if ((flags & MP_STREAM_POLL_WR) && common_hal_ssl_sslsocket_writable(self)) {
ret |= MP_STREAM_POLL_WR;
}
} else {
*errcode = MP_EINVAL;
ret = MP_STREAM_ERROR;
}
return ret;
}


static const mp_stream_p_t sslsocket_stream_p = {
.read = sslsocket_read,
.write = sslsocket_write,
.ioctl = sslsocket_ioctl,
.is_text = false,
};


MP_DEFINE_CONST_OBJ_TYPE(
ssl_sslsocket_type,
MP_QSTR_SSLSocket,
MP_TYPE_FLAG_NONE,
locals_dict, &ssl_sslsocket_locals_dict
locals_dict, &ssl_sslsocket_locals_dict,
protocol, &sslsocket_stream_p
);
6 changes: 4 additions & 2 deletions shared-bindings/ssl/SSLSocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ void common_hal_ssl_sslsocket_close(ssl_sslsocket_obj_t *self);
void common_hal_ssl_sslsocket_connect(ssl_sslsocket_obj_t *self, mp_obj_t addr);
bool common_hal_ssl_sslsocket_get_closed(ssl_sslsocket_obj_t *self);
bool common_hal_ssl_sslsocket_get_connected(ssl_sslsocket_obj_t *self);
bool common_hal_ssl_sslsocket_readable(ssl_sslsocket_obj_t *self);
bool common_hal_ssl_sslsocket_writable(ssl_sslsocket_obj_t *self);
void common_hal_ssl_sslsocket_listen(ssl_sslsocket_obj_t *self, int backlog);
mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t *self, uint8_t *buf, uint32_t len);
mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t *self, const uint8_t *buf, uint32_t len);
mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t *self, uint8_t *buf, mp_uint_t len);
mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t *self, const uint8_t *buf, mp_uint_t len);
void common_hal_ssl_sslsocket_settimeout(ssl_sslsocket_obj_t *self, mp_obj_t timeout_obj);
void common_hal_ssl_sslsocket_setsockopt(ssl_sslsocket_obj_t *self, mp_obj_t level, mp_obj_t optname, mp_obj_t optval);
22 changes: 21 additions & 1 deletion shared-module/audiomp3/MP3Decoder.c
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,18 @@ static off_t stream_lseek(void *stream, off_t offset, int whence) {
#define INPUT_BUFFER_CONSUME(i, n) ((i).read_off += (n))
#define INPUT_BUFFER_CLEAR(i) ((i).read_off = (i).write_off = 0)

static void stream_set_blocking(audiomp3_mp3file_obj_t *self, bool block_ok) {
if (!self->settimeout_args[0]) {
return;
}
if (block_ok == self->block_ok) {
return;
}
self->block_ok = block_ok;
self->settimeout_args[2] = block_ok ? mp_const_none : mp_obj_new_int(0);
mp_call_method_n_kw(1, 0, self->settimeout_args);
}

/** Fill the input buffer unconditionally.
*
* Returns true if the input buffer contains any useful data,
Expand All @@ -110,6 +122,8 @@ static bool mp3file_update_inbuf_always(audiomp3_mp3file_obj_t *self, bool block
return INPUT_BUFFER_AVAILABLE(self->inbuf) > 0;
}

stream_set_blocking(self, block_ok);

// We didn't previously reach EOF and we have input buffer space available

// Move the unconsumed portion of the buffer to the start
Expand All @@ -119,7 +133,7 @@ static bool mp3file_update_inbuf_always(audiomp3_mp3file_obj_t *self, bool block
self->inbuf.read_off = 0;
}

for (size_t to_read; !self->eof && (to_read = INPUT_BUFFER_SPACE(self->inbuf)) > 0 && (block_ok || stream_readable(self->stream));) {
for (size_t to_read; !self->eof && (to_read = INPUT_BUFFER_SPACE(self->inbuf)) > 0;) {
uint8_t *write_ptr = self->inbuf.buf + self->inbuf.write_off;
ssize_t n_read = stream_read(self->stream, write_ptr, to_read);

Expand Down Expand Up @@ -328,9 +342,14 @@ void common_hal_audiomp3_mp3file_set_file(audiomp3_mp3file_obj_t *self, mp_obj_t
background_callback_prevent();

self->stream = stream;
mp_load_method_maybe(stream, MP_QSTR_settimeout, self->settimeout_args);

INPUT_BUFFER_CLEAR(self->inbuf);
self->eof = 0;

self->block_ok = false;
stream_set_blocking(self, true);

self->other_channel = -1;
mp3file_update_inbuf_half(self, true);
mp3file_find_sync_word(self, true);
Expand Down Expand Up @@ -365,6 +384,7 @@ void common_hal_audiomp3_mp3file_deinit(audiomp3_mp3file_obj_t *self) {
self->pcm_buffer[0] = NULL;
self->pcm_buffer[1] = NULL;
self->stream = mp_const_none;
self->settimeout_args[0] = MP_OBJ_NULL;
self->samples_decoded = 0;
}

Expand Down
2 changes: 2 additions & 0 deletions shared-module/audiomp3/MP3Decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ typedef struct {
uint8_t buffer_index;
uint8_t channel_count;
bool eof;
bool block_ok;
mp_obj_t settimeout_args[3];

int8_t other_channel;
int8_t other_buffer_index;
Expand Down
49 changes: 47 additions & 2 deletions shared-module/ssl/SSLSocket.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

#include "mbedtls/version.h"

#define MP_STREAM_POLL_RDWR (MP_STREAM_POLL_RD | MP_STREAM_POLL_WR)

#if defined(MBEDTLS_ERROR_C)
#include "../../lib/mbedtls_errors/mp_mbedtls_errors.c"
#endif
Expand Down Expand Up @@ -220,6 +222,7 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t
o->base.type = &ssl_sslsocket_type;
o->ssl_context = self;
o->sock_obj = socket;
o->poll_mask = 0;

mp_load_method(socket, MP_QSTR_accept, o->accept_args);
mp_load_method(socket, MP_QSTR_bind, o->bind_args);
Expand Down Expand Up @@ -330,7 +333,8 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t
}
}

mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t *self, uint8_t *buf, uint32_t len) {
mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t *self, uint8_t *buf, mp_uint_t len) {
self->poll_mask = 0;
int ret = mbedtls_ssl_read(&self->ssl, buf, len);
DEBUG_PRINT("recv_into mbedtls_ssl_read() -> %d\n", ret);
if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
Expand All @@ -342,17 +346,24 @@ mp_uint_t common_hal_ssl_sslsocket_recv_into(ssl_sslsocket_obj_t *self, uint8_t
DEBUG_PRINT("returning %d\n", ret);
return ret;
}
if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
self->poll_mask = MP_STREAM_POLL_WR;
}
DEBUG_PRINT("raising errno [error case] %d\n", ret);
mbedtls_raise_error(ret);
}

mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t *self, const uint8_t *buf, uint32_t len) {
mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t *self, const uint8_t *buf, mp_uint_t len) {
self->poll_mask = 0;
int ret = mbedtls_ssl_write(&self->ssl, buf, len);
DEBUG_PRINT("send mbedtls_ssl_write() -> %d\n", ret);
if (ret >= 0) {
DEBUG_PRINT("returning %d\n", ret);
return ret;
}
if (ret == MBEDTLS_ERR_SSL_WANT_READ) {
self->poll_mask = MP_STREAM_POLL_RD;
}
DEBUG_PRINT("raising errno [error case] %d\n", ret);
mbedtls_raise_error(ret);
}
Expand Down Expand Up @@ -448,3 +459,37 @@ void common_hal_ssl_sslsocket_setsockopt(ssl_sslsocket_obj_t *self, mp_obj_t lev
void common_hal_ssl_sslsocket_settimeout(ssl_sslsocket_obj_t *self, mp_obj_t timeout_obj) {
ssl_socket_settimeout(self, timeout_obj);
}

static bool poll_common(ssl_sslsocket_obj_t *self, uintptr_t arg) {
// Take into account that the library might have buffered data already
int has_pending = 0;
if (arg & MP_STREAM_POLL_RD) {
has_pending = mbedtls_ssl_check_pending(&self->ssl);
if (has_pending) {
// Shortcut if we only need to read and we have buffered data, no need to go to the underlying socket
return true;
}
}

// If the library signaled us that it needs reading or writing, only
// check that direction
if (self->poll_mask && (arg & MP_STREAM_POLL_RDWR)) {
arg = (arg & ~MP_STREAM_POLL_RDWR) | self->poll_mask;
}

// If direction the library needed is available, return a fake
// result to the caller so that it reenters a read or a write to
// allow the handshake to progress
const mp_stream_p_t *stream_p = mp_get_stream_raise(self->sock_obj, MP_STREAM_OP_IOCTL);
int errcode;
mp_int_t ret = stream_p->ioctl(self->sock_obj, MP_STREAM_POLL, arg, &errcode);
return ret != 0;
}

bool common_hal_ssl_sslsocket_readable(ssl_sslsocket_obj_t *self) {
return poll_common(self, MP_STREAM_POLL_RD);
}

bool common_hal_ssl_sslsocket_writable(ssl_sslsocket_obj_t *self) {
return poll_common(self, MP_STREAM_POLL_WR);
}
1 change: 1 addition & 0 deletions shared-module/ssl/SSLSocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ typedef struct ssl_sslsocket_obj {
mbedtls_x509_crt cacert;
mbedtls_x509_crt cert;
mbedtls_pk_context pkey;
uintptr_t poll_mask;
bool closed;
mp_obj_t accept_args[2];
mp_obj_t bind_args[3];
Expand Down

0 comments on commit ed5591c

Please sign in to comment.