diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -802,7 +802,8 @@ let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat], "integer/index/float type">:$input); - let results = (outs AnyTypeOf<[AnyVector, AnyStaticShapeTensor]>:$aggregate); + let results = (outs AnyTypeOf<[AnyVectorOfAnyRank, + AnyStaticShapeTensor]>:$aggregate); let builders = [ OpBuilder<(ins "Value":$element, "Type":$aggregateType), diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -302,7 +302,7 @@ Results<(outs AnyType:$dest)> { let summary = "Multi-dimensional reduction operation"; let description = [{ - Reduces an n-D vector into an (n-k)-D vector (or a scalar when k == n) + Reduces an n-D vector into an (n-k)-D vector (or a scalar when k == n) using the given operation (add/mul/min/max for int/fp and and/or/xor for int only). @@ -380,7 +380,7 @@ PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>]>, Arguments<(ins AnyType:$source)>, - Results<(outs AnyVector:$vector)> { + Results<(outs AnyVectorOfAnyRank:$vector)> { let summary = "broadcast operation"; let description = [{ Broadcasts the scalar or k-D vector value in the source operand diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -702,7 +702,7 @@ }; // The Splat operation is lowered to an insertelement + a shufflevector -// operation. Splat to only 1-d vector result types are lowered. +// operation. Splat to only 0-d and 1-d vector result types are lowered. struct SplatOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -710,7 +710,7 @@ matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType resultType = splatOp.getType().dyn_cast(); - if (!resultType || resultType.getRank() != 1) + if (!resultType || resultType.getRank() > 1) return failure(); // First insert it into an undef vector so we can shuffle it. @@ -721,6 +721,14 @@ typeConverter->convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); + // For 0-d vector, we simply do `insertelement`. + if (resultType.getRank() == 0) { + rewriter.replaceOpWithNewOp( + splatOp, vectorType, undef, adaptor.getInput(), zero); + return success(); + } + + // For 1-d vector, we additionally do a `vectorshuffle`. auto v = rewriter.create( splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); @@ -745,7 +753,7 @@ matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType resultType = splatOp.getType().dyn_cast(); - if (!resultType || resultType.getRank() == 1) + if (!resultType || resultType.getRank() < 2) return failure(); // First insert it into an undef vector so we can shuffle it. diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -546,10 +546,32 @@ VectorType srcType = op.getSourceType().dyn_cast(); Type eltType = dstType.getElementType(); + // Scalar to any vector can use splat. + if (!srcType) { + rewriter.replaceOpWithNewOp(op, dstType, op.source()); + return success(); + } + // Determine rank of source and destination. - int64_t srcRank = srcType ? srcType.getRank() : 0; + int64_t srcRank = srcType.getRank(); int64_t dstRank = dstType.getRank(); + // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. + if (srcRank <= 1 && dstRank == 1) { + // A `vector` is effectively the same as `vector<1xf32>`. + if (srcRank == 0 && dstType.getDimSize(0) == 1) { + rewriter.replaceOp(op, op.source()); + return success(); + } + Value ext; + if (srcRank == 0) + ext = rewriter.create(loc, op.source()); + else + ext = rewriter.create(loc, op.source(), 0); + rewriter.replaceOpWithNewOp(op, dstType, ext); + return success(); + } + // Duplicate this rank. // For example: // %x = broadcast %y : k-D to n-D, k < n @@ -560,11 +582,6 @@ // %b = [%y,%y] : (n-1)-D // %x = [%b,%b,%b,%b] : n-D if (srcRank < dstRank) { - // Scalar to any vector can use splat. - if (srcRank == 0) { - rewriter.replaceOpWithNewOp(op, dstType, op.source()); - return success(); - } // Duplication. VectorType resType = VectorType::get(dstType.getShape().drop_front(), eltType); @@ -593,14 +610,6 @@ return success(); } - // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. - if (srcRank == 1) { - assert(m == 0); - Value ext = rewriter.create(loc, op.source(), 0); - rewriter.replaceOpWithNewOp(op, dstType, ext); - return success(); - } - // Any non-matching dimension forces a stretch along this rank. // For example: // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32> diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -35,6 +35,27 @@ // ----- +func @broadcast_vec0d_from_f32(%arg0: f32) -> vector { + %0 = vector.broadcast %arg0 : f32 to vector + return %0 : vector +} +// CHECK-LABEL: @broadcast_vec0d_from_f32 +// CHECK-SAME: %[[A:.*]]: f32) +// CHECK: %[[T0:.*]] = splat %[[A]] : vector +// CHECK: return %[[T0]] : vector + +// ----- + +func @broadcast_vec0d_from_vec0d(%arg0: vector) -> vector { + %0 = vector.broadcast %arg0 : vector to vector + return %0 : vector +} +// CHECK-LABEL: @broadcast_vec0d_from_vec0d( +// CHECK-SAME: %[[A:.*]]: vector) +// CHECK: return %[[A]] : vector + +// ----- + func @broadcast_vec1d_from_f32(%arg0: f32) -> vector<2xf32> { %0 = vector.broadcast %arg0 : f32 to vector<2xf32> return %0 : vector<2xf32> @@ -89,6 +110,26 @@ // ----- +func @broadcast_vec2d_from_vec0d(%arg0: vector) -> vector<3x2xf32> { + %0 = vector.broadcast %arg0 : vector to vector<3x2xf32> + return %0 : vector<3x2xf32> +} +// CHECK-LABEL: @broadcast_vec2d_from_vec0d( +// CHECK-SAME: %[[A:.*]]: vector) +// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector to vector<1xf32> +// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> +// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>> +// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK: %[[T5:.*]] = llvm.extractelement %[[T0]][%[[T4]] : i64] : vector<1xf32> +// CHECK: %[[T6:.*]] = splat %[[T5]] : vector<2xf32> +// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T6]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>> +// CHECK: %[[T8:.*]] = llvm.insertvalue %[[T6]], %[[T7]][1] : !llvm.array<3 x vector<2xf32>> +// CHECK: %[[T9:.*]] = llvm.insertvalue %[[T6]], %[[T8]][2] : !llvm.array<3 x vector<2xf32>> +// CHECK: %[[T10:.*]] = builtin.unrealized_conversion_cast %[[T9]] : !llvm.array<3 x vector<2xf32>> to vector<3x2xf32> +// CHECK: return %[[T10]] : vector<3x2xf32> + +// ----- + func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> { %0 = vector.broadcast %arg0 : vector<2xf32> to vector<3x2xf32> return %0 : vector<3x2xf32> diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir @@ -21,6 +21,33 @@ return } +func @broadcast_0d(%a: f32) { + %1 = vector.broadcast %a : f32 to vector + // CHECK: ( 42 ) + vector.print %1: vector + + %2 = vector.broadcast %1 : vector to vector + // CHECK: ( 42 ) + vector.print %2: vector + + %3 = vector.broadcast %1 : vector to vector<1xf32> + // CHECK: ( 42 ) + vector.print %3: vector<1xf32> + + %4 = vector.broadcast %1 : vector to vector<2xf32> + // CHECK: ( 42, 42 ) + vector.print %4: vector<2xf32> + + %5 = vector.broadcast %1 : vector to vector<2x1xf32> + // CHECK: ( ( 42 ), ( 42 ) ) + vector.print %5: vector<2x1xf32> + + %6 = vector.broadcast %1 : vector to vector<2x3xf32> + // CHECK: ( ( 42, 42, 42 ), ( 42, 42, 42 ) ) + vector.print %6: vector<2x3xf32> + return +} + func @entry() { %0 = arith.constant 42.0 : f32 %1 = arith.constant dense<0.0> : vector @@ -30,5 +57,8 @@ %3 = arith.constant dense<42.0> : vector call @print_vector_0d(%3) : (vector) -> () + %4 = arith.constant 42.0 : f32 + call @broadcast_0d(%4) : (f32) -> () + return }