diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td --- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td @@ -11,13 +11,16 @@ include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" include "mlir/Dialect/PDL/IR/PDLTypes.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" +def Transform_MemRefAllocOp : Transform_ConcreteOpType<"memref.alloc">; + def MemRefMultiBufferOp : Op { + DeclareOpInterfaceMethods]> { let summary = "Multibuffers an allocation"; let description = [{ Transformation to do multi-buffering/array expansion to remove @@ -33,19 +36,13 @@ }]; let arguments = - (ins PDL_Operation:$target, + (ins Transform_MemRefAllocOp:$target, ConfinedAttr:$factor); let results = (outs PDL_Operation:$transformed); - let assemblyFormat = "$target attr-dict"; - - let extraClassDeclaration = [{ - ::mlir::DiagnosedSilenceableFailure applyToOne( - memref::AllocOp target, - ::mlir::transform::ApplyToEachResultList &results, - ::mlir::transform::TransformState &state); - }]; + let assemblyFormat = + "$target attr-dict `:` functional-type(operands, results)"; } #endif // MEMREF_TRANSFORM_OPS diff --git a/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt @@ -12,6 +12,7 @@ MLIRArithDialect MLIRIR MLIRPDLDialect + MLIRLoopLikeInterface MLIRMemRefDialect MLIRMemRefTransforms MLIRTransformDialect diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Interfaces/LoopLikeInterface.h" using namespace mlir; @@ -21,15 +22,33 @@ // MemRefMultiBufferOp //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::applyToOne( - memref::AllocOp target, transform::ApplyToEachResultList &results, +DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply( + transform::TransformResults &transformResults, transform::TransformState &state) { - auto newBuffer = memref::multiBuffer(target, getFactor()); - if (failed(newBuffer)) - return emitSilenceableFailure(target->getLoc()) - << "op failed to multibuffer"; + SmallVector results; + ArrayRef payloadOps = state.getPayloadOps(getTarget()); + for (auto *op : payloadOps) { + bool canApplyMultiBuffer = true; + auto target = cast(op); + // Skip allocations not used in a loop. + for (Operation *user : target->getUsers()) { + auto loop = user->getParentOfType(); + if (!loop) { + canApplyMultiBuffer = false; + break; + } + } + if (!canApplyMultiBuffer) + continue; - results.push_back(*newBuffer); + auto newBuffer = memref::multiBuffer(target, getFactor()); + if (failed(newBuffer)) + return emitSilenceableFailure(target->getLoc()) + << "op failed to multibuffer"; + + results.push_back(*newBuffer); + } + transformResults.set(getResult().cast(), results); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir --- a/mlir/test/Dialect/MemRef/transform-ops.mlir +++ b/mlir/test/Dialect/MemRef/transform-ops.mlir @@ -30,26 +30,102 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["memref.alloc"]} in %arg1 : (!pdl.operation) -> !pdl.operation - %1 = transform.memref.multibuffer %0 {factor = 2 : i64} + %0 = transform.structured.match ops{["memref.alloc"]} in %arg1 : (!pdl.operation) -> !transform.op<"memref.alloc"> + %1 = transform.memref.multibuffer %0 {factor = 2 : i64} : (!transform.op<"memref.alloc">) -> !pdl.operation // Verify that the returned handle is usable. transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation } // ----- -// Trying to use multibuffer on alloc that are used outside of loops is -// going to fail. +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> ((d0 floordiv 4) mod 2)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> + +// CHECK-LABEL: func @multi_buffer_on_affine_loop +func.func @multi_buffer_on_affine_loop(%in: memref<16xf32>) { + // CHECK: %[[A:.*]] = memref.alloc() : memref<2x4xf32> + // expected-remark @below {{transformed}} + %tmp = memref.alloc() : memref<4xf32> + + // CHECK: %[[C0:.*]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + + // CHECK: affine.for %[[IV:.*]] = 0 + affine.for %i0 = 0 to 16 step 4 { + // CHECK: %[[I:.*]] = affine.apply #[[$MAP0]](%[[IV]]) + // CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0] [1, 4] [1, 1] : memref<2x4xf32> to memref<4xf32, strided<[1], offset: ?>> + %1 = memref.subview %in[%i0] [4] [1] : memref<16xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> + // CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4xf32, #[[$MAP1]]> to memref<4xf32, strided<[1], offset: ?>> + memref.copy %1, %tmp : memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<4xf32> + + "some_use"(%tmp) : (memref<4xf32>) ->() + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["memref.alloc"]} in %arg1 : (!pdl.operation) -> !transform.op<"memref.alloc"> + %1 = transform.memref.multibuffer %0 {factor = 2 : i64} : (!transform.op<"memref.alloc">) -> !pdl.operation + // Verify that the returned handle is usable. + transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation +} + +// ----- + +// Trying to use multibuffer on allocs that are used in different loops +// with none dominating the other is going to fail. // Check that we emit a proper error for that. -func.func @multi_buffer_uses_outside_of_loop(%in: memref<16xf32>) { +func.func @multi_buffer_uses_with_no_loop_dominator(%in: memref<16xf32>, %cond: i1) { // expected-error @below {{op failed to multibuffer}} %tmp = memref.alloc() : memref<4xf32> - "some_outside_loop_use"(%tmp) : (memref<4xf32>) -> () + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c16 = arith.constant 16 : index + scf.if %cond { + scf.for %i0 = %c0 to %c16 step %c4 { + %var = memref.subview %in[%i0] [4] [1] : memref<16xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> + memref.copy %var, %tmp : memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<4xf32> + + "some_use"(%tmp) : (memref<4xf32>) ->() + } + } + + scf.for %i0 = %c0 to %c16 step %c4 { + %1 = memref.subview %in[%i0] [4] [1] : memref<16xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> + memref.copy %1, %tmp : memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<4xf32> + + "some_use"(%tmp) : (memref<4xf32>) ->() + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["memref.alloc"]} in %arg1 : (!pdl.operation) -> !transform.op<"memref.alloc"> + %1 = transform.memref.multibuffer %0 {factor = 2 : i64} : (!transform.op<"memref.alloc">) -> !pdl.operation +} + +// ----- + +// Make sure the multibuffer operation is typed so that it only supports +// memref.alloc. +// Check that we emit an error if we try to match something else. +func.func @multi_buffer_reject_alloca(%in: memref<16xf32>, %cond: i1) { + %tmp = memref.alloca() : memref<4xf32> %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index %c16 = arith.constant 16 : index + scf.if %cond { + scf.for %i0 = %c0 to %c16 step %c4 { + %var = memref.subview %in[%i0] [4] [1] : memref<16xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> + memref.copy %var, %tmp : memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<4xf32> + + "some_use"(%tmp) : (memref<4xf32>) ->() + } + } scf.for %i0 = %c0 to %c16 step %c4 { %1 = memref.subview %in[%i0] [4] [1] : memref<16xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> @@ -62,6 +138,50 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["memref.alloc"]} in %arg1 : (!pdl.operation) -> !pdl.operation - %1 = transform.memref.multibuffer %0 {factor = 2 : i64} + %0 = transform.structured.match ops{["memref.alloca"]} in %arg1 : (!pdl.operation) -> !transform.op<"memref.alloca"> + // expected-error @below {{'transform.memref.multibuffer' op operand #0 must be Transform IR handle to memref.alloc operations, but got '!transform.op<"memref.alloca">'}} + %1 = transform.memref.multibuffer %0 {factor = 2 : i64} : (!transform.op<"memref.alloca">) -> !pdl.operation +} + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> ((d0 floordiv 4) mod 2)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> + +// CHECK-LABEL: func @multi_buffer_one_alloc_with_use_outside_of_loop +// Make sure we manage to apply multi_buffer to the memref that is used in +// the loop (%tmp) and don't error out for the one that is not (%tmp2). +func.func @multi_buffer_one_alloc_with_use_outside_of_loop(%in: memref<16xf32>) { + // CHECK: %[[A:.*]] = memref.alloc() : memref<2x4xf32> + // expected-remark @below {{transformed}} + %tmp = memref.alloc() : memref<4xf32> + %tmp2 = memref.alloc() : memref<4xf32> + + "some_use_outside_of_loop"(%tmp2) : (memref<4xf32>) -> () + + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[C4:.*]] = arith.constant 4 : index + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c16 = arith.constant 16 : index + + // CHECK: scf.for %[[IV:.*]] = %[[C0]] + scf.for %i0 = %c0 to %c16 step %c4 { + // CHECK: %[[I:.*]] = affine.apply #[[$MAP0]](%[[IV]]) + // CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0] [1, 4] [1, 1] : memref<2x4xf32> to memref<4xf32, strided<[1], offset: ?>> + %1 = memref.subview %in[%i0] [4] [1] : memref<16xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> + // CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4xf32, #[[$MAP1]]> to memref<4xf32, strided<[1], offset: ?>> + memref.copy %1, %tmp : memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<4xf32> + + "some_use"(%tmp) : (memref<4xf32>) ->() + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["memref.alloc"]} in %arg1 : (!pdl.operation) -> !transform.op<"memref.alloc"> + %1 = transform.memref.multibuffer %0 {factor = 2 : i64} : (!transform.op<"memref.alloc">) -> !pdl.operation + // Verify that the returned handle is usable. + transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -9954,6 +9954,7 @@ ":AffineDialect", ":ArithDialect", ":IR", + ":LoopLikeInterface", ":MemRefDialect", ":MemRefTransformOpsIncGen", ":MemRefTransforms",