diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h @@ -26,6 +26,13 @@ MLIRContext *ctx, TypeConverter &converter); +/// Add a pattern to the given pattern list to rewrite branch operations and +/// `return` to use operands that have been legalized by the conversion +/// framework. This can only be done if the branch operation implements the +/// BranchOpInterface. Only needed for partial conversions. +void populateBranchOpInterfaceAndReturnOpTypeConversionPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx, + TypeConverter &converter); } // end namespace mlir #endif // MLIR_DIALECT_STANDARDOPS_TRANSFORMS_FUNCCONVERSIONS_H_ diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td @@ -25,28 +25,26 @@ def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> { let summary = "Bufferize func/call/return ops"; let description = [{ - A finalizing bufferize pass that bufferizes std.func and std.call ops. + A bufferize pass that bufferizes std.func and std.call ops. Because this pass updates std.func ops, it must be a module pass. It is useful to keep this pass separate from other bufferizations so that the other ones can be run at function-level in parallel. - This pass must be done atomically for two reasons: - 1. This pass changes func op signatures, which requires atomically updating - calls as well throughout the entire module. - 2. This pass changes the type of block arguments, which requires that all - successor arguments of predecessors be converted. Terminators are not - a closed universe (and need not implement BranchOpInterface), and so we - cannot in general rewrite them. + This pass must be done atomically because it changes func op signatures, + which requires atomically updating calls as well throughout the entire + module. - Note, because this is a "finalizing" bufferize step, it can create - invalid IR because it will not create materializations. To avoid this - situation, the pass must only be run when the only SSA values of - tensor type are: - - block arguments - - the result of tensor_load - Other values of tensor type should be eliminated by earlier - bufferization passes. + This pass also changes the type of block arguments, which requires that all + successor arguments of predecessors be converted. This is achieved by + rewriting terminators based on the information provided by the + `BranchOpInterface`. + As this pass rewrites function operations, it also rewrites the + corresponding return operations. Other return-like operations that + implement the `ReturnLike` trait are not rewritten in general, as they + require that the correspondign parent operation is also rewritten. + Finally, this pass fails for unknown terminators, as we cannot decide + whether they need rewriting. }]; let constructor = "mlir::createFuncBufferizePass()"; } diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -46,6 +46,10 @@ createPromoteBuffersToStackPass(unsigned maxAllocSizeInBytes = 1024, unsigned bitwidthOfIndexType = 64); +/// Creates a pass that finalizes a partial bufferization by removing remaining +/// tensor_load and tensor_to_memref operations. +std::unique_ptr createFinalizingBufferizePass(); + /// Creates a pass that converts memref function results to out-params. std::unique_ptr createBufferResultsToOutParamsPass(); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -290,6 +290,22 @@ ]; } +def FinalizingBufferize : FunctionPass<"finalizing-bufferize"> { + let summary = "Finalize a partial bufferization"; + let description = [{ + A bufferize pass that finalizes a partial bufferization by removing + remaining `tensor_load` and `tensor_to_memref` operations. + + The removal of those operations is only possible if the operations only + exist in pairs, i.e., all uses of `tensor_load` operations are + `tensor_to_memref` operations. + + This pass will fail if not all operations can be removed or if any operation + with tensor typed operands remains. + }]; + let constructor = "mlir::createFinalizingBufferizePass()"; +} + def LocationSnapshot : Pass<"snapshot-op-locations"> { let summary = "Generate new locations from the current IR"; let description = [{ diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp @@ -21,6 +21,8 @@ namespace { struct FuncBufferizePass : public FuncBufferizeBase { + using FuncBufferizeBase::FuncBufferizeBase; + void runOnOperation() override { auto module = getOperation(); auto *context = &getContext(); @@ -35,14 +37,42 @@ typeConverter.isLegal(&op.getBody()); }); populateCallOpTypeConversionPattern(patterns, context, typeConverter); - populateEliminateBufferizeMaterializationsPatterns(context, typeConverter, - patterns); - target.addIllegalOp(); + target.addDynamicallyLegalOp( + [&](CallOp op) { return typeConverter.isLegal(op); }); - // If all result types are legal, and all block arguments are legal (ensured - // by func conversion above), then all types in the program are legal. + populateBranchOpInterfaceAndReturnOpTypeConversionPattern(patterns, context, + typeConverter); + target.addLegalOp(); + target.addDynamicallyLegalOp( + [&](ReturnOp op) { return typeConverter.isLegal(op); }); + // Mark terminators as legal if they have the ReturnLike trait or + // implement the BranchOpInterface and have valid types. If they do not + // implement the trait or interface, mark them as illegal no matter what. target.markUnknownOpDynamicallyLegal([&](Operation *op) { - return typeConverter.isLegal(op->getResultTypes()); + // If it is not a terminator, ignore it. + if (op->isKnownNonTerminator()) + return true; + // If it is not the last operation in the block, also ignore it. We do + // this to handle unknown operations, as well. + Block *block = op->getBlock(); + if (!block || &block->back() != op) + return true; + // ReturnLike operations have to be legalized with their parent. For + // return this is handled, for other ops they remain as is. + if (op->hasTrait()) + return true; + // All successor operands of branch like operations must be rewritten. + if (auto branchOp = dyn_cast(op)) { + for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) { + auto successorOperands = branchOp.getSuccessorOperands(p); + if (successorOperands.hasValue() && + !typeConverter.isLegal(successorOperands.getValue().getTypes())) + return false; + } + return true; + } + return false; }); if (failed(applyFullConversion(module, target, std::move(patterns)))) diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp @@ -13,21 +13,19 @@ using namespace mlir; namespace { -// Converts the operand and result types of the Standard's CallOp, used together -// with the FuncOpSignatureConversion. +/// Converts the operand and result types of the Standard's CallOp, used +/// together with the FuncOpSignatureConversion. struct CallOpSignatureConversion : public OpConversionPattern { - CallOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) - : OpConversionPattern(ctx), converter(converter) {} + using OpConversionPattern::OpConversionPattern; /// Hook for derived classes to implement combined matching and rewriting. LogicalResult matchAndRewrite(CallOp callOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - FunctionType type = callOp.getCalleeType(); - // Convert the original function results. SmallVector convertedResults; - if (failed(converter.convertTypes(type.getResults(), convertedResults))) + if (failed(typeConverter->convertTypes(callOp.getResultTypes(), + convertedResults))) return failure(); // Substitute with the new result types from the corresponding FuncType @@ -36,14 +34,77 @@ convertedResults, operands); return success(); } - - /// The type converter to use when rewriting the signature. - TypeConverter &converter; }; } // end anonymous namespace void mlir::populateCallOpTypeConversionPattern( OwningRewritePatternList &patterns, MLIRContext *ctx, TypeConverter &converter) { - patterns.insert(ctx, converter); + patterns.insert(converter, ctx); +} + +namespace { +/// Only needed to support partial conversion of functions where this pattern +/// ensures that the branch operation arguments matches up with the succesor +/// block arguments. +class BranchOpInterfaceTypeConversion : public ConversionPattern { +public: + BranchOpInterfaceTypeConversion(TypeConverter &typeConverter, + MLIRContext *ctx) + : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto branchOp = dyn_cast(op); + if (!branchOp) + return failure(); + + // For a branch operation, only some operands go to the target blocks, so + // only rewrite those. + SmallVector newOperands(op->operand_begin(), op->operand_end()); + for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors(); + succIdx < succEnd; ++succIdx) { + auto successorOperands = branchOp.getSuccessorOperands(succIdx); + if (!successorOperands) + continue; + for (int idx = successorOperands->getBeginOperandIndex(), + eidx = idx + successorOperands->size(); + idx < eidx; ++idx) { + newOperands[idx] = operands[idx]; + } + } + rewriter.updateRootInPlace( + op, [newOperands, op]() { op->setOperands(newOperands); }); + return success(); + } +}; +} // end anonymous namespace + +namespace { +/// Only needed to support partial conversion of functions where this pattern +/// ensures that the branch operation arguments matches up with the succesor +/// block arguments. +class ReturnOpTypeConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ReturnOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // For a return, all operands go to the results of the parent, so + // rewrite them all. + Operation *operation = op.getOperation(); + rewriter.updateRootInPlace( + op, [operands, operation]() { operation->setOperands(operands); }); + return success(); + } +}; +} // end anonymous namespace + +void mlir::populateBranchOpInterfaceAndReturnOpTypeConversionPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx, + TypeConverter &typeConverter) { + patterns.insert( + typeConverter, ctx); } diff --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp --- a/mlir/lib/Transforms/Bufferize.cpp +++ b/mlir/lib/Transforms/Bufferize.cpp @@ -7,7 +7,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Transforms/Bufferize.h" +#include "PassDetail.h" #include "mlir/IR/Operation.h" +#include "mlir/Transforms/Passes.h" using namespace mlir; @@ -15,6 +17,13 @@ // BufferizeTypeConverter //===----------------------------------------------------------------------===// +static Value materializeTensorLoad(OpBuilder &builder, TensorType type, + ValueRange inputs, Location loc) { + assert(inputs.size() == 1); + assert(inputs[0].getType().isa()); + return builder.create(loc, type, inputs[0]); +} + /// Registers conversions into BufferizeTypeConverter BufferizeTypeConverter::BufferizeTypeConverter() { // Keep all types unchanged. @@ -27,12 +36,8 @@ addConversion([](UnrankedTensorType type) -> Type { return UnrankedMemRefType::get(type.getElementType(), 0); }); - addSourceMaterialization([](OpBuilder &builder, TensorType type, - ValueRange inputs, Location loc) -> Value { - assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); - return builder.create(loc, type, inputs[0]); - }); + addArgumentMaterialization(materializeTensorLoad); + addSourceMaterialization(materializeTensorLoad); addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); @@ -83,3 +88,37 @@ patterns.insert( typeConverter, context); } + +namespace { +struct FinalizingBufferizePass + : public FinalizingBufferizeBase { + using FinalizingBufferizeBase< + FinalizingBufferizePass>::FinalizingBufferizeBase; + + void runOnFunction() override { + auto func = getFunction(); + auto *context = &getContext(); + + BufferizeTypeConverter typeConverter; + OwningRewritePatternList patterns; + ConversionTarget target(*context); + + populateEliminateBufferizeMaterializationsPatterns(context, typeConverter, + patterns); + target.addIllegalOp(); + + // If all result types are legal, and all block arguments are legal (ensured + // by func conversion above), then all types in the program are legal. + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + return typeConverter.isLegal(op->getResultTypes()); + }); + + if (failed(applyFullConversion(func, target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr mlir::createFinalizingBufferizePass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Standard/func-bufferize-partial.mlir b/mlir/test/Dialect/Standard/func-bufferize-partial.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Standard/func-bufferize-partial.mlir @@ -0,0 +1,59 @@ +// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics --debug-only=dialect-conversion | FileCheck %s + +// CHECK-LABEL: func @block_arguments( +// CHECK-SAME: %[[ARG:.*]]: memref) -> memref { +// CHECK: %[[T1:.*]] = tensor_load %[[ARG]] : memref +// CHECK: %[[M1:.*]] = tensor_to_memref %[[T1]] : memref +// CHECK: br ^bb1(%[[M1]] : memref) +// CHECK: ^bb1(%[[BBARG:.*]]: memref): +// CHECK: %[[T2:.*]] = tensor_load %[[BBARG]] : memref +// CHECK: %[[M2:.*]] = tensor_to_memref %[[T2]] : memref +// CHECK: return %[[M2]] : memref +func @block_arguments(%arg0: tensor) -> tensor { + br ^bb1(%arg0: tensor) +^bb1(%bbarg: tensor): + return %bbarg : tensor +} + +// CHECK-LABEL: func @partial() +// CHECK-SAME: memref +func @partial() -> tensor { + // CHECK-NEXT: %[[SRC:.*]] = "test.source"() : () -> tensor + // CHECK-NEXT: %[[MEM:.*]] = tensor_to_memref %[[SRC]] : memref + %0 = "test.source"() : () -> tensor + // CHECK-NEXT: return %[[MEM]] : memref + return %0 : tensor +} + +// CHECK-LABEL: func @region_op +// CHECK-SAME: (%[[ARG0:.*]]: i1) -> memref +func @region_op(%arg0: i1) -> tensor { + // CHECK-NEXT: %[[IF:.*]] = scf.if %[[ARG0]] -> (tensor) + %0 = scf.if %arg0 -> (tensor) { + // CHECK-NEXT: %[[SRC:.*]] = "test.source"() : () -> tensor + %1 = "test.source"() : () -> tensor + // CHECK-NEXT: scf.yield %[[SRC]] : tensor + scf.yield %1 : tensor + // CHECK-NEXT: else + } else { + // CHECK-NEXT: %[[OSRC:.*]] = "test.other_source"() : () -> tensor + %1 = "test.other_source"() : () -> tensor + // CHECK-NEXT: scf.yield %[[OSRC]] : tensor + scf.yield %1 : tensor + } + // CHECK: %[[MEM:.*]] = tensor_to_memref %[[IF]] : memref + // CHECK: return %[[MEM]] : memref + return %0 : tensor +} + +// ----- + +func @failed_to_legalize(%arg0: tensor) -> tensor { + %0 = constant true + cond_br %0, ^bb1(%arg0: tensor), ^bb2(%arg0: tensor) + ^bb1(%bbarg0: tensor): + // expected-error @+1 {{failed to legalize operation 'test.terminator'}} + "test.terminator"() : () -> () + ^bb2(%bbarg1: tensor): + return %bbarg1 : tensor +} diff --git a/mlir/test/Dialect/Standard/func-bufferize.mlir b/mlir/test/Dialect/Standard/func-bufferize.mlir --- a/mlir/test/Dialect/Standard/func-bufferize.mlir +++ b/mlir/test/Dialect/Standard/func-bufferize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -func-bufferize -finalizing-bufferize -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func @identity( // CHECK-SAME: %[[ARG:.*]]: memref) -> memref {