diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -1022,4 +1022,49 @@ let verifier = [{ return ::verify(*this); }]; } +def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix", + [NoSideEffect, + TypesMatchWith<"value type matches element type of mma_matrix", + "res", "value", + "$_self.cast().getElementType()">]>{ + + let summary = "GPU warp synchronous constant matrix"; + + let description = [{ + The `gpu.subgroup_mma_constant_matrix` creates a `!gpu.mma_matrix` with + constant elements. + + The operation takes a scalar input and return a `!gpu.mma_matrix` where each + element of is equal to the operand constant. The destination mma_matrix type + must have elememt type equal to the constant type. Since the layout of + `!gpu.mma_matrix` is opaque this only support setting all the elements to + the same value. + + This op is meant to be used along with `gpu.subgroup_mma_compute`. + + Example: + + ```mlir + %0 = gpu.subgroup_mma_constant_matrix %a : + !gpu.mma_matrix<16x16xf16, "AOp"> + %1 = gpu.subgroup_mma_constant_matrix %b : + !gpu.mma_matrix<16x16xf32, "COp"> + ``` + }]; + + let arguments = (ins AnyTypeOf<[F16, F32]>:$value); + + let results = (outs GPU_MMAMatrix:$res); + + let extraClassDeclaration = [{ + gpu::MMAMatrixType getType() { + return res().getType().cast(); + } + }]; + + let assemblyFormat = [{ + $value attr-dict `:` type($res) + }]; +} + #endif // GPU_OPS diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -348,12 +348,52 @@ } }; +/// Convert GPU MMA ConstantMatrixOp to a chain of InsertValueOp. +struct WmmaConstantOpToNVVMLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + gpu::SubgroupMmaConstantMatrixOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (failed(areAllLLVMTypes(subgroupMmaConstantOp.getOperation(), operands, + rewriter))) + return failure(); + Location loc = subgroupMmaConstantOp.getLoc(); + Value cst = operands[0]; + LLVM::LLVMStructType type = convertMMAToLLVMType( + subgroupMmaConstantOp.getType().cast()); + // If the element type is a vector create a vector from the operand. + if (auto vecType = type.getBody()[0].dyn_cast()) { + Value vecCst = rewriter.create(loc, vecType); + for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) { + Value idx = rewriter.create( + loc, typeConverter->convertType(rewriter.getIntegerType(32)), + rewriter.getI32ArrayAttr(vecEl)); + vecCst = rewriter.create(loc, vecType, vecCst, + cst, idx); + } + cst = vecCst; + } + Value matrixStruct = rewriter.create(loc, type); + for (size_t i : llvm::seq(size_t(0), type.getBody().size())) { + matrixStruct = rewriter.create( + loc, matrixStruct, cst, rewriter.getI32ArrayAttr(i)); + } + rewriter.replaceOp(subgroupMmaConstantOp, matrixStruct); + return success(); + } +}; + } // anonymous namespace namespace mlir { void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.insert(converter); + WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering>( + converter); } } // namespace mlir diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir --- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir @@ -151,3 +151,28 @@ return } } + + +// ----- + +gpu.module @test_module { + +// CHECK-LABEL: func @gpu_wmma_constant_op +// CHECK: %[[CST:.+]] = llvm.mlir.constant(1.000000e+00 : f16) : f16 +// CHECK: %[[V0:.+]] = llvm.mlir.undef : vector<2xf16> +// CHECK: %[[C0:.+]] = llvm.mlir.constant([0 : i32]) : i32 +// CHECK: %[[V1:.+]] = llvm.insertelement %[[CST]], %[[V0]][%[[C0]] : i32] : vector<2xf16> +// CHECK: %[[C1:.+]] = llvm.mlir.constant([1 : i32]) : i32 +// CHECK: %[[V2:.+]] = llvm.insertelement %[[CST]], %[[V1]][%[[C1]] : i32] : vector<2xf16> +// CHECK: %[[M0:.+]] = llvm.mlir.undef : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[M1:.+]] = llvm.insertvalue %[[V2]], %[[M0]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[M2:.+]] = llvm.insertvalue %[[V2]], %[[M1]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[M3:.+]] = llvm.insertvalue %[[V2]], %[[M2]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[M4:.+]] = llvm.insertvalue %[[V2]], %[[M3]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: llvm.return %[[M4]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + func @gpu_wmma_constant_op() ->(!gpu.mma_matrix<16x16xf16, "COp">) { + %cst = constant 1.0 : f16 + %C = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp"> + return %C : !gpu.mma_matrix<16x16xf16, "COp"> + } +} diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -201,8 +201,12 @@ // CHECK: %[[wg:.*]] = memref.alloca() %i = constant 16 : index // CHECK: %[[i:.*]] = constant 16 : index + %cst = constant 1.000000e+00 : f32 + // CHECK: %[[cst:.*]] = constant 1.000000e+00 : f32 %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp"> // CHECK: gpu.subgroup_mma_load_matrix %[[wg]][%[[i]], %[[i]]] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp"> + %1 = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf32, "COp"> + // CHECK: gpu.subgroup_mma_constant_matrix %[[cst]] : !gpu.mma_matrix<16x16xf32, "COp"> return } }