diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td --- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td @@ -19,8 +19,8 @@ def OneShotBufferizeOp : Op, - DeclareOpInterfaceMethods]> { + [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface, + DeclareOpInterfaceMethods]> { let description = [{ Indicates that the given `target` op should be bufferized with One-Shot Bufferize. The bufferization can be configured with various attributes that @@ -28,10 +28,8 @@ `one-shot-bufferize` pass. More information can be found in the pass documentation. - If `target_is_module` is set, `target` must be a module. In that case the - `target` handle can be reused by other transform ops. When bufferizing other - ops, the `target` handled is freed after bufferization and can no longer be - used. + The targeted ops must be modules or functions. This is because there is + always a single, bufferized replacement op for such targets. Note: Only ops that implement `BufferizableOpInterface` are bufferized. All other ops are ignored if `allow_unknown_ops`. If `allow_unknown_ops` is @@ -39,24 +37,28 @@ Many ops implement `BufferizableOpInterface` via an external model. These external models must be registered when applying this transform op; otherwise, said ops would be considered non-bufferizable. + + #### Return modes + + This operation consumes the `target` handle and produces the `transformed` + handle. }]; let arguments = ( - ins PDL_Operation:$target, + ins TransformHandleTypeInterface:$target, OptionalAttr:$function_boundary_type_conversion, DefaultValuedAttr:$allow_return_allocs, DefaultValuedAttr:$allow_unknown_ops, DefaultValuedAttr:$bufferize_function_boundaries, DefaultValuedAttr:$create_deallocs, - DefaultValuedAttr:$target_is_module, DefaultValuedAttr:$test_analysis_only, DefaultValuedAttr:$print_conflicts); - let results = (outs); + let results = (outs TransformHandleTypeInterface:$transformed); let assemblyFormat = [{ (`layout` `{` $function_boundary_type_conversion^ `}`)? - $target attr-dict + $target attr-dict `:` functional-type($target, results) }]; } diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/IR/FunctionInterfaces.h" using namespace mlir; using namespace mlir::bufferization; @@ -41,12 +42,12 @@ ArrayRef payloadOps = state.getPayloadOps(getTarget()); for (Operation *target : payloadOps) { + if (!isa(target)) + return emitSilenceableError() << "expected module or function target"; auto moduleOp = dyn_cast(target); - if (getTargetIsModule() && !moduleOp) - return emitSilenceableError() << "expected ModuleOp target"; if (options.bufferizeFunctionBoundaries) { if (!moduleOp) - return emitSilenceableError() << "expected ModuleOp target"; + return emitSilenceableError() << "expected module target"; if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options))) return emitSilenceableError() << "bufferization failed"; } else { @@ -55,21 +56,12 @@ } } + // This transform op is currently restricted to ModuleOps and function ops. + // Such ops are modified in-place. + transformResults.set(getTransformed().cast(), payloadOps); return DiagnosedSilenceableFailure::success(); } -void transform::OneShotBufferizeOp::getEffects( - SmallVectorImpl &effects) { - // Handles that are not modules are not longer usable. - if (!getTargetIsModule()) { - consumesHandle(getTarget(), effects); - } else { - onlyReadsHandle(getTarget(), effects); - } - - modifiesPayload(effects); -} - //===----------------------------------------------------------------------===// // EmptyTensorToAllocTensorOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir @@ -5,8 +5,7 @@ transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation - transform.bufferization.one_shot_bufferize %0 - {target_is_module = false} + %1 = transform.bufferization.one_shot_bufferize %0 : (!pdl.operation) -> !pdl.operation } // CHECK-LABEL: func @test_function( @@ -34,8 +33,8 @@ transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation - transform.bufferization.one_shot_bufferize %0 - {target_is_module = false, test_analysis_only = true} + %1 = transform.bufferization.one_shot_bufferize %0 + {test_analysis_only = true} : (!pdl.operation) -> !pdl.operation } // CHECK-LABEL: func @test_function_analysis( @@ -58,7 +57,7 @@ ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation // expected-error @+1 {{bufferization failed}} - transform.bufferization.one_shot_bufferize %0 {target_is_module = false} + %1 = transform.bufferization.one_shot_bufferize %0 : (!pdl.operation) -> !pdl.operation } func.func @test_unknown_op_failure() -> (tensor) { @@ -69,12 +68,10 @@ // ----- -// Test One-Shot Bufferize transform failure with a module op. - transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): // %arg1 is the module - transform.bufferization.one_shot_bufferize %arg1 + %0 = transform.bufferization.one_shot_bufferize %arg1 : (!pdl.operation) -> !pdl.operation } module { @@ -103,9 +100,8 @@ transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): - transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %arg1 { - target_is_module = true, - bufferize_function_boundaries = true } + %0 = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %arg1 + { bufferize_function_boundaries = true } : (!pdl.operation) -> !pdl.operation } // CHECK: func.func @matmul( diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir --- a/mlir/test/Dialect/LLVM/transform-e2e.mlir +++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir @@ -18,10 +18,11 @@ %1, %loops:3 = transform.structured.tile %0 [2, 2, 2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) %2 = get_closest_isolated_parent %1 : (!pdl.operation) -> !pdl.operation transform.structured.vectorize %2 - transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %module_op - {bufferize_function_boundaries = true} + %b = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} + %module_op {bufferize_function_boundaries = true} + : (!pdl.operation) -> !pdl.operation - %f = transform.structured.match ops{["func.func"]} in %module_op + %f = transform.structured.match ops{["func.func"]} in %b : (!pdl.operation) -> !pdl.operation // TODO: group these lower-level controls into various properly named vector diff --git a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir --- a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir @@ -61,7 +61,7 @@ %1 = transform.get_result %0[0] : (!pdl.operation) -> !transform.any_value %2 = transform.structured.bufferize_to_allocation %1 // Make sure that One-Shot Bufferize can bufferize the rest. - transform.bufferization.one_shot_bufferize %arg1 + %3 = transform.bufferization.one_shot_bufferize %arg1 : (!pdl.operation) -> !pdl.operation } // ----- @@ -108,7 +108,7 @@ %1 = test_produce_value_handle_to_argument_of_parent_block %0, 0 : (!pdl.operation) -> !transform.any_value %2 = transform.structured.bufferize_to_allocation %1 {memory_space = 4} // Make sure that One-Shot Bufferize can bufferize the rest. - transform.bufferization.one_shot_bufferize %arg1 + %3 = transform.bufferization.one_shot_bufferize %arg1 : (!pdl.operation) -> !pdl.operation } // ----- diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir --- a/mlir/test/Dialect/Vector/transform-vector.mlir +++ b/mlir/test/Dialect/Vector/transform-vector.mlir @@ -20,16 +20,18 @@ : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) %2 = get_closest_isolated_parent %1 : (!pdl.operation) -> !pdl.operation transform.structured.vectorize %2 - transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %module_op - {bufferize_function_boundaries = true, allow_return_allocs = true} + %b = transform.bufferization.one_shot_bufferize + layout{IdentityLayoutMap} %module_op + {bufferize_function_boundaries = true, allow_return_allocs = true} + : (!pdl.operation) -> !pdl.operation - %f = transform.structured.match ops{["func.func"]} in %module_op + %f = transform.structured.match ops{["func.func"]} in %b : (!pdl.operation) -> !pdl.operation // TODO: group these lower-level controls into various properly named vector // lowering TD macros. %func = transform.vector.lower_contraction %f - lowering_strategy = "outerproduct" + lowering_strategy = "outerproduct" : (!pdl.operation) -> !pdl.operation %func_2 = transform.vector.apply_transfer_permutation_patterns %func