-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathCloudFetchResultHandler.ts
100 lines (81 loc) · 3.28 KB
/
CloudFetchResultHandler.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import fetch, { RequestInfo, RequestInit, Request } from 'node-fetch';
import { TGetResultSetMetadataResp, TRowSet, TSparkArrowResultLink } from '../../thrift/TCLIService_types';
import HiveDriverError from '../errors/HiveDriverError';
import IClientContext from '../contracts/IClientContext';
import IResultsProvider, { ResultsProviderFetchNextOptions } from './IResultsProvider';
import { ArrowBatch } from './utils';
import { LZ4 } from '../utils';
export default class CloudFetchResultHandler implements IResultsProvider<ArrowBatch> {
private readonly context: IClientContext;
private readonly source: IResultsProvider<TRowSet | undefined>;
private readonly isLZ4Compressed: boolean;
private pendingLinks: Array<TSparkArrowResultLink> = [];
private downloadTasks: Array<Promise<ArrowBatch>> = [];
constructor(
context: IClientContext,
source: IResultsProvider<TRowSet | undefined>,
{ lz4Compressed }: TGetResultSetMetadataResp,
) {
this.context = context;
this.source = source;
this.isLZ4Compressed = lz4Compressed ?? false;
if (this.isLZ4Compressed && !LZ4) {
throw new HiveDriverError('Cannot handle LZ4 compressed result: module `lz4` not installed');
}
}
public async hasMore() {
if (this.pendingLinks.length > 0 || this.downloadTasks.length > 0) {
return true;
}
return this.source.hasMore();
}
public async fetchNext(options: ResultsProviderFetchNextOptions) {
const data = await this.source.fetchNext(options);
data?.resultLinks?.forEach((link) => {
this.pendingLinks.push(link);
});
const clientConfig = this.context.getConfig();
const freeTaskSlotsCount = clientConfig.cloudFetchConcurrentDownloads - this.downloadTasks.length;
if (freeTaskSlotsCount > 0) {
const links = this.pendingLinks.splice(0, freeTaskSlotsCount);
const tasks = links.map((link) => this.downloadLink(link));
this.downloadTasks.push(...tasks);
}
const batch = await this.downloadTasks.shift();
if (!batch) {
return {
batches: [],
rowCount: 0,
};
}
if (this.isLZ4Compressed) {
batch.batches = batch.batches.map((buffer) => LZ4!.decode(buffer));
}
return batch;
}
private async downloadLink(link: TSparkArrowResultLink): Promise<ArrowBatch> {
if (Date.now() >= link.expiryTime.toNumber()) {
throw new Error('CloudFetch link has expired');
}
const response = await this.fetch(link.fileLink, { headers: link.httpHeaders });
if (!response.ok) {
throw new Error(`CloudFetch HTTP error ${response.status} ${response.statusText}`);
}
const result = await response.arrayBuffer();
return {
batches: [Buffer.from(result)],
rowCount: link.rowCount.toNumber(true),
};
}
private async fetch(url: RequestInfo, init?: RequestInit) {
const connectionProvider = await this.context.getConnectionProvider();
const agent = await connectionProvider.getAgent();
const retryPolicy = await connectionProvider.getRetryPolicy();
const requestConfig: RequestInit = { agent, ...init };
const result = await retryPolicy.invokeWithRetry(() => {
const request = new Request(url, requestConfig);
return fetch(request).then((response) => ({ request, response }));
});
return result.response;
}
}