diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -972,9 +972,17 @@ } // namespace void LowerVectorToLLVMPass::runOnModule() { - // Convert to the LLVM IR dialect using the converter defined above. - OwningRewritePatternList patterns; + // Perform progressive lowering of operations on "slices". + // Folding and DCE get rid of all non-leaking tuple ops. + { + OwningRewritePatternList patterns; + populateVectorSlicesLoweringPatterns(patterns, &getContext()); + applyPatternsGreedily(getModule(), patterns); + } + + // Convert to the LLVM IR dialect. LLVMTypeConverter converter(&getContext()); + OwningRewritePatternList patterns; populateVectorToLLVMConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -424,10 +424,11 @@ // CHECK: llvm.call @print_close() : () -> () // CHECK: llvm.call @print_newline() : () -> () - -func @strided_slice(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>, %arg2: vector<4x8x16xf32>) { -// CHECK-LABEL: llvm.func @strided_slice( +func @strided_slice1(%arg0: vector<4xf32>) -> vector<2xf32> { %0 = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> + return %0 : vector<2xf32> +} +// CHECK-LABEL: llvm.func @strided_slice1 // CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float // CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm<"<2 x float>"> // CHECK: llvm.mlir.constant(2 : index) : !llvm.i64 @@ -439,7 +440,11 @@ // CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK: llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> - %1 = vector.strided_slice %arg1 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8xf32> to vector<2x8xf32> +func @strided_slice2(%arg0: vector<4x8xf32>) -> vector<2x8xf32> { + %0 = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8xf32> to vector<2x8xf32> + return %0 : vector<2x8xf32> +} +// CHECK-LABEL: llvm.func @strided_slice2 // CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float // CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2x8xf32>) : !llvm<"[2 x <8 x float>]"> // CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"[4 x <8 x float>]"> @@ -447,7 +452,11 @@ // CHECK: llvm.extractvalue %{{.*}}[3] : !llvm<"[4 x <8 x float>]"> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"[2 x <8 x float>]"> - %2 = vector.strided_slice %arg1 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32> +func @strided_slice3(%arg0: vector<4x8xf32>) -> vector<2x2xf32> { + %0 = vector.strided_slice %arg0 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32> + return %0 : vector<2x2xf32> +} +// CHECK-LABEL: llvm.func @strided_slice3 // CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float // CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2x2xf32>) : !llvm<"[2 x <2 x float>]"> // @@ -479,17 +488,19 @@ // CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> // CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[2 x <2 x float>]"> - return -} - -func @insert_strided_slice(%a: vector<2x2xf32>, %b: vector<4x4xf32>, %c: vector<4x4x4xf32>) { -// CHECK-LABEL: @insert_strided_slice - +func @insert_strided_slice1(%b: vector<4x4xf32>, %c: vector<4x4x4xf32>) -> vector<4x4x4xf32> { %0 = vector.insert_strided_slice %b, %c {offsets = [2, 0, 0], strides = [1, 1]} : vector<4x4xf32> into vector<4x4x4xf32> + return %0 : vector<4x4x4xf32> +} +// CHECK-LABEL: @insert_strided_slice1 // CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x [4 x <4 x float>]]"> // CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [4 x <4 x float>]]"> - %1 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +func @insert_strided_slice2(%a: vector<2x2xf32>, %b: vector<4x4xf32>) -> vector<4x4xf32> { + %0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> + return %0 : vector<4x4xf32> +} +// CHECK-LABEL: @insert_strided_slice2 // // Subvector vector<2xf32> @0 into vector<4xf32> @2 // CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[2 x <2 x float>]"> @@ -521,6 +532,19 @@ // CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> // CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x <4 x float>]"> - return +func @extract_strides(%arg0: vector<3x3xf32>) -> vector<1x1xf32> { + %0 = vector.extract_slices %arg0, [2, 2], [1, 1] + : vector<3x3xf32> into tuple, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> + %1 = vector.tuple_get %0, 3 : tuple, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> + return %1 : vector<1x1xf32> } - +// CHECK-LABEL: extract_strides(%arg0: !llvm<"[3 x <3 x float>]">) +// CHECK: %[[s0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1x1xf32>) : !llvm<"[1 x <1 x float>]"> +// CHECK: %[[s1:.*]] = llvm.extractvalue %arg0[2] : !llvm<"[3 x <3 x float>]"> +// CHECK: %[[s3:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1xf32>) : !llvm<"<1 x float>"> +// CHECK: %[[s4:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 +// CHECK: %[[s5:.*]] = llvm.extractelement %[[s1]][%[[s4]] : !llvm.i64] : !llvm<"<3 x float>"> +// CHECK: %[[s6:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: %[[s7:.*]] = llvm.insertelement %[[s5]], %[[s3]][%[[s6]] : !llvm.i64] : !llvm<"<1 x float>"> +// CHECK: %[[s8:.*]] = llvm.insertvalue %[[s7]], %[[s0]][0] : !llvm<"[1 x <1 x float>]"> +// CHECK: llvm.return %[[s8]] : !llvm<"[1 x <1 x float>]">