diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -122,7 +122,7 @@ moduleTranslation.lookupValue(op.getPtr())}); }]; - let assemblyFormat = "$size `,` $ptr attr-dict `:` type($ptr)"; + let assemblyFormat = "$size `,` $ptr attr-dict `:` qualified(type($ptr))"; } def LLVM_LifetimeStartOp : LLVM_LifetimeBaseOp<"lifetime.start"> { diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1232,7 +1232,7 @@ LLVMFuncOp getFunction(SymbolTableCollection &symbolTable); }]; - let assemblyFormat = "$global_name attr-dict `:` type($res)"; + let assemblyFormat = "$global_name attr-dict `:` qualified(type($res))"; } def LLVM_MetadataOp : LLVM_Op<"metadata", [ @@ -1656,7 +1656,7 @@ let results = (outs LLVM_AnyPointer:$res); let builders = [LLVM_OneResultOpBuilder]; - let assemblyFormat = "attr-dict `:` type($res)"; + let assemblyFormat = "attr-dict `:` qualified(type($res))"; } def LLVM_UndefOp : LLVM_Op<"mlir.undef", [Pure]>, diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -1,4 +1,4 @@ -//===- LLVMDialect.h - MLIR LLVM dialect types ------------------*- C++ -*-===// +//===- LLVMTypes.h - MLIR LLVM dialect types --------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -72,75 +72,6 @@ #undef DEFINE_TRIVIAL_LLVM_TYPE -//===----------------------------------------------------------------------===// -// LLVMPointerType. -//===----------------------------------------------------------------------===// - -/// LLVM dialect pointer type. This type typically represents a reference to an -/// object in memory. Pointers may be opaque or parameterized by the element -/// type. Both opaque and non-opaque pointers are additionally parameterized by -/// the address space. -class LLVMPointerType - : public Type::TypeBase< - LLVMPointerType, Type, detail::LLVMPointerTypeStorage, - DataLayoutTypeInterface::Trait, SubElementTypeInterface::Trait> { -public: - /// Inherit base constructors. - using Base::Base; - - /// Checks if the given type can have a pointer type pointing to it. - static bool isValidElementType(Type type); - - /// Gets or creates an instance of LLVM dialect pointer type pointing to an - /// object of `pointee` type in the given address space. The pointer type is - /// created in the same context as `pointee`. If the pointee is not provided, - /// creates an opaque pointer in the given context and address space. - static LLVMPointerType get(MLIRContext *context, unsigned addressSpace = 0); - static LLVMPointerType get(Type pointee, unsigned addressSpace = 0); - static LLVMPointerType - getChecked(function_ref emitError, Type pointee, - unsigned addressSpace = 0); - static LLVMPointerType - getChecked(function_ref emitError, MLIRContext *context, - unsigned addressSpace = 0); - - /// Returns the pointed-to type. It may be null if the pointer is opaque. - Type getElementType() const; - - /// Returns `true` if this type is the opaque pointer type, i.e., it has no - /// pointed-to type. - bool isOpaque() const; - - /// Returns the address space of the pointer. - unsigned getAddressSpace() const; - - /// Verifies that the type about to be constructed is well-formed. - static LogicalResult verify(function_ref emitError, - Type pointee, unsigned); - static LogicalResult verify(function_ref emitError, - MLIRContext *context, unsigned) { - return success(); - } - - /// Hooks for DataLayoutTypeInterface. Should not be called directly. Obtain a - /// DataLayout instance and query it instead. - unsigned getTypeSizeInBits(const DataLayout &dataLayout, - DataLayoutEntryListRef params) const; - unsigned getABIAlignment(const DataLayout &dataLayout, - DataLayoutEntryListRef params) const; - unsigned getPreferredAlignment(const DataLayout &dataLayout, - DataLayoutEntryListRef params) const; - bool areCompatible(DataLayoutEntryListRef oldLayout, - DataLayoutEntryListRef newLayout) const; - LogicalResult verifyEntries(DataLayoutEntryListRef entries, - Location loc) const; - - void walkImmediateSubElements(function_ref walkAttrsFn, - function_ref walkTypesFn) const; - Type replaceImmediateSubElements(ArrayRef replAttrs, - ArrayRef replTypes) const; -}; - //===----------------------------------------------------------------------===// // LLVMStructType. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td @@ -118,4 +118,53 @@ }]; } +//===----------------------------------------------------------------------===// +// LLVMPointerType +//===----------------------------------------------------------------------===// + +def LLVMPointerType : LLVMType<"LLVMPointer", "ptr", [ + DeclareTypeInterfaceMethods, + DeclareTypeInterfaceMethods]> { + let summary = "LLVM pointer type"; + let description = [{ + The `!llvm.ptr` type is an LLVM pointer type. This type typically represents + a reference to an object in memory. Pointers may be opaque or parameterized + by the element type. Both opaque and non-opaque pointers are additionally + parameterized by the address space. + + Example: + + ```mlir + !llvm.ptr + !llvm.ptr + ``` + }]; + + let parameters = (ins DefaultValuedParameter<"Type", "Type()">:$elementType, + DefaultValuedParameter<"unsigned", "0">:$addressSpace); + let assemblyFormat = [{ + (`<` custom($elementType, $addressSpace)^ `>`)? + }]; + + let genVerifyDecl = 1; + + let builders = [ + TypeBuilderWithInferredContext<(ins "Type":$elementType, + CArg<"unsigned", "0">:$addressSpace)>, + TypeBuilder<(ins CArg<"unsigned", "0">:$addressSpace), [{ + return $_get($_ctxt, Type(), addressSpace); + }]> + ]; + + let extraClassDeclaration = [{ + /// Returns `true` if this type is the opaque pointer type, i.e., it has no + /// pointed-to type. + bool isOpaque() const { return !getElementType(); } + + /// Checks if the given type can have a pointer type pointing to it. + static bool isValidElementType(Type type); + }]; +} + #endif // LLVMTYPES_TD diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -168,7 +168,7 @@ let hasCustomAssemblyFormat = 1; } -def NVVM_SyncWarpOp : +def NVVM_SyncWarpOp : NVVM_Op<"bar.warp.sync">, Arguments<(ins LLVM_Type:$mask)> { string llvmBuilder = [{ @@ -534,9 +534,9 @@ ); } -// Returns a list of operation suffixes corresponding to possible b1 -// multiply-and-accumulate operations for all fragments which have a -// b1 type. For all other fragments, the list returned holds a list +// Returns a list of operation suffixes corresponding to possible b1 +// multiply-and-accumulate operations for all fragments which have a +// b1 type. For all other fragments, the list returned holds a list // containing the empty string. class NVVM_MMA_B1OPS frags> { list ret = !cond( @@ -555,7 +555,7 @@ # "_" # ALayout # "_" # BLayout # !if(Satfinite, "_satfinite", "") - # signature; + # signature; } /// Helper to create the mapping between the configuration and the mma.sync @@ -572,13 +572,13 @@ "if (layoutA == \"" # layoutA # "\" && layoutB == \"" # layoutB # "\" && " " m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k # " && \"" # op[0].ptx_elt_type # "\" == eltypeA && \"" - # op[1].ptx_elt_type # "\" == eltypeB && " + # op[1].ptx_elt_type # "\" == eltypeB && " # " \"" # op[2].ptx_elt_type # "\" == eltypeC && " # " \"" # op[3].ptx_elt_type # "\" == eltypeD " # " && (sat.has_value() ? " # sat # " == static_cast(*sat) : true)" # !if(!ne(b1op, ""), " && (b1Op.has_value() ? MMAB1Op::" # b1op # " == b1Op.value() : true)", "") # ")\n" # " return " # - MMA_SYNC_NAME.id # ";", + MMA_SYNC_NAME.id # ";", "") // if supported ) // b1op ) // sat @@ -586,7 +586,7 @@ ) // layoutA ); // all_mma_sync_ops list>> f1 = !foldl([[[""]]], - !foldl([[[[""]]]], cond0, acc, el, + !foldl([[[[""]]]], cond0, acc, el, !listconcat(acc, el)), acc1, el1, !listconcat(acc1, el1)); list> f2 = !foldl([[""]], f1, acc1, el1, !listconcat(acc1, el1)); @@ -776,7 +776,10 @@ ``` }]; - let assemblyFormat = "$ptr `,` $stride `,` $args attr-dict `:` type($ptr) `,` type($args)"; + let assemblyFormat = [{ + $ptr `,` $stride `,` $args attr-dict `:` qualified(type($ptr)) `,` + type($args) + }]; let hasVerifier = 1; } @@ -884,32 +887,32 @@ let description = [{ The `nvvm.mma.sync` operation collectively performs the operation - `D = matmul(A, B) + C` using all threads in a warp. + `D = matmul(A, B) + C` using all threads in a warp. All the threads in the warp must execute the same `mma.sync` operation. For each possible multiplicand PTX data type, there are one or more possible instruction shapes given as "mMnNkK". The below table describes the posssibilities - as well as the types required for the operands. Note that the data type for - C (the accumulator) and D (the result) can vary independently when there are + as well as the types required for the operands. Note that the data type for + C (the accumulator) and D (the result) can vary independently when there are multiple possibilities in the "C/D Type" column. When an optional attribute cannot be immediately inferred from the types of the operands and the result during parsing or validation, an error will be raised. - `b1Op` is only relevant when the binary (b1) type is given to + `b1Op` is only relevant when the binary (b1) type is given to `multiplicandDataType`. It specifies how the multiply-and-acumulate is performed and is either `xor_popc` or `and_poc`. The default is `xor_popc`. - `intOverflowBehavior` is only relevant when the `multiplicandType` attribute + `intOverflowBehavior` is only relevant when the `multiplicandType` attribute is one of `u8, s8, u4, s4`, this attribute describes how overflow is handled in the accumulator. When the attribute is `satfinite`, the accumulator values are clamped in the int32 range on overflow. This is the default behavior. - Alternatively, accumulator behavior `wrapped` can also be specified, in + Alternatively, accumulator behavior `wrapped` can also be specified, in which case overflow wraps from one end of the range to the other. - `layoutA` and `layoutB` are required and should generally be set to + `layoutA` and `layoutB` are required and should generally be set to `#nvvm.mma_layout` and `#nvvm.mma_layout` respectively, but other combinations are possible for certain layouts according to the table below. @@ -938,12 +941,12 @@ Example: ```mlir - %128 = nvvm.mma.sync A[%120, %121, %122, %123] - B[%124, %125] - C[%126, %127] - {layoutA = #nvvm.mma_layout, - layoutB = #nvvm.mma_layout, - shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} + %128 = nvvm.mma.sync A[%120, %121, %122, %123] + B[%124, %125] + C[%126, %127] + {layoutA = #nvvm.mma_layout, + layoutB = #nvvm.mma_layout, + shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> ``` @@ -951,7 +954,7 @@ let results = (outs LLVM_AnyStruct:$res); let arguments = (ins NVVM_MMAShapeAttr:$shape, - OptionalAttr:$b1Op, + OptionalAttr:$b1Op, OptionalAttr:$intOverflowBehavior, MMALayoutAttr:$layoutA, MMALayoutAttr:$layoutB, @@ -959,12 +962,12 @@ OptionalAttr:$multiplicandBPtxType, Variadic:$operandA, Variadic:$operandB, - Variadic:$operandC); + Variadic:$operandC); let extraClassDeclaration = !strconcat([{ static llvm::Intrinsic::ID getIntrinsicID( int64_t m, int64_t n, uint64_t k, - llvm::Optional b1Op, + llvm::Optional b1Op, llvm::Optional sat, mlir::NVVM::MMALayout layoutAEnum, mlir::NVVM::MMALayout layoutBEnum, mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum, @@ -988,7 +991,7 @@ }]); let builders = [ - OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA, + OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA, "ValueRange":$operandB, "ValueRange":$operandC, "ArrayRef":$shape, "Optional":$b1Op, "Optional":$intOverflow, @@ -999,12 +1002,12 @@ string llvmBuilder = [{ auto operands = moduleTranslation.lookupValues(opInst.getOperands()); auto intId = mlir::NVVM::MmaOp::getIntrinsicID( - $shape.getM(), $shape.getN(), $shape.getK(), + $shape.getM(), $shape.getN(), $shape.getK(), $b1Op, $intOverflowBehavior, $layoutA, $layoutB, - $multiplicandAPtxType.value(), + $multiplicandAPtxType.value(), $multiplicandBPtxType.value(), - op.accumPtxType(), + op.accumPtxType(), op.resultPtxType()); $res = createIntrinsicCall( diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2573,7 +2573,6 @@ LLVMTokenType, LLVMLabelType, LLVMMetadataType, - LLVMPointerType, LLVMFixedVectorType, LLVMScalableVectorType, LLVMStructType>(); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -124,20 +124,8 @@ printer << getTypeKeyword(type); - if (auto ptrType = type.dyn_cast()) { - if (ptrType.isOpaque()) { - if (ptrType.getAddressSpace() != 0) - printer << '<' << ptrType.getAddressSpace() << '>'; - return; - } - - printer << '<'; - dispatchPrint(printer, ptrType.getElementType()); - if (ptrType.getAddressSpace() != 0) - printer << ", " << ptrType.getAddressSpace(); - printer << '>'; - return; - } + if (auto ptrType = type.dyn_cast()) + return ptrType.print(printer); if (auto arrayType = type.dyn_cast()) return arrayType.print(printer); @@ -164,37 +152,6 @@ static ParseResult dispatchParse(AsmParser &parser, Type &type); -/// Parses an LLVM dialect pointer type. -/// llvm-type ::= `ptr<` llvm-type (`,` integer)? `>` -/// | `ptr` (`<` integer `>`)? -static LLVMPointerType parsePointerType(AsmParser &parser) { - SMLoc loc = parser.getCurrentLocation(); - Type elementType; - if (parser.parseOptionalLess()) { - return parser.getChecked(loc, parser.getContext(), - /*addressSpace=*/0); - } - - unsigned addressSpace = 0; - OptionalParseResult opr = parser.parseOptionalInteger(addressSpace); - if (opr.has_value()) { - if (failed(*opr) || parser.parseGreater()) - return LLVMPointerType(); - return parser.getChecked(loc, parser.getContext(), - addressSpace); - } - - if (dispatchParse(parser, elementType)) - return LLVMPointerType(); - - if (succeeded(parser.parseOptionalComma()) && - failed(parser.parseInteger(addressSpace))) - return LLVMPointerType(); - if (failed(parser.parseGreater())) - return LLVMPointerType(); - return parser.getChecked(loc, elementType, addressSpace); -} - /// Parses an LLVM dialect vector type. /// llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>` /// Supports both fixed and scalable vectors. @@ -391,7 +348,7 @@ .Case("label", [&] { return LLVMLabelType::get(ctx); }) .Case("metadata", [&] { return LLVMMetadataType::get(ctx); }) .Case("func", [&] { return LLVMFunctionType::parse(parser); }) - .Case("ptr", [&] { return parsePointerType(parser); }) + .Case("ptr", [&] { return LLVMPointerType::parse(parser); }) .Case("vec", [&] { return parseVectorType(parser); }) .Case("array", [&] { return LLVMArrayType::parse(parser); }) .Case("struct", [&] { return parseStructType(parser); }) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -1,3 +1,4 @@ +//===- LLVMTypes.cpp - MLIR LLVM dialect types ------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -75,6 +76,41 @@ p << ')'; } +//===----------------------------------------------------------------------===// +// custom +//===----------------------------------------------------------------------===// + +static ParseResult parsePointer(AsmParser &p, FailureOr &elementType, + FailureOr &addressSpace) { + addressSpace = 0; + // `<` addressSpace `>` + OptionalParseResult result = p.parseOptionalInteger(*addressSpace); + if (result.has_value()) { + if (failed(result.value())) + return failure(); + elementType = Type(); + return success(); + } + + if (parsePrettyLLVMType(p, elementType)) + return failure(); + if (succeeded(p.parseOptionalComma())) + return p.parseInteger(*addressSpace); + + return success(); +} + +static void printPointer(AsmPrinter &p, Type elementType, + unsigned addressSpace) { + if (elementType) + printPrettyLLVMType(p, elementType); + if (addressSpace != 0) { + if (elementType) + p << ", "; + p << addressSpace; + } +} + //===----------------------------------------------------------------------===// // ODS-Generated Definitions //===----------------------------------------------------------------------===// @@ -228,7 +264,7 @@ } //===----------------------------------------------------------------------===// -// Pointer type. +// LLVMPointerType //===----------------------------------------------------------------------===// bool LLVMPointerType::isValidElementType(Type type) { @@ -246,32 +282,6 @@ return Base::get(pointee.getContext(), pointee, addressSpace); } -LLVMPointerType LLVMPointerType::get(MLIRContext *context, - unsigned addressSpace) { - return Base::get(context, Type(), addressSpace); -} - -LLVMPointerType -LLVMPointerType::getChecked(function_ref emitError, - Type pointee, unsigned addressSpace) { - return Base::getChecked(emitError, pointee.getContext(), pointee, - addressSpace); -} - -LLVMPointerType -LLVMPointerType::getChecked(function_ref emitError, - MLIRContext *context, unsigned addressSpace) { - return Base::getChecked(emitError, context, Type(), addressSpace); -} - -Type LLVMPointerType::getElementType() const { return getImpl()->pointeeType; } - -bool LLVMPointerType::isOpaque() const { return !getImpl()->pointeeType; } - -unsigned LLVMPointerType::getAddressSpace() const { - return getImpl()->addressSpace; -} - LogicalResult LLVMPointerType::verify(function_ref emitError, Type pointee, unsigned) { @@ -280,6 +290,9 @@ return success(); } +//===----------------------------------------------------------------------===// +// DataLayoutTypeInterface + constexpr const static unsigned kDefaultPointerSizeBits = 64; constexpr const static unsigned kDefaultPointerAlignment = 8; @@ -426,6 +439,9 @@ return success(); } +//===----------------------------------------------------------------------===// +// SubElementTypeInterface + void LLVMPointerType::walkImmediateSubElements( function_ref walkAttrsFn, function_ref walkTypesFn) const { diff --git a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h --- a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h +++ b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h @@ -321,33 +321,6 @@ unsigned identifiedBodySizeAndFlags = 0; }; -//===----------------------------------------------------------------------===// -// LLVMPointerTypeStorage. -//===----------------------------------------------------------------------===// - -/// Storage type for LLVM dialect pointer types. These are uniqued by a pair of -/// element type and address space. The element type may be null indicating that -/// the pointer is opaque. -struct LLVMPointerTypeStorage : public TypeStorage { - using KeyTy = std::tuple; - - LLVMPointerTypeStorage(const KeyTy &key) - : pointeeType(std::get<0>(key)), addressSpace(std::get<1>(key)) {} - - static LLVMPointerTypeStorage *construct(TypeStorageAllocator &allocator, - const KeyTy &key) { - return new (allocator.allocate()) - LLVMPointerTypeStorage(key); - } - - bool operator==(const KeyTy &key) const { - return std::make_tuple(pointeeType, addressSpace) == key; - } - - Type pointeeType; - unsigned addressSpace; -}; - //===----------------------------------------------------------------------===// // LLVMTypeAndSizeStorage. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -495,7 +495,7 @@ // ----- func.func @null_non_llvm_type() { - // expected-error@+1 {{custom op 'llvm.mlir.null' invalid kind of type specified}} + // expected-error@+1 {{'llvm.mlir.null' op result #0 must be LLVM pointer type, but got 'i32'}} llvm.mlir.null : i32 } diff --git a/mlir/test/Dialect/LLVMIR/layout.mlir b/mlir/test/Dialect/LLVMIR/layout.mlir --- a/mlir/test/Dialect/LLVMIR/layout.mlir +++ b/mlir/test/Dialect/LLVMIR/layout.mlir @@ -82,12 +82,12 @@ // CHECK: size = 8 "test.data_layout_query"() : () -> !llvm.ptr // CHECK: alignment = 4 - // CHECK: bitsize = 32 + // CHECK: bitsize = 32 // CHECK: preferred = 8 // CHECK: size = 4 "test.data_layout_query"() : () -> !llvm.ptr<3> // CHECK: alignment = 8 - // CHECK: bitsize = 32 + // CHECK: bitsize = 32 // CHECK: preferred = 8 // CHECK: size = 4 "test.data_layout_query"() : () -> !llvm.ptr<4>