clang 22.0.0git
SemaSPIRV.cpp
Go to the documentation of this file.
1//===- SemaSPIRV.cpp - Semantic Analysis for SPIRV constructs--------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This implements Semantic Analysis for SPIRV constructs.
9//===----------------------------------------------------------------------===//
10
14#include "clang/Sema/Sema.h"
15
16// SPIR-V enumerants. Enums have only the required entries, see SPIR-V specs for
17// values.
18// FIXME: either use the SPIRV-Headers or generate a custom header using the
19// grammar (like done with MLIR).
20namespace spirv {
21enum class StorageClass : int {
22 Workgroup = 4,
24 Function = 7
25};
26}
27
28namespace clang {
29
31
32static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
33 assert(TheCall->getNumArgs() > 1);
34 QualType ArgTy0 = TheCall->getArg(0)->getType();
35
36 for (unsigned I = 1, N = TheCall->getNumArgs(); I < N; ++I) {
38 ArgTy0, TheCall->getArg(I)->getType())) {
39 S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector)
40 << TheCall->getDirectCallee() << /*useAllTerminology*/ true
41 << SourceRange(TheCall->getArg(0)->getBeginLoc(),
42 TheCall->getArg(N - 1)->getEndLoc());
43 return true;
44 }
45 }
46 return false;
47}
48
50 Sema *S, CallExpr *TheCall,
52 llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
53 Checks) {
54 unsigned NumArgs = TheCall->getNumArgs();
55 assert(Checks.size() == NumArgs &&
56 "Wrong number of checks for Number of args.");
57 // Apply each check to the corresponding argument
58 for (unsigned I = 0; I < NumArgs; ++I) {
59 Expr *Arg = TheCall->getArg(I);
60 if (Checks[I](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
61 return true;
62 }
63 return false;
64}
65
67 int ArgOrdinal,
68 clang::QualType PassedType) {
69 clang::QualType BaseType =
70 PassedType->isVectorType()
71 ? PassedType->castAs<clang::VectorType>()->getElementType()
72 : PassedType;
73 if (!BaseType->isHalfType() && !BaseType->isFloat16Type() &&
74 !BaseType->isFloat32Type())
75 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
76 << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
77 << /* half or float */ 2 << PassedType;
78 return false;
79}
80
82 int ArgOrdinal,
83 clang::QualType PassedType) {
84 if (!PassedType->isHalfType() && !PassedType->isFloat16Type() &&
85 !PassedType->isFloat32Type())
86 return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
87 << ArgOrdinal << /* scalar */ 1 << /* no int */ 0
88 << /* half or float */ 2 << PassedType;
89 return false;
90}
91
92static std::optional<int>
94 ExprResult Arg =
95 SemaRef.DefaultFunctionArrayLvalueConversion(Call->getArg(Argument));
96 if (Arg.isInvalid())
97 return true;
98 Call->setArg(Argument, Arg.get());
99
100 const Expr *IntArg = Arg.get();
102 Expr::EvalResult Eval;
103 Eval.Diag = &Notes;
104 if ((!IntArg->EvaluateAsConstantExpr(Eval, SemaRef.getASTContext())) ||
105 !Eval.Val.isInt() || Eval.Val.getInt().getBitWidth() > 32) {
106 SemaRef.Diag(IntArg->getBeginLoc(), diag::err_spirv_enum_not_int)
107 << 0 << IntArg->getSourceRange();
108 for (const PartialDiagnosticAt &PDiag : Notes)
109 SemaRef.Diag(PDiag.first, PDiag.second);
110 return true;
111 }
112 return {Eval.Val.getInt().getZExtValue()};
113}
114
115static bool checkGenericCastToPtr(Sema &SemaRef, CallExpr *Call) {
116 if (SemaRef.checkArgCount(Call, 2))
117 return true;
118
119 {
120 ExprResult Arg =
121 SemaRef.DefaultFunctionArrayLvalueConversion(Call->getArg(0));
122 if (Arg.isInvalid())
123 return true;
124 Call->setArg(0, Arg.get());
125
126 QualType Ty = Arg.get()->getType();
127 const auto *PtrTy = Ty->getAs<PointerType>();
128 auto AddressSpaceNotInGeneric = [&](LangAS AS) {
129 if (SemaRef.LangOpts.OpenCL)
130 return AS != LangAS::opencl_generic;
131 return AS != LangAS::Default;
132 };
133 if (!PtrTy ||
134 AddressSpaceNotInGeneric(PtrTy->getPointeeType().getAddressSpace())) {
135 SemaRef.Diag(Arg.get()->getBeginLoc(),
136 diag::err_spirv_builtin_generic_cast_invalid_arg)
137 << Call->getSourceRange();
138 return true;
139 }
140 }
141
143 if (std::optional<int> SCInt =
145 SCInt.has_value()) {
146 StorageClass = static_cast<spirv::StorageClass>(SCInt.value());
150 SemaRef.Diag(Call->getArg(1)->getBeginLoc(),
151 diag::err_spirv_enum_not_valid)
152 << 0 << Call->getArg(1)->getSourceRange();
153 return true;
154 }
155 } else {
156 return true;
157 }
158 auto RT = Call->getArg(0)->getType();
159 RT = RT->getPointeeType();
160 auto Qual = RT.getQualifiers();
161 LangAS AddrSpace;
162 switch (StorageClass) {
164 AddrSpace =
166 break;
168 AddrSpace =
170 break;
172 AddrSpace = SemaRef.LangOpts.isSYCL() ? LangAS::sycl_private
174 break;
175 }
176 Qual.setAddressSpace(AddrSpace);
177 Call->setType(SemaRef.getASTContext().getPointerType(
178 SemaRef.getASTContext().getQualifiedType(RT.getUnqualifiedType(), Qual)));
179
180 return false;
181}
182
184 unsigned BuiltinID,
185 CallExpr *TheCall) {
186 if (BuiltinID >= SPIRV::FirstVKBuiltin && BuiltinID <= SPIRV::LastVKBuiltin &&
187 TI.getTriple().getArch() != llvm::Triple::spirv) {
188 SemaRef.Diag(TheCall->getBeginLoc(), diag::err_spirv_invalid_target) << 0;
189 return true;
190 }
191 if (BuiltinID >= SPIRV::FirstCLBuiltin && BuiltinID <= SPIRV::LastTSBuiltin &&
192 TI.getTriple().getArch() != llvm::Triple::spirv32 &&
193 TI.getTriple().getArch() != llvm::Triple::spirv64) {
194 SemaRef.Diag(TheCall->getBeginLoc(), diag::err_spirv_invalid_target) << 1;
195 return true;
196 }
197
198 switch (BuiltinID) {
199 case SPIRV::BI__builtin_spirv_distance: {
200 if (SemaRef.checkArgCount(TheCall, 2))
201 return true;
202
203 ExprResult A = TheCall->getArg(0);
204 QualType ArgTyA = A.get()->getType();
205 auto *VTyA = ArgTyA->getAs<VectorType>();
206 if (VTyA == nullptr) {
208 diag::err_typecheck_convert_incompatible)
209 << ArgTyA
211 << 0 << 0;
212 return true;
213 }
214
215 ExprResult B = TheCall->getArg(1);
216 QualType ArgTyB = B.get()->getType();
217 auto *VTyB = ArgTyB->getAs<VectorType>();
218 if (VTyB == nullptr) {
220 diag::err_typecheck_convert_incompatible)
221 << ArgTyB
223 << 0 << 0;
224 return true;
225 }
226
227 QualType RetTy = VTyA->getElementType();
228 TheCall->setType(RetTy);
229 break;
230 }
231 case SPIRV::BI__builtin_spirv_length: {
232 if (SemaRef.checkArgCount(TheCall, 1))
233 return true;
234 ExprResult A = TheCall->getArg(0);
235 QualType ArgTyA = A.get()->getType();
236 auto *VTy = ArgTyA->getAs<VectorType>();
237 if (VTy == nullptr) {
239 diag::err_typecheck_convert_incompatible)
240 << ArgTyA
242 << 0 << 0;
243 return true;
244 }
245 QualType RetTy = VTy->getElementType();
246 TheCall->setType(RetTy);
247 break;
248 }
249 case SPIRV::BI__builtin_spirv_reflect: {
250 if (SemaRef.checkArgCount(TheCall, 2))
251 return true;
252
253 ExprResult A = TheCall->getArg(0);
254 QualType ArgTyA = A.get()->getType();
255 auto *VTyA = ArgTyA->getAs<VectorType>();
256 if (VTyA == nullptr) {
258 diag::err_typecheck_convert_incompatible)
259 << ArgTyA
261 << 0 << 0;
262 return true;
263 }
264
265 ExprResult B = TheCall->getArg(1);
266 QualType ArgTyB = B.get()->getType();
267 auto *VTyB = ArgTyB->getAs<VectorType>();
268 if (VTyB == nullptr) {
270 diag::err_typecheck_convert_incompatible)
271 << ArgTyB
273 << 0 << 0;
274 return true;
275 }
276
277 QualType RetTy = ArgTyA;
278 TheCall->setType(RetTy);
279 break;
280 }
281 case SPIRV::BI__builtin_spirv_refract: {
282 if (SemaRef.checkArgCount(TheCall, 3))
283 return true;
284
285 llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>
286 ChecksArr[] = {CheckFloatOrHalfRepresentation,
290 llvm::ArrayRef(ChecksArr)))
291 return true;
292 // Check that first two arguments are vectors/scalars of the same type
293 QualType Arg0Type = TheCall->getArg(0)->getType();
295 Arg0Type, TheCall->getArg(1)->getType()))
296 return SemaRef.Diag(TheCall->getBeginLoc(),
297 diag::err_vec_builtin_incompatible_vector)
298 << TheCall->getDirectCallee() << /* first two */ 0
299 << SourceRange(TheCall->getArg(0)->getBeginLoc(),
300 TheCall->getArg(1)->getEndLoc());
301
302 // Check that scalar type of 3rd arg is same as base type of first two args
303 clang::QualType BaseType =
304 Arg0Type->isVectorType()
305 ? Arg0Type->castAs<clang::VectorType>()->getElementType()
306 : Arg0Type;
308 BaseType, TheCall->getArg(2)->getType()))
309 return SemaRef.Diag(TheCall->getBeginLoc(),
310 diag::err_hlsl_builtin_scalar_vector_mismatch)
311 << /* all */ 0 << TheCall->getDirectCallee() << Arg0Type
312 << TheCall->getArg(2)->getType();
313
314 QualType RetTy = TheCall->getArg(0)->getType();
315 TheCall->setType(RetTy);
316 break;
317 }
318 case SPIRV::BI__builtin_spirv_smoothstep: {
319 if (SemaRef.checkArgCount(TheCall, 3))
320 return true;
321
322 // Check if first argument has floating representation
323 ExprResult A = TheCall->getArg(0);
324 QualType ArgTyA = A.get()->getType();
325 if (!ArgTyA->hasFloatingRepresentation()) {
326 SemaRef.Diag(A.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type)
327 << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
328 << /* fp */ 1 << ArgTyA;
329 return true;
330 }
331
332 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
333 return true;
334
335 QualType RetTy = ArgTyA;
336 TheCall->setType(RetTy);
337 break;
338 }
339 case SPIRV::BI__builtin_spirv_faceforward: {
340 if (SemaRef.checkArgCount(TheCall, 3))
341 return true;
342
343 // Check if first argument has floating representation
344 ExprResult A = TheCall->getArg(0);
345 QualType ArgTyA = A.get()->getType();
346 if (!ArgTyA->hasFloatingRepresentation()) {
347 SemaRef.Diag(A.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type)
348 << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
349 << /* fp */ 1 << ArgTyA;
350 return true;
351 }
352
353 if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
354 return true;
355
356 QualType RetTy = ArgTyA;
357 TheCall->setType(RetTy);
358 break;
359 }
360 case SPIRV::BI__builtin_spirv_generic_cast_to_ptr_explicit: {
361 return checkGenericCastToPtr(SemaRef, TheCall);
362 }
363 }
364 return false;
365}
366} // namespace clang
SourceLocation Loc
Definition: SemaObjC.cpp:754
This file declares semantic analysis for SPIRV constructs.
Enumerates target-specific builtins in their own namespaces within namespace clang.
APSInt & getInt()
Definition: APValue.h:489
bool isInt() const
Definition: APValue.h:467
QualType getVectorType(QualType VectorType, unsigned NumElts, VectorKind VecKind) const
Return the unique reference to a vector type of the specified element type and size.
QualType getPointerType(QualType T) const
Return the uniqued reference to the type for a pointer to the specified type.
QualType getQualifiedType(SplitQualType split) const
Un-split a SplitQualType.
Definition: ASTContext.h:2442
bool hasSameUnqualifiedType(QualType T1, QualType T2) const
Determine whether the given types are equivalent after cvr-qualifiers have been removed.
Definition: ASTContext.h:2898
PtrTy get() const
Definition: Ownership.h:171
bool isInvalid() const
Definition: Ownership.h:167
CallExpr - Represents a function call (C99 6.5.2.2, C++ [expr.call]).
Definition: Expr.h:2879
Expr * getArg(unsigned Arg)
getArg - Return the specified argument.
Definition: Expr.h:3083
SourceLocation getBeginLoc() const
Definition: Expr.h:3213
FunctionDecl * getDirectCallee()
If the callee is a FunctionDecl, return it. Otherwise return null.
Definition: Expr.h:3062
unsigned getNumArgs() const
getNumArgs - Return the number of actual arguments to this call.
Definition: Expr.h:3070
This represents one expression.
Definition: Expr.h:112
void setType(QualType t)
Definition: Expr.h:145
bool EvaluateAsConstantExpr(EvalResult &Result, const ASTContext &Ctx, ConstantExprKind Kind=ConstantExprKind::Normal) const
Evaluate an expression that is required to be a constant expression.
QualType getType() const
Definition: Expr.h:144
bool isSYCL() const
Definition: LangOptions.h:702
PointerType - C99 6.7.5.1 - Pointer Declarators.
Definition: TypeBase.h:3346
A (possibly-)qualified type.
Definition: TypeBase.h:937
SemaDiagnosticBuilder Diag(SourceLocation Loc, unsigned DiagID, bool DeferHint=false)
Emit a diagnostic.
Definition: SemaBase.cpp:61
Sema & SemaRef
Definition: SemaBase.h:40
bool CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID, CallExpr *TheCall)
Definition: SemaSPIRV.cpp:183
SemaSPIRV(Sema &S)
Definition: SemaSPIRV.cpp:30
Sema - This implements semantic analysis and AST building for C.
Definition: Sema.h:850
ASTContext & Context
Definition: Sema.h:1276
ExprResult DefaultFunctionArrayLvalueConversion(Expr *E, bool Diagnose=true)
Definition: SemaExpr.cpp:748
ASTContext & getASTContext() const
Definition: Sema.h:918
const LangOptions & LangOpts
Definition: Sema.h:1274
bool checkArgCount(CallExpr *Call, unsigned DesiredArgCount)
Checks that a call expression's argument count is the desired number.
Encodes a location in the source.
A trivial tuple used to represent a source range.
SourceLocation getEndLoc() const LLVM_READONLY
Definition: Stmt.cpp:358
SourceRange getSourceRange() const LLVM_READONLY
SourceLocation tokens are not useful in isolation - they are low level value objects created/interpre...
Definition: Stmt.cpp:334
SourceLocation getBeginLoc() const LLVM_READONLY
Definition: Stmt.cpp:346
Exposes information about the current target.
Definition: TargetInfo.h:226
const llvm::Triple & getTriple() const
Returns the target triple of the primary target.
Definition: TargetInfo.h:1288
bool isFloat16Type() const
Definition: TypeBase.h:8945
const T * castAs() const
Member-template castAs<specific type>.
Definition: TypeBase.h:9226
bool isFloat32Type() const
Definition: TypeBase.h:8949
bool isHalfType() const
Definition: TypeBase.h:8940
bool hasFloatingRepresentation() const
Determine whether this type has a floating-point representation of some sort, e.g....
Definition: Type.cpp:2316
bool isVectorType() const
Definition: TypeBase.h:8719
const T * getAs() const
Member-template getAs<specific type>'.
Definition: TypeBase.h:9159
Represents a GCC generic vector type.
Definition: TypeBase.h:4191
#define bool
Definition: gpuintrin.h:32
Defines the clang::TargetInfo interface.
The JSON file list parser is used to communicate input to InstallAPI.
static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
Definition: SemaSPIRV.cpp:81
static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc, int ArgOrdinal, clang::QualType PassedType)
Definition: SemaSPIRV.cpp:66
static bool CheckAllArgTypesAreCorrect(Sema *S, CallExpr *TheCall, llvm::ArrayRef< llvm::function_ref< bool(Sema *, SourceLocation, int, QualType)> > Checks)
Definition: SemaSPIRV.cpp:49
StorageClass
Storage classes.
Definition: Specifiers.h:248
static std::optional< int > processConstant32BitIntArgument(Sema &SemaRef, CallExpr *Call, int Argument)
Definition: SemaSPIRV.cpp:93
static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall)
Definition: SemaSPIRV.cpp:32
LangAS
Defines the address space values used by the address space qualifier of QualType.
Definition: AddressSpaces.h:25
static bool checkGenericCastToPtr(Sema &SemaRef, CallExpr *Call)
Definition: SemaSPIRV.cpp:115
std::pair< SourceLocation, PartialDiagnostic > PartialDiagnosticAt
A partial diagnostic along with the source location where this diagnostic occurs.
@ Generic
not a target-specific vector type
StorageClass
Definition: SemaSPIRV.cpp:21
EvalResult is a struct with detailed info about an evaluated expression.
Definition: Expr.h:645
APValue Val
Val - This is the value the expression can be folded to.
Definition: Expr.h:647
SmallVectorImpl< PartialDiagnosticAt > * Diag
Diag - If this is non-null, it will be filled in with a stack of notes indicating why evaluation fail...
Definition: Expr.h:633