diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -606,12 +606,45 @@ return success(); } }; + +template +struct CanonicalizeCastedExtentTensorOperandsPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Canonicalize operands. + bool anyChange = false; + auto canonicalizeOperand = [&](Value operand) { + if (auto castOp = operand.getDefiningOp()) { + // Only eliminate the cast if it holds no shape information. + bool isInformationLoosingCast = + castOp.getType().cast().isDynamicDim(0); + if (isInformationLoosingCast) { + anyChange = true; + return castOp.source(); + } + } + return operand; + }; + auto newOperands = llvm::to_vector<8>( + llvm::map_range(op.getOperands(), canonicalizeOperand)); + + // Rewrite op if any change required. + if (!anyChange) + return failure(); + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), newOperands); + return success(); + } +}; } // namespace void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add, RemoveDuplicateOperandsPattern, RemoveEmptyShapeOperandsPattern>(context); } @@ -694,9 +727,11 @@ // Canonicalization patterns have overlap with the considerations during // folding in case additional shape information is inferred at some point that // does not result in folding. - patterns.add, - RemoveEmptyShapeOperandsPattern>(context); + patterns + .add, + CstrBroadcastableEqOps, + RemoveDuplicateOperandsPattern, + RemoveEmptyShapeOperandsPattern>(context); } // Return true if there is exactly one attribute not representing a scalar diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1325,3 +1325,21 @@ %2 = shape.cstr_broadcastable %0, %1: tensor<2xindex>, tensor<1xindex> "use"(%2) : (!shape.witness) -> () } + +// ----- + +// CHECK-LABEL: @casted_extent_tensor_operands +// CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor<3xindex>) +func @casted_extent_tensor_operands(%arg0 : tensor, + %arg1 : tensor<3xindex>) -> (!shape.witness, tensor) { + // CHECK: %[[CASTED_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor to tensor<3xindex> + // CHECK: %[[WIT:.*]] = shape.cstr_broadcastable %[[CASTED_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex> + // CHECK: %[[RES:.*]] = shape.broadcast %[[CASTED_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex> + // CHECK: return %[[WIT]], %[[RES]] + %0 = tensor.cast %arg0 : tensor to tensor<3xindex> + %1 = tensor.cast %arg1 : tensor<3xindex> to tensor + %2 = shape.cstr_broadcastable %0, %1 : tensor<3xindex>, tensor + %3 = shape.broadcast %0, %1 :tensor<3xindex>, tensor + -> tensor + return %2, %3 : !shape.witness, tensor +}