-
Notifications
You must be signed in to change notification settings - Fork 2.5k
/
Copy pathmulti_route.ts
134 lines (119 loc) Β· 3.65 KB
/
multi_route.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import { ChainValues } from "@langchain/core/utils/types";
import {
CallbackManagerForChainRun,
Callbacks,
} from "@langchain/core/callbacks/manager";
import { BaseChain, ChainInputs } from "../../chains/base.js";
/**
* A type that represents the inputs for the MultiRouteChain. It is a
* recursive type that can contain nested objects, arrays, strings, and
* numbers.
*/
type Inputs = {
[key: string]: Inputs | Inputs[] | string | string[] | number | number[];
};
/**
* An interface that represents the route returned by the RouterChain. It
* includes optional fields for the destination and nextInputs.
*/
export interface Route {
destination?: string;
nextInputs: { [key: string]: Inputs };
}
/**
* An interface that extends the ChainInputs interface and adds additional
* properties for the routerChain, destinationChains, defaultChain, and
* silentErrors. It represents the input expected by the MultiRouteChain
* class.
*/
export interface MultiRouteChainInput extends ChainInputs {
routerChain: RouterChain;
destinationChains: { [name: string]: BaseChain };
defaultChain: BaseChain;
silentErrors?: boolean;
}
/**
* A class that represents a router chain. It
* extends the BaseChain class and provides functionality for routing
* inputs to different chains.
*/
export abstract class RouterChain extends BaseChain {
get outputKeys(): string[] {
return ["destination", "next_inputs"];
}
async route(inputs: ChainValues, callbacks?: Callbacks): Promise<Route> {
const result = await this.call(inputs, callbacks);
return {
destination: result.destination,
nextInputs: result.next_inputs,
};
}
}
/**
* A class that represents a multi-route chain.
* It extends the BaseChain class and provides functionality for routing
* inputs to different chains based on a router chain.
*/
export class MultiRouteChain extends BaseChain {
static lc_name() {
return "MultiRouteChain";
}
routerChain: RouterChain;
destinationChains: { [name: string]: BaseChain };
defaultChain: BaseChain;
silentErrors = false;
constructor(fields: MultiRouteChainInput) {
super(fields);
this.routerChain = fields.routerChain;
this.destinationChains = fields.destinationChains;
this.defaultChain = fields.defaultChain;
this.silentErrors = fields.silentErrors ?? this.silentErrors;
}
get inputKeys(): string[] {
return this.routerChain.inputKeys;
}
get outputKeys(): string[] {
return [];
}
async _call(
values: ChainValues,
runManager?: CallbackManagerForChainRun
): Promise<ChainValues> {
const { destination, nextInputs } = await this.routerChain.route(
values,
runManager?.getChild()
);
await runManager?.handleText(
`${destination}: ${JSON.stringify(nextInputs)}`
);
if (!destination) {
return this.defaultChain
.call(nextInputs, runManager?.getChild())
.catch((err) => {
throw new Error(`Error in default chain: ${err}`);
});
}
if (destination in this.destinationChains) {
return this.destinationChains[destination]
.call(nextInputs, runManager?.getChild())
.catch((err) => {
throw new Error(`Error in ${destination} chain: ${err}`);
});
}
if (this.silentErrors) {
return this.defaultChain
.call(nextInputs, runManager?.getChild())
.catch((err) => {
throw new Error(`Error in default chain: ${err}`);
});
}
throw new Error(
`Destination ${destination} not found in destination chains with keys ${Object.keys(
this.destinationChains
)}`
);
}
_chainType(): string {
return "multi_route_chain";
}
}