diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td @@ -225,4 +225,42 @@ }]; } +//===----------------------------------------------------------------------===// +// LLVMTargetExtType +//===----------------------------------------------------------------------===// + +def LLVMTargetExtType : LLVMType<"LLVMTargetExt", "target"> { + let summary = "LLVM target-specific extension type"; + let description = [{ + LLVM dialect target extension type, which are generally unintrospectable + from target-independent optimizations. + + Target extension types have a string name, and optionally have type and/or + integer parameters. The exact meaning of any parameters is dependent on the + target. + }]; + + let parameters = (ins StringRefParameter<>:$extTypeName, + OptionalArrayRefParameter<"Type">:$typeParams, + OptionalArrayRefParameter<"unsigned int">:$intParams); + + + let assemblyFormat = [{ + `<` $extTypeName (`,` custom($typeParams, $intParams)^ )? `>` + }]; + + let extraClassDeclaration = [{ + enum Property { + /// zeroinitializer is valid for this target extension type. + HasZeroInit = 1U << 0, + /// This type may be used as the value type of a global variable. + CanBeGlobal = 1U << 1, + }; + + bool hasProperty(Property Prop) const; + + bool supportsAlloca() const; + }]; +} + #endif // LLVMTYPES_TD diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -266,8 +266,22 @@ } LogicalResult AllocaOp::verify() { - return verifyOpaquePtr(getOperation(), llvm::cast(getType()), - getElemType()); + LLVMPointerType ptrType = llvm::cast(getType()); + if (auto verifyElemTy = + verifyOpaquePtr(getOperation(), ptrType, getElemType()); + failed(verifyElemTy)) { + return verifyElemTy; + } + + Type elemTy = + (ptrType.isOpaque()) ? *getElemType() : ptrType.getElementType(); + // Only certain target extension types can be used in 'alloca'. + if (auto targetExtType = dyn_cast(elemTy); + targetExtType && !targetExtType.supportsAlloca()) { + return emitOpError() + << "this target extension type cannot be used in alloca"; + } + return success(); } //===----------------------------------------------------------------------===// @@ -731,6 +745,9 @@ if (llvm::isa(type)) return isPointerTypeAllowed; + if (auto targetExtType = dyn_cast(type)) + return targetExtType.supportsAlloca(); + std::optional bitWidth; if (auto floatType = llvm::dyn_cast(type)) { if (!isCompatibleFloatingPointType(type)) @@ -1832,6 +1849,21 @@ "attribute"); } + if (auto targetExtType = dyn_cast(getType())) { + if (!targetExtType.hasProperty(LLVMTargetExtType::CanBeGlobal)) { + return emitOpError() + << "this target extension type cannot be used in a global"; + } + if (Attribute value = getValueOrNull()) { + // Only a single, zero integer attribute (=zeroinitializer) is allowed a + // value for a global with TargetExtType. + if (!isa(value) || !isZeroAttribute(value)) { + return emitOpError() + << "expected zero value for global with target extension type"; + } + } + } + if (getLinkage() == Linkage::Common) { if (Attribute value = getValueOrNull()) { if (!isZeroAttribute(value)) { @@ -2288,6 +2320,19 @@ } return success(); } + if (auto targetExtType = dyn_cast(getType())) { + if (!targetExtType.hasProperty(LLVM::LLVMTargetExtType::HasZeroInit)) { + return emitOpError() + << "target extension type does not support zero-initializer"; + } + // Only a single, zero integer attribute (=zeroinitializer) is allowed a + // value for a global with TargetExtType. + if (!isa(getValue()) || !isZeroAttribute(getValue())) { + return emitOpError() + << "only zero-initializer allowed for target extension types"; + } + return success(); + } if (!llvm::isa(getValue())) return emitOpError() << "only supports integer, float, string or elements attributes"; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -45,6 +45,7 @@ [&](Type) { return "vec"; }) .Case([&](Type) { return "array"; }) .Case([&](Type) { return "struct"; }) + .Case([&](Type) { return "target"; }) .Default([](Type) -> StringRef { llvm_unreachable("unexpected 'llvm' type kind"); }); @@ -119,7 +120,7 @@ llvm::TypeSwitch(type) .Case( + LLVMScalableVectorType, LLVMFunctionType, LLVMTargetExtType>( [&](auto type) { type.print(printer); }) .Case([&](LLVMStructType structType) { printStructType(printer, structType); @@ -332,6 +333,7 @@ .Case("vec", [&] { return parseVectorType(parser); }) .Case("array", [&] { return LLVMArrayType::parse(parser); }) .Case("struct", [&] { return parseStructType(parser); }) + .Case("target", [&] { return LLVMTargetExtType::parse(parser); }) .Default([&] { parser.emitError(keyLoc) << "unknown LLVM type: " << key; return Type(); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -109,6 +109,56 @@ } } +//===----------------------------------------------------------------------===// +// custom +//===----------------------------------------------------------------------===// + +static bool parseTypeOrIntParam(AsmParser &p, SmallVectorImpl &typeParams, + SmallVectorImpl &intParams, + bool &parseType) { + unsigned int i; + if (p.parseOptionalInteger(i).has_value()) { + // Successfully parsed an integer. + intParams.push_back(i); + // After the first integer was successfully parsed, no more types can be + // parsed. + parseType = false; + return true; + } + if (parseType) { + Type t; + if (!parsePrettyLLVMType(p, t)) { + // Successfully parsed a type. + typeParams.push_back(t); + return true; + } + } + // Failed to parse a type or an integer. + return false; +} + +static ParseResult +parseExtTypeParams(AsmParser &p, SmallVectorImpl &typeParams, + SmallVectorImpl &intParams) { + bool parseType = true; + // ([type | integer ])? (, [type | integer])* | empty + bool keepParsing = parseTypeOrIntParam(p, typeParams, intParams, parseType); + while (keepParsing) { + keepParsing = !p.parseOptionalComma() && + parseTypeOrIntParam(p, typeParams, intParams, parseType); + } + return success(); +} + +static void printExtTypeParams(AsmPrinter &p, ArrayRef typeParams, + ArrayRef intParams) { + p << typeParams; + if (!typeParams.empty() && !intParams.empty()) + p << ", "; + + p << intParams; +} + //===----------------------------------------------------------------------===// // ODS-Generated Definitions //===----------------------------------------------------------------------===// @@ -721,6 +771,32 @@ emitError, elementType, numElements); } +//===----------------------------------------------------------------------===// +// LLVMTargetExtType. +//===----------------------------------------------------------------------===// + +bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const { + // See llvm/lib/IR/Type.cpp for reference. + uint64_t properties = 0; + + if (getExtTypeName().starts_with("spirv.")) + properties |= + (LLVMTargetExtType::HasZeroInit | LLVM::LLVMTargetExtType::CanBeGlobal); + + return (properties & prop) == prop; +} + +bool LLVM::LLVMTargetExtType::supportsAlloca() const { + // See llvm/lib/IR/Type.cpp for reference. + if (getExtTypeName().starts_with("spirv.")) + return true; + + if (getExtTypeName() == "aarch64.svcount") + return true; + + return false; +} + //===----------------------------------------------------------------------===// // Utility functions. //===----------------------------------------------------------------------===// @@ -746,6 +822,7 @@ LLVMTokenType, LLVMFixedVectorType, LLVMScalableVectorType, + LLVMTargetExtType, LLVMVoidType, LLVMX86MMXType >(type)) { @@ -791,6 +868,9 @@ return true; return isCompatible(pointerType.getElementType()); }) + .Case([&](auto extType) { + return llvm::all_of(extType.getTypeParams(), isCompatible); + }) // clang-format off .Case< LLVMFixedVectorType, @@ -974,7 +1054,8 @@ .Default([](Type ty) { assert((llvm::isa(ty)) && + LLVMPointerType, LLVMFunctionType, LLVMTargetExtType>( + ty)) && "unexpected missing support for primitive type"); return llvm::TypeSize::Fixed(0); }); diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -1065,6 +1065,21 @@ return root; } + if (auto *constTargetNone = dyn_cast(constant)) { + LLVMTargetExtType targetExtType = + cast(convertType(constTargetNone->getType())); + assert(targetExtType.hasProperty(LLVMTargetExtType::HasZeroInit) && + "target extension type does not support zero-initialization"); + // As the number of values needed for initialization is target-specific and + // opaque to the compiler, use a single i64 zero-valued attribute to + // represent the 'zeroinitializer', which is the only constant value allowed + // for target extension types (besides poison and undef). + return builder + .create(loc, targetExtType, + builder.getI64IntegerAttr(0)) + .getRes(); + } + StringRef error = ""; if (isa(constant)) error = " since blockaddress(...) is unsupported"; diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -342,6 +342,14 @@ return nullptr; return llvm::ConstantStruct::get(structType, {real, imag}); } + if (auto *targetExtType = dyn_cast<::llvm::TargetExtType>(llvmType)) { + if (auto intAttr = dyn_cast(attr); + !intAttr || intAttr.getInt() != 0) + emitError(loc, + "Only zero-initialization allowed for target extension type"); + + return llvm::ConstantTargetNone::get(targetExtType); + } // For integer types, we allow a mismatch in sizes as the index type in // MLIR might have a different size than the index type in the LLVM module. if (auto intAttr = dyn_cast(attr)) diff --git a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp --- a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp +++ b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp @@ -36,7 +36,7 @@ llvm::TypeSwitch(type) .Case( + llvm::ScalableVectorType, llvm::TargetExtType>( [this](auto *type) { return this->translate(type); }) .Default([this](llvm::Type *type) { return translatePrimitiveType(type); @@ -135,6 +135,15 @@ translateType(type->getElementType()), type->getMinNumElements()); } + /// Translates the given target extension type. + Type translate(llvm::TargetExtType *type) { + SmallVector typeParams; + translateTypes(type->type_params(), typeParams); + + return LLVM::LLVMTargetExtType::get(&context, type->getName(), typeParams, + type->int_params()); + } + /// Translates a list of types. void translateTypes(ArrayRef types, SmallVectorImpl &result) { diff --git a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp --- a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp +++ b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp @@ -73,7 +73,7 @@ .Case( + VectorType, LLVM::LLVMTargetExtType>( [this](auto type) { return this->translate(type); }) .Default([](Type t) -> llvm::Type * { llvm_unreachable("unknown LLVM dialect type"); @@ -155,6 +155,14 @@ type.getMinNumElements()); } + /// Translates the given target extension type. + llvm::Type *translate(LLVM::LLVMTargetExtType type) { + SmallVector typeParams; + translateTypes(type.getTypeParams(), typeParams); + return llvm::TargetExtType::get(context, type.getExtTypeName(), typeParams, + type.getIntParams()); + } + /// Translates a list of types. void translateTypes(ArrayRef types, SmallVectorImpl &result) { diff --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir --- a/mlir/test/Dialect/LLVMIR/global.mlir +++ b/mlir/test/Dialect/LLVMIR/global.mlir @@ -232,3 +232,16 @@ // CHECK: llvm.mlir.global_dtors {dtors = [@dtor], priorities = [0 : i32]} llvm.mlir.global_dtors { dtors = [@dtor], priorities = [0 : i32]} + +// ----- + +// CHECK: llvm.mlir.global external @target_ext() {addr_space = 0 : i32} : !llvm.target<"spirv.Image", i32, 0> +llvm.mlir.global @target_ext() : !llvm.target<"spirv.Image", i32, 0> + +// CHECK: llvm.mlir.global external @target_ext_init(0 : i64) {addr_space = 0 : i32} : !llvm.target<"spirv.Image", i32, 0> +llvm.mlir.global @target_ext_init(0 : i64) : !llvm.target<"spirv.Image", i32, 0> + +// ----- + +// expected-error @+1 {{expected zero value for global with target extension type}} +llvm.mlir.global @target_fail(1 : i64) : !llvm.target<"spirv.Image", i32, 0> diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1387,3 +1387,25 @@ // expected-error@+1 {{cannot cast pointers of different address spaces, use 'llvm.addrspacecast' instead}} %0 = llvm.bitcast %arg : !llvm.vec<4 x ptr<1>> to !llvm.vec<4 x ptr> } + +// ----- + +func.func @invalid_target_ext_alloca() { + %0 = llvm.mlir.constant(1 : i64) : i64 + // expected-error@+1 {{this target extension type cannot be used in alloca}} + %1 = llvm.alloca %0 x !llvm.target<"no_alloca"> : (i64) -> !llvm.ptr +} + +// ----- + +func.func @invalid_target_ext_constant() { + // expected-error@+1 {{target extension type does not support zero-initializer}} + %0 = llvm.mlir.constant(0 : index) : !llvm.target<"invalid_constant"> +} + +// ----- + +func.func @invalid_target_ext_constant() { + // expected-error@+1 {{only zero-initializer allowed for target extension types}} + %0 = llvm.mlir.constant(42 : index) : !llvm.target<"spirv.Event"> +} diff --git a/mlir/test/Dialect/LLVMIR/types.mlir b/mlir/test/Dialect/LLVMIR/types.mlir --- a/mlir/test/Dialect/LLVMIR/types.mlir +++ b/mlir/test/Dialect/LLVMIR/types.mlir @@ -176,3 +176,20 @@ "some.op"() : () -> !llvm.struct<(i32, f32, !qux)> llvm.return } + +// ----- + +// CHECK-LABEL: ext_target +llvm.func @ext_target() { + // CHECK: !llvm.target<"target1", i32, 1> + %0 = "some.op"() : () -> !llvm.target<"target1", i32, 1> + // CHECK: !llvm.target<"target2"> + %1 = "some.op"() : () -> !llvm.target<"target2"> + // CHECK: !llvm.target<"target3", i32, i64, f64> + %2 = "some.op"() : () -> !llvm.target<"target3", i32, i64, f64> + // CHECK: !llvm.target<"target4", 1, 0, 42> + %3 = "some.op"() : () -> !llvm.target<"target4", 1, 0, 42> + // CHECK: !llvm.target<"target5", i32, f64, 0, 5> + %4 = "some.op"() : () -> !llvm.target<"target5", i32, f64, 0, 5> + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/Import/target-ext-type.ll b/mlir/test/Target/LLVMIR/Import/target-ext-type.ll new file mode 100644 --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/target-ext-type.ll @@ -0,0 +1,24 @@ +; RUN: mlir-translate -import-llvm %s | FileCheck %s + +; CHECK: llvm.func spir_kernelcc @foo(%arg0: !llvm.target<"spirv.Pipe", 0> +; CHECK-SAME: %arg1: !llvm.target<"spirv.Pipe", 1> +; CHECK-SAME: %arg2: !llvm.target<"spirv.Image", !llvm.void, 0, 0, 0, 0, 0, 0, 0> +; CHECK-SAME: %arg3: !llvm.target<"spirv.Image", i32, 1, 0, 0, 0, 0, 0, 0> +; CHECK-SAME: %arg4: !llvm.target<"spirv.Image", i32, 2, 0, 0, 0, 0, 0, 0> +; CHECK-SAME: %arg5: !llvm.target<"spirv.Image", f16, 1, 0, 1, 0, 0, 0, 0> +; CHECK-SAME: %arg6: !llvm.target<"spirv.Image", f32, 5, 0, 0, 0, 0, 0, 0> +; CHECK-SAME: %arg7: !llvm.target<"spirv.Image", !llvm.void, 0, 0, 0, 0, 0, 0, 1> +; CHECK-SAME: %arg8: !llvm.target<"spirv.Image", !llvm.void, 1, 0, 0, 0, 0, 0, 2>) +define spir_kernel void @foo( + target("spirv.Pipe", 0) %a, + target("spirv.Pipe", 1) %b, + target("spirv.Image", void, 0, 0, 0, 0, 0, 0, 0) %c1, + target("spirv.Image", i32, 1, 0, 0, 0, 0, 0, 0) %d1, + target("spirv.Image", i32, 2, 0, 0, 0, 0, 0, 0) %e1, + target("spirv.Image", half, 1, 0, 1, 0, 0, 0, 0) %f1, + target("spirv.Image", float, 5, 0, 0, 0, 0, 0, 0) %g1, + target("spirv.Image", void, 0, 0, 0, 0, 0, 0, 1) %c2, + target("spirv.Image", void, 1, 0, 0, 0, 0, 0, 2) %d3) { +entry: + ret void +} diff --git a/mlir/test/Target/LLVMIR/llvmir-types.mlir b/mlir/test/Target/LLVMIR/llvmir-types.mlir --- a/mlir/test/Target/LLVMIR/llvmir-types.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-types.mlir @@ -141,6 +141,18 @@ // CHECK: declare <{ { i32 } }> @return_sp_s_i32() llvm.func @return_sp_s_i32() -> !llvm.struct)> +// CHECK: declare target("target-no-param") @return_target_ext_no_param() +llvm.func @return_target_ext_no_param() -> !llvm.target<"target-no-param"> + +// CHECK: declare target("target-type-param", i32, double) @return_target_ext_type_params() +llvm.func @return_target_ext_type_params() -> !llvm.target<"target-type-param", i32, f64> + +// CHECK: declare target("target-int-param", 0, 42) @return_target_ext_int_params() +llvm.func @return_target_ext_int_params() -> !llvm.target<"target-int-param", 0, 42> + +// CHECK: declare target("target-params", i32, double, 0, 5) @return_target_ext_params() +llvm.func @return_target_ext_params() -> !llvm.target<"target-params", i32, f64, 0, 5> + // ----- // Put structs into a separate split so that we can match their declarations // locally.