diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1092,7 +1092,10 @@ let builders = [ OpBuilder<(ins "Attribute":$value), - [{ build($_builder, $_state, value.getType(), value); }]>]; + [{ build($_builder, $_state, value.getType(), value); }]>, + OpBuilder<(ins "Attribute":$value, "Type":$type), + [{ build($_builder, $_state, type, value); }]>, + ]; let extraClassDeclaration = [{ Attribute getValue() { return (*this)->getAttr("value"); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1940,7 +1940,23 @@ } return success(); } - if (!op.value().isa()) + if (auto complexTy = op.getType().dyn_cast()) { + auto arrayAttr = op.value().dyn_cast(); + if (!complexTy || arrayAttr.size() != 2) + return op.emitOpError() << "expected array with exactly 2 elements for " + "the complex constant"; + auto complexEltTy = complexTy.getElementType(); + if (complexEltTy != arrayAttr[0].getType() || + complexEltTy != arrayAttr[1].getType()) { + return op.emitOpError() + << "requires attribute's element types (" << arrayAttr[0].getType() + << ", " << arrayAttr[1].getType() + << ") to match the element type of the op's return type (" + << complexEltTy << ")"; + } + return success(); + } + if (!op.value().isa()) return op.emitOpError() << "only supports integer, float, string or elements attributes"; return success(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -1339,7 +1339,8 @@ // Create a constant scalar value from the splat constant. Value scalarConstant = rewriter.create( - def->getLoc(), constantAttr.getSplatValue()); + def->getLoc(), constantAttr.getSplatValue(), + constantAttr.getType().getElementType()); LinalgOp fusedOp = createLinalgOpOfSameType( linalgOp, rewriter, rewriter.getUnknownLoc(), diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1067,8 +1067,8 @@ p << ' '; p << op.getValue(); - // If the value is a symbol reference, print a trailing type. - if (op.getValue().isa()) + // If the value is a symbol reference or Array, print a trailing type. + if (op.getValue().isa()) p << " : " << op.getType(); } @@ -1079,9 +1079,10 @@ parser.parseAttribute(valueAttr, "value", result.attributes)) return failure(); - // If the attribute is a symbol reference, then we expect a trailing type. + // If the attribute is a symbol reference or array, then we expect a trailing + // type. Type type; - if (!valueAttr.isa()) + if (!valueAttr.isa()) type = valueAttr.getType(); else if (parser.parseColonType(type)) return failure(); @@ -1119,6 +1120,24 @@ return success(); } + if (auto complexTy = type.dyn_cast()) { + auto arrayAttr = value.dyn_cast(); + if (!complexTy || arrayAttr.size() != 2) + return op.emitOpError( + "requires 'value' to be a complex constant, represented as array of " + "two values"); + auto complexEltTy = complexTy.getElementType(); + if (complexEltTy != arrayAttr[0].getType() || + complexEltTy != arrayAttr[1].getType()) { + return op.emitOpError() + << "requires attribute's element types (" << arrayAttr[0].getType() + << ", " << arrayAttr[1].getType() + << ") to match the element type of the op's return type (" + << complexEltTy << ")"; + } + return success(); + } + if (type.isa()) { if (!value.isa()) return op.emitOpError("requires 'value' to be a floating point constant"); @@ -1193,13 +1212,21 @@ if (value.isa()) return type.isa(); // The attribute must have the same type as 'type'. - if (value.getType() != type) + if (!value.getType().isa() && value.getType() != type) return false; // If the type is an integer type, it must be signless. if (IntegerType integerTy = type.dyn_cast()) if (!integerTy.isSignless()) return false; // Finally, check that the attribute kind is handled. + if (auto arrAttr = value.dyn_cast()) { + auto complexTy = type.dyn_cast(); + if (!complexTy) + return false; + auto complexEltTy = complexTy.getElementType(); + return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy && + arrAttr[1].getType() == complexEltTy; + } return value.isa(); } diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -578,6 +578,25 @@ FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); return FloatAttr::get(eltTy, *floatIt); } + if (auto complexTy = eltTy.dyn_cast()) { + auto complexEltTy = complexTy.getElementType(); + ComplexIntElementIterator complexIntIt(owner, index); + if (complexEltTy.isa()) { + auto value = *complexIntIt; + auto real = IntegerAttr::get(complexEltTy, value.real()); + auto imag = IntegerAttr::get(complexEltTy, value.imag()); + return ArrayAttr::get(complexTy.getContext(), + ArrayRef{real, imag}); + } + + ComplexFloatElementIterator complexFloatIt( + complexEltTy.cast().getFloatSemantics(), complexIntIt); + auto value = *complexFloatIt; + auto real = FloatAttr::get(complexEltTy, value.real()); + auto imag = FloatAttr::get(complexEltTy, value.imag()); + return ArrayAttr::get(complexTy.getContext(), + ArrayRef{real, imag}); + } if (owner.isa()) { ArrayRef vals = owner.getRawStringData(); return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -111,7 +111,19 @@ if (!attr) return llvm::UndefValue::get(llvmType); if (llvmType->isStructTy()) { - emitError(loc, "struct types are not supported in constants"); + if (auto arrayAttr = attr.dyn_cast()) { + if (arrayAttr.size() != 2) + return llvm::UndefValue::get(llvmType); + + llvm::StructType *structType = cast<::llvm::StructType>(llvmType); + llvm::Type *elementType = structType->getElementType(0); + llvm::Constant *real = + getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation); + llvm::Constant *imag = + getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation); + return llvm::ConstantStruct::get(structType, {real, imag}); + } + emitError(loc, "struct types are only supported for complex constants"); return nullptr; } // For integer types, we allow a mismatch in sizes as the index type in diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -37,3 +37,35 @@ %0 = constant "" : index return } + +// ----- + +func @complex_constant_wrong_array_attribute_length() { + // expected-error @+1 {{requires 'value' to be a complex constant, represented as array of two values}} + %0 = constant [1.0 : f32] : complex + return +} + +// ----- + +func @complex_constant_wrong_attribute_type() { + // expected-error @+1 {{requires attribute's type ('f32') to match op's return type ('complex')}} + %0 = "std.constant" () {value = 1.0 : f32} : () -> complex + return +} + +// ----- + +func @complex_constant_wrong_element_types() { + // expected-error @+1 {{requires attribute's element types ('f32', 'f32') to match the element type of the op's return type ('f64')}} + %0 = constant [1.0 : f32, -1.0 : f32] : complex + return +} + +// ----- + +func @complex_constant_two_different_element_types() { + // expected-error @+1 {{requires attribute's element types ('f32', 'f64') to match the element type of the op's return type ('f64')}} + %0 = constant [1.0 : f32, -1.0 : f64] : complex + return +} diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir --- a/mlir/test/Dialect/Standard/ops.mlir +++ b/mlir/test/Dialect/Standard/ops.mlir @@ -68,3 +68,15 @@ ^bb3(%bb3arg : i32): return } + +// CHECK-LABEL: func @constant_complex_f32( +func @constant_complex_f32() -> complex { + %result = constant [0.1 : f32, -1.0 : f32] : complex + return %result : complex +} + +// CHECK-LABEL: func @constant_complex_f64( +func @constant_complex_f64() -> complex { + %result = constant [0.1 : f64, -1.0 : f64] : complex + return %result : complex +} diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir --- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir @@ -35,7 +35,7 @@ // ----- llvm.func @no_nested_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> { - // expected-error @+1 {{struct types are not supported in constants}} + // expected-error @+1 {{struct types are only supported for complex constants}} %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> } diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1007,6 +1007,18 @@ llvm.return %1 : !llvm.array<12 x i8> } +llvm.func @complexfpconstant() -> !llvm.struct<(f32, f32)> { + %1 = llvm.mlir.constant([-1.000000e+00 : f32, 0.000000e+00 : f32]) : !llvm.struct<(f32, f32)> + // CHECK: ret { float, float } { float -1.000000e+00, float 0.000000e+00 } + llvm.return %1 : !llvm.struct<(f32, f32)> +} + +llvm.func @complexintconstant() -> !llvm.struct<(i32, i32)> { + %1 = llvm.mlir.constant([-1 : i32, 0 : i32]) : !llvm.struct<(i32, i32)> + // CHECK: ret { i32, i32 } { i32 -1, i32 0 } + llvm.return %1 : !llvm.struct<(i32, i32)> +} + llvm.func @noreach() { // CHECK: unreachable llvm.unreachable