diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -389,12 +389,16 @@ LogicalResult matchAndRewrite(AssumingAllOp op, PatternRewriter &rewriter) const override { SmallVector shapes; - for (Value v : op.inputs()) { - auto cstrEqOp = v.getDefiningOp(); + for (Value w : op.inputs()) { + auto cstrEqOp = w.getDefiningOp(); if (!cstrEqOp) return failure(); - auto range = cstrEqOp.shapes(); - shapes.append(range.begin(), range.end()); + bool disjointShapes = llvm::none_of(cstrEqOp.shapes(), [&](Value s) { + return llvm::is_contained(shapes, s); + }); + if (!shapes.empty() && !cstrEqOp.shapes().empty() && disjointShapes) + return failure(); + shapes.append(cstrEqOp.shapes().begin(), cstrEqOp.shapes().end()); } rewriter.replaceOpWithNewOp(op, shapes); return success(); diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -434,7 +434,7 @@ } // ----- -// `assuming_all` with all `cstr_eq` can be collapsed. +// `assuming_all` with all `cstr_eq` and shared operands can be collapsed. // CHECK-LABEL: func @assuming_all_to_cstr_eq // CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: tensor, %[[C:.*]]: tensor<3xindex>) func @assuming_all_to_cstr_eq(%a : !shape.shape, %b : tensor, @@ -447,6 +447,22 @@ return %2 : !shape.witness } +// ----- +// `assuming_all` with all `cstr_eq` but disjoint operands cannot be collapsed. +// CHECK-LABEL: func @assuming_all_to_cstr_eq +// CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: tensor, %[[C:.*]]: tensor<3xindex>, %[[D:.*]]: tensor<3xindex>) +func @assuming_all_to_cstr_eq(%a : !shape.shape, %b : tensor, + %c : tensor<3xindex>, %d : tensor<3xindex>) -> !shape.witness { + // CHECK: %[[EQ0:.*]] = shape.cstr_eq %[[A]], %[[B]] + // CHECK: %[[EQ1:.*]] = shape.cstr_eq %[[C]], %[[D]] + // CHECK: %[[RESULT:.*]] = shape.assuming_all %[[EQ0]], %[[EQ1]] + // CHECK: return %[[RESULT]] + %0 = shape.cstr_eq %a, %b : !shape.shape, tensor + %1 = shape.cstr_eq %c, %d : tensor<3xindex>, tensor<3xindex> + %2 = shape.assuming_all %0, %1 + return %2 : !shape.witness +} + // ----- // assuming_all with known passing witnesses can be folded // CHECK-LABEL: func @f