diff --git a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp --- a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp @@ -48,11 +48,16 @@ ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::llvm::APFloat real, ::llvm::APFloat imag, ::mlir::Type type) { - if (!type.isa()) + if (!type.isa()) + return emitError() << "complex attribute must be a complex type."; + + Type elementType = type.cast().getElementType(); + if (!elementType.isa()) return emitError() - << "element of the complex attribute must be float like type."; + << "element type of the complex attribute must be float like type."; - const auto &typeFloatSemantics = type.cast().getFloatSemantics(); + const auto &typeFloatSemantics = + elementType.cast().getFloatSemantics(); if (&real.getSemantics() != &typeFloatSemantics) return emitError() << "type doesn't match the type implied by its `real` value"; @@ -64,7 +69,8 @@ } void complex::NumberAttr::print(AsmPrinter &printer) const { - printer << "<:" << getType() << " " << getReal() << ", " << getImag() << ">"; + printer << "<:" << getType().cast().getElementType() << " " + << getReal() << ", " << getImag() << ">"; } Attribute complex::NumberAttr::parse(AsmParser &parser, Type odsType) { @@ -82,5 +88,6 @@ APFloat imagFloat(imag); imagFloat.convert(type.cast().getFloatSemantics(), APFloat::rmNearestTiesToEven, &unused); - return NumberAttr::get(parser.getContext(), realFloat, imagFloat, type); + return NumberAttr::get(parser.getContext(), realFloat, imagFloat, + ComplexType::get(type)); } diff --git a/mlir/test/Dialect/Complex/attribute.mlir b/mlir/test/Dialect/Complex/attribute.mlir --- a/mlir/test/Dialect/Complex/attribute.mlir +++ b/mlir/test/Dialect/Complex/attribute.mlir @@ -2,7 +2,7 @@ func.func @number_attr_f64() { "test.number_attr"() { - // CHECK: attr = #complex.number<:f64 1.000000e+00, 0.000000e+00> : f64 + // CHECK: attr = #complex.number<:f64 1.000000e+00, 0.000000e+00> : complex attr = #complex.number<:f64 1.0, 0.0> } : () -> () @@ -11,7 +11,7 @@ func.func @number_attr_f32() { "test.number_attr"() { - // CHECK: attr = #complex.number<:f32 1.000000e+00, 0.000000e+00> : f32 + // CHECK: attr = #complex.number<:f32 1.000000e+00, 0.000000e+00> : complex attr = #complex.number<:f32 1.0, 0.0> } : () -> ()