Skip to content

Commit 62e228f

Browse files
committed
[Matrix] Add info about number of operations to remarks.
This patch updates the remark to also include a summary of the number of vector operations generated for each matrix expression. Reviewers: anemet, Gerolf, thegameg, hfinkel, andrew.w.kaylor, LuoYuanke Reviewed By: anemet Differential Revision: https://reviews.llvm.org/D72480
1 parent 3a5acdc commit 62e228f

File tree

2 files changed

+118
-17
lines changed

2 files changed

+118
-17
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 109 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,30 @@ class LowerMatrixIntrinsics {
141141
const TargetTransformInfo &TTI;
142142
OptimizationRemarkEmitter &ORE;
143143

144+
/// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
145+
struct OpInfoTy {
146+
/// Number of stores emitted to generate this matrix.
147+
unsigned NumStores = 0;
148+
/// Number of loads emitted to generate this matrix.
149+
unsigned NumLoads = 0;
150+
/// Number of compute operations emitted to generate this matrix.
151+
unsigned NumComputeOps = 0;
152+
153+
OpInfoTy &operator+=(const OpInfoTy &RHS) {
154+
NumStores += RHS.NumStores;
155+
NumLoads += RHS.NumLoads;
156+
NumComputeOps += RHS.NumComputeOps;
157+
return *this;
158+
}
159+
};
160+
144161
/// Wrapper class representing a matrix as a set of column vectors.
145162
/// All column vectors must have the same vector type.
146163
class ColumnMatrixTy {
147164
SmallVector<Value *, 16> Columns;
148165

166+
OpInfoTy OpInfo;
167+
149168
public:
150169
ColumnMatrixTy() : Columns() {}
151170
ColumnMatrixTy(ArrayRef<Value *> Cols)
@@ -167,6 +186,10 @@ class LowerMatrixIntrinsics {
167186

168187
void addColumn(Value *V) { Columns.push_back(V); }
169188

189+
VectorType *getColumnTy() {
190+
return cast<VectorType>(Columns[0]->getType());
191+
}
192+
170193
iterator_range<SmallVector<Value *, 8>::iterator> columns() {
171194
return make_range(Columns.begin(), Columns.end());
172195
}
@@ -177,6 +200,29 @@ class LowerMatrixIntrinsics {
177200
return Columns.size() == 1 ? Columns[0]
178201
: concatenateVectors(Builder, Columns);
179202
}
203+
204+
ColumnMatrixTy &addNumLoads(unsigned N) {
205+
OpInfo.NumLoads += N;
206+
return *this;
207+
}
208+
209+
void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
210+
211+
ColumnMatrixTy &addNumStores(unsigned N) {
212+
OpInfo.NumStores += N;
213+
return *this;
214+
}
215+
216+
ColumnMatrixTy &addNumComputeOps(unsigned N) {
217+
OpInfo.NumComputeOps += N;
218+
return *this;
219+
}
220+
221+
unsigned getNumStores() const { return OpInfo.NumStores; }
222+
unsigned getNumLoads() const { return OpInfo.NumLoads; }
223+
unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
224+
225+
const OpInfoTy &getOpInfo() const { return OpInfo; }
180226
};
181227

182228
struct ShapeInfo {
@@ -224,6 +270,20 @@ class LowerMatrixIntrinsics {
224270
OptimizationRemarkEmitter &ORE)
225271
: Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), ORE(ORE) {}
226272

273+
unsigned getNumOps(Type *VT) {
274+
assert(isa<VectorType>(VT) && "Expected vector type");
275+
return getNumOps(VT->getScalarType(),
276+
cast<VectorType>(VT)->getNumElements());
277+
}
278+
279+
//
280+
/// Return the estimated number of vector ops required for an operation on
281+
/// \p VT * N.
282+
unsigned getNumOps(Type *ST, unsigned N) {
283+
return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() /
284+
double(TTI.getRegisterBitWidth(true)));
285+
}
286+
227287
/// Return the set of column vectors that a matrix value is lowered to.
228288
///
229289
/// If we lowered \p MatrixVal, just return the cache result column matrix.
@@ -582,7 +642,10 @@ class LowerMatrixIntrinsics {
582642
Result.addColumn(Column);
583643
}
584644

585-
finalizeLowering(Inst, Result, Builder);
645+
finalizeLowering(Inst,
646+
Result.addNumLoads(getNumOps(Result.getColumnTy()) *
647+
Result.getNumColumns()),
648+
Builder);
586649
}
587650

