diff --git a/mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -121,9 +121,10 @@ if (op.offsets().getValue().empty()) return failure(); - int64_t rankDiff = dstType.getRank() - srcType.getRank(); - assert(rankDiff >= 0); - if (rankDiff != 0) + int64_t srcRank = srcType.getRank(); + int64_t dstRank = dstType.getRank(); + assert(dstRank >= srcRank); + if (dstRank != srcRank) return failure(); if (srcType == dstType) { @@ -139,6 +140,34 @@ auto loc = op.getLoc(); Value res = op.dest(); + + if (srcRank == 1) { + int nSrc = srcType.getShape().front(); + int nDest = dstType.getShape().front(); + // 1. Scale source to destType so we can shufflevector them together. + SmallVector offsets(nDest, 0); + for (int64_t i = 0; i < nSrc; ++i) + offsets[i] = i; + Value scaledSource = + rewriter.create(loc, op.source(), op.source(), offsets); + + // 2. Create a mask where we take the value from scaledSource of dest + // depending on the offset. + offsets.clear(); + for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) { + if (i < offset || i >= e || (i - offset) % stride != 0) + offsets.push_back(nDest + i); + else + offsets.push_back((i - offset) / stride); + } + + // 3. Replace with a ShuffleOp. + rewriter.replaceOpWithNewOp(op, scaledSource, op.dest(), + offsets); + + return success(); + } + // For each slice of the source vector along the most major dimension. for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; off += stride, ++idx) { diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1451,7 +1451,10 @@ Value v2, ArrayRef mask) { result.addOperands({v1, v2}); auto maskAttr = getVectorSubscriptAttr(builder, mask); - result.addTypes(v1.getType()); + auto v1Type = v1.getType().cast(); + auto shape = llvm::to_vector<4>(v1Type.getShape()); + shape[0] = mask.size(); + result.addTypes(VectorType::get(shape, v1Type.getElementType())); result.addAttribute(getMaskAttrName(), maskAttr); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -900,46 +900,24 @@ %0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> return %0 : vector<4x4xf32> } + // CHECK-LABEL: @insert_strided_slice2 // // Subvector vector<2xf32> @0 into vector<4xf32> @2 -// CHECK: unrealized_conversion_cast %{{.*}} : vector<4x4xf32> to !llvm.array<4 x vector<4xf32>> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm.array<2 x vector<2xf32>> -// CHECK-NEXT: llvm.extractvalue {{.*}}[2] : !llvm.array<4 x vector<4xf32>> +// CHECK: %[[V2_0:.*]] = llvm.extractvalue {{.*}}[0] : !llvm.array<2 x vector<2xf32>> +// CHECK: %[[V4_0:.*]] = llvm.extractvalue {{.*}}[2] : !llvm.array<4 x vector<4xf32>> // Element @0 -> element @2 -// CHECK-NEXT: arith.constant 0 : index -// CHECK-NEXT: unrealized_conversion_cast %{{.*}} : index to i64 -// CHECK-NEXT: llvm.extractelement {{.*}}[{{.*}} : i64] : vector<2xf32> -// CHECK-NEXT: arith.constant 2 : index -// CHECK-NEXT: unrealized_conversion_cast %{{.*}} : index to i64 -// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i64] : vector<4xf32> -// Element @1 -> element @3 -// CHECK-NEXT: arith.constant 1 : index -// CHECK-NEXT: unrealized_conversion_cast %{{.*}} : index to i64 -// CHECK-NEXT: llvm.extractelement {{.*}}[{{.*}} : i64] : vector<2xf32> -// CHECK-NEXT: arith.constant 3 : index -// CHECK-NEXT: unrealized_conversion_cast %{{.*}} : index to i64 -// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i64] : vector<4xf32> -// CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x vector<4xf32>> +// CHECK: %[[R4_0:.*]] = llvm.shufflevector %[[V2_0]], %[[V2_0]] [0, 1, 0, 0] : vector<2xf32>, vector<2xf32> +// CHECK: %[[R4_1:.*]] = llvm.shufflevector %[[R4_0]], %[[V4_0]] [4, 5, 0, 1] : vector<4xf32>, vector<4xf32> +// CHECK: llvm.insertvalue %[[R4_1]], {{.*}}[2] : !llvm.array<4 x vector<4xf32>> // // Subvector vector<2xf32> @1 into vector<4xf32> @3 -// CHECK: llvm.extractvalue {{.*}}[1] : !llvm.array<2 x vector<2xf32>> -// CHECK-NEXT: llvm.extractvalue {{.*}}[3] : !llvm.array<4 x vector<4xf32>> +// CHECK: %[[V2_1:.*]] = llvm.extractvalue {{.*}}[1] : !llvm.array<2 x vector<2xf32>> +// CHECK: %[[V4_3:.*]] = llvm.extractvalue {{.*}}[3] : !llvm.array<4 x vector<4xf32>> // Element @0 -> element @2 -// CHECK-NEXT: arith.constant 0 : index -// CHECK-NEXT: unrealized_conversion_cast %{{.*}} : index to i64 -// CHECK-NEXT: llvm.extractelement {{.*}}[{{.*}} : i64] : vector<2xf32> -// CHECK-NEXT: arith.constant 2 : index -// CHECK-NEXT: unrealized_conversion_cast %{{.*}} : index to i64 -// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i64] : vector<4xf32> -// Element @1 -> element @3 -// CHECK-NEXT: arith.constant 1 : index -// CHECK-NEXT: unrealized_conversion_cast %{{.*}} : index to i64 -// CHECK-NEXT: llvm.extractelement {{.*}}[{{.*}} : i64] : vector<2xf32> -// CHECK-NEXT: arith.constant 3 : index -// CHECK-NEXT: unrealized_conversion_cast %{{.*}} : index to i64 -// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i64] : vector<4xf32> -// CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm.array<4 x vector<4xf32>> +// CHECK: %[[R4_2:.*]] = llvm.shufflevector %[[V2_1]], %[[V2_1]] [0, 1, 0, 0] : vector<2xf32>, vector<2xf32> +// CHECK: %[[R4_3:.*]] = llvm.shufflevector %[[R4_2]], %[[V4_3]] [4, 5, 0, 1] : vector<4xf32>, vector<4xf32> +// CHECK: llvm.insertvalue %[[R4_3]], {{.*}}[3] : !llvm.array<4 x vector<4xf32>> // ----- @@ -948,69 +926,18 @@ vector<2x4xf32> into vector<16x4x8xf32> return %0 : vector<16x4x8xf32> } -// CHECK-LABEL: @insert_strided_slice3( -// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>, -// CHECK-SAME: %[[B:.*]]: vector<16x4x8xf32>) -// CHECK-DAG: %[[s2:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<16x4x8xf32> to !llvm.array<16 x array<4 x vector<8xf32>>> -// CHECK-DAG: %[[s4:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>> -// CHECK: %[[s3:.*]] = llvm.extractvalue %[[s2]][0] : !llvm.array<16 x array<4 x vector<8xf32>>> -// CHECK: %[[s5:.*]] = llvm.extractvalue %[[s4]][0] : !llvm.array<2 x vector<4xf32>> -// CHECK: %[[s7:.*]] = llvm.extractvalue %[[s2]][0, 0] : !llvm.array<16 x array<4 x vector<8xf32>>> -// CHECK: %[[s8:.*]] = arith.constant 0 : index -// CHECK: %[[s9:.*]] = builtin.unrealized_conversion_cast %[[s8]] : index to i64 -// CHECK: %[[s10:.*]] = llvm.extractelement %[[s5]]{{\[}}%[[s9]] : i64] : vector<4xf32> -// CHECK: %[[s11:.*]] = arith.constant 2 : index -// CHECK: %[[s12:.*]] = builtin.unrealized_conversion_cast %[[s11]] : index to i64 -// CHECK: %[[s13:.*]] = llvm.insertelement %[[s10]], %[[s7]]{{\[}}%[[s12]] : i64] : vector<8xf32> -// CHECK: %[[s14:.*]] = arith.constant 1 : index -// CHECK: %[[s15:.*]] = builtin.unrealized_conversion_cast %[[s14]] : index to i64 -// CHECK: %[[s16:.*]] = llvm.extractelement %[[s5]]{{\[}}%[[s15]] : i64] : vector<4xf32> -// CHECK: %[[s17:.*]] = arith.constant 3 : index -// CHECK: %[[s18:.*]] = builtin.unrealized_conversion_cast %[[s17]] : index to i64 -// CHECK: %[[s19:.*]] = llvm.insertelement %[[s16]], %[[s13]]{{\[}}%[[s18]] : i64] : vector<8xf32> -// CHECK: %[[s20:.*]] = arith.constant 2 : index -// CHECK: %[[s21:.*]] = builtin.unrealized_conversion_cast %[[s20]] : index to i64 -// CHECK: %[[s22:.*]] = llvm.extractelement %[[s5]]{{\[}}%[[s21]] : i64] : vector<4xf32> -// CHECK: %[[s23:.*]] = arith.constant 4 : index -// CHECK: %[[s24:.*]] = builtin.unrealized_conversion_cast %[[s23]] : index to i64 -// CHECK: %[[s25:.*]] = llvm.insertelement %[[s22]], %[[s19]]{{\[}}%[[s24]] : i64] : vector<8xf32> -// CHECK: %[[s26:.*]] = arith.constant 3 : index -// CHECK: %[[s27:.*]] = builtin.unrealized_conversion_cast %[[s26]] : index to i64 -// CHECK: %[[s28:.*]] = llvm.extractelement %[[s5]]{{\[}}%[[s27]] : i64] : vector<4xf32> -// CHECK: %[[s29:.*]] = arith.constant 5 : index -// CHECK: %[[s30:.*]] = builtin.unrealized_conversion_cast %[[s29]] : index to i64 -// CHECK: %[[s31:.*]] = llvm.insertelement %[[s28]], %[[s25]]{{\[}}%[[s30]] : i64] : vector<8xf32> -// CHECK: %[[s32:.*]] = llvm.insertvalue %[[s31]], %[[s3]][0] : !llvm.array<4 x vector<8xf32>> -// CHECK: %[[s34:.*]] = llvm.extractvalue %[[s4]][1] : !llvm.array<2 x vector<4xf32>> -// CHECK: %[[s36:.*]] = llvm.extractvalue %[[s2]][0, 1] : !llvm.array<16 x array<4 x vector<8xf32>>> -// CHECK: %[[s37:.*]] = arith.constant 0 : index -// CHECK: %[[s38:.*]] = builtin.unrealized_conversion_cast %[[s37]] : index to i64 -// CHECK: %[[s39:.*]] = llvm.extractelement %[[s34]]{{\[}}%[[s38]] : i64] : vector<4xf32> -// CHECK: %[[s40:.*]] = arith.constant 2 : index -// CHECK: %[[s41:.*]] = builtin.unrealized_conversion_cast %[[s40]] : index to i64 -// CHECK: %[[s42:.*]] = llvm.insertelement %[[s39]], %[[s36]]{{\[}}%[[s41]] : i64] : vector<8xf32> -// CHECK: %[[s43:.*]] = arith.constant 1 : index -// CHECK: %[[s44:.*]] = builtin.unrealized_conversion_cast %[[s43]] : index to i64 -// CHECK: %[[s45:.*]] = llvm.extractelement %[[s34]]{{\[}}%[[s44]] : i64] : vector<4xf32> -// CHECK: %[[s46:.*]] = arith.constant 3 : index -// CHECK: %[[s47:.*]] = builtin.unrealized_conversion_cast %[[s46]] : index to i64 -// CHECK: %[[s48:.*]] = llvm.insertelement %[[s45]], %[[s42]]{{\[}}%[[s47]] : i64] : vector<8xf32> -// CHECK: %[[s49:.*]] = arith.constant 2 : index -// CHECK: %[[s50:.*]] = builtin.unrealized_conversion_cast %[[s49]] : index to i64 -// CHECK: %[[s51:.*]] = llvm.extractelement %[[s34]]{{\[}}%[[s50]] : i64] : vector<4xf32> -// CHECK: %[[s52:.*]] = arith.constant 4 : index -// CHECK: %[[s53:.*]] = builtin.unrealized_conversion_cast %[[s52]] : index to i64 -// CHECK: %[[s54:.*]] = llvm.insertelement %[[s51]], %[[s48]]{{\[}}%[[s53]] : i64] : vector<8xf32> -// CHECK: %[[s55:.*]] = arith.constant 3 : index -// CHECK: %[[s56:.*]] = builtin.unrealized_conversion_cast %[[s55]] : index to i64 -// CHECK: %[[s57:.*]] = llvm.extractelement %[[s34]]{{\[}}%[[s56]] : i64] : vector<4xf32> -// CHECK: %[[s58:.*]] = arith.constant 5 : index -// CHECK: %[[s59:.*]] = builtin.unrealized_conversion_cast %[[s58]] : index to i64 -// CHECK: %[[s60:.*]] = llvm.insertelement %[[s57]], %[[s54]]{{\[}}%[[s59]] : i64] : vector<8xf32> -// CHECK: %[[s61:.*]] = llvm.insertvalue %[[s60]], %[[s32]][1] : !llvm.array<4 x vector<8xf32>> -// CHECK: %[[s63:.*]] = llvm.insertvalue %[[s61]], %[[s2]][0] : !llvm.array<16 x array<4 x vector<8xf32>>> -// CHECK: %[[s64:.*]] = builtin.unrealized_conversion_cast %[[s63]] : !llvm.array<16 x array<4 x vector<8xf32>>> to vector<16x4x8xf32> -// CHECK: return %[[s64]] : vector<16x4x8xf32> +// CHECK-LABEL: func @insert_strided_slice3 +// CHECK: %[[V4_0:.*]] = llvm.extractvalue {{.*}}[0] : !llvm.array<2 x vector<4xf32>> +// CHECK: %[[V4_0_0:.*]] = llvm.extractvalue {{.*}}[0, 0] : !llvm.array<16 x array<4 x vector<8xf32>>> +// CHECK: %[[R8_0:.*]] = llvm.shufflevector %[[V4_0]], %[[V4_0]] [0, 1, 2, 3, 0, 0, 0, 0] : vector<4xf32>, vector<4xf32> +// CHECK: %[[R8_1:.*]] = llvm.shufflevector %[[R8_0:.*]], %[[V4_0_0]] [8, 9, 0, 1, 2, 3, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: llvm.insertvalue %[[R8_1]], {{.*}}[0] : !llvm.array<4 x vector<8xf32>> + +// CHECK: %[[V4_1:.*]] = llvm.extractvalue {{.*}}[1] : !llvm.array<2 x vector<4xf32>> +// CHECK: %[[V4_0_1:.*]] = llvm.extractvalue {{.*}}[0, 1] : !llvm.array<16 x array<4 x vector<8xf32>>> +// CHECK: %[[R8_2:.*]] = llvm.shufflevector %[[V4_1]], %[[V4_1]] [0, 1, 2, 3, 0, 0, 0, 0] : vector<4xf32>, vector<4xf32> +// CHECK: %[[R8_3:.*]] = llvm.shufflevector %[[R8_2]], %[[V4_0_1]] [8, 9, 0, 1, 2, 3, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: llvm.insertvalue %[[R8_3]], {{.*}}[1] : !llvm.array<4 x vector<8xf32>> // -----