diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2994,6 +2994,8 @@ /// The result of a tensor_cast is always a tensor. TensorType getType() { return getResult().getType().cast(); } }]; + + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -3137,6 +3137,87 @@ return impl::foldCastOp(*this); } +/// Compute a TensorType that has the joined shape knowledge of the two +/// given TensorTypes. The element types need to match. +static TensorType joinShapes(TensorType one, TensorType two) { + assert(one.getElementType() == two.getElementType()); + + if (!one.hasRank()) + return two; + if (!two.hasRank()) + return one; + + int64_t rank = one.getRank(); + if (rank != two.getRank()) + return {}; + + SmallVector join; + join.reserve(rank); + for (int64_t i = 0; i < rank; ++i) { + if (one.isDynamicDim(i)) { + join.push_back(two.getDimSize(i)); + continue; + } + if (two.isDynamicDim(i)) { + join.push_back(one.getDimSize(i)); + continue; + } + if (one.getDimSize(i) != two.getDimSize(i)) + return {}; + join.push_back(one.getDimSize(i)); + } + return RankedTensorType::get(join, one.getElementType()); +} + +namespace { + +/// Replaces chains of two tensor_cast operations by a single tensor_cast +/// operation if doing so does not remove runtime constraints. +struct ChainedTensorCast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TensorCastOp tensorCast, + PatternRewriter &rewriter) const final { + auto tensorCastOperand = + tensorCast.getOperand().getDefiningOp(); + + if (!tensorCastOperand) + return failure(); + + auto sourceType = + tensorCastOperand.getOperand().getType().cast(); + auto intermediateType = tensorCastOperand.getType().cast(); + auto resultType = tensorCast.getType().cast(); + + // We can remove the intermediate cast if joining all three produces the + // same result as just joining the source and result shapes. + auto firstJoin = + joinShapes(joinShapes(sourceType, intermediateType), resultType); + + // The join might not exist if the cast sequence would fail at runtime. + if (!firstJoin) + return failure(); + + // The newJoin always exists if the above join exists, it might just contain + // less information. If so, we cannot drop the intermediate cast, as doing + // so would remove runtime checks. + auto newJoin = joinShapes(sourceType, resultType); + if (firstJoin != newJoin) + return failure(); + + rewriter.replaceOpWithNewOp(tensorCast, resultType, + tensorCastOperand.getOperand()); + return success(); + } +}; + +} // namespace + +void TensorCastOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // Helpers for Tensor[Load|Store]Op //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -1062,3 +1062,51 @@ return %0 : tensor<3x?x?x7x?xindex> } +// ----- + +// CHECK-LABEL: @tensor_cast_chain_ok +// CHECK-SAME: %[[IN:.*]]: tensor<*xi32> +func @tensor_cast_chain_ok(%input: tensor<*xi32>) -> tensor<4x8xi32> { + // CHECK-NEXT: %[[RES:.*]] = tensor_cast %[[IN]] : tensor<*xi32> to tensor<4x8xi32> + %0 = tensor_cast %input : tensor<*xi32> to tensor<4x?xi32> + %1 = tensor_cast %0 : tensor<4x?xi32> to tensor<4x8xi32> + // CHECK-NEXT: return %[[RES]] + return %1 : tensor<4x8xi32> +} + +// ----- + +// CHECK-LABEL: @tensor_cast_chain_regain +// CHECK-SAME: %[[IN:.*]]: tensor<4xi32> +func @tensor_cast_chain_regain(%input: tensor<4xi32>) -> tensor<4xi32> { + %0 = tensor_cast %input : tensor<4xi32> to tensor + %1 = tensor_cast %0 : tensor to tensor<4xi32> + // CHECK-NEXT: return %[[IN]] + return %1 : tensor<4xi32> +} + +// ----- + +// CHECK-LABEL: @tensor_cast_chain_keep +// CHECK-SAME: %[[IN:.*]]: tensor +func @tensor_cast_chain_keep(%input: tensor) -> tensor { + // CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]] + %0 = tensor_cast %input : tensor to tensor<4x?xi32> + // CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]] + %1 = tensor_cast %0 : tensor<4x?xi32> to tensor + // CHECK-NEXT: return %[[C2]] + return %1 : tensor +} + +// ----- + +// CHECK-LABEL: @tensor_cast_chain_invalid +// CHECK-SAME: %[[IN:.*]]: tensor<4x8xi32> +func @tensor_cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> { + // CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]] + %0 = tensor_cast %input : tensor<4x8xi32> to tensor + // CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]] + %1 = tensor_cast %0 : tensor to tensor<8x4xi32> + // CHECK-NEXT: return %[[C2]] + return %1 : tensor<8x4xi32> +}