588651
/// Lowers llvm.matrix.columnwise.load.
@@ -607,7 +670,8 @@ class LowerMatrixIntrinsics {
607670
Shape.NumRows, VType->getElementType(), Builder);
608671
createColumnStore(C.value(), GEP, VType->getElementType(), Builder);
609672
}
610-
Inst2ColumnMatrix[Inst] = ColumnMatrixTy();
673+
Inst2ColumnMatrix[Inst] = ColumnMatrixTy().addNumStores(
674+
getNumOps(LM.getColumnTy()) * LM.getNumColumns());
611675

612676
ToRemove.push_back(Inst);
613677
}
@@ -668,8 +732,9 @@ class LowerMatrixIntrinsics {
668732
}
669733

670734
Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
671-
IRBuilder<> &Builder, bool AllowContraction) {
672-
735+
IRBuilder<> &Builder, bool AllowContraction,
736+
unsigned &NumComputeOps) {
737+
NumComputeOps += getNumOps(A->getType());
673738
if (!Sum)
674739
return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
675740

@@ -681,10 +746,12 @@ class LowerMatrixIntrinsics {
681746
Func.getParent(), Intrinsic::fmuladd, A->getType());
682747
return Builder.CreateCall(FMulAdd, {A, B, Sum});
683748
}
749+
NumComputeOps += getNumOps(A->getType());
684750
Value *Mul = Builder.CreateFMul(A, B);
685751
return Builder.CreateFAdd(Sum, Mul);
686752
}
687753

754+
NumComputeOps += getNumOps(A->getType());
688755
Value *Mul = Builder.CreateMul(A, B);
689756
return Builder.CreateAdd(Sum, Mul);
690757
}
@@ -738,6 +805,7 @@ class LowerMatrixIntrinsics {
738805

739806
bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) &&
740807
MatMul->hasAllowContract());
808+
unsigned NumComputeOps = 0;
741809
// Multiply columns from the first operand with scalars from the second
742810
// operand. Then move along the K axes and accumulate the columns. With
743811
// this the adds can be vectorized without reassociation.
@@ -754,11 +822,12 @@ class LowerMatrixIntrinsics {
754822
Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K);
755823
Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
756824
Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(),
757-
Builder, AllowContract);
825+
Builder, AllowContract, NumComputeOps);
758826
}
759827
Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder));
760828
}
761829
}
830+
Result.addNumComputeOps(NumComputeOps);
762831
finalizeLowering(MatMul, Result, Builder);
763832
}
764833

@@ -788,7 +857,13 @@ class LowerMatrixIntrinsics {
788857
Result.addColumn(ResultColumn);
789858
}
790859

791-
finalizeLowering(Inst, Result, Builder);
860+
// TODO: Improve estimate of operations needed for transposes. Currently we
861+
// just count the insertelement/extractelement instructions, but do not
862+
// account for later simplifications/combines.
863+
finalizeLowering(
864+
Inst,
865+
Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns),
866+
Builder);
792867
}
793868

794869
/// Lower load instructions, if shape information is available.
@@ -850,7 +925,10 @@ class LowerMatrixIntrinsics {
850925
Result.addColumn(
851926
BuildColumnOp(LoweredLhs.getColumn(C), LoweredRhs.getColumn(C)));
852927

853-
finalizeLowering(Inst, Result, Builder);
928+
finalizeLowering(Inst,
929+
Result.addNumComputeOps(getNumOps(Result.getColumnTy()) *
930+
Result.getNumColumns()),
931+
Builder);
854932
return true;
855933
}
856934

