diff --git a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h --- a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h @@ -89,6 +89,14 @@ /// Returns the (LLVM) pointer type this descriptor contains. LLVM::LLVMPointerType getElementPtrType(); + /// Builds IR for getting the start address of the buffer represented + /// by this memref: + /// `memref.alignedPtr + memref.offset * sizeof(type.getElementType())`. + /// \note there is no setter for this one since it is derived from alignedPtr + /// and offset. + Value bufferPtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &converter, MemRefType type); + /// Builds IR populating a MemRef descriptor structure from a list of /// individual values composing that descriptor, in the following order: /// - allocated pointer; diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp --- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp +++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp @@ -199,6 +199,28 @@ .cast(); } +Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &converter, + MemRefType type) { + // When we convert to LLVM, the input memref must have been normalized + // beforehand. Hence, this call is guaranteed to work. + auto [strides, offsetCst] = getStridesAndOffset(type); + + Value ptr = alignedPtr(builder, loc); + // Skip if offset is zero. + if (offsetCst != 0) { + Type indexType = converter.getIndexType(); + Value offsetVal = + ShapedType::isDynamic(offsetCst) + ? offset(builder, loc) + : createIndexAttrConstant(builder, loc, indexType, offsetCst); + Type elementType = converter.convertType(type.getElementType()); + ptr = builder.create(loc, ptr.getType(), elementType, ptr, + offsetVal); + } + return ptr; +} + /// Creates a MemRef descriptor structure from a list of individual values /// composing that descriptor, in the following order: /// - allocated pointer; diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -72,14 +72,14 @@ auto [strides, offset] = getStridesAndOffset(type); MemRefDescriptor memRefDescriptor(memRefDesc); - Value base = memRefDescriptor.alignedPtr(rewriter, loc); + // Use a canonical representation of the start address so that later + // optimizations have a longer sequence of instructions to CSE. + // If we don't do that we would sprinkle the memref.offset in various + // position of the different address computations. + Value base = + memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), type); Value index; - if (offset != 0) // Skip if offset is zero. - index = ShapedType::isDynamic(offset) - ? memRefDescriptor.offset(rewriter, loc) - : createIndexConstant(rewriter, loc, offset); - for (int i = 0, e = indices.size(); i < e; ++i) { Value increment = indices[i]; if (strides[i] != 1) { // Skip if stride is 1. diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -332,6 +332,8 @@ : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; + explicit AssumeAlignmentOpLowering(LLVMTypeConverter &converter) + : ConvertOpToLLVMPattern(converter) {} LogicalResult matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, @@ -341,28 +343,15 @@ auto loc = op.getLoc(); auto srcMemRefType = op.getMemref().getType().cast(); - // When we convert to LLVM, the input memref must have been normalized - // beforehand. Hence, this call is guaranteed to work. - auto [strides, offset] = getStridesAndOffset(srcMemRefType); - - MemRefDescriptor memRefDescriptor(memref); - Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); - // Skip if offset is zero. - if (offset != 0) { - Value offsetVal = ShapedType::isDynamic(offset) - ? memRefDescriptor.offset(rewriter, loc) - : createIndexConstant(rewriter, loc, offset); - Type elementType = - typeConverter->convertType(srcMemRefType.getElementType()); - ptr = rewriter.create(loc, ptr.getType(), elementType, ptr, - offsetVal); - } + Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{}, + rewriter); // Emit llvm.assume(memref & (alignment - 1) == 0). // // This relies on LLVM's CSE optimization (potentially after SROA), since // after CSE all memref instances should get de-duplicated into the same // pointer SSA value. + MemRefDescriptor memRefDescriptor(memref); auto intPtrType = getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0); diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir @@ -668,3 +668,32 @@ %1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>> into memref<64xf32, strided<[1], offset: ?>> return %1 : memref<64xf32, strided<[1], offset: ?>> } + +// ----- + +// Check that the address of %arg0 appears with the same value +// in both the llvm.assume and as base of the load. +// This is to make sure that later CSEs and alignment propagation +// will be able to do their job easily. + +// CHECK-LABEL: func @load_and_assume( +// CHECK-SAME: %[[ARG0:.*]]: memref>, +// CHECK: %[[DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[ALIGNED_PTR:.*]] = llvm.extractvalue %[[DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[OFFSET:.*]] = llvm.extractvalue %[[DESC]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[BUFF_ADDR:.*]] = llvm.getelementptr %[[ALIGNED_PTR]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK: %[[INT_TO_PTR:.*]] = llvm.ptrtoint %[[BUFF_ADDR]] : !llvm.ptr to i64 +// CHECK: %[[AND:.*]] = llvm.and %[[INT_TO_PTR]], {{.*}} : i64 +// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[AND]], {{.*}} : i64 +// CHECK: "llvm.intr.assume"(%[[CMP]]) : (i1) -> () +// CHECK: %[[LD_ADDR:.*]] = llvm.getelementptr %[[BUFF_ADDR]][%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK: %[[VAL:.*]] = llvm.load %[[LD_ADDR]] : !llvm.ptr -> f32 +// CHECK: return %[[VAL]] : f32 +func.func @load_and_assume( + %arg0: memref>, + %i0: index, %i1: index) + -> f32 { + memref.assume_alignment %arg0, 16 : memref> + %2 = memref.load %arg0[%i0, %i1] : memref> + func.return %2 : f32 +}