diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -802,7 +802,8 @@ let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat], "integer/index/float type">:$input); - let results = (outs AnyTypeOf<[AnyVector, AnyStaticShapeTensor]>:$aggregate); + let results = (outs AnyTypeOf<[AnyVectorOfAnyRank, + AnyStaticShapeTensor]>:$aggregate); let builders = [ OpBuilder<(ins "Value":$element, "Type":$aggregateType), diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -702,7 +702,7 @@ }; // The Splat operation is lowered to an insertelement + a shufflevector -// operation. Splat to only 1-d vector result types are lowered. +// operation. Splat to only 0-d and 1-d vector result types are lowered. struct SplatOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -710,7 +710,7 @@ matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType resultType = splatOp.getType().dyn_cast(); - if (!resultType || resultType.getRank() != 1) + if (!resultType || resultType.getRank() > 1) return failure(); // First insert it into an undef vector so we can shuffle it. @@ -721,6 +721,14 @@ typeConverter->convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); + // For 0-d vector, we simply do `insertelement`. + if (resultType.getRank() == 0) { + rewriter.replaceOpWithNewOp( + splatOp, vectorType, undef, adaptor.getInput(), zero); + return success(); + } + + // For 1-d vector, we additionally do a `vectorshuffle`. auto v = rewriter.create( splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); @@ -745,7 +753,7 @@ matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType resultType = splatOp.getType().dyn_cast(); - if (!resultType || resultType.getRank() == 1) + if (!resultType || resultType.getRank() <= 1) return failure(); // First insert it into an undef vector so we can shuffle it. diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -454,6 +454,21 @@ br ^bb1 } +// ----- + +// CHECK-LABEL: @splat_0d +// CHECK-SAME: %[[ARG:.*]]: f32 +func @splat_0d(%a: f32) -> vector { + %v = splat %a : vector + return %v : vector +} +// CHECK-NEXT: %[[UNDEF:[0-9]+]] = llvm.mlir.undef : vector<1xf32> +// CHECK-NEXT: %[[ZERO:[0-9]+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK-NEXT: %[[V:[0-9]+]] = llvm.insertelement %[[ARG]], %[[UNDEF]][%[[ZERO]] : i32] : vector<1xf32> +// CHECK-NEXT: llvm.return %[[V]] : vector<1xf32> + +// ----- + // CHECK-LABEL: @splat // CHECK-SAME: %[[A:arg[0-9]+]]: vector<4xf32> // CHECK-SAME: %[[ELT:arg[0-9]+]]: f32 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir @@ -21,6 +21,13 @@ return } +func @splat_0d(%a: f32) { + %1 = splat %a : vector + // CHECK: ( 42 ) + vector.print %1: vector + return +} + func @entry() { %0 = arith.constant 42.0 : f32 %1 = arith.constant dense<0.0> : vector @@ -30,5 +37,8 @@ %3 = arith.constant dense<42.0> : vector call @print_vector_0d(%3) : (vector) -> () + %4 = arith.constant 42.0 : f32 + call @splat_0d(%4) : (f32) -> () + return }