Skip to content

Commit 508ef17

Browse files
authored
[HLSL][RootSignature] Introduce HLSLFrontendAction to implement rootsig-define (#154639)
This pr implements the functionality of `rootsig-define` as described [here](https://github.com/llvm/wg-hlsl/blob/main/proposals/0029-root-signature-driver-options.md#option--rootsig-define). This is accomplished by: - Defining the `fdx-rootsignature-define`, and `rootsig-define` alias, driver options. It simply specifies the name of a macro that will expand to a `LiteralString` to be interpreted as a root signature. - Introduces a new general frontend action wrapper, `HLSLFrontendAction`. This class allows us to introduce `HLSL` specific behaviour on the underlying action (primarily `ASTFrontendAction`). Which will be further extended, or modularly wrapped, when considering future DXC options. - Using `HLSLFrontendAction` we can add a new `PPCallback` that will eagerly parse the root signature specified with `rootsig-define` and push it as a `TopLevelDecl` to `Sema`. This occurs when the macro has been lexed. - Since the root signature is parsed early, before any function declarations, we can then simply attach it to the entry function once it is encountered. Overwriting any applicable root signature attrs. Resolves #150274 ##### Implementation considerations To implement this feature, note that: 1. We need access to all defined macros. These are created as part of the first `Lex` in `Parser::Initialize` after `PP->EnterMainSourceFile` 2. `RootSignatureDecl` must be added to `Sema` before `Consumer->HandleTranslationUnit` is invoked in `ParseAST` Therefore, we can't handle the root signature in `HLSLFrontendAction::ExecuteAction` before (from 1.) or after (from 2.) invoking the underlying `ASTFrontendAction`. This means we could alternatively: - Manually handle this case [here](https://github.com/llvm/llvm-project/blob/ac8f0bb070c9071742b6f6ce03bebc9d87217830/clang/lib/Parse/ParseAST.cpp#L168) before parsing the first top level decl. - Hook into when we [return the entry function decl](https://github.com/llvm/llvm-project/blob/ac8f0bb070c9071742b6f6ce03bebc9d87217830/clang/lib/Parse/Parser.cpp#L1190) and then parse the root signature and override its `RootSignatureAttr`. The proposed solution handles this in the most modular way which should work on any `FrontendAction` that might use the `Parser` without invoking `ParseAST`, and, is not subject to needing to call the hook in multiple different places of function declarators.
1 parent 1d3c302 commit 508ef17

File tree

16 files changed

+284
-22
lines changed

16 files changed

+284
-22
lines changed

clang/include/clang/Basic/LangOptions.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,10 @@ class LangOptions : public LangOptionsBase {
552552
llvm::dxbc::RootSignatureVersion HLSLRootSigVer =
553553
llvm::dxbc::RootSignatureVersion::V1_1;
554554

555+
/// The HLSL root signature that will be used to overide the root signature
556+
/// used for the shader entry point.
557+
std::string HLSLRootSigOverride;
558+
555559
// Indicates if the wasm-opt binary must be ignored in the case of a
556560
// WebAssembly target.
557561
bool NoWasmOpt = false;

clang/include/clang/Driver/Options.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9436,6 +9436,18 @@ def dxc_rootsig_ver :
94369436
Alias<fdx_rootsignature_version>,
94379437
Group<dxc_Group>,
94389438
Visibility<[DXCOption]>;
9439+
def fdx_rootsignature_define :
9440+
Joined<["-"], "fdx-rootsignature-define=">,
9441+
Group<dxc_Group>,
9442+
Visibility<[ClangOption, CC1Option]>,
9443+
MarshallingInfoString<LangOpts<"HLSLRootSigOverride">, "\"\"">,
9444+
HelpText<"Override entry function root signature with root signature at "
9445+
"given macro name.">;
9446+
def dxc_rootsig_define :
9447+
Separate<["-"], "rootsig-define">,
9448+
Alias<fdx_rootsignature_define>,
9449+
Group<dxc_Group>,
9450+
Visibility<[DXCOption]>;
94399451
def hlsl_entrypoint : Option<["-"], "hlsl-entry", KIND_SEPARATE>,
94409452
Group<dxc_Group>,
94419453
Visibility<[ClangOption, CC1Option]>,

clang/include/clang/Frontend/FrontendActions.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,18 @@ class GetDependenciesByModuleNameAction : public PreprocessOnlyAction {
329329
: ModuleName(ModuleName) {}
330330
};
331331

332+
//===----------------------------------------------------------------------===//
333+
// HLSL Specific Actions
334+
//===----------------------------------------------------------------------===//
335+
336+
class HLSLFrontendAction : public WrapperFrontendAction {
337+
protected:
338+
void ExecuteAction() override;
339+
340+
public:
341+
HLSLFrontendAction(std::unique_ptr<FrontendAction> WrappedAction);
342+
};
343+
332344
} // end namespace clang
333345

334346
#endif

clang/include/clang/Parse/ParseHLSLRootSignature.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,10 @@ class RootSignatureParser {
236236
RootSignatureToken CurToken;
237237
};
238238

239+
IdentifierInfo *ParseHLSLRootSignature(Sema &Actions,
240+
llvm::dxbc::RootSignatureVersion Version,
241+
StringLiteral *Signature);
242+
239243
} // namespace hlsl
240244
} // namespace clang
241245

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ class SemaHLSL : public SemaBase {
153153
ActOnFinishRootSignatureDecl(SourceLocation Loc, IdentifierInfo *DeclIdent,
154154
ArrayRef<hlsl::RootSignatureElement> Elements);
155155

156+
void SetRootSignatureOverride(IdentifierInfo *DeclIdent) {
157+
RootSigOverrideIdent = DeclIdent;
158+
}
159+
156160
// Returns true if any RootSignatureElement is invalid and a diagnostic was
157161
// produced
158162
bool
@@ -221,6 +225,8 @@ class SemaHLSL : public SemaBase {
221225

222226
uint32_t ImplicitBindingNextOrderID = 0;
223227

228+
IdentifierInfo *RootSigOverrideIdent = nullptr;
229+
224230
private:
225231
void collectResourceBindingsOnVarDecl(VarDecl *D);
226232
void collectResourceBindingsOnUserRecordDecl(const VarDecl *VD,

clang/lib/CodeGen/CodeGenModule.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7534,6 +7534,9 @@ void CodeGenModule::EmitTopLevelDecl(Decl *D) {
75347534
getContext().getCanonicalTagType(cast<EnumDecl>(D)));
75357535
break;
75367536

7537+
case Decl::HLSLRootSignature:
7538+
// Will be handled by attached function
7539+
break;
75377540
case Decl::HLSLBuffer:
75387541
getHLSLRuntime().addBuffer(cast<HLSLBufferDecl>(D));
75397542
break;

clang/lib/Driver/ToolChains/Clang.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3801,6 +3801,7 @@ static void RenderHLSLOptions(const ArgList &Args, ArgStringList &CmdArgs,
38013801
options::OPT_disable_llvm_passes,
38023802
options::OPT_fnative_half_type,
38033803
options::OPT_hlsl_entrypoint,
3804+
options::OPT_fdx_rootsignature_define,
38043805
options::OPT_fdx_rootsignature_version};
38053806
if (!types::isHLSL(InputType))
38063807
return;

clang/lib/Driver/ToolChains/HLSL.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,13 @@ HLSLToolChain::TranslateArgs(const DerivedArgList &Args, StringRef BoundArch,
351351
A->claim();
352352
continue;
353353
}
354+
if (A->getOption().getID() == options::OPT_dxc_rootsig_define) {
355+
DAL->AddJoinedArg(nullptr,
356+
Opts.getOption(options::OPT_fdx_rootsignature_define),
357+
A->getValue());
358+
A->claim();
359+
continue;
360+
}
354361
if (A->getOption().getID() == options::OPT__SLASH_O) {
355362
StringRef OStr = A->getValue();
356363
if (OStr == "d") {

clang/lib/Frontend/CompilerInvocation.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,10 @@ static bool FixupInvocation(CompilerInvocation &Invocation,
640640
Diags.Report(diag::err_drv_argument_not_allowed_with)
641641
<< "-fdx-rootsignature-version" << GetInputKindName(IK);
642642

643+
if (Args.hasArg(OPT_fdx_rootsignature_define) && !LangOpts.HLSL)
644+
Diags.Report(diag::err_drv_argument_not_allowed_with)
645+
<< "-fdx-rootsignature-define" << GetInputKindName(IK);
646+
643647
if (Args.hasArg(OPT_fgpu_allow_device_init) && !LangOpts.HIP)
644648
Diags.Report(diag::warn_ignored_hip_only_option)
645649
<< Args.getLastArg(OPT_fgpu_allow_device_init)->getAsString(Args);

clang/lib/Frontend/FrontendActions.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "clang/Lex/HeaderSearch.h"
2323
#include "clang/Lex/Preprocessor.h"
2424
#include "clang/Lex/PreprocessorOptions.h"
25+
#include "clang/Parse/ParseHLSLRootSignature.h"
2526
#include "clang/Sema/TemplateInstCallback.h"
2627
#include "clang/Serialization/ASTReader.h"
2728
#include "clang/Serialization/ASTWriter.h"
@@ -1241,3 +1242,85 @@ void GetDependenciesByModuleNameAction::ExecuteAction() {
12411242
PPCallbacks *CB = PP.getPPCallbacks();
12421243
CB->moduleImport(SourceLocation(), Path, ModResult);
12431244
}
1245+
1246+
//===----------------------------------------------------------------------===//
1247+
// HLSL Specific Actions
1248+
//===----------------------------------------------------------------------===//
1249+
1250+
class InjectRootSignatureCallback : public PPCallbacks {
1251+
private:
1252+
Sema &Actions;
1253+
StringRef RootSigName;
1254+
llvm::dxbc::RootSignatureVersion Version;
1255+
1256+
std::optional<StringLiteral *> processStringLiteral(ArrayRef<Token> Tokens) {
1257+
for (Token Tok : Tokens)
1258+
if (!tok::isStringLiteral(Tok.getKind()))
1259+
return std::nullopt;
1260+
1261+
ExprResult StringResult = Actions.ActOnUnevaluatedStringLiteral(Tokens);
1262+
if (StringResult.isInvalid())
1263+
return std::nullopt;
1264+
1265+
if (auto Signature = dyn_cast<StringLiteral>(StringResult.get()))
1266+
return Signature;
1267+
1268+
return std::nullopt;
1269+
}
1270+
1271+
public:
1272+
void MacroDefined(const Token &MacroNameTok,
1273+
const MacroDirective *MD) override {
1274+
if (RootSigName != MacroNameTok.getIdentifierInfo()->getName())
1275+
return;
1276+
1277+
const MacroInfo *MI = MD->getMacroInfo();
1278+
auto Signature = processStringLiteral(MI->tokens());
1279+
if (!Signature.has_value()) {
1280+
Actions.getDiagnostics().Report(MI->getDefinitionLoc(),
1281+
diag::err_expected_string_literal)
1282+
<< /*in attributes...*/ 4 << "RootSignature";
1283+
return;
1284+
}
1285+
1286+
IdentifierInfo *DeclIdent =
1287+
hlsl::ParseHLSLRootSignature(Actions, Version, *Signature);
1288+
Actions.HLSL().SetRootSignatureOverride(DeclIdent);
1289+
}
1290+
1291+
InjectRootSignatureCallback(Sema &Actions, StringRef RootSigName,
1292+
llvm::dxbc::RootSignatureVersion Version)
1293+
: PPCallbacks(), Actions(Actions), RootSigName(RootSigName),
1294+
Version(Version) {}
1295+
};
1296+
1297+
void HLSLFrontendAction::ExecuteAction() {
1298+
// Pre-requisites to invoke
1299+
CompilerInstance &CI = getCompilerInstance();
1300+
if (!CI.hasASTContext() || !CI.hasPreprocessor())
1301+
return WrapperFrontendAction::ExecuteAction();
1302+
1303+
// InjectRootSignatureCallback requires access to invoke Sema to lookup/
1304+
// register a root signature declaration. The wrapped action is required to
1305+
// account for this by only creating a Sema if one doesn't already exist
1306+
// (like we have done, and, ASTFrontendAction::ExecuteAction)
1307+
if (!CI.hasSema())
1308+
CI.createSema(getTranslationUnitKind(),
1309+
/*CodeCompleteConsumer=*/nullptr);
1310+
Sema &S = CI.getSema();
1311+
1312+
// Register HLSL specific callbacks
1313+
auto LangOpts = CI.getLangOpts();
1314+
auto MacroCallback = std::make_unique<InjectRootSignatureCallback>(
1315+
S, LangOpts.HLSLRootSigOverride, LangOpts.HLSLRootSigVer);
1316+
1317+
Preprocessor &PP = CI.getPreprocessor();
1318+
PP.addPPCallbacks(std::move(MacroCallback));
1319+
1320+
// Invoke as normal
1321+
WrapperFrontendAction::ExecuteAction();
1322+
}
1323+
1324+
HLSLFrontendAction::HLSLFrontendAction(
1325+
std::unique_ptr<FrontendAction> WrappedAction)
1326+
: WrapperFrontendAction(std::move(WrappedAction)) {}

0 commit comments

Comments
 (0)