Index: mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h @@ -0,0 +1,44 @@ +//===- VectorDistribution.h - Vector distribution patterns --*- C++------*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_ +#define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_ + +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +namespace mlir { +class RewritePatternSet; +namespace vector { + +struct WarpExecuteOnLane0LoweringOptions { + /// Lamdba function to let users allocate memory needed for the lowering of + /// WarpExecuteOnLane0Op. + /// The function needs to return an allocation that the lowering can use as + /// temporary memory. The allocation needs to match the shape of the type (the + /// type may be VectorType or a scalar) and be availble for the current warp. + /// If there are several warps running in parallel the allocation needs to be + /// split so that each warp has its own allocation. + using WarpAllocationFn = + std::function; + WarpAllocationFn warpAllocationFn = nullptr; + + /// Lamdba function to let user emit operation to syncronize all the thread + /// within a warp. After this operation all the threads can see any memory + /// written before the operation. + using WarpSyncronizationFn = + std::function; + WarpSyncronizationFn warpSyncronizationFn = nullptr; +}; + +void populateWarpExecuteOnLane0OpToScfForPattern( + RewritePatternSet &patterns, + const WarpExecuteOnLane0LoweringOptions &options); + +} // namespace vector +} // namespace mlir +#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_ Index: mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRVectorTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp + VectorDistribute.cpp VectorDropLeadUnitDim.cpp VectorInsertExtractStridedSliceRewritePatterns.cpp VectorMultiDimReductionTransforms.cpp Index: mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -0,0 +1,158 @@ +//===- VectorDistribute.cpp - patterns to do vector distribution ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" + +using namespace mlir; +using namespace mlir::vector; + +static LogicalResult +rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, + const WarpExecuteOnLane0LoweringOptions &options) { + assert(warpOp.getBodyRegion().hasOneBlock() && + "expected WarpOp with single block"); + Block *warpOpBody = &warpOp.getBodyRegion().front(); + Location loc = warpOp.getLoc(); + + // Passed all checks. Start rewriting. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(warpOp); + + // Create scf.if op. + Value c0 = rewriter.create(loc, 0); + Value isLane0 = rewriter.create(loc, arith::CmpIPredicate::eq, + warpOp.getLaneid(), c0); + auto ifOp = rewriter.create(loc, isLane0, + /*withElseRegion=*/false); + rewriter.eraseOp(ifOp.thenBlock()->getTerminator()); + + // Store vectors that are defined outside of warpOp into the scratch pad + // buffer. + SmallVector bbArgReplacements; + for (const auto &it : llvm::enumerate(warpOp.getArgs())) { + Value val = it.value(); + Value bbArg = warpOpBody->getArgument(it.index()); + + rewriter.setInsertionPoint(ifOp); + Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp, + bbArg.getType()); + + // Store arg vector into buffer. + rewriter.setInsertionPoint(ifOp); + auto vectorType = val.getType().cast(); + int64_t storeSize = vectorType.getShape()[0]; + Value storeOffset = rewriter.create( + loc, warpOp.getLaneid(), + rewriter.create(loc, storeSize)); + rewriter.create(loc, val, buffer, storeOffset); + + // Load bbArg vector from buffer. + rewriter.setInsertionPointToStart(ifOp.thenBlock()); + auto bbArgType = bbArg.getType().cast(); + Value loadOp = rewriter.create(loc, bbArgType, buffer, c0); + bbArgReplacements.push_back(loadOp); + } + + // Insert sync after all the stores and before all the loads. + if (!warpOp.getArgs().empty()) { + rewriter.setInsertionPoint(ifOp); + options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp); + } + + // Move body of warpOp to ifOp. + rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements); + + // Rewrite terminator and compute replacements of WarpOp results. + SmallVector replacements; + auto yieldOp = cast(ifOp.thenBlock()->getTerminator()); + Location yieldLoc = yieldOp.getLoc(); + for (const auto &it : llvm::enumerate(yieldOp.operands())) { + Value val = it.value(); + Type resultType = warpOp->getResultTypes()[it.index()]; + rewriter.setInsertionPoint(ifOp); + Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp, + val.getType()); + + // Store yielded value into buffer. + rewriter.setInsertionPoint(yieldOp); + if (val.getType().isa()) + rewriter.create(yieldLoc, val, buffer, c0); + else + rewriter.create(yieldLoc, val, buffer, c0); + + // Load value from buffer (after warpOp). + rewriter.setInsertionPointAfter(ifOp); + if (resultType == val.getType()) { + // Result type and yielded value type are the same. This is a broadcast. + // E.g.: + // %r = vector_ext.warp_execute_on_lane_0(...) -> (f32) { + // vector_ext.yield %cst : f32 + // } + // Both types are f32. The constant %cst is broadcasted to all lanes. + // This is described in more detail in the documentation of the op. + Value loadOp = rewriter.create(loc, buffer, c0); + replacements.push_back(loadOp); + } else { + auto loadedVectorType = resultType.cast(); + int64_t loadSize = loadedVectorType.getShape()[0]; + + // loadOffset = laneid * loadSize + Value loadOffset = rewriter.create( + loc, warpOp.getLaneid(), + rewriter.create(loc, loadSize)); + Value loadOp = rewriter.create(loc, loadedVectorType, + buffer, loadOffset); + replacements.push_back(loadOp); + } + } + + // Insert sync after all the stores and before all the loads. + if (!yieldOp.operands().empty()) { + rewriter.setInsertionPointAfter(ifOp); + options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp); + } + + // Delete terminator and add empty scf.yield. + rewriter.eraseOp(yieldOp); + rewriter.setInsertionPointToEnd(ifOp.thenBlock()); + rewriter.create(yieldLoc); + + // Compute replacements for WarpOp results. + rewriter.replaceOp(warpOp, replacements); + + return success(); +} + +namespace { + +struct WarpOpToScfForPattern : public OpRewritePattern { + WarpOpToScfForPattern(MLIRContext *context, + const WarpExecuteOnLane0LoweringOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + options(options) {} + + LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + return rewriteWarpOpToScfFor(rewriter, warpOp, options); + } + +private: + const WarpExecuteOnLane0LoweringOptions &options; +}; + +} // namespace + +void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( + RewritePatternSet &patterns, + const WarpExecuteOnLane0LoweringOptions &options) { + patterns.add(patterns.getContext(), options); +} Index: mlir/test/Dialect/Vector/vector-warp-distribute.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -0,0 +1,54 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if | FileCheck %s --check-prefix=CHECK-SCF-IF + +// CHECK-SCF-IF-DAG: memref.global "private" @__shared_32xf32 : memref<32xf32, 3> +// CHECK-SCF-IF-DAG: memref.global "private" @__shared_64xf32 : memref<64xf32, 3> +// CHECK-SCF-IF-DAG: memref.global "private" @__shared_128xf32 : memref<128xf32, 3> +// CHECK-SCF-IF-DAG: memref.global "private" @__shared_256xf32 : memref<256xf32, 3> + +// CHECK-SCF-IF-LABEL: func @rewrite_warp_op_to_scf_if( +// CHECK-SCF-IF-SAME: %[[laneid:.*]]: index, +// CHECK-SCF-IF-SAME: %[[v0:.*]]: vector<4xf32>, %[[v1:.*]]: vector<8xf32>) +func.func @rewrite_warp_op_to_scf_if(%laneid: index, + %v0: vector<4xf32>, %v1: vector<8xf32>) { +// CHECK-SCF-IF-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-SCF-IF-DAG: %[[c2:.*]] = arith.constant 2 : index +// CHECK-SCF-IF-DAG: %[[c4:.*]] = arith.constant 4 : index +// CHECK-SCF-IF-DAG: %[[c8:.*]] = arith.constant 8 : index +// CHECK-SCF-IF: %[[is_lane_0:.*]] = arith.cmpi eq, %[[laneid]], %[[c0]] + +// CHECK-SCF-IF: %[[buffer_v0:.*]] = memref.get_global @__shared_128xf32 +// CHECK-SCF-IF: %[[s0:.*]] = arith.muli %[[laneid]], %[[c4]] +// CHECK-SCF-IF: vector.store %[[v0]], %[[buffer_v0]][%[[s0]]] +// CHECK-SCF-IF: %[[buffer_v1:.*]] = memref.get_global @__shared_256xf32 +// CHECK-SCF-IF: %[[s1:.*]] = arith.muli %[[laneid]], %[[c8]] +// CHECK-SCF-IF: vector.store %[[v1]], %[[buffer_v1]][%[[s1]]] + +// CHECK-SCF-IF-DAG: gpu.barrier +// CHECK-SCF-IF-DAG: %[[buffer_def_0:.*]] = memref.get_global @__shared_32xf32 +// CHECK-SCF-IF-DAG: %[[buffer_def_1:.*]] = memref.get_global @__shared_64xf32 + +// CHECK-SCF-IF: scf.if %[[is_lane_0]] { + %r:2 = vector.warp_execute_on_lane_0(%laneid)[32] + args(%v0, %v1 : vector<4xf32>, vector<8xf32>) -> (vector<1xf32>, vector<2xf32>) { + ^bb0(%arg0: vector<128xf32>, %arg1: vector<256xf32>): +// CHECK-SCF-IF: %[[arg1:.*]] = vector.load %[[buffer_v1]][%[[c0]]] : memref<256xf32, 3>, vector<256xf32> +// CHECK-SCF-IF: %[[arg0:.*]] = vector.load %[[buffer_v0]][%[[c0]]] : memref<128xf32, 3>, vector<128xf32> +// CHECK-SCF-IF: %[[def_0:.*]] = "some_def"(%[[arg0]]) : (vector<128xf32>) -> vector<32xf32> +// CHECK-SCF-IF: %[[def_1:.*]] = "some_def"(%[[arg1]]) : (vector<256xf32>) -> vector<64xf32> + %2 = "some_def"(%arg0) : (vector<128xf32>) -> vector<32xf32> + %3 = "some_def"(%arg1) : (vector<256xf32>) -> vector<64xf32> +// CHECK-SCF-IF: vector.store %[[def_0]], %[[buffer_def_0]][%[[c0]]] +// CHECK-SCF-IF: vector.store %[[def_1]], %[[buffer_def_1]][%[[c0]]] + vector.yield %2, %3 : vector<32xf32>, vector<64xf32> + } +// CHECK-SCF-IF: } +// CHECK-SCF-IF: gpu.barrier +// CHECK-SCF-IF: %[[o1:.*]] = arith.muli %[[laneid]], %[[c2]] +// CHECK-SCF-IF: %[[r1:.*]] = vector.load %[[buffer_def_1]][%[[o1]]] : memref<64xf32, 3>, vector<2xf32> +// CHECK-SCF-IF: %[[r0:.*]] = vector.load %[[buffer_def_0]][%[[laneid]]] : memref<32xf32, 3>, vector<1xf32> +// CHECK-SCF-IF: "some_use"(%[[r0]]) : (vector<1xf32>) -> () +// CHECK-SCF-IF: "some_use"(%[[r1]]) : (vector<2xf32>) -> () + "some_use"(%r#0) : (vector<1xf32>) -> () + "some_use"(%r#1) : (vector<2xf32>) -> () + return +} Index: mlir/test/Integration/Dialect/Vector/GPU/CUDA/lit.local.cfg =================================================================== --- /dev/null +++ mlir/test/Integration/Dialect/Vector/GPU/CUDA/lit.local.cfg @@ -0,0 +1,2 @@ +if not config.enable_cuda_runner: + config.unsupported = True Index: mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir =================================================================== --- /dev/null +++ mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir @@ -0,0 +1,56 @@ +// Run the test cases without distributing ops to test default lowering. Run +// everything on the same thread. +// RUN: mlir-opt %s -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if -canonicalize | \ +// RUN: mlir-opt -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \ +// RUN: -gpu-kernel-outlining \ +// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm,reconcile-unrealized-casts,gpu-to-cubin)' \ +// RUN: -gpu-to-llvm -reconcile-unrealized-casts |\ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_cuda_runtime%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | \ +// RUN: FileCheck %s + +func.func @gpu_func(%arg1: memref<32xf32>, %arg2: memref<32xf32>) { + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %cst = arith.constant 0.000000e+00 : f32 + gpu.launch blocks(%arg3, %arg4, %arg5) + in (%arg9 = %c1, %arg10 = %c1, %arg11 = %c1) + threads(%arg6, %arg7, %arg8) in (%arg12 = %c32, %arg13 = %c1, %arg14 = %c1) { + vector.warp_execute_on_lane_0(%arg6)[32] { + %0 = vector.transfer_read %arg1[%c0], %cst {in_bounds = [true]} : memref<32xf32>, vector<32xf32> + %1 = vector.transfer_read %arg2[%c0], %cst {in_bound = [true]} : memref<32xf32>, vector<32xf32> + %2 = arith.addf %0, %1 : vector<32xf32> + vector.transfer_write %2, %arg1[%c0] {in_bounds = [true]} : vector<32xf32>, memref<32xf32> + } + gpu.terminator + } + return +} +func.func @main() { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %0 = memref.alloc() : memref<32xf32> + %1 = memref.alloc() : memref<32xf32> + %cst_1 = arith.constant dense<[ + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, + 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, + 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, + 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0]> : vector<32xf32> + %cst_2 = arith.constant dense<2.000000e+00> : vector<32xf32> + // init the buffers. + vector.transfer_write %cst_1, %0[%c0] {in_bounds = [true]} : vector<32xf32>, memref<32xf32> + vector.transfer_write %cst_2, %1[%c0] {in_bounds = [true]} : vector<32xf32>, memref<32xf32> + %3 = memref.cast %0 : memref<32xf32> to memref<*xf32> + gpu.host_register %3 : memref<*xf32> + %5 = memref.cast %1 : memref<32xf32> to memref<*xf32> + gpu.host_register %5 : memref<*xf32> + call @gpu_func(%0, %1) : (memref<32xf32>, memref<32xf32>) -> () + %6 = vector.transfer_read %0[%c0], %cst : memref<32xf32>, vector<32xf32> + vector.print %6 : vector<32xf32> + return +} + +// CHECK: ( 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33 ) Index: mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp =================================================================== --- mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -11,12 +11,14 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -700,6 +702,90 @@ } }; +/// Allocate shared memory for a single warp to test lowering of +/// WarpExecuteOnLane0Op. +static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder, + WarpExecuteOnLane0Op warpOp, + Type type) { + static constexpr int64_t kSharedMemorySpace = 3; + // Compute type of shared memory buffer. + MemRefType memrefType; + if (auto vectorType = type.dyn_cast()) { + memrefType = + MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {}, + kSharedMemorySpace); + } else { + memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace); + } + + // Get symbol table holding all shared memory globals. + ModuleOp moduleOp = warpOp->getParentOfType(); + SymbolTable symbolTable(moduleOp); + + // Create a pretty name. + SmallString<64> buf; + llvm::raw_svector_ostream os(buf); + interleave(memrefType.getShape(), os, "x"); + os << "x" << memrefType.getElementType(); + std::string symbolName = (Twine("__shared_") + os.str()).str(); + + auto ip = builder.saveInsertionPoint(); + builder.setInsertionPoint(moduleOp); + auto global = builder.create( + loc, + /*sym_name=*/symbolName, + /*sym_visibility=*/builder.getStringAttr("private"), + /*type=*/memrefType, + /*initial_value=*/Attribute(), + /*constant=*/false, + /*alignment=*/IntegerAttr()); + symbolTable.insert(global); + // The symbol table inserts at the end of the module, but globals are a bit + // nicer if they are at the beginning. + global->moveBefore(&moduleOp.front()); + + builder.restoreInsertionPoint(ip); + return builder.create(loc, memrefType, symbolName); +} + +struct TestVectorDistribution + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution) + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + StringRef getArgument() const final { return "test-vector-warp-distribute"; } + StringRef getDescription() const final { + return "Test vector warp distribute transformation and lowering patterns"; + } + TestVectorDistribution() = default; + TestVectorDistribution(const TestVectorDistribution &pass) + : PassWrapper(pass) {} + + Option warpOpToSCF{ + *this, "rewrite-warp-ops-to-scf-if", + llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"), + llvm::cl::init(false)}; + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + WarpExecuteOnLane0LoweringOptions options; + options.warpAllocationFn = allocateGlobalSharedMemory; + options.warpSyncronizationFn = [](Location loc, OpBuilder &builder, + WarpExecuteOnLane0Op warpOp) { + builder.create(loc); + }; + // Test on one pattern in isolation. + if (warpOpToSCF) { + populateWarpExecuteOnLane0OpToScfForPattern(patterns, options); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + return; + } + } +}; + } // namespace namespace mlir { @@ -736,6 +822,8 @@ PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir Index: utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel =================================================================== --- utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -586,6 +586,7 @@ "//mlir:Affine", "//mlir:Analysis", "//mlir:FuncDialect", + "//mlir:GPUDialect", "//mlir:LLVMDialect", "//mlir:LinalgOps", "//mlir:LinalgTransforms",