diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1352,9 +1352,9 @@ }; /// Progressive lowering of ExtractStridedSliceOp to either: -/// 1. extractelement + insertelement for the 1-D case -/// 2. extract + optional strided_slice + insert for the n-D case. -class VectorStridedSliceOpConversion +/// 1. express single offset extract as direct shuffle. +/// 2. extract + lower rank strided_slice + insert for the n-D case. +class VectorExtractStridedSliceOpConversion : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -1374,21 +1374,34 @@ auto loc = op.getLoc(); auto elemType = dstType.getElementType(); assert(elemType.isSignlessIntOrIndexOrFloat()); + + // Single offset can be more efficiently shuffled. + if (op.offsets().getValue().size() == 1) { + SmallVector offsets; + offsets.reserve(size); + for (int64_t off = offset, e = offset + size * stride; off < e; + off += stride) + offsets.push_back(off); + rewriter.replaceOpWithNewOp(op, dstType, op.vector(), + op.vector(), + rewriter.getI64ArrayAttr(offsets)); + return success(); + } + + // Extract/insert on a lower ranked extract strided slice op. Value zero = rewriter.create(loc, elemType, rewriter.getZeroAttr(elemType)); Value res = rewriter.create(loc, dstType, zero); for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; off += stride, ++idx) { - Value extracted = extractOne(rewriter, loc, op.vector(), off); - if (op.offsets().getValue().size() > 1) { - extracted = rewriter.create( - loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1), - getI64SubArray(op.sizes(), /* dropFront=*/1), - getI64SubArray(op.strides(), /* dropFront=*/1)); - } + Value one = extractOne(rewriter, loc, op.vector(), off); + Value extracted = rewriter.create( + loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1), + getI64SubArray(op.sizes(), /* dropFront=*/1), + getI64SubArray(op.strides(), /* dropFront=*/1)); res = insertOne(rewriter, loc, extracted, res, idx); } - rewriter.replaceOp(op, {res}); + rewriter.replaceOp(op, res); return success(); } /// This pattern creates recursive ExtractStridedSliceOp, but the recursion is @@ -1407,7 +1420,7 @@ patterns.insert(ctx); + VectorExtractStridedSliceOpConversion>(ctx); patterns.insert( ctx, converter, reassociateFPReductions); patterns diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -512,65 +512,38 @@ %0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> return %0 : vector<2xf32> } -// CHECK-LABEL: llvm.func @extract_strided_slice1 -// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float -// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm.vec<2 x float> -// CHECK: llvm.mlir.constant(2 : index) : !llvm.i64 -// CHECK: llvm.extractelement %{{.*}}[%{{.*}} : !llvm.i64] : !llvm.vec<4 x float> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : !llvm.i64] : !llvm.vec<2 x float> -// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64 -// CHECK: llvm.extractelement %{{.*}}[%{{.*}} : !llvm.i64] : !llvm.vec<4 x float> -// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK: llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : !llvm.i64] : !llvm.vec<2 x float> +// CHECK-LABEL: llvm.func @extract_strided_slice1( +// CHECK-SAME: %[[A:.*]]: !llvm.vec<4 x float>) +// CHECK: %[[T0:.*]] = llvm.shufflevector %[[A]], %[[A]] [2, 3] : !llvm.vec<4 x float>, !llvm.vec<4 x float> +// CHECK: llvm.return %[[T0]] : !llvm.vec<2 x float> func @extract_strided_slice2(%arg0: vector<4x8xf32>) -> vector<2x8xf32> { %0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8xf32> to vector<2x8xf32> return %0 : vector<2x8xf32> } -// CHECK-LABEL: llvm.func @extract_strided_slice2 -// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float -// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2x8xf32>) : !llvm.array<2 x vec<8 x float>> -// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.array<4 x vec<8 x float>> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.array<2 x vec<8 x float>> -// CHECK: llvm.extractvalue %{{.*}}[3] : !llvm.array<4 x vec<8 x float>> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.array<2 x vec<8 x float>> +// CHECK-LABEL: llvm.func @extract_strided_slice2( +// CHECK-SAME: %[[A:.*]]: !llvm.array<4 x vec<8 x float>>) +// CHECK: %[[T0:.*]] = llvm.mlir.undef : !llvm.array<2 x vec<8 x float>> +// CHECK: %[[T1:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vec<8 x float>> +// CHECK: %[[T2:.*]] = llvm.insertvalue %[[T1]], %[[T0]][0] : !llvm.array<2 x vec<8 x float>> +// CHECK: %[[T3:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vec<8 x float>> +// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T2]][1] : !llvm.array<2 x vec<8 x float>> +// CHECK: llvm.return %[[T4]] : !llvm.array<2 x vec<8 x float>> func @extract_strided_slice3(%arg0: vector<4x8xf32>) -> vector<2x2xf32> { %0 = vector.extract_strided_slice %arg0 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32> return %0 : vector<2x2xf32> } -// CHECK-LABEL: llvm.func @extract_strided_slice3 -// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float -// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2x2xf32>) : !llvm.array<2 x vec<2 x float>> -// -// Subvector vector<8xf32> @2 -// CHECK: llvm.extractvalue {{.*}}[2] : !llvm.array<4 x vec<8 x float>> -// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float -// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm.vec<2 x float> -// CHECK: llvm.mlir.constant(2 : index) : !llvm.i64 -// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<8 x float> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<2 x float> -// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64 -// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<8 x float> -// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<2 x float> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm.array<2 x vec<2 x float>> -// -// Subvector vector<8xf32> @3 -// CHECK: llvm.extractvalue {{.*}}[3] : !llvm.array<4 x vec<8 x float>> -// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float -// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm.vec<2 x float> -// CHECK: llvm.mlir.constant(2 : index) : !llvm.i64 -// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<8 x float> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<2 x float> -// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64 -// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<8 x float> -// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm.vec<2 x float> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm.array<2 x vec<2 x float>> +// CHECK-LABEL: llvm.func @extract_strided_slice3( +// CHECK-SAME: %[[A:.*]]: !llvm.array<4 x vec<8 x float>>) +// CHECK: %[[T1:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<2x2xf32>) : !llvm.array<2 x vec<2 x float>> +// CHECK: %[[T2:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vec<8 x float>> +// CHECK: %[[T3:.*]] = llvm.shufflevector %[[T2]], %[[T2]] [2, 3] : !llvm.vec<8 x float>, !llvm.vec<8 x float> +// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T1]][0] : !llvm.array<2 x vec<2 x float>> +// CHECK: %[[T5:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vec<8 x float>> +// CHECK: %[[T6:.*]] = llvm.shufflevector %[[T5]], %[[T5]] [2, 3] : !llvm.vec<8 x float>, !llvm.vec<8 x float> +// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T6]], %[[T4]][1] : !llvm.array<2 x vec<2 x float>> +// CHECK: llvm.return %[[T7]] : !llvm.array<2 x vec<2 x float>> func @insert_strided_slice1(%b: vector<4x4xf32>, %c: vector<4x4x4xf32>) -> vector<4x4x4xf32> { %0 = vector.insert_strided_slice %b, %c {offsets = [2, 0, 0], strides = [1, 1]} : vector<4x4xf32> into vector<4x4x4xf32> @@ -674,15 +647,11 @@ } // CHECK-LABEL: llvm.func @extract_strides( // CHECK-SAME: %[[A:.*]]: !llvm.array<3 x vec<3 x float>>) -// CHECK: %[[s0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1x1xf32>) : !llvm.array<1 x vec<1 x float>> -// CHECK: %[[s1:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<3 x vec<3 x float>> -// CHECK: %[[s3:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1xf32>) : !llvm.vec<1 x float> -// CHECK: %[[s4:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 -// CHECK: %[[s5:.*]] = llvm.extractelement %[[s1]][%[[s4]] : !llvm.i64] : !llvm.vec<3 x float> -// CHECK: %[[s6:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: %[[s7:.*]] = llvm.insertelement %[[s5]], %[[s3]][%[[s6]] : !llvm.i64] : !llvm.vec<1 x float> -// CHECK: %[[s8:.*]] = llvm.insertvalue %[[s7]], %[[s0]][0] : !llvm.array<1 x vec<1 x float>> -// CHECK: llvm.return %[[s8]] : !llvm.array<1 x vec<1 x float>> +// CHECK: %[[T1:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1x1xf32>) : !llvm.array<1 x vec<1 x float>> +// CHECK: %[[T2:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<3 x vec<3 x float>> +// CHECK: %[[T3:.*]] = llvm.shufflevector %[[T2]], %[[T2]] [2] : !llvm.vec<3 x float>, !llvm.vec<3 x float> +// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T1]][0] : !llvm.array<1 x vec<1 x float>> +// CHECK: llvm.return %[[T4]] : !llvm.array<1 x vec<1 x float>> // CHECK-LABEL: llvm.func @vector_fma( // CHECK-SAME: %[[A:.*]]: !llvm.vec<8 x float>, %[[B:.*]]: !llvm.array<2 x vec<4 x float>>)