diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -191,6 +191,10 @@ interleaveComma(types, p); return p; } +inline OpAsmPrinter &operator<<(OpAsmPrinter &p, ArrayRef types) { + interleaveComma(types, p); + return p; +} //===----------------------------------------------------------------------===// // OpAsmParser diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -260,10 +260,10 @@ /// Support result type iteration. using result_type_iterator = result_range::type_iterator; - using result_type_range = iterator_range; - result_type_iterator result_type_begin() { return result_begin(); } - result_type_iterator result_type_end() { return result_end(); } - result_type_range getResultTypes() { return getResults().getTypes(); } + using result_type_range = ArrayRef; + result_type_iterator result_type_begin() { return getResultTypes().begin(); } + result_type_iterator result_type_end() { return getResultTypes().end(); } + result_type_range getResultTypes(); //===--------------------------------------------------------------------===// // Attributes diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -595,8 +595,8 @@ ResultRange(Operation *op); /// Returns the types of the values within this range. - using type_iterator = ValueTypeIterator; - iterator_range getTypes() const { return {begin(), end()}; } + using type_iterator = ArrayRef::iterator; + ArrayRef getTypes() const; private: /// See `indexed_accessor_range` for details. diff --git a/mlir/lib/Analysis/InferTypeOpInterface.cpp b/mlir/lib/Analysis/InferTypeOpInterface.cpp --- a/mlir/lib/Analysis/InferTypeOpInterface.cpp +++ b/mlir/lib/Analysis/InferTypeOpInterface.cpp @@ -53,8 +53,8 @@ op->getOperands(), op->getAttrs(), op->getRegions(), inferedReturnTypes))) return failure(); - SmallVector resultTypes(op->getResultTypes()); - if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes)) + if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, + op->getResultTypes())) return op->emitOpError( "inferred type incompatible with return type of operation"); return success(); diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -652,8 +652,7 @@ Type packedType; if (numResults != 0) { - packedType = this->lowering.packFunctionResults( - llvm::to_vector<4>(op->getResultTypes())); + packedType = this->lowering.packFunctionResults(op->getResultTypes()); if (!packedType) return this->matchFailure(); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -292,11 +292,11 @@ p.printOptionalAttrDict(op.getAttrs(), {"callee"}); // Reconstruct the function MLIR function type from operand and result types. - SmallVector resultTypes(op.getResultTypes()); SmallVector argTypes( llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1)); - p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext()); + p << " : " + << FunctionType::get(argTypes, op.getResultTypes(), op.getContext()); } // ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1685,9 +1685,8 @@ static void print(spirv::FunctionCallOp functionCallOp, OpAsmPrinter &printer) { SmallVector argTypes(functionCallOp.getOperandTypes()); - SmallVector resultTypes(functionCallOp.getResultTypes()); - Type functionType = - FunctionType::get(argTypes, resultTypes, functionCallOp.getContext()); + Type functionType = FunctionType::get( + argTypes, functionCallOp.getResultTypes(), functionCallOp.getContext()); printer << spirv::FunctionCallOp::getOperationName() << ' ' << functionCallOp.getAttr(kCallee) << '(' diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -1764,12 +1764,9 @@ auto funcName = op.callee(); uint32_t resTypeID = 0; - SmallVector resultTypes(op.getResultTypes()); - if (failed(processType(op.getLoc(), - (resultTypes.empty() ? getVoidType() : resultTypes[0]), - resTypeID))) { + Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType(); + if (failed(processType(op.getLoc(), resultTy, resTypeID))) return failure(); - } auto funcID = getOrCreateFunctionID(funcName); auto funcCallID = getNextID(); @@ -1781,9 +1778,8 @@ operands.push_back(valueID); } - if (!resultTypes.empty()) { + if (!resultTy.isa()) valueIDMap[op.getResult(0)] = funcCallID; - } return encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands); diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -500,9 +500,8 @@ } FunctionType CallOp::getCalleeType() { - SmallVector resultTypes(getResultTypes()); SmallVector argTypes(getOperandTypes()); - return FunctionType::get(argTypes, resultTypes, getContext()); + return FunctionType::get(argTypes, getResultTypes(), getContext()); } //===----------------------------------------------------------------------===// @@ -522,8 +521,8 @@ return matchFailure(); // Replace with a direct call. - SmallVector callResults(indirectCall.getResultTypes()); - rewriter.replaceOpWithNewOp(indirectCall, calledFn, callResults, + rewriter.replaceOpWithNewOp(indirectCall, calledFn, + indirectCall.getResultTypes(), indirectCall.getArgOperands()); return matchSuccess(); } diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -551,6 +551,14 @@ return hasSingleResult ? 1 : resultType.cast().size(); } +auto Operation::getResultTypes() -> result_type_range { + if (!resultType) + return llvm::None; + if (hasSingleResult) + return resultType; + return resultType.cast().getTypes(); +} + void Operation::setSuccessor(Block *block, unsigned index) { assert(index < getNumSuccessors()); getBlockOperands()[index].set(block); @@ -666,10 +674,9 @@ } } - SmallVector resultTypes(getResultTypes()); unsigned numRegions = getNumRegions(); auto *newOp = - Operation::create(getLoc(), getName(), resultTypes, operands, attrs, + Operation::create(getLoc(), getName(), getResultTypes(), operands, attrs, successors, numRegions, hasResizableOperandsList()); // Remember the mapping of any results. @@ -919,7 +926,7 @@ auto type = op->getResult(0).getType(); auto elementType = getElementTypeOrSelf(type); - for (auto resultType : llvm::drop_begin(op->getResultTypes(), 1)) { + for (auto resultType : op->getResultTypes().drop_front(1)) { if (getElementTypeOrSelf(resultType) != elementType || failed(verifyCompatibleShape(resultType, type))) return op->emitOpError() diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -152,6 +152,10 @@ ResultRange::ResultRange(Operation *op) : ResultRange(op, /*startIndex=*/0, op->getNumResults()) {} +ArrayRef ResultRange::getTypes() const { + return getBase()->getResultTypes(); +} + /// See `indexed_accessor_range` for details. OpResult ResultRange::dereference(Operation *op, ptrdiff_t index) { return op->getResult(index); diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -37,10 +37,9 @@ // - Attributes // - Result Types // - Operands - return hash_combine( - op->getName(), op->getAttrList().getDictionary(), - hash_combine_range(op->result_type_begin(), op->result_type_end()), - hash_combine_range(op->operand_begin(), op->operand_end())); + return llvm::hash_combine( + op->getName(), op->getAttrList().getDictionary(), op->getResultTypes(), + llvm::hash_combine_range(op->operand_begin(), op->operand_end())); } static bool isEqual(const Operation *lhsC, const Operation *rhsC) { auto *lhs = const_cast(lhsC); diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -241,7 +241,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // If the type is I32, change the type to F32. - if (!(*op->result_type_begin()).isInteger(32)) + if (!Type(*op->result_type_begin()).isInteger(32)) return matchFailure(); rewriter.replaceOpWithNewOp(op, rewriter.getF32Type()); return matchSuccess(); @@ -254,7 +254,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // If the type is F32, change the type to F64. - if (!(*op->result_type_begin()).isF32()) + if (!Type(*op->result_type_begin()).isF32()) return matchFailure(); rewriter.replaceOpWithNewOp(op, rewriter.getF64Type()); return matchSuccess(); @@ -477,8 +477,7 @@ remappedOperands.push_back(rewriter.getRemappedValue(origOp)); remappedOperands.push_back(rewriter.getRemappedValue(origOp)); - SmallVector resultTypes(op.getResultTypes()); - rewriter.replaceOpWithNewOp(op, resultTypes, + rewriter.replaceOpWithNewOp(op, op.getResultTypes(), remappedOperands); return matchSuccess(); }