diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -238,20 +238,21 @@ valid = true; } +/// Add all index operands of `operation` to `indexEdges`. An index operand is +/// an operand of type index. +static void addIndexOperandsToIndexEdges(Operation *operation, + SetVector &indexEdges) { + for (Value operand : operation->getOperands()) + if (operand.getType().isIndex()) + indexEdges.insert(operand); +} + SetVector HoistingAnalysis::getIndexingLoops(PadTensorOp padTensorOp, tensor::ExtractSliceOp sliceOp) { // Set of all values used for index computation. SetVector indexEdges; - // Helper function that adds all index operands of an operation to - // `indexEdges`. An operand is an index operand if it is of index type. - auto addIndexOperandsToIndexEdges = [&](Operation *op) { - for (Value operand : op->getOperands()) - if (operand.getType().isIndex()) - indexEdges.insert(operand); - }; - // Starting from `padTensorOp` and `sliceOp` walk the use-def edges of index // type in `backwardSlice`. Add the index operands of an operation to // `indexEdges` if one of its results is an index edge found so far and store @@ -276,7 +277,7 @@ // Add the index operands of `padTensorOp` and `sliceOp` to start the // exploration of the index computation. if (op == padTensorOp || op == sliceOp) { - addIndexOperandsToIndexEdges(op); + addIndexOperandsToIndexEdges(op, indexEdges); continue; } // Add the index operands of the loop if its induction variable is @@ -284,7 +285,7 @@ // `indexingLoops` if (auto forOp = dyn_cast(op)) { if (indexEdges.contains(forOp.getInductionVar())) { - addIndexOperandsToIndexEdges(op); + addIndexOperandsToIndexEdges(op, indexEdges); indexingLoops.insert(forOp); continue; } @@ -293,7 +294,7 @@ // used for index computation. if (llvm::any_of(op->getResults(), [&](Value result) { return indexEdges.contains(result); })) - addIndexOperandsToIndexEdges(op); + addIndexOperandsToIndexEdges(op, indexEdges); } return indexingLoops; } @@ -314,6 +315,8 @@ /// - scf::ForOp are simply skipped. /// - AffineApplyOp are composed to replace the result by an equality. /// - AffineMinOp are composed by adding each entry as an upper bound. +/// Additionally, the following terminal operations are handled: +/// - DimOp and ConstantOp are skipped. /// If any other operation is met, return failure. // TODO: extend on a per-need basis. static LogicalResult @@ -323,23 +326,60 @@ SetVector toProjectOut; for (scf::ForOp loop : loops) { auto ub = loop.upperBound(); - if (isDefinedOutsideOrConstant(outerLimit, ub)) - continue; - // Compute a backward slice up to, but not including, `outerLimit`. - SetVector backwardSlice; - getBackwardSlice(ub, &backwardSlice, [&](Operation *op) { - return outerLimit->isProperAncestor(op); + // Set of all values used for index computation. + SetVector indexEdges; + indexEdges.insert(ub); + + // Compute the backward slice `indexSlice` containing the index computation + // performed to obtain the upper bound `ub`. Starting from `ub` add the + // index operands of an operation to `indexEdges` if one of its results is + // an index edge. Otherwise, stop the slice computation. For a loop, check + // if its induction variable is an index edge. + // + // Example: + // ``` + // %c0 = arith.constant 0 + // scf.for %i = %c0 to ... + // scf.for %j = %c0 to ... + // %ub = affine.min #map(%i) + // scf.for %k = %c0 to %ub + // ``` + // After computing the backward slice we obtain: + // indexEdges = [%ub, %i, %c0] + // indexSlice = [arith.constant 0, scf.for %i, affine.min #map(%i)] + SetVector indexSlice; + getBackwardSlice(ub, &indexSlice, [&](Operation *op) { + // Continue only along the index operands of the ForOp. + if (auto forOp = dyn_cast(op)) { + // Consider only loops part of the enclosing loops. + if (!outerLimit->isAncestor(op)) + return false; + if (!indexEdges.contains(forOp.getInductionVar())) + return false; + addIndexOperandsToIndexEdges(op, indexEdges); + return true; + } + // All supported index operations have one result. + assert(op->getNumResults() == 1 && + "expect operations to have one result"); + if (!indexEdges.contains(op->getResult(0))) + return false; + addIndexOperandsToIndexEdges(op, indexEdges); + return true; }); - backwardSlice.insert(ub.getDefiningOp()); + indexSlice.insert(ub.getDefiningOp()); // Iterate over all ops in the slice and compose them in the constraints. - for (Operation *op : llvm::reverse(backwardSlice)) { - if (!isa(op)) - return failure(); - if (isa(op)) + for (Operation *op : llvm::reverse(indexSlice)) { + // All ForOps have previously been added to the constraints and ConstantOp + // and DimOp are terminals of the index computation. + if (isa(op)) continue; - // Ensure there is a + // Check all index computation operations are supported. + if (!isa(op)) + return failure(); + // Ensure there is an id. auto ensureIdFailed = [&](Value v) { if (constraints.containsId(v)) { unsigned pos; @@ -357,6 +397,8 @@ // All supported ops have 1 result. // TODO: extend when needed. + assert(op->getNumResults() == 1 && + "expect operations to have one result"); toProjectOut.insert(op->getResult(0)); // Compose supported ops. diff --git a/mlir/test/Dialect/Linalg/hoist-padding.mlir b/mlir/test/Dialect/Linalg/hoist-padding.mlir --- a/mlir/test/Dialect/Linalg/hoist-padding.mlir +++ b/mlir/test/Dialect/Linalg/hoist-padding.mlir @@ -164,18 +164,16 @@ // CHECK: scf.for %[[I:[0-9a-z]+]] = // // CHECK: %[[MR8:.*]] = affine.min #[[$MIN_REST8]](%[[I]]) - // CHECK: %[[D0:.*]] = affine.apply #[[$DIV4]](%[[MR8]]) // Init tensor and pack. - // CHECK: %[[INIT_PACKED_A:.*]] = linalg.init_tensor [%[[D0]], 2, 2] : tensor - // CHECK: %[[CAST_INIT_PACKED_A:.*]] = tensor.cast %[[INIT_PACKED_A]] : tensor to tensor + // CHECK: %[[INIT_PACKED_A:.*]] = linalg.init_tensor [2, 2, 2] : tensor<2x2x2xf32> + // CHECK: %[[CAST_INIT_PACKED_A:.*]] = tensor.cast %[[INIT_PACKED_A]] : tensor<2x2x2xf32> to tensor // CHECK: %[[PACKED_A:.*]] = scf.for %[[II:[0-9a-z]+]] = {{.*}} iter_args(%{{.*}} = %[[CAST_INIT_PACKED_A]]) -> (tensor) { // CHECK: scf.for %[[III:[0-9a-z]+]] = // CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}, %{{.*}}, 0] [1, 1, 2] [1, 1, 1] : tensor<2xf32> into tensor // - // CHECK: %[[D0_2:.*]] = affine.apply #[[$DIV4]](%[[MR8]]) // Init tensor and pack. - // CHECK: %[[INIT_PACKED_B:.*]] = linalg.init_tensor [%[[D0_2]], 2, 2] : tensor - // CHECK: %[[CAST_INIT_PACKED_B:.*]] = tensor.cast %[[INIT_PACKED_B]] : tensor to tensor + // CHECK: %[[INIT_PACKED_B:.*]] = linalg.init_tensor [2, 2, 2] : tensor<2x2x2xf32> + // CHECK: %[[CAST_INIT_PACKED_B:.*]] = tensor.cast %[[INIT_PACKED_B]] : tensor<2x2x2xf32> to tensor // CHECK: %[[PACKED_B:.*]] = scf.for %[[II_2:[0-9a-z]+]] = {{.*}} iter_args(%{{.*}} = %[[CAST_INIT_PACKED_B]]) -> (tensor) { // CHECK: scf.for %[[III_2:[0-9a-z]+]] = // CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}, %{{.*}}, 0] [1, 1, 2] [1, 1, 1] : tensor<2xf32> into tensor diff --git a/mlir/test/Dialect/Linalg/pad-and-hoist.mlir b/mlir/test/Dialect/Linalg/pad-and-hoist.mlir --- a/mlir/test/Dialect/Linalg/pad-and-hoist.mlir +++ b/mlir/test/Dialect/Linalg/pad-and-hoist.mlir @@ -233,7 +233,7 @@ %4 = tensor.extract_slice %arg6[%arg3, %arg5] [%2, %3] [1, 1] : tensor<24x25xf32> to tensor // Packing the second input operand. - // CHECK-DOUBLE: = linalg.init_tensor [%{{.*}}, 12, 6] + // CHECK-DOUBLE: = linalg.init_tensor [3, 12, 6] // CHECK-DOUBLE: %[[PT1:.*]] = scf.for %[[PIV1:[0-9a-z]+]] = // CHECK-DOUBLE: %[[PIDX1:.*]] = affine.apply #[[DIV6]](%[[PIV1]]) // CHECK-DOUBLE: %[[T3:.*]] = tensor.extract_slice %[[ARG1]]