diff --git a/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h b/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h --- a/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h +++ b/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h @@ -93,6 +93,18 @@ getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams ¶ms); +/// Returns whether the `vector.transfer_read` instruction can be interpreted +/// as a warp-level cooperative matrix load operation. This function is meant to +/// be used to establish whether `op` is part of a chain of such warp-level +/// operations. +bool canLowerToWarpMatrixOperation(vector::TransferReadOp op); + +/// Returns whether the `vector.transfer_write` instruction can be interpreted +/// as a warp-level cooperative matrix store operation. This function is meant +/// to be used to establish whether `op` is part of a chain of such warp-level +/// operations. +bool canLowerToWarpMatrixOperation(vector::TransferWriteOp op); + } // namespace nvgpu } // namespace mlir diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -119,10 +119,9 @@ permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx); } -// Return the stide for the dimension 0 of |type| if it is a memref and has a -// constant stride. -static std::optional -getMemrefConstantHorizontalStride(ShapedType type) { +// Return the stide for the second-to-last dimension of |type| if it is a memref +// and has a constant stride. +static std::optional getStaticallyKnownRowStride(ShapedType type) { auto memrefType = dyn_cast(type); if (!memrefType) return false; @@ -141,35 +140,27 @@ } // Return true if the transfer op can be converted to a MMA matrix load. -static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp, - bool useNvGpu) { +static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) { if (readOp.getMask() || readOp.hasOutOfBoundsDim() || readOp.getVectorType().getRank() != 2) return false; - if (!getMemrefConstantHorizontalStride(readOp.getShapedType())) + if (!getStaticallyKnownRowStride(readOp.getShapedType())) return false; // Only allow integer types if the signedness can be inferred. - if (!useNvGpu && readOp.getVectorType().getElementType().isInteger(8)) + if (readOp.getVectorType().getElementType().isInteger(8)) if (!readOp->hasOneUse() || (!isa(*readOp->user_begin()) && !isa(*readOp->user_begin()))) return false; AffineMap map = readOp.getPermutationMap(); - MLIRContext *ctx = readOp.getContext(); AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx); AffineExpr zero = getAffineConstantExpr(0, ctx); auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx); - - if (!useNvGpu) { - bool result = map.isMinorIdentity() || map == broadcastInnerDim || - isTransposeMatrixLoadMap(map); - return result; - } - - return true; + return map.isMinorIdentity() || map == broadcastInnerDim || + isTransposeMatrixLoadMap(map); } // Return true if the transfer op can be converted to a MMA matrix store. @@ -182,7 +173,7 @@ if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() || writeOp.getVectorType().getRank() != 2) return false; - if (!getMemrefConstantHorizontalStride(writeOp.getShapedType())) + if (!getStaticallyKnownRowStride(writeOp.getShapedType())) return false; // TODO: Support transpose once it is added to GPU dialect ops. if (!writeOp.getPermutationMap().isMinorIdentity()) @@ -281,9 +272,11 @@ if (isa(op)) return true; if (auto transferRead = dyn_cast(op)) - return transferReadSupportsMMAMatrixType(transferRead, useNvGpu); + return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead) + : transferReadSupportsMMAMatrixType(transferRead); if (auto transferWrite = dyn_cast(op)) - return transferWriteSupportsMMAMatrixType(transferWrite); + return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite) + : transferWriteSupportsMMAMatrixType(transferWrite); if (auto extractStridedSlice = dyn_cast(op)) return useNvGpu && extractStridedSliceSupportsMMAMatrixType(extractStridedSlice); @@ -366,9 +359,14 @@ // chain. MMA matrix are stored in an opaque type so they cannot be used // by all operations. if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) { - return !supportsMMaMatrixType(op, useNvGpu); + if (!supportsMMaMatrixType(op, useNvGpu)) { + LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n"); + return true; + } + return false; })) return; + opToConvert.insert(dependentOps.begin(), dependentOps.end()); }); // Sort the operations so that we can convert them in topological order. @@ -531,10 +529,11 @@ rewriter.setInsertionPoint(op); assert(op.getTransferRank() > 0 && "unexpected 0-d transfer"); - assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false)); + assert(transferReadSupportsMMAMatrixType(op) && + "expected convertible operation"); std::optional stride = - getMemrefConstantHorizontalStride(op.getShapedType()); + getStaticallyKnownRowStride(op.getShapedType()); if (!stride.has_value()) { LLVM_DEBUG(DBGS() << "no stride\n"); return rewriter.notifyMatchFailure(op, "no stride"); @@ -585,7 +584,7 @@ assert(transferWriteSupportsMMAMatrixType(op)); std::optional stride = - getMemrefConstantHorizontalStride(op.getShapedType()); + getStaticallyKnownRowStride(op.getShapedType()); if (!stride.has_value()) { LLVM_DEBUG(DBGS() << "no stride\n"); return rewriter.notifyMatchFailure(op, "no stride"); @@ -1289,7 +1288,8 @@ return op->emitError() << "unhandled vector to mma type: " << *op; }) .failed()) { - return op->emitError() << "Failed to convert op " << *op; + return op->emitOpError() + << "failed to convert op during vector-to-nvgpu conversion"; } } return success(); @@ -1312,10 +1312,11 @@ return signalPassFailure(); IRRewriter rewriter(&getContext()); - if (useNvGpu.getValue()) { + if (useNvGpu) { if (failed( convertVectorToNVVMCompatibleMMASync(rewriter, getOperation()))) return signalPassFailure(); + return; } (void)convertVectorToMMAOps(rewriter, getOperation()); } diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp --- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp +++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp @@ -272,3 +272,54 @@ return failure(); } + +bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferReadOp op) { + if (op.getMask() || op.hasOutOfBoundsDim()) + return false; + VectorType type = op.getType(); + // The result type should be 2D. Note that it is possible to expand support so + // that we are robust to extra unit dimensions that failed to fold, but that + // would significantly increase downstream code complexity in the conversion + // step. For now, we rely on other patterns to ensure canonical 2D form is + // used when targeting the `nvgpu.mma.sync` lowering path. + if (!type.hasStaticShape() || type.getRank() != 2) + return false; + + // Currently we can't support reads on tensor types because we need stride + // information to ensure correctness of downstream assumptions. It is possible + // to enable this if caller can assert that tensor will be lowered in a + // particular manner. + auto sourceType = dyn_cast(op.getSource().getType()); + if (!sourceType) + return false; + + // Check that the last dimension of the read is contiguous. Note that it is + // possible to expand support for this by scalarizing all the loads during + // conversion. + auto [strides, offset] = mlir::getStridesAndOffset(sourceType); + return strides.back() == 1; +} + +bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferWriteOp op) { + if (op.getMask() || op.hasOutOfBoundsDim() || op.getTransferRank() == 0) + return false; + VectorType type = op.getVectorType(); + if (!type.hasStaticShape() || type.getRank() != 2) + return false; + // TODO: Currently we rely on lowering to a `vector.store` operation. We could + // support the transposed write case by lowering to scalarized `memref.store` + // operations. + if (!op.getPermutationMap().isMinorIdentity()) + return false; + // Currently we can't support reads on tensor types because we need stride + // information to ensure correctness of downstream assumptions. + auto sourceType = dyn_cast(op.getSource().getType()); + if (!sourceType) + return false; + + // Check that the last dimension of the target memref is contiguous. Note that + // it is possible to expand support for this by scalarizing all the stores + // during conversion. + auto [strides, offset] = mlir::getStridesAndOffset(sourceType); + return strides.back() == 1; +} diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir --- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir +++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir @@ -47,7 +47,7 @@ // CHECK: nvgpu.ldmatrix %arg0[[[m_coord]], [[k_coord]]] {numTiles = 4 : i32, transpose = false} : memref<128x128xi8, #gpu.address_space> -> vector<4x4xi8> // Verify that the operandB load is lowered to scalar load to be able - // to transpose at 8-bit granularity. ldmatrix can only transpose at + // to transpose at 8-bit granularity. ldmatrix can only transpose at // 16-bit granularity. // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB0_map]]()[{{%.+}}] @@ -282,7 +282,7 @@ // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$strided_map]] // CHECK-DAG: [[fragmentB:%.+]] = nvgpu.ldmatrix %arg1[[[c0]], [[c0]], [[k_coord]], [[n_coord]]] {numTiles = 4 : i32, transpose = true} %B = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true], permutation_map = #map_b} : memref<4x1x32x32xf16, #gpu.address_space>, vector<16x16xf16> - + // CHECK-DAG: [[m_coord:%.+]] = affine.apply [[$strided_map]] // CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$contiguous_map]] // CHECK-DAG: [[fragmentC:%.*]] = nvgpu.ldmatrix %arg2[[[c0]], [[m_coord]], [[n_coord]]] {numTiles = 4 : i32, transpose = false} @@ -713,3 +713,125 @@ vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xi32>, memref<128x128xi32> return } + +// ----- + + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> +!smem_type = memref<20x20xf16, strided<[?, 1], offset: ?>, #gpu.address_space> + +// This test case is identical to m16n8k16 test case, but it tests that having +// n row dimension with unknown stride is handled correctly. + +// CHECK-DAG: [[$strided_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)> +// CHECK-DAG: [[$contiguous_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8)> +// CHECK-LABEL: func @strided_memref_read_write +func.func @strided_memref_read_write(%arg0: !smem_type, + %arg1: !smem_type, + %arg2: !smem_type) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16> + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + + // CHECK-DAG: [[m_coord:%.+]] = affine.apply [[$strided_map]] + // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$contiguous_map]] + // CHECK: nvgpu.ldmatrix %arg0[[[m_coord]], [[k_coord]]] {numTiles = 4 : i32, transpose = false} + // CHECK-DAG: [[n_coord:%.+]] = affine.apply [[$contiguous_map]] + // CHECK-DAG: [[k_coord:%.+]] = affine.apply [[$strided_map]] + // CHECK: nvgpu.ldmatrix %arg1[[[k_coord]], [[n_coord]]] {numTiles = 2 : i32, transpose = true} + %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : !smem_type, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : !smem_type, vector<8x16xf16> + %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : !smem_type, vector<16x8xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} + %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, !smem_type + return +} + +// ----- + + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d0, d3)> +#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +!smem_type = memref<20x20x20xf16, strided<[?, ?, 1], offset: ?>, #gpu.address_space> + +// CHECK-LABEL: func @unsupported_non_2d_load_store +func.func @unsupported_non_2d_load_store(%arg0: !smem_type, + %arg1: !smem_type, + %arg2: !smem_type) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16> + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + + // CHECK-NOT: nvgpu.ldmatrix + // CHECK-NOT: nvgpu.mma + %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : !smem_type, vector<1x16x16xf16> + %B = vector.transfer_read %arg1[%c0, %c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true, true]} : !smem_type, vector<8x1x16xf16> + %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : !smem_type, vector<1x16x8xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind} + %A, %B, %C : vector<1x16x16xf16>, vector<8x1x16xf16> into vector<1x16x8xf16> + vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x16x8xf16>, !smem_type + return +} + +// ----- + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +!smem_type = memref<20x20xf16, strided<[?, ?], offset: ?>, #gpu.address_space> + +// CHECK-LABEL: func @unsupported_fully_dynamic_strides +func.func @unsupported_fully_dynamic_strides(%arg0: !smem_type, + %arg1: !smem_type, + %arg2: !smem_type) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16> + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + + // CHECK-NOT: nvgpu.ldmatrix + // CHECK-NOT: nvgpu.mma + %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : !smem_type, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : !smem_type, vector<8x16xf16> + %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : !smem_type, vector<16x8xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} + %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, !smem_type + return +} + +// ----- + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + + +!smem_type = memref<20x20xf16, strided<[?, 1], offset: ?>, #gpu.address_space> + +// CHECK-LABEL: func @unsupported_transposed_store +func.func @unsupported_transposed_store(%arg0: !smem_type, + %arg1: !smem_type, + %arg2: !smem_type) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16> + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + + // CHECK-NOT: nvgpu.ldmatrix + // CHECK-NOT: nvgpu.mma + %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : !smem_type, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : !smem_type, vector<8x16xf16> + %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : !smem_type, vector<16x8xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} + %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1)->(d1, d0)>} : vector<16x8xf16>, !smem_type + return +}