diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h @@ -26,6 +26,12 @@ MLIRContext *ctx, TypeConverter &converter); +/// Add a pattern to the given pattern list to rewrite the return op to use +/// operands that have been legalized by the conversion framework. Only needed +/// for partial conversions. +void populateReturnOpTypeConversionPattern(OwningRewritePatternList &patterns, + MLIRContext *ctx, + TypeConverter &converter); } // end namespace mlir #endif // MLIR_DIALECT_STANDARDOPS_TRANSFORMS_FUNCCONVERSIONS_H_ diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -33,7 +33,8 @@ std::unique_ptr createStdBufferizePass(); /// Creates an instance of func bufferization pass. -std::unique_ptr createFuncBufferizePass(); +std::unique_ptr +createFuncBufferizePass(bool allowPartialBufferization = false); /// Creates an instance of tensor constant bufferization pass. std::unique_ptr createTensorConstantBufferizePass(); diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td @@ -25,7 +25,7 @@ def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> { let summary = "Bufferize func/call/return ops"; let description = [{ - A finalizing bufferize pass that bufferizes std.func and std.call ops. + A bufferize pass that bufferizes std.func and std.call ops. Because this pass updates std.func ops, it must be a module pass. It is useful to keep this pass separate from other bufferizations so that the @@ -39,16 +39,28 @@ a closed universe (and need not implement BranchOpInterface), and so we cannot in general rewrite them. - Note, because this is a "finalizing" bufferize step, it can create - invalid IR because it will not create materializations. To avoid this - situation, the pass must only be run when the only SSA values of - tensor type are: + By default this is a "finalizing" bufferize step and it can create + invalid IR because it will not create materializations. + + A finalizing pass must only be run when the only SSA values of tensor type + are: - block arguments - the result of tensor_load Other values of tensor type should be eliminated by earlier bufferization passes. + + The pass also supports partial bufferization of functions with the + limitations as described above. The only supported terminator in such a + setting is the return operation. Consequently, the partial bufferization + can only be used in the absence of controlflow (blocks and branch + operations). }]; let constructor = "mlir::createFuncBufferizePass()"; + let options = [ + Option<"allowPartialBufferization", "allow-partial-bufferization", + "bool", /*default=*/"false", "Allow partial bufferization. " + "Value-based operations may remain.">, + ]; } def TensorConstantBufferize : Pass<"tensor-constant-bufferize", "ModuleOp"> { diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp @@ -21,6 +21,10 @@ namespace { struct FuncBufferizePass : public FuncBufferizeBase { + FuncBufferizePass(bool allowPartialBufferization) { + this->allowPartialBufferization = allowPartialBufferization; + } + void runOnOperation() override { auto module = getOperation(); auto *context = &getContext(); @@ -35,6 +39,23 @@ typeConverter.isLegal(&op.getBody()); }); populateCallOpTypeConversionPattern(patterns, context, typeConverter); + target.addDynamicallyLegalOp([&](CallOp op) { + return typeConverter.isLegal(op.getOperandTypes()) && + typeConverter.isLegal(op.getResultTypes()); + }); + + if (allowPartialBufferization) { + populateReturnOpTypeConversionPattern(patterns, context, typeConverter); + target.addLegalOp(); + target.addDynamicallyLegalOp([&](ReturnOp op) { + return typeConverter.isLegal(op.getOperandTypes()); + }); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); + return; + } + populateEliminateBufferizeMaterializationsPatterns(context, typeConverter, patterns); target.addIllegalOp(); @@ -51,6 +72,7 @@ }; } // namespace -std::unique_ptr mlir::createFuncBufferizePass() { - return std::make_unique(); +std::unique_ptr +mlir::createFuncBufferizePass(bool allowPartialBufferization) { + return std::make_unique(allowPartialBufferization); } diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp @@ -16,18 +16,16 @@ // Converts the operand and result types of the Standard's CallOp, used together // with the FuncOpSignatureConversion. struct CallOpSignatureConversion : public OpConversionPattern { - CallOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) - : OpConversionPattern(ctx), converter(converter) {} + using OpConversionPattern::OpConversionPattern; /// Hook for derived classes to implement combined matching and rewriting. LogicalResult matchAndRewrite(CallOp callOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - FunctionType type = callOp.getCalleeType(); - // Convert the original function results. SmallVector convertedResults; - if (failed(converter.convertTypes(type.getResults(), convertedResults))) + if (failed(typeConverter->convertTypes(callOp.getResultTypes(), + convertedResults))) return failure(); // Substitute with the new result types from the corresponding FuncType @@ -36,14 +34,33 @@ convertedResults, operands); return success(); } - - /// The type converter to use when rewriting the signature. - TypeConverter &converter; }; } // end anonymous namespace void mlir::populateCallOpTypeConversionPattern( OwningRewritePatternList &patterns, MLIRContext *ctx, TypeConverter &converter) { - patterns.insert(ctx, converter); + patterns.insert(converter, ctx); +} + +namespace { +// Only needed to support partial conversion of functions where this pattern +// ensures that the return matches up with the new function signature. +class ReturnOpTypeConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ReturnOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + rewriter.replaceOpWithNewOp(op, operands); + return success(); + } +}; +} // end anonymous namespace + +void mlir::populateReturnOpTypeConversionPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx, + TypeConverter &typeConverter) { + patterns.insert(typeConverter, ctx); } diff --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp --- a/mlir/lib/Transforms/Bufferize.cpp +++ b/mlir/lib/Transforms/Bufferize.cpp @@ -15,6 +15,13 @@ // BufferizeTypeConverter //===----------------------------------------------------------------------===// +static Value InsertTensorLoad(OpBuilder &builder, TensorType type, + ValueRange inputs, Location loc) { + assert(inputs.size() == 1); + assert(inputs[0].getType().isa()); + return builder.create(loc, type, inputs[0]); +} + /// Registers conversions into BufferizeTypeConverter BufferizeTypeConverter::BufferizeTypeConverter() { // Keep all types unchanged. @@ -27,12 +34,8 @@ addConversion([](UnrankedTensorType type) -> Type { return UnrankedMemRefType::get(type.getElementType(), 0); }); - addSourceMaterialization([](OpBuilder &builder, TensorType type, - ValueRange inputs, Location loc) -> Value { - assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); - return builder.create(loc, type, inputs[0]); - }); + addArgumentMaterialization(InsertTensorLoad); + addSourceMaterialization(InsertTensorLoad); addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); diff --git a/mlir/test/Dialect/Standard/func-bufferize-partial.mlir b/mlir/test/Dialect/Standard/func-bufferize-partial.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Standard/func-bufferize-partial.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-opt %s -func-bufferize="allow-partial-bufferization=true" -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func @partial() +// CHECK-SAME: memref +func @partial() -> tensor { + // CHECK-NEXT: %[[SRC:.*]] = "test.source"() : () -> tensor + // CHECK-NEXT: %[[MEM:.*]] = tensor_to_memref %[[SRC]] : memref + %0 = "test.source"() : () -> tensor + // CHECK-NEXT: return %[[MEM]] : memref + return %0 : tensor +}