diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -748,6 +748,30 @@ }]; } +def ForeachThreadToGpuAndTranslationInfo : + Op { + let description = [{ + TBD + }]; + + let arguments = (ins PDL_Operation:$target, + DefaultValuedAttr:$workgroup_size); + let results = (outs PDL_Operation:$result); + + let assemblyFormat = "$target attr-dict"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); + }]; +} + def VectorizeOp : Op { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -8,6 +8,8 @@ #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" + #include "mlir/AsmParser/AsmParser.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" @@ -21,6 +23,7 @@ #include "mlir/Parser/Parser.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/StringSet.h" +#include "llvm/ADT/STLExtras.h" using namespace mlir; using namespace mlir::linalg; @@ -1027,6 +1030,178 @@ modifiesPayload(effects); } +//===----------------------------------------------------------------------===// +// ForeachThreadToGpuAndTranslationInfo +//===----------------------------------------------------------------------===// +template +static FailureOr> permute(const SmallVector &vals, + ArrayRef perm) { + if (vals.size() != perm.size()) return failure(); + SmallVector result(vals.size()); + SmallVector seen(vals.size()); + for (auto [idx, val] : llvm::zip(perm, vals)) { + // Already seen, invalid thread_dim_mapping. + if (seen[idx]) return failure(); + result[idx] = val; + seen[idx] = true; + } + // Some not seen, invalid thread_dim_mapping. + if (!llvm::all_of(seen, [](bool b) { return b; })) return failure(); + return result; +} +template +static FailureOr> getValuesPermutedByThreadMapping( + scf::ForeachThreadOp foreachThreadOp, const SmallVector &values) { + // Apply mapping permutation if specified. + auto mapping = foreachThreadOp.getThreadDimMapping(); + + if (mapping && !mapping.empty()) { + auto maybePermuted = permute(values, extractFromI64ArrayAttr(mapping)); + if (failed(maybePermuted)) + return foreachThreadOp->emitError("invalid permutation"); + return *maybePermuted; + } + return values; +} + +static FailureOr> getThreadIndices( + OpBuilder &b, scf::ForeachThreadOp foreachThreadOp) { + SmallVector threadCount = foreachThreadOp.getThreadIndices(); + threadCount.resize(3, Value()); + return getValuesPermutedByThreadMapping(foreachThreadOp, threadCount); +} + + +static FailureOr> getNumThreads( + OpBuilder &b, scf::ForeachThreadOp foreachThreadOp) { + SmallVector threadCount = foreachThreadOp.getNumThreads(); + threadCount.resize(3, b.getIndexAttr(1)); + return getValuesPermutedByThreadMapping(foreachThreadOp, threadCount); +} + +static FailureOr> +rewriteForeachThreadToGpu( + scf::ForeachThreadOp foreachThreadOp, + const SmallVector &globalWorkgroupSizes, RewriterBase &rewriter, + bool syncAfterDistribute) { + if (foreachThreadOp.getNumResults() > 0) + return foreachThreadOp->emitError( + "only bufferized scf.foreach_thread lowers to gpu.thread"); + if (foreachThreadOp.getNumThreads().size() > 3) + return foreachThreadOp->emitError( + "scf.foreach_thread with rank > 3 does not lower to gpu.thread"); + + auto maybeWorkgroupSizes = getNumThreads(rewriter, foreachThreadOp); + if (failed(maybeWorkgroupSizes) || + llvm::any_of(*maybeWorkgroupSizes, [](OpFoldResult ofr) { + return !getConstantIntValue(ofr).has_value(); + })) + return foreachThreadOp->emitError("unsupported dynamic workgroup size"); + + SmallVector workgroupSizes = llvm::to_vector(llvm::map_range( + *maybeWorkgroupSizes, + [](OpFoldResult ofr) { return getConstantIntValue(ofr).value(); })); + + // Step 1. Create the gpu.thread ops + Location loc = foreachThreadOp.getLoc(); + IndexType indexType = rewriter.getIndexType(); + + SmallVector gpuDims{gpu::Dimension::x, gpu::Dimension::y, + gpu::Dimension::z}; + SmallVector threadOps; + for (int64_t idx : llvm::seq(0, workgroupSizes.size())) { + threadOps.push_back( + rewriter.create(loc, indexType, gpuDims[idx])); + } +// Step 2. Maybe create conditionals to predicate the region. + Value predicate; + for (auto [threadId, workgroupSize, globalWorkgroupSize] : + llvm::zip(threadOps, workgroupSizes, globalWorkgroupSizes)) { + if (workgroupSize > globalWorkgroupSize) { + return foreachThreadOp.emitOpError("workgroup size overflow: ") + << workgroupSize << " > " << globalWorkgroupSize; + } + if (workgroupSize == globalWorkgroupSize) continue; + Value tmpPredicate = rewriter.create( + loc, arith::CmpIPredicate::ult, threadId, + rewriter.create(loc, workgroupSize)); + predicate = + predicate ? rewriter.create(loc, predicate, tmpPredicate) + : tmpPredicate; + } + + // Step 3. Move the body of foreachThreadOp. + // Erase the terminator first, it will not be used. + rewriter.eraseOp(foreachThreadOp.getTerminator()); + Block *targetBlock; + Block::iterator insertionPoint; + if (predicate) { + // Step 3.a. If predicated, move at the beginning. + auto ifOp = + rewriter.create(loc, predicate, /*withElseRegion=*/false); + targetBlock = ifOp.thenBlock(); + insertionPoint = ifOp.thenBlock()->begin(); + } else { + // Step 3.a. Otherwise, move inline just before foreachThreadOp. + targetBlock = foreachThreadOp->getBlock(); + insertionPoint = Block::iterator(foreachThreadOp); + } + Block &sourceBlock = foreachThreadOp.getRegion().front(); + targetBlock->getOperations().splice(insertionPoint, + sourceBlock.getOperations()); + + // Step 4. RAUW thread indices to thread ops. + SmallVector threadIndices = + *getThreadIndices(rewriter, foreachThreadOp); + for (auto it : llvm::zip(threadIndices, threadOps)) { + if (!std::get<0>(it)) continue; + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + } + + // Step 5. syncthreads. + if (syncAfterDistribute) rewriter.create(loc); + + // Step 6. Erase old op. + rewriter.eraseOp(foreachThreadOp); + + return *maybeWorkgroupSizes; +} + +DiagnosedSilenceableFailure +transform::ForeachThreadToGpuAndTranslationInfo::applyToOne( + Operation *target, SmallVectorImpl &results, + transform::TransformState &state) { + SmallVector workgroupSize = + extractFromI64ArrayAttr(getWorkgroupSize()); + + gpu::LaunchOp kernelLaunch = dyn_cast(target); + if(!kernelLaunch) { + target->emitError("It is not a GPU kernel launch"); + return DiagnosedSilenceableFailure::definiteFailure(); + } + + workgroupSize.resize(/*size=*/3, /*value=*/1); + + SimpleRewriter rewriter(getContext()); + rewriter.setInsertionPoint(target); + + auto walkResult = target->walk([&](scf::ForeachThreadOp foreachThreadOp) { + rewriter.setInsertionPoint(foreachThreadOp); + if (failed(rewriteForeachThreadToGpu(foreachThreadOp, workgroupSize, + rewriter, true))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + + if (walkResult.wasInterrupted()) + return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + + //TODO: assign workgroup to kernel launch + + results.assign({target}); + return DiagnosedSilenceableFailure(success()); +} + //===----------------------------------------------------------------------===// // TileToForeachThreadOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-gpu.mlir b/mlir/test/Dialect/Linalg/transform-gpu.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-gpu.mlir @@ -0,0 +1,134 @@ +// RUN: mlir-opt %s \ +// RUN: -test-transform-dialect-interpreter \ +// RUN: -canonicalize \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-cf-to-llvm \ +// RUN: -convert-vector-to-llvm \ +// RUN: -convert-arith-to-llvm \ +// RUN: -gpu-kernel-outlining \ +// RUN: -convert-math-to-llvm \ +// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm,reconcile-unrealized-casts,gpu-to-cubin)' \ +// RUN: -gpu-to-llvm \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_cuda_runtime.%shlibext \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils.%shlibext \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_runner_utils.%shlibext | \ +// RUN: FileCheck %s + +!type = memref<2 x 32 x f32> + +func.func @saxpy2d(%x: !type, %y: !type, %alpha : f32, %stream : !gpu.async.token) -> !type { + %grid1 = arith.constant 1 : index + %grid2 = arith.constant 1 : index + %grid3 = arith.constant 1 : index + %cta1 = arith.constant 2 : index + %cta2 = arith.constant 32 : index + %cta3 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %m = arith.constant 2 : index + %n = arith.constant 32 : index + %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %grid1, %arg10 = %grid2, %arg11 = %grid3) + threads(%arg6, %arg7, %arg8) in (%arg12 = %cta1, %arg13 = %cta2, %arg14 = %cta3) + { + scf.foreach_thread (%i, %j) in (%n, %m) { + %4 = memref.load %x[%i, %j] : !type + %5 = memref.load %y[%i, %j] : !type + %6 = math.fma %alpha, %4, %5 : f32 + memref.store %6, %y[%i, %j] : !type + } {thread_dim_mapping = [0, 1, 2]} + gpu.terminator + } + + return %y : !type +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 + transform.structured.foreach_thread_to_gpu_and_translation_info %funcop { workgroup_size = [32, 2, 1] } + } +} + +func.func private @getn() -> index { + %n = arith.constant 2 : index + return %n : index +} +func.func private @getm() -> index { + %m = arith.constant 32 : index + return %m : index +} + +func.func @main() { + %c0 = arith.constant 0 : index + + %c1 = arith.constant 1 : index + %cst = arith.constant 4.000000e+00 : f32 + %cst_0 = arith.constant 1.000000e+00 : f32 + %alpha = arith.constant 0.2 : f32 + %n = call @getn():()-> index + %m = call @getm():()-> index + %x = memref.alloc() {alignment = 64 : i64} : !type + %y = memref.alloc() {alignment = 64 : i64} : !type + + scf.for %arg0 = %c0 to %n step %c1 { + scf.for %arg1 = %c0 to %m step %c1 { + %l1 = arith.index_cast %arg1 : index to i32 + %l2 = arith.sitofp %l1 : i32 to f32 + %l3 = arith.index_cast %arg0 : index to i32 + %l4 = arith.sitofp %l3 : i32 to f32 + %l5 = arith.addf %l4, %l2 : f32 + memref.store %cst_0, %x[%arg0, %arg1] : !type + memref.store %l5, %y[%arg0, %arg1] : !type + } + } + + // call @printme(%x) : (!type) -> () + // call @printme(%y) : (!type) -> () + + %stream = gpu.wait async + %dx = call @alloccopyme(%x, %stream) : (!type, !gpu.async.token) -> !type + %dy = call @alloccopyme(%y, %stream) : (!type, !gpu.async.token) -> !type + + %res = call @saxpy2d(%dx, %dy, %alpha, %stream) : (!type, !type, f32, !gpu.async.token) -> !type + + call @copyme(%dy, %y, %stream) : (!type, !type, !gpu.async.token) -> () + %wait = gpu.wait async [%stream] + call @printme(%y) : (!type) -> () + + memref.dealloc %x : !type + memref.dealloc %y : !type + return +} + +func.func private @allocme(%stream : !gpu.async.token) -> !type { + %r, %asyncToken = gpu.alloc async [%stream] () : !type + return %r : !type +} + +func.func private @copyme(%dest : !type, %src : !type, %stream : !gpu.async.token) -> () { + %4 = gpu.memcpy async [%stream] %src, %dest : !type, !type + return +} + +func.func private @alloccopyme(%ptr : !type, %stream : !gpu.async.token) -> !type { + %dev = call @allocme(%stream) : (!gpu.async.token) -> !type + call @copyme(%ptr, %dev, %stream):(!type, !type, !gpu.async.token)-> () + return %dev : !type +} + +func.func private @printme(%ptr : !type) { + %1 = memref.cast %ptr : !type to memref<*xf32> + call @printMemrefF32(%1) : (memref<*xf32>) -> () + return +} + +func.func private @printMemrefF32(%ptr : memref<*xf32>) + + + +// CHECK: 0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2, 9.2, 10.2, 11.2, 12.2, 13.2, 14.2, 15.2, 16.2, 17.2, 18.2, 19.2, 20.2, 21.2, 22.2, 23.2, 24.2, 25.2, 26.2, 27.2, 28.2, 29.2, 30.2, 31.2 +// CHECK: 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2, 9.2, 10.2, 11.2, 12.2, 13.2, 14.2, 15.2, 16.2, 17.2, 18.2, 19.2, 20.2, 21.2, 22.2, 23.2, 24.2, 25.2, 26.2, 27.2, 28.2, 29.2, 30.2, 31.2, 32.2 \ No newline at end of file