diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -32,10 +32,16 @@ if (!complexTy || arrAttr.size() != 2) return false; auto complexEltTy = complexTy.getElementType(); - auto re = arrAttr[0].dyn_cast(); - auto im = arrAttr[1].dyn_cast(); - return re && im && re.getType() == complexEltTy && - im.getType() == complexEltTy; + if (auto fre = arrAttr[0].dyn_cast()) { + auto im = arrAttr[1].dyn_cast(); + return im && fre.getType() == complexEltTy && + im.getType() == complexEltTy; + } + if (auto ire = arrAttr[0].dyn_cast()) { + auto im = arrAttr[1].dyn_cast(); + return im && ire.getType() == complexEltTy && + im.getType() == complexEltTy; + } } return false; } diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -891,9 +891,44 @@ ArrayRef values) { assert(hasSameElementsOrSplat(type, values)); + Type eltType = type.getElementType(); + + // Take care complex type case first. + if (auto complexType = eltType.dyn_cast()) { + if (complexType.getElementType().isIntOrIndex()) { + SmallVector> complexValues; + complexValues.reserve(values.size()); + for (Attribute attr : values) { + assert(attr.isa() && + "expected ArrayAttr for complex"); + auto arrayAttr = attr.cast(); + assert(arrayAttr.size() == 2 && "expected 2 element for complex"); + auto attr0 = arrayAttr[0]; + auto attr1 = arrayAttr[1]; + complexValues.push_back( + std::complex(attr0.cast().getValue(), + attr1.cast().getValue())); + } + return DenseElementsAttr::get(type, complexValues); + } + // Must be float. + SmallVector> complexValues; + complexValues.reserve(values.size()); + for (Attribute attr : values) { + assert(attr.isa() && "expected ArrayAttr for complex"); + auto arrayAttr = attr.cast(); + assert(arrayAttr.size() == 2 && "expected 2 element for complex"); + auto attr0 = arrayAttr[0]; + auto attr1 = arrayAttr[1]; + complexValues.push_back( + std::complex(attr0.cast().getValue(), + attr1.cast().getValue())); + } + return DenseElementsAttr::get(type, complexValues); + } + // If the element type is not based on int/float/index, assume it is a string // type. - Type eltType = type.getElementType(); if (!eltType.isIntOrIndexOrFloat()) { SmallVector stringValues; stringValues.reserve(values.size()); diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -207,6 +207,34 @@ // ----- +// CHECK-LABEL: func.func @extract_from_elements_complex_i() -> tensor<3xcomplex> { +// CHECK-NEXT: %cst = arith.constant dense<[(1,2), (3,2), (1,2)]> : tensor<3xcomplex> +// CHECK-NEXT: return %cst : tensor<3xcomplex> +func.func @extract_from_elements_complex_i() -> tensor<3xcomplex> { + %c1 = arith.constant dense<(1, 2)> : tensor> + %complex1 = tensor.extract %c1[] : tensor> + %c2 = arith.constant dense<(3, 2)> : tensor> + %complex2 = tensor.extract %c2[] : tensor> + %tensor = tensor.from_elements %complex1, %complex2, %complex1 : tensor<3xcomplex> + return %tensor : tensor<3xcomplex> +} + +// ----- + +// CHECK-LABEL: func.func @extract_from_elements_complex_f() -> tensor<3xcomplex> { +// CHECK-NEXT: %cst = arith.constant dense<[(1.200000e+00,2.300000e+00), (3.200000e+00,2.100000e+00), (1.200000e+00,2.300000e+00)]> : tensor<3xcomplex> +// CHECK-NEXT: return %cst : tensor<3xcomplex> +func.func @extract_from_elements_complex_f() -> tensor<3xcomplex> { + %c1 = arith.constant dense<(1.2, 2.3)> : tensor> + %complex1 = tensor.extract %c1[] : tensor> + %c2 = arith.constant dense<(3.2, 2.1)> : tensor> + %complex2 = tensor.extract %c2[] : tensor> + %tensor = tensor.from_elements %complex1, %complex2, %complex1 : tensor<3xcomplex> + return %tensor : tensor<3xcomplex> +} + +// ----- + // Ensure the optimization doesn't segfault from bad constants // CHECK-LABEL: func @extract_negative_from_tensor.from_elements func.func @extract_negative_from_tensor.from_elements(%element : index) -> index {