diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -9,6 +9,7 @@ // This file implements the linalg dialect Vectorization transformations. // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/Utils.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -1084,6 +1085,16 @@ vectorizeNDExtract))) return failure(); LDBG("Vectorize generic by broadcasting to the canonical vector shape\n"); + + for (Operation &op : linalgOp.getBlock()->getOperations()) { + if (auto affineApply = dyn_cast(op)) { + auto expanded = expandAffineExpr(rewriter, affineApply->getLoc(), + affineApply.getAffineMap().getResult(0), + affineApply.getOperands(), ValueRange{}); + affineApply.replaceAllUsesWith(expanded); + } + } + // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted to // 'OpBuilder' when it is passed over to some methods like // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we erase an op diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -164,7 +164,7 @@ return false; for (Operation &op : r.front()) { if (!(isa(op) || + linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(op) || OpTrait::hasElementwiseMappableTraits(&op)) || llvm::any_of(op.getResultTypes(), [](Type type) { return !type.isIntOrIndexOrFloat(); }))