diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1472,6 +1472,7 @@ def GPU_ElementwiseOpDivU : I32EnumAttrCase<"DIVU", 10, "divu">; def GPU_ElementwiseOpNEGF : I32EnumAttrCase<"NEGATEF", 11, "negatef">; def GPU_ElementwiseOpNEGS : I32EnumAttrCase<"NEGATES", 12, "negates">; +def GPU_ElementwiseOpEXTF : I32EnumAttrCase<"EXTF", 13, "extf">; def MMAElementWise : I32EnumAttr<"MMAElementwiseOp", "elementwise operation to apply to mma matrix", [ @@ -1487,7 +1488,8 @@ GPU_ElementwiseOpDivS, GPU_ElementwiseOpDivU, GPU_ElementwiseOpNEGF, - GPU_ElementwiseOpNEGS + GPU_ElementwiseOpNEGS, + GPU_ElementwiseOpEXTF ]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::gpu"; diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp @@ -61,6 +61,9 @@ case gpu::MMAElementwiseOp::NEGATES: builder.replaceOpWithNewOp(op, coopType, operands); return true; + case gpu::MMAElementwiseOp::EXTF: + builder.replaceOpWithNewOp(op, coopType, operands); + return true; default: break; } diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -214,6 +214,13 @@ }); } +// Return true if the given ext op is only used by vector transfer write ops. +static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { + return llvm::all_of(extOp->getUsers(), [](Operation *op) { + return isa(op); + }); +} + /// Return the MMA elementwise enum associated with `op` if it is supported. /// Return `std::nullopt` otherwise. static std::optional @@ -242,6 +249,8 @@ return gpu::MMAElementwiseOp::DIVU; if (isa(op)) return gpu::MMAElementwiseOp::NEGATEF; + if (isa(op)) + return gpu::MMAElementwiseOp::EXTF; return std::nullopt; } @@ -297,6 +306,8 @@ return integerExtendSupportsMMAMatrixType(signedExtend); if (auto unsignedExtend = dyn_cast(op)) return integerExtendSupportsMMAMatrixType(unsignedExtend); + if (auto fpExtend = dyn_cast(op)) + return fpExtendSupportsMMAMatrixType(fpExtend); return elementwiseSupportsMMAMatrixType(op); } @@ -1203,8 +1214,17 @@ return rewriter.notifyMatchFailure(op, "no mapping"); matrixOperands.push_back(it->second); } + auto resultType = matrixOperands[0].getType().cast(); + if (opType == gpu::MMAElementwiseOp::EXTF) { + // The floating point extension case has a different result type. + auto vectorType = op->getResultTypes()[0].cast(); + resultType = gpu::MMAMatrixType::get(resultType.getShape(), + vectorType.getElementType(), + resultType.getOperand()); + } + Value newOp = rewriter.create( - op->getLoc(), matrixOperands[0].getType(), matrixOperands, opType); + op->getLoc(), resultType, matrixOperands, opType); valueMapping[op->getResult(0)] = newOp; return success(); } diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir --- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir @@ -142,6 +142,8 @@ %D = gpu.subgroup_mma_elementwise negatef %C : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup> %E = gpu.subgroup_mma_elementwise divf %D, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> + // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.NV.coopmatrix<16x16xf16, Subgroup> to !spirv.NV.coopmatrix<16x16xf32, Subgroup> + %F = gpu.subgroup_mma_elementwise extf %E : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp"> // CHECK: spirv.Return gpu.return } diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir --- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir +++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir @@ -437,4 +437,41 @@ %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %Ae, %Be, %C : vector<16x32xi32>, vector<16x32xi32> into vector<16x16xi32> vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32> return -} \ No newline at end of file +} + +// ----- + +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func @cast_f16_to_f32_write +// CHECK: %[[COMPUTE:.+]] = gpu.subgroup_mma_compute +// CHECK: %[[EXT:.+]] = gpu.subgroup_mma_elementwise extf %[[COMPUTE]] : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp"> +// CHECK: gpu.subgroup_mma_store_matrix %[[EXT]] +func.func @cast_f16_to_f32_write(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>, %arg3: memref<16x16xf32>) { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + %cast = arith.extf %D : vector<16x16xf16> to vector<16x16xf32> + vector.transfer_write %cast, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32> + return +} + +// CHECK-LABEL: func @cast_f16_to_f32_non_write +// CHECK: arith.extf %{{.+}} : vector<16x16xf16> to vector<16x16xf32> +func.func @cast_f16_to_f32_non_write(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>, %arg3: memref<16x16xf32>) { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + %cast = arith.extf %D : vector<16x16xf16> to vector<16x16xf32> + %neg = arith.negf %cast : vector<16x16xf32> + vector.transfer_write %neg, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32> + return +}