diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -628,12 +628,45 @@ return success(); } }; + +template +struct CanonicalizeCastExtentTensorOperandsPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Canonicalize operands. + bool anyChange = false; + auto canonicalizeOperand = [&](Value operand) { + if (auto castOp = operand.getDefiningOp()) { + // Only eliminate the cast if it holds no shape information. + bool isInformationLoosingCast = + castOp.getType().cast().isDynamicDim(0); + if (isInformationLoosingCast) { + anyChange = true; + return castOp.source(); + } + } + return operand; + }; + auto newOperands = llvm::to_vector<8>( + llvm::map_range(op.getOperands(), canonicalizeOperand)); + + // Rewrite op if any change required. + if (!anyChange) + return failure(); + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), newOperands); + return success(); + } +}; } // namespace void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add, RemoveDuplicateOperandsPattern, RemoveEmptyShapeOperandsPattern>(context); } @@ -716,7 +749,8 @@ // Canonicalization patterns have overlap with the considerations during // folding in case additional shape information is inferred at some point that // does not result in folding. - patterns.add, + CstrBroadcastableEqOps, RemoveDuplicateOperandsPattern, RemoveEmptyShapeOperandsPattern>(context); } @@ -1188,7 +1222,7 @@ // ``` // %1 = shape.shape_of %arg : tensor -> tensor // ``` -struct ShapeOfCastedExtentTensor : public OpRewritePattern { +struct ShapeOfCastExtentTensor : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::CastOp op, @@ -1214,7 +1248,7 @@ void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add(context); + patterns.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1115,8 +1115,8 @@ // CHECK-LABEL: @fold_index_cast_on_index func @fold_index_cast_on_index(%arg: index) -> index { // CHECK-NOT: size_to_index - %casted = shape.size_to_index %arg : index - return %casted : index + %0 = shape.size_to_index %arg : index + return %0 : index } // ----- @@ -1125,8 +1125,8 @@ // CHECK-LABEL: @fold_to_extent_tensor_on_tensor func @fold_to_extent_tensor_on_tensor(%arg: tensor) -> tensor { // CHECK-NOT: to_extent_tensor - %casted = shape.to_extent_tensor %arg : tensor -> tensor - return %casted : tensor + %0 = shape.to_extent_tensor %arg : tensor -> tensor + return %0 : tensor } // ----- @@ -1264,9 +1264,9 @@ // ----- -// CHECK-LABEL: @casted_extent_tensor +// CHECK-LABEL: @cast_extent_tensor // CHECK-SAME: (%[[ARG:.*]]: tensor) -> tensor -func @casted_extent_tensor(%arg : tensor) -> tensor { +func @cast_extent_tensor(%arg : tensor) -> tensor { // CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor -> tensor // CHECK: return %[[RESULT]] : tensor %0 = shape.shape_of %arg : tensor -> tensor<3xindex> @@ -1276,9 +1276,9 @@ // ----- -// CHECK-LABEL: @casted_extent_tensor +// CHECK-LABEL: @cast_extent_tensor // CHECK-SAME: (%[[ARG:.*]]: tensor) -> tensor<3xindex> -func @casted_extent_tensor(%arg : tensor) -> tensor<3xindex> { +func @cast_extent_tensor(%arg : tensor) -> tensor<3xindex> { // CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor -> tensor<3xindex> // CHECK: return %[[RESULT]] : tensor<3xindex> %0 = shape.shape_of %arg : tensor -> tensor @@ -1288,8 +1288,8 @@ // ----- -// CHECK-LABEL: @casted_extent_tensor -func @casted_extent_tensor(%arg : tensor) -> tensor<3xindex> { +// CHECK-LABEL: @cast_extent_tensor +func @cast_extent_tensor(%arg : tensor) -> tensor<3xindex> { // CHECK: tensor.cast %{{.*}} : tensor to tensor<3xindex> %0 = shape.shape_of %arg : tensor -> tensor %1 = tensor.cast %0 : tensor to tensor<3xindex> @@ -1298,8 +1298,8 @@ // ----- -// CHECK-LABEL: @casted_extent_tensor -func @casted_extent_tensor(%arg : tensor<*xf32>) -> tensor<3xindex> { +// CHECK-LABEL: @cast_extent_tensor +func @cast_extent_tensor(%arg : tensor<*xf32>) -> tensor<3xindex> { // CHECK: tensor.cast %{{.*}} : tensor to tensor<3xindex> %0 = shape.shape_of %arg : tensor<*xf32> -> tensor %1 = tensor.cast %0 : tensor to tensor<3xindex> @@ -1335,3 +1335,21 @@ %2 = shape.cstr_broadcastable %0, %1: tensor<2xindex>, tensor<1xindex> "use"(%2) : (!shape.witness) -> () } + +// ----- + +// CHECK-LABEL: @cast_extent_tensor_operands +// CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor<3xindex>) +func @cast_extent_tensor_operands(%arg0 : tensor, + %arg1 : tensor<3xindex>) -> (!shape.witness, tensor) { + // CHECK: %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor to tensor<3xindex> + // CHECK: %[[WIT:.*]] = shape.cstr_broadcastable %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex> + // CHECK: %[[RES:.*]] = shape.broadcast %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex> + // CHECK: return %[[WIT]], %[[RES]] + %0 = tensor.cast %arg0 : tensor to tensor<3xindex> + %1 = tensor.cast %arg1 : tensor<3xindex> to tensor + %2 = shape.cstr_broadcastable %0, %1 : tensor<3xindex>, tensor + %3 = shape.broadcast %0, %1 :tensor<3xindex>, tensor + -> tensor + return %2, %3 : !shape.witness, tensor +}