diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -491,6 +491,19 @@ return $_op.iterator_types(); }] >, + InterfaceMethod< + /*desc=*/[{ + Return iterator types in the current operation. + }], + /*retTy=*/"SmallVector<StringRef>", + /*methodName=*/"getIteratorTypesArray", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto range = $_op.iterator_types().template getAsValueRange<StringAttr>(); + return {range.begin(), range.end()}; + }] + >, InterfaceMethod< /*desc=*/[{ Return true if the indexing map is depending on the current op instance. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -297,8 +297,7 @@ !indexingMaps.back().isProjectedPermutation()) return MatchConvolutionResult::NotProjectedPermutations; - auto iteratorTypesRange = - linalgOp.iterator_types().getAsValueRange<StringAttr>(); + auto iteratorTypesRange = linalgOp.getIteratorTypesArray(); llvm::SmallDenseSet<unsigned> outputDims = getPreservedDims(indexingMaps.back()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -438,8 +438,7 @@ resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]); GenericOp replacementOp = rewriter.create<GenericOp>( loc, resultTypes, newInputs, newOutputs, newIndexingMaps, - llvm::to_vector<4>(genericOp.getIteratorTypes() - .template getAsValueRange<StringAttr>())); + genericOp.getIteratorTypesArray()); rewriter.inlineRegionBefore(genericOp.getRegion(), replacementOp.getRegion(), replacementOp.getRegion().begin()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -467,10 +467,10 @@ .isProjectedPermutation(); }) && genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 && - llvm::all_of(genericOp.getIteratorTypes(), [](Attribute attr) { - return attr.cast<StringAttr>().getValue() == - getParallelIteratorTypeName(); - }); + llvm::all_of(genericOp.getIteratorTypesArray(), + [](StringRef it) { + return it == getParallelIteratorTypeName(); + }); } namespace { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -53,8 +53,7 @@ SmallVector<Value> inputOperands = linalgOp.getInputOperands(); SmallVector<Value> outputOperands = linalgOp.getOutputOperands(); SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray(); - SmallVector<StringRef> iterators = llvm::to_vector<4>( - linalgOp.iterator_types().getAsValueRange<StringAttr>()); + SmallVector<StringRef> iterators = linalgOp.getIteratorTypesArray(); SmallVector<RankedTensorType> resultTypes = linalgOp.getOutputTensorTypes(); SmallVector<Type> types(resultTypes.begin(), resultTypes.end()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp @@ -61,9 +61,7 @@ SmallVector<Value> outputOperands = genericOp.getOutputOperands(); auto newOp = rewriter.create<GenericOp>( loc, genericOp->getResultTypes(), newOperands, outputOperands, - newIndexingMaps, - llvm::to_vector<4>(genericOp.getIteratorTypes() - .template getAsValueRange<StringAttr>())); + newIndexingMaps, genericOp.getIteratorTypesArray()); rewriter.cloneRegionBefore(genericOp.getRegion(), newOp.getRegion(), newOp.getRegion().begin());