diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1092,7 +1092,10 @@ let builders = [ OpBuilder<(ins "Attribute":$value), - [{ build($_builder, $_state, value.getType(), value); }]>]; + [{ build($_builder, $_state, value.getType(), value); }]>, + OpBuilder<(ins "Attribute":$value, "Type":$type), + [{ build($_builder, $_state, type, value); }]>, + ]; let extraClassDeclaration = [{ Attribute getValue() { return (*this)->getAttr("value"); } diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -326,11 +326,13 @@ /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`. /// This currently supports integer, floating point, splat and dense element -/// attributes and combinations thereof. In case of error, report it to `loc` -/// and return nullptr. +/// attributes and combinations thereof. Also, an array attribute with two +/// elements is supported to represent a complex constant. In case of error, +/// report it to `loc` and return nullptr. llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr, Location loc, - const ModuleTranslation &moduleTranslation); + const ModuleTranslation &moduleTranslation, + bool isTopLevel = true); /// Creates a call to an LLVM IR intrinsic function with the given arguments. llvm::Value *createIntrinsicCall(llvm::IRBuilderBase &builder, diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1962,7 +1962,30 @@ } return success(); } - if (!op.value().isa()) + if (auto structType = op.getType().dyn_cast()) { + if (structType.getBody().size() != 2 || + structType.getBody()[0] != structType.getBody()[1]) { + return op.emitError() << "expected struct type with two elements of the " + "same type, the type of a complex constant"; + } + + auto arrayAttr = op.value().dyn_cast(); + if (!arrayAttr || arrayAttr.size() != 2 || + arrayAttr[0].getType() != arrayAttr[1].getType()) { + return op.emitOpError() << "expected array attribute with two elements, " + "representing a complex constant"; + } + + Type elementType = structType.getBody()[0]; + if (!elementType + .isa()) { + return op.emitError() + << "expected struct element types to be floating point type or " + "integer type"; + } + return success(); + } + if (!op.value().isa()) return op.emitOpError() << "only supports integer, float, string or elements attributes"; return success(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -1200,7 +1200,8 @@ // Create a constant scalar value from the splat constant. Value scalarConstant = rewriter.create( - def->getLoc(), constantAttr.getSplatValue()); + def->getLoc(), constantAttr.getSplatValue(), + constantAttr.getType().getElementType()); auto fusedOp = rewriter.create( rewriter.getUnknownLoc(), genericOp->getResultTypes(), diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1067,8 +1067,8 @@ p << ' '; p << op.getValue(); - // If the value is a symbol reference, print a trailing type. - if (op.getValue().isa()) + // If the value is a symbol reference or Array, print a trailing type. + if (op.getValue().isa()) p << " : " << op.getType(); } @@ -1079,9 +1079,10 @@ parser.parseAttribute(valueAttr, "value", result.attributes)) return failure(); - // If the attribute is a symbol reference, then we expect a trailing type. + // If the attribute is a symbol reference or array, then we expect a trailing + // type. Type type; - if (!valueAttr.isa()) + if (!valueAttr.isa()) type = valueAttr.getType(); else if (parser.parseColonType(type)) return failure(); @@ -1119,6 +1120,24 @@ return success(); } + if (auto complexTy = type.dyn_cast()) { + auto arrayAttr = value.dyn_cast(); + if (!complexTy || arrayAttr.size() != 2) + return op.emitOpError( + "requires 'value' to be a complex constant, represented as array of " + "two values"); + auto complexEltTy = complexTy.getElementType(); + if (complexEltTy != arrayAttr[0].getType() || + complexEltTy != arrayAttr[1].getType()) { + return op.emitOpError() + << "requires attribute's element types (" << arrayAttr[0].getType() + << ", " << arrayAttr[1].getType() + << ") to match the element type of the op's return type (" + << complexEltTy << ")"; + } + return success(); + } + if (type.isa()) { if (!value.isa()) return op.emitOpError("requires 'value' to be a floating point constant"); @@ -1193,13 +1212,21 @@ if (value.isa()) return type.isa(); // The attribute must have the same type as 'type'. - if (value.getType() != type) + if (!value.getType().isa() && value.getType() != type) return false; // If the type is an integer type, it must be signless. if (IntegerType integerTy = type.dyn_cast()) if (!integerTy.isSignless()) return false; // Finally, check that the attribute kind is handled. + if (auto arrAttr = value.dyn_cast()) { + auto complexTy = type.dyn_cast(); + if (!complexTy) + return false; + auto complexEltTy = complexTy.getElementType(); + return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy && + arrAttr[1].getType() == complexEltTy; + } return value.isa(); } diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -578,6 +578,25 @@ FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); return FloatAttr::get(eltTy, *floatIt); } + if (auto complexTy = eltTy.dyn_cast()) { + auto complexEltTy = complexTy.getElementType(); + ComplexIntElementIterator complexIntIt(owner, index); + if (complexEltTy.isa()) { + auto value = *complexIntIt; + auto real = IntegerAttr::get(complexEltTy, value.real()); + auto imag = IntegerAttr::get(complexEltTy, value.imag()); + return ArrayAttr::get(complexTy.getContext(), + ArrayRef{real, imag}); + } + + ComplexFloatElementIterator complexFloatIt( + complexEltTy.cast().getFloatSemantics(), complexIntIt); + auto value = *complexFloatIt; + auto real = FloatAttr::get(complexEltTy, value.real()); + auto imag = FloatAttr::get(complexEltTy, value.imag()); + return ArrayAttr::get(complexTy.getContext(), + ArrayRef{real, imag}); + } if (owner.isa()) { ArrayRef vals = owner.getRawStringData(); return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -103,16 +103,30 @@ /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`. /// This currently supports integer, floating point, splat and dense element -/// attributes and combinations thereof. In case of error, report it to `loc` -/// and return nullptr. +/// attributes and combinations thereof. Also, an array attribute with two +/// elements is supported to represent a complex constant. In case of error, +/// report it to `loc` and return nullptr. llvm::Constant *mlir::LLVM::detail::getLLVMConstant( llvm::Type *llvmType, Attribute attr, Location loc, - const ModuleTranslation &moduleTranslation) { + const ModuleTranslation &moduleTranslation, bool isTopLevel) { if (!attr) return llvm::UndefValue::get(llvmType); - if (llvmType->isStructTy()) { - emitError(loc, "struct types are not supported in constants"); - return nullptr; + if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) { + if (!isTopLevel) { + emitError(loc, "nested struct types are not supported in constants"); + return nullptr; + } + auto arrayAttr = attr.cast(); + llvm::Type *elementType = structType->getElementType(0); + llvm::Constant *real = getLLVMConstant(elementType, arrayAttr[0], loc, + moduleTranslation, false); + if (!real) + return nullptr; + llvm::Constant *imag = getLLVMConstant(elementType, arrayAttr[1], loc, + moduleTranslation, false); + if (!imag) + return nullptr; + return llvm::ConstantStruct::get(structType, {real, imag}); } // For integer types, we allow a mismatch in sizes as the index type in // MLIR might have a different size than the index type in the LLVM module. @@ -120,8 +134,15 @@ return llvm::ConstantInt::get( llvmType, intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth())); - if (auto floatAttr = attr.dyn_cast()) + if (auto floatAttr = attr.dyn_cast()) { + if (llvmType != + llvm::Type::getFloatingPointTy(llvmType->getContext(), + floatAttr.getValue().getSemantics())) { + emitError(loc, "FloatAttr does not match expected type of the constant"); + return nullptr; + } return llvm::ConstantFP::get(llvmType, floatAttr.getValue()); + } if (auto funcAttr = attr.dyn_cast()) return llvm::ConstantExpr::getBitCast( moduleTranslation.lookupFunction(funcAttr.getValue()), llvmType); @@ -144,7 +165,7 @@ llvm::Constant *child = getLLVMConstant( elementType, elementTypeSequential ? splatAttr : splatAttr.getSplatValue(), loc, - moduleTranslation); + moduleTranslation, false); if (!child) return nullptr; if (llvmType->isVectorTy()) @@ -169,7 +190,7 @@ llvm::Type *innermostType = getInnermostElementType(llvmType); for (auto n : elementsAttr.getValues()) { constants.push_back( - getLLVMConstant(innermostType, n, loc, moduleTranslation)); + getLLVMConstant(innermostType, n, loc, moduleTranslation, false)); if (!constants.back()) return nullptr; } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -265,6 +265,54 @@ // ----- +llvm.func @array_attribute_one_element() -> !llvm.struct<(f64, f64)> { + // expected-error @+1 {{expected array attribute with two elements, representing a complex constant}} + %0 = llvm.mlir.constant([1.0 : f64]) : !llvm.struct<(f64, f64)> + llvm.return %0 : !llvm.struct<(f64, f64)> +} + +// ----- + +llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> { + // expected-error @+1 {{expected array attribute with two elements, representing a complex constant}} + %0 = llvm.mlir.constant([1.0 : f64, 1.0 : f32]) : !llvm.struct<(f64, f64)> + llvm.return %0 : !llvm.struct<(f64, f64)> +} + +// ----- + +llvm.func @struct_wrong_attribute_type() -> !llvm.struct<(f64, f64)> { + // expected-error @+1 {{expected array attribute with two elements, representing a complex constant}} + %0 = llvm.mlir.constant(1.0 : f64) : !llvm.struct<(f64, f64)> + llvm.return %0 : !llvm.struct<(f64, f64)> +} + +// ----- + +llvm.func @struct_one_element() -> !llvm.struct<(f64)> { + // expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}} + %0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64)> + llvm.return %0 : !llvm.struct<(f64)> +} + +// ----- + +llvm.func @struct_two_different_elements() -> !llvm.struct<(f64, f32)> { + // expected-error @+1 {{expected struct type with two elements of the same type, the type of a complex constant}} + %0 = llvm.mlir.constant([1.0 : f64, 1.0 : f64]) : !llvm.struct<(f64, f32)> + llvm.return %0 : !llvm.struct<(f64, f32)> +} + +// ----- + +llvm.func @struct_wrong_element_types() -> !llvm.struct<(!llvm.array<2 x f64>, !llvm.array<2 x f64>)> { + // expected-error @+1 {{expected struct element types to be floating point type or integer type}} + %0 = llvm.mlir.constant([dense<[1.0, 1.0]> : tensor<2xf64>, dense<[1.0, 1.0]> : tensor<2xf64>]) : !llvm.struct<(!llvm.array<2 x f64>, !llvm.array<2 x f64>)> + llvm.return %0 : !llvm.struct<(!llvm.array<2 x f64>, !llvm.array<2 x f64>)> +} + +// ----- + func @insertvalue_non_llvm_type(%a : i32, %b : i32) { // expected-error@+1 {{expected LLVM IR Dialect type}} llvm.insertvalue %a, %b[0] : tensor<*xi32> diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -37,3 +37,35 @@ %0 = constant "" : index return } + +// ----- + +func @complex_constant_wrong_array_attribute_length() { + // expected-error @+1 {{requires 'value' to be a complex constant, represented as array of two values}} + %0 = constant [1.0 : f32] : complex + return +} + +// ----- + +func @complex_constant_wrong_attribute_type() { + // expected-error @+1 {{requires attribute's type ('f32') to match op's return type ('complex')}} + %0 = "std.constant" () {value = 1.0 : f32} : () -> complex + return +} + +// ----- + +func @complex_constant_wrong_element_types() { + // expected-error @+1 {{requires attribute's element types ('f32', 'f32') to match the element type of the op's return type ('f64')}} + %0 = constant [1.0 : f32, -1.0 : f32] : complex + return +} + +// ----- + +func @complex_constant_two_different_element_types() { + // expected-error @+1 {{requires attribute's element types ('f32', 'f64') to match the element type of the op's return type ('f64')}} + %0 = constant [1.0 : f32, -1.0 : f64] : complex + return +} diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir --- a/mlir/test/Dialect/Standard/ops.mlir +++ b/mlir/test/Dialect/Standard/ops.mlir @@ -68,3 +68,15 @@ ^bb3(%bb3arg : i32): return } + +// CHECK-LABEL: func @constant_complex_f32( +func @constant_complex_f32() -> complex { + %result = constant [0.1 : f32, -1.0 : f32] : complex + return %result : complex +} + +// CHECK-LABEL: func @constant_complex_f64( +func @constant_complex_f64() -> complex { + %result = constant [0.1 : f64, -1.0 : f64] : complex + return %result : complex +} diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir --- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir @@ -35,13 +35,21 @@ // ----- llvm.func @no_nested_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> { - // expected-error @+1 {{struct types are not supported in constants}} + // expected-error @+1 {{nested struct types are not supported in constants}} %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> } // ----- +llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> { + // expected-error @+1 {{FloatAttr does not match expected type of the constant}} + %0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)> + llvm.return %0 : !llvm.struct<(f64, f64)> +} + +// ----- + // expected-error @+1 {{unsupported constant value}} llvm.mlir.global internal constant @test([2.5, 7.4]) : !llvm.array<2 x f64> @@ -63,4 +71,6 @@ // ----- // expected-error @+1 {{expected arrays within 'passthrough' to contain two strings}} -llvm.func @passthrough_wrong_type() attributes {passthrough = [[42, 42]]} +llvm.func @passthrough_wrong_type() attributes { + passthrough = [[ 42, 42 ]] +} diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1016,6 +1016,18 @@ llvm.return %1 : !llvm.array<12 x i8> } +llvm.func @complexfpconstant() -> !llvm.struct<(f32, f32)> { + %1 = llvm.mlir.constant([-1.000000e+00 : f32, 0.000000e+00 : f32]) : !llvm.struct<(f32, f32)> + // CHECK: ret { float, float } { float -1.000000e+00, float 0.000000e+00 } + llvm.return %1 : !llvm.struct<(f32, f32)> +} + +llvm.func @complexintconstant() -> !llvm.struct<(i32, i32)> { + %1 = llvm.mlir.constant([-1 : i32, 0 : i32]) : !llvm.struct<(i32, i32)> + // CHECK: ret { i32, i32 } { i32 -1, i32 0 } + llvm.return %1 : !llvm.struct<(i32, i32)> +} + llvm.func @noreach() { // CHECK: unreachable llvm.unreachable