diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -194,7 +194,7 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ return getNumIterators(getParallelIteratorTypeName(), - $_op.iterator_types()); + $_op.getIteratorTypesArray()); }] >, InterfaceMethod< @@ -206,7 +206,7 @@ /*args=*/(ins "SmallVectorImpl &":$res), /*methodBody=*/"", /*defaultImplementation=*/[{ - return findPositionsOfType($_op.iterator_types(), + return findPositionsOfType($_op.getIteratorTypesArray(), getParallelIteratorTypeName(), res); }] >, @@ -220,7 +220,7 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ return getNumIterators(getReductionIteratorTypeName(), - $_op.iterator_types()); + $_op.getIteratorTypesArray()); }] >, InterfaceMethod< @@ -232,7 +232,7 @@ /*args=*/(ins "SmallVectorImpl &":$res), /*methodBody=*/"", /*defaultImplementation=*/[{ - return findPositionsOfType($_op.iterator_types(), + return findPositionsOfType($_op.getIteratorTypesArray(), getReductionIteratorTypeName(), res); }] >, @@ -246,7 +246,7 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ return getNumIterators(getWindowIteratorTypeName(), - $_op.iterator_types()); + $_op.getIteratorTypesArray()); }] >, InterfaceMethod< @@ -258,7 +258,7 @@ /*args=*/(ins "SmallVectorImpl &":$res), /*methodBody=*/"", /*defaultImplementation=*/[{ - return findPositionsOfType($_op.iterator_types(), + return findPositionsOfType($_op.getIteratorTypesArray(), getWindowIteratorTypeName(), res); }] >, @@ -271,7 +271,7 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return getNumIterators($_op.iterator_types()); + return getNumIterators($_op.getIteratorTypesArray()); }] >, InterfaceMethod< @@ -284,7 +284,7 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - auto iters = $_op.iterator_types(); + auto iters = $_op.getIteratorTypesArray(); return iters.size() == 1 && getNumIterators(getReductionIteratorTypeName(), iters) == 1; }]>, @@ -759,7 +759,7 @@ ArrayAttr getIteratorTypes() { return iterator_types(); } SmallVector getIteratorTypeNames() { - return llvm::to_vector(getIteratorTypes().getAsValueRange()); + return getIteratorTypesArray(); } //========================================================================// diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -45,11 +45,11 @@ /// `[0, permutation.size())`. bool isPermutation(ArrayRef permutation); -/// Check if `attr` has "parallel" iterator type semantics. -bool isParallelIterator(Attribute attr); +/// Check if iterator type has "parallel" semantics. +bool isParallelIterator(StringRef iteratorType); -/// Check if `attr` has "reduction" iterator type semantics. -bool isReductionIterator(Attribute attr); +/// Check if iterator type has "reduction" semantics. +bool isReductionIterator(StringRef iteratorType); /// Helper function that creates a memref::DimOp or tensor::DimOp depending on /// the type of `source`. @@ -488,7 +488,7 @@ template struct GenerateLoopNest { static void doit(OpBuilder &b, Location loc, ArrayRef loopRanges, - LinalgOp linalgOp, ArrayRef iteratorTypes, + LinalgOp linalgOp, ArrayRef iteratorTypes, function_ref bodyBuilderFn, diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -97,16 +97,15 @@ } /// Returns the iterator of a certain type. -inline unsigned getNumIterators(StringRef name, ArrayAttr iteratorTypes) { +inline unsigned getNumIterators(StringRef name, + ArrayRef iteratorTypes) { auto names = getAllIteratorTypeNames(); (void)names; assert(llvm::is_contained(names, name)); - return llvm::count_if(iteratorTypes, [name](Attribute a) { - return a.cast().getValue() == name; - }); + return llvm::count(iteratorTypes, name); } -inline unsigned getNumIterators(ArrayAttr iteratorTypes) { +inline unsigned getNumIterators(ArrayRef iteratorTypes) { unsigned res = 0; for (auto n : getAllIteratorTypeNames()) res += getNumIterators(n, iteratorTypes); @@ -114,11 +113,10 @@ } /// Return positions in `iteratorTypes` that match `iteratorTypeName`. -inline void findPositionsOfType(ArrayAttr iteratorTypes, +inline void findPositionsOfType(ArrayRef iteratorTypes, StringRef iteratorTypeName, SmallVectorImpl &res) { - for (const auto &en : - llvm::enumerate(iteratorTypes.getAsValueRange())) { + for (const auto &en : llvm::enumerate(iteratorTypes)) { if (en.value() == iteratorTypeName) res.push_back(en.index()); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -297,7 +297,7 @@ !indexingMaps.back().isProjectedPermutation()) return MatchConvolutionResult::NotProjectedPermutations; - auto iteratorTypesRange = linalgOp.getIteratorTypesArray(); + auto iteratorTypes = linalgOp.getIteratorTypesArray(); llvm::SmallDenseSet outputDims = getPreservedDims(indexingMaps.back()); @@ -321,8 +321,7 @@ if (inputExprWalker.unConvolvedDims.count(outputDim) && !filterDims.count(outputDim)) { // Batch dimension. - if (*std::next(iteratorTypesRange.begin(), outputDim) != - getParallelIteratorTypeName()) + if (iteratorTypes[outputDim] != getParallelIteratorTypeName()) return MatchConvolutionResult::OutputDimsNotParallel; allLoopDims.insert(outputDim); continue; @@ -330,8 +329,7 @@ if (inputExprWalker.convolvedDims.count(outputDim) && !filterDims.count(outputDim)) { // Output image Loop dimension. - if (*std::next(iteratorTypesRange.begin(), outputDim) != - getParallelIteratorTypeName()) + if (iteratorTypes[outputDim] != getParallelIteratorTypeName()) return MatchConvolutionResult::OutputDimsNotParallel; allLoopDims.insert(outputDim); continue; @@ -340,8 +338,7 @@ !inputExprWalker.unConvolvedDims.count(outputDim) && filterDims.count(outputDim)) { // Output channel dimension. - if (*std::next(iteratorTypesRange.begin(), outputDim) != - getParallelIteratorTypeName()) + if (iteratorTypes[outputDim] != getParallelIteratorTypeName()) return MatchConvolutionResult::OutputDimsNotParallel; allLoopDims.insert(outputDim); continue; @@ -349,8 +346,7 @@ if (inputExprWalker.unConvolvedDims.count(outputDim) && filterDims.count(outputDim)) { // Depth multiplier. - if (*std::next(iteratorTypesRange.begin(), outputDim) != - getParallelIteratorTypeName()) + if (iteratorTypes[outputDim] != getParallelIteratorTypeName()) return MatchConvolutionResult::OutputDimsNotParallel; allLoopDims.insert(outputDim); continue; @@ -368,8 +364,7 @@ if (inputExprWalker.convolvedDims.count(filterDim) && !outputDims.count(filterDim)) { // Filter loop dimension. - if (*std::next(iteratorTypesRange.begin(), filterDim) != - getReductionIteratorTypeName()) + if (iteratorTypes[filterDim] != getReductionIteratorTypeName()) return MatchConvolutionResult::NonOutputDimNotReduction; if (allLoopDims.count(filterDim)) return MatchConvolutionResult::NonConvolutionLoop; @@ -379,8 +374,7 @@ if (inputExprWalker.unConvolvedDims.count(filterDim) && !outputDims.count(filterDim)) { // Input channel dimension. - if (*std::next(iteratorTypesRange.begin(), filterDim) != - getReductionIteratorTypeName()) + if (iteratorTypes[filterDim] != getReductionIteratorTypeName()) return MatchConvolutionResult::NonOutputDimNotReduction; if (allLoopDims.count(filterDim)) return MatchConvolutionResult::NonConvolutionLoop; @@ -634,8 +628,7 @@ LinalgOp linalgOp = cast(op); // Check all iterator types are known. - auto iteratorTypesRange = - linalgOp.iterator_types().getAsValueRange(); + auto iteratorTypesRange = linalgOp.getIteratorTypesArray(); for (StringRef iteratorType : iteratorTypesRange) { if (!llvm::is_contained(getAllIteratorTypeNames(), iteratorType) || !utils::symbolizeIteratorType(iteratorType).has_value()) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1055,17 +1055,12 @@ } // Compute all the loops with the reduction iterator types. - SmallVector reductionDims; - for (const auto &iteratorType : - llvm::enumerate(genericOp.getIteratorTypes())) { - if (isReductionIterator(iteratorType.value())) { - reductionDims.push_back(iteratorType.index()); - } - } + SmallVector reductionDims; + genericOp.getReductionDims(reductionDims); llvm::SmallDenseSet processedIterationDims; AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand); - auto iteratorTypes = genericOp.getIteratorTypes().getValue(); + auto iteratorTypes = genericOp.getIteratorTypesArray(); SmallVector iterationSpaceReassociation; for (ReassociationIndicesRef foldedRangeDims : reassociation) { assert(!foldedRangeDims.empty() && "unexpected empty reassociation"); @@ -1085,7 +1080,7 @@ continue; // Check that all folded iterator types are all parallel or all reductions. - Attribute startIteratorType = iteratorTypes[foldedIterationSpaceDims[0]]; + StringRef startIteratorType = iteratorTypes[foldedIterationSpaceDims[0]]; if (!isParallelIterator(startIteratorType) && !isReductionIterator(startIteratorType)) continue; diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -433,12 +433,9 @@ // Search the number of outer parallel loops to separate them from possible // inner reduction dimensions. - SmallVector iterTypes = - llvm::to_vector<6>(consumerOp.iterator_types().getAsRange()); + SmallVector iterTypes = consumerOp.getIteratorTypesArray(); applyPermutationToVector(iterTypes, tileInterchange); - auto *it = find_if(iterTypes, [&](StringAttr iterType) { - return !isParallelIterator(iterType); - }); + auto *it = find_if_not(iterTypes, isParallelIterator); int64_t split = std::distance(iterTypes.begin(), it); // Helper to fuse the producers greedily using a queue of fusion candidates. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -217,7 +217,7 @@ "expected linalg op with buffer semantics"); auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc()); - auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue()); + auto iteratorTypes = linalgOp.getIteratorTypesArray(); SmallVector allIvs; GenerateLoopNest::doit( diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -215,10 +215,10 @@ newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr, op.getContext())); SmallVector newIteratorTypes; - for (auto &it : llvm::enumerate(op.iterator_types())) { + for (auto &it : llvm::enumerate(op.getIteratorTypesArray())) { if (insertSplitDimension == it.index() && !control.innerParallel) newIteratorTypes.push_back(getParallelIteratorTypeName()); - newIteratorTypes.push_back(it.value().cast().getValue()); + newIteratorTypes.push_back(it.value()); if (insertSplitDimension == it.index() && control.innerParallel) newIteratorTypes.push_back(getParallelIteratorTypeName()); } @@ -413,8 +413,7 @@ // Step 4. Create the new op matching the original op with an extra parallel // dimension. - SmallVector iteratorTypes = - llvm::to_vector<4>(op.getIteratorTypes().getAsValueRange()); + auto iteratorTypes = op.getIteratorTypesArray(); iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos, getParallelIteratorTypeName()); GenericOp genericOp = diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -427,9 +427,8 @@ auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges( b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes); - SmallVector iteratorTypes; - for (const auto &attr : - enumerate(op.iterator_types().cast().getValue())) { + SmallVector iteratorTypes; + for (const auto &attr : enumerate(op.getIteratorTypesArray())) { if (loopIndexToRangeIndex.count(attr.index())) iteratorTypes.push_back(attr.value()); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -190,14 +190,8 @@ } static SmallVector getReductionMask(LinalgOp linalgOp) { - unsigned idx = 0; - SmallVector reductionMask(linalgOp.iterator_types().size(), false); - for (auto attr : linalgOp.iterator_types()) { - if (isReductionIterator(attr)) - reductionMask[idx] = true; - ++idx; - } - return reductionMask; + return llvm::to_vector( + llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator)); } /// Build a vector.transfer_write of `value` into `outputOperand` at indices set @@ -540,7 +534,7 @@ // TODO: probably need some extra checks for reduction followed by consumer // ops that may not commute (e.g. linear reduction + non-linear instructions). static LogicalResult reductionPreconditions(LinalgOp op) { - if (llvm::none_of(op.iterator_types(), isReductionIterator)) { + if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) { LDBG("reduction precondition failed: no reduction iterator"); return failure(); } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -199,14 +199,12 @@ return count(indexCounts, 1) == static_cast(permutation.size()); } -bool isParallelIterator(Attribute attr) { - auto strAttr = attr.dyn_cast_or_null(); - return strAttr && strAttr.getValue() == getParallelIteratorTypeName(); +bool isParallelIterator(StringRef iteratorType) { + return iteratorType == getParallelIteratorTypeName(); } -bool isReductionIterator(Attribute attr) { - auto strAttr = attr.dyn_cast_or_null(); - return strAttr && strAttr.getValue() == getReductionIteratorTypeName(); +bool isReductionIterator(StringRef iteratorType) { + return iteratorType == getReductionIteratorTypeName(); } /// Helper function that creates a memref::DimOp or tensor::DimOp depending on @@ -484,7 +482,7 @@ template <> void GenerateLoopNest::doit( OpBuilder &b, Location loc, ArrayRef loopRanges, LinalgOp linalgOp, - ArrayRef iteratorTypes, + ArrayRef iteratorTypes, function_ref bodyBuilderFn, @@ -527,7 +525,7 @@ template <> void GenerateLoopNest::doit( OpBuilder &b, Location loc, ArrayRef loopRanges, LinalgOp linalgOp, - ArrayRef iteratorTypes, + ArrayRef iteratorTypes, function_ref bodyBuilderFn, @@ -577,7 +575,7 @@ // exceeds 10. static void generateParallelLoopNest( OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs, - ValueRange steps, ArrayRef iteratorTypes, + ValueRange steps, ArrayRef iteratorTypes, ArrayRef procInfo, function_ref bodyBuilderFn, SmallVectorImpl &ivStorage) { @@ -692,7 +690,7 @@ template <> void GenerateLoopNest::doit( OpBuilder &b, Location loc, ArrayRef loopRanges, LinalgOp linalgOp, - ArrayRef iteratorTypes, + ArrayRef iteratorTypes, function_ref bodyBuilderFn, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -214,7 +214,7 @@ /// as we use adj matrix for the graph. /// The sorted result will put the first Reduction iterator to the /// latest possible index. -static bool topSortOptimal(unsigned n, ArrayRef iteratorTypes, +static bool topSortOptimal(unsigned n, ArrayRef iteratorTypes, std::vector &topSort, std::vector &inDegree, std::vector> &adjM) { @@ -289,7 +289,7 @@ unsigned n = op.getNumLoops(); std::vector> adjM(n, std::vector(n, false)); std::vector inDegree(n, 0); // in-degree of each node. - auto iteratorTypes = op.iterator_types().getValue(); + auto iteratorTypes = op.getIteratorTypesArray(); // Iterate over the indexing maps of every tensor in the tensor expression. for (OpOperand *t : op.getInputAndOutputOperands()) { // Skip tensor during cycle resolution. @@ -361,7 +361,7 @@ // An all-dense annotated "sparse" output tensor becomes a linearized random // access 1-dim memref. Also admissible since insertions cannot occur. bool allDense = true; - auto iteratorTypes = op.iterator_types().getValue(); + auto iteratorTypes = op.getIteratorTypesArray(); unsigned numLoops = iteratorTypes.size(); for (unsigned i = 0; i < numLoops; i++) if (merger.isDimLevelType(tensor, i, DimLvlType::kCompressed) || @@ -1299,7 +1299,7 @@ unsigned fb = indices.find_first(); unsigned tensor = merger.tensor(fb); assert(idx == merger.index(fb)); - auto iteratorTypes = op.iterator_types().getValue(); + auto iteratorTypes = op.getIteratorTypesArray(); bool isReduction = linalg::isReductionIterator(iteratorTypes[idx]); bool isSparse = merger.isDimLevelType(fb, DimLvlType::kCompressed) || merger.isDimLevelType(fb, DimLvlType::kSingleton);