diff --git a/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h b/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h --- a/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h +++ b/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h @@ -12,6 +12,7 @@ #include "mlir/IR/PatternMatch.h" namespace mlir { +class LogicalResult; class MLIRContext; class Pass; class RewritePatternSet; @@ -29,13 +30,14 @@ /// Convert vector ops to MMA matrix operations nested under `rootOp`. This will /// convert slice of operations that can be legally converted to MMA operations. /// The rest of the vector operations are left untouched. -void convertVectorToMMAOps(Operation *rootOp); +LogicalResult convertVectorToMMAOps(RewriterBase &rewriter, Operation *rootOp); /// Convert vector ops ops nested under `rootOp` to vector and GPU operaitons /// compatible with the `nvvm.mma.sync` lowering path. This will convert a slice /// of operations that can be legally lowered on this path while the rest of /// the vector operations are left untouched. -LogicalResult convertVectorToNVVMCompatibleMMASync(Operation *rootOp); +LogicalResult convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, + Operation *rootOp); /// Convert from vector to GPU ops. std::unique_ptr createConvertVectorToGPUPass(bool useNvGpu = false); diff --git a/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h b/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h --- a/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h +++ b/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h @@ -69,7 +69,7 @@ /// please see NVIDIA's PTX documentation: /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma FailureOr -getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder, +getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType); /// Encapsulates the parameters needed to lower a `nvgpu.ldmatrix` operation to @@ -90,7 +90,7 @@ /// to two results representing offsets within the matrix operand that should /// be the pointer locations a thread should pass to the ldmatrix instruction. FailureOr -getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder, +getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams ¶ms); /// Transform `vector.contract` into (m,k)x(n,k)x(m,n) form so that it can be diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -26,11 +26,19 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Region.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" +#define DEBUG_TYPE "vector-to-gpu" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define DBGSNL() (llvm::dbgs() << "\n") + namespace mlir { #define GEN_PASS_DEF_CONVERTVECTORTOGPU #include "mlir/Conversion/Passes.h.inc" @@ -45,7 +53,7 @@ /// the `offsetMap` has dimension placeholders, those should be provided in /// `dimValues`. template -static void getXferIndices(OpBuilder &b, TransferOpType xferOp, +static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef dimValues, SmallVector &indices) { indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end()); @@ -56,9 +64,9 @@ Value prevIdx = indices[dim.getPosition()]; SmallVector dims(dimValues.begin(), dimValues.end()); dims.push_back(prevIdx); - AffineExpr d0 = b.getAffineDimExpr(offsetMap.getNumDims()); + AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims()); indices[dim.getPosition()] = makeComposedAffineApply( - b, loc, d0 + offsetMap.getResult(offsetsIdx++), dims); + rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims); continue; } } @@ -94,8 +102,10 @@ // Return true if the given map represents a transposed matrix load, // i.e. (d0, d1, ...) -> (dn-1, dn-2). -static bool isTransposeMatrixLoadMap(OpBuilder &b, AffineMap permutationMap) { - MLIRContext *ctx = b.getContext(); +static bool isTransposeMatrixLoadMap(AffineMap permutationMap) { + MLIRContext *ctx = permutationMap.getContext(); + // Local OpBuilder is fine here, we just build attributes. + OpBuilder b(ctx); auto nDim = permutationMap.getNumDims(); AffineExpr zero = b.getAffineConstantExpr(0); if (nDim < 2) { @@ -147,15 +157,16 @@ return false; AffineMap map = readOp.getPermutationMap(); - OpBuilder b(readOp.getContext()); - AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1); - AffineExpr zero = b.getAffineConstantExpr(0); - auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, - readOp.getContext()); + + MLIRContext *ctx = readOp.getContext(); + AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx); + AffineExpr zero = getAffineConstantExpr(0, ctx); + auto broadcastInnerDim = + AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx); if (!useNvGpu) { bool result = map.isMinorIdentity() || map == broadcastInnerDim || - isTransposeMatrixLoadMap(b, map); + isTransposeMatrixLoadMap(map); return result; } @@ -379,14 +390,13 @@ if (!(vector::isParallelIterator(iteratorTypes[0]) && vector::isParallelIterator(iteratorTypes[1]) && vector::isReductionIterator(iteratorTypes[2]))) - return failure(); + return rewriter.notifyMatchFailure(op, "not a gemm contraction"); // // Two outer parallel, one inner reduction (matmat flavor). // - if (maps == infer({{m, k}, {k, n}, {m, n}})) { - // This is the classical row-major matmul, nothing to do. - return failure(); - } + // This is the classical row-major matmul, nothing to do. + if (maps == infer({{m, k}, {k, n}, {m, n}})) + return rewriter.notifyMatchFailure(op, "contraction already prepared"); if (maps == infer({{m, k}, {n, k}, {m, n}})) { rhs = rewriter.create(loc, rhs, perm); } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { @@ -407,7 +417,8 @@ } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { std::swap(lhs, rhs); } else { - return failure(); + // TODO: llvm_unreachable ? + return rewriter.notifyMatchFailure(op, "unexpected contraction case"); } rewriter.replaceOpWithNewOp( op, lhs, rhs, res, @@ -440,14 +451,15 @@ auto transferReadOp = source.getDefiningOp(); if (!transferReadOp) - return failure(); + return rewriter.notifyMatchFailure(op, "no transfer read"); // TODO: support 0-d corner case. if (transferReadOp.getTransferRank() == 0) - return failure(); + return rewriter.notifyMatchFailure(op, "0-D transfer read"); if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim()) - return failure(); + return rewriter.notifyMatchFailure(op, "not inbounds transfer read"); + SmallVector perm; op.getTransp(perm); SmallVector permU; @@ -498,17 +510,24 @@ return "COp"; } -static void convertTransferReadOp(vector::TransferReadOp op, - llvm::DenseMap &valueMapping) { +static LogicalResult +convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, + llvm::DenseMap &valueMapping) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + assert(op.getTransferRank() > 0 && "unexpected 0-d transfer"); assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false)); std::optional stride = getMemrefConstantHorizontalStride(op.getShapedType()); + if (!stride.has_value()) { + LLVM_DEBUG(DBGS() << "no stride\n"); + return rewriter.notifyMatchFailure(op, "no stride"); + } AffineMap map = op.getPermutationMap(); - OpBuilder b(op); - bool isTranspose = isTransposeMatrixLoadMap(b, map); + bool isTranspose = isTransposeMatrixLoadMap(map); // Handle broadcast by setting the stride to 0. if (auto cstExpr = @@ -516,7 +535,7 @@ assert(cstExpr.getValue() == 0); stride = 0; } - assert(stride); + Value mappingResult = op.getResult(); auto elType = op.getVectorType().getElementType(); const char *fragType = inferFragType(op); @@ -533,24 +552,47 @@ } gpu::MMAMatrixType type = gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType); - Value load = b.create( + Value load = rewriter.create( op.getLoc(), type, op.getSource(), op.getIndices(), - b.getIndexAttr(*stride), isTranspose ? b.getUnitAttr() : UnitAttr()); + rewriter.getIndexAttr(*stride), + isTranspose ? rewriter.getUnitAttr() : UnitAttr()); valueMapping[mappingResult] = load; + + LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n"); + return success(); } -static void convertTransferWriteOp(vector::TransferWriteOp op, - llvm::DenseMap &valueMapping) { +static LogicalResult +convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, + llvm::DenseMap &valueMapping) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + assert(transferWriteSupportsMMAMatrixType(op)); std::optional stride = getMemrefConstantHorizontalStride(op.getShapedType()); - assert(stride); - OpBuilder b(op); - Value matrix = valueMapping.find(op.getVector())->second; - b.create( + if (!stride.has_value()) { + LLVM_DEBUG(DBGS() << "no stride\n"); + return rewriter.notifyMatchFailure(op, "no stride"); + } + + auto it = valueMapping.find(op.getVector()); + if (it == valueMapping.end()) { + LLVM_DEBUG(DBGS() << "no mapping\n"); + return rewriter.notifyMatchFailure(op, "no mapping"); + } + + Value matrix = it->second; + auto store = rewriter.create( op.getLoc(), matrix, op.getSource(), op.getIndices(), - b.getIndexAttr(*stride), /*transpose=*/UnitAttr()); - op.erase(); + rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr()); + (void)store; + + LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n"); + + LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); + rewriter.eraseOp(op); + return success(); } /// Returns the vector type which represents a matrix fragment. @@ -566,24 +608,33 @@ /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. static LogicalResult -convertConstantOpMmaSync(arith::ConstantOp op, +convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap &valueMapping) { - OpBuilder b(op); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + FailureOr warpMatrixInfo = nvgpu::getWarpMatrixInfo(op); - if (failed(warpMatrixInfo)) - return failure(); + if (failed(warpMatrixInfo)) { + LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); + return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); + } FailureOr regInfo = nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); - if (failed(regInfo)) - return failure(); + if (failed(regInfo)) { + LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); + return rewriter.notifyMatchFailure(op, "not mma sync reg info"); + } VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); auto dense = op.getValue().dyn_cast(); - if (!dense) - return failure(); - Value result = b.create( + if (!dense) { + LLVM_DEBUG(DBGS() << "not a splat\n"); + return rewriter.notifyMatchFailure(op, "not a splat"); + } + + Value result = rewriter.create( op.getLoc(), vectorType, DenseElementsAttr::get(vectorType, dense.getSplatValue())); valueMapping[op.getResult()] = result; @@ -591,43 +642,54 @@ } static LogicalResult -creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder, +creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap &valueMapping) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); Location loc = op->getLoc(); FailureOr warpMatrixInfo = nvgpu::getWarpMatrixInfo(op); - if (failed(warpMatrixInfo)) - return failure(); + if (failed(warpMatrixInfo)) { + LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); + return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); + } FailureOr regInfo = nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); - if (failed(regInfo)) - return failure(); + if (failed(regInfo)) { + LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); + return rewriter.notifyMatchFailure(op, "not mma sync reg info"); + } FailureOr params = nvgpu::getLdMatrixParams( *warpMatrixInfo, /*transpose=*/!op.getPermutationMap().isMinorIdentity()); if (failed(params)) { - return op->emitError() - << "failed to convert vector.transfer_read to ldmatrix; this op " - "likely " - "should not be converted to a nvgpu.ldmatrix call."; + LLVM_DEBUG( + DBGS() + << "failed to convert vector.transfer_read to ldmatrix. " + << "Op should likely not be converted to a nvgpu.ldmatrix call.\n"); + return rewriter.notifyMatchFailure( + op, "failed to convert vector.transfer_read to ldmatrix; this op " + "likely should not be converted to a nvgpu.ldmatrix call."); } // Adjust the load offset. - auto laneId = builder.create(loc); + auto laneId = rewriter.create(loc); FailureOr offsets = - nvgpu::getLaneIdToLdMatrixMatrixCoord(loc, builder, *params); - if (failed(offsets)) - return failure(); + nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params); + if (failed(offsets)) { + LLVM_DEBUG(DBGS() << "no offsets\n"); + return rewriter.notifyMatchFailure(op, "no offsets"); + } VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); SmallVector indices; - getXferIndices(builder, op, *offsets, {laneId}, + getXferIndices(rewriter, op, *offsets, {laneId}, indices); - nvgpu::LdMatrixOp newOp = builder.create( + nvgpu::LdMatrixOp newOp = rewriter.create( loc, vectorType, op.getSource(), indices, !op.getPermutationMap().isMinorIdentity(), params->numTiles); valueMapping[op] = newOp->getResult(0); @@ -635,32 +697,36 @@ } static LogicalResult -createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder, +createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap &valueMapping) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); FailureOr warpMatrixInfo = nvgpu::getWarpMatrixInfo(op); if (failed(warpMatrixInfo)) - return failure(); + return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); FailureOr regInfo = nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); if (failed(regInfo)) { - op->emitError() << "Failed to deduce register fragment type during " - "conversion to distributed non-ldmatrix compatible load"; - return failure(); + rewriter.notifyMatchFailure( + op, "Failed to deduce register fragment type during " + "conversion to distributed non-ldmatrix compatible load"); } - Value laneId = builder.create(loc); + Value laneId = rewriter.create(loc); SmallVector elements; // This is the individual element type. Type loadedElType = regInfo->registerLLVMType; VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); - Value fill = builder.create( + Value fill = rewriter.create( op.getLoc(), vectorType.getElementType(), - builder.getZeroAttr(vectorType.getElementType())); - Value result = builder.create(op.getLoc(), fill, vectorType); + rewriter.getZeroAttr(vectorType.getElementType())); + Value result = + rewriter.create(op.getLoc(), fill, vectorType); bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity(); @@ -673,20 +739,21 @@ for (int i = 0; i < vectorType.getShape()[0]; i++) { FailureOr coords = nvgpu::getLaneIdAndValueIdToOperandCoord( - op.getLoc(), builder, *warpMatrixInfo); + rewriter, op.getLoc(), *warpMatrixInfo); if (failed(coords)) - return failure(); - Value logicalValueId = builder.create( - loc, builder.getIndexType(), - builder.getIndexAttr(i * regInfo->elementsPerRegister)); + return rewriter.notifyMatchFailure(op, "no coords"); + + Value logicalValueId = rewriter.create( + loc, rewriter.getIndexType(), + rewriter.getIndexAttr(i * regInfo->elementsPerRegister)); SmallVector newIndices; getXferIndices( - builder, op, *coords, {laneId, logicalValueId}, newIndices); + rewriter, op, *coords, {laneId, logicalValueId}, newIndices); - Value el = builder.create(loc, loadedElType, - op.getSource(), newIndices); - result = builder.create(loc, el, result, - builder.getI64ArrayAttr(i)); + Value el = rewriter.create(loc, loadedElType, + op.getSource(), newIndices); + result = rewriter.create(loc, el, result, + rewriter.getI64ArrayAttr(i)); } } else { if (auto vecType = loadedElType.dyn_cast()) { @@ -696,21 +763,21 @@ for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1]; innerIdx++) { - Value logicalValueId = builder.create( - loc, builder.getIndexType(), - builder.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx)); + Value logicalValueId = rewriter.create( + loc, rewriter.getIndexType(), + rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx)); FailureOr coords = nvgpu::getLaneIdAndValueIdToOperandCoord( - op.getLoc(), builder, *warpMatrixInfo); + rewriter, op.getLoc(), *warpMatrixInfo); if (failed(coords)) - return failure(); + return rewriter.notifyMatchFailure(op, "no coords"); SmallVector newIndices; getXferIndices( - builder, op, *coords, {laneId, logicalValueId}, newIndices); - Value el = builder.create(op.getLoc(), loadedElType, - op.getSource(), newIndices); - result = builder.create( - op.getLoc(), el, result, builder.getI64ArrayAttr({i, innerIdx})); + rewriter, op, *coords, {laneId, logicalValueId}, newIndices); + Value el = rewriter.create(op.getLoc(), loadedElType, + op.getSource(), newIndices); + result = rewriter.create( + op.getLoc(), el, result, rewriter.getI64ArrayAttr({i, innerIdx})); } } } @@ -733,14 +800,15 @@ /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be /// used when converting to `nvgpu.mma.sync` operations. static LogicalResult -convertTransferReadToLoads(vector::TransferReadOp op, +convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap &valueMapping) { - OpBuilder b(op); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); FailureOr warpMatrixInfo = nvgpu::getWarpMatrixInfo(op); if (failed(warpMatrixInfo)) - return failure(); + return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); bool isLdMatrixCompatible = isSharedMemory(op.getSource().getType().cast()) && @@ -758,46 +826,54 @@ isLdMatrixCompatible = false; if (!isLdMatrixCompatible) - return createNonLdMatrixLoads(op, b, valueMapping); + return createNonLdMatrixLoads(rewriter, op, valueMapping); - return creatLdMatrixCompatibleLoads(op, b, valueMapping); + return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping); } static LogicalResult -convertTransferWriteToStores(vector::TransferWriteOp op, +convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap &valueMapping) { - OpBuilder b(op); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + Location loc = op->getLoc(); - Value matrix = valueMapping.find(op.getVector())->second; + auto it = valueMapping.find(op.getVector()); + if (it == valueMapping.end()) + return rewriter.notifyMatchFailure(op, "no mapping"); + Value matrix = it->second; FailureOr warpMatrixInfo = nvgpu::getWarpMatrixInfo(op); if (failed(warpMatrixInfo)) - return failure(); + return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); FailureOr regInfo = nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); if (failed(regInfo)) - return failure(); + return rewriter.notifyMatchFailure(op, "not mma sync reg info"); VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); - Value laneId = b.create(loc); + Value laneId = rewriter.create(loc); for (unsigned i = 0; i < vectorType.getShape()[0]; i++) { - Value logicalValueId = b.create( - loc, b.getIndexType(), - b.getIndexAttr(i * regInfo->elementsPerRegister)); + Value logicalValueId = rewriter.create( + loc, rewriter.getIndexType(), + rewriter.getIndexAttr(i * regInfo->elementsPerRegister)); FailureOr coords = nvgpu::getLaneIdAndValueIdToOperandCoord( - op.getLoc(), b, *warpMatrixInfo); + rewriter, op.getLoc(), *warpMatrixInfo); if (failed(coords)) - return failure(); + return rewriter.notifyMatchFailure(op, "no coords"); - Value el = b.create(loc, matrix, ArrayRef{i}); + Value el = + rewriter.create(loc, matrix, ArrayRef{i}); SmallVector newIndices; getXferIndices( - b, op, *coords, {laneId, logicalValueId}, newIndices); - b.create(loc, el, op.getSource(), newIndices); + rewriter, op, *coords, {laneId, logicalValueId}, newIndices); + rewriter.create(loc, el, op.getSource(), newIndices); } - op->erase(); + + LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); + rewriter.eraseOp(op); return success(); } @@ -808,35 +884,37 @@ } static LogicalResult -convertExtractStridedSlice(vector::ExtractStridedSliceOp op, +convertExtractStridedSlice(RewriterBase &rewriter, + vector::ExtractStridedSliceOp op, llvm::DenseMap &valueMapping) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); - OpBuilder b(op); Location loc = op->getLoc(); FailureOr warpMatrixInfo = nvgpu::getWarpMatrixInfo(op); if (failed(warpMatrixInfo)) - return failure(); + return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); FailureOr mmaSyncFragmentInfo = nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); if (failed(mmaSyncFragmentInfo)) - return failure(); + return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo"); // Find the vector.transer_read whose result vector is being sliced. auto transferReadOp = op.getVector().getDefiningOp(); if (!transferReadOp) - return failure(); + return rewriter.notifyMatchFailure(op, "no transfer read"); warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp); if (failed(warpMatrixInfo)) - return failure(); + return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); FailureOr ldFragmentInfo = nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); if (failed(ldFragmentInfo)) - return failure(); + return rewriter.notifyMatchFailure(op, "no ldFragmentInfo"); assert( (mmaSyncFragmentInfo->elementsPerRegister == @@ -849,7 +927,10 @@ std::array sliceShape = { mmaSyncFragmentInfo->numRegistersPerFragment, mmaSyncFragmentInfo->elementsPerRegister}; - auto sourceVector = valueMapping.find(transferReadOp)->second; + auto it = valueMapping.find(transferReadOp); + if (it == valueMapping.end()) + return rewriter.notifyMatchFailure(op, "no mapping"); + auto sourceVector = it->second; // offset and sizes at warp-level of onwership. SmallVector offsets; @@ -871,86 +952,114 @@ else if (offsets[1]) sliceOffset[0] = (warpVectorShape[1] / offsets[1]); - Value newOp = b.create( + Value newOp = rewriter.create( loc, sourceVector, sliceOffset, sliceShape, strides); valueMapping[op] = newOp; return success(); } -static void convertContractOp(vector::ContractionOp op, - llvm::DenseMap &valueMapping) { - OpBuilder b(op); - Value opA = valueMapping.find(op.getLhs())->second; - Value opB = valueMapping.find(op.getRhs())->second; - Value opC = valueMapping.find(op.getAcc())->second; - Value matmul = b.create( +static LogicalResult +convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, + llvm::DenseMap &valueMapping) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + + auto itA = valueMapping.find(op.getLhs()); + auto itB = valueMapping.find(op.getRhs()); + auto itC = valueMapping.find(op.getAcc()); + if (itA == valueMapping.end() || itB == valueMapping.end() || + itC == valueMapping.end()) + return rewriter.notifyMatchFailure(op, "no mapping"); + Value opA = itA->second, opB = itB->second, opC = itC->second; + Value matmul = rewriter.create( op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(), /*b_transpose=*/UnitAttr()); valueMapping[op.getResult()] = matmul; + return success(); } static LogicalResult -convertContractOpToMmaSync(vector::ContractionOp op, +convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap &valueMapping) { - OpBuilder b(op); - Value opA = valueMapping.find(op.getLhs())->second; - Value opB = valueMapping.find(op.getRhs())->second; - Value opC = valueMapping.find(op.getAcc())->second; + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + + auto itA = valueMapping.find(op.getLhs()); + auto itB = valueMapping.find(op.getRhs()); + auto itC = valueMapping.find(op.getAcc()); + if (itA == valueMapping.end() || itB == valueMapping.end() || + itC == valueMapping.end()) + return rewriter.notifyMatchFailure(op, "no mapping"); + Value opA = itA->second, opB = itB->second, opC = itC->second; int64_t m = op.getLhs().getType().cast().getShape()[0]; int64_t n = op.getRhs().getType().cast().getShape()[0]; int64_t k = op.getLhs().getType().cast().getShape()[1]; - Value matmul = b.create(op.getLoc(), opA, opB, opC, - b.getI64ArrayAttr({m, n, k})); + Value matmul = rewriter.create( + op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k})); valueMapping[op.getResult()] = matmul; return success(); } /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. -static void convertConstantOp(arith::ConstantOp op, - llvm::DenseMap &valueMapping) { +static LogicalResult +convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, + llvm::DenseMap &valueMapping) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + assert(constantSupportsMMAMatrixType(op)); - OpBuilder b(op); + auto splat = op.getValue().cast().getSplatValue(); auto scalarConstant = - b.create(op.getLoc(), splat.getType(), splat); + rewriter.create(op.getLoc(), splat.getType(), splat); const char *fragType = inferFragType(op); auto vecType = op.getType().cast(); gpu::MMAMatrixType type = gpu::MMAMatrixType::get( vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); - auto matrix = b.create(op.getLoc(), type, - scalarConstant); + auto matrix = rewriter.create( + op.getLoc(), type, scalarConstant); valueMapping[op.getResult()] = matrix; + return success(); } /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op. -static void convertBroadcastOp(vector::BroadcastOp op, - llvm::DenseMap &valueMapping) { +static LogicalResult +convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, + llvm::DenseMap &valueMapping) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + assert(broadcastSupportsMMAMatrixType(op)); - OpBuilder b(op); + const char *fragType = inferFragType(op); auto vecType = op.getVectorType(); gpu::MMAMatrixType type = gpu::MMAMatrixType::get( vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); - auto matrix = b.create(op.getLoc(), type, - op.getSource()); + auto matrix = rewriter.create( + op.getLoc(), type, op.getSource()); valueMapping[op.getResult()] = matrix; + return success(); } // Replace ForOp with a new ForOp with extra operands. The YieldOp is not // updated and needs to be updated separatly for the loop to be correct. -static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop, +static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, + scf::ForOp loop, ValueRange newIterOperands) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(loop); + // Create a new loop before the existing one, with the extra operands. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(loop); + rewriter.setInsertionPoint(loop); auto operands = llvm::to_vector<4>(loop.getIterOperands()); operands.append(newIterOperands.begin(), newIterOperands.end()); - scf::ForOp newLoop = - b.create(loop.getLoc(), loop.getLowerBound(), - loop.getUpperBound(), loop.getStep(), operands); + scf::ForOp newLoop = rewriter.create( + loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), + operands); newLoop.getBody()->erase(); + newLoop.getLoopBody().getBlocks().splice( newLoop.getLoopBody().getBlocks().begin(), loop.getLoopBody().getBlocks()); @@ -959,25 +1068,35 @@ for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( loop.getNumResults()))) - std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); - loop.erase(); + rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); + + LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n"); + LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n"); + LLVM_DEBUG(DBGS() << "erase: " << loop); + + rewriter.eraseOp(loop); return newLoop; } -static void convertForOp(scf::ForOp op, - llvm::DenseMap &valueMapping) { +static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, + llvm::DenseMap &valueMapping) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + SmallVector newOperands; SmallVector> argMapping; for (const auto &operand : llvm::enumerate(op.getIterOperands())) { auto it = valueMapping.find(operand.value()); - if (it == valueMapping.end()) + if (it == valueMapping.end()) { + LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n"); continue; + } argMapping.push_back(std::make_pair( operand.index(), op.getNumIterOperands() + newOperands.size())); newOperands.push_back(it->second); } - OpBuilder b(op); - scf::ForOp newForOp = replaceForOpWithNewSignature(b, op, newOperands); + + scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands); Block &loopBody = *newForOp.getBody(); for (auto mapping : argMapping) { valueMapping[newForOp.getResult(mapping.first)] = @@ -986,11 +1105,17 @@ newForOp.getNumInductionVars())] = loopBody.getArgument(mapping.second + newForOp.getNumInductionVars()); } + + LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n"); + return success(); } -static void convertYieldOp(scf::YieldOp op, - llvm::DenseMap &valueMapping) { - OpBuilder b(op); +static LogicalResult +convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, + llvm::DenseMap &valueMapping) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + auto loop = cast(op->getParentOp()); auto yieldOperands = llvm::to_vector<4>(op.getOperands()); for (const auto &operand : llvm::enumerate(op.getOperands())) { @@ -1002,20 +1127,32 @@ yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()]; yieldOperands.push_back(it->second); } - b.create(op.getLoc(), yieldOperands); - op.erase(); + rewriter.create(op.getLoc(), yieldOperands); + + LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); + rewriter.eraseOp(op); + return success(); } /// Convert an elementwise op to the equivalent elementwise op on MMA matrix. -static void convertElementwiseOp(Operation *op, gpu::MMAElementwiseOp opType, - llvm::DenseMap &valueMapping) { - OpBuilder b(op); +static LogicalResult +convertElementwiseOp(RewriterBase &rewriter, Operation *op, + gpu::MMAElementwiseOp opType, + llvm::DenseMap &valueMapping) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + SmallVector matrixOperands; - for (Value operand : op->getOperands()) - matrixOperands.push_back(valueMapping.find(operand)->second); - Value newOp = b.create( + for (Value operand : op->getOperands()) { + auto it = valueMapping.find(operand); + if (it == valueMapping.end()) + return rewriter.notifyMatchFailure(op, "no mapping"); + matrixOperands.push_back(it->second); + } + Value newOp = rewriter.create( op->getLoc(), matrixOperands[0].getType(), matrixOperands, opType); valueMapping[op->getResult(0)] = newOp; + return success(); } void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, @@ -1030,67 +1167,75 @@ patterns.getContext()); } -void mlir::convertVectorToMMAOps(Operation *rootOp) { +LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter, + Operation *rootOp) { SetVector ops = getOpToConvert(rootOp, /*useNvGpu=*/false); llvm::DenseMap valueMapping; + + auto globalRes = LogicalResult::success(); for (Operation *op : ops) { + LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n"); + // Apparently callers do not want to early exit on failure here. + auto res = LogicalResult::success(); if (auto transferRead = dyn_cast(op)) { - convertTransferReadOp(transferRead, valueMapping); + res = convertTransferReadOp(rewriter, transferRead, valueMapping); } else if (auto transferWrite = dyn_cast(op)) { - convertTransferWriteOp(transferWrite, valueMapping); + res = convertTransferWriteOp(rewriter, transferWrite, valueMapping); } else if (auto contractOp = dyn_cast(op)) { - convertContractOp(contractOp, valueMapping); + res = convertContractOp(rewriter, contractOp, valueMapping); } else if (auto constantOp = dyn_cast(op)) { - convertConstantOp(constantOp, valueMapping); + res = convertConstantOp(rewriter, constantOp, valueMapping); } else if (auto broadcastOp = dyn_cast(op)) { - convertBroadcastOp(broadcastOp, valueMapping); + res = convertBroadcastOp(rewriter, broadcastOp, valueMapping); } else if (auto forOp = dyn_cast(op)) { - convertForOp(forOp, valueMapping); - } else if (auto yiledOp = dyn_cast(op)) { - convertYieldOp(yiledOp, valueMapping); + res = convertForOp(rewriter, forOp, valueMapping); + } else if (auto yieldOp = dyn_cast(op)) { + res = convertYieldOp(rewriter, yieldOp, valueMapping); } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) { - convertElementwiseOp(op, *elementwiseType, valueMapping); + res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping); } + if (failed(res)) + globalRes = failure(); } + return globalRes; } -LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(Operation *rootOp) { +LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, + Operation *rootOp) { SetVector ops = getOpToConvert(rootOp, /*useNvGpu=*/true); llvm::DenseMap valueMapping; for (Operation *op : ops) { if (llvm::TypeSwitch(op) .Case([&](vector::TransferReadOp transferReadOp) { - return convertTransferReadToLoads(transferReadOp, valueMapping); + return convertTransferReadToLoads(rewriter, transferReadOp, + valueMapping); }) .Case([&](vector::TransferWriteOp transferWriteOp) { - return convertTransferWriteToStores(transferWriteOp, + return convertTransferWriteToStores(rewriter, transferWriteOp, valueMapping); }) .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) { - return convertExtractStridedSlice(extractStridedSliceOp, + return convertExtractStridedSlice(rewriter, extractStridedSliceOp, valueMapping); }) .Case([&](vector::ContractionOp contractionOp) { - return convertContractOpToMmaSync(contractionOp, valueMapping); + return convertContractOpToMmaSync(rewriter, contractionOp, + valueMapping); }) .Case([&](scf::ForOp forOp) { - convertForOp(forOp, valueMapping); - return success(); + return convertForOp(rewriter, forOp, valueMapping); }) .Case([&](scf::YieldOp yieldOp) { - convertYieldOp(yieldOp, valueMapping); - return success(); + return convertYieldOp(rewriter, yieldOp, valueMapping); }) .Case([&](arith::ConstantOp constOp) { - return convertConstantOpMmaSync(constOp, valueMapping); + return convertConstantOpMmaSync(rewriter, constOp, valueMapping); }) .Default([&](Operation *op) { - op->emitError() << "unhandled vector to mma type: " << *op; - return failure(); + return op->emitError() << "unhandled vector to mma type: " << *op; }) .failed()) { - op->emitError() << "Failed to convert op " << *op; - return failure(); + return op->emitError() << "Failed to convert op " << *op; } } return success(); @@ -1112,12 +1257,13 @@ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); + IRRewriter rewriter(&getContext()); if (useNvGpu.getValue()) { - if (failed(convertVectorToNVVMCompatibleMMASync(getOperation()))) + if (failed( + convertVectorToNVVMCompatibleMMASync(rewriter, getOperation()))) return signalPassFailure(); } - - (void)convertVectorToMMAOps(getOperation()); + (void)convertVectorToMMAOps(rewriter, getOperation()); } }; diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp --- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp +++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp @@ -170,7 +170,7 @@ } FailureOr -nvgpu::getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder, +nvgpu::getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType) { Type elementType = fragmentType.vectorType.getElementType(); ArrayRef operandShape = fragmentType.vectorType.getShape(); @@ -235,7 +235,7 @@ } FailureOr -nvgpu::getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder, +nvgpu::getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams ¶ms) { // One thread per 128b row. const int bitsPerElement = static_cast( diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir --- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir +++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-gpu),canonicalize)" | FileCheck %s +// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-gpu),canonicalize)" --split-input-file | FileCheck %s #map0 = affine_map<(d0, d1) -> (d1, d0)> #map1 = affine_map<(d0, d1, d2) -> (d0, d2)> @@ -25,6 +25,15 @@ return } +// ----- + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map4 = affine_map<(d0) -> (d0, 0)> +#map5 = affine_map<(d0, d1) -> (d0, d1)> + // CHECK-LABEL: func @matmul_cst // CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f16 // CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> @@ -43,6 +52,15 @@ return } +// ----- + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map4 = affine_map<(d0) -> (d0, 0)> +#map5 = affine_map<(d0, d1) -> (d0, d1)> + // CHECK-LABEL: func @matmul_broadcast // CHECK-SAME: (%{{.*}}: memref<16x16xf16>, %{{.*}}: memref<16x16xf16>, %{{.*}}: memref<16x16xf16>, %[[F:.*]]: f16) // CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_constant_matrix %[[F]] : !gpu.mma_matrix<16x16xf16, "COp"> @@ -61,6 +79,15 @@ return } +// ----- + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map4 = affine_map<(d0) -> (d0, 0)> +#map5 = affine_map<(d0, d1) -> (d0, d1)> + // CHECK-LABEL: func @matmul_loop // CHECK: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "COp"> // CHECK: %[[ACC:.+]] = scf.for {{.*}} iter_args(%[[ACC1:.+]] = %[[C]]) -> (!gpu.mma_matrix<16x16xf16, "COp">) { @@ -86,6 +113,15 @@ return } +// ----- + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map4 = affine_map<(d0) -> (d0, 0)> +#map5 = affine_map<(d0, d1) -> (d0, d1)> + // CHECK-LABEL: func @matmul_fused_elementwise // CHECK-DAG: %[[CST_0:.+]] = arith.constant 0.000000e+00 : f16 // CHECK-DAG: %[[CST_1:.+]] = arith.constant 1.000000e+00 : f16 @@ -109,6 +145,15 @@ return } +// ----- + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map4 = affine_map<(d0) -> (d0, 0)> +#map5 = affine_map<(d0, d1) -> (d0, d1)> + // CHECK-LABEL: func @matmul_fused_broadcast // CHECK-DAG: %[[CST_0:.+]] = arith.constant 0.000000e+00 : f16 // CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> @@ -134,6 +179,15 @@ return } +// ----- + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map4 = affine_map<(d0) -> (d0, 0)> +#map5 = affine_map<(d0, d1) -> (d0, d1)> + // CHECK-LABEL: func @matmul_3Dmemref // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : memref<2x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> @@ -153,6 +207,15 @@ return } +// ----- + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map4 = affine_map<(d0) -> (d0, 0)> +#map5 = affine_map<(d0, d1) -> (d0, d1)> + // CHECK-LABEL: func @matmul_memref_strided // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 32 : index} : memref<2x16x16xf16, #{{.*}}> -> !gpu.mma_matrix<16x16xf16, "AOp"> @@ -172,6 +235,15 @@ return } +// ----- + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map4 = affine_map<(d0) -> (d0, 0)> +#map5 = affine_map<(d0, d1) -> (d0, d1)> + // CHECK-LABEL: func @matmul_transposed // CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> // CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index, transpose} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp"> @@ -190,6 +262,15 @@ return } +// ----- + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map4 = affine_map<(d0) -> (d0, 0)> +#map5 = affine_map<(d0, d1) -> (d0, d1)> + // CHECK-LABEL: func @matmul_transposed_broadcasted_1d // CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}] {leadDimension = 0 : index, transpose} : memref<16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> // CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}] {leadDimension = 0 : index} : memref<16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp"> @@ -208,6 +289,15 @@ return } +// ----- + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map4 = affine_map<(d0) -> (d0, 0)> +#map5 = affine_map<(d0, d1) -> (d0, d1)> + // CHECK-LABEL: func @matmul_transposed_broadcasted_2d // CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}] {leadDimension = 0 : index, transpose} : memref<32x32xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> // CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}] {leadDimension = 0 : index} : memref<32x32xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">