From e4fd251825d741350746c474881f8243b843d0d6 Mon Sep 17 00:00:00 2001 From: Osip Fatkullin Date: Wed, 29 Jan 2025 16:18:22 +0100 Subject: [PATCH 1/2] KTOR-6970 Darwin, Java, JS: Propagate Sec-WebSocket-Protocol header (#4633) --- .../ktor/client/engine/js/JsClientEngine.kt | 7 +++- .../client/engine/js/WasmJsClientEngine.kt | 7 +++- .../engine/darwin/KtorNSURLSessionDelegate.kt | 13 ++++--- .../darwin/internal/DarwinWebsocketSession.kt | 27 ++++++++----- .../client/engine/java/JavaHttpWebSocket.kt | 38 ++++++++++++++----- .../io/ktor/client/tests/WebSocketTest.kt | 19 ++++++++++ 6 files changed, 83 insertions(+), 28 deletions(-) diff --git a/ktor-client/ktor-client-core/js/src/io/ktor/client/engine/js/JsClientEngine.kt b/ktor-client/ktor-client-core/js/src/io/ktor/client/engine/js/JsClientEngine.kt index 87b52e43dc..84c2edc257 100644 --- a/ktor-client/ktor-client-core/js/src/io/ktor/client/engine/js/JsClientEngine.kt +++ b/ktor-client/ktor-client-core/js/src/io/ktor/client/engine/js/JsClientEngine.kt @@ -73,7 +73,7 @@ internal class JsClientEngine( headers: Headers ): WebSocket { val protocolHeaderNames = headers.names().filter { headerName -> - headerName.equals("sec-websocket-protocol", true) + headerName.equals(HttpHeaders.SecWebSocketProtocol, ignoreCase = true) } val protocols = protocolHeaderNames.mapNotNull { headers.getAll(it) }.flatten().toTypedArray() return when { @@ -108,10 +108,13 @@ internal class JsClientEngine( throw cause } + val protocol = socket.protocol.takeIf { it.isNotEmpty() } + val headers = if (protocol != null) headersOf(HttpHeaders.SecWebSocketProtocol, protocol) else Headers.Empty + return HttpResponseData( HttpStatusCode.SwitchingProtocols, requestTime, - Headers.Empty, + headers, HttpProtocolVersion.HTTP_1_1, session, callContext diff --git a/ktor-client/ktor-client-core/wasmJs/src/io/ktor/client/engine/js/WasmJsClientEngine.kt b/ktor-client/ktor-client-core/wasmJs/src/io/ktor/client/engine/js/WasmJsClientEngine.kt index de007d750b..6435738a8a 100644 --- a/ktor-client/ktor-client-core/wasmJs/src/io/ktor/client/engine/js/WasmJsClientEngine.kt +++ b/ktor-client/ktor-client-core/wasmJs/src/io/ktor/client/engine/js/WasmJsClientEngine.kt @@ -82,7 +82,7 @@ internal class JsClientEngine( headers: Headers ): WebSocket { val protocolHeaderNames = headers.names().filter { headerName -> - headerName.equals("sec-websocket-protocol", true) + headerName.equals(HttpHeaders.SecWebSocketProtocol, ignoreCase = true) } val protocols = protocolHeaderNames.mapNotNull { headers.getAll(it) }.flatten().toTypedArray() return when { @@ -116,10 +116,13 @@ internal class JsClientEngine( val session = JsWebSocketSession(callContext, socket) + val protocol = socket.protocol.takeIf { it.isNotEmpty() } + val headers = if (protocol != null) headersOf(HttpHeaders.SecWebSocketProtocol, protocol) else Headers.Empty + return HttpResponseData( HttpStatusCode.SwitchingProtocols, requestTime, - Headers.Empty, + headers, HttpProtocolVersion.HTTP_1_1, session, callContext diff --git a/ktor-client/ktor-client-darwin/darwin/src/io/ktor/client/engine/darwin/KtorNSURLSessionDelegate.kt b/ktor-client/ktor-client-darwin/darwin/src/io/ktor/client/engine/darwin/KtorNSURLSessionDelegate.kt index 182133eb3c..e5daebab7f 100644 --- a/ktor-client/ktor-client-darwin/darwin/src/io/ktor/client/engine/darwin/KtorNSURLSessionDelegate.kt +++ b/ktor-client/ktor-client-darwin/darwin/src/io/ktor/client/engine/darwin/KtorNSURLSessionDelegate.kt @@ -1,5 +1,5 @@ /* - * Copyright 2014-2022 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + * Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. */ package io.ktor.client.engine.darwin @@ -7,11 +7,12 @@ package io.ktor.client.engine.darwin import io.ktor.client.engine.darwin.internal.* import io.ktor.client.request.* import io.ktor.util.collections.* -import kotlinx.cinterop.* -import kotlinx.coroutines.* +import kotlinx.cinterop.UnsafeNumber +import kotlinx.coroutines.CompletableDeferred import platform.Foundation.* -import platform.darwin.* -import kotlin.coroutines.* +import platform.darwin.NSObject +import kotlin.collections.set +import kotlin.coroutines.CoroutineContext private const val HTTP_REQUESTS_INITIAL_CAPACITY = 32 private const val WS_REQUESTS_INITIAL_CAPACITY = 16 @@ -77,7 +78,7 @@ public class KtorNSURLSessionDelegate( didOpenWithProtocol: String? ) { val wsSession = webSocketSessions[webSocketTask] ?: return - wsSession.didOpen() + wsSession.didOpen(didOpenWithProtocol) } override fun URLSession( diff --git a/ktor-client/ktor-client-darwin/darwin/src/io/ktor/client/engine/darwin/internal/DarwinWebsocketSession.kt b/ktor-client/ktor-client-darwin/darwin/src/io/ktor/client/engine/darwin/internal/DarwinWebsocketSession.kt index 943735e5ab..0bb5f343e2 100644 --- a/ktor-client/ktor-client-darwin/darwin/src/io/ktor/client/engine/darwin/internal/DarwinWebsocketSession.kt +++ b/ktor-client/ktor-client-darwin/darwin/src/io/ktor/client/engine/darwin/internal/DarwinWebsocketSession.kt @@ -1,5 +1,5 @@ /* - * Copyright 2014-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + * Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. */ package io.ktor.client.engine.darwin.internal @@ -10,13 +10,20 @@ import io.ktor.http.* import io.ktor.util.date.* import io.ktor.utils.io.core.* import io.ktor.websocket.* -import kotlinx.cinterop.* +import kotlinx.cinterop.ExperimentalForeignApi +import kotlinx.cinterop.UnsafeNumber +import kotlinx.cinterop.convert import kotlinx.coroutines.* -import kotlinx.coroutines.channels.* -import kotlinx.io.* +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.ReceiveChannel +import kotlinx.coroutines.channels.SendChannel +import kotlinx.coroutines.channels.consumeEach +import kotlinx.io.readByteArray import platform.Foundation.* -import platform.darwin.* -import kotlin.coroutines.* +import platform.darwin.NSInteger +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.resume +import kotlin.coroutines.resumeWithException @OptIn(UnsafeNumber::class, ExperimentalForeignApi::class) internal class DarwinWebsocketSession( @@ -157,11 +164,13 @@ internal class DarwinWebsocketSession( coroutineContext.cancel() } - fun didOpen() { + fun didOpen(protocol: String?) { + val headers = if (protocol != null) headersOf(HttpHeaders.SecWebSocketProtocol, protocol) else Headers.Empty + val response = HttpResponseData( task.getStatusCode()?.let { HttpStatusCode.fromValue(it) } ?: HttpStatusCode.SwitchingProtocols, requestTime, - Headers.Empty, + headers, HttpProtocolVersion.HTTP_1_1, this, coroutineContext @@ -177,7 +186,7 @@ internal class DarwinWebsocketSession( // KTOR-7363 We want to proceed with the request if we get 401 Unauthorized status code if (task.getStatusCode() == HttpStatusCode.Unauthorized.value) { - didOpen() + didOpen(protocol = null) socketJob.complete() return } diff --git a/ktor-client/ktor-client-java/jvm/src/io/ktor/client/engine/java/JavaHttpWebSocket.kt b/ktor-client/ktor-client-java/jvm/src/io/ktor/client/engine/java/JavaHttpWebSocket.kt index a4d7ab9fb9..5e9ce10c6c 100644 --- a/ktor-client/ktor-client-java/jvm/src/io/ktor/client/engine/java/JavaHttpWebSocket.kt +++ b/ktor-client/ktor-client-java/jvm/src/io/ktor/client/engine/java/JavaHttpWebSocket.kt @@ -1,6 +1,6 @@ /* -* Copyright 2014-2021 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. -*/ + * Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ package io.ktor.client.engine.java @@ -15,14 +15,18 @@ import io.ktor.utils.io.* import io.ktor.utils.io.core.* import io.ktor.websocket.* import kotlinx.coroutines.* -import kotlinx.coroutines.channels.* -import kotlinx.coroutines.future.* +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.ReceiveChannel +import kotlinx.coroutines.channels.SendChannel +import kotlinx.coroutines.channels.consumeEach +import kotlinx.coroutines.future.asCompletableFuture +import kotlinx.coroutines.future.await import java.net.http.* -import java.nio.* -import java.time.* +import java.nio.ByteBuffer +import java.time.Duration import java.util.* -import java.util.concurrent.* -import kotlin.coroutines.* +import java.util.concurrent.CompletionStage +import kotlin.coroutines.CoroutineContext import kotlin.text.String import kotlin.text.toByteArray @@ -92,9 +96,11 @@ internal class JavaHttpWebSocket( FrameType.TEXT -> { webSocket.sendText(String(frame.data), frame.fin).await() } + FrameType.BINARY -> { webSocket.sendBinary(frame.buffer, frame.fin).await() } + FrameType.CLOSE -> { val data = buildPacket { writeFully(frame.data) } val code = data.readShort().toInt() @@ -103,9 +109,11 @@ internal class JavaHttpWebSocket( socketJob.complete() return@launch } + FrameType.PING -> { webSocket.sendPing(frame.buffer).await() } + FrameType.PONG -> { webSocket.sendPong(frame.buffer).await() } @@ -153,11 +161,15 @@ internal class JavaHttpWebSocket( } var status = HttpStatusCode.SwitchingProtocols + var headers: Headers try { webSocket = builder.buildAsync(requestData.url.toURI(), this).await() + val protocol = webSocket.subprotocol?.takeIf { it.isNotEmpty() } + headers = if (protocol != null) headersOf(HttpHeaders.SecWebSocketProtocol, protocol) else Headers.Empty } catch (cause: WebSocketHandshakeException) { if (cause.response.statusCode() == HttpStatusCode.Unauthorized.value) { status = HttpStatusCode.Unauthorized + headers = headersOf(cause.response.headers().map()) } else { throw cause } @@ -166,7 +178,7 @@ internal class JavaHttpWebSocket( return HttpResponseData( status, requestTime, - Headers.Empty, + headers, HttpProtocolVersion.HTTP_1_1, this, callContext @@ -217,3 +229,11 @@ internal class JavaHttpWebSocket( socketJob.cancel() } } + +private fun headersOf(map: Map>): Headers = object : Headers { + override val caseInsensitiveName: Boolean = true + override fun getAll(name: String): List? = map[name] + override fun names(): Set = map.keys + override fun entries(): Set>> = map.entries + override fun isEmpty(): Boolean = map.isEmpty() +} diff --git a/ktor-client/ktor-client-tests/common/test/io/ktor/client/tests/WebSocketTest.kt b/ktor-client/ktor-client-tests/common/test/io/ktor/client/tests/WebSocketTest.kt index 8371f15df4..2896ec86c0 100644 --- a/ktor-client/ktor-client-tests/common/test/io/ktor/client/tests/WebSocketTest.kt +++ b/ktor-client/ktor-client-tests/common/test/io/ktor/client/tests/WebSocketTest.kt @@ -311,6 +311,25 @@ class WebSocketTest : ClientLoader() { } } + @Test + fun testResponseContainsSecWebsocketProtocolHeader() = clientTests(except(ENGINES_WITHOUT_WS)) { + config { + install(WebSockets) + } + + test { client -> + val session = client.webSocketSession("$TEST_WEBSOCKET_SERVER/websockets/sub-protocol") { + header(HttpHeaders.SecWebSocketProtocol, "test-protocol") + } + + try { + assertEquals(session.call.response.headers[HttpHeaders.SecWebSocketProtocol], "test-protocol") + } finally { + session.close() + } + } + } + @Test fun testIncomingOverflow() = clientTests(except(ENGINES_WITHOUT_WS)) { config { From 1c4f5b9a94e52082097f9beffaba51b9301a7360 Mon Sep 17 00:00:00 2001 From: Bruce Hamilton Date: Tue, 28 Jan 2025 14:13:51 +0100 Subject: [PATCH 2/2] KTOR-8105 Fix concurrent flush attempts corrupting segment pool --- .../jvm/src/io/ktor/network/sockets/NIOSocketImpl.kt | 1 - .../test/io/ktor/network/sockets/tests/ServerSocketTest.kt | 7 ++++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ktor-network/jvm/src/io/ktor/network/sockets/NIOSocketImpl.kt b/ktor-network/jvm/src/io/ktor/network/sockets/NIOSocketImpl.kt index 9a2edbbc10..1d87c544ba 100644 --- a/ktor-network/jvm/src/io/ktor/network/sockets/NIOSocketImpl.kt +++ b/ktor-network/jvm/src/io/ktor/network/sockets/NIOSocketImpl.kt @@ -62,7 +62,6 @@ internal abstract class NIOSocketImpl( override fun close() { if (!closeFlag.compareAndSet(false, true)) return - readerJob.get()?.channel?.close() writerJob.get()?.cancel() checkChannels() } diff --git a/ktor-network/jvm/test/io/ktor/network/sockets/tests/ServerSocketTest.kt b/ktor-network/jvm/test/io/ktor/network/sockets/tests/ServerSocketTest.kt index 501d911555..b74ecb41c0 100644 --- a/ktor-network/jvm/test/io/ktor/network/sockets/tests/ServerSocketTest.kt +++ b/ktor-network/jvm/test/io/ktor/network/sockets/tests/ServerSocketTest.kt @@ -6,6 +6,7 @@ package io.ktor.network.sockets.tests import io.ktor.network.selector.* import io.ktor.network.sockets.* +import io.ktor.util.cio.* import io.ktor.utils.io.* import kotlinx.coroutines.* import kotlinx.coroutines.debug.junit5.CoroutinesTimeout @@ -15,7 +16,6 @@ import java.nio.channels.ClosedChannelException import java.util.concurrent.CancellationException import java.util.concurrent.CountDownLatch import java.util.concurrent.Executors -import kotlin.concurrent.Volatile import kotlin.concurrent.thread import kotlin.coroutines.CoroutineContext import kotlin.test.AfterTest @@ -90,8 +90,9 @@ class ServerSocketTest : CoroutineScope { @Test fun testWrite() { val server = server { client -> - val channel = client.openWriteChannel(true) - channel.writeStringUtf8("123") + client.openWriteChannel(true).use { + writeStringUtf8("123") + } } client { socket ->