diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -235,6 +235,19 @@ }]; } +// Code definition that is used for nontemporal metadata creation. +class MemoryOpWithAlignmentAndAttributes : MemoryOpWithAlignmentBase { + code setNonTemporalMetadataCode = [{ + if ($nontemporal) { + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::MDNode *metadata = llvm::MDNode::get( + inst->getContext(), llvm::ConstantAsMetadata::get( + builder.getInt32(1))); + inst->setMetadata(module->getMDKindID("nontemporal"), metadata); + } + }]; +} + // Memory-related operations. def LLVM_AllocaOp : MemoryOpWithAlignmentBase, @@ -266,52 +279,49 @@ }]; } def LLVM_LoadOp : - MemoryOpWithAlignmentBase, + MemoryOpWithAlignmentAndAttributes, LLVM_OneResultOp<"load">, - Arguments<(ins LLVM_Type:$addr, OptionalAttr:$alignment)> { + Arguments<(ins LLVM_Type:$addr, + OptionalAttr:$alignment, + UnitAttr:$volatile_, + UnitAttr:$nontemporal)> { string llvmBuilder = [{ - auto *inst = builder.CreateLoad($addr); - }] # setAlignmentCode # [{ + auto *inst = builder.CreateLoad($addr, $volatile_); + }] # setAlignmentCode # setNonTemporalMetadataCode # [{ $res = inst; }]; let builders = [OpBuilder< - "OpBuilder &b, OperationState &result, Value addr, unsigned alignment = 0", + "OpBuilder &b, OperationState &result, Value addr, " + "unsigned alignment = 0, bool isVolatile = false, " + "bool isNonTemporal = false", [{ auto type = addr.getType().cast().getPointerElementTy(); - build(b, result, type, addr, alignment); + build(b, result, type, addr, alignment, isVolatile, isNonTemporal); }]>, OpBuilder< "OpBuilder &b, OperationState &result, Type t, Value addr, " - "unsigned alignment = 0", - [{ - if (alignment == 0) - return build(b, result, t, addr, IntegerAttr()); - build(b, result, t, addr, b.getI64IntegerAttr(alignment)); - }]>]; + "unsigned alignment = 0, bool isVolatile = false, " + "bool isNonTemporal = false">]; let parser = [{ return parseLoadOp(parser, result); }]; let printer = [{ printLoadOp(p, *this); }]; let verifier = alignmentVerifierCode; } def LLVM_StoreOp : - MemoryOpWithAlignmentBase, + MemoryOpWithAlignmentAndAttributes, LLVM_ZeroResultOp<"store">, Arguments<(ins LLVM_Type:$value, LLVM_Type:$addr, - OptionalAttr:$alignment)> { + OptionalAttr:$alignment, + UnitAttr:$volatile_, + UnitAttr:$nontemporal)> { string llvmBuilder = [{ - auto *inst = builder.CreateStore($value, $addr); - }] # setAlignmentCode; - let builders = [ - OpBuilder< + auto *inst = builder.CreateStore($value, $addr, $volatile_); + }] # setAlignmentCode # setNonTemporalMetadataCode; + let builders = [OpBuilder< "OpBuilder &b, OperationState &result, Value value, Value addr, " - "unsigned alignment = 0", - [{ - if (alignment == 0) - return build(b, result, ArrayRef{}, value, addr, IntegerAttr()); - build(b, result, ArrayRef{}, value, addr, - b.getI64IntegerAttr(alignment)); - }] - >]; + "unsigned alignment = 0, bool isVolatile = false, " + "bool isNonTemporal = false"> + ]; let parser = [{ return parseStoreOp(parser, result); }]; let printer = [{ printStoreOp(p, *this); }]; let verifier = alignmentVerifierCode; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -31,6 +31,9 @@ using namespace mlir; using namespace mlir::LLVM; +static constexpr const char kVolatileAttrName[] = "volatile_"; +static constexpr const char kNonTemporalAttrName[] = "nontemporal"; + #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc" //===----------------------------------------------------------------------===// @@ -178,12 +181,28 @@ } //===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::LoadOp. +// Builder, printer and parser for for LLVM::LoadOp. //===----------------------------------------------------------------------===// +void LoadOp::build(OpBuilder &builder, OperationState &result, Type t, + Value addr, unsigned alignment, bool isVolatile, + bool isNonTemporal) { + result.addOperands(addr); + result.addTypes(t); + if (isVolatile) + result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); + if (isNonTemporal) + result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); + if (alignment != 0) + result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); +} + static void printLoadOp(OpAsmPrinter &p, LoadOp &op) { - p << op.getOperationName() << ' ' << op.addr(); - p.printOptionalAttrDict(op.getAttrs()); + p << op.getOperationName() << ' '; + if (op.volatile_()) + p << "volatile "; + p << op.addr(); + p.printOptionalAttrDict(op.getAttrs(), {kVolatileAttrName}); p << " : " << op.addr().getType(); } @@ -201,12 +220,15 @@ return llvmTy.getPointerElementTy(); } -// ::= `llvm.load` ssa-use attribute-dict? `:` type +// ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType addr; Type type; llvm::SMLoc trailingTypeLoc; + if (succeeded(parser.parseOptionalKeyword("volatile"))) + result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); + if (parser.parseOperand(addr) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || @@ -220,21 +242,41 @@ } //===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::StoreOp. +// Builder, printer and parser for LLVM::StoreOp. //===----------------------------------------------------------------------===// +void StoreOp::build(OpBuilder &builder, OperationState &result, Value value, + Value addr, unsigned alignment, bool isVolatile, + bool isNonTemporal) { + result.addOperands({value, addr}); + result.addTypes(ArrayRef{}); + if (isVolatile) + result.addAttribute(kVolatileAttrName, builder.getUnitAttr()); + if (isNonTemporal) + result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr()); + if (alignment != 0) + result.addAttribute("alignment", builder.getI64IntegerAttr(alignment)); +} + static void printStoreOp(OpAsmPrinter &p, StoreOp &op) { - p << op.getOperationName() << ' ' << op.value() << ", " << op.addr(); - p.printOptionalAttrDict(op.getAttrs()); + p << op.getOperationName() << ' '; + if (op.volatile_()) + p << "volatile "; + p << op.value() << ", " << op.addr(); + p.printOptionalAttrDict(op.getAttrs(), {kVolatileAttrName}); p << " : " << op.addr().getType(); } -// ::= `llvm.store` ssa-use `,` ssa-use attribute-dict? `:` type +// ::= `llvm.store` `volatile` ssa-use `,` ssa-use +// attribute-dict? `:` type static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType addr, value; Type type; llvm::SMLoc trailingTypeLoc; + if (succeeded(parser.parseOptionalKeyword("volatile"))) + result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr()); + if (parser.parseOperand(value) || parser.parseComma() || parser.parseOperand(addr) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir --- a/mlir/test/Target/llvmir.mlir +++ b/mlir/test/Target/llvmir.mlir @@ -1266,3 +1266,32 @@ } // CHECK: ![[NODE]] = !{!"branch_weights", i32 5, i32 10} + +// ----- + +llvm.func @volatile_store_and_load() { + %val = llvm.mlir.constant(5 : i32) : !llvm.i32 + %size = llvm.mlir.constant(1 : i64) : !llvm.i64 + %0 = llvm.alloca %size x !llvm.i32 : (!llvm.i64) -> (!llvm<"i32*">) + // CHECK: store volatile i32 5, i32* %{{.*}} + llvm.store volatile %val, %0 : !llvm<"i32*"> + // CHECK: %{{.*}} = load volatile i32, i32* %{{.*}} + %1 = llvm.load volatile %0: !llvm<"i32*"> + llvm.return +} + +// ----- + +// Check that nontemporal attribute is exported as metadata node. +llvm.func @nontemoral_store_and_load() { + %val = llvm.mlir.constant(5 : i32) : !llvm.i32 + %size = llvm.mlir.constant(1 : i64) : !llvm.i64 + %0 = llvm.alloca %size x !llvm.i32 : (!llvm.i64) -> (!llvm<"i32*">) + // CHECK: !nontemporal ![[NODE:[0-9]+]] + llvm.store %val, %0 {nontemporal} : !llvm<"i32*"> + // CHECK: !nontemporal ![[NODE]] + %1 = llvm.load %0 {nontemporal} : !llvm<"i32*"> + llvm.return +} + +// CHECK: ![[NODE]] = !{i32 1}