diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -727,6 +727,10 @@ // TODO: Remove once prefixing is flipped. ArrayAttr getIteratorTypes() { return iterator_types(); } + SmallVector getIteratorTypeNames() { + return llvm::to_vector(getIteratorTypes().getAsValueRange()); + } + //========================================================================// // Forwarding functions to access interface methods from the // DestinationStyleOpInterface. diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -45,6 +45,12 @@ /// `[0, permutation.size())`. bool isPermutation(ArrayRef permutation); +/// Check if `attr` has "parallel" iterator type semantics. +bool isParallelIterator(Attribute attr); + +/// Check if `attr` has "reduction" iterator type semantics. +bool isReductionIterator(Attribute attr); + /// Helper function that creates a memref::DimOp or tensor::DimOp depending on /// the type of `source`. Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim); diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -78,24 +78,12 @@ /// Use to encode that a particular iterator type has parallel semantics. constexpr StringRef getParallelIteratorTypeName() { return "parallel"; } -inline bool isParallelIterator(Attribute attr) { - auto strAttr = attr.dyn_cast_or_null(); - return strAttr && strAttr.getValue() == getParallelIteratorTypeName(); -} /// Use to encode that a particular iterator type has reduction semantics. constexpr StringRef getReductionIteratorTypeName() { return "reduction"; } -inline bool isReductionIterator(Attribute attr) { - auto strAttr = attr.dyn_cast_or_null(); - return strAttr && strAttr.getValue() == getReductionIteratorTypeName(); -} /// Use to encode that a particular iterator type has window semantics. constexpr StringRef getWindowIteratorTypeName() { return "window"; } -inline bool isWindowIterator(Attribute attr) { - auto strAttr = attr.dyn_cast_or_null(); - return strAttr && strAttr.getValue() == getWindowIteratorTypeName(); -} /// Use to encode that a particular iterator type has window semantics. inline ArrayRef getAllIteratorTypeNames() { @@ -122,19 +110,6 @@ return res; } -/// Typed representation for loop type strings. -enum class IteratorType { Parallel, Reduction }; - -inline StringRef toString(IteratorType t) { - switch (t) { - case IteratorType::Parallel: - return getParallelIteratorTypeName(); - case IteratorType::Reduction: - return getReductionIteratorTypeName(); - } - llvm_unreachable("Unsupported IteratorType"); -} - /// Helper StructuredGenerator class to manipulate and rewrite ops with /// `StructuredOpInterface`. This is templated for now because VectorOps do not /// yet implement the StructuredOpInterface itself. @@ -145,10 +120,7 @@ struct IteratorType { IteratorType(StringRef strRef) : strRef(strRef) {} - bool isOfType(Attribute attr) const { - auto sAttr = attr.dyn_cast(); - return sAttr && sAttr.getValue() == strRef; - } + bool isOfType(StringRef typeName) const { return typeName == strRef; } StringRef strRef; }; struct Par : public IteratorType { @@ -163,7 +135,7 @@ StructuredGenerator(OpBuilder &builder, StructuredOpInterface op) : builder(builder), ctx(op.getContext()), loc(op.getLoc()), - iterators(op.getIteratorTypes()), maps(op.getIndexingMapsArray()), + iterators(op.getIteratorTypeNames()), maps(op.getIndexingMapsArray()), op(op) {} bool iters(ArrayRef its) { @@ -185,7 +157,7 @@ OpBuilder &builder; MLIRContext *ctx; Location loc; - ArrayAttr iterators; + SmallVector iterators; SmallVector maps; Operation *op; }; diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -185,6 +185,17 @@ /// corresponding arith operation. Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value v2); + +/// Returns true if `attr` has "parallel" iterator type semantics. +inline bool isParallelIterator(Attribute attr) { + return attr.cast().getValue() == IteratorType::parallel; +} + +/// Returns true if `attr` has "reduction" iterator type semantics. +inline bool isReductionIterator(Attribute attr) { + return attr.cast().getValue() == IteratorType::reduction; +} + } // namespace vector } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -63,6 +63,21 @@ let assemblyFormat = "`<` $value `>`"; } +def IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [ + I32EnumAttrCase<"parallel", 0>, + I32EnumAttrCase<"reduction", 1> + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::vector"; +} + +def IteratorTypeEnum : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def IteratorTypeArrayAttr : TypedArrayAttrBase; + // TODO: Add an attribute to specify a different algebra with operators other // than the current set: {*, +}. def Vector_ContractionOp : @@ -76,7 +91,7 @@ Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc, Variadic>:$masks, ArrayAttr:$indexing_maps, - ArrayAttr:$iterator_types, + IteratorTypeArrayAttr:$iterator_types, DefaultValuedAttr:$kind)>, Results<(outs AnyType)> { @@ -201,7 +216,7 @@ "ArrayAttr":$indexingMaps, "ArrayAttr":$iteratorTypes)>, OpBuilder<(ins "Value":$lhs, "Value":$rhs, "Value":$acc, "ArrayRef>":$indexingExprs, - "ArrayRef":$iteratorTypes)>, + "ArrayRef":$iteratorTypes)>, OpBuilder<(ins "Value":$lhs, "Value":$rhs, "Value":$acc, "ArrayAttr":$indexingMaps, "ArrayAttr":$iteratorTypes, "CombiningKind":$kind)> @@ -249,6 +264,14 @@ static CombiningKind getDefaultKind() { return CombiningKind::ADD; } + + // Returns iterator types in string format. + SmallVector getIteratorTypeNames() { + return llvm::to_vector( + llvm::map_range(getIteratorTypes(), [](Attribute a) { + return stringifyIteratorType(a.cast().getValue()); + })); + } }]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h --- a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h +++ b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h @@ -71,7 +71,7 @@ VectorType fragmentType; bool isAccum; int64_t numTiles; - IteratorType contiguousDimType; + vector::IteratorType contiguousDimType; NVVM::MMALayout targetLayout; }; diff --git a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp --- a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp +++ b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp @@ -217,10 +217,10 @@ params.targetLayout = NVVM::MMALayout::col; } ArrayRef shape = type.vectorType.getShape(); - params.contiguousDimType = - transpose ? IteratorType::Parallel : IteratorType::Reduction; + params.contiguousDimType = transpose ? vector::IteratorType::parallel + : vector::IteratorType::reduction; - if (params.contiguousDimType == IteratorType::Reduction) { + if (params.contiguousDimType == vector::IteratorType::reduction) { params.numTiles = (shape[0] / kNumRowsPerTile) * ((shape[1] * elType.getIntOrFloatBitWidth()) / 128); } else { @@ -250,7 +250,7 @@ }; // This case corresponds to row-major A|C or col-major B operands. - if (params.contiguousDimType == IteratorType::Reduction) { + if (params.contiguousDimType == vector::IteratorType::reduction) { AffineExpr row = d0 % (operandShape[0]); AffineExpr col = d0.floorDiv(operandShape[0]) * (kElementsPer128b); return makeMap({row, col}); @@ -258,7 +258,7 @@ // This case Corresponds to col-major A|C or row-major B operands. The // operandShape given is already pre-transposed (e.g. 8x16 = KxN). - if (params.contiguousDimType == IteratorType::Parallel) { + if (params.contiguousDimType == vector::IteratorType::parallel) { const int64_t num8x128bCols = (operandShape[0] * bitsPerElement) / 128; // Threads are assigned in groups of 8 first across columns, then to // rows. This is transpose of what `ldmatrix` expects, but when @@ -293,9 +293,9 @@ SmallVector maps = op.getIndexingMapsArray(); if (iteratorTypes.size() != 3) return failure(); - if (!(isParallelIterator(iteratorTypes[0]) && - isParallelIterator(iteratorTypes[1]) && - isReductionIterator(iteratorTypes[2]))) + if (!(vector::isParallelIterator(iteratorTypes[0]) && + vector::isParallelIterator(iteratorTypes[1]) && + vector::isReductionIterator(iteratorTypes[2]))) return failure(); // The canonical form is "TNT" = A row-major, B col-major, C row-major. diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -74,9 +74,9 @@ AffineExpr m, n, k; bindDims(contract.getContext(), m, n, k); auto iteratorTypes = contract.getIteratorTypes().getValue(); - if (!(isParallelIterator(iteratorTypes[0]) && - isParallelIterator(iteratorTypes[1]) && - isReductionIterator(iteratorTypes[2]))) + if (!(vector::isParallelIterator(iteratorTypes[0]) && + vector::isParallelIterator(iteratorTypes[1]) && + vector::isReductionIterator(iteratorTypes[2]))) return false; // The contract needs to represent a matmul to be able to convert to @@ -296,9 +296,9 @@ static constexpr std::array perm = {1, 0}; auto iteratorTypes = op.getIteratorTypes().getValue(); SmallVector maps = op.getIndexingMapsArray(); - if (!(isParallelIterator(iteratorTypes[0]) && - isParallelIterator(iteratorTypes[1]) && - isReductionIterator(iteratorTypes[2]))) + if (!(vector::isParallelIterator(iteratorTypes[0]) && + vector::isParallelIterator(iteratorTypes[1]) && + vector::isReductionIterator(iteratorTypes[2]))) return failure(); // // Two outer parallel, one inner reduction (matmat flavor). diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1488,13 +1488,14 @@ // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f} Value conv1dSliceAsContraction(OpBuilder &b, Location loc, Value lhs, Value rhs, Value res) { - StringRef par = Par().strRef, red = Red().strRef; + vector::IteratorType par = vector::IteratorType::parallel; + vector::IteratorType red = vector::IteratorType::reduction; AffineExpr n, w, f, c; bindDims(ctx, n, w, f, c); return builder.create( loc, lhs, rhs, res, /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}}, - /*iteratorTypes=*/ArrayRef{par, par, par, red}); + /*iteratorTypes=*/ArrayRef{par, par, par, red}); } /// Generate a vector implementation for: diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -199,6 +199,16 @@ return count(indexCounts, 1) == static_cast(permutation.size()); } +bool isParallelIterator(Attribute attr) { + auto strAttr = attr.dyn_cast_or_null(); + return strAttr && strAttr.getValue() == getParallelIteratorTypeName(); +} + +bool isReductionIterator(Attribute attr) { + auto strAttr = attr.dyn_cast_or_null(); + return strAttr && strAttr.getValue() == getReductionIteratorTypeName(); +} + /// Helper function that creates a memref::DimOp or tensor::DimOp depending on /// the type of `source`. Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -350,7 +350,7 @@ if (isMaterializing(lhs->get())) { unsigned nest = 0; for (unsigned i = 0; i < numLoops; i++) { - if (isReductionIterator(iteratorTypes[topSort[i]])) + if (linalg::isReductionIterator(iteratorTypes[topSort[i]])) break; // terminate at first reduction nest++; } @@ -1234,7 +1234,7 @@ unsigned tensor = merger.tensor(fb); assert(idx == merger.index(fb)); auto iteratorTypes = op.iterator_types().getValue(); - bool isReduction = isReductionIterator(iteratorTypes[idx]); + bool isReduction = linalg::isReductionIterator(iteratorTypes[idx]); bool isSparse = merger.isDim(fb, Dim::kSparse); bool isVector = isVectorFor(codegen, isInner, isReduction, isSparse) && denseUnitStrides(merger, op, idx); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -455,14 +455,18 @@ void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, Value lhs, Value rhs, Value acc, ArrayRef> indexingExprs, - ArrayRef iteratorTypes) { + ArrayRef iteratorTypes) { result.addOperands({lhs, rhs, acc}); result.addTypes(acc.getType()); result.addAttribute(::mlir::getIndexingMapsAttrName(), builder.getAffineMapArrayAttr( AffineMap::inferFromExprList(indexingExprs))); - result.addAttribute(::mlir::getIteratorTypesAttrName(), - builder.getStrArrayAttr(iteratorTypes)); + result.addAttribute( + ::mlir::getIteratorTypesAttrName(), + builder.getArrayAttr(llvm::to_vector(llvm::map_range( + iteratorTypes, [&](IteratorType t) -> mlir::Attribute { + return IteratorTypeAttr::get(builder.getContext(), t); + })))); } void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, @@ -510,6 +514,27 @@ return failure(); result.attributes.assign(dictAttr.getValue().begin(), dictAttr.getValue().end()); + + // Convert array of string into an array of IteratyType enums. This is needed, + // because tests still use the old format when 'iterator_types' attribute is + // represented as an array of strings. + // TODO: Remove this conversion once tests are fixed. + ArrayAttr iteratorTypes = + result.attributes.get("iterator_types").cast(); + + SmallVector iteratorTypeAttrs; + + for (StringRef s : iteratorTypes.getAsValueRange()) { + auto maybeIteratorType = symbolizeIteratorType(s); + if (!maybeIteratorType.hasValue()) + return parser.emitError(loc) << "unexpected iterator_type (" << s << ")"; + + iteratorTypeAttrs.push_back(IteratorTypeAttr::get( + parser.getContext(), maybeIteratorType.getValue())); + } + result.attributes.set("iterator_types", + parser.getBuilder().getArrayAttr(iteratorTypeAttrs)); + if (!result.attributes.get(ContractionOp::getKindAttrStrName())) { result.addAttribute( ContractionOp::getKindAttrStrName(), @@ -538,9 +563,26 @@ llvm::StringSet<> traitAttrsSet; traitAttrsSet.insert(attrNames.begin(), attrNames.end()); SmallVector attrs; - for (auto attr : (*this)->getAttrs()) - if (traitAttrsSet.count(attr.getName().strref()) > 0) + for (auto attr : (*this)->getAttrs()) { + if (attr.getName() == getIteratorTypesAttrName()) { + auto iteratorTypes = + attr.getValue() + .cast() + .getAsValueRange(); + // Convert IteratorType enums into the string representation. This is + // needed, because tests still use the old format when 'iterator_types' + // attribute is represented as an array of strings. + // TODO: Remove this conversion once tests are fixed. + SmallVector iteratorTypeNames = llvm::to_vector( + llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute { + return StringAttr::get(getContext(), stringifyIteratorType(t)); + })); + + attrs.emplace_back(getIteratorTypesAttrName(), + ArrayAttr::get(getContext(), iteratorTypeNames)); + } else if (traitAttrsSet.count(attr.getName().strref()) > 0) attrs.push_back(attr); + } auto dictAttr = DictionaryAttr::get(getContext(), attrs); p << " " << dictAttr << " " << getLhs() << ", "; @@ -746,11 +788,11 @@ static std::vector> getDimMap(ArrayRef indexingMaps, ArrayAttr iteratorTypes, - StringRef targetIteratorTypeName, MLIRContext *context) { + IteratorType targetIteratorType, MLIRContext *context) { std::vector> dimMap; for (const auto &it : llvm::enumerate(iteratorTypes)) { - auto iteratorTypeName = it.value().cast().getValue(); - if (iteratorTypeName != targetIteratorTypeName) + auto iteratorType = it.value().cast().getValue(); + if (iteratorType != targetIteratorType) continue; // Search lhs/rhs map results for 'targetExpr'. auto targetExpr = getAffineDimExpr(it.index(), context); @@ -771,8 +813,8 @@ for (const auto &it : llvm::enumerate(getIteratorTypes())) { // Search lhs/rhs map results for 'targetExpr'. auto targetExpr = getAffineDimExpr(it.index(), getContext()); - auto iteratorTypeName = it.value().cast().getValue(); - if (iteratorTypeName == getReductionIteratorTypeName()) { + auto iteratorType = it.value().cast().getValue(); + if (iteratorType == IteratorType::reduction) { // Get reduction dim size from lhs shape (same size in rhsShape). int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr); assert(lhsDimIndex >= 0); @@ -803,14 +845,14 @@ std::vector> ContractionOp::getContractingDimMap() { SmallVector indexingMaps(getIndexingMapsArray()); - return getDimMap(indexingMaps, getIteratorTypes(), - getReductionIteratorTypeName(), getContext()); + return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction, + getContext()); } std::vector> ContractionOp::getBatchDimMap() { SmallVector indexingMaps(getIndexingMapsArray()); - return getDimMap(indexingMaps, getIteratorTypes(), - getParallelIteratorTypeName(), getContext()); + return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel, + getContext()); } Optional> ContractionOp::getShapeForUnroll() { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -22,6 +22,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" @@ -986,13 +987,13 @@ SmallVector reductionMask = reduceOp.getReductionMask(); auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size()); SmallVector exprs; - SmallVector iteratorTypes; + SmallVector iteratorTypes; for (const auto &isReduceDim : llvm::enumerate(reductionMask)) { if (!isReduceDim.value()) { - iteratorTypes.push_back(getParallelIteratorTypeName()); + iteratorTypes.push_back(vector::IteratorType::parallel); exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index())); } else { - iteratorTypes.push_back(getReductionIteratorTypeName()); + iteratorTypes.push_back(vector::IteratorType::reduction); } } auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(), @@ -1000,7 +1001,10 @@ rewriter.replaceOpWithNewOp( reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(), rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}), - rewriter.getStrArrayAttr(iteratorTypes)); + rewriter.getArrayAttr(llvm::to_vector(llvm::map_range( + iteratorTypes, [&](IteratorType t) -> mlir::Attribute { + return IteratorTypeAttr::get(rewriter.getContext(), t); + })))); return success(); } };