Skip to content

Commit

Permalink
* Rebase changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ag-ramachandran committed Jan 26, 2025
1 parent df63674 commit 96468c9
Show file tree
Hide file tree
Showing 13 changed files with 211 additions and 201 deletions.
17 changes: 0 additions & 17 deletions connector/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -37,28 +37,11 @@
<groupId>com.microsoft.azure.kusto</groupId>
<artifactId>kusto-data</artifactId>
<version>${kusto.sdk.version}</version>
<exclusions>
<exclusion>
<groupId>com.microsoft.azure</groupId>
<artifactId>msal4j</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>com.microsoft.azure</groupId>
<artifactId>msal4j</artifactId>
<version>${msal4j.version}</version>
</dependency>
<dependency>
<groupId>com.microsoft.azure.kusto</groupId>
<artifactId>kusto-ingest</artifactId>
<version>${kusto.sdk.version}</version>
<exclusions>
<exclusion>
<groupId>com.microsoft.azure</groupId>
<artifactId>msal4j</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>com.azure</groupId>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,59 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.microsoft.kusto.spark.authentication

import java.util.concurrent.{CompletableFuture, TimeUnit}
import java.util.function.Consumer
import com.microsoft.aad.msal4j.{DeviceCode, DeviceCodeFlowParameters, IAuthenticationResult}
import com.microsoft.azure.kusto.data.auth
import scala.concurrent.TimeoutException

