Please use GitHub pull requests for new patches. Avoid migrating existing patches. Phabricator shutdown timeline
Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Show First 20 Lines • Show All 379 Lines • ▼ Show 20 Lines | LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite( | ||||
ArrayRef<int64_t> inShape = inShapeType.getShape(); | ArrayRef<int64_t> inShape = inShapeType.getShape(); | ||||
ArrayRef<int64_t> kShape = kShapeType.getShape(); | ArrayRef<int64_t> kShape = kShapeType.getShape(); | ||||
if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape()) | if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape()) | ||||
return failure(); | return failure(); | ||||
SmallVector<AffineExpr, 4> mapping; | SmallVector<AffineExpr, 4> mapping; | ||||
// Fail to apply when the size of not vectorized dimension is not 1 or | SmallVector<int64_t, 4> vectorDims; | ||||
// when the size of vectorized dimension is not dimSize. | // Fail to apply when the size of not vectorized dimension is not 1. | ||||
for (unsigned i = 0; i < N; i++) { | for (unsigned i = 0; i < N; i++) { | ||||
if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1)) | if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1)) | ||||
return failure(); | return failure(); | ||||
if (mask[i] && (inShape[i] != tileSize || kShape[i] != tileSize)) | |||||
if (mask[i] && inShape[i] != kShape[i]) | |||||
return failure(); | return failure(); | ||||
if (mask[i]) | if (mask[i]) { | ||||
mapping.push_back(getAffineDimExpr(i, context)); | mapping.push_back(getAffineDimExpr(i, context)); | ||||
vectorDims.push_back(inShape[i]); | |||||
} | |||||
} | } | ||||
Value input = op.getInput(0); | Value input = op.getInput(0); | ||||
Value kernel = op.getInput(1); | Value kernel = op.getInput(1); | ||||
Value output = op.getOutputBuffer(0); | Value output = op.getOutputBuffer(0); | ||||
unsigned rank = inShapeType.getRank(); | unsigned rank = inShapeType.getRank(); | ||||
unsigned numDims = mapping.size(); | unsigned numDims = mapping.size(); | ||||
Type elemType = inShapeType.getElementType(); | Type elemType = inShapeType.getElementType(); | ||||
auto map = AffineMap::get(rank, 0, mapping, context); | auto map = AffineMap::get(rank, 0, mapping, context); | ||||
SmallVector<Value, 4> zeros(rank, std_constant_index(0)); | SmallVector<Value, 4> zeros(rank, std_constant_index(0)); | ||||
auto vecType = | auto vecType = VectorType::get(vectorDims, elemType); | ||||
VectorType::get(SmallVector<int64_t, 4>(numDims, tileSize), elemType); | |||||
auto inputVec = vector_transfer_read(vecType, input, zeros, map); | auto inputVec = vector_transfer_read(vecType, input, zeros, map); | ||||
auto kernelVec = vector_transfer_read(vecType, kernel, zeros, map); | auto kernelVec = vector_transfer_read(vecType, kernel, zeros, map); | ||||
auto acc = std_constant(elemType, rewriter.getZeroAttr(elemType)); | auto acc = std_constant(elemType, rewriter.getZeroAttr(elemType)); | ||||
std::array<AffineMap, 3> indexingMaps{ | std::array<AffineMap, 3> indexingMaps{ | ||||
AffineMap::getMultiDimIdentityMap(numDims, context), | AffineMap::getMultiDimIdentityMap(numDims, context), | ||||
Show All 18 Lines | |||||
/// conversion into corresponding pattern lists. | /// conversion into corresponding pattern lists. | ||||
template <typename ConvOp, unsigned N> | template <typename ConvOp, unsigned N> | ||||
static void | static void | ||||
populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns, | populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns, | ||||
OwningRewritePatternList &promotionPatterns, | OwningRewritePatternList &promotionPatterns, | ||||
OwningRewritePatternList &vectorizationPatterns, | OwningRewritePatternList &vectorizationPatterns, | ||||
ArrayRef<int64_t> tileSizes, | ArrayRef<int64_t> tileSizes, | ||||
MLIRContext *context) { | MLIRContext *context) { | ||||
if (tileSizes.size() < N) | |||||
return; | |||||
constexpr static StringRef kTiledMarker = "TILED"; | constexpr static StringRef kTiledMarker = "TILED"; | ||||
constexpr static StringRef kPromotedMarker = "PROMOTED"; | constexpr static StringRef kPromotedMarker = "PROMOTED"; | ||||
tilingPatterns.insert<LinalgTilingPattern<ConvOp>>( | tilingPatterns.insert<LinalgTilingPattern<ConvOp>>( | ||||
context, LinalgTilingOptions().setTileSizes(tileSizes), | context, LinalgTilingOptions().setTileSizes(tileSizes), | ||||
LinalgMarker({}, Identifier::get(kTiledMarker, context))); | LinalgMarker({}, Identifier::get(kTiledMarker, context))); | ||||
promotionPatterns.insert<LinalgPromotionPattern<ConvOp>>( | promotionPatterns.insert<LinalgPromotionPattern<ConvOp>>( | ||||
context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), | context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), | ||||
LinalgMarker(Identifier::get(kTiledMarker, context), | LinalgMarker(Identifier::get(kTiledMarker, context), | ||||
Identifier::get(kPromotedMarker, context))); | Identifier::get(kPromotedMarker, context))); | ||||
SmallVector<bool, 4> mask(N); | SmallVector<bool, 4> mask(N); | ||||
int offset = tileSizes.size() - N; | int offset = tileSizes.size() - N; | ||||
std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(), | std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(), | ||||
[](int64_t i) -> bool { return i != ConvOpConst::noTile; }); | [](int64_t i) -> bool { return i > 1; }); | ||||
vectorizationPatterns.insert<ConvOpVectorization<ConvOp, N>>(context, mask); | vectorizationPatterns.insert<ConvOpVectorization<ConvOp, N>>(context, mask); | ||||
} | } | ||||
void mlir::linalg::populateConvVectorizationPatterns( | void mlir::linalg::populateConvVectorizationPatterns( | ||||
MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns) { | MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns, | ||||
const int64_t tileSize = ConvOpConst::tileSize; | ArrayRef<int64_t> tileSizes) { | ||||
const int64_t noTile = ConvOpConst::noTile; | |||||
auto makeTileSizes = [&](unsigned numNoTile, unsigned numTile) { | |||||
SmallVector<int64_t, 10> result(numNoTile, noTile); | |||||
result.append(numTile, tileSize); | |||||
return result; | |||||
}; | |||||
OwningRewritePatternList tiling, promotion, vectorization; | OwningRewritePatternList tiling, promotion, vectorization; | ||||
populateVectorizationPatterns<ConvWOp, 1>( | populateVectorizationPatterns<ConvWOp, 1>(tiling, promotion, vectorization, | ||||
tiling, promotion, vectorization, | tileSizes, context); | ||||
makeTileSizes(/*numNoTile=*/1, /*numTile*/ 1), context); | |||||
populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization, | populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization, | ||||
makeTileSizes(3, 2), context); | tileSizes, context); | ||||
populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization, | populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization, | ||||
makeTileSizes(3, 2), context); | tileSizes, context); | ||||
populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization, | populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization, | ||||
makeTileSizes(2, 2), context); | tileSizes, context); | ||||
populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization, | populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization, | ||||
makeTileSizes(4, 3), context); | tileSizes, context); | ||||
populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization, | populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization, | ||||
makeTileSizes(4, 3), context); | tileSizes, context); | ||||
populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization, | populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization, | ||||
makeTileSizes(3, 3), context); | tileSizes, context); | ||||
populateVectorizationPatterns<ConvNDHWCOp, 5>( | populateVectorizationPatterns<ConvNDHWCOp, 5>( | ||||
tiling, promotion, vectorization, makeTileSizes(5, 4), context); | tiling, promotion, vectorization, tileSizes, context); | ||||
populateVectorizationPatterns<ConvNCDHWOp, 5>( | populateVectorizationPatterns<ConvNCDHWOp, 5>( | ||||
tiling, promotion, vectorization, makeTileSizes(5, 4), context); | tiling, promotion, vectorization, tileSizes, context); | ||||
patterns.push_back(std::move(tiling)); | patterns.push_back(std::move(tiling)); | ||||
patterns.push_back(std::move(promotion)); | patterns.push_back(std::move(promotion)); | ||||
patterns.push_back(std::move(vectorization)); | patterns.push_back(std::move(vectorization)); | ||||
} | } |