diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td --- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td @@ -14,6 +14,29 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" +def ApplyUnrollVectorsSubgroupMmaOp : Op]> { + let description = [{ + Unrolls contractions to the target `m`, `n`, and `k` native vector size, + along with other vector operations based on expected usage. `transfer_read` + ops unroll based on the extract slice shape introduced by unrolling the + contractions, while elementwise and `transfer_write` ops unroll to the shape of + the C matrix (`m x n`). + + This operation applies to pure vector operations and should be applied before + lowering to subgroup_mma ops. + }]; + + let arguments = (ins I64Attr:$m, + I64Attr:$n, + I64Attr:$k); + + let assemblyFormat = [{ + `[` $m `,` $n `,` $k `]` attr-dict + }]; +} + def EliminateBarriersOp : Op]> { diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" @@ -46,6 +47,132 @@ #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ") +//===----------------------------------------------------------------------===// +// ApplyUnrollVectorsSubgroupMmaOp +//===----------------------------------------------------------------------===// + +/// Pick an unrolling order that will allow tensorcore operation to reuse LHS +/// register. +static std::optional> +gpuMmaUnrollOrder(vector::ContractionOp contract) { + SmallVector order; + // First make reduction the outer dimensions. + for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { + if (vector::isReductionIterator(iter)) { + order.push_back(index); + } + } + + llvm::SmallDenseSet dims; + for (AffineExpr expr : contract.getIndexingMapsArray()[0].getResults()) { + dims.insert(expr.cast().getPosition()); + } + // Then parallel dimensions that are part of Lhs as we want to re-use Lhs. + for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { + if (vector::isParallelIterator(iter) && dims.count(index)) { + order.push_back(index); + } + } + // Then the remaining parallel loops. + for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { + if (vector::isParallelIterator(iter) && !dims.count(index)) { + order.push_back(index); + } + } + return order; +} + +/// Returns the target vector size for the target operation based on the native +/// vector size specified with `m`, `n`, and `k`. +static std::optional> +getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) { + if (auto contract = dyn_cast(op)) { + int64_t contractRank = contract.getIteratorTypes().size(); + if (contractRank < 3) + return std::nullopt; + SmallVector nativeSize(contractRank - 3, 1); + nativeSize.append({m, n, k}); + return nativeSize; + } + if (auto writeOp = dyn_cast(op)) { + int64_t writeRank = writeOp.getVectorType().getRank(); + if (writeRank < 2) + return std::nullopt; + SmallVector nativeSize(writeRank - 2, 1); + nativeSize.append({m, n}); + return nativeSize; + } + if (auto readOp = dyn_cast(op)) { + // Transfer read ops may need different shapes based on how they are being + // used. For simplicity just match the shape used by the extract strided op. + VectorType sliceType; + for (Operation *users : op->getUsers()) { + auto extract = dyn_cast(users); + if (!extract) + return std::nullopt; + auto vecType = extract.getResult().getType().cast(); + if (sliceType && sliceType != vecType) + return std::nullopt; + sliceType = vecType; + } + return llvm::to_vector(sliceType.getShape()); + } + if ((OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)) { + if (auto vecType = op->getResultTypes()[0].dyn_cast()) { + // TODO: The condition for unrolling elementwise should be restricted + // only to operations that need unrolling (connected to the contract). + if (vecType.getRank() < 2) + return std::nullopt; + + // First check whether there is a slice to infer the shape from. This is + // required for cases where the accumulator type differs from the input + // types, in which case we will see an `arith.ext_` between the contract + // and transfer_read which needs to be unrolled. + VectorType sliceType; + for (Operation *users : op->getUsers()) { + auto extract = dyn_cast(users); + if (!extract) + return std::nullopt; + auto vecType = extract.getResult().getType().cast(); + if (sliceType && sliceType != vecType) + return std::nullopt; + sliceType = vecType; + } + if (sliceType) + return llvm::to_vector(sliceType.getShape()); + + // Else unroll for trailing elementwise. + SmallVector nativeSize(vecType.getRank() - 2, 1); + // Map elementwise ops to the output shape. + nativeSize.append({m, n}); + return nativeSize; + } + } + return std::nullopt; +} + +void transform::ApplyUnrollVectorsSubgroupMmaOp::populatePatterns( + RewritePatternSet &patterns) { + auto unrollOrder = [](Operation *op) -> std::optional> { + auto contract = dyn_cast(op); + if (!contract) + return std::nullopt; + return gpuMmaUnrollOrder(contract); + }; + + int64_t m = getM(); + int64_t n = getN(); + int64_t k = getK(); + auto nativeShapeFn = + [m, n, k](Operation *op) -> std::optional> { + return getSubgroupMmaNativeVectorSize(op, m, n, k); + }; + vector::populateVectorUnrollPatterns( + patterns, vector::UnrollVectorOptions() + .setNativeShapeFn(nativeShapeFn) + .setUnrollTraversalOrderFn(unrollOrder)); +} + //===----------------------------------------------------------------------===// // EliminateBarriersOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/GPU/subgroup-mma-vector-unroll.mlir b/mlir/test/Dialect/GPU/subgroup-mma-vector-unroll.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/GPU/subgroup-mma-vector-unroll.mlir @@ -0,0 +1,98 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +func.func @matmul(%lhs: memref<32x32xf32>, %rhs: memref<32x32xf32>, %out: memref<32x32xf32>) { + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0.000000e+00> : vector<16x16xf32> + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + %3 = gpu.thread_id x + %4 = gpu.thread_id y + %5 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%4] + %6 = affine.apply affine_map<()[s0] -> ((s0 floordiv 32) * 16)>()[%3] + // CHECK: scf.for {{.*}} -> (vector<16x16xf32>) { + // CHECK-COUNT-2: vector.transfer_read {{.*}} vector<16x8xf32> + // CHECK-COUNT-2: vector.transfer_read {{.*}} vector<8x16xf32> + // CHECK-COUNT-2: vector.contract {{.*}} vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32> + // CHECK: scf.yield {{.*}} : vector<16x16xf32> + // CHECK: } + %7 = scf.for %arg0 = %c0 to %c32 step %c16 iter_args(%arg1 = %cst) -> (vector<16x16xf32>) { + %10 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%c0)[%5] + %11 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%c0)[%arg0] + %12 = vector.transfer_read %lhs[%10, %11], %cst_0 {in_bounds = [true, true]} : memref<32x32xf32>, vector<16x16xf32> + %16 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%c0)[%6] + %17 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%c0)[%arg0] + %18 = vector.transfer_read %rhs[%17, %16], %cst_0 {in_bounds = [true, true]} : memref<32x32xf32>, vector<16x16xf32> + %22 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %12, %18, %arg1 : vector<16x16xf32>, vector<16x16xf32> into vector<16x16xf32> + scf.yield %22 : vector<16x16xf32> + } + %8 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%c0)[%5] + %9 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%c0)[%6] + vector.transfer_write %7, %out[%8, %9] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32> + return +} + +transform.sequence failures(propagate) { +^bb1(%func_op: !transform.op<"func.func">): + transform.apply_patterns to %func_op { + transform.apply_patterns.gpu.unroll_vectors_subgroup_mma [16, 16, 8] + } : !transform.op<"func.func"> +} + +// ----- + +// CHECK-LABEL: func.func @gathered_matmul +func.func @gathered_matmul(%lhs: memref<32x32xf32>, %rhs: memref<32x32xf32>, %out: memref<32x32xf32>) { + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0.000000e+00> : vector<16x16xf32> + %cst_mask = arith.constant dense : vector<4x4xi1> + %cst_pt = arith.constant dense<0.000000e+00> : vector<4x4xf32> + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + %cst_1 = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> + %cst_2 = arith.constant dense<1> : vector<4x4xindex> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %3 = gpu.thread_id x + %4 = gpu.thread_id y + %5 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%4] + %6 = affine.apply affine_map<()[s0] -> ((s0 floordiv 32) * 16)>()[%3] + // CHECK: scf.for {{.*}} -> (vector<16x16xf32>) { + // CHECK: arith.addi {{.*}} : vector<4xindex> + // CHECK: vector.gather {{.*}} : memref<32x32xf32>, vector<4x4xindex>, vector<4x4xi1>, vector<4x4xf32> into vector<4x4xf32> + // CHECK-COUNT-8: vector.transfer_read {{.*}} vector<8x4xf32> + // CHECK-COUNT-4: vector.transfer_read {{.*}} vector<4x16xf32> + // CHECK-COUNT-8: vector.contract {{.*}} vector<8x4xf32>, vector<4x16xf32> into vector<8x16xf32> + // CHECK: scf.yield {{.*}} : vector<16x16xf32> + // CHECK: } + %7 = scf.for %arg0 = %c0 to %c32 step %c16 iter_args(%arg1 = %cst) -> (vector<16x16xf32>) { + %10 = vector.broadcast %arg0 : index to vector<4xindex> + %11 = arith.addi %10, %cst_1 : vector<4xindex> + %12 = vector.broadcast %11 : vector<4xindex> to vector<4x4xindex> + %13 = arith.addi %12, %cst_2 : vector<4x4xindex> + %14 = vector.gather %lhs[%c0, %c0] [%13], %cst_mask, %cst_pt : memref<32x32xf32>, vector<4x4xindex>, vector<4x4xi1>, vector<4x4xf32> into vector<4x4xf32> + vector.transfer_write %14, %alloc[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, memref<32x32xf32> + gpu.barrier + %15 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%c0)[%5] + %16 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%c0)[%arg0] + %17 = vector.transfer_read %alloc[%15, %16], %cst_0 {in_bounds = [true, true]} : memref<32x32xf32>, vector<16x16xf32> + %18 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%c0)[%6] + %19 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%c0)[%arg0] + %20 = vector.transfer_read %rhs[%19, %18], %cst_0 {in_bounds = [true, true]} : memref<32x32xf32>, vector<16x16xf32> + %21 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %17, %20, %arg1 : vector<16x16xf32>, vector<16x16xf32> into vector<16x16xf32> + scf.yield %21 : vector<16x16xf32> + } + %8 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%c0)[%5] + %9 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%c0)[%6] + vector.transfer_write %7, %out[%8, %9] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32> + return +} + +transform.sequence failures(propagate) { +^bb1(%func_op: !transform.op<"func.func">): + transform.apply_patterns to %func_op { + transform.apply_patterns.gpu.unroll_vectors_subgroup_mma [8, 16, 4] + } : !transform.op<"func.func"> +}