Skip to content

Commit

Permalink
add prefetch time in logic
Browse files Browse the repository at this point in the history
  • Loading branch information
peze committed Jan 1, 2025
1 parent 3d2787b commit 6f08be7
Show file tree
Hide file tree
Showing 14 changed files with 728 additions and 453 deletions.
48 changes: 22 additions & 26 deletions src/providers/ecs_ram_role.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import Credentials from '../credentials'
import CredentialsProvider from '../credentials_provider'
import { Request, doRequest } from './http'
import Session from './session'
import { parseUTC } from './time'
import { Session, SessionCredentialProvider, STALE_TIME } from './session'

const PREFETCH_TIME = 60 * 60;
const defaultMetadataTokenDuration = 21600; // 6 hours

export default class ECSRAMRoleCredentialsProvider implements CredentialsProvider {
export default class ECSRAMRoleCredentialsProvider extends SessionCredentialProvider implements CredentialsProvider {
private readonly roleName: string
private readonly disableIMDSv1: boolean
// for sts
private session: Session
private expirationTimestamp: number
// for refresher
private checker: NodeJS.Timeout
// for mock
private doRequest = doRequest;
private readonly readTimeout: number;
Expand All @@ -22,34 +20,33 @@ export default class ECSRAMRoleCredentialsProvider implements CredentialsProvide
}

constructor(builder: ECSRAMRoleCredentialsProviderBuilder) {
super(STALE_TIME, PREFETCH_TIME);
this.refresher = this.getCredentialsInternal;
this.roleName = builder.roleName;
this.disableIMDSv1 = builder.disableIMDSv1;
this.readTimeout = builder.readTimeout;
this.connectTimeout = builder.connectTimeout;
this.checker = this.checkCredentialsUpdateAsynchronously();
}

async getCredentials(): Promise<Credentials> {
if (!this.session || this.needUpdateCredential()) {
const session = await this.getCredentialsInternal();
const expirationTime = parseUTC(session.expiration);
this.session = session;
this.expirationTimestamp = expirationTime / 1000;
}

return Credentials.builder()
.withAccessKeyId(this.session.accessKeyId)
.withAccessKeySecret(this.session.accessKeySecret)
.withSecurityToken(this.session.securityToken)
.withProviderName(this.getProviderName())
.build();
checkCredentialsUpdateAsynchronously(): NodeJS.Timeout {
return setTimeout(async () => {
try {
await this.getCredentials();
} catch(err) {
console.error('CheckCredentialsUpdateAsynchronously Error:', err);
} finally {
this.checker = this.checkCredentialsUpdateAsynchronously();
}
}, 1000 * 60);
}

private needUpdateCredential(): boolean {
if (!this.expirationTimestamp) {
return true
close(): void {
if (this.checker != null) {
clearTimeout(this.checker);
this.checker = null;
}

return this.expirationTimestamp - (Date.now() / 1000) <= 180;
}

private async getMetadataToken(): Promise<string> {
Expand Down Expand Up @@ -139,7 +136,6 @@ export default class ECSRAMRoleCredentialsProvider implements CredentialsProvide

const request = builder.build();
const response = await this.doRequest(request);

if (response.statusCode !== 200) {
throw new Error(`get sts token failed, httpStatus: ${response.statusCode}, message = ${response.body.toString()}`);
}
Expand Down
36 changes: 4 additions & 32 deletions src/providers/oidc_role_arn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ import { promisify } from 'util';

import Credentials from '../credentials';
import CredentialsProvider from '../credentials_provider';
import Session from './session';
import { Session, SessionCredentialProvider, STALE_TIME } from './session';
import * as utils from '../util/utils';
import { doRequest, Request } from './http';
import { parseUTC } from './time';

const readFileAsync = promisify(readFile);

Expand Down Expand Up @@ -143,7 +142,7 @@ class OIDCRoleArnCredentialsProviderBuilder {
}
}

export default class OIDCRoleArnCredentialsProvider implements CredentialsProvider {
export default class OIDCRoleArnCredentialsProvider extends SessionCredentialProvider implements CredentialsProvider {
private readonly roleArn: string;
private readonly oidcProviderArn: string;
private readonly oidcTokenFilePath: string;
Expand All @@ -156,15 +155,15 @@ export default class OIDCRoleArnCredentialsProvider implements CredentialsProvid
private readonly readTimeout: number;
private readonly connectTimeout: number;

private session: Session;
expirationTimestamp: number;
lastUpdateTimestamp: number;

static builder() {
return new OIDCRoleArnCredentialsProviderBuilder();
}

constructor(builder: OIDCRoleArnCredentialsProviderBuilder) {
super(STALE_TIME);
this.refresher = this.getCredentialsInternal;
this.roleArn = builder.roleArn;
this.oidcProviderArn = builder.oidcProviderArn;
this.oidcTokenFilePath = builder.oidcTokenFilePath;
Expand All @@ -178,25 +177,6 @@ export default class OIDCRoleArnCredentialsProvider implements CredentialsProvid
this.doRequest = doRequest;
}

async getCredentials(): Promise<Credentials> {
if (!this.session || this.needUpdateCredential()) {
const session = await this.getCredentialsInternal();
// UTC time: 2015-04-09T11:52:19Z
const expirationTime = parseUTC(session.expiration)

this.expirationTimestamp = Math.floor(expirationTime / 1000);
this.lastUpdateTimestamp = Date.now();
this.session = session
}

return Credentials.builder()
.withAccessKeyId(this.session.accessKeyId)
.withAccessKeySecret(this.session.accessKeySecret)
.withSecurityToken(this.session.securityToken)
.withProviderName(this.getProviderName())
.build();
}

getProviderName(): string {
return 'oidc_role_arn';
}
Expand Down Expand Up @@ -255,12 +235,4 @@ export default class OIDCRoleArnCredentialsProvider implements CredentialsProvid

return new Session(AccessKeyId, AccessKeySecret, SecurityToken, Expiration);
}

needUpdateCredential(): boolean {
if (!this.expirationTimestamp) {
return true
}

return this.expirationTimestamp - Date.now() / 1000 <= 180
}
}
43 changes: 7 additions & 36 deletions src/providers/ram_role_arn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ import * as utils from '../util/utils';
import Credentials from '../credentials';
import CredentialsProvider from '../credentials_provider'
import { doRequest, Request } from './http';
import { parseUTC } from './time';
import Session from './session';
import { Session, SessionCredentialProvider, STALE_TIME } from './session';

const log = debug('sign');

Expand Down Expand Up @@ -139,7 +138,7 @@ function encode(str: string): string {
.replace(/\*/g, '%2A');
}

export default class RAMRoleARNCredentialsProvider implements CredentialsProvider {
export default class RAMRoleARNCredentialsProvider extends SessionCredentialProvider implements CredentialsProvider {
private readonly credentialsProvider: CredentialsProvider;
private readonly stsEndpoint: string;
private readonly roleSessionName: string;
Expand All @@ -153,15 +152,15 @@ export default class RAMRoleARNCredentialsProvider implements CredentialsProvide
// used for mock
private doRequest = doRequest;

private session: Session;
private lastUpdateTimestamp: number;
private expirationTimestamp: any;

static builder(): RAMRoleARNCredentialsProviderBuilder {
return new RAMRoleARNCredentialsProviderBuilder();
}

constructor(builder: RAMRoleARNCredentialsProviderBuilder) {
super(STALE_TIME);
this.refresher = this.getCredentialsInternal;
this.credentialsProvider = builder.credentialsProvider;
this.stsEndpoint = builder.stsEndpoint;
this.roleSessionName = builder.roleSessionName;
Expand All @@ -173,7 +172,8 @@ export default class RAMRoleARNCredentialsProvider implements CredentialsProvide
this.connectTimeout = builder.connectTimeout;
}

private async getCredentialsInternal(credentials: Credentials): Promise<Session> {
private async getCredentialsInternal(): Promise<Session> {
const credentials = await this.credentialsProvider.getCredentials();
const method = 'POST';
const builder = Request.builder().withMethod(method).withProtocol('https').withHost(this.stsEndpoint).withReadTimeout(this.readTimeout || 10000).withConnectTimeout(this.connectTimeout || 5000);

Expand Down Expand Up @@ -274,36 +274,7 @@ export default class RAMRoleARNCredentialsProvider implements CredentialsProvide
return new Session(AccessKeyId, AccessKeySecret, SecurityToken, Expiration);
}

async getCredentials(): Promise<Credentials> {
if (!this.session || this.needUpdateCredential()) {
// 获取前置凭证
const previousCredentials = await this.credentialsProvider.getCredentials();
const session = await this.getCredentialsInternal(previousCredentials);
// UTC time: 2015-04-09T11:52:19Z
const expirationTime = parseUTC(session.expiration)

this.expirationTimestamp = Math.floor(expirationTime / 1000);
this.lastUpdateTimestamp = Date.now();
this.session = session
}

return Credentials.builder()
.withAccessKeyId(this.session.accessKeyId)
.withAccessKeySecret(this.session.accessKeySecret)
.withSecurityToken(this.session.securityToken)
.withProviderName(`${this.getProviderName()}/${this.credentialsProvider.getProviderName()}`)
.build();
}

needUpdateCredential(): boolean {
if (!this.expirationTimestamp) {
return true
}

return this.expirationTimestamp - Date.now() / 1000 <= 180
}

getProviderName(): string {
return 'ram_role_arn';
return `ram_role_arn/${this.credentialsProvider.getProviderName()}`;
}
}
131 changes: 130 additions & 1 deletion src/providers/session.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
export default class Session {
import { parseUTC } from './time'
import { getRandomInt } from '../util/utils'
import CredentialsProvider from '../credentials_provider'
import Credentials from '../credentials'

export const STALE_TIME = 15 * 60;

export class Session {
accessKeyId: string;
accessKeySecret: string;
securityToken: string;
Expand All @@ -11,3 +18,125 @@ export default class Session {
this.expiration = expiration;
}
}

export declare type SessionRefresher = () => Promise<Session>;

export class SessionCredentialProvider implements CredentialsProvider {
private expirationTimestamp: number;
private session: Session;
private refreshFaliure: number;
private readonly staleTime: number;
private readonly prefetchTime: number;
private staleTimestamp: number;
private prefetchTimestamp: number;
refresher: SessionRefresher;

constructor(staleTime: number = 0, prefetchTime: number = 0) {
this.staleTime = staleTime || STALE_TIME;
if(prefetchTime) {
this.prefetchTime = prefetchTime;
this.prefetchTimestamp = Date.now() + (prefetchTime * 1000);
}
this.refreshFaliure = 0;
}

async getCredentials(): Promise<Credentials> {
this.session = await this.getSession();

return Credentials.builder()
.withAccessKeyId(this.session.accessKeyId)
.withAccessKeySecret(this.session.accessKeySecret)
.withSecurityToken(this.session.securityToken)
.withProviderName(this.getProviderName())
.build();
}

refreshTimestamp() {
this.staleTimestamp = this.expirationTimestamp - this.staleTime;
if(this.prefetchTimestamp) {
this.prefetchTimestamp = (Date.now() + (this.prefetchTime * 1000)) / 1000;
}
this.refreshFaliure = 0;
}

maxStaleFailureJitter(): number {
const exponentialBackoffMillis = (1 << (this.refreshFaliure - 1));
return exponentialBackoffMillis > 10 ? exponentialBackoffMillis : 10;
}

jitterTime(time: number, jitterStart: number, jitterEnd: number): number {
const jitterRange = jitterEnd - jitterStart;
const jitterAmount = Math.abs(Math.floor(Math.random() * jitterRange));
return time + jitterStart + jitterAmount;
}

async refreshSession(): Promise<void> {
try {
const session = await this.refresher();
this.refreshFaliure = 0;
const now = Date.now() / 1000;
const oldSessionAvailable = this.staleTimestamp > now;
const oldSession = this.session;
this.expirationTimestamp = parseUTC(session.expiration) / 1000;
this.session = session;
this.refreshTimestamp();
// 过期时间大于15分钟,不用管
if (this.staleTimestamp > now) {
return;
}
// 不足或等于15分钟,但未过期,下次会再次刷新
if (now < (this.staleTimestamp + this.staleTime)) {
this.expirationTimestamp = now + this.staleTime;
}
// 已过期,看缓存,缓存若大于15分钟,返回缓存,若小于15分钟,则根据策略判断是立刻重试还是稍后重试
if (now > (this.staleTimestamp + this.staleTime)) {
if(oldSessionAvailable) {
this.session = oldSession;
this.expirationTimestamp = parseUTC(oldSession.expiration) / 1000;
this.refreshTimestamp();
return;
}
const waitUntilNextRefresh = 50 + getRandomInt(20);
this.expirationTimestamp = now + waitUntilNextRefresh + this.staleTime;
}
} catch(err) {
if (!this.session) {
throw err;
}
const now = Date.now() / 1000;
if (now < this.staleTimestamp) {
return;
}
this.refreshFaliure++;
this.expirationTimestamp = this.jitterTime(now, 1, this.maxStaleFailureJitter()) + this.staleTime;
}
}
async getSession(): Promise<Session> {
if (this.needUpdateCredential() || this.shouldPrefetchCredential()) {
await this.refreshSession();
this.refreshTimestamp();
}
return this.session;
}

needUpdateCredential(): boolean {
if (!this.session || !this.expirationTimestamp) {
return true;
}

return (Date.now() / 1000) >= this.staleTimestamp;
}

shouldPrefetchCredential(): boolean {
if (!this.prefetchTimestamp) {
return false;
}

return this.expirationTimestamp - (Date.now() / 1000) <= this.prefetchTime;
}

getProviderName(): string {
return 'session';
}
}

Loading

0 comments on commit 6f08be7

Please sign in to comment.