diff --git a/mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt b/mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt --- a/mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt @@ -1,5 +1,9 @@ -add_mlir_dialect(NVGPU nvgpu) -add_mlir_doc(NVGPU NVGPU Dialects/ -gen-dialect-doc) +add_subdirectory(IR) +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name NVGPU) +mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix NVGPU) +mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix NVGPU) +add_public_tablegen_target(MLIRNVGPUPassIncGen) -set(LLVM_TARGET_DEFINITIONS NVGPU.td) +add_mlir_doc(Passes NVGPUPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt b/mlir/include/mlir/Dialect/NVGPU/IR/CMakeLists.txt copy from mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt copy to mlir/include/mlir/Dialect/NVGPU/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/NVGPU/IR/CMakeLists.txt @@ -1,5 +1,2 @@ add_mlir_dialect(NVGPU nvgpu) add_mlir_doc(NVGPU NVGPU Dialects/ -gen-dialect-doc) - - -set(LLVM_TARGET_DEFINITIONS NVGPU.td) diff --git a/mlir/include/mlir/Dialect/NVGPU/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td rename from mlir/include/mlir/Dialect/NVGPU/NVGPU.td rename to mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td diff --git a/mlir/include/mlir/Dialect/NVGPU/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h rename from mlir/include/mlir/Dialect/NVGPU/NVGPUDialect.h rename to mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h --- a/mlir/include/mlir/Dialect/NVGPU/NVGPUDialect.h +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h @@ -32,9 +32,9 @@ } // namespace nvgpu } // namespace mlir -#include "mlir/Dialect/NVGPU/NVGPUDialect.h.inc" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h.inc" #define GET_OP_CLASSES -#include "mlir/Dialect/NVGPU/NVGPU.h.inc" +#include "mlir/Dialect/NVGPU/IR/NVGPU.h.inc" #endif // MLIR_DIALECT_NVGPU_NVGPUDIALECT_H_ diff --git a/mlir/include/mlir/Dialect/NVGPU/Passes.h b/mlir/include/mlir/Dialect/NVGPU/Passes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/NVGPU/Passes.h @@ -0,0 +1,35 @@ +//===- Passes.h - NVGPU pass entry points -----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes that expose pass constructors. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_NVGPU_PASSES_H_ +#define MLIR_DIALECT_NVGPU_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace nvgpu { + +/// Create a pass to optimize shared memory reads and writes. +std::unique_ptr createOptimizeSharedMemoryPass(); + +} // namespace nvgpu + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/NVGPU/Passes.h.inc" + +} // namespace mlir + +#endif // MLIR_DIALECT_NVGPU_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/NVGPU/Passes.td b/mlir/include/mlir/Dialect/NVGPU/Passes.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/NVGPU/Passes.td @@ -0,0 +1,22 @@ +//===-- Passes.td - NvGpu pass definition file ------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_NVGPU_PASSES_TD_ +#define MLIR_DIALECT_NVGPU_PASSES_TD_ + +include "mlir/Pass/PassBase.td" + +def OptimizeSharedMemory : Pass<"nvgpu-optimize-shared-memory"> { + let summary = "Optimizes accesses to shard memory memrefs in order to reduce bank conflicts."; + let constructor = "mlir::nvgpu::createOptimizeSharedMemoryPass()"; + let dependentDialects = [ + "memref::MemRefDialect", "vector::VectorDialect" + ]; +} + +#endif // MLIR_DIALECT_NVGPU_PASSES_TD_ diff --git a/mlir/include/mlir/Dialect/NVGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/NVGPU/Transforms/Transforms.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/NVGPU/Transforms/Transforms.h @@ -0,0 +1,47 @@ +//===- Transforms.h - NVGPU Dialect transformations --------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares functions that assist transformations for the nvgpu +// dialect. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_NVGPU_TRANSFORMS_TRANSFORMS_H_ +#define MLIR_DIALECT_NVGPU_TRANSFORMS_TRANSFORMS_H_ + +#include "mlir/IR/Operation.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace nvgpu { + +/// Optimizes vectorized accesses to a shared memory buffer specified by +/// memrefValue. This transformation assumes the following: +/// 1) All relevant accesses to `memrefValue` are contained with `parentOp`. +/// 2) The function will fail precondition checks if any subviews are +/// taken of `memrefValue`. All reads/writes to `memrefValue` should occur +/// through `memrefValue` directly. +/// +/// Shared memory bank conflicts occur when multiple threads attempt to read or +/// write locations assigned to the same shared memory bank. For `2^N` byte +/// vectorized accesses, we need to be concerned with conflicts among threads +/// identified as `(tid) -> tid.floordiv(2^{7-N})`. As such, this transformation +/// changes any indexed memory access (vector.load, memref.load, nvgpu.ldmatrix, +/// etc) such that the final dimension's index value is permuted such that +/// `newColIndex = oldColIndex % vectorSize + +/// perm[rowIndex](oldColIndex/vectorSize, rowIndex)` where `rowIndex` is the +/// index for the second-to last dimension and `perm[rowIndex]` is a permutation +/// function that depends on the row Index. The permutation function is chosen +/// to ensure that sequential distributed+vectorized reads/writes down a single +/// dimension of the memref have minimal conflicts. +mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp, + Value memrefValue); + +} // namespace nvgpu +} // namespace mlir + +#endif // MLIR_DIALECT_NVGPU_TRANSFORMS_TRANSFORMS_H_ diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -40,7 +40,7 @@ #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/NVGPU/NVGPUDialect.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/PDL/IR/PDL.h" diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -24,6 +24,7 @@ #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/NVGPU/Passes.h" #include "mlir/Dialect/Quant/Passes.h" #include "mlir/Dialect/SCF/Passes.h" #include "mlir/Dialect/SPIRV/Transforms/Passes.h" @@ -64,6 +65,7 @@ registerGpuSerializeToCubinPass(); registerGpuSerializeToHsacoPass(); registerLinalgPasses(); + registerNVGPUPasses(); registerSparseTensorPasses(); LLVM::registerLLVMPasses(); memref::registerMemRefPasses(); diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -12,7 +12,7 @@ #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" -#include "mlir/Dialect/NVGPU/NVGPUDialect.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" using namespace mlir; diff --git a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp --- a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp +++ b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp @@ -13,7 +13,7 @@ #include "NvGpuSupport.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/NVGPU/NVGPUDialect.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" namespace mlir { diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -20,7 +20,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/NVGPU/NVGPUDialect.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" diff --git a/mlir/lib/Dialect/NVGPU/CMakeLists.txt b/mlir/lib/Dialect/NVGPU/CMakeLists.txt --- a/mlir/lib/Dialect/NVGPU/CMakeLists.txt +++ b/mlir/lib/Dialect/NVGPU/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -10,7 +10,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/NVGPU/NVGPUDialect.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" @@ -21,13 +21,13 @@ using namespace mlir; using namespace mlir::nvgpu; -#include "mlir/Dialect/NVGPU/NVGPUDialect.cpp.inc" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc" void nvgpu::NVGPUDialect::initialize() { addTypes(); addOperations< #define GET_OP_LIST -#include "mlir/Dialect/NVGPU/NVGPU.cpp.inc" +#include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc" >(); } @@ -88,4 +88,4 @@ } #define GET_OP_CLASSES -#include "mlir/Dialect/NVGPU/NVGPU.cpp.inc" +#include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc" diff --git a/mlir/lib/Dialect/NVGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/NVGPU/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/NVGPU/Transforms/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_dialect_library(MLIRNVGPUTransforms + OptimizeSharedMemory.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/NVGPU + + DEPENDS + MLIRNVGPUPassIncGen + + LINK_LIBS PUBLIC + MLIRArithmeticDialect + MLIRGPUOps + MLIRIR + MLIRMemRefDialect + MLIRPass + MLIRTransforms + MLIRVectorDialect + MLIRVectorUtils +) diff --git a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp @@ -0,0 +1,269 @@ +//===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements transforms to optimize accesses to shared memory. +// +//===----------------------------------------------------------------------===// +#include "PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/NVGPU/Passes.h" +#include "mlir/Dialect/NVGPU/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/MathExtras.h" + +using namespace mlir; +using namespace mlir::nvgpu; + +/// The size of a shared memory line according to NV documentation. +constexpr int64_t kSharedMemoryLineSizeBytes = 128; +/// We optimize for 128bit accesses, but this can be made an argument in the +/// future. +constexpr int64_t kDefaultVectorSizeBits = 128; + +/// Uses `srcIndexValue` to permute `tgtIndexValue` via +/// `result = xor(floordiv(srcIdxVal,permuteEveryN), +/// floordiv(tgtIdxVal,vectorSize))) +/// + tgtIdxVal % vectorSize` +/// This is done using an optimized sequence of `arith` operations. +static Value permuteVectorOffset(OpBuilder &b, Location loc, + ArrayRef indices, MemRefType memrefTy, + int64_t srcDim, int64_t tgtDim) { + // Adjust the src index to change how often the permutation changes + // if necessary. + Value src = indices[srcDim]; + + // We only want to permute every N iterations of the target dim where N is + // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)). + const int64_t permuteEveryN = std::max( + 1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) * + memrefTy.getElementTypeBitWidth()) / + 8)); + + // clang-format off + // Index bit representation (b0 = least significant bit) for dim(1) + // of a `memref` is as follows: + // N := log2(128/elementSizeBits) + // M := log2(dimSize(1)) + // then + // bits[0:N] = sub-vector element offset + // bits[N:M] = vector index + // clang-format on + int64_t N = + llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth()); + int64_t M = llvm::Log2_64(memrefTy.getDimSize(tgtDim)); + + // Capture bits[0:(M-N)] of src by first creating a (M-N) mask. + int64_t mask = (1 << (M - N)) - 1; + if (permuteEveryN > 1) + mask = mask << llvm::Log2_64(permuteEveryN); + Value srcBits = b.create(loc, mask); + srcBits = b.create(loc, src, srcBits); + + // Use the src bits to permute the target bits b[N:M] containing the + // vector offset. + if (permuteEveryN > 1) { + int64_t shlBits = N - llvm::Log2_64(permuteEveryN); + if (shlBits > 0) { + Value finalShiftVal = b.create(loc, shlBits); + srcBits = b.createOrFold(loc, srcBits, finalShiftVal); + } else if (shlBits < 0) { + Value finalShiftVal = b.create(loc, -1 * shlBits); + srcBits = b.createOrFold(loc, srcBits, finalShiftVal); + } + } else { + Value finalShiftVal = b.create(loc, N); + srcBits = b.createOrFold(loc, srcBits, finalShiftVal); + } + + Value permutedVectorIdx = + b.create(loc, indices[tgtDim], srcBits); + return permutedVectorIdx; +} + +static void transformIndices(OpBuilder &builder, Location loc, + SmallVector &indices, + MemRefType memrefTy, int64_t srcDim, + int64_t tgtDim) { + indices[tgtDim] = + permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim); +} + +Operation::operand_range getIndices(Operation *op) { + if (auto ldmatrixOp = dyn_cast(op)) + return ldmatrixOp.indices(); + if (auto copyOp = dyn_cast(op)) + return copyOp.dstIndices(); + if (auto loadOp = dyn_cast(op)) + return loadOp.indices(); + if (auto storeOp = dyn_cast(op)) + return storeOp.indices(); + if (auto vectorReadOp = dyn_cast(op)) + return vectorReadOp.getIndices(); + if (auto vectorStoreOp = dyn_cast(op)) + return vectorStoreOp.getIndices(); + llvm_unreachable("unsupported op type"); +} + +void setIndices(Operation *op, ArrayRef indices) { + if (auto ldmatrixOp = dyn_cast(op)) + return ldmatrixOp.indicesMutable().assign(indices); + if (auto copyOp = dyn_cast(op)) + return copyOp.dstIndicesMutable().assign(indices); + if (auto loadOp = dyn_cast(op)) + return loadOp.indicesMutable().assign(indices); + if (auto storeOp = dyn_cast(op)) + return storeOp.indicesMutable().assign(indices); + if (auto vectorReadOp = dyn_cast(op)) + return vectorReadOp.getIndicesMutable().assign(indices); + if (auto vectorStoreOp = dyn_cast(op)) + return vectorStoreOp.getIndicesMutable().assign(indices); + llvm_unreachable("unsupported op type"); +} + +/// Return all operations within `parentOp` that read from or write to +/// `shmMemRef`. +static LogicalResult +getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef, + SmallVector &readOps, + SmallVector &writeOps) { + parentOp->walk([&](Operation *op) { + MemoryEffectOpInterface iface = dyn_cast(op); + if (!iface) + return; + Optional effect = + iface.getEffectOnValue(shmMemRef); + if (effect) { + readOps.push_back(op); + return; + } + effect = iface.getEffectOnValue(shmMemRef); + if (effect) + writeOps.push_back(op); + }); + + // Restrict to a supported set of ops. We also require at least 2D access, + // although this could be relaxed. + if (llvm::any_of(readOps, [](Operation *op) { + return !isa(op) || + getIndices(op).size() < 2; + })) + return failure(); + if (llvm::any_of(writeOps, [](Operation *op) { + return !isa( + op) || + getIndices(op).size() < 2; + })) + return failure(); + + return success(); +} + +mlir::LogicalResult +mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp, + Value memrefValue) { + auto memRefType = memrefValue.getType().dyn_cast(); + if (!memRefType || memRefType.getMemorySpaceAsInt() != + gpu::GPUDialect::getWorkgroupAddressSpace()) + return failure(); + + // Abort if the given value has any sub-views; we do not do any alias + // analysis. + bool hasSubView = false; + parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; }); + if (hasSubView) + return failure(); + + // Check if this is necessary given the assumption of 128b accesses: + // If dim[rank-1] is small enough to fit 8 rows in a 128B line. + const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1); + const int64_t rowsPerLine = + (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) / + rowSize; + const int64_t threadGroupSize = + 1 << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8)); + if (rowsPerLine >= threadGroupSize) + return failure(); + + // Get sets of operations within the function that read/write to shared + // memory. + SmallVector shmReadOps; + SmallVector shmWriteOps; + if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps, + shmWriteOps))) + return failure(); + + if (shmReadOps.empty() || shmWriteOps.empty()) + return failure(); + + OpBuilder builder(parentOp->getContext()); + + int64_t tgtDim = memRefType.getRank() - 1; + int64_t srcDim = memRefType.getRank() - 2; + + // Transform indices for the ops writing to shared memory. + while (!shmWriteOps.empty()) { + Operation *shmWriteOp = shmWriteOps.back(); + shmWriteOps.pop_back(); + builder.setInsertionPoint(shmWriteOp); + + auto indices = getIndices(shmWriteOp); + SmallVector transformedIndices(indices.begin(), indices.end()); + transformIndices(builder, shmWriteOp->getLoc(), transformedIndices, + memRefType, srcDim, tgtDim); + setIndices(shmWriteOp, transformedIndices); + } + + // Transform indices for the ops reading from shared memory. + while (!shmReadOps.empty()) { + Operation *shmReadOp = shmReadOps.back(); + shmReadOps.pop_back(); + builder.setInsertionPoint(shmReadOp); + + auto indices = getIndices(shmReadOp); + SmallVector transformedIndices(indices.begin(), indices.end()); + transformIndices(builder, shmReadOp->getLoc(), transformedIndices, + memRefType, srcDim, tgtDim); + setIndices(shmReadOp, transformedIndices); + } + + return success(); +} + +namespace { +class OptimizeSharedMemoryPass + : public OptimizeSharedMemoryBase { +public: + OptimizeSharedMemoryPass() = default; + + void runOnOperation() override { + Operation *op = getOperation(); + SmallVector shmAllocOps; + op->walk([&](memref::AllocOp allocOp) { + if (allocOp.memref().getType().cast().getMemorySpaceAsInt() != + gpu::GPUDialect::getWorkgroupAddressSpace()) + return; + shmAllocOps.push_back(allocOp); + }); + for (auto allocOp : shmAllocOps) { + if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(), + allocOp.memref()))) + return; + } + } +}; +} // namespace + +std::unique_ptr mlir::nvgpu::createOptimizeSharedMemoryPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/NVGPU/Transforms/PassDetail.h b/mlir/lib/Dialect/NVGPU/Transforms/PassDetail.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/NVGPU/Transforms/PassDetail.h @@ -0,0 +1,33 @@ +//===- PassDetail.h - NVGPU Pass class details -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef DIALECT_NVGPU_TRANSFORMS_PASSDETAIL_H_ +#define DIALECT_NVGPU_TRANSFORMS_PASSDETAIL_H_ + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace arith { +class ArithmeticDialect; +} // namespace arith + +namespace memref { +class MemRefDialect; +} // namespace memref + +namespace vector { +class VectorDialect; +} // namespace vector + +#define GEN_PASS_CLASSES +#include "mlir/Dialect/NVGPU/Passes.h.inc" + +} // namespace mlir + +#endif // DIALECT_NVGPU_TRANSFORMS_PASSDETAIL_H_ diff --git a/mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir b/mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir @@ -0,0 +1,240 @@ +// RUN: mlir-opt %s -split-input-file --pass-pipeline='func.func(nvgpu-optimize-shared-memory)' | FileCheck %s + +// CHECK: @optimize_128x32xf16_32x128xf16([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index) +func.func @optimize_128x32xf16_32x128xf16(%arg0: memref<128x128xf16>, + %ldRow: index, %ldCol: index, + %stRow: index, %stCol: index, + %fragRow: index, %fragCol :index) + -> (vector<4x2xf16>, vector<4x2xf16>) { + // CHECK: [[shm:%.+]] = memref.alloc + // CHECK: [[shmB:%.+]] = memref.alloc + %shm = memref.alloc() : memref<128x32xf16, 3> + %shmB = memref.alloc() : memref<32x128xf16, 3> + + // CHECK: [[c6:%.+]] = arith.constant 6 : index + // CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c6]] + // CHECK: [[c2:%.+]] = arith.constant 2 : index + // CHECK: [[xorBits:%.+]] = arith.shli [[src_bits]], [[c2]] + // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]] + // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stColPerm]]] + %0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 8 + : memref<128x128xf16> to memref<128x32xf16, 3> + %1 = nvgpu.device_async_create_group %0 + nvgpu.device_async_wait %1 { numGroups = 1 : i32} + + // CHECK: [[c6:%.+]] = arith.constant 6 : index + // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]] + // CHECK: [[c2:%.+]] = arith.constant 2 : index + // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] + // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] + // CHECK: nvgpu.ldmatrix [[shm]][[[fragRow]], [[fragColPerm]]] + %mat = nvgpu.ldmatrix %shm[%fragRow, %fragCol] {numTiles = 4 : i32, transpose = false} + : memref<128x32xf16, 3> -> vector<4x2xf16> + + // CHECK: [[c15:%.+]] = arith.constant 15 : index + // CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c15]] + // CHECK: [[c3:%.+]] = arith.constant 3 : index + // CHECK: [[xorBits:%.+]] = arith.shli [[src_bits]], [[c3]] + // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]] + // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shmB]][[[stRow]], [[stColPerm]]] + %2 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shmB[%stRow, %stCol], 8 + : memref<128x128xf16> to memref<32x128xf16, 3> + %3 = nvgpu.device_async_create_group %0 + nvgpu.device_async_wait %1 { numGroups = 1 : i32} + + // CHECK: [[c15:%.+]] = arith.constant 15 : index + // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c15]] + // CHECK: [[c3:%.+]] = arith.constant 3 : index + // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c3]] + // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] + // CHECK: nvgpu.ldmatrix [[shmB]][[[fragRow]], [[fragColPerm]]] + %matB = nvgpu.ldmatrix %shmB[%fragRow, %fragCol] {numTiles = 4 : i32, transpose = false} + : memref<32x128xf16, 3> -> vector<4x2xf16> + + return %mat, %matB: vector<4x2xf16>, vector<4x2xf16> +} + + +// ----- + +// CHECK: @optimize_64x16xf32_16x64xf32([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index) +func.func @optimize_64x16xf32_16x64xf32(%arg0: memref<128x128xf32>, + %ldRow: index, %ldCol: index, + %stRow: index, %stCol: index, + %fragRow: index, %fragCol :index) + -> (vector<4x1xf32>, vector<4x1xf32>, f32, vector<4xf32>, f32) { + // CHECK: [[shm:%.+]] = memref.alloc + // CHECK: [[shmB:%.+]] = memref.alloc + %shm = memref.alloc() : memref<64x16xf32, 3> + %shmB = memref.alloc() : memref<16x64xf32, 3> + + // CHECK: [[c6:%.+]] = arith.constant 6 : index + // CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c6]] + // CHECK: [[c1:%.+]] = arith.constant 1 : index + // CHECK: [[xorBits:%.+]] = arith.shli [[src_bits]], [[c1]] + // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]] + // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stColPerm]]] + %0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 8 + : memref<128x128xf32> to memref<64x16xf32, 3> + %1 = nvgpu.device_async_create_group %0 + nvgpu.device_async_wait %1 { numGroups = 1 : i32} + + // CHECK: [[c6:%.+]] = arith.constant 6 : index + // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]] + // CHECK: [[c1:%.+]] = arith.constant 1 : index + // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]] + // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] + // CHECK: nvgpu.ldmatrix [[shm]][[[fragRow]], [[fragColPerm]]] + %mat = nvgpu.ldmatrix %shm[%fragRow, %fragCol] {numTiles = 4 : i32, transpose = false} + : memref<64x16xf32, 3> -> vector<4x1xf32> + + // CHECK: [[c6:%.+]] = arith.constant 6 : index + // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]] + // CHECK: [[c1:%.+]] = arith.constant 1 : index + // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]] + // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] + // CHECK: memref.load [[shm]][[[fragRow]], [[fragColPerm]]] + %elem = memref.load %shm[%fragRow, %fragCol] : memref<64x16xf32, 3> + + // Verify vector operations. + + // CHECK: [[c6:%.+]] = arith.constant 6 : index + // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]] + // CHECK: [[c1:%.+]] = arith.constant 1 : index + // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]] + // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] + // CHECK: vector.load [[shm]][[[fragRow]], [[fragColPerm]]] + %elem2 = vector.load %shm[%fragRow, %fragCol] : memref<64x16xf32, 3>, vector<4xf32> + + // CHECK: [[c6:%.+]] = arith.constant 6 : index + // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]] + // CHECK: [[c1:%.+]] = arith.constant 1 : index + // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]] + // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] + // CHECK: vector.store %{{.+}}, [[shm]][[[fragRow]], [[fragColPerm]]] + vector.store %elem2, %shm[%fragRow, %fragCol] : memref<64x16xf32, 3>, vector<4xf32> + + // CHECK: [[c6:%.+]] = arith.constant 6 : index + // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]] + // CHECK: [[c1:%.+]] = arith.constant 1 : index + // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]] + // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] + // CHECK: memref.store %{{.+}}, [[shm]][[[fragRow]], [[fragColPerm]]] + memref.store %elem, %shm[%fragRow, %fragCol] : memref<64x16xf32, 3> + + // Verify 16x64xf32 memory size. + + // CHECK: [[c15:%.+]] = arith.constant 15 : index + // CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c15]] + // CHECK: [[c2:%.+]] = arith.constant 2 : index + // CHECK: [[xorBits:%.+]] = arith.shli [[src_bits]], [[c2]] + // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]] + // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shmB]][[[stRow]], [[stColPerm]]] + %2 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shmB[%stRow, %stCol], 8 + : memref<128x128xf32> to memref<16x64xf32, 3> + %3 = nvgpu.device_async_create_group %0 + nvgpu.device_async_wait %1 { numGroups = 1 : i32} + + // CHECK: [[c15:%.+]] = arith.constant 15 : index + // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c15]] + // CHECK: [[c2:%.+]] = arith.constant 2 : index + // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] + // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] + // CHECK: nvgpu.ldmatrix [[shmB]][[[fragRow]], [[fragColPerm]]] + %matB = nvgpu.ldmatrix %shmB[%fragRow, %fragCol] {numTiles = 4 : i32, transpose = false} + : memref<16x64xf32, 3> -> vector<4x1xf32> + + // CHECK: [[c15:%.+]] = arith.constant 15 : index + // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c15]] + // CHECK: [[c2:%.+]] = arith.constant 2 : index + // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] + // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] + // CHECK: memref.load [[shmB]][[[fragRow]], [[fragColPerm]]] + %elemB = memref.load %shmB[%fragRow, %fragCol] : memref<16x64xf32, 3> + + return %mat, %matB, %elem, %elem2, %elemB: vector<4x1xf32>, vector<4x1xf32>, f32, vector<4xf32>, f32 +} + + +// ----- + +// Small column edge cases + +// CHECK: @small_column_size_f64([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index) +func.func @small_column_size_f64(%arg0: memref<32x32xf64>, + %ldRow: index, %ldCol: index, + %stRow: index, %stCol: index, + %fragRow: index, %fragCol :index) + -> f64 { + // CHECK: [[shm:%.+]] = memref.alloc + %shm = memref.alloc() : memref<32x4xf64, 3> + + // CHECK: [[c4:%.+]] = arith.constant 4 : index + // CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c4]] + // CHECK: [[c1:%.+]] = arith.constant 1 : index + // CHECK: [[xorBits:%.+]] = arith.shrui [[src_bits]], [[c1]] + // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]] + // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stColPerm]]] + %0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 8 + : memref<32x32xf64> to memref<32x4xf64, 3> + %1 = nvgpu.device_async_create_group %0 + nvgpu.device_async_wait %1 { numGroups = 1 : i32} + + // CHECK: [[c6:%.+]] = arith.constant 4 : index + // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]] + // CHECK: [[c1:%.+]] = arith.constant 1 : index + // CHECK: [[xorBits:%.+]] = arith.shrui [[srcBits]], [[c1]] + // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]] + // CHECK: memref.load [[shm]][[[fragRow]], [[fragColPerm]]] + %el = memref.load %shm[%fragRow, %fragCol] : memref<32x4xf64, 3> + + return %el: f64 +} + +// CHECK: @too_small_column_size_f16([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index) +func.func @too_small_column_size_f16(%arg0: memref<128x128xf16>, + %ldRow: index, %ldCol: index, + %stRow: index, %stCol: index, + %fragRow: index, %fragCol :index) + -> vector<1x2xf16> { + // CHECK: [[shm:%.+]] = memref.alloc + %shm = memref.alloc() : memref<128x8xf16, 3> + + // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stCol]]] + %0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 8 + : memref<128x128xf16> to memref<128x8xf16, 3> + %1 = nvgpu.device_async_create_group %0 + nvgpu.device_async_wait %1 { numGroups = 1 : i32} + + // CHECK: nvgpu.ldmatrix [[shm]][[[fragRow]], [[fragCol]]] + %mat = nvgpu.ldmatrix %shm[%fragRow, %fragCol] {numTiles = 1 : i32, transpose = false} + : memref<128x8xf16, 3> -> vector<1x2xf16> + + return %mat: vector<1x2xf16> +} + +// ----- + +// CHECK: @abort_if_subview([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index) +func.func @abort_if_subview(%arg0: memref<128x128xf16>, + %ldRow: index, %ldCol: index, + %stRow: index, %stCol: index, + %fragRow: index, %fragCol :index) + -> vector<1x2xf16> { + // CHECK: [[shm:%.+]] = memref.alloc + %shm = memref.alloc() : memref<128x32xf16, 3> + // CHECK: [[shmView:%.+]] = memref.subview + %shmView = memref.subview %shm[0, 0][64, 32][1, 1] : memref<128x32xf16, 3> to memref<64x32xf16, 3> + + // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stCol]]] + %0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 8 + : memref<128x128xf16> to memref<128x32xf16, 3> + %1 = nvgpu.device_async_create_group %0 + nvgpu.device_async_wait %1 { numGroups = 1 : i32} + + // CHECK: nvgpu.ldmatrix [[shmView]][[[fragRow]], [[fragCol]]] + %mat = nvgpu.ldmatrix %shmView[%fragRow, %fragCol] {numTiles = 1 : i32, transpose = false} + : memref<64x32xf16, 3> -> vector<1x2xf16> + + return %mat: vector<1x2xf16> +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2143,7 +2143,7 @@ td_library( name = "NVGPUTdFiles", - srcs = ["include/mlir/Dialect/NVGPU/NVGPU.td"], + srcs = ["include/mlir/Dialect/NVGPU/IR/NVGPU.td"], includes = ["include"], deps = [ ":SideEffectInterfacesTdFiles", @@ -2159,22 +2159,22 @@ "-gen-dialect-decls", "-dialect=nvgpu", ], - "include/mlir/Dialect/NVGPU/NVGPUDialect.h.inc", + "include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h.inc", ), ( [ "-gen-dialect-defs", "-dialect=nvgpu", ], - "include/mlir/Dialect/NVGPU/NVGPUDialect.cpp.inc", + "include/mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc", ), ( ["-gen-op-decls"], - "include/mlir/Dialect/NVGPU/NVGPU.h.inc", + "include/mlir/Dialect/NVGPU/IR/NVGPU.h.inc", ), ( ["-gen-op-defs"], - "include/mlir/Dialect/NVGPU/NVGPU.cpp.inc", + "include/mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc", ), ( ["-gen-op-doc"], @@ -2182,20 +2182,66 @@ ), ], tblgen = ":mlir-tblgen", - td_file = "include/mlir/Dialect/NVGPU/NVGPU.td", + td_file = "include/mlir/Dialect/NVGPU/IR/NVGPU.td", deps = [":NVGPUTdFiles"], ) +gentbl_cc_library( + name = "NVGPUPassIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=NVGPU", + ], + "include/mlir/Dialect/NVGPU/Passes.h.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/NVGPU/Passes.td", + deps = [":PassBaseTdFiles"], +) + cc_library( name = "NVGPUDialect", srcs = ["lib/Dialect/NVGPU/IR/NVGPUDialect.cpp"], - hdrs = ["include/mlir/Dialect/NVGPU/NVGPUDialect.h"], + hdrs = ["include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h"], includes = ["include"], deps = [ ":GPUDialect", ":IR", ":NVGPUIncGen", + ":NVGPUPassIncGen", + ":SideEffectInterfaces", + "//llvm:Core", + "//llvm:Support", + ], +) + +cc_library( + name = "NVGPUTransforms", + srcs = [ + "lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp", + "lib/Dialect/NVGPU/Transforms/PassDetail.h", + ], + hdrs = [ + "include/mlir/Dialect/NVGPU/Transforms/Transforms.h", + "include/mlir/Dialect/NVGPU/Passes.h", + ], + includes = ["include"], + deps = [ + ":FuncDialect", + ":AffineDialect", + ":ArithmeticDialect", + ":GPUDialect", + ":MemRefDialect", + ":NVGPUDialect", + ":Pass", ":SideEffectInterfaces", + ":Support", + ":Transforms", + ":VectorDialect", "//llvm:Core", "//llvm:Support", ], @@ -6215,6 +6261,8 @@ ":MemRefToSPIRV", ":MemRefTransforms", ":NVGPUDialect", + ":NVGPUPassIncGen", + ":NVGPUTransforms", ":NVGPUToNVVM", ":NVVMDialect", ":OpenACCDialect",