Skip to content

Commit b593473

Browse files
feat(aws): support reasoning blocks for claude 3.7 (langchain-ai#7768)
1 parent 4a645ba commit b593473

File tree

5 files changed

+498
-27
lines changed

5 files changed

+498
-27
lines changed

libs/langchain-aws/package.json

+5-5
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@
3232
"author": "LangChain",
3333
"license": "MIT",
3434
"dependencies": {
35-
"@aws-sdk/client-bedrock-agent-runtime": "^3.749.0",
36-
"@aws-sdk/client-bedrock-runtime": "^3.749.0",
37-
"@aws-sdk/client-kendra": "^3.749.0",
38-
"@aws-sdk/credential-provider-node": "^3.749.0",
35+
"@aws-sdk/client-bedrock-agent-runtime": "^3.755.0",
36+
"@aws-sdk/client-bedrock-runtime": "^3.755.0",
37+
"@aws-sdk/client-kendra": "^3.750.0",
38+
"@aws-sdk/credential-provider-node": "^3.750.0",
3939
"zod": "^3.23.8",
4040
"zod-to-json-schema": "^3.22.5"
4141
},
4242
"peerDependencies": {
43-
"@langchain/core": ">=0.2.21 <0.4.0"
43+
"@langchain/core": ">=0.3.41 <0.4.0"
4444
},
4545
"devDependencies": {
4646
"@aws-sdk/types": "^3.734.0",

libs/langchain-aws/src/common.ts

+209-16
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,22 @@ import type {
2020
ContentBlockDeltaEvent,
2121
ConverseStreamMetadataEvent,
2222
ContentBlockStartEvent,
23+
ReasoningContentBlock,
24+
ReasoningContentBlockDelta,
25+
ReasoningTextBlock,
2326
} from "@aws-sdk/client-bedrock-runtime";
2427
import type { DocumentType as __DocumentType } from "@smithy/types";
2528
import { isLangChainTool } from "@langchain/core/utils/function_calling";
2629
import { zodToJsonSchema } from "zod-to-json-schema";
2730
import { ChatGenerationChunk } from "@langchain/core/outputs";
28-
import { ChatBedrockConverseToolType, BedrockToolChoice } from "./types.js";
31+
import {
32+
ChatBedrockConverseToolType,
33+
BedrockToolChoice,
34+
MessageContentReasoningBlock,
35+
MessageContentReasoningBlockReasoningText,
36+
MessageContentReasoningBlockReasoningTextPartial,
37+
MessageContentReasoningBlockRedacted,
38+
} from "./types.js";
2939

3040
export function extractImageInfo(base64: string): ContentBlock.ImageMember {
3141
// Extract the format from the base64 string
@@ -99,22 +109,34 @@ export function convertToConverseMessages(messages: BaseMessage[]): {
99109
text: castMsg.content,
100110
});
101111
} else if (Array.isArray(castMsg.content)) {
102-
const contentBlocks: ContentBlock[] = castMsg.content.map((block) => {
103-
if (block.type === "text" && block.text !== "") {
104-
return {
105-
text: block.text,
106-
};
107-
} else {
108-
const blockValues = Object.fromEntries(
109-
Object.entries(block).filter(([key]) => key !== "type")
110-
);
111-
throw new Error(
112-
`Unsupported content block type: ${
113-
block.type
114-
} with content of ${JSON.stringify(blockValues, null, 2)}`
115-
);
112+
const concatenatedBlocks = concatenateLangchainReasoningBlocks(
113+
castMsg.content
114+
);
115+
const contentBlocks: ContentBlock[] = concatenatedBlocks.map(
116+
(block) => {
117+
if (block.type === "text" && block.text !== "") {
118+
return {
119+
text: block.text,
120+
};
121+
} else if (block.type === "reasoning_content") {
122+
return {
123+
reasoningContent:
124+
langchainReasoningBlockToBedrockReasoningBlock(
125+
block as MessageContentReasoningBlock
126+
),
127+
};
128+
} else {
129+
const blockValues = Object.fromEntries(
130+
Object.entries(block).filter(([key]) => key !== "type")
131+
);
132+
throw new Error(
133+
`Unsupported content block type: ${
134+
block.type
135+
} with content of ${JSON.stringify(blockValues, null, 2)}`
136+
);
137+
}
116138
}
117-
});
139+
);
118140

119141
assistantMsg.content = [
120142
...(assistantMsg.content ? assistantMsg.content : []),
@@ -411,6 +433,12 @@ export function convertConverseMessageToLangChainMessage(
411433
});
412434
} else if ("text" in c && typeof c.text === "string") {
413435
content.push({ type: "text", text: c.text });
436+
} else if ("reasoningContent" in c) {
437+
content.push(
438+
bedrockReasoningBlockToLangchainReasoningBlock(
439+
c.reasoningContent as ReasoningContentBlock
440+
)
441+
);
414442
} else {
415443
content.push(c);
416444
}
@@ -453,6 +481,17 @@ export function handleConverseStreamContentBlockDelta(
453481
],
454482
}),
455483
});
484+
} else if (contentBlockDelta.delta.reasoningContent) {
485+
return new ChatGenerationChunk({
486+
text: "",
487+
message: new AIMessageChunk({
488+
content: [
489+
bedrockReasoningDeltaToLangchainPartialReasoningBlock(
490+
contentBlockDelta.delta.reasoningContent
491+
),
492+
],
493+
}),
494+
});
456495
} else {
457496
throw new Error(
458497
`Unsupported content block type(s): ${JSON.stringify(
@@ -512,3 +551,157 @@ export function handleConverseStreamMetadata(
512551
}),
513552
});
514553
}
554+
555+
export function bedrockReasoningDeltaToLangchainPartialReasoningBlock(
556+
reasoningContent: ReasoningContentBlockDelta
557+
):
558+
| MessageContentReasoningBlockReasoningTextPartial
559+
| MessageContentReasoningBlockRedacted {
560+
const { text, redactedContent, signature } = reasoningContent;
561+
if (text) {
562+
return {
563+
type: "reasoning_content",
564+
reasoningText: { text },
565+
};
566+
}
567+
if (signature) {
568+
return {
569+
type: "reasoning_content",
570+
reasoningText: { signature },
571+
};
572+
}
573+
if (redactedContent) {
574+
return {
575+
type: "reasoning_content",
576+
redactedContent: Buffer.from(redactedContent).toString("base64"),
577+
};
578+
}
579+
throw new Error("Invalid reasoning content");
580+
}
581+
582+
export function bedrockReasoningBlockToLangchainReasoningBlock(
583+
reasoningContent: ReasoningContentBlock
584+
): MessageContentReasoningBlock {
585+
const { reasoningText, redactedContent } = reasoningContent;
586+
if (reasoningText) {
587+
return {
588+
type: "reasoning_content",
589+
reasoningText: reasoningText as Required<ReasoningTextBlock>,
590+
};
591+
}
592+
593+
if (redactedContent) {
594+
return {
595+
type: "reasoning_content",
596+
redactedContent: Buffer.from(redactedContent).toString("base64"),
597+
};
598+
}
599+
throw new Error("Invalid reasoning content");
600+
}
601+
602+
export function langchainReasoningBlockToBedrockReasoningBlock(
603+
content: MessageContentReasoningBlock
604+
): ReasoningContentBlock {
605+
if (content.type !== "reasoning_content") {
606+
throw new Error("Invalid reasoning content");
607+
}
608+
if ("reasoningText" in content) {
609+
return {
610+
reasoningText: content.reasoningText as ReasoningTextBlock,
611+
};
612+
}
613+
if ("redactedContent" in content) {
614+
return {
615+
redactedContent: Buffer.from(content.redactedContent, "base64"),
616+
};
617+
}
618+
throw new Error("Invalid reasoning content");
619+
}
620+
621+
export function concatenateLangchainReasoningBlocks(
622+
content: Array<MessageContentComplex | MessageContentReasoningBlock>
623+
): MessageContentComplex[] {
624+
const concatenatedBlocks: MessageContentComplex[] = [];
625+
let concatenatedBlock: Partial<MessageContentReasoningBlock> = {};
626+
627+
for (const block of content) {
628+
if (block.type !== "reasoning_content") {
629+
// if it's some other block type, end the current block, but keep it so we preserve order
630+
if (Object.keys(concatenatedBlock).length > 0) {
631+
concatenatedBlocks.push(
632+
concatenatedBlock as MessageContentReasoningBlock
633+
);
634+
concatenatedBlock = {};
635+
}
636+
concatenatedBlocks.push(block);
637+
continue;
638+
}
639+
640+
// non-redacted block
641+
if ("reasoningText" in block && typeof block.reasoningText === "object") {
642+
if ("redactedContent" in concatenatedBlock) {
643+
// new type of block, so end the previous one
644+
concatenatedBlocks.push(
645+
concatenatedBlock as MessageContentReasoningBlock
646+
);
647+
concatenatedBlock = {};
648+
}
649+
const { text, signature } = block.reasoningText as Partial<
650+
MessageContentReasoningBlockReasoningText["reasoningText"]
651+
>;
652+
const { text: prevText, signature: prevSignature } = (
653+
"reasoningText" in concatenatedBlock
654+
? concatenatedBlock.reasoningText
655+
: {}
656+
) as Partial<MessageContentReasoningBlockReasoningText["reasoningText"]>;
657+
658+
concatenatedBlock = {
659+
type: "reasoning_content",
660+
reasoningText: {
661+
...((concatenatedBlock as MessageContentReasoningBlockReasoningText)
662+
.reasoningText ?? {}),
663+
...(prevText !== undefined || text !== undefined
664+
? { text: (prevText ?? "") + (text ?? "") }
665+
: {}),
666+
...(prevSignature !== undefined || signature !== undefined
667+
? { signature: (prevSignature ?? "") + (signature ?? "") }
668+
: {}),
669+
},
670+
};
671+
// if a partial block chunk has a signature, the next one will begin a new reasoning block.
672+
// full blocks always have signatures, so we start one now, anyway
673+
if ("signature" in block.reasoningText) {
674+
concatenatedBlocks.push(
675+
concatenatedBlock as MessageContentReasoningBlock
676+
);
677+
concatenatedBlock = {};
678+
}
679+
}
680+
681+
if ("redactedContent" in block) {
682+
if ("reasoningText" in concatenatedBlock) {
683+
// New type of block, so end the previous one. We should't really hit
684+
// this, as we'll have created a new block upon encountering the
685+
// signature above, but better safe than sorry.
686+
concatenatedBlocks.push(
687+
concatenatedBlock as MessageContentReasoningBlock
688+
);
689+
concatenatedBlock = {};
690+
}
691+
const { redactedContent } = block;
692+
const prevRedactedContent = (
693+
"redactedContent" in concatenatedBlock
694+
? concatenatedBlock.redactedContent!
695+
: ""
696+
) as Partial<MessageContentReasoningBlockRedacted["redactedContent"]>;
697+
concatenatedBlock = {
698+
type: "reasoning_content",
699+
redactedContent: prevRedactedContent + redactedContent,
700+
};
701+
}
702+
}
703+
if (Object.keys(concatenatedBlock).length > 0) {
704+
concatenatedBlocks.push(concatenatedBlock as MessageContentReasoningBlock);
705+
}
706+
return concatenatedBlocks;
707+
}

0 commit comments

Comments
 (0)