diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -960,6 +960,10 @@ if (!isa(op)) return failure(); + // TODO: Index vectorization assumes static shape. + if (op.hasIndexSemantics()) + return failure(); + // TODO: 0-d vectors are not supported yet. if (llvm::any_of(op.getIndexingMapsArray(), [](AffineMap map) { return map.isEmpty() || map.getResults().empty(); @@ -1052,15 +1056,15 @@ /// Converts affine.apply Ops to arithmetic operations. static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) { - auto &newIP = linalgOp.getBlock()->front(); OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPointAfter(&newIP); auto toReplace = linalgOp.getBlock()->getOps(); for (auto op : make_early_inc_range(toReplace)) { - auto expanded = - expandAffineExpr(rewriter, op->getLoc(), op.getAffineMap().getResult(0), - op.getOperands(), ValueRange{}); + rewriter.setInsertionPoint(op); + auto expanded = expandAffineExpr( + rewriter, op->getLoc(), op.getAffineMap().getResult(0), + op.getOperands().take_front(op.getAffineMap().getNumDims()), + op.getOperands().take_back(op.getAffineMap().getNumSymbols())); rewriter.replaceOp(op, expanded); } } diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -301,7 +301,8 @@ ^bb0(%arg1: f32, %arg2: i32): %2 = linalg.index 0 : index %12 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %arg3) - %3 = arith.index_cast %12 : index to i32 + %13 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%12)[%arg3] + %3 = arith.index_cast %13 : index to i32 linalg.yield %3 : i32 } -> tensor<32xi32> return %1 : tensor<32xi32> @@ -315,7 +316,9 @@ // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<32xi32> // CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG1]] : index to vector<32xindex> // CHECK: %[[ADDI:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<32xindex> -// CHECK: %[[CAST:.*]] = arith.index_cast %[[ADDI]] : vector<32xindex> to vector<32xi32> +// CHECK: %[[BCAST2:.*]] = vector.broadcast %[[ARG1]] : index to vector<32xindex> +// CHECK: %[[ADDI2:.*]] = arith.addi %[[ADDI]], %[[BCAST2]] : vector<32xindex> +// CHECK: %[[CAST:.*]] = arith.index_cast %[[ADDI2]] : vector<32xindex> to vector<32xi32> // CHECK: vector.transfer_write %[[CAST]], %[[EMPTY]][%[[C0:.*]]] {in_bounds = [true]} : vector<32xi32>, tensor<32xi32> transform.sequence failures(propagate) {