diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -470,6 +470,16 @@ ConversionPatternRewriter &rewriter) const override; }; +/// Converts std.splat to spv.CompositeConstruct. +class SplatPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(SplatOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + /// Converts memref.store to spv.Store on integers. class IntStoreOpPattern final : public OpConversionPattern { public: @@ -1127,6 +1137,23 @@ return success(); } +//===----------------------------------------------------------------------===// +// SplatOp +//===----------------------------------------------------------------------===// + +LogicalResult +SplatPattern::matchAndRewrite(SplatOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + auto dstVecType = op.getType().dyn_cast(); + if (!dstVecType || !spirv::CompositeType::isValid(dstVecType)) + return failure(); + SplatOp::Adaptor adaptor(operands); + SmallVector source(dstVecType.getNumElements(), adaptor.input()); + rewriter.replaceOpWithNewOp(op, dstVecType, + source); + return success(); +} + //===----------------------------------------------------------------------===// // StoreOp //===----------------------------------------------------------------------===// @@ -1332,7 +1359,7 @@ AllocOpPattern, DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern, StoreOpPattern, - ReturnOpPattern, SelectOpPattern, + ReturnOpPattern, SelectOpPattern, SplatPattern, // Type cast patterns UIToFPI1Pattern, ZeroExtendI1Pattern, TruncI1Pattern, diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -1249,3 +1249,18 @@ // CHECK: spv.ReturnValue %[[VAL]] return %extract : i32 } + +// ----- + +//===----------------------------------------------------------------------===// +// splat +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @splat +// CHECK-SAME: (%[[A:.+]]: f32) +// CHECK: %[[VAL:.+]] = spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32> +// CHECK: spv.ReturnValue %[[VAL]] +func @splat(%f : f32) -> vector<4xf32> { + %splat = splat %f : vector<4xf32> + return %splat : vector<4xf32> +}