diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -429,11 +429,36 @@ return success(); } }; + +template +struct RemoveDuplicateOperandsPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Find unique operands. + SmallVector unique; + for (Value v : op.getOperands()) { + if (!llvm::is_contained(unique, v)) + unique.push_back(v); + } + + // Reduce op to equivalent with unique operands. + if (unique.size() < op.getNumOperands()) { + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), unique, + op->getAttrs()); + return success(); + } + + return failure(); + } +}; } // namespace void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add(context); + patterns.add>(context); } OpFoldResult AssumingAllOp::fold(ArrayRef operands) { @@ -508,30 +533,6 @@ } namespace { -template -struct RemoveDuplicateOperandsPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - // Find unique operands. - SmallVector unique; - for (Value v : op.getOperands()) { - if (!llvm::is_contained(unique, v)) - unique.push_back(v); - } - - // Reduce op to equivalent with unique operands. - if (unique.size() < op.getNumOperands()) { - rewriter.replaceOpWithNewOp(op, op->getResultTypes(), unique, - op->getAttrs()); - return success(); - } - - return failure(); - } -}; - template struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -477,6 +477,19 @@ return %2 : !shape.witness } +// ----- +// `assuming_all` with duplicate operands. +// CHECK-LABEL: func @assuming_all_duplicate_operands +// CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) +func @assuming_all_duplicate_operands(%arg0 : tensor, + %arg1 : tensor) -> !shape.witness { + // CHECK: %[[RES:.*]] = shape.cstr_broadcastable %[[ARG0]], %[[ARG1]] + // CHECK: return %[[RES]] + %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor, tensor + %1 = shape.assuming_all %0, %0, %0 + return %1 : !shape.witness +} + // ----- // `assuming_all` with all `cstr_eq` but disjoint operands cannot be collapsed. // CHECK-LABEL: func @assuming_all_to_cstr_eq