diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td --- a/mlir/include/mlir/IR/CommonAttrConstraints.td +++ b/mlir/include/mlir/IR/CommonAttrConstraints.td @@ -612,7 +612,7 @@ let convertFromStorage = [{ llvm::to_vector<4>( llvm::map_range($_self.getAsRange(), - [](IntegerAttr attr) { return attr.getInt(); })); + [](mlir::IntegerAttr attr) { return attr.getInt(); })); }]; let constBuilderCall = "$_builder.getI64ArrayAttr($0)"; } diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -16,16 +16,36 @@ return decorator_builder +@register_attribute_builder("AffineMapAttr") +def _affineMapAttr(x, context): + return AffineMapAttr.get(x) + + @register_attribute_builder("BoolAttr") def _boolAttr(x, context): return BoolAttr.get(x, context=context) +@register_attribute_builder("DictionaryAttr") +def _dictAttr(x, context): + return DictAttr.get(x, context=context) + + @register_attribute_builder("IndexAttr") def _indexAttr(x, context): return IntegerAttr.get(IndexType.get(context=context), x) +@register_attribute_builder("I1Attr") +def _i1Attr(x, context): + return IntegerAttr.get(IntegerType.get_signless(1, context=context), x) + + +@register_attribute_builder("I8Attr") +def _i8Attr(x, context): + return IntegerAttr.get(IntegerType.get_signless(8, context=context), x) + + @register_attribute_builder("I16Attr") def _i16Attr(x, context): return IntegerAttr.get(IntegerType.get_signless(16, context=context), x) @@ -41,6 +61,16 @@ return IntegerAttr.get(IntegerType.get_signless(64, context=context), x) +@register_attribute_builder("SI1Attr") +def _si1Attr(x, context): + return IntegerAttr.get(IntegerType.get_signed(1, context=context), x) + + +@register_attribute_builder("SI8Attr") +def _i8Attr(x, context): + return IntegerAttr.get(IntegerType.get_signed(8, context=context), x) + + @register_attribute_builder("SI16Attr") def _si16Attr(x, context): return IntegerAttr.get(IntegerType.get_signed(16, context=context), x) @@ -51,6 +81,36 @@ return IntegerAttr.get(IntegerType.get_signed(32, context=context), x) +@register_attribute_builder("SI64Attr") +def _si64Attr(x, context): + return IntegerAttr.get(IntegerType.get_signed(64, context=context), x) + + +@register_attribute_builder("UI1Attr") +def _ui1Attr(x, context): + return IntegerAttr.get(IntegerType.get_unsigned(1, context=context), x) + + +@register_attribute_builder("UI8Attr") +def _i8Attr(x, context): + return IntegerAttr.get(IntegerType.get_unsigned(8, context=context), x) + + +@register_attribute_builder("UI16Attr") +def _ui16Attr(x, context): + return IntegerAttr.get(IntegerType.get_unsigned(16, context=context), x) + + +@register_attribute_builder("UI32Attr") +def _ui32Attr(x, context): + return IntegerAttr.get(IntegerType.get_unsigned(32, context=context), x) + + +@register_attribute_builder("UI64Attr") +def _ui64Attr(x, context): + return IntegerAttr.get(IntegerType.get_unsigned(64, context=context), x) + + @register_attribute_builder("F32Attr") def _f32Attr(x, context): return FloatAttr.get_f32(x, context=context) @@ -84,11 +144,39 @@ return FlatSymbolRefAttr.get(x, context=context) +@register_attribute_builder("UnitAttr") +def _unitAttr(x, context): + if x: + return UnitAttr.get(context=context) + else: + return None + + @register_attribute_builder("ArrayAttr") def _arrayAttr(x, context): return ArrayAttr.get(x, context=context) +@register_attribute_builder("AffineMapArrayAttr") +def _affineMapArrayAttr(x, context): + return ArrayAttr.get([_affineMapAttr(v, context) for v in x]) + + +@register_attribute_builder("BoolArrayAttr") +def _boolArrayAttr(x, context): + return ArrayAttr.get([_boolAttr(v, context) for v in x]) + + +@register_attribute_builder("DictArrayAttr") +def _dictArrayAttr(x, context): + return ArrayAttr.get([_dictAttr(v, context) for v in x]) + + +@register_attribute_builder("FlatSymbolRefArrayAttr") +def _flatSymbolRefArrayAttr(x, context): + return ArrayAttr.get([_flatSymbolRefAttr(v, context) for v in x]) + + @register_attribute_builder("I32ArrayAttr") def _i32ArrayAttr(x, context): return ArrayAttr.get([_i32Attr(v, context) for v in x]) @@ -99,6 +187,16 @@ return ArrayAttr.get([_i64Attr(v, context) for v in x]) +@register_attribute_builder("I64SmallVectorArrayAttr") +def _i64SmallVectorArrayAttr(x, context): + return _i64ArrayAttr(x, context=context) + + +@register_attribute_builder("IndexListArrayAttr") +def _indexListArrayAttr(x, context): + return ArrayAttr.get([_i64ArrayAttr(v, context) for v in x]) + + @register_attribute_builder("F32ArrayAttr") def _f32ArrayAttr(x, context): return ArrayAttr.get([_f32Attr(v, context) for v in x]) @@ -109,6 +207,41 @@ return ArrayAttr.get([_f64Attr(v, context) for v in x]) +@register_attribute_builder("StrArrayAttr") +def _strArrayAttr(x, context): + return ArrayAttr.get([_stringAttr(v, context) for v in x]) + + +@register_attribute_builder("SymbolRefArrayAttr") +def _symbolRefArrayAttr(x, context): + return ArrayAttr.get([_symbolRefAttr(v, context) for v in x]) + + +@register_attribute_builder("DenseF32ArrayAttr") +def _denseF32ArrayAttr(x, context): + return DenseF32ArrayAttr.get(x, context=context) + + +@register_attribute_builder("DenseF64ArrayAttr") +def _denseF64ArrayAttr(x, context): + return DenseF64ArrayAttr.get(x, context=context) + + +@register_attribute_builder("DenseI8ArrayAttr") +def _denseI8ArrayAttr(x, context): + return DenseI8ArrayAttr.get(x, context=context) + + +@register_attribute_builder("DenseI16ArrayAttr") +def _denseI16ArrayAttr(x, context): + return DenseI16ArrayAttr.get(x, context=context) + + +@register_attribute_builder("DenseI32ArrayAttr") +def _denseI32ArrayAttr(x, context): + return DenseI32ArrayAttr.get(x, context=context) + + @register_attribute_builder("DenseI64ArrayAttr") def _denseI64ArrayAttr(x, context): return DenseI64ArrayAttr.get(x, context=context) @@ -132,6 +265,30 @@ try: import numpy as np + @register_attribute_builder("F64ElementsAttr") + def _f64ElementsAttr(x, context): + return DenseElementsAttr.get( + np.array(x, dtype=np.int64), + type=F64Type.get(context=context), + context=context, + ) + + @register_attribute_builder("I32ElementsAttr") + def _i32ElementsAttr(x, context): + return DenseElementsAttr.get( + np.array(x, dtype=np.int32), + type=IntegerType.get_signed(32, context=context), + context=context, + ) + + @register_attribute_builder("I64ElementsAttr") + def _i64ElementsAttr(x, context): + return DenseElementsAttr.get( + np.array(x, dtype=np.int64), + type=IntegerType.get_signed(64, context=context), + context=context, + ) + @register_attribute_builder("IndexElementsAttr") def _indexElementsAttr(x, context): return DenseElementsAttr.get( diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -140,23 +140,76 @@ def attrBuilder(): with Context() as ctx, Location.unknown(): ctx.allow_unregistered_dialects = True + # CHECK: python_test.attributes_op op = test.AttributesOp( - x_bool=True, - x_i16=1, - x_i32=2, - x_i64=3, - x_si16=-1, - x_si32=-2, - x_f32=1.5, - x_f64=2.5, - x_str="x_str", - x_i32_array=[1, 2, 3], - x_i64_array=[4, 5, 6], - x_f32_array=[1.5, -2.5, 3.5], - x_f64_array=[4.5, 5.5, -6.5], - x_i64_dense=[1, 2, 3, 4, 5, 6], + # CHECK-DAG: x_affinemap = affine_map<() -> (2)> + x_affinemap=AffineMap.get_constant(2), + # CHECK-DAG: x_affinemaparr = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] + x_affinemaparr=[AffineMap.get_identity(3)], + # CHECK-DAG: x_arr = [true, "x"] + x_arr=[BoolAttr.get(True), StringAttr.get("x")], + x_boolarr=[False, True], # CHECK-DAG: x_boolarr = [false, true] + x_bool=True, # CHECK-DAG: x_bool = true + x_dboolarr=[True, False], # CHECK-DAG: x_dboolarr = array + x_df16arr=[21, 22], # CHECK-DAG: x_df16arr = array + # CHECK-DAG: x_df32arr = array + x_df32arr=[23, 24], + # CHECK-DAG: x_df64arr = array + x_df64arr=[25, 26], + x_di32arr=[0, 1], # CHECK-DAG: x_di32arr = array + # CHECK-DAG: x_di64arr = array + x_di64arr=[1, 2], + x_di8arr=[2, 3], # CHECK-DAG: x_di8arr = array + # CHECK-DAG: x_dictarr = [{a = false}] + x_dictarr=[{"a": BoolAttr.get(False)}], + x_dict={"b": BoolAttr.get(True)}, # CHECK-DAG: x_dict = {b = true} + x_f32=-2.25, # CHECK-DAG: x_f32 = -2.250000e+00 : f32 + # CHECK-DAG: x_f32arr = [2.000000e+00 : f32, 3.000000e+00 : f32] + x_f32arr=[2.0, 3.0], + x_f64=4.25, # CHECK-DAG: x_f64 = 4.250000e+00 : f64 + x_f64arr=[4.0, 8.0], # CHECK-DAG: x_f64arr = [4.000000e+00, 8.000000e+00] + # CHECK-DAG: x_f64elems = dense<[3.952530e-323, 7.905050e-323]> : tensor<2xf64> + x_f64elems=[8.0, 16.0], + # CHECK-DAG: x_flatsymrefarr = [@symbol1, @symbol2] + x_flatsymrefarr=["symbol1", "symbol2"], + x_flatsymref="symbol3", # CHECK-DAG: x_flatsymref = @symbol3 + x_i1=0, # CHECK-DAG: x_i1 = false + x_i16=42, # CHECK-DAG: x_i16 = 42 : i16 + x_i32=6, # CHECK-DAG: x_i32 = 6 : i32 + x_i32arr=[4, 5], # CHECK-DAG: x_i32arr = [4 : i32, 5 : i32] + x_i32elems=[5, 6], # CHECK-DAG: x_i32elems = dense<[5, 6]> : tensor<2xsi32> + x_i64=9, # CHECK-DAG: x_i64 = 9 : i64 + x_i64arr=[7, 8], # CHECK-DAG: x_i64arr = [7, 8] + x_i64elems=[8, 9], # CHECK-DAG: x_i64elems = dense<[8, 9]> : tensor<2xsi64> + x_i64svecarr=[10, 11], # CHECK-DAG: x_i64svecarr = [10, 11] + x_i8=11, # CHECK-DAG: x_i8 = 11 : i8 + x_idx=10, # CHECK-DAG: x_idx = 10 : index + # CHECK-DAG: x_idxelems = dense<[11, 12]> : tensor<2xindex> + x_idxelems=[11, 12], + # CHECK-DAG: x_idxlistarr = [{{\[}}13], [14, 15]] + x_idxlistarr=[[13], [14, 15]], + x_si1=-1, # CHECK-DAG: x_si1 = -1 : si1 + x_si16=-2, # CHECK-DAG: x_si16 = -2 : si16 + x_si32=-3, # CHECK-DAG: x_si32 = -3 : si32 + x_si64=-123, # CHECK-DAG: x_si64 = -123 : si64 + x_si8=-4, # CHECK-DAG: x_si8 = -4 : si8 + x_strarr=["hello", "world"], # CHECK-DAG: x_strarr = ["hello", "world"] + x_str="hello world!", # CHECK-DAG: x_str = "hello world!" + # CHECK-DAG: x_symrefarr = [@flatsym, @deep::@sym] + x_symrefarr=["flatsym", ["deep", "sym"]], + x_symref=["deep", "sym2"], # CHECK-DAG: x_symref = @deep::@sym2 + x_sym="symbol", # CHECK-DAG: x_sym = "symbol" + x_typearr=[F32Type.get()], # CHECK-DAG: x_typearr = [f32] + x_type=F64Type.get(), # CHECK-DAG: x_type = f64 + x_ui1=1, # CHECK-DAG: x_ui1 = 1 : ui1 + x_ui16=2, # CHECK-DAG: x_ui16 = 2 : ui16 + x_ui32=3, # CHECK-DAG: x_ui32 = 3 : ui32 + x_ui64=4, # CHECK-DAG: x_ui64 = 4 : ui64 + x_ui8=5, # CHECK-DAG: x_ui8 = 5 : ui8 + x_unit=True, # CHECK-DAG: x_unit ) - print(op) + op.verify() + op.print(use_local_scope=True) # CHECK-LABEL: TEST: inferReturnTypes @@ -247,7 +300,6 @@ module = Module.create() with InsertionPoint(module.body): - op1 = test.OptionalOperandOp() # CHECK: op1.input is None: True print(f"op1.input is None: {op1.input is None}") diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td --- a/mlir/test/python/python_test_ops.td +++ b/mlir/test/python/python_test_ops.td @@ -57,20 +57,60 @@ } def AttributesOp : TestOp<"attributes_op"> { - let arguments = (ins BoolAttr:$x_bool, - I16Attr: $x_i16, - I32Attr: $x_i32, - I64Attr: $x_i64, - SI16Attr: $x_si16, - SI32Attr: $x_si32, - F32Attr: $x_f32, - F64Attr: $x_f64, - StrAttr: $x_str, - I32ArrayAttr: $x_i32_array, - I64ArrayAttr: $x_i64_array, - F32ArrayAttr: $x_f32_array, - F64ArrayAttr: $x_f64_array, - DenseI64ArrayAttr: $x_i64_dense); + let arguments = (ins + AffineMapArrayAttr:$x_affinemaparr, + AffineMapAttr:$x_affinemap, + ArrayAttr:$x_arr, + BoolArrayAttr:$x_boolarr, + BoolAttr:$x_bool, + DenseBoolArrayAttr:$x_dboolarr, + DenseF32ArrayAttr:$x_df32arr, + DenseF64ArrayAttr:$x_df64arr, + DenseI16ArrayAttr:$x_df16arr, + DenseI32ArrayAttr:$x_di32arr, + DenseI64ArrayAttr:$x_di64arr, + DenseI8ArrayAttr:$x_di8arr, + DictArrayAttr:$x_dictarr, + DictionaryAttr:$x_dict, + F32ArrayAttr:$x_f32arr, + F32Attr:$x_f32, + F64ArrayAttr:$x_f64arr, + F64Attr:$x_f64, + F64ElementsAttr:$x_f64elems, + FlatSymbolRefArrayAttr:$x_flatsymrefarr, + FlatSymbolRefAttr:$x_flatsymref, + I16Attr:$x_i16, + I1Attr:$x_i1, + I32ArrayAttr:$x_i32arr, + I32Attr:$x_i32, + I32ElementsAttr:$x_i32elems, + I64ArrayAttr:$x_i64arr, + I64Attr:$x_i64, + I64ElementsAttr:$x_i64elems, + I64SmallVectorArrayAttr:$x_i64svecarr, + I8Attr:$x_i8, + IndexAttr:$x_idx, + IndexElementsAttr:$x_idxelems, + IndexListArrayAttr:$x_idxlistarr, + SI16Attr:$x_si16, + SI1Attr:$x_si1, + SI32Attr:$x_si32, + SI64Attr:$x_si64, + SI8Attr:$x_si8, + StrArrayAttr:$x_strarr, + StrAttr:$x_str, + SymbolNameAttr:$x_sym, + SymbolRefArrayAttr:$x_symrefarr, + SymbolRefAttr:$x_symref, + TypeArrayAttr:$x_typearr, + TypeAttr:$x_type, + UI16Attr:$x_ui16, + UI1Attr:$x_ui1, + UI32Attr:$x_ui32, + UI64Attr:$x_ui64, + UI8Attr:$x_ui8, + UnitAttr:$x_unit + ); } def PropertyOp : TestOp<"property_op"> {