diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -101,6 +101,92 @@ } while (true); } +/// Convert a dense elements attribute to an LLVM IR constant using its raw data +/// storage if possible. This supports elements attributes of tensor or vector +/// type and avoids constructing separate objects for individual values of the +/// innermost dimension. Constants for other dimensions are still constructed +/// recursively. Returns null if constructing from raw data is not supported for +/// this type, e.g., element type is not a power-of-two-sized primitive. Reports +/// other errors at `loc`. +static llvm::Constant * +convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr, + llvm::Type *llvmType, + const ModuleTranslation &moduleTranslation) { + if (!denseElementsAttr) + return nullptr; + + llvm::Type *innermostLLVMType = getInnermostElementType(llvmType); + if (!llvm::ConstantDataSequential::isElementTypeCompatible(innermostLLVMType)) + return nullptr; + + // Compute the shape of all dimensions but the innermost. Note that the + // innermost dimension may be that of the vector element type. + ShapedType type = denseElementsAttr.getType(); + bool hasVectorElementType = type.getElementType().isa(); + unsigned numAggregates = + denseElementsAttr.getNumElements() / + (hasVectorElementType ? 1 + : denseElementsAttr.getType().getShape().back()); + ArrayRef outerShape = type.getShape(); + if (!hasVectorElementType) + outerShape = outerShape.drop_back(); + + // Handle the case of vector splat, LLVM has special support for it. + if (denseElementsAttr.isSplat() && + (type.isa() || hasVectorElementType)) { + llvm::Constant *splatValue = LLVM::detail::getLLVMConstant( + innermostLLVMType, denseElementsAttr.getSplatValue(), loc, + moduleTranslation, /*isTopLevel=*/false); + llvm::Constant *splatVector = + llvm::ConstantDataVector::getSplat(0, splatValue); + SmallVector constants(numAggregates, splatVector); + ArrayRef constantsRef = constants; + return buildSequentialConstant(constantsRef, outerShape, llvmType, loc); + } + if (denseElementsAttr.isSplat()) + return nullptr; + + // In case of non-splat, create a constructor for the innermost constant from + // a piece of raw data. + std::function buildCstData; + if (type.isa()) { + auto vectorElementType = type.getElementType().dyn_cast(); + if (vectorElementType && vectorElementType.getRank() == 1) { + buildCstData = [&](StringRef data) { + return llvm::ConstantDataVector::getRaw( + data, vectorElementType.getShape().back(), innermostLLVMType); + }; + } else if (!vectorElementType) { + buildCstData = [&](StringRef data) { + return llvm::ConstantDataArray::getRaw(data, type.getShape().back(), + innermostLLVMType); + }; + } + } else if (type.isa()) { + buildCstData = [&](StringRef data) { + return llvm::ConstantDataVector::getRaw(data, type.getShape().back(), + innermostLLVMType); + }; + } + if (!buildCstData) + return nullptr; + + // Create innermost constants and defer to the default constant creation + // mechanism for other dimensions. + SmallVector constants; + unsigned aggregateSize = denseElementsAttr.getType().getShape().back() * + (innermostLLVMType->getScalarSizeInBits() / 8); + constants.reserve(numAggregates); + for (unsigned i = 0; i < numAggregates; ++i) { + StringRef data(denseElementsAttr.getRawData().data() + i * aggregateSize, + aggregateSize); + constants.push_back(buildCstData(data)); + } + + ArrayRef constantsRef = constants; + return buildSequentialConstant(constantsRef, outerShape, llvmType, loc); +} + /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`. /// This currently supports integer, floating point, splat and dense element /// attributes and combinations thereof. Also, an array attribute with two @@ -178,6 +264,14 @@ } } + // Try using raw elements data if possible. + if (llvm::Constant *result = + convertDenseElementsAttr(loc, attr.dyn_cast(), + llvmType, moduleTranslation)) { + return result; + } + + // Fall back to element-by-element construction otherwise. if (auto elementsAttr = attr.dyn_cast()) { assert(elementsAttr.getType().hasStaticShape()); assert(!elementsAttr.getType().getShape().empty() && diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -50,6 +50,36 @@ llvm.return %gepinit : !llvm.ptr } +// CHECK{LITERAL}: @dense_float_vector = internal global <3 x float> +llvm.mlir.global internal @dense_float_vector(dense<[1.0, 2.0, 3.0]> : vector<3xf32>) : vector<3xf32> + +// CHECK{LITERAL}: @splat_float_vector = internal global <3 x float> +llvm.mlir.global internal @splat_float_vector(dense<42.0> : vector<3xf32>) : vector<3xf32> + +// CHECK{LITERAL}: @dense_double_vector = internal global <3 x double> +llvm.mlir.global internal @dense_double_vector(dense<[1.0, 2.0, 3.0]> : vector<3xf64>) : vector<3xf64> + +// CHECK{LITERAL}: @splat_double_vector = internal global <3 x double> +llvm.mlir.global internal @splat_double_vector(dense<42.0> : vector<3xf64>) : vector<3xf64> + +// CHECK{LITERAL}: @dense_i64_vector = internal global <3 x i64> +llvm.mlir.global internal @dense_i64_vector(dense<[1, 2, 3]> : vector<3xi64>) : vector<3xi64> + +// CHECK{LITERAL}: @splat_i64_vector = internal global <3 x i64> +llvm.mlir.global internal @splat_i64_vector(dense<42> : vector<3xi64>) : vector<3xi64> + +// CHECK{LITERAL}: @dense_float_vector_2d = internal global [2 x <2 x float>] [<2 x float> , <2 x float> ] +llvm.mlir.global internal @dense_float_vector_2d(dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>) : !llvm.array<2 x vector<2xf32>> + +// CHECK{LITERAL}: @splat_float_vector_2d = internal global [2 x <2 x float>] [<2 x float> , <2 x float> ] +llvm.mlir.global internal @splat_float_vector_2d(dense<42.0> : vector<2x2xf32>) : !llvm.array<2 x vector<2xf32>> + +// CHECK{LITERAL}: @dense_float_vector_3d = internal global [2 x [2 x <2 x float>]] [[2 x <2 x float>] [<2 x float> , <2 x float> ], [2 x <2 x float>] [<2 x float> , <2 x float> ]] +llvm.mlir.global internal @dense_float_vector_3d(dense<[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]> : vector<2x2x2xf32>) : !llvm.array<2 x !llvm.array<2 x vector<2xf32>>> + +// CHECK{LITERAL}: @splat_float_vector_3d = internal global [2 x [2 x <2 x float>]] [[2 x <2 x float>] [<2 x float> , <2 x float> ], [2 x <2 x float>] [<2 x float> , <2 x float> ]] +llvm.mlir.global internal @splat_float_vector_3d(dense<42.0> : vector<2x2x2xf32>) : !llvm.array<2 x !llvm.array<2 x vector<2xf32>>> + // // Linkage attribute. // @@ -67,7 +97,7 @@ // CHECK: @common = common global i32 0 llvm.mlir.global common @common(0 : i32) : i32 // CHECK: @appending = appending global [3 x i32] [i32 1, i32 2, i32 3] -llvm.mlir.global appending @appending(dense<[1,2,3]> : vector<3xi32>) : !llvm.array<3xi32> +llvm.mlir.global appending @appending(dense<[1,2,3]> : tensor<3xi32>) : !llvm.array<3xi32> // CHECK: @extern_weak = extern_weak global i32 llvm.mlir.global extern_weak @extern_weak() : i32 // CHECK: @linkonce_odr = linkonce_odr global i32 42