Index: mlir/include/mlir/Dialect/GPU/IR/GPUOps.td =================================================================== --- mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1102,7 +1102,7 @@ determined using `indices`. The matrix being loaded into is the result. The `leadDimension` attribute specifies the leading dimension size of the source matrix which eventually allows the lowering to determine the size of each - row. + row. If the `transpose` attribute is present then the op does a transposed load. This op is often meant to be used along with `gpu.subgroup_mma_store_matrix` and `gpu.subgroup_mma_compute`. @@ -1117,7 +1117,8 @@ let arguments = (ins Arg:$srcMemref, Variadic:$indices, - IndexAttr:$leadDimension); + IndexAttr:$leadDimension, + OptionalAttr:$transpose); let results = (outs GPU_MMAMatrix:$res); Index: mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp =================================================================== --- mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -77,6 +77,10 @@ if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) return failure(); + // TODO: Support transposed mma loads. + if (subgroupMmaLoadMatrixOp.getTranspose()) + return failure(); + // Get the shape of the MMAMatrix type being returned. The shape will // choose which intrinsic this op will be lowered to. gpu::MMAMatrixType retType = Index: mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp =================================================================== --- mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp +++ mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp @@ -87,10 +87,12 @@ auto i32Type = rewriter.getI32Type(); auto strideValue = rewriter.create( loc, i32Type, IntegerAttr::get(i32Type, stride)); - auto coloumnMajor = rewriter.create( - loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); + bool useColMajor = + static_cast(subgroupMmaLoadMatrixOp.getTranspose()); + auto columnMajor = rewriter.create( + loc, rewriter.getI1Type(), rewriter.getBoolAttr(useColMajor)); rewriter.replaceOpWithNewOp( - subgroupMmaLoadMatrixOp, coopType, bufferPtr, strideValue, coloumnMajor, + subgroupMmaLoadMatrixOp, coopType, bufferPtr, strideValue, columnMajor, spirv::MemoryAccessAttr()); return success(); } Index: mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp =================================================================== --- mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -92,6 +92,19 @@ return true; } +// Return true if the given map represents a transposed matrix load, +// i.e. (d0, d1, ...) -> (dn-1, dn-2). +static bool isTransposeMatrixLoadMap(OpBuilder &b, AffineMap permutationMap) { + auto nDim = permutationMap.getNumDims(); + if (nDim < 2) + return false; + + AffineExpr innerDim = b.getAffineDimExpr(nDim - 1); + AffineExpr outerDim = b.getAffineDimExpr(nDim - 2); + return permutationMap == + AffineMap::get(nDim, 0, {innerDim, outerDim}, b.getContext()); +} + // Return the stide for the dimension 0 of |type| if it is a memref and has a // constant stride. static llvm::Optional @@ -129,9 +142,9 @@ readOp.getContext()); if (!useNvGpu) { - // TODO: Support transpose once it is added to GPU dialect ops. - // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1). - return map.isMinorIdentity() || map == broadcastInnerDim; + bool result = map.isMinorIdentity() || map == broadcastInnerDim || + isTransposeMatrixLoadMap(b, map); + return result; } return true; @@ -445,9 +458,10 @@ gpu::MMAMatrixType::get(op.getVectorType().getShape(), op.getVectorType().getElementType(), fragType); OpBuilder b(op); + bool isTranspose = isTransposeMatrixLoadMap(b, map); Value load = b.create( op.getLoc(), type, op.getSource(), op.getIndices(), - b.getIndexAttr(*stride)); + b.getIndexAttr(*stride), isTranspose ? b.getUnitAttr() : UnitAttr()); valueMapping[op.getResult()] = load; } Index: mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir =================================================================== --- mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir +++ mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir @@ -5,6 +5,7 @@ #map2 = affine_map<(d0, d1, d2) -> (d1, d2)> #map3 = affine_map<(d0, d1, d2) -> (d0, d1)> #map4 = affine_map<(d0) -> (d0, 0)> +#map5 = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @matmul // CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> @@ -170,3 +171,21 @@ vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16> return } + +// CHECK-LABEL: func @matmul_transposed +// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> +// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index, transpose} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp"> +// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16> +func.func @matmul_transposed(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16> + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map5, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16> + return +}