diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -533,7 +533,7 @@ /// : (!llvm.ptr, i64) -> !llvm.ptr /// `sizeBytes` = llvm.ptrtoint %gep : !llvm.ptr to i64 void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, - ArrayRef dynamicSizes, + ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl &sizes, SmallVectorImpl &strides, diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -373,19 +373,19 @@ return failure(); auto loc = allocOp.getLoc(); + auto adaptor = gpu::AllocOpAdaptor(operands, allocOp->getAttrDictionary()); // Get shape of the memref as values: static sizes are constant // values and dynamic sizes are passed to 'alloc' as operands. SmallVector shape; SmallVector strides; Value sizeBytes; - getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, shape, strides, - sizeBytes); + getMemRefDescriptorSizes(loc, memRefType, adaptor.dynamicSizes(), rewriter, + shape, strides, sizeBytes); // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. Type elementPtrType = this->getElementPtrType(memRefType); - auto adaptor = gpu::AllocOpAdaptor(operands, allocOp->getAttrDictionary()); auto stream = adaptor.asyncDependencies().front(); Value allocatedPtr = allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult(0); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1093,11 +1093,14 @@ } void ConvertToLLVMPattern::getMemRefDescriptorSizes( - Location loc, MemRefType memRefType, ArrayRef dynamicSizes, + Location loc, MemRefType memRefType, ValueRange dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl &sizes, SmallVectorImpl &strides, Value &sizeBytes) const { assert(isConvertibleAndHasIdentityMaps(memRefType) && "layout maps must have been normalized away"); + assert(count(memRefType.getShape(), ShapedType::kDynamicSize) == + static_cast(dynamicSizes.size()) && + "dynamicSizes size doesn't match dynamic sizes count in memref shape"); sizes.reserve(memRefType.getRank()); unsigned dynamicIndex = 0; @@ -4092,8 +4095,7 @@ continue; } if (auto memrefType = operand.getType().dyn_cast()) { - MemRefDescriptor::unpack(builder, loc, llvmOperand, - operand.getType().cast(), + MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType, promotedOperands); continue; } diff --git a/mlir/test/Conversion/GPUCommon/lower-alloc-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-alloc-to-gpu-runtime-calls.mlir --- a/mlir/test/Conversion/GPUCommon/lower-alloc-to-gpu-runtime-calls.mlir +++ b/mlir/test/Conversion/GPUCommon/lower-alloc-to-gpu-runtime-calls.mlir @@ -1,16 +1,19 @@ // RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s module attributes {gpu.container_module} { - func @main() { + // CHECK-LABEL: llvm.func @main + // CHECK-SAME: %[[size:.*]]: i64 + func @main(%size : index) { // CHECK: %[[stream:.*]] = llvm.call @mgpuStreamCreate() %0 = gpu.wait async - // CHECK: %[[size_bytes:.*]] = llvm.ptrtoint + // CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}}[%[[size]]] + // CHECK: %[[size_bytes:.*]] = llvm.ptrtoint %[[gep]] // CHECK: llvm.call @mgpuMemAlloc(%[[size_bytes]], %[[stream]]) - %1, %2 = gpu.alloc async [%0] () : memref<13xf32> + %1, %2 = gpu.alloc async [%0] (%size) : memref // CHECK: %[[float_ptr:.*]] = llvm.extractvalue {{.*}}[0] // CHECK: %[[void_ptr:.*]] = llvm.bitcast %[[float_ptr]] // CHECK: llvm.call @mgpuMemFree(%[[void_ptr]], %[[stream]]) - %3 = gpu.dealloc async [%2] %1 : memref<13xf32> + %3 = gpu.dealloc async [%2] %1 : memref // CHECK: llvm.call @mgpuStreamSynchronize(%[[stream]]) // CHECK: llvm.call @mgpuStreamDestroy(%[[stream]]) gpu.wait [%3]