Skip to content

Use correct scopes for OAuth U2M and M2M flows #228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lib/DBSQLClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import Status from './dto/Status';
import HiveDriverError from './errors/HiveDriverError';
import { buildUserAgentString, definedOrError } from './utils';
import PlainHttpAuthentication from './connection/auth/PlainHttpAuthentication';
import DatabricksOAuth from './connection/auth/DatabricksOAuth';
import DatabricksOAuth, { OAuthFlow } from './connection/auth/DatabricksOAuth';
import IDBSQLLogger, { LogLevel } from './contracts/IDBSQLLogger';
import DBSQLLogger from './DBSQLLogger';
import CloseableCollection from './utils/CloseableCollection';
Expand Down Expand Up @@ -125,6 +125,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
});
case 'databricks-oauth':
return new DatabricksOAuth({
flow: options.oauthClientSecret === undefined ? OAuthFlow.U2M : OAuthFlow.M2M,
host: options.host,
persistence: options.persistence,
azureTenantId: options.azureTenantId,
Expand Down
85 changes: 59 additions & 26 deletions lib/connection/auth/DatabricksOAuth/OAuthManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@ import HiveDriverError from '../../../errors/HiveDriverError';
import { LogLevel } from '../../../contracts/IDBSQLLogger';
import OAuthToken from './OAuthToken';
import AuthorizationCode from './AuthorizationCode';
import { OAuthScope, OAuthScopes } from './OAuthScope';
import { OAuthScope, OAuthScopes, scopeDelimiter } from './OAuthScope';
import IClientContext from '../../../contracts/IClientContext';

export enum OAuthFlow {
U2M = 'U2M',
M2M = 'M2M',
}

export interface OAuthManagerOptions {
flow: OAuthFlow;
host: string;
callbackPorts?: Array<number>;
clientId?: string;
Expand Down Expand Up @@ -47,9 +53,7 @@ export default abstract class OAuthManager {

protected abstract getCallbackPorts(): Array<number>;

protected getScopes(requestedScopes: OAuthScopes): OAuthScopes {
return requestedScopes;
}
protected abstract getScopes(requestedScopes: OAuthScopes): OAuthScopes;

protected async getClient(): Promise<BaseClient> {
// Obtain http agent each time when we need an OAuth client
Expand Down Expand Up @@ -113,17 +117,11 @@ export default abstract class OAuthManager {
if (!accessToken || !refreshToken) {
throw new Error('Failed to refresh token: invalid response');
}
return new OAuthToken(accessToken, refreshToken);
return new OAuthToken(accessToken, refreshToken, token.scopes);
}

private async refreshAccessTokenM2M(): Promise<OAuthToken> {
const { access_token: accessToken, refresh_token: refreshToken } = await this.getTokenM2M();

if (!accessToken) {
throw new Error('Failed to fetch access token');
}

return new OAuthToken(accessToken, refreshToken);
private async refreshAccessTokenM2M(token: OAuthToken): Promise<OAuthToken> {
return this.getTokenM2M(token.scopes ?? []);
}

public async refreshAccessToken(token: OAuthToken): Promise<OAuthToken> {
Expand All @@ -137,10 +135,16 @@ export default abstract class OAuthManager {
throw error;
}

return this.options.clientSecret === undefined ? this.refreshAccessTokenU2M(token) : this.refreshAccessTokenM2M();
switch (this.options.flow) {
case OAuthFlow.U2M:
return this.refreshAccessTokenU2M(token);
case OAuthFlow.M2M:
return this.refreshAccessTokenM2M(token);
// no default
}
}

private async getTokenU2M(scopes: OAuthScopes) {
private async getTokenU2M(scopes: OAuthScopes): Promise<OAuthToken> {
const client = await this.getClient();

const authCode = new AuthorizationCode({
Expand All @@ -153,37 +157,47 @@ export default abstract class OAuthManager {

const { code, verifier, redirectUri } = await authCode.fetch(mappedScopes);

return client.grant({
const { access_token: accessToken, refresh_token: refreshToken } = await client.grant({
grant_type: 'authorization_code',
code,
code_verifier: verifier,
redirect_uri: redirectUri,
});

if (!accessToken) {
throw new Error('Failed to fetch access token');
}
return new OAuthToken(accessToken, refreshToken, mappedScopes);
}

private async getTokenM2M() {
private async getTokenM2M(scopes: OAuthScopes): Promise<OAuthToken> {
const client = await this.getClient();

const mappedScopes = this.getScopes(scopes);

// M2M flow doesn't really support token refreshing, and refresh should not be available
// in response. Each time access token expires, client can just acquire a new one using
// client secret. Here we explicitly return access token only as a sign that we're not going
// to use refresh token for M2M flow anywhere later
const { access_token: accessToken } = await client.grant({
grant_type: 'client_credentials',
scope: 'all-apis', // this is the only allowed scope for M2M flow
scope: mappedScopes.join(scopeDelimiter),
});
return { access_token: accessToken, refresh_token: undefined };
}

public async getToken(scopes: OAuthScopes): Promise<OAuthToken> {
const { access_token: accessToken, refresh_token: refreshToken } =
this.options.clientSecret === undefined ? await this.getTokenU2M(scopes) : await this.getTokenM2M();

if (!accessToken) {
throw new Error('Failed to fetch access token');
}
return new OAuthToken(accessToken, undefined, mappedScopes);
}

return new OAuthToken(accessToken, refreshToken);
public async getToken(scopes: OAuthScopes): Promise<OAuthToken> {
switch (this.options.flow) {
case OAuthFlow.U2M:
return this.getTokenU2M(scopes);
case OAuthFlow.M2M:
return this.getTokenM2M(scopes);
// no default
}
}

public static getManager(options: OAuthManagerOptions): OAuthManager {
Expand Down Expand Up @@ -245,6 +259,14 @@ export class DatabricksOAuthManager extends OAuthManager {
protected getCallbackPorts(): Array<number> {
return this.options.callbackPorts ?? DatabricksOAuthManager.defaultCallbackPorts;
}

protected getScopes(requestedScopes: OAuthScopes): OAuthScopes {
if (this.options.flow === OAuthFlow.M2M) {
// this is the only allowed scope for M2M flow
return [OAuthScope.allAPIs];
}
return requestedScopes;
}
}

export class AzureOAuthManager extends OAuthManager {
Expand Down Expand Up @@ -273,7 +295,18 @@ export class AzureOAuthManager extends OAuthManager {
protected getScopes(requestedScopes: OAuthScopes): OAuthScopes {
// There is no corresponding scopes in Azure, instead, access control will be delegated to Databricks
const tenantId = this.options.azureTenantId ?? AzureOAuthManager.datatricksAzureApp;
const azureScopes = [`${tenantId}/user_impersonation`];

const azureScopes = [];

switch (this.options.flow) {
case OAuthFlow.U2M:
azureScopes.push(`${tenantId}/user_impersonation`);
break;
case OAuthFlow.M2M:
azureScopes.push(`${tenantId}/.default`);
break;
// no default
}

if (requestedScopes.includes(OAuthScope.offlineAccess)) {
azureScopes.push(OAuthScope.offlineAccess);
Expand Down
1 change: 1 addition & 0 deletions lib/connection/auth/DatabricksOAuth/OAuthScope.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
export enum OAuthScope {
offlineAccess = 'offline_access',
SQL = 'sql',
allAPIs = 'all-apis',
}

export type OAuthScopes = Array<string>;
Expand Down
11 changes: 10 additions & 1 deletion lib/connection/auth/DatabricksOAuth/OAuthToken.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import { OAuthScopes } from './OAuthScope';

export default class OAuthToken {
private readonly _accessToken: string;

private readonly _refreshToken?: string;

private readonly _scopes?: OAuthScopes;

private _expirationTime?: number;

constructor(accessToken: string, refreshToken?: string) {
constructor(accessToken: string, refreshToken?: string, scopes?: OAuthScopes) {
this._accessToken = accessToken;
this._refreshToken = refreshToken;
this._scopes = scopes;
}

get accessToken(): string {
Expand All @@ -18,6 +23,10 @@ export default class OAuthToken {
return this._refreshToken;
}

get scopes(): OAuthScopes | undefined {
return this._scopes;
}

get expirationTime(): number {
// This token has already been verified, and we are just parsing it.
// If it has been tampered with, it will be rejected on the server side.
Expand Down
4 changes: 3 additions & 1 deletion lib/connection/auth/DatabricksOAuth/index.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import { HeadersInit } from 'node-fetch';
import IAuthentication from '../../contracts/IAuthentication';
import OAuthPersistence, { OAuthPersistenceCache } from './OAuthPersistence';
import OAuthManager, { OAuthManagerOptions } from './OAuthManager';
import OAuthManager, { OAuthManagerOptions, OAuthFlow } from './OAuthManager';
import { OAuthScopes, defaultOAuthScopes } from './OAuthScope';
import IClientContext from '../../../contracts/IClientContext';

export { OAuthFlow };

interface DatabricksOAuthOptions extends OAuthManagerOptions {
scopes?: OAuthScopes;
persistence?: OAuthPersistence;
Expand Down
1 change: 1 addition & 0 deletions tests/unit/DBSQLClient.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ describe('DBSQLClient.initAuthProvider', () => {
authType: 'databricks-oauth',
// host is used when creating OAuth manager, so make it look like a real AWS instance
host: 'example.dev.databricks.com',
oauthClientSecret: 'test-secret',
});

expect(provider).to.be.instanceOf(DatabricksOAuth);
Expand Down
Loading