diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td @@ -47,6 +47,7 @@ let hasCanonicalizer = 1; let hasConstantMaterializer = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; let dependentDialects = [ "AffineDialect", "arith::ArithDialect", diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -418,9 +418,9 @@ return success(); } -OpFoldResult DimOp::fold(ArrayRef operands) { +OpFoldResult DimOp::fold(FoldAdaptor adaptor) { // All forms of folding require a known index. - auto index = operands[1].dyn_cast_or_null(); + auto index = adaptor.getIndex().dyn_cast_or_null(); if (!index) return {}; @@ -763,16 +763,16 @@ return success(); } -OpFoldResult ExtractOp::fold(ArrayRef operands) { +OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { // If this is a splat elements attribute, simply return the value. All of // the elements of a splat attribute are the same. - if (Attribute tensor = operands.front()) + if (Attribute tensor = adaptor.getTensor()) if (auto splatTensor = tensor.dyn_cast()) return splatTensor.getSplatValue(); // Collect the constant indices into the tensor. SmallVector indices; - for (Attribute indice : llvm::drop_begin(operands, 1)) { + for (Attribute indice : adaptor.getIndices()) { if (!indice || !indice.isa()) return {}; indices.push_back(indice.cast().getInt()); @@ -800,7 +800,7 @@ } // If this is an elements attribute, query the value at the given indices. - if (Attribute tensor = operands.front()) { + if (Attribute tensor = adaptor.getTensor()) { auto elementsAttr = tensor.dyn_cast(); if (elementsAttr && elementsAttr.isValidIndex(indices)) return elementsAttr.getValues()[indices]; @@ -837,9 +837,9 @@ build(builder, result, resultType, elements); } -OpFoldResult FromElementsOp::fold(ArrayRef operands) { - if (!llvm::is_contained(operands, nullptr)) - return DenseElementsAttr::get(getType(), operands); +OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) { + if (!llvm::is_contained(adaptor.getElements(), nullptr)) + return DenseElementsAttr::get(getType(), adaptor.getElements()); return {}; } @@ -996,9 +996,9 @@ return success(); } -OpFoldResult InsertOp::fold(ArrayRef operands) { - Attribute scalar = operands[0]; - Attribute dest = operands[1]; +OpFoldResult InsertOp::fold(FoldAdaptor adaptor) { + Attribute scalar = adaptor.getScalar(); + Attribute dest = adaptor.getDest(); if (scalar && dest) if (auto splatDest = dest.dyn_cast()) if (scalar == splatDest.getSplatValue()) @@ -1178,7 +1178,7 @@ setNameFn(getResult(), "rank"); } -OpFoldResult RankOp::fold(ArrayRef operands) { +OpFoldResult RankOp::fold(FoldAdaptor adaptor) { // Constant fold rank when the rank of the operand is known. auto type = getOperand().getType(); auto shapedType = type.dyn_cast(); @@ -1558,12 +1558,14 @@ context); } -OpFoldResult ExpandShapeOp::fold(ArrayRef operands) { - return foldReshapeOp(*this, operands); +OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) { + return foldReshapeOp(*this, + adaptor.getOperands()); } -OpFoldResult CollapseShapeOp::fold(ArrayRef operands) { - return foldReshapeOp(*this, operands); +OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) { + return foldReshapeOp(*this, + adaptor.getOperands()); } //===----------------------------------------------------------------------===// @@ -2050,8 +2052,8 @@ return {}; } -OpFoldResult ExtractSliceOp::fold(ArrayRef operands) { - if (auto splat = operands[0].dyn_cast_or_null()) { +OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) { + if (auto splat = adaptor.getSource().dyn_cast_or_null()) { auto resultType = getResult().getType().cast(); if (resultType.hasStaticShape()) return splat.resizeSplat(resultType); @@ -2197,7 +2199,7 @@ return extractOp.getSource(); } -OpFoldResult InsertSliceOp::fold(ArrayRef) { +OpFoldResult InsertSliceOp::fold(FoldAdaptor) { if (getSourceType().hasStaticShape() && getType().hasStaticShape() && getSourceType() == getType() && succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) @@ -2869,7 +2871,7 @@ return padValue; } -OpFoldResult PadOp::fold(ArrayRef) { +OpFoldResult PadOp::fold(FoldAdaptor) { if (getResultType().hasStaticShape() && getResultType() == getSourceType() && !getNofold()) return getSource(); @@ -3004,8 +3006,8 @@ setNameFn(getResult(), "splat"); } -OpFoldResult SplatOp::fold(ArrayRef operands) { - auto constOperand = operands.front(); +OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { + auto constOperand = adaptor.getInput(); if (!constOperand.isa_and_nonnull()) return {};