diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -282,8 +282,11 @@ # setAliasAnalysisMetadataCode; } -def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2], [], - /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1> { +def LLVM_MemsetOp : LLVM_ZeroResultIntrOp<"memset", [0, 2], + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods], + /*requiresAccessGroup=*/1, /*requiresAliasAnalysis=*/1> { dag args = (ins Arg:$dst, I8:$val, AnySignlessInteger:$len, I1Attr:$isVolatile); // Append the alias attributes defined by LLVM_IntrOpBase. diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td @@ -103,16 +103,32 @@ "bool", "loadsFrom", (ins "const ::mlir::MemorySlot &":$slot) >, + InterfaceMethod<[{ + Gets whether this operation stores to the specified slot. + + No IR mutation is allowed in this method. + }], + "bool", "storesTo", + (ins "const ::mlir::MemorySlot &":$slot) + >, InterfaceMethod<[{ Gets the value stored to the provided memory slot, or returns a null value if this operation does not store to this slot. An operation storing a value to a slot must always be able to provide the value it - stores. This method is only called on operations that use the slot. - - No IR mutation is allowed in this method. + stores. This method is only called once per slot promotion, and only + on operations that store to the slot according to the `storesTo` method. + The returned value must dominate all operations dominated by the storing + operation. + + If IR must be mutated to extract a concrete value being stored, mutation + must happen through the provided rewriter. The rewriter is located + immediately after the memory operation on call. No IR deletion is + allowed in this method. IR mutations must not introduce new uses of the + memory slot. Existing control flow must not be modified. }], "::mlir::Value", "getStored", - (ins "const ::mlir::MemorySlot &":$slot) + (ins "const ::mlir::MemorySlot &":$slot, + "::mlir::RewriterBase &":$rewriter) >, InterfaceMethod<[{ Checks that this operation can be promoted to no longer use the provided diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -13,10 +13,13 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/ValueRange.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -119,12 +122,20 @@ return getAddr() == slot.ptr; } -Value LLVM::LoadOp::getStored(const MemorySlot &slot) { return {}; } +bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; } + +Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) { + llvm_unreachable("getStored should not be called on LoadOp"); +} bool LLVM::StoreOp::loadsFrom(const MemorySlot &slot) { return false; } -Value LLVM::StoreOp::getStored(const MemorySlot &slot) { - return getAddr() == slot.ptr ? getValue() : Value(); +bool LLVM::StoreOp::storesTo(const MemorySlot &slot) { + return getAddr() == slot.ptr; +} + +Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) { + return getValue(); } bool LLVM::LoadOp::canUsesBeRemoved( @@ -310,7 +321,8 @@ for (IntegerAttr index : llvm::drop_begin(indices)) { // Ensure the structure of the type being indexed can be reasoned about. // This includes rejecting any potential typed pointer. - auto destructurable = llvm::dyn_cast(selectedType); + auto destructurable = + llvm::dyn_cast(selectedType); if (!destructurable) return {}; @@ -369,7 +381,8 @@ DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot, DenseMap &subslots, RewriterBase &rewriter) { - IntegerAttr firstLevelIndex = llvm::dyn_cast_if_present(getIndices()[1]); + IntegerAttr firstLevelIndex = + llvm::dyn_cast_if_present(getIndices()[1]); const MemorySlot &newSlot = subslots.at(firstLevelIndex); ArrayRef remainingIndices = getRawConstantIndices().slice(2); @@ -400,6 +413,180 @@ return DeletionKind::Keep; } +//===----------------------------------------------------------------------===// +// Interfaces for memset +//===----------------------------------------------------------------------===// + +/// Returns the length of the given memset in bytes if it can be known at +/// compile-time on a best-effort basis, nothing otherwise. +static std::optional getStaticMemsetLen(LLVM::MemsetOp op) { + IntegerAttr::ValueType memsetLen; + if (!matchPattern(op.getLen(), m_ConstantInt(&memsetLen))) + return {}; + + if (memsetLen.getBitWidth() > 64) + return {}; + + return memsetLen.getZExtValue(); +} + +/// Returns whether one can be sure the memset does not write outside of the +/// bounds of the given slot, on a best-effort basis. +static bool definitelyWritesOnlyWithinSlot(LLVM::MemsetOp op, + const MemorySlot &slot, + DataLayout &dataLayout) { + if (!isa(slot.ptr.getType()) || + op.getDst() != slot.ptr) + return false; + + std::optional memsetLen = getStaticMemsetLen(op); + return memsetLen && *memsetLen <= dataLayout.getTypeSize(slot.elemType); +} + +bool LLVM::MemsetOp::loadsFrom(const MemorySlot &slot) { return false; } + +bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) { + return getDst() == slot.ptr; +} + +Value LLVM::MemsetOp::getStored(const MemorySlot &slot, + RewriterBase &rewriter) { + // TODO: Support non-integer types. + return TypeSwitch(slot.elemType) + .Case([&](IntegerType intType) -> Value { + if (intType.getWidth() == 8) + return getVal(); + + assert(intType.getWidth() % 8 == 0); + + // Build the memset integer by repeatedly shifting the value and or-ing + // it with the previous value. + uint64_t coveredBits = 8; + Value currentValue = + rewriter.create(getLoc(), intType, getVal()); + while (coveredBits < intType.getWidth()) { + Value shiftBy = + rewriter.create(getLoc(), intType, coveredBits); + Value shifted = + rewriter.create(getLoc(), currentValue, shiftBy); + currentValue = + rewriter.create(getLoc(), currentValue, shifted); + coveredBits *= 2; + } + + return currentValue; + }) + .Default([](Type) -> Value { + llvm_unreachable( + "getStored should not be called on memset to unsupported type"); + }); +} + +bool LLVM::MemsetOp::canUsesBeRemoved( + const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, + SmallVectorImpl &newBlockingUses) { + // TODO: Support non-integer types. + bool canConvertType = + TypeSwitch(slot.elemType) + .Case([](IntegerType intType) { + return intType.getWidth() % 8 == 0 && intType.getWidth() > 0; + }) + .Default([](Type) { return false; }); + if (!canConvertType) + return false; + + if (getIsVolatile()) + return false; + + DataLayout layout = DataLayout::closest(*this); + return getStaticMemsetLen(*this) == layout.getTypeSize(slot.elemType); +} + +DeletionKind LLVM::MemsetOp::removeBlockingUses( + const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, + RewriterBase &rewriter, Value reachingDefinition) { + return DeletionKind::Delete; +} + +LogicalResult LLVM::MemsetOp::ensureOnlySafeAccesses( + const MemorySlot &slot, SmallVectorImpl &mustBeSafelyUsed) { + DataLayout dataLayout = DataLayout::closest(*this); + return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout)); +} + +bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot, + SmallPtrSetImpl &usedIndices, + SmallVectorImpl &mustBeSafelyUsed) { + if (&slot.elemType.getDialect() != getOperation()->getDialect()) + return false; + + if (getIsVolatile()) + return false; + + if (!slot.elemType.cast() + .getSubelementIndexMap()) + return false; + + DataLayout dataLayout = DataLayout::closest(*this); + return definitelyWritesOnlyWithinSlot(*this, slot, dataLayout); +} + +DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot, + DenseMap &subslots, + RewriterBase &rewriter) { + std::optional> types = + slot.elemType.cast().getSubelementIndexMap(); + + IntegerAttr memsetLenAttr; + bool successfulMatch = + matchPattern(getLen(), m_Constant(&memsetLenAttr)); + (void)successfulMatch; + assert(successfulMatch); + + bool packed = false; + if (auto structType = dyn_cast(slot.elemType)) + packed = structType.isPacked(); + + Type i32 = IntegerType::get(getContext(), 32); + DataLayout dataLayout = DataLayout::closest(*this); + uint64_t memsetLen = memsetLenAttr.getValue().getZExtValue(); + uint64_t covered = 0; + for (size_t i = 0; i < types->size(); i++) { + // Create indices on the fly to get elements in the right order. + Attribute index = IntegerAttr::get(i32, i); + Type elemType = types->at(index); + uint64_t typeSize = dataLayout.getTypeSize(elemType); + + if (!packed) + covered = + llvm::alignTo(covered, dataLayout.getTypeABIAlignment(elemType)); + + if (covered >= memsetLen) + break; + + // If this subslot is used, apply a new memset to it. + // Otherwise, only compute its offset within the original memset. + if (subslots.contains(index)) { + uint64_t newMemsetSize = std::min(memsetLen - covered, typeSize); + + Value newMemsetSizeValue = + rewriter + .create( + getLen().getLoc(), + IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize)) + .getResult(); + + rewriter.create(getLoc(), subslots.at(index).ptr, + getVal(), newMemsetSizeValue, + getIsVolatile()); + } + + covered += typeSize; + } + + return DeletionKind::Delete; +} + //===----------------------------------------------------------------------===// // Interfaces for destructurable types //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp @@ -23,6 +23,7 @@ #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/ErrorHandling.h" using namespace mlir; @@ -160,7 +161,12 @@ return getMemRef() == slot.ptr; } -Value memref::LoadOp::getStored(const MemorySlot &slot) { return {}; } +bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; } + +Value memref::LoadOp::getStored(const MemorySlot &slot, + RewriterBase &rewriter) { + llvm_unreachable("getStored should not be called on LoadOp"); +} bool memref::LoadOp::canUsesBeRemoved( const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, @@ -222,9 +228,12 @@ bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; } -Value memref::StoreOp::getStored(const MemorySlot &slot) { - if (getMemRef() != slot.ptr) - return {}; +bool memref::StoreOp::storesTo(const MemorySlot &slot) { + return getMemRef() == slot.ptr; +} + +Value memref::StoreOp::getStored(const MemorySlot &slot, + RewriterBase &rewriter) { return getValue(); } diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp --- a/mlir/lib/Transforms/Mem2Reg.cpp +++ b/mlir/lib/Transforms/Mem2Reg.cpp @@ -172,12 +172,13 @@ /// Computes the reaching definition for all the operations that require /// promotion. `reachingDef` is the value the slot should contain at the /// beginning of the block. This method returns the reached definition at the - /// end of the block. + /// end of the block. This method must only be called at most once per block. Value computeReachingDefInBlock(Block *block, Value reachingDef); /// Computes the reaching definition for all the operations that require /// promotion. `reachingDef` corresponds to the initial value the /// slot will contain before any write, typically a poison value. + /// This method must only be called at most once per region. void computeReachingDefInRegion(Region *region, Value reachingDef); /// Removes the blocking uses of the slot, in topological order. @@ -326,7 +327,7 @@ // If we store to the slot, further loads will see that value. // Because we did not meet any load before, the value is not live-in. - if (memOp.getStored(slot)) + if (memOp.storesTo(slot)) break; } } @@ -365,7 +366,7 @@ SmallPtrSet definingBlocks; for (Operation *user : slot.ptr.getUsers()) if (auto storeOp = dyn_cast(user)) - if (storeOp.getStored(slot)) + if (storeOp.storesTo(slot)) definingBlocks.insert(user->getBlock()); idfCalculator.setDefiningBlocks(definingBlocks); @@ -416,13 +417,21 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block, Value reachingDef) { - for (Operation &op : block->getOperations()) { + SmallVector blockOps; + for (Operation &op : block->getOperations()) + blockOps.push_back(&op); + for (Operation *op : blockOps) { if (auto memOp = dyn_cast(op)) { if (info.userToBlockingUses.contains(memOp)) reachingDefs.insert({memOp, reachingDef}); - if (Value stored = memOp.getStored(slot)) + if (memOp.storesTo(slot)) { + rewriter.setInsertionPointAfter(memOp); + Value stored = memOp.getStored(slot, rewriter); + assert(stored && "a memory operation storing to a slot must provide a " + "new definition of the slot"); reachingDef = stored; + } } } diff --git a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir @@ -0,0 +1,145 @@ +// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg{region-simplify=false}))" --split-input-file | FileCheck %s + +// CHECK-LABEL: llvm.func @basic_memset +llvm.func @basic_memset() -> i32 { + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr + %memset_value = llvm.mlir.constant(42 : i8) : i8 + %memset_len = llvm.mlir.constant(4 : i32) : i32 + // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8 + // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32 + "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () + // CHECK-NOT: "llvm.intr.memset" + // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32 + // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]] + // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]] + // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]] + // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]] + // CHECK-NOT: "llvm.intr.memset" + %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32 + // CHECK: llvm.return %[[VALUE_32]] : i32 + llvm.return %2 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @allow_dynamic_value_memset +// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8) +llvm.func @allow_dynamic_value_memset(%memset_value: i8) -> i32 { + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr + %memset_len = llvm.mlir.constant(4 : i32) : i32 + // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32 + "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () + // CHECK-NOT: "llvm.intr.memset" + // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32 + // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]] + // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]] + // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]] + // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]] + // CHECK-NOT: "llvm.intr.memset" + %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32 + // CHECK: llvm.return %[[VALUE_32]] : i32 + llvm.return %2 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @exotic_target_memset +llvm.func @exotic_target_memset() -> i40 { + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.alloca %0 x i40 {alignment = 4 : i64} : (i32) -> !llvm.ptr + %memset_value = llvm.mlir.constant(42 : i8) : i8 + %memset_len = llvm.mlir.constant(5 : i32) : i32 + // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8 + // CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i40) : i40 + // CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i40) : i40 + // CHECK-DAG: %[[C32:.*]] = llvm.mlir.constant(32 : i40) : i40 + "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () + // CHECK-NOT: "llvm.intr.memset" + // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i40 + // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]] + // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]] + // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]] + // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]] + // CHECK: %[[SHIFTED_COMPL:.*]] = llvm.shl %[[VALUE_32]], %[[C32]] + // CHECK: %[[VALUE_COMPL:.*]] = llvm.or %[[VALUE_32]], %[[SHIFTED_COMPL]] + // CHECK-NOT: "llvm.intr.memset" + %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i40 + // CHECK: llvm.return %[[VALUE_COMPL]] : i40 + llvm.return %2 : i40 +} + +// ----- + +// CHECK-LABEL: llvm.func @no_volatile_memset +llvm.func @no_volatile_memset() -> i32 { + // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32 + // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8 + // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32 + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr + %memset_value = llvm.mlir.constant(42 : i8) : i8 + %memset_len = llvm.mlir.constant(4 : i32) : i32 + // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = true}> + "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = true}> : (!llvm.ptr, i8, i32) -> () + %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32 + llvm.return %2 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @no_partial_memset +llvm.func @no_partial_memset() -> i32 { + // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32 + // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8 + // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(2 : i32) : i32 + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr + %memset_value = llvm.mlir.constant(42 : i8) : i8 + %memset_len = llvm.mlir.constant(2 : i32) : i32 + // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}> + "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () + %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32 + llvm.return %2 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @no_overflowing_memset +llvm.func @no_overflowing_memset() -> i32 { + // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32 + // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8 + // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(6 : i32) : i32 + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr + %memset_value = llvm.mlir.constant(42 : i8) : i8 + %memset_len = llvm.mlir.constant(6 : i32) : i32 + // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}> + "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () + %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32 + llvm.return %2 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @only_byte_aligned_integers_memset +llvm.func @only_byte_aligned_integers_memset() -> i10 { + // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i10 + // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8 + // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(2 : i32) : i32 + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.alloca %0 x i10 {alignment = 4 : i64} : (i32) -> !llvm.ptr + %memset_value = llvm.mlir.constant(42 : i8) : i8 + %memset_len = llvm.mlir.constant(2 : i32) : i32 + // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}> + "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () + %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i10 + llvm.return %2 : i10 +} diff --git a/mlir/test/Dialect/LLVMIR/sroa-intrinsics.mlir b/mlir/test/Dialect/LLVMIR/sroa-intrinsics.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/sroa-intrinsics.mlir @@ -0,0 +1,237 @@ +// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(sroa))" --split-input-file | FileCheck %s + +// CHECK-LABEL: llvm.func @memset +llvm.func @memset() -> i32 { + // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32 + // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8 + // After SROA, only one i32 will be actually used, so only 4 bytes will be set. + // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32 + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %memset_value = llvm.mlir.constant(42 : i8) : i8 + // 16 bytes means it will span over the first 4 i32 entries + %memset_len = llvm.mlir.constant(16 : i32) : i32 + // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}> + "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () + %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32> + %3 = llvm.load %2 : !llvm.ptr -> i32 + llvm.return %3 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @memset_partial +llvm.func @memset_partial() -> i32 { + // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32 + // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8 + // After SROA, only the second i32 will be actually used. As the memset writes up + // to half of it, only 2 bytes will be set. + // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(2 : i32) : i32 + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %memset_value = llvm.mlir.constant(42 : i8) : i8 + // 6 bytes means it will span over the first i32 and half of the second i32. + %memset_len = llvm.mlir.constant(6 : i32) : i32 + // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}> + "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () + %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32> + %3 = llvm.load %2 : !llvm.ptr -> i32 + llvm.return %3 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @memset_full +llvm.func @memset_full() -> i32 { + // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32 + // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8 + // After SROA, only one i32 will be actually used, so only 4 bytes will be set. + // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32 + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %memset_value = llvm.mlir.constant(42 : i8) : i8 + // 40 bytes means it will span over the entire array + %memset_len = llvm.mlir.constant(40 : i32) : i32 + // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}> + "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () + %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32> + %3 = llvm.load %2 : !llvm.ptr -> i32 + llvm.return %3 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @memset_too_much +llvm.func @memset_too_much() -> i32 { + // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x !llvm.array<10 x i32> + // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8 + // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(41 : i32) : i32 + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %memset_value = llvm.mlir.constant(42 : i8) : i8 + // 41 bytes means it will span over the entire array, and then some + %memset_len = llvm.mlir.constant(41 : i32) : i32 + // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}> + "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () + %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32> + %3 = llvm.load %2 : !llvm.ptr -> i32 + llvm.return %3 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @memset_no_volatile +llvm.func @memset_no_volatile() -> i32 { + // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x !llvm.array<10 x i32> + // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8 + // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(16 : i32) : i32 + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %memset_value = llvm.mlir.constant(42 : i8) : i8 + %memset_len = llvm.mlir.constant(16 : i32) : i32 + // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = true}> + "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = true}> : (!llvm.ptr, i8, i32) -> () + %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32> + %3 = llvm.load %2 : !llvm.ptr -> i32 + llvm.return %3 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @indirect_memset +llvm.func @indirect_memset() -> i32 { + // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32 + // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8 + // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32 + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32)> : (i32) -> !llvm.ptr + %memset_value = llvm.mlir.constant(42 : i8) : i8 + // This memset will only cover the selected element. + %memset_len = llvm.mlir.constant(4 : i32) : i32 + %2 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32)> + // CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}> + "llvm.intr.memset"(%2, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () + %3 = llvm.load %2 : !llvm.ptr -> i32 + llvm.return %3 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @invalid_indirect_memset +llvm.func @invalid_indirect_memset() -> i32 { + // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x !llvm.struct<"foo", (i32, i32)> + // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8 + // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(6 : i32) : i32 + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32)> : (i32) -> !llvm.ptr + %memset_value = llvm.mlir.constant(42 : i8) : i8 + // This memset will go slightly beyond one of the elements. + %memset_len = llvm.mlir.constant(6 : i32) : i32 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] + %2 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32)> + // CHECK: "llvm.intr.memset"(%[[GEP]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}> + "llvm.intr.memset"(%2, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () + %3 = llvm.load %2 : !llvm.ptr -> i32 + llvm.return %3 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @memset_double_use +llvm.func @memset_double_use() -> i32 { + // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: %[[ALLOCA_INT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32 + // CHECK-DAG: %[[ALLOCA_FLOAT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x f32 + // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8 + // After SROA, only one i32 will be actually used, so only 4 bytes will be set. + // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32 + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %memset_value = llvm.mlir.constant(42 : i8) : i8 + // 8 bytes means it will span over the two i32 entries. + %memset_len = llvm.mlir.constant(8 : i32) : i32 + // We expect two generated memset, one for each field. + // CHECK-NOT: "llvm.intr.memset" + // CHECK-DAG: "llvm.intr.memset"(%[[ALLOCA_INT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}> + // CHECK-DAG: "llvm.intr.memset"(%[[ALLOCA_FLOAT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}> + // CHECK-NOT: "llvm.intr.memset" + "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () + %2 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f32)> + %3 = llvm.load %2 : !llvm.ptr -> i32 + %4 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f32)> + %5 = llvm.load %4 : !llvm.ptr -> f32 + // We use this exotic bitcast to use the f32 easily. Semantics do not matter here. + %6 = llvm.bitcast %5 : f32 to i32 + %7 = llvm.add %3, %6 : i32 + llvm.return %7 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @memset_considers_alignment +llvm.func @memset_considers_alignment() -> i32 { + // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: %[[ALLOCA_INT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32 + // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8 + // After SROA, only 32-bit values will be actually used, so only 4 bytes will be set. + // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32 + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i8, i32, f32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %memset_value = llvm.mlir.constant(42 : i8) : i8 + // 8 bytes means it will span over the i8 and the i32 entry. + // Because of padding, the f32 entry will not be touched. + %memset_len = llvm.mlir.constant(8 : i32) : i32 + // Even though the two i32 are used, only one memset should be generated, + // as the second i32 is not touched by the initial memset. + // CHECK-NOT: "llvm.intr.memset" + // CHECK: "llvm.intr.memset"(%[[ALLOCA_INT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}> + // CHECK-NOT: "llvm.intr.memset" + "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () + %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i8, i32, f32)> + %3 = llvm.load %2 : !llvm.ptr -> i32 + %4 = llvm.getelementptr %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i8, i32, f32)> + %5 = llvm.load %4 : !llvm.ptr -> f32 + // We use this exotic bitcast to use the f32 easily. Semantics do not matter here. + %6 = llvm.bitcast %5 : f32 to i32 + %7 = llvm.add %3, %6 : i32 + llvm.return %7 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @memset_considers_packing +llvm.func @memset_considers_packing() -> i32 { + // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: %[[ALLOCA_INT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32 + // CHECK-DAG: %[[ALLOCA_FLOAT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x f32 + // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8 + // After SROA, only 32-bit values will be actually used, so only 4 bytes will be set. + // CHECK-DAG: %[[MEMSET_LEN_WHOLE:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK-DAG: %[[MEMSET_LEN_PARTIAL:.*]] = llvm.mlir.constant(3 : i32) : i32 + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.alloca %0 x !llvm.struct<"foo", packed (i8, i32, f32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %memset_value = llvm.mlir.constant(42 : i8) : i8 + // 8 bytes means it will span over all the fields, because there is no padding as the struct is packed. + %memset_len = llvm.mlir.constant(8 : i32) : i32 + // Now all fields are touched by the memset. + // CHECK-NOT: "llvm.intr.memset" + // CHECK: "llvm.intr.memset"(%[[ALLOCA_INT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN_WHOLE]]) <{isVolatile = false}> + // CHECK: "llvm.intr.memset"(%[[ALLOCA_FLOAT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN_PARTIAL]]) <{isVolatile = false}> + // CHECK-NOT: "llvm.intr.memset" + "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> () + %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", packed (i8, i32, f32)> + %3 = llvm.load %2 : !llvm.ptr -> i32 + %4 = llvm.getelementptr %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", packed (i8, i32, f32)> + %5 = llvm.load %4 : !llvm.ptr -> f32 + // We use this exotic bitcast to use the f32 easily. Semantics do not matter here. + %6 = llvm.bitcast %5 : f32 to i32 + %7 = llvm.add %3, %6 : i32 + llvm.return %7 : i32 +}