diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -14,6 +14,8 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/SetVector.h" +#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc" + namespace mlir { class OpBuilder; @@ -187,12 +189,6 @@ using UnknownTypeConverterFn = std::function; - enum class LayoutMapOption : int8_t { - InferLayoutMap = 0, - IdentityLayoutMap = 1, - FullyDynamicLayoutMap = 2 - }; - BufferizationOptions(); /// Try to cast the given op to BufferizableOpInterface if the op is allow @@ -585,6 +581,10 @@ } // namespace bufferization } // namespace mlir +//===----------------------------------------------------------------------===// +// Bufferization Interfaces +//===----------------------------------------------------------------------===// + #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h.inc" #endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_ diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td @@ -0,0 +1,27 @@ +//===- BufferizationEnums.td - Bufferization enums ---------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the definition file for enums used in Bufferization. +// +//===----------------------------------------------------------------------===// + +#ifndef BUFFERIZATION_ENUMS +#define BUFFERIZATION_ENUMS + +include "mlir/IR/EnumAttr.td" + +def LayoutMapOption : I32EnumAttr<"LayoutMapOption", + "option for map layout", [ + I32EnumAttrCase<"InferLayoutMap", 0>, + I32EnumAttrCase<"IdentityLayoutMap", 1>, + I32EnumAttrCase<"FullyDynamicLayoutMap", 2> +]> { + let cppNamespace = "::mlir::bufferization"; +} + +#endif // BUFFERIZATION_ENUMS diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt @@ -2,3 +2,9 @@ add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc) add_mlir_interface(AllocationOpInterface) add_mlir_interface(BufferizableOpInterface) + +set(LLVM_TARGET_DEFINITIONS BufferizationEnums.td) +mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls) +mlir_tablegen(BufferizationEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRBufferizationEnumsIncGen) +add_dependencies(mlir-headers MLIRBufferizationEnumsIncGen) diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h --- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h +++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H #define MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td --- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td @@ -9,6 +9,7 @@ #ifndef BUFFERIZATION_TRANSFORM_OPS #define BUFFERIZATION_TRANSFORM_OPS +include "mlir/Dialect/Bufferization/IR/BufferizationEnums.td" include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/IR/TransformEffects.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" @@ -42,6 +43,7 @@ let arguments = ( ins PDL_Operation:$target, + OptionalAttr:$function_boundary_type_conversion, DefaultValuedAttr:$allow_return_allocs, DefaultValuedAttr:$allow_unknown_ops, DefaultValuedAttr:$bufferize_function_boundaries, @@ -52,7 +54,10 @@ let results = (outs); - let assemblyFormat = "$target attr-dict"; + let assemblyFormat = [{ + (`layout` `{` $function_boundary_type_conversion^ `}`)? + $target attr-dict + }]; } #endif // BUFFERIZATION_TRANSFORM_OPS diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt @@ -10,6 +10,7 @@ DEPENDS MLIRAllocationOpInterfaceIncGen MLIRBufferizationOpsIncGen + MLIRBufferizationEnumsIncGen LINK_LIBS PUBLIC MLIRAffineDialect diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -34,6 +34,9 @@ options.createDeallocs = getCreateDeallocs(); options.testAnalysisOnly = getTestAnalysisOnly(); options.printConflicts = getPrintConflicts(); + if (getFunctionBoundaryTypeConversion().has_value()) + options.functionBoundaryTypeConversion = + *getFunctionBoundaryTypeConversion(); ArrayRef payloadOps = state.getPayloadOps(getTarget()); for (Operation *target : payloadOps) { @@ -94,6 +97,8 @@ #define GET_OP_CLASSES #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" +#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.cpp.inc" + void mlir::bufferization::registerTransformDialectExtension( DialectRegistry ®istry) { registry.addExtensions(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -163,14 +163,13 @@ } }; -static BufferizationOptions::LayoutMapOption -parseLayoutMapOption(const std::string &s) { +static LayoutMapOption parseLayoutMapOption(const std::string &s) { if (s == "fully-dynamic-layout-map") - return BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap; + return LayoutMapOption::FullyDynamicLayoutMap; if (s == "identity-layout-map") - return BufferizationOptions::LayoutMapOption::IdentityLayoutMap; + return LayoutMapOption::IdentityLayoutMap; if (s == "infer-layout-map") - return BufferizationOptions::LayoutMapOption::InferLayoutMap; + return LayoutMapOption::InferLayoutMap; llvm_unreachable("invalid layout map option"); } @@ -216,19 +215,17 @@ opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries; // Configure type converter. - BufferizationOptions::LayoutMapOption unknownTypeConversionOption = + LayoutMapOption unknownTypeConversionOption = parseLayoutMapOption(unknownTypeConversion); opt.unknownTypeConverterFn = [=](Value value, unsigned memorySpace, const BufferizationOptions &options) { auto tensorType = value.getType().cast(); - if (unknownTypeConversionOption == - BufferizationOptions::LayoutMapOption::IdentityLayoutMap) + if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap) return bufferization::getMemRefTypeWithStaticIdentityLayout( tensorType, memorySpace); - assert( - unknownTypeConversionOption == - BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap && - "invalid layout map option"); + assert(unknownTypeConversionOption == + LayoutMapOption::FullyDynamicLayoutMap && + "invalid layout map option"); return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace); }; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt @@ -18,6 +18,7 @@ DEPENDS MLIRBufferizationPassIncGen + MLIRBufferizationEnumsIncGen LINK_LIBS PUBLIC MLIRBufferizationDialect diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -69,7 +69,7 @@ BaseMemRefType memrefType; if (options.functionBoundaryTypeConversion == - BufferizationOptions::LayoutMapOption::IdentityLayoutMap) { + LayoutMapOption::IdentityLayoutMap) { memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType); } else { // Note: Layout maps on function parameters cannot be inferred. The best we @@ -471,7 +471,7 @@ BaseMemRefType resultType; if (options.functionBoundaryTypeConversion == - BufferizationOptions::LayoutMapOption::IdentityLayoutMap) { + LayoutMapOption::IdentityLayoutMap) { resultType = getMemRefTypeWithStaticIdentityLayout(tensorType); } else { // Note: If `InferLayoutMap`, cast are later folded away. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -423,7 +423,7 @@ return failure(); // Change buffer return types to more precise layout maps. if (options.functionBoundaryTypeConversion == - BufferizationOptions::LayoutMapOption::InferLayoutMap) + LayoutMapOption::InferLayoutMap) foldMemRefCasts(funcOp); } diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -32,8 +32,7 @@ // TODO(springerm): To spot memory leaks more easily, returning dense allocs // should be disallowed. options.allowReturnAllocs = true; - options.functionBoundaryTypeConversion = - BufferizationOptions::LayoutMapOption::IdentityLayoutMap; + options.functionBoundaryTypeConversion = LayoutMapOption::IdentityLayoutMap; options.unknownTypeConverterFn = [](Value value, unsigned memorySpace, const BufferizationOptions &options) { return getMemRefTypeWithStaticIdentityLayout( diff --git a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir @@ -96,3 +96,25 @@ return %0 : tensor } } + +// ----- + +// Test we use identity layout at function boundaries. + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %arg1 { + target_is_module = true, + bufferize_function_boundaries = true } +} + +// CHECK: func.func @matmul( +// CHECK-SAME: %[[A:.*]]: memref<12x9xf32>, +// CHECK-SAME: %[[B:.*]]: memref<9x6xf32>, +// CHECK-SAME: %[[C:.*]]: memref<12x6xf32>) -> memref<12x6xf32> { +func.func @matmul(%A: tensor<12x9xf32>, %B: tensor<9x6xf32>, %C: tensor<12x6xf32>) -> tensor<12x6xf32> { + // CHECK: linalg.matmul ins(%[[A]], %[[B]] : memref<12x9xf32>, memref<9x6xf32>) outs(%[[C]] : memref<12x6xf32>) + %D = linalg.matmul ins(%A, %B: tensor<12x9xf32>, tensor<9x6xf32>) outs(%C: tensor<12x6xf32>) -> tensor<12x6xf32> + // CHECK: return %[[C]] : memref<12x6xf32> + return %D : tensor<12x6xf32> +}