diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td @@ -41,6 +41,7 @@ let useDefaultTypePrinterParser = 1; let hasConstantMaterializer = 1; let hasOperationAttrVerify = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; } class Shape_Type : TypeDef { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -215,10 +215,10 @@ // TODO: Canonicalization should be implemented for shapes that can be // determined through mixtures of the known dimensions of the inputs. -OpFoldResult AnyOp::fold(ArrayRef operands) { +OpFoldResult AnyOp::fold(FoldAdaptor adaptor) { // Only the last operand is checked because AnyOp is commutative. - if (operands.back()) - return operands.back(); + if (adaptor.getInputs().back()) + return adaptor.getInputs().back(); return nullptr; } @@ -410,13 +410,14 @@ return eachHasOnlyOneOfTypes(l, r); } -OpFoldResult mlir::shape::AddOp::fold(ArrayRef operands) { +OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) { // add(x, 0) -> x if (matchPattern(getRhs(), m_Zero())) return getLhs(); return constFoldBinaryOp( - operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); + adaptor.getOperands(), + [](APInt a, const APInt &b) { return std::move(a) + b; }); } LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); } @@ -604,11 +605,11 @@ RemoveDuplicateOperandsPattern>(context); } -OpFoldResult AssumingAllOp::fold(ArrayRef operands) { +OpFoldResult AssumingAllOp::fold(FoldAdaptor adaptor) { // Iterate in reverse to first handle all constant operands. They are // guaranteed to be the tail of the inputs because this is commutative. - for (int idx = operands.size() - 1; idx >= 0; idx--) { - Attribute a = operands[idx]; + for (int idx = adaptor.getInputs().size() - 1; idx >= 0; idx--) { + Attribute a = adaptor.getInputs()[idx]; // Cannot fold if any inputs are not constant; if (!a) return nullptr; @@ -637,7 +638,7 @@ // BroadcastOp //===----------------------------------------------------------------------===// -OpFoldResult BroadcastOp::fold(ArrayRef operands) { +OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { if (getShapes().size() == 1) { // Otherwise, we need a cast which would be a canonicalization, not folding. if (getShapes().front().getType() != getType()) @@ -649,12 +650,12 @@ if (getShapes().size() > 2) return nullptr; - if (!operands[0] || !operands[1]) + if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1]) return nullptr; auto lhsShape = llvm::to_vector<6>( - operands[0].cast().getValues()); + adaptor.getShapes()[0].cast().getValues()); auto rhsShape = llvm::to_vector<6>( - operands[1].cast().getValues()); + adaptor.getShapes()[1].cast().getValues()); SmallVector resultShape; // If the shapes are not compatible, we can't fold it. @@ -847,13 +848,13 @@ // ConcatOp //===----------------------------------------------------------------------===// -OpFoldResult ConcatOp::fold(ArrayRef operands) { - if (!operands[0] || !operands[1]) +OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) { + if (!adaptor.getLhs() || !adaptor.getRhs()) return nullptr; auto lhsShape = llvm::to_vector<6>( - operands[0].cast().getValues()); + adaptor.getLhs().cast().getValues()); auto rhsShape = llvm::to_vector<6>( - operands[1].cast().getValues()); + adaptor.getRhs().cast().getValues()); SmallVector resultShape; resultShape.append(lhsShape.begin(), lhsShape.end()); resultShape.append(rhsShape.begin(), rhsShape.end()); @@ -903,7 +904,7 @@ return success(); } -OpFoldResult ConstShapeOp::fold(ArrayRef) { return getShapeAttr(); } +OpFoldResult ConstShapeOp::fold(FoldAdaptor) { return getShapeAttr(); } void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { @@ -966,14 +967,14 @@ return true; } -OpFoldResult CstrBroadcastableOp::fold(ArrayRef operands) { +OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) { // No broadcasting is needed if all operands but one are scalar. - if (hasAtMostSingleNonScalar(operands)) + if (hasAtMostSingleNonScalar(adaptor.getShapes())) return BoolAttr::get(getContext(), true); if ([&] { SmallVector, 6> extents; - for (const auto &operand : operands) { + for (const auto &operand : adaptor.getShapes()) { if (!operand) return false; extents.push_back(llvm::to_vector<6>( @@ -1018,9 +1019,10 @@ patterns.add(context); } -OpFoldResult CstrEqOp::fold(ArrayRef operands) { - if (llvm::all_of(operands, - [&](Attribute a) { return a && a == operands[0]; })) +OpFoldResult CstrEqOp::fold(FoldAdaptor adaptor) { + if (llvm::all_of(adaptor.getShapes(), [&](Attribute a) { + return a && a == adaptor.getShapes().front(); + })) return BoolAttr::get(getContext(), true); // Because a failing witness result here represents an eventual assertion @@ -1038,7 +1040,7 @@ build(builder, result, builder.getIndexAttr(value)); } -OpFoldResult ConstSizeOp::fold(ArrayRef) { return getValueAttr(); } +OpFoldResult ConstSizeOp::fold(FoldAdaptor) { return getValueAttr(); } void ConstSizeOp::getAsmResultNames( llvm::function_ref setNameFn) { @@ -1052,16 +1054,14 @@ // ConstWitnessOp //===----------------------------------------------------------------------===// -OpFoldResult ConstWitnessOp::fold(ArrayRef) { - return getPassingAttr(); -} +OpFoldResult ConstWitnessOp::fold(FoldAdaptor) { return getPassingAttr(); } //===----------------------------------------------------------------------===// // CstrRequireOp //===----------------------------------------------------------------------===// -OpFoldResult CstrRequireOp::fold(ArrayRef operands) { - return operands[0]; +OpFoldResult CstrRequireOp::fold(FoldAdaptor adaptor) { + return adaptor.getPred(); } //===----------------------------------------------------------------------===// @@ -1076,7 +1076,7 @@ return std::nullopt; } -OpFoldResult DimOp::fold(ArrayRef operands) { +OpFoldResult DimOp::fold(FoldAdaptor adaptor) { Type valType = getValue().getType(); auto valShapedType = valType.dyn_cast(); if (!valShapedType || !valShapedType.hasRank()) @@ -1120,11 +1120,11 @@ // DivOp //===----------------------------------------------------------------------===// -OpFoldResult DivOp::fold(ArrayRef operands) { - auto lhs = operands[0].dyn_cast_or_null(); +OpFoldResult DivOp::fold(FoldAdaptor adaptor) { + auto lhs = adaptor.getLhs().dyn_cast_or_null(); if (!lhs) return nullptr; - auto rhs = operands[1].dyn_cast_or_null(); + auto rhs = adaptor.getRhs().dyn_cast_or_null(); if (!rhs) return nullptr; @@ -1163,14 +1163,14 @@ // ShapeEqOp //===----------------------------------------------------------------------===// -OpFoldResult ShapeEqOp::fold(ArrayRef operands) { +OpFoldResult ShapeEqOp::fold(FoldAdaptor adaptor) { bool allSame = true; - if (!operands.empty() && !operands[0]) + if (!adaptor.getShapes().empty() && !adaptor.getShapes().front()) return {}; - for (Attribute operand : operands.drop_front(1)) { + for (Attribute operand : adaptor.getShapes().drop_front()) { if (!operand) return {}; - allSame = allSame && operand == operands[0]; + allSame = allSame && operand == adaptor.getShapes().front(); } return BoolAttr::get(getContext(), allSame); } @@ -1179,10 +1179,10 @@ // IndexToSizeOp //===----------------------------------------------------------------------===// -OpFoldResult IndexToSizeOp::fold(ArrayRef operands) { +OpFoldResult IndexToSizeOp::fold(FoldAdaptor adaptor) { // Constant values of both types, `shape.size` and `index`, are represented as // `IntegerAttr`s which makes constant folding simple. - if (Attribute arg = operands[0]) + if (Attribute arg = adaptor.getArg()) return arg; return {}; } @@ -1196,11 +1196,11 @@ // FromExtentsOp //===----------------------------------------------------------------------===// -OpFoldResult FromExtentsOp::fold(ArrayRef operands) { - if (llvm::any_of(operands, [](Attribute a) { return !a; })) +OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) { + if (llvm::any_of(adaptor.getExtents(), [](Attribute a) { return !a; })) return nullptr; SmallVector extents; - for (auto attr : operands) + for (auto attr : adaptor.getExtents()) extents.push_back(attr.cast().getInt()); Builder builder(getContext()); return builder.getIndexTensorAttr(extents); @@ -1335,8 +1335,8 @@ return std::nullopt; } -OpFoldResult GetExtentOp::fold(ArrayRef operands) { - auto elements = operands[0].dyn_cast_or_null(); +OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) { + auto elements = adaptor.getShape().dyn_cast_or_null(); if (!elements) return nullptr; std::optional dim = getConstantDim(); @@ -1386,9 +1386,9 @@ patterns.add>(context); } -OpFoldResult IsBroadcastableOp::fold(ArrayRef operands) { +OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) { // Can always broadcast fewer than two shapes. - if (operands.size() < 2) { + if (adaptor.getShapes().size() < 2) { return BoolAttr::get(getContext(), true); } @@ -1479,8 +1479,8 @@ // RankOp //===----------------------------------------------------------------------===// -OpFoldResult shape::RankOp::fold(ArrayRef operands) { - auto shape = operands[0].dyn_cast_or_null(); +OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) { + auto shape = adaptor.getShape().dyn_cast_or_null(); if (!shape) return {}; int64_t rank = shape.getNumElements(); @@ -1557,10 +1557,10 @@ // NumElementsOp //===----------------------------------------------------------------------===// -OpFoldResult NumElementsOp::fold(ArrayRef operands) { +OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) { // Fold only when argument constant. - Attribute shape = operands[0]; + Attribute shape = adaptor.getShape(); if (!shape) return {}; @@ -1596,7 +1596,7 @@ // MaxOp //===----------------------------------------------------------------------===// -OpFoldResult MaxOp::fold(llvm::ArrayRef operands) { +OpFoldResult MaxOp::fold(FoldAdaptor adaptor) { // If operands are equal, just propagate one. if (getLhs() == getRhs()) return getLhs(); @@ -1628,7 +1628,7 @@ // MinOp //===----------------------------------------------------------------------===// -OpFoldResult MinOp::fold(llvm::ArrayRef operands) { +OpFoldResult MinOp::fold(FoldAdaptor adaptor) { // If operands are equal, just propagate one. if (getLhs() == getRhs()) return getLhs(); @@ -1660,11 +1660,11 @@ // MulOp //===----------------------------------------------------------------------===// -OpFoldResult MulOp::fold(ArrayRef operands) { - auto lhs = operands[0].dyn_cast_or_null(); +OpFoldResult MulOp::fold(FoldAdaptor adaptor) { + auto lhs = adaptor.getLhs().dyn_cast_or_null(); if (!lhs) return nullptr; - auto rhs = operands[1].dyn_cast_or_null(); + auto rhs = adaptor.getRhs().dyn_cast_or_null(); if (!rhs) return nullptr; APInt folded = lhs.getValue() * rhs.getValue(); @@ -1695,7 +1695,7 @@ // ShapeOfOp //===----------------------------------------------------------------------===// -OpFoldResult ShapeOfOp::fold(ArrayRef) { +OpFoldResult ShapeOfOp::fold(FoldAdaptor) { auto type = getOperand().getType().dyn_cast(); if (!type || !type.hasStaticShape()) return nullptr; @@ -1805,10 +1805,10 @@ // SizeToIndexOp //===----------------------------------------------------------------------===// -OpFoldResult SizeToIndexOp::fold(ArrayRef operands) { +OpFoldResult SizeToIndexOp::fold(FoldAdaptor adaptor) { // Constant values of both types, `shape.size` and `index`, are represented as // `IntegerAttr`s which makes constant folding simple. - if (Attribute arg = operands[0]) + if (Attribute arg = adaptor.getArg()) return arg; return OpFoldResult(); } @@ -1847,14 +1847,14 @@ // SplitAtOp //===----------------------------------------------------------------------===// -LogicalResult SplitAtOp::fold(ArrayRef operands, +LogicalResult SplitAtOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { - if (!operands[0] || !operands[1]) + if (!adaptor.getOperand() || !adaptor.getIndex()) return failure(); auto shapeVec = llvm::to_vector<6>( - operands[0].cast().getValues()); + adaptor.getOperand().cast().getValues()); auto shape = llvm::ArrayRef(shapeVec); - auto splitPoint = operands[1].cast().getInt(); + auto splitPoint = adaptor.getIndex().cast().getInt(); // Verify that the split point is in the correct range. // TODO: Constant fold to an "error". int64_t rank = shape.size(); @@ -1862,7 +1862,7 @@ return failure(); if (splitPoint < 0) splitPoint += shape.size(); - Builder builder(operands[0].getContext()); + Builder builder(adaptor.getOperand().getContext()); results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); return success(); @@ -1872,12 +1872,12 @@ // ToExtentTensorOp //===----------------------------------------------------------------------===// -OpFoldResult ToExtentTensorOp::fold(ArrayRef operands) { - if (!operands[0]) +OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) { + if (!adaptor.getInput()) return OpFoldResult(); Builder builder(getContext()); auto shape = llvm::to_vector<6>( - operands[0].cast().getValues()); + adaptor.getInput().cast().getValues()); auto type = RankedTensorType::get({static_cast(shape.size())}, builder.getIndexType()); return DenseIntElementsAttr::get(type, shape);