diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -255,23 +255,44 @@ SetVector toProjectOut; for (scf::ForOp loop : loops) { auto ub = loop.upperBound(); - if (isDefinedOutsideOrConstant(outerLimit, ub)) - continue; - // Compute a backward slice up to, but not including, `outerLimit`. - SetVector backwardSlice; - getBackwardSlice(ub, &backwardSlice, [&](Operation *op) { - return outerLimit->isProperAncestor(op); + // Compute the index computation slice for the upper bound `ub`. + SetVector indexEdges; + indexEdges.insert(ub); + SetVector indexSlice; + getBackwardSlice(ub, &indexSlice, [&](Operation *op) { + // Continue only along the index operands of the ForOp. + if (auto forOp = dyn_cast(op)) { + if (!outerLimit->isAncestor(op)) + return false; + if (!indexEdges.contains(forOp.getInductionVar())) + return false; + indexEdges.insert(forOp.lowerBound()); + indexEdges.insert(forOp.upperBound()); + indexEdges.insert(forOp.step()); + return true; + } + // All supported index operations have one result. + assert(op->getNumResults() == 1 && + "expect operations to have one result"); + if (!indexEdges.contains(op->getResult(0))) + return false; + for (Value value : op->getOperands()) + indexEdges.insert(value); + return true; }); - backwardSlice.insert(ub.getDefiningOp()); + indexSlice.insert(ub.getDefiningOp()); // Iterate over all ops in the slice and compose them in the constraints. - for (Operation *op : llvm::reverse(backwardSlice)) { - if (!isa(op)) - return failure(); - if (isa(op)) + for (Operation *op : llvm::reverse(indexSlice)) { + // All ForOps have previously been added to the constraints and ConstantOp + // and DimOp are terminals of the index computation. + if (isa(op)) continue; - // Ensure there is a + // Check all index computation operations are supported. + if (!isa(op)) + return failure(); + // Ensure there is an id. auto ensureIdFailed = [&](Value v) { if (constraints.containsId(v)) { unsigned pos; @@ -289,6 +310,8 @@ // All supported ops have 1 result. // TODO: extend when needed. + assert(op->getNumResults() == 1 && + "expect operations to have one result"); toProjectOut.insert(op->getResult(0)); // Compose supported ops. diff --git a/mlir/test/Dialect/Linalg/hoist-padding.mlir b/mlir/test/Dialect/Linalg/hoist-padding.mlir --- a/mlir/test/Dialect/Linalg/hoist-padding.mlir +++ b/mlir/test/Dialect/Linalg/hoist-padding.mlir @@ -168,18 +168,16 @@ // CHECK: scf.for %[[I:[0-9a-z]+]] = // // CHECK: %[[MR8:.*]] = affine.min #[[$MIN_REST8]](%[[I]]) - // CHECK: %[[D0:.*]] = affine.apply #[[$DIV4]](%[[MR8]]) // Init tensor and pack. - // CHECK: %[[INIT_PACKED_A:.*]] = linalg.init_tensor [%[[D0]], 2, 2] : tensor - // CHECK: %[[CAST_INIT_PACKED_A:.*]] = tensor.cast %[[INIT_PACKED_A]] : tensor to tensor + // CHECK: %[[INIT_PACKED_A:.*]] = linalg.init_tensor [2, 2, 2] : tensor<2x2x2xf32> + // CHECK: %[[CAST_INIT_PACKED_A:.*]] = tensor.cast %[[INIT_PACKED_A]] : tensor<2x2x2xf32> to tensor // CHECK: %[[PACKED_A:.*]] = scf.for %[[II:[0-9a-z]+]] = {{.*}} iter_args(%{{.*}} = %[[CAST_INIT_PACKED_A]]) -> (tensor) { // CHECK: scf.for %[[III:[0-9a-z]+]] = // CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}, %{{.*}}, 0] [1, 1, 2] [1, 1, 1] : tensor<2xf32> into tensor // - // CHECK: %[[D0_2:.*]] = affine.apply #[[$DIV4]](%[[MR8]]) // Init tensor and pack. - // CHECK: %[[INIT_PACKED_B:.*]] = linalg.init_tensor [%[[D0_2]], 2, 2] : tensor - // CHECK: %[[CAST_INIT_PACKED_B:.*]] = tensor.cast %[[INIT_PACKED_B]] : tensor to tensor + // CHECK: %[[INIT_PACKED_B:.*]] = linalg.init_tensor [2, 2, 2] : tensor<2x2x2xf32> + // CHECK: %[[CAST_INIT_PACKED_B:.*]] = tensor.cast %[[INIT_PACKED_B]] : tensor<2x2x2xf32> to tensor // CHECK: %[[PACKED_B:.*]] = scf.for %[[II_2:[0-9a-z]+]] = {{.*}} iter_args(%{{.*}} = %[[CAST_INIT_PACKED_B]]) -> (tensor) { // CHECK: scf.for %[[III_2:[0-9a-z]+]] = // CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}, %{{.*}}, 0] [1, 1, 2] [1, 1, 1] : tensor<2xf32> into tensor diff --git a/mlir/test/Dialect/Linalg/pad-and-hoist.mlir b/mlir/test/Dialect/Linalg/pad-and-hoist.mlir --- a/mlir/test/Dialect/Linalg/pad-and-hoist.mlir +++ b/mlir/test/Dialect/Linalg/pad-and-hoist.mlir @@ -236,7 +236,7 @@ %4 = tensor.extract_slice %arg6[%arg3, %arg5] [%2, %3] [1, 1] : tensor<24x25xf32> to tensor // Packing the second input operand. - // CHECK-DOUBLE: = linalg.init_tensor [%{{.*}}, 12, 6] + // CHECK-DOUBLE: = linalg.init_tensor [3, 12, 6] // CHECK-DOUBLE: %[[PT1:.*]] = scf.for %[[PIV1:[0-9a-z]+]] = // CHECK-DOUBLE: %[[PIDX1:.*]] = affine.apply #[[DIV6]](%[[PIV1]]) // CHECK-DOUBLE: %[[T3:.*]] = tensor.extract_slice %[[ARG1]]