diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -952,6 +952,7 @@ } // 2. Handle contiguous access. + LDBG("Vectorised as contiguous load: " << extractOp); SmallVector transferReadIdxs; auto resTrailingDim = resultType.getShape().back(); auto zero = rewriter.create( @@ -985,12 +986,29 @@ } // `tensor.extract_element` is always in-bounds, hence the following holds. - SmallVector inBounds(resultType.getRank(), true); + auto dstRank = resultType.getRank(); + SmallVector inBounds(dstRank, true); + + // Create a permutation map for a transfer read + auto srcRank = extractOp.getTensor().getType().getRank(); + auto permutationMap = AffineMap::getMinorIdentityMap( + srcRank, std::min(dstRank, srcRank), rewriter.getContext()); + + int32_t rankDiff = dstRank - srcRank; + // When dstRank > srcRank, extend the map with 0. For example, for dstRank = + // 3, srcRank = 2, the following map created above: + // (d0, d1) --> (d0, d1) + // is extended as: + // (d0, d1) --> (0, d0, d1) + while (rankDiff > 0) { + permutationMap = permutationMap.insertResult( + mlir::getAffineConstantExpr(0, rewriter.getContext()), 0); + rankDiff--; + } auto transferReadOp = rewriter.create( - loc, resultType, extractOp.getTensor(), transferReadIdxs, inBounds); - - LDBG("Vectorised as contiguous load: " << extractOp); + loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap, + inBounds); return VectorizationResult{VectorizationStatus::NewOp, transferReadOp}; } diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1834,6 +1834,71 @@ // ----- +func.func @vectorize_nd_tensor_extract_with_tensor_extract(%input_1: tensor<1x20xi32>, %input_2: tensor<257x24xf32>, %arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index) -> tensor<1x1x4xf32> { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %output = tensor.empty() : tensor<1x1x4xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%output : tensor<1x1x4xf32>) { + ^bb0(%out: f32): + %13 = linalg.index 0 : index + %14 = affine.apply affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>(%arg0, %13, %arg2) + %15 = linalg.index 2 : index + %16 = linalg.index 1 : index + %17 = affine.apply affine_map<(d0, d1, d2, d3) -> (d0 + d1 * 24 + d2 + d3)>(%arg1, %16, %15, %arg3) + %extracted_0 = tensor.extract %input_1[%c0, %14] : tensor<1x20xi32> + %18 = arith.index_cast %extracted_0 : i32 to index + %19 = arith.maxsi %18, %c0 : index + %20 = arith.minsi %19, %c256 : index + %extracted_1 = tensor.extract %input_2[%20, %17] : tensor<257x24xf32> + linalg.yield %extracted_1 : f32 + } -> tensor<1x1x4xf32> + return %1 : tensor<1x1x4xf32> +} + +// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_tensor_extract( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x20xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<257x24xf32>, +// CHECK-SAME: -> tensor<1x1x4xf32> { +// CHECK: %[[VAL_6:.*]] = arith.constant dense<0> : vector<1x1x4xindex> +// CHECK: %[[VAL_7:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> +// CHECK: %[[VAL_8:.*]] = arith.constant dense : vector<1x1x4xi1> +// CHECK: %[[VAL_9:.*]] = arith.constant dense<0> : vector<1x1x4xi32> +// CHECK: %[[VAL_10:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_11:.*]] = arith.constant dense<256> : vector<1x1x4xindex> +// CHECK: %[[VAL_12:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_13:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_14:.*]] = tensor.empty() : tensor<1x1x4xf32> +// CHECK: %[[VAL_15:.*]] = vector.broadcast %{{.*}} : index to vector<1x1x4xindex> +// CHECK: %[[VAL_16:.*]] = vector.broadcast %{{.*}} : index to vector<1x1x4xindex> +// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : vector<1x1x4xindex> +// CHECK: %[[VAL_18:.*]] = vector.broadcast %{{.*}} : index to vector<1x1x4xindex> +// CHECK: %[[VAL_19:.*]] = vector.broadcast %[[VAL_7]] : vector<4xindex> to vector<1x1x4xindex> +// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_18]], %[[VAL_19]] : vector<1x1x4xindex> +// CHECK: %[[VAL_21:.*]] = vector.broadcast %{{.*}} : index to vector<1x1x4xindex> +// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_20]], %[[VAL_21]] : vector<1x1x4xindex> +// CHECK: %[[VAL_23:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_10]], %[[VAL_10]]] {{\[}}%[[VAL_17]]], %[[VAL_8]], %[[VAL_9]] : tensor<1x20xi32>, vector<1x1x4xindex>, vector<1x1x4xi1>, vector<1x1x4xi32> into vector<1x1x4xi32> +// CHECK: %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : vector<1x1x4xi32> to vector<1x1x4xindex> +// CHECK: %[[VAL_25:.*]] = arith.maxsi %[[VAL_24]], %[[VAL_6]] : vector<1x1x4xindex> +// CHECK: %[[VAL_26:.*]] = arith.minsi %[[VAL_25]], %[[VAL_11]] : vector<1x1x4xindex> +// CHECK: %[[VAL_27:.*]] = vector.shape_cast %[[VAL_26]] : vector<1x1x4xindex> to vector<4xindex> +// CHECK: %[[VAL_28:.*]] = vector.extractelement %[[VAL_27]]{{\[}}%[[VAL_12]] : i32] : vector<4xindex> +// CHECK: %[[VAL_29:.*]] = vector.shape_cast %[[VAL_22]] : vector<1x1x4xindex> to vector<4xindex> +// CHECK: %[[VAL_30:.*]] = vector.extractelement %[[VAL_29]]{{\[}}%[[VAL_12]] : i32] : vector<4xindex> +// CHECK: %[[VAL_31:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_28]], %[[VAL_30]]], %[[VAL_13]] {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32> +// CHECK: %[[VAL_32:.*]] = vector.broadcast %[[VAL_31]] : vector<1x4xf32> to vector<1x1x4xf32> +// CHECK: %[[VAL_33:.*]] = vector.transfer_write %[[VAL_32]], %[[VAL_14]]{{\[}}%[[VAL_10]], %[[VAL_10]], %[[VAL_10]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, tensor<1x1x4xf32> +// CHECK: return %[[VAL_33]] : tensor<1x1x4xf32> +// CHECK: } + +transform.sequence failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 { vectorize_nd_extract } + } + +// ----- + func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<80x16xf32>, %arg0: index, %extracted_slice : tensor<1x3xf32>) -> tensor<1x3xf32> { %c79 = arith.constant 79 : index %1 = linalg.generic { @@ -1918,7 +1983,7 @@ ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation transform.structured.masked_vectorize %0 vector_sizes [1, 4] { vectorize_nd_extract } - } +} // -----