diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -20,8 +20,13 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#define DEBUG_TYPE "nvgpu-to-nvvm" +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +#define DBGSE() (llvm::dbgs()) + namespace mlir { #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS #include "mlir/Conversion/Passes.h.inc" @@ -980,9 +985,14 @@ }; int ex4LSB = 4; - Value strideDim = makeConst((layout << 3) >> ex4LSB); int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0); - Value leadDim = makeConst((sizeN * layout) >> ex4LSB); + uint64_t strideDimVal = (layout << 3) >> ex4LSB; + uint64_t leadDimVal = (sizeN * layout) >> ex4LSB; + uint64_t offsetVal = 0; + + Value strideDim = makeConst(strideDimVal); + Value leadDim = makeConst(leadDimVal); + Value baseAddr = getStridedElementPtr( op->getLoc(), cast(op.getTensor().getType()), adaptor.getTensor(), {}, rewriter); @@ -996,7 +1006,7 @@ // // [62,64) swizzle type dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit); // // [49,52) base_offset - dsc = insertBit(dsc, makeConst(0), startOffsetBit); + dsc = insertBit(dsc, makeConst(offsetVal), startOffsetBit); // // [32,46) stride dsc = insertBit(dsc, strideDim, startStrideBit); // // [16,30) leading dimension @@ -1004,6 +1014,14 @@ // // [0,14) start_address dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit); + LLVM_DEBUG(DBGS() << "Generating wgmma.descriptor: " + << "leading_off:" << leadDimVal << "\t" + << "stride_off :" << strideDimVal << "\t" + << "base_offset:" << offsetVal << "\t" + << "layout_type:" << swizzle << " (" + << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) + << ")\n start_addr : " << baseAddr << "\n"); + rewriter.replaceOp(op, dsc); return success(); }