diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Module.h" +#include "mlir/IR/StandardTypes.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/SetVector.h" @@ -30,6 +31,43 @@ using namespace mlir; using namespace mlir::LLVM; +static llvm::Constant * +buildSequentialConstant(ArrayRef &constants, + ArrayRef shape, llvm::Type *type, + Location loc) { + if (shape.empty()) { + llvm::Constant *result = constants.front(); + constants = constants.drop_front(); + return result; + } + + if (!isa(type)) { + emitError(loc) << "expected sequential LLVM types wrapping a scalar"; + return nullptr; + } + + llvm::Type *elementType = type->getSequentialElementType(); + SmallVector nested; + nested.reserve(shape.front()); + for (int64_t i = 0; i < shape.front(); ++i) { + nested.push_back(buildSequentialConstant(constants, shape.drop_front(), + elementType, loc)); + if (!nested.back()) + return nullptr; + } + + if (shape.size() == 1 && type->isVectorTy()) + return llvm::ConstantVector::get(nested); + return llvm::ConstantArray::get( + llvm::ArrayType::get(elementType, shape.front()), nested); +} + +static llvm::Type *getInnermostElementType(llvm::Type *type) { + while (isa(type)) + type = type->getSequentialElementType(); + return type; +} + /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`. /// This currently supports integer, floating point, splat and dense element /// attributes and combinations thereof. In case of error, report it to `loc` @@ -39,6 +77,10 @@ Location loc) { if (!attr) return llvm::UndefValue::get(llvmType); + if (llvmType->isStructTy()) { + emitError(loc, "struct types are not supported in constants"); + return nullptr; + } if (auto intAttr = attr.dyn_cast()) return llvm::ConstantInt::get(llvmType, intAttr.getValue()); if (auto floatAttr = attr.dyn_cast()) @@ -57,6 +99,8 @@ isa(elementType) ? splatAttr : splatAttr.getSplatValue(), loc); + if (!child) + return nullptr; if (llvmType->isVectorTy()) return llvm::ConstantVector::getSplat(numElements, child); if (llvmType->isArrayTy()) { @@ -65,24 +109,29 @@ return llvm::ConstantArray::get(arrayType, constants); } } + if (auto elementsAttr = attr.dyn_cast()) { - auto *sequentialType = cast(llvmType); - auto elementType = sequentialType->getElementType(); - uint64_t numElements = sequentialType->getNumElements(); + assert(elementsAttr.getType().hasStaticShape()); + assert(elementsAttr.getNumElements() != 0 && + "unexpected empty elements attribute"); + assert(!elementsAttr.getType().getShape().empty() && + "unexpected empty elements attribute shape"); + SmallVector constants; - constants.reserve(numElements); + constants.reserve(elementsAttr.getNumElements()); + llvm::Type *innermostType = getInnermostElementType(llvmType); for (auto n : elementsAttr.getValues()) { - constants.push_back(getLLVMConstant(elementType, n, loc)); + constants.push_back(getLLVMConstant(innermostType, n, loc)); if (!constants.back()) return nullptr; } - if (llvmType->isVectorTy()) - return llvm::ConstantVector::get(constants); - if (llvmType->isArrayTy()) { - auto arrayType = llvm::ArrayType::get(elementType, numElements); - return llvm::ConstantArray::get(arrayType, constants); - } + ArrayRef constantsRef = constants; + llvm::Constant *result = buildSequentialConstant( + constantsRef, elementsAttr.getType().getShape(), llvmType, loc); + assert(constantsRef.empty() && "did not consume all elemental constants"); + return result; } + if (auto stringAttr = attr.dyn_cast()) { return llvm::ConstantDataArray::get( llvmModule->getContext(), ArrayRef{stringAttr.getValue().data(), diff --git a/mlir/test/Target/llvmir-invalid.mlir b/mlir/test/Target/llvmir-invalid.mlir --- a/mlir/test/Target/llvmir-invalid.mlir +++ b/mlir/test/Target/llvmir-invalid.mlir @@ -1,6 +1,14 @@ -// RUN: mlir-translate -verify-diagnostics -mlir-to-llvmir %s +// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s // expected-error @+1 {{unsupported module-level operation}} func @foo() { llvm.return } + +// ----- + +llvm.func @no_nested_struct() -> !llvm<"[2 x [2 x [2 x {i32}]]]"> { + // expected-error @+1 {{struct types are not supported in constants}} + %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm<"[2 x [2 x [2 x {i32}]]]"> + llvm.return %0 : !llvm<"[2 x [2 x [2 x {i32}]]]"> +} diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir --- a/mlir/test/Target/llvmir.mlir +++ b/mlir/test/Target/llvmir.mlir @@ -1067,3 +1067,22 @@ // CHECK: ret i32* null llvm.return %0 : !llvm<"i32*"> } + +// Check that dense elements attributes are exported properly in constants. +// CHECK-LABEL: @elements_constant_3d_vector +llvm.func @elements_constant_3d_vector() -> !llvm<"[2 x [2 x <2 x i32>]]"> { + // CHECK: ret [2 x [2 x <2 x i32>]] + // CHECK-SAME: {{\[}}[2 x <2 x i32>] [<2 x i32> , <2 x i32> ], + // CHECK-SAME: [2 x <2 x i32>] [<2 x i32> , <2 x i32> ]] + %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : vector<2x2x2xi32>) : !llvm<"[2 x [2 x <2 x i32>]]"> + llvm.return %0 : !llvm<"[2 x [2 x <2 x i32>]]"> +} + +// CHECK-LABEL: @elements_constant_3d_array +llvm.func @elements_constant_3d_array() -> !llvm<"[2 x [2 x [2 x i32]]]"> { + // CHECK: ret [2 x [2 x [2 x i32]]] + // CHECK-SAME: {{\[}}[2 x [2 x i32]] {{\[}}[2 x i32] [i32 1, i32 2], [2 x i32] [i32 3, i32 4]], + // CHECK-SAME: [2 x [2 x i32]] {{\[}}[2 x i32] [i32 42, i32 43], [2 x i32] [i32 44, i32 45]]] + %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm<"[2 x [2 x [2 x i32]]]"> + llvm.return %0 : !llvm<"[2 x [2 x [2 x i32]]]"> +}