Skip to content

Commit aff4ed7

Browse files
committed
[mlir][NFC] Update Operation::getResultTypes to use ArrayRef<Type> instead of iterator_range.
Summary: The new internal representation of operation results now allows for accessing the result types to be more efficient. Changing the API to ArrayRef is more efficient and removes the need to explicitly materialize vectors in several places. Differential Revision: https://reviews.llvm.org/D73429
1 parent ce674b1 commit aff4ed7

File tree

13 files changed

+43
-37
lines changed

13 files changed

+43
-37
lines changed

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,10 @@ operator<<(OpAsmPrinter &p,
191191
interleaveComma(types, p);
192192
return p;
193193
}
194+
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, ArrayRef<Type> types) {
195+
interleaveComma(types, p);
196+
return p;
197+
}
194198

195199
//===----------------------------------------------------------------------===//
196200
// OpAsmParser

mlir/include/mlir/IR/Operation.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,10 +260,10 @@ class Operation final
260260

261261
/// Support result type iteration.
262262
using result_type_iterator = result_range::type_iterator;
263-
using result_type_range = iterator_range<result_type_iterator>;
264-
result_type_iterator result_type_begin() { return result_begin(); }
265-
result_type_iterator result_type_end() { return result_end(); }
266-
result_type_range getResultTypes() { return getResults().getTypes(); }
263+
using result_type_range = ArrayRef<Type>;
264+
result_type_iterator result_type_begin() { return getResultTypes().begin(); }
265+
result_type_iterator result_type_end() { return getResultTypes().end(); }
266+
result_type_range getResultTypes();
267267

268268
//===--------------------------------------------------------------------===//
269269
// Attributes

mlir/include/mlir/IR/OperationSupport.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -595,8 +595,8 @@ class ResultRange final
595595
ResultRange(Operation *op);
596596

597597
/// Returns the types of the values within this range.
598-
using type_iterator = ValueTypeIterator<iterator>;
599-
iterator_range<type_iterator> getTypes() const { return {begin(), end()}; }
598+
using type_iterator = ArrayRef<Type>::iterator;
599+
ArrayRef<Type> getTypes() const;
600600

601601
private:
602602
/// See `indexed_accessor_range` for details.

mlir/lib/Analysis/InferTypeOpInterface.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
5353
op->getOperands(), op->getAttrs(),
5454
op->getRegions(), inferedReturnTypes)))
5555
return failure();
56-
SmallVector<Type, 4> resultTypes(op->getResultTypes());
57-
if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes))
56+
if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes,
57+
op->getResultTypes()))
5858
return op->emitOpError(
5959
"inferred type incompatible with return type of operation");
6060
return success();

mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -652,8 +652,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
652652

653653
Type packedType;
654654
if (numResults != 0) {
655-
packedType = this->lowering.packFunctionResults(
656-
llvm::to_vector<4>(op->getResultTypes()));
655+
packedType = this->lowering.packFunctionResults(op->getResultTypes());
657656
if (!packedType)
658657
return this->matchFailure();
659658
}

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,11 +292,11 @@ static void printCallOp(OpAsmPrinter &p, CallOp &op) {
292292
p.printOptionalAttrDict(op.getAttrs(), {"callee"});
293293

294294
// Reconstruct the function MLIR function type from operand and result types.
295-
SmallVector<Type, 1> resultTypes(op.getResultTypes());
296295
SmallVector<Type, 8> argTypes(
297296
llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1));
298297

299-
p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext());
298+
p << " : "
299+
<< FunctionType::get(argTypes, op.getResultTypes(), op.getContext());
300300
}
301301

302302
// <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`

mlir/lib/Dialect/SPIRV/SPIRVOps.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1685,9 +1685,8 @@ static ParseResult parseFunctionCallOp(OpAsmParser &parser,
16851685

16861686
static void print(spirv::FunctionCallOp functionCallOp, OpAsmPrinter &printer) {
16871687
SmallVector<Type, 4> argTypes(functionCallOp.getOperandTypes());
1688-
SmallVector<Type, 1> resultTypes(functionCallOp.getResultTypes());
1689-
Type functionType =
1690-
FunctionType::get(argTypes, resultTypes, functionCallOp.getContext());
1688+
Type functionType = FunctionType::get(
1689+
argTypes, functionCallOp.getResultTypes(), functionCallOp.getContext());
16911690

16921691
printer << spirv::FunctionCallOp::getOperationName() << ' '
16931692
<< functionCallOp.getAttr(kCallee) << '('

mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1764,12 +1764,9 @@ Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
17641764
auto funcName = op.callee();
17651765
uint32_t resTypeID = 0;
17661766

1767-
SmallVector<Type, 1> resultTypes(op.getResultTypes());
1768-
if (failed(processType(op.getLoc(),
1769-
(resultTypes.empty() ? getVoidType() : resultTypes[0]),
1770-
resTypeID))) {
1767+
Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
1768+
if (failed(processType(op.getLoc(), resultTy, resTypeID)))
17711769
return failure();
1772-
}
17731770

17741771
auto funcID = getOrCreateFunctionID(funcName);
17751772
auto funcCallID = getNextID();
@@ -1781,9 +1778,8 @@ Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
17811778
operands.push_back(valueID);
17821779
}
17831780

1784-
if (!resultTypes.empty()) {
1781+
if (!resultTy.isa<NoneType>())
17851782
valueIDMap[op.getResult(0)] = funcCallID;
1786-
}
17871783

17881784
return encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall,
17891785
operands);

