diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1875,24 +1875,26 @@ typeElision = AttrTypeElision::Must; switch (denseArrayAttr.getElementType()) { case DenseArrayBaseAttr::EltType::I8: - os << "[:i8 "; + os << "[:i8"; break; case DenseArrayBaseAttr::EltType::I16: - os << "[:i16 "; + os << "[:i16"; break; case DenseArrayBaseAttr::EltType::I32: - os << "[:i32 "; + os << "[:i32"; break; case DenseArrayBaseAttr::EltType::I64: - os << "[:i64 "; + os << "[:i64"; break; case DenseArrayBaseAttr::EltType::F32: - os << "[:f32 "; + os << "[:f32"; break; case DenseArrayBaseAttr::EltType::F64: - os << "[:f64 "; + os << "[:f64"; break; } + if (denseArrayAttr.getType().cast().getRank()) + os << " "; denseArrayAttr.printWithoutBraces(os); os << "]"; } else if (auto locAttr = attr.dyn_cast()) { diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -838,6 +838,9 @@ Attribute DenseArrayAttr::parse(AsmParser &parser, Type odsType) { if (parser.parseLSquare()) return {}; + // Handle empty list case. + if (succeeded(parser.parseOptionalRSquare())) + return get(parser.getContext(), {}); Attribute result = parseWithoutBraces(parser, odsType); if (parser.parseRSquare()) return {}; @@ -860,42 +863,48 @@ template <> struct denseArrayAttrEltTypeBuilder { constexpr static auto eltType = DenseArrayBaseAttr::EltType::I8; - static ShapedType getShapedType(MLIRContext *context, int64_t shape) { + static ShapedType getShapedType(MLIRContext *context, + ArrayRef shape) { return VectorType::get(shape, IntegerType::get(context, 8)); } }; template <> struct denseArrayAttrEltTypeBuilder { constexpr static auto eltType = DenseArrayBaseAttr::EltType::I16; - static ShapedType getShapedType(MLIRContext *context, int64_t shape) { + static ShapedType getShapedType(MLIRContext *context, + ArrayRef shape) { return VectorType::get(shape, IntegerType::get(context, 16)); } }; template <> struct denseArrayAttrEltTypeBuilder { constexpr static auto eltType = DenseArrayBaseAttr::EltType::I32; - static ShapedType getShapedType(MLIRContext *context, int64_t shape) { + static ShapedType getShapedType(MLIRContext *context, + ArrayRef shape) { return VectorType::get(shape, IntegerType::get(context, 32)); } }; template <> struct denseArrayAttrEltTypeBuilder { constexpr static auto eltType = DenseArrayBaseAttr::EltType::I64; - static ShapedType getShapedType(MLIRContext *context, int64_t shape) { + static ShapedType getShapedType(MLIRContext *context, + ArrayRef shape) { return VectorType::get(shape, IntegerType::get(context, 64)); } }; template <> struct denseArrayAttrEltTypeBuilder { constexpr static auto eltType = DenseArrayBaseAttr::EltType::F32; - static ShapedType getShapedType(MLIRContext *context, int64_t shape) { + static ShapedType getShapedType(MLIRContext *context, + ArrayRef shape) { return VectorType::get(shape, Float32Type::get(context)); } }; template <> struct denseArrayAttrEltTypeBuilder { constexpr static auto eltType = DenseArrayBaseAttr::EltType::F64; - static ShapedType getShapedType(MLIRContext *context, int64_t shape) { + static ShapedType getShapedType(MLIRContext *context, + ArrayRef shape) { return VectorType::get(shape, Float64Type::get(context)); } }; @@ -905,8 +914,9 @@ template DenseArrayAttr DenseArrayAttr::get(MLIRContext *context, ArrayRef content) { - auto shapedType = - denseArrayAttrEltTypeBuilder::getShapedType(context, content.size()); + auto size = static_cast(content.size()); + auto shapedType = denseArrayAttrEltTypeBuilder::getShapedType( + context, size ? ArrayRef{size} : ArrayRef{}); auto eltType = denseArrayAttrEltTypeBuilder::eltType; auto rawArray = ArrayRef(reinterpret_cast(content.data()), content.size() * sizeof(T)); diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -844,19 +844,34 @@ return {}; CustomAsmParser parser(*this); Attribute result; + // Check for empty list. + bool isEmptyList = getToken().is(Token::r_square); + if (auto intType = type.dyn_cast()) { switch (type.getIntOrFloatBitWidth()) { case 8: - result = DenseI8ArrayAttr::parseWithoutBraces(parser, Type{}); + if (isEmptyList) + result = DenseI8ArrayAttr::get(parser.getContext(), {}); + else + result = DenseI8ArrayAttr::parseWithoutBraces(parser, Type{}); break; case 16: - result = DenseI16ArrayAttr::parseWithoutBraces(parser, Type{}); + if (isEmptyList) + result = DenseI16ArrayAttr::get(parser.getContext(), {}); + else + result = DenseI16ArrayAttr::parseWithoutBraces(parser, Type{}); break; case 32: - result = DenseI32ArrayAttr::parseWithoutBraces(parser, Type{}); + if (isEmptyList) + result = DenseI32ArrayAttr::get(parser.getContext(), {}); + else + result = DenseI32ArrayAttr::parseWithoutBraces(parser, Type{}); break; case 64: - result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{}); + if (isEmptyList) + result = DenseI64ArrayAttr::get(parser.getContext(), {}); + else + result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{}); break; default: emitError(typeLoc, "expected i8, i16, i32, or i64 but got: ") << type; @@ -865,10 +880,16 @@ } else if (auto floatType = type.dyn_cast()) { switch (type.getIntOrFloatBitWidth()) { case 32: - result = DenseF32ArrayAttr::parseWithoutBraces(parser, Type{}); + if (isEmptyList) + result = DenseF32ArrayAttr::get(parser.getContext(), {}); + else + result = DenseF32ArrayAttr::parseWithoutBraces(parser, Type{}); break; case 64: - result = DenseF64ArrayAttr::parseWithoutBraces(parser, Type{}); + if (isEmptyList) + result = DenseF64ArrayAttr::get(parser.getContext(), {}); + else + result = DenseF64ArrayAttr::parseWithoutBraces(parser, Type{}); break; default: emitError(typeLoc, "expected f32 or f64 but got: ") << type; diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -521,7 +521,19 @@ //===----------------------------------------------------------------------===// // CHECK-LABEL: func @dense_array_attr -func.func @dense_array_attr() attributes{ +func.func @dense_array_attr() attributes{ +// CHECK-SAME: emptyf32attr = [:f32], + emptyf32attr = [:f32], +// CHECK-SAME: emptyf64attr = [:f64], + emptyf64attr = [:f64], +// CHECK-SAME: emptyi16attr = [:i16], + emptyi16attr = [:i16], +// CHECK-SAME: emptyi32attr = [:i32], + emptyi32attr = [:i32], +// CHECK-SAME: emptyi64attr = [:i64], + emptyi64attr = [:i64], +// CHECK-SAME: emptyi8attr = [:i8], + emptyi8attr = [:i8], // CHECK-SAME: f32attr = [:f32 1.024000e+03, 4.530000e+02, -6.435000e+03], f32attr = [:f32 1024., 453., -6435.], // CHECK-SAME: f64attr = [:f64 -1.420000e+02], @@ -549,6 +561,8 @@ f32attr = [1024., 453., -6435.] // CHECK-SAME: f64attr = [-1.420000e+02] f64attr = [-142.] +// CHECK-SAME: emptyattr = [] + emptyattr = [] return } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -277,11 +277,13 @@ DenseI32ArrayAttr:$i32attr, DenseI64ArrayAttr:$i64attr, DenseF32ArrayAttr:$f32attr, - DenseF64ArrayAttr:$f64attr + DenseF64ArrayAttr:$f64attr, + DenseI32ArrayAttr:$emptyattr ); let assemblyFormat = [{ `i8attr` `=` $i8attr `i16attr` `=` $i16attr `i32attr` `=` $i32attr `i64attr` `=` $i64attr `f32attr` `=` $f32attr `f64attr` `=` $f64attr + `emptyattr` `=` $emptyattr attr-dict }]; }