diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -435,6 +435,15 @@ /// Return the dimensions of the source that are dropped in the /// result when the result is rank-reduced. llvm::SmallBitVector getDroppedDims(); + + /// Given a `value`, asserted to be of RankedTensorType, build an + /// ExtractSliceOp that results in a rank-reducing extract to the desired + /// tensor shape and return the new value created. + /// If the shape of `value` is already the `desiredShape`, just return + /// `value`. + /// If the shape of `value` cannot be rank-reduced to `desiredShape`, fail. + static FailureOr rankReduceIfNeeded( + OpBuilder &b, Location loc, Value value, ArrayRef desiredShape); }]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformUtils.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -299,7 +300,14 @@ // Replace the extract op. Operation *fusedOp = tiledProducer->getDefiningOp(); - rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(resultNumber)); + auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded( + rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber), + sliceOpToTile->getResult(0) + .getType() + .cast() + .getShape()); + assert(succeeded(maybeRankReduced) && "unexpected shape"); + rewriter.replaceOp(sliceOpToTile, *maybeRankReduced); return fusedOp; } @@ -399,7 +407,14 @@ // Replace the extract op. Operation *fusedOp = tiledProducer->getDefiningOp(); - rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(resultNumber)); + auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded( + rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber), + sliceOpToTile->getResult(0) + .getType() + .cast() + .getShape()); + assert(succeeded(maybeRankReduced) && "unexpected shape"); + rewriter.replaceOp(sliceOpToTile, *maybeRankReduced); // Replace the use in containingOp. rewriter.updateRootInPlace(containingOp, [&]() { diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -17,7 +17,9 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Support/MathExtras.h" @@ -1754,6 +1756,23 @@ return droppedDims; } +FailureOr +ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value, + ArrayRef desiredShape) { + auto sourceTensorType = value.getType().dyn_cast(); + assert(sourceTensorType && "not a ranked tensor type"); + auto sourceShape = sourceTensorType.getShape(); + if (sourceShape.equals(desiredShape)) + return value; + auto maybeRankReductionMask = + mlir::computeRankReductionMask(sourceShape, desiredShape); + if (!maybeRankReductionMask) + return failure(); + return createCanonicalRankReducingExtractSliceOp( + b, loc, value, + RankedTensorType::Builder(sourceTensorType).setShape(desiredShape)); +} + LogicalResult ExtractSliceOp::reifyResultShapes( OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { reifiedReturnShapes.resize(1); @@ -2375,7 +2394,6 @@ insertSliceOp, cast, insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); - cast.getDefiningOp()->getParentOfType().dump(); return success(); } }; @@ -2475,8 +2493,7 @@ SmallVector inferredShape; for (auto i : llvm::seq(0, rank)) { - if (sourceType.isDynamicDim(i) || - staticLow[i] == ShapedType::kDynamic || + if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic || staticHigh[i] == ShapedType::kDynamic) { inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic : resultShape[i]); @@ -2525,8 +2542,7 @@ // This will grow staticLow and staticHigh with 1 value. If the config is // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1 // value as well. - dispatchIndexOpFoldResults(low, dynamicLow, staticLow, - ShapedType::kDynamic); + dispatchIndexOpFoldResults(low, dynamicLow, staticLow, ShapedType::kDynamic); dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh, ShapedType::kDynamic); if (!resultType) { diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir --- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir @@ -96,6 +96,52 @@ // ----- +module { + func.func @foo(%0: tensor) -> tensor { + return %0: tensor + } + + // CHECK-LABEL: func.func @fuse_tileable_op_rank_reducing + // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index + // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor + // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor + func.func @fuse_tileable_op_rank_reducing(%arg0: index, %arg1: tensor, %arg2: tensor) -> tensor { + %cst = arith.constant 4.200000e+01 : f32 + %c0 = arith.constant 0 : index + %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor) -> tensor + %d0 = tensor.dim %arg1, %c0 : tensor + + // CHECK: scf.foreach_thread {{.*}} -> (tensor) { + %2 = scf.foreach_thread (%arg3) in (%d0) shared_outs(%o = %0) -> (tensor) { + %5 = tensor.extract_slice %o[%arg3] [1] [1] : tensor to tensor + + // CHECK: tensor.extract_slice %{{.*}}[%{{.*}}] [1] [1] : tensor to tensor<1xf32> + // CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<1xf32>) -> tensor<1xf32> + // CHECK: tensor.extract_slice %{{.*}}[0] [1] [1] : tensor<1xf32> to tensor + // CHECK: func.call @foo(%{{.*}}) : (tensor) -> tensor + %7 = func.call @foo(%5) : (tensor) -> tensor + + scf.foreach_thread.perform_concurrently { + // CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [1] [1] : tensor into tensor + tensor.parallel_insert_slice %7 into %o[%arg3] [1] [1] : tensor into tensor + } + } + // CHECK: } + func.return %2 : tensor + } + + transform.sequence failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 + %1 = transform.structured.match ops{["scf.foreach_thread"]} in %arg1 + + // linalg.fill is tileable. The op is tiled and fused. + transform.structured.fuse_into_containing_op %0 into %1 + } +} + +// ----- + #map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> #map1 = affine_map<(d0)[s0] -> (d0 * s0)> #map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>