@@ -1116,6 +1194,23 @@ class LowerMatrixIntrinsics {
11161194
return Leaves;
11171195
}
11181196

1197+
/// Calculate the number of exclusive and shared op counts for expression
1198+
/// starting at \p V. Expressions used multiple times are counted once.
1199+
OpInfoTy sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs) {
1200+
auto CM = Inst2ColumnMatrix.find(Root);
1201+
if (CM == Inst2ColumnMatrix.end())
1202+
return {};
1203+
1204+
// Already counted this expression. Stop.
1205+
if (!ReusedExprs.insert(Root).second)
1206+
return {};
1207+
1208+
OpInfoTy Count = CM->second.getOpInfo();
1209+
for (Value *Op : cast<Instruction>(Root)->operand_values())
1210+
Count += sumOpInfos(Op, ReusedExprs);
1211+
return Count;
1212+
}
1213+
11191214
void emitRemarks() {
11201215
if (!ORE.allowExtraAnalysis(DEBUG_TYPE))
11211216
return;
@@ -1125,10 +1220,16 @@ class LowerMatrixIntrinsics {
11251220

11261221
// Generate remarks for each leaf.
11271222
for (auto *L : Leaves) {
1223+
SmallPtrSet<Value *, 8> ReusedExprs;
1224+
auto Counts = sumOpInfos(L, ReusedExprs);
11281225
OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered",
11291226
cast<Instruction>(L)->getDebugLoc(),
11301227
cast<Instruction>(L)->getParent());
1131-
Rem << "Lowered matrix expression ";
1228+
Rem << "Lowered with ";
1229+
Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
1230+
<< ore::NV("NumLoads", Counts.NumLoads) << " loads, "
1231+
<< ore::NV("NumComputeOps", Counts.NumComputeOps) << " compute ops";
1232+
11321233
Rem << ("\n" + linearize(L, DL));
11331234
ORE.emit(Rem);
11341235
}

llvm/test/Transforms/LowerMatrixIntrinsics/remarks.ll

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128"
44
target triple = "aarch64-apple-ios"
55

6-
; CHECK-LABEL: remark: test.h:40:20: Lowered matrix expression
6+
; CHECK-LABEL: remark: test.h:40:20: Lowered with 6 stores, 6 loads, 24 compute ops
77
; CHECK-NEXT: store(
88
; CHECK-NEXT: transpose.2x6.double(load(addr %A)),
99
; CHECK-NEXT: addr %B)
@@ -17,7 +17,7 @@ define void @transpose(<12 x double>* %A, <12 x double>* %B) !dbg !23 {
1717
declare <12 x double> @llvm.matrix.transpose.v12f64.v12f64(<12 x double>, i32, i32)
1818

1919

20-
; CHECK-LABEL: remark: test.h:50:20: Lowered matrix expression
20+
; CHECK-LABEL: remark: test.h:50:20: Lowered with 2 stores, 12 loads, 22 compute ops
2121
; CHECK-NEXT: store(
2222
; CHECK-NEXT: multiply.2x6.6x2.double(
2323
; CHECK-NEXT: load(addr %A),
@@ -33,7 +33,7 @@ define void @multiply(<12 x double>* %A, <12 x double>* %B, <4 x double>* %C) !d
3333

3434
declare <4 x double> @llvm.matrix.multiply(<12 x double>, <12 x double>, i32, i32, i32)
3535

36-
; CHECK-LABEL: remark: test.h:60:20: Lowered matrix expression
36+
; CHECK-LABEL: remark: test.h:60:20: Lowered with 6 stores, 6 loads, 0 compute ops
3737
; CHECK-NEXT: store(
3838
; CHECK-NEXT: columnwise.load.3x3.double(addr %A, 5),
3939
; CHECK-NEXT: addr %B)
@@ -45,7 +45,7 @@ define void @columnwise.load(<9 x double>* %A, <9 x double>* %B) !dbg !27 {
4545

4646
declare <9 x double> @llvm.matrix.columnwise.load(<9 x double>*, i32, i32, i32)
4747

48-
; CHECK-LABEL: remark: test.h:70:20: Lowered matrix expression
48+
; CHECK-LABEL: remark: test.h:70:20: Lowered with 6 stores, 6 loads, 0 compute ops
4949
; CHECK-NEXT: columnwise.store.3x3.double(
5050
; CHECK-NEXT: columnwise.load.3x3.double(addr %A, 5),
5151
; CHECK-NEXT: addr %B,
@@ -58,7 +58,7 @@ define void @columnwise.store(<9 x double>* %A, <9 x double>* %B) !dbg !29 {
5858

5959
declare void @llvm.matrix.columnwise.store(<9 x double>, <9 x double>*, i32, i32, i32)
6060

61-
; CHECK-LABEL: remark: test.h:80:20: Lowered matrix expression
61+
; CHECK-LABEL: remark: test.h:80:20: Lowered with 6 stores, 6 loads, 12 compute ops
6262
; CHECK-NEXT: columnwise.store.3x3.double(
6363
; CHECK-NEXT: fmul(
6464
; CHECK-NEXT: fadd(
@@ -76,7 +76,7 @@ define void @binaryops(<9 x double>* %A, <9 x double>* %B) !dbg !31 {
7676
ret void
7777
}
7878

79-
; CHECK-LABEL: remark: test.h:90:20: Lowered matrix expression
79+
; CHECK-LABEL: remark: test.h:90:20: Lowered with 6 stores, 6 loads, 12 compute ops
8080
; CHECK-NEXT: columnwise.store.3x3.double(
8181
; CHECK-NEXT: fmul(
8282
; CHECK-NEXT: fadd(
@@ -85,7 +85,7 @@ define void @binaryops(<9 x double>* %A, <9 x double>* %B) !dbg !31 {
8585
; CHECK-NEXT: (reused) columnwise.load.3x3.double(addr %A, 5)),
8686
; CHECK-NEXT: addr %B,
8787
; CHECK-NEXT: 10)
88-
; CHECK-NEXT: remark: test.h:90:20: Lowered matrix expression
88+
; CHECK-NEXT: remark: test.h:90:20: Lowered with 2 stores, 12 loads, 22 compute ops
8989
; CHECK-NEXT: store(
9090
; CHECK-NEXT: multiply.2x6.6x2.double(
9191
; CHECK-NEXT: load(addr %C),
@@ -106,7 +106,7 @@ define void @multiple_expressions(<9 x double>* %A, <9 x double>* %B, <12 x doub
106106
ret void
107107
}
108108

109-
; CHECK-LABEL: remark: test.h:100:20: Lowered matrix expression
109+
; CHECK-LABEL: remark: test.h:100:20: Lowered with 6 stores, 6 loads, 12 compute ops
110110
; CHECK-NEXT: columnwise.store.3x3.double(
111111
; CHECK-NEXT: fmul(
112112
; CHECK-NEXT: fadd(
@@ -124,7 +124,7 @@ define void @stackaddresses(<9 x double>* %A) !dbg !35 {
124124
ret void
125125
}
126126

127-
; CHECK-LABEL: remark: test.h:30:20: Lowered matrix expression
127+
; CHECK-LABEL: remark: test.h:30:20: Lowered with 10 stores, 9 loads, 30 compute ops
128128
; CHECK-NEXT: store(
129129
; CHECK-NEXT: transpose.5x3.double(load(addr %A)),
130130
; CHECK-NEXT: stack addr %s1)

0 commit comments

Comments
 (0)