Index: mlir/include/mlir/Conversion/Passes.h =================================================================== --- mlir/include/mlir/Conversion/Passes.h +++ mlir/include/mlir/Conversion/Passes.h @@ -36,6 +36,7 @@ #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" #include "mlir/Conversion/TosaToSCF/TosaToSCF.h" #include "mlir/Conversion/TosaToStandard/TosaToStandard.h" +#include "mlir/Conversion/VectorToGPU/VectorToGPU.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h" #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" Index: mlir/include/mlir/Conversion/Passes.td =================================================================== --- mlir/include/mlir/Conversion/Passes.td +++ mlir/include/mlir/Conversion/Passes.td @@ -503,6 +503,20 @@ let constructor = "tosa::createTosaToStandard()"; } +//===----------------------------------------------------------------------===// +// VectorToGPU +//===----------------------------------------------------------------------===// + +def ConvertVectorToGPU : FunctionPass<"convert-vector-to-gpu"> { + let summary = "Lower the operations from the vector dialect into the GPU " + "dialect"; + let constructor = "mlir::createConvertVectorToGPUPass()"; + let dependentDialects = [ + "memref::MemRefDialect", + "gpu::GPUDialect" + ]; +} + //===----------------------------------------------------------------------===// // VectorToSCF //===----------------------------------------------------------------------===// Index: mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h =================================================================== --- /dev/null +++ mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h @@ -0,0 +1,26 @@ +#ifndef MLIR_INCLUDE_MLIR_CONVERSION_VECTORTOSCF_VECTORTOGPU_H_ +#define MLIR_INCLUDE_MLIR_CONVERSION_VECTORTOSCF_VECTORTOGPU_H_ + +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +class MLIRContext; +class Pass; +class FuncOp; +class RewritePatternSet; + +/// Patterns to transform vector ops into a canonical form to convert to MMA +/// matrix operations. +void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns); + +/// Convert vector ops to MMA matrix operations. This will convert slice of +/// operations that can be legally converted to MMA operations. The rest of the +/// vector operations are left untouched. +void convertVectorToMMAOps(FuncOp funcOp); + +/// Pass to test convertion from vector to GPU ops. +std::unique_ptr createConvertVectorToGPUPass(); + +} // namespace mlir + +#endif // MLIR_INCLUDE_MLIR_CONVERSION_VECTORTOSCF_VECTORTOGPU_H_ Index: mlir/lib/Conversion/CMakeLists.txt =================================================================== --- mlir/lib/Conversion/CMakeLists.txt +++ mlir/lib/Conversion/CMakeLists.txt @@ -27,5 +27,6 @@ add_subdirectory(TosaToStandard) add_subdirectory(VectorToROCDL) add_subdirectory(VectorToLLVM) +add_subdirectory(VectorToGPU) add_subdirectory(VectorToSCF) add_subdirectory(VectorToSPIRV) Index: mlir/lib/Conversion/VectorToGPU/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Conversion/VectorToGPU/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_conversion_library(MLIRVectorToGPU + VectorToGPU.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToGPU + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRGPU + MLIRLLVMIR + MLIRMemRef + MLIRTransforms + ) Index: mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp =================================================================== --- /dev/null +++ mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -0,0 +1,415 @@ +//===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements lowering of vector operations to GPU dialect ops. +// +//===----------------------------------------------------------------------===// + +#include + +#include "mlir/Conversion/VectorToGPU/VectorToGPU.h" + +#include "../PassDetail.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Utils.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/Vector/VectorUtils.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; + +// Return true if the contract op can be convert to MMA matmul. +bool contractSupportsMMAMatrixType(vector::ContractionOp contract) { + if (llvm::size(contract.masks()) != 0) + return false; + + using MapList = ArrayRef>; + auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; + AffineExpr m, n, k; + bindDims(contract.getContext(), m, n, k); + auto iteratorTypes = contract.iterator_types().getValue(); + if (!(isParallelIterator(iteratorTypes[0]) && + isParallelIterator(iteratorTypes[1]) && + isReductionIterator(iteratorTypes[2]))) + return false; + + // The contract needs to represent a matmul to be able to convert to + // MMAMatrix matmul. + if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}})) + return false; + + // Check that the size match what is natively supported. + VectorType lhsType = contract.lhs().getType().cast(); + VectorType rhsType = contract.rhs().getType().cast(); + VectorType accType = contract.acc().getType().cast(); + + std::tuple dim(lhsType.getDimSize(0), rhsType.getDimSize(1), + lhsType.getDimSize(1)); + if (lhsType.getElementType().isInteger(8) && + rhsType.getElementType().isInteger(8) && + accType.getElementType().isInteger(32) && + (dim == std::make_tuple(8, 8, 32) || dim == std::make_tuple(16, 16, 32) || + dim == std::make_tuple(16, 8, 32))) + return true; + + if (lhsType.getElementType().isF16() && rhsType.getElementType().isF16() && + (accType.getElementType().isF16() || accType.getElementType().isF32()) && + (dim == std::make_tuple(8, 8, 16) || dim == std::make_tuple(16, 16, 16) || + dim == std::make_tuple(16, 8, 16))) + return true; + return false; +} + +// Return true if the transfer op can be converted to a MMA matrix load. +bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) { + if (readOp.mask() || readOp.hasOutOfBoundsDim() || + readOp.getVectorType().getRank() != 2) + return false; + + // TODO: Support transpose once it is added to GPU dialect ops. + if (!readOp.permutation_map().isMinorIdentity()) + return false; + return true; +} + +// Return true if the transfer op can be converted to a MMA matrix store. +bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { + if (writeOp.mask() || writeOp.hasOutOfBoundsDim() || + writeOp.getVectorType().getRank() != 2) + return false; + + // TODO: Support transpose once it is added to GPU dialect ops. + if (!writeOp.permutation_map().isMinorIdentity()) + return false; + return true; +} + +static bool supportsMMaMatrixType(Operation *op) { + if (isa(op)) + return true; + if (auto transferRead = dyn_cast(op)) { + return transferReadSupportsMMAMatrixType(transferRead); + } + if (auto transferWrite = dyn_cast(op)) { + return transferWriteSupportsMMAMatrixType(transferWrite); + } + if (auto contract = dyn_cast(op)) { + return contractSupportsMMAMatrixType(contract); + } + return false; +} + +// Analysis slice of operations based on convert op to figure out if the whole +// slice can be converted to MMA operations. +static SetVector getOpToConvert(mlir::Operation *op) { + SetVector opToConvert; + op->walk([&](Operation *op) { + auto contract = dyn_cast(op); + if (contract == nullptr || opToConvert.contains(op)) + return; + auto hasVectorDest = [](Operation *op) { + for (auto resultType : op->getResultTypes()) { + if (resultType.isa()) + return true; + } + if (op->getNumResults() == 0) + return true; + return false; + }; + auto dependentOps = getSlice(op, hasVectorDest, hasVectorDest); + for (auto *dependeOp : dependentOps) { + // If any instruction cannot use MMA matrix type drop the whole + // chaine. MMA matrix are stored in an opaque type so they cannot be used + // by all operations. + if (!supportsMMaMatrixType(dependeOp)) { + return; + } + } + opToConvert.insert(dependentOps.begin(), dependentOps.end()); + }); + return opToConvert; +} + +namespace { +// Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted +// to MMA matmul. +struct PrepareContractToGPUMMA + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc(); + + // Set up the parallel/reduction structure in right form. + using MapList = ArrayRef>; + auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; + AffineExpr m, n, k; + bindDims(rewriter.getContext(), m, n, k); + static constexpr std::array perm = {1, 0}; + auto iteratorTypes = op.iterator_types().getValue(); + SmallVector maps = op.getIndexingMaps(); + if (!(isParallelIterator(iteratorTypes[0]) && + isParallelIterator(iteratorTypes[1]) && + isReductionIterator(iteratorTypes[2]))) + return failure(); + // + // Two outer parallel, one inner reduction (matmat flavor). + // + if (maps == infer({{m, k}, {k, n}, {m, n}})) { + // This is the classical row-major matmul, nothing to do. + return failure(); + } + if (maps == infer({{m, k}, {n, k}, {m, n}})) { + rhs = rewriter.create(loc, rhs, perm); + } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { + rhs = rewriter.create(loc, rhs, perm); + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { + std::swap(rhs, lhs); + rhs = rewriter.create(loc, rhs, perm); + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { + std::swap(rhs, lhs); + rhs = rewriter.create(loc, rhs, perm); + } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { + std::swap(lhs, rhs); + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { + std::swap(lhs, rhs); + } else { + return failure(); + } + rewriter.replaceOpWithNewOp( + op, lhs, rhs, res, + rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})), + op.iterator_types()); + return success(); + } +}; + +// Merge transpose op into the transfer read op. Transpose are not supported on +// MMA types but MMA load can transpose the matrix when loading. +struct CombineTransferReadOpTranspose final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransposeOp op, + PatternRewriter &rewriter) const override { + auto transferReadOp = op.vector().getDefiningOp(); + if (!transferReadOp) + return failure(); + if (transferReadOp.mask() || transferReadOp.hasOutOfBoundsDim()) + return failure(); + SmallVector perm; + op.getTransp(perm); + SmallVector permU; + for (int64_t o : perm) + permU.push_back(unsigned(o)); + AffineMap permutationMap = + AffineMap::getPermutationMap(permU, op.getContext()); + AffineMap newMap = permutationMap.compose(transferReadOp.permutation_map()); + rewriter.replaceOpWithNewOp( + op, op.getType(), transferReadOp.source(), transferReadOp.indices(), + newMap, transferReadOp.padding(), transferReadOp.mask(), + transferReadOp.in_boundsAttr()); + return success(); + } +}; + +} // namespace + +// MMA types have different layout based on how they are used in matmul ops. +// Figure the right layout to use by looking at Transfer op uses. +// TODO: Change the GPU dialect to abstract the layout at the this level and +// only care about it during lowering to NVVM. +static MMAFragType inferFragType(vector::TransferReadOp op) { + for (Operation *users : op->getUsers()) { + auto contract = dyn_cast(users); + if (!contract) + continue; + if (contract.lhs() == op.getResult()) + return MMAFragType::AOp; + if (contract.rhs() == op.getResult()) + return MMAFragType::BOp; + if (contract.acc() == op.getResult()) + return MMAFragType::COp; + } + return MMAFragType::COp; +} + +static void convertTransferReadOp(vector::TransferReadOp op, + llvm::DenseMap &valueMapping) { + assert(transferReadSupportsMMAMatrixType(op)); + auto memrefType = op.getShapedType().cast(); + int64_t offset = 0; + SmallVector strides; + if (failed(getStridesAndOffset(memrefType, strides, offset))) + return; + int64_t stride = strides[0]; + MMAFragType fragType = inferFragType(op); + gpu::MMAMatrixType type = gpu::MMAMatrixType::get( + op.getVectorType().getShape()[0], op.getVectorType().getShape()[1], + op.getVectorType().getElementType(), fragType); + OpBuilder b(op); + Value load = b.create( + op.getLoc(), type, op.source(), op.indices(), b.getIndexAttr(stride)); + valueMapping[op.getResult()] = load; +} + +static void convertTransferWriteOp(vector::TransferWriteOp op, + llvm::DenseMap &valueMapping) { + assert(transferWriteSupportsMMAMatrixType(op)); + auto memrefType = op.getShapedType().cast(); + int64_t offset = 0; + SmallVector strides; + if (failed(getStridesAndOffset(memrefType, strides, offset))) + return; + int64_t stride = strides[0]; + OpBuilder b(op); + Value matrix = valueMapping.find(op.vector())->second; + b.create(op.getLoc(), matrix, op.source(), + op.indices(), b.getIndexAttr(stride)); + op.erase(); +} + +static void convertContractOp(vector::ContractionOp op, + llvm::DenseMap &valueMapping) { + OpBuilder b(op); + Value A = valueMapping.find(op.lhs())->second; + Value B = valueMapping.find(op.rhs())->second; + Value C = valueMapping.find(op.acc())->second; + Value matmul = + b.create(op.getLoc(), C.getType(), A, B, C); + valueMapping[op.getResult()] = matmul; +} + +// Replace the ForOp operation with a new one with extra operands. The loop +// region is unchanged and the yield op needs to be updated separately to math +// the ForOp signature. +static scf::ForOp replaceForWithNewOperand(OpBuilder &b, scf::ForOp loop, + ValueRange newIterOperands) { + // Create a new loop before the existing one, with the extra operands. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(loop); + auto operands = llvm::to_vector<4>(loop.getIterOperands()); + operands.append(newIterOperands.begin(), newIterOperands.end()); + scf::ForOp newLoop = + b.create(loop.getLoc(), loop.lowerBound(), loop.upperBound(), + loop.step(), operands); + newLoop.getBody()->erase(); + newLoop.getLoopBody().getBlocks().splice( + newLoop.getLoopBody().getBlocks().begin(), + loop.getLoopBody().getBlocks()); + for (auto operand : newIterOperands) + newLoop.getBody()->addArgument(operand.getType()); + + for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( + loop.getNumResults()))) + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + loop.erase(); + return newLoop; +} + +static void convertForOp(scf::ForOp op, + llvm::DenseMap &valueMapping) { + SmallVector newOperands; + SmallVector> argMapping; + for (auto operand : llvm::enumerate(op.getIterOperands())) { + auto it = valueMapping.find(operand.value()); + if (it == valueMapping.end()) + continue; + argMapping.push_back(std::make_pair( + operand.index(), op.getNumIterOperands() + newOperands.size())); + newOperands.push_back(it->second); + } + OpBuilder b(op); + // Add new operands for the new MMA matrix values and leave the previous + // operands to dead code. + auto newForOp = replaceForWithNewOperand(b, op, newOperands); + Block &loopBody = *newForOp.getBody(); + for (auto mapping : argMapping) { + valueMapping[newForOp.getResult(mapping.first)] = + newForOp.getResult(mapping.second); + valueMapping[loopBody.getArgument(mapping.first + + newForOp.getNumInductionVars())] = + loopBody.getArgument(mapping.second + newForOp.getNumInductionVars()); + } +} + +static void convertYieldOp(scf::YieldOp op, + llvm::DenseMap &valueMapping) { + OpBuilder b(op); + auto loop = cast(op->getParentOp()); + auto yieldOperands = llvm::to_vector<4>(op.getOperands()); + for (auto operand : llvm::enumerate(op.getOperands())) { + auto it = valueMapping.find(operand.value()); + if (it == valueMapping.end()) + continue; + // Replace the yield of old value with the for op argument to make it easier + // to remove the dead code. + yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()]; + yieldOperands.push_back(it->second); + } + b.create(op.getLoc(), yieldOperands); + op.erase(); +} + +namespace mlir { + +void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + +void convertVectorToMMAOps(FuncOp funcOp) { + SetVector ops = getOpToConvert(funcOp); + llvm::DenseMap valueMapping; + for (Operation *op : ops) { + if (auto transferRead = dyn_cast(op)) { + convertTransferReadOp(transferRead, valueMapping); + } else if (auto transferWrite = dyn_cast(op)) { + convertTransferWriteOp(transferWrite, valueMapping); + } else if (auto contractOp = dyn_cast(op)) { + convertContractOp(contractOp, valueMapping); + } else if (auto forOp = dyn_cast(op)) { + convertForOp(forOp, valueMapping); + } else if (auto yiledOp = dyn_cast(op)) { + convertYieldOp(yiledOp, valueMapping); + } + } +} + +} // namespace mlir +namespace { + +struct ConvertVectorToGPUPass + : public ConvertVectorToGPUBase { + void runOnFunction() override { + RewritePatternSet patterns(getFunction().getContext()); + populatePrepareVectorToMMAPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); + + convertVectorToMMAOps(getFunction()); + } +}; + +} // namespace + +std::unique_ptr mlir::createConvertVectorToGPUPass() { + return std::make_unique(); +} Index: mlir/lib/Dialect/GPU/IR/GPUDialect.cpp =================================================================== --- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1067,10 +1067,12 @@ "destination memorySpace of kGenericMemorySpace, " "kGlobalMemorySpace or kSharedMemorySpace only allowed"); - if (srcMatrixType.getMMAFragType() != MMAFragType::DOp) - return op.emitError( - "expected the operand matrix being stored to have 'DOp' operand type"); - + // Allow COp type matrices to be stored as well. This assumes COp and DOp are + // layout the same way, without that we cannot chain mma matmul ops. + if (srcMatrixType.getMMAFragType() != MMAFragType::COp && + srcMatrixType.getMMAFragType() != MMAFragType::DOp) + return op.emitError("expected the operand matrix being stored to have " + "'COp' or 'DOp' operand type"); return success(); } Index: mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir =================================================================== --- /dev/null +++ mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir @@ -0,0 +1,49 @@ +// RUN: mlir-opt %s -convert-vector-to-gpu -canonicalize -split-input-file | FileCheck %s + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func @matmul +// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> +// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp"> +// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16> +func @matmul(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>) { + %cst_0 = constant dense<0.000000e+00> : vector<16x16xf16> + %c0 = constant 0 : index + %cst = constant 0.000000e+00 : f16 + %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16> + return +} + + +// CHECK-LABEL: func @matmul_loop +// CHECK: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: %[[ACC:.+]] = scf.for {{.*}} iter_args(%[[ACC1:.+]] = %[[C]]) -> (!gpu.mma_matrix<16x16xf16, "COp">) { +// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> +// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "BOp"> +// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[ACC1]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: scf.yield %[[D]] : !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: gpu.subgroup_mma_store_matrix %[[ACC]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 128 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<128x128xf16> +func @matmul_loop(%arg0: memref<128x128xf16>, %arg1: memref<128x128xf16>, %arg2: memref<128x128xf16>) { + %c0 = constant 0 : index + %c128 = constant 128 : index + %c32 = constant 32 : index + %cst = constant 0.000000e+00 : f16 + %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<128x128xf16>, vector<16x16xf16> + %14 = scf.for %arg17 = %c0 to %c128 step %c32 iter_args(%arg18 = %C) -> (vector<16x16xf16>) { + %17 = vector.transfer_read %arg0[%c0, %arg17], %cst {in_bounds = [true, true]} : memref<128x128xf16>, vector<16x16xf16> + %18 = vector.transfer_read %arg1[%arg17, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<128x128xf16>, vector<16x16xf16> + %19 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %17, %18, %arg18 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> + scf.yield %19 : vector<16x16xf16> + } + vector.transfer_write %14, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<128x128xf16> + return +} Index: mlir/test/Dialect/GPU/invalid.mlir =================================================================== --- mlir/test/Dialect/GPU/invalid.mlir +++ mlir/test/Dialect/GPU/invalid.mlir @@ -551,7 +551,7 @@ %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3> %i = constant 16 : index %j = constant 16 : index - // expected-error @+1 {{expected the operand matrix being stored to have 'DOp' operand type}} + // expected-error @+1 {{expected the operand matrix being stored to have 'COp' or 'DOp' operand type}} gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "AOp">, memref<32x32xf16, 3> return }