diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -121,7 +121,7 @@ Results<(outs LLVM_Type:$res)>, Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount)> { ... let extraClassDefinition = [{ - const char* $cppClass::getPtx() { return \"mbarrier.arrive.expect_tx.b64 %0, [%1], %2;\"; } + std::string $cppClass::getPtx() { return std::string(\"mbarrier.arrive.expect_tx.b64 %0, [%1], %2;\"); } }\]; } ``` @@ -160,7 +160,7 @@ >, InterfaceMethod< /*desc=*/[{ Returns PTX code. }], - /*retType=*/"const char*", + /*retType=*/"std::string", /*methodName=*/"getPtx" >, InterfaceMethod< @@ -377,7 +377,7 @@ Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount)> { let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands) `->` type($res)"; let extraClassDefinition = [{ - const char* $cppClass::getPtx() { return "mbarrier.arrive.expect_tx.b64 %0, [%1], %2;"; } + std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 %0, [%1], %2;"); } }]; } @@ -387,7 +387,7 @@ Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount)> { let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands) `->` type($res)"; let extraClassDefinition = [{ - const char* $cppClass::getPtx() { return "mbarrier.arrive.expect_tx.shared.b64 %0, [%1], %2;"; } + std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 %0, [%1], %2;"); } }]; } @@ -397,12 +397,12 @@ Arguments<(ins LLVM_i64ptr_any:$addr, LLVM_Type:$token)> { let assemblyFormat = "$addr `,` $token attr-dict `:` type(operands) `->` type($res)"; let extraClassDefinition = [{ - const char* $cppClass::getPtx() { - return "{\n\t" + std::string $cppClass::getPtx() { + return std::string("{\n\t" ".reg .pred P1; \n\t" "mbarrier.try_wait.parity.b64 P1, [%1], %2; \n\t" "selp.b32 %0, 1, 0, P1; \n\t" - "}"; + "}"); } }]; } @@ -413,12 +413,12 @@ Arguments<(ins LLVM_i64ptr_shared:$addr, LLVM_Type:$token)> { let assemblyFormat = "$addr `,` $token attr-dict `:` type(operands) `->` type($res)"; let extraClassDefinition = [{ - const char* $cppClass::getPtx() { - return "{\n\t" + std::string $cppClass::getPtx() { + return std::string("{\n\t" ".reg .pred P1; \n\t" "mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \n\t" "selp.b32 %0, 1, 0, P1; \n\t" - "}"; + "}"); } }]; } @@ -567,11 +567,11 @@ } }]; let extraClassDefinition = [{ - const char* $cppClass::getPtx() { + std::string $cppClass::getPtx() { if(getModifier() == NVVM::LoadCacheModifierKind::CG) - return "cp.async.cg.shared.global [%0], [%1], %2, %3;\n"; + return std::string("cp.async.cg.shared.global [%0], [%1], %2, %3;\n"); if(getModifier() == NVVM::LoadCacheModifierKind::CA) - return "cp.async.ca.shared.global [%0], [%1], %2, %3;\n"; + return std::string("cp.async.ca.shared.global [%0], [%1], %2, %3;\n"); llvm_unreachable("unsupported cache modifier"); } }]; diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp --- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp +++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp @@ -33,6 +33,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/raw_ostream.h" +#include #include #define DEBUG_TYPE "nvvm-to-llvm" @@ -53,7 +54,7 @@ class PtxBuilder { Operation *op; PatternRewriter &rewriter; - const char *asmStr; + std::string asmStr; SmallVector asmVals; std::string asmConstraints; bool sideEffects; @@ -85,9 +86,10 @@ } public: - PtxBuilder(Operation *op, PatternRewriter &rewriter, const char *ptxAsm, + PtxBuilder(Operation *op, PatternRewriter &rewriter, std::string ptxAsm, bool sideEffects = false) - : op(op), rewriter(rewriter), asmStr(ptxAsm), sideEffects(sideEffects) {} + : op(op), rewriter(rewriter), asmStr(std::move(ptxAsm)), + sideEffects(sideEffects) {} void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read) { llvm::raw_string_ostream ss(asmConstraints); @@ -116,10 +118,13 @@ asmConstraints[asmConstraints.size() - 1] == ',') asmConstraints.pop_back(); + // asm keywords expects %, but inline assembly uses $. Replace all % with $ + std::replace(asmStr.begin(), asmStr.end(), '%', '$'); + return rewriter.create( op->getLoc(), resultType, /*operands=*/asmVals, - /*asm_string=*/asmStr, + /*asm_string=*/llvm::StringRef(asmStr), /*constraints=*/asmConstraints.data(), /*has_side_effects=*/sideEffects, /*is_align_stack=*/false, diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -2,28 +2,28 @@ // CHECK-LABEL : @init_mbarrier_arrive_expect_tx llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32) -> i64 { - //CHECK : llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 %0, [%1], %2;", "=l,r,r" %{{.*}}, %{{.*}} : (!llvm.ptr<3>, i32) -> i64 + //CHECK : llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 $0, [$1], $2;", "=l,r,r" %{{.*}}, %{{.*}} : (!llvm.ptr<3>, i32) -> i64 %res = nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32 -> i64 llvm.return %res : i64 } // CHECK-LABEL : @init_mbarrier_arrive_expect_tx_generic llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32)-> i64 { - // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 %0, [%1], %2;", "=l,l,r" %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> i64 + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 $0, [$1], $2;", "=l,l,r" %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> i64 %res = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32 -> i64 llvm.return %res : i64 } // CHECK-LABEL : @init_mbarrier_try_wait.parity.shared llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %token : i32) -> i32 { - // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \0A\09selp.b32 %0, 1, 0, P1; \0A\09}", "=r,r,r" %{{.*}}, %{{.*}} : (!llvm.ptr<3>, i32) -> i32 + // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.shared.b64 P1, [$1], $2; \0A\09selp.b32 $0, 1, 0, P1; \0A\09}", "=r,r,r" %{{.*}}, %{{.*}} : (!llvm.ptr<3>, i32) -> i32 %res = nvvm.mbarrier.try_wait.parity.shared %barrier, %token : !llvm.ptr<3>, i32 -> i32 llvm.return %res : i32 } // CHECK-LABEL : @init_mbarrier_try_wait.parity llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %token : i32) -> i32{ - // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.b64 P1, [%1], %2; \0A\09selp.b32 %0, 1, 0, P1; \0A\09}", "=r,l,r" %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> i32 + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.b64 P1, [$1], $2; \0A\09selp.b32 $0, 1, 0, P1; \0A\09}", "=r,l,r" %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> i32 %res = nvvm.mbarrier.try_wait.parity %barrier, %token : !llvm.ptr, i32 -> i32 llvm.return %res : i32 } @@ -39,9 +39,9 @@ // CHECK-LABEL : @async_cp_zfill func.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32) { - // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [%0], [%1], %2, %3;\0A", "r,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32) -> !llvm.void + // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32) -> !llvm.void nvvm.cp.async.shared.global %dst, %src, 16, cache = cg, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32 - // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.ca.shared.global [%0], [%1], %2, %3;\0A", "r,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32) -> !llvm.void + // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.ca.shared.global [$0], [$1], $2, $3;\0A", "r,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32) -> !llvm.void nvvm.cp.async.shared.global %dst, %src, 4, cache = ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32 return }