diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -255,23 +255,39 @@ SetVector toProjectOut; for (scf::ForOp loop : loops) { auto ub = loop.upperBound(); - if (isDefinedOutsideOrConstant(outerLimit, ub)) - continue; // Compute a backward slice up to, but not including, `outerLimit`. - SetVector backwardSlice; - getBackwardSlice(ub, &backwardSlice, [&](Operation *op) { - return outerLimit->isProperAncestor(op); + SetVector indexEdges; + indexEdges.insert(ub); + SetVector indexSlice; + getBackwardSlice(ub, &indexSlice, [&](Operation *op) { + // Continue only along the index operands of the for loop. + if (auto forOp = dyn_cast(op)) { + if (!indexEdges.contains(forOp.getInductionVar())) + return false; + indexEdges.insert(forOp.lowerBound()); + indexEdges.insert(forOp.upperBound()); + indexEdges.insert(forOp.step()); + return true; + } + // All supported operations have one result. + assert(op->getNumResults() == 1 && + "expect operations to have one result"); + if (!indexEdges.contains(op->getResult(0))) + return false; + for (Value value : op->getOperands()) + indexEdges.insert(value); + return true; }); - backwardSlice.insert(ub.getDefiningOp()); + indexSlice.insert(ub.getDefiningOp()); // Iterate over all ops in the slice and compose them in the constraints. - for (Operation *op : llvm::reverse(backwardSlice)) { - if (!isa(op)) - return failure(); - if (isa(op)) + for (Operation *op : llvm::reverse(indexSlice)) { + if (isa(op)) continue; - // Ensure there is a + if (!isa(op)) + return failure(); + // Ensure there is an id. auto ensureIdFailed = [&](Value v) { if (constraints.containsId(v)) { unsigned pos; @@ -289,6 +305,8 @@ // All supported ops have 1 result. // TODO: extend when needed. + assert(op->getNumResults() == 1 && + "expect operations to have one result"); toProjectOut.insert(op->getResult(0)); // Compose supported ops. diff --git a/mlir/test/Dialect/Linalg/hoist-padding.mlir b/mlir/test/Dialect/Linalg/hoist-padding.mlir --- a/mlir/test/Dialect/Linalg/hoist-padding.mlir +++ b/mlir/test/Dialect/Linalg/hoist-padding.mlir @@ -168,18 +168,16 @@ // CHECK: scf.for %[[I:[0-9a-z]+]] = // // CHECK: %[[MR8:.*]] = affine.min #[[$MIN_REST8]](%[[I]]) - // CHECK: %[[D0:.*]] = affine.apply #[[$DIV4]](%[[MR8]]) // Init tensor and pack. - // CHECK: %[[INIT_PACKED_A:.*]] = linalg.init_tensor [%[[D0]], 2, 2] : tensor - // CHECK: %[[CAST_INIT_PACKED_A:.*]] = tensor.cast %[[INIT_PACKED_A]] : tensor to tensor + // CHECK: %[[INIT_PACKED_A:.*]] = linalg.init_tensor [2, 2, 2] : tensor<2x2x2xf32> + // CHECK: %[[CAST_INIT_PACKED_A:.*]] = tensor.cast %[[INIT_PACKED_A]] : tensor<2x2x2xf32> to tensor // CHECK: %[[PACKED_A:.*]] = scf.for %[[II:[0-9a-z]+]] = {{.*}} iter_args(%{{.*}} = %[[CAST_INIT_PACKED_A]]) -> (tensor) { // CHECK: scf.for %[[III:[0-9a-z]+]] = // CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}, %{{.*}}, 0] [1, 1, 2] [1, 1, 1] : tensor<2xf32> into tensor // - // CHECK: %[[D0_2:.*]] = affine.apply #[[$DIV4]](%[[MR8]]) // Init tensor and pack. - // CHECK: %[[INIT_PACKED_B:.*]] = linalg.init_tensor [%[[D0_2]], 2, 2] : tensor - // CHECK: %[[CAST_INIT_PACKED_B:.*]] = tensor.cast %[[INIT_PACKED_B]] : tensor to tensor + // CHECK: %[[INIT_PACKED_B:.*]] = linalg.init_tensor [2, 2, 2] : tensor<2x2x2xf32> + // CHECK: %[[CAST_INIT_PACKED_B:.*]] = tensor.cast %[[INIT_PACKED_B]] : tensor<2x2x2xf32> to tensor // CHECK: %[[PACKED_B:.*]] = scf.for %[[II_2:[0-9a-z]+]] = {{.*}} iter_args(%{{.*}} = %[[CAST_INIT_PACKED_B]]) -> (tensor) { // CHECK: scf.for %[[III_2:[0-9a-z]+]] = // CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}, %{{.*}}, 0] [1, 1, 2] [1, 1, 1] : tensor<2xf32> into tensor diff --git a/mlir/test/Dialect/Linalg/pad-and-hoist.mlir b/mlir/test/Dialect/Linalg/pad-and-hoist.mlir --- a/mlir/test/Dialect/Linalg/pad-and-hoist.mlir +++ b/mlir/test/Dialect/Linalg/pad-and-hoist.mlir @@ -128,7 +128,7 @@ %4 = tensor.extract_slice %arg6[%arg3, %arg5] [%2, %3] [1, 1] : tensor<24x25xf32> to tensor // Packing the second input operand. - // CHECK-DOUBLE: = linalg.init_tensor [%{{.*}}, 12, 6] + // CHECK-DOUBLE: = linalg.init_tensor [3, 12, 6] // CHECK-DOUBLE: = linalg.pad_tensor {{.*}} nofold // CHECK-DOUBLE: scf.for %[[IV2:[0-9a-zA-Z]*]] =