diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -13,6 +13,7 @@ #ifndef LINALG_BASE #define LINALG_BASE +include "mlir/Dialect/Utils/StructuredOpsUtils.td" include "mlir/Dialect/Linalg/IR/LinalgEnums.td" include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" @@ -71,4 +72,10 @@ let assemblyFormat = "`<` $value `>`"; } +def IteratorTypeEnum : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} +def IteratorTypeArrayAttr : TypedArrayAttrBase; + #endif // LINALG_BASE diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -25,6 +25,7 @@ namespace mlir { namespace linalg { +class IteratorTypeAttr; class LinalgOp; namespace detail { diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -193,8 +193,8 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return getNumIterators(getParallelIteratorTypeName(), - $_op.getIteratorTypesArray()); + return llvm::count($_op.getIteratorTypesArray(), + utils::IteratorType::parallel); }] >, InterfaceMethod< @@ -207,7 +207,7 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ return findPositionsOfType($_op.getIteratorTypesArray(), - getParallelIteratorTypeName(), res); + utils::IteratorType::parallel, res); }] >, InterfaceMethod< @@ -219,8 +219,8 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return getNumIterators(getReductionIteratorTypeName(), - $_op.getIteratorTypesArray()); + return llvm::count($_op.getIteratorTypesArray(), + utils::IteratorType::reduction); }] >, InterfaceMethod< @@ -233,33 +233,7 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ return findPositionsOfType($_op.getIteratorTypesArray(), - getReductionIteratorTypeName(), res); - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the number of window loops. - }], - /*retTy=*/"unsigned", - /*methodName=*/"getNumWindowLoops", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return getNumIterators(getWindowIteratorTypeName(), - $_op.getIteratorTypesArray()); - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the dims that are window loops. - }], - /*retTy=*/"void", - /*methodName=*/"getWindowDims", - /*args=*/(ins "SmallVectorImpl &":$res), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return findPositionsOfType($_op.getIteratorTypesArray(), - getWindowIteratorTypeName(), res); + utils::IteratorType::reduction, res); }] >, InterfaceMethod< @@ -271,7 +245,7 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return getNumIterators($_op.getIteratorTypesArray()); + return $_op.getIteratorTypesArray().size(); }] >, InterfaceMethod< @@ -286,7 +260,7 @@ /*defaultImplementation=*/[{ auto iters = $_op.getIteratorTypesArray(); return iters.size() == 1 && - getNumIterators(getReductionIteratorTypeName(), iters) == 1; + llvm::count(iters, utils::IteratorType::reduction) == 1; }]>, //===------------------------------------------------------------------===// // Input and Init arguments handling. @@ -506,12 +480,14 @@ can be infered from other parameters and in such cases default getIteratorTypesArray should be overriden. }], - /*retTy=*/"SmallVector", + /*retTy=*/"SmallVector", /*methodName=*/"getIteratorTypesArray", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - auto range = $_op.getIteratorTypes().template getAsValueRange(); + auto range = $_op.getIteratorTypes() + .template getAsValueRange(); return {range.begin(), range.end()}; }] >, @@ -767,10 +743,6 @@ LogicalResult reifyResultShapes(OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes); - SmallVector getIteratorTypeNames() { - return getIteratorTypesArray(); - } - //========================================================================// // Forwarding functions to access interface methods from the // DestinationStyleOpInterface. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -163,7 +163,7 @@ let arguments = (ins Variadic:$inputs, Variadic:$outputs, AffineMapArrayAttr:$indexing_maps, - ArrayAttr:$iterator_types, + IteratorTypeArrayAttr:$iterator_types, OptionalAttr:$doc, OptionalAttr:$library_call); let results = (outs Variadic:$result_tensors); @@ -178,22 +178,22 @@ CArg<"ArrayRef", "{}">:$attributes)>, OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, "ValueRange":$outputs, "ArrayRef":$indexingMaps, - "ArrayRef":$iteratorTypes, "StringRef":$doc, + "ArrayRef":$iteratorTypes, "StringRef":$doc, "StringRef":$libraryCall, CArg<"function_ref", "nullptr">, CArg<"ArrayRef", "{}">:$attributes)>, OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers, - "ArrayRef":$indexingMaps, "ArrayRef":$iteratorTypes, + "ArrayRef":$indexingMaps, "ArrayRef":$iteratorTypes, "StringRef":$doc, "StringRef":$libraryCall, CArg<"function_ref", "nullptr">, CArg<"ArrayRef", "{}">:$attributes)>, OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, "ValueRange":$outputs, "ArrayRef":$indexingMaps, - "ArrayRef":$iteratorTypes, + "ArrayRef":$iteratorTypes, CArg<"function_ref", "nullptr">, CArg<"ArrayRef", "{}">:$attributes)>, OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers, - "ArrayRef":$indexingMaps, "ArrayRef":$iteratorTypes, + "ArrayRef":$indexingMaps, "ArrayRef":$iteratorTypes, CArg<"function_ref", "nullptr">, CArg<"ArrayRef", "{}">:$attributes)> ]; @@ -275,7 +275,7 @@ let extraClassDeclaration = structuredOpsBaseDecls # [{ // Implement functions necessary for LinalgStructuredInterface. - SmallVector getIteratorTypesArray(); + SmallVector getIteratorTypesArray(); ArrayAttr getIndexingMaps(); std::string getLibraryCallName() { return "op_has_no_registered_library_name"; @@ -356,7 +356,7 @@ let extraClassDeclaration = structuredOpsBaseDecls # [{ // Declare functions necessary for LinalgStructuredInterface. - SmallVector getIteratorTypesArray(); + SmallVector getIteratorTypesArray(); ArrayAttr getIndexingMaps(); std::string getLibraryCallName() { return "op_has_no_registered_library_name"; @@ -426,7 +426,7 @@ let extraClassDeclaration = structuredOpsBaseDecls # [{ // Declare functions necessary for LinalgStructuredInterface. - SmallVector getIteratorTypesArray(); + SmallVector getIteratorTypesArray(); ArrayAttr getIndexingMaps(); std::string getLibraryCallName() { return "op_has_no_registered_library_name"; @@ -502,7 +502,7 @@ let extraClassDeclaration = structuredOpsBaseDecls # [{ // Declare functions necessary for LinalgStructuredInterface. - SmallVector getIteratorTypesArray(); + SmallVector getIteratorTypesArray(); ArrayAttr getIndexingMaps(); std::string getLibraryCallName() { return "op_has_no_registered_library_name"; diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -42,10 +42,10 @@ bool isElementwise(LinalgOp op); /// Check if iterator type has "parallel" semantics. -bool isParallelIterator(StringRef iteratorType); +bool isParallelIterator(utils::IteratorType iteratorType); /// Check if iterator type has "reduction" semantics. -bool isReductionIterator(StringRef iteratorType); +bool isReductionIterator(utils::IteratorType iteratorType); /// Helper function that creates a memref::DimOp or tensor::DimOp depending on /// the type of `source`. @@ -480,7 +480,8 @@ template struct GenerateLoopNest { static void doit(OpBuilder &b, Location loc, ArrayRef loopRanges, - LinalgOp linalgOp, ArrayRef iteratorTypes, + LinalgOp linalgOp, + ArrayRef iteratorTypes, function_ref bodyBuilderFn, diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h --- a/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h @@ -22,7 +22,8 @@ namespace tosa { // Creates a SmallVector of Stringrefs for N parallel loops -SmallVector getNParallelLoopsAttrs(unsigned nParallelLoops); +SmallVector +getNParallelLoopsAttrs(unsigned nParallelLoops); // Takes a vector of values and condenses them to a vector with no gaps. SmallVector condenseValues(const SmallVector &values); diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -21,7 +21,6 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Location.h" #include "mlir/Support/LLVM.h" -#include "llvm/ADT/StringRef.h" // Pull in all enum type definitions and utility function declarations. #include "mlir/Dialect/Utils/DialectUtilsEnums.h.inc" @@ -48,42 +47,9 @@ /// the reduction. bool isRowMajorBatchMatmul(ArrayAttr indexingMaps); -/// Use to encode that a particular iterator type has parallel semantics. -constexpr StringRef getParallelIteratorTypeName() { return "parallel"; } - -/// Use to encode that a particular iterator type has reduction semantics. -constexpr StringRef getReductionIteratorTypeName() { return "reduction"; } - -/// Use to encode that a particular iterator type has window semantics. -constexpr StringRef getWindowIteratorTypeName() { return "window"; } - -/// Use to encode that a particular iterator type has window semantics. -inline ArrayRef getAllIteratorTypeNames() { - static constexpr StringRef names[3] = {getParallelIteratorTypeName(), - getReductionIteratorTypeName(), - getWindowIteratorTypeName()}; - return llvm::makeArrayRef(names); -} - -/// Returns the iterator of a certain type. -inline unsigned getNumIterators(StringRef name, - ArrayRef iteratorTypes) { - auto names = getAllIteratorTypeNames(); - (void)names; - assert(llvm::is_contained(names, name)); - return llvm::count(iteratorTypes, name); -} - -inline unsigned getNumIterators(ArrayRef iteratorTypes) { - unsigned res = 0; - for (auto n : getAllIteratorTypeNames()) - res += getNumIterators(n, iteratorTypes); - return res; -} - /// Return positions in `iteratorTypes` that match `iteratorTypeName`. -inline void findPositionsOfType(ArrayRef iteratorTypes, - StringRef iteratorTypeName, +inline void findPositionsOfType(ArrayRef iteratorTypes, + utils::IteratorType iteratorTypeName, SmallVectorImpl &res) { for (const auto &en : llvm::enumerate(iteratorTypes)) { if (en.value() == iteratorTypeName) @@ -94,29 +60,28 @@ /// Helper StructuredGenerator class to manipulate and rewrite ops with /// `StructuredOpInterface`. This is templated for now because VectorOps do not /// yet implement the StructuredOpInterface itself. -template +template class StructuredGenerator { public: using MapList = ArrayRef>; struct IteratorType { - IteratorType(StringRef strRef) : strRef(strRef) {} - bool isOfType(StringRef typeName) const { return typeName == strRef; } - StringRef strRef; + IteratorType(IteratorTypeT iter) : iter(iter) {} + bool isOfType(IteratorTypeT expectedIter) const { + return expectedIter == iter; + } + IteratorTypeT iter; }; struct Par : public IteratorType { - Par() : IteratorType(getParallelIteratorTypeName()) {} + Par() : IteratorType(IteratorTypeT::parallel) {} }; struct Red : public IteratorType { - Red() : IteratorType(getReductionIteratorTypeName()) {} - }; - struct Win : public IteratorType { - Win() : IteratorType(getWindowIteratorTypeName()) {} + Red() : IteratorType(IteratorTypeT::reduction) {} }; StructuredGenerator(OpBuilder &builder, StructuredOpInterface op) : builder(builder), ctx(op.getContext()), loc(op.getLoc()), - iterators(op.getIteratorTypeNames()), maps(op.getIndexingMapsArray()), + iterators(op.getIteratorTypesArray()), maps(op.getIndexingMapsArray()), op(op) {} bool iters(ArrayRef its) { @@ -138,7 +103,7 @@ OpBuilder &builder; MLIRContext *ctx; Location loc; - SmallVector iterators; + SmallVector iterators; SmallVector maps; Operation *op; }; diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -269,12 +269,11 @@ return CombiningKind::ADD; } - // Returns iterator types in string format. - SmallVector getIteratorTypeNames() { - return llvm::to_vector( - llvm::map_range(getIteratorTypes(), [](Attribute a) { - return stringifyIteratorType(a.cast().getValue()); - })); + SmallVector getIteratorTypesArray() { + auto range = + getIteratorTypes() + .template getAsValueRange(); + return {range.begin(), range.end()}; } }]; diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -791,12 +791,12 @@ SmallVector srcExprs; SmallVector dstExprs; - SmallVector iteratorTypes; + SmallVector iteratorTypes; for (unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) { srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); - iteratorTypes.push_back(axis == i ? getReductionIteratorTypeName() - : getParallelIteratorTypeName()); + iteratorTypes.push_back(axis == i ? utils::IteratorType::reduction + : utils::IteratorType::parallel); if (axis != i) dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); } @@ -1383,7 +1383,8 @@ auto inputMap = AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, inputExprs, builder.getContext()); auto resultMap = rewriter.getMultiDimIdentityMap(resultTy.getRank()); - SmallVector iterators(4, getParallelIteratorTypeName()); + SmallVector iterators(4, + utils::IteratorType::parallel); Value empty = builder.create( resultTy.getShape(), resultTy.getElementType(), outputDynSize); @@ -2083,9 +2084,9 @@ // We need to reduce along the arg-max axis, with parallel operations along // the rest. - SmallVector iteratorTypes; - iteratorTypes.resize(inputTy.getRank(), getParallelIteratorTypeName()); - iteratorTypes[axis] = getReductionIteratorTypeName(); + SmallVector iteratorTypes; + iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel); + iteratorTypes[axis] = utils::IteratorType::reduction; SmallVector srcExprs; SmallVector dstExprs; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -321,7 +321,7 @@ if (inputExprWalker.unConvolvedDims.count(outputDim) && !filterDims.count(outputDim)) { // Batch dimension. - if (iteratorTypes[outputDim] != getParallelIteratorTypeName()) + if (iteratorTypes[outputDim] != utils::IteratorType::parallel) return MatchConvolutionResult::OutputDimsNotParallel; allLoopDims.insert(outputDim); continue; @@ -329,7 +329,7 @@ if (inputExprWalker.convolvedDims.count(outputDim) && !filterDims.count(outputDim)) { // Output image Loop dimension. - if (iteratorTypes[outputDim] != getParallelIteratorTypeName()) + if (iteratorTypes[outputDim] != utils::IteratorType::parallel) return MatchConvolutionResult::OutputDimsNotParallel; allLoopDims.insert(outputDim); continue; @@ -338,7 +338,7 @@ !inputExprWalker.unConvolvedDims.count(outputDim) && filterDims.count(outputDim)) { // Output channel dimension. - if (iteratorTypes[outputDim] != getParallelIteratorTypeName()) + if (iteratorTypes[outputDim] != utils::IteratorType::parallel) return MatchConvolutionResult::OutputDimsNotParallel; allLoopDims.insert(outputDim); continue; @@ -346,7 +346,7 @@ if (inputExprWalker.unConvolvedDims.count(outputDim) && filterDims.count(outputDim)) { // Depth multiplier. - if (iteratorTypes[outputDim] != getParallelIteratorTypeName()) + if (iteratorTypes[outputDim] != utils::IteratorType::parallel) return MatchConvolutionResult::OutputDimsNotParallel; allLoopDims.insert(outputDim); continue; @@ -364,7 +364,7 @@ if (inputExprWalker.convolvedDims.count(filterDim) && !outputDims.count(filterDim)) { // Filter loop dimension. - if (iteratorTypes[filterDim] != getReductionIteratorTypeName()) + if (iteratorTypes[filterDim] != utils::IteratorType::reduction) return MatchConvolutionResult::NonOutputDimNotReduction; if (allLoopDims.count(filterDim)) return MatchConvolutionResult::NonConvolutionLoop; @@ -374,7 +374,7 @@ if (inputExprWalker.unConvolvedDims.count(filterDim) && !outputDims.count(filterDim)) { // Input channel dimension. - if (iteratorTypes[filterDim] != getReductionIteratorTypeName()) + if (iteratorTypes[filterDim] != utils::IteratorType::reduction) return MatchConvolutionResult::NonOutputDimNotReduction; if (allLoopDims.count(filterDim)) return MatchConvolutionResult::NonConvolutionLoop; @@ -619,15 +619,6 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { LinalgOp linalgOp = cast(op); - // Check all iterator types are known. - auto iteratorTypesRange = linalgOp.getIteratorTypesArray(); - for (StringRef iteratorType : iteratorTypesRange) { - if (!llvm::is_contained(getAllIteratorTypeNames(), iteratorType) || - !utils::symbolizeIteratorType(iteratorType).has_value()) - return op->emitOpError("unexpected iterator_type (") - << iteratorType << ")"; - } - // Before checking indexing maps, we need to make sure the attributes // referenced by it are valid. if (linalgOp.hasDynamicIndexingMaps()) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -705,12 +705,17 @@ void GenericOp::build( OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, - ArrayRef iteratorTypes, StringRef doc, StringRef libraryCall, + ArrayRef iteratorTypes, StringRef doc, + StringRef libraryCall, function_ref bodyBuild, ArrayRef attributes) { build(builder, result, resultTensorTypes, inputs, outputs, builder.getAffineMapArrayAttr(indexingMaps), - builder.getStrArrayAttr(iteratorTypes), + builder.getArrayAttr(llvm::to_vector(llvm::map_range( + iteratorTypes, + [&](utils::IteratorType iter) -> mlir::Attribute { + return IteratorTypeAttr::get(builder.getContext(), iter); + }))), doc.empty() ? StringAttr() : builder.getStringAttr(doc), libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall), bodyBuild, attributes); @@ -719,7 +724,8 @@ void GenericOp::build( OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, - ArrayRef iteratorTypes, StringRef doc, StringRef libraryCall, + ArrayRef iteratorTypes, StringRef doc, + StringRef libraryCall, function_ref bodyBuild, ArrayRef attributes) { build(builder, result, TypeRange{}, inputs, outputs, indexingMaps, @@ -729,7 +735,7 @@ void GenericOp::build( OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, - ArrayRef iteratorTypes, + ArrayRef iteratorTypes, function_ref bodyBuild, ArrayRef attributes) { build(builder, result, inputs, outputs, indexingMaps, iteratorTypes, @@ -740,7 +746,7 @@ void GenericOp::build( OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, - ArrayRef iteratorTypes, + ArrayRef iteratorTypes, function_ref bodyBuild, ArrayRef attributes) { build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps, @@ -758,9 +764,29 @@ llvm::StringSet<> genericAttrNamesSet; genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end()); SmallVector genericAttrs; - for (auto attr : (*this)->getAttrs()) - if (genericAttrNamesSet.count(attr.getName().strref()) > 0) + for (auto attr : (*this)->getAttrs()) { + if (attr.getName() == getIteratorTypesAttrName()) { + auto iteratorTypes = + attr.getValue() + .cast() + .getAsValueRange(); + // Convert IteratorType enums into the string representation. This is + // needed, because tests still use the old format when 'iterator_types' + // attribute is represented as an array of strings. + // TODO: Remove this conversion once tests are fixed. + SmallVector iteratorTypeNames = + llvm::to_vector(llvm::map_range( + iteratorTypes, [&](utils::IteratorType t) -> Attribute { + return StringAttr::get(getContext(), stringifyIteratorType(t)); + })); + + genericAttrs.emplace_back( + getIteratorTypesAttrName(), + ArrayAttr::get(getContext(), iteratorTypeNames)); + } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) { genericAttrs.push_back(attr); + } + } if (!genericAttrs.empty()) { auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs); p << genericDictAttr; @@ -805,6 +831,28 @@ result.attributes.assign(dictAttr.getValue().begin(), dictAttr.getValue().end()); + // Convert array of string into an array of IteratyType enums. This is needed, + // because tests still use the old format when 'iterator_types' attribute is + // represented as an array of strings. + // TODO: Remove this conversion once tests are fixed. + ArrayAttr iteratorTypes = + result.attributes.get(getIteratorTypesAttrName(result.name)) + .cast(); + + SmallVector iteratorTypeAttrs; + + for (StringRef s : iteratorTypes.getAsValueRange()) { + auto maybeIteratorType = utils::symbolizeIteratorType(s); + if (!maybeIteratorType.has_value()) + return parser.emitError(parser.getCurrentLocation()) + << "unexpected iterator_type (" << s << ")"; + + iteratorTypeAttrs.push_back( + IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value())); + } + result.attributes.set(getIteratorTypesAttrName(result.name), + parser.getBuilder().getArrayAttr(iteratorTypeAttrs)); + // Parsing is shared with named ops, except for the region. SmallVector inputTypes, outputTypes; if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) @@ -1418,9 +1466,9 @@ return success(); } -SmallVector MapOp::getIteratorTypesArray() { +SmallVector MapOp::getIteratorTypesArray() { int64_t rank = getInit().getType().getRank(); - return SmallVector(rank, getParallelIteratorTypeName()); + return SmallVector(rank, utils::IteratorType::parallel); } ArrayAttr MapOp::getIndexingMaps() { @@ -1476,12 +1524,12 @@ inputs, inits, bodyBuild); } -SmallVector ReduceOp::getIteratorTypesArray() { +SmallVector ReduceOp::getIteratorTypesArray() { int64_t inputRank = getInputs()[0].getType().cast().getRank(); - SmallVector iteratorTypes(inputRank, - getParallelIteratorTypeName()); + SmallVector iteratorTypes(inputRank, + utils::IteratorType::parallel); for (int64_t reductionDim : getDimensions()) - iteratorTypes[reductionDim] = getReductionIteratorTypeName(); + iteratorTypes[reductionDim] = utils::IteratorType::reduction; return iteratorTypes; } @@ -1753,9 +1801,9 @@ return success(); } -SmallVector TransposeOp::getIteratorTypesArray() { +SmallVector TransposeOp::getIteratorTypesArray() { int64_t rank = getInit().getType().getRank(); - return SmallVector(rank, getParallelIteratorTypeName()); + return SmallVector(rank, utils::IteratorType::parallel); } ArrayAttr TransposeOp::getIndexingMaps() { @@ -1891,9 +1939,9 @@ return success(); } -SmallVector BroadcastOp::getIteratorTypesArray() { +SmallVector BroadcastOp::getIteratorTypesArray() { int64_t rank = getInit().getType().getRank(); - return SmallVector(rank, getParallelIteratorTypeName()); + return SmallVector(rank, utils::IteratorType::parallel); } ArrayAttr BroadcastOp::getIndexingMaps() { diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -470,10 +470,9 @@ .getValue() .isProjectedPermutation(); }) && - genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() > 0 && - llvm::all_of(genericOp.getIteratorTypesArray(), [](StringRef it) { - return it == getParallelIteratorTypeName(); - }); + genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() > + 0 && + llvm::all_of(genericOp.getIteratorTypesArray(), isParallelIterator); } namespace { @@ -783,8 +782,8 @@ } // The iterator types of the expanded op are all parallel. - SmallVector iteratorTypes(expansionInfo.getExpandedOpNumDims(), - getParallelIteratorTypeName()); + SmallVector iteratorTypes( + expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel); TypeRange resultTypes = ValueRange(outputs).getTypes(); auto fusedOp = @@ -1083,7 +1082,8 @@ continue; // Check that all folded iterator types are all parallel or all reductions. - StringRef startIteratorType = iteratorTypes[foldedIterationSpaceDims[0]]; + utils::IteratorType startIteratorType = + iteratorTypes[foldedIterationSpaceDims[0]]; if (!isParallelIterator(startIteratorType) && !isReductionIterator(startIteratorType)) continue; @@ -1235,10 +1235,10 @@ /// Get the iterator types for the collapsed operation given the original /// iterator types and collapsed dimensions. -static SmallVector -getCollapsedOpIteratorTypes(ArrayRef iteratorTypes, +static SmallVector +getCollapsedOpIteratorTypes(ArrayRef iteratorTypes, const CollapsingInfo &collapsingInfo) { - SmallVector collapsedIteratorTypes; + SmallVector collapsedIteratorTypes; for (ReassociationIndicesRef foldedIterDims : collapsingInfo.getCollapsedOpToOrigOpMapping()) { assert(!foldedIterDims.empty() && @@ -1246,8 +1246,7 @@ // Just pick the iterator type of the first folded dim. Pre-condition checks // expected to have checked that iterator types of all folded dimensions are // the same. - collapsedIteratorTypes.push_back( - iteratorTypes[foldedIterDims[0]].cast().getValue()); + collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]); } return collapsedIteratorTypes; } @@ -1406,8 +1405,8 @@ } // Get the iterator types for the operand. - SmallVector iteratorTypes = getCollapsedOpIteratorTypes( - genericOp.getIteratorTypes().getValue(), collapsingInfo); + SmallVector iteratorTypes = getCollapsedOpIteratorTypes( + genericOp.getIteratorTypesArray(), collapsingInfo); // Get the indexing maps. auto indexingMaps = llvm::to_vector( diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp @@ -91,8 +91,8 @@ SmallVector indexingMaps( op->getNumResults() + op->getNumOperands(), rewriter.getMultiDimIdentityMap(rank)); - SmallVector iteratorTypes(rank, - getParallelIteratorTypeName()); + SmallVector iteratorTypes( + rank, utils::IteratorType::parallel); auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op); rewriter.replaceOpWithNewOp( op, /*resultTensorTypes=*/op->getResultTypes(), diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -53,7 +53,7 @@ SmallVector inputs = linalgOp.getDpsInputOperands(); SmallVector outputs = linalgOp.getDpsInitOperands(); SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); - SmallVector iterators = linalgOp.getIteratorTypesArray(); + SmallVector iterators = linalgOp.getIteratorTypesArray(); SmallVector resultTypes = linalgOp.hasTensorSemantics() ? TypeRange(ValueRange(outputs)) : TypeRange{}; diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -162,13 +162,13 @@ newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr, op.getContext())); - SmallVector newIteratorTypes; + SmallVector newIteratorTypes; for (auto &it : llvm::enumerate(op.getIteratorTypesArray())) { if (insertSplitDimension == it.index() && !control.innerParallel) - newIteratorTypes.push_back(getParallelIteratorTypeName()); + newIteratorTypes.push_back(utils::IteratorType::parallel); newIteratorTypes.push_back(it.value()); if (insertSplitDimension == it.index() && control.innerParallel) - newIteratorTypes.push_back(getParallelIteratorTypeName()); + newIteratorTypes.push_back(utils::IteratorType::parallel); } // Create the new op matching the original op with an extra parallel // dimension. @@ -182,14 +182,14 @@ // from the previous op. unsigned intermRank = newOutputShape.size(); AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); - SmallVector reductionIteratorTypes; + SmallVector reductionIteratorTypes; SmallVector exprs; for (unsigned i : llvm::seq(0, intermRank)) { if (insertSplitDimension == i) { - reductionIteratorTypes.push_back(getReductionIteratorTypeName()); + reductionIteratorTypes.push_back(utils::IteratorType::reduction); } else { exprs.push_back(b.getAffineDimExpr(i)); - reductionIteratorTypes.push_back(getParallelIteratorTypeName()); + reductionIteratorTypes.push_back(utils::IteratorType::parallel); } } AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext()); @@ -367,7 +367,7 @@ // dimension. auto iteratorTypes = op.getIteratorTypesArray(); iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos, - getParallelIteratorTypeName()); + utils::IteratorType::parallel); GenericOp genericOp = b.create(loc, ValueRange(newOutputs).getTypes(), newInputs, newOutputs, newMaps, iteratorTypes); @@ -394,10 +394,10 @@ AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1); SmallVector indexingMaps = { map, map.dropResult(insertSplitDimension)}; - SmallVector reductionIteratorTypes( - originalOutputType.getRank() + 1, getParallelIteratorTypeName()); + SmallVector reductionIteratorTypes( + originalOutputType.getRank() + 1, utils::IteratorType::parallel); reductionIteratorTypes[insertSplitDimension] = - getReductionIteratorTypeName(); + utils::IteratorType::reduction; // clang-format off auto reductionOp = b.create( diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -431,7 +431,7 @@ auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges( b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes); - SmallVector iteratorTypes; + SmallVector iteratorTypes; for (const auto &attr : enumerate(op.getIteratorTypesArray())) { if (loopIndexToRangeIndex.count(attr.index())) iteratorTypes.push_back(attr.value()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -88,10 +88,7 @@ /// Return the loop iterator type. SmallVector getLoopIteratorTypes(Operation *op) const { LinalgOpTy concreteOp = cast(op); - return llvm::to_vector(llvm::map_range( - concreteOp.getIteratorTypesArray(), [](StringRef iteratorType) { - return utils::symbolizeIteratorType(iteratorType).value(); - })); + return concreteOp.getIteratorTypesArray(); } /// Return the iteration domain range. @@ -339,8 +336,9 @@ // Step3. create a generic op where the reduction dimension is replaced by a // parallel dimension of the size of reduction. - SmallVector newIteratorTypes = linalgOp.getIteratorTypesArray(); - newIteratorTypes[reductionDims[0]] = getParallelIteratorTypeName(); + SmallVector newIteratorTypes = + linalgOp.getIteratorTypesArray(); + newIteratorTypes[reductionDims[0]] = utils::IteratorType::parallel; SmallVector newMaps = linalgOp.getIndexingMapsArray(); newMaps.back() = AffineMap::get(newMaps.back().getNumDims(), 0, outputExpr, linalgOp.getContext()); @@ -366,14 +364,14 @@ int64_t intermRank = partialReduce[0].getType().cast().getRank(); AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); - SmallVector reductionIteratorTypes; + SmallVector reductionIteratorTypes; SmallVector exprs; for (int64_t i : llvm::seq(0, intermRank)) { if (dimToMerge == i) { - reductionIteratorTypes.push_back(getReductionIteratorTypeName()); + reductionIteratorTypes.push_back(utils::IteratorType::reduction); } else { exprs.push_back(b.getAffineDimExpr(i)); - reductionIteratorTypes.push_back(getParallelIteratorTypeName()); + reductionIteratorTypes.push_back(utils::IteratorType::parallel); } } AffineMap outputMap = diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -297,8 +297,10 @@ return vectorizeCopy(rewriter, copyOp); } -static SmallVector getNParallelLoopsAttrs(unsigned nParallelLoops) { - return SmallVector(nParallelLoops, getParallelIteratorTypeName()); +static SmallVector +getNParallelLoopsAttrs(unsigned nParallelLoops) { + return SmallVector(nParallelLoops, + utils::IteratorType::parallel); } /// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp (to diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1420,11 +1420,12 @@ /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}} /// ``` /// kw is unrolled, w is unrolled iff dilationW > 1. -struct Conv1DGenerator : public StructuredGenerator { +struct Conv1DGenerator + : public StructuredGenerator { Conv1DGenerator(OpBuilder &builder, LinalgOp linalgOp, int strideW, int dilationW) - : StructuredGenerator(builder, linalgOp), strideW(strideW), - dilationW(dilationW) { + : StructuredGenerator(builder, linalgOp), + strideW(strideW), dilationW(dilationW) { // Determine whether `linalgOp` can be generated with this generator if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1) return; diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -186,12 +186,12 @@ return hasOnlyScalarElementwiseOp(op->getRegion(0)); } -bool isParallelIterator(StringRef iteratorType) { - return iteratorType == getParallelIteratorTypeName(); +bool isParallelIterator(utils::IteratorType iteratorType) { + return iteratorType == utils::IteratorType::parallel; } -bool isReductionIterator(StringRef iteratorType) { - return iteratorType == getReductionIteratorTypeName(); +bool isReductionIterator(utils::IteratorType iteratorType) { + return iteratorType == utils::IteratorType::reduction; } /// Helper function that creates a memref::DimOp or tensor::DimOp depending on @@ -422,15 +422,13 @@ b.getContext())), AffineMap::getMultiDimIdentityMap(transposeVector.size(), b.getContext())}; - SmallVector iteratorTypes(transposeVector.size(), - getParallelIteratorTypeName()); + SmallVector iteratorTypes(transposeVector.size(), + utils::IteratorType::parallel); // Create a GenericOp to transpose `inputTensor` into `outputTensor`. - auto transposeOp = b.create( - loc, resultTensorType, inputTensor, outputTensor, - b.getAffineMapArrayAttr(indexingMaps), b.getStrArrayAttr(iteratorTypes), - /*doc=*/nullptr, - /*library_call=*/nullptr); + auto transposeOp = + b.create(loc, resultTensorType, inputTensor, outputTensor, + indexingMaps, iteratorTypes); Region &body = transposeOp.getRegion(); body.push_back(new Block()); body.front().addArguments({elementType, elementType}, {loc, loc}); @@ -452,8 +450,8 @@ AffineMap id = AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext()); - SmallVector iteratorTypes(memrefTypeTo.getRank(), - getParallelIteratorTypeName()); + SmallVector iteratorTypes(memrefTypeTo.getRank(), + utils::IteratorType::parallel); return b.create( loc, /*inputs=*/from, @@ -469,7 +467,7 @@ template <> void GenerateLoopNest::doit( OpBuilder &b, Location loc, ArrayRef loopRanges, LinalgOp linalgOp, - ArrayRef iteratorTypes, + ArrayRef iteratorTypes, function_ref bodyBuilderFn, @@ -513,7 +511,7 @@ template <> void GenerateLoopNest::doit( OpBuilder &b, Location loc, ArrayRef loopRanges, LinalgOp linalgOp, - ArrayRef iteratorTypes, + ArrayRef iteratorTypes, function_ref bodyBuilderFn, @@ -564,7 +562,7 @@ // exceeds 10. static void generateParallelLoopNest( OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs, - ValueRange steps, ArrayRef iteratorTypes, + ValueRange steps, ArrayRef iteratorTypes, ArrayRef procInfo, function_ref bodyBuilderFn, SmallVectorImpl &ivStorage) { @@ -679,7 +677,7 @@ template <> void GenerateLoopNest::doit( OpBuilder &b, Location loc, ArrayRef loopRanges, LinalgOp linalgOp, - ArrayRef iteratorTypes, + ArrayRef iteratorTypes, function_ref bodyBuilderFn, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -178,7 +178,8 @@ /// as we use adj matrix for the graph. /// The sorted result will put the first Reduction iterator to the /// latest possible index. -static bool topSortOptimal(unsigned n, ArrayRef iteratorTypes, +static bool topSortOptimal(unsigned n, + ArrayRef iteratorTypes, std::vector &topSort, std::vector &inDegree, std::vector> &adjM) { diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp --- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp @@ -15,9 +15,10 @@ using namespace mlir; using namespace mlir::tosa; -SmallVector +SmallVector mlir::tosa::getNParallelLoopsAttrs(unsigned nParallelLoops) { - return SmallVector(nParallelLoops, getParallelIteratorTypeName()); + return SmallVector(nParallelLoops, + utils::IteratorType::parallel); } SmallVector diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1518,27 +1518,14 @@ } namespace { -struct IteratorType { - IteratorType(StringRef strRef) : strRef(strRef) {} - bool isOfType(Attribute attr) const { - auto sAttr = attr.dyn_cast(); - return sAttr && sAttr.getValue() == strRef; - } - StringRef strRef; -}; -struct Par : public IteratorType { - Par() : IteratorType(getParallelIteratorTypeName()) {} -}; -struct Red : public IteratorType { - Red() : IteratorType(getReductionIteratorTypeName()) {} -}; /// Generate a vector implementation for matmat, matvec and tmatvec. /// This unrolls outer-products along the reduction dimension. struct UnrolledOuterProductGenerator - : public StructuredGenerator { + : public StructuredGenerator { UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op) - : StructuredGenerator(builder, op), + : StructuredGenerator( + builder, op), kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()), res(op.getAcc()), lhsType(op.getLhsType()) {} @@ -2719,8 +2706,10 @@ } else { MemRefLayoutAttrInterface updatedLayout; if (auto strided = layout.dyn_cast()) { - auto strides = llvm::to_vector(strided.getStrides().drop_back(dimsToDrop)); - updatedLayout = StridedLayoutAttr::get(strided.getContext(), strided.getOffset(), strides); + auto strides = + llvm::to_vector(strided.getStrides().drop_back(dimsToDrop)); + updatedLayout = StridedLayoutAttr::get(strided.getContext(), + strided.getOffset(), strides); } else { AffineMap map = srcType.getLayout().getAffineMap(); int numSymbols = map.getNumSymbols(); diff --git a/mlir/test/Dialect/Linalg/conv-interface-invalid.mlir b/mlir/test/Dialect/Linalg/conv-interface-invalid.mlir --- a/mlir/test/Dialect/Linalg/conv-interface-invalid.mlir +++ b/mlir/test/Dialect/Linalg/conv-interface-invalid.mlir @@ -17,7 +17,7 @@ // expected-error @+1 {{expected op with 2 inputs and 1 output}} %0 = test.linalg_conv_op { indexing_maps = [#map, #map], - iterator_types = ["parallel"]} + iterator_types = [#test.iterator_type]} ins(%arg0 : tensor) outs(%arg1 : tensor) { ^bb0(%arg2 : f32, %arg3 : f32): linalg.yield %arg3 : f32 @@ -34,7 +34,8 @@ indexing_maps = [affine_map<(d0, d1) -> (d0 * 2)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"]} + iterator_types = [#test.iterator_type, + #test.iterator_type]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): @@ -52,7 +53,8 @@ indexing_maps = [affine_map<(d0, d1) -> (d0 + d1, d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"]} + iterator_types = [#test.iterator_type, + #test.iterator_type]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): @@ -70,7 +72,8 @@ indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1 + d0)>, affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"]} + iterator_types = [#test.iterator_type, + #test.iterator_type]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): @@ -88,7 +91,8 @@ indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0 + d1)>], - iterator_types = ["parallel", "parallel"]} + iterator_types = [#test.iterator_type, + #test.iterator_type]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): @@ -108,7 +112,8 @@ indexing_maps = [affine_map<(d0, d1) -> (d0 + d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} + iterator_types = [#test.iterator_type, + #test.iterator_type]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): @@ -127,7 +132,9 @@ indexing_maps = [affine_map<(d0, d1, d2) -> (d0 + d1)>, affine_map<(d0, d1, d2) -> (d1)>, affine_map<(d0, d1, d2) -> (d0, d2)>], - iterator_types = ["parallel", "reduction", "parallel"]} + iterator_types = [#test.iterator_type, + #test.iterator_type, + #test.iterator_type]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): @@ -146,7 +153,9 @@ indexing_maps = [affine_map<(d0, d1, d2) -> (d0 + d1)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], - iterator_types = ["parallel", "reduction", "reduction"]} + iterator_types = [#test.iterator_type, + #test.iterator_type, + #test.iterator_type]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): @@ -165,7 +174,9 @@ indexing_maps = [affine_map<(d0, d1, d2) -> (d0 + d1, d2)>, affine_map<(d0, d1, d2) -> (d1)>, affine_map<(d0, d1, d2) -> (d0)>], - iterator_types = ["parallel", "reduction", "reduction"]} + iterator_types = [#test.iterator_type, + #test.iterator_type, + #test.iterator_type]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): @@ -184,7 +195,8 @@ indexing_maps = [affine_map<(d0, d1) -> (d0 + d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "parallel"]} + iterator_types = [#test.iterator_type, + #test.iterator_type]} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -96,7 +96,7 @@ // ----- func.func @generic_wrong_iterator(%arg0: memref<1xi32>) { - // expected-error @+1 {{op unexpected iterator_type (random)}} + // expected-error @+4 {{unexpected iterator_type (random)}} linalg.generic { indexing_maps = [ affine_map<(i) -> (i)> ], iterator_types = ["random"]} diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir --- a/mlir/test/Dialect/Linalg/transform-op-match.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir @@ -59,13 +59,19 @@ ^bb1(%arg1: !pdl.operation): %match_attr = transform.structured.match ops{["linalg.generic"]} - attributes{iterator_types = ["parallel", "parallel", "parallel"]} + attributes{iterator_types = [ + #linalg.iterator_type, + #linalg.iterator_type, + #linalg.iterator_type]} in %arg1 transform.test_print_remark_at_operand %match_attr, "matched complex attr" : !pdl.operation transform.test_consume_operand %match_attr %no_match = transform.structured.match - attributes{iterator_types = ["parallel", "parallel", "reduction"]} + attributes{iterator_types = [ + #linalg.iterator_type, + #linalg.iterator_type, + #linalg.iterator_type]} in %arg1 // expected-remark @below {{0}} transform.test_print_number_of_associated_payload_ir_ops %no_match diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -23,13 +23,16 @@ mlir_tablegen(TestTypeDefs.cpp.inc -gen-typedef-defs -typedefs-dialect=test) add_public_tablegen_target(MLIRTestTypeDefIncGen) +set(LLVM_TARGET_DEFINITIONS TestEnumDefs.td) +mlir_tablegen(TestOpEnums.h.inc -gen-enum-decls) +mlir_tablegen(TestOpEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRTestEnumDefIncGen) + set(LLVM_TARGET_DEFINITIONS TestOps.td) mlir_tablegen(TestOps.h.inc -gen-op-decls) mlir_tablegen(TestOps.cpp.inc -gen-op-defs) mlir_tablegen(TestOpsDialect.h.inc -gen-dialect-decls -dialect=test) mlir_tablegen(TestOpsDialect.cpp.inc -gen-dialect-defs -dialect=test) -mlir_tablegen(TestOpEnums.h.inc -gen-enum-decls) -mlir_tablegen(TestOpEnums.cpp.inc -gen-enum-defs) mlir_tablegen(TestPatterns.inc -gen-rewriters) add_public_tablegen_target(MLIRTestOpsIncGen) @@ -46,6 +49,7 @@ DEPENDS MLIRTestAttrDefIncGen + MLIRTestEnumDefIncGen MLIRTestInterfaceIncGen MLIRTestTypeDefIncGen MLIRTestOpsIncGen diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -15,6 +15,8 @@ // To get the test dialect definition. include "TestDialect.td" +include "TestEnumDefs.td" +include "mlir/Dialect/Utils/StructuredOpsUtils.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/BuiltinAttributeInterfaces.td" include "mlir/IR/EnumAttr.td" @@ -277,13 +279,6 @@ "array_of_ints", "int32_t">; // An array of enum attributes. -def TestSimpleEnum : I32EnumAttr<"SimpleEnum", "", [ - I32EnumAttrCase<"a", 0>, - I32EnumAttrCase<"b", 1> - ]> { - let genSpecializedAttr = 0; - let cppNamespace = "::test"; -} def TestSimpleEnumAttr : EnumAttr { let assemblyFormat = "`` $value"; } @@ -297,4 +292,14 @@ let assemblyFormat = "`<` $a (`>`) : (`,` ` ` custom($b)^ `>`)?"; } +def Test_IteratorTypeEnum + : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def Test_IteratorTypeArrayAttr + : TypedArrayAttrBase; + + #endif // TEST_ATTRDEFS diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h --- a/mlir/test/lib/Dialect/Test/TestAttributes.h +++ b/mlir/test/lib/Dialect/Test/TestAttributes.h @@ -17,6 +17,7 @@ #include #include "TestTraits.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/test/lib/Dialect/Test/TestEnumDefs.td b/mlir/test/lib/Dialect/Test/TestEnumDefs.td new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestEnumDefs.td @@ -0,0 +1,97 @@ +//===-- TestEnumDefs.td - Test dialect enum definitions ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// TableGen enum definitions for Test dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_ENUMDEFS +#define TEST_ENUMDEFS + +include "mlir/IR/EnumAttr.td" + +def I32Case5: I32EnumAttrCase<"case5", 5>; +def I32Case10: I32EnumAttrCase<"case10", 10>; + +def SomeI32Enum: I32EnumAttr< + "SomeI32Enum", "", [I32Case5, I32Case10]>; + +def I64Case5: I64EnumAttrCase<"case5", 5>; +def I64Case10: I64EnumAttrCase<"case10", 10>; + +def SomeI64Enum: I64EnumAttr< + "SomeI64Enum", "", [I64Case5, I64Case10]>; + +//===----------------------------------------------------------------------===// +// Test Enum +//===----------------------------------------------------------------------===// + +// Define the C++ enum. +def TestEnum + : I32EnumAttr<"TestEnum", "a test enum", [ + I32EnumAttrCase<"First", 0, "first">, + I32EnumAttrCase<"Second", 1, "second">, + I32EnumAttrCase<"Third", 2, "third">, + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "test"; +} + +def TestSimpleEnum : I32EnumAttr<"SimpleEnum", "", [ + I32EnumAttrCase<"a", 0>, + I32EnumAttrCase<"b", 1> + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::test"; +} + +//===----------------------------------------------------------------------===// +// Test Bit Enum +//===----------------------------------------------------------------------===// + +// Define the C++ enum. +def TestBitEnum + : I32BitEnumAttr<"TestBitEnum", "a test bit enum", [ + I32BitEnumAttrCaseBit<"Read", 0, "read">, + I32BitEnumAttrCaseBit<"Write", 1, "write">, + I32BitEnumAttrCaseBit<"Execute", 2, "execute">, + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "test"; + let separator = ", "; +} + +// Define an enum with a different separator +def TestBitEnumVerticalBar + : I32BitEnumAttr<"TestBitEnumVerticalBar", "another test bit enum", [ + I32BitEnumAttrCaseBit<"User", 0, "user">, + I32BitEnumAttrCaseBit<"Group", 1, "group">, + I32BitEnumAttrCaseBit<"Other", 2, "other">, + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "test"; + let separator = " | "; +} + +//===----------------------------------------------------------------------===// +// Test Patterns (Multi-result Ops) + +def MultiResultOpKind1: I64EnumAttrCase<"kind1", 1>; +def MultiResultOpKind2: I64EnumAttrCase<"kind2", 2>; +def MultiResultOpKind3: I64EnumAttrCase<"kind3", 3>; +def MultiResultOpKind4: I64EnumAttrCase<"kind4", 4>; +def MultiResultOpKind5: I64EnumAttrCase<"kind5", 5>; +def MultiResultOpKind6: I64EnumAttrCase<"kind6", 6>; + +def MultiResultOpEnum: I64EnumAttr< + "MultiResultOpEnum", "Multi-result op kinds", [ + MultiResultOpKind1, MultiResultOpKind2, MultiResultOpKind3, + MultiResultOpKind4, MultiResultOpKind5, MultiResultOpKind6 + ]>; + +#endif // TEST_ENUMDEFS diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -202,23 +202,11 @@ ); } -def I32Case5: I32EnumAttrCase<"case5", 5>; -def I32Case10: I32EnumAttrCase<"case10", 10>; - -def SomeI32Enum: I32EnumAttr< - "SomeI32Enum", "", [I32Case5, I32Case10]>; - def I32EnumAttrOp : TEST_Op<"i32_enum_attr"> { let arguments = (ins SomeI32Enum:$attr); let results = (outs I32:$val); } -def I64Case5: I64EnumAttrCase<"case5", 5>; -def I64Case10: I64EnumAttrCase<"case10", 10>; - -def SomeI64Enum: I64EnumAttr< - "SomeI64Enum", "", [I64Case5, I64Case10]>; - def I64EnumAttrOp : TEST_Op<"i64_enum_attr"> { let arguments = (ins SomeI64Enum:$attr); let results = (outs I32:$val); @@ -319,17 +307,6 @@ // Test Enum Attributes //===----------------------------------------------------------------------===// -// Define the C++ enum. -def TestEnum - : I32EnumAttr<"TestEnum", "a test enum", [ - I32EnumAttrCase<"First", 0, "first">, - I32EnumAttrCase<"Second", 1, "second">, - I32EnumAttrCase<"Third", 2, "third">, - ]> { - let genSpecializedAttr = 0; - let cppNamespace = "test"; -} - // Define the enum attribute. def TestEnumAttr : EnumAttr; @@ -351,18 +328,6 @@ // Test Bit Enum Attributes //===----------------------------------------------------------------------===// -// Define the C++ enum. -def TestBitEnum - : I32BitEnumAttr<"TestBitEnum", "a test bit enum", [ - I32BitEnumAttrCaseBit<"Read", 0, "read">, - I32BitEnumAttrCaseBit<"Write", 1, "write">, - I32BitEnumAttrCaseBit<"Execute", 2, "execute">, - ]> { - let genSpecializedAttr = 0; - let cppNamespace = "test"; - let separator = ", "; -} - // Define the enum attribute. def TestBitEnumAttr : EnumAttr { let assemblyFormat = "`<` $value `>`"; @@ -374,18 +339,6 @@ let assemblyFormat = "$value (`tag` $tag^)? attr-dict"; } -// Define an enum with a different separator -def TestBitEnumVerticalBar - : I32BitEnumAttr<"TestBitEnumVerticalBar", "another test bit enum", [ - I32BitEnumAttrCaseBit<"User", 0, "user">, - I32BitEnumAttrCaseBit<"Group", 1, "group">, - I32BitEnumAttrCaseBit<"Other", 2, "other">, - ]> { - let genSpecializedAttr = 0; - let cppNamespace = "test"; - let separator = " | "; -} - def TestBitEnumVerticalBarAttr : EnumAttr { let assemblyFormat = "`<` $value `>`"; @@ -1392,22 +1345,6 @@ def : Pat<(I32EnumAttrOp I32Case5), (I32EnumAttrOp I32Case10)>; def : Pat<(I64EnumAttrOp I64Case5), (I64EnumAttrOp I64Case10)>; -//===----------------------------------------------------------------------===// -// Test Patterns (Multi-result Ops) - -def MultiResultOpKind1: I64EnumAttrCase<"kind1", 1>; -def MultiResultOpKind2: I64EnumAttrCase<"kind2", 2>; -def MultiResultOpKind3: I64EnumAttrCase<"kind3", 3>; -def MultiResultOpKind4: I64EnumAttrCase<"kind4", 4>; -def MultiResultOpKind5: I64EnumAttrCase<"kind5", 5>; -def MultiResultOpKind6: I64EnumAttrCase<"kind6", 6>; - -def MultiResultOpEnum: I64EnumAttr< - "MultiResultOpEnum", "Multi-result op kinds", [ - MultiResultOpKind1, MultiResultOpKind2, MultiResultOpKind3, - MultiResultOpKind4, MultiResultOpKind5, MultiResultOpKind6 - ]>; - def ThreeResultOp : TEST_Op<"three_result"> { let arguments = (ins MultiResultOpEnum:$kind); let results = (outs I32:$result1, F32:$result2, F32:$result3); @@ -2824,8 +2761,10 @@ return ®ionBuilder; } - mlir::ArrayAttr getIteratorTypes() { - return getOperation()->getAttrOfType("iterator_types"); + llvm::SmallVector getIteratorTypesArray() { + auto attrs = getOperation()->getAttrOfType("iterator_types"); + auto range = attrs.getAsValueRange(); + return {range.begin(), range.end()}; } mlir::ArrayAttr getIndexingMaps() { @@ -2884,8 +2823,10 @@ return ®ionBuilder; } - mlir::ArrayAttr getIteratorTypes() { - return getOperation()->getAttrOfType("iterator_types"); + llvm::SmallVector getIteratorTypesArray() { + auto attrs = getOperation()->getAttrOfType("iterator_types"); + auto range = attrs.getAsValueRange(); + return {range.begin(), range.end()}; } mlir::ArrayAttr getIndexingMaps() { diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -553,7 +553,7 @@ let extraClassDeclaration = structuredOpsBaseDecls # [{{ // Auto-generated. - SmallVector getIteratorTypesArray(); + SmallVector getIteratorTypesArray(); ArrayAttr getIndexingMaps(); static void regionBuilder(ImplicitLocOpBuilder &b, Block &block, ArrayRef attrs); @@ -597,8 +597,8 @@ // {1}: Comma interleaved iterator type names. static const char structuredOpIteratorTypesFormat[] = R"FMT( -SmallVector {0}::getIteratorTypesArray() {{ - return SmallVector{{ {1} }; +SmallVector {0}::getIteratorTypesArray() {{ + return SmallVector{{ {1} }; } )FMT"; @@ -607,9 +607,9 @@ // {0}: Class name static const char rankPolyStructuredOpIteratorTypesFormat[] = R"FMT( -SmallVector {0}::getIteratorTypesArray() {{ +SmallVector {0}::getIteratorTypesArray() {{ int64_t rank = getRank(getDpsInitOperand(0)); - return SmallVector(rank, getParallelIteratorTypeName()); + return SmallVector(rank, utils::IteratorType::parallel); } )FMT"; @@ -812,10 +812,10 @@ [&](LinalgIteratorTypeDef it) { switch (it) { case LinalgIteratorTypeDef::parallel: - ss << "getParallelIteratorTypeName()"; + ss << "utils::IteratorType::parallel"; break; case LinalgIteratorTypeDef::reduction: - ss << "getReductionIteratorTypeName()"; + ss << "utils::IteratorType::reduction"; break; } }); diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -131,14 +131,6 @@ ], "lib/Dialect/Test/TestOpsDialect.cpp.inc", ), - ( - ["-gen-enum-decls"], - "lib/Dialect/Test/TestOpEnums.h.inc", - ), - ( - ["-gen-enum-defs"], - "lib/Dialect/Test/TestOpEnums.cpp.inc", - ), ( ["-gen-rewriters"], "lib/Dialect/Test/TestPatterns.inc", @@ -211,6 +203,27 @@ ], ) +gentbl_cc_library( + name = "TestEnumDefsIncGen", + strip_include_prefix = "lib/Dialect/Test", + tbl_outs = [ + ( + ["-gen-enum-decls"], + "lib/Dialect/Test/TestOpEnums.h.inc", + ), + ( + ["-gen-enum-defs"], + "lib/Dialect/Test/TestOpEnums.cpp.inc", + ), + ], + tblgen = "//mlir:mlir-tblgen", + td_file = "lib/Dialect/Test/TestEnumDefs.td", + test = True, + deps = [ + ":TestOpTdFiles", + ], +) + gentbl_cc_library( name = "TestTypeDefsIncGen", strip_include_prefix = "lib/Dialect/Test", @@ -318,6 +331,7 @@ ], deps = [ ":TestAttrDefsIncGen", + ":TestEnumDefsIncGen", ":TestInterfacesIncGen", ":TestOpsIncGen", ":TestTypeDefsIncGen", @@ -330,6 +344,7 @@ "//mlir:DerivedAttributeOpInterface", "//mlir:DestinationStyleOpInterface", "//mlir:Dialect", + "//mlir:DialectUtils", "//mlir:FuncDialect", "//mlir:FuncTransforms", "//mlir:IR",