diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -497,28 +497,21 @@ return $_op.getBody(); }] >, - InterfaceMethod< - /*desc=*/[{ - Return the iterator types attribute within the current operation. - }], - /*retTy=*/"ArrayAttr", - /*methodName=*/"iterator_types", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return $_op.getIteratorTypes(); - }] - >, InterfaceMethod< /*desc=*/[{ Return iterator types in the current operation. + + Default implementation assumes that the operation has an attribute + `iterator_types`, but it's not always the case. Sometimes iterator types + can be infered from other parameters and in such cases default + getIteratorTypesArray should be overriden. }], /*retTy=*/"SmallVector", /*methodName=*/"getIteratorTypesArray", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - auto range = $_op.iterator_types().template getAsValueRange(); + auto range = $_op.getIteratorTypes().template getAsValueRange(); return {range.begin(), range.end()}; }] >, @@ -773,9 +766,6 @@ LogicalResult reifyResultShapes(OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes); - // TODO: Remove once prefixing is flipped. - ArrayAttr getIteratorTypes() { return iterator_types(); } - SmallVector getIteratorTypeNames() { return getIteratorTypesArray(); } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -264,7 +264,7 @@ let extraClassDeclaration = structuredOpsBaseDecls # [{ // Implement functions necessary for LinalgStructuredInterface. - ArrayAttr getIteratorTypes(); + SmallVector getIteratorTypesArray(); ArrayAttr getIndexingMaps(); std::string getLibraryCallName() { return "op_has_no_registered_library_name"; @@ -334,7 +334,7 @@ let extraClassDeclaration = structuredOpsBaseDecls # [{ // Declare functions necessary for LinalgStructuredInterface. - ArrayAttr getIteratorTypes(); + SmallVector getIteratorTypesArray(); ArrayAttr getIndexingMaps(); std::string getLibraryCallName() { return "op_has_no_registered_library_name"; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1393,11 +1393,9 @@ return success(); } -ArrayAttr MapOp::getIteratorTypes() { +SmallVector MapOp::getIteratorTypesArray() { int64_t rank = getInit().getType().getRank(); - return Builder(getContext()) - .getStrArrayAttr( - SmallVector(rank, getParallelIteratorTypeName())); + return SmallVector(rank, getParallelIteratorTypeName()); } ArrayAttr MapOp::getIndexingMaps() { @@ -1435,13 +1433,13 @@ setNameFn(getResults().front(), "reduced"); } -ArrayAttr ReduceOp::getIteratorTypes() { +SmallVector ReduceOp::getIteratorTypesArray() { int64_t inputRank = getInputs()[0].getType().cast().getRank(); SmallVector iteratorTypes(inputRank, getParallelIteratorTypeName()); for (int64_t reductionDim : getDimensions()) iteratorTypes[reductionDim] = getReductionIteratorTypeName(); - return Builder(getContext()).getStrArrayAttr(iteratorTypes); + return iteratorTypes; } ArrayAttr ReduceOp::getIndexingMaps() { diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -92,11 +92,9 @@ /// Return the loop iterator type. SmallVector getLoopIteratorTypes(Operation *op) const { LinalgOpTy concreteOp = cast(op); - return llvm::to_vector( - llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) { - return utils::symbolizeIteratorType( - strAttr.cast().getValue()) - .value(); + return llvm::to_vector(llvm::map_range( + concreteOp.getIteratorTypesArray(), [](StringRef iteratorType) { + return utils::symbolizeIteratorType(iteratorType).value(); })); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -250,7 +250,7 @@ // Fuse producer and consumer into a new generic op. auto fusedOp = rewriter.create( loc, op.getResult(0).getType(), inputOps, outputOps, - rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.iterator_types(), + rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(), /*doc=*/nullptr, /*library_call=*/nullptr); Block &prodBlock = prod.getRegion().front(); Block &consBlock = op.getRegion().front(); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1857,7 +1857,7 @@ if (op.getNumOutputs() != 1) return failure(); unsigned numTensors = op.getNumInputsAndOutputs(); - unsigned numLoops = op.iterator_types().getValue().size(); + unsigned numLoops = op.getNumLoops(); Merger merger(numTensors, numLoops); if (!findSparseAnnotations(merger, op)) return failure(); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2816,7 +2816,7 @@ return ®ionBuilder; } - mlir::ArrayAttr iterator_types() { + mlir::ArrayAttr getIteratorTypes() { return getOperation()->getAttrOfType("iterator_types"); } @@ -2875,7 +2875,7 @@ return ®ionBuilder; } - mlir::ArrayAttr iterator_types() { + mlir::ArrayAttr getIteratorTypes() { return getOperation()->getAttrOfType("iterator_types"); } diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -235,7 +235,7 @@ - !ScalarExpression scalar_arg: value -# IMPL: Test3Op::iterator_types() { +# IMPL: Test3Op::getIteratorTypesArray() { # IMPL-NEXT: int64_t rank = getRank(getOutputOperand(0)); # IMPL: Test3Op::getIndexingMaps() { diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -553,7 +553,7 @@ let extraClassDeclaration = structuredOpsBaseDecls # [{{ // Auto-generated. - ArrayAttr iterator_types(); + SmallVector getIteratorTypesArray(); ArrayAttr getIndexingMaps(); static void regionBuilder(ImplicitLocOpBuilder &b, Block &block, ArrayRef attrs); @@ -587,24 +587,24 @@ }]> )FMT"; -// The iterator_types() method for structured ops. Parameters: +// The getIteratorTypesArray() method for structured ops. Parameters: // {0}: Class name // {1}: Comma interleaved iterator type names. static const char structuredOpIteratorTypesFormat[] = R"FMT( -ArrayAttr {0}::iterator_types() {{ - return Builder(getContext()).getStrArrayAttr(SmallVector{{ {1} }); +SmallVector {0}::getIteratorTypesArray() {{ + return SmallVector{{ {1} }; } )FMT"; -// The iterator_types() method for rank polymorphic structured ops. Parameters: +// The getIteratorTypesArray() method for rank polymorphic structured ops. +// Parameters: // {0}: Class name static const char rankPolyStructuredOpIteratorTypesFormat[] = R"FMT( -ArrayAttr {0}::iterator_types() {{ +SmallVector {0}::getIteratorTypesArray() {{ int64_t rank = getRank(getOutputOperand(0)); - return Builder(getContext()).getStrArrayAttr( - SmallVector(rank, getParallelIteratorTypeName())); + return SmallVector(rank, getParallelIteratorTypeName()); } )FMT";