diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -16,7 +16,12 @@ add_mlir_doc(LLVMOps LLVMOps Dialects/ -gen-op-doc) -add_mlir_interface(LLVMOpsInterfaces) +set(LLVM_TARGET_DEFINITIONS LLVMOpsInterfaces.td) +mlir_tablegen(LLVMOpsInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(LLVMOpsInterfaces.cpp.inc -gen-op-interface-defs) +mlir_tablegen(LLVMTypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(LLVMTypeInterfaces.cpp.inc -gen-type-interface-defs) +add_public_tablegen_target(MLIRLLVMOpsInterfacesIncGen) set(LLVM_TARGET_DEFINITIONS LLVMOps.td) mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -108,10 +108,18 @@ And<[LLVM_AnyStruct.predicate, CPred<"$_self.cast<::mlir::LLVM::LLVMStructType>().isOpaque()">]>>; +// Type constraint accepting types that implement that pointer element +// interface. +def LLVM_PointerElementType : Type< + CPred<"$_self.isa<::mlir::LLVM::PointerElementTypeInterface>()">, + "LLVM-compatible pointer element type">; + + // Type constraint accepting any LLVM type that can be loaded or stored, i.e. a // type that has size (not void, function or opaque struct type). def LLVM_LoadableType : Type< - And<[LLVM_PrimitiveType.predicate, Neg]>, + Or<[And<[LLVM_PrimitiveType.predicate, Neg]>, + LLVM_PointerElementType.predicate]>, "LLVM type with size">; // Type constraint accepting any LLVM aggregate type, i.e. structure or array. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -331,7 +331,7 @@ OptionalAttr:$access_groups, OptionalAttr:$alignment, UnitAttr:$volatile_, UnitAttr:$nontemporal); - let results = (outs LLVM_Type:$res); + let results = (outs LLVM_LoadableType:$res); string llvmBuilder = [{ auto *inst = builder.CreateLoad( $addr->getType()->getPointerElementType(), $addr, $volatile_); diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td @@ -23,8 +23,38 @@ let cppNamespace = "::mlir::LLVM"; let methods = [ - InterfaceMethod<"Get fastmath flags", "::mlir::LLVM::FastmathFlags", "fastmathFlags">, + InterfaceMethod<"Get fastmath flags", "::mlir::LLVM::FastmathFlags", + "fastmathFlags">, ]; } +//===----------------------------------------------------------------------===// +// LLVM dialect type interfaces. +//===----------------------------------------------------------------------===// + +// An interface for LLVM pointer element types. +def LLVM_PointerElementTypeInterface + : TypeInterface<"PointerElementTypeInterface"> { + let cppNamespace = "::mlir::LLVM"; + + let description = [{ + An interface for types that are allowed as elements of LLVM pointer type. + Such types must have a size. + }]; + + let methods = [ + InterfaceMethod< + /*description=*/"Returns the size of the type in bytes.", + /*retTy=*/"unsigned", + /*methodName=*/"getSizeInBytes", + /*args=*/(ins "const DataLayout &":$dataLayout), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return dataLayout.getTypeSize($_type); + }] + > + ]; +} + + #endif // LLVM_OPS_INTERFACES diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -36,6 +36,13 @@ struct LLVMStructTypeStorage; struct LLVMTypeAndSizeStorage; } // namespace detail +} // namespace LLVM +} // namespace mlir + +#include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.h.inc" + +namespace mlir { +namespace LLVM { //===----------------------------------------------------------------------===// // Trivial types. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -120,8 +120,9 @@ //===----------------------------------------------------------------------===// bool LLVMPointerType::isValidElementType(Type type) { - return !type.isa(); + return isCompatibleType(type) ? !type.isa() + : type.isa(); } LLVMPointerType LLVMPointerType::get(Type pointee, unsigned addressSpace) { @@ -607,3 +608,5 @@ return llvm::TypeSize::Fixed(0); }); } + +#include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc" diff --git a/mlir/test/Dialect/LLVMIR/types.mlir b/mlir/test/Dialect/LLVMIR/types.mlir --- a/mlir/test/Dialect/LLVMIR/types.mlir +++ b/mlir/test/Dialect/LLVMIR/types.mlir @@ -176,6 +176,14 @@ return } +// CHECK-LABEL: @ptr_elem_interface +// CHECK-COUNT-3: !llvm.ptr +func @ptr_elem_interface(%arg0: !llvm.ptr) { + %0 = llvm.load %arg0 : !llvm.ptr + llvm.store %0, %arg0 : !llvm.ptr + return +} + // ----- // Check that type aliases can be used inside LLVM dialect types. Note that diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -62,6 +62,7 @@ MLIRIR MLIRInferTypeOpInterface MLIRLinalgTransforms + MLIRLLVMIR MLIRPass MLIRReduce MLIRStandard diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -13,6 +13,7 @@ #include "TestTypes.h" #include "TestDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Types.h" @@ -222,11 +223,19 @@ // TestDialect //===----------------------------------------------------------------------===// +namespace { + +struct PtrElementModel + : public LLVM::PointerElementTypeInterface::ExternalModel {}; +} // namespace + void TestDialect::registerTypes() { addTypes(); + SimpleAType::attachInterface(*getContext()); } static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2441,6 +2441,14 @@ ["-gen-op-interface-defs"], "include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc", ), + ( + ["-gen-type-interface-decls"], + "include/mlir/Dialect/LLVMIR/LLVMTypeInterfaces.h.inc", + ), + ( + ["-gen-type-interface-defs"], + "include/mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc", + ), ], tblgen = ":mlir-tblgen", td_file = "include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td", diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -227,6 +227,7 @@ "//mlir:Dialect", "//mlir:IR", "//mlir:InferTypeOpInterface", + "//mlir:LLVMDialect", "//mlir:Pass", "//mlir:Reducer", "//mlir:SideEffects",