Please use GitHub pull requests for new patches. Phabricator shutdown timeline
Changeset View
Standalone View
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
//===- Vectorization.cpp - Implementation of linalg Vectorization ---------===// | //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===// | ||||
// | // | ||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||||
// See https://llvm.org/LICENSE.txt for license information. | // See https://llvm.org/LICENSE.txt for license information. | ||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||||
// | // | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// | // | ||||
// This file implements the linalg dialect Vectorization transformations. | // This file implements the linalg dialect Vectorization transformations. | ||||
// | // | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
#include "mlir/Dialect/Affine/Utils.h" | |||||
#include "mlir/Analysis/SliceAnalysis.h" | #include "mlir/Analysis/SliceAnalysis.h" | ||||
#include "mlir/Dialect/Affine/IR/AffineOps.h" | #include "mlir/Dialect/Affine/IR/AffineOps.h" | ||||
#include "mlir/Dialect/Arith/IR/Arith.h" | #include "mlir/Dialect/Arith/IR/Arith.h" | ||||
#include "mlir/Dialect/Func/IR/FuncOps.h" | #include "mlir/Dialect/Func/IR/FuncOps.h" | ||||
#include "mlir/Dialect/Linalg/IR/Linalg.h" | #include "mlir/Dialect/Linalg/IR/Linalg.h" | ||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" | ||||
#include "mlir/Dialect/Linalg/Utils/Utils.h" | #include "mlir/Dialect/Linalg/Utils/Utils.h" | ||||
▲ Show 20 Lines • Show All 1,023 Lines • ▼ Show 20 Lines | mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp, | ||||
} | } | ||||
if (failed(reductionPreconditions(linalgOp))) { | if (failed(reductionPreconditions(linalgOp))) { | ||||
LDBG("precondition failed: reduction preconditions\n"); | LDBG("precondition failed: reduction preconditions\n"); | ||||
return failure(); | return failure(); | ||||
} | } | ||||
return success(); | return success(); | ||||
} | } | ||||
/// Converts affine.apply Ops to arithmetic operations. | |||||
dcaballe: nit: `//` -> `///` and `.` at the end per coding guidelines. | |||||
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) { | |||||
auto &newIP = linalgOp.getBlock()->front(); | |||||
OpBuilder::InsertionGuard g(rewriter); | |||||
rewriter.setInsertionPointAfter(&newIP); | |||||
auto toReplace = linalgOp.getBlock()->getOps<AffineApplyOp>(); | |||||
for (auto op : make_early_inc_range(toReplace)) { | |||||
auto expanded = | |||||
expandAffineExpr(rewriter, op->getLoc(), op.getAffineMap().getResult(0), | |||||
op.getOperands(), ValueRange{}); | |||||
rewriter.replaceOp(op, expanded); | |||||
} | |||||
} | |||||
/// Emit a suitable vector form for a Linalg op. If provided, `inputVectorSizes` | /// Emit a suitable vector form for a Linalg op. If provided, `inputVectorSizes` | ||||
/// are used to vectorize this operation. `inputVectorSizes` must match the rank | /// are used to vectorize this operation. `inputVectorSizes` must match the rank | ||||
/// of the iteration space of the operation and the sizes must be smaller or | /// of the iteration space of the operation and the sizes must be smaller or | ||||
/// equal than their counterpart interation space sizes, if static. | /// equal than their counterpart interation space sizes, if static. | ||||
/// `inputVectorShapes` also allows the vectorization of operations with dynamic | /// `inputVectorShapes` also allows the vectorization of operations with dynamic | ||||
/// shapes. | /// shapes. | ||||
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, LinalgOp linalgOp, | LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, LinalgOp linalgOp, | ||||
ArrayRef<int64_t> inputVectorSizes, | ArrayRef<int64_t> inputVectorSizes, | ||||
Show All 20 Lines | LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, LinalgOp linalgOp, | ||||
FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, linalgOp); | FailureOr<Operation *> convOr = vectorizeConvolution(rewriter, linalgOp); | ||||
if (succeeded(convOr)) { | if (succeeded(convOr)) { | ||||
llvm::append_range(results, (*convOr)->getResults()); | llvm::append_range(results, (*convOr)->getResults()); | ||||
} else { | } else { | ||||
if (failed(vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes, | if (failed(vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes, | ||||
vectorizeNDExtract))) | vectorizeNDExtract))) | ||||
return failure(); | return failure(); | ||||
LDBG("Vectorize generic by broadcasting to the canonical vector shape\n"); | LDBG("Vectorize generic by broadcasting to the canonical vector shape\n"); | ||||
// Pre-process before proceeding. | |||||
Not Done ReplyInline Actionsnit: ii typo dcaballe: nit: `ii` typo | |||||
convertAffineApply(rewriter, linalgOp); | |||||
// TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted to | // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted to | ||||
// 'OpBuilder' when it is passed over to some methods like | // 'OpBuilder' when it is passed over to some methods like | ||||
// 'vectorizeAsLinalgGeneric'. This is highly problematic: if we erase an op | // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we erase an op | ||||
// within these methods, the actual rewriter won't be notified and we will | // within these methods, the actual rewriter won't be notified and we will | ||||
Not Done ReplyInline ActionsWhen we have to change the insertion point, we use OpBuilder::Guard using RAII. I'd suggest that you move this to a utility function and then you can do RAII using the scope of the function. You can look at other examples in MLIR. Just search for OpBuilder::Guard. This IR change could be part of a larger set of "linalgOp pre-processing" transformations that happens right before vectorization starts but after we know we can vectorize the op. dcaballe: When we have to change the insertion point, we use `OpBuilder::Guard` using RAII. I'd suggest… | |||||
I like this idea :) Just to double-check - that set is yet to be created, right? awarzynski: >This IR change could be part of a larger set of "linalgOp pre-processing" transformations that… | |||||
// end up with read-after-free issues! | // end up with read-after-free issues! | ||||
if (failed(vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results))) | if (failed(vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results))) | ||||
return failure(); | return failure(); | ||||
} | } | ||||
if (!results.empty()) | if (!results.empty()) | ||||
rewriter.replaceOp(linalgOp, results); | rewriter.replaceOp(linalgOp, results); | ||||
else | else | ||||
rewriter.eraseOp(linalgOp); | rewriter.eraseOp(linalgOp); | ||||
return success(); | return success(); | ||||
} | } | ||||
Not Done ReplyInline ActionsCouldn't we just do rewriter.replaceOp(op, expanded) and avoid the manual U-D chain update? dcaballe: Couldn't we just do `rewriter.replaceOp(op, expanded)` and avoid the manual U-D chain update? | |||||
Perhaps I'm being daft, but things go horrible wrong when I do that. And I assume that that's because rewriter.replaceOp invalidates the iterators in the surrounding for loop. awarzynski: Perhaps I'm being daft, but things go horrible wrong when I do that. And I assume that that's… | |||||
LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter, | LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter, | ||||
memref::CopyOp copyOp) { | memref::CopyOp copyOp) { | ||||
auto srcType = copyOp.getSource().getType().cast<MemRefType>(); | auto srcType = copyOp.getSource().getType().cast<MemRefType>(); | ||||
auto dstType = copyOp.getTarget().getType().cast<MemRefType>(); | auto dstType = copyOp.getTarget().getType().cast<MemRefType>(); | ||||
if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) | if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) | ||||
return failure(); | return failure(); | ||||
▲ Show 20 Lines • Show All 1,366 Lines • Show Last 20 Lines |
nit: // -> /// and . at the end per coding guidelines.