diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -340,16 +340,28 @@ unsigned alignment = op.getAlignment(); auto loc = op.getLoc(); + auto srcMemRefType = op.getMemref().getType().cast(); + // When we convert to LLVM, the input memref must have been normalized + // beforehand. Hence, this call is guaranteed to work. + auto [strides, offset] = getStridesAndOffset(srcMemRefType); + MemRefDescriptor memRefDescriptor(memref); Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); + // Skip if offset is zero. + if (offset != 0) { + Value offsetVal = ShapedType::isDynamic(offset) + ? memRefDescriptor.offset(rewriter, loc) + : createIndexConstant(rewriter, loc, offset); + Type elementType = + typeConverter->convertType(srcMemRefType.getElementType()); + ptr = rewriter.create(loc, ptr.getType(), elementType, ptr, + offsetVal); + } - // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that - // the asserted memref.alignedPtr isn't used anywhere else, as the real - // users like load/store/views always re-extract memref.alignedPtr as they - // get lowered. + // Emit llvm.assume(memref & (alignment - 1) == 0). // // This relies on LLVM's CSE optimization (potentially after SROA), since - // after CSE all memref.alignedPtr instances get de-duplicated into the same + // after CSE all memref instances should get de-duplicated into the same // pointer SSA value. auto intPtrType = getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -130,7 +130,7 @@ // ----- -// CHECK-LABEL: func @assume_alignment +// CHECK-LABEL: func @assume_alignment( func.func @assume_alignment(%0 : memref<4x4xf16>) { // CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64 @@ -145,6 +145,22 @@ // ----- +// CHECK-LABEL: func @assume_alignment_w_offset +func.func @assume_alignment_w_offset(%0 : memref<4x4xf16, strided<[?, ?], offset: ?>>) { + // CHECK-DAG: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-DAG: %[[OFFSET:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-DAG: %[[BUFF_ADDR:.*]] = llvm.getelementptr %[[PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, f16 + // CHECK-DAG: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : i64 + // CHECK-DAG: %[[MASK:.*]] = llvm.mlir.constant(15 : index) : i64 + // CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[BUFF_ADDR]] : !llvm.ptr to i64 + // CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : i64 + // CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : i64 + // CHECK-NEXT: "llvm.intr.assume"(%[[CONDITION]]) : (i1) -> () + memref.assume_alignment %0, 16 : memref<4x4xf16, strided<[?, ?], offset: ?>> + return +} +// ----- + // CHECK-LABEL: func @dim_of_unranked // CHECK32-LABEL: func @dim_of_unranked func.func @dim_of_unranked(%unranked: memref<*xi32>) -> index {