diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -309,6 +309,10 @@ /// Add a analysis state initializer that initializes the specified /// dialect-specific analysis state. void addDialectStateInitializer(StringRef name, const DialectStateInitFn &fn); + + /// Parse layout map options. + BufferizationOptions::LayoutMapOption + parseLayoutMapOption(const std::string &s); }; /// Specify fine-grain relationship between buffers to enable more analysis. diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td --- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td @@ -48,7 +48,8 @@ DefaultValuedAttr:$create_deallocs, DefaultValuedAttr:$target_is_module, DefaultValuedAttr:$test_analysis_only, - DefaultValuedAttr:$print_conflicts); + DefaultValuedAttr:$print_conflicts, + DefaultValuedStrAttr:$function_boundary_type_conversion); let results = (outs); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -264,6 +264,17 @@ memorySpace); } +BufferizationOptions::LayoutMapOption +BufferizationOptions::parseLayoutMapOption(const std::string &s) { + if (s == "fully-dynamic-layout-map") + return BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap; + if (s == "identity-layout-map") + return BufferizationOptions::LayoutMapOption::IdentityLayoutMap; + if (s == "infer-layout-map") + return BufferizationOptions::LayoutMapOption::InferLayoutMap; + llvm_unreachable("invalid layout map option"); +} + // Default constructor for BufferizationOptions. BufferizationOptions::BufferizationOptions() : unknownTypeConverterFn(defaultUnknownTypeConverter) {} diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -34,6 +34,8 @@ options.createDeallocs = getCreateDeallocs(); options.testAnalysisOnly = getTestAnalysisOnly(); options.printConflicts = getPrintConflicts(); + options.functionBoundaryTypeConversion = + options.parseLayoutMapOption(getFunctionBoundaryTypeConversion().str()); ArrayRef payloadOps = state.getPayloadOps(getTarget()); for (Operation *target : payloadOps) { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -163,17 +163,6 @@ } }; -static BufferizationOptions::LayoutMapOption -parseLayoutMapOption(const std::string &s) { - if (s == "fully-dynamic-layout-map") - return BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap; - if (s == "identity-layout-map") - return BufferizationOptions::LayoutMapOption::IdentityLayoutMap; - if (s == "infer-layout-map") - return BufferizationOptions::LayoutMapOption::InferLayoutMap; - llvm_unreachable("invalid layout map option"); -} - static OneShotBufferizationOptions::AnalysisHeuristic parseHeuristicOption(const std::string &s) { if (s == "bottom-up") @@ -208,7 +197,7 @@ opt.copyBeforeWrite = copyBeforeWrite; opt.createDeallocs = createDeallocs; opt.functionBoundaryTypeConversion = - parseLayoutMapOption(functionBoundaryTypeConversion); + opt.parseLayoutMapOption(functionBoundaryTypeConversion); if (mustInferMemorySpace) opt.defaultMemorySpace = None; opt.printConflicts = printConflicts; @@ -217,7 +206,7 @@ // Configure type converter. BufferizationOptions::LayoutMapOption unknownTypeConversionOption = - parseLayoutMapOption(unknownTypeConversion); + opt.parseLayoutMapOption(unknownTypeConversion); opt.unknownTypeConverterFn = [=](Value value, unsigned memorySpace, const BufferizationOptions &options) { auto tensorType = value.getType().cast(); diff --git a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir @@ -96,3 +96,26 @@ return %0 : tensor } } + +// ----- + +// Test we use identity layout at function boundaries. + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + transform.bufferization.one_shot_bufferize %arg1 { + target_is_module = true, + bufferize_function_boundaries = true, + function_boundary_type_conversion = "identity-layout-map" } +} + +// CHECK: func.func @matmul( +// CHECK-SAME: %[[A:.*]]: memref<12x9xf32>, +// CHECK-SAME: %[[B:.*]]: memref<9x6xf32>, +// CHECK-SAME: %[[C:.*]]: memref<12x6xf32>) -> memref<12x6xf32> { +func.func @matmul(%A: tensor<12x9xf32>, %B: tensor<9x6xf32>, %C: tensor<12x6xf32>) -> tensor<12x6xf32> { + // CHECK: linalg.matmul ins(%[[A]], %[[B]] : memref<12x9xf32>, memref<9x6xf32>) outs(%[[C]] : memref<12x6xf32>) + %D = linalg.matmul ins(%A, %B: tensor<12x9xf32>, tensor<9x6xf32>) outs(%C: tensor<12x6xf32>) -> tensor<12x6xf32> + // CHECK: return %[[C]] : memref<12x6xf32> + return %D : tensor<12x6xf32> +}