@@ -141,11 +141,30 @@ class LowerMatrixIntrinsics {
141
141
const TargetTransformInfo &TTI;
142
142
OptimizationRemarkEmitter &ORE;
143
143
144
+ // / Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
145
+ struct OpInfoTy {
146
+ // / Number of stores emitted to generate this matrix.
147
+ unsigned NumStores = 0 ;
148
+ // / Number of loads emitted to generate this matrix.
149
+ unsigned NumLoads = 0 ;
150
+ // / Number of compute operations emitted to generate this matrix.
151
+ unsigned NumComputeOps = 0 ;
152
+
153
+ OpInfoTy &operator +=(const OpInfoTy &RHS) {
154
+ NumStores += RHS.NumStores ;
155
+ NumLoads += RHS.NumLoads ;
156
+ NumComputeOps += RHS.NumComputeOps ;
157
+ return *this ;
158
+ }
159
+ };
160
+
144
161
// / Wrapper class representing a matrix as a set of column vectors.
145
162
// / All column vectors must have the same vector type.
146
163
class ColumnMatrixTy {
147
164
SmallVector<Value *, 16 > Columns;
148
165
166
+ OpInfoTy OpInfo;
167
+
149
168
public:
150
169
ColumnMatrixTy () : Columns() {}
151
170
ColumnMatrixTy (ArrayRef<Value *> Cols)
@@ -167,6 +186,10 @@ class LowerMatrixIntrinsics {
167
186
168
187
void addColumn (Value *V) { Columns.push_back (V); }
169
188
189
+ VectorType *getColumnTy () {
190
+ return cast<VectorType>(Columns[0 ]->getType ());
191
+ }
192
+
170
193
iterator_range<SmallVector<Value *, 8 >::iterator> columns () {
171
194
return make_range (Columns.begin (), Columns.end ());
172
195
}
@@ -177,6 +200,29 @@ class LowerMatrixIntrinsics {
177
200
return Columns.size () == 1 ? Columns[0 ]
178
201
: concatenateVectors (Builder, Columns);
179
202
}
203
+
204
+ ColumnMatrixTy &addNumLoads (unsigned N) {
205
+ OpInfo.NumLoads += N;
206
+ return *this ;
207
+ }
208
+
209
+ void setNumLoads (unsigned N) { OpInfo.NumLoads = N; }
210
+
211
+ ColumnMatrixTy &addNumStores (unsigned N) {
212
+ OpInfo.NumStores += N;
213
+ return *this ;
214
+ }
215
+
216
+ ColumnMatrixTy &addNumComputeOps (unsigned N) {
217
+ OpInfo.NumComputeOps += N;
218
+ return *this ;
219
+ }
220
+
221
+ unsigned getNumStores () const { return OpInfo.NumStores ; }
222
+ unsigned getNumLoads () const { return OpInfo.NumLoads ; }
223
+ unsigned getNumComputeOps () const { return OpInfo.NumComputeOps ; }
224
+
225
+ const OpInfoTy &getOpInfo () const { return OpInfo; }
180
226
};
181
227
182
228
struct ShapeInfo {
@@ -224,6 +270,20 @@ class LowerMatrixIntrinsics {
224
270
OptimizationRemarkEmitter &ORE)
225
271
: Func(F), DL(F.getParent()->getDataLayout ()), TTI(TTI), ORE(ORE) {}
226
272
273
+ unsigned getNumOps (Type *VT) {
274
+ assert (isa<VectorType>(VT) && " Expected vector type" );
275
+ return getNumOps (VT->getScalarType (),
276
+ cast<VectorType>(VT)->getNumElements ());
277
+ }
278
+
279
+ //
280
+ // / Return the estimated number of vector ops required for an operation on
281
+ // / \p VT * N.
282
+ unsigned getNumOps (Type *ST, unsigned N) {
283
+ return std::ceil ((ST->getPrimitiveSizeInBits () * N).getFixedSize () /
284
+ double (TTI.getRegisterBitWidth (true )));
285
+ }
286
+
227
287
// / Return the set of column vectors that a matrix value is lowered to.
228
288
// /
229
289
// / If we lowered \p MatrixVal, just return the cache result column matrix.
@@ -582,7 +642,10 @@ class LowerMatrixIntrinsics {
582
642
Result.addColumn (Column);
583
643
}
584
644
585
- finalizeLowering (Inst, Result, Builder);
645
+ finalizeLowering (Inst,
646
+ Result.addNumLoads (getNumOps (Result.getColumnTy ()) *
647
+ Result.getNumColumns ()),
648
+ Builder);
586
649
}
587
650
588
651
// / Lowers llvm.matrix.columnwise.load.
@@ -607,7 +670,8 @@ class LowerMatrixIntrinsics {
607
670
Shape.NumRows , VType->getElementType (), Builder);
608
671
createColumnStore (C.value (), GEP, VType->getElementType (), Builder);
609
672
}
610
- Inst2ColumnMatrix[Inst] = ColumnMatrixTy ();
673
+ Inst2ColumnMatrix[Inst] = ColumnMatrixTy ().addNumStores (
674
+ getNumOps (LM.getColumnTy ()) * LM.getNumColumns ());
611
675
612
676
ToRemove.push_back (Inst);
613
677
}
@@ -668,8 +732,9 @@ class LowerMatrixIntrinsics {
668
732
}
669
733
670
734
Value *createMulAdd (Value *Sum, Value *A, Value *B, bool UseFPOp,
671
- IRBuilder<> &Builder, bool AllowContraction) {
672
-
735
+ IRBuilder<> &Builder, bool AllowContraction,
736
+ unsigned &NumComputeOps) {
737
+ NumComputeOps += getNumOps (A->getType ());
673
738
if (!Sum)
674
739
return UseFPOp ? Builder.CreateFMul (A, B) : Builder.CreateMul (A, B);
675
740
@@ -681,10 +746,12 @@ class LowerMatrixIntrinsics {
681
746
Func.getParent (), Intrinsic::fmuladd, A->getType ());
682
747
return Builder.CreateCall (FMulAdd, {A, B, Sum});
683
748
}
749
+ NumComputeOps += getNumOps (A->getType ());
684
750
Value *Mul = Builder.CreateFMul (A, B);
685
751
return Builder.CreateFAdd (Sum, Mul);
686
752
}
687
753
754
+ NumComputeOps += getNumOps (A->getType ());
688
755
Value *Mul = Builder.CreateMul (A, B);
689
756
return Builder.CreateAdd (Sum, Mul);
690
757
}
@@ -738,6 +805,7 @@ class LowerMatrixIntrinsics {
738
805
739
806
bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) &&
740
807
MatMul->hasAllowContract ());
808
+ unsigned NumComputeOps = 0 ;
741
809
// Multiply columns from the first operand with scalars from the second
742
810
// operand. Then move along the K axes and accumulate the columns. With
743
811
// this the adds can be vectorized without reassociation.
@@ -754,11 +822,12 @@ class LowerMatrixIntrinsics {
754
822
Value *RH = Builder.CreateExtractElement (Rhs.getColumn (J), K);
755
823
Value *Splat = Builder.CreateVectorSplat (BlockSize, RH, " splat" );
756
824
Sum = createMulAdd (Sum, L, Splat, EltType->isFloatingPointTy (),
757
- Builder, AllowContract);
825
+ Builder, AllowContract, NumComputeOps );
758
826
}
759
827
Result.setColumn (J, insertVector (Result.getColumn (J), I, Sum, Builder));
760
828
}
761
829
}
830
+ Result.addNumComputeOps (NumComputeOps);
762
831
finalizeLowering (MatMul, Result, Builder);
763
832
}
764
833
@@ -788,7 +857,13 @@ class LowerMatrixIntrinsics {
788
857
Result.addColumn (ResultColumn);
789
858
}
790
859
791
- finalizeLowering (Inst, Result, Builder);
860
+ // TODO: Improve estimate of operations needed for transposes. Currently we
861
+ // just count the insertelement/extractelement instructions, but do not
862
+ // account for later simplifications/combines.
863
+ finalizeLowering (
864
+ Inst,
865
+ Result.addNumComputeOps (2 * ArgShape.NumRows * ArgShape.NumColumns ),
866
+ Builder);
792
867
}
793
868
794
869
// / Lower load instructions, if shape information is available.
@@ -850,7 +925,10 @@ class LowerMatrixIntrinsics {
850
925
Result.addColumn (
851
926
BuildColumnOp (LoweredLhs.getColumn (C), LoweredRhs.getColumn (C)));
852
927
853
- finalizeLowering (Inst, Result, Builder);
928
+ finalizeLowering (Inst,
929
+ Result.addNumComputeOps (getNumOps (Result.getColumnTy ()) *
930
+ Result.getNumColumns ()),
931
+ Builder);
854
932
return true ;
855
933
}
856
934
@@ -1116,6 +1194,23 @@ class LowerMatrixIntrinsics {
1116
1194
return Leaves;
1117
1195
}
1118
1196
1197
+ // / Calculate the number of exclusive and shared op counts for expression
1198
+ // / starting at \p V. Expressions used multiple times are counted once.
1199
+ OpInfoTy sumOpInfos (Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs) {
1200
+ auto CM = Inst2ColumnMatrix.find (Root);
1201
+ if (CM == Inst2ColumnMatrix.end ())
1202
+ return {};
1203
+
1204
+ // Already counted this expression. Stop.
1205
+ if (!ReusedExprs.insert (Root).second )
1206
+ return {};
1207
+
1208
+ OpInfoTy Count = CM->second .getOpInfo ();
1209
+ for (Value *Op : cast<Instruction>(Root)->operand_values ())
1210
+ Count += sumOpInfos (Op, ReusedExprs);
1211
+ return Count;
1212
+ }
1213
+
1119
1214
void emitRemarks () {
1120
1215
if (!ORE.allowExtraAnalysis (DEBUG_TYPE))
1121
1216
return ;
@@ -1125,10 +1220,16 @@ class LowerMatrixIntrinsics {
1125
1220
1126
1221
// Generate remarks for each leaf.
1127
1222
for (auto *L : Leaves) {
1223
+ SmallPtrSet<Value *, 8 > ReusedExprs;
1224
+ auto Counts = sumOpInfos (L, ReusedExprs);
1128
1225
OptimizationRemark Rem (DEBUG_TYPE, " matrix-lowered" ,
1129
1226
cast<Instruction>(L)->getDebugLoc (),
1130
1227
cast<Instruction>(L)->getParent ());
1131
- Rem << " Lowered matrix expression " ;
1228
+ Rem << " Lowered with " ;
1229
+ Rem << ore::NV (" NumStores" , Counts.NumStores ) << " stores, "
1230
+ << ore::NV (" NumLoads" , Counts.NumLoads ) << " loads, "
1231
+ << ore::NV (" NumComputeOps" , Counts.NumComputeOps ) << " compute ops" ;
1232
+
1132
1233
Rem << (" \n " + linearize (L, DL));
1133
1234
ORE.emit (Rem);
1134
1235
}
0 commit comments