diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1040,6 +1040,26 @@ def I32ElementsAttr : IntElementsAttr<32>; def I64ElementsAttr : IntElementsAttr<64>; +// A `width`-bit integer elements attribute. The attribute should be ranked and +// has a shape as specified in `dims`. +class RankedIntElementsAttr dims> : IntElementsAttr { + // Check that this has the specified shape. + let predicate = And<[ + IntElementsAttr.predicate, + CPred<"$_self.cast().getType().getShape() == " + "ArrayRef({" # StrJoinInt.result # "})">]>; + + let description = width # "-bit int elements attribute of shape [" # + StrJoinInt.result # "]"; + + let constBuilderCall = "DenseIntElementsAttr::get(" + "RankedTensorType::get({" # StrJoinInt.result # + "}, $_builder.getIntegerType(" # width # ")), makeArrayRef($0))"; +} + +class RankedI32ElementsAttr dims> : RankedIntElementsAttr<32, dims>; +class RankedI64ElementsAttr dims> : RankedIntElementsAttr<64, dims>; + class FloatElementsAttr : ElementsAttrBase< CPred<"$_self.isa() &&" "$_self.cast().getType()." diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -243,3 +243,53 @@ // expected-error @+1 {{referencing to a 'FuncOp' symbol}} "test.symbol_ref_attr"() {symbol = @foo} : () -> () + +// ----- + +//===----------------------------------------------------------------------===// +// Test IntElementsAttr +//===----------------------------------------------------------------------===// + +func @correct_type_pass() { + "test.int_elements_attr"() { + // CHECK: matrix_i64_attr = dense<6> : tensor<4x8xi64> + // CHECK: vector_i32_attr = dense<5> : tensor<2xi32> + matrix_i64_attr = dense<6> : tensor<4x8xi64>, + vector_i32_attr = dense<5> : tensor<2xi32> + } : () -> () + return +} + +// ----- + +func @wrong_element_type_fail() { + // expected-error @+1 {{failed to satisfy constraint: 32-bit int elements attribute of shape [2]}} + "test.int_elements_attr"() { + matrix_i64_attr = dense<6> : tensor<4x8xi64>, + vector_i32_attr = dense<5> : tensor<2xi64> + } : () -> () + return +} + +// ----- + +func @wrong_shape_fail() { + // expected-error @+1 {{failed to satisfy constraint: 64-bit int elements attribute of shape [4, 8]}} + "test.int_elements_attr"() { + matrix_i64_attr = dense<6> : tensor<4xi64>, + vector_i32_attr = dense<5> : tensor<2xi32> + } : () -> () + return +} + +// ----- + +func @wrong_shape_fail() { + // expected-error @+1 {{failed to satisfy constraint: 32-bit int elements attribute of shape [2]}} + "test.int_elements_attr"() { + matrix_i64_attr = dense<6> : tensor<4x8xi64>, + vector_i32_attr = dense<5> : tensor + } : () -> () + return +} + diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -204,6 +204,13 @@ ConstantAttr, "{5.0f, 6.0f}">:$f32attr, $f64attr)>; +def IntElementsAttrOp : TEST_Op<"int_elements_attr"> { + let arguments = (ins + RankedI32ElementsAttr<[2]>:$vector_i32_attr, + RankedI64ElementsAttr<[4, 8]>:$matrix_i64_attr + ); +} + //===----------------------------------------------------------------------===// // Test Attribute Constraints //===----------------------------------------------------------------------===//