diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -354,9 +354,6 @@ "ArrayRef":$reductionMask, "CombiningKind":$kind)> ]; let extraClassDeclaration = [{ - static StringRef getKindAttrStrName() { return "kind"; } - static StringRef getReductionDimsAttrStrName() { return "reduction_dims"; } - VectorType getSourceVectorType() { return ::llvm::cast(getSource().getType()); } @@ -510,7 +507,6 @@ let hasFolder = 1; let hasCanonicalizer = 1; let extraClassDeclaration = [{ - static StringRef getMaskAttrStrName() { return "mask"; } VectorType getV1VectorType() { return ::llvm::cast(getV1().getType()); } @@ -599,7 +595,6 @@ OpBuilder<(ins "Value":$source, "ValueRange":$position)> ]; let extraClassDeclaration = [{ - static StringRef getPositionAttrStrName() { return "position"; } VectorType getSourceVectorType() { return ::llvm::cast(getVector().getType()); } @@ -723,7 +718,6 @@ OpBuilder<(ins "Value":$source, "Value":$dest, "ValueRange":$position)> ]; let extraClassDeclaration = [{ - static StringRef getPositionAttrStrName() { return "position"; } Type getSourceType() { return getSource().getType(); } VectorType getDestVectorType() { return ::llvm::cast(getDest().getType()); @@ -882,8 +876,6 @@ "ArrayRef":$offsets, "ArrayRef":$strides)> ]; let extraClassDeclaration = [{ - static StringRef getOffsetsAttrStrName() { return "offsets"; } - static StringRef getStridesAttrStrName() { return "strides"; } VectorType getSourceVectorType() { return ::llvm::cast(getSource().getType()); } @@ -981,9 +973,6 @@ VectorType getResultVectorType() { return ::llvm::cast(getResult().getType()); } - static constexpr StringRef getKindAttrStrName() { - return "kind"; - } static CombiningKind getDefaultKind() { return CombiningKind::ADD; } @@ -1099,12 +1088,6 @@ int64_t getNumOutputShapeSizes() { return getOutputShape().size(); } void getFixedVectorSizes(SmallVectorImpl &results); - - static StringRef getFixedVectorSizesAttrStrName() { - return "fixed_vector_sizes"; - } - static StringRef getInputShapeAttrStrName() { return "input_shape"; } - static StringRef getOutputShapeAttrStrName() { return "output_shape"; } }]; let assemblyFormat = [{ @@ -1151,9 +1134,6 @@ "ArrayRef":$sizes, "ArrayRef":$strides)> ]; let extraClassDeclaration = [{ - static StringRef getOffsetsAttrStrName() { return "offsets"; } - static StringRef getSizesAttrStrName() { return "sizes"; } - static StringRef getStridesAttrStrName() { return "strides"; } VectorType getSourceVectorType() { return ::llvm::cast(getVector().getType()); } @@ -2288,9 +2268,6 @@ ``` }]; - let extraClassDeclaration = [{ - static StringRef getMaskDimSizesAttrStrName() { return "mask_dim_sizes"; } - }]; let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)"; let hasVerifier = 1; } @@ -2473,7 +2450,6 @@ return ::llvm::cast(getResult().getType()); } void getTransp(SmallVectorImpl &results); - static StringRef getTranspAttrStrName() { return "transp"; } }]; let assemblyFormat = [{ $vector `,` $transp attr-dict `:` type($vector) `to` type($result) @@ -2738,8 +2714,6 @@ CArg<"bool", "true">:$inclusive)> ]; let extraClassDeclaration = [{ - static StringRef getKindAttrStrName() { return "kind"; } - static StringRef getReductionDimAttrStrName() { return "reduction_dim"; } VectorType getSourceType() { return ::llvm::cast(getSource().getType()); } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1218,8 +1218,7 @@ // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); std::reverse(globalPosition.begin(), globalPosition.end()); - extractOp->setAttr(ExtractOp::getPositionAttrStrName(), - b.getI64ArrayAttr(globalPosition)); + extractOp.setPositionAttr(b.getI64ArrayAttr(globalPosition)); return success(); } @@ -1499,8 +1498,7 @@ // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); extractOp.setOperand(source); - extractOp->setAttr(ExtractOp::getPositionAttrStrName(), - b.getI64ArrayAttr(extractPos)); + extractOp.setPositionAttr(b.getI64ArrayAttr(extractPos)); return extractOp.getResult(); } @@ -1565,8 +1563,7 @@ SmallVector newPosition = delinearize(position, newStrides); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); - extractOp->setAttr(ExtractOp::getPositionAttrStrName(), - b.getI64ArrayAttr(newPosition)); + extractOp.setPositionAttr(b.getI64ArrayAttr(newPosition)); extractOp.setOperand(shapeCastOp.getSource()); return extractOp.getResult(); } @@ -1613,8 +1610,7 @@ extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector()); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); - extractOp->setAttr(ExtractOp::getPositionAttrStrName(), - b.getI64ArrayAttr(extractedPos)); + extractOp.setPositionAttr(b.getI64ArrayAttr(extractedPos)); return extractOp.getResult(); } @@ -1679,8 +1675,7 @@ extractOp.getVectorMutable().assign(insertOp.getSource()); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); - extractOp->setAttr(ExtractOp::getPositionAttrStrName(), - b.getI64ArrayAttr(offsetDiffs)); + extractOp.setPositionAttr(b.getI64ArrayAttr(offsetDiffs)); return extractOp.getResult(); } // If the chunk extracted is disjoint from the chunk inserted, keep @@ -2300,7 +2295,7 @@ result.addOperands({source, dest}); auto positionAttr = getVectorSubscriptAttr(builder, position); result.addTypes(dest.getType()); - result.addAttribute(getPositionAttrStrName(), positionAttr); + result.addAttribute(InsertOp::getPositionAttrName(result.name), positionAttr); } // Convenience builder which assumes the values are constant indices. @@ -2467,8 +2462,10 @@ auto offsetsAttr = getVectorSubscriptAttr(builder, offsets); auto stridesAttr = getVectorSubscriptAttr(builder, strides); result.addTypes(dest.getType()); - result.addAttribute(getOffsetsAttrStrName(), offsetsAttr); - result.addAttribute(getStridesAttrStrName(), stridesAttr); + result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name), + offsetsAttr); + result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name), + stridesAttr); } // TODO: Should be moved to Tablegen ConfinedAttr attributes. @@ -2790,9 +2787,9 @@ scalableDimsRes); } - if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) { + if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) { result.attributes.append( - OuterProductOp::getKindAttrStrName(), + OuterProductOp::getKindAttrName(result.name), CombiningKindAttr::get(result.getContext(), OuterProductOp::getDefaultKind())); } @@ -2951,9 +2948,12 @@ result.addTypes( inferStridedSliceOpResultType(llvm::cast(source.getType()), offsetsAttr, sizesAttr, stridesAttr)); - result.addAttribute(getOffsetsAttrStrName(), offsetsAttr); - result.addAttribute(getSizesAttrStrName(), sizesAttr); - result.addAttribute(getStridesAttrStrName(), stridesAttr); + result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name), + offsetsAttr); + result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name), + sizesAttr); + result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name), + stridesAttr); } LogicalResult ExtractStridedSliceOp::verify() { @@ -3046,8 +3046,7 @@ op.setOperand(insertOp.getSource()); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(op.getContext()); - op->setAttr(ExtractStridedSliceOp::getOffsetsAttrStrName(), - b.getI64ArrayAttr(offsetDiffs)); + op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs)); return success(); } // If the chunk extracted is disjoint from the chunk inserted, keep looking @@ -4973,7 +4972,8 @@ result.addOperands(vector); result.addTypes(VectorType::get(transposedShape, vt.getElementType())); - result.addAttribute(getTranspAttrStrName(), builder.getI64ArrayAttr(transp)); + result.addAttribute(TransposeOp::getTranspAttrName(result.name), + builder.getI64ArrayAttr(transp)); } OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {