From d0db505b95b2156d2ddba0ceb2fb11e95799319f Mon Sep 17 00:00:00 2001 From: Osip Fatkullin Date: Tue, 10 Sep 2024 17:22:03 +0200 Subject: [PATCH] KTOR-7284 Deprecate witTestApplication and withApplication --- .../io/ktor/server/auth/jwt/JWTAuthTest.kt | 503 ++++++++---------- .../src/io/ktor/server/testing/TestEngine.kt | 25 +- .../testing/TestApplicationEngineTest.kt | 354 ++++++------ 3 files changed, 416 insertions(+), 466 deletions(-) diff --git a/ktor-server/ktor-server-plugins/ktor-server-auth-jwt/jvm/test/io/ktor/server/auth/jwt/JWTAuthTest.kt b/ktor-server/ktor-server-plugins/ktor-server-auth-jwt/jvm/test/io/ktor/server/auth/jwt/JWTAuthTest.kt index bf329b839b2..34fed54db4d 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-auth-jwt/jvm/test/io/ktor/server/auth/jwt/JWTAuthTest.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-auth-jwt/jvm/test/io/ktor/server/auth/jwt/JWTAuthTest.kt @@ -7,9 +7,10 @@ package io.ktor.server.auth.jwt import com.auth0.jwk.* import com.auth0.jwt.* import com.auth0.jwt.algorithms.* +import io.ktor.client.request.* +import io.ktor.client.statement.* import io.ktor.http.* import io.ktor.http.auth.* -import io.ktor.server.application.* import io.ktor.server.auth.* import io.ktor.server.response.* import io.ktor.server.routing.* @@ -23,387 +24,326 @@ import kotlin.test.* class JWTAuthTest { @Test - fun testJwtNoAuth() { - withApplication { - application.configureServerJwt() + fun testJwtNoAuth() = testApplication { + configureServerJwt() - val response = handleRequest { - uri = "/" - } + val response = client.request("/") - verifyResponseUnauthorized(response) - } + verifyResponseUnauthorized(response) } @Test - fun testJwtNoAuthCustomChallengeNoToken() { - withApplication { - application.configureServerJwt { - challenge { _, _ -> - call.respond(UnauthorizedResponse(HttpAuthHeader.basicAuthChallenge("custom1", Charsets.UTF_8))) - } + fun testJwtNoAuthCustomChallengeNoToken() = testApplication { + configureServerJwt { + challenge { _, _ -> + call.respond(UnauthorizedResponse(HttpAuthHeader.basicAuthChallenge("custom1", Charsets.UTF_8))) } + } - val response = handleRequest { - uri = "/" - } + val response = client.request("/") - verifyResponseUnauthorized(response) - assertEquals("Basic realm=custom1, charset=UTF-8", response.response.headers[HttpHeaders.WWWAuthenticate]) - } + verifyResponseUnauthorized(response) + assertEquals("Basic realm=custom1, charset=UTF-8", response.headers[HttpHeaders.WWWAuthenticate]) } @Test - fun testJwtMultipleNoAuthCustomChallengeNoToken() { - withApplication { - application.configureServerJwt { - challenge { _, _ -> - call.respond(UnauthorizedResponse(HttpAuthHeader.basicAuthChallenge("custom1", Charsets.UTF_8))) - } + fun testJwtMultipleNoAuthCustomChallengeNoToken() = testApplication { + configureServerJwt { + challenge { _, _ -> + call.respond(UnauthorizedResponse(HttpAuthHeader.basicAuthChallenge("custom1", Charsets.UTF_8))) } + } - val response = handleRequest { - uri = "/" - } + val response = client.request("/") - verifyResponseUnauthorized(response) - assertEquals("Basic realm=custom1, charset=UTF-8", response.response.headers[HttpHeaders.WWWAuthenticate]) - } + verifyResponseUnauthorized(response) + assertEquals("Basic realm=custom1, charset=UTF-8", response.headers[HttpHeaders.WWWAuthenticate]) } @Test - fun testJwtWithMultipleConfigurations() { + fun testJwtWithMultipleConfigurations() = testApplication { val validated = mutableSetOf() var currentPrincipal: (JWTCredential) -> Any? = { null } - withApplication { - application.install(Authentication) { - jwt(name = "first") { - realm = "realm1" - verifier(issuer, audience, algorithm) - validate { validated.add("1"); currentPrincipal(it) } - challenge { _, _ -> - call.respond(UnauthorizedResponse(HttpAuthHeader.basicAuthChallenge("custom1", Charsets.UTF_8))) - } - } - jwt(name = "second") { - realm = "realm2" - verifier(issuer, audience, algorithm) - validate { validated.add("2"); currentPrincipal(it) } - challenge { _, _ -> - call.respond(UnauthorizedResponse(HttpAuthHeader.basicAuthChallenge("custom2", Charsets.UTF_8))) - } + install(Authentication) { + jwt(name = "first") { + realm = "realm1" + verifier(issuer, audience, algorithm) + validate { validated.add("1"); currentPrincipal(it) } + challenge { _, _ -> + call.respond(UnauthorizedResponse(HttpAuthHeader.basicAuthChallenge("custom1", Charsets.UTF_8))) } } - - application.routing { - authenticate("first", "second") { - get("/") { - val principal = call.authentication.principal()!! - call.respondText("Secret info, ${principal.audience}") - } + jwt(name = "second") { + realm = "realm2" + verifier(issuer, audience, algorithm) + validate { validated.add("2"); currentPrincipal(it) } + challenge { _, _ -> + call.respond(UnauthorizedResponse(HttpAuthHeader.basicAuthChallenge("custom2", Charsets.UTF_8))) } } + } - val token = getToken() - handleRequestWithToken(token).let { call -> - verifyResponseUnauthorized(call) - assertEquals( - "Basic realm=custom1, charset=UTF-8", - call.response.headers[HttpHeaders.WWWAuthenticate] - ) + routing { + authenticate("first", "second") { + get("/") { + val principal = call.authentication.principal()!! + call.respondText("Secret info, ${principal.audience}") + } } - assertEquals(setOf("1", "2"), validated) + } - currentPrincipal = { JWTPrincipal(it.payload) } - validated.clear() + val token = getToken() + handleRequestWithToken(token).let { response -> + verifyResponseUnauthorized(response) + assertEquals( + "Basic realm=custom1, charset=UTF-8", + response.headers[HttpHeaders.WWWAuthenticate] + ) + } + assertEquals(setOf("1", "2"), validated) - handleRequestWithToken(token).let { call -> - assertEquals(HttpStatusCode.OK, call.response.status()) + currentPrincipal = { JWTPrincipal(it.payload) } + validated.clear() - assertEquals( - "Secret info, [$audience]", - call.response.content - ) + handleRequestWithToken(token).let { response -> + assertEquals(HttpStatusCode.OK, response.status) - assertNull(call.response.headers[HttpHeaders.WWWAuthenticate]) - } + assertEquals( + "Secret info, [$audience]", + response.bodyAsText() + ) - assertEquals(setOf("1"), validated) + assertNull(response.headers[HttpHeaders.WWWAuthenticate]) } + + assertEquals(setOf("1"), validated) } @Test - fun testJwtSuccess() { - withApplication { - application.configureServerJwt() + fun testJwtSuccess() = testApplication { + configureServerJwt() - val token = getToken() + val token = getToken() - val response = handleRequestWithToken(token) + val response = handleRequestWithToken(token) - assertEquals(HttpStatusCode.OK, response.response.status()) - assertNotNull(response.response.content) - } + assertEquals(HttpStatusCode.OK, response.status) + assertTrue(response.bodyAsText().isNotEmpty()) } @Test - fun testJwtSuccessWithCustomScheme() { - withApplication { - application.configureServerJwt { - authSchemes("Bearer", "Token") - } + fun testJwtSuccessWithCustomScheme() = testApplication { + configureServerJwt { + authSchemes("Bearer", "Token") + } - val token = getToken(scheme = "Token") + val token = getToken(scheme = "Token") - val response = handleRequestWithToken(token) + val response = handleRequestWithToken(token) - assertEquals(HttpStatusCode.OK, response.response.status()) - assertNotNull(response.response.content) - } + assertEquals(HttpStatusCode.OK, response.status) + assertTrue(response.bodyAsText().isNotEmpty()) } @Test - fun testJwtSuccessWithCustomSchemeWithDifferentCases() { - withApplication { - application.configureServerJwt { - authSchemes("Bearer", "tokEN") - } + fun testJwtSuccessWithCustomSchemeWithDifferentCases() = testApplication { + configureServerJwt { + authSchemes("Bearer", "tokEN") + } - val token = getToken(scheme = "TOKen") + val token = getToken(scheme = "TOKen") - val response = handleRequestWithToken(token) + val response = handleRequestWithToken(token) - assertEquals(HttpStatusCode.OK, response.response.status()) - assertNotNull(response.response.content) - } + assertEquals(HttpStatusCode.OK, response.status) + assertTrue(response.bodyAsText().isNotEmpty()) } @Test - fun testJwtAlgorithmMismatch() { - withApplication { - application.configureServerJwt() - val token = JWT.create().withAudience(audience).withIssuer(issuer).sign(Algorithm.HMAC256("false")) - val response = handleRequestWithToken(token) - verifyResponseUnauthorized(response) - } + fun testJwtAlgorithmMismatch() = testApplication { + configureServerJwt() + + val token = JWT.create().withAudience(audience).withIssuer(issuer).sign(Algorithm.HMAC256("false")) + val response = handleRequestWithToken(token) + verifyResponseUnauthorized(response) } @Test - fun testJwtAudienceMismatch() { - withApplication { - application.configureServerJwt() - val token = JWT.create().withAudience("wrong").withIssuer(issuer).sign(algorithm) - val response = handleRequestWithToken(token) - verifyResponseUnauthorized(response) - } + fun testJwtAudienceMismatch() = testApplication { + configureServerJwt() + val token = JWT.create().withAudience("wrong").withIssuer(issuer).sign(algorithm) + val response = handleRequestWithToken(token) + verifyResponseUnauthorized(response) } @Test - fun testJwtIssuerMismatch() { - withApplication { - application.configureServerJwt() - val token = JWT.create().withAudience(audience).withIssuer("wrong").sign(algorithm) - val response = handleRequestWithToken(token) - verifyResponseUnauthorized(response) - } + fun testJwtIssuerMismatch() = testApplication { + configureServerJwt() + val token = JWT.create().withAudience(audience).withIssuer("wrong").sign(algorithm) + val response = handleRequestWithToken(token) + verifyResponseUnauthorized(response) } @Test - fun testJwkNoAuth() { - withApplication { - application.configureServerJwk() + fun testJwkNoAuth() = testApplication { + configureServerJwk() - val response = handleRequest { - uri = "/" - } + val response = client.request("/") - verifyResponseUnauthorized(response) - } + verifyResponseUnauthorized(response) } @Test - fun testJwkSuccess() { - withApplication { - application.configureServerJwk(mock = true) + fun testJwkSuccess() = testApplication { + configureServerJwk(mock = true) - val token = getJwkToken() + val token = getJwkToken() - val response = handleRequestWithToken(token) + val response = handleRequestWithToken(token) - assertEquals(HttpStatusCode.OK, response.response.status()) - assertNotNull(response.response.content) - } + assertEquals(HttpStatusCode.OK, response.status) + assertTrue(response.bodyAsText().isNotEmpty()) } @Test - fun testJwkSuccessNoIssuer() { - withApplication { - application.configureServerJwkNoIssuer(mock = true) + fun testJwkSuccessNoIssuer() = testApplication { + configureServerJwkNoIssuer(mock = true) - val token = getJwkToken() + val token = getJwkToken() - val response = handleRequestWithToken(token) + val response = handleRequestWithToken(token) - assertEquals(HttpStatusCode.OK, response.response.status()) - assertNotNull(response.response.content) - } + assertEquals(HttpStatusCode.OK, response.status) + assertTrue(response.bodyAsText().isNotEmpty()) } @Test - fun testJwkSuccessWithLeeway() { - withApplication { - application.configureServerJwtWithLeeway(mock = true) + fun testJwkSuccessWithLeeway() = testApplication { + configureServerJwtWithLeeway(mock = true) - val token = getJwkToken() + val token = getJwkToken() - val response = handleRequestWithToken(token) + val response = handleRequestWithToken(token) - assertEquals(HttpStatusCode.OK, response.response.status()) - assertNotNull(response.response.content) - } + assertEquals(HttpStatusCode.OK, response.status) + assertTrue(response.bodyAsText().isNotEmpty()) } @Test - fun testJwtAuthSchemeMismatch() { - withApplication { - application.configureServerJwt() - val token = getToken().removePrefix("Bearer ") - val response = handleRequestWithToken(token) - verifyResponseUnauthorized(response) - } + fun testJwtAuthSchemeMismatch() = testApplication { + configureServerJwt() + val token = getToken().removePrefix("Bearer ") + val response = handleRequestWithToken(token) + verifyResponseUnauthorized(response) } @Test - fun testJwtAuthSchemeMismatch2() { - withApplication { - application.configureServerJwt() - val token = getToken("Token") - val response = handleRequestWithToken(token) - verifyResponseUnauthorized(response) - } + fun testJwtAuthSchemeMismatch2() = testApplication { + configureServerJwt() + val token = getToken("Token") + val response = handleRequestWithToken(token) + verifyResponseUnauthorized(response) } @Test - fun testJwtAuthSchemeMistake() { - withApplication { - application.configureServerJwt() - val token = getToken().replace("Bearer", "Bearer:") - val response = handleRequestWithToken(token) - verifyResponseBadRequest(response) - } + fun testJwtAuthSchemeMistake() = testApplication { + configureServerJwt() + val token = getToken().replace("Bearer", "Bearer:") + val response = handleRequestWithToken(token) + verifyResponseBadRequest(response) } @Test - fun testJwtBlobPatternMismatch() { - withApplication { - application.configureServerJwt() - val token = getToken().let { - val i = it.length - 2 - it.replaceRange(i..i + 1, " ") - } - val response = handleRequestWithToken(token) - verifyResponseUnauthorized(response) + fun testJwtBlobPatternMismatch() = testApplication { + configureServerJwt() + val token = getToken().let { + val i = it.length - 2 + it.replaceRange(i..i + 1, " ") } + val response = handleRequestWithToken(token) + verifyResponseUnauthorized(response) } @Test - fun testJwkAuthSchemeMismatch() { - withApplication { - application.configureServerJwk(mock = true) - val token = getJwkToken(false) - val response = handleRequestWithToken(token) - verifyResponseUnauthorized(response) - } + fun testJwkAuthSchemeMismatch() = testApplication { + configureServerJwk(mock = true) + val token = getJwkToken(false) + val response = handleRequestWithToken(token) + verifyResponseUnauthorized(response) } @Test - fun testJwkAuthSchemeMistake() { - withApplication { - application.configureServerJwk(mock = true) - val token = getJwkToken(true).replace("Bearer", "Bearer:") - val response = handleRequestWithToken(token) - verifyResponseBadRequest(response) - } + fun testJwkAuthSchemeMistake() = testApplication { + configureServerJwk(mock = true) + val token = getJwkToken(true).replace("Bearer", "Bearer:") + val response = handleRequestWithToken(token) + verifyResponseBadRequest(response) } @Test - fun testJwkBlobPatternMismatch() { - withApplication { - application.configureServerJwk(mock = true) - val token = getJwkToken(true).let { - val i = it.length - 2 - it.replaceRange(i..i + 1, " ") - } - val response = handleRequestWithToken(token) - verifyResponseUnauthorized(response) + fun testJwkBlobPatternMismatch() = testApplication { + configureServerJwk(mock = true) + val token = getJwkToken(true).let { + val i = it.length - 2 + it.replaceRange(i..i + 1, " ") } + val response = handleRequestWithToken(token) + verifyResponseUnauthorized(response) } @Test - fun testJwkAlgorithmMismatch() { - withApplication { - application.configureServerJwk(mock = true) - val token = JWT.create().withAudience(audience).withIssuer(issuer).sign(Algorithm.HMAC256("false")) - val response = handleRequestWithToken(token) - verifyResponseUnauthorized(response) - } + fun testJwkAlgorithmMismatch() = testApplication { + configureServerJwk(mock = true) + val token = JWT.create().withAudience(audience).withIssuer(issuer).sign(Algorithm.HMAC256("false")) + val response = handleRequestWithToken(token) + verifyResponseUnauthorized(response) } @Test - fun testJwkAudienceMismatch() { - withApplication { - application.configureServerJwk(mock = true) - val token = JWT.create().withAudience("wrong").withIssuer(issuer).sign(algorithm) - val response = handleRequestWithToken(token) - verifyResponseUnauthorized(response) - } + fun testJwkAudienceMismatch() = testApplication { + configureServerJwk(mock = true) + val token = JWT.create().withAudience("wrong").withIssuer(issuer).sign(algorithm) + val response = handleRequestWithToken(token) + verifyResponseUnauthorized(response) } @Test - fun testJwkIssuerMismatch() { - withApplication { - application.configureServerJwk(mock = true) - val token = JWT.create().withAudience(audience).withIssuer("wrong").sign(algorithm) - val response = handleRequestWithToken(token) - verifyResponseUnauthorized(response) - } + fun testJwkIssuerMismatch() = testApplication { + configureServerJwk(mock = true) + val token = JWT.create().withAudience(audience).withIssuer("wrong").sign(algorithm) + val response = handleRequestWithToken(token) + verifyResponseUnauthorized(response) } @Test - fun testJwkKidMismatch() { - withApplication { - application.configureServerJwk(mock = true) + fun testJwkKidMismatch() = testApplication { + configureServerJwk(mock = true) - val token = "Bearer " + JWT.create() - .withAudience(audience) - .withIssuer(issuer) - .withKeyId("wrong") - .sign(jwkAlgorithm) + val token = "Bearer " + JWT.create() + .withAudience(audience) + .withIssuer(issuer) + .withKeyId("wrong") + .sign(jwkAlgorithm) - val response = handleRequestWithToken(token) - verifyResponseUnauthorized(response) - } + val response = handleRequestWithToken(token) + verifyResponseUnauthorized(response) } @Test - fun testJwkInvalidToken() { - withApplication { - application.configureServerJwk(mock = true) - val token = "Bearer wrong" - val response = handleRequestWithToken(token) - verifyResponseUnauthorized(response) - } + fun testJwkInvalidToken() = testApplication { + configureServerJwk(mock = true) + val token = "Bearer wrong" + val response = handleRequestWithToken(token) + verifyResponseUnauthorized(response) } @Test - fun testJwkInvalidTokenCustomChallenge() { - withApplication { - application.configureServerJwk(mock = true, challenge = true) - val token = "Bearer wrong" - val response = handleRequestWithToken(token) - verifyResponseForbidden(response) - } + fun testJwkInvalidTokenCustomChallenge() = testApplication { + configureServerJwk(mock = true, challenge = true) + val token = "Bearer wrong" + val response = handleRequestWithToken(token) + verifyResponseForbidden(response) } @Test @@ -429,8 +369,8 @@ class JWTAuthTest { } @Test - fun authHeaderFromCookie(): Unit = withApplication { - application.configureServer { + fun authHeaderFromCookie(): Unit = testApplication { + configureServer { jwt { this@jwt.realm = this@JWTAuthTest.realm authHeader { call -> @@ -445,38 +385,37 @@ class JWTAuthTest { val token = getToken() - val response = handleRequest { - uri = "/" - addHeader(HttpHeaders.Cookie, "JWT=${token.encodeURLParameter()}") + val response = client.request("/") { + header(HttpHeaders.Cookie, "JWT=${token.encodeURLParameter()}") } - assertEquals(HttpStatusCode.OK, response.response.status()) - assertNotNull(response.response.content) + assertEquals(HttpStatusCode.OK, response.status) + assertTrue(response.bodyAsText().isNotEmpty()) } - private fun verifyResponseUnauthorized(response: TestApplicationCall) { - assertEquals(HttpStatusCode.Unauthorized, response.response.status()) - assertNull(response.response.content) + private suspend fun verifyResponseUnauthorized(response: HttpResponse) { + assertEquals(HttpStatusCode.Unauthorized, response.status) + assertTrue(response.bodyAsText().isEmpty()) } - private fun verifyResponseBadRequest(response: TestApplicationCall) { - assertEquals(HttpStatusCode.BadRequest, response.response.status()) - assertNull(response.response.content) + private suspend fun verifyResponseBadRequest(response: HttpResponse) { + assertEquals(HttpStatusCode.BadRequest, response.status) + assertTrue(response.bodyAsText().isEmpty()) } - private fun verifyResponseForbidden(response: TestApplicationCall) { - assertEquals(HttpStatusCode.Forbidden, response.response.status()) - assertNull(response.response.content) + private suspend fun verifyResponseForbidden(response: HttpResponse) { + assertEquals(HttpStatusCode.Forbidden, response.status) + assertTrue(response.bodyAsText().isEmpty()) } - private fun TestApplicationEngine.handleRequestWithToken(token: String): TestApplicationCall { - return handleRequest { - uri = "/" - addHeader(HttpHeaders.Authorization, token) + private suspend fun ApplicationTestBuilder.handleRequestWithToken(token: String): HttpResponse { + return client.request("/") { + header(HttpHeaders.Authorization, token) } } - private fun Application.configureServerJwk(mock: Boolean = false, challenge: Boolean = false) = configureServer { + private fun ApplicationTestBuilder.configureServerJwk(mock: Boolean = false, challenge: Boolean = false) = + configureServer { jwt { this@jwt.realm = this@JWTAuthTest.realm if (mock) { @@ -505,7 +444,7 @@ class JWTAuthTest { } } - private fun Application.configureServerJwkNoIssuer(mock: Boolean = false) = configureServer { + private fun ApplicationTestBuilder.configureServerJwkNoIssuer(mock: Boolean = false) = configureServer { jwt { this@jwt.realm = this@JWTAuthTest.realm if (mock) { @@ -523,7 +462,7 @@ class JWTAuthTest { } } - private fun Application.configureServerJwtWithLeeway(mock: Boolean = false) = configureServer { + private fun ApplicationTestBuilder.configureServerJwtWithLeeway(mock: Boolean = false) = configureServer { jwt { this@jwt.realm = this@JWTAuthTest.realm if (mock) { @@ -544,7 +483,7 @@ class JWTAuthTest { } } - private fun Application.configureServerJwt(extra: JWTAuthenticationProvider.Config.() -> Unit = {}) = + private fun ApplicationTestBuilder.configureServerJwt(extra: JWTAuthenticationProvider.Config.() -> Unit = {}) = configureServer { jwt { this@jwt.realm = this@JWTAuthTest.realm @@ -559,7 +498,7 @@ class JWTAuthTest { } } - private fun Application.configureServer(authBlock: (AuthenticationConfig.() -> Unit)) { + private fun ApplicationTestBuilder.configureServer(authBlock: (AuthenticationConfig.() -> Unit)) { install(Authentication) { authBlock(this) } diff --git a/ktor-server/ktor-server-test-host/common/src/io/ktor/server/testing/TestEngine.kt b/ktor-server/ktor-server-test-host/common/src/io/ktor/server/testing/TestEngine.kt index d3d25c2ef6b..db5ecabda79 100644 --- a/ktor-server/ktor-server-test-host/common/src/io/ktor/server/testing/TestEngine.kt +++ b/ktor-server/ktor-server-test-host/common/src/io/ktor/server/testing/TestEngine.kt @@ -2,8 +2,6 @@ * Copyright 2014-2019 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. */ -@file:Suppress("DEPRECATION") - package io.ktor.server.testing import io.ktor.events.* @@ -41,7 +39,10 @@ public fun TestApplicationEngine.handleRequest( /** * Starts a test application engine, passes it to the [test] function, and stops it. */ -@Deprecated("Please use new `testApplication` API: https://ktor.io/docs/migrating-2.html#testing-api") +@Deprecated( + "Use new `testApplication` API: https://ktor.io/docs/migrating-2.html#testing-api", + level = DeprecationLevel.ERROR, +) public fun withApplication( environment: ApplicationEnvironment = createTestEnvironment(), configure: TestApplicationEngine.Configuration.() -> Unit = {}, @@ -62,16 +63,24 @@ public fun withApplication( /** * Starts a test application engine, passes it to the [test] function, and stops it. */ -@Deprecated("Please use new `testApplication` API: https://ktor.io/docs/migrating-2.html#testing-api") +@Deprecated( + "Use new `testApplication` API: https://ktor.io/docs/migrating-2.html#testing-api", + level = DeprecationLevel.ERROR, +) public fun withTestApplication(test: TestApplicationEngine.() -> R): R { + @Suppress("DEPRECATION_ERROR") return withApplication(createTestEnvironment(), test = test) } /** * Starts a test application engine, passes it to the [test] function, and stops it. */ -@Deprecated("Please use new `testApplication` API: https://ktor.io/docs/migrating-2.html#testing-api") +@Deprecated( + "Use new `testApplication` API: https://ktor.io/docs/migrating-2.html#testing-api", + level = DeprecationLevel.ERROR, +) public fun withTestApplication(moduleFunction: Application.() -> Unit, test: TestApplicationEngine.() -> R): R { + @Suppress("DEPRECATION_ERROR") return withApplication(createTestEnvironment()) { moduleFunction(application) test() @@ -81,12 +90,16 @@ public fun withTestApplication(moduleFunction: Application.() -> Unit, test: /** * Starts a test application engine, passes it to the [test] function, and stops it. */ -@Deprecated("Please use new `testApplication` API: https://ktor.io/docs/migrating-2.html#testing-api") +@Deprecated( + "Use new `testApplication` API: https://ktor.io/docs/migrating-2.html#testing-api", + level = DeprecationLevel.ERROR, +) public fun withTestApplication( moduleFunction: Application.() -> Unit, configure: TestApplicationEngine.Configuration.() -> Unit = {}, test: TestApplicationEngine.() -> R ): R { + @Suppress("DEPRECATION_ERROR") return withApplication(createTestEnvironment(), configure) { moduleFunction(application) test() diff --git a/ktor-server/ktor-server-tests/jvm/test/io/ktor/server/testing/TestApplicationEngineTest.kt b/ktor-server/ktor-server-tests/jvm/test/io/ktor/server/testing/TestApplicationEngineTest.kt index 102ec0a49f8..2573c25498d 100644 --- a/ktor-server/ktor-server-tests/jvm/test/io/ktor/server/testing/TestApplicationEngineTest.kt +++ b/ktor-server/ktor-server-tests/jvm/test/io/ktor/server/testing/TestApplicationEngineTest.kt @@ -7,7 +7,6 @@ package io.ktor.server.testing import io.ktor.client.request.* import io.ktor.http.* import io.ktor.http.content.* -import io.ktor.server.application.* import io.ktor.server.plugins.calllogging.* import io.ktor.server.request.* import io.ktor.server.response.* @@ -18,7 +17,7 @@ import io.ktor.utils.io.charsets.* import io.ktor.utils.io.core.* import kotlinx.coroutines.* import kotlinx.io.* -import kotlinx.serialization.Serializable +import kotlinx.serialization.* import java.util.concurrent.atomic.* import kotlin.coroutines.* import kotlin.system.* @@ -26,8 +25,9 @@ import kotlin.test.* import kotlin.text.Charsets class TestApplicationEngineTest { + @Test - fun testCustomDispatcher() { + fun testCustomDispatcher() = testApplication { @OptIn(InternalCoroutinesApi::class) fun CoroutineDispatcher.withDelay(delay: Delay): CoroutineDispatcher = object : CoroutineDispatcher(), Delay by delay { @@ -41,247 +41,245 @@ class TestApplicationEngineTest { val delayLog = arrayListOf() val delayTime = 10_000L - withTestApplication( - moduleFunction = { - routing { - get("/") { - delay(delayTime) - delay(delayTime) - call.respondText("OK") - } - } - }, - configure = { - @OptIn(InternalCoroutinesApi::class) - dispatcher = Dispatchers.Unconfined.withDelay( - object : Delay { - override fun scheduleResumeAfterDelay( - timeMillis: Long, - continuation: CancellableContinuation - ) { - // Run immediately and log it - delayLog += "Delay($timeMillis)" - continuation.resume(Unit) - } - } - ) + routing { + get("/") { + delay(delayTime) + delay(delayTime) + call.respondText("OK") } - ) { - val elapsedTime = measureTimeMillis { - handleRequest(HttpMethod.Get, "/").let { call -> - assertTrue(call.response.status()!!.isSuccess()) + } + + engine { + @OptIn(InternalCoroutinesApi::class) + dispatcher = Dispatchers.Unconfined.withDelay( + object : Delay { + override fun scheduleResumeAfterDelay( + timeMillis: Long, + continuation: CancellableContinuation + ) { + // Run immediately and log it + delayLog += "Delay($timeMillis)" + continuation.resume(Unit) + } } + ) + } + + val engine = startApplicationAndGetEngine() + val elapsedTime = measureTimeMillis { + engine.handleRequest(HttpMethod.Get, "/").let { call -> + assertTrue(call.response.status()!!.isSuccess()) } - assertEquals(listOf("Delay($delayTime)", "Delay($delayTime)"), delayLog) - assertTrue { elapsedTime < (delayTime * 2) } } + assertEquals(listOf("Delay($delayTime)", "Delay($delayTime)"), delayLog) + assertTrue { elapsedTime < (delayTime * 2) } } @Test - fun testExceptionHandle() { - withTestApplication { - application.install(CallLogging) - application.routing { - get("/") { - error("Handle me") - } - } + fun testExceptionHandle() = testApplication { + install(CallLogging) - assertFails { - handleRequest(HttpMethod.Get, "/") { - } + routing { + get("/") { + error("Handle me") } } + + val engine = startApplicationAndGetEngine() + assertFails { + engine.handleRequest(HttpMethod.Get, "/") + } } @Test - fun testResponseAwait() { - withTestApplication { - application.install(RoutingRoot) { - get("/good") { - call.respond(HttpStatusCode.OK, "The Response") - } - get("/broken") { - delay(100) - call.respond(HttpStatusCode.OK, "The Response") - } - get("/fail") { - error("Handle me") - } + fun testResponseAwait() = testApplication { + install(RoutingRoot) { + get("/good") { + call.respond(HttpStatusCode.OK, "The Response") } - - with(handleRequest(HttpMethod.Get, "/good")) { - assertEquals(HttpStatusCode.OK, response.status()) - assertEquals("The Response", response.content) + get("/broken") { + delay(100) + call.respond(HttpStatusCode.OK, "The Response") } - - with(handleRequest(HttpMethod.Get, "/broken")) { - assertEquals(HttpStatusCode.OK, response.status()) - assertEquals("The Response", response.content) + get("/fail") { + error("Handle me") } + } - assertFailsWith { - handleRequest(HttpMethod.Get, "/fail") - } + val engine = startApplicationAndGetEngine() + with(engine.handleRequest(HttpMethod.Get, "/good")) { + assertEquals(HttpStatusCode.OK, response.status()) + assertEquals("The Response", response.content) + } + + with(engine.handleRequest(HttpMethod.Get, "/broken")) { + assertEquals(HttpStatusCode.OK, response.status()) + assertEquals("The Response", response.content) + } + + assertFailsWith { + engine.handleRequest(HttpMethod.Get, "/fail") } } @Test - fun testResponseAwaitWithCustomPort() { - withTestApplication { - application.install(RoutingRoot) { - port(7070) { - get("/good") { - call.respond(HttpStatusCode.OK, "The Response") - } + fun testResponseAwaitWithCustomPort() = testApplication { + install(RoutingRoot) { + port(7070) { + get("/good") { + call.respond(HttpStatusCode.OK, "The Response") } } + } - with(handleRequest(HttpMethod.Get, "/good") { port = 7070 }) { - assertEquals(HttpStatusCode.OK, response.status()) - assertEquals("The Response", response.content) - } + val engine = startApplicationAndGetEngine() + with(engine.handleRequest(HttpMethod.Get, "/good") { port = 7070 }) { + assertEquals(HttpStatusCode.OK, response.status()) + assertEquals("The Response", response.content) + } - with(handleRequest(HttpMethod.Get, "/good") { addHeader(HttpHeaders.Host, "localhost:7070") }) { - assertEquals(HttpStatusCode.OK, response.status()) - assertEquals("The Response", response.content) - } + with(engine.handleRequest(HttpMethod.Get, "/good") { addHeader(HttpHeaders.Host, "localhost:7070") }) { + assertEquals(HttpStatusCode.OK, response.status()) + assertEquals("The Response", response.content) } } @Test - fun testHookRequests() { + fun testHookRequests() = testApplication { val numberOfRequestsProcessed = AtomicInteger(0) val numberOfResponsesProcessed = AtomicInteger(0) - val dummyApplication: Application.() -> Unit = { - routing { - get("/") { - call.respond(HttpStatusCode.NoContent) - } - } - } - val expectedNumberOfCalls = 1 - withTestApplication(dummyApplication) { - // Injecting the hooks and checking they are invoked only once - hookRequests( - processRequest = { setup -> - numberOfRequestsProcessed.incrementAndGet() - setup() - }, - processResponse = { numberOfResponsesProcessed.incrementAndGet() } - ) { - handleRequest(HttpMethod.Get, "/").apply { - assertEquals(expectedNumberOfCalls, numberOfRequestsProcessed.get()) - assertEquals(expectedNumberOfCalls, numberOfResponsesProcessed.get()) - } + routing { + get("/") { + call.respond(HttpStatusCode.NoContent) } + } - // Outside hookRequests scope original processors are restored - // so further requests should not increment the counters - handleRequest(HttpMethod.Get, "/").apply { + val engine = startApplicationAndGetEngine() + + // Injecting the hooks and checking they are invoked only once + engine.hookRequests( + processRequest = { setup -> + numberOfRequestsProcessed.incrementAndGet() + setup() + }, + processResponse = { numberOfResponsesProcessed.incrementAndGet() } + ) { + engine.handleRequest(HttpMethod.Get, "/").apply { assertEquals(expectedNumberOfCalls, numberOfRequestsProcessed.get()) assertEquals(expectedNumberOfCalls, numberOfResponsesProcessed.get()) } } + + // Outside hookRequests scope original processors are restored + // so further requests should not increment the counters + engine.handleRequest(HttpMethod.Get, "/").apply { + assertEquals(expectedNumberOfCalls, numberOfRequestsProcessed.get()) + assertEquals(expectedNumberOfCalls, numberOfResponsesProcessed.get()) + } } @Test - fun testCookiesSession() { + fun testCookiesSession() = testApplication { + @Serializable data class CountSession(val count: Int) - withTestApplication { - application.install(Sessions) { - cookie("MY_SESSION") - } - application.routing { - get("/") { - val session = call.sessions.getOrSet { CountSession(0) } - call.sessions.set(session.copy(count = session.count + 1)) - call.respond(HttpStatusCode.OK, "${session.count}") - } - } + install(Sessions) { + cookie("MY_SESSION") + } - fun doRequestAndCheckResponse(expected: String) { - handleRequest(HttpMethod.Get, "/").apply { assertEquals(expected, response.content) } + routing { + get("/") { + val session = call.sessions.getOrSet { CountSession(0) } + call.sessions.set(session.copy(count = session.count + 1)) + call.respond(HttpStatusCode.OK, "${session.count}") } + } - // By defaul it doesn't preserve cookies - doRequestAndCheckResponse("0") - doRequestAndCheckResponse("0") + val engine = startApplicationAndGetEngine() - // Inside a cookiesSession block cookies are preserved. - cookiesSession { - doRequestAndCheckResponse("0") - doRequestAndCheckResponse("1") - } + fun doRequestAndCheckResponse(expected: String) { + engine.handleRequest(HttpMethod.Get, "/").apply { assertEquals(expected, response.content) } + } - // Starting another cookiesSession block, doesn't preserve cookies from previous blocks. - cookiesSession { - doRequestAndCheckResponse("0") - doRequestAndCheckResponse("1") - doRequestAndCheckResponse("2") - } + // By default it doesn't preserve cookies + doRequestAndCheckResponse("0") + doRequestAndCheckResponse("0") + + // Inside a cookiesSession block cookies are preserved. + engine.cookiesSession { + doRequestAndCheckResponse("0") + doRequestAndCheckResponse("1") + } + + // Starting another cookiesSession block doesn't preserve cookies from previous blocks. + engine.cookiesSession { + doRequestAndCheckResponse("0") + doRequestAndCheckResponse("1") + doRequestAndCheckResponse("2") } } @Test - fun accessNotExistingRouteTest() { - withTestApplication { - application.routing { - get("/exist") { - call.respondText("Routing exist") - } + fun accessNotExistingRouteTest() = testApplication { + routing { + get("/exist") { + call.respondText("Routing exist") } + } - val client = client.config { expectSuccess = false } - runBlocking { - val notExistingResponse = client.get("/notExist") - assertEquals(HttpStatusCode.NotFound, notExistingResponse.status) + val client = client.config { expectSuccess = false } - val existingResponse = client.get("/exist") - assertEquals(HttpStatusCode.OK, existingResponse.status) - } - } + val notExistingResponse = client.get("/notExist") + assertEquals(HttpStatusCode.NotFound, notExistingResponse.status) + + val existingResponse = client.get("/exist") + assertEquals(HttpStatusCode.OK, existingResponse.status) } @Test - fun testMultipart() { - withTestApplication { - application.routing { - post("/multipart") { - call.receiveMultipart().readPart() - call.respond(HttpStatusCode.OK, "OK") - } + fun testMultipart() = testApplication { + routing { + post("/multipart") { + call.receiveMultipart().readPart() + call.respond(HttpStatusCode.OK, "OK") } + } - val boundary = "***bbb***" - val multipart = listOf( - PartData.FileItem( - { ByteReadChannel("BODY".toByteArray()) }, - {}, - headersOf( - HttpHeaders.ContentDisposition, - ContentDisposition.File - .withParameter(ContentDisposition.Parameters.Name, "file") - .withParameter(ContentDisposition.Parameters.FileName, "test.jpg") - .toString() - ) + val boundary = "***bbb***" + val multipart = listOf( + PartData.FileItem( + { ByteReadChannel("BODY".toByteArray()) }, + {}, + headersOf( + HttpHeaders.ContentDisposition, + ContentDisposition.File + .withParameter(ContentDisposition.Parameters.Name, "file") + .withParameter(ContentDisposition.Parameters.FileName, "test.jpg") + .toString() ) ) + ) - val response = handleRequest(method = HttpMethod.Post, uri = "/multipart") { - addHeader( - HttpHeaders.ContentType, - ContentType.MultiPart.FormData.withParameter("boundary", boundary).toString() - ) - bodyChannel = buildMultipart(boundary, multipart) - } - assertEquals(HttpStatusCode.OK, response.response.status()) - } + val engine = startApplicationAndGetEngine() + val response = engine.handleRequest(HttpMethod.Post, "/multipart") { + addHeader( + HttpHeaders.ContentType, + ContentType.MultiPart.FormData.withParameter("boundary", boundary).toString() + ) + bodyChannel = buildMultipart(boundary, multipart) + }.response + assertEquals(HttpStatusCode.OK, response.status()) + } + + private suspend fun ApplicationTestBuilder.startApplicationAndGetEngine(): TestApplicationEngine { + var engine: TestApplicationEngine? = null + application { engine = this.engine as TestApplicationEngine } + startApplication() + return checkNotNull(engine) } }