diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -189,11 +189,10 @@ }]; code setNonTemporalMetadataCode = [{ if ($nontemporal) { - llvm::Module *module = builder.GetInsertBlock()->getModule(); llvm::MDNode *metadata = llvm::MDNode::get( inst->getContext(), llvm::ConstantAsMetadata::get( builder.getInt32(1))); - inst->setMetadata(module->getMDKindID("nontemporal"), metadata); + inst->setMetadata(llvm::LLVMContext::MD_nontemporal, metadata); } }]; code setAccessGroupsMetadataCode = [{ @@ -355,6 +354,10 @@ UnitAttr:$nontemporal); let results = (outs LLVM_LoadableType:$res); string llvmInstName = "Load"; + let assemblyFormat = [{ + (`volatile` $volatile_^)? $addr attr-dict `:` + custom(type($addr), type($res)) + }]; string llvmBuilder = [{ auto *inst = builder.CreateLoad($_resultType, $addr, $volatile_); }] # setAlignmentCode @@ -365,9 +368,12 @@ # [{ $res = inst; }]; - // FIXME: Import attributes. string mlirBuilder = [{ - $res = $_builder.create($_location, $_resultType, $addr); + auto *loadInst = cast(inst); + unsigned alignment = loadInst->getAlign().value(); + $res = $_builder.create($_location, $_resultType, $addr, + alignment, loadInst->isVolatile(), + loadInst->hasMetadata(llvm::LLVMContext::MD_nontemporal)); }]; let builders = [ OpBuilder<(ins "Value":$addr, CArg<"unsigned", "0">:$alignment, @@ -378,9 +384,10 @@ "when the pointer type is opaque"); build($_builder, $_state, type, addr, alignment, isVolatile, isNonTemporal); }]>, - OpBuilder<(ins "Type":$t, "Value":$addr, + OpBuilder<(ins "Type":$type, "Value":$addr, CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile, - CArg<"bool", "false">:$isNonTemporal)>,]; + CArg<"bool", "false">:$isNonTemporal)> + ]; let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -395,6 +402,10 @@ OptionalAttr:$alignment, UnitAttr:$volatile_, UnitAttr:$nontemporal); string llvmInstName = "Store"; + let assemblyFormat = [{ + (`volatile` $volatile_^)? $value `,` $addr attr-dict `:` + custom(type($value), type($addr)) + }]; string llvmBuilder = [{ auto *inst = builder.CreateStore($value, $addr, $volatile_); }] # setAlignmentCode @@ -402,16 +413,18 @@ # setAccessGroupsMetadataCode # setAliasScopeMetadataCode # setTBAAMetadataCode; - // FIXME: Import attributes. string mlirBuilder = [{ - $_op = $_builder.create($_location, $value, $addr); + auto *storeInst = cast(inst); + unsigned alignment = storeInst->getAlign().value(); + $_op = $_builder.create($_location, $value, $addr, + alignment, storeInst->isVolatile(), + storeInst->hasMetadata(llvm::LLVMContext::MD_nontemporal)); }]; let builders = [ OpBuilder<(ins "Value":$value, "Value":$addr, CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile, CArg<"bool", "false">:$isNonTemporal)> - ]; - let hasCustomAssemblyFormat = 1; + ]; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -44,8 +44,6 @@ #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc" -static constexpr const char kVolatileAttrName[] = "volatile_"; -static constexpr const char kNonTemporalAttrName[] = "nontemporal"; static constexpr const char kElemTypeAttrName[] = "elem_type"; #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc" @@ -665,7 +663,7 @@ } //===----------------------------------------------------------------------===// -// Builder, printer and parser for for LLVM::LoadOp. +// LoadOp //===----------------------------------------------------------------------===// /// Verifies the given array attribute contains symbol references and checks the @@ -759,29 +757,13 @@ LogicalResult LoadOp::verify() { return verifyMemOpMetadata(*this); } -void LoadOp::build(OpBuilder &builder, OperationState &result, Type t, +void LoadOp::build(OpBuilder &builder, OperationState &state, Type type, Value addr, unsigned alignment, bool isVolatile, bool isNonTemporal) { - result.addOperands(addr); - result.addTypes(t); - if (isVolatile) - result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); - if (isNonTemporal) - result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); - if (alignment != 0) - result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); -} - -void LoadOp::print(OpAsmPrinter &p) { - p << ' '; - if (getVolatile_()) - p << "volatile "; - p << getAddr(); - p.printOptionalAttrDict((*this)->getAttrs(), - {kVolatileAttrName, kElemTypeAttrName}); - p << " : " << getAddr().getType(); - if (getAddr().getType().cast().isOpaque()) - p << " -> " << getType(); + build(builder, state, type, addr, /*access_groups=*/nullptr, + /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr, + alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile, + isNonTemporal); } // Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return @@ -797,105 +779,85 @@ return llvmTy.getElementType(); } -// ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type -// (`->` type)? -ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand addr; - Type type; +/// Parses the LoadOp type either using the typed or opaque pointer format. +// TODO: Drop once the typed pointer assembly format is not needed anymore. +static ParseResult parseLoadType(OpAsmParser &parser, Type &type, + Type &elementType) { SMLoc trailingTypeLoc; - - if (succeeded(parser.parseOptionalKeyword("volatile"))) - result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); - - if (parser.parseOperand(addr) || - parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || - parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || - parser.resolveOperand(addr, type, result.operands)) + if (parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) return failure(); - std::optional elemTy = + std::optional pointerElementType = getLoadStoreElementType(parser, type, trailingTypeLoc); - if (!elemTy) + if (!pointerElementType) return failure(); - if (*elemTy) { - result.addTypes(*elemTy); + if (*pointerElementType) { + elementType = *pointerElementType; return success(); } - Type trailingType; - if (parser.parseArrow() || parser.parseType(trailingType)) + if (parser.parseArrow() || parser.parseType(elementType)) return failure(); - result.addTypes(trailingType); return success(); } +/// Prints the LoadOp type either using the typed or opaque pointer format. +// TODO: Drop once the typed pointer assembly format is not needed anymore. +static void printLoadType(OpAsmPrinter &printer, Operation *op, Type type, + Type elementType) { + printer << type; + auto pointerType = cast(type); + if (pointerType.isOpaque()) + printer << " -> " << elementType; +} + //===----------------------------------------------------------------------===// -// Builder, printer and parser for LLVM::StoreOp. +// StoreOp //===----------------------------------------------------------------------===// LogicalResult StoreOp::verify() { return verifyMemOpMetadata(*this); } -void StoreOp::build(OpBuilder &builder, OperationState &result, Value value, +void StoreOp::build(OpBuilder &builder, OperationState &state, Value value, Value addr, unsigned alignment, bool isVolatile, bool isNonTemporal) { - result.addOperands({value, addr}); - result.addTypes({}); - if (isVolatile) - result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); - if (isNonTemporal) - result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); - if (alignment != 0) - result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); + build(builder, state, value, addr, /*access_groups=*/nullptr, + /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr, + alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile, + isNonTemporal); } -void StoreOp::print(OpAsmPrinter &p) { - p << ' '; - if (getVolatile_()) - p << "volatile "; - p << getValue() << ", " << getAddr(); - p.printOptionalAttrDict((*this)->getAttrs(), {kVolatileAttrName}); - p << " : "; - if (getAddr().getType().cast().isOpaque()) - p << getValue().getType() << ", "; - p << getAddr().getType(); -} - -// ::= `llvm.store` `volatile` ssa-use `,` ssa-use -// attribute-dict? `:` type (`,` type)? -ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand addr, value; - Type type; +/// Parses the StoreOp type either using the typed or opaque pointer format. +// TODO: Drop once the typed pointer assembly format is not needed anymore. +static ParseResult parseStoreType(OpAsmParser &parser, Type &elementType, + Type &type) { SMLoc trailingTypeLoc; - - if (succeeded(parser.parseOptionalKeyword("volatile"))) - result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); - - if (parser.parseOperand(value) || parser.parseComma() || - parser.parseOperand(addr) || - parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || - parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) + if (parser.getCurrentLocation(&trailingTypeLoc) || + parser.parseType(elementType)) return failure(); - Type operandType; - if (succeeded(parser.parseOptionalComma())) { - operandType = type; - if (parser.parseType(type)) - return failure(); - } else { - std::optional maybeOperandType = - getLoadStoreElementType(parser, type, trailingTypeLoc); - if (!maybeOperandType) - return failure(); - operandType = *maybeOperandType; - } + if (succeeded(parser.parseOptionalComma())) + return parser.parseType(type); - if (parser.resolveOperand(value, operandType, result.operands) || - parser.resolveOperand(addr, type, result.operands)) + // Extract the element type from the pointer type. + type = elementType; + std::optional pointerElementType = + getLoadStoreElementType(parser, type, trailingTypeLoc); + if (!pointerElementType) return failure(); - + elementType = *pointerElementType; return success(); } +/// Prints the StoreOp type either using the typed or opaque pointer format. +// TODO: Drop once the typed pointer assembly format is not needed anymore. +static void printStoreType(OpAsmPrinter &printer, Operation *op, + Type elementType, Type type) { + auto pointerType = cast(type); + if (pointerType.isOpaque()) + printer << elementType << ", "; + printer << type; +} + //===----------------------------------------------------------------------===// // CallOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll --- a/mlir/test/Target/LLVMIR/Import/instructions.ll +++ b/mlir/test/Target/LLVMIR/Import/instructions.ll @@ -251,7 +251,7 @@ ; CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]] ; CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]] define half @extract_element(ptr %vec, i32 %idx) { - ; CHECK: %[[V1:.+]] = llvm.load %[[VEC]] : !llvm.ptr -> vector<4xf16> + ; CHECK: %[[V1:.+]] = llvm.load %[[VEC]] {{.*}} : !llvm.ptr -> vector<4xf16> ; CHECK: %[[V2:.+]] = llvm.extractelement %[[V1]][%[[IDX]] : i32] : vector<4xf16> ; CHECK: llvm.return %[[V2]] %1 = load <4 x half>, ptr %vec @@ -266,7 +266,7 @@ ; CHECK-SAME: %[[VAL:[a-zA-Z0-9]+]] ; CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]] define <4 x half> @insert_element(ptr %vec, half %val, i32 %idx) { - ; CHECK: %[[V1:.+]] = llvm.load %[[VEC]] : !llvm.ptr -> vector<4xf16> + ; CHECK: %[[V1:.+]] = llvm.load %[[VEC]] {{.*}} : !llvm.ptr -> vector<4xf16> ; CHECK: %[[V2:.+]] = llvm.insertelement %[[VAL]], %[[V1]][%[[IDX]] : i32] : vector<4xf16> ; CHECK: llvm.return %[[V2]] %1 = load <4 x half>, ptr %vec @@ -352,13 +352,20 @@ ; CHECK-LABEL: @load_store ; CHECK-SAME: %[[PTR:[a-zA-Z0-9]+]] define void @load_store(ptr %ptr) { - ; CHECK: %[[V1:[0-9]+]] = llvm.load %[[PTR]] : !llvm.ptr -> f64 - ; CHECK: llvm.store %[[V1]], %[[PTR]] : f64, !llvm.ptr + ; CHECK: %[[V1:[0-9]+]] = llvm.load %[[PTR]] {alignment = 8 : i64} : !llvm.ptr -> f64 + ; CHECK: %[[V2:[0-9]+]] = llvm.load volatile %[[PTR]] {alignment = 16 : i64, nontemporal} : !llvm.ptr -> f64 %1 = load double, ptr %ptr + %2 = load volatile double, ptr %ptr, align 16, !nontemporal !0 + + ; CHECK: llvm.store %[[V1]], %[[PTR]] {alignment = 8 : i64} : f64, !llvm.ptr + ; CHECK: llvm.store volatile %[[V2]], %[[PTR]] {alignment = 16 : i64, nontemporal} : f64, !llvm.ptr store double %1, ptr %ptr + store volatile double %2, ptr %ptr, align 16, !nontemporal !0 ret void } +!0 = !{i32 1} + ; // ----- ; CHECK-LABEL: @atomic_rmw diff --git a/mlir/test/Target/LLVMIR/Import/metadata-loop.ll b/mlir/test/Target/LLVMIR/Import/metadata-loop.ll --- a/mlir/test/Target/LLVMIR/Import/metadata-loop.ll +++ b/mlir/test/Target/LLVMIR/Import/metadata-loop.ll @@ -8,13 +8,12 @@ ; CHECK: } ; CHECK-LABEL: llvm.func @access_group -; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] define void @access_group(ptr %arg1) { - ; CHECK: llvm.load %[[ARG1]] {access_groups = [@__llvm_global_metadata::@[[$GROUP0]], @__llvm_global_metadata::@[[$GROUP1]]]} + ; CHECK: access_groups = [@__llvm_global_metadata::@[[$GROUP0]], @__llvm_global_metadata::@[[$GROUP1]]] %1 = load i32, ptr %arg1, !llvm.access.group !0 - ; CHECK: llvm.load %[[ARG1]] {access_groups = [@__llvm_global_metadata::@[[$GROUP2]], @__llvm_global_metadata::@[[$GROUP0]]]} + ; CHECK: access_groups = [@__llvm_global_metadata::@[[$GROUP2]], @__llvm_global_metadata::@[[$GROUP0]]] %2 = load i32, ptr %arg1, !llvm.access.group !1 - ; CHECK: llvm.load %[[ARG1]] {access_groups = [@__llvm_global_metadata::@[[$GROUP3]]]} + ; CHECK: access_groups = [@__llvm_global_metadata::@[[$GROUP3]]] %3 = load i32, ptr %arg1, !llvm.access.group !2 ret void }