diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -118,20 +118,29 @@ MemRefDescriptor memrefDescriptor(memref); Type llvmI64 = this->typeConverter->convertType(rewriter.getI64Type()); Type llvm2xI32 = this->typeConverter->convertType(VectorType::get(2, i32)); + Value c32I64 = rewriter.create( + loc, llvmI64, rewriter.getI64IntegerAttr(32)); Value resource = rewriter.create(loc, llvm4xI32); Value ptr = memrefDescriptor.alignedPtr(rewriter, loc); Value ptrAsInt = rewriter.create(loc, llvmI64, ptr); - Value ptrAsInts = - rewriter.create(loc, llvm2xI32, ptrAsInt); - for (int64_t i = 0; i < 2; ++i) { - Value idxConst = this->createIndexConstant(rewriter, loc, i); - Value part = - rewriter.create(loc, ptrAsInts, idxConst); - resource = rewriter.create( - loc, llvm4xI32, resource, part, idxConst); - } + Value lowHalf = rewriter.create(loc, llvmI32, ptrAsInt); + resource = rewriter.create( + loc, llvm4xI32, resource, lowHalf, + this->createIndexConstant(rewriter, loc, 0)); + + // Bits 48-63 are used both for the stride of the buffer and (on gfx10) for + // enabling swizzling. Prevent the high bits of pointers from accidentally + // setting those flags. + Value highHalfShifted = rewriter.create( + loc, llvmI32, rewriter.create(loc, ptrAsInt, c32I64)); + Value highHalfTruncated = rewriter.create( + loc, llvmI32, highHalfShifted, + createI32Constant(rewriter, loc, 0x0000ffff)); + resource = rewriter.create( + loc, llvm4xI32, resource, highHalfTruncated, + this->createIndexConstant(rewriter, loc, 1)); Value numRecords; if (memrefType.hasStaticShape()) { diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir --- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir @@ -3,11 +3,18 @@ // CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i32 func.func @gpu_gcn_raw_buffer_load_i32(%buf: memref<64xi32>, %idx: i32) -> i32 { + // CHECK: %[[ptr:.*]] = llvm.ptrtoint + // CHECK: %[[lowHalf:.*]] = llvm.trunc %[[ptr]] : i64 to i32 + // CHECK: %[[resource_1:.*]] = llvm.insertelement %[[lowHalf]] + // CHECK: %[[highHalfI64:.*]] = llvm.lshr %[[ptr]] + // CHECK: %[[highHalfI32:.*]] = llvm.trunc %[[highHalfI64]] : i64 to i32 + // CHECK: %[[highHalf:.*]] = llvm.and %[[highHalfI32]], %{{.*}} : i32 + // CHECK: %[[resource_2:.*]] = llvm.insertelement %[[highHalf]], %[[resource_1]] // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i32) - // CHECK: llvm.insertelement{{.*}}%[[numRecords]] + // CHECK: %[[resource_3:.*]] = llvm.insertelement %[[numRecords]], %[[resource_2]] // CHECK: %[[word3:.*]] = llvm.mlir.constant(159744 : i32) // RDNA: %[[word3:.*]] = llvm.mlir.constant(822243328 : i32) - // CHECK: %[[resource:.*]] = llvm.insertelement{{.*}}%[[word3]] + // CHECK: %[[resource:.*]] = llvm.insertelement %[[word3]], %[[resource_3]] // CHECK: %[[ret:.*]] = rocdl.raw.buffer.load %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i32 // CHECK: return %[[ret]] %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xi32>, i32 -> i32