diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -32,22 +32,68 @@ unpacked.push_back(v); } -class ConvertForOpTypes : public OpConversionPattern { +template +class Structural1ToNConversionPattern : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::typeConverter; + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename OpConversionPattern::OpAdaptor; + LogicalResult - matchAndRewrite(ForOp op, OpAdaptor adaptor, + matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector newResultTypes; + SmallVector dstTypes; SmallVector offsets; offsets.push_back(0); // Do the type conversion and record the offsets. for (Type type : op.getResultTypes()) { - if (failed(typeConverter->convertTypes(type, newResultTypes))) - return rewriter.notifyMatchFailure(op, "could not convert result"); - offsets.push_back(newResultTypes.size()); + if (failed(typeConverter->convertTypes(type, dstTypes))) + return rewriter.notifyMatchFailure(op, "could not convert result type"); + offsets.push_back(dstTypes.size()); } + // Calls the actually converter implementation to convert the operation. + Optional newOp = + static_cast(this)->convertSourceOp( + op, adaptor, rewriter, dstTypes); + + if (!newOp) + return rewriter.notifyMatchFailure(op, "could not convert operator"); + + // Pack the return value. + SmallVector packedRets; + for (unsigned i = 1, e = offsets.size(); i < e; i++) { + unsigned start = offsets[i - 1], end = offsets[i]; + unsigned len = end - start; + ValueRange mappedValue = newOp->getResults().slice(start, len); + if (len != 1) { + // 1 : N type conversion. + Type origType = op.getResultTypes()[i - 1]; + Value mat = typeConverter->materializeSourceConversion( + rewriter, op.getLoc(), origType, mappedValue); + if (!mat) + return rewriter.notifyMatchFailure( + op, "Failed to materialize 1:N type conversion"); + packedRets.push_back(mat); + } else { + // 1 : 1 type conversion. + packedRets.push_back(mappedValue.front()); + } + } + + rewriter.replaceOp(op, packedRets); + return success(); + } +}; + +class ConvertForOpTypes + : public Structural1ToNConversionPattern { +public: + using Structural1ToNConversionPattern::Structural1ToNConversionPattern; + + Optional convertSourceOp(ForOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + TypeRange dstTypes) const { // Create a empty new op and inline the regions from the old op. // // This is a little bit tricky. We have two concerns here: @@ -58,24 +104,25 @@ // PR47938 tracks this issue, but it seems hard to fix. Instead, we need // to clone the op. // - // 2. We need to resue the original region instead of cloning it, otherwise - // the dialect conversion framework thinks that we just inserted all the - // cloned child ops. But what we want is to "take" the child regions and let - // the dialect conversion framework continue recursively into ops inside - // those regions (which are already in its worklist; inlining them into the - // new op's regions doesn't remove the child ops from the worklist). + // 2. We need to resue the original region instead of cloning it, + // otherwise the dialect conversion framework thinks that we just inserted + // all the cloned child ops. But what we want is to "take" the child + // regions and let the dialect conversion framework continue recursively + // into ops inside those regions (which are already in its worklist; + // inlining them into the new op's regions doesn't remove the child ops + // from the worklist). // convertRegionTypes already takes care of 1:N conversion. if (failed(rewriter.convertRegionTypes(&op.getLoopBody(), *typeConverter))) - return failure(); + return llvm::None; // Unpacked the iteration arguments. SmallVector flatArgs; for (Value arg : adaptor.getInitArgs()) unpackUnrealizedConversionCast(arg, flatArgs); - // We can not do clone as the number of result types after conversion might - // be different. + // We can not do clone as the number of result types after conversion + // might be different. ForOp newOp = rewriter.create(op.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep(), flatArgs); @@ -89,29 +136,7 @@ rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(), newOp.getLoopBody().end()); - // Pack the return value. - SmallVector packedRets; - for (unsigned i = 1, e = offsets.size(); i < e; i++) { - unsigned start = offsets[i - 1], end = offsets[i]; - unsigned len = end - start; - ValueRange mappedValue = newOp.getResults().slice(start, len); - if (len != 1) { - // 1 : N type conversion. - Type origType = op.getResultTypes()[i - 1]; - Value mat = typeConverter->materializeSourceConversion( - rewriter, op.getLoc(), origType, mappedValue); - if (!mat) - return rewriter.notifyMatchFailure( - op, "Failed to materialize 1:N type conversion"); - packedRets.push_back(mat); - } else { - // 1 : 1 type conversion. - packedRets.push_back(mappedValue.front()); - } - } - - rewriter.replaceOp(op, packedRets); - return success(); + return newOp; } }; } // namespace