diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -43,7 +43,8 @@ /// 2. Pad op does not have a constant padding value. /// 3. There is no immediately enclosing scf::ForOp. /// 4. The backward slice from the pad op to the scf::ForOp to hoist above -/// contains an unknown op with a region. +/// contains an unknown op with non index type operands, a region, or a +/// memory effect. /// 5. The backward slice from the pad op to the scf::ForOp to hoist above is /// empty. /// 6. The source tensor of pad op is not defined by an extract slice op. @@ -80,7 +81,8 @@ /// operands consumed by `padTensorOp` and `sliceOp` and drops the operations /// not part of this index computation. Afterwards, the filtered /// `backwardSlice` contains only the loops whose induction variable is used, - /// directly or indirectly, to index the padded tensor. + /// directly or indirectly, to index the padded tensor. The method returns + /// failure if the filtered backward slice contains an unexpected operation. /// /// Example: /// ``` @@ -96,8 +98,8 @@ /// ``` /// dropNonIndexDependencies(%padded_slice, %slice) /// removes [scf.for %k, linalg.fill(%cst, %arg1)] from backwardSlice. - void dropNonIndexDependencies(PadTensorOp padTensorOp, - tensor::ExtractSliceOp sliceOp); + LogicalResult dropNonIndexDependencies(PadTensorOp padTensorOp, + tensor::ExtractSliceOp sliceOp); /// Encodes whether the analysis is valid and hoisting can proceed. bool valid; @@ -209,18 +211,8 @@ // Remove all ops in the backward slice that are not used to index the padded // tensor. In particular, keep `padTensorOp`, `sliceOp`, and the loop and // affine operations used for the index computation. - dropNonIndexDependencies(padTensorOp, sliceOp); - - // Check if an op has a region it is either `padTensorOp`, a scf::ForOp, or a - // LinalgOp. - for (Operation *op : backwardSlice) { - if (op != padTensorOp && op->getNumRegions() > 0 && - !isa(op)) { - LLVM_DEBUG(DBGS() << "Unsupported op with region: " << *op - << " -> skip\n"); - return; - } - } + if (failed(dropNonIndexDependencies(padTensorOp, sliceOp))) + return; // Add only the loops part of the filtered `backwardSlice` to the packing // loops. All other loops are not used to index the padded data and @@ -239,8 +231,9 @@ valid = true; } -void HoistingAnalysis::dropNonIndexDependencies( - PadTensorOp padTensorOp, tensor::ExtractSliceOp sliceOp) { +LogicalResult +HoistingAnalysis::dropNonIndexDependencies(PadTensorOp padTensorOp, + tensor::ExtractSliceOp sliceOp) { // Set of all values used for index computation. SetVector indexEdges; @@ -289,7 +282,7 @@ // Add the index operands of the loop if its induction variable is // used for index computation. if (auto forOp = dyn_cast(op)) { - if (indexEdges.contains(forOp.getInductionVar())) { + if (!hasIndexResult(op) && indexEdges.contains(forOp.getInductionVar())) { addIndexOperandsToIndexEdges(op); continue; } @@ -298,6 +291,21 @@ // used for index computation. if (hasIndexResult(op)) { addIndexOperandsToIndexEdges(op); + // Check the operands of the remaining operations all have index type. + if (llvm::any_of(op->getOperandTypes(), + [](Type type) { return !type.isIndex(); })) { + LLVM_DEBUG(DBGS() << "Unsupported op with non index type operands: " + << op << " -> skip\n"); + return failure(); + } + // Check the remaining operations do not have regions or memory effects. + auto effectInterface = dyn_cast(op); + bool hasMemoryEffect = effectInterface && !effectInterface.hasNoEffect(); + if (hasMemoryEffect || op->getNumRegions() != 0) { + LLVM_DEBUG(DBGS() << "Unsupported op with region or memory effect: " + << op << " -> skip\n"); + return failure(); + } continue; } // Remove all other operation not used by the index computation except for @@ -305,6 +313,7 @@ if (!isa(op)) backwardSlice.remove(op); } + return success(); } SmallVector @@ -416,18 +425,13 @@ if (auto sliceOp = dyn_cast(op)) if (bvm.lookupOrDefault(sliceOp.source()) == packedTensor) continue; - auto effects = dyn_cast(op); - bool hasNoEffects = !effects || effects.hasNoEffect(); - if (hasNoEffects && - (op->getNumRegions() == 0 || isa(op))) { + // Clone all operations except it is a loop. + auto forOp = dyn_cast(op); + if (!forOp) { b.clone(*op, bvm); continue; } - // TODO: support more cases as they appear. - auto forOp = dyn_cast(op); - assert(forOp && llvm::is_contained(analysis.packingLoops, forOp) && - "expect an scf::ForOp that is a packing loop"); - + // Create a packing loop that takes `packedTensor` as iteration argument. auto clonedForOp = b.create(loc, bvm.lookupOrDefault(forOp.lowerBound()), bvm.lookupOrDefault(forOp.upperBound()), diff --git a/mlir/test/Dialect/Linalg/pad-and-hoist.mlir b/mlir/test/Dialect/Linalg/pad-and-hoist.mlir --- a/mlir/test/Dialect/Linalg/pad-and-hoist.mlir +++ b/mlir/test/Dialect/Linalg/pad-and-hoist.mlir @@ -436,3 +436,154 @@ return %0 : tensor<24x25xf32> } +// ----- + +#map0 = affine_map<(d0) -> (5, -d0 + 24)> +#map1 = affine_map<(d0) -> (7, -d0 + 25)> +#map2 = affine_map<(d0) -> (-d0 + 5)> +#map3 = affine_map<(d0) -> (-d0 + 7)> + +// CHECK: unexpected_operation +// CHECK-DOUBLE: unexpected_operation +// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]*]]: memref +// CHECK-SAME: %[[ARG4:[0-9a-zA-Z]*]]: i32 +func @unexpected_operation(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>, + %arg3: memref, + %arg4: i32) -> tensor<24x25xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %c5 = arith.constant 5 : index + %c7 = arith.constant 7 : index + %c6 = arith.constant 6 : index + %c24 = arith.constant 24 : index + %c25 = arith.constant 25 : index + %c12 = arith.constant 12 : index + %c0 = arith.constant 0 : index + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = + %0 = scf.for %arg5 = %c0 to %c24 step %c5 iter_args(%arg6 = %arg2) -> (tensor<24x25xf32>) { + + // CHECK-NEXT: scf.for %[[IV1:[0-9a-zA-Z]*]] = + %1 = scf.for %arg7 = %c0 to %c25 step %c7 iter_args(%arg8 = %arg6) -> (tensor<24x25xf32>) { + + // CHECK-NEXT: scf.for %[[IV2:[0-9a-zA-Z]*]] = + %2 = scf.for %arg9 = %c0 to %c12 step %c6 iter_args(%arg10 = %arg8) -> (tensor<24x25xf32>) { + %3 = affine.min #map0(%arg5) + %4 = tensor.extract_slice %arg0[%arg5, %arg9] [%3, 6] [1, 1] : tensor<24x12xf32> to tensor + %5 = affine.min #map1(%arg7) + %6 = tensor.extract_slice %arg1[%arg9, %arg7] [6, %5] [1, 1] : tensor<12x25xf32> to tensor<6x?xf32> + %7 = tensor.extract_slice %arg10[%arg5, %arg7] [%3, %5] [1, 1] : tensor<24x25xf32> to tensor + %8 = affine.apply #map2(%3) + + // Check cannot hoist due to unexpected operation with memory effect. + // CHECK: %[[IDX0:.*]] = memref.load %[[ARG3]] + // CHECK: %[[T0:.*]] = linalg.pad_tensor {{.*}}, %[[IDX0]] + %9 = memref.load %arg3[%c0] : memref + %10 = linalg.pad_tensor %4 nofold low[%c0, %c0] high[%8, %9] { + ^bb0(%arg11: index, %arg12: index): // no predecessors + linalg.yield %cst : f32 + } : tensor to tensor<5x6xf32> + %11 = affine.apply #map3(%5) + + // Check cannot hoist due to unexpected operation with non index operand. + // CHECK: %[[IDX1:.*]] = arith.index_cast %[[ARG4]] + // CHECK: %[[T1:.*]] = linalg.pad_tensor {{.*}}[%[[IDX1]] + %12 = arith.index_cast %arg4 : i32 to index + %13 = linalg.pad_tensor %6 nofold low[%c0, %c0] high[%12, %11] { + ^bb0(%arg11: index, %arg12: index): // no predecessors + linalg.yield %cst : f32 + } : tensor<6x?xf32> to tensor<6x7xf32> + %14 = linalg.pad_tensor %7 low[%c0, %c0] high[%8, %11] { + ^bb0(%arg11: index, %arg12: index): // no predecessors + linalg.yield %cst : f32 + } : tensor to tensor<5x7xf32> + + // CHECK: = linalg.matmul ins(%[[T0]], %[[T1]] + %15 = linalg.matmul ins(%10, %13 : tensor<5x6xf32>, tensor<6x7xf32>) outs(%14 : tensor<5x7xf32>) -> tensor<5x7xf32> + %16 = tensor.extract_slice %15[0, 0] [%3, %5] [1, 1] : tensor<5x7xf32> to tensor + %17 = tensor.insert_slice %16 into %arg10[%arg5, %arg7] [%3, %5] [1, 1] : tensor into tensor<24x25xf32> + scf.yield %17 : tensor<24x25xf32> + } + scf.yield %2 : tensor<24x25xf32> + } + scf.yield %1 : tensor<24x25xf32> + } + return %0 : tensor<24x25xf32> +} + +// ----- + +#map0 = affine_map<(d0) -> (5, -d0 + 24)> +#map1 = affine_map<(d0) -> (7, -d0 + 25)> +#map2 = affine_map<(d0) -> (-d0 + 5)> +#map3 = affine_map<(d0) -> (-d0 + 7)> + +// CHECK: unexpected_loop +// CHECK-DOUBLE: unexpected_loop +// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]*]]: index +func @unexpected_loop(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>, + %arg3: index) -> tensor<24x25xf32> { + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c25 = arith.constant 25 : index + %c24 = arith.constant 24 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %c5 = arith.constant 5 : index + %cst = arith.constant 0.000000e+00 : f32 + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = + %0 = scf.for %arg4 = %c0 to %c24 step %c5 iter_args(%arg5 = %arg2) -> (tensor<24x25xf32>) { + + // CHECK-NEXT: scf.for %[[IV1:[0-9a-zA-Z]*]] = + %1 = scf.for %arg6 = %c0 to %c25 step %c7 iter_args(%arg7 = %arg5) -> (tensor<24x25xf32>) { + + // Check the padding of the first input operand is hoisted. + // CHECK: = linalg.pad_tensor + + // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = + %2 = scf.for %arg8 = %c0 to %c12 step %c6 iter_args(%arg9 = %arg7) -> (tensor<24x25xf32>) { + %3 = affine.min #map0(%arg4) + %4 = tensor.extract_slice %arg0[%arg4, %arg8] [%3, 6] [1, 1] : tensor<24x12xf32> to tensor + %5 = affine.min #map1(%arg6) + %6 = tensor.extract_slice %arg1[%arg8, %arg6] [6, %5] [1, 1] : tensor<12x25xf32> to tensor<6x?xf32> + %7 = tensor.extract_slice %arg9[%arg4, %arg6] [%3, %5] [1, 1] : tensor<24x25xf32> to tensor + %8 = affine.apply #map2(%3) + + // Check cannot hoist due to unexpected operation that has a region. + // CHECK: %[[IDX0:.*]] = scf.for {{.*}} step %[[ARG3]] + // CHECK: %[[T0:.*]] = linalg.pad_tensor {{.*}}, %[[IDX0]] + %9 = scf.for %arg10 = %c0 to %c24 step %arg3 iter_args(%arg11 = %c0) -> (index) { + %17 = arith.addi %arg3, %arg11 : index + scf.yield %17 : index + } + %10 = linalg.pad_tensor %4 nofold low[%c0, %c0] high[%8, %9] { + ^bb0(%arg10: index, %arg11: index): // no predecessors + linalg.yield %cst : f32 + } : tensor to tensor<5x6xf32> + %11 = affine.apply #map3(%5) + %12 = linalg.pad_tensor %6 nofold low[%c0, %c0] high[%c0, %11] { + ^bb0(%arg10: index, %arg11: index): // no predecessors + linalg.yield %cst : f32 + } : tensor<6x?xf32> to tensor<6x7xf32> + %13 = linalg.pad_tensor %7 low[%c0, %c0] high[%8, %11] { + ^bb0(%arg10: index, %arg11: index): // no predecessors + linalg.yield %cst : f32 + } : tensor to tensor<5x7xf32> + + // CHECK: = linalg.matmul ins(%[[T0]] + %14 = linalg.matmul ins(%10, %12 : tensor<5x6xf32>, tensor<6x7xf32>) outs(%13 : tensor<5x7xf32>) -> tensor<5x7xf32> + %15 = tensor.extract_slice %14[0, 0] [%3, %5] [1, 1] : tensor<5x7xf32> to tensor + %16 = tensor.insert_slice %15 into %arg9[%arg4, %arg6] [%3, %5] [1, 1] : tensor into tensor<24x25xf32> + scf.yield %16 : tensor<24x25xf32> + } + scf.yield %2 : tensor<24x25xf32> + } + scf.yield %1 : tensor<24x25xf32> + } + return %0 : tensor<24x25xf32> +} +