diff --git a/core/http/transport.hpp b/core/http/transport.hpp index 8a19e9daa..5d0b5a186 100644 --- a/core/http/transport.hpp +++ b/core/http/transport.hpp @@ -61,8 +61,14 @@ namespace Lambda::HTTP::Transport { virtual bool awaitNext() = 0; virtual HTTP::Request nextRequest() = 0; virtual void respond(const HTTP::Response& response) = 0; + + virtual std::vector readRaw() = 0; + virtual std::vector readRaw(size_t expectedSize) = 0; + virtual void writeRaw(const std::vector& data) = 0; + virtual void reset() noexcept = 0; virtual bool hasPartialData() const noexcept = 0; + virtual void close() = 0; TransportFlags flags; }; diff --git a/core/http/transport_impl.hpp b/core/http/transport_impl.hpp index ce4007584..8df17910b 100644 --- a/core/http/transport_impl.hpp +++ b/core/http/transport_impl.hpp @@ -37,8 +37,14 @@ namespace Lambda::HTTP::Transport { bool awaitNext(); HTTP::Request nextRequest(); void respond(const HTTP::Response& response); + + std::vector readRaw(); + std::vector readRaw(size_t expectedSize); + void writeRaw(const std::vector& data); + void reset() noexcept; bool hasPartialData() const noexcept; + void close(); }; }; diff --git a/core/http/transport_v1.cpp b/core/http/transport_v1.cpp index 1ec137242..ecdcd2c59 100644 --- a/core/http/transport_v1.cpp +++ b/core/http/transport_v1.cpp @@ -24,26 +24,6 @@ static const std::initializer_list compressibleTypes = { "text", "html", "json", "xml" }; -const Network::ConnectionInfo& TransportContextV1::conninfo() const noexcept { - return this->m_conn.info(); -} - -const TransportOptions& TransportContextV1::options() const noexcept { - return this->m_topts; -} - -Network::TCP::Connection& TransportContextV1::tcpconn() const noexcept { - return this->m_conn; -} - -const ContentEncodings& TransportContextV1::getEnconding() const noexcept { - return this->m_compress; -} - -bool TransportContextV1::ok() const noexcept { - return this->m_conn.active(); -} - TransportContextV1::TransportContextV1( Network::TCP::Connection& connInit, const TransportOptions& optsInit @@ -52,7 +32,7 @@ TransportContextV1::TransportContextV1( bool TransportContextV1::awaitNext() { if (this->m_next != nullptr) { - throw Lambda::Error("awaitNext() cannot receive more requests until the previous one is processed"); + throw Lambda::Error("awaitNext() cancelec: the previous requests was not processed yet"); } auto headerEnded = this->m_readbuff.end(); @@ -206,7 +186,7 @@ bool TransportContextV1::awaitNext() { HTTP::Request TransportContextV1::nextRequest() { if (this->m_next == nullptr) { - throw Lambda::Error("nextRequest() canceled: no requests pending"); + throw Lambda::Error("nextRequest() canceled: no requests pending. Use awaitNext() to read more requests"); } const auto tempNext = std::move(*this->m_next); @@ -219,6 +199,10 @@ HTTP::Request TransportContextV1::nextRequest() { void TransportContextV1::respond(const Response& response) { + if (this->m_next != nullptr) { + throw Lambda::Error("respond() canceled: Before responding to a request one must be read with nextRequest() call first"); + } + auto applyEncoding = ContentEncodings::None; auto responseHeaders = response.headers; @@ -296,10 +280,61 @@ void TransportContextV1::respond(const Response& response) { if (bodySize) this->m_conn.write(responseBody); } +const Network::ConnectionInfo& TransportContextV1::conninfo() const noexcept { + return this->m_conn.info(); +} + +const TransportOptions& TransportContextV1::options() const noexcept { + return this->m_topts; +} + +Network::TCP::Connection& TransportContextV1::tcpconn() const noexcept { + return this->m_conn; +} + +const ContentEncodings& TransportContextV1::getEnconding() const noexcept { + return this->m_compress; +} + +bool TransportContextV1::ok() const noexcept { + return this->m_conn.active(); +} + void TransportContextV1::reset() noexcept { + this->m_readbuff.clear(); + + if (this->m_next != nullptr) { + delete this->m_next; + this->m_next = nullptr; + } } bool TransportContextV1::hasPartialData() const noexcept { return this->m_readbuff.size() > 0; } + +void TransportContextV1::close() { + this->m_conn.end(); +} + +void TransportContextV1::writeRaw(const std::vector& data) { + this->m_conn.write(data); +} + +std::vector TransportContextV1::readRaw() { + return this->readRaw(Network::TCP::Connection::ReadChunkSize); +} + +std::vector TransportContextV1::readRaw(size_t expectedSize) { + + auto bufferHave = std::move(this->m_readbuff); + this->m_readbuff = {}; + + if (bufferHave.size() < expectedSize) { + auto bufferFetched = this->m_conn.read(expectedSize - bufferHave.size()); + bufferHave.insert(bufferHave.end(), bufferFetched.begin(), bufferFetched.end()); + } + + return bufferHave; +} diff --git a/core/sse/sse.hpp b/core/sse/sse.hpp index 721d7016c..c3cd1860f 100644 --- a/core/sse/sse.hpp +++ b/core/sse/sse.hpp @@ -19,7 +19,7 @@ namespace Lambda::SSE { class Writer { private: - Network::TCP::Connection& m_conn; + HTTP::Transport::TransportContext& transport; public: Writer(HTTP::Transport::TransportContext& tctx, const HTTP::Request initRequest); diff --git a/core/sse/writer.cpp b/core/sse/writer.cpp index ef6e2c0ce..7175b4178 100644 --- a/core/sse/writer.cpp +++ b/core/sse/writer.cpp @@ -5,7 +5,7 @@ using namespace Lambda; using namespace Lambda::Network; using namespace Lambda::SSE; -Writer::Writer(HTTP::Transport::TransportContext& tctx, const HTTP::Request initRequest) : m_conn(tctx.tcpconn()) { +Writer::Writer(HTTP::Transport::TransportContext& tctx, const HTTP::Request initRequest) : transport(tctx) { tctx.flags.autocompress = false; tctx.flags.forceContentLength = false; @@ -15,8 +15,7 @@ Writer::Writer(HTTP::Transport::TransportContext& tctx, const HTTP::Request init auto upgradeResponse = HTTP::Response(200, { { "connection", "keep-alive" }, { "cache-control", "no-cache" }, - { "content-type", "text/event-stream; charset=UTF-8" }, - { "pragma", "no-cache" }, + { "content-type", "text/event-stream; charset=UTF-8" } }); if (originHeader.size()) { @@ -28,7 +27,7 @@ Writer::Writer(HTTP::Transport::TransportContext& tctx, const HTTP::Request init void Writer::push(const EventMessage& event) { - if (!this->m_conn.active()) { + if (!this->transport.ok()) { throw Lambda::Error("SSE listener disconnected"); } @@ -63,17 +62,21 @@ void Writer::push(const EventMessage& event) { serializedMessage.insert(serializedMessage.end(), lineSeparator.begin(), lineSeparator.end()); try { - this->m_conn.write(serializedMessage); + this->transport.writeRaw(serializedMessage); } catch(...) { - this->m_conn.end(); + this->transport.tcpconn().end(); } } bool Writer::connected() const noexcept { - return this->m_conn.active(); + return this->transport.ok(); } void Writer::close() { - this->push({ "", "close" }); - this->m_conn.end(); + + EventMessage closeEvent; + closeEvent.event = "close"; + + this->push(closeEvent); + this->transport.tcpconn().end(); } diff --git a/core/websocket/context.cpp b/core/websocket/context.cpp index ee593c56a..41f5f9062 100644 --- a/core/websocket/context.cpp +++ b/core/websocket/context.cpp @@ -23,7 +23,7 @@ static const std::string wsMagicString = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; static const time_t sockRcvTimeout = 100; WebsocketContext::WebsocketContext(HTTP::Transport::TransportContext& tctx, const HTTP::Request initRequest) - : conn(tctx.tcpconn()), topts(tctx.options()) { + : transport(tctx), topts(tctx.options()) { auto headerUpgrade = Strings::toLowerCase(initRequest.headers.get("Upgrade")); auto headerWsKey = initRequest.headers.get("Sec-WebSocket-Key"); @@ -49,8 +49,8 @@ WebsocketContext::WebsocketContext(HTTP::Transport::TransportContext& tctx, cons tctx.respond(handshakeReponse); tctx.reset(); - this->conn.flags.closeOnTimeout = false; - this->conn.setTimeouts(sockRcvTimeout, Network::SetTimeoutsDirection::Receive); + this->transport.tcpconn().flags.closeOnTimeout = false; + this->transport.tcpconn().setTimeouts(sockRcvTimeout, Network::SetTimeoutsDirection::Receive); this->m_reader = std::async(&WebsocketContext::asyncWorker, this); } @@ -83,7 +83,7 @@ void WebsocketContext::close(Websocket::CloseReason reason) { closeMessageBuff.insert(closeMessageBuff.end(), closeReasonBuff.begin(), closeReasonBuff.end()); - this->conn.write(closeMessageBuff); + this->transport.writeRaw(closeMessageBuff); if (this->m_reader.valid()) { try { this->m_reader.get(); } diff --git a/core/websocket/transport.cpp b/core/websocket/transport.cpp index cd26c38dd..58b82e424 100644 --- a/core/websocket/transport.cpp +++ b/core/websocket/transport.cpp @@ -10,6 +10,7 @@ static const std::string wsPingString = "ping/lambda/ws"; // these values are used for both pings and actual receive timeouts static const time_t wsActTimeout = 5000; static const unsigned short wsMaxSkippedAttempts = 3; +static const size_t wsReadChunk = 256; static const std::initializer_list supportedWsOpcodes = { OpCode::Binary, @@ -22,7 +23,7 @@ static const std::initializer_list supportedWsOpcodes = { void WebsocketContext::sendMessage(const Websocket::Message& msg) { auto writeBuff = serializeMessage(msg); - this->conn.write(writeBuff); + this->transport.writeRaw(writeBuff); } FrameHeader Transport::parseFrameHeader(const std::vector& buffer) { @@ -115,7 +116,7 @@ void WebsocketContext::asyncWorker() { auto lastPingResponse = std::chrono::steady_clock::now(); auto pingWindow = std::chrono::milliseconds(wsMaxSkippedAttempts * wsActTimeout); - while (this->conn.active() && !this->m_stopped) { + while (this->transport.ok() && !this->m_stopped) { // send ping or terminate websocket if there is no response if ((lastPing - lastPingResponse) > pingWindow) { @@ -131,13 +132,13 @@ void WebsocketContext::asyncWorker() { wsPingString.size() }); - this->conn.write(pingHeader); - this->conn.write(std::vector(wsPingString.begin(), wsPingString.end())); + this->transport.writeRaw(pingHeader); + this->transport.writeRaw(std::vector(wsPingString.begin(), wsPingString.end())); lastPing = std::chrono::steady_clock::now(); } - auto nextChunk = this->conn.read(); + auto nextChunk = this->transport.readRaw(wsReadChunk); if (!nextChunk.size()) continue; downloadBuff.insert(downloadBuff.end(), nextChunk.begin(), nextChunk.end()); @@ -145,7 +146,7 @@ void WebsocketContext::asyncWorker() { if (downloadBuff.size() > this->topts.maxRequestSize) { this->close(CloseReason::MessageTooBig); - throw std::runtime_error("expected frame size too large"); + throw std::runtime_error("Expected frame size too large"); } auto frameHeader = parseFrameHeader(downloadBuff); @@ -174,7 +175,7 @@ void WebsocketContext::asyncWorker() { if (frameHeader.payloadSize + frameHeader.payloadSize < downloadBuff.size()) { auto expectedSize = frameHeader.payloadSize - payloadBuff.size(); - auto payloadChunk = this->conn.read(expectedSize); + auto payloadChunk = this->transport.readRaw(expectedSize); if (payloadChunk.size() < expectedSize) { this->close(CloseReason::ProtocolError); @@ -225,8 +226,8 @@ void WebsocketContext::asyncWorker() { frameHeader.payloadSize }); - this->conn.write(pongHeader); - this->conn.write(payloadBuff); + this->transport.writeRaw(pongHeader); + this->transport.writeRaw(payloadBuff); } break; diff --git a/core/websocket/websocket.hpp b/core/websocket/websocket.hpp index 58b9b3e0f..7e50c8119 100644 --- a/core/websocket/websocket.hpp +++ b/core/websocket/websocket.hpp @@ -45,7 +45,7 @@ namespace Lambda::Websocket { class WebsocketContext { private: - Network::TCP::Connection& conn; + HTTP::Transport::TransportContext& transport; const HTTP::Transport::TransportOptions& topts; std::future m_reader; std::queue m_queue;