diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h --- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h +++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h @@ -215,22 +215,66 @@ ArrayRef generatedNames = {}) : OneToNConversionPattern(typeConverter, SourceOp::getOperationName(), benefit, context, generatedNames) {} + /// Generic adaptor around the root op of this pattern using the converted + /// operands. Importantly, each operand is represented as a *range* of values, + /// namely the N values each original operand gets converted to. Concretely, + /// this makes the result type of the accessor functions of the adaptor class + /// be a `ValueRange`. + class OpAdaptor + : public SourceOp::template GenericAdaptor> { + public: + using RangeT = ArrayRef; + using BaseT = typename SourceOp::template GenericAdaptor; + + OpAdaptor(const OneToNTypeMapping *operandMapping, + const OneToNTypeMapping *resultMapping, + const ValueRange *convertedOperands, RangeT values, + DictionaryAttr attrs = nullptr, RegionRange regions = {}) + : BaseT(values, attrs, regions), operandMapping(operandMapping), + resultMapping(resultMapping), convertedOperands(convertedOperands) {} + + /// Get the type mapping of the original operands to the converted operands. + const OneToNTypeMapping &getOperandMapping() const { + return *operandMapping; + } + + /// Get the type mapping of the original results to the converted results. + const OneToNTypeMapping &getResultMapping() const { return *resultMapping; } + + /// Get a flat range of all converted operands. Unlike `getOperands`, which + /// returns an `ArrayRef` with one `ValueRange` for each original operand, + /// this function returns a `ValueRange` that contains all converted + /// operands irrespectively of which operand they originated from. + ValueRange getFlatOperands() const { return *convertedOperands; } + + private: + const OneToNTypeMapping *operandMapping; + const OneToNTypeMapping *resultMapping; + const ValueRange *convertedOperands; + }; using OneToNConversionPattern::matchAndRewrite; /// Overload that derived classes have to override for their op type. - virtual LogicalResult matchAndRewrite(SourceOp op, - OneToNPatternRewriter &rewriter, - const OneToNTypeMapping &operandMapping, - const OneToNTypeMapping &resultMapping, - ValueRange convertedOperands) const = 0; + virtual LogicalResult + matchAndRewrite(SourceOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const = 0; LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter, const OneToNTypeMapping &operandMapping, const OneToNTypeMapping &resultMapping, ValueRange convertedOperands) const final { - return matchAndRewrite(cast(op), rewriter, operandMapping, - resultMapping, convertedOperands); + // Wrap converted operands and type mappings into an adaptor. + SmallVector valueRanges; + for (int64_t i = 0; i < op->getNumOperands(); i++) { + auto values = operandMapping.getConvertedValues(convertedOperands, i); + valueRanges.push_back(values); + } + OpAdaptor adaptor(&operandMapping, &resultMapping, &convertedOperands, + valueRanges, op->getAttrDictionary(), op->getRegions()); + + // Call overload implemented by the derived class. + return matchAndRewrite(cast(op), adaptor, rewriter); } }; diff --git a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp --- a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp +++ b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp @@ -27,21 +27,21 @@ public: using OneToNOpConversionPattern::OneToNOpConversionPattern; - LogicalResult matchAndRewrite(CallOp op, OneToNPatternRewriter &rewriter, - const OneToNTypeMapping &operandMapping, - const OneToNTypeMapping &resultMapping, - ValueRange convertedOperands) const override { + LogicalResult + matchAndRewrite(CallOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { Location loc = op->getLoc(); + const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); // Nothing to do if the op doesn't have any non-identity conversions for its // operands or results. - if (!operandMapping.hasNonIdentityConversion() && + if (!adaptor.getOperandMapping().hasNonIdentityConversion() && !resultMapping.hasNonIdentityConversion()) return failure(); // Create new CallOp. auto newOp = rewriter.create(loc, resultMapping.getConvertedTypes(), - convertedOperands); + adaptor.getFlatOperands()); newOp->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newOp->getResults(), resultMapping); @@ -54,10 +54,8 @@ using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult - matchAndRewrite(FuncOp op, OneToNPatternRewriter &rewriter, - const OneToNTypeMapping & /*operandMapping*/, - const OneToNTypeMapping & /*resultMapping*/, - ValueRange /*convertedOperands*/) const override { + matchAndRewrite(FuncOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { auto *typeConverter = getTypeConverter(); // Construct mapping for function arguments. @@ -99,16 +97,16 @@ public: using OneToNOpConversionPattern::OneToNOpConversionPattern; - LogicalResult matchAndRewrite(ReturnOp op, OneToNPatternRewriter &rewriter, - const OneToNTypeMapping &operandMapping, - const OneToNTypeMapping & /*resultMapping*/, - ValueRange convertedOperands) const override { + LogicalResult + matchAndRewrite(ReturnOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { // Nothing to do if there is no non-identity conversion. - if (!operandMapping.hasNonIdentityConversion()) + if (!adaptor.getOperandMapping().hasNonIdentityConversion()) return failure(); // Convert operands. - rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); }); + rewriter.updateRootInPlace( + op, [&] { op->setOperands(adaptor.getFlatOperands()); }); return success(); } diff --git a/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp --- a/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp @@ -25,11 +25,10 @@ using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult - matchAndRewrite(IfOp op, OneToNPatternRewriter &rewriter, - const OneToNTypeMapping & /*operandMapping*/, - const OneToNTypeMapping &resultMapping, - const ValueRange /*convertedOperands*/) const override { + matchAndRewrite(IfOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { Location loc = op->getLoc(); + const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); // Nothing to do if there is no non-identity conversion. if (!resultMapping.hasNonIdentityConversion()) @@ -62,12 +61,13 @@ using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult - matchAndRewrite(WhileOp op, OneToNPatternRewriter &rewriter, - const OneToNTypeMapping &operandMapping, - const OneToNTypeMapping &resultMapping, - const ValueRange convertedOperands) const override { + matchAndRewrite(WhileOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { Location loc = op->getLoc(); + const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping(); + const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); + // Nothing to do if the op doesn't have any non-identity conversions for its // operands or results. if (!operandMapping.hasNonIdentityConversion() && @@ -77,8 +77,8 @@ // Create new WhileOp. TypeRange convertedResultTypes = resultMapping.getConvertedTypes(); - auto newOp = - rewriter.create(loc, convertedResultTypes, convertedOperands); + auto newOp = rewriter.create(loc, convertedResultTypes, + adaptor.getFlatOperands()); newOp->setAttrs(op->getAttrs()); // Update block signatures. @@ -106,16 +106,15 @@ using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult - matchAndRewrite(YieldOp op, OneToNPatternRewriter &rewriter, - const OneToNTypeMapping &operandMapping, - const OneToNTypeMapping & /*resultMapping*/, - const ValueRange convertedOperands) const override { + matchAndRewrite(YieldOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { // Nothing to do if there is no non-identity conversion. - if (!operandMapping.hasNonIdentityConversion()) + if (!adaptor.getOperandMapping().hasNonIdentityConversion()) return failure(); // Convert operands. - rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); }); + rewriter.updateRootInPlace( + op, [&] { op->setOperands(adaptor.getFlatOperands()); }); return success(); } @@ -127,16 +126,15 @@ using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult - matchAndRewrite(ConditionOp op, OneToNPatternRewriter &rewriter, - const OneToNTypeMapping &operandMapping, - const OneToNTypeMapping & /*resultMapping*/, - const ValueRange convertedOperands) const override { + matchAndRewrite(ConditionOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { // Nothing to do if there is no non-identity conversion. - if (!operandMapping.hasNonIdentityConversion()) + if (!adaptor.getOperandMapping().hasNonIdentityConversion()) return failure(); // Convert operands. - rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); }); + rewriter.updateRootInPlace( + op, [&] { op->setOperands(adaptor.getFlatOperands()); }); return success(); } diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp --- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp @@ -77,13 +77,12 @@ using OneToNOpConversionPattern< ::test::MakeTupleOp>::OneToNOpConversionPattern; - LogicalResult matchAndRewrite(::test::MakeTupleOp op, - OneToNPatternRewriter &rewriter, - const OneToNTypeMapping &operandMapping, - const OneToNTypeMapping &resultMapping, - ValueRange convertedOperands) const override { + LogicalResult + matchAndRewrite(::test::MakeTupleOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { // Simply replace the current op with the converted operands. - rewriter.replaceOp(op, convertedOperands, resultMapping); + rewriter.replaceOp(op, adaptor.getFlatOperands(), + adaptor.getResultMapping()); return success(); } }; @@ -99,11 +98,9 @@ using OneToNOpConversionPattern< ::test::GetTupleElementOp>::OneToNOpConversionPattern; - LogicalResult matchAndRewrite(::test::GetTupleElementOp op, - OneToNPatternRewriter &rewriter, - const OneToNTypeMapping &operandMapping, - const OneToNTypeMapping &resultMapping, - ValueRange convertedOperands) const override { + LogicalResult + matchAndRewrite(::test::GetTupleElementOp op, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { // Construct mapping for tuple element types. auto stateType = op->getOperand(0).getType().cast(); TypeRange originalElementTypes = stateType.getTypes(); @@ -113,16 +110,17 @@ return failure(); // Compute converted operands corresponding to original input tuple. - ValueRange convertedTuple = - operandMapping.getConvertedValues(convertedOperands, 0); + assert(adaptor.getOperands().size() == 1 && + "expected 'get_tuple_element' to have one operand"); + ValueRange convertedTuple = adaptor.getOperands()[0]; - // Got those converted operands that correspond to the index-th element of + // Got those converted operands that correspond to the index-th element ofq // the original input tuple. size_t index = op.getIndex(); ValueRange extractedElement = elementMapping.getConvertedValues(convertedTuple, index); - rewriter.replaceOp(op, extractedElement, resultMapping); + rewriter.replaceOp(op, extractedElement, adaptor.getResultMapping()); return success(); }