diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -85,6 +85,9 @@ }); } + // Return both static and dynamic sizes as a list of `OpFoldResult`. + SmallVector getMixedSizes(); + // Return the Value of the dynamic size of the tensor at dimension // `idx`. Asserts that the shape is dynamic at that `idx. Value getDynamicSize(unsigned idx) { diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -85,6 +85,28 @@ /// ``` bool canFoldIntoConsumerOp(CastOp castOp); +/// Determines whether the tensor::CastOp casts to a more static version of the +/// source tensor. This is useful to fold into a producing op and implement +/// canonicaliation patterns with the `tensor.cast` op as the root, but producer +/// being from different dialects. Returns true when all conditions are met: +/// 1. source and result and ranked tensors with same element type and rank. +/// 2. the result type has more static information than the source. +/// +/// Example: +/// ```mlir +/// %1 = producer ... : tensor +/// %2 = tensor.cast %1 : tensor to tensor<8x16xf32> +/// ``` +/// +/// can be canonicalized to : +/// +/// ```mlir +/// %2 = producer ... : tensor<8x16xf32> +/// ``` +/// Not all ops might be canonicalizable this way, but for those that can be, +/// this method provides a check that it is worth doing the canonicalization. +bool canFoldIntoProducerOp(CastOp castOp); + /// Performs folding of any operand of `op` if it comes from a tensor::CastOp /// that can be folded. LogicalResult foldTensorCast(Operation *op); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1068,6 +1068,21 @@ return RankedTensorType::get(staticSizes, elementType, encoding); } +SmallVector InitTensorOp::getMixedSizes() { + SmallVector mixedSizes; + mixedSizes.reserve(getType().getRank()); + unsigned dynamicValIndex = 0; + for (Attribute attr : static_sizes()) { + auto intAttr = attr.cast(); + if (!ShapedType::isDynamic(intAttr.getInt())) { + mixedSizes.push_back(intAttr); + continue; + } + mixedSizes.push_back(sizes()[dynamicValIndex++]); + } + return mixedSizes; +} + namespace { /// Change the type of the result of a `linalg.init_tensor` by making the result /// type statically sized along dimension that in the original operation where @@ -1189,11 +1204,86 @@ return success(); } }; + +/// Canonicalize +/// +/// ```mlir +/// %0 = linalg.init_tensor [%d0, %d1] : tensor +/// %1 = tensor.cast %0 : tensor to tensor<4x?xf32> +/// ``` +/// +/// into +/// +/// ```mlir +/// %0 = linalg.init_tensor [4, %d1] : tensor<4x?xf32> +/// ``` +/// +/// This assumes the input program is correct in terms of its shape. So it +/// is safe to assume that `%d0` is in fact 4. If that was not the case, the +/// input program is wrong to begin with, so its undefined behavior anyway (i.e. +/// this optimization can still triggering without violating program semantics). +struct FoldInitTensorWithTensorCastOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::CastOp castOp, + PatternRewriter &rewriter) const override { + if (!canFoldIntoProducerOp(castOp)) + return failure(); + auto producer = castOp.source().getDefiningOp(); + if (!producer) + return failure(); + + auto resultType = castOp->getResult(0).getType().cast(); + ArrayRef resultShape = resultType.getShape(); + SmallVector currMixedSizes = producer.getMixedSizes(); + SmallVector newMixedSizes; + newMixedSizes.reserve(currMixedSizes.size()); + assert(resultShape.size() == currMixedSizes.size() && + "mismatch in result shape and sizes of init_tensor op"); + for (auto it : llvm::zip(resultShape, currMixedSizes)) { + int64_t newDim = std::get<0>(it); + OpFoldResult currDim = std::get<1>(it); + // Case 1: The init tensor dim is static. Check that the tensor cast + // result dim matches. + if (auto attr = currDim.dyn_cast()) { + if (ShapedType::isDynamic(newDim) || + newDim != attr.cast().getInt()) { + // Something is off, the cast result shape cannot be more dynamic than + // the init tensor result shape (enforced by `canFoldIntoProducer`). + // Abort for now. + return rewriter.notifyMatchFailure( + producer, "mismatch in static value of shape of init " + "tensor result and cast result"); + } + newMixedSizes.push_back(attr); + continue; + } + + // Case 2 : The tensor cast shape is static, but init tensor result shape + // is dynamic. + if (!ShapedType::isDynamic(newDim)) { + newMixedSizes.push_back(rewriter.getIndexAttr(newDim)); + continue; + } + + // Case 3 : The tensor cast shape is dynamic and init tensor result shape + // is dynamic. Use the dynamic value from the init tensor op. + newMixedSizes.push_back(currDim); + } + + rewriter.replaceOpWithNewOp(castOp, newMixedSizes, + resultType.getElementType()); + return success(); + } +}; + } // namespace void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, FoldInitTensorWithTensorReshapeOp, ReplaceStaticShapeDims>(context); @@ -1604,7 +1694,7 @@ } }; -struct FoldTensorCastOp : public OpInterfaceRewritePattern { +struct FoldTensorCastProducerOp : public OpInterfaceRewritePattern { using OpInterfaceRewritePattern::OpInterfaceRewritePattern; LogicalResult matchAndRewrite(LinalgOp op, @@ -1660,6 +1750,63 @@ } }; +/// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has +/// result that is more static than the linalg op. +struct FoldTensorCastConsumerOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::CastOp castOp, + PatternRewriter &rewriter) const override { + if (!tensor::canFoldIntoProducerOp(castOp)) + return failure(); + auto linalgOp = castOp.source().getDefiningOp(); + if (!linalgOp) + return failure(); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(linalgOp); + + Location loc = linalgOp.getLoc(); + OpResult resultValue = castOp.source().cast(); + unsigned resultNumber = resultValue.getResultNumber(); + auto resultType = castOp->getResult(0).getType().cast(); + // Replace the `outs` for the result with a `tensor.cast`. This cast is now + // going from a more dynamic shape to a less dynamic shape. If the producer + // for this cast, i.e. producer of the out operand, is also an operation + // that folds with tensor.cast consumer (like this pattern), the cast will + // continue to propagate as far up the stack as it can go. + OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber); + Value newOperand = + rewriter.create(loc, resultType, outOperand->get()); + SmallVector newOperands = linalgOp.getInputOperands(); + SmallVector outputOperands = linalgOp.getOutputOperands(); + outputOperands[resultNumber] = newOperand; + newOperands.append(outputOperands.begin(), outputOperands.end()); + + SmallVector resultTypes(linalgOp->result_type_begin(), + linalgOp->result_type_end()); + resultTypes[resultNumber] = resultType; + Operation *newOp = linalgOp.clone(rewriter, loc, resultTypes, newOperands); + + if (!resultValue.hasOneUse()) { + SmallVector results(newOp->result_begin(), newOp->result_end()); + // Create a tensor.cast operation back to the original type. + Value castBack = rewriter.create( + loc, resultValue.getType(), newOp->getResult(resultNumber)); + results[resultNumber] = castBack; + // Replace all uses except the use in the cast op that is matched by the + // pattern. Note that this cast is from a more static shape to a more + // dynamic shape. These are expected to be pulled into their consumers. + rewriter.replaceOpWithIf(linalgOp, results, + [&castOp](OpOperand &use) -> bool { + return use.getOwner() != castOp.getOperation(); + }); + } + rewriter.replaceOp(castOp, newOp->getResult(resultNumber)); + return success(); + } +}; + } // namespace #define LINALGOP_FOLDERS(XXX) \ @@ -1680,7 +1827,8 @@ void LinalgDialect::getCanonicalizationPatterns( RewritePatternSet &results) const { - results.add(getContext()); + results.add(getContext()); } Operation *LinalgDialect::materializeConstant(OpBuilder &builder, diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -98,6 +98,33 @@ castOp.source().getType()); } +/// Determines whether the tensor::CastOp casts to a more static version of the +/// source tensor. This is useful to fold into a producing op and implement +/// canonicaliation patterns with the `tensor.cast` op as the root, but producer +/// being from different dialects. Returns true when all conditions are met: +/// 1. source and result and ranked tensors with same element type and rank. +/// 2. the result type has more static information than the source. +/// +/// Example: +/// ```mlir +/// %1 = producer ... : tensor +/// %2 = tensor.cast %1 : tensor to tensor<8x16xf32> +/// ``` +/// +/// can be canonicalized to : +/// +/// ```mlir +/// %2 = producer ... : tensor<8x16xf32> +/// ``` +/// Not all ops might be canonicalizable this way, but for those that can be, +/// this method provides a check that it is worth doing the canonicalization. +bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) { + if (!castOp) + return false; + return preservesStaticInformation(castOp.source().getType(), + castOp.getType()); +} + /// Performs folding of any operand of `op` if it comes from a tensor::CastOp /// that can be folded. LogicalResult mlir::tensor::foldTensorCast(Operation *op) { diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -244,6 +244,17 @@ // ----- +func @fold_init_tensor_with_cast(%arg0 : index) -> tensor<1x12xf32> { + %0 = linalg.init_tensor [%arg0, 12] : tensor + %1 = tensor.cast %0 : tensor to tensor<1x12xf32> + return %1 : tensor<1x12xf32> +} +// CHECK: func @fold_init_tensor_with_cast(%[[ARG0:.+]]: index) +// CHECK: %[[T0:.+]] = linalg.init_tensor [1, 12] : tensor<1x12xf32> +// CHECK: return %[[T0]] : tensor<1x12xf32> + +// ----- + #accesses = [ affine_map<(i, j) -> (i, j)> ] @@ -747,3 +758,23 @@ %2 = tensor.insert_slice %pad into %1 [0, 0, 256] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32> return %2: tensor<8x384x384xf32> } + +// ----- + +func @fold_linalgop_with_cast_consumer(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> (tensor<4x8xf32>, tensor) { + %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + %1 = tensor.cast %0 : tensor to tensor<4x8xf32> + return %1, %0 : tensor<4x8xf32>, tensor +} +// CHECK: func @fold_linalgop_with_cast_consumer( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor) +// CHECK: %[[OUT_CAST:.+]] = tensor.cast %[[ARG2]] : tensor to tensor<4x8xf32> +// CHECK: %[[MATMUL:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : +// CHECK-SAME: outs(%[[OUT_CAST]] : +// CHECK: %[[RESULT_CAST:.+]] = tensor.cast %[[MATMUL]] +// CHECK: return %[[MATMUL]], %[[RESULT_CAST]] diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir @@ -47,7 +47,8 @@ // CHECK: scf.for %[[I:[0-9a-z]*]] // CHECK: %[[sizeA0:.*]] = affine.min #[[BOUND2_MAP]](%[[I]])[%[[dA0]]] // CHECK: %[[stA:.*]] = tensor.extract_slice %[[A]][%[[I]], 0] [%[[sizeA0]], %[[dA1]]] [1, 1] : tensor to tensor -// CHECK-NEXT: scf.for %[[J:[0-9a-z]*]] +// CHECK: %[[castA:.*]] = tensor.cast %[[stA]] : tensor to tensor<2x?xf32> +// CHECK: scf.for %[[J:[0-9a-z]*]] // CHECK-NEXT: scf.for %[[K:[0-9a-z]*]] {{.*}} iter_args(%[[RES:[0-9a-z]*]] // CHECK-DAG: %[[stB1:.*]] = tensor.extract_slice %[[B]][%[[K]], %[[J]]] [4, 3] [1, 1] : tensor to tensor<4x3xf32> // CHECK-DAG: %[[stF:.*]] = tensor.extract_slice %[[RES]][%[[I]], %[[J]]] [2, 3] [1, 1] : tensor to tensor<2x3xf32> @@ -56,9 +57,10 @@ // CHECK: %[[sizeB1:.*]] = affine.min #[[BOUND4_MAP]](%[[K]])[%[[dB1]]] // CHECK: %[[stB2:.*]] = tensor.extract_slice %[[B]][0, %[[K]]] [%[[dB0]], %[[sizeB1]]] [1, 1] : tensor to tensor // CHECK: %[[stC:.*]] = tensor.extract_slice %[[C]][%[[I]], %[[K]]] [%[[sizeA0]], %[[sizeB1]]] [1, 1] : tensor to tensor -// CHECK: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor, tensor) outs(%[[stC]] : tensor) -> tensor -// CHECK: %[[CAST:.*]] = tensor.cast %[[stD]] : tensor to tensor<2x4xf32> -// CHECK-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[CAST]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) outs(%[[stF]] : tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK-DAG: %[[castB:.+]] = tensor.cast %[[stB2]] : tensor to tensor +// CHECK-DAG: %[[castC:.+]] = tensor.cast %[[stC]] : tensor to tensor<2x4xf32> +// CHECK: %[[stD:.*]] = linalg.matmul ins(%[[castA]], %[[castB]] : tensor<2x?xf32>, tensor) outs(%[[castC]] : tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[stD]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) outs(%[[stF]] : tensor<2x3xf32>) -> tensor<2x3xf32> // CHECK-NEXT: tensor.insert_slice %[[stG]] into %[[RES]][%[[I]], %[[J]]] // -----