diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -9,6 +9,7 @@ // This file implements the linalg dialect Vectorization transformations. // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/Utils.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -1057,6 +1058,21 @@ return success(); } +/// Converts affine.apply Ops to arithmetic operations. +static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) { + OpBuilder::InsertionGuard g(rewriter); + auto toReplace = linalgOp.getBlock()->getOps(); + + for (auto op : make_early_inc_range(toReplace)) { + rewriter.setInsertionPoint(op); + auto expanded = expandAffineExpr( + rewriter, op->getLoc(), op.getAffineMap().getResult(0), + op.getOperands().take_front(op.getAffineMap().getNumDims()), + op.getOperands().take_back(op.getAffineMap().getNumSymbols())); + rewriter.replaceOp(op, expanded); + } +} + /// Emit a suitable vector form for a Linalg op. If provided, `inputVectorSizes` /// are used to vectorize this operation. `inputVectorSizes` must match the rank /// of the iteration space of the operation and the input vector sizes must be @@ -1093,6 +1109,10 @@ vectorizeNDExtract))) return failure(); LDBG("Vectorize generic by broadcasting to the canonical vector shape\n"); + + // Pre-process before proceeding. + convertAffineApply(rewriter, linalgOp); + // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted to // 'OpBuilder' when it is passed over to some methods like // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we erase an op diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -164,7 +164,7 @@ return false; for (Operation &op : r.front()) { if (!(isa(op) || + linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(op) || OpTrait::hasElementwiseMappableTraits(&op)) || llvm::any_of(op.getResultTypes(), [](Type type) { return !type.isIntOrIndexOrFloat(); })) diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -290,6 +290,51 @@ // ----- +#map0 = affine_map<(d0) -> (d0)> + +func.func @vectorize_affine_apply(%arg0: tensor<5xf32>, %arg3: index) -> tensor<5xi32> { + %0 = tensor.empty() : tensor<5xi32> + %1 = linalg.generic {indexing_maps = [#map0, #map0], + iterator_types = ["parallel"]} + ins(%arg0 : tensor<5xf32>) + outs(%0 : tensor<5xi32>) { + ^bb0(%arg1: f32, %arg2: i32): + %2 = linalg.index 0 : index + %11 = affine.apply affine_map<() -> (123)>() + %12 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %11) + %13 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%12)[%arg3] + %14 = affine.apply affine_map<(d0) -> (d0 + 1)>(%13) + %15 = affine.apply affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>(%13, %14, %12) + %3 = arith.index_cast %15 : index to i32 + linalg.yield %3 : i32 + } -> tensor<5xi32> + return %1 : tensor<5xi32> +} + +// CHECK-LABEL: func.func @vectorize_affine_apply +// CHECK-SAME: %arg0: tensor<5xf32> +// CHECK-SAME: %[[ARG1:.*]]: index +// CHECK: %[[CST:.*]] = arith.constant dense<[123, 124, 125, 126, 127]> : vector<5xindex> +// CHECK: %[[CST_0:.*]] = arith.constant dense<1> : vector<5xindex> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<5xi32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG1]] : index to vector<5xindex> +// CHECK: %[[ADDI_1:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<5xindex> +// CHECK: %[[ADDI_2:.*]] = arith.addi %[[ADDI_1]], %[[CST_0]] : vector<5xindex> +// CHECK: %[[ADDI_3:.*]] = arith.addi %[[ADDI_1]], %[[ADDI_2]] : vector<5xindex> +// CHECK: %[[ADDI_4:.*]] = arith.addi %[[ADDI_3]], %[[CST]] : vector<5xindex> +// CHECK: %[[CAST:.*]] = arith.index_cast %[[ADDI_4]] : vector<5xindex> to vector<5xi32> +// CHECK: vector.transfer_write %[[CAST]], %[[EMPTY]][%[[C0:.*]]] {in_bounds = [true]} : vector<5xi32>, tensor<5xi32> + +transform.sequence failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 { vectorize_nd_extract } +} + +// ----- + // CHECK-LABEL: func @test_vectorize_fill func.func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) { // CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>