diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h @@ -26,6 +26,14 @@ MLIRContext *ctx, TypeConverter &converter); +/// Add a pattern to the given pattern list to rewrite terminators to use +/// operands that have been legalized by the conversion framework. This can only +/// be done if the terminator has the ReturnLike trait or implements the +/// BranchOpInterface. Otherwise the pattern will produce an error. Only needed +/// for partial conversions. +void populateTerminatorTypeConversionPattern(OwningRewritePatternList &patterns, + MLIRContext *ctx, + TypeConverter &converter); } // end namespace mlir #endif // MLIR_DIALECT_STANDARDOPS_TRANSFORMS_FUNCCONVERSIONS_H_ diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td @@ -25,28 +25,21 @@ def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> { let summary = "Bufferize func/call/return ops"; let description = [{ - A finalizing bufferize pass that bufferizes std.func and std.call ops. + A bufferize pass that bufferizes std.func and std.call ops. Because this pass updates std.func ops, it must be a module pass. It is useful to keep this pass separate from other bufferizations so that the other ones can be run at function-level in parallel. - This pass must be done atomically for two reasons: - 1. This pass changes func op signatures, which requires atomically updating - calls as well throughout the entire module. - 2. This pass changes the type of block arguments, which requires that all - successor arguments of predecessors be converted. Terminators are not - a closed universe (and need not implement BranchOpInterface), and so we - cannot in general rewrite them. + This pass must be done atomically because it changes func op signatures, + which requires atomically updating calls as well throughout the entire + module. - Note, because this is a "finalizing" bufferize step, it can create - invalid IR because it will not create materializations. To avoid this - situation, the pass must only be run when the only SSA values of - tensor type are: - - block arguments - - the result of tensor_load - Other values of tensor type should be eliminated by earlier - bufferization passes. + This pass also changes the type of block arguments, which requires that all + successor arguments of predecessors be converted. This is achieved by + rewriting terminators based on the information provided by the `ReturnLike` + trait and `BranchOpInterface`. Consequently, using this pass requires that + all terminators implement one of those interfaces. }]; let constructor = "mlir::createFuncBufferizePass()"; } diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -46,6 +46,10 @@ createPromoteBuffersToStackPass(unsigned maxAllocSizeInBytes = 1024, unsigned bitwidthOfIndexType = 64); +/// Creates a pass that finalizes a partial bufferization by removing remaining +/// tensor_load and tensor_to_memref operations. +std::unique_ptr createFinalizingBufferizePass(); + /// Creates a pass that converts memref function results to out-params. std::unique_ptr createBufferResultsToOutParamsPass(); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -602,4 +602,21 @@ }]; let constructor = "mlir::createSymbolDCEPass()"; } + +def FinalizingBufferize : FunctionPass<"finalizing-bufferize"> { + let summary = "Finalize a partial bufferization"; + let description = [{ + A bufferize pass that finalizes a partial bufferization by removing + remaining `tensor_load` and `tensor_to_memref` operations. + + The removal of those operations is only possible if the operations only + exist in pairs, i.e., all uses of `tensor_load` operations are + `tensor_to_memref` operations. + + This pass will fail if not all operations can be removed or if any operation + with tensor typed operands remains. + }]; + let constructor = "mlir::createFinalizingBufferizePass()"; +} + #endif // MLIR_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp @@ -21,6 +21,8 @@ namespace { struct FuncBufferizePass : public FuncBufferizeBase { + using FuncBufferizeBase::FuncBufferizeBase; + void runOnOperation() override { auto module = getOperation(); auto *context = &getContext(); @@ -35,14 +37,34 @@ typeConverter.isLegal(&op.getBody()); }); populateCallOpTypeConversionPattern(patterns, context, typeConverter); - populateEliminateBufferizeMaterializationsPatterns(context, typeConverter, - patterns); - target.addIllegalOp(); + target.addDynamicallyLegalOp([&](CallOp op) { + return typeConverter.isLegal(op.getOperandTypes()) && + typeConverter.isLegal(op.getResultTypes()); + }); - // If all result types are legal, and all block arguments are legal (ensured - // by func conversion above), then all types in the program are legal. + populateTerminatorTypeConversionPattern(patterns, context, typeConverter); + target.addLegalOp(); + // Mark terminators as legal if they have the ReturnLike trait or + // implement the BranchOpInterface and have valid types. If they do not + // implement the trait or interface, mark them as illegal no matter what. target.markUnknownOpDynamicallyLegal([&](Operation *op) { - return typeConverter.isLegal(op->getResultTypes()); + // If it is not a terminator, ignore it. + if (op->getNextNode()) + return true; + if (op->hasTrait()) { + return typeConverter.isLegal(op->getOperandTypes()); + } + if (auto branchOp = dyn_cast(op)) { + for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) { + auto successorOperands = branchOp.getSuccessorOperands(p); + if (successorOperands.hasValue() && + !typeConverter.isLegal(successorOperands.getValue().getTypes())) + return false; + } + return true; + } + return false; }); if (failed(applyFullConversion(module, target, std::move(patterns)))) diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp @@ -16,18 +16,16 @@ // Converts the operand and result types of the Standard's CallOp, used together // with the FuncOpSignatureConversion. struct CallOpSignatureConversion : public OpConversionPattern { - CallOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) - : OpConversionPattern(ctx), converter(converter) {} + using OpConversionPattern::OpConversionPattern; /// Hook for derived classes to implement combined matching and rewriting. LogicalResult matchAndRewrite(CallOp callOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - FunctionType type = callOp.getCalleeType(); - // Convert the original function results. SmallVector convertedResults; - if (failed(converter.convertTypes(type.getResults(), convertedResults))) + if (failed(typeConverter->convertTypes(callOp.getResultTypes(), + convertedResults))) return failure(); // Substitute with the new result types from the corresponding FuncType @@ -36,14 +34,65 @@ convertedResults, operands); return success(); } - - /// The type converter to use when rewriting the signature. - TypeConverter &converter; }; } // end anonymous namespace void mlir::populateCallOpTypeConversionPattern( OwningRewritePatternList &patterns, MLIRContext *ctx, TypeConverter &converter) { - patterns.insert(ctx, converter); + patterns.insert(converter, ctx); +} + +namespace { +// Only needed to support partial conversion of functions where this pattern +// ensures that the terminator operation arguments matches up with the succesor +// block arguments. +class TerminatorOpTypeConversion : public ConversionPattern { +public: + TerminatorOpTypeConversion(TypeConverter &typeConverter, MLIRContext *ctx) + : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // Only look at potential terminators (last op in the block). + if (op->getNextNode()) + return failure(); + if (op->hasTrait()) { + // For a return like, all operands go to the results of the parent, so + // rewrite them all. + rewriter.updateRootInPlace( + op, [operands, op]() { op->setOperands(operands); }); + return success(); + } + if (auto branchOp = dyn_cast(op)) { + // For a branch operation, only some operands go to the target blocks, so + // only rewrite those. + SmallVector newOperands(op->operand_begin(), op->operand_end()); + for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors(); + succIdx < succEnd; ++succIdx) { + auto successorOperands = branchOp.getSuccessorOperands(succIdx); + if (!successorOperands) + continue; + for (int idx = successorOperands->getBeginOperandIndex(), + eidx = idx + successorOperands->size(); + idx < eidx; ++idx) { + newOperands[idx] = operands[idx]; + } + } + rewriter.updateRootInPlace( + op, [newOperands, op]() { op->setOperands(newOperands); }); + return success(); + } + + return rewriter.notifyMatchFailure( + op, "terminator does not implement required interfaces"); + } +}; +} // end anonymous namespace + +void mlir::populateTerminatorTypeConversionPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx, + TypeConverter &typeConverter) { + patterns.insert(typeConverter, ctx); } diff --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp --- a/mlir/lib/Transforms/Bufferize.cpp +++ b/mlir/lib/Transforms/Bufferize.cpp @@ -7,7 +7,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Transforms/Bufferize.h" +#include "PassDetail.h" #include "mlir/IR/Operation.h" +#include "mlir/Transforms/Passes.h" using namespace mlir; @@ -15,6 +17,13 @@ // BufferizeTypeConverter //===----------------------------------------------------------------------===// +static Value InsertTensorLoad(OpBuilder &builder, TensorType type, + ValueRange inputs, Location loc) { + assert(inputs.size() == 1); + assert(inputs[0].getType().isa()); + return builder.create(loc, type, inputs[0]); +} + /// Registers conversions into BufferizeTypeConverter BufferizeTypeConverter::BufferizeTypeConverter() { // Keep all types unchanged. @@ -27,12 +36,8 @@ addConversion([](UnrankedTensorType type) -> Type { return UnrankedMemRefType::get(type.getElementType(), 0); }); - addSourceMaterialization([](OpBuilder &builder, TensorType type, - ValueRange inputs, Location loc) -> Value { - assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); - return builder.create(loc, type, inputs[0]); - }); + addArgumentMaterialization(InsertTensorLoad); + addSourceMaterialization(InsertTensorLoad); addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); @@ -83,3 +88,37 @@ patterns.insert( typeConverter, context); } + +namespace { +struct FinalizingBufferizePass + : public FinalizingBufferizeBase { + using FinalizingBufferizeBase< + FinalizingBufferizePass>::FinalizingBufferizeBase; + + void runOnFunction() override { + auto func = getFunction(); + auto *context = &getContext(); + + BufferizeTypeConverter typeConverter; + OwningRewritePatternList patterns; + ConversionTarget target(*context); + + populateEliminateBufferizeMaterializationsPatterns(context, typeConverter, + patterns); + target.addIllegalOp(); + + // If all result types are legal, and all block arguments are legal (ensured + // by func conversion above), then all types in the program are legal. + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + return typeConverter.isLegal(op->getResultTypes()); + }); + + if (failed(applyFullConversion(func, target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr mlir::createFinalizingBufferizePass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Standard/func-bufferize-partial.mlir b/mlir/test/Dialect/Standard/func-bufferize-partial.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Standard/func-bufferize-partial.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func @block_arguments( +// CHECK-SAME: %[[ARG:.*]]: memref) -> memref { +// CHECK: %[[T1:.*]] = tensor_load %[[ARG]] : memref +// CHECK: %[[M1:.*]] = tensor_to_memref %[[T1]] : memref +// CHECK: br ^bb1(%[[M1]] : memref) +// CHECK: ^bb1(%[[BBARG:.*]]: memref): +// CHECK: %[[T2:.*]] = tensor_load %[[BBARG]] : memref +// CHECK: %[[M2:.*]] = tensor_to_memref %[[T2]] : memref +// CHECK: return %[[M2]] : memref +func @block_arguments(%arg0: tensor) -> tensor { + br ^bb1(%arg0: tensor) +^bb1(%bbarg: tensor): + return %bbarg : tensor +} + +// CHECK-LABEL: func @partial() +// CHECK-SAME: memref +func @partial() -> tensor { + // CHECK-NEXT: %[[SRC:.*]] = "test.source"() : () -> tensor + // CHECK-NEXT: %[[MEM:.*]] = tensor_to_memref %[[SRC]] : memref + %0 = "test.source"() : () -> tensor + // CHECK-NEXT: return %[[MEM]] : memref + return %0 : tensor +} + +// ----- + +func @failed_to_legalize(%arg0: tensor) -> tensor { + %0 = constant true + cond_br %0, ^bb1(%arg0: tensor), ^bb2(%arg0: tensor) + ^bb1(%bbarg0: tensor): + // expected-error @+1 {{failed to legalize operation 'test.terminator'}} + "test.terminator"() : () -> () + ^bb2(%bbarg1: tensor): + return %bbarg1 : tensor +} diff --git a/mlir/test/Dialect/Standard/func-bufferize.mlir b/mlir/test/Dialect/Standard/func-bufferize.mlir --- a/mlir/test/Dialect/Standard/func-bufferize.mlir +++ b/mlir/test/Dialect/Standard/func-bufferize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -func-bufferize --finalizing-bufferize -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func @identity( // CHECK-SAME: %[[ARG:.*]]: memref) -> memref {