diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -53,6 +53,10 @@ namespace mlir { namespace tensor { +/// Returns true if `target` is a ranked tensor type that preserves static +/// information available in the `source` ranked tensor type. +bool preservesStaticInformation(Type source, Type target); + /// Determines whether tensor::CastOp casts to a more dynamic version of the /// source tensor. This is useful to fold a tensor.cast into a consuming op and /// implement canonicalization patterns for ops in different dialects that may diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1482,11 +1482,41 @@ return success(); } }; + +// Fold CastOp using the result of PadTensorOp back into the latter if it adds +// static information. +struct FoldTargetTensorCast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PadTensorOp padTensorOp, + PatternRewriter &rewriter) const override { + if (!padTensorOp.result().hasOneUse()) + return failure(); + auto tensorCastOp = + dyn_cast(*padTensorOp->getUsers().begin()); + if (!tensorCastOp) + return failure(); + if (!tensor::preservesStaticInformation(padTensorOp.result().getType(), + tensorCastOp.dest().getType())) + return failure(); + + auto replacementOp = rewriter.create( + padTensorOp.getLoc(), tensorCastOp.dest().getType(), + padTensorOp.source(), padTensorOp.low(), padTensorOp.high(), + padTensorOp.static_low(), padTensorOp.static_high()); + replacementOp.region().takeBody(padTensorOp.region()); + + rewriter.replaceOp(padTensorOp, replacementOp.result()); + rewriter.replaceOp(tensorCastOp, replacementOp.result()); + return success(); + } +}; } // namespace void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + results.add(context); } /// Return the padding value of the PadTensorOp if it constant. In this context, diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -31,6 +31,34 @@ // CastOp //===----------------------------------------------------------------------===// +/// Returns true if `target` is a ranked tensor type that preserves static +/// information available in the `source` ranked tensor type. +bool mlir::tensor::preservesStaticInformation(Type source, Type target) { + auto sourceType = source.dyn_cast(); + auto targetType = target.dyn_cast(); + + // Requires RankedTensorType. + if (!sourceType || !targetType) + return false; + + // Requires same elemental type. + if (sourceType.getElementType() != targetType.getElementType()) + return false; + + // Requires same rank. + if (sourceType.getRank() != targetType.getRank()) + return false; + + // If cast is towards more static sizes along any dimension, don't fold. + for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) { + if (!ShapedType::isDynamic(std::get<0>(t)) && + ShapedType::isDynamic(std::get<1>(t))) + return false; + } + + return true; +} + /// Determines whether tensor::CastOp casts to a more dynamic version of the /// source tensor. This is useful to fold a tensor.cast into a consuming op and /// implement canonicalization patterns for ops in different dialects that may @@ -57,30 +85,10 @@ if (!castOp) return false; - RankedTensorType sourceType = - castOp.source().getType().dyn_cast(); - RankedTensorType resultType = castOp.getType().dyn_cast(); - - // Requires RankedTensorType. - if (!sourceType || !resultType) - return false; - - // Requires same elemental type. - if (sourceType.getElementType() != resultType.getElementType()) - return false; - - // Requires same rank. - if (sourceType.getRank() != resultType.getRank()) - return false; - - // If cast is towards more static sizes along any dimension, don't fold. - for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) { - if (ShapedType::isDynamic(std::get<0>(t)) && - !ShapedType::isDynamic(std::get<1>(t))) - return false; - } - - return true; + // Can fold if the source of cast has at least as much static information as + // its results. + return preservesStaticInformation(castOp.getType(), + castOp.source().getType()); } /// Performs folding of any operand of `op` if it comes from a tensor::CastOp diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -696,6 +696,39 @@ // ----- +// CHECK-LABEL: @cast_of_pad_more_static +func @cast_of_pad_more_static(%arg0: tensor, %padding: index) -> tensor<32x32xf32> { + %cst = constant 0.000000e+00 : f32 + // CHECK: %[[PAD:.*]] = linalg.pad_tensor + // CHECK: tensor to tensor<32x32xf32> + %padded = linalg.pad_tensor %arg0 low[%padding, %padding] high[0, 0] { + ^bb0(%arg1: index, %arg2: index): + linalg.yield %cst : f32 + } : tensor to tensor + // CHECK-NOT: tensor.cast + %casted = tensor.cast %padded : tensor to tensor<32x32xf32> + // CHECK: return %[[PAD]] + return %casted : tensor<32x32xf32> +} + +// ----- + +// CHECK-LABEL: @cast_of_pad_less_static +func @cast_of_pad_less_static(%arg0: tensor<32x?x?xf32>, %padding: index) -> tensor { + %cst = constant 0.000000e+00 : f32 + // CHECK: linalg.pad_tensor + %padded = linalg.pad_tensor %arg0 low[%padding, %padding, %padding] high[0, 0, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index): + linalg.yield %cst : f32 + } : tensor<32x?x?xf32> to tensor<32x?x?xf32> + // CHECK: %[[CAST:.*]] = tensor.cast + %casted = tensor.cast %padded : tensor<32x?x?xf32> to tensor + // CHECK: return %[[CAST]] + return %casted : tensor +} + +// ----- + func @propogate_casts(%arg0 : tensor, %arg1 : f32, %arg2 : index, %arg3 : index) -> tensor { %c0 = constant 0 : index diff --git a/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir --- a/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir +++ b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir @@ -140,8 +140,7 @@ // CHECK: } else { // CHECK: %[[SUBTENSOR:.*]] = tensor.extract_slice %[[ARG0]][%{{.*}}, 4] [%{{.*}}, 1] [1, 1] : tensor to tensor // CHECK: %[[PADTENSOR:.*]] = linalg.pad_tensor %[[SUBTENSOR]] low[0, 0] high[%{{.*}}, 3] -// CHECK: %[[CAST:.*]] = tensor.cast %[[PADTENSOR]] : tensor to tensor<3x4xf32> -// CHECK: scf.yield %[[CAST]] +// CHECK: scf.yield %[[PADTENSOR]] // CHECK: } // CHECK: return %[[RESULT]] func @dynamic_high_pad(%arg0 : tensor, %h1: index, %pad : f32) -> tensor<3x4xf32> { diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir @@ -289,7 +289,6 @@ // CHECK: else // CHECK: tensor.extract_slice // CHECK: linalg.pad_tensor -// CHECK: tensor.cast // CHECK: tensor.extract_slice // CHECK: tensor.extract_slice // CHECK: linalg.generic diff --git a/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir --- a/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir +++ b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir @@ -111,8 +111,7 @@ // TILE1: else // TILE1: %[[SLICE:.*]] = tensor.extract_slice %arg0[0, %{{.*}}] [7, %{{.*}}] [1, 1] : tensor<7x9xf32> to tensor<7x?xf32> // TILE1: %[[PAD:.*]] = linalg.pad_tensor %[[SLICE]] low[0, 0] high[7, %{{.*}}] -// TILE1: %[[CAST:.*]] = tensor.cast %[[PAD]] : tensor<14x?xf32> to tensor<14x3xf32> -// TILE1: scf.yield %[[CAST]] : tensor<14x3xf32> +// TILE1: scf.yield %[[PAD]] : tensor<14x3xf32> // TILE1: %[[R3:.*]] = tensor.insert_slice %[[R2]] into %[[INNER_OUT]][0, %[[IV]]] [14, 3] [1, 1] : tensor<14x3xf32> into tensor<14x15xf32> // TILE1: scf.yield %[[R3]] : tensor<14x15xf32> // TILE1: return %[[RESULT]] : tensor<14x15xf32>