Skip to content

Commit

Permalink
Single channel handler to manage tasks (#76)
Browse files Browse the repository at this point in the history
* Move timeout to MQTTTask

* Combine all task handlers into one with a list of tasks

* remove tasks on eventLoop

* Add testMultipleTasks

* swift format

* Remove task from list, before completing it

* Fixed docker-compose file
  • Loading branch information
adam-fowler authored Nov 1, 2021
1 parent 483a171 commit 75f32f5
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 132 deletions.
26 changes: 26 additions & 0 deletions Sources/MQTTNIO/ChannelHandlers/MQTTEncoderHandler.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import Logging
import NIO

/// Handler encoding MQTT Messages into ByteBuffers
final class MQTTEncodeHandler: ChannelOutboundHandler {
public typealias OutboundIn = MQTTPacket
public typealias OutboundOut = ByteBuffer

let client: MQTTClient

init(client: MQTTClient) {
self.client = client
}

func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
let message = unwrapOutboundIn(data)
self.client.logger.trace("MQTT Out", metadata: ["mqtt_message": .string("\(message)"), "mqtt_packet_id": .string("\(message.packetId)")])
var bb = context.channel.allocator.buffer(capacity: 0)
do {
try message.write(version: self.client.configuration.version, to: &bb)
context.write(wrapOutboundOut(bb), promise: promise)
} catch {
promise?.fail(error)
}
}
}
Original file line number Diff line number Diff line change
@@ -1,30 +1,6 @@
import Logging
import NIO

/// Handler encoding MQTT Messages into ByteBuffers
final class MQTTEncodeHandler: ChannelOutboundHandler {
public typealias OutboundIn = MQTTPacket
public typealias OutboundOut = ByteBuffer

let client: MQTTClient

init(client: MQTTClient) {
self.client = client
}

func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
let message = unwrapOutboundIn(data)
self.client.logger.trace("MQTT Out", metadata: ["mqtt_message": .string("\(message)"), "mqtt_packet_id": .string("\(message.packetId)")])
var bb = context.channel.allocator.buffer(capacity: 0)
do {
try message.write(version: self.client.configuration.version, to: &bb)
context.write(wrapOutboundOut(bb), promise: promise)
} catch {
promise?.fail(error)
}
}
}

/// Decode ByteBuffers into MQTT Messages
struct ByteToMQTTMessageDecoder: ByteToMessageDecoder {
typealias InboundOut = MQTTPacket
Expand Down
90 changes: 90 additions & 0 deletions Sources/MQTTNIO/ChannelHandlers/MQTTTaskHandler.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import NIO

final class MQTTTaskHandler: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = MQTTPacket

var eventLoop: EventLoop!

init() {
self.eventLoop = nil
self.tasks = []
}

func addTask(_ task: MQTTTask) -> EventLoopFuture<Void> {
return self.eventLoop.submit {
self.tasks.append(task)
}
}

func _removeTask(_ task: MQTTTask) {
self.tasks.removeAll { $0 === task }
}

func removeTask(_ task: MQTTTask) {
if self.eventLoop.inEventLoop {
self._removeTask(task)
} else {
self.eventLoop.execute {
self._removeTask(task)
}
}
}

func handlerAdded(context: ChannelHandlerContext) {
self.eventLoop = context.eventLoop
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let response = self.unwrapInboundIn(data)
for task in self.tasks {
do {
if try task.checkInbound(response) {
self.removeTask(task)
task.succeed(response)
return
}
} catch {
self.removeTask(task)
task.fail(error)
return
}
}
}

func channelInactive(context: ChannelHandlerContext) {
self.tasks.forEach { $0.fail(MQTTError.serverClosedConnection) }
self.tasks.removeAll()
}

func errorCaught(context: ChannelHandlerContext, error: Error) {
self.tasks.forEach { $0.fail(error) }
self.tasks.removeAll()
}

var tasks: [MQTTTask]
}

