diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -450,34 +450,27 @@ //===----------------------------------------------------------------------===// OpFoldResult BroadcastOp::fold(ArrayRef operands) { - if (operands.size() == 1) + if (shapes().size() == 1 && shapes().front().getType().isa() == + getType().isa()) return shapes().front(); // TODO: Support folding with more than 2 input shapes if (shapes().size() > 2) return nullptr; - if (!operands[1]) - return nullptr; - - auto rhsShape = llvm::to_vector<6>( - operands[1].cast().getValues()); - if (rhsShape.empty()) - return shapes()[0]; - - if (!operands[0]) + if (!operands[0] || !operands[1]) return nullptr; - auto lhsShape = llvm::to_vector<6>( operands[0].cast().getValues()); - if (lhsShape.empty()) - return shapes()[1]; - + auto rhsShape = llvm::to_vector<6>( + operands[1].cast().getValues()); SmallVector resultShape; + // If the shapes are not compatible, we can't fold it. // TODO: Fold to an "error". if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) return nullptr; + Builder builder(getContext()); return builder.getIndexTensorAttr(resultShape); } @@ -511,6 +504,35 @@ } }; +template +struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + SmallVector nonEmptyShapeOperands; + auto isEmptyShape = [](Value shape) { + auto constShape = shape.getDefiningOp(); + if (!constShape) + return false; + return constShape.shape().size() == 0; + }; + for (Value shape : op->getOperands()) { + if (!isEmptyShape(shape)) + nonEmptyShapeOperands.push_back(shape); + } + + // Reduce op to equivalent with non-empty shape operands. + if (nonEmptyShapeOperands.size() < op.getNumOperands()) { + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), + nonEmptyShapeOperands, op->getAttrs()); + return success(); + } + + return failure(); + } +}; + struct BroadcastForwardSingleOperandPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -518,18 +540,63 @@ LogicalResult matchAndRewrite(BroadcastOp op, PatternRewriter &rewriter) const override { if (op.getNumOperands() == 1) { - rewriter.replaceOp(op, op.shapes().front()); - return success(); + Value uniqueShapeOperand = op.shapes().front(); + if (op.getType().isa() == + uniqueShapeOperand.getType().isa()) { + rewriter.replaceOp(op, uniqueShapeOperand); + return success(); + } } return failure(); } }; + +struct BroadcastFoldConstantOperandsPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BroadcastOp op, + PatternRewriter &rewriter) const override { + SmallVector foldedConstantOperands; + SmallVector newShapeOperands; + for (Value shape : op.shapes()) { + auto constShape = shape.getDefiningOp(); + if (constShape) { + SmallVector newFoldedConstantOperands; + if (OpTrait::util::getBroadcastedShape( + foldedConstantOperands, + llvm::to_vector<8>(constShape.shape().getValues()), + newFoldedConstantOperands)) { + foldedConstantOperands = newFoldedConstantOperands; + continue; + } + } + newShapeOperands.push_back(shape); + } + + // Need at least two constant operands to fold anything. + if (op.getNumOperands() - newShapeOperands.size() < 2) + return failure(); + + auto foldedConstantOperandsTy = RankedTensorType::get( + {static_cast(foldedConstantOperands.size())}, + rewriter.getIndexType()); + newShapeOperands.push_back(rewriter.create( + op.getLoc(), foldedConstantOperandsTy, + rewriter.getIndexTensorAttr(foldedConstantOperands))); + rewriter.replaceOpWithNewOp(op, op.getType(), + newShapeOperands); + return success(); + } +}; } // namespace void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add>(context); + patterns.add, + RemoveEmptyShapeOperandsPattern>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -120,6 +120,36 @@ // ----- +// All but one operands are known empty shapes. +// CHECK-LABEL: @all_but_one_empty +// CHECK-SAME: (%[[ARG:.*]]: !shape.shape) +func @all_but_one_empty(%arg0 : !shape.shape) -> !shape.shape { + // CHECK: return %[[ARG]] + %0 = shape.const_shape [] : !shape.shape + %1 = shape.const_shape [] : tensor<0xindex> + %2 = shape.broadcast %0, %arg0, %1, %0 : !shape.shape, !shape.shape, + tensor<0xindex>, !shape.shape -> !shape.shape + return %2 : !shape.shape +} + +// ----- + +// Partial folding. +// CHECK-LABEL: @partial_folding +// CHECK-SAME: (%[[ARG:.*]]: !shape.shape) +func @partial_folding(%arg0 : !shape.shape) -> !shape.shape { + // CHECK: %[[CST_SHAPE:.*]] = constant dense<[1, 2, 3]> : tensor<3xindex> + // CHECK: %[[RESULT:.*]] = shape.broadcast %[[ARG]], %[[CST_SHAPE]] : !shape.shape, tensor<3xindex> -> !shape.shape + // CHECK: return %[[RESULT]] + %0 = shape.const_shape [2, 1] : !shape.shape + %1 = shape.const_shape [1, 2, 3] : tensor<3xindex> + %2 = shape.broadcast %0, %arg0, %1, %0 : !shape.shape, !shape.shape, + tensor<3xindex>, !shape.shape -> !shape.shape + return %2 : !shape.shape +} + +// ----- + // Incompatible shapes. No folding. // CHECK-LABEL: func @f func @f() -> !shape.shape {