mlir/lib/Dialect/StandardOps/Ops.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -500,9 +500,8 @@ static LogicalResult verify(CallOp op) {
500500
}
501501

502502
FunctionType CallOp::getCalleeType() {
503-
SmallVector<Type, 4> resultTypes(getResultTypes());
504503
SmallVector<Type, 8> argTypes(getOperandTypes());
505-
return FunctionType::get(argTypes, resultTypes, getContext());
504+
return FunctionType::get(argTypes, getResultTypes(), getContext());
506505
}
507506

508507
//===----------------------------------------------------------------------===//
@@ -522,8 +521,8 @@ struct SimplifyIndirectCallWithKnownCallee
522521
return matchFailure();
523522

524523
// Replace with a direct call.
525-
SmallVector<Type, 8> callResults(indirectCall.getResultTypes());
526-
rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn, callResults,
524+
rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
525+
indirectCall.getResultTypes(),
527526
indirectCall.getArgOperands());
528527
return matchSuccess();
529528
}

mlir/lib/IR/Operation.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,14 @@ unsigned Operation::getNumResults() {
551551
return hasSingleResult ? 1 : resultType.cast<TupleType>().size();
552552
}
553553

554+
auto Operation::getResultTypes() -> result_type_range {
555+
if (!resultType)
556+
return llvm::None;
557+
if (hasSingleResult)
558+
return resultType;
559+
return resultType.cast<TupleType>().getTypes();
560+
}
561+
554562
void Operation::setSuccessor(Block *block, unsigned index) {
555563
assert(index < getNumSuccessors());
556564
getBlockOperands()[index].set(block);
@@ -666,10 +674,9 @@ Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) {
666674
}
667675
}
668676

669-
SmallVector<Type, 8> resultTypes(getResultTypes());
670677
unsigned numRegions = getNumRegions();
671678
auto *newOp =
672-
Operation::create(getLoc(), getName(), resultTypes, operands, attrs,
679+
Operation::create(getLoc(), getName(), getResultTypes(), operands, attrs,
673680
successors, numRegions, hasResizableOperandsList());
674681

675682
// Remember the mapping of any results.
@@ -919,7 +926,7 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
919926

920927
auto type = op->getResult(0).getType();
921928
auto elementType = getElementTypeOrSelf(type);
922-
for (auto resultType : llvm::drop_begin(op->getResultTypes(), 1)) {
929+
for (auto resultType : op->getResultTypes().drop_front(1)) {
923930
if (getElementTypeOrSelf(resultType) != elementType ||
924931
failed(verifyCompatibleShape(resultType, type)))
925932
return op->emitOpError()

mlir/lib/IR/OperationSupport.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@ OperandRange::OperandRange(Operation *op)
152152
ResultRange::ResultRange(Operation *op)
153153
: ResultRange(op, /*startIndex=*/0, op->getNumResults()) {}
154154

155+
ArrayRef<Type> ResultRange::getTypes() const {
156+
return getBase()->getResultTypes();
157+
}
158+
155159
/// See `indexed_accessor_range` for details.
156160
OpResult ResultRange::dereference(Operation *op, ptrdiff_t index) {
157161
return op->getResult(index);

mlir/lib/Transforms/CSE.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,9 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
3737
// - Attributes
3838
// - Result Types
3939
// - Operands
40-
return hash_combine(
41-
op->getName(), op->getAttrList().getDictionary(),
42-
hash_combine_range(op->result_type_begin(), op->result_type_end()),
43-
hash_combine_range(op->operand_begin(), op->operand_end()));
40+
return llvm::hash_combine(
41+
op->getName(), op->getAttrList().getDictionary(), op->getResultTypes(),
42+
llvm::hash_combine_range(op->operand_begin(), op->operand_end()));
4443
}
4544
static bool isEqual(const Operation *lhsC, const Operation *rhsC) {
4645
auto *lhs = const_cast<Operation *>(lhsC);

mlir/test/lib/TestDialect/TestPatterns.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ struct TestChangeProducerTypeI32ToF32 : public ConversionPattern {
241241
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
242242
ConversionPatternRewriter &rewriter) const final {
243243
// If the type is I32, change the type to F32.
244-
if (!(*op->result_type_begin()).isInteger(32))
244+
if (!Type(*op->result_type_begin()).isInteger(32))
245245
return matchFailure();
246246
rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type());
247247
return matchSuccess();
@@ -254,7 +254,7 @@ struct TestChangeProducerTypeF32ToF64 : public ConversionPattern {
254254
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
255255
ConversionPatternRewriter &rewriter) const final {
256256
// If the type is F32, change the type to F64.
257-
if (!(*op->result_type_begin()).isF32())
257+
if (!Type(*op->result_type_begin()).isF32())
258258
return matchFailure();
259259
rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type());
260260
return matchSuccess();
@@ -477,8 +477,7 @@ struct OneVResOneVOperandOp1Converter
477477
remappedOperands.push_back(rewriter.getRemappedValue(origOp));
478478
remappedOperands.push_back(rewriter.getRemappedValue(origOp));
479479

480-
SmallVector<Type, 1> resultTypes(op.getResultTypes());
481-
rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, resultTypes,
480+
rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(),
482481
remappedOperands);
483482
return matchSuccess();
484483
}

0 commit comments

Comments
 (0)