diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -78,6 +78,9 @@ /// `target`. SmallVector processBranchArgs(llvm::BranchInst *br, llvm::BasicBlock *target); + /// Returns the standard type equivalent to be used in attributes for the + /// given LLVM IR dialect type. + Type getStdTypeForAttr(LLVMType type); /// Return `value` as an attribute to attach to a GlobalOp. Attribute getConstantAsAttr(llvm::Constant *value); /// Return `c` as an MLIR Value. This could either be a ConstantOp, or @@ -193,6 +196,65 @@ } } +// We only need integers, floats, doubles, and vectors and tensors thereof for +// attributes. Scalar and vector types are converted to the standard +// equivalents. Array types are converted to ranked tensors; nested array types +// are converted to multi-dimensional tensors or vectors, depending on the +// innermost type being a scalar or a vector. +Type Importer::getStdTypeForAttr(LLVMType type) { + if (!type) + return nullptr; + + if (type.isIntegerTy()) + return b.getIntegerType(type.getUnderlyingType()->getIntegerBitWidth()); + + if (type.getUnderlyingType()->isFloatTy()) + return b.getF32Type(); + + if (type.getUnderlyingType()->isDoubleTy()) + return b.getF64Type(); + + // LLVM vectors can only contain scalars. + if (type.isVectorTy()) { + auto numElements = type.getUnderlyingType()->getVectorElementCount(); + if (numElements.Scalable) + emitError(unknownLoc) << "scalable vectors not supported"; + return VectorType::get(numElements.Min, + getStdTypeForAttr(type.getVectorElementType())); + } + + // LLVM arrays can contain other arrays or vectors. + if (type.isArrayTy()) { + // Recover the nested array shape. + SmallVector shape; + shape.push_back(type.getArrayNumElements()); + while (type.getArrayElementType().isArrayTy()) { + type = type.getArrayElementType(); + shape.push_back(type.getArrayNumElements()); + } + + // If the innermost type is a vector, use the multi-dimensional vector as + // attribute type. + if (type.getArrayElementType().isVectorTy()) { + LLVMType vectorType = type.getArrayElementType(); + auto numElements = + vectorType.getUnderlyingType()->getVectorElementCount(); + if (numElements.Scalable) + emitError(unknownLoc) << "scalable vectors not supported"; + shape.push_back(numElements.Min); + + LLVMType elementType = vectorType.getVectorElementType(); + return VectorType::get(shape, getStdTypeForAttr(elementType)); + } + + // Otherwise use a tensor. + return RankedTensorType::get(shape, + getStdTypeForAttr(type.getArrayElementType())); + } + + llvm_unreachable("no equivalent standard type for typed attributes"); +} + // Get the given constant as an attribute. Not all constants can be represented // as attributes. Attribute Importer::getConstantAsAttr(llvm::Constant *value) { @@ -211,7 +273,59 @@ } if (auto *f = dyn_cast(value)) return b.getSymbolRefAttr(f->getName()); - return Attribute(); + + // Convert constant data to a dense elements attribute. + if (auto *cd = dyn_cast(value)) { + LLVMType type = processType(cd->getElementType()); + auto attrType = getStdTypeForAttr(processType(cd->getType())) + .dyn_cast_or_null(); + assert(attrType); + if (!attrType) + return nullptr; + + if (type.isIntegerTy()) { + SmallVector values; + values.reserve(cd->getNumElements()); + for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i) + values.push_back(cd->getElementAsAPInt(i)); + return DenseElementsAttr::get(attrType, values); + } + + if (type.isFloatTy() || type.isDoubleTy()) { + SmallVector values; + values.reserve(cd->getNumElements()); + for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i) + values.push_back(cd->getElementAsAPFloat(i)); + return DenseElementsAttr::get(attrType, values); + } + + return nullptr; + } + + // Unpack constant aggregates to create dense elements attribute whenever + // possible. Return nullptr (failure) otherwise. + if (auto *ca = dyn_cast(value)) { + auto outerType = getStdTypeForAttr(processType(value->getType())) + .dyn_cast_or_null(); + if (!outerType) + return nullptr; + + SmallVector values; + SmallVector shape; + + for (unsigned i = 0, e = value->getNumOperands(); i < e; ++i) { + auto nested = getConstantAsAttr(value->getAggregateElement(i)) + .dyn_cast_or_null(); + if (!nested) + return nullptr; + + values.append(nested.attr_value_begin(), nested.attr_value_end()); + } + + return DenseElementsAttr::get(outerType, values); + } + + return nullptr; } /// Converts LLVM global variable linkage type into the LLVM dialect predicate. diff --git a/mlir/test/Target/import.ll b/mlir/test/Target/import.ll --- a/mlir/test/Target/import.ll +++ b/mlir/test/Target/import.ll @@ -49,6 +49,21 @@ ; CHECK: llvm.mlir.global external @external() : !llvm.i32 @external = external global i32 +; +; Sequential constants. +; + +; CHECK: llvm.mlir.global internal constant @vector_constant(dense<[1, 2]> : vector<2xi32>) : !llvm<"<2 x i32>"> +@vector_constant = internal constant <2 x i32> +; CHECK: llvm.mlir.global internal constant @array_constant(dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>) : !llvm<"[2 x float]"> +@array_constant = internal constant [2 x float] [float 1., float 2.] +; CHECK: llvm.mlir.global internal constant @nested_array_constant(dense<[{{\[}}1, 2], [3, 4]]> : tensor<2x2xi32>) : !llvm<"[2 x [2 x i32]]"> +@nested_array_constant = internal constant [2 x [2 x i32]] [[2 x i32] [i32 1, i32 2], [2 x i32] [i32 3, i32 4]] +; CHECK: llvm.mlir.global internal constant @nested_array_constant3(dense<[{{\[}}[1, 2], [3, 4]]]> : tensor<1x2x2xi32>) : !llvm<"[1 x [2 x [2 x i32]]]"> +@nested_array_constant3 = internal constant [1 x [2 x [2 x i32]]] [[2 x [2 x i32]] [[2 x i32] [i32 1, i32 2], [2 x i32] [i32 3, i32 4]]] +; CHECK: llvm.mlir.global internal constant @nested_array_vector(dense<[{{\[}}[1, 2], [3, 4]]]> : vector<1x2x2xi32>) : !llvm<"[1 x [2 x <2 x i32>]]"> +@nested_array_vector = internal constant [1 x [2 x <2 x i32>]] [[2 x <2 x i32>] [<2 x i32> , <2 x i32> ]] + ; CHECK: llvm.func @fe(!llvm.i32) -> !llvm.float declare float @fe(i32)