diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -693,54 +693,6 @@ let hasFolder = 1; } -//===----------------------------------------------------------------------===// -// SplatOp -//===----------------------------------------------------------------------===// - -def SplatOp : Std_Op<"splat", [NoSideEffect, - TypesMatchWith<"operand type matches element type of result", - "aggregate", "input", - "$_self.cast().getElementType()">]> { - let summary = "splat or broadcast operation"; - let description = [{ - Broadcast the operand to all elements of the result vector or tensor. The - operand has to be of integer/index/float type. When the result is a tensor, - it has to be statically shaped. - - Example: - - ```mlir - %s = load %A[%i] : memref<128xf32> - %v = splat %s : vector<4xf32> - %t = splat %s : tensor<8x16xi32> - ``` - - TODO: This operation is easy to extend to broadcast to dynamically shaped - tensors in the same way dynamically shaped memrefs are handled. - - ```mlir - // Broadcasts %s to a 2-d dynamically shaped tensor, with %m, %n binding - // to the sizes of the two dynamic dimensions. - %m = "foo"() : () -> (index) - %n = "bar"() : () -> (index) - %t = splat %s [%m, %n] : tensor - ``` - }]; - - let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat], - "integer/index/float type">:$input); - let results = (outs AnyTypeOf<[AnyVectorOfAnyRank, - AnyStaticShapeTensor]>:$aggregate); - - let builders = [ - OpBuilder<(ins "Value":$element, "Type":$aggregateType), - [{ build($_builder, $_state, aggregateType, element); }]>]; - - let hasFolder = 1; - - let assemblyFormat = "$input attr-dict `:` type($aggregate)"; -} - //===----------------------------------------------------------------------===// // SwitchOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -963,6 +963,53 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// SplatOp +//===----------------------------------------------------------------------===// + +def Tensor_SplatOp : Tensor_Op<"splat", [ + NoSideEffect, + TypesMatchWith<"operand type matches element type of result", + "aggregate", "input", + "$_self.cast().getElementType()"> + ]> { + let summary = "tensor splat or broadcast operation"; + let description = [{ + Broadcast the operand to all elements of the result tensor. The operand is + required to be of integer/index/float type, and the result tensor must be + statically shaped. + + Example: + + ```mlir + %s = arith.constant 10.1 : f32 + %t = tensor.splat %s : tensor<8x16xi32> + ``` + + TODO: This operation is easy to extend to broadcast to dynamically shaped + tensors: + + ```mlir + // Broadcasts %s to a 2-d dynamically shaped tensor, with %m, %n binding + // to the sizes of the two dynamic dimensions. + %m = "foo"() : () -> (index) + %n = "bar"() : () -> (index) + %t = tensor.splat %s [%m, %n] : tensor + ``` + }]; + + let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat], + "integer/index/float type">:$input); + let results = (outs AnyStaticShapeTensor:$aggregate); + + let builders = [ + OpBuilder<(ins "Value":$element, "Type":$aggregateType), + [{ build($_builder, $_state, aggregateType, element); }]>]; + let assemblyFormat = "$input attr-dict `:` type($aggregate)"; + + let hasFolder = 1; + let verifier = ?; +} //===----------------------------------------------------------------------===// // YieldOp diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -2397,6 +2397,42 @@ let assemblyFormat = "$matrix attr-dict `:` type($matrix) `->` type($res)"; } +//===----------------------------------------------------------------------===// +// SplatOp +//===----------------------------------------------------------------------===// + +def Vector_SplatOp : Vector_Op<"splat", [ + NoSideEffect, + TypesMatchWith<"operand type matches element type of result", + "aggregate", "input", + "$_self.cast().getElementType()"> + ]> { + let summary = "vector splat or broadcast operation"; + let description = [{ + Broadcast the operand to all elements of the result vector. The operand is + required to be of integer/index/float type. + + Example: + + ```mlir + %s = arith.constant 10.1 : f32 + %t = vector.splat %s : vector<8x16xi32> + ``` + }]; + + let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat], + "integer/index/float type">:$input); + let results = (outs AnyVectorOfAnyRank:$aggregate); + + let builders = [ + OpBuilder<(ins "Value":$element, "Type":$aggregateType), + [{ build($_builder, $_state, aggregateType, element); }]>]; + let assemblyFormat = "$input attr-dict `:` type($aggregate)"; + + let hasFolder = 1; + let verifier = ?; +} + //===----------------------------------------------------------------------===// // VectorScaleOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -678,99 +678,6 @@ using Super::Super; }; -// The Splat operation is lowered to an insertelement + a shufflevector -// operation. Splat to only 0-d and 1-d vector result types are lowered. -struct SplatOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - VectorType resultType = splatOp.getType().dyn_cast(); - if (!resultType || resultType.getRank() > 1) - return failure(); - - // First insert it into an undef vector so we can shuffle it. - auto vectorType = typeConverter->convertType(splatOp.getType()); - Value undef = rewriter.create(splatOp.getLoc(), vectorType); - auto zero = rewriter.create( - splatOp.getLoc(), - typeConverter->convertType(rewriter.getIntegerType(32)), - rewriter.getZeroAttr(rewriter.getIntegerType(32))); - - // For 0-d vector, we simply do `insertelement`. - if (resultType.getRank() == 0) { - rewriter.replaceOpWithNewOp( - splatOp, vectorType, undef, adaptor.getInput(), zero); - return success(); - } - - // For 1-d vector, we additionally do a `vectorshuffle`. - auto v = rewriter.create( - splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); - - int64_t width = splatOp.getType().cast().getDimSize(0); - SmallVector zeroValues(width, 0); - - // Shuffle the value across the desired number of elements. - ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); - rewriter.replaceOpWithNewOp(splatOp, v, undef, - zeroAttrs); - return success(); - } -}; - -// The Splat operation is lowered to an insertelement + a shufflevector -// operation. Splat to only 2+-d vector result types are lowered by the -// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. -struct SplatNdOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - VectorType resultType = splatOp.getType().dyn_cast(); - if (!resultType || resultType.getRank() <= 1) - return failure(); - - // First insert it into an undef vector so we can shuffle it. - auto loc = splatOp.getLoc(); - auto vectorTypeInfo = - LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); - auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; - auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; - if (!llvmNDVectorTy || !llvm1DVectorTy) - return failure(); - - // Construct returned value. - Value desc = rewriter.create(loc, llvmNDVectorTy); - - // Construct a 1-D vector with the splatted value that we insert in all the - // places within the returned descriptor. - Value vdesc = rewriter.create(loc, llvm1DVectorTy); - auto zero = rewriter.create( - loc, typeConverter->convertType(rewriter.getIntegerType(32)), - rewriter.getZeroAttr(rewriter.getIntegerType(32))); - Value v = rewriter.create(loc, llvm1DVectorTy, vdesc, - adaptor.getInput(), zero); - - // Shuffle the value across the desired number of elements. - int64_t width = resultType.getDimSize(resultType.getRank() - 1); - SmallVector zeroValues(width, 0); - ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); - v = rewriter.create(loc, v, v, zeroAttrs); - - // Iterate of linear index, convert to coords space and insert splatted 1-D - // vector in each position. - nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { - desc = rewriter.create(loc, llvmNDVectorTy, desc, v, - position); - }); - rewriter.replaceOp(splatOp, desc); - return success(); - } -}; - /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be /// retried until it succeeds in atomically storing a new value into memory. /// @@ -914,8 +821,6 @@ GenericAtomicRMWOpLowering, ReturnOpLowering, SelectOpLowering, - SplatOpLowering, - SplatNdOpLowering, SwitchOpLowering>(converter); // clang-format on } diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -55,16 +55,6 @@ ConversionPatternRewriter &rewriter) const override; }; -/// Converts std.splat to spv.CompositeConstruct. -class SplatPattern final : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(SplatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - /// Converts std.br to spv.Branch. struct BranchOpPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -178,22 +168,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// SplatOp -//===----------------------------------------------------------------------===// - -LogicalResult -SplatPattern::matchAndRewrite(SplatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto dstVecType = op.getType().dyn_cast(); - if (!dstVecType || !spirv::CompositeType::isValid(dstVecType)) - return failure(); - SmallVector source(dstVecType.getNumElements(), adaptor.getInput()); - rewriter.replaceOpWithNewOp(op, dstVecType, - source); - return success(); -} - //===----------------------------------------------------------------------===// // BranchOpPattern //===----------------------------------------------------------------------===// @@ -237,8 +211,8 @@ spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, - ReturnOpPattern, SelectOpPattern, SplatPattern, BranchOpPattern, - CondBranchOpPattern>(typeConverter, context); + ReturnOpPattern, SelectOpPattern, BranchOpPattern, CondBranchOpPattern>( + typeConverter, context); } void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -778,7 +778,7 @@ auto elemType = vType.getElementType(); Value zero = rewriter.create( loc, elemType, rewriter.getZeroAttr(elemType)); - Value desc = rewriter.create(loc, vType, zero); + Value desc = rewriter.create(loc, vType, zero); for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { Value extrLHS = rewriter.create(loc, op.lhs(), i); Value extrRHS = rewriter.create(loc, op.rhs(), i); @@ -1062,6 +1062,99 @@ } }; +/// The Splat operation is lowered to an insertelement + a shufflevector +/// operation. Splat to only 0-d and 1-d vector result types are lowered. +struct VectorSplatOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType resultType = splatOp.getType().cast(); + if (resultType.getRank() > 1) + return failure(); + + // First insert it into an undef vector so we can shuffle it. + auto vectorType = typeConverter->convertType(splatOp.getType()); + Value undef = rewriter.create(splatOp.getLoc(), vectorType); + auto zero = rewriter.create( + splatOp.getLoc(), + typeConverter->convertType(rewriter.getIntegerType(32)), + rewriter.getZeroAttr(rewriter.getIntegerType(32))); + + // For 0-d vector, we simply do `insertelement`. + if (resultType.getRank() == 0) { + rewriter.replaceOpWithNewOp( + splatOp, vectorType, undef, adaptor.input(), zero); + return success(); + } + + // For 1-d vector, we additionally do a `vectorshuffle`. + auto v = rewriter.create( + splatOp.getLoc(), vectorType, undef, adaptor.input(), zero); + + int64_t width = splatOp.getType().cast().getDimSize(0); + SmallVector zeroValues(width, 0); + + // Shuffle the value across the desired number of elements. + ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); + rewriter.replaceOpWithNewOp(splatOp, v, undef, + zeroAttrs); + return success(); + } +}; + +/// The Splat operation is lowered to an insertelement + a shufflevector +/// operation. Splat to only 2+-d vector result types are lowered by the +/// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. +struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType resultType = splatOp.getType(); + if (resultType.getRank() <= 1) + return failure(); + + // First insert it into an undef vector so we can shuffle it. + auto loc = splatOp.getLoc(); + auto vectorTypeInfo = + LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); + auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; + auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; + if (!llvmNDVectorTy || !llvm1DVectorTy) + return failure(); + + // Construct returned value. + Value desc = rewriter.create(loc, llvmNDVectorTy); + + // Construct a 1-D vector with the splatted value that we insert in all the + // places within the returned descriptor. + Value vdesc = rewriter.create(loc, llvm1DVectorTy); + auto zero = rewriter.create( + loc, typeConverter->convertType(rewriter.getIntegerType(32)), + rewriter.getZeroAttr(rewriter.getIntegerType(32))); + Value v = rewriter.create(loc, llvm1DVectorTy, vdesc, + adaptor.input(), zero); + + // Shuffle the value across the desired number of elements. + int64_t width = resultType.getDimSize(resultType.getRank() - 1); + SmallVector zeroValues(width, 0); + ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); + v = rewriter.create(loc, v, v, zeroAttrs); + + // Iterate of linear index, convert to coords space and insert splatted 1-D + // vector in each position. + nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { + desc = rewriter.create(loc, llvmNDVectorTy, desc, v, + position); + }); + rewriter.replaceOp(splatOp, desc); + return success(); + } +}; + } // namespace /// Populate the given list with patterns that convert from Vector to LLVM. @@ -1085,8 +1178,8 @@ VectorLoadStoreConversion, VectorGatherOpConversion, VectorScatterOpConversion, - VectorExpandLoadOpConversion, VectorCompressStoreOpConversion>( - converter); + VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, + VectorSplatOpLowering, VectorSplatNdOpLowering>(converter); // Transfer ops with rank > 1 are handled by VectorToSCF. populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); } diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -418,7 +418,7 @@ Location loc = xferOp.getLoc(); auto bufferType = buffer.getType().dyn_cast(); auto vecType = bufferType.getElementType().dyn_cast(); - auto vec = b.create(loc, vecType, xferOp.padding()); + auto vec = b.create(loc, vecType, xferOp.padding()); b.create(loc, vec, buffer, storeIndices); return Value(); @@ -848,8 +848,8 @@ if (auto insertOp = getInsertOp(xferOp)) return insertOp.dest(); Location loc = xferOp.getLoc(); - return rewriter.create(loc, xferOp.getVectorType(), - xferOp.padding()); + return rewriter.create(loc, xferOp.getVectorType(), + xferOp.padding()); } /// If the result of the TransferReadOp has exactly one user, which is a @@ -1136,7 +1136,8 @@ static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) { // Inititalize vector with padding value. Location loc = xferOp.getLoc(); - return b.create(loc, xferOp.getVectorType(), xferOp.padding()); + return b.create(loc, xferOp.getVectorType(), + xferOp.padding()); } }; diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -247,6 +247,23 @@ } }; +class VectorSplatPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType dstVecType = op.getType(); + if (!spirv::CompositeType::isValid(dstVecType)) + return failure(); + SmallVector source(dstVecType.getNumElements(), adaptor.input()); + rewriter.replaceOpWithNewOp(op, dstVecType, + source); + return success(); + } +}; + } // namespace void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, @@ -255,6 +272,6 @@ VectorExtractElementOpConvert, VectorExtractOpConvert, VectorExtractStridedSliceOpConvert, VectorFmaOpConvert, VectorInsertElementOpConvert, VectorInsertOpConvert, - VectorInsertStridedSliceOpConvert>(typeConverter, - patterns.getContext()); + VectorInsertStridedSliceOpConvert, VectorSplatPattern>( + typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -946,35 +946,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// SplatOp -//===----------------------------------------------------------------------===// - -static LogicalResult verify(SplatOp op) { - // TODO: we could replace this by a trait. - if (op.getOperand().getType() != - op.getType().cast().getElementType()) - return op.emitError("operand should be of elemental type of result type"); - - return success(); -} - -// Constant folding hook for SplatOp. -OpFoldResult SplatOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "splat takes one operand"); - - auto constOperand = operands.front(); - if (!constOperand || !constOperand.isa()) - return {}; - - auto shapedType = getType().cast(); - assert(shapedType.getElementType() == constOperand.getType() && - "incorrect input attribute type for folding"); - - // SplatElementsAttr::get treats single value for second arg as being a splat. - return SplatElementsAttr::get(shapedType, {constOperand}); -} - //===----------------------------------------------------------------------===// // SwitchOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1735,6 +1735,25 @@ return {}; } +//===----------------------------------------------------------------------===// +// SplatOp +//===----------------------------------------------------------------------===// + +OpFoldResult SplatOp::fold(ArrayRef operands) { + assert(operands.size() == 1 && "splat takes one operand"); + + auto constOperand = operands.front(); + if (!constOperand || !constOperand.isa()) + return {}; + + TensorType tensorType = getType(); + assert(tensorType.getElementType() == constOperand.getType() && + "incorrect input attribute type for folding"); + + // SplatElementsAttr::get treats single value for second arg as being a splat. + return SplatElementsAttr::get(tensorType, {constOperand}); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -8,7 +8,6 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/VectorUtils.h" diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -14,7 +14,6 @@ #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" @@ -2477,7 +2476,7 @@ auto splat = op.vector().getDefiningOp(); if (!splat) return failure(); - rewriter.replaceOpWithNewOp(op, op.getType(), splat.getInput()); + rewriter.replaceOpWithNewOp(op, op.getType(), splat.input()); return success(); } }; @@ -4272,5 +4271,28 @@ patterns.getContext()); } +//===----------------------------------------------------------------------===// +// SplatOp +//===----------------------------------------------------------------------===// + +OpFoldResult SplatOp::fold(ArrayRef operands) { + assert(operands.size() == 1 && "splat takes one operand"); + + auto constOperand = operands.front(); + if (!constOperand || !constOperand.isa()) + return {}; + + VectorType vectorType = getType(); + assert(vectorType.getElementType() == constOperand.getType() && + "incorrect input attribute type for folding"); + + // SplatElementsAttr::get treats single value for second arg as being a splat. + return SplatElementsAttr::get(vectorType, {constOperand}); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + #define GET_OP_CLASSES #include "mlir/Dialect/Vector/VectorOps.cpp.inc" diff --git a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp @@ -10,8 +10,8 @@ // transfer_write ops. // //===----------------------------------------------------------------------===// + #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorTransforms.h" #include "mlir/Dialect/Vector/VectorUtils.h" diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -17,7 +17,6 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/VectorTransforms.h" @@ -205,7 +204,7 @@ // Scalar to any vector can use splat. if (!srcType) { - rewriter.replaceOpWithNewOp(op, dstType, op.source()); + rewriter.replaceOpWithNewOp(op, dstType, op.source()); return success(); } @@ -220,7 +219,7 @@ ext = rewriter.create(loc, op.source()); else ext = rewriter.create(loc, op.source(), 0); - rewriter.replaceOpWithNewOp(op, dstType, ext); + rewriter.replaceOpWithNewOp(op, dstType, ext); return success(); } @@ -1735,7 +1734,7 @@ // Create vector load op. Operation *loadOp; if (read.mask()) { - Value fill = rewriter.create( + Value fill = rewriter.create( read.getLoc(), unbroadcastedVectorType, read.padding()); loadOp = rewriter.create( read.getLoc(), unbroadcastedVectorType, read.source(), read.indices(), @@ -2168,12 +2167,13 @@ // Add in an offset if requested. if (off) { Value o = createCastToIndexLike(rewriter, loc, idxType, *off); - Value ov = rewriter.create(loc, indices.getType(), o); + Value ov = rewriter.create(loc, indices.getType(), o); indices = rewriter.create(loc, ov, indices); } // Construct the vector comparison. Value bound = createCastToIndexLike(rewriter, loc, idxType, b); - Value bounds = rewriter.create(loc, indices.getType(), bound); + Value bounds = + rewriter.create(loc, indices.getType(), bound); return rewriter.create(loc, arith::CmpIPredicate::slt, indices, bounds); } diff --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp --- a/mlir/lib/Dialect/Vector/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp @@ -15,7 +15,6 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Builders.h" diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -456,36 +456,6 @@ // ----- -// CHECK-LABEL: @splat_0d -// CHECK-SAME: %[[ARG:.*]]: f32 -func @splat_0d(%a: f32) -> vector { - %v = splat %a : vector - return %v : vector -} -// CHECK-NEXT: %[[UNDEF:[0-9]+]] = llvm.mlir.undef : vector<1xf32> -// CHECK-NEXT: %[[ZERO:[0-9]+]] = llvm.mlir.constant(0 : i32) : i32 -// CHECK-NEXT: %[[V:[0-9]+]] = llvm.insertelement %[[ARG]], %[[UNDEF]][%[[ZERO]] : i32] : vector<1xf32> -// CHECK-NEXT: llvm.return %[[V]] : vector<1xf32> - -// ----- - -// CHECK-LABEL: @splat -// CHECK-SAME: %[[A:arg[0-9]+]]: vector<4xf32> -// CHECK-SAME: %[[ELT:arg[0-9]+]]: f32 -func @splat(%a: vector<4xf32>, %b: f32) -> vector<4xf32> { - %vb = splat %b : vector<4xf32> - %r = arith.mulf %a, %vb : vector<4xf32> - return %r : vector<4xf32> -} -// CHECK-NEXT: %[[UNDEF:[0-9]+]] = llvm.mlir.undef : vector<4xf32> -// CHECK-NEXT: %[[ZERO:[0-9]+]] = llvm.mlir.constant(0 : i32) : i32 -// CHECK-NEXT: %[[V:[0-9]+]] = llvm.insertelement %[[ELT]], %[[UNDEF]][%[[ZERO]] : i32] : vector<4xf32> -// CHECK-NEXT: %[[SPLAT:[0-9]+]] = llvm.shufflevector %[[V]], %[[UNDEF]] [0 : i32, 0 : i32, 0 : i32, 0 : i32] -// CHECK-NEXT: %[[SCALE:[0-9]+]] = llvm.fmul %[[A]], %[[SPLAT]] : vector<4xf32> -// CHECK-NEXT: llvm.return %[[SCALE]] : vector<4xf32> - -// ----- - // CHECK-LABEL: func @generic_atomic_rmw func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) -> i32 { %x = generic_atomic_rmw %I[%i] : memref<10xi32> { diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -921,21 +921,6 @@ // ----- -//===----------------------------------------------------------------------===// -// splat -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func @splat -// CHECK-SAME: (%[[A:.+]]: f32) -// CHECK: %[[VAL:.+]] = spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32> -// CHECK: spv.ReturnValue %[[VAL]] -func @splat(%f : f32) -> vector<4xf32> { - %splat = splat %f : vector<4xf32> - return %splat : vector<4xf32> -} - -// ----- - //===----------------------------------------------------------------------===// // std.br, std.cond_br //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir @@ -5,17 +5,19 @@ // CMP32-SAME: %[[ARG:.*]]: index) // CMP32: %[[T0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : vector<11xi32> // CMP32: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i32 -// CMP32: %[[T2:.*]] = splat %[[T1]] : vector<11xi32> -// CMP32: %[[T3:.*]] = arith.cmpi slt, %[[T0]], %[[T2]] : vector<11xi32> -// CMP32: return %[[T3]] : vector<11xi1> +// CMP32: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<11xi32> +// CMP32: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<11xi32>, vector<11xi32> +// CMP32: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<11xi32> +// CMP32: return %[[T4]] : vector<11xi1> // CMP64-LABEL: @genbool_var_1d( // CMP64-SAME: %[[ARG:.*]]: index) // CMP64: %[[T0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : vector<11xi64> // CMP64: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i64 -// CMP64: %[[T2:.*]] = splat %[[T1]] : vector<11xi64> -// CMP64: %[[T3:.*]] = arith.cmpi slt, %[[T0]], %[[T2]] : vector<11xi64> -// CMP64: return %[[T3]] : vector<11xi1> +// CMP64: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<11xi64> +// CMP64: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32] : vector<11xi64>, vector<11xi64> +// CMP64: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<11xi64> +// CMP64: return %[[T4]] : vector<11xi1> func @genbool_var_1d(%arg0: index) -> vector<11xi1> { %0 = vector.create_mask %arg0 : vector<11xi1> diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -55,8 +55,9 @@ } // CHECK-LABEL: @broadcast_vec0d_from_f32 // CHECK-SAME: %[[A:.*]]: f32) -// CHECK: %[[T0:.*]] = splat %[[A]] : vector -// CHECK: return %[[T0]] : vector +// CHECK: %[[T0:.*]] = llvm.insertelement %[[A]] +// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<1xf32> to vector +// CHECK: return %[[T1]] : vector // ----- @@ -76,8 +77,9 @@ } // CHECK-LABEL: @broadcast_vec1d_from_f32 // CHECK-SAME: %[[A:.*]]: f32) -// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2xf32> -// CHECK: return %[[T0]] : vector<2xf32> +// CHECK: %[[T0:.*]] = llvm.insertelement %[[A]] +// CHECK: %[[T1:.*]] = llvm.shufflevector %[[T0]] +// CHECK: return %[[T1]] : vector<2xf32> // ----- @@ -87,8 +89,11 @@ } // CHECK-LABEL: @broadcast_vec1d_from_index // CHECK-SAME: %[[A:.*]]: index) -// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2xindex> -// CHECK: return %[[T0]] : vector<2xindex> +// CHECK: %[[A1:.*]] = builtin.unrealized_conversion_cast %[[A]] : index to i64 +// CHECK: %[[T0:.*]] = llvm.insertelement %[[A1]] +// CHECK: %[[T1:.*]] = llvm.shufflevector %[[T0]] +// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<2xi64> to vector<2xindex> +// CHECK: return %[[T2]] : vector<2xindex> // ----- @@ -98,8 +103,12 @@ } // CHECK-LABEL: @broadcast_vec2d_from_scalar( // CHECK-SAME: %[[A:.*]]: f32) -// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2x3xf32> -// CHECK: return %[[T0]] : vector<2x3xf32> +// CHECK: %[[T0:.*]] = llvm.insertelement %[[A]] +// CHECK: %[[T1:.*]] = llvm.shufflevector %[[T0]] +// CHECK: %[[T2:.*]] = llvm.insertvalue %[[T1]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>> +// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>> +// CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : !llvm.array<2 x vector<3xf32>> to vector<2x3xf32> +// CHECK: return %[[T4]] : vector<2x3xf32> // ----- @@ -109,8 +118,13 @@ } // CHECK-LABEL: @broadcast_vec3d_from_scalar( // CHECK-SAME: %[[A:.*]]: f32) -// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2x3x4xf32> -// CHECK: return %[[T0]] : vector<2x3x4xf32> +// CHECK: %[[T0:.*]] = llvm.insertelement %[[A]] +// CHECK: %[[T1:.*]] = llvm.shufflevector %[[T0]] +// CHECK: %[[T2:.*]] = llvm.insertvalue %[[T1]], %{{.*}}[0, 0] : !llvm.array<2 x array<3 x vector<4xf32>>> +// ... +// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %{{.*}}[1, 2] : !llvm.array<2 x array<3 x vector<4xf32>>> +// CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : !llvm.array<2 x array<3 x vector<4xf32>>> to vector<2x3x4xf32> +// CHECK: return %[[T4]] : vector<2x3x4xf32> // ----- @@ -135,7 +149,8 @@ // CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>> // CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : i64 // CHECK: %[[T5:.*]] = llvm.extractelement %[[T0]][%[[T4]] : i64] : vector<1xf32> -// CHECK: %[[T6:.*]] = splat %[[T5]] : vector<2xf32> +// CHECK: %[[T6Insert:.*]] = llvm.insertelement %[[T5]] +// CHECK: %[[T6:.*]] = llvm.shufflevector %[[T6Insert]] // CHECK: %[[T7:.*]] = llvm.insertvalue %[[T6]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>> // CHECK: %[[T8:.*]] = llvm.insertvalue %[[T6]], %[[T7]][1] : !llvm.array<3 x vector<2xf32>> // CHECK: %[[T9:.*]] = llvm.insertvalue %[[T6]], %[[T8]][2] : !llvm.array<3 x vector<2xf32>> @@ -228,8 +243,9 @@ // CHECK-SAME: %[[A:.*]]: vector<1xf32>) // CHECK: %[[T1:.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[T2:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T1]] : i64] : vector<1xf32> -// CHECK: %[[T3:.*]] = splat %[[T2]] : vector<4xf32> -// CHECK: return %[[T3]] : vector<4xf32> +// CHECK: %[[T3:.*]] = llvm.insertelement %[[T2]] +// CHECK: %[[T4:.*]] = llvm.shufflevector %[[T3]] +// CHECK: return %[[T4]] : vector<4xf32> // ----- @@ -263,22 +279,26 @@ // CHECK: %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm.array<4 x vector<1xf32>> // CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[T5:.*]] = llvm.extractelement %[[T3]]{{\[}}%[[T4]] : i64] : vector<1xf32> -// CHECK: %[[T6:.*]] = splat %[[T5]] : vector<3xf32> +// CHECK: %[[T6Insert:.*]] = llvm.insertelement %[[T5]] +// CHECK: %[[T6:.*]] = llvm.shufflevector %[[T6Insert]] // CHECK: %[[T8:.*]] = llvm.insertvalue %[[T6]], %[[T7]][0] : !llvm.array<4 x vector<3xf32>> // CHECK: %[[T10:.*]] = llvm.extractvalue %[[T2]][1] : !llvm.array<4 x vector<1xf32>> // CHECK: %[[T11:.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[T12:.*]] = llvm.extractelement %[[T10]]{{\[}}%[[T11]] : i64] : vector<1xf32> -// CHECK: %[[T13:.*]] = splat %[[T12]] : vector<3xf32> +// CHECK: %[[T13Insert:.*]] = llvm.insertelement %[[T12]] +// CHECK: %[[T13:.*]] = llvm.shufflevector %[[T13Insert]] // CHECK: %[[T14:.*]] = llvm.insertvalue %[[T13]], %[[T8]][1] : !llvm.array<4 x vector<3xf32>> // CHECK: %[[T16:.*]] = llvm.extractvalue %[[T2]][2] : !llvm.array<4 x vector<1xf32>> // CHECK: %[[T17:.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[T18:.*]] = llvm.extractelement %[[T16]]{{\[}}%[[T17]] : i64] : vector<1xf32> -// CHECK: %[[T19:.*]] = splat %[[T18]] : vector<3xf32> +// CHECK: %[[T19Insert:.*]] = llvm.insertelement %[[T18]] +// CHECK: %[[T19:.*]] = llvm.shufflevector %[[T19Insert]] // CHECK: %[[T20:.*]] = llvm.insertvalue %[[T19]], %[[T14]][2] : !llvm.array<4 x vector<3xf32>> // CHECK: %[[T22:.*]] = llvm.extractvalue %[[T2]][3] : !llvm.array<4 x vector<1xf32>> // CHECK: %[[T23:.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[T24:.*]] = llvm.extractelement %[[T22]]{{\[}}%[[T23]] : i64] : vector<1xf32> -// CHECK: %[[T25:.*]] = splat %[[T24]] : vector<3xf32> +// CHECK: %[[T25Insert:.*]] = llvm.insertelement %[[T24]] +// CHECK: %[[T25:.*]] = llvm.shufflevector %[[T25Insert]] // CHECK: %[[T26:.*]] = llvm.insertvalue %[[T25]], %[[T20]][3] : !llvm.array<4 x vector<3xf32>> // CHECK: %[[T27:.*]] = builtin.unrealized_conversion_cast %[[T26]] : !llvm.array<4 x vector<3xf32>> to vector<4x3xf32> // CHECK: return %[[T27]] : vector<4x3xf32> @@ -332,12 +352,14 @@ // CHECK: %[[T7:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<2x3xf32> to !llvm.array<2 x vector<3xf32>> // CHECK: %[[T3:.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[T4:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T3]] : i64] : vector<2xf32> -// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<3xf32> +// CHECK: %[[T5Insert:.*]] = llvm.insertelement %[[T4]] +// CHECK: %[[T5:.*]] = llvm.shufflevector %[[T5Insert]] // CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32> // CHECK: %[[T8:.*]] = llvm.insertvalue %[[T6]], %[[T7]][0] : !llvm.array<2 x vector<3xf32>> // CHECK: %[[T9:.*]] = llvm.mlir.constant(1 : i64) : i64 // CHECK: %[[T10:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T9]] : i64] : vector<2xf32> -// CHECK: %[[T11:.*]] = splat %[[T10]] : vector<3xf32> +// CHECK: %[[T11Insert:.*]] = llvm.insertelement %[[T10]] +// CHECK: %[[T11:.*]] = llvm.shufflevector %[[T11Insert]] // CHECK: %[[T12:.*]] = arith.mulf %[[T11]], %[[B]] : vector<3xf32> // CHECK: %[[T13:.*]] = llvm.insertvalue %[[T12]], %[[T8]][1] : !llvm.array<2 x vector<3xf32>> // CHECK: %[[T14:.*]] = builtin.unrealized_conversion_cast %[[T13]] : !llvm.array<2 x vector<3xf32>> to vector<2x3xf32> @@ -357,9 +379,10 @@ // CHECK: %[[T8:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<2x3xindex> to !llvm.array<2 x vector<3xi64>> // CHECK: %[[T2:.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[T3:.*]] = llvm.extractelement %[[T1]]{{\[}}%[[T2]] : i64] : vector<2xi64> -// CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : i64 to index -// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<3xindex> -// CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xindex> +// CHECK: %[[T4:.*]] = llvm.insertelement %[[T3]] +// CHECK: %[[T5:.*]] = llvm.shufflevector %[[T4]] +// CHECK: %[[T5Cast:.*]] = builtin.unrealized_conversion_cast %[[T5]] : vector<3xi64> to vector<3xindex> +// CHECK: %[[T6:.*]] = arith.muli %[[T5Cast]], %[[B]] : vector<3xindex> // CHECK: %[[T7:.*]] = builtin.unrealized_conversion_cast %[[T6]] : vector<3xindex> to vector<3xi64> // CHECK: %{{.*}} = llvm.insertvalue %[[T7]], %[[T8]][0] : !llvm.array<2 x vector<3xi64>> @@ -378,13 +401,15 @@ // CHECK: %[[T10:.*]] = builtin.unrealized_conversion_cast %[[T3]] : vector<2x3xf32> to !llvm.array<2 x vector<3xf32>> // CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[T5:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T4]] : i64] : vector<2xf32> -// CHECK: %[[T6:.*]] = splat %[[T5]] : vector<3xf32> +// CHECK: %[[T6Insert:.*]] = llvm.insertelement %[[T5]] +// CHECK: %[[T6:.*]] = llvm.shufflevector %[[T6Insert]] // CHECK: %[[T8:.*]] = llvm.extractvalue %[[T7]][0] : !llvm.array<2 x vector<3xf32>> // CHECK: %[[T9:.*]] = "llvm.intr.fmuladd"(%[[T6]], %[[B]], %[[T8]]) : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> vector<3xf32> // CHECK: %[[T11:.*]] = llvm.insertvalue %[[T9]], %[[T10]][0] : !llvm.array<2 x vector<3xf32>> // CHECK: %[[T12:.*]] = llvm.mlir.constant(1 : i64) : i64 // CHECK: %[[T13:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T12]] : i64] : vector<2xf32> -// CHECK: %[[T14:.*]] = splat %[[T13]] : vector<3xf32> +// CHECK: %[[T14Insert:.*]] = llvm.insertelement %[[T13]] +// CHECK: %[[T14:.*]] = llvm.shufflevector %[[T14Insert]] // CHECK: %[[T16:.*]] = llvm.extractvalue %[[T7]][1] : !llvm.array<2 x vector<3xf32>> // CHECK: %[[T17:.*]] = "llvm.intr.fmuladd"(%[[T14]], %[[B]], %[[T16]]) : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> vector<3xf32> // CHECK: %[[T18:.*]] = llvm.insertvalue %[[T17]], %[[T11]][1] : !llvm.array<2 x vector<3xf32>> @@ -986,8 +1011,7 @@ // CHECK-LABEL: @extract_strided_slice3( // CHECK-SAME: %[[ARG:.*]]: vector<4x8xf32>) // CHECK: %[[A:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x8xf32> to !llvm.array<4 x vector<8xf32>> -// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_2:.*]] = splat %[[VAL_1]] : vector<2x2xf32> +// CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> // CHECK: %[[VAL_6:.*]] = builtin.unrealized_conversion_cast %[[VAL_2]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>> // CHECK: %[[T2:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<8xf32>> // CHECK: %[[T3:.*]] = llvm.shufflevector %[[T2]], %[[T2]] [2, 3] : vector<8xf32>, vector<8xf32> @@ -1233,17 +1257,19 @@ // // 2. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. // CHECK: %[[otrunc:.*]] = arith.index_cast %[[BASE]] : index to i32 -// CHECK: %[[offsetVec:.*]] = splat %[[otrunc]] : vector<17xi32> +// CHECK: %[[offsetVecInsert:.*]] = llvm.insertelement %[[otrunc]] +// CHECK: %[[offsetVec:.*]] = llvm.shufflevector %[[offsetVecInsert]] // CHECK: %[[offsetVec2:.*]] = arith.addi %[[offsetVec]], %[[linearIndex]] : vector<17xi32> // // 3. Let dim the memref dimension, compute the vector comparison mask: // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] // CHECK: %[[dtrunc:.*]] = arith.index_cast %[[DIM]] : index to i32 -// CHECK: %[[dimVec:.*]] = splat %[[dtrunc]] : vector<17xi32> +// CHECK: %[[dimVecInsert:.*]] = llvm.insertelement %[[dtrunc]] +// CHECK: %[[dimVec:.*]] = llvm.shufflevector %[[dimVecInsert]] // CHECK: %[[mask:.*]] = arith.cmpi slt, %[[offsetVec2]], %[[dimVec]] : vector<17xi32> // // 4. Create pass-through vector. -// CHECK: %[[PASS_THROUGH:.*]] = splat %[[c7]] : vector<17xf32> +// CHECK: %[[PASS_THROUGH:.*]] = arith.constant dense<7.{{.*}}> : vector<17xf32> // // 5. Bitcast to vector form. // CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} : @@ -1262,12 +1288,12 @@ // CHECK-SAME: vector<17xi32> // // 2. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. -// CHECK: splat %{{.*}} : vector<17xi32> +// CHECK: llvm.shufflevector %{{.*}} : vector<17xi32> // CHECK: arith.addi // // 3. Let dim the memref dimension, compute the vector comparison mask: // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] -// CHECK: splat %{{.*}} : vector<17xi32> +// CHECK: llvm.shufflevector %{{.*}} : vector<17xi32> // CHECK: %[[mask_b:.*]] = arith.cmpi slt, {{.*}} : vector<17xi32> // // 4. Bitcast to vector form. @@ -1295,8 +1321,7 @@ } // CHECK-LABEL: func @transfer_read_index_1d // CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xindex> -// CHECK: %[[C7:.*]] = arith.constant 7 : index -// CHECK: %[[SPLAT:.*]] = splat %[[C7]] : vector<17xindex> +// CHECK: %[[SPLAT:.*]] = arith.constant dense<7> : vector<17xindex> // CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[SPLAT]] : vector<17xindex> to vector<17xi64> // CHECK: %[[loaded:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : @@ -1321,12 +1346,14 @@ // // Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. // CHECK: %[[trunc:.*]] = arith.index_cast %[[BASE_1]] : index to i32 -// CHECK: %[[offsetVec:.*]] = splat %[[trunc]] : vector<17xi32> +// CHECK: %[[offsetVecInsert:.*]] = llvm.insertelement %[[trunc]] +// CHECK: %[[offsetVec:.*]] = llvm.shufflevector %[[offsetVecInsert]] // // Let dim the memref dimension, compute the vector comparison mask: // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] // CHECK: %[[dimtrunc:.*]] = arith.index_cast %[[DIM]] : index to i32 -// CHECK: splat %[[dimtrunc]] : vector<17xi32> +// CHECK: %[[dimtruncInsert:.*]] = llvm.insertelement %[[dimtrunc]] +// CHECK: llvm.shufflevector %[[dimtruncInsert]] // ----- @@ -1451,9 +1478,11 @@ // CHECK-SAME: %[[arg:.*]]: index // CHECK: %[[indices:.*]] = arith.constant dense<0> : vector // CHECK: %[[arg_i32:.*]] = arith.index_cast %[[arg]] : index to i32 -// CHECK: %[[bounds:.*]] = splat %[[arg_i32]] : vector -// CHECK: %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector +// CHECK: %[[bounds:.*]] = llvm.insertelement %[[arg_i32]] +// CHECK: %[[boundsCast:.*]] = builtin.unrealized_conversion_cast %[[bounds]] : vector<1xi32> to vector +// CHECK: %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[boundsCast]] : vector // CHECK: return %[[result]] : vector + // ----- func @create_mask_1d(%a : index) -> vector<4xi1> { @@ -1465,7 +1494,8 @@ // CHECK-SAME: %[[arg:.*]]: index // CHECK: %[[indices:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32> // CHECK: %[[arg_i32:.*]] = arith.index_cast %[[arg]] : index to i32 -// CHECK: %[[bounds:.*]] = splat %[[arg_i32]] : vector<4xi32> +// CHECK: %[[boundsInsert:.*]] = llvm.insertelement %[[arg_i32]] +// CHECK: %[[bounds:.*]] = llvm.shufflevector %[[boundsInsert]] // CHECK: %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<4xi32> // CHECK: return %[[result]] : vector<4xi1> @@ -1728,3 +1758,34 @@ } // CHECK-LABEL: func @compress_store_op_index // CHECK: "llvm.intr.masked.compressstore"(%{{.*}}, %{{.*}}, %{{.*}}) : (vector<11xi64>, !llvm.ptr, vector<11xi1>) -> () + +// ----- + +// CHECK-LABEL: @splat_0d +// CHECK-SAME: %[[ARG:.*]]: f32 +func @splat_0d(%a: f32) -> vector { + %v = vector.splat %a : vector + return %v : vector +} +// CHECK-NEXT: %[[UNDEF:[0-9]+]] = llvm.mlir.undef : vector<1xf32> +// CHECK-NEXT: %[[ZERO:[0-9]+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK-NEXT: %[[V:[0-9]+]] = llvm.insertelement %[[ARG]], %[[UNDEF]][%[[ZERO]] : i32] : vector<1xf32> +// CHECK-NEXT: %[[VCAST:[0-9]+]] = builtin.unrealized_conversion_cast %[[V]] : vector<1xf32> to vector +// CHECK-NEXT: return %[[VCAST]] : vector + +// ----- + +// CHECK-LABEL: @splat +// CHECK-SAME: %[[A:arg[0-9]+]]: vector<4xf32> +// CHECK-SAME: %[[ELT:arg[0-9]+]]: f32 +func @splat(%a: vector<4xf32>, %b: f32) -> vector<4xf32> { + %vb = vector.splat %b : vector<4xf32> + %r = arith.mulf %a, %vb : vector<4xf32> + return %r : vector<4xf32> +} +// CHECK-NEXT: %[[UNDEF:[0-9]+]] = llvm.mlir.undef : vector<4xf32> +// CHECK-NEXT: %[[ZERO:[0-9]+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK-NEXT: %[[V:[0-9]+]] = llvm.insertelement %[[ELT]], %[[UNDEF]][%[[ZERO]] : i32] : vector<4xf32> +// CHECK-NEXT: %[[SPLAT:[0-9]+]] = llvm.shufflevector %[[V]], %[[UNDEF]] [0 : i32, 0 : i32, 0 : i32, 0 : i32] +// CHECK-NEXT: %[[SCALE:[0-9]+]] = arith.mulf %[[A]], %[[SPLAT]] : vector<4xf32> +// CHECK-NEXT: return %[[SCALE]] : vector<4xf32> diff --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir --- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir @@ -168,3 +168,14 @@ %0 = vector.fma %a, %b, %c: vector<4xf32> return %0 : vector<4xf32> } + +// ----- + +// CHECK-LABEL: func @splat +// CHECK-SAME: (%[[A:.+]]: f32) +// CHECK: %[[VAL:.+]] = spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32> +// CHECK: return %[[VAL]] +func @splat(%f : f32) -> vector<4xf32> { + %splat = vector.splat %f : vector<4xf32> + return %splat : vector<4xf32> +} diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir --- a/mlir/test/Dialect/Standard/ops.mlir +++ b/mlir/test/Dialect/Standard/ops.mlir @@ -50,10 +50,3 @@ ^bb3(%bb3arg : i32): return } - -// CHECK-LABEL: func @vector_splat_0d( -func @vector_splat_0d(%a: f32) -> vector { - // CHECK: splat %{{.*}} : vector - %0 = splat %a : vector - return %0 : vector -} diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1178,3 +1178,15 @@ return %0 : tensor<2x3x4xf32> } + +// ----- + +// CHECK-LABEL: func @splat_fold +func @splat_fold() -> tensor<4xf32> { + %c = arith.constant 1.0 : f32 + %t = tensor.splat %c : tensor<4xf32> + return %t : tensor<4xf32> + + // CHECK-NEXT: [[T:%.*]] = arith.constant dense<1.000000e+00> : tensor<4xf32> + // CHECK-NEXT: return [[T]] : tensor<4xf32> +} diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -363,3 +363,18 @@ return %0 : tensor } +// ----- + +func @invalid_splat(%v : f32) { + // expected-error@+1 {{invalid kind of type specified}} + tensor.splat %v : memref<8xf32> + return +} + +// ----- + +func @invalid_splat(%v : vector<8xf32>) { + // expected-error@+1 {{must be integer/index/float type}} + %w = tensor.splat %v : tensor<8xvector<8xf32>> + return +} \ No newline at end of file diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -250,3 +250,13 @@ // ----- +// CHECK-LABEL: func @test_splat_op +// CHECK-SAME: [[S:%arg[0-9]+]]: f32 +func @test_splat_op(%s : f32) { + // CHECK: tensor.splat [[S]] : tensor<8xf32> + %v = tensor.splat %s : tensor<8xf32> + + // CHECK: tensor.splat [[S]] : tensor<4xf32> + %u = "tensor.splat"(%s) : (f32) -> tensor<4xf32> + return +} diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -515,7 +515,7 @@ // CHECK-SAME: %[[A:.*]]: f32 // CHECK: return %[[A]] : f32 func @fold_extract_splat(%a : f32) -> f32 { - %b = splat %a : vector<1x2x4xf32> + %b = vector.splat %a : vector<1x2x4xf32> %r = vector.extract %b[0, 1, 2] : vector<1x2x4xf32> return %r : f32 } @@ -1121,10 +1121,10 @@ // ----- // CHECK-LABEL: extract_strided_splat -// CHECK: %[[B:.*]] = splat %{{.*}} : vector<2x4xf16> +// CHECK: %[[B:.*]] = vector.splat %{{.*}} : vector<2x4xf16> // CHECK-NEXT: return %[[B]] : vector<2x4xf16> func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> { - %0 = splat %arg0 : vector<16x4xf16> + %0 = vector.splat %arg0 : vector<16x4xf16> %1 = vector.extract_strided_slice %0 {offsets = [1, 0], sizes = [2, 4], strides = [1, 1]} : vector<16x4xf16> to vector<2x4xf16> @@ -1242,3 +1242,15 @@ %1 = vector.extract %0[0] : vector<1x4xf32> return %1 : vector<4xf32> } + +// ----- + +// CHECK-LABEL: func @splat_fold +func @splat_fold() -> vector<4xf32> { + %c = arith.constant 1.0 : f32 + %v = vector.splat %c : vector<4xf32> + return %v : vector<4xf32> + + // CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> + // CHECK-NEXT: return [[V]] : vector<4xf32> +} diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -300,7 +300,7 @@ func @test_vector.transfer_read(%arg0: vector<4x3xf32>) { %c3 = arith.constant 3 : index %f0 = arith.constant 0.0 : f32 - %vf0 = splat %f0 : vector<4x3xf32> + %vf0 = vector.splat %f0 : vector<4x3xf32> // expected-error@+1 {{ requires memref or ranked tensor type}} %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : vector<4x3xf32>, vector<1x1x2x3xf32> } @@ -310,7 +310,7 @@ func @test_vector.transfer_read(%arg0: memref<4x3xf32>) { %c3 = arith.constant 3 : index %f0 = arith.constant 0.0 : f32 - %vf0 = splat %f0 : vector<4x3xf32> + %vf0 = vector.splat %f0 : vector<4x3xf32> // expected-error@+1 {{ requires vector type}} %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref<4x3xf32>, f32 } @@ -376,7 +376,7 @@ %c3 = arith.constant 3 : index %cst = arith.constant 3.0 : f32 // expected-note@+1 {{prior use here}} - %mask = splat %c1 : vector<3x8x7xi1> + %mask = vector.splat %c1 : vector<3x8x7xi1> // expected-error@+1 {{expects different type than prior uses: 'vector<3x7xi1>' vs 'vector<3x8x7xi1>'}} %0 = vector.transfer_read %arg0[%c3, %c3, %c3], %cst, %mask {permutation_map = affine_map<(d0, d1, d2)->(d0, 0, d2)>} : memref, vector<3x8x7xf32> } @@ -386,7 +386,7 @@ func @test_vector.transfer_read(%arg0: memref>) { %c3 = arith.constant 3 : index %f0 = arith.constant 0.0 : f32 - %vf0 = splat %f0 : vector<4x3xf32> + %vf0 = vector.splat %f0 : vector<4x3xf32> // expected-error@+1 {{requires source vector element and vector result ranks to match}} %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref>, vector<3xf32> } @@ -396,7 +396,7 @@ func @test_vector.transfer_read(%arg0: memref>) { %c3 = arith.constant 3 : index %f0 = arith.constant 0.0 : f32 - %vf0 = splat %f0 : vector<6xf32> + %vf0 = vector.splat %f0 : vector<6xf32> // expected-error@+1 {{requires the bitwidth of the minor 1-D vector to be an integral multiple of the bitwidth of the minor 1-D vector of the source}} %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref>, vector<3xf32> } @@ -406,7 +406,7 @@ func @test_vector.transfer_read(%arg0: memref>) { %c3 = arith.constant 3 : index %f0 = arith.constant 0.0 : f32 - %vf0 = splat %f0 : vector<2x3xf32> + %vf0 = vector.splat %f0 : vector<2x3xf32> // expected-error@+1 {{ expects the optional in_bounds attr of same rank as permutation_map results: affine_map<(d0, d1) -> (d0, d1)>}} %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {in_bounds = [true], permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref>, vector<1x1x2x3xf32> } @@ -416,7 +416,7 @@ func @test_vector.transfer_read(%arg0: memref>) { %c3 = arith.constant 3 : index %f0 = arith.constant 0.0 : f32 - %vf0 = splat %f0 : vector<2x3xf32> + %vf0 = vector.splat %f0 : vector<2x3xf32> // expected-error@+1 {{requires broadcast dimensions to be in-bounds}} %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {in_bounds = [false, true], permutation_map = affine_map<(d0, d1)->(0, d1)>} : memref>, vector<1x1x2x3xf32> } @@ -426,8 +426,8 @@ func @test_vector.transfer_read(%arg0: memref>) { %c3 = arith.constant 3 : index %f0 = arith.constant 0.0 : f32 - %vf0 = splat %f0 : vector<2x3xf32> - %mask = splat %c1 : vector<2x3xi1> + %vf0 = vector.splat %f0 : vector<2x3xf32> + %mask = vector.splat %c1 : vector<2x3xi1> // expected-error@+1 {{does not support masks with vector element type}} %0 = vector.transfer_read %arg0[%c3, %c3], %vf0, %mask {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref>, vector<1x1x2x3xf32> } @@ -446,7 +446,7 @@ func @test_vector.transfer_write(%arg0: memref>) { %c3 = arith.constant 3 : index %f0 = arith.constant 0.0 : f32 - %vf0 = splat %f0 : vector<4x3xf32> + %vf0 = vector.splat %f0 : vector<4x3xf32> // expected-error@+1 {{ requires vector type}} vector.transfer_write %arg0, %arg0[%c3, %c3] : memref>, vector<4x3xf32> } @@ -456,7 +456,7 @@ func @test_vector.transfer_write(%arg0: vector<4x3xf32>) { %c3 = arith.constant 3 : index %f0 = arith.constant 0.0 : f32 - %vf0 = splat %f0 : vector<4x3xf32> + %vf0 = vector.splat %f0 : vector<4x3xf32> // expected-error@+1 {{ requires memref or ranked tensor type}} vector.transfer_write %arg0, %arg0[%c3, %c3] : vector<4x3xf32>, f32 } @@ -1480,3 +1480,11 @@ %0 = vector.insert_map %v, %v1[%id] : vector<2x1xf32> into vector<4x32xf32> } + +// ----- + +func @invalid_splat(%v : f32) { + // expected-error@+1 {{invalid kind of type specified}} + vector.splat %v : memref<8xf32> + return +} diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -45,11 +45,11 @@ %i0 = arith.constant 0 : index %i1 = arith.constant 1 : i1 - %vf0 = splat %f0 : vector<4x3xf32> - %v0 = splat %c0 : vector<4x3xi32> - %vi0 = splat %i0 : vector<4x3xindex> + %vf0 = vector.splat %f0 : vector<4x3xf32> + %v0 = vector.splat %c0 : vector<4x3xi32> + %vi0 = vector.splat %i0 : vector<4x3xindex> %m = arith.constant dense<[0, 0, 1, 0, 1]> : vector<5xi1> - %m2 = splat %i1 : vector<5x4xi1> + %m2 = vector.splat %i1 : vector<5x4xi1> // // CHECK: vector.transfer_read %0 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = affine_map<(d0, d1)->(d0)>} : memref, vector<128xf32> @@ -106,9 +106,9 @@ %c0 = arith.constant 0 : i32 %i0 = arith.constant 0 : index - %vf0 = splat %f0 : vector<4x3xf32> - %v0 = splat %c0 : vector<4x3xi32> - %vi0 = splat %i0 : vector<4x3xindex> + %vf0 = vector.splat %f0 : vector<4x3xf32> + %v0 = vector.splat %c0 : vector<4x3xi32> + %vi0 = vector.splat %i0 : vector<4x3xindex> // // CHECK: vector.transfer_read @@ -717,3 +717,21 @@ %0 = vector.vscale return %0 : index } + +// CHECK-LABEL: func @test_splat_op +// CHECK-SAME: [[S:%arg[0-9]+]]: f32 +func @test_splat_op(%s : f32) { + // CHECK: vector.splat [[S]] : vector<8xf32> + %v = vector.splat %s : vector<8xf32> + + // CHECK: vector.splat [[S]] : vector<4xf32> + %u = "vector.splat"(%s) : (f32) -> vector<4xf32> + return +} + +// CHECK-LABEL: func @vector_splat_0d( +func @vector_splat_0d(%a: f32) -> vector { + // CHECK: vector.splat %{{.*}} : vector + %0 = vector.splat %a : vector + return %0 : vector +} diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -268,11 +268,11 @@ // CHECK-SAME: %[[B:.*1]]: vector<3xf32> // CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> -// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xf32> +// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> // CHECK: %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> // CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xf32> -// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<3xf32> +// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32> // CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32> // CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32> // CHECK: return %[[T7]] : vector<2x3xf32> @@ -289,12 +289,12 @@ // CHECK-SAME: %[[C:.*2]]: vector<2x3xf32> // CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> -// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xf32> +// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> // CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xf32> // CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32> // CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> // CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2xf32> -// CHECK: %[[T6:.*]] = splat %[[T5]] : vector<3xf32> +// CHECK: %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32> // CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<2x3xf32> // CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32> // CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32> @@ -312,11 +312,11 @@ // CHECK-SAME: %[[B:.*1]]: vector<3xi32> // CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32> -// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xi32> +// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> // CHECK: %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> // CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xi32> -// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<3xi32> +// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32> // CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32> // CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32> // CHECK: return %[[T7]] : vector<2x3xi32> @@ -332,13 +332,13 @@ // CHECK-SAME: %[[C:.*2]]: vector<2x3xi32> // CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32> -// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xi32> +// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> // CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xi32> // CHECK: %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> // CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32> // CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> // CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xi32> -// CHECK: %[[T7:.*]] = splat %[[T6]] : vector<3xi32> +// CHECK: %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32> // CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<2x3xi32> // CHECK: %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32> // CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32> @@ -354,7 +354,7 @@ // CHECK-LABEL: func @axpy_fp( // CHECK-SAME: %[[A:.*0]]: vector<16xf32>, // CHECK-SAME: %[[B:.*1]]: f32) -// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xf32> +// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32> // CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32> // CHECK: return %[[T1]] : vector<16xf32> func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> { @@ -366,7 +366,7 @@ // CHECK-SAME: %[[A:.*0]]: vector<16xf32>, // CHECK-SAME: %[[B:.*1]]: f32, // CHECK-SAME: %[[C:.*2]]: vector<16xf32>) -// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xf32> +// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32> // CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32> // CHECK: return %[[T1]] : vector<16xf32> func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> { @@ -377,7 +377,7 @@ // CHECK-LABEL: func @axpy_int( // CHECK-SAME: %[[A:.*0]]: vector<16xi32>, // CHECK-SAME: %[[B:.*1]]: i32) -// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xi32> +// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32> // CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> // CHECK: return %[[T1]] : vector<16xi32> func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> { @@ -389,7 +389,7 @@ // CHECK-SAME: %[[A:.*0]]: vector<16xi32>, // CHECK-SAME: %[[B:.*1]]: i32, // CHECK-SAME: %[[C:.*2]]: vector<16xi32>) -// CHECK: %[[T0:.*]] = splat %[[B]] : vector<16xi32> +// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32> // CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> // CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32> // CHECK: return %[[T2]] : vector<16xi32> @@ -612,7 +612,7 @@ // CHECK-LABEL: func @broadcast_vec1d_from_scalar // CHECK-SAME: %[[A:.*0]]: f32 -// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2xf32> +// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2xf32> // CHECK: return %[[T0]] : vector<2xf32> func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> { @@ -622,7 +622,7 @@ // CHECK-LABEL: func @broadcast_vec2d_from_scalar // CHECK-SAME: %[[A:.*0]]: f32 -// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2x3xf32> +// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3xf32> // CHECK: return %[[T0]] : vector<2x3xf32> func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> { @@ -632,7 +632,7 @@ // CHECK-LABEL: func @broadcast_vec3d_from_scalar // CHECK-SAME: %[[A:.*0]]: f32 -// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2x3x4xf32> +// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3x4xf32> // CHECK: return %[[T0]] : vector<2x3x4xf32> func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> { @@ -697,7 +697,7 @@ // CHECK-LABEL: func @broadcast_stretch // CHECK-SAME: %[[A:.*0]]: vector<1xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<1xf32> -// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<4xf32> +// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<4xf32> // CHECK: return %[[T1]] : vector<4xf32> func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> { @@ -723,16 +723,16 @@ // CHECK-SAME: %[[A:.*0]]: vector<4x1xf32> // CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<4x1xf32> -// CHECK: %[[T2:.*]] = splat %[[T0]] : vector<3xf32> +// CHECK: %[[T2:.*]] = vector.splat %[[T0]] : vector<3xf32> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<4x3xf32> // CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<4x1xf32> -// CHECK: %[[T6:.*]] = splat %[[T4]] : vector<3xf32> +// CHECK: %[[T6:.*]] = vector.splat %[[T4]] : vector<3xf32> // CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32> // CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<4x1xf32> -// CHECK: %[[T10:.*]] = splat %[[T8]] : vector<3xf32> +// CHECK: %[[T10:.*]] = vector.splat %[[T8]] : vector<3xf32> // CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32> // CHECK: %[[T12:.*]] = vector.extract %[[A]][3, 0] : vector<4x1xf32> -// CHECK: %[[T14:.*]] = splat %[[T12]] : vector<3xf32> +// CHECK: %[[T14:.*]] = vector.splat %[[T12]] : vector<3xf32> // CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32> // CHECK: return %[[T15]] : vector<4x3xf32> diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -282,19 +282,19 @@ %c0 = arith.constant 0 : index %m = arith.constant 1 : i1 - %mask0 = splat %m : vector<7x14xi1> + %mask0 = vector.splat %m : vector<7x14xi1> %0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask0 {in_bounds = [true, false, true, true], permutation_map = #map0} : memref, vector<7x14x8x16xf32> // CHECK: %[[MASK0:.*]] = vector.transpose {{.*}} : vector<7x14xi1> to vector<14x7xi1> // CHECK: vector.transfer_read {{.*}} %[[MASK0]] {in_bounds = [false, true, true, true], permutation_map = #[[$MAP0]]} : memref, vector<14x7x8x16xf32> // CHECK: vector.transpose %{{.*}}, [1, 0, 2, 3] : vector<14x7x8x16xf32> to vector<7x14x8x16xf32> - %mask1 = splat %m : vector<14x16xi1> + %mask1 = vector.splat %m : vector<14x16xi1> %1 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask1 {permutation_map = #map1} : memref, vector<7x14x8x16xf32> // CHECK: %[[MASK1:.*]] = vector.transpose {{.*}} : vector<14x16xi1> to vector<16x14xi1> // CHECK: vector.transfer_read {{.*}} %[[MASK1]] {permutation_map = #[[$MAP0]]} : memref, vector<16x14x7x8xf32> // CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32> - %mask2 = splat %m : vector<7x14xi1> + %mask2 = vector.splat %m : vector<7x14xi1> %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask2 {in_bounds = [true, false, true, true], permutation_map = #map2} : memref, vector<7x14x8x16xf32> // CHECK: %[[MASK2:.*]] = vector.transpose {{.*}} : vector<7x14xi1> to vector<14x7xi1> // CHECK: vector.transfer_read {{.*}} %[[MASK2]] {in_bounds = [false, true, true], permutation_map = #[[$MAP1]]} : memref, vector<14x16x7xf32> @@ -333,7 +333,7 @@ %c0 = arith.constant 0 : index %m = arith.constant 1 : i1 - %mask0 = splat %m : vector<7x14x8x16xi1> + %mask0 = vector.splat %m : vector<7x14x8x16xi1> vector.transfer_write %v1, %arg0[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, memref // CHECK: %[[NEW_MASK0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xi1> to vector<8x14x16x7xi1> // CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xf32> to vector<8x14x16x7xf32> diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -295,18 +295,6 @@ return } -// CHECK-LABEL: func @test_splat_op -// CHECK-SAME: [[S:%arg[0-9]+]]: f32 -func @test_splat_op(%s : f32) { - %v = splat %s : vector<8xf32> - // CHECK: splat [[S]] : vector<8xf32> - %t = splat %s : tensor<8xf32> - // CHECK: splat [[S]] : tensor<8xf32> - %u = "std.splat"(%s) : (f32) -> vector<4xf32> - // CHECK: splat [[S]] : vector<4xf32> - return -} - // CHECK-LABEL: func @tensor_load_store func @tensor_load_store(%0 : memref<4x4xi32>, %1 : tensor<4x4xi32>) { // CHECK-SAME: (%[[MEMREF:.*]]: memref<4x4xi32>, diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -106,24 +106,8 @@ // ----- -func @invalid_splat(%v : f32) { - splat %v : memref<8xf32> - // expected-error@-1 {{must be vector of any type values or statically shaped tensor of any type values}} - return -} - -// ----- - -func @invalid_splat(%v : vector<8xf32>) { - %w = splat %v : tensor<8xvector<8xf32>> - // expected-error@-1 {{must be integer/index/float type}} - return -} - -// ----- - func @invalid_splat(%v : f32) { // expected-note {{prior use here}} - splat %v : vector<8xf64> + vector.splat %v : vector<8xf64> // expected-error@-1 {{expects different type than prior uses}} return } diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir --- a/mlir/test/Transforms/constant-fold.mlir +++ b/mlir/test/Transforms/constant-fold.mlir @@ -789,18 +789,6 @@ return } -// CHECK-LABEL: func @splat_fold -func @splat_fold() -> (vector<4xf32>, tensor<4xf32>) { - %c = arith.constant 1.0 : f32 - %v = splat %c : vector<4xf32> - %t = splat %c : tensor<4xf32> - return %v, %t : vector<4xf32>, tensor<4xf32> - - // CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> - // CHECK-NEXT: [[T:%.*]] = arith.constant dense<1.000000e+00> : tensor<4xf32> - // CHECK-NEXT: return [[V]], [[T]] : vector<4xf32>, tensor<4xf32> -} - // ----- // CHECK-LABEL: func @subview_scalar_fold diff --git a/mlir/test/mlir-cpu-runner/utils.mlir b/mlir/test/mlir-cpu-runner/utils.mlir --- a/mlir/test/mlir-cpu-runner/utils.mlir +++ b/mlir/test/mlir-cpu-runner/utils.mlir @@ -56,7 +56,7 @@ func @vector_splat_2d() { %c0 = arith.constant 0 : index %f10 = arith.constant 10.0 : f32 - %vf10 = splat %f10: !vector_type_C + %vf10 = vector.splat %f10: !vector_type_C %C = memref.alloc() : !matrix_type_CC memref.store %vf10, %C[%c0, %c0]: !matrix_type_CC