class DeviceAuthentication(val cluster: String, val authority: String)
extends auth.DeviceAuthTokenProvider(cluster, authority, null) {
var currentDeviceCode: Option[DeviceCode] = None
var expiresAt: Option[Long] = None
val NewDeviceCodeFetchTimeout = 60L * 1000L
var currentToken: Option[String] = None

override def acquireNewAccessToken(): IAuthenticationResult = {
acquireNewAccessTokenAsync().get(NewDeviceCodeFetchTimeout, TimeUnit.MILLISECONDS)
}

def acquireNewAccessTokenAsync(): CompletableFuture[IAuthenticationResult] = {
val deviceCodeConsumer: Consumer[DeviceCode] = toJavaConsumer((deviceCode: DeviceCode) => {
this.currentDeviceCode = Some(deviceCode)
this.expiresAt = Some(System.currentTimeMillis + (deviceCode.expiresIn() * 1000))
println(deviceCode.message())
})

val deviceCodeFlowParams: DeviceCodeFlowParameters =
DeviceCodeFlowParameters.builder(scopes, deviceCodeConsumer).build
clientApplication.acquireToken(deviceCodeFlowParams)
}

implicit def toJavaConsumer[T](f: Function1[T, Unit]): Consumer[T] = new Consumer[T] {
override def accept(t: T) = f(t)
}

def refreshIfNeeded(): Unit = {
if (currentDeviceCode.isEmpty || expiresAt.get <= System.currentTimeMillis) {
currentToken = Some(acquireAccessToken())
}
}

def getDeviceCodeMessage: String = {
refreshIfNeeded()
this.currentDeviceCode.get.message()
}

def getDeviceCode: DeviceCode = {
refreshIfNeeded()
this.currentDeviceCode.get
}

def acquireToken(): String = {
refreshIfNeeded()
currentToken.get
}
}
//// Copyright (c) Microsoft Corporation. All rights reserved.
//// Licensed under the MIT License.
//
//package com.microsoft.kusto.spark.authentication
//
//import com.azure.core.credential.TokenRequestContext
//import com.azure.identity.DeviceCodeCredentialBuilder
//import com.microsoft.azure.kusto.data.auth
//
//import java.time.Duration
//import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable`
//
//class DeviceAuthentication(val cluster: String, val authority: String)
// extends auth.DeviceAuthTokenProvider(cluster, authority, null) {
// private val newDeviceCodeFetchTimeout = 60L * 1000L
// private var expiresAt: Option[Long] = Some(0L)
// private var currentToken: Option[String] = None
//
// def acquireToken(): String = {
// refreshIfNeeded()
// currentToken.get
// }
//
// private def refreshIfNeeded(): Unit = {
// if (isRefreshNeeded) {
// val tokenCredential =
// acquireNewAccessToken.getToken(new TokenRequestContext().addScopes(scopes.toSeq: _*))
// val tokenCredentialValue =
// tokenCredential.blockOptional(Duration.ofMillis(newDeviceCodeFetchTimeout))
// tokenCredentialValue.ifPresent(token => {
// currentToken = Some(token.getToken)
// expiresAt = Some(token.getExpiresAt.toEpochSecond * 1000)
// })
// }
// }
//
// private def acquireNewAccessToken = {
// val deviceCodeFlowParams = new DeviceCodeCredentialBuilder()
// super.createTokenCredential(deviceCodeFlowParams)
// }
//
// private def isRefreshNeeded: Boolean = {
// expiresAt.isEmpty || expiresAt.get < System.currentTimeMillis()
// }
//}
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class ExtendedKustoClient(
maybeCrp: Option[ClientRequestProperties],
retryConfig: Option[RetryConfig] = None): KustoOperationResult = {
KDSU.retryApplyFunction(
() => dmClient.execute(ExtendedKustoClient.DefaultDb, command, maybeCrp.orNull),
() => dmClient.executeMgmt(ExtendedKustoClient.DefaultDb, command, maybeCrp.orNull),
retryConfig.getOrElse(this.retryConfig),
"Execute DM command with retries")
}
Expand Down Expand Up @@ -553,8 +553,12 @@ class ExtendedKustoClient(
crp: ClientRequestProperties,
retryConfig: Option[RetryConfig] = None): KustoOperationResult = {
// TODO - CID should reflect retries
val isMgmtCommand = command.startsWith(".")
KDSU.retryApplyFunction(
() => engineClient.execute(database, command, crp),
() =>
if (isMgmtCommand) {
engineClient.executeMgmt(database, command, crp)
} else { engineClient.executeQuery(database, command, crp) },
retryConfig.getOrElse(this.retryConfig),
"Execute engine command with retries")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,9 @@ object KustoDataSourceUtils {
logWarn(
"parseSourceParameters",
"No authentication method was supplied - using device code authentication. The token should last for one hour")
val deviceCodeProvider = new DeviceAuthentication(clusterUrl, authorityId)
val accessToken = deviceCodeProvider.acquireToken()
authentication = KustoAccessTokenAuthentication(accessToken)
// val deviceCodeProvider = new DeviceAuthentication(clusterUrl, authorityId)
// val accessToken = deviceCodeProvider.acquireToken()
// authentication = KustoAccessTokenAuthentication(accessToken)
}
}
(authentication, keyVaultAuthentication)
Expand Down Expand Up @@ -722,7 +722,7 @@ object KustoDataSourceUtils {
val statusCol = "Status"
val statusCheck: () => Option[KustoResultSetTable] = () => {
try {
Some(client.execute(database, operationsShowCommand).getPrimaryResults)
Some(client.executeMgmt(database, operationsShowCommand).getPrimaryResults)
} catch {
case e: DataServiceException =>
if (e.isPermanent) {
Expand Down Expand Up @@ -852,7 +852,7 @@ object KustoDataSourceUtils {
query: String,
database: String,
crp: ClientRequestProperties): Int = {
val res = client.execute(database, generateCountQuery(query), crp).getPrimaryResults
val res = client.executeQuery(database, generateCountQuery(query), crp).getPrimaryResults
res.next()
res.getInt(0)
}
Expand All @@ -866,7 +866,9 @@ object KustoDataSourceUtils {
val estimationResult: util.List[AnyRef] = Await.result(
Future {
val res =
client.execute(database, generateEstimateRowsCountQuery(query), crp).getPrimaryResults
client
.executeQuery(database, generateEstimateRowsCountQuery(query), crp)
.getPrimaryResults
res.next()
res.getCurrentRow
},
Expand All @@ -887,7 +889,8 @@ object KustoDataSourceUtils {
if (estimatedCount == 0) {
Await.result(
Future {
val res = client.execute(database, generateCountQuery(query), crp).getPrimaryResults
val res =
client.executeQuery(database, generateCountQuery(query), crp).getPrimaryResults
res.next()
res.getInt(0)
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,6 @@ class ExtendedKustoClientTests extends AnyFlatSpec with Matchers {
WriteOptions(writeMode = WriteMode.Queued),
null,
true)
verify(stubbedClient.engineClient, times(0)).execute(any(), any(), any())
verify(stubbedClient.engineClient, times(0)).executeMgmt(any(), any(), any())
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.


package com.microsoft.kusto.spark

import com.microsoft.azure.kusto.data.ClientFactory
Expand All @@ -22,7 +21,8 @@ class KustoAuthenticationTestE2E extends AnyFlatSpec {
.appName("KustoSink")
.master(f"local[2]")
.getOrCreate()
private lazy val kustoConnectionOptions: KustoConnectionOptions = KustoTestUtils.getSystemTestOptions()
private lazy val kustoConnectionOptions: KustoConnectionOptions =
KustoTestUtils.getSystemTestOptions()

val keyVaultAppId: String = System.getProperty(KustoSinkOptions.KEY_VAULT_APP_ID)
val keyVaultAppKey: String = System.getProperty(KustoSinkOptions.KEY_VAULT_APP_KEY)
Expand All @@ -38,19 +38,25 @@ class KustoAuthenticationTestE2E extends AnyFlatSpec {
val prefix = "keyVaultAuthentication"
val table = KustoQueryUtils.simplifyName(s"${prefix}_${UUID.randomUUID()}")
val engineKcsb = ConnectionStringBuilder.createWithAadAccessTokenAuthentication(
kustoConnectionOptions.cluster,kustoConnectionOptions.accessToken)
kustoConnectionOptions.cluster,
kustoConnectionOptions.accessToken)
val kustoAdminClient = ClientFactory.createClient(engineKcsb)

val df = rows.toDF("name", "value")
val conf: Map[String, String] = Map(
KustoSinkOptions.KEY_VAULT_URI -> keyVaultUri,
KustoSinkOptions.KEY_VAULT_APP_ID -> (if (keyVaultAppId == null) "" else keyVaultAppId),
KustoSinkOptions.KEY_VAULT_APP_KEY -> (if (keyVaultAppKey == null) {""} else keyVaultAppKey),
KustoSinkOptions.KEY_VAULT_APP_KEY -> (if (keyVaultAppKey == null) { "" }
else keyVaultAppKey),
KustoSinkOptions.KUSTO_TABLE_CREATE_OPTIONS -> SinkTableCreationMode.CreateIfNotExist.toString)

df.write.kusto(kustoConnectionOptions.cluster, kustoConnectionOptions.database, table, conf)

val dfResult = spark.read.kusto(kustoConnectionOptions.cluster, kustoConnectionOptions.database, table, conf)
val dfResult = spark.read.kusto(
kustoConnectionOptions.cluster,
kustoConnectionOptions.database,
table,
conf)
val result = dfResult.select("name", "value").rdd.collect().sortBy(x => x.getInt(1))
val orig = df.select("name", "value").rdd.collect().sortBy(x => x.getInt(1))

Expand All @@ -75,41 +81,45 @@ class KustoAuthenticationTestE2E extends AnyFlatSpec {

df.write.kusto(kustoConnectionOptions.cluster, kustoConnectionOptions.database, table, conf)

val dfResult = spark.read.kusto(kustoConnectionOptions.cluster, kustoConnectionOptions.database, table, conf)
val dfResult = spark.read.kusto(
kustoConnectionOptions.cluster,
kustoConnectionOptions.database,
table,
conf)
val result = dfResult.select("name", "value").rdd.collect().sortBy(x => x.getInt(1))
val orig = df.select("name", "value").rdd.collect().sortBy(x => x.getInt(1))

assert(result.diff(orig).isEmpty)
}

"deviceAuthentication" should "use aad device authentication" taggedAs KustoE2E in {
import spark.implicits._
val expectedNumberOfRows = 1000
val timeoutMs: Int = 8 * 60 * 1000 // 8 minutes

val rows: immutable.IndexedSeq[(String, Int)] =
(1 to expectedNumberOfRows).map(v => (s"row-$v", v))
val prefix = "deviceAuthentication"
val table = KustoQueryUtils.simplifyName(s"${prefix}_${UUID.randomUUID()}")

val deviceAuth = new com.microsoft.kusto.spark.authentication.DeviceAuthentication(
kustoConnectionOptions.cluster,
kustoConnectionOptions.tenantId)
val token = deviceAuth.acquireToken()
val engineKcsb = ConnectionStringBuilder.createWithAadAccessTokenAuthentication(
kustoConnectionOptions.cluster,
token)
val kustoAdminClient = ClientFactory.createClient(engineKcsb)
val df = rows.toDF("name", "value")
val conf: Map[String, String] = Map(
KustoSinkOptions.KUSTO_TABLE_CREATE_OPTIONS -> SinkTableCreationMode.CreateIfNotExist.toString)
df.write.kusto(kustoConnectionOptions.cluster, kustoConnectionOptions.database, table, conf)
KustoTestUtils.validateResultsAndCleanup(
kustoAdminClient,
table,
kustoConnectionOptions.database,
expectedNumberOfRows,
timeoutMs,
tableCleanupPrefix = prefix)
}
// "deviceAuthentication" should "use aad device authentication" taggedAs KustoE2E in {
// import spark.implicits._
// val expectedNumberOfRows = 1000
// val timeoutMs: Int = 8 * 60 * 1000 // 8 minutes
//
// val rows: immutable.IndexedSeq[(String, Int)] =
// (1 to expectedNumberOfRows).map(v => (s"row-$v", v))
// val prefix = "deviceAuthentication"
// val table = KustoQueryUtils.simplifyName(s"${prefix}_${UUID.randomUUID()}")
//
// val deviceAuth = new com.microsoft.kusto.spark.authentication.DeviceAuthentication(
// kustoConnectionOptions.cluster,
// kustoConnectionOptions.tenantId)
// val token = deviceAuth.acquireToken()
// val engineKcsb = ConnectionStringBuilder.createWithAadAccessTokenAuthentication(
// kustoConnectionOptions.cluster,
// token)
// val kustoAdminClient = ClientFactory.createClient(engineKcsb)
// val df = rows.toDF("name", "value")
// val conf: Map[String, String] = Map(
// KustoSinkOptions.KUSTO_TABLE_CREATE_OPTIONS -> SinkTableCreationMode.CreateIfNotExist.toString)
// df.write.kusto(kustoConnectionOptions.cluster, kustoConnectionOptions.database, table, conf)
// KustoTestUtils.validateResultsAndCleanup(
// kustoAdminClient,
// table,
// kustoConnectionOptions.database,
// expectedNumberOfRows,
// timeoutMs,
// tableCleanupPrefix = prefix)
// }
}
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.


package com.microsoft.kusto.spark

import com.microsoft.azure.kusto.data.ClientFactory
import com.microsoft.azure.kusto.data.auth.ConnectionStringBuilder
import com.microsoft.kusto.spark.KustoTestUtils.getSystemTestOptions
import com.microsoft.kusto.spark.datasink.KustoSinkOptions
import com.microsoft.kusto.spark.datasource.{KustoResponseDeserializer, KustoSourceOptions, TransientStorageCredentials, TransientStorageParameters}
import com.microsoft.kusto.spark.datasource.{
KustoResponseDeserializer,
KustoSourceOptions,
TransientStorageCredentials,
TransientStorageParameters
}
import com.microsoft.kusto.spark.sql.extension.SparkExtension._

import java.util.concurrent.atomic.AtomicInteger
import com.microsoft.kusto.spark.utils.KustoQueryUtils.getQuerySchemaQuery
import com.microsoft.kusto.spark.utils.{CslCommandsGenerator, KustoBlobStorageUtils, KustoQueryUtils, KustoDataSourceUtils => KDSU}
import com.microsoft.kusto.spark.utils.{
CslCommandsGenerator,
KustoBlobStorageUtils,
KustoQueryUtils,
KustoDataSourceUtils => KDSU
}
import org.apache.spark.SparkContext
import org.apache.spark.sql.{SQLContext, SparkSession}
import org.scalatest.BeforeAndAfterAll
Expand Down Expand Up @@ -94,7 +103,7 @@ class KustoBlobAccessE2E extends AnyFlatSpec with BeforeAndAfterAll {
val myTable = updateKustoTable()
val schema = KustoResponseDeserializer(
kustoAdminClient
.execute(kustoTestConnectionOptions.database, getQuerySchemaQuery(myTable))
.executeMgmt(kustoTestConnectionOptions.database, getQuerySchemaQuery(myTable))
.getPrimaryResults).getSchema

val firstColumn =
Expand Down Expand Up @@ -131,7 +140,7 @@ class KustoBlobAccessE2E extends AnyFlatSpec with BeforeAndAfterAll {
Some(partitionPredicate))

val blobs = kustoAdminClient
.execute(kustoTestConnectionOptions.database, exportCommand)
.executeMgmt(kustoTestConnectionOptions.database, exportCommand)
.getPrimaryResults
.getData
.asScala
Expand Down
Loading

0 comments on commit 96468c9

Please sign in to comment.