diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1497,7 +1497,7 @@ let assemblyFormat = "`<` $value `>`"; } -def NVVM_WgmmaMmaSyncOp : NVVM_Op<"wgmma.mma_async", +def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async", [DeclareOpInterfaceMethods, PredOpTrait<"input struct and result struct must be the same type", TCresIsSameAsOpBase<0, 0>>,]> diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -708,7 +708,7 @@ return success(); } -LogicalResult NVVM::WgmmaMmaSyncOp::verify() { +LogicalResult NVVM::WgmmaMmaAsyncOp::verify() { Value outValue = getResults(); auto stype = dyn_cast(outValue.getType()); if (!stype) @@ -889,7 +889,7 @@ return success(); } -std::string NVVM::WgmmaMmaSyncOp::getPtx() { +std::string NVVM::WgmmaMmaAsyncOp::getPtx() { int m = getShape().getM(), n = getShape().getN(), k = getShape().getK(); bool isF16 = getTypeA() == mlir::NVVM::MMATypes::f16 || @@ -952,7 +952,7 @@ return ptx; } -void NVVM::WgmmaMmaSyncOp::getAsmValues( +void NVVM::WgmmaMmaAsyncOp::getAsmValues( RewriterBase &rewriter, llvm::SmallVectorImpl> &asmValues) {