diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -257,35 +257,67 @@ OpBuilder &builder) const { auto loc = launchOp.getLoc(); auto numKernelOperands = launchOp.getNumKernelOperands(); - auto arguments = typeConverter.promoteOperands( - loc, launchOp.getOperands().take_back(numKernelOperands), - operands.take_back(numKernelOperands), builder); - auto numArguments = arguments.size(); - SmallVector argumentTypes; - argumentTypes.reserve(numArguments); - for (auto argument : arguments) - argumentTypes.push_back(argument.getType().cast()); - auto structType = LLVM::LLVMType::createStructTy(argumentTypes, StringRef()); - auto one = builder.create(loc, llvmInt32Type, - builder.getI32IntegerAttr(1)); - auto structPtr = builder.create( + auto arguments = operands.take_back(numKernelOperands); + + // Flatten the arguments. + SmallVector flatArgValues; + flatArgValues.reserve(arguments.size()); + for (auto argument : arguments) { + auto llvmArgType = argument.getType().cast(); + if (!llvmArgType.isStructTy()) { + flatArgValues.push_back(argument); + continue; + } + for (int32_t j = 0, ej = llvmArgType.getStructNumElements(); j < ej; ++j) { + auto elemType = llvmArgType.getStructElementType(j); + if (elemType.isArrayTy()) { + auto arrayElemType = elemType.getArrayElementType(); + for (int32_t k = 0, ek = elemType.getArrayNumElements(); k < ek; ++k) { + Value elem = builder.create( + loc, arrayElemType, argument, builder.getI32ArrayAttr({j, k})); + flatArgValues.push_back(elem); + } + continue; + } + assert((elemType.isIntegerTy() || elemType.isFloatTy() || + elemType.isDoubleTy() || elemType.isPointerTy()) && + "expected scalar type"); + Value elem = builder.create( + loc, elemType, argument, builder.getI32ArrayAttr(j)); + flatArgValues.push_back(elem); + } + } + + // Get types of flattened args. + SmallVector flatArgTypes; + flatArgTypes.reserve(flatArgValues.size()); + for (auto argument : flatArgValues) { + auto llvmArgType = argument.getType().cast(); + flatArgTypes.push_back(llvmArgType); + } + + // Promote the flattened args to stack. + auto structType = LLVM::LLVMType::createStructTy(flatArgTypes, StringRef()); + Value one = builder.create(loc, llvmInt32Type, + builder.getI32IntegerAttr(1)); + Value structPtr = builder.create( loc, structType.getPointerTo(), one, /*alignment=*/0); - auto arraySize = builder.create( - loc, llvmInt32Type, builder.getI32IntegerAttr(numArguments)); - auto arrayPtr = builder.create(loc, llvmPointerPointerType, - arraySize, /*alignment=*/0); - auto zero = builder.create(loc, llvmInt32Type, - builder.getI32IntegerAttr(0)); - for (auto en : llvm::enumerate(arguments)) { - auto index = builder.create( + Value arraySize = builder.create( + loc, llvmInt32Type, builder.getI32IntegerAttr(flatArgTypes.size())); + Value arrayPtr = builder.create(loc, llvmPointerPointerType, + arraySize, /*alignment=*/0); + Value zero = builder.create(loc, llvmInt32Type, + builder.getI32IntegerAttr(0)); + for (auto &en : llvm::enumerate(flatArgValues)) { + Value index = builder.create( loc, llvmInt32Type, builder.getI32IntegerAttr(en.index())); - auto fieldPtr = + Value fieldPtr = builder.create(loc, structType.getPointerTo(), structPtr, - ArrayRef{zero, index.getResult()}); + ArrayRef{zero, index}); builder.create(loc, en.value(), fieldPtr); - auto elementPtr = builder.create(loc, llvmPointerPointerType, - arrayPtr, index.getResult()); - auto casted = + Value elementPtr = builder.create(loc, llvmPointerPointerType, + arrayPtr, index); + Value casted = builder.create(loc, llvmPointerType, fieldPtr); builder.create(loc, casted, elementPtr); } diff --git a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir --- a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir +++ b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir @@ -13,23 +13,36 @@ } } - llvm.func @foo() { - %0 = "op"() : () -> (!llvm.float) - %1 = "op"() : () -> (!llvm.ptr) - %cst = llvm.mlir.constant(8 : index) : !llvm.i64 - + llvm.func @foo(%f: !llvm.float, %f_ptr: !llvm.ptr, + %memref : !llvm.struct< + (ptr, ptr, i64, array<1 x i64>, array<1 x i64>) + >) { // CHECK: %[[addressof:.*]] = llvm.mlir.addressof @[[global]] // CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) // CHECK: %[[binary:.*]] = llvm.getelementptr %[[addressof]][%[[c0]], %[[c0]]] // CHECK-SAME: -> !llvm.ptr + // CHECK: %[[module:.*]] = llvm.call @mgpuModuleLoad(%[[binary]]) : (!llvm.ptr) -> !llvm.ptr // CHECK: %[[func:.*]] = llvm.call @mgpuModuleGetFunction(%[[module]], {{.*}}) : (!llvm.ptr, !llvm.ptr) -> !llvm.ptr // CHECK: llvm.call @mgpuStreamCreate - // CHECK: llvm.call @mgpuLaunchKernel - // CHECK: llvm.call @mgpuStreamSynchronize - "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1) { kernel = @kernel_module::@kernel } - : (!llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.float, !llvm.ptr) -> () + // CHECK: llvm.alloca {{.*}} -> !llvm.ptr, ptr, ptr, i64, i64, i64)>> + + // CHECK: llvm.call @mgpuLaunchKernel({{.*}}) : (!llvm.ptr, !llvm.i64, + // CHECK-SAME: !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, + // CHECK-SAME: !llvm.i32, !llvm.ptr, !llvm.ptr>, + // CHECK-SAME: !llvm.ptr>) -> !llvm.void + + // CHECK: llvm.call @mgpuStreamSynchronize + %cst = llvm.mlir.constant(8 : index) : !llvm.i64 + "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %f, %f_ptr, %memref) { + kernel = @kernel_module::@kernel + } : (!llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, + !llvm.float, !llvm.ptr, + !llvm.struct< + (ptr, ptr, i64, array<1 x i64>, array<1 x i64>) + >) -> () llvm.return }