diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -401,22 +401,22 @@ ss << "=r,"; for (unsigned i = 0; i < matASize + matBSize + matCSize; i++) ss << "r,"; - // The final two operands are for the sparsity metadata and sparsity selector. - ss << "r,r"; + // The final operand is for the sparsity metadata. + // The sparsity selector appears as direct literal. + ss << "r"; ss.flush(); return str; } /// Returns the string for the `mma.sp.sync` instruction that corresponds to -/// the give parameters. Note that this function doesn't do any validation, +/// the given parameters. Note that this function doesn't do any validation, /// it's expected that the provided parameters correspond to a valid /// instruction. -static std::string -buildMmaSparseAsmString(const std::array &shape, unsigned matASize, - unsigned matBSize, unsigned matCSize, - NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB, - NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD, - std::optional overflow) { +static std::string buildMmaSparseAsmString( + const std::array &shape, unsigned matASize, unsigned matBSize, + unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB, + NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD, + std::optional overflow, unsigned metaDataSelector) { auto ptxTypeStr = [](NVVM::MMATypes ptxType) { return NVVM::stringifyMMATypes(ptxType); }; @@ -442,7 +442,8 @@ ss << "},"; } ss << "$" << asmArgIdx++ << ","; - ss << "$" << asmArgIdx++ << ";"; + assert(metaDataSelector <= 1); + ss << "0x" << metaDataSelector << ";"; ss.flush(); return asmStr; } @@ -459,22 +460,21 @@ auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), LLVM::AsmDialect::AD_ATT); - std::string asmStr = buildMmaSparseAsmString( - shape, unpackedAData.size(), unpackedB.size(), unpackedC.size(), ptxTypeA, - ptxTypeB, ptxTypeC, ptxTypeD, overflow); - std::string constraintStr = buildMmaSparseAsmConstraintString( - unpackedAData.size(), unpackedB.size(), unpackedC.size()); + const unsigned matASize = unpackedAData.size(); + const unsigned matBSize = unpackedB.size(); + const unsigned matCSize = unpackedC.size(); - Value selectorVal = rewriter.create( - loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(metadataSelector)); + std::string asmStr = buildMmaSparseAsmString( + shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC, + ptxTypeD, overflow, metadataSelector); + std::string constraintStr = + buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize); SmallVector asmVals; - asmVals.reserve(unpackedAData.size() + unpackedB.size() + unpackedC.size() + - 2); + asmVals.reserve(matASize + matBSize + matCSize + 1); for (ArrayRef args : {unpackedAData, unpackedB, unpackedC}) llvm::append_range(asmVals, args); asmVals.push_back(indexData); - asmVals.push_back(selectorVal); return rewriter.create(loc, /*resultTypes=*/intrinsicResultType, diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -333,12 +333,11 @@ // CHECK-NOT llvm.extractvalue // CHECK: %[[sparseMetadata:.+]] = llvm.bitcast %{{.+}} : vector<2xi16> to i32 - // CHECK: %[[sparseSelector:.+]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[d:.+]] = llvm.inline_asm has_side_effects asm_dialect = att - // CHECK-SAME: "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {$0,$1},{$2,$3,$4,$5},{$6,$7,$8,$9},{$10,$11},$12,$13;" - // CHECK-SAME: "=r,=r,r,r,r,r,r,r,r,r,r,r,r,r" - // CHECK-SAME: %[[sparseMetadata]], %[[sparseSelector]] : + // CHECK-SAME: "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {$0,$1},{$2,$3,$4,$5},{$6,$7,$8,$9},{$10,$11},$12,0x0;" + // CHECK-SAME: "=r,=r,r,r,r,r,r,r,r,r,r,r,r" + // CHECK-SAME: %[[sparseMetadata]] : // CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 32]} : @@ -372,12 +371,11 @@ // CHECK-NOT llvm.extractvalue // CHECK: %[[sparseMetadata:.+]] = llvm.bitcast %{{.+}} : vector<2xi16> to i32 - // CHECK: %[[sparseSelector:.+]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[d:.+]] = llvm.inline_asm has_side_effects asm_dialect = att - // CHECK-SAME: "mma.sp.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {$0,$1},{$2,$3},{$4,$5},{$6,$7},$8,$9;" - // CHECK-SAME: "=r,=r,r,r,r,r,r,r,r,r" - // CHECK-SAME: %[[sparseMetadata]], %[[sparseSelector]] : + // CHECK-SAME: "mma.sp.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {$0,$1},{$2,$3},{$4,$5},{$6,$7},$8,0x0;" + // CHECK-SAME: "=r,=r,r,r,r,r,r,r,r" + // CHECK-SAME: %[[sparseMetadata]] : // CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 16]} : @@ -413,12 +411,11 @@ // CHECK-NOT llvm.extractvalue // CHECK: %[[sparseMetadata:.+]] = llvm.bitcast %{{.+}} : vector<2xi16> to i32 - // CHECK: %[[sparseSelector:.+]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[d:.+]] = llvm.inline_asm has_side_effects asm_dialect = att - // CHECK-SAME: "mma.sp.sync.aligned.m16n8k64.row.col.satfinite.s32.s8.s8.s32 - // CHECK-SAME: "=r,=r,=r,=r,r,r,r,r,r,r,r,r,r,r,r,r,r,r" - // CHECK-SAME: %[[sparseMetadata]], %[[sparseSelector]] : + // CHECK-SAME: "mma.sp.sync.aligned.m16n8k64.row.col.satfinite.s32.s8.s8.s32 {$0,$1,$2,$3},{$4,$5,$6,$7},{$8,$9,$10,$11},{$12,$13,$14,$15},$16,0x0;" + // CHECK-SAME: "=r,=r,=r,=r,r,r,r,r,r,r,r,r,r,r,r,r,r" + // CHECK-SAME: %[[sparseMetadata]] : // CHECK-SAME: -> !llvm.struct<(i32, i32, i32, i32) %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 64]} :