diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp @@ -349,10 +349,18 @@ if (!gpu::GPUDialect::isKernel(funcOp)) return failure(); - // TODO(antiagainst): we are dictating the ABI by ourselves here; it should be - // specified outside. SmallVector argABI; for (auto argIndex : llvm::seq(0, funcOp.getNumArguments())) { + // If the ABI is already specified, use it. + auto abiAttr = funcOp.getArgAttrOfType( + argIndex, spirv::getInterfaceVarABIAttrName()); + if (abiAttr) { + argABI.push_back(abiAttr); + continue; + } + // todo(ravishankarm): Use the "default ABI". Remove this in a follow up + // CL. Staging this to make this easy to revert in case of breakages out of + // tree. Optional sc; if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat()) sc = spirv::StorageClass::StorageBuffer;