diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h @@ -26,6 +26,14 @@ MLIRContext *ctx, TypeConverter &converter); +/// Add a pattern to the given pattern list to rewrite terminators to use +/// operands that have been legalized by the conversion framework. This can only +/// be done if the terminator has the ReturnLike trait or implements the +/// BranchOpInterface. Otherwise the pattern will produce an error. Only needed +/// for partial conversions. +void populateTerminatorTypeConversionPattern(OwningRewritePatternList &patterns, + MLIRContext *ctx, + TypeConverter &converter); } // end namespace mlir #endif // MLIR_DIALECT_STANDARDOPS_TRANSFORMS_FUNCCONVERSIONS_H_ diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -33,7 +33,8 @@ std::unique_ptr createStdBufferizePass(); /// Creates an instance of func bufferization pass. -std::unique_ptr createFuncBufferizePass(); +std::unique_ptr +createFuncBufferizePass(bool allowPartialBufferization = false); /// Creates an instance of tensor constant bufferization pass. std::unique_ptr createTensorConstantBufferizePass(); diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td @@ -25,7 +25,7 @@ def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> { let summary = "Bufferize func/call/return ops"; let description = [{ - A finalizing bufferize pass that bufferizes std.func and std.call ops. + A bufferize pass that bufferizes std.func and std.call ops. Because this pass updates std.func ops, it must be a module pass. It is useful to keep this pass separate from other bufferizations so that the @@ -39,16 +39,28 @@ a closed universe (and need not implement BranchOpInterface), and so we cannot in general rewrite them. - Note, because this is a "finalizing" bufferize step, it can create - invalid IR because it will not create materializations. To avoid this - situation, the pass must only be run when the only SSA values of - tensor type are: + By default this is a "finalizing" bufferize step and it can create + invalid IR because it will not create materializations. + + A finalizing pass must only be run when the only SSA values of tensor type + are: - block arguments - the result of tensor_load Other values of tensor type should be eliminated by earlier bufferization passes. + + The pass also supports partial bufferization of functions with the + limitations as described above. The only supported terminator in such a + setting is the return operation. Consequently, the partial bufferization + can only be used in the absence of controlflow (blocks and branch + operations). }]; let constructor = "mlir::createFuncBufferizePass()"; + let options = [ + Option<"allowPartialBufferization", "allow-partial-bufferization", + "bool", /*default=*/"false", "Allow partial bufferization. " + "Value-based operations may remain.">, + ]; } def TensorConstantBufferize : Pass<"tensor-constant-bufferize", "ModuleOp"> { diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp @@ -21,6 +21,10 @@ namespace { struct FuncBufferizePass : public FuncBufferizeBase { + FuncBufferizePass(bool allowPartialBufferization) { + this->allowPartialBufferization = allowPartialBufferization; + } + void runOnOperation() override { auto module = getOperation(); auto *context = &getContext(); @@ -35,6 +39,41 @@ typeConverter.isLegal(&op.getBody()); }); populateCallOpTypeConversionPattern(patterns, context, typeConverter); + target.addDynamicallyLegalOp([&](CallOp op) { + return typeConverter.isLegal(op.getOperandTypes()) && + typeConverter.isLegal(op.getResultTypes()); + }); + + if (allowPartialBufferization) { + populateTerminatorTypeConversionPattern(patterns, context, typeConverter); + target.addLegalOp(); + // Mark terminators as legal if they have the ReturnLike trait or + // implement the BranchOpInterface and have valid types. If they do not + // implement the trait or interface, mark them as illegal no matter what. + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + // If it is not a terminator, ignore it. + if (op->getNextNode()) + return true; + if (op->hasTrait()) { + return typeConverter.isLegal(op->getOperandTypes()); + } + if (auto branchOp = dyn_cast(op)) { + for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) { + auto successorOperands = branchOp.getSuccessorOperands(p); + if (successorOperands.hasValue() && + !typeConverter.isLegal(successorOperands.getValue().getTypes())) + return false; + } + return true; + } + return false; + }); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); + return; + } + populateEliminateBufferizeMaterializationsPatterns(context, typeConverter, patterns); target.addIllegalOp(); @@ -51,6 +90,7 @@ }; } // namespace -std::unique_ptr mlir::createFuncBufferizePass() { - return std::make_unique(); +std::unique_ptr +mlir::createFuncBufferizePass(bool allowPartialBufferization) { + return std::make_unique(allowPartialBufferization); } diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; @@ -16,18 +17,16 @@ // Converts the operand and result types of the Standard's CallOp, used together // with the FuncOpSignatureConversion. struct CallOpSignatureConversion : public OpConversionPattern { - CallOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) - : OpConversionPattern(ctx), converter(converter) {} + using OpConversionPattern::OpConversionPattern; /// Hook for derived classes to implement combined matching and rewriting. LogicalResult matchAndRewrite(CallOp callOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - FunctionType type = callOp.getCalleeType(); - // Convert the original function results. SmallVector convertedResults; - if (failed(converter.convertTypes(type.getResults(), convertedResults))) + if (failed(typeConverter->convertTypes(callOp.getResultTypes(), + convertedResults))) return failure(); // Substitute with the new result types from the corresponding FuncType @@ -36,14 +35,67 @@ convertedResults, operands); return success(); } - - /// The type converter to use when rewriting the signature. - TypeConverter &converter; }; } // end anonymous namespace void mlir::populateCallOpTypeConversionPattern( OwningRewritePatternList &patterns, MLIRContext *ctx, TypeConverter &converter) { - patterns.insert(ctx, converter); + patterns.insert(converter, ctx); +} + +namespace { +// Only needed to support partial conversion of functions where this pattern +// ensures that the terminator operation arguments matches up with the succesor +// block arguments. +class TerminatorOpTypeConversion : public ConversionPattern { +public: + TerminatorOpTypeConversion(TypeConverter &typeConverter, MLIRContext *ctx) + : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // Only look at potential terminators (last op in the block). + if (op->getNextNode()) + return failure(); + if (op->hasTrait()) { + // For a return like, all operands go to the results of the parent, so + // rewrite them all. + rewriter.updateRootInPlace( + op, [operands, op]() { op->setOperands(operands); }); + return success(); + } + if (auto branchOp = dyn_cast(op)) { + // For a branch operation, only some operands go to the target blocks, so + // only rewrite those. + SmallVector newOperands(op->operand_begin(), op->operand_end()); + for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors(); + succIdx < succEnd; ++succIdx) { + auto successorOperands = branchOp.getSuccessorOperands(succIdx); + if (!successorOperands) + continue; + for (int idx = successorOperands->getBeginOperandIndex(), + eidx = idx + successorOperands->size(); + idx < eidx; ++idx) { + newOperands[idx] = operands[idx]; + } + } + rewriter.updateRootInPlace( + op, [newOperands, op]() { op->setOperands(newOperands); }); + return success(); + } + // These are fine, so do not produce an error. + if (isa(op)) + return failure(); + + return op->emitError("terminator does not implement required interfaces"); + } +}; +} // end anonymous namespace + +void mlir::populateTerminatorTypeConversionPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx, + TypeConverter &typeConverter) { + patterns.insert(typeConverter, ctx); } diff --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp --- a/mlir/lib/Transforms/Bufferize.cpp +++ b/mlir/lib/Transforms/Bufferize.cpp @@ -15,6 +15,13 @@ // BufferizeTypeConverter //===----------------------------------------------------------------------===// +static Value InsertTensorLoad(OpBuilder &builder, TensorType type, + ValueRange inputs, Location loc) { + assert(inputs.size() == 1); + assert(inputs[0].getType().isa()); + return builder.create(loc, type, inputs[0]); +} + /// Registers conversions into BufferizeTypeConverter BufferizeTypeConverter::BufferizeTypeConverter() { // Keep all types unchanged. @@ -27,12 +34,8 @@ addConversion([](UnrankedTensorType type) -> Type { return UnrankedMemRefType::get(type.getElementType(), 0); }); - addSourceMaterialization([](OpBuilder &builder, TensorType type, - ValueRange inputs, Location loc) -> Value { - assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); - return builder.create(loc, type, inputs[0]); - }); + addArgumentMaterialization(InsertTensorLoad); + addSourceMaterialization(InsertTensorLoad); addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); diff --git a/mlir/test/Dialect/Standard/func-bufferize-partial.mlir b/mlir/test/Dialect/Standard/func-bufferize-partial.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Standard/func-bufferize-partial.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-opt %s -func-bufferize="allow-partial-bufferization=true" -split-input-file -verify-diagnostics --debug-only=dialect-conversion | FileCheck %s + +// CHECK-LABEL: func @block_arguments( +// CHECK-SAME: %[[ARG:.*]]: memref) -> memref { +// CHECK: %[[T1:.*]] = tensor_load %[[ARG]] : memref +// CHECK: %[[M1:.*]] = tensor_to_memref %[[T1]] : memref +// CHECK: br ^bb1(%[[M1]] : memref) +// CHECK: ^bb1(%[[BBARG:.*]]: memref): +// CHECK: %[[T2:.*]] = tensor_load %[[BBARG]] : memref +// CHECK: %[[M2:.*]] = tensor_to_memref %[[T2]] : memref +// CHECK: return %[[M2]] : memref +func @block_arguments(%arg0: tensor) -> tensor { + br ^bb1(%arg0: tensor) +^bb1(%bbarg: tensor): + return %bbarg : tensor +} + +// CHECK-LABEL: func @partial() +// CHECK-SAME: memref +func @partial() -> tensor { + // CHECK-NEXT: %[[SRC:.*]] = "test.source"() : () -> tensor + // CHECK-NEXT: %[[MEM:.*]] = tensor_to_memref %[[SRC]] : memref + %0 = "test.source"() : () -> tensor + // CHECK-NEXT: return %[[MEM]] : memref + return %0 : tensor +} + +// ----- + +func @failed_to_legalize(%arg0: tensor) -> tensor { + %0 = constant true + cond_br %0, ^bb1(%arg0: tensor), ^bb2(%arg0: tensor) + ^bb1(%bbarg0: tensor): + // expected-error @+1 {{terminator does not implement required interfaces}} + "test.terminator"() : () -> () + ^bb2(%bbarg1: tensor): + return %bbarg1 : tensor +}