Index: mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp =================================================================== --- mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -17,6 +17,7 @@ #include "../PassDetail.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorUtils.h" @@ -122,6 +123,8 @@ } static bool supportsMMaMatrixType(Operation *op) { + if (isa(op)) + return true; if (auto transferRead = dyn_cast(op)) return transferReadSupportsMMAMatrixType(transferRead); if (auto transferWrite = dyn_cast(op)) @@ -324,6 +327,74 @@ valueMapping[op.getResult()] = matrix; } +// Replace ForOp with a new ForOp with extra operands. The YieldOp is not +// updated and needs to be updated separatly for the loop to be correct. +static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop, + ValueRange newIterOperands) { + // Create a new loop before the existing one, with the extra operands. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(loop); + auto operands = llvm::to_vector<4>(loop.getIterOperands()); + operands.append(newIterOperands.begin(), newIterOperands.end()); + scf::ForOp newLoop = + b.create(loop.getLoc(), loop.lowerBound(), loop.upperBound(), + loop.step(), operands); + newLoop.getBody()->erase(); + newLoop.getLoopBody().getBlocks().splice( + newLoop.getLoopBody().getBlocks().begin(), + loop.getLoopBody().getBlocks()); + for (auto operand : newIterOperands) + newLoop.getBody()->addArgument(operand.getType()); + + for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( + loop.getNumResults()))) + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + loop.erase(); + return newLoop; +} + +static void convertForOp(scf::ForOp op, + llvm::DenseMap &valueMapping) { + SmallVector newOperands; + SmallVector> argMapping; + for (auto operand : llvm::enumerate(op.getIterOperands())) { + auto it = valueMapping.find(operand.value()); + if (it == valueMapping.end()) + continue; + argMapping.push_back(std::make_pair( + operand.index(), op.getNumIterOperands() + newOperands.size())); + newOperands.push_back(it->second); + } + OpBuilder b(op); + scf::ForOp newForOp = replaceForOpWithNewSignature(b, op, newOperands); + Block &loopBody = *newForOp.getBody(); + for (auto mapping : argMapping) { + valueMapping[newForOp.getResult(mapping.first)] = + newForOp.getResult(mapping.second); + valueMapping[loopBody.getArgument(mapping.first + + newForOp.getNumInductionVars())] = + loopBody.getArgument(mapping.second + newForOp.getNumInductionVars()); + } +} + +static void convertYieldOp(scf::YieldOp op, + llvm::DenseMap &valueMapping) { + OpBuilder b(op); + auto loop = cast(op->getParentOp()); + auto yieldOperands = llvm::to_vector<4>(op.getOperands()); + for (auto operand : llvm::enumerate(op.getOperands())) { + auto it = valueMapping.find(operand.value()); + if (it == valueMapping.end()) + continue; + // Replace the yield of old value with the for op argument to make it easier + // to remove the dead code. + yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()]; + yieldOperands.push_back(it->second); + } + b.create(op.getLoc(), yieldOperands); + op.erase(); +} + namespace mlir { void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) { @@ -343,6 +414,10 @@ convertContractOp(contractOp, valueMapping); } else if (auto constantOp = dyn_cast(op)) { convertConstantOp(constantOp, valueMapping); + } else if (auto forOp = dyn_cast(op)) { + convertForOp(forOp, valueMapping); + } else if (auto yiledOp = dyn_cast(op)) { + convertYieldOp(yiledOp, valueMapping); } } } Index: mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir =================================================================== --- mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir +++ mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir @@ -41,9 +41,15 @@ return } -// Negative test until scf.for support is added. // CHECK-LABEL: func @matmul_loop -// CHECK: vector.contract +// CHECK: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: %[[ACC:.+]] = scf.for {{.*}} iter_args(%[[ACC1:.+]] = %[[C]]) -> (!gpu.mma_matrix<16x16xf16, "COp">) { +// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> +// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "BOp"> +// CHECK-NEXT: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[ACC1]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK-NEXT: scf.yield %[[D]] : !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK-NEXT: } +// CHECK-NEXT: gpu.subgroup_mma_store_matrix %[[ACC]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 128 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<128x128xf16> func @matmul_loop(%arg0: memref<128x128xf16>, %arg1: memref<128x128xf16>, %arg2: memref<128x128xf16>) { %c0 = constant 0 : index %c128 = constant 128 : index