diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -22,6 +22,7 @@ let cppNamespace = "::mlir::vector"; let hasConstantMaterializer = 1; let dependentDialects = ["arith::ArithmeticDialect"]; + let emitAccessorPrefix = kEmitAccessorPrefix_Both; } // Base class for Vector dialect ops. @@ -63,6 +64,18 @@ "::mlir::vector::CombiningKindAttr::get($0, $_builder.getContext())"; } +def Vector_AffineMapArrayAttr : TypedArrayAttrBase { + let returnType = [{ ::llvm::SmallVector<::mlir::AffineMap, 4> }]; + let convertFromStorage = [{ + llvm::to_vector<4>( + llvm::map_range($_self, [](::Attribute mapAttr) { + return mapAttr.cast<::mlir::AffineMapAttr>().getValue(); + })); + }]; + let constBuilderCall = "$_builder.getAffineMapArrayAttr($0)"; +} + // TODO: Add an attribute to specify a different algebra with operators other // than the current set: {*, +}. def Vector_ContractionOp : @@ -75,7 +88,8 @@ ]>, Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc, Variadic>:$masks, - AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types, + Vector_AffineMapArrayAttr:$indexing_maps, + ArrayAttr:$iterator_types, DefaultValuedAttr:$kind)>, Results<(outs AnyType)> { @@ -223,7 +237,6 @@ } Type getResultType() { return getResult().getType(); } ArrayRef getTraitAttrNames(); - SmallVector getIndexingMaps(); static unsigned getAccOperandIndex() { return 2; } // Returns the bounds of each dimension in the iteration space spanned @@ -240,7 +253,7 @@ std::vector> getContractingDimMap(); std::vector> getBatchDimMap(); - static constexpr StringRef getKindAttrName() { return "kind"; } + static constexpr StringRef getKindAttrStrName() { return "kind"; } static CombiningKind getDefaultKind() { return CombiningKind::ADD; @@ -327,8 +340,8 @@ "CombiningKind":$kind)> ]; let extraClassDeclaration = [{ - static StringRef getKindAttrName() { return "kind"; } - static StringRef getReductionDimsAttrName() { return "reduction_dims"; } + static StringRef getKindAttrStrName() { return "kind"; } + static StringRef getReductionDimsAttrStrName() { return "reduction_dims"; } VectorType getSourceVectorType() { return source().getType().cast(); @@ -474,7 +487,7 @@ ]; let hasFolder = 1; let extraClassDeclaration = [{ - static StringRef getMaskAttrName() { return "mask"; } + static StringRef getMaskAttrStrName() { return "mask"; } VectorType getV1VectorType() { return v1().getType().cast(); } @@ -561,7 +574,7 @@ OpBuilder<(ins "Value":$source, "ValueRange":$position)> ]; let extraClassDeclaration = [{ - static StringRef getPositionAttrName() { return "position"; } + static StringRef getPositionAttrStrName() { return "position"; } VectorType getVectorType() { return vector().getType().cast(); } @@ -754,7 +767,7 @@ OpBuilder<(ins "Value":$source, "Value":$dest, "ValueRange":$position)> ]; let extraClassDeclaration = [{ - static StringRef getPositionAttrName() { return "position"; } + static StringRef getPositionAttrStrName() { return "position"; } Type getSourceType() { return source().getType(); } VectorType getDestVectorType() { return dest().getType().cast(); @@ -873,15 +886,15 @@ "ArrayRef":$offsets, "ArrayRef":$strides)> ]; let extraClassDeclaration = [{ - static StringRef getOffsetsAttrName() { return "offsets"; } - static StringRef getStridesAttrName() { return "strides"; } + static StringRef getOffsetsAttrStrName() { return "offsets"; } + static StringRef getStridesAttrStrName() { return "strides"; } VectorType getSourceVectorType() { return source().getType().cast(); } VectorType getDestVectorType() { return dest().getType().cast(); } - bool hasNonUnitStrides() { + bool hasNonUnitStrides() { return llvm::any_of(strides(), [](Attribute attr) { return attr.cast().getInt() != 1; }); @@ -970,7 +983,7 @@ VectorType getVectorType() { return getResult().getType().cast(); } - static constexpr StringRef getKindAttrName() { + static constexpr StringRef getKindAttrStrName() { return "kind"; } static CombiningKind getDefaultKind() { @@ -1089,11 +1102,11 @@ void getFixedVectorSizes(SmallVectorImpl &results); - static StringRef getFixedVectorSizesAttrName() { + static StringRef getFixedVectorSizesAttrStrName() { return "fixed_vector_sizes"; } - static StringRef getInputShapeAttrName() { return "input_shape"; } - static StringRef getOutputShapeAttrName() { return "output_shape"; } + static StringRef getInputShapeAttrStrName() { return "input_shape"; } + static StringRef getOutputShapeAttrStrName() { return "output_shape"; } }]; let assemblyFormat = [{ @@ -1140,12 +1153,12 @@ "ArrayRef":$sizes, "ArrayRef":$strides)> ]; let extraClassDeclaration = [{ - static StringRef getOffsetsAttrName() { return "offsets"; } - static StringRef getSizesAttrName() { return "sizes"; } - static StringRef getStridesAttrName() { return "strides"; } + static StringRef getOffsetsAttrStrName() { return "offsets"; } + static StringRef getSizesAttrStrName() { return "sizes"; } + static StringRef getStridesAttrStrName() { return "strides"; } VectorType getVectorType(){ return vector().getType().cast(); } void getOffsets(SmallVectorImpl &results); - bool hasNonUnitStrides() { + bool hasNonUnitStrides() { return llvm::any_of(strides(), [](Attribute attr) { return attr.cast().getInt() != 1; }); @@ -2190,7 +2203,7 @@ }]; let extraClassDeclaration = [{ - static StringRef getMaskDimSizesAttrName() { return "mask_dim_sizes"; } + static StringRef getMaskDimSizesAttrStrName() { return "mask_dim_sizes"; } }]; let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)"; let hasVerifier = 1; @@ -2276,7 +2289,7 @@ return result().getType().cast(); } void getTransp(SmallVectorImpl &results); - static StringRef getTranspAttrName() { return "transp"; } + static StringRef getTranspAttrStrName() { return "transp"; } }]; let assemblyFormat = [{ $vector `,` $transp attr-dict `:` type($vector) `to` type($result) @@ -2537,8 +2550,8 @@ CArg<"bool", "true">:$inclusive)> ]; let extraClassDeclaration = [{ - static StringRef getKindAttrName() { return "kind"; } - static StringRef getReductionDimAttrName() { return "reduction_dim"; } + static StringRef getKindAttrStrName() { return "kind"; } + static StringRef getReductionDimAttrStrName() { return "reduction_dim"; } VectorType getSourceType() { return source().getType().cast(); } diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td --- a/mlir/include/mlir/Interfaces/VectorInterfaces.td +++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td @@ -55,7 +55,7 @@ StaticInterfaceMethod< /*desc=*/"Return the `in_bounds` attribute name.", /*retTy=*/"::mlir::StringRef", - /*methodName=*/"getInBoundsAttrName", + /*methodName=*/"getInBoundsAttrStrName", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/ [{ return "in_bounds"; }] @@ -63,7 +63,7 @@ StaticInterfaceMethod< /*desc=*/"Return the `permutation_map` attribute name.", /*retTy=*/"::mlir::StringRef", - /*methodName=*/"getPermutationMapAttrName", + /*methodName=*/"getPermutationMapAttrStrName", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/ [{ return "permutation_map"; }] diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -318,7 +318,7 @@ write.insertSliceOp.destMutable().assign(read.extractSliceOp.source()); } else { newForOp.getResult(initArgNumber) - .replaceAllUsesWith(write.transferWriteOp.getResult(0)); + .replaceAllUsesWith(write.transferWriteOp.getResult()); write.transferWriteOp.sourceMutable().assign( newForOp.getResult(initArgNumber)); } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -347,9 +347,9 @@ for (const auto &en : llvm::enumerate(reductionMask)) if (en.value()) reductionDims.push_back(en.index()); - result.addAttribute(getReductionDimsAttrName(), + result.addAttribute(getReductionDimsAttrStrName(), builder.getI64ArrayAttr(reductionDims)); - result.addAttribute(getKindAttrName(), + result.addAttribute(getKindAttrStrName(), CombiningKindAttr::get(kind, builder.getContext())); } @@ -491,10 +491,10 @@ ArrayRef iteratorTypes) { result.addOperands({lhs, rhs, acc}); result.addTypes(acc.getType()); - result.addAttribute(getIndexingMapsAttrName(), + result.addAttribute(::mlir::getIndexingMapsAttrName(), builder.getAffineMapArrayAttr( AffineMap::inferFromExprList(indexingExprs))); - result.addAttribute(getIteratorTypesAttrName(), + result.addAttribute(::mlir::getIteratorTypesAttrName(), builder.getStrArrayAttr(iteratorTypes)); } @@ -512,9 +512,9 @@ ArrayAttr iteratorTypes, CombiningKind kind) { result.addOperands({lhs, rhs, acc}); result.addTypes(acc.getType()); - result.addAttribute(getIndexingMapsAttrName(), indexingMaps); - result.addAttribute(getIteratorTypesAttrName(), iteratorTypes); - result.addAttribute(ContractionOp::getKindAttrName(), + result.addAttribute(::mlir::getIndexingMapsAttrName(), indexingMaps); + result.addAttribute(::mlir::getIteratorTypesAttrName(), iteratorTypes); + result.addAttribute(ContractionOp::getKindAttrStrName(), CombiningKindAttr::get(kind, builder.getContext())); } @@ -543,8 +543,8 @@ return failure(); result.attributes.assign(dictAttr.getValue().begin(), dictAttr.getValue().end()); - if (!result.attributes.get(ContractionOp::getKindAttrName())) { - result.addAttribute(ContractionOp::getKindAttrName(), + if (!result.attributes.get(ContractionOp::getKindAttrStrName())) { + result.addAttribute(ContractionOp::getKindAttrStrName(), CombiningKindAttr::get(ContractionOp::getDefaultKind(), result.getContext())); } @@ -698,7 +698,7 @@ unsigned numIterators = iterator_types().getValue().size(); for (const auto &it : llvm::enumerate(indexing_maps())) { auto index = it.index(); - auto map = it.value().cast().getValue(); + auto map = it.value(); if (map.getNumSymbols() != 0) return emitOpError("expected indexing map ") << index << " to have no symbols"; @@ -759,9 +759,9 @@ } ArrayRef ContractionOp::getTraitAttrNames() { - static constexpr StringRef names[3] = {getIndexingMapsAttrName(), - getIteratorTypesAttrName(), - ContractionOp::getKindAttrName()}; + static constexpr StringRef names[3] = {::mlir::getIndexingMapsAttrName(), + ::mlir::getIteratorTypesAttrName(), + ContractionOp::getKindAttrStrName()}; return llvm::makeArrayRef(names); } @@ -817,11 +817,11 @@ void ContractionOp::getIterationIndexMap( std::vector> &iterationIndexMap) { - unsigned numMaps = indexing_maps().getValue().size(); + unsigned numMaps = indexing_maps().size(); iterationIndexMap.resize(numMaps); for (const auto &it : llvm::enumerate(indexing_maps())) { auto index = it.index(); - auto map = it.value().cast().getValue(); + auto map = it.value(); for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) { auto dim = map.getResult(i).cast(); iterationIndexMap[index][dim.getPosition()] = i; @@ -841,13 +841,6 @@ getParallelIteratorTypeName(), getContext()); } -SmallVector ContractionOp::getIndexingMaps() { - return llvm::to_vector<4>( - llvm::map_range(indexing_maps().getValue(), [](Attribute mapAttr) { - return mapAttr.cast().getValue(); - })); -} - Optional> ContractionOp::getShapeForUnroll() { SmallVector shape; getIterationBounds(shape); @@ -961,7 +954,7 @@ auto positionAttr = getVectorSubscriptAttr(builder, position); result.addTypes(inferExtractOpResultType(source.getType().cast(), positionAttr)); - result.addAttribute(getPositionAttrName(), positionAttr); + result.addAttribute(getPositionAttrStrName(), positionAttr); } // Convenience builder which assumes the values are constant indices. @@ -1053,7 +1046,7 @@ // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); std::reverse(globalPosition.begin(), globalPosition.end()); - extractOp->setAttr(ExtractOp::getPositionAttrName(), + extractOp->setAttr(ExtractOp::getPositionAttrStrName(), b.getI64ArrayAttr(globalPosition)); return success(); } @@ -1295,7 +1288,7 @@ extractOp.setOperand(source); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); - extractOp->setAttr(ExtractOp::getPositionAttrName(), + extractOp->setAttr(ExtractOp::getPositionAttrStrName(), b.getI64ArrayAttr(extractPos)); return extractOp.getResult(); } @@ -1355,7 +1348,7 @@ SmallVector newPosition = delinearize(newStrides, position); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); - extractOp->setAttr(ExtractOp::getPositionAttrName(), + extractOp->setAttr(ExtractOp::getPositionAttrStrName(), b.getI64ArrayAttr(newPosition)); extractOp.setOperand(shapeCastOp.source()); return extractOp.getResult(); @@ -1396,7 +1389,7 @@ extractOp.vectorMutable().assign(extractStridedSliceOp.vector()); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); - extractOp->setAttr(ExtractOp::getPositionAttrName(), + extractOp->setAttr(ExtractOp::getPositionAttrStrName(), b.getI64ArrayAttr(extractedPos)); return extractOp.getResult(); } @@ -1453,7 +1446,7 @@ op.vectorMutable().assign(insertOp.source()); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(op.getContext()); - op->setAttr(ExtractOp::getPositionAttrName(), + op->setAttr(ExtractOp::getPositionAttrStrName(), b.getI64ArrayAttr(offsetDiffs)); return op.getResult(); } @@ -1736,7 +1729,7 @@ auto shape = llvm::to_vector<4>(v1Type.getShape()); shape[0] = mask.size(); result.addTypes(VectorType::get(shape, v1Type.getElementType())); - result.addAttribute(getMaskAttrName(), maskAttr); + result.addAttribute(getMaskAttrStrName(), maskAttr); } void ShuffleOp::print(OpAsmPrinter &p) { @@ -1784,7 +1777,7 @@ VectorType v1Type, v2Type; if (parser.parseOperand(v1) || parser.parseComma() || parser.parseOperand(v2) || - parser.parseAttribute(attr, ShuffleOp::getMaskAttrName(), + parser.parseAttribute(attr, ShuffleOp::getMaskAttrStrName(), result.attributes) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(v1Type) || parser.parseComma() || @@ -1877,7 +1870,7 @@ result.addOperands({source, dest}); auto positionAttr = getVectorSubscriptAttr(builder, position); result.addTypes(dest.getType()); - result.addAttribute(getPositionAttrName(), positionAttr); + result.addAttribute(getPositionAttrStrName(), positionAttr); } // Convenience builder which assumes the values are constant indices. @@ -1995,8 +1988,8 @@ auto offsetsAttr = getVectorSubscriptAttr(builder, offsets); auto stridesAttr = getVectorSubscriptAttr(builder, strides); result.addTypes(dest.getType()); - result.addAttribute(getOffsetsAttrName(), offsetsAttr); - result.addAttribute(getStridesAttrName(), stridesAttr); + result.addAttribute(getOffsetsAttrStrName(), offsetsAttr); + result.addAttribute(getStridesAttrStrName(), stridesAttr); } // TODO: Should be moved to Tablegen Confined attributes. @@ -2172,9 +2165,9 @@ vLHS.getElementType()) : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType()); - if (!result.attributes.get(OuterProductOp::getKindAttrName())) { + if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) { result.attributes.append( - OuterProductOp::getKindAttrName(), + OuterProductOp::getKindAttrStrName(), CombiningKindAttr::get(OuterProductOp::getDefaultKind(), result.getContext())); } @@ -2322,9 +2315,9 @@ result.addTypes( inferStridedSliceOpResultType(source.getType().cast(), offsetsAttr, sizesAttr, stridesAttr)); - result.addAttribute(getOffsetsAttrName(), offsetsAttr); - result.addAttribute(getSizesAttrName(), sizesAttr); - result.addAttribute(getStridesAttrName(), stridesAttr); + result.addAttribute(getOffsetsAttrStrName(), offsetsAttr); + result.addAttribute(getSizesAttrStrName(), sizesAttr); + result.addAttribute(getStridesAttrStrName(), stridesAttr); } LogicalResult ExtractStridedSliceOp::verify() { @@ -2412,7 +2405,7 @@ op.setOperand(insertOp.source()); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(op.getContext()); - op->setAttr(ExtractStridedSliceOp::getOffsetsAttrName(), + op->setAttr(ExtractStridedSliceOp::getOffsetsAttrStrName(), b.getI64ArrayAttr(offsetDiffs)); return success(); } @@ -2765,7 +2758,7 @@ SmallVector elidedAttrs; elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr()); if (op.permutation_map().isMinorIdentity()) - elidedAttrs.push_back(op.getPermutationMapAttrName()); + elidedAttrs.push_back(op.getPermutationMapAttrStrName()); bool elideInBounds = true; if (auto inBounds = op.in_bounds()) { for (auto attr : *inBounds) { @@ -2776,7 +2769,7 @@ } } if (elideInBounds) - elidedAttrs.push_back(op.getInBoundsAttrName()); + elidedAttrs.push_back(op.getInBoundsAttrStrName()); p.printOptionalAttrDict(op->getAttrs(), elidedAttrs); } @@ -2817,7 +2810,7 @@ VectorType vectorType = types[1].dyn_cast(); if (!vectorType) return parser.emitError(typesLoc, "requires vector type"); - auto permutationAttrName = TransferReadOp::getPermutationMapAttrName(); + auto permutationAttrName = TransferReadOp::getPermutationMapAttrStrName(); Attribute mapAttr = result.attributes.get(permutationAttrName); if (!mapAttr) { auto permMap = getTransferMinorIdentityMap(shapedType, vectorType); @@ -2963,7 +2956,7 @@ return failure(); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(op.getContext()); - op->setAttr(TransferOp::getInBoundsAttrName(), + op->setAttr(TransferOp::getInBoundsAttrStrName(), b.getBoolArrayAttr(newInBounds)); return success(); } @@ -3193,7 +3186,7 @@ ShapedType shapedType = types[1].dyn_cast(); if (!shapedType || !shapedType.isa()) return parser.emitError(typesLoc, "requires memref or ranked tensor type"); - auto permutationAttrName = TransferWriteOp::getPermutationMapAttrName(); + auto permutationAttrName = TransferWriteOp::getPermutationMapAttrStrName(); auto attr = result.attributes.get(permutationAttrName); if (!attr) { auto permMap = getTransferMinorIdentityMap(shapedType, vectorType); @@ -4151,7 +4144,7 @@ result.addOperands(vector); result.addTypes(VectorType::get(transposedShape, vt.getElementType())); - result.addAttribute(getTranspAttrName(), builder.getI64ArrayAttr(transp)); + result.addAttribute(getTranspAttrStrName(), builder.getI64ArrayAttr(transp)); } // Eliminates transpose operations, which produce values identical to their diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -514,7 +514,7 @@ SmallVector bools(xferOp.getTransferRank(), true); auto inBoundsAttr = b.getBoolArrayAttr(bools); if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) { - xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr); + xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr); return success(); } @@ -585,7 +585,7 @@ for (unsigned i = 0, e = returnTypes.size(); i != e; ++i) xferReadOp.setOperand(i, fullPartialIfOp.getResult(i)); - xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr); + xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr); return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1050,7 +1050,7 @@ bindDims(rew.getContext(), m, n, k); // LHS must be A(m, k) or A(k, m). Value lhs = op.lhs(); - auto lhsMap = op.indexing_maps()[0].cast().getValue(); + auto lhsMap = op.indexing_maps()[0]; if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) lhs = rew.create(loc, lhs, ArrayRef{1, 0}); else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) @@ -1058,7 +1058,7 @@ // RHS must be B(k, n) or B(n, k). Value rhs = op.rhs(); - auto rhsMap = op.indexing_maps()[1].cast().getValue(); + auto rhsMap = op.indexing_maps()[1]; if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) rhs = rew.create(loc, rhs, ArrayRef{1, 0}); else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) @@ -1088,7 +1088,7 @@ mul); // ACC must be C(m, n) or C(n, m). - auto accMap = op.indexing_maps()[2].cast().getValue(); + auto accMap = op.indexing_maps()[2]; if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) mul = rew.create(loc, mul, ArrayRef{1, 0}); else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))