Index: mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td =================================================================== --- mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td +++ mlir/include/mlir/Dialect/Tensor/IR/TensorBase.td @@ -47,6 +47,7 @@ let hasCanonicalizer = 1; let hasConstantMaterializer = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; let dependentDialects = [ "AffineDialect", "arith::ArithDialect", Index: mlir/lib/Dialect/Tensor/IR/TensorOps.cpp =================================================================== --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -417,9 +417,9 @@ return success(); } -OpFoldResult DimOp::fold(ArrayRef operands) { +OpFoldResult DimOp::fold(FoldAdaptor adaptor) { // All forms of folding require a known index. - auto index = operands[1].dyn_cast_or_null(); + auto index = adaptor.getIndex().dyn_cast_or_null(); if (!index) return {}; @@ -762,16 +762,16 @@ return success(); } -OpFoldResult ExtractOp::fold(ArrayRef operands) { +OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { // If this is a splat elements attribute, simply return the value. All of // the elements of a splat attribute are the same. - if (Attribute tensor = operands.front()) + if (Attribute tensor = adaptor.getTensor()) if (auto splatTensor = tensor.dyn_cast()) return splatTensor.getSplatValue(); // Collect the constant indices into the tensor. SmallVector indices; - for (Attribute indice : llvm::drop_begin(operands, 1)) { + for (Attribute indice : adaptor.getIndices()) { if (!indice || !indice.isa()) return {}; indices.push_back(indice.cast().getInt()); @@ -799,7 +799,7 @@ } // If this is an elements attribute, query the value at the given indices. - if (Attribute tensor = operands.front()) { + if (Attribute tensor = adaptor.getTensor()) { auto elementsAttr = tensor.dyn_cast(); if (elementsAttr && elementsAttr.isValidIndex(indices)) return elementsAttr.getValues()[indices]; @@ -836,9 +836,9 @@ build(builder, result, resultType, elements); } -OpFoldResult FromElementsOp::fold(ArrayRef operands) { - if (!llvm::is_contained(operands, nullptr)) - return DenseElementsAttr::get(getType(), operands); +OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) { + if (!llvm::is_contained(adaptor.getElements(), nullptr)) + return DenseElementsAttr::get(getType(), adaptor.getElements()); return {}; } @@ -995,9 +995,9 @@ return success(); } -OpFoldResult InsertOp::fold(ArrayRef operands) { - Attribute scalar = operands[0]; - Attribute dest = operands[1]; +OpFoldResult InsertOp::fold(FoldAdaptor adaptor) { + Attribute scalar = adaptor.getScalar(); + Attribute dest = adaptor.getDest(); if (scalar && dest) if (auto splatDest = dest.dyn_cast()) if (scalar == splatDest.getSplatValue()) @@ -1177,7 +1177,7 @@ setNameFn(getResult(), "rank"); } -OpFoldResult RankOp::fold(ArrayRef operands) { +OpFoldResult RankOp::fold(FoldAdaptor adaptor) { // Constant fold rank when the rank of the operand is known. auto type = getOperand().getType(); auto shapedType = type.dyn_cast(); @@ -1557,12 +1557,14 @@ context); } -OpFoldResult ExpandShapeOp::fold(ArrayRef operands) { - return foldReshapeOp(*this, operands); +OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) { + return foldReshapeOp(*this, + adaptor.getOperands()); } -OpFoldResult CollapseShapeOp::fold(ArrayRef operands) { - return foldReshapeOp(*this, operands); +OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) { + return foldReshapeOp(*this, + adaptor.getOperands()); } //===----------------------------------------------------------------------===// @@ -2049,8 +2051,8 @@ return {}; } -OpFoldResult ExtractSliceOp::fold(ArrayRef operands) { - if (auto splat = operands[0].dyn_cast_or_null()) { +OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) { + if (auto splat = adaptor.getSource().dyn_cast_or_null()) { auto resultType = getResult().getType().cast(); if (resultType.hasStaticShape()) return splat.resizeSplat(resultType); @@ -2196,7 +2198,7 @@ return extractOp.getSource(); } -OpFoldResult InsertSliceOp::fold(ArrayRef) { +OpFoldResult InsertSliceOp::fold(FoldAdaptor) { if (getSourceType().hasStaticShape() && getType().hasStaticShape() && getSourceType() == getType() && succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) @@ -2868,7 +2870,7 @@ return padValue; } -OpFoldResult PadOp::fold(ArrayRef) { +OpFoldResult PadOp::fold(FoldAdaptor) { if (getResultType().hasStaticShape() && getResultType() == getSourceType() && !getNofold()) return getSource(); @@ -3003,8 +3005,8 @@ setNameFn(getResult(), "splat"); } -OpFoldResult SplatOp::fold(ArrayRef operands) { - auto constOperand = operands.front(); +OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { + auto constOperand = adaptor.getInput(); if (!constOperand.isa_and_nonnull()) return {};