diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -30,6 +30,7 @@ let useDefaultAttributePrinterParser = 1; let hasConstantMaterializer = 1; let dependentDialects = ["arith::ArithDialect"]; + let useFoldAPI = kEmitFoldAdaptorFolder; } // Base class for Vector dialect ops. diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -308,7 +308,7 @@ builder.getI64ArrayAttr(reductionDims)); } -OpFoldResult MultiDimReductionOp::fold(ArrayRef operands) { +OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) { // Single parallel dim, this is a noop. if (getSourceVectorType().getRank() == 1 && !isReducedDim(0)) return getSource(); @@ -1035,13 +1035,13 @@ return success(); } -OpFoldResult vector::ExtractElementOp::fold(ArrayRef operands) { +OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) { // Skip the 0-D vector here now. - if (operands.size() < 2) + if (!adaptor.getPosition()) return {}; - Attribute src = operands[0]; - Attribute pos = operands[1]; + Attribute src = adaptor.getVector(); + Attribute pos = adaptor.getPosition(); // Fold extractelement (splat X) -> X. if (auto splat = getVector().getDefiningOp()) @@ -1587,7 +1587,7 @@ return Value(); } -OpFoldResult ExtractOp::fold(ArrayRef) { +OpFoldResult ExtractOp::fold(FoldAdaptor) { if (getPosition().empty()) return getVector(); if (succeeded(foldExtractOpFromExtractChain(*this))) @@ -1918,15 +1918,15 @@ llvm_unreachable("unexpected vector.broadcast op error"); } -OpFoldResult BroadcastOp::fold(ArrayRef operands) { +OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { if (getSourceType() == getVectorType()) return getSource(); - if (!operands[0]) + if (!adaptor.getSource()) return {}; auto vectorType = getVectorType(); - if (operands[0].isa()) - return DenseElementsAttr::get(vectorType, operands[0]); - if (auto attr = operands[0].dyn_cast()) + if (adaptor.getSource().isa()) + return DenseElementsAttr::get(vectorType, adaptor.getSource()); + if (auto attr = adaptor.getSource().dyn_cast()) return DenseElementsAttr::get(vectorType, attr.getSplatValue()); return {}; } @@ -2034,7 +2034,7 @@ }); } -OpFoldResult vector::ShuffleOp::fold(ArrayRef operands) { +OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) { VectorType v1Type = getV1VectorType(); // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding // but must be a canonicalization into a vector.broadcast. @@ -2051,7 +2051,7 @@ getV2VectorType().getDimSize(0))) return getV2(); - Attribute lhs = operands.front(), rhs = operands.back(); + Attribute lhs = adaptor.getV1(), rhs = adaptor.getV2(); if (!lhs || !rhs) return {}; @@ -2154,14 +2154,14 @@ return success(); } -OpFoldResult vector::InsertElementOp::fold(ArrayRef operands) { +OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) { // Skip the 0-D vector here. - if (operands.size() < 3) + if (!adaptor.getPosition()) return {}; - Attribute src = operands[0]; - Attribute dst = operands[1]; - Attribute pos = operands[2]; + Attribute src = adaptor.getSource(); + Attribute dst = adaptor.getDest(); + Attribute pos = adaptor.getPosition(); if (!src || !dst || !pos) return {}; @@ -2335,7 +2335,7 @@ // Eliminates insert operations that produce values identical to their source // value. This happens when the source and destination vectors have identical // sizes. -OpFoldResult vector::InsertOp::fold(ArrayRef operands) { +OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) { if (getPosition().empty()) return getSource(); return {}; @@ -2621,7 +2621,7 @@ InsertStridedSliceConstantFolder>(context); } -OpFoldResult InsertStridedSliceOp::fold(ArrayRef operands) { +OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) { if (getSourceVectorType() == getDestVectorType()) return getSource(); return {}; @@ -2929,7 +2929,7 @@ return failure(); } -OpFoldResult ExtractStridedSliceOp::fold(ArrayRef operands) { +OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) { if (getVectorType() == getResult().getType()) return getVector(); if (succeeded(foldExtractStridedOpFromInsertChain(*this))) @@ -3564,7 +3564,7 @@ return {}; } -OpFoldResult TransferReadOp::fold(ArrayRef) { +OpFoldResult TransferReadOp::fold(FoldAdaptor) { if (Value vec = foldRAW(*this)) return vec; /// transfer_read(memrefcast) -> transfer_read @@ -4039,9 +4039,9 @@ return success(); } -LogicalResult TransferWriteOp::fold(ArrayRef operands, +LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { - if (succeeded(foldReadInitWrite(*this, operands, results))) + if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results))) return success(); if (succeeded(foldWAR(*this, results))) return success(); @@ -4346,7 +4346,7 @@ return success(); } -OpFoldResult LoadOp::fold(ArrayRef) { +OpFoldResult LoadOp::fold(FoldAdaptor) { if (succeeded(memref::foldMemRefCast(*this))) return getResult(); return OpFoldResult(); @@ -4379,7 +4379,7 @@ return success(); } -LogicalResult StoreOp::fold(ArrayRef operands, +LogicalResult StoreOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { return memref::foldMemRefCast(*this); } @@ -4432,7 +4432,7 @@ results.add(context); } -OpFoldResult MaskedLoadOp::fold(ArrayRef) { +OpFoldResult MaskedLoadOp::fold(FoldAdaptor) { if (succeeded(memref::foldMemRefCast(*this))) return getResult(); return OpFoldResult(); @@ -4483,7 +4483,7 @@ results.add(context); } -LogicalResult MaskedStoreOp::fold(ArrayRef operands, +LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { return memref::foldMemRefCast(*this); } @@ -4754,7 +4754,7 @@ return success(); } -OpFoldResult ShapeCastOp::fold(ArrayRef operands) { +OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { // No-op shape cast. if (getSource().getType() == getResult().getType()) return getSource(); @@ -4888,7 +4888,7 @@ return success(); } -OpFoldResult BitCastOp::fold(ArrayRef operands) { +OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) { // Nop cast. if (getSource().getType() == getResult().getType()) return getSource(); @@ -4902,7 +4902,7 @@ return getResult(); } - Attribute sourceConstant = operands.front(); + Attribute sourceConstant = adaptor.getSource(); if (!sourceConstant) return {}; @@ -4995,9 +4995,9 @@ result.addAttribute(getTranspAttrStrName(), builder.getI64ArrayAttr(transp)); } -OpFoldResult vector::TransposeOp::fold(ArrayRef operands) { +OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) { // Eliminate splat constant transpose ops. - if (auto attr = operands.front().dyn_cast_or_null()) + if (auto attr = adaptor.getVector().dyn_cast_or_null()) if (attr.isSplat()) return attr.reshape(getResultType()); @@ -5495,8 +5495,8 @@ // SplatOp //===----------------------------------------------------------------------===// -OpFoldResult SplatOp::fold(ArrayRef operands) { - auto constOperand = operands.front(); +OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { + auto constOperand = adaptor.getInput(); if (!constOperand.isa_and_nonnull()) return {};