diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -138,6 +138,47 @@ static LLVMType getVectorTy(LLVMType elementType, unsigned numElements); static LLVMType getVoidTy(LLVMDialect *dialect); + // Creation and setting of LLVM's identified struct types + static LLVMType createStructTy(LLVMDialect *dialect, + ArrayRef elements, + Optional name, + bool isPacked = false); + + static LLVMType createStructTy(LLVMDialect *dialect, + Optional name) { + return createStructTy(dialect, llvm::None, name); + } + + static LLVMType createStructTy(ArrayRef elements, + Optional name, + bool isPacked = false) { + assert(!elements.empty() && + "This method may not be invoked with an empty list"); + LLVMType ele0 = elements.front(); + return createStructTy(&ele0.getDialect(), elements, name, isPacked); + } + + template + static typename std::enable_if_t::value, + LLVMType> + createStructTy(StringRef name, LLVMType elt1, Args... elts) { + SmallVector fields({elt1, elts...}); + Optional opt_name(name); + return createStructTy(&elt1.getDialect(), fields, opt_name); + } + + static LLVMType setStructTyBody(LLVMType structType, + ArrayRef elements, + bool isPacked = false); + + template + static typename std::enable_if_t::value, + LLVMType> + setStructTyBody(LLVMType structType, LLVMType elt1, Args... elts) { + SmallVector fields({elt1, elts...}); + return setStructTyBody(structType, fields); + } + private: friend LLVMDialect; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1641,6 +1641,35 @@ isPacked); }); } +inline static SmallVector +toUnderlyingTypes(ArrayRef elements) { + SmallVector llvmElements; + for (auto elt : elements) + llvmElements.push_back(elt.getUnderlyingType()); + return llvmElements; +} +LLVMType LLVMType::createStructTy(LLVMDialect *dialect, + ArrayRef elements, + Optional name, bool isPacked) { + StringRef sr = name.hasValue() ? *name : ""; + SmallVector llvmElements(toUnderlyingTypes(elements)); + return getLocked(dialect, [=] { + auto *rv = llvm::StructType::create(dialect->getLLVMContext(), sr); + if (!llvmElements.empty()) + rv->setBody(llvmElements, isPacked); + return rv; + }); +} +LLVMType LLVMType::setStructTyBody(LLVMType structType, + ArrayRef elements, bool isPacked) { + llvm::StructType *st = + llvm::cast(structType.getUnderlyingType()); + SmallVector llvmElements(toUnderlyingTypes(elements)); + return getLocked(&structType.getDialect(), [=] { + st->setBody(llvmElements, isPacked); + return st; + }); +} LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) { // Lock access to the dialect as this may modify the LLVM context. return getLocked(&elementType.getDialect(), [=] { diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -382,3 +382,15 @@ %0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> (!llvm<"{ float, float, float, float, float, float, float, float }">, !llvm.i32) llvm.return %0 : (!llvm<"{ float, float, float, float, float, float, float, float }">, !llvm.i32) } + +// ----- + +// FIXME: the LLVM-IR dialect should parse mutually recursive types +// CHECK-LABEL: @recursive_type +// expected-error@+1 {{expected end of string}} +llvm.func @recursive_type(%a : !llvm<"%a = type { %a* }">) -> + !llvm<"%a = type { %a* }"> { + llvm.return %a : !llvm<"%a = type { %a* }"> +} + +