diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -903,6 +903,57 @@ }]; } +def ApplyPatternsOp : Op { + let description = [{ + Greedily applies patterns as specified by its attributes. + + Must be applied to an op with trait IsolatedFromAbove since the + GreedyPatternRewriter asserts those. + + Returns the IsolatedFromAbove op whose content it has modified for better + chaining APIs. + + The following additive attributes can be set, they add patterns in an + unspecified order: + - canonicalization: adds all the canonicalization patterns of all + registered dialects and ops. + - rank_reducing: adds patterns that results in rank-reducing behavior on + subset-based operations. + - vector_to_gpu: adds patterns that converts vector dialect to gpu dialect. + + Return modes: + ============= + This operation applies a number of patterns to rewrite vector IR into + distributed warp form. To apply these patterns, this operation must target + an operation that is isolated from above, otherwise the transform definitely + fails. + + If the pattern application fails, or if the underlying listener fails to + capture op handles, the transformation definitely fails. + + Otherwise the transformation is successful and no result is returned. + }]; + + let arguments = (ins PDL_Operation:$target, + UnitAttr:$canonicalization, + UnitAttr:$rank_reducing, + UnitAttr:$vector_to_gpu); + let results = (outs PDL_Operation:$result); + + let assemblyFormat = "$target attr-dict"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); + }]; +} + def VectorizeOp : Op { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -23,6 +23,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Conversion/VectorToGPU/VectorToGPU.h" using namespace mlir; using namespace mlir::linalg; @@ -1507,6 +1508,61 @@ return DiagnosedSilenceableFailure(success()); } +//===----------------------------------------------------------------------===// +// ApplyPatternsOp +//===----------------------------------------------------------------------===// +static void addRankReducingPatterns(RewritePatternSet &patterns) { + vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); + linalg::populateFoldUnitExtentDimsPatterns(patterns); +} + +static void addAllRegisteredCanonicalizationPatterns( + RewritePatternSet &patterns) { + MLIRContext *ctx = patterns.getContext(); + for (Dialect *dialect : ctx->getLoadedDialects()) + dialect->getCanonicalizationPatterns(patterns); + for (RegisteredOperationName op : ctx->getRegisteredOperations()) + op.getCanonicalizationPatterns(patterns, ctx); + FrozenRewritePatternSet(); +} + +static void addVectorToGPU(RewritePatternSet &patterns) { + populatePrepareVectorToMMAPatterns(patterns, false); +} + +static void postVectorToGPU(Operation* target) { + (void)convertVectorToMMAOps(target); +} + +DiagnosedSilenceableFailure +transform::ApplyPatternsOp::applyToOne( + Operation *target, SmallVectorImpl &results, + transform::TransformState &state) { + + if (!target->hasTrait()) { + target->emitOpError( + "applies only to isolated-from-above targets because it needs to apply " + "patterns greedily"); + return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + } + + MLIRContext *ctx = target->getContext(); + RewritePatternSet patterns(ctx); + if (getVectorToGpu()) addVectorToGPU(patterns); + if (getCanonicalization()) addAllRegisteredCanonicalizationPatterns(patterns); + if (getRankReducing()) addRankReducingPatterns(patterns); + + GreedyRewriteConfig config; + LogicalResult result = applyPatternsAndFoldGreedily( + target, std::move(patterns), config); + if (failed(result)) + return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + + if (getVectorToGpu()) postVectorToGPU(target); + + results.assign({target}); + return DiagnosedSilenceableFailure(success()); +} //===----------------------------------------------------------------------===// // TileToForeachThreadOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-gpu.mlir b/mlir/test/Dialect/Linalg/transform-gpu.mlir --- a/mlir/test/Dialect/Linalg/transform-gpu.mlir +++ b/mlir/test/Dialect/Linalg/transform-gpu.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s +// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file -canonicalize -cse %s | FileCheck %s !type = memref<2 x 32 x f32> !type1d = memref<32 x f32> @@ -53,19 +53,19 @@ %c12 = arith.constant 12 : index %c9 = arith.constant 9 : index %c7 = arith.constant 7 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C9:.*]] = arith.constant 9 : index +// CHECK: %[[C7:.*]] = arith.constant 7 : index // CHECK: gpu.launch // CHECK: %[[TIDX:.*]] = gpu.thread_id x // CHECK: %[[TIDY:.*]] = gpu.thread_id y -// CHECK: %[[C9:.*]] = arith.constant 9 : index // CHECK: arith.cmpi ult, %[[TIDX]], %[[C9]] : index -// CHECK: %[[C7:.*]] = arith.constant 7 : index // CHECK: arith.cmpi ult, %[[TIDY]], %[[C7]] : index // CHECK: memref.load %[[ARGX]][%[[TIDY]], %[[TIDX]]] // CHECK: memref.load %[[ARGY]][%[[TIDY]], %[[TIDX]]] // CHECK: gpu.barrier // CHECK: %[[TIDX2:.*]] = gpu.thread_id x // CHECK: %[[TIDY2:.*]] = gpu.thread_id y -// CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: arith.cmpi ult, %[[TIDY2]], %[[C1]] : index // CHECK: memref.load %[[ARGT]][%[[TIDX2]]] // CHECK: gpu.barrier @@ -135,3 +135,81 @@ transform.structured.map_nested_foreach_thread_to_gpu_threads %gpuLaunch { blockDim = [32, 4, 1] } } } + +// ----- +// CHECK-LABEL: func.func @tiled_matmul( +// CHECK-SAME: %[[INX:[0-9a-z]+]]: memref<8192x4096xf16> +// CHECK-SAME: %[[INY:[0-9a-z]+]]: memref<4096x8192xf16> +// CHECK-SAME: %[[OUT:[0-9a-z]+]]: memref<8192x8192xf16> +func.func @tiled_matmul(%arg0: memref<8192x4096xf16>, %arg1: memref<4096x8192xf16>, %arg2: memref<8192x8192xf16>) -> memref<8192x8192xf16> { + %c0 = arith.constant 0 : index + %c8192 = arith.constant 8192 : index + %c256 = arith.constant 256 : index + %c1 = arith.constant 1 : index +// CHECK: gpu.launch +// CHECK: %[[BLKX:.*]] = gpu.block_id x +// CHECK: %[[BLKY:.*]] = gpu.block_id y +// CHECK: %[[BIDINX:[0-9a-z]+]] = memref.subview %[[INX]][%[[BLKX]], 0] +// CHECK: %[[BIDINY:[0-9a-z]+]] = memref.subview %[[INY]][0, %[[BLKY]]] +// CHECK: %[[BIDOUT:[0-9a-z]+]] = memref.subview %[[OUT]][%[[BLKX]], %[[BLKY]]] +// CHECK: %[[TIDX:.*]] = gpu.thread_id x +// CHECK: %[[TIDY:.*]] = gpu.thread_id y +// CHECK: %[[TIDINX:[0-9a-z]+]] = memref.subview %[[BIDINX]][%[[TIDX]], 0] +// CHECK: %[[TIDINY:[0-9a-z]+]] = memref.subview %[[BIDINY]][0, %[[TIDY]]] +// CHECK: %[[TIDOUT:[0-9a-z]+]] = memref.subview %[[BIDOUT]][%[[TIDX]], %[[TIDY]]] +// CHECK: scf.for %[[IV:[0-9a-z]+]] +// CHECK: %[[TILEDINX:[0-9a-z]+]] = memref.subview %[[TIDINX]][0, %[[IV]]] +// CHECK: %[[TILEDINY:[0-9a-z]+]] = memref.subview %[[TIDINY]][%[[IV]], 0] +// CHECK: %[[SUBINX:[0-9a-z]+]] = gpu.subgroup_mma_load_matrix %[[TILEDINX]] +// CHECK: %[[SUBINY:[0-9a-z]+]] = gpu.subgroup_mma_load_matrix %[[TILEDINY]] +// CHECK: %[[SUBOUT:[0-9a-z]+]] = gpu.subgroup_mma_load_matrix %[[TIDOUT]] +// CHECK: %[[RES:[0-9a-z]+]] = gpu.subgroup_mma_compute %[[SUBINX]], %[[SUBINY]], %[[SUBOUT]] : !gpu.mma_matrix<4x4xf16, "AOp">, !gpu.mma_matrix<4x4xf16, "BOp"> -> !gpu.mma_matrix<4x4xf16, "COp"> +// CHECK: gpu.subgroup_mma_store_matrix %[[RES]], %[[TIDOUT]] + scf.foreach_thread (%i, %j) in (%c256, %c256) { + %0 = memref.subview %arg0[%i, 0] [256, 4096] [1, 1] : memref<8192x4096xf16> to memref<256x4096xf16, strided<[4096, 1], offset: ?>> + %1 = memref.subview %arg1[0, %j] [4096, 256] [1, 1] : memref<4096x8192xf16> to memref<4096x256xf16, strided<[8192, 1], offset: ?>> + %2 = memref.subview %arg2[%i, %j] [256, 256] [1, 1] : memref<8192x8192xf16> to memref<256x256xf16, strided<[8192, 1], offset: ?>> + %c4 = arith.constant 4 : index + scf.foreach_thread (%ti, %tj) in (%c4, %c4) { + %3 = memref.subview %0[%ti, 0] [4, 4096] [1, 1] : memref<256x4096xf16, strided<[4096, 1], offset: ?>> to memref<4x4096xf16, strided<[4096, 1], offset: ?>> + %4 = memref.subview %1[0, %tj] [4096, 4] [1, 1] : memref<4096x256xf16, strided<[8192, 1], offset: ?>> to memref<4096x4xf16, strided<[8192, 1], offset: ?>> + %5 = memref.subview %2[%ti, %tj] [4, 4] [1, 1] : memref<256x256xf16, strided<[8192, 1], offset: ?>> to memref<4x4xf16, strided<[8192, 1], offset: ?>> + %c4096 = arith.constant 4096 : index + scf.for %arg7 = %c0 to %c4096 step %c4 { + %6 = memref.subview %3[0, %arg7] [4, 4] [1, 1] : memref<4x4096xf16, strided<[4096, 1], offset: ?>> to memref<4x4xf16, strided<[4096, 1], offset: ?>> + %7 = memref.subview %4[%arg7, 0] [4, 4] [1, 1] : memref<4096x4xf16, strided<[8192, 1], offset: ?>> to memref<4x4xf16, strided<[8192, 1], offset: ?>> + linalg.matmul ins(%6, %7 : memref<4x4xf16, strided<[4096, 1], offset: ?>>, memref<4x4xf16, strided<[8192, 1], offset: ?>>) outs(%5 : memref<4x4xf16, strided<[8192, 1], offset: ?>>) + } + } + } + return %arg2 : memref<8192x8192xf16> +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %funcop = transform.structured.match ops{["func.func"]} in %arg1 + + // This op transforms as follows + // scf.foreach_thread(i, j) --> gpu.launch (blockIdx.x, blockIdx.y) + // scf.foreach_thread(ti, tj) + // body + %gpuLaunch = transform.structured.map_nested_foreach_thread_to_gpu_blocks %funcop { generate_gpu_launch } + + // This op transforms as follows + // gpu.launch (...) + // i = blockIdx.x + // j = blockIdx.y + // scf.foreach_thread(ti, tj) --> (threadIdx.x, threadIdx.y) + // body + %gpuLaunch2 = transform.structured.map_nested_foreach_thread_to_gpu_threads %gpuLaunch { blockDim = [4, 4, 1] } + + // linalg.matmul -> vector dialect + %isolatedFunc = get_closest_isolated_parent %gpuLaunch2 + %vectorizedFunc = transform.structured.vectorize %isolatedFunc + + // vector dialect -> GPU warp synchronous matrix multiply accumulate + transform.structured.apply_patterns %vectorizedFunc { vector_to_gpu } + } +}