diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -725,6 +725,14 @@ result_type_range getResultTypes() { return this->getOperation()->getResultTypes(); } + + /// Sets the types of all results to `types`. + /// + /// Requires that the number of elements of `types` matches the number of + /// results. + void setResultTypes(TypeRange types) { + this->getOperation()->setResultTypes(types); + } }; } // end namespace detail diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -292,6 +292,12 @@ result_type_iterator result_type_end() { return getResultTypes().end(); } result_type_range getResultTypes(); + /// Sets the types of all results to `types`. + /// + /// Requires that the number of elements of `types` matches the number of + /// results. + void setResultTypes(TypeRange types) { getResults().setTypes(types); } + //===--------------------------------------------------------------------===// // Attributes //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -735,6 +735,11 @@ using type_range = ArrayRef; type_range getTypes() const; auto getType() const { return getTypes(); } + /// Sets the types of values in this range to the given `types`. + /// + /// Requires that the number of elements of `types` matches the number of + /// elements in this range. + void setTypes(TypeRange types); private: /// See `llvm::indexed_accessor_range` for details. diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -31,8 +31,7 @@ newResultTypes.push_back(newType); } rewriter.updateRootInPlace(op, [&] { - for (auto t : llvm::zip(op.getResults(), newResultTypes)) - std::get<0>(t).setType(std::get<1>(t)); + op.setResultTypes(newResultTypes); auto bodyArgs = op.getBody()->getArguments(); for (auto t : llvm::zip(llvm::drop_begin(bodyArgs, 1), newResultTypes)) std::get<0>(t).setType(std::get<1>(t)); @@ -49,6 +48,16 @@ LogicalResult matchAndRewrite(IfOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { + // TODO: Generalize this to any type conversion, not just 1:1. + // + // We need to implement something more sophisticated here that tracks which + // types convert to which other types and does the appropriate + // materialization logic. + // For example, it's possible that one result type converts to 0 types and + // another to 2 types, so newResultTypes would at least be the right size to + // not crash in the setResultTypes call, but then we would set the the wrong + // type on the SSA values! These edge cases are also why we cannot safely + // use the TypeConverter::convertTypes helper here. SmallVector newResultTypes; for (auto type : op.getResultTypes()) { Type newType = typeConverter->convertType(type); @@ -56,10 +65,7 @@ return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); newResultTypes.push_back(newType); } - rewriter.updateRootInPlace(op, [&] { - for (auto t : llvm::zip(op.getResults(), newResultTypes)) - std::get<0>(t).setType(std::get<1>(t)); - }); + rewriter.updateRootInPlace(op, [&] { op.setResultTypes(newResultTypes); }); return success(); } }; diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -478,6 +478,12 @@ return getBase()->getResultTypes().slice(getStartIndex(), size()); } +void ResultRange::setTypes(TypeRange types) { + assert(types.size() == size() && "mismatch in number of types!"); + for (auto t : llvm::zip(*this, types)) + std::get<0>(t).setType(std::get<1>(t)); +} + /// See `llvm::indexed_accessor_range` for details. OpResult ResultRange::dereference(Operation *op, ptrdiff_t index) { return op->getResult(index);