/// If packet reaches this handler then it was never dealt with by a task
final class MQTTUnhandledPacketHandler: ChannelInboundHandler {
typealias InboundIn = MQTTPacket
let client: MQTTClient

init(client: MQTTClient) {
self.client = client
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
// we only send response to v5 server
guard self.client.configuration.version == .v5_0 else { return }
guard let connection = client.connection else { return }
let response = self.unwrapInboundIn(data)
switch response.type {
case .PUBREC:
_ = connection.sendMessageNoWait(MQTTPubAckPacket(type: .PUBREL, packetId: response.packetId, reason: .packetIdentifierNotFound))
case .PUBREL:
_ = connection.sendMessageNoWait(MQTTPubAckPacket(type: .PUBCOMP, packetId: response.packetId, reason: .packetIdentifierNotFound))
default:
break
}
}
}
19 changes: 9 additions & 10 deletions Sources/MQTTNIO/MQTTConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@ import NIOWebSocket
final class MQTTConnection {
let channel: Channel
let timeout: TimeAmount?
let unhandledHandler: MQTTUnhandledPacketHandler
let taskHandler: MQTTTaskHandler

private init(channel: Channel, timeout: TimeAmount?, unhandledHandler: MQTTUnhandledPacketHandler) {
private init(channel: Channel, timeout: TimeAmount?, taskHandler: MQTTTaskHandler) {
self.channel = channel
self.timeout = timeout
self.unhandledHandler = unhandledHandler
self.taskHandler = taskHandler
}

static func create(client: MQTTClient, pingInterval: TimeAmount) -> EventLoopFuture<MQTTConnection> {
let unhandledHandler = MQTTUnhandledPacketHandler(client: client)
return self.createBootstrap(client: client, pingInterval: pingInterval, unhandledHandler: unhandledHandler)
.map { MQTTConnection(channel: $0, timeout: client.configuration.timeout, unhandledHandler: unhandledHandler) }
let taskHandler = MQTTTaskHandler()
return self.createBootstrap(client: client, pingInterval: pingInterval, taskHandler: taskHandler)
.map { MQTTConnection(channel: $0, timeout: client.configuration.timeout, taskHandler: taskHandler) }
}

static func createBootstrap(client: MQTTClient, pingInterval: TimeAmount, unhandledHandler: MQTTUnhandledPacketHandler) -> EventLoopFuture<Channel> {
static func createBootstrap(client: MQTTClient, pingInterval: TimeAmount, taskHandler: MQTTTaskHandler) -> EventLoopFuture<Channel> {
let eventLoop = client.eventLoopGroup.next()
let channelPromise = eventLoop.makePromise(of: Channel.self)
do {
Expand All @@ -41,7 +41,7 @@ final class MQTTConnection {
// Work out what handlers to add
var handlers: [ChannelHandler] = [
ByteToMessageHandler(ByteToMQTTMessageDecoder(client: client)),
unhandledHandler,
taskHandler,
MQTTEncodeHandler(client: client),
]
if !client.configuration.disablePing {
Expand Down Expand Up @@ -179,9 +179,8 @@ final class MQTTConnection {

func sendMessage(_ message: MQTTPacket, checkInbound: @escaping (MQTTPacket) throws -> Bool) -> EventLoopFuture<MQTTPacket> {
let task = MQTTTask(on: channel.eventLoop, timeout: self.timeout, checkInbound: checkInbound)
let taskHandler = MQTTTaskHandler(task: task, channel: channel)

self.channel.pipeline.addHandler(taskHandler, position: .before(self.unhandledHandler))
self.taskHandler.addTask(task)
.flatMap {
self.channel.writeAndFlush(message)
}
Expand Down
100 changes: 13 additions & 87 deletions Sources/MQTTNIO/MQTTTask.swift
Original file line number Diff line number Diff line change
@@ -1,108 +1,34 @@

import NIO

/// Class encapsulating a single task
final class MQTTTask {
let promise: EventLoopPromise<MQTTPacket>
let checkInbound: (MQTTPacket) throws -> Bool
let timeout: TimeAmount?
let timeoutTask: Scheduled<Void>?

init(on eventLoop: EventLoop, timeout: TimeAmount?, checkInbound: @escaping (MQTTPacket) throws -> Bool) {
self.promise = eventLoop.makePromise(of: MQTTPacket.self)
let promise = eventLoop.makePromise(of: MQTTPacket.self)
self.promise = promise
self.checkInbound = checkInbound
self.timeout = timeout
}

func succeed(_ response: MQTTPacket) {
self.promise.succeed(response)
}

func fail(_ error: Error) {
self.promise.fail(error)
}
}

final class MQTTTaskHandler: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = MQTTPacket

let task: MQTTTask
let channel: Channel
var timeoutTask: Scheduled<Void>?

init(task: MQTTTask, channel: Channel) {
self.task = task
self.channel = channel
self.timeoutTask = nil
}

public func handlerAdded(context: ChannelHandlerContext) {
self.addTimeoutTask()
}

public func handlerRemoved(context: ChannelHandlerContext) {
self.timeoutTask?.cancel()
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let response = self.unwrapInboundIn(data)
do {
if try self.task.checkInbound(response) {
self.channel.pipeline.removeHandler(self).whenSuccess { _ in
self.timeoutTask?.cancel()
self.task.succeed(response)
}
} else {
context.fireChannelRead(data)
}
} catch {
self.errorCaught(context: context, error: error)
}
}

func channelInactive(context: ChannelHandlerContext) {
self.task.fail(MQTTError.serverClosedConnection)
}

func errorCaught(context: ChannelHandlerContext, error: Error) {
self.timeoutTask?.cancel()
self.channel.pipeline.removeHandler(self).whenSuccess { _ in
self.task.fail(error)
}
}

func addTimeoutTask() {
if let timeout = task.timeout {
self.timeoutTask = self.channel.eventLoop.scheduleTask(in: timeout) {
self.channel.pipeline.removeHandler(self).whenSuccess { _ in
self.task.fail(MQTTError.timeout)
}
if let timeout = timeout {
self.timeoutTask = eventLoop.scheduleTask(in: timeout) {
promise.fail(MQTTError.timeout)
}
} else {
self.timeoutTask = nil
}
}
}

/// If packet reaches this handler then it was never dealt with by a task
final class MQTTUnhandledPacketHandler: ChannelInboundHandler {
typealias InboundIn = MQTTPacket
let client: MQTTClient

init(client: MQTTClient) {
self.client = client
func succeed(_ response: MQTTPacket) {
self.timeoutTask?.cancel()
self.promise.succeed(response)
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
// we only send response to v5 server
guard self.client.configuration.version == .v5_0 else { return }
guard let connection = client.connection else { return }
let response = self.unwrapInboundIn(data)
switch response.type {
case .PUBREC:
_ = connection.sendMessageNoWait(MQTTPubAckPacket(type: .PUBREL, packetId: response.packetId, reason: .packetIdentifierNotFound))
case .PUBREL:
_ = connection.sendMessageNoWait(MQTTPubAckPacket(type: .PUBCOMP, packetId: response.packetId, reason: .packetIdentifierNotFound))
default:
break
}
func fail(_ error: Error) {
self.timeoutTask?.cancel()
self.promise.fail(error)
}
}
8 changes: 4 additions & 4 deletions Tests/MQTTNIOTests/MQTTNIOTests+async.swift
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ final class AsyncMQTTNIOTests: XCTestCase {

let client = self.createClient(identifier: "testAsyncSequencePublishListener+async", version: .v5_0)
let client2 = self.createClient(identifier: "testAsyncSequencePublishListener+async2", version: .v5_0)
let payloadString = "Hello"

self.XCTRunAsyncAndBlock {
try await client.connect()
Expand All @@ -127,16 +126,17 @@ final class AsyncMQTTNIOTests: XCTestCase {
case .success(let publish):
var buffer = publish.payload
let string = buffer.readString(length: buffer.readableBytes)
XCTAssertEqual(string, payloadString)
print("Received: \(string ?? "nothing")")
expectation.fulfill()

case .failure(let error):
XCTFail("\(error)")
}
}
finishExpectation.fulfill()
}
try await client.publish(to: "TestSubject", payload: ByteBufferAllocator().buffer(string: payloadString), qos: .atLeastOnce)
try await client.publish(to: "TestSubject", payload: ByteBufferAllocator().buffer(string: payloadString), qos: .atLeastOnce)
try await client.publish(to: "TestSubject", payload: ByteBufferAllocator().buffer(string: "Hello"), qos: .atLeastOnce)
try await client.publish(to: "TestSubject", payload: ByteBufferAllocator().buffer(string: "Goodbye"), qos: .atLeastOnce)
try await client.disconnect()

self.wait(for: [expectation], timeout: 5.0)
Expand Down
14 changes: 13 additions & 1 deletion Tests/MQTTNIOTests/MQTTNIOTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,17 @@ final class MQTTNIOTests: XCTestCase {
try client.disconnect().wait()
}

func testMultipleTasks() throws {
let client = self.createClient(identifier: "testMultipleTasks")
defer { XCTAssertNoThrow(try client.syncShutdownGracefully()) }
_ = try client.connect().wait()
let publishFutures = (0..<16).map { client.publish(to: "test/multiple", payload: ByteBuffer(integer: $0), qos: .exactlyOnce) }
_ = client.ping()
try EventLoopFuture.andAllComplete(publishFutures, on: client.eventLoopGroup.next()).wait()
XCTAssertEqual(client.connection?.taskHandler.tasks.count, 0)
try client.disconnect().wait()
}

func testMQTTSubscribe() throws {
let client = self.createClient(identifier: "testMQTTSubscribe")
defer { XCTAssertNoThrow(try client.syncShutdownGracefully()) }
Expand Down Expand Up @@ -385,11 +396,12 @@ final class MQTTNIOTests: XCTestCase {
eventLoopGroupProvider: .createNew,
logger: self.logger
)
defer { XCTAssertNoThrow(try client.syncShutdownGracefully()) }

_ = try client.connect().wait()
_ = try client.subscribe(to: [.init(topicFilter: "#", qos: .exactlyOnce)]).wait()
Thread.sleep(forTimeInterval: 5)
try client.disconnect().wait()
try client.syncShutdownGracefully()
}

func testRawIPConnect() throws {
Expand Down
Loading

0 comments on commit 75f32f5

Please sign in to comment.