diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -9,6 +9,7 @@ // This file implements the linalg dialect Vectorization transformations. // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/Utils.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -1084,6 +1085,27 @@ vectorizeNDExtract))) return failure(); LDBG("Vectorize generic by broadcasting to the canonical vector shape\n"); + + // Convert affine.apply to arithmetic operatiions before trying to + // vectorize + SmallVector affineApplyOpsToDelete; + auto oldIP = rewriter.saveInsertionPoint(); + auto &newIP = linalgOp.getBlock()->front(); + rewriter.setInsertionPointAfter(&newIP); + auto toReplace = linalgOp.getBlock()->getOps(); + + for (auto op : toReplace) { + auto expanded = expandAffineExpr(rewriter, op->getLoc(), + op.getAffineMap().getResult(0), + op.getOperands(), ValueRange{}); + op.replaceAllUsesWith(expanded); + affineApplyOpsToDelete.push_back(op); + } + for (auto op : affineApplyOpsToDelete) { + rewriter.eraseOp(op.getDefiningOp()); + } + rewriter.restoreInsertionPoint(oldIP); + // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted to // 'OpBuilder' when it is passed over to some methods like // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we erase an op diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -164,7 +164,7 @@ return false; for (Operation &op : r.front()) { if (!(isa(op) || + linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(op) || OpTrait::hasElementwiseMappableTraits(&op)) || llvm::any_of(op.getResultTypes(), [](Type type) { return !type.isIntOrIndexOrFloat(); })) diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -2004,3 +2004,39 @@ // CHECK-LABEL: @wrong_reduction_detection // CHECK: vector.broadcast // CHECK: vector.transfer_write + +// ----- + +// Regression test: %12 was considered as not vectorizable despite there being +// a simple arithmetic representation that can be used instead + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + +func.func @affine_apply(%arg0: tensor<128x12x32xf32>, %arg3: index) -> tensor<128x12x32xf32> { + %0 = tensor.empty() : tensor<128x12x32xf32> + %1 = linalg.generic {indexing_maps = [#map0, #map0], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0 : tensor<128x12x32xf32>) + outs(%0 : tensor<128x12x32xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + %2 = linalg.index 2 : index + %12 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %arg3) + %3 = arith.index_cast %12 : index to i32 + %4 = arith.uitofp %3 : i32 to f32 + %5 = arith.mulf %4, %arg1 : f32 + linalg.yield %5 : f32 + } -> tensor<128x12x32xf32> + return %1 : tensor<128x12x32xf32> +} + +transform.sequence failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 { vectorize_nd_extract } + } + +// CHECK-LABEL: @affine_apply +// CHECK: vector.transfer_read +// CHECK: vector.broadcast +// CHECK: vector.transfer_write