Skip to content

Commit 102b410

Browse files
[CMSE] Clear padding bits of struct/unions/fp16 passed by value
When passing a value of a struct/union type from secure to non-secure state (that is returning from a CMSE entry function or passing an argument to CMSE-non-secure call), there is a potential sensitive information leak via the padding bits in the structure. It is not possible in the general case to ensure those bits are cleared by using Standard C/C++. This patch makes the compiler emit code to clear such padding bits. Since type information is lost in LLVM IR, the code generation is done by Clang. For each interesting record type, we build a bitmask, in which all the bits, corresponding to user declared members, are set. Values of record types are returned by coercing them to an integer. After the coercion, the coerced value is masked (with bitwise AND) and then returned by the function. In a similar manner, values of record types are passed as arguments by coercing them to an array of integers, and the coerced values themselves are masked. For union types, we effectively clear only bits, which aren't part of any member, since we don't know which is the currently active one. The compiler will issue a warning, whenever a union is passed to non-secure state. Values of half-precision floating-point types are passed in the least significant bits of a 32-bit register (GPR or FPR) with the most significant bits unspecified. Since this is also a potential leak of sensitive information, this patch also clears those unspecified bits. Differential Revision: https://reviews.llvm.org/D76369
1 parent e770153 commit 102b410

File tree

13 files changed

+842
-3
lines changed

13 files changed

+842
-3
lines changed

clang/include/clang/AST/Decl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3961,6 +3961,11 @@ class RecordDecl : public TagDecl {
39613961
return cast_or_null<RecordDecl>(TagDecl::getDefinition());
39623962
}
39633963

3964+
/// Returns whether this record is a union, or contains (at any nesting level)
3965+
/// a union member. This is used by CMSE to warn about possible information
3966+
/// leaks.
3967+
bool isOrContainsUnion() const;
3968+
39643969
// Iterator access to field members. The field iterator only visits
39653970
// the non-static data members of this class, ignoring any static
39663971
// data members, functions, constructors, destructors, etc.

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3123,6 +3123,10 @@ def warn_weak_identifier_undeclared : Warning<
31233123
def warn_attribute_cmse_entry_static : Warning<
31243124
"'cmse_nonsecure_entry' cannot be applied to functions with internal linkage">,
31253125
InGroup<IgnoredAttributes>;
3126+
def warn_cmse_nonsecure_union : Warning<
3127+
"passing union across security boundary via %select{parameter %1|return value}0 "
3128+
"may leak information">,
3129+
InGroup<DiagGroup<"cmse-union-leak">>;
31263130
def err_attribute_weak_static : Error<
31273131
"weak declaration cannot have internal linkage">;
31283132
def err_attribute_selectany_non_extern_data : Error<

clang/lib/AST/Decl.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4410,6 +4410,21 @@ void RecordDecl::setCapturedRecord() {
44104410
addAttr(CapturedRecordAttr::CreateImplicit(getASTContext()));
44114411
}
44124412

4413+
bool RecordDecl::isOrContainsUnion() const {
4414+
if (isUnion())
4415+
return true;
4416+
4417+
if (const RecordDecl *Def = getDefinition()) {
4418+
for (const FieldDecl *FD : Def->fields()) {
4419+
const RecordType *RT = FD->getType()->getAs<RecordType>();
4420+
if (RT && RT->getDecl()->isOrContainsUnion())
4421+
return true;
4422+
}
4423+
}
4424+
4425+
return false;
4426+
}
4427+
44134428
RecordDecl::field_iterator RecordDecl::field_begin() const {
44144429
if (hasExternalLexicalStorage() && !hasLoadedFieldsFromExternalStorage())
44154430
LoadFieldsFromExternalStorage();

clang/lib/CodeGen/CGCall.cpp

Lines changed: 242 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "CGBlocks.h"
1717
#include "CGCXXABI.h"
1818
#include "CGCleanup.h"
19+
#include "CGRecordLayout.h"
1920
#include "CodeGenFunction.h"
2021
#include "CodeGenModule.h"
2122
#include "TargetInfo.h"
@@ -2871,6 +2872,213 @@ static llvm::StoreInst *findDominatingStoreToReturnValue(CodeGenFunction &CGF) {
28712872
return store;
28722873
}
28732874

2875+
// Helper functions for EmitCMSEClearRecord
2876+
2877+
// Set the bits corresponding to a field having width `BitWidth` and located at
2878+
// offset `BitOffset` (from the least significant bit) within a storage unit of
2879+
// `Bits.size()` bytes. Each element of `Bits` corresponds to one target byte.
2880+
// Use little-endian layout, i.e.`Bits[0]` is the LSB.
2881+
static void setBitRange(SmallVectorImpl<uint64_t> &Bits, int BitOffset,
2882+
int BitWidth, int CharWidth) {
2883+
assert(CharWidth <= 64);
2884+
assert(static_cast<unsigned>(BitWidth) <= Bits.size() * CharWidth);
2885+
2886+
int Pos = 0;
2887+
if (BitOffset >= CharWidth) {
2888+
Pos += BitOffset / CharWidth;
2889+
BitOffset = BitOffset % CharWidth;
2890+
}
2891+
2892+
const uint64_t Used = (uint64_t(1) << CharWidth) - 1;
2893+
if (BitOffset + BitWidth >= CharWidth) {
2894+
Bits[Pos++] |= (Used << BitOffset) & Used;
2895+
BitWidth -= CharWidth - BitOffset;
2896+
BitOffset = 0;
2897+
}
2898+
2899+
while (BitWidth >= CharWidth) {
2900+
Bits[Pos++] = Used;
2901+
BitWidth -= CharWidth;
2902+
}
2903+
2904+
if (BitWidth > 0)
2905+
Bits[Pos++] |= (Used >> (CharWidth - BitWidth)) << BitOffset;
2906+
}
2907+
2908+
// Set the bits corresponding to a field having width `BitWidth` and located at
2909+
// offset `BitOffset` (from the least significant bit) within a storage unit of
2910+
// `StorageSize` bytes, located at `StorageOffset` in `Bits`. Each element of
2911+
// `Bits` corresponds to one target byte. Use target endian layout.
2912+
static void setBitRange(SmallVectorImpl<uint64_t> &Bits, int StorageOffset,
2913+
int StorageSize, int BitOffset, int BitWidth,
2914+
int CharWidth, bool BigEndian) {
2915+
2916+
SmallVector<uint64_t, 8> TmpBits(StorageSize);
2917+
setBitRange(TmpBits, BitOffset, BitWidth, CharWidth);
2918+
2919+
if (BigEndian)
2920+
std::reverse(TmpBits.begin(), TmpBits.end());
2921+
2922+
for (uint64_t V : TmpBits)
2923+
Bits[StorageOffset++] |= V;
2924+
}
2925+
2926+
static void setUsedBits(CodeGenModule &, QualType, int,
2927+
SmallVectorImpl<uint64_t> &);
2928+
2929+
// Set the bits in `Bits`, which correspond to the value representations of
2930+
// the actual members of the record type `RTy`. Note that this function does
2931+
// not handle base classes, virtual tables, etc, since they cannot happen in
2932+
// CMSE function arguments or return. The bit mask corresponds to the target
2933+
// memory layout, i.e. it's endian dependent.
2934+
static void setUsedBits(CodeGenModule &CGM, const RecordType *RTy, int Offset,
2935+
SmallVectorImpl<uint64_t> &Bits) {
2936+
ASTContext &Context = CGM.getContext();
2937+
int CharWidth = Context.getCharWidth();
2938+
const RecordDecl *RD = RTy->getDecl()->getDefinition();
2939+
const ASTRecordLayout &ASTLayout = Context.getASTRecordLayout(RD);
2940+
const CGRecordLayout &Layout = CGM.getTypes().getCGRecordLayout(RD);
2941+
2942+
int Idx = 0;
2943+
for (auto I = RD->field_begin(), E = RD->field_end(); I != E; ++I, ++Idx) {
2944+
const FieldDecl *F = *I;
2945+
2946+
if (F->isUnnamedBitfield() || F->isZeroLengthBitField(Context) ||
2947+
F->getType()->isIncompleteArrayType())
2948+
continue;
2949+
2950+
if (F->isBitField()) {
2951+
const CGBitFieldInfo &BFI = Layout.getBitFieldInfo(F);
2952+
setBitRange(Bits, Offset + BFI.StorageOffset.getQuantity(),
2953+
BFI.StorageSize / CharWidth, BFI.Offset,
2954+
BFI.Size, CharWidth,
2955+
CGM.getDataLayout().isBigEndian());
2956+
continue;
2957+
}
2958+
2959+
setUsedBits(CGM, F->getType(),
2960+
Offset + ASTLayout.getFieldOffset(Idx) / CharWidth, Bits);
2961+
}
2962+
}
2963+
2964+
// Set the bits in `Bits`, which correspond to the value representations of
2965+
// the elements of an array type `ATy`.
2966+
static void setUsedBits(CodeGenModule &CGM, const ConstantArrayType *ATy,
2967+
int Offset, SmallVectorImpl<uint64_t> &Bits) {
2968+
const ASTContext &Context = CGM.getContext();
2969+
2970+
QualType ETy = Context.getBaseElementType(ATy);
2971+
int Size = Context.getTypeSizeInChars(ETy).getQuantity();
2972+
SmallVector<uint64_t, 4> TmpBits(Size);
2973+
setUsedBits(CGM, ETy, 0, TmpBits);
2974+
2975+
for (int I = 0, N = Context.getConstantArrayElementCount(ATy); I < N; ++I) {
2976+
auto Src = TmpBits.begin();
2977+
auto Dst = Bits.begin() + Offset + I * Size;
2978+
for (int J = 0; J < Size; ++J)
2979+
*Dst++ |= *Src++;
2980+
}
2981+
}
2982+
2983+
// Set the bits in `Bits`, which correspond to the value representations of
2984+
// the type `QTy`.
2985+
static void setUsedBits(CodeGenModule &CGM, QualType QTy, int Offset,
2986+
SmallVectorImpl<uint64_t> &Bits) {
2987+
if (const auto *RTy = QTy->getAs<RecordType>())
2988+
return setUsedBits(CGM, RTy, Offset, Bits);
2989+
2990+
ASTContext &Context = CGM.getContext();
2991+
if (const auto *ATy = Context.getAsConstantArrayType(QTy))
2992+
return setUsedBits(CGM, ATy, Offset, Bits);
2993+
2994+
int Size = Context.getTypeSizeInChars(QTy).getQuantity();
2995+
if (Size <= 0)
2996+
return;
2997+
2998+
std::fill_n(Bits.begin() + Offset, Size,
2999+
(uint64_t(1) << Context.getCharWidth()) - 1);
3000+
}
3001+
3002+
static uint64_t buildMultiCharMask(const SmallVectorImpl<uint64_t> &Bits,
3003+
int Pos, int Size, int CharWidth,
3004+
bool BigEndian) {
3005+
assert(Size > 0);
3006+
uint64_t Mask = 0;
3007+
if (BigEndian) {
3008+
for (auto P = Bits.begin() + Pos, E = Bits.begin() + Pos + Size; P != E;
3009+
++P)
3010+
Mask = (Mask << CharWidth) | *P;
3011+
} else {
3012+
auto P = Bits.begin() + Pos + Size, End = Bits.begin() + Pos;
3013+
do
3014+
Mask = (Mask << CharWidth) | *--P;
3015+
while (P != End);
3016+
}
3017+
return Mask;
3018+
}
3019+
3020+
// Emit code to clear the bits in a record, which aren't a part of any user
3021+
// declared member, when the record is a function return.
3022+
llvm::Value *CodeGenFunction::EmitCMSEClearRecord(llvm::Value *Src,
3023+
llvm::IntegerType *ITy,
3024+
QualType QTy) {
3025+
assert(Src->getType() == ITy);
3026+
assert(ITy->getScalarSizeInBits() <= 64);
3027+
3028+
const llvm::DataLayout &DataLayout = CGM.getDataLayout();
3029+
int Size = DataLayout.getTypeStoreSize(ITy);
3030+
SmallVector<uint64_t, 4> Bits(Size);
3031+
setUsedBits(CGM, QTy->getAs<RecordType>(), 0, Bits);
3032+
3033+
int CharWidth = CGM.getContext().getCharWidth();
3034+
uint64_t Mask =
3035+
buildMultiCharMask(Bits, 0, Size, CharWidth, DataLayout.isBigEndian());
3036+
3037+
return Builder.CreateAnd(Src, Mask, "cmse.clear");
3038+
}
3039+
3040+
// Emit code to clear the bits in a record, which aren't a part of any user
3041+
// declared member, when the record is a function argument.
3042+
llvm::Value *CodeGenFunction::EmitCMSEClearRecord(llvm::Value *Src,
3043+
llvm::ArrayType *ATy,
3044+
QualType QTy) {
3045+
const llvm::DataLayout &DataLayout = CGM.getDataLayout();
3046+
int Size = DataLayout.getTypeStoreSize(ATy);
3047+
SmallVector<uint64_t, 16> Bits(Size);
3048+
setUsedBits(CGM, QTy->getAs<RecordType>(), 0, Bits);
3049+
3050+
// Clear each element of the LLVM array.
3051+
int CharWidth = CGM.getContext().getCharWidth();
3052+
int CharsPerElt =
3053+
ATy->getArrayElementType()->getScalarSizeInBits() / CharWidth;
3054+
int MaskIndex = 0;
3055+
llvm::Value *R = llvm::UndefValue::get(ATy);
3056+
for (int I = 0, N = ATy->getArrayNumElements(); I != N; ++I) {
3057+
uint64_t Mask = buildMultiCharMask(Bits, MaskIndex, CharsPerElt, CharWidth,
3058+
DataLayout.isBigEndian());
3059+
MaskIndex += CharsPerElt;
3060+
llvm::Value *T0 = Builder.CreateExtractValue(Src, I);
3061+
llvm::Value *T1 = Builder.CreateAnd(T0, Mask, "cmse.clear");
3062+
R = Builder.CreateInsertValue(R, T1, I);
3063+
}
3064+
3065+
return R;
3066+
}
3067+
3068+
// Emit code to clear the padding bits when returning or passing as an argument
3069+
// a 16-bit floating-point value.
3070+
llvm::Value *CodeGenFunction::EmitCMSEClearFP16(llvm::Value *Src) {
3071+
llvm::Type *RetTy = Src->getType();
3072+
assert(RetTy->isFloatTy() ||
3073+
RetTy->isIntegerTy() && RetTy->getIntegerBitWidth() == 32);
3074+
if (RetTy->isFloatTy()) {
3075+
llvm::Value *T0 = Builder.CreateBitCast(Src, Builder.getIntNTy(32));
3076+
llvm::Value *T1 = Builder.CreateAnd(T0, 0xffff, "cmse.clear");
3077+
return Builder.CreateBitCast(T1, RetTy);
3078+
}
3079+
return Builder.CreateAnd(Src, 0xffff, "cmse.clear");
3080+
}
3081+
28743082
void CodeGenFunction::EmitFunctionEpilog(const CGFunctionInfo &FI,
28753083
bool EmitRetDbgLoc,
28763084
SourceLocation EndLoc) {
@@ -3037,6 +3245,21 @@ void CodeGenFunction::EmitFunctionEpilog(const CGFunctionInfo &FI,
30373245

30383246
llvm::Instruction *Ret;
30393247
if (RV) {
3248+
if (CurFuncDecl && CurFuncDecl->hasAttr<CmseNSEntryAttr>()) {
3249+
// For certain return types, clear padding bits, as they may reveal
3250+
// sensitive information.
3251+
const Type *RTy = RetTy.getCanonicalType().getTypePtr();
3252+
if (RTy->isFloat16Type() || RTy->isHalfType()) {
3253+
// 16-bit floating-point types are passed in a 32-bit integer or float,
3254+
// with unspecified upper bits.
3255+
RV = EmitCMSEClearFP16(RV);
3256+
} else {
3257+
// Small struct/union types are passed as integers.
3258+
auto *ITy = dyn_cast<llvm::IntegerType>(RV->getType());
3259+
if (ITy != nullptr && isa<RecordType>(RetTy.getCanonicalType()))
3260+
RV = EmitCMSEClearRecord(RV, ITy, RetTy);
3261+
}
3262+
}
30403263
EmitReturnValueCheck(RV);
30413264
Ret = Builder.CreateRet(RV);
30423265
} else {
@@ -4332,8 +4555,25 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
43324555
} else {
43334556
// In the simple case, just pass the coerced loaded value.
43344557
assert(NumIRArgs == 1);
4335-
IRCallArgs[FirstIRArg] =
4336-
CreateCoercedLoad(Src, ArgInfo.getCoerceToType(), *this);
4558+
llvm::Value *Load =
4559+
CreateCoercedLoad(Src, ArgInfo.getCoerceToType(), *this);
4560+
4561+
if (CallInfo.isCmseNSCall()) {
4562+
// For certain parameter types, clear padding bits, as they may reveal
4563+
// sensitive information.
4564+
const Type *PTy = I->Ty.getCanonicalType().getTypePtr();
4565+
// 16-bit floating-point types are passed in a 32-bit integer or
4566+
// float, with unspecified upper bits.
4567+
if (PTy->isFloat16Type() || PTy->isHalfType()) {
4568+
Load = EmitCMSEClearFP16(Load);
4569+
} else {
4570+
// Small struct/union types are passed as integer arrays.
4571+
auto *ATy = dyn_cast<llvm::ArrayType>(Load->getType());
4572+
if (ATy != nullptr && isa<RecordType>(I->Ty.getCanonicalType()))
4573+
Load = EmitCMSEClearRecord(Load, ATy, I->Ty);
4574+
}
4575+
}
4576+
IRCallArgs[FirstIRArg] = Load;
43374577
}
43384578

43394579
break;

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3877,6 +3877,11 @@ class CodeGenFunction : public CodeGenTypeCache {
38773877
llvm::Value *EmitARMCDEBuiltinExpr(unsigned BuiltinID, const CallExpr *E,
38783878
ReturnValueSlot ReturnValue,
38793879
llvm::Triple::ArchType Arch);
3880+
llvm::Value *EmitCMSEClearRecord(llvm::Value *V, llvm::IntegerType *ITy,
3881+
QualType RTy);
3882+
llvm::Value *EmitCMSEClearRecord(llvm::Value *V, llvm::ArrayType *ATy,
3883+
QualType RTy);
3884+
llvm::Value *EmitCMSEClearFP16(llvm::Value *V);
38803885

38813886
llvm::Value *EmitCommonNeonBuiltinExpr(unsigned BuiltinID,
38823887
unsigned LLVMIntrinsic,

clang/lib/Sema/SemaDeclAttr.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1998,7 +1998,8 @@ static void handleCmseNSEntryAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
19981998
return;
19991999
}
20002000

2001-
if (cast<FunctionDecl>(D)->getStorageClass() == SC_Static) {
2001+
const auto *FD = cast<FunctionDecl>(D);
2002+
if (!FD->isExternallyVisible()) {
20022003
S.Diag(AL.getLoc(), diag::warn_attribute_cmse_entry_static);
20032004
return;
20042005
}

clang/lib/Sema/SemaExpr.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6583,6 +6583,18 @@ ExprResult Sema::BuildResolvedCallExpr(Expr *Fn, NamedDecl *NDecl,
65836583
if (NDecl)
65846584
DiagnoseSentinelCalls(NDecl, LParenLoc, Args);
65856585

6586+
// Warn for unions passing across security boundary (CMSE).
6587+
if (FuncT != nullptr && FuncT->getCmseNSCallAttr()) {
6588+
for (unsigned i = 0, e = Args.size(); i != e; i++) {
6589+
if (const auto *RT =
6590+
dyn_cast<RecordType>(Args[i]->getType().getCanonicalType())) {
6591+
if (RT->getDecl()->isOrContainsUnion())
6592+
Diag(Args[i]->getBeginLoc(), diag::warn_cmse_nonsecure_union)
6593+
<< 0 << i;
6594+
}
6595+
}
6596+
}
6597+
65866598
// Do special checking on direct calls to functions.
65876599
if (FDecl) {
65886600
if (CheckFunctionCall(FDecl, TheCall, Proto))

clang/lib/Sema/SemaStmt.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3631,6 +3631,12 @@ StmtResult Sema::BuildReturnStmt(SourceLocation ReturnLoc, Expr *RetValExp) {
36313631
if (isa<CXXBoolLiteralExpr>(RetValExp))
36323632
Diag(ReturnLoc, diag::warn_main_returns_bool_literal)
36333633
<< RetValExp->getSourceRange();
3634+
if (FD->hasAttr<CmseNSEntryAttr>() && RetValExp) {
3635+
if (const auto *RT = dyn_cast<RecordType>(FnRetType.getCanonicalType())) {
3636+
if (RT->getDecl()->isOrContainsUnion())
3637+
Diag(RetValExp->getBeginLoc(), diag::warn_cmse_nonsecure_union) << 1;
3638+
}
3639+
}
36343640
} else if (ObjCMethodDecl *MD = getCurMethodDecl()) {
36353641
FnRetType = MD->getReturnType();
36363642
isObjCMethod = true;

0 commit comments

Comments
 (0)