diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h @@ -54,11 +54,20 @@ void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns, MLIRContext *context); -/// Collect a set of vector contraction transformation patterns -/// that express all vector.contract ops in terms of more elementary -/// extraction and reduction ops. -void populateVectorContractLoweringPatterns(OwningRewritePatternList &patterns, - MLIRContext *context); +/// Collect a set of vector changing transformation patterns that are +/// related to contracting or expanding vector operations: +/// ContractionOpLowering, +/// ShapeCastOp2DDownCastRewritePattern, ShapeCastOp2DUpCastRewritePattern +/// OuterproductOpLowering +/// These transformation express higher level vector changing ops in +/// terms of more elementary extraction, insertion, reduction, product, +/// and broadcast ops. +void populateVectorChangesLoweringPatterns(OwningRewritePatternList &patterns, + MLIRContext *context); + +// TODO(ajcbik): remove when external references are renamed too +#define populateVectorContractLoweringPatterns \ + populateVectorChangesLoweringPatterns /// Returns the integer type required for subscripts in the vector dialect. IntegerType getVectorSubscriptType(Builder &builder); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -817,6 +817,7 @@ } }; +// TODO(ajcbik): remove this rule once LinAlg tests are cleaned up class VectorOuterProductOpConversion : public ConvertToLLVMPattern { public: explicit VectorOuterProductOpConversion(MLIRContext *context, @@ -1176,12 +1177,12 @@ } // namespace void LowerVectorToLLVMPass::runOnModule() { - // Perform progressive lowering of operations on "slices" and - // all contraction operations. Also applies folding and DCE. + // Perform progressive lowering of operations on slices and + // vector changing operations. Also applies folding and DCE. { OwningRewritePatternList patterns; populateVectorSlicesLoweringPatterns(patterns, &getContext()); - populateVectorContractLoweringPatterns(patterns, &getContext()); + populateVectorChangesLoweringPatterns(patterns, &getContext()); applyPatternsGreedily(getModule(), patterns); } diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -864,6 +864,53 @@ } }; +/// Progressive lowering of OuterProductOp. +/// One: +/// %x = vector.outerproduct %lhs, %rhs, %acc +/// is replaced by: +/// %z = zero-result +/// %0 = vector.extract %lhs[0] +/// %1 = vector.broadcast %0 +/// %2 = vector.extract %acc[0] +/// %3 = vector.fma %1, %arg1, %2 +/// %4 = vector.insert %3, %z[0] +/// .. +/// %x = vector.insert %.., %..[N-1] +/// +class OuterProductOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(vector::OuterProductOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + VectorType rhsType = op.getOperandVectorTypeRHS(); + VectorType resType = op.getVectorType(); + Type eltType = resType.getElementType(); + Value acc = (op.acc().empty()) ? nullptr : op.acc()[0]; + + Value zero = rewriter.create(loc, eltType, + rewriter.getZeroAttr(eltType)); + Value result = rewriter.create(loc, resType, zero); + for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { + auto pos = rewriter.getI64ArrayAttr(d); + Value x = rewriter.create(loc, eltType, op.lhs(), pos); + Value b = rewriter.create(loc, rhsType, x); + Value m; + if (acc != nullptr) { + Value z = rewriter.create(loc, rhsType, acc, pos); + m = rewriter.create(loc, b, op.rhs(), z); + } else { + m = rewriter.create(loc, b, op.rhs()); + } + result = rewriter.create(loc, resType, m, result, pos); + } + rewriter.replaceOp(op, result); + return matchSuccess(); + } +}; + /// Progressive lowering of ContractionOp. /// One: /// %x = vector.contract with at least one free/batch dimension @@ -1255,11 +1302,9 @@ patterns.insert(context); } -void mlir::vector::populateVectorContractLoweringPatterns( +void mlir::vector::populateVectorChangesLoweringPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert(context); + patterns.insert( + context); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -204,39 +204,64 @@ %2 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32> return %2 : vector<2x3xf32> } -// CHECK-LABEL: llvm.func @outerproduct -// CHECK: llvm.mlir.undef : !llvm<"[2 x <3 x float>]"> -// CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>"> -// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<3 x float>"> -// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]"> -// CHECK: llvm.shufflevector {{.*}} [1 : i32, 1 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>"> -// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<3 x float>"> -// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]"> -// CHECK: llvm.return {{.*}} : !llvm<"[2 x <3 x float>]"> +// CHECK-LABEL: llvm.func @outerproduct( +// CHECK-SAME: %[[A:.*]]: !llvm<"<2 x float>">, +// CHECK-SAME: %[[B:.*]]: !llvm<"<3 x float>">) +// CHECK: %[[T0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<2x3xf32>) +// CHECK: %[[T1:.*]] = llvm.mlir.constant(0 : i64) : !llvm.i64 +// CHECK: %[[T2:.*]] = llvm.extractelement %[[A]][%[[T1]] : !llvm.i64] : !llvm<"<2 x float>"> +// CHECK: %[[T3:.*]] = llvm.mlir.undef : !llvm<"<3 x float>"> +// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: %[[T5:.*]] = llvm.insertelement %[[T2]], %[[T3]][%4 : !llvm.i64] : !llvm<"<3 x float>"> +// CHECK: %[[T6:.*]] = llvm.shufflevector %[[T5]], %[[T3]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> +// CHECK: %[[T7:.*]] = llvm.fmul %[[T6]], %[[B]] : !llvm<"<3 x float>"> +// CHECK: %[[T8:.*]] = llvm.insertvalue %[[T7]], %[[T0]][0] : !llvm<"[2 x <3 x float>]"> +// CHECK: %[[T9:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64 +// CHECK: %[[T10:.*]] = llvm.extractelement %[[A]][%9 : !llvm.i64] : !llvm<"<2 x float>"> +// CHECK: %[[T11:.*]] = llvm.mlir.undef : !llvm<"<3 x float>"> +// CHECK: %[[T12:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: %[[T13:.*]] = llvm.insertelement %[[T10]], %[[T11]][%12 : !llvm.i64] : !llvm<"<3 x float>"> +// CHECK: %[[T14:.*]] = llvm.shufflevector %[[T13]], %[[T11]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> +// CHECK: %[[T15:.*]] = llvm.fmul %[[T14]], %[[B]] : !llvm<"<3 x float>"> +// CHECK: %[[T16:.*]] = llvm.insertvalue %[[T15]], %[[T8]][1] : !llvm<"[2 x <3 x float>]"> +// CHECK: llvm.return %[[T16]] : !llvm<"[2 x <3 x float>]"> func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> { %2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xf32>, vector<3xf32> return %2 : vector<2x3xf32> } -// CHECK-LABEL: llvm.func @outerproduct_add -// CHECK: llvm.mlir.undef : !llvm<"[2 x <3 x float>]"> -// CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]"> -// CHECK: "llvm.intr.fma"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>"> -// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]"> -// CHECK: llvm.shufflevector {{.*}} [1 : i32, 1 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>"> -// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]"> -// CHECK: "llvm.intr.fma"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>"> -// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]"> -// CHECK: llvm.return {{.*}} : !llvm<"[2 x <3 x float>]"> +// CHECK-LABEL: llvm.func @outerproduct_add( +// CHECK-SAME: %[[A:.*]]: !llvm<"<2 x float>">, +// CHECK-SAME: %[[B:.*]]: !llvm<"<3 x float>">, +// CHECK-SAME: %[[C:.*]]: !llvm<"[2 x <3 x float>]">) +// CHECK: %[[T0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<2x3xf32>) +// CHECK: %[[T1:.*]] = llvm.mlir.constant(0 : i64) : !llvm.i64 +// CHECK: %[[T2:.*]] = llvm.extractelement %[[A]][%[[T1]] : !llvm.i64] : !llvm<"<2 x float>"> +// CHECK: %[[T3:.*]] = llvm.mlir.undef : !llvm<"<3 x float>"> +// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: %[[T5:.*]] = llvm.insertelement %[[T2]], %[[T3]][%[[T4]] : !llvm.i64] : !llvm<"<3 x float>"> +// CHECK: %[[T6:.*]] = llvm.shufflevector %[[T5]], %[[T3]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> +// CHECK: %[[T7:.*]] = llvm.extractvalue %[[C]][0] : !llvm<"[2 x <3 x float>]"> +// CHECK: %[[T8:.*]] = "llvm.intr.fma"(%[[T6]], %[[B]], %[[T7]]) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) +// CHECK: %[[T9:.*]] = llvm.insertvalue %[[T8]], %[[T0]][0] : !llvm<"[2 x <3 x float>]"> +// CHECK: %[[T10:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64 +// CHECK: %[[T11:.*]] = llvm.extractelement %[[A]][%[[T10]] : !llvm.i64] : !llvm<"<2 x float>"> +// CHECK: %[[T12:.*]] = llvm.mlir.undef : !llvm<"<3 x float>"> +// CHECK: %[[T13:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: %[[T14:.*]] = llvm.insertelement %[[T11]], %[[T12]][%[[T13]] : !llvm.i64] : !llvm<"<3 x float>"> +// CHECK: %[[T15:.*]] = llvm.shufflevector %[[T14]], %[[T12]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> +// CHECK: %[[T16:.*]] = llvm.extractvalue %[[C]][1] : !llvm<"[2 x <3 x float>]"> +// CHECK: %[[T17:.*]] = "llvm.intr.fma"(%[[T15]], %[[B]], %[[T16]]) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) +// CHECK: %[[T18:.*]] = llvm.insertvalue %[[T17]], %[[T9]][1] : !llvm<"[2 x <3 x float>]"> +// CHECK: llvm.return %[[T18]] : !llvm<"[2 x <3 x float>]"> func @shuffle_1D_direct(%arg0: vector<2xf32>, %arg1: vector<2xf32>) -> vector<2xf32> { %1 = vector.shuffle %arg0, %arg1 [0, 1] : vector<2xf32>, vector<2xf32> return %1 : vector<2xf32> } -// CHECK-LABEL: llvm.func @shuffle_1D_direct +// CHECK-LABEL: llvm.func @shuffle_1D_direct( // CHECK-SAME: %[[A:.*]]: !llvm<"<2 x float>">, -// CHECK-SAME: %[[B:.*]]: !llvm<"<2 x float>"> +// CHECK-SAME: %[[B:.*]]: !llvm<"<2 x float>">) // CHECK: %[[s:.*]] = llvm.shufflevector %[[A]], %[[B]] [0, 1] : !llvm<"<2 x float>">, !llvm<"<2 x float>"> // CHECK: llvm.return %[[s]] : !llvm<"<2 x float>"> @@ -244,9 +269,9 @@ %1 = vector.shuffle %arg0, %arg1 [4, 3, 2, 1, 0] : vector<2xf32>, vector<3xf32> return %1 : vector<5xf32> } -// CHECK-LABEL: llvm.func @shuffle_1D +// CHECK-LABEL: llvm.func @shuffle_1D( // CHECK-SAME: %[[A:.*]]: !llvm<"<2 x float>">, -// CHECK-SAME: %[[B:.*]]: !llvm<"<3 x float>"> +// CHECK-SAME: %[[B:.*]]: !llvm<"<3 x float>">) // CHECK: %[[u0:.*]] = llvm.mlir.undef : !llvm<"<5 x float>"> // CHECK: %[[c2:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 // CHECK: %[[e1:.*]] = llvm.extractelement %[[B]][%[[c2]] : !llvm.i64] : !llvm<"<3 x float>"> @@ -274,9 +299,9 @@ %1 = vector.shuffle %a, %b[1, 0, 2] : vector<1x4xf32>, vector<2x4xf32> return %1 : vector<3x4xf32> } -// CHECK-LABEL: llvm.func @shuffle_2D +// CHECK-LABEL: llvm.func @shuffle_2D( // CHECK-SAME: %[[A:.*]]: !llvm<"[1 x <4 x float>]">, -// CHECK-SAME: %[[B:.*]]: !llvm<"[2 x <4 x float>]"> +// CHECK-SAME: %[[B:.*]]: !llvm<"[2 x <4 x float>]">) // CHECK: %[[u0:.*]] = llvm.mlir.undef : !llvm<"[3 x <4 x float>]"> // CHECK: %[[e1:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]"> // CHECK: %[[i1:.*]] = llvm.insertvalue %[[e1]], %[[u0]][0] : !llvm<"[3 x <4 x float>]"> @@ -291,8 +316,8 @@ %1 = vector.extractelement %arg0[%0 : i32]: vector<16xf32> return %1 : f32 } -// CHECK-LABEL: llvm.func @extract_element -// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x float>"> +// CHECK-LABEL: llvm.func @extract_element( +// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x float>">) // CHECK: %[[c:.*]] = llvm.mlir.constant(15 : i32) : !llvm.i32 // CHECK: %[[x:.*]] = llvm.extractelement %[[A]][%[[c]] : !llvm.i32] : !llvm<"<16 x float>"> // CHECK: llvm.return %[[x]] : !llvm.float @@ -337,9 +362,9 @@ %1 = vector.insertelement %arg0, %arg1[%0 : i32] : vector<4xf32> return %1 : vector<4xf32> } -// CHECK-LABEL: llvm.func @insert_element +// CHECK-LABEL: llvm.func @insert_element( // CHECK-SAME: %[[A:.*]]: !llvm.float, -// CHECK-SAME: %[[B:.*]]: !llvm<"<4 x float>"> +// CHECK-SAME: %[[B:.*]]: !llvm<"<4 x float>">) // CHECK: %[[c:.*]] = llvm.mlir.constant(3 : i32) : !llvm.i32 // CHECK: %[[x:.*]] = llvm.insertelement %[[A]], %[[B]][%[[c]] : !llvm.i32] : !llvm<"<4 x float>"> // CHECK: llvm.return %[[x]] : !llvm<"<4 x float>"> @@ -399,8 +424,8 @@ vector.print %arg0 : i32 return } -// CHECK-LABEL: llvm.func @vector_print_scalar_i32 -// CHECK-SAME: %[[A:.*]]: !llvm.i32 +// CHECK-LABEL: llvm.func @vector_print_scalar_i32( +// CHECK-SAME: %[[A:.*]]: !llvm.i32) // CHECK: llvm.call @print_i32(%[[A]]) : (!llvm.i32) -> () // CHECK: llvm.call @print_newline() : () -> () @@ -408,8 +433,8 @@ vector.print %arg0 : i64 return } -// CHECK-LABEL: llvm.func @vector_print_scalar_i64 -// CHECK-SAME: %[[A:.*]]: !llvm.i64 +// CHECK-LABEL: llvm.func @vector_print_scalar_i64( +// CHECK-SAME: %[[A:.*]]: !llvm.i64) // CHECK: llvm.call @print_i64(%[[A]]) : (!llvm.i64) -> () // CHECK: llvm.call @print_newline() : () -> () @@ -417,8 +442,8 @@ vector.print %arg0 : f32 return } -// CHECK-LABEL: llvm.func @vector_print_scalar_f32 -// CHECK-SAME: %[[A:.*]]: !llvm.float +// CHECK-LABEL: llvm.func @vector_print_scalar_f32( +// CHECK-SAME: %[[A:.*]]: !llvm.float) // CHECK: llvm.call @print_f32(%[[A]]) : (!llvm.float) -> () // CHECK: llvm.call @print_newline() : () -> () @@ -426,8 +451,8 @@ vector.print %arg0 : f64 return } -// CHECK-LABEL: llvm.func @vector_print_scalar_f64 -// CHECK-SAME: %[[A:.*]]: !llvm.double +// CHECK-LABEL: llvm.func @vector_print_scalar_f64( +// CHECK-SAME: %[[A:.*]]: !llvm.double) // CHECK: llvm.call @print_f64(%[[A]]) : (!llvm.double) -> () // CHECK: llvm.call @print_newline() : () -> () @@ -435,8 +460,8 @@ vector.print %arg0 : vector<2x2xf32> return } -// CHECK-LABEL: llvm.func @vector_print_vector -// CHECK-SAME: %[[A:.*]]: !llvm<"[2 x <2 x float>]"> +// CHECK-LABEL: llvm.func @vector_print_vector( +// CHECK-SAME: %[[A:.*]]: !llvm<"[2 x <2 x float>]">) // CHECK: llvm.call @print_open() : () -> () // CHECK: %[[x0:.*]] = llvm.extractvalue %[[A]][0] : !llvm<"[2 x <2 x float>]"> // CHECK: llvm.call @print_open() : () -> () @@ -575,9 +600,9 @@ vector<2x4xf32> into vector<16x4x8xf32> return %0 : vector<16x4x8xf32> } -// CHECK-LABEL: llvm.func @insert_strided_slice3 +// CHECK-LABEL: llvm.func @insert_strided_slice3( // CHECK-SAME: %[[A:.*]]: !llvm<"[2 x <4 x float>]">, -// CHECK-SAME: %[[B:.*]]: !llvm<"[16 x [4 x <8 x float>]]"> +// CHECK-SAME: %[[B:.*]]: !llvm<"[16 x [4 x <8 x float>]]">) // CHECK: %[[s0:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[16 x [4 x <8 x float>]]"> // CHECK: %[[s1:.*]] = llvm.extractvalue %[[A]][0] : !llvm<"[2 x <4 x float>]"> // CHECK: %[[s2:.*]] = llvm.extractvalue %[[s0]][0] : !llvm<"[4 x <8 x float>]"> @@ -626,8 +651,8 @@ %1 = vector.tuple_get %0, 3 : tuple, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> return %1 : vector<1x1xf32> } -// CHECK-LABEL: llvm.func @extract_strides -// CHECK-SAME: %[[A:.*]]: !llvm<"[3 x <3 x float>]"> +// CHECK-LABEL: llvm.func @extract_strides( +// CHECK-SAME: %[[A:.*]]: !llvm<"[3 x <3 x float>]">) // CHECK: %[[s0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1x1xf32>) : !llvm<"[1 x <1 x float>]"> // CHECK: %[[s1:.*]] = llvm.extractvalue %[[A]][2] : !llvm<"[3 x <3 x float>]"> // CHECK: %[[s3:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1xf32>) : !llvm<"<1 x float>"> @@ -667,8 +692,8 @@ %0 = vector.reduction "add", %arg0 : vector<16xf32> into f32 return %0 : f32 } -// CHECK-LABEL: llvm.func @reduce_f32 -// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x float>"> +// CHECK-LABEL: llvm.func @reduce_f32( +// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x float>">) // CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float // CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]]) // CHECK: llvm.return %[[V]] : !llvm.float @@ -677,8 +702,8 @@ %0 = vector.reduction "add", %arg0 : vector<16xf64> into f64 return %0 : f64 } -// CHECK-LABEL: llvm.func @reduce_f64 -// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x double>"> +// CHECK-LABEL: llvm.func @reduce_f64( +// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x double>">) // CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f64) : !llvm.double // CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]]) // CHECK: llvm.return %[[V]] : !llvm.double @@ -687,8 +712,8 @@ %0 = vector.reduction "add", %arg0 : vector<16xi32> into i32 return %0 : i32 } -// CHECK-LABEL: llvm.func @reduce_i32 -// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x i32>"> +// CHECK-LABEL: llvm.func @reduce_i32( +// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x i32>">) // CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]]) // CHECK: llvm.return %[[V]] : !llvm.i32 @@ -696,8 +721,8 @@ %0 = vector.reduction "add", %arg0 : vector<16xi64> into i64 return %0 : i64 } -// CHECK-LABEL: llvm.func @reduce_i64 -// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x i64>"> +// CHECK-LABEL: llvm.func @reduce_i64( +// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x i64>">) // CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]]) // CHECK: llvm.return %[[V]] : !llvm.i64 diff --git a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-changes-transforms.mlir rename from mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir rename to mlir/test/Dialect/VectorOps/vector-changes-transforms.mlir --- a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/VectorOps/vector-changes-transforms.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s +// RUN: mlir-opt %s -test-vector-changes-conversion | FileCheck %s #dotp_accesses = [ affine_map<(i) -> (i)>, @@ -251,6 +251,50 @@ return %0 : f32 } +// CHECK-LABEL: func @outerproduct_noacc +// CHECK-SAME: %[[A:.*0]]: vector<2xf32>, +// CHECK-SAME: %[[B:.*1]]: vector<3xf32> +// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<2x3xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32> +// CHECK: %[[T2:.*]] = mulf %[[T1]], %[[B]] : vector<3xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xf32> +// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32> +// CHECK: %[[T6:.*]] = mulf %[[T5]], %[[B]] : vector<3xf32> +// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32> +// CHECK: return %[[T7]] : vector<2x3xf32> + +func @outerproduct_noacc(%arg0: vector<2xf32>, + %arg1: vector<3xf32>) -> vector<2x3xf32> { + %0 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32> + return %0: vector<2x3xf32> +} + +// CHECK-LABEL: func @outerproduct_acc +// CHECK-SAME: %[[A:.*0]]: vector<2xf32>, +// CHECK-SAME: %[[B:.*1]]: vector<3xf32>, +// CHECK-SAME: %[[C:.*2]]: vector<2x3xf32> +// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<2x3xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> +// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32> +// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xf32> +// CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32> +// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> +// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2xf32> +// CHECK: %[[T6:.*]] = vector.broadcast %[[T5]] : f32 to vector<3xf32> +// CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<2x3xf32> +// CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32> +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32> +// CHECK: return %[[T9]] : vector<2x3xf32> + +func @outerproduct_acc(%arg0: vector<2xf32>, + %arg1: vector<3xf32>, + %arg2: vector<2x3xf32>) -> vector<2x3xf32> { + %0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xf32>, vector<3xf32> + return %0: vector<2x3xf32> +} + // Shape up and downcasts for 2-D vectors, for supporting conversion to // llvm.matrix operations // CHECK-LABEL: func @shape_casts diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -42,11 +42,11 @@ } }; -struct TestVectorContractionConversion - : public FunctionPass { +struct TestVectorChangesConversion + : public FunctionPass { void runOnFunction() override { OwningRewritePatternList patterns; - populateVectorContractLoweringPatterns(patterns, &getContext()); + populateVectorChangesLoweringPatterns(patterns, &getContext()); applyPatternsGreedily(getFunction(), patterns); } }; @@ -63,8 +63,9 @@ "test-vector-slices-conversion", "Test conversion patterns that lower slices ops in the vector dialect"); - PassRegistration contractionPass( - "test-vector-contraction-conversion", - "Test conversion patterns that lower contract ops in the vector dialect"); + PassRegistration contractionPass( + "test-vector-changes-conversion", + "Test conversion patterns that lower vector changing ops " + "in the vector dialect"); } } // namespace mlir