diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -470,34 +470,26 @@ //===----------------------------------------------------------------------===// OpFoldResult BroadcastOp::fold(ArrayRef operands) { - if (operands.size() == 1) + if (shapes().size() == 1) return shapes().front(); // TODO: Support folding with more than 2 input shapes if (shapes().size() > 2) return nullptr; - if (!operands[1]) - return nullptr; - - auto rhsShape = llvm::to_vector<6>( - operands[1].cast().getValues()); - if (rhsShape.empty()) - return shapes()[0]; - - if (!operands[0]) + if (!operands[0] || !operands[1]) return nullptr; - auto lhsShape = llvm::to_vector<6>( operands[0].cast().getValues()); - if (lhsShape.empty()) - return shapes()[1]; - + auto rhsShape = llvm::to_vector<6>( + operands[1].cast().getValues()); SmallVector resultShape; + // If the shapes are not compatible, we can't fold it. // TODO: Fold to an "error". if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) return nullptr; + Builder builder(getContext()); return builder.getIndexTensorAttr(resultShape); } @@ -531,6 +523,31 @@ } }; +template +struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + auto isPotentiallyNonEmptyShape = [](Value shape) { + if (auto constShape = shape.getDefiningOp()) + return constShape.shape().size() != 0; + return true; + }; + auto newOperands = llvm::to_vector<8>( + llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape)); + + // Reduce op to equivalent without empty shape operands. + if (newOperands.size() < op.getNumOperands()) { + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), newOperands, + op->getAttrs()); + return success(); + } + + return failure(); + } +}; + struct BroadcastForwardSingleOperandPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -538,18 +555,59 @@ LogicalResult matchAndRewrite(BroadcastOp op, PatternRewriter &rewriter) const override { if (op.getNumOperands() == 1) { - rewriter.replaceOp(op, op.shapes().front()); + Value uniqueShapeOperand = op.shapes().front(); + rewriter.replaceOp(op, uniqueShapeOperand); return success(); } return failure(); } }; + +struct BroadcastFoldConstantOperandsPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BroadcastOp op, + PatternRewriter &rewriter) const override { + SmallVector foldedConstantShape; + SmallVector newShapeOperands; + for (Value shape : op.shapes()) { + if (auto constShape = shape.getDefiningOp()) { + SmallVector newFoldedConstantShape; + if (OpTrait::util::getBroadcastedShape( + foldedConstantShape, + llvm::to_vector<8>(constShape.shape().getValues()), + newFoldedConstantShape)) { + foldedConstantShape = newFoldedConstantShape; + continue; + } + } + newShapeOperands.push_back(shape); + } + + // Need at least two constant operands to fold anything. + if (op.getNumOperands() - newShapeOperands.size() < 2) + return failure(); + + auto foldedConstantOperandsTy = RankedTensorType::get( + {static_cast(foldedConstantShape.size())}, + rewriter.getIndexType()); + newShapeOperands.push_back(rewriter.create( + op.getLoc(), foldedConstantOperandsTy, + rewriter.getIndexTensorAttr(foldedConstantShape))); + rewriter.replaceOpWithNewOp(op, op.getType(), + newShapeOperands); + return success(); + } +}; } // namespace void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add>(context); + patterns.add, + RemoveEmptyShapeOperandsPattern>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -120,6 +120,36 @@ // ----- +// All but one operands are known empty shapes. +// CHECK-LABEL: @all_but_one_empty +// CHECK-SAME: (%[[ARG:.*]]: !shape.shape) +func @all_but_one_empty(%arg0 : !shape.shape) -> !shape.shape { + // CHECK: return %[[ARG]] + %0 = shape.const_shape [] : !shape.shape + %1 = shape.const_shape [] : tensor<0xindex> + %2 = shape.broadcast %0, %arg0, %1, %0 : !shape.shape, !shape.shape, + tensor<0xindex>, !shape.shape -> !shape.shape + return %2 : !shape.shape +} + +// ----- + +// Partial folding. +// CHECK-LABEL: @partial_folding +// CHECK-SAME: (%[[ARG:.*]]: !shape.shape) +func @partial_folding(%arg0 : !shape.shape) -> !shape.shape { + // CHECK: %[[CST_SHAPE:.*]] = constant dense<[1, 2, 3]> : tensor<3xindex> + // CHECK: %[[RESULT:.*]] = shape.broadcast %[[ARG]], %[[CST_SHAPE]] : !shape.shape, tensor<3xindex> -> !shape.shape + // CHECK: return %[[RESULT]] + %0 = shape.const_shape [2, 1] : !shape.shape + %1 = shape.const_shape [1, 2, 3] : tensor<3xindex> + %2 = shape.broadcast %0, %arg0, %1, %0 : !shape.shape, !shape.shape, + tensor<3xindex>, !shape.shape -> !shape.shape + return %2 : !shape.shape +} + +// ----- + // Incompatible shapes. No folding. // CHECK-LABEL: func @f func @f() -> !shape.shape {