Changeset View
Standalone View
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
- This file was added.
//===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===// | |||||
// | |||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||||
// See https://llvm.org/LICENSE.txt for license information. | |||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||||
// | |||||
//===----------------------------------------------------------------------===// | |||||
// | |||||
// This file implements lowering of vector operations to GPU dialect ops. | |||||
// | |||||
//===----------------------------------------------------------------------===// | |||||
#include <type_traits> | |||||
#include "mlir/Conversion/VectorToGPU/VectorToGPU.h" | |||||
#include "../PassDetail.h" | |||||
#include "mlir/Analysis/SliceAnalysis.h" | |||||
#include "mlir/Dialect/GPU/GPUDialect.h" | |||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h" | |||||
#include "mlir/Dialect/Vector/VectorOps.h" | |||||
#include "mlir/Dialect/Vector/VectorUtils.h" | |||||
#include "mlir/IR/Builders.h" | |||||
#include "mlir/Pass/Pass.h" | |||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | |||||
#include "mlir/Transforms/Passes.h" | |||||
using namespace mlir; | |||||
// Return true if the contract op can be convert to MMA matmul. | |||||
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) { | |||||
if (llvm::size(contract.masks()) != 0) | |||||
return false; | |||||
using MapList = ArrayRef<ArrayRef<AffineExpr>>; | |||||
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; | |||||
AffineExpr m, n, k; | |||||
bindDims(contract.getContext(), m, n, k); | |||||
auto iteratorTypes = contract.iterator_types().getValue(); | |||||
if (!(isParallelIterator(iteratorTypes[0]) && | |||||
isParallelIterator(iteratorTypes[1]) && | |||||
isReductionIterator(iteratorTypes[2]))) | |||||
return false; | |||||
// The contract needs to represent a matmul to be able to convert to | |||||
// MMAMatrix matmul. | |||||
if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}})) | |||||
return false; | |||||
// Check that the size match what is natively supported. | |||||
VectorType lhsType = contract.lhs().getType().cast<VectorType>(); | |||||
nicolasvasilache: nit: matches | |||||
VectorType rhsType = contract.rhs().getType().cast<VectorType>(); | |||||
VectorType accType = contract.acc().getType().cast<VectorType>(); | |||||
std::tuple<int, int, int> dim(lhsType.getDimSize(0), rhsType.getDimSize(1), | |||||
lhsType.getDimSize(1)); | |||||
if (lhsType.getElementType().isInteger(8) && | |||||
rhsType.getElementType().isInteger(8) && | |||||
accType.getElementType().isInteger(32) && | |||||
(dim == std::make_tuple(8, 8, 32) || dim == std::make_tuple(16, 16, 32) || | |||||
dim == std::make_tuple(16, 8, 32))) | |||||
return true; | |||||
if (lhsType.getElementType().isF16() && rhsType.getElementType().isF16() && | |||||
(accType.getElementType().isF16() || accType.getElementType().isF32()) && | |||||
(dim == std::make_tuple(8, 8, 16) || dim == std::make_tuple(16, 16, 16) || | |||||
dim == std::make_tuple(16, 8, 16))) | |||||
return true; | |||||
return false; | |||||
} | |||||
// Return the stide for the dimension 0 of |type| if it is a memref and has a | |||||
// constant stride. | |||||
static llvm::Optional<int64_t> | |||||
getMemrefConstantHorizontalStride(ShapedType type) { | |||||
auto memrefType = type.dyn_cast<MemRefType>(); | |||||
if (!memrefType) | |||||
return false; | |||||
int64_t offset = 0; | |||||
SmallVector<int64_t, 2> strides; | |||||
if (failed(getStridesAndOffset(memrefType, strides, offset))) | |||||
return llvm::None; | |||||
if (strides[0] == ShapedType::kDynamicStrideOrOffset) | |||||
return llvm::None; | |||||
return strides[0]; | |||||
} | |||||
// Return true if the transfer op can be converted to a MMA matrix load. | |||||
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) { | |||||
if (readOp.mask() || readOp.hasOutOfBoundsDim() || | |||||
readOp.getVectorType().getRank() != 2) | |||||
return false; | |||||
if(!getMemrefConstantHorizontalStride(readOp.getShapedType())) | |||||
Lint: Pre-merge checks clang-format: please reformat the code - if(!getMemrefConstantHorizontalStride(readOp.getShapedType())) + if (!getMemrefConstantHorizontalStride(readOp.getShapedType())) Lint: Pre-merge checks: clang-format: please reformat the code
```
- if(!getMemrefConstantHorizontalStride(readOp. | |||||
return false; | |||||
// TODO: Support transpose once it is added to GPU dialect ops. | |||||
if (!readOp.permutation_map().isMinorIdentity()) | |||||
return false; | |||||
return true; | |||||
} | |||||
// Return true if the transfer op can be converted to a MMA matrix store. | |||||
static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { | |||||
Lint: Pre-merge checks clang-format: please reformat the code -static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { +static bool +transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { Lint: Pre-merge checks: clang-format: please reformat the code
```
-static bool transferWriteSupportsMMAMatrixType… | |||||
if (writeOp.mask() || writeOp.hasOutOfBoundsDim() || | |||||
writeOp.getVectorType().getRank() != 2) | |||||
return false; | |||||
if(!getMemrefConstantHorizontalStride(writeOp.getShapedType())) | |||||
Lint: Pre-merge checks clang-format: please reformat the code - if(!getMemrefConstantHorizontalStride(writeOp.getShapedType())) + if (!getMemrefConstantHorizontalStride(writeOp.getShapedType())) Lint: Pre-merge checks: clang-format: please reformat the code
```
- if(!getMemrefConstantHorizontalStride(writeOp. | |||||
return false; | |||||
// TODO: Support transpose once it is added to GPU dialect ops. | |||||
if (!writeOp.permutation_map().isMinorIdentity()) | |||||
return false; | |||||
return true; | |||||
} | |||||
static bool supportsMMaMatrixType(Operation *op) { | |||||
Nite: Analyze ? nicolasvasilache: Nite: Analyze ? | |||||
if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) { | |||||
return transferReadSupportsMMAMatrixType(transferRead); | |||||
} | |||||
nit: trivial braces here and below nicolasvasilache: nit: trivial braces here and below | |||||
if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) { | |||||
You should be able to only walk the vector::ContractionOp and save some logic below. nicolasvasilache: You should be able to only walk the vector::ContractionOp and save some logic below. | |||||
return transferWriteSupportsMMAMatrixType(transferWrite); | |||||
} | |||||
if (auto contract = dyn_cast<vector::ContractionOp>(op)) { | |||||
return contractSupportsMMAMatrixType(contract); | |||||
auto hasVectorDest = [](Operation *op) { return op->getNumResults() == 0 || llvm::any_of(op->getResultTypes(), [](Type t) { return t.isa<VectorType>(); }); }; Also, do you want "any_of" or "all_of" behavior ? I'd also lift it at the top of the function to avoid interleaving logic and simplify the reading. nicolasvasilache: ```
auto hasVectorDest = [](Operation *op) {
return op->getNumResults() == 0 || llvm::any_of… | |||||
} | |||||
return false; | |||||
} | |||||
// Analyze slice of operations based on convert op to figure out if the whole | |||||
// slice can be converted to MMA operations. | |||||
static SetVector<Operation *> getOpToConvert(mlir::Operation *op) { | |||||
auto hasVectorDest = [](Operation *op) { | |||||
return op->getNumResults() == 0 || | |||||
llvm::any_of(op->getResultTypes(), | |||||
if (llvm::any_of(dependentOp, [](){ return !supports...})) return; nicolasvasilache: ```
if (llvm::any_of(dependentOp, [](){ return !supports...}))
return;
```
| |||||
[](Type t) { return t.isa<VectorType>(); }); | |||||
}; | |||||
SetVector<Operation *> opToConvert; | |||||
op->walk([&](vector::ContractionOp contract) { | |||||
if (opToConvert.contains(contract.getOperation())) | |||||
return; | |||||
SetVector<Operation *> dependentOps = | |||||
getSlice(contract, hasVectorDest, hasVectorDest); | |||||
// If any instruction cannot use MMA matrix type drop the whole | |||||
// chaine. MMA matrix are stored in an opaque type so they cannot be used | |||||
// by all operations. | |||||
if (llvm::any_of(dependentOps, | |||||
[](Operation *op) { return !supportsMMaMatrixType(op); })) | |||||
return; | |||||
opToConvert.insert(dependentOps.begin(), dependentOps.end()); | |||||
}); | |||||
Not Done ReplyInline ActionsCan a proper subset of the vector matmul lowering pattern be exposed and plugged here (maybe with some extra lambda) ? nicolasvasilache: Can a proper subset of the vector matmul lowering pattern be exposed and plugged here (maybe… | |||||
I couldn't get to anything that makes sense, the matmul lowering is trying to put the contract in the form (k, m), (k, n), (m, n) while this code is transforming it to (m, k), (k, n), (m, n) so most of the logic is different the only thing that I could move into a common function was the dim binding but it tends to make the code more complicated as the matmul lowering also handles vector * mat. ThomasRaoux: I couldn't get to anything that makes sense, the matmul lowering is trying to put the contract… | |||||
Not Done ReplyInline ActionsThat's sad, these are all so close. Not for this CL obviously, thanks for trying! nicolasvasilache: That's sad, these are all so close.
I guess we are reaching the point where we want to cast… | |||||
Correct, we can reduce the number of cases here and in matmul lowering, I'll try to do it in the next CL. ThomasRaoux: Correct, we can reduce the number of cases here and in matmul lowering, I'll try to do it in… | |||||
return opToConvert; | |||||
} | |||||
namespace { | |||||
// Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted | |||||
// to MMA matmul. | |||||
struct PrepareContractToGPUMMA | |||||
: public OpRewritePattern<vector::ContractionOp> { | |||||
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; | |||||
LogicalResult matchAndRewrite(vector::ContractionOp op, | |||||
PatternRewriter &rewriter) const override { | |||||
Location loc = op.getLoc(); | |||||
Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc(); | |||||
// Set up the parallel/reduction structure in right form. | |||||
using MapList = ArrayRef<ArrayRef<AffineExpr>>; | |||||
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; | |||||
AffineExpr m, n, k; | |||||
bindDims(rewriter.getContext(), m, n, k); | |||||
static constexpr std::array<int64_t, 2> perm = {1, 0}; | |||||
auto iteratorTypes = op.iterator_types().getValue(); | |||||
SmallVector<AffineMap, 4> maps = op.getIndexingMaps(); | |||||
if (!(isParallelIterator(iteratorTypes[0]) && | |||||
isParallelIterator(iteratorTypes[1]) && | |||||
isReductionIterator(iteratorTypes[2]))) | |||||
return failure(); | |||||
// | |||||
// Two outer parallel, one inner reduction (matmat flavor). | |||||
// | |||||
if (maps == infer({{m, k}, {k, n}, {m, n}})) { | |||||
// This is the classical row-major matmul, nothing to do. | |||||
return failure(); | |||||
} | |||||
if (maps == infer({{m, k}, {n, k}, {m, n}})) { | |||||
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); | |||||
} else if (maps == infer({{k, m}, {k, n}, {m, n}})) { | |||||
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); | |||||
} else if (maps == infer({{k, m}, {n, k}, {m, n}})) { | |||||
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); | |||||
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); | |||||
} else if (maps == infer({{m, k}, {k, n}, {n, m}})) { | |||||
std::swap(rhs, lhs); | |||||
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); | |||||
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); | |||||
} else if (maps == infer({{m, k}, {n, k}, {n, m}})) { | |||||
std::swap(rhs, lhs); | |||||
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); | |||||
} else if (maps == infer({{k, m}, {k, n}, {n, m}})) { | |||||
std::swap(lhs, rhs); | |||||
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); | |||||
} else if (maps == infer({{k, m}, {n, k}, {n, m}})) { | |||||
std::swap(lhs, rhs); | |||||
} else { | |||||
return failure(); | |||||
} | |||||
rewriter.replaceOpWithNewOp<vector::ContractionOp>( | |||||
op, lhs, rhs, res, | |||||
rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})), | |||||
op.iterator_types()); | |||||
Not Done ReplyInline ActionsWe should already have a pattern that rewrites vector.transfer + vector.transpose into the proper memref.transpose + vector.transfer; can we reuse it? nicolasvasilache: We should already have a pattern that rewrites vector.transfer + vector.transpose into the… | |||||
I can't find this pattern but i'm not sure how this help for this case. Ideally I want a transfer with transpose that can directly be lowered to a mma.load op. How would the memref.transpose help this case? ThomasRaoux: I can't find this pattern but i'm not sure how this help for this case. Ideally I want a… | |||||
Not Done ReplyInline ActionsI forgot part of the thinking but IIRC, the idea was that since the mma.load seems richer, it may be possible to fold these transposes into the vector.transfer indexing logic and propagate that to the mma ops; rather than perform acrual transposes. Not clear whether this is really possible, I may look a little more into it in the future. nicolasvasilache: I forgot part of the thinking but IIRC, the idea was that since the mma.load seems richer, it… | |||||
I'm a bit confused, the pattern does merge the transpose into the transfer_read indexing logic. I don't understand why we would want a memref.transpose. If there is a pattern already doing transpose+transfer_read -> transfer_read with affine map I can use it but I couldn't find any. ThomasRaoux: I'm a bit confused, the pattern does merge the transpose into the transfer_read indexing logic. | |||||
Not Done ReplyInline ActionsYou're right, I confirmed with @bkramer offline that this actually did not land, I just reviewed it a few months back. I imagine we will want to revive the vector.transfer + transpose -> vector.transfer + strided_memref ? nicolasvasilache: You're right, I confirmed with @bkramer offline that this actually did not land, I just… | |||||
Sounds good, thanks for checking, I can help move those patterns to a more generic place when it makes sense ThomasRaoux: Sounds good, thanks for checking, I can help move those patterns to a more generic place when… | |||||
return success(); | |||||
} | |||||
}; | |||||
// Merge transpose op into the transfer read op. Transpose are not supported on | |||||
// MMA types but MMA load can transpose the matrix when loading. | |||||
struct CombineTransferReadOpTranspose final | |||||
: public OpRewritePattern<vector::TransposeOp> { | |||||
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; | |||||
LogicalResult matchAndRewrite(vector::TransposeOp op, | |||||
PatternRewriter &rewriter) const override { | |||||
auto transferReadOp = op.vector().getDefiningOp<vector::TransferReadOp>(); | |||||
if (!transferReadOp) | |||||
return failure(); | |||||
if (transferReadOp.mask() || transferReadOp.hasOutOfBoundsDim()) | |||||
return failure(); | |||||
SmallVector<int64_t, 2> perm; | |||||
op.getTransp(perm); | |||||
SmallVector<unsigned, 2> permU; | |||||
for (int64_t o : perm) | |||||
permU.push_back(unsigned(o)); | |||||
AffineMap permutationMap = | |||||
AffineMap::getPermutationMap(permU, op.getContext()); | |||||
AffineMap newMap = permutationMap.compose(transferReadOp.permutation_map()); | |||||
rewriter.replaceOpWithNewOp<vector::TransferReadOp>( | |||||
op, op.getType(), transferReadOp.source(), transferReadOp.indices(), | |||||
newMap, transferReadOp.padding(), transferReadOp.mask(), | |||||
transferReadOp.in_boundsAttr()); | |||||
return success(); | |||||
} | |||||
}; | |||||
} // namespace | |||||
// MMA types have different layout based on how they are used in matmul ops. | |||||
// Figure the right layout to use by looking at Transfer op uses. | |||||
// TODO: Change the GPU dialect to abstract the layout at the this level and | |||||
// only care about it during lowering to NVVM. | |||||
static const char* inferFragType(vector::TransferReadOp op) { | |||||
Lint: Pre-merge checks clang-format: please reformat the code -static const char* inferFragType(vector::TransferReadOp op) { +static const char *inferFragType(vector::TransferReadOp op) { Lint: Pre-merge checks: clang-format: please reformat the code
```
-static const char* inferFragType(vector… | |||||
Not Done ReplyInline Actions+1 I wonder whether this has a relation to my rambling about transpose and vector.transfer ? nicolasvasilache: +1 I wonder whether this has a relation to my rambling about transpose and vector.transfer ? | |||||
I don't think it is directly related. The transfer op should already have the transpose indexing merged at this point. ThomasRaoux: I don't think it is directly related. The transfer op should already have the transpose… | |||||
for (Operation *users : op->getUsers()) { | |||||
auto contract = dyn_cast<vector::ContractionOp>(users); | |||||
if(!contract) | |||||
Lint: Pre-merge checks clang-format: please reformat the code - if(!contract) + if (!contract) Lint: Pre-merge checks: clang-format: please reformat the code
```
- if(!contract)
+ if (!contract)
``` | |||||
continue; | |||||
if(contract.lhs() == op.getResult()) | |||||
Lint: Pre-merge checks clang-format: please reformat the code - if(contract.lhs() == op.getResult()) + if (contract.lhs() == op.getResult()) Lint: Pre-merge checks: clang-format: please reformat the code
```
- if(contract.lhs() == op.getResult())
+ if… | |||||
return "AOp"; | |||||
if(contract.rhs() == op.getResult()) | |||||
Lint: Pre-merge checks clang-format: please reformat the code - if(contract.rhs() == op.getResult()) + if (contract.rhs() == op.getResult()) Lint: Pre-merge checks: clang-format: please reformat the code
```
- if(contract.rhs() == op.getResult())
+ if… | |||||
return "BOp"; | |||||
} | |||||
return "COp"; | |||||
} | |||||
static void convertTransferReadOp(vector::TransferReadOp op, | |||||
llvm::DenseMap<Value, Value> &valueMapping) { | |||||
assert(transferReadSupportsMMAMatrixType(op)); | |||||
Optional<int64_t> stride = | |||||
Not Done ReplyInline ActionsYou need to check for static stride (and bail if not). nicolasvasilache: You need to check for static stride (and bail if not).
In the general case of dynamic stride… | |||||
Good point, I moved this into a helper function. ThomasRaoux: Good point, I moved this into a helper function. | |||||
getMemrefConstantHorizontalStride(op.getShapedType()); | |||||
assert(stride); | |||||
const char *fragType = inferFragType(op); | |||||
gpu::MMAMatrixType type = gpu::MMAMatrixType::get( | |||||
Lint: Pre-merge checks clang-format: please reformat the code - gpu::MMAMatrixType type = gpu::MMAMatrixType::get( - op.getVectorType().getShape(), - op.getVectorType().getElementType(), fragType); + gpu::MMAMatrixType type = + gpu::MMAMatrixType::get(op.getVectorType().getShape(), + op.getVectorType().getElementType(), fragType); Lint: Pre-merge checks: clang-format: please reformat the code
```
- gpu::MMAMatrixType type = gpu::MMAMatrixType::get… | |||||
op.getVectorType().getShape(), | |||||
op.getVectorType().getElementType(), fragType); | |||||
OpBuilder b(op); | |||||
Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>( | |||||
op.getLoc(), type, op.source(), op.indices(), b.getIndexAttr(*stride)); | |||||
valueMapping[op.getResult()] = load; | |||||
} | |||||
static void convertTransferWriteOp(vector::TransferWriteOp op, | |||||
llvm::DenseMap<Value, Value> &valueMapping) { | |||||
assert(transferWriteSupportsMMAMatrixType(op)); | |||||
Optional<int64_t> stride = | |||||
getMemrefConstantHorizontalStride(op.getShapedType()); | |||||
assert(stride); | |||||
OpBuilder b(op); | |||||
Not Done ReplyInline ActionsYou need to check for static stride (and bail if not). nicolasvasilache: You need to check for static stride (and bail if not).
In the general case of dynamic stride… | |||||
ditto ThomasRaoux: ditto | |||||
Value matrix = valueMapping.find(op.vector())->second; | |||||
b.create<gpu::SubgroupMmaStoreMatrixOp>(op.getLoc(), matrix, op.source(), | |||||
Lint: Pre-merge checks clang-format: please reformat the code - b.create<gpu::SubgroupMmaStoreMatrixOp>(op.getLoc(), matrix, op.source(), - op.indices(), b.getIndexAttr(*stride)); + b.create<gpu::SubgroupMmaStoreMatrixOp>( + op.getLoc(), matrix, op.source(), op.indices(), b.getIndexAttr(*stride)); Lint: Pre-merge checks: clang-format: please reformat the code
```
- b.create<gpu::SubgroupMmaStoreMatrixOp>(op.getLoc… | |||||
op.indices(), b.getIndexAttr(*stride)); | |||||
op.erase(); | |||||
} | |||||
static void convertContractOp(vector::ContractionOp op, | |||||
llvm::DenseMap<Value, Value> &valueMapping) { | |||||
OpBuilder b(op); | |||||
Value A = valueMapping.find(op.lhs())->second; | |||||
Lint: Pre-merge checks clang-tidy: warning: invalid case style for variable 'A' [readability-identifier-naming] Lint: Pre-merge checks: clang-tidy: warning: invalid case style for variable 'A' [readability-identifier-naming]… | |||||
Value B = valueMapping.find(op.rhs())->second; | |||||
Lint: Pre-merge checks clang-tidy: warning: invalid case style for variable 'B' [readability-identifier-naming] Lint: Pre-merge checks: clang-tidy: warning: invalid case style for variable 'B' [readability-identifier-naming]… | |||||
Value C = valueMapping.find(op.acc())->second; | |||||
Lint: Pre-merge checks clang-tidy: warning: invalid case style for variable 'C' [readability-identifier-naming] Lint: Pre-merge checks: clang-tidy: warning: invalid case style for variable 'C' [readability-identifier-naming]… | |||||
Value matmul = | |||||
b.create<gpu::SubgroupMmaComputeOp>(op.getLoc(), C.getType(), A, B, C); | |||||
valueMapping[op.getResult()] = matmul; | |||||
} | |||||
namespace mlir { | |||||
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) { | |||||
patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>( | |||||
patterns.getContext()); | |||||
} | |||||
void convertVectorToMMAOps(FuncOp funcOp) { | |||||
SetVector<Operation *> ops = getOpToConvert(funcOp); | |||||
llvm::DenseMap<Value, Value> valueMapping; | |||||
for (Operation *op : ops) { | |||||
if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) { | |||||
convertTransferReadOp(transferRead, valueMapping); | |||||
} else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) { | |||||
convertTransferWriteOp(transferWrite, valueMapping); | |||||
} else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) { | |||||
convertContractOp(contractOp, valueMapping); | |||||
} | |||||
} | |||||
} | |||||
} // namespace mlir | |||||
namespace { | |||||
struct ConvertVectorToGPUPass | |||||
: public ConvertVectorToGPUBase<ConvertVectorToGPUPass> { | |||||
void runOnFunction() override { | |||||
RewritePatternSet patterns(getFunction().getContext()); | |||||
populatePrepareVectorToMMAPatterns(patterns); | |||||
Not Done ReplyInline ActionsI would split anything related to scf.for + scf.yield in a separate CL and discuss there; in particular there are quite some rewrite and canonicalization patterns that may simplify some of this code. nicolasvasilache: I would split anything related to scf.for + scf.yield in a separate CL and discuss there; in… | |||||
I removed all the code related to scf, I'll send a separate patch for it once this one lands. ThomasRaoux: I removed all the code related to scf, I'll send a separate patch for it once this one lands. | |||||
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); | |||||
convertVectorToMMAOps(getFunction()); | |||||
} | |||||
}; | |||||
} // namespace | |||||
std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass() { | |||||
return std::make_unique<ConvertVectorToGPUPass>(); | |||||
} |
nit: matches