diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -84,10 +84,23 @@ And<[LLVM_AnyStruct.predicate, CPred<"::llvm::cast<::mlir::LLVM::LLVMStructType>($_self).isOpaque()">]>>; +// Type constraint accepting any LLVM target extension type. +def LLVM_AnyTargetExt : Type($_self)">, + "LLVM target extension type">; + +// Type constraint accepting LLVM target extension types with no support for +// memory operations such as alloca, load and store. +def LLVM_NonLoadableTargetExtType : Type< + And<[LLVM_AnyTargetExt.predicate, + CPred<"!::llvm::cast<::mlir::LLVM::LLVMTargetExtType>($_self).supportsMemOps()">] + >>; + // Type constraint accepting any LLVM type that can be loaded or stored, i.e. a -// type that has size (not void, function or opaque struct type). +// type that has size (not void, function, opaque struct type or target +// extension type which does not support memory operations). def LLVM_LoadableType : Type< - Or<[And<[LLVM_PrimitiveType.predicate, Neg]>, + Or<[And<[LLVM_PrimitiveType.predicate, Neg, + Neg]>, LLVM_PointerElementTypeInterface.predicate]>, "LLVM type with size">; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td @@ -225,4 +225,40 @@ }]; } +//===----------------------------------------------------------------------===// +// LLVMTargetExtType +//===----------------------------------------------------------------------===// + +def LLVMTargetExtType : LLVMType<"LLVMTargetExt", "target"> { + let summary = "LLVM target-specific extension type"; + let description = [{ + LLVM dialect target extension type, which are generally unintrospectable + from target-independent optimizations. + + Target extension types have a string name, and optionally have type and/or + integer parameters. The exact meaning of any parameters is dependent on the + target. + }]; + + let parameters = (ins StringRefParameter<>:$extTypeName, + OptionalArrayRefParameter<"Type">:$typeParams, + OptionalArrayRefParameter<"unsigned int">:$intParams); + + let assemblyFormat = [{ + `<` $extTypeName (`,` custom($typeParams, $intParams)^ )? `>` + }]; + + let extraClassDeclaration = [{ + enum Property { + /// zeroinitializer is valid for this target extension type. + HasZeroInit = 1U << 0, + /// This type may be used as the value type of a global variable. + CanBeGlobal = 1U << 1, + }; + + bool hasProperty(Property Prop) const; + bool supportsMemOps() const; + }]; +} + #endif // LLVMTYPES_TD diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -266,8 +266,19 @@ } LogicalResult AllocaOp::verify() { - return verifyOpaquePtr(getOperation(), llvm::cast(getType()), - getElemType()); + LLVMPointerType ptrType = llvm::cast(getType()); + if (failed(verifyOpaquePtr(getOperation(), ptrType, getElemType()))) + return failure(); + + Type elemTy = + (ptrType.isOpaque()) ? *getElemType() : ptrType.getElementType(); + // Only certain target extension types can be used in 'alloca'. + if (auto targetExtType = dyn_cast(elemTy); + targetExtType && !targetExtType.supportsMemOps()) + return emitOpError() + << "this target extension type cannot be used in alloca"; + + return success(); } //===----------------------------------------------------------------------===// @@ -1832,6 +1843,22 @@ "attribute"); } + if (auto targetExtType = dyn_cast(getType())) { + if (!targetExtType.hasProperty(LLVMTargetExtType::CanBeGlobal)) + return emitOpError() + << "this target extension type cannot be used in a global"; + + if (Attribute value = getValueOrNull()) { + // Only a single, zero integer attribute (=zeroinitializer) is allowed for + // a global value with TargetExtType. + // TODO: Replace with 'zeroinitializer' once there is a dedicated + // zeroinitializer operation in the LLVM dialect. + if (!isa(value) || !isZeroAttribute(value)) + return emitOpError() + << "expected zero value for global with target extension type"; + } + } + if (getLinkage() == Linkage::Common) { if (Attribute value = getValueOrNull()) { if (!isZeroAttribute(value)) { @@ -2288,6 +2315,18 @@ } return success(); } + if (auto targetExtType = dyn_cast(getType())) { + if (!targetExtType.hasProperty(LLVM::LLVMTargetExtType::HasZeroInit)) + return emitOpError() + << "target extension type does not support zero-initializer"; + // Only a single, zero integer attribute (=zeroinitializer) is allowed for a + // global value with TargetExtType. + if (!isa(getValue()) || !isZeroAttribute(getValue())) + return emitOpError() + << "only zero-initializer allowed for target extension types"; + + return success(); + } if (!llvm::isa(getValue())) return emitOpError() << "only supports integer, float, string or elements attributes"; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -45,6 +45,7 @@ [&](Type) { return "vec"; }) .Case([&](Type) { return "array"; }) .Case([&](Type) { return "struct"; }) + .Case([&](Type) { return "target"; }) .Default([](Type) -> StringRef { llvm_unreachable("unexpected 'llvm' type kind"); }); @@ -119,7 +120,7 @@ llvm::TypeSwitch(type) .Case( + LLVMScalableVectorType, LLVMFunctionType, LLVMTargetExtType>( [&](auto type) { type.print(printer); }) .Case([&](LLVMStructType structType) { printStructType(printer, structType); @@ -332,6 +333,7 @@ .Case("vec", [&] { return parseVectorType(parser); }) .Case("array", [&] { return LLVMArrayType::parse(parser); }) .Case("struct", [&] { return parseStructType(parser); }) + .Case("target", [&] { return LLVMTargetExtType::parse(parser); }) .Default([&] { parser.emitError(keyLoc) << "unknown LLVM type: " << key; return Type(); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -109,6 +109,59 @@ } } +//===----------------------------------------------------------------------===// +// custom +//===----------------------------------------------------------------------===// + +/// Parses the parameter list for a target extension type. The parameter list +/// contains an optional list of type parameters, followed by an optional list +/// of integer parameters. Type and integer parameters cannot be interleaved in +/// the list. +/// extTypeParams ::= typeList? | intList? | (typeList "," intList) +/// typeList ::= type ("," type)* +/// intList ::= integer ("," integer)* +static ParseResult +parseExtTypeParams(AsmParser &p, SmallVectorImpl &typeParams, + SmallVectorImpl &intParams) { + bool parseType = true; + auto typeOrIntParser = [&]() -> ParseResult { + unsigned int i; + auto intResult = p.parseOptionalInteger(i); + if (intResult.has_value() && !failed(*intResult)) { + // Successfully parsed an integer. + intParams.push_back(i); + // After the first integer was successfully parsed, no + // more types can be parsed. + parseType = false; + return success(); + } + if (parseType) { + Type t; + if (!parsePrettyLLVMType(p, t)) { + // Successfully parsed a type. + typeParams.push_back(t); + return success(); + } + } + return failure(); + }; + if (p.parseCommaSeparatedList(typeOrIntParser)) { + p.emitError(p.getCurrentLocation(), + "failed to parse parameter list for target extension type"); + return failure(); + } + return success(); +} + +static void printExtTypeParams(AsmPrinter &p, ArrayRef typeParams, + ArrayRef intParams) { + p << typeParams; + if (!typeParams.empty() && !intParams.empty()) + p << ", "; + + p << intParams; +} + //===----------------------------------------------------------------------===// // ODS-Generated Definitions //===----------------------------------------------------------------------===// @@ -721,6 +774,35 @@ emitError, elementType, numElements); } +//===----------------------------------------------------------------------===// +// LLVMTargetExtType. +//===----------------------------------------------------------------------===// + +static constexpr llvm::StringRef kSpirvPrefix = "spirv."; +static constexpr llvm::StringRef kArmSVCount = "aarch64.svcount"; + +bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const { + // See llvm/lib/IR/Type.cpp for reference. + uint64_t properties = 0; + + if (getExtTypeName().starts_with(kSpirvPrefix)) + properties |= + (LLVMTargetExtType::HasZeroInit | LLVM::LLVMTargetExtType::CanBeGlobal); + + return (properties & prop) == prop; +} + +bool LLVM::LLVMTargetExtType::supportsMemOps() const { + // See llvm/lib/IR/Type.cpp for reference. + if (getExtTypeName().starts_with(kSpirvPrefix)) + return true; + + if (getExtTypeName() == kArmSVCount) + return true; + + return false; +} + //===----------------------------------------------------------------------===// // Utility functions. //===----------------------------------------------------------------------===// @@ -746,6 +828,7 @@ LLVMTokenType, LLVMFixedVectorType, LLVMScalableVectorType, + LLVMTargetExtType, LLVMVoidType, LLVMX86MMXType >(type)) { @@ -791,6 +874,9 @@ return true; return isCompatible(pointerType.getElementType()); }) + .Case([&](auto extType) { + return llvm::all_of(extType.getTypeParams(), isCompatible); + }) // clang-format off .Case< LLVMFixedVectorType, @@ -974,7 +1060,8 @@ .Default([](Type ty) { assert((llvm::isa(ty)) && + LLVMPointerType, LLVMFunctionType, LLVMTargetExtType>( + ty)) && "unexpected missing support for primitive type"); return llvm::TypeSize::Fixed(0); }); diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -1065,6 +1065,23 @@ return root; } + if (auto *constTargetNone = dyn_cast(constant)) { + LLVMTargetExtType targetExtType = + cast(convertType(constTargetNone->getType())); + assert(targetExtType.hasProperty(LLVMTargetExtType::HasZeroInit) && + "target extension type does not support zero-initialization"); + // As the number of values needed for initialization is target-specific and + // opaque to the compiler, use a single i64 zero-valued attribute to + // represent the 'zeroinitializer', which is the only constant value allowed + // for target extension types (besides poison and undef). + // TODO: Replace with 'zeroinitializer' once there is a dedicated + // zeroinitializer operation in the LLVM dialect. + return builder + .create(loc, targetExtType, + builder.getI64IntegerAttr(0)) + .getRes(); + } + StringRef error = ""; if (isa(constant)) error = " since blockaddress(...) is unsupported"; diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -342,6 +342,16 @@ return nullptr; return llvm::ConstantStruct::get(structType, {real, imag}); } + if (auto *targetExtType = dyn_cast<::llvm::TargetExtType>(llvmType)) { + // TODO: Replace with 'zeroinitializer' once there is a dedicated + // zeroinitializer operation in the LLVM dialect. + auto intAttr = dyn_cast(attr); + if (!intAttr || intAttr.getInt() != 0) + emitError(loc, + "Only zero-initialization allowed for target extension type"); + + return llvm::ConstantTargetNone::get(targetExtType); + } // For integer types, we allow a mismatch in sizes as the index type in // MLIR might have a different size than the index type in the LLVM module. if (auto intAttr = dyn_cast(attr)) diff --git a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp --- a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp +++ b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp @@ -36,7 +36,7 @@ llvm::TypeSwitch(type) .Case( + llvm::ScalableVectorType, llvm::TargetExtType>( [this](auto *type) { return this->translate(type); }) .Default([this](llvm::Type *type) { return translatePrimitiveType(type); @@ -135,6 +135,15 @@ translateType(type->getElementType()), type->getMinNumElements()); } + /// Translates the given target extension type. + Type translate(llvm::TargetExtType *type) { + SmallVector typeParams; + translateTypes(type->type_params(), typeParams); + + return LLVM::LLVMTargetExtType::get(&context, type->getName(), typeParams, + type->int_params()); + } + /// Translates a list of types. void translateTypes(ArrayRef types, SmallVectorImpl &result) { diff --git a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp --- a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp +++ b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp @@ -73,7 +73,7 @@ .Case( + VectorType, LLVM::LLVMTargetExtType>( [this](auto type) { return this->translate(type); }) .Default([](Type t) -> llvm::Type * { llvm_unreachable("unknown LLVM dialect type"); @@ -155,6 +155,14 @@ type.getMinNumElements()); } + /// Translates the given target extension type. + llvm::Type *translate(LLVM::LLVMTargetExtType type) { + SmallVector typeParams; + translateTypes(type.getTypeParams(), typeParams); + return llvm::TargetExtType::get(context, type.getExtTypeName(), typeParams, + type.getIntParams()); + } + /// Translates a list of types. void translateTypes(ArrayRef types, SmallVectorImpl &result) { diff --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir --- a/mlir/test/Dialect/LLVMIR/global.mlir +++ b/mlir/test/Dialect/LLVMIR/global.mlir @@ -232,3 +232,16 @@ // CHECK: llvm.mlir.global_dtors {dtors = [@dtor], priorities = [0 : i32]} llvm.mlir.global_dtors { dtors = [@dtor], priorities = [0 : i32]} + +// ----- + +// CHECK: llvm.mlir.global external @target_ext() {addr_space = 0 : i32} : !llvm.target<"spirv.Image", i32, 0> +llvm.mlir.global @target_ext() : !llvm.target<"spirv.Image", i32, 0> + +// CHECK: llvm.mlir.global external @target_ext_init(0 : i64) {addr_space = 0 : i32} : !llvm.target<"spirv.Image", i32, 0> +llvm.mlir.global @target_ext_init(0 : i64) : !llvm.target<"spirv.Image", i32, 0> + +// ----- + +// expected-error @+1 {{expected zero value for global with target extension type}} +llvm.mlir.global @target_fail(1 : i64) : !llvm.target<"spirv.Image", i32, 0> diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1387,3 +1387,39 @@ // expected-error@+1 {{cannot cast pointers of different address spaces, use 'llvm.addrspacecast' instead}} %0 = llvm.bitcast %arg : !llvm.vec<4 x ptr<1>> to !llvm.vec<4 x ptr> } + +// ----- + +func.func @invalid_target_ext_alloca() { + %0 = llvm.mlir.constant(1 : i64) : i64 + // expected-error@+1 {{this target extension type cannot be used in alloca}} + %1 = llvm.alloca %0 x !llvm.target<"no_alloca"> : (i64) -> !llvm.ptr +} + +// ----- + +func.func @invalid_target_ext_load(%arg0 : !llvm.ptr) { + // expected-error@+1 {{result #0 must be LLVM type with size, but got '!llvm.target<"no_load">'}} + %0 = llvm.load %arg0 {alignment = 8 : i64} : !llvm.ptr -> !llvm.target<"no_load"> +} + +// ----- + +func.func @invalid_target_ext_atomic(%arg0 : !llvm.ptr) { + // expected-error@+1 {{unsupported type '!llvm.target<"spirv.Event">' for atomic access}} + %0 = llvm.load %arg0 atomic monotonic {alignment = 8 : i64} : !llvm.ptr -> !llvm.target<"spirv.Event"> +} + +// ----- + +func.func @invalid_target_ext_constant() { + // expected-error@+1 {{target extension type does not support zero-initializer}} + %0 = llvm.mlir.constant(0 : index) : !llvm.target<"invalid_constant"> +} + +// ----- + +func.func @invalid_target_ext_constant() { + // expected-error@+1 {{only zero-initializer allowed for target extension types}} + %0 = llvm.mlir.constant(42 : index) : !llvm.target<"spirv.Event"> +} diff --git a/mlir/test/Dialect/LLVMIR/types-invalid.mlir b/mlir/test/Dialect/LLVMIR/types-invalid.mlir --- a/mlir/test/Dialect/LLVMIR/types-invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/types-invalid.mlir @@ -158,3 +158,18 @@ // expected-error @below {{cannot use !llvm.vec for built-in primitives, use 'vector' instead}} func.func private @llvm_vector_primitive() -> !llvm.vec<4 x f32> + +// ----- + +func.func private @target_ext_invalid_order() { + // expected-error @+1 {{failed to parse parameter list for target extension type}} + "some.op"() : () -> !llvm.target<"target1", 5, i32, 1> +} + +// ----- + +func.func private @target_ext_no_name() { + // expected-error@below {{expected string}} + // expected-error@below {{failed to parse LLVMTargetExtType parameter 'extTypeName' which is to be a `::llvm::StringRef`}} + "some.op"() : () -> !llvm.target +} diff --git a/mlir/test/Dialect/LLVMIR/types.mlir b/mlir/test/Dialect/LLVMIR/types.mlir --- a/mlir/test/Dialect/LLVMIR/types.mlir +++ b/mlir/test/Dialect/LLVMIR/types.mlir @@ -176,3 +176,20 @@ "some.op"() : () -> !llvm.struct<(i32, f32, !qux)> llvm.return } + +// ----- + +// CHECK-LABEL: ext_target +llvm.func @ext_target() { + // CHECK: !llvm.target<"target1", i32, 1> + %0 = "some.op"() : () -> !llvm.target<"target1", i32, 1> + // CHECK: !llvm.target<"target2"> + %1 = "some.op"() : () -> !llvm.target<"target2"> + // CHECK: !llvm.target<"target3", i32, i64, f64> + %2 = "some.op"() : () -> !llvm.target<"target3", i32, i64, f64> + // CHECK: !llvm.target<"target4", 1, 0, 42> + %3 = "some.op"() : () -> !llvm.target<"target4", 1, 0, 42> + // CHECK: !llvm.target<"target5", i32, f64, 0, 5> + %4 = "some.op"() : () -> !llvm.target<"target5", i32, f64, 0, 5> + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/Import/target-ext-type.ll b/mlir/test/Target/LLVMIR/Import/target-ext-type.ll new file mode 100644 --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/target-ext-type.ll @@ -0,0 +1,53 @@ +; RUN: mlir-translate -import-llvm %s | FileCheck %s + +; CHECK-LABEL: llvm.mlir.global external @global() {addr_space = 0 : i32} +; CHECK-SAME: !llvm.target<"spirv.DeviceEvent"> +; CHECK-NEXT: %0 = llvm.mlir.constant(0 : i64) : !llvm.target<"spirv.DeviceEvent"> +; CHECK-NEXT: llvm.return %0 : !llvm.target<"spirv.DeviceEvent"> +@global = global target("spirv.DeviceEvent") zeroinitializer + +; CHECK-LABEL: llvm.func spir_kernelcc @func1( +define spir_kernel void @func1( + ; CHECK-SAME: %arg0: !llvm.target<"spirv.Pipe", 0> + target("spirv.Pipe", 0) %a, + ; CHECK-SAME: %arg1: !llvm.target<"spirv.Pipe", 1> + target("spirv.Pipe", 1) %b, + ; CHECK-SAME: %arg2: !llvm.target<"spirv.Image", !llvm.void, 0, 0, 0, 0, 0, 0, 0> + target("spirv.Image", void, 0, 0, 0, 0, 0, 0, 0) %c1, + ; CHECK-SAME: %arg3: !llvm.target<"spirv.Image", i32, 1, 0, 0, 0, 0, 0, 0> + target("spirv.Image", i32, 1, 0, 0, 0, 0, 0, 0) %d1, + ; CHECK-SAME: %arg4: !llvm.target<"spirv.Image", i32, 2, 0, 0, 0, 0, 0, 0> + target("spirv.Image", i32, 2, 0, 0, 0, 0, 0, 0) %e1, + ; CHECK-SAME: %arg5: !llvm.target<"spirv.Image", f16, 1, 0, 1, 0, 0, 0, 0> + target("spirv.Image", half, 1, 0, 1, 0, 0, 0, 0) %f1, + ; CHECK-SAME: %arg6: !llvm.target<"spirv.Image", f32, 5, 0, 0, 0, 0, 0, 0> + target("spirv.Image", float, 5, 0, 0, 0, 0, 0, 0) %g1, + ; CHECK-SAME: %arg7: !llvm.target<"spirv.Image", !llvm.void, 0, 0, 0, 0, 0, 0, 1> + target("spirv.Image", void, 0, 0, 0, 0, 0, 0, 1) %c2, + ; CHECK-SAME: %arg8: !llvm.target<"spirv.Image", !llvm.void, 1, 0, 0, 0, 0, 0, 2>) + target("spirv.Image", void, 1, 0, 0, 0, 0, 0, 2) %d3) { +entry: + ret void +} + +; CHECK-LABEL: llvm.func @func2() +; CHECK-SAME: !llvm.target<"spirv.Event"> { +define target("spirv.Event") @func2() { + ; CHECK-NEXT: %0 = llvm.mlir.constant(1 : i32) : i32 + ; CHECK-NEXT: %1 = llvm.mlir.poison : !llvm.target<"spirv.Event"> + ; CHECK-NEXT: %2 = llvm.alloca %0 x !llvm.target<"spirv.Event"> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %mem = alloca target("spirv.Event") + ; CHECK-NEXT: %3 = llvm.load %2 {alignment = 8 : i64} : !llvm.ptr -> !llvm.target<"spirv.Event"> + %val = load target("spirv.Event"), ptr %mem + ; CHECK-NEXT: llvm.return %1 : !llvm.target<"spirv.Event"> + ret target("spirv.Event") poison +} + +; CHECK-LABEL: llvm.func @func3() +define void @func3() { + ; CHECK-NEXT: %0 = llvm.mlir.constant(0 : i64) : !llvm.target<"spirv.DeviceEvent"> + ; CHECK-NEXT: %1 = llvm.freeze %0 : !llvm.target<"spirv.DeviceEvent"> + %val = freeze target("spirv.DeviceEvent") zeroinitializer + ; CHECK-NEXT: llvm.return + ret void +} diff --git a/mlir/test/Target/LLVMIR/llvmir-types.mlir b/mlir/test/Target/LLVMIR/llvmir-types.mlir --- a/mlir/test/Target/LLVMIR/llvmir-types.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-types.mlir @@ -141,6 +141,18 @@ // CHECK: declare <{ { i32 } }> @return_sp_s_i32() llvm.func @return_sp_s_i32() -> !llvm.struct)> +// CHECK: declare target("target-no-param") @return_target_ext_no_param() +llvm.func @return_target_ext_no_param() -> !llvm.target<"target-no-param"> + +// CHECK: declare target("target-type-param", i32, double) @return_target_ext_type_params() +llvm.func @return_target_ext_type_params() -> !llvm.target<"target-type-param", i32, f64> + +// CHECK: declare target("target-int-param", 0, 42) @return_target_ext_int_params() +llvm.func @return_target_ext_int_params() -> !llvm.target<"target-int-param", 0, 42> + +// CHECK: declare target("target-params", i32, double, 0, 5) @return_target_ext_params() +llvm.func @return_target_ext_params() -> !llvm.target<"target-params", i32, f64, 0, 5> + // ----- // Put structs into a separate split so that we can match their declarations // locally. diff --git a/mlir/test/Target/LLVMIR/target-ext-type.mlir b/mlir/test/Target/LLVMIR/target-ext-type.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/LLVMIR/target-ext-type.mlir @@ -0,0 +1,28 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// CHECK: @global = global target("spirv.DeviceEvent") zeroinitializer +llvm.mlir.global external @global() {addr_space = 0 : i32} : !llvm.target<"spirv.DeviceEvent"> { + %0 = llvm.mlir.constant(0 : i64) : !llvm.target<"spirv.DeviceEvent"> + llvm.return %0 : !llvm.target<"spirv.DeviceEvent"> +} + +// CHECK-LABEL: define target("spirv.Event") @func2() { +// CHECK-NEXT: %1 = alloca target("spirv.Event"), align 8 +// CHECK-NEXT: %2 = load target("spirv.Event"), ptr %1, align 8 +// CHECK-NEXT: ret target("spirv.Event") poison +llvm.func @func2() -> !llvm.target<"spirv.Event"> { + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.mlir.poison : !llvm.target<"spirv.Event"> + %2 = llvm.alloca %0 x !llvm.target<"spirv.Event"> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %3 = llvm.load %2 {alignment = 8 : i64} : !llvm.ptr -> !llvm.target<"spirv.Event"> + llvm.return %1 : !llvm.target<"spirv.Event"> +} + +// CHECK-LABEL: define void @func3() { +// CHECK-NEXT: %1 = freeze target("spirv.DeviceEvent") zeroinitializer +// CHECK-NEXT: ret void +llvm.func @func3() { + %0 = llvm.mlir.constant(0 : i64) : !llvm.target<"spirv.DeviceEvent"> + %1 = llvm.freeze %0 : !llvm.target<"spirv.DeviceEvent"> + llvm.return +}