diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -262,14 +262,21 @@ ### Region Signature Conversion -From the perspective of type conversion, the entry block to a region is often -special. The types of the entry block arguments are often tied semantically to -details on the operation, e.g. FuncOp, AffineForOp, etc. Given this, the -conversion of the types for this block must be done explicitly via a conversion -pattern. To convert the signature of a region entry block, a custom hook on the -ConversionPatternRewriter must be invoked `applySignatureConversion`. A -signature conversion, `TypeConverter::SignatureConversion`, can be built -programmatically: +From the perspective of type conversion, the types of block arguments are a bit +special. Throughout the conversion process, blocks may move between regions of +different operations. Given this, the conversion of the types for blocks must be +done explicitly via a conversion pattern. To convert the types of block +arguments within a Region, a custom hook on the `ConversionPatternRewriter` must +be invoked; `convertRegionTypes`. This hook uses a provided type converter to +apply type conversions to all blocks within the region, and all blocks that move +into that region. This hook also takes an optional +`TypeConverter::SignatureConversion` parameter that applies a custom conversion +to the entry block of the region. The types of the entry block arguments are +often tied semantically to details on the operation, e.g. FuncOp, AffineForOp, +etc. To convert the signature of just the region entry block, and not any other +blocks within the region, the `applySignatureConversion` hook may be used +instead. A signature conversion, `TypeConverter::SignatureConversion`, can be +built programmatically: ```c++ class SignatureConversion { @@ -293,5 +300,6 @@ }; ``` -The `TypeConverter` provides several default utilities for signature conversion: -`convertSignatureArg`/`convertBlockSignature`. +The `TypeConverter` provides several default utilities for signature conversion +and legality checking: +`convertSignatureArgs`/`convertBlockSignature`/`isLegal(Region *|Type)`. diff --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md --- a/mlir/docs/Tutorials/Toy/Ch-6.md +++ b/mlir/docs/Tutorials/Toy/Ch-6.md @@ -106,8 +106,7 @@ ```c++ mlir::ModuleOp module = getOperation(); - if (mlir::failed(mlir::applyFullConversion(module, target, patterns, - &typeConverter))) + if (mlir::failed(mlir::applyFullConversion(module, target, patterns))) signalPassFailure(); ``` diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -203,7 +203,7 @@ // We want to completely lower to LLVM, so we use a `FullConversion`. This // ensures that only legal operations will remain after the conversion. auto module = getOperation(); - if (failed(applyFullConversion(module, target, patterns, &typeConverter))) + if (failed(applyFullConversion(module, target, patterns))) signalPassFailure(); } diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -203,7 +203,7 @@ // We want to completely lower to LLVM, so we use a `FullConversion`. This // ensures that only legal operations will remain after the conversion. auto module = getOperation(); - if (failed(applyFullConversion(module, target, patterns, &typeConverter))) + if (failed(applyFullConversion(module, target, patterns))) signalPassFailure(); } diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -99,7 +99,7 @@ /// pattern. Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context); - /// This contructor is used when a pattern may match against multiple + /// This constructor is used when a pattern may match against multiple /// different types of operations. The `benefit` is the expected benefit of /// matching this pattern. `MatchAnyOpTypeTag` is just a tag to ensure that /// the "match any" behavior is what the user actually desired, @@ -163,28 +163,27 @@ ArrayRef getGeneratedOps() const { return generatedOps; } protected: - /// Patterns must specify the root operation name they match against, and can - /// also specify the benefit of the pattern matching. + /// Construct a rewrite pattern with a certain benefit that matches the + /// operation with the given root name. RewritePattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context) : Pattern(rootName, benefit, context) {} - /// Patterns must specify the root operation name they match against, and can - /// also specify the benefit of the pattern matching. `MatchAnyOpTypeTag` - /// is just a tag to ensure that the "match any" behavior is what the user - /// actually desired, `MatchAnyOpTypeTag()` should always be supplied here. + /// Construct a rewrite pattern with a certain benefit that matches any + /// operation type. `MatchAnyOpTypeTag` is just a tag to ensure that the + /// "match any" behavior is what the user actually desired, + /// `MatchAnyOpTypeTag()` should always be supplied here. RewritePattern(PatternBenefit benefit, MatchAnyOpTypeTag tag) : Pattern(benefit, tag) {} - /// Patterns must specify the root operation name they match against, and can - /// also specify the benefit of the pattern matching. They can also specify - /// the names of operations that may be generated during a successful rewrite. + /// Construct a rewrite pattern with a certain benefit that matches the + /// operation with the given root name. `generatedNames` contains the names of + /// operations that may be generated during a successful rewrite. RewritePattern(StringRef rootName, ArrayRef generatedNames, PatternBenefit benefit, MLIRContext *context); - /// Patterns must specify the root operation name they match against, and can - /// also specify the benefit of the pattern matching. They can also specify - /// the names of operations that may be generated during a successful rewrite. - /// `MatchAnyOpTypeTag` is just a tag to ensure that the "match any" - /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should - /// always be supplied here. + /// Construct a rewrite pattern that may match any operation type. + /// `generatedNames` contains the names of operations that may be generated + /// during a successful rewrite. `MatchAnyOpTypeTag` is just a tag to ensure + /// that the "match any" behavior is what the user actually desired, + /// `MatchAnyOpTypeTag()` should always be supplied here. RewritePattern(ArrayRef generatedNames, PatternBenefit benefit, MLIRContext *context, MatchAnyOpTypeTag tag); diff --git a/mlir/include/mlir/Support/LogicalResult.h b/mlir/include/mlir/Support/LogicalResult.h --- a/mlir/include/mlir/Support/LogicalResult.h +++ b/mlir/include/mlir/Support/LogicalResult.h @@ -10,11 +10,12 @@ #define MLIR_SUPPORT_LOGICAL_RESULT_H #include "mlir/Support/LLVM.h" +#include "llvm/ADT/Optional.h" namespace mlir { -// Values that can be used to signal success/failure. This should be used in -// conjunction with the utility functions below. +/// Values that can be used to signal success/failure. This should be used in +/// conjunction with the utility functions below. struct LogicalResult { enum ResultEnum { Success, Failure } value; LogicalResult(ResultEnum v) : value(v) {} @@ -46,6 +47,28 @@ return result.value == LogicalResult::Failure; } +/// This class provides support for representing a failure result, or a valid +/// value of type `T`. This allows for integrating with LogicalResult, while +/// also providing a value on the success path. +template class LLVM_NODISCARD FailureOr : public Optional { +public: + /// Allow constructing from a LogicalResult. The result *must* be a failure. + /// Success results should use a proper instance of type `T`. + FailureOr(LogicalResult result) { + assert(failed(result) && + "success should be constructed with an instance of 'T'"); + } + FailureOr() : FailureOr(failure()) {} + FailureOr(T &&y) : Optional(std::forward(y)) {} + + operator LogicalResult() const { return success(this->hasValue()); } + +private: + /// Hide the bool conversion as it easily creates confusion. + using Optional::operator bool; + using Optional::hasValue; +}; + } // namespace mlir #endif // MLIR_SUPPORT_LOGICAL_RESULT_H diff --git a/mlir/include/mlir/Transforms/BufferPlacement.h b/mlir/include/mlir/Transforms/BufferPlacement.h --- a/mlir/include/mlir/Transforms/BufferPlacement.h +++ b/mlir/include/mlir/Transforms/BufferPlacement.h @@ -141,12 +141,14 @@ else newResultTypes.push_back(convertedType); } + if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *converter, + &conversion))) + return failure(); // Update the signature of the function. rewriter.updateRootInPlace(funcOp, [&] { funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(), newResultTypes)); - rewriter.applySignatureConversion(&funcOp.getBody(), conversion); }); return success(); } diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -160,6 +160,9 @@ /// Return true if the given operation has legal operand and result types. bool isLegal(Operation *op); + /// Return true if the types of block arguments within the region are legal. + bool isLegal(Region *region); + /// Return true if the inputs and outputs of the given function type are /// legal. bool isSignatureLegal(FunctionType ty); @@ -268,16 +271,15 @@ // Conversion Patterns //===----------------------------------------------------------------------===// -/// Base class for the conversion patterns that require type changes. Specific -/// conversions must derive this class and implement least one `rewrite` method. -/// NOTE: These conversion patterns can only be used with the 'apply*' methods -/// below. +/// Base class for the conversion patterns. This pattern class enables type +/// conversions, and other uses specific to the conversion framework. As such, +/// patterns of this type can only be used with the 'apply*' methods below. class ConversionPattern : public RewritePattern { public: /// Hook for derived classes to implement rewriting. `op` is the (first) - /// operation matched by the pattern, `operands` is a list of rewritten values - /// that are passed to this operation, `rewriter` can be used to emit the new - /// operations. This function should not fail. If some specific cases of + /// operation matched by the pattern, `operands` is a list of the rewritten + /// operand values that are passed to `op`, `rewriter` can be used to emit the + /// new operations. This function should not fail. If some specific cases of /// the operation are not supported, these cases should not be matched. virtual void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { @@ -298,8 +300,32 @@ LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final; + /// Return the type converter held by this pattern, or nullptr if the pattern + /// does not require type conversion. + TypeConverter *getTypeConverter() const { return typeConverter; } + protected: + /// See `RewritePattern::RewritePattern` for information on the other + /// available constructors. using RewritePattern::RewritePattern; + /// Construct a conversion pattern that matches an operation with the given + /// root name. This constructor allows for providing a type converter to use + /// within the pattern. + ConversionPattern(StringRef rootName, PatternBenefit benefit, + TypeConverter &typeConverter, MLIRContext *ctx) + : RewritePattern(rootName, benefit, ctx), typeConverter(&typeConverter) {} + /// Construct a conversion pattern that matches any operation type. This + /// constructor allows for providing a type converter to use within the + /// pattern. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any" + /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should + /// always be supplied here. + ConversionPattern(PatternBenefit benefit, TypeConverter &typeConverter, + MatchAnyOpTypeTag tag) + : RewritePattern(benefit, tag), typeConverter(&typeConverter) {} + +protected: + /// An optional type converter for use by this pattern. + TypeConverter *typeConverter; private: using RewritePattern::rewrite; @@ -312,6 +338,10 @@ struct OpConversionPattern : public ConversionPattern { OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) : ConversionPattern(SourceOp::getOperationName(), benefit, context) {} + OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context, + PatternBenefit benefit = 1) + : ConversionPattern(SourceOp::getOperationName(), benefit, typeConverter, + context) {} /// Wrappers around the ConversionPattern methods that pass the derived op /// type. @@ -367,7 +397,7 @@ /// hooks. class ConversionPatternRewriter final : public PatternRewriter { public: - ConversionPatternRewriter(MLIRContext *ctx, TypeConverter *converter); + ConversionPatternRewriter(MLIRContext *ctx); ~ConversionPatternRewriter() override; /// Apply a signature conversion to the entry block of the given region. This @@ -377,6 +407,15 @@ applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion); + /// Convert the types of block arguments within the given region. This + /// replaces each block with a new block containing the updated signature. The + /// entry block may have a special conversion if `entryConversion` is + /// provided. On success, the new entry block to the region is returned for + /// convenience. Otherwise, failure is returned. + FailureOr convertRegionTypes( + Region *region, TypeConverter &converter, + TypeConverter::SignatureConversion *entryConversion = nullptr); + /// Replace all the uses of the block argument `from` with value `to`. void replaceUsesOfBlockArgument(BlockArgument from, Value to); @@ -721,36 +760,30 @@ /// Apply a partial conversion on the given operations and all nested /// operations. This method converts as many operations to the target as /// possible, ignoring operations that failed to legalize. This method only -/// returns failure if there ops explicitly marked as illegal. If `converter` is -/// provided, the signatures of blocks and regions are also converted. -/// If an `unconvertedOps` set is provided, all operations that are found not -/// to be legalizable to the given `target` are placed within that set. (Note -/// that if there is an op explicitly marked as illegal, the conversion -/// terminates and the `unconvertedOps` set will not necessarily be complete.) +/// returns failure if there ops explicitly marked as illegal. If an +/// `unconvertedOps` set is provided, all operations that are found not to be +/// legalizable to the given `target` are placed within that set. (Note that if +/// there is an op explicitly marked as illegal, the conversion terminates and +/// the `unconvertedOps` set will not necessarily be complete.) LLVM_NODISCARD LogicalResult applyPartialConversion(ArrayRef ops, ConversionTarget &target, const OwningRewritePatternList &patterns, - TypeConverter *converter = nullptr, DenseSet *unconvertedOps = nullptr); LLVM_NODISCARD LogicalResult applyPartialConversion(Operation *op, ConversionTarget &target, const OwningRewritePatternList &patterns, - TypeConverter *converter = nullptr, DenseSet *unconvertedOps = nullptr); /// Apply a complete conversion on the given operations, and all nested /// operations. This method returns failure if the conversion of any operation /// fails, or if there are unreachable blocks in any of the regions nested -/// within 'ops'. If 'converter' is provided, the signatures of blocks and -/// regions are also converted. +/// within 'ops'. LLVM_NODISCARD LogicalResult applyFullConversion(ArrayRef ops, ConversionTarget &target, - const OwningRewritePatternList &patterns, - TypeConverter *converter = nullptr); + const OwningRewritePatternList &patterns); LLVM_NODISCARD LogicalResult applyFullConversion(Operation *op, ConversionTarget &target, - const OwningRewritePatternList &patterns, - TypeConverter *converter = nullptr); + const OwningRewritePatternList &patterns); /// Apply an analysis conversion on the given operations, and all nested /// operations. This method analyzes which operations would be successfully @@ -759,17 +792,15 @@ /// provided 'convertedOps' set; note that no actual rewrites are applied to the /// operations on success and only pre-existing operations are added to the set. /// This method only returns failure if there are unreachable blocks in any of -/// the regions nested within 'ops', or if a type conversion failed. If -/// 'converter' is provided, the signatures of blocks and regions are also -/// considered for conversion. -LLVM_NODISCARD LogicalResult applyAnalysisConversion( - ArrayRef ops, ConversionTarget &target, - const OwningRewritePatternList &patterns, - DenseSet &convertedOps, TypeConverter *converter = nullptr); -LLVM_NODISCARD LogicalResult applyAnalysisConversion( - Operation *op, ConversionTarget &target, - const OwningRewritePatternList &patterns, - DenseSet &convertedOps, TypeConverter *converter = nullptr); +/// the regions nested within 'ops'. +LLVM_NODISCARD LogicalResult +applyAnalysisConversion(ArrayRef ops, ConversionTarget &target, + const OwningRewritePatternList &patterns, + DenseSet &convertedOps); +LLVM_NODISCARD LogicalResult +applyAnalysisConversion(Operation *op, ConversionTarget &target, + const OwningRewritePatternList &patterns, + DenseSet &convertedOps); } // end namespace mlir #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_ diff --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp --- a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp +++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp @@ -179,10 +179,7 @@ target.addLegalDialect(); target.addLegalDialect(); target.addIllegalDialect(); - target.addDynamicallyLegalOp( - [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); - if (failed(applyPartialConversion(getOperation(), target, patterns, - &converter))) { + if (failed(applyPartialConversion(getOperation(), target, patterns))) { signalPassFailure(); } } diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h @@ -145,8 +145,9 @@ // Move the region to the new function, update the entry block signature. rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(), llvmFuncOp.end()); - rewriter.applySignatureConversion(&llvmFuncOp.getBody(), - signatureConversion); + if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), typeConverter, + &signatureConversion))) + return failure(); rewriter.eraseOp(gpuFuncOp); return success(); diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -133,7 +133,7 @@ target.addLegalDialect(); // TODO(csigg): Remove once we support replacing non-root ops. target.addLegalOp(); - if (failed(applyPartialConversion(m, target, patterns, &converter))) + if (failed(applyPartialConversion(m, target, patterns))) signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -67,7 +67,7 @@ target.addLegalDialect(); // TODO(whchung): Remove once we support replacing non-root ops. target.addLegalOp(); - if (failed(applyPartialConversion(m, target, patterns, &converter))) + if (failed(applyPartialConversion(m, target, patterns))) signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp @@ -164,8 +164,11 @@ TypeConverter::SignatureConversion signatureConverter( body->getNumArguments()); signatureConverter.remapInput(0, newIndVar); - body = rewriter.applySignatureConversion(&forOp.getLoopBody(), - signatureConverter); + FailureOr newBody = rewriter.convertRegionTypes( + &forOp.getLoopBody(), typeConverter, &signatureConverter); + if (failed(newBody)) + return failure(); + body = *newBody; // Delete the loop terminator. rewriter.eraseOp(body->getTerminator()); @@ -356,9 +359,12 @@ continue; newFuncOp.setAttr(namedAttr.first, namedAttr.second); } + rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); - rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter); + if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, + &signatureConverter))) + return nullptr; rewriter.eraseOp(funcOp); spirv::setABIAttrs(newFuncOp, entryPointInfo, argABIInfo); diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp @@ -61,10 +61,8 @@ populateGPUToSPIRVPatterns(context, typeConverter, patterns); populateStandardToSPIRVPatterns(context, typeConverter, patterns); - if (failed(applyFullConversion(kernelModules, *target, patterns, - &typeConverter))) { + if (failed(applyFullConversion(kernelModules, *target, patterns))) return signalPassFailure(); - } } std::unique_ptr> mlir::createConvertGPUToSPIRVPass() { diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -383,10 +383,8 @@ populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext()); LLVMConversionTarget target(getContext()); - target.addDynamicallyLegalOp( - [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); target.addLegalOp(); - if (failed(applyFullConversion(module, target, patterns, &converter))) + if (failed(applyFullConversion(module, target, patterns))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp --- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp @@ -36,8 +36,10 @@ // Allow builtin ops. target->addLegalOp(); - target->addDynamicallyLegalOp( - [&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); }); + target->addDynamicallyLegalOp([&](FuncOp op) { + return typeConverter.isSignatureLegal(op.getType()) && + typeConverter.isLegal(&op.getBody()); + }); if (failed(applyFullConversion(module, *target, patterns))) return signalPassFailure(); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp @@ -44,8 +44,7 @@ ConversionTarget target(getContext()); target.addIllegalDialect(); target.addLegalDialect(); - - if (failed(applyPartialConversion(module, target, patterns, &converter))) + if (failed(applyPartialConversion(module, target, patterns))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -82,7 +82,8 @@ target.addLegalDialect(); target.addLegalOp(); target.addDynamicallyLegalOp([&](FuncOp op) { - return typeConverter.isSignatureLegal(op.getType()); + return typeConverter.isSignatureLegal(op.getType()) && + typeConverter.isLegal(&op.getBody()); }); // Setup conversion patterns. @@ -92,7 +93,7 @@ // Apply conversion. auto module = getOperation(); - if (failed(applyFullConversion(module, target, patterns, &typeConverter))) + if (failed(applyFullConversion(module, target, patterns))) signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -398,7 +398,7 @@ MLIRContext *context, LLVMTypeConverter &typeConverter_, PatternBenefit benefit) - : ConversionPattern(rootOpName, benefit, context), + : ConversionPattern(rootOpName, benefit, typeConverter_, context), typeConverter(typeConverter_) {} /*============================================================================*/ @@ -1038,8 +1038,9 @@ attributes); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); - // Tell the rewriter to convert the region signature. - rewriter.applySignatureConversion(&newFuncOp.getBody(), result); + if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, + &result))) + return nullptr; return newFuncOp; } @@ -1059,6 +1060,9 @@ auto funcOp = cast(op); auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); + if (!newFuncOp) + return failure(); + if (emitWrappers || funcOp.getAttrOfType(kEmitIfaceAttrName)) { if (newFuncOp.isExternal()) wrapExternalFunction(rewriter, op->getLoc(), typeConverter, funcOp, @@ -1095,6 +1099,8 @@ getMemRefArgIndicesAndTypes(funcOp.getType(), promotedArgsInfo); auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); + if (!newFuncOp) + return failure(); if (newFuncOp.getBody().empty()) { rewriter.eraseOp(op); return success(); @@ -3172,7 +3178,7 @@ emitCWrappers, useAlignedAlloc); LLVMConversionTarget target(getContext()); - if (failed(applyPartialConversion(m, target, patterns, &typeConverter))) + if (failed(applyPartialConversion(m, target, patterns))) signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1195,10 +1195,7 @@ populateStdToLLVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); - target.addDynamicallyLegalOp( - [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); - if (failed(applyPartialConversion(getOperation(), target, patterns, - &converter))) { + if (failed(applyPartialConversion(getOperation(), target, patterns))) { signalPassFailure(); } } diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp --- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp +++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp @@ -173,10 +173,8 @@ LLVMConversionTarget target(getContext()); target.addLegalDialect(); - if (failed(applyPartialConversion(getOperation(), target, patterns, - &converter))) { + if (failed(applyPartialConversion(getOperation(), target, patterns))) signalPassFailure(); - } } std::unique_ptr> diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp @@ -136,19 +136,19 @@ target.addDynamicallyLegalOp([&](FuncOp funcOp) { return converter.isSignatureLegal(funcOp.getType()) && llvm::none_of(funcOp.getType().getResults(), - [&](Type type) { return type.isa(); }); + [&](Type type) { return type.isa(); }) && + converter.isLegal(&funcOp.getBody()); }); // Walk over all the functions to apply buffer assignment. - getOperation().walk([&](FuncOp function) { + getOperation().walk([&](FuncOp function) -> WalkResult { OwningRewritePatternList patterns; BufferAssignmentPlacer placer(function); populateConvertLinalgOnTensorsToBuffersPattern(&context, &placer, &converter, &patterns); // Applying full conversion - return WalkResult( - applyFullConversion(function, target, patterns, &converter)); + return applyFullConversion(function, target, patterns); }); } }; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -489,7 +489,9 @@ rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); - rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter); + if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, + &signatureConverter))) + return failure(); rewriter.eraseOp(funcOp); return success(); } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -201,12 +201,14 @@ } signatureConverter.remapInput(argType.index(), replacement); } + if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), typeConverter, + &signatureConverter))) + return failure(); // Creates a new function with the update signature. rewriter.updateRootInPlace(funcOp, [&] { funcOp.setType(rewriter.getFunctionType( signatureConverter.getConvertedTypes(), llvm::None)); - rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter); }); return success(); } @@ -237,10 +239,8 @@ return op->getDialect()->getNamespace() == spirv::SPIRVDialect::getDialectNamespace(); }); - if (failed( - applyPartialConversion(module, target, patterns, &typeConverter))) { + if (failed(applyPartialConversion(module, target, patterns))) return signalPassFailure(); - } // Walks over all the FuncOps in spirv::ModuleOp to lower the entry point // attributes. diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -98,7 +98,7 @@ } //===----------------------------------------------------------------------===// -// Multi-Level Value Mapper +// ConversionValueMapping //===----------------------------------------------------------------------===// namespace { @@ -140,9 +140,7 @@ /// types and extracting the block that contains the old illegal types to allow /// for undoing pending rewrites in the case of failure. struct ArgConverter { - ArgConverter(TypeConverter *typeConverter, PatternRewriter &rewriter) - : loc(rewriter.getUnknownLoc()), typeConverter(typeConverter), - rewriter(rewriter) {} + ArgConverter(PatternRewriter &rewriter) : rewriter(rewriter) {} /// This structure contains the information pertaining to an argument that has /// been converted. @@ -166,7 +164,8 @@ /// This structure contains information pertaining to a block that has had its /// signature converted. struct ConvertedBlockInfo { - ConvertedBlockInfo(Block *origBlock) : origBlock(origBlock) {} + ConvertedBlockInfo(Block *origBlock, TypeConverter &converter) + : origBlock(origBlock), converter(&converter) {} /// The original block that was requested to have its signature converted. Block *origBlock; @@ -174,11 +173,26 @@ /// The conversion information for each of the arguments. The information is /// None if the argument was dropped during conversion. SmallVector, 1> argInfo; + + /// The type converter used to convert the arguments. + TypeConverter *converter; }; /// Return if the signature of the given block has already been converted. bool hasBeenConverted(Block *block) const { - return conversionInfo.count(block); + return conversionInfo.count(block) || convertedBlocks.count(block); + } + + /// Set the type converter to use for the given region. + void setConverter(Region *region, TypeConverter *typeConverter) { + assert(typeConverter && "expected valid type converter"); + regionToConverter[region] = typeConverter; + } + + /// Return the type converter to use for the given region, or null if there + /// isn't one. + TypeConverter *getConverter(Region *region) { + return regionToConverter.lookup(region); } //===--------------------------------------------------------------------===// @@ -204,32 +218,39 @@ //===--------------------------------------------------------------------===// /// Attempt to convert the signature of the given block, if successful a new - /// block is returned containing the new arguments. On failure, nullptr is - /// returned. - Block *convertSignature(Block *block, ConversionValueMapping &mapping); + /// block is returned containing the new arguments. Returns `block` if it did + /// not require conversion. + FailureOr convertSignature(Block *block, TypeConverter &converter, + ConversionValueMapping &mapping); /// Apply the given signature conversion on the given block. The new block - /// containing the updated signature is returned. + /// containing the updated signature is returned. If no conversions were + /// necessary, e.g. if the block has no arguments, `block` is returned. + /// `converter` is used to generate any necessary cast operations that + /// translate between the origin argument types and those specified in the + /// signature conversion. Block *applySignatureConversion( - Block *block, TypeConverter::SignatureConversion &signatureConversion, + Block *block, TypeConverter &converter, + TypeConverter::SignatureConversion &signatureConversion, ConversionValueMapping &mapping); /// Insert a new conversion into the cache. void insertConversion(Block *newBlock, ConvertedBlockInfo &&info); - /// A collection of blocks that have had their arguments converted. + /// A collection of blocks that have had their arguments converted. This is a + /// map from the new replacement block, back to the original block. llvm::MapVector conversionInfo; + /// The set of original blocks that were converted. + DenseSet convertedBlocks; + /// A mapping from valid regions, to those containing the original blocks of a /// conversion. DenseMap> regionMapping; - /// An instance of the unknown location that is used when materializing - /// conversions. - Location loc; - - /// The type converter to use when changing types. - TypeConverter *typeConverter; + /// A mapping of regions to type converters that should be used when + /// converting the arguments of blocks within that region. + DenseMap regionToConverter; /// The pattern rewriter to use when materializing conversions. PatternRewriter &rewriter; @@ -240,6 +261,9 @@ // Rewrite Application void ArgConverter::notifyOpRemoved(Operation *op) { + if (conversionInfo.empty()) + return; + for (Region ®ion : op->getRegions()) { for (Block &block : region) { // Drop any rewrites from within. @@ -277,6 +301,7 @@ origBlock->moveBefore(block); block->erase(); + convertedBlocks.erase(origBlock); conversionInfo.erase(it); } @@ -305,8 +330,8 @@ // persist in the IR after conversion. if (!origArg.use_empty()) { rewriter.setInsertionPointToStart(newBlock); - Value newArg = typeConverter->materializeConversion( - rewriter, loc, origArg.getType(), llvm::None); + Value newArg = blockInfo.converter->materializeConversion( + rewriter, origArg.getLoc(), origArg.getType(), llvm::None); assert(newArg && "Couldn't materialize a block argument after 1->0 conversion"); origArg.replaceAllUsesWith(newArg); @@ -333,15 +358,23 @@ //===----------------------------------------------------------------------===// // Conversion -Block *ArgConverter::convertSignature(Block *block, - ConversionValueMapping &mapping) { - if (auto conversion = typeConverter->convertBlockSignature(block)) - return applySignatureConversion(block, *conversion, mapping); - return nullptr; +FailureOr +ArgConverter::convertSignature(Block *block, TypeConverter &converter, + ConversionValueMapping &mapping) { + // Check if the block was already converted. If the block is detached, + // conservatively assume it is going to be deleted. + if (hasBeenConverted(block) || !block->getParent()) + return block; + + // Try to convert the signature for the block with the provided converter. + if (auto conversion = converter.convertBlockSignature(block)) + return applySignatureConversion(block, converter, *conversion, mapping); + return failure(); } Block *ArgConverter::applySignatureConversion( - Block *block, TypeConverter::SignatureConversion &signatureConversion, + Block *block, TypeConverter &converter, + TypeConverter::SignatureConversion &signatureConversion, ConversionValueMapping &mapping) { // If no arguments are being changed or added, there is nothing to do. unsigned origArgCount = block->getNumArguments(); @@ -359,7 +392,7 @@ // Remap each of the original arguments as determined by the signature // conversion. - ConvertedBlockInfo info(block); + ConvertedBlockInfo info(block, converter); info.argInfo.resize(origArgCount); OpBuilder::InsertionGuard guard(rewriter); @@ -384,10 +417,8 @@ // to pack the new values. For 1->1 mappings, if there is no materialization // provided, use the argument directly instead. auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size); - Value newArg; - if (typeConverter) - newArg = typeConverter->materializeConversion( - rewriter, loc, origArg.getType(), replArgs); + Value newArg = converter.materializeConversion(rewriter, origArg.getLoc(), + origArg.getType(), replArgs); if (!newArg) { assert(replArgs.size() == 1 && "couldn't materialize the result of 1->N conversion"); @@ -414,6 +445,7 @@ // Move the original block to the mapped region and emplace the conversion. mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(), info.origBlock->getIterator()); + convertedBlocks.insert(info.origBlock); conversionInfo.insert({newBlock, std::move(info)}); } @@ -548,9 +580,8 @@ }; }; - ConversionPatternRewriterImpl(PatternRewriter &rewriter, - TypeConverter *converter) - : argConverter(converter, rewriter) {} + ConversionPatternRewriterImpl(PatternRewriter &rewriter) + : argConverter(rewriter) {} /// Return the current state of the rewriter. RewriterState getCurrentState(); @@ -575,13 +606,20 @@ void applyRewrites(); /// Convert the signature of the given block. - LogicalResult convertBlockSignature(Block *block); + FailureOr convertBlockSignature( + Block *block, TypeConverter &converter, + TypeConverter::SignatureConversion *conversion = nullptr); /// Apply a signature conversion on the given region. Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion); + /// Convert the types of block arguments within the given region. + FailureOr + convertRegionTypes(Region *region, TypeConverter &converter, + TypeConverter::SignatureConversion *entryConversion); + /// PatternRewriter hook for replacing the results of an operation. void replaceOp(Operation *op, ValueRange newValues); @@ -654,6 +692,10 @@ /// A logger used to emit diagnostics during the conversion process. llvm::ScopedPrinter logger{llvm::dbgs()}; #endif + + /// A default type converter, used when block conversions do not have one + /// explicitly provided. + TypeConverter defaultTypeConverter; }; } // end namespace detail } // end namespace mlir @@ -791,7 +833,7 @@ // If this operation defines any regions, drop any pending argument // rewrites. - if (argConverter.typeConverter && repl.op->getNumRegions()) + if (repl.op->getNumRegions()) argConverter.notifyOpRemoved(repl.op); } @@ -826,34 +868,45 @@ eraseDanglingBlocks(); } -LogicalResult -ConversionPatternRewriterImpl::convertBlockSignature(Block *block) { - // Check to see if this block should not be converted: - // * There is no type converter. - // * The block has already been converted. - // * This is an entry block, these are converted explicitly via patterns. - if (!argConverter.typeConverter || argConverter.hasBeenConverted(block) || - !block->getParent() || block->isEntryBlock()) - return success(); - - // Otherwise, try to convert the block signature. - Block *newBlock = argConverter.convertSignature(block, mapping); - if (newBlock) - blockActions.push_back(BlockAction::getTypeConversion(newBlock)); - return success(newBlock); +FailureOr ConversionPatternRewriterImpl::convertBlockSignature( + Block *block, TypeConverter &converter, + TypeConverter::SignatureConversion *conversion) { + FailureOr result = + conversion ? argConverter.applySignatureConversion(block, converter, + *conversion, mapping) + : argConverter.convertSignature(block, converter, mapping); + if (Block *newBlock = result.getValue()) { + if (newBlock != block) + blockActions.push_back(BlockAction::getTypeConversion(newBlock)); + } + return result; } Block *ConversionPatternRewriterImpl::applySignatureConversion( Region *region, TypeConverter::SignatureConversion &conversion) { if (!region->empty()) { - Block *newEntry = argConverter.applySignatureConversion( - ®ion->front(), conversion, mapping); - blockActions.push_back(BlockAction::getTypeConversion(newEntry)); - return newEntry; + return *convertBlockSignature(®ion->front(), defaultTypeConverter, + &conversion); } return nullptr; } +FailureOr ConversionPatternRewriterImpl::convertRegionTypes( + Region *region, TypeConverter &converter, + TypeConverter::SignatureConversion *entryConversion) { + argConverter.setConverter(region, &converter); + if (region->empty()) + return nullptr; + + // Convert the arguments of each block within the region. + FailureOr newEntry = + convertBlockSignature(®ion->front(), converter, entryConversion); + for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) + if (failed(convertBlockSignature(&block, converter))) + return failure(); + return newEntry; +} + void ConversionPatternRewriterImpl::replaceOp(Operation *op, ValueRange newValues) { assert(newValues.size() == op->getNumResults()); @@ -938,10 +991,9 @@ // ConversionPatternRewriter //===----------------------------------------------------------------------===// -ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx, - TypeConverter *converter) +ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx) : PatternRewriter(ctx), - impl(new detail::ConversionPatternRewriterImpl(*this, converter)) {} + impl(new detail::ConversionPatternRewriterImpl(*this)) {} ConversionPatternRewriter::~ConversionPatternRewriter() {} /// PatternRewriter hook for replacing the results of an operation. @@ -979,12 +1031,17 @@ block->getParent()->getBlocks().remove(block); } -/// Apply a signature conversion to the entry block of the given region. Block *ConversionPatternRewriter::applySignatureConversion( Region *region, TypeConverter::SignatureConversion &conversion) { return impl->applySignatureConversion(region, conversion); } +FailureOr ConversionPatternRewriter::convertRegionTypes( + Region *region, TypeConverter &converter, + TypeConverter::SignatureConversion *entryConversion) { + return impl->convertRegionTypes(region, converter, entryConversion); +} + void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, Value to) { LLVM_DEBUG({ @@ -1163,6 +1220,20 @@ ConversionPatternRewriter &rewriter, RewriterState &curState); + /// Legalizes the actions registered during the execution of a pattern. + LogicalResult legalizePatternBlockActions(Operation *op, + ConversionPatternRewriter &rewriter, + ConversionPatternRewriterImpl &impl, + RewriterState &state, + RewriterState &newState); + LogicalResult legalizePatternCreatedOperations( + ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, + RewriterState &state, RewriterState &newState); + LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter, + ConversionPatternRewriterImpl &impl, + RewriterState &state, + RewriterState &newState); + /// Build an optimistic legalization graph given the provided patterns. This /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with /// patterns for operations that are not directly legal, but may be @@ -1402,50 +1473,29 @@ LogicalResult OperationLegalizer::legalizePatternResult( Operation *op, const RewritePattern &pattern, ConversionPatternRewriter &rewriter, RewriterState &curState) { - auto &rewriterImpl = rewriter.getImpl(); + auto &impl = rewriter.getImpl(); #ifndef NDEBUG - assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); + assert(impl.pendingRootUpdates.empty() && "dangling root updates"); #endif - // If the pattern moved or created any blocks, try to legalize their types. - // This ensures that the types of the block arguments are legal for the region - // they were moved into. - for (unsigned i = curState.numBlockActions, - e = rewriterImpl.blockActions.size(); - i != e; ++i) { - auto &action = rewriterImpl.blockActions[i]; - if (action.kind == - ConversionPatternRewriterImpl::BlockActionKind::TypeConversion || - action.kind == ConversionPatternRewriterImpl::BlockActionKind::Erase) - continue; - - // Convert the block signature. - if (failed(rewriterImpl.convertBlockSignature(action.block))) { - LLVM_DEBUG(logFailure(rewriterImpl.logger, - "failed to convert types of moved block")); - return failure(); - } - } - // Check all of the replacements to ensure that the pattern actually replaced // the root operation. We also mark any other replaced ops as 'dead' so that // we don't try to legalize them later. bool replacedRoot = false; - for (unsigned i = curState.numReplacements, - e = rewriterImpl.replacements.size(); + for (unsigned i = curState.numReplacements, e = impl.replacements.size(); i != e; ++i) { - Operation *replacedOp = rewriterImpl.replacements[i].op; + Operation *replacedOp = impl.replacements[i].op; if (replacedOp == op) replacedRoot = true; else - rewriterImpl.ignoredOps.insert(replacedOp); + impl.ignoredOps.insert(replacedOp); } // Check that the root was either updated or replace. auto updatedRootInPlace = [&] { return llvm::any_of( - llvm::drop_begin(rewriterImpl.rootUpdates, curState.numRootUpdates), + llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates), [op](auto &state) { return state.getOperation() == op; }); }; (void)replacedRoot; @@ -1453,32 +1503,99 @@ assert((replacedRoot || updatedRootInPlace()) && "expected pattern to replace the root operation"); - // Recursively legalize each of the operations updated in place. - for (unsigned i = curState.numRootUpdates, - e = rewriterImpl.rootUpdates.size(); - i != e; ++i) { - auto &state = rewriterImpl.rootUpdates[i]; - if (failed(legalize(state.getOperation(), rewriter))) { - LLVM_DEBUG(logFailure(rewriterImpl.logger, - "operation updated in-place '{0}' was illegal", - op->getName())); + // Legalize each of the actions registered during application. + RewriterState newState = impl.getCurrentState(); + if (failed(legalizePatternBlockActions(op, rewriter, impl, curState, + newState)) || + failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) || + failed(legalizePatternCreatedOperations(rewriter, impl, curState, + newState))) { + return failure(); + } + + LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully")); + return success(); +} + +LogicalResult OperationLegalizer::legalizePatternBlockActions( + Operation *op, ConversionPatternRewriter &rewriter, + ConversionPatternRewriterImpl &impl, RewriterState &state, + RewriterState &newState) { + SmallPtrSet operationsToIgnore; + + // If the pattern moved or created any blocks, make sure the types of block + // arguments get legalized. + for (int i = state.numBlockActions, e = newState.numBlockActions; i != e; + ++i) { + auto &action = impl.blockActions[i]; + if (action.kind == + ConversionPatternRewriterImpl::BlockActionKind::TypeConversion || + action.kind == ConversionPatternRewriterImpl::BlockActionKind::Erase) + continue; + // Only check blocks outside of the current operation. + Operation *parentOp = action.block->getParentOp(); + if (!parentOp || parentOp == op || action.block->getNumArguments() == 0) + continue; + + // If the region of the block has a type converter, try to convert the block + // directly. + if (auto *converter = + impl.argConverter.getConverter(action.block->getParent())) { + if (failed(impl.convertBlockSignature(action.block, *converter))) { + LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved " + "block")); + return failure(); + } + continue; + } + + // Otherwise, check that this operation isn't one generated by this pattern. + // This is because we will attempt to legalize the parent operation, and + // blocks in regions created by this pattern will already be legalized later + // on. If we haven't built the set yet, build it now. + if (operationsToIgnore.empty()) { + auto createdOps = ArrayRef(impl.createdOps) + .drop_front(state.numCreatedOps); + operationsToIgnore.insert(createdOps.begin(), createdOps.end()); + } + + // If this operation should be considered for re-legalization, try it. + if (operationsToIgnore.insert(parentOp).second && + failed(legalize(parentOp, rewriter))) { + LLVM_DEBUG(logFailure( + impl.logger, "operation '{0}'({1}) became illegal after block action", + parentOp->getName(), parentOp)); return failure(); } } - - // Recursively legalize each of the new operations. - for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size(); - i != e; ++i) { - Operation *op = rewriterImpl.createdOps[i]; + return success(); +} +LogicalResult OperationLegalizer::legalizePatternCreatedOperations( + ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, + RewriterState &state, RewriterState &newState) { + for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) { + Operation *op = impl.createdOps[i]; if (failed(legalize(op, rewriter))) { - LLVM_DEBUG(logFailure(rewriterImpl.logger, + LLVM_DEBUG(logFailure(impl.logger, "generated operation '{0}'({1}) was illegal", op->getName(), op)); return failure(); } } - - LLVM_DEBUG(logSuccess(rewriterImpl.logger, "pattern applied successfully")); + return success(); +} +LogicalResult OperationLegalizer::legalizePatternRootUpdates( + ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, + RewriterState &state, RewriterState &newState) { + for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) { + Operation *op = impl.rootUpdates[i].getOperation(); + if (failed(legalize(op, rewriter))) { + LLVM_DEBUG(logFailure(impl.logger, + "operation updated in-place '{0}' was illegal", + op->getName())); + return failure(); + } + } return success(); } @@ -1699,17 +1816,12 @@ : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {} /// Converts the given operations to the conversion target. - LogicalResult convertOperations(ArrayRef ops, - TypeConverter *typeConverter); + LogicalResult convertOperations(ArrayRef ops); private: /// Converts an operation with the given rewriter. LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op); - /// Converts the type signatures of the blocks nested within 'op'. - LogicalResult convertBlockSignatures(ConversionPatternRewriter &rewriter, - Operation *op); - /// The legalizer to use when converting operations. OperationLegalizer opLegalizer; @@ -1724,21 +1836,6 @@ }; } // end anonymous namespace -LogicalResult -OperationConverter::convertBlockSignatures(ConversionPatternRewriter &rewriter, - Operation *op) { - // Check to see if type signatures need to be converted. - if (!rewriter.getImpl().argConverter.typeConverter) - return success(); - - for (auto ®ion : op->getRegions()) { - for (auto &block : llvm::make_early_inc_range(region)) - if (failed(rewriter.getImpl().convertBlockSignature(&block))) - return failure(); - } - return success(); -} - LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter, Operation *op) { // Legalize the given operation. @@ -1759,24 +1856,16 @@ if (trackedOps) trackedOps->insert(op); } - } else { + } else if (mode == OpConversionMode::Analysis) { // Analysis conversions don't fail if any operations fail to legalize, // they are only interested in the operations that were successfully // legalized. - if (mode == OpConversionMode::Analysis) - trackedOps->insert(op); - - // If legalization succeeded, convert the types any of the blocks within - // this operation. - if (failed(convertBlockSignatures(rewriter, op))) - return failure(); + trackedOps->insert(op); } return success(); } -LogicalResult -OperationConverter::convertOperations(ArrayRef ops, - TypeConverter *typeConverter) { +LogicalResult OperationConverter::convertOperations(ArrayRef ops) { if (ops.empty()) return success(); ConversionTarget &target = opLegalizer.getTarget(); @@ -1792,7 +1881,7 @@ } // Convert each operation and discard rewrites on failure. - ConversionPatternRewriter rewriter(ops.front()->getContext(), typeConverter); + ConversionPatternRewriter rewriter(ops.front()->getContext()); for (auto *op : toConvert) if (failed(convert(rewriter, op))) return rewriter.getImpl().discardRewrites(), failure(); @@ -1913,6 +2002,13 @@ return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes()); } +/// Return true if the types of block arguments within the region are legal. +bool TypeConverter::isLegal(Region *region) { + return llvm::all_of(*region, [this](Block &block) { + return isLegal(block.getArgumentTypes()); + }); +} + /// Return true if the inputs and outputs of the given function type are /// legal. bool TypeConverter::isSignatureLegal(FunctionType ty) { @@ -1969,7 +2065,7 @@ namespace { struct FuncOpSignatureConversion : public OpConversionPattern { FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) - : OpConversionPattern(ctx), converter(converter) {} + : OpConversionPattern(converter, ctx) {} /// Hook for derived classes to implement combined matching and rewriting. LogicalResult @@ -1979,22 +2075,20 @@ // Convert the original function types. TypeConverter::SignatureConversion result(type.getNumInputs()); - SmallVector convertedResults; - if (failed(converter.convertSignatureArgs(type.getInputs(), result)) || - failed(converter.convertTypes(type.getResults(), convertedResults))) + SmallVector newResults; + if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) || + failed(typeConverter->convertTypes(type.getResults(), newResults)) || + failed(rewriter.convertRegionTypes(&funcOp.getBody(), *typeConverter, + &result))) return failure(); // Update the function signature in-place. rewriter.updateRootInPlace(funcOp, [&] { - funcOp.setType(FunctionType::get(result.getConvertedTypes(), - convertedResults, funcOp.getContext())); - rewriter.applySignatureConversion(&funcOp.getBody(), result); + funcOp.setType(FunctionType::get(result.getConvertedTypes(), newResults, + funcOp.getContext())); }); return success(); } - - /// The type converter to use when rewriting the signature. - TypeConverter &converter; }; } // end anonymous namespace @@ -2128,27 +2222,26 @@ /// Apply a partial conversion on the given operations and all nested /// operations. This method converts as many operations to the target as /// possible, ignoring operations that failed to legalize. This method only -/// returns failure if there ops explicitly marked as illegal. If `converter` is -/// provided, the signatures of blocks and regions are also converted. +/// returns failure if there ops explicitly marked as illegal. /// If an `unconvertedOps` set is provided, all operations that are found not /// to be legalizable to the given `target` are placed within that set. (Note /// that if there is an op explicitly marked as illegal, the conversion /// terminates and the `unconvertedOps` set will not necessarily be complete.) -LogicalResult mlir::applyPartialConversion( - ArrayRef ops, ConversionTarget &target, - const OwningRewritePatternList &patterns, TypeConverter *converter, - DenseSet *unconvertedOps) { +LogicalResult +mlir::applyPartialConversion(ArrayRef ops, + ConversionTarget &target, + const OwningRewritePatternList &patterns, + DenseSet *unconvertedOps) { OperationConverter opConverter(target, patterns, OpConversionMode::Partial, unconvertedOps); - return opConverter.convertOperations(ops, converter); + return opConverter.convertOperations(ops); } LogicalResult mlir::applyPartialConversion(Operation *op, ConversionTarget &target, const OwningRewritePatternList &patterns, - TypeConverter *converter, DenseSet *unconvertedOps) { return applyPartialConversion(llvm::makeArrayRef(op), target, patterns, - converter, unconvertedOps); + unconvertedOps); } /// Apply a complete conversion on the given operations, and all nested @@ -2156,17 +2249,14 @@ /// operation fails. LogicalResult mlir::applyFullConversion(ArrayRef ops, ConversionTarget &target, - const OwningRewritePatternList &patterns, - TypeConverter *converter) { + const OwningRewritePatternList &patterns) { OperationConverter opConverter(target, patterns, OpConversionMode::Full); - return opConverter.convertOperations(ops, converter); + return opConverter.convertOperations(ops); } LogicalResult mlir::applyFullConversion(Operation *op, ConversionTarget &target, - const OwningRewritePatternList &patterns, - TypeConverter *converter) { - return applyFullConversion(llvm::makeArrayRef(op), target, patterns, - converter); + const OwningRewritePatternList &patterns) { + return applyFullConversion(llvm::makeArrayRef(op), target, patterns); } /// Apply an analysis conversion on the given operations, and all nested @@ -2175,19 +2265,19 @@ /// were found to be legalizable to the given 'target' are placed within the /// provided 'convertedOps' set; note that no actual rewrites are applied to the /// operations on success and only pre-existing operations are added to the set. -LogicalResult mlir::applyAnalysisConversion( - ArrayRef ops, ConversionTarget &target, - const OwningRewritePatternList &patterns, - DenseSet &convertedOps, TypeConverter *converter) { +LogicalResult +mlir::applyAnalysisConversion(ArrayRef ops, + ConversionTarget &target, + const OwningRewritePatternList &patterns, + DenseSet &convertedOps) { OperationConverter opConverter(target, patterns, OpConversionMode::Analysis, &convertedOps); - return opConverter.convertOperations(ops, converter); + return opConverter.convertOperations(ops); } LogicalResult mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target, const OwningRewritePatternList &patterns, - DenseSet &convertedOps, - TypeConverter *converter) { + DenseSet &convertedOps) { return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns, - convertedOps, converter); + convertedOps); } diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -153,15 +153,11 @@ // CHECK-LABEL: @create_block func @create_block() { - // expected-remark@+1 {{op 'test.container' is not legalizable}} - "test.container"() ({ - // Check that we created a block with arguments. - // CHECK-NOT: test.create_block - // CHECK: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32): - // CHECK: test.finish - "test.create_block"() : () -> () - "test.finish"() : () -> () - }) : () -> () + // Check that we created a block with arguments. + // CHECK-NOT: test.create_block + // CHECK: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32): + "test.create_block"() : () -> () + // expected-remark@+1 {{op 'std.return' is not legalizable}} return } @@ -212,15 +208,12 @@ // CHECK-LABEL: @create_illegal_block func @create_illegal_block() { - // expected-remark@+1 {{op 'test.container' is not legalizable}} - "test.container"() ({ - // Check that we can undo block creation, i.e. that the block was removed. - // CHECK: test.create_illegal_block - // CHECK-NOT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32): - // expected-remark@+1 {{op 'test.create_illegal_block' is not legalizable}} - "test.create_illegal_block"() : () -> () - "test.finish"() : () -> () - }) : () -> () + // Check that we can undo block creation, i.e. that the block was removed. + // CHECK: test.create_illegal_block + // CHECK-NOT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32): + // expected-remark@+1 {{op 'test.create_illegal_block' is not legalizable}} + "test.create_illegal_block"() : () -> () + // expected-remark@+1 {{op 'std.return' is not legalizable}} return } diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -304,8 +304,7 @@ /// This patterns erases a region operation that has had a type conversion. struct TestDropOpSignatureConversion : public ConversionPattern { TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) - : ConversionPattern("test.drop_region_op", 1, ctx), converter(converter) { - } + : ConversionPattern("test.drop_region_op", 1, converter, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { @@ -313,19 +312,17 @@ Block *entry = ®ion.front(); // Convert the original entry arguments. + TypeConverter &converter = *getTypeConverter(); TypeConverter::SignatureConversion result(entry->getNumArguments()); - if (failed( - converter.convertSignatureArgs(entry->getArgumentTypes(), result))) + if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), + result)) || + failed(rewriter.convertRegionTypes(®ion, converter, &result))) return failure(); // Convert the region signature and just drop the operation. - rewriter.applySignatureConversion(®ion, result); rewriter.eraseOp(op); return success(); } - - /// The type converter to use when rewriting the signature. - TypeConverter &converter; }; /// This pattern simply updates the operands of the given operation. struct TestPassthroughInvalidOp : public ConversionPattern { @@ -568,8 +565,10 @@ return llvm::none_of(op.getOperandTypes(), [](Type type) { return type.isF32(); }); }); - target.addDynamicallyLegalOp( - [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); + target.addDynamicallyLegalOp([&](FuncOp op) { + return converter.isSignatureLegal(op.getType()) && + converter.isLegal(&op.getBody()); + }); // Expect the type_producer/type_consumer operations to only operate on f64. target.addDynamicallyLegalOp( @@ -591,7 +590,7 @@ // Handle a partial conversion. if (mode == ConversionMode::Partial) { DenseSet unlegalizedOps; - (void)applyPartialConversion(getOperation(), target, patterns, &converter, + (void)applyPartialConversion(getOperation(), target, patterns, &unlegalizedOps); // Emit remarks for each legalizable operation. for (auto *op : unlegalizedOps) @@ -606,7 +605,7 @@ return (bool)op->getAttrOfType("test.dynamically_legal"); }); - (void)applyFullConversion(getOperation(), target, patterns, &converter); + (void)applyFullConversion(getOperation(), target, patterns); return; } @@ -616,7 +615,7 @@ // Analyze the convertible operations. DenseSet legalizedOps; if (failed(applyAnalysisConversion(getOperation(), target, patterns, - legalizedOps, &converter))) + legalizedOps))) return signalPassFailure(); // Emit remarks for each legalizable operation. diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -1,4 +1,4 @@ -//===- TestBufferPlacement.cpp - Test for buffer placement 0----*- C++ -*-===// +//===- TestBufferPlacement.cpp - Test for buffer placement ------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -140,7 +140,8 @@ // Mark the function whose arguments are in tensor-type illegal. target.addDynamicallyLegalOp([&](FuncOp funcOp) { - return converter.isSignatureLegal(funcOp.getType()); + return converter.isSignatureLegal(funcOp.getType()) && + converter.isLegal(&funcOp.getBody()); }); // Walk over all the functions to apply buffer assignment. @@ -151,7 +152,7 @@ &context, &placer, &converter, &patterns); // Applying full conversion - return applyFullConversion(function, target, patterns, &converter); + return applyFullConversion(function, target, patterns); }); }; };