10
10
//
11
11
// TODO:
12
12
// * Implement multiply & add fusion
13
- // * Add remark, summarizing the available matrix optimization opportunities.
13
+ // * Add remark, summarizing the available matrix optimization opportunities
14
+ // (WIP).
14
15
//
15
16
// ===----------------------------------------------------------------------===//
16
17
17
18
#include " llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
18
19
#include " llvm/ADT/GraphTraits.h"
19
20
#include " llvm/ADT/PostOrderIterator.h"
20
21
#include " llvm/ADT/SmallVector.h"
22
+ #include " llvm/Analysis/OptimizationRemarkEmitter.h"
21
23
#include " llvm/Analysis/TargetTransformInfo.h"
24
+ #include " llvm/Analysis/ValueTracking.h"
22
25
#include " llvm/Analysis/VectorUtils.h"
23
26
#include " llvm/IR/CFG.h"
24
27
#include " llvm/IR/DataLayout.h"
@@ -136,6 +139,7 @@ class LowerMatrixIntrinsics {
136
139
Function &Func;
137
140
const DataLayout &DL;
138
141
const TargetTransformInfo &TTI;
142
+ OptimizationRemarkEmitter &ORE;
139
143
140
144
// / Wrapper class representing a matrix as a set of column vectors.
141
145
// / All column vectors must have the same vector type.
@@ -213,11 +217,12 @@ class LowerMatrixIntrinsics {
213
217
SmallVector<Instruction *, 16 > ToRemove;
214
218
215
219
// / Map from instructions to their produced column matrix.
216
- DenseMap <Value *, ColumnMatrixTy> Inst2ColumnMatrix;
220
+ MapVector <Value *, ColumnMatrixTy> Inst2ColumnMatrix;
217
221
218
222
public:
219
- LowerMatrixIntrinsics (Function &F, TargetTransformInfo &TTI)
220
- : Func(F), DL(F.getParent()->getDataLayout ()), TTI(TTI) {}
223
+ LowerMatrixIntrinsics (Function &F, TargetTransformInfo &TTI,
224
+ OptimizationRemarkEmitter &ORE)
225
+ : Func(F), DL(F.getParent()->getDataLayout ()), TTI(TTI), ORE(ORE) {}
221
226
222
227
// / Return the set of column vectors that a matrix value is lowered to.
223
228
// /
@@ -509,6 +514,9 @@ class LowerMatrixIntrinsics {
509
514
}
510
515
}
511
516
517
+ RemarkGenerator RemarkGen (Inst2ColumnMatrix, ORE, DL);
518
+ RemarkGen.emitRemarks ();
519
+
512
520
for (Instruction *Inst : reverse (ToRemove))
513
521
Inst->eraseFromParent ();
514
522
@@ -599,6 +607,7 @@ class LowerMatrixIntrinsics {
599
607
Shape.NumRows , VType->getElementType (), Builder);
600
608
createColumnStore (C.value (), GEP, VType->getElementType (), Builder);
601
609
}
610
+ Inst2ColumnMatrix[Inst] = ColumnMatrixTy ();
602
611
603
612
ToRemove.push_back (Inst);
604
613
}
@@ -844,13 +853,301 @@ class LowerMatrixIntrinsics {
844
853
finalizeLowering (Inst, Result, Builder);
845
854
return true ;
846
855
}
856
+
857
+ // / Helper to linearize a matrix expression tree into a string. Currently
858
+ // / matrix expressions are linarized by starting at an expression leaf and
859
+ // / linearizing bottom up.
860
+ struct ExprLinearizer {
861
+ unsigned LengthToBreak = 100 ;
862
+ std::string Str;
863
+ raw_string_ostream Stream;
864
+ unsigned LineLength = 0 ;
865
+ const DataLayout &DL;
866
+
867
+ // / Mapping from instructions to column matrixes. It is used to identify
868
+ // / matrix instructions.
869
+ const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix;
870
+
871
+ // / Used to keep track of sub-expressions that get reused while linearizing
872
+ // / the expression. Re-used sub-expressions are marked as (reused).
873
+ SmallPtrSet<Value *, 8 > ReusedExprs;
874
+
875
+ ExprLinearizer (const DataLayout &DL,
876
+ const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix)
877
+ : Str(), Stream(Str), DL(DL), Inst2ColumnMatrix(Inst2ColumnMatrix) {}
878
+
879
+ void indent (unsigned N) {
880
+ LineLength += N;
881
+ for (unsigned i = 0 ; i < N; i++)
882
+ Stream << " " ;
883
+ }
884
+
885
+ void lineBreak () {
886
+ Stream << " \n " ;
887
+ LineLength = 0 ;
888
+ }
889
+
890
+ void maybeIndent (unsigned Indent) {
891
+ if (LineLength >= LengthToBreak)
892
+ lineBreak ();
893
+
894
+ if (LineLength == 0 )
895
+ indent (Indent);
896
+ }
897
+
898
+ void write (const std::string &S) {
899
+ LineLength += S.size ();
900
+ Stream << S;
901
+ }
902
+
903
+ Value *getUnderlyingObjectThroughLoads (Value *V) {
904
+ if (Value *Ptr = getPointerOperand (V))
905
+ return getUnderlyingObjectThroughLoads (Ptr);
906
+ else if (V->getType ()->isPointerTy ())
907
+ return GetUnderlyingObject (V, DL);
908
+ return V;
909
+ }
910
+
911
+ // / Returns true if \p V is a matrix value.
912
+ bool isMatrix (Value *V) const {
913
+ return Inst2ColumnMatrix.find (V) != Inst2ColumnMatrix.end ();
914
+ }
915
+
916
+ // / If \p V is a matrix value, print its shape as as NumRows x NumColumns to
917
+ // / \p SS.
918
+ void prettyPrintMatrixType (Value *V, raw_string_ostream &SS) {
919
+ auto M = Inst2ColumnMatrix.find (V);
920
+ if (M == Inst2ColumnMatrix.end ())
921
+ SS << " unknown" ;
922
+ else {
923
+ SS << M->second .getNumRows ();
924
+ SS << " x" ;
925
+ SS << M->second .getNumColumns ();
926
+ }
927
+ }
928
+
929
+ // / Write the called function name. Handles calls to llvm.matrix.*
930
+ // / specially: we write the name, followed by the dimensions of the input
931
+ // / matrixes, followed by the scalar type name.
932
+ void writeFnName (CallInst *CI) {
933
+ if (!CI->getCalledFunction ())
934
+ write (" <no called fn>" );
935
+ else {
936
+ StringRef Name = CI->getCalledFunction ()->getName ();
937
+ if (!Name.startswith (" llvm.matrix" )) {
938
+ write (Name);
939
+ return ;
940
+ }
941
+ IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
942
+ write (StringRef (Intrinsic::getName (II->getIntrinsicID (), {}))
943
+ .drop_front (StringRef (" llvm.matrix." ).size ()));
944
+ write (" ." );
945
+ std::string Tmp = " " ;
946
+ raw_string_ostream SS (Tmp);
947
+
948
+ switch (II->getIntrinsicID ()) {
949
+ case Intrinsic::matrix_multiply:
950
+ prettyPrintMatrixType (II->getOperand (0 ), SS);
951
+ SS << " ." ;
952
+ prettyPrintMatrixType (II->getOperand (1 ), SS);
953
+ SS << " ." << *II->getType ()->getScalarType ();
954
+ break ;
955
+ case Intrinsic::matrix_transpose:
956
+ prettyPrintMatrixType (II->getOperand (0 ), SS);
957
+ SS << " ." << *II->getType ()->getScalarType ();
958
+ break ;
959
+ case Intrinsic::matrix_columnwise_load:
960
+ prettyPrintMatrixType (II, SS);
961
+ SS << " ." << *II->getType ()->getScalarType ();
962
+ break ;
963
+ case Intrinsic::matrix_columnwise_store:
964
+ prettyPrintMatrixType (II->getOperand (0 ), SS);
965
+ SS << " ." << *II->getOperand (0 )->getType ()->getScalarType ();
966
+ break ;
967
+ default :
968
+ llvm_unreachable (" Unhandled case" );
969
+ }
970
+ SS.flush ();
971
+ write (Tmp);
972
+ }
973
+ }
974
+
975
+ unsigned getNumShapeArgs (CallInst *CI) const {
976
+ if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
977
+ switch (II->getIntrinsicID ()) {
978
+ case Intrinsic::matrix_multiply:
979
+ return 3 ;
980
+ case Intrinsic::matrix_transpose:
981
+ case Intrinsic::matrix_columnwise_load:
982
+ case Intrinsic::matrix_columnwise_store:
983
+ return 2 ;
984
+ default :
985
+ return 0 ;
986
+ }
987
+ }
988
+ return 0 ;
989
+ }
990
+
991
+ // / Special printing for values: for pointers, we print if they refer to an
992
+ // / (function) external address or a stack address, for other values we
993
+ // / either print the constant or "scalar"/"matrix" for other values.
994
+ void write (Value *V) {
995
+ V = getUnderlyingObjectThroughLoads (V);
996
+ if (V->getType ()->isPointerTy ()) {
997
+ if (isa<AllocaInst>(V)) {
998
+ Stream << " stack addr" ;
999
+ LineLength += StringRef (" stack addr" ).size ();
1000
+ } else {
1001
+ Stream << " addr" ;
1002
+ LineLength += StringRef (" addr" ).size ();
1003
+ }
1004
+ if (!V->getName ().empty ()) {
1005
+ Stream << " %" << V->getName () << " " ;
1006
+ LineLength += V->getName ().size () + 2 ;
1007
+ }
1008
+ return ;
1009
+ }
1010
+
1011
+ std::string Tmp;
1012
+ raw_string_ostream TmpStream (Tmp);
1013
+
1014
+ if (auto *CI = dyn_cast<ConstantInt>(V))
1015
+ TmpStream << CI->getValue ();
1016
+ else if (isa<Constant>(V))
1017
+ TmpStream << " constant" ;
1018
+ else {
1019
+ if (isMatrix (V))
1020
+ TmpStream << " matrix" ;
1021
+ else
1022
+ TmpStream << " scalar" ;
1023
+ }
1024
+ TmpStream.flush ();
1025
+ Tmp = StringRef (Tmp).trim ();
1026
+ LineLength += Tmp.size ();
1027
+ Stream << Tmp;
1028
+ }
1029
+
1030
+ // / Linearize expression \p Expr starting at an indentation of \p Indent.
1031
+ // / Expressions that are re-used multiple times are prefixed with (reused)
1032
+ // / at the re-used root instruction.
1033
+ void linearizeExpr (Value *Expr, unsigned Indent, bool ParentReused) {
1034
+ auto *I = cast<Instruction>(Expr);
1035
+ maybeIndent (Indent);
1036
+ SmallVector<Value *, 8 > Ops;
1037
+
1038
+ bool Reused = !ReusedExprs.insert (Expr).second ;
1039
+ if (Reused && !ParentReused)
1040
+ write (" (reused) " );
1041
+
1042
+ if (auto *CI = dyn_cast<CallInst>(I)) {
1043
+ writeFnName (CI);
1044
+
1045
+ Ops.append (CallSite (CI).arg_begin (),
1046
+ CallSite (CI).arg_end () - getNumShapeArgs (CI));
1047
+ } else if (isa<BitCastInst>(Expr)) {
1048
+ // Special case bitcasts, which are used to materialize matrixes from
1049
+ // non-matrix ops.
1050
+ write (" matrix" );
1051
+ return ;
1052
+ } else {
1053
+ Ops.append (I->value_op_begin (), I->value_op_end ());
1054
+ write (std::string (I->getOpcodeName ()));
1055
+ }
1056
+
1057
+ write (std::string (" (" ));
1058
+
1059
+ unsigned NumOpsToBreak = 1 ;
1060
+ if (match (Expr, m_Intrinsic<Intrinsic::matrix_columnwise_load>()))
1061
+ NumOpsToBreak = 2 ;
1062
+
1063
+ for (Value *Op : Ops) {
1064
+ if (Ops.size () > NumOpsToBreak)
1065
+ lineBreak ();
1066
+
1067
+ maybeIndent (Indent + 1 );
1068
+ if (isMatrix (Op))
1069
+ linearizeExpr (Op, Indent + 1 , Reused);
1070
+ else
1071
+ write (Op);
1072
+ if (Op != Ops.back ())
1073
+ write (" , " );
1074
+ }
1075
+
1076
+ write (" )" );
1077
+ }
1078
+
1079
+ const std::string &getResult () {
1080
+ Stream.flush ();
1081
+ return Str;
1082
+ }
1083
+ };
1084
+
1085
+ // / Generate remarks for matrix operations in a function. To generate remarks
1086
+ // / for matrix expressions, the following approach is used:
1087
+ // / 1. Collect leafs of matrix expressions (done in
1088
+ // / RemarkGenerator::getExpressionLeaves). Leaves are lowered matrix
1089
+ // / instructions without other matrix users (like stores).
1090
+ // /
1091
+ // / 2. For each leaf, create a remark containing a linearizied version of the
1092
+ // / matrix expression.
1093
+ // /
1094
+ // / TODO:
1095
+ // / * Summarize number of vector instructions generated for each expression.
1096
+ // / * Account for shared sub-expressions.
1097
+ // / * Propagate matrix remarks up the inlining chain.
1098
+ struct RemarkGenerator {
1099
+ const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix;
1100
+ OptimizationRemarkEmitter &ORE;
1101
+ const DataLayout &DL;
1102
+
1103
+ RemarkGenerator (const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix,
1104
+ OptimizationRemarkEmitter &ORE, const DataLayout &DL)
1105
+ : Inst2ColumnMatrix(Inst2ColumnMatrix), ORE(ORE), DL(DL) {}
1106
+
1107
+ // / Return all leafs of matrix expressions. Those are instructions in
1108
+ // / Inst2ColumnMatrix returing void. Currently that should only include
1109
+ // / stores.
1110
+ SmallVector<Value *, 4 > getExpressionLeaves () {
1111
+ SmallVector<Value *, 4 > Leaves;
1112
+ for (auto &KV : Inst2ColumnMatrix)
1113
+ if (KV.first ->getType ()->isVoidTy ())
1114
+ Leaves.push_back (KV.first );
1115
+
1116
+ return Leaves;
1117
+ }
1118
+
1119
+ void emitRemarks () {
1120
+ if (!ORE.allowExtraAnalysis (DEBUG_TYPE))
1121
+ return ;
1122
+
1123
+ // Find leafs of matrix expressions.
1124
+ auto Leaves = getExpressionLeaves ();
1125
+
1126
+ // Generate remarks for each leaf.
1127
+ for (auto *L : Leaves) {
1128
+ OptimizationRemark Rem (DEBUG_TYPE, " matrix-lowered" ,
1129
+ cast<Instruction>(L)->getDebugLoc (),
1130
+ cast<Instruction>(L)->getParent ());
1131
+ Rem << " Lowered matrix expression " ;
1132
+ Rem << (" \n " + linearize (L, DL));
1133
+ ORE.emit (Rem);
1134
+ }
1135
+ }
1136
+
1137
+ std::string linearize (Value *L, const DataLayout &DL) {
1138
+ ExprLinearizer Lin (DL, Inst2ColumnMatrix);
1139
+ Lin.linearizeExpr (L, 0 , false );
1140
+ return Lin.getResult ();
1141
+ }
1142
+ };
847
1143
};
848
1144
} // namespace
849
1145
850
1146
PreservedAnalyses LowerMatrixIntrinsicsPass::run (Function &F,
851
1147
FunctionAnalysisManager &AM) {
852
1148
auto &TTI = AM.getResult <TargetIRAnalysis>(F);
853
- LowerMatrixIntrinsics LMT (F, TTI);
1149
+ auto &ORE = AM.getResult <OptimizationRemarkEmitterAnalysis>(F);
1150
+ LowerMatrixIntrinsics LMT (F, TTI, ORE);
854
1151
if (LMT.Visit ()) {
855
1152
PreservedAnalyses PA;
856
1153
PA.preserveSet <CFGAnalyses>();
@@ -871,14 +1168,16 @@ class LowerMatrixIntrinsicsLegacyPass : public FunctionPass {
871
1168
}
872
1169
873
1170
bool runOnFunction (Function &F) override {
874
- auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI (F);
875
- LowerMatrixIntrinsics LMT (F, *TTI);
1171
+ auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI (F);
1172
+ auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE ();
1173
+ LowerMatrixIntrinsics LMT (F, TTI, ORE);
876
1174
bool C = LMT.Visit ();
877
1175
return C;
878
1176
}
879
1177
880
1178
void getAnalysisUsage (AnalysisUsage &AU) const override {
881
1179
AU.addRequired <TargetTransformInfoWrapperPass>();
1180
+ AU.addRequired <OptimizationRemarkEmitterWrapperPass>();
882
1181
AU.setPreservesCFG ();
883
1182
}
884
1183
};
@@ -888,6 +1187,7 @@ static const char pass_name[] = "Lower the matrix intrinsics";
888
1187
char LowerMatrixIntrinsicsLegacyPass::ID = 0 ;
889
1188
INITIALIZE_PASS_BEGIN (LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
890
1189
false , false )
1190
+ INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
891
1191
INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
892
1192
false , false )
893
1193
0 commit comments