diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -255,11 +255,14 @@ /// DictionaryAttr for the LLVM dialect. DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs, OpBuilder &builder); - /// Returns the builtin type equivalent to be used in attributes for the given - /// LLVM IR dialect type. - Type getStdTypeForAttr(Type type); - /// Returns `value` as an attribute to attach to a GlobalOp. - Attribute getConstantAsAttr(llvm::Constant *value); + /// Returns the builtin type equivalent to the given LLVM dialect type or + /// nullptr if there is no equivalent. The returned type can be used to create + /// an attribute for a GlobalOp or a ConstantOp. + Type getBuiltinTypeForAttr(Type type); + /// Returns `constant` as an attribute to attach to a GlobalOp or ConstantOp + /// or nullptr if the constant is not convertible. It supports scalar integer + /// and float constants as well as shaped types thereof including strings. + Attribute getConstantAsAttr(llvm::Constant *constant); /// Returns the topologically sorted set of transitive dependencies needed to /// convert the given constant. SetVector getConstantsToConvert(llvm::Constant *constant); diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -621,144 +621,193 @@ iface->setAttr(iface.getFastmathAttrName(), attr); } -// We only need integers, floats, doubles, and vectors and tensors thereof for -// attributes. Scalar and vector types are converted to the standard -// equivalents. Array types are converted to ranked tensors; nested array types -// are converted to multi-dimensional tensors or vectors, depending on the -// innermost type being a scalar or a vector. -Type ModuleImport::getStdTypeForAttr(Type type) { - if (!type) - return nullptr; +/// Returns if `type` is a scalar integer or floating-point type. +static bool isScalarType(Type type) { + return type.isa(); +} - if (type.isa()) - return type; +/// Returns `type` if it is a builtin integer or floating-point vector type that +/// can be used to create an attribute or nullptr otherwise. If provided, +/// `arrayShape` is added to the shape of the vector to create an attribute that +/// matches an array of vectors. +static Type getVectorTypeForAttr(Type type, ArrayRef arrayShape = {}) { + if (!LLVM::isCompatibleVectorType(type)) + return {}; - // LLVM vectors can only contain scalars. - if (LLVM::isCompatibleVectorType(type)) { - llvm::ElementCount numElements = LLVM::getVectorNumElements(type); - if (numElements.isScalable()) { - emitError(UnknownLoc::get(context)) << "scalable vectors not supported"; - return nullptr; - } - Type elementType = getStdTypeForAttr(LLVM::getVectorElementType(type)); - if (!elementType) - return nullptr; - return VectorType::get(numElements.getKnownMinValue(), elementType); + llvm::ElementCount numElements = LLVM::getVectorNumElements(type); + if (numElements.isScalable()) { + emitError(UnknownLoc::get(type.getContext())) + << "scalable vectors not supported"; + return {}; } - // LLVM arrays can contain other arrays or vectors. - if (auto arrayType = type.dyn_cast()) { - // Recover the nested array shape. - SmallVector shape; - shape.push_back(arrayType.getNumElements()); - while (arrayType.getElementType().isa()) { - arrayType = arrayType.getElementType().cast(); - shape.push_back(arrayType.getNumElements()); - } + // An LLVM dialect vector can only contain scalars. + Type elementType = LLVM::getVectorElementType(type); + if (!isScalarType(elementType)) + return {}; - // If the innermost type is a vector, use the multi-dimensional vector as - // attribute type. - if (LLVM::isCompatibleVectorType(arrayType.getElementType())) { - llvm::ElementCount numElements = - LLVM::getVectorNumElements(arrayType.getElementType()); - if (numElements.isScalable()) { - emitError(UnknownLoc::get(context)) << "scalable vectors not supported"; - return nullptr; - } - shape.push_back(numElements.getKnownMinValue()); + SmallVector shape(arrayShape.begin(), arrayShape.end()); + shape.push_back(numElements.getKnownMinValue()); + return VectorType::get(shape, elementType); +} - Type elementType = getStdTypeForAttr( - LLVM::getVectorElementType(arrayType.getElementType())); - if (!elementType) - return nullptr; - return VectorType::get(shape, elementType); - } +Type ModuleImport::getBuiltinTypeForAttr(Type type) { + if (!type) + return {}; - // Otherwise use a tensor. - Type elementType = getStdTypeForAttr(arrayType.getElementType()); - if (!elementType) - return nullptr; - return RankedTensorType::get(shape, elementType); - } + // Return builtin integer and floating-point types as is. + if (isScalarType(type)) + return type; + + // Return builtin vectors of integer and floating-point types as is. + if (Type vectorType = getVectorTypeForAttr(type)) + return vectorType; - return nullptr; + // Multi-dimensional array types are converted to tensors or vectors, + // depending on the innermost type being a scalar or a vector. + SmallVector arrayShape; + while (auto arrayType = dyn_cast(type)) { + arrayShape.push_back(arrayType.getNumElements()); + type = arrayType.getElementType(); + } + if (isScalarType(type)) + return RankedTensorType::get(arrayShape, type); + return getVectorTypeForAttr(type, arrayShape); } -// Get the given constant as an attribute. Not all constants can be represented -// as attributes. -Attribute ModuleImport::getConstantAsAttr(llvm::Constant *value) { - if (auto *ci = dyn_cast(value)) +/// Returns an integer or float attribute for the provided scalar constant +/// `constScalar` or nullptr if the conversion fails. +static Attribute getScalarConstantAsAttr(OpBuilder &builder, + llvm::Constant *constScalar) { + MLIRContext *context = builder.getContext(); + + // Convert scalar intergers. + if (auto *constInt = dyn_cast(constScalar)) { return builder.getIntegerAttr( - IntegerType::get(context, ci->getType()->getBitWidth()), - ci->getValue()); - if (auto *c = dyn_cast(value)) - if (c->isString()) - return builder.getStringAttr(c->getAsString()); - if (auto *c = dyn_cast(value)) { - llvm::Type *type = c->getType(); - FloatType floatTy; - if (type->isBFloatTy()) - floatTy = FloatType::getBF16(context); - else - floatTy = detail::getFloatType(context, type->getScalarSizeInBits()); - assert(floatTy && "unsupported floating point type"); - return builder.getFloatAttr(floatTy, c->getValueAPF()); + IntegerType::get(context, constInt->getType()->getBitWidth()), + constInt->getValue()); } - if (auto *f = dyn_cast(value)) - return SymbolRefAttr::get(builder.getContext(), f->getName()); - - // Convert constant data to a dense elements attribute. - if (auto *cd = dyn_cast(value)) { - Type type = convertType(cd->getElementType()); - auto attrType = getStdTypeForAttr(convertType(cd->getType())) - .dyn_cast_or_null(); - if (!attrType) - return nullptr; - - if (type.isa()) { - SmallVector values; - values.reserve(cd->getNumElements()); - for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i) - values.push_back(cd->getElementAsAPInt(i)); - return DenseElementsAttr::get(attrType, values); - } - if (type.isa()) { - SmallVector values; - values.reserve(cd->getNumElements()); - for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i) - values.push_back(cd->getElementAsAPFloat(i)); - return DenseElementsAttr::get(attrType, values); + // Convert scalar floats. + if (auto *constFloat = dyn_cast(constScalar)) { + llvm::Type *type = constFloat->getType(); + FloatType floatType = + type->isBFloatTy() + ? FloatType::getBF16(context) + : LLVM::detail::getFloatType(context, type->getScalarSizeInBits()); + if (!floatType) { + emitError(UnknownLoc::get(builder.getContext())) + << "unexpected floating-point type"; + return {}; } + return builder.getFloatAttr(floatType, constFloat->getValueAPF()); + } + return {}; +} - return nullptr; +/// Returns an integer or float attribute array for the provided constant +/// sequence `constSequence` or nullptr if the conversion fails. +static SmallVector +getSequenceConstantAsAttrs(OpBuilder &builder, + llvm::ConstantDataSequential *constSequence) { + SmallVector elementAttrs; + elementAttrs.reserve(constSequence->getNumElements()); + for (auto idx : llvm::seq(0, constSequence->getNumElements())) { + llvm::Constant *constElement = constSequence->getElementAsConstant(idx); + elementAttrs.push_back(getScalarConstantAsAttr(builder, constElement)); } + return elementAttrs; +} + +Attribute ModuleImport::getConstantAsAttr(llvm::Constant *constant) { + // Convert scalar constants. + if (Attribute scalarAttr = getScalarConstantAsAttr(builder, constant)) + return scalarAttr; + + // Convert function references. + if (auto *func = dyn_cast(constant)) + return SymbolRefAttr::get(builder.getContext(), func->getName()); + + // Returns the static shape of the provided type if possible. + auto getConstantShape = [&](llvm::Type *type) { + return getBuiltinTypeForAttr(convertType(type)) + .dyn_cast_or_null(); + }; - // Unpack constant aggregates to create dense elements attribute whenever - // possible. Return nullptr (failure) otherwise. - if (isa(value)) { - auto outerType = getStdTypeForAttr(convertType(value->getType())) - .dyn_cast_or_null(); - if (!outerType) - return nullptr; - - SmallVector values; - SmallVector shape; - - for (unsigned i = 0, e = value->getNumOperands(); i < e; ++i) { - auto nested = getConstantAsAttr(value->getAggregateElement(i)) - .dyn_cast_or_null(); - if (!nested) - return nullptr; - - values.append(nested.value_begin(), - nested.value_end()); + // Convert one-dimensional constant arrays or vectors that store 1/2/4/8-byte + // integer or half/bfloat/float/double values. + if (auto *constArray = dyn_cast(constant)) { + if (constArray->isString()) + return builder.getStringAttr(constArray->getAsString()); + auto shape = getConstantShape(constArray->getType()); + if (!shape) + return {}; + // Convert splat constants to splat elements attributes. + auto *constVector = dyn_cast(constant); + if (constVector && constVector->isSplat()) { + // A vector is guaranteed to have at least size one. + Attribute splatAttr = getScalarConstantAsAttr( + builder, constVector->getElementAsConstant(0)); + return SplatElementsAttr::get(shape, splatAttr); } + // Convert non-splat constants to dense elements attributes. + SmallVector elementAttrs = + getSequenceConstantAsAttrs(builder, constArray); + return DenseElementsAttr::get(shape, elementAttrs); + } - return DenseElementsAttr::get(outerType, values); + // Convert multi-dimensional constant aggregates that store all kinds of + // integer and floating-point types. + if (auto *constAggregate = dyn_cast(constant)) { + auto shape = getConstantShape(constAggregate->getType()); + if (!shape) + return {}; + // Collect the aggregate elements in depths first order. + SmallVector elementAttrs; + SmallVector workList = {constAggregate}; + while (!workList.empty()) { + llvm::Constant *current = workList.pop_back_val(); + // Append any nested aggregates in reverse order to ensure the head + // element of the nested aggregates is at the back of the work list. + if (auto *constAggregate = dyn_cast(current)) { + for (auto idx : + reverse(llvm::seq(0, constAggregate->getNumOperands()))) + workList.push_back(constAggregate->getAggregateElement(idx)); + continue; + } + // Append the elements of nested constant arrays or vectors that store + // 1/2/4/8-byte integer or half/bfloat/float/double values. + if (auto *constArray = dyn_cast(current)) { + SmallVector attrs = + getSequenceConstantAsAttrs(builder, constArray); + elementAttrs.append(attrs.begin(), attrs.end()); + continue; + } + // Append nested scalar constants that store all kinds of integer and + // floating-point types. + if (Attribute scalarAttr = getScalarConstantAsAttr(builder, current)) { + elementAttrs.push_back(scalarAttr); + continue; + } + // Bail if the aggregate contains a unsupported constant type such as a + // constant expression. + return {}; + } + return DenseElementsAttr::get(shape, elementAttrs); } - return nullptr; + // Convert zero aggregates. + if (auto *constZero = dyn_cast(constant)) { + auto shape = getBuiltinTypeForAttr(convertType(constZero->getType())) + .dyn_cast_or_null(); + if (!shape) + return {}; + // Convert zero aggregates with a static shape to splat elements attributes. + Attribute splatAttr = builder.getZeroAttr(shape.getElementType()); + assert(splatAttr && "expected non-null zero attribute for scalar types"); + return SplatElementsAttr::get(shape, splatAttr); + } + return {}; } LogicalResult ModuleImport::convertGlobal(llvm::GlobalVariable *globalVar) { diff --git a/mlir/test/Target/LLVMIR/Import/global-variables.ll b/mlir/test/Target/LLVMIR/Import/global-variables.ll --- a/mlir/test/Target/LLVMIR/Import/global-variables.ll +++ b/mlir/test/Target/LLVMIR/Import/global-variables.ll @@ -168,20 +168,58 @@ @array_constant = internal constant [2 x float] [float 1., float 2.] ; CHECK: llvm.mlir.global internal constant @nested_array_constant -; CHECK-SAME: (dense<[{{\[}}1, 2], [3, 4]]> : tensor<2x2xi32>) -; CHECK-SAME: {addr_space = 0 : i32, dso_local} : !llvm.array<2 x array<2 x i32>> +; CHECK-SAME-LITERAL: (dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>) +; CHECK-SAME-LITERAL: {addr_space = 0 : i32, dso_local} : !llvm.array<2 x array<2 x i32>> @nested_array_constant = internal constant [2 x [2 x i32]] [[2 x i32] [i32 1, i32 2], [2 x i32] [i32 3, i32 4]] ; CHECK: llvm.mlir.global internal constant @nested_array_constant3 -; CHECK-SAME: (dense<[{{\[}}[1, 2], [3, 4]]]> : tensor<1x2x2xi32>) -; CHECK-SAME: {addr_space = 0 : i32, dso_local} : !llvm.array<1 x array<2 x array<2 x i32>>> +; CHECK-SAME-LITERAL: (dense<[[[1, 2], [3, 4]]]> : tensor<1x2x2xi32>) +; CHECK-SAME-LITERAL: {addr_space = 0 : i32, dso_local} : !llvm.array<1 x array<2 x array<2 x i32>>> @nested_array_constant3 = internal constant [1 x [2 x [2 x i32]]] [[2 x [2 x i32]] [[2 x i32] [i32 1, i32 2], [2 x i32] [i32 3, i32 4]]] ; CHECK: llvm.mlir.global internal constant @nested_array_vector -; CHECK-SAME: (dense<[{{\[}}[1, 2], [3, 4]]]> : vector<1x2x2xi32>) -; CHECK-SAME: {addr_space = 0 : i32, dso_local} : !llvm.array<1 x array<2 x vector<2xi32>>> +; CHECK-SAME-LITERAL: (dense<[[[1, 2], [3, 4]]]> : vector<1x2x2xi32>) +; CHECK-SAME-LITERAL: {addr_space = 0 : i32, dso_local} : !llvm.array<1 x array<2 x vector<2xi32>>> @nested_array_vector = internal constant [1 x [2 x <2 x i32>]] [[2 x <2 x i32>] [<2 x i32> , <2 x i32> ]] +; CHECK: llvm.mlir.global internal constant @vector_constant_zero +; CHECK-SAME: (dense<0> : vector<2xi24>) +; CHECK-SAME: {addr_space = 0 : i32, dso_local} : vector<2xi24> +@vector_constant_zero = internal constant <2 x i24> zeroinitializer + +; CHECK: llvm.mlir.global internal constant @array_constant_zero +; CHECK-SAME: (dense<0.000000e+00> : tensor<2xbf16>) +; CHECK-SAME: {addr_space = 0 : i32, dso_local} : !llvm.array<2 x bf16> +@array_constant_zero = internal constant [2 x bfloat] zeroinitializer + +; CHECK: llvm.mlir.global internal constant @nested_array_constant3_zero +; CHECK-SAME: (dense<0> : tensor<1x2x2xi32>) +; CHECK-SAME: {addr_space = 0 : i32, dso_local} : !llvm.array<1 x array<2 x array<2 x i32>>> +@nested_array_constant3_zero = internal constant [1 x [2 x [2 x i32]]] zeroinitializer + +; CHECK: llvm.mlir.global internal constant @nested_array_vector_zero +; CHECK-SAME: (dense<0> : vector<1x2x2xi32>) +; CHECK-SAME: {addr_space = 0 : i32, dso_local} : !llvm.array<1 x array<2 x vector<2xi32>>> +@nested_array_vector_zero = internal constant [1 x [2 x <2 x i32>]] zeroinitializer + +; CHECK: llvm.mlir.global internal constant @nested_bool_array_constant +; CHECK-SAME-LITERAL: (dense<[[true, false]]> : tensor<1x2xi1>) +; CHECK-SAME-LITERAL: {addr_space = 0 : i32, dso_local} : !llvm.array<1 x array<2 x i1>> +@nested_bool_array_constant = internal constant [1 x [2 x i1]] [[2 x i1] [i1 1, i1 0]] + +; CHECK: llvm.mlir.global internal constant @quad_float_constant +; CHECK-SAME: dense<[ +; CHECK-SAME: 529.340000000000031832314562052488327 +; CHECK-SAME: 529.340000000001850821718107908964157 +; CHECK-SAME: ]> : vector<2xf128>) +; CHECK-SAME: {addr_space = 0 : i32, dso_local} : vector<2xf128> +@quad_float_constant = internal constant <2 x fp128> + +; CHECK: llvm.mlir.global internal constant @quad_float_splat_constant +; CHECK-SAME: dense<529.340000000000031832314562052488327> : vector<2xf128>) +; CHECK-SAME: {addr_space = 0 : i32, dso_local} : vector<2xf128> +@quad_float_splat_constant = internal constant <2 x fp128> + ; // ----- ; CHECK: llvm.mlir.global_ctors {ctors = [@foo, @bar], priorities = [0 : i32, 42 : i32]}