diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -89,6 +89,7 @@ ]; let hasFolder = 1; + let hasCanonicalizer = 1; let verifier = [{ return ::verify(*this); }]; } @@ -277,10 +278,10 @@ }; }]; + let hasFolder = 1; let hasCanonicalizer = 1; let assemblyFormat = "$shapes attr-dict `:` type($shapes)"; - let verifier = [{ return ::verify(*this); }]; } def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -354,13 +354,16 @@ //===----------------------------------------------------------------------===// OpFoldResult BroadcastOp::fold(ArrayRef operands) { - if (!operands[1]) - return nullptr; + if (operands.size() == 1) + return shapes().front(); // TODO: Support folding with more than 2 input shapes if (shapes().size() > 2) return nullptr; + if (!operands[1]) + return nullptr; + auto rhsShape = llvm::to_vector<6>( operands[1].cast().getValues()); if (rhsShape.empty()) @@ -384,13 +387,40 @@ } static LogicalResult verify(BroadcastOp op) { - // Ensure that AssumingAllOp contains at least one operand - if (op.getNumOperands() < 2) - return op.emitOpError("required at least 2 input shapes"); - return verifyShapeOrExtentTensorOp(op); } +namespace { +template +struct RemoveDuplicateOperandsPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Find unique operands. + SmallVector unique; + for (Value v : op.getOperands()) { + if (!llvm::is_contained(unique, v)) + unique.push_back(v); + } + + // Reduce op to equivalent with unique operands. + if (unique.size() < op.getNumOperands()) { + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), unique, + op.getAttrs()); + return success(); + } + + return failure(); + } +}; +} // namespace + +void BroadcastOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert>(context); +} + //===----------------------------------------------------------------------===// // ConcatOp //===----------------------------------------------------------------------===// @@ -772,49 +802,18 @@ // IsBroadcastableOp //===----------------------------------------------------------------------===// -static LogicalResult verify(IsBroadcastableOp op) { - // Ensure that AssumingAllOp contains at least one operand - if (op.getNumOperands() < 2) - return op.emitOpError("required at least 2 input shapes"); - return success(); +void IsBroadcastableOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert>(context); } -namespace { -struct IsBroadcastableCanonicalizationPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(IsBroadcastableOp op, - PatternRewriter &rewriter) const override { - // Find unique operands. - SmallVector unique; - for (Value v : op.getOperands()) { - if (!llvm::is_contained(unique, v)) - unique.push_back(v); - } - - // Can always broadcast fewer than two shapes. - if (unique.size() < 2) { - rewriter.replaceOpWithNewOp(op, - rewriter.getBoolAttr(true)); - return success(); - } - - // Reduce op to equivalent with unique operands. - if (unique.size() < op.getNumOperands()) { - rewriter.replaceOpWithNewOp(op, rewriter.getI1Type(), - unique); - return success(); - } - - return failure(); +OpFoldResult IsBroadcastableOp::fold(ArrayRef operands) { + // Can always broadcast fewer than two shapes. + if (operands.size() < 2) { + return BoolAttr::get(getContext(), true); } -}; -} // namespace -void IsBroadcastableOp::getCanonicalizationPatterns( - OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert(context); + return nullptr; } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1088,9 +1088,34 @@ // CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: !shape.shape) func @is_broadcastable_on_duplicate_shapes(%a : !shape.shape, %b : !shape.shape) -> i1 { - // CHECK: %[[RES:.*]] = shape.is_broadcastable %[[A]], %[[B]] + // CHECK: %[[RES:.*]] = shape.is_broadcastable %[[A]], %[[B]] : // CHECK: return %[[RES]] %0 = shape.is_broadcastable %a, %b, %a, %a, %a, %b : !shape.shape, !shape.shape, !shape.shape, !shape.shape, !shape.shape, !shape.shape return %0 : i1 } + +// ----- + +// CHECK-LABEL: @broadcast_on_same_shape +// CHECK-SAME: (%[[SHAPE:.*]]: !shape.shape) +func @broadcast_on_same_shape(%shape : !shape.shape) -> !shape.shape { + // CHECK-NOT: broadcast + // CHECK: return %[[SHAPE]] + %0 = shape.broadcast %shape, %shape, %shape : !shape.shape, !shape.shape, + !shape.shape -> !shape.shape + return %0 : !shape.shape +} + +// ----- + +// CHECK-LABEL: @broadcast_on_duplicate_shapes +// CHECK-SAME: (%[[A:.*]]: !shape.shape, %[[B:.*]]: !shape.shape) +func @broadcast_on_duplicate_shapes(%a : !shape.shape, %b : !shape.shape) + -> !shape.shape { + // CHECK: %[[RES:.*]] = shape.broadcast %[[A]], %[[B]] : + // CHECK: return %[[RES]] + %0 = shape.broadcast %a, %b, %a, %a, %a, %b : !shape.shape, !shape.shape, + !shape.shape, !shape.shape, !shape.shape, !shape.shape -> !shape.shape + return %0 : !shape.shape +} diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir --- a/mlir/test/Dialect/Shape/invalid.mlir +++ b/mlir/test/Dialect/Shape/invalid.mlir @@ -249,18 +249,8 @@ // ----- -func @fn(%arg: !shape.shape) -> i1 { - // expected-error@+1 {{required at least 2 input shapes}} - %0 = shape.is_broadcastable %arg : !shape.shape - return %0 : i1 -} - -// ----- - func @fn(%arg: !shape.shape) -> !shape.witness { // expected-error@+1 {{required at least 2 input shapes}} %0 = shape.cstr_broadcastable %arg : !shape.shape return %0 : !shape.witness } - -