diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -32,22 +32,82 @@ unpacked.push_back(v); } -class ConvertForOpTypes : public OpConversionPattern { +// CRTP +// A base class that takes care of 1:N type conversion, which maps the converted +// op results (computed by the derived class) and materializes 1:N conversion. +template +class Structural1ToNConversionPattern : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::typeConverter; + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename OpConversionPattern::OpAdaptor; + + // + // Derived classes should provide the following method which performs the + // actual conversion. It should return llvm::None upon conversion failure and + // return the converted operation upon success. + // + // Optional convertSourceOp(SourceOp op, OpAdaptor adaptor, + // ConversionPatternRewriter &rewriter, + // TypeRange dstTypes) const; + LogicalResult - matchAndRewrite(ForOp op, OpAdaptor adaptor, + matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector newResultTypes; + SmallVector dstTypes; SmallVector offsets; offsets.push_back(0); // Do the type conversion and record the offsets. for (Type type : op.getResultTypes()) { - if (failed(typeConverter->convertTypes(type, newResultTypes))) - return rewriter.notifyMatchFailure(op, "could not convert result"); - offsets.push_back(newResultTypes.size()); + if (failed(typeConverter->convertTypes(type, dstTypes))) + return rewriter.notifyMatchFailure(op, "could not convert result type"); + offsets.push_back(dstTypes.size()); } + // Calls the actual converter implementation to convert the operation. + Optional newOp = + static_cast(this)->convertSourceOp( + op, adaptor, rewriter, dstTypes); + + if (!newOp) + return rewriter.notifyMatchFailure(op, "could not convert operation"); + + // Packs the return value. + SmallVector packedRets; + for (unsigned i = 1, e = offsets.size(); i < e; i++) { + unsigned start = offsets[i - 1], end = offsets[i]; + unsigned len = end - start; + ValueRange mappedValue = newOp->getResults().slice(start, len); + if (len != 1) { + // 1 : N type conversion. + Type origType = op.getResultTypes()[i - 1]; + Value mat = typeConverter->materializeSourceConversion( + rewriter, op.getLoc(), origType, mappedValue); + if (!mat) { + return rewriter.notifyMatchFailure( + op, "Failed to materialize 1:N type conversion"); + } + packedRets.push_back(mat); + } else { + // 1 : 1 type conversion. + packedRets.push_back(mappedValue.front()); + } + } + + rewriter.replaceOp(op, packedRets); + return success(); + } +}; + +class ConvertForOpTypes + : public Structural1ToNConversionPattern { +public: + using Structural1ToNConversionPattern::Structural1ToNConversionPattern; + + // The callback required by CRTP. + Optional convertSourceOp(ForOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + TypeRange dstTypes) const { // Create a empty new op and inline the regions from the old op. // // This is a little bit tricky. We have two concerns here: @@ -67,15 +127,15 @@ // convertRegionTypes already takes care of 1:N conversion. if (failed(rewriter.convertRegionTypes(&op.getLoopBody(), *typeConverter))) - return failure(); + return llvm::None; // Unpacked the iteration arguments. SmallVector flatArgs; for (Value arg : adaptor.getInitArgs()) unpackUnrealizedConversionCast(arg, flatArgs); - // We can not do clone as the number of result types after conversion might - // be different. + // We can not do clone as the number of result types after conversion + // might be different. ForOp newOp = rewriter.create(op.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep(), flatArgs); @@ -89,29 +149,7 @@ rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(), newOp.getLoopBody().end()); - // Pack the return value. - SmallVector packedRets; - for (unsigned i = 1, e = offsets.size(); i < e; i++) { - unsigned start = offsets[i - 1], end = offsets[i]; - unsigned len = end - start; - ValueRange mappedValue = newOp.getResults().slice(start, len); - if (len != 1) { - // 1 : N type conversion. - Type origType = op.getResultTypes()[i - 1]; - Value mat = typeConverter->materializeSourceConversion( - rewriter, op.getLoc(), origType, mappedValue); - if (!mat) - return rewriter.notifyMatchFailure( - op, "Failed to materialize 1:N type conversion"); - packedRets.push_back(mat); - } else { - // 1 : 1 type conversion. - packedRets.push_back(mappedValue.front()); - } - } - - rewriter.replaceOp(op, packedRets); - return success(); + return newOp; } }; } // namespace