Skip to content

[PECOBLR-314] add thrift protocol version handling #292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/DBSQLClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
const session = new DBSQLSession({
handle: definedOrError(response.sessionHandle),
context: this,
serverProtocolVersion: response.serverProtocolVersion,
});
this.sessions.add(session);
return session;
Expand Down
77 changes: 59 additions & 18 deletions lib/DBSQLSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import {
TSparkDirectResults,
TSparkArrowTypes,
TSparkParameter,
TProtocolVersion,
TExecuteStatementReq,
} from '../thrift/TCLIService_types';
import IDBSQLSession, {
ExecuteStatementOptions,
Expand All @@ -29,7 +31,7 @@ import IOperation from './contracts/IOperation';
import DBSQLOperation from './DBSQLOperation';
import Status from './dto/Status';
import InfoValue from './dto/InfoValue';
import { definedOrError, LZ4 } from './utils';
import { definedOrError, LZ4, ProtocolVersion } from './utils';
import CloseableCollection from './utils/CloseableCollection';
import { LogLevel } from './contracts/IDBSQLLogger';
import HiveDriverError from './errors/HiveDriverError';
Expand Down Expand Up @@ -74,13 +76,16 @@ function getDirectResultsOptions(maxRows: number | bigint | Int64 | null | undef
};
}

function getArrowOptions(config: ClientConfig): {
function getArrowOptions(
config: ClientConfig,
serverProtocolVersion: TProtocolVersion | undefined | null,
): {
canReadArrowResult: boolean;
useArrowNativeTypes?: TSparkArrowTypes;
} {
const { arrowEnabled = true, useArrowNativeTypes = true } = config;

if (!arrowEnabled) {
if (!arrowEnabled || !ProtocolVersion.supportsArrowMetadata(serverProtocolVersion)) {
return {
canReadArrowResult: false,
};
Expand Down Expand Up @@ -136,6 +141,7 @@ function getQueryParameters(
interface DBSQLSessionConstructorOptions {
handle: TSessionHandle;
context: IClientContext;
serverProtocolVersion?: TProtocolVersion;
}

export default class DBSQLSession implements IDBSQLSession {
Expand All @@ -145,14 +151,28 @@ export default class DBSQLSession implements IDBSQLSession {

private isOpen = true;

private serverProtocolVersion?: TProtocolVersion;

public onClose?: () => void;

private operations = new CloseableCollection<DBSQLOperation>();

constructor({ handle, context }: DBSQLSessionConstructorOptions) {
/**
* Helper method to determine if runAsync should be set for metadata operations
* @private
* @returns true if supported by protocol version, undefined otherwise
*/
private getRunAsyncForMetadataOperations(): boolean | undefined {
return ProtocolVersion.supportsAsyncMetadataOperations(this.serverProtocolVersion) ? true : undefined;
}

constructor({ handle, context, serverProtocolVersion }: DBSQLSessionConstructorOptions) {
this.sessionHandle = handle;
this.context = context;
// Get the server protocol version from the provided parameter (from TOpenSessionResp)
this.serverProtocolVersion = serverProtocolVersion;
this.context.getLogger().log(LogLevel.debug, `Session created with id: ${this.id}`);
this.context.getLogger().log(LogLevel.debug, `Server protocol version: ${this.serverProtocolVersion}`);
}

public get id() {
Expand Down Expand Up @@ -193,17 +213,29 @@ export default class DBSQLSession implements IDBSQLSession {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();
const operationPromise = driver.executeStatement({

const request = new TExecuteStatementReq({
sessionHandle: this.sessionHandle,
statement,
queryTimeout: options.queryTimeout ? numberToInt64(options.queryTimeout) : undefined,
runAsync: true,
...getDirectResultsOptions(options.maxRows, clientConfig),
...getArrowOptions(clientConfig),
canDownloadResult: options.useCloudFetch ?? clientConfig.useCloudFetch,
parameters: getQueryParameters(options.namedParameters, options.ordinalParameters),
canDecompressLZ4Result: (options.useLZ4Compression ?? clientConfig.useLZ4Compression) && Boolean(LZ4),
...getArrowOptions(clientConfig, this.serverProtocolVersion),
});

if (ProtocolVersion.supportsParameterizedQueries(this.serverProtocolVersion)) {
request.parameters = getQueryParameters(options.namedParameters, options.ordinalParameters);
}

if (ProtocolVersion.supportsArrowCompression(this.serverProtocolVersion)) {
request.canDecompressLZ4Result = (options.useLZ4Compression ?? clientConfig.useLZ4Compression) && Boolean(LZ4);
}

if (ProtocolVersion.supportsCloudFetch(this.serverProtocolVersion)) {
request.canDownloadResult = options.useCloudFetch ?? clientConfig.useCloudFetch;
}

const operationPromise = driver.executeStatement(request);
const response = await this.handleResponse(operationPromise);
const operation = this.createOperation(response);

Expand Down Expand Up @@ -352,9 +384,10 @@ export default class DBSQLSession implements IDBSQLSession {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();

const operationPromise = driver.getTypeInfo({
sessionHandle: this.sessionHandle,
runAsync: true,
runAsync: this.getRunAsyncForMetadataOperations(),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
Expand All @@ -371,9 +404,10 @@ export default class DBSQLSession implements IDBSQLSession {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();

const operationPromise = driver.getCatalogs({
sessionHandle: this.sessionHandle,
runAsync: true,
runAsync: this.getRunAsyncForMetadataOperations(),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
Expand All @@ -390,11 +424,12 @@ export default class DBSQLSession implements IDBSQLSession {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();

const operationPromise = driver.getSchemas({
sessionHandle: this.sessionHandle,
catalogName: request.catalogName,
schemaName: request.schemaName,
runAsync: true,
runAsync: this.getRunAsyncForMetadataOperations(),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
Expand All @@ -411,13 +446,14 @@ export default class DBSQLSession implements IDBSQLSession {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();

const operationPromise = driver.getTables({
sessionHandle: this.sessionHandle,
catalogName: request.catalogName,
schemaName: request.schemaName,
tableName: request.tableName,
tableTypes: request.tableTypes,
runAsync: true,
runAsync: this.getRunAsyncForMetadataOperations(),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
Expand All @@ -434,9 +470,10 @@ export default class DBSQLSession implements IDBSQLSession {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();

const operationPromise = driver.getTableTypes({
sessionHandle: this.sessionHandle,
runAsync: true,
runAsync: this.getRunAsyncForMetadataOperations(),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
Expand All @@ -453,13 +490,14 @@ export default class DBSQLSession implements IDBSQLSession {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();

const operationPromise = driver.getColumns({
sessionHandle: this.sessionHandle,
catalogName: request.catalogName,
schemaName: request.schemaName,
tableName: request.tableName,
columnName: request.columnName,
runAsync: true,
runAsync: this.getRunAsyncForMetadataOperations(),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
Expand All @@ -476,12 +514,13 @@ export default class DBSQLSession implements IDBSQLSession {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();

const operationPromise = driver.getFunctions({
sessionHandle: this.sessionHandle,
catalogName: request.catalogName,
schemaName: request.schemaName,
functionName: request.functionName,
runAsync: true,
runAsync: this.getRunAsyncForMetadataOperations(),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
Expand All @@ -492,12 +531,13 @@ export default class DBSQLSession implements IDBSQLSession {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();

const operationPromise = driver.getPrimaryKeys({
sessionHandle: this.sessionHandle,
catalogName: request.catalogName,
schemaName: request.schemaName,
tableName: request.tableName,
runAsync: true,
runAsync: this.getRunAsyncForMetadataOperations(),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
Expand All @@ -514,6 +554,7 @@ export default class DBSQLSession implements IDBSQLSession {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();

const operationPromise = driver.getCrossReference({
sessionHandle: this.sessionHandle,
parentCatalogName: request.parentCatalogName,
Expand All @@ -522,7 +563,7 @@ export default class DBSQLSession implements IDBSQLSession {
foreignCatalogName: request.foreignCatalogName,
foreignSchemaName: request.foreignSchemaName,
foreignTableName: request.foreignTableName,
runAsync: true,
runAsync: this.getRunAsyncForMetadataOperations(),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
Expand Down
3 changes: 2 additions & 1 deletion lib/utils/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ import definedOrError from './definedOrError';
import buildUserAgentString from './buildUserAgentString';
import formatProgress, { ProgressUpdateTransformer } from './formatProgress';
import LZ4 from './lz4';
import * as ProtocolVersion from './protocolVersion';

export { definedOrError, buildUserAgentString, formatProgress, ProgressUpdateTransformer, LZ4 };
export { definedOrError, buildUserAgentString, formatProgress, ProgressUpdateTransformer, LZ4, ProtocolVersion };
95 changes: 95 additions & 0 deletions lib/utils/protocolVersion.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import { TProtocolVersion } from '../../thrift/TCLIService_types';

/**
* Protocol version information from Thrift TCLIService
* Each version adds certain features to the Spark/Hive API
*
* Databricks only supports SPARK_CLI_SERVICE_PROTOCOL_V1 (0xA501) or higher
*/

/**
* Check if the current protocol version supports a specific feature
* @param serverProtocolVersion The protocol version received from server in TOpenSessionResp
* @param requiredVersion The minimum protocol version required for a feature
* @returns boolean indicating if the feature is supported
*/
export function isFeatureSupported(
serverProtocolVersion: TProtocolVersion | undefined | null,
requiredVersion: TProtocolVersion,
): boolean {
if (serverProtocolVersion === undefined || serverProtocolVersion === null) {
return false;
}

return serverProtocolVersion >= requiredVersion;
}

/**
* Check if parameterized queries are supported
* (Requires SPARK_CLI_SERVICE_PROTOCOL_V8 or higher)
* @param serverProtocolVersion The protocol version from server
* @returns boolean indicating if parameterized queries are supported
*/
export function supportsParameterizedQueries(serverProtocolVersion: TProtocolVersion | undefined | null): boolean {
return isFeatureSupported(serverProtocolVersion, TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8);
}

/**
* Check if async metadata operations are supported
* (Requires SPARK_CLI_SERVICE_PROTOCOL_V6 or higher)
* @param serverProtocolVersion The protocol version from server
* @returns boolean indicating if async metadata operations are supported
*/
export function supportsAsyncMetadataOperations(serverProtocolVersion: TProtocolVersion | undefined | null): boolean {
return isFeatureSupported(serverProtocolVersion, TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V6);
}

/**
* Check if result persistence mode is supported
* (Requires SPARK_CLI_SERVICE_PROTOCOL_V7 or higher)
* @param serverProtocolVersion The protocol version from server
* @returns boolean indicating if result persistence mode is supported
*/
export function supportsResultPersistenceMode(serverProtocolVersion: TProtocolVersion | undefined | null): boolean {
return isFeatureSupported(serverProtocolVersion, TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7);
}

/**
* Check if Arrow compression is supported
* (Requires SPARK_CLI_SERVICE_PROTOCOL_V6 or higher)
* @param serverProtocolVersion The protocol version from server
* @returns boolean indicating if compressed Arrow batches are supported
*/
export function supportsArrowCompression(serverProtocolVersion: TProtocolVersion | undefined | null): boolean {
return isFeatureSupported(serverProtocolVersion, TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V6);
}

/**
* Check if Arrow metadata is supported
* (Requires SPARK_CLI_SERVICE_PROTOCOL_V5 or higher)
* @param serverProtocolVersion The protocol version from server
* @returns boolean indicating if Arrow metadata is supported
*/
export function supportsArrowMetadata(serverProtocolVersion: TProtocolVersion | undefined | null): boolean {
return isFeatureSupported(serverProtocolVersion, TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V5);
}

/**
* Check if multiple catalogs are supported
* (Requires SPARK_CLI_SERVICE_PROTOCOL_V4 or higher)
* @param serverProtocolVersion The protocol version from server
* @returns boolean indicating if multiple catalogs are supported
*/
export function supportsMultipleCatalogs(serverProtocolVersion: TProtocolVersion | undefined | null): boolean {
return isFeatureSupported(serverProtocolVersion, TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4);
}

/**
* Check if cloud object storage fetching is supported
* (Requires SPARK_CLI_SERVICE_PROTOCOL_V3 or higher)
* @param serverProtocolVersion The protocol version from server
* @returns boolean indicating if cloud fetching is supported
*/
export function supportsCloudFetch(serverProtocolVersion: TProtocolVersion | undefined | null): boolean {
return isFeatureSupported(serverProtocolVersion, TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V3);
}
Loading
Loading