diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -128,6 +128,7 @@ /// as attributes. DenseIntElementsAttr getI32TensorAttr(ArrayRef values); DenseIntElementsAttr getI64TensorAttr(ArrayRef values); + DenseIntElementsAttr getIndexTensorAttr(ArrayRef values); ArrayAttr getAffineMapArrayAttr(ArrayRef values); ArrayAttr getBoolArrayAttr(ArrayRef values); diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1218,6 +1218,13 @@ let convertFromStorage = "$_self"; } +def IndexElementsAttr + : IntElementsAttrBase() + .getType() + .getElementType() + .isIndex()}]>, + "index elements attribute">; + class AnyIntElementsAttr : IntElementsAttrBase< CPred<"$_self.cast().getType()." "getElementType().isInteger(" # width # ")">, diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -624,6 +624,8 @@ owner.getContext()); return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); } + if (eltTy.isa()) + return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); if (auto floatEltTy = eltTy.dyn_cast()) { IntElementIterator intIt(owner, index); FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -130,6 +130,13 @@ values); } +DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef values) { + return DenseIntElementsAttr::get( + RankedTensorType::get(static_cast(values.size()), + getIndexType()), + values); +} + IntegerAttr Builder::getI32IntegerAttr(int32_t value) { return IntegerAttr::get(getIntegerType(32), APInt(32, value)); } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -454,6 +454,10 @@ let arguments = (ins I32ElementsAttr:$attr); } +def IndexElementsAttrOp : TEST_Op<"indexElementsAttr"> { + let arguments = (ins IndexElementsAttr:$attr); +} + def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [ DeclareOpInterfaceMethods]> { let arguments = (ins AnyTensor, AnyTensor); diff --git a/mlir/test/mlir-tblgen/types.mlir b/mlir/test/mlir-tblgen/types.mlir --- a/mlir/test/mlir-tblgen/types.mlir +++ b/mlir/test/mlir-tblgen/types.mlir @@ -489,3 +489,18 @@ "test.i32ElementsAttr"() {attr = dense<[1, 2]>:tensor<2xi32>} : () -> () return } + +// ----- + +func @elements_attr_index() { + "test.indexElementsAttr"() {attr = dense<[1, 2]>:tensor<2xindex>} : () -> () + return +} + +// ----- + +func @elements_attr_not_index() { + // expected-error@+1 {{index elements attribute}} + "test.indexElementsAttr"() {attr = dense<[1, 2]>:tensor<2xi32>} : () -> () + return +}