diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -48,37 +48,6 @@ /// the reduction. bool isRowMajorBatchMatmul(ArrayAttr indexingMaps); -/// Attribute name for the AffineArrayAttr which encodes the relationship -/// between a structured op iterators' and its operands. -constexpr StringRef getIndexingMapsAttrName() { return "indexing_maps"; } - -/// Attribute name for the StrArrayAttr which encodes the type of a structured -/// op's iterators. -constexpr StringRef getIteratorTypesAttrName() { return "iterator_types"; } - -/// Attribute name for the StrArrayAttr which encodes the distribution type for -/// `linalg.tiled_loop`. -constexpr StringRef getDistributionTypesAttrName() { - return "distribution_types"; -} - -/// Attribute name for the StringAttr which encodes an optional documentation -/// string of the structured op. -constexpr StringRef getDocAttrName() { return "doc"; } - -/// Attribute name for the StrArrayAttr which encodes the external library -/// function that implements the structured op. -constexpr StringRef getLibraryCallAttrName() { return "library_call"; } - -/// Attribute name for the StrArrayAttr which encodes the value of strides. -constexpr StringRef getStridesAttrName() { return "strides"; } - -/// Attribute name for the StrArrayAttr which encodes the value of dilations. -constexpr StringRef getDilationsAttrName() { return "dilations"; } - -/// Attribute name for the StrArrayAttr which encodes the value of paddings. -constexpr StringRef getPaddingAttrName() { return "padding"; } - /// Use to encode that a particular iterator type has parallel semantics. constexpr StringRef getParallelIteratorTypeName() { return "parallel"; } diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -244,7 +244,7 @@ return getOperand(4).getType().cast(); } Type getResultType() { return getResult().getType(); } - ArrayRef getTraitAttrNames(); + SmallVector getTraitAttrNames(); static unsigned getAccOperandIndex() { return 2; } llvm::SmallVector<::mlir::AffineMap, 4> getIndexingMapsArray() { @@ -265,8 +265,6 @@ std::vector> getContractingDimMap(); std::vector> getBatchDimMap(); - static constexpr StringRef getKindAttrStrName() { return "kind"; } - static CombiningKind getDefaultKind() { return CombiningKind::ADD; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp @@ -73,8 +73,8 @@ m = m.compose(permutationMap); newIndexingMaps.push_back(m); } - genericOp->setAttr(getIndexingMapsAttrName(), - rewriter.getAffineMapArrayAttr(newIndexingMaps)); + genericOp.setIndexingMapsAttr( + rewriter.getAffineMapArrayAttr(newIndexingMaps)); // 3. Compute the interchanged iterator types. ArrayRef itTypes = genericOp.getIteratorTypes().getValue(); @@ -83,8 +83,7 @@ SmallVector permutation(interchangeVector.begin(), interchangeVector.end()); applyPermutationToVector(itTypesVector, permutation); - genericOp->setAttr(getIteratorTypesAttrName(), - ArrayAttr::get(context, itTypesVector)); + genericOp.setIteratorTypesAttr(rewriter.getArrayAttr(itTypesVector)); // 4. Transform the index operations by applying the permutation map. if (genericOp.hasIndexSemantics()) { diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -508,11 +508,11 @@ ArrayRef iteratorTypes) { result.addOperands({lhs, rhs, acc}); result.addTypes(acc.getType()); - result.addAttribute(::mlir::getIndexingMapsAttrName(), + result.addAttribute(getIndexingMapsAttrName(result.name), builder.getAffineMapArrayAttr( AffineMap::inferFromExprList(indexingExprs))); result.addAttribute( - ::mlir::getIteratorTypesAttrName(), + getIteratorTypesAttrName(result.name), builder.getArrayAttr(llvm::to_vector(llvm::map_range( iteratorTypes, [&](IteratorType t) -> mlir::Attribute { return IteratorTypeAttr::get(builder.getContext(), t); @@ -533,9 +533,9 @@ ArrayAttr iteratorTypes, CombiningKind kind) { result.addOperands({lhs, rhs, acc}); result.addTypes(acc.getType()); - result.addAttribute(::mlir::getIndexingMapsAttrName(), indexingMaps); - result.addAttribute(::mlir::getIteratorTypesAttrName(), iteratorTypes); - result.addAttribute(ContractionOp::getKindAttrStrName(), + result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps); + result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes); + result.addAttribute(getKindAttrName(result.name), CombiningKindAttr::get(builder.getContext(), kind)); } @@ -570,7 +570,8 @@ // represented as an array of strings. // TODO: Remove this conversion once tests are fixed. ArrayAttr iteratorTypes = - result.attributes.get("iterator_types").cast(); + result.attributes.get(getIteratorTypesAttrName(result.name)) + .cast(); SmallVector iteratorTypeAttrs; @@ -579,15 +580,15 @@ if (!maybeIteratorType.has_value()) return parser.emitError(loc) << "unexpected iterator_type (" << s << ")"; - iteratorTypeAttrs.push_back(IteratorTypeAttr::get( - parser.getContext(), maybeIteratorType.value())); + iteratorTypeAttrs.push_back( + IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value())); } - result.attributes.set("iterator_types", + result.attributes.set(getIteratorTypesAttrName(result.name), parser.getBuilder().getArrayAttr(iteratorTypeAttrs)); - if (!result.attributes.get(ContractionOp::getKindAttrStrName())) { + if (!result.attributes.get(getKindAttrName(result.name))) { result.addAttribute( - ContractionOp::getKindAttrStrName(), + getKindAttrName(result.name), CombiningKindAttr::get(result.getContext(), ContractionOp::getDefaultKind())); } @@ -822,11 +823,9 @@ return success(); } -ArrayRef ContractionOp::getTraitAttrNames() { - static constexpr StringRef names[3] = {::mlir::getIndexingMapsAttrName(), - ::mlir::getIteratorTypesAttrName(), - ContractionOp::getKindAttrStrName()}; - return llvm::makeArrayRef(names); +SmallVector ContractionOp::getTraitAttrNames() { + return SmallVector{getIndexingMapsAttrName(), + getIteratorTypesAttrName(), getKindAttrName()}; } static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {