diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -372,6 +372,33 @@ }]; } +def Vector_FMAOp : + Op]>, + Arguments<(ins VectorOf<[F32, F64]>:$lhs, + VectorOf<[F32, F64]>:$rhs, + VectorOf<[F32, F64]>:$acc)>, + Results<(outs VectorOf<[F32, F64]>:$result)> { + let summary = "vector fused multiply-add"; + let description = [{ + Multiply-add expressions that operates on n-D f32 or f64 vectors and lower + to the llvm.fmuladd.* intrinsic. + + Example + + %3 = vector.fma %0, %1, %2: vector<8x16xf32> + }]; + // Fully specified by traits. + let verifier = ?; + let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` type($lhs)"; + let builders = [OpBuilder< + "Builder *b, OperationState &result, Value lhs, Value rhs, Value acc", + "build(b, result, lhs.getType(), lhs, rhs, acc);">]; + let extraClassDeclaration = [{ + VectorType getVectorType() { return lhs().getType().cast(); } + }]; +} + def Vector_InsertElementOp : Vector_Op<"insertelement", [NoSideEffect, PredOpTrait<"source operand and result have same element type", diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -410,6 +410,27 @@ } }; +class VectorFMAOpConversion : public LLVMOpLowering { +public: + explicit VectorFMAOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::FMAOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto adaptor = vector::FMAOpOperandAdaptor(operands); + vector::FMAOp fmaOp = cast(op); + VectorType vType = fmaOp.getVectorType(); + if (vType.getRank() != 1) + return matchFailure(); + rewriter.replaceOpWithNewOp(op, adaptor.lhs(), + adaptor.rhs(), adaptor.acc()); + return matchSuccess(); + } +}; + class VectorInsertElementOpConversion : public LLVMOpLowering { public: explicit VectorInsertElementOpConversion(MLIRContext *context, @@ -502,6 +523,34 @@ } }; +// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. +class VectorFMAOpRewritePattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(FMAOp op, + PatternRewriter &rewriter) const override { + auto vType = op.getVectorType(); + if (vType.getRank() < 2) + return matchFailure(); + + auto loc = op.getLoc(); + auto elemType = vType.getElementType(); + Value zero = rewriter.create(loc, elemType, + rewriter.getZeroAttr(elemType)); + Value desc = rewriter.create(loc, vType, zero); + for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { + Value extrLHS = rewriter.create(loc, op.lhs(), i); + Value extrRHS = rewriter.create(loc, op.rhs(), i); + Value extrACC = rewriter.create(loc, op.acc(), i); + Value fma = rewriter.create(loc, extrLHS, extrRHS, extrACC); + desc = rewriter.create(loc, fma, desc, i); + } + rewriter.replaceOp(op, desc); + return matchSuccess(); + } +}; + // When ranks are different, InsertStridedSlice needs to extract a properly // ranked vector from the destination vector into which to insert. This pattern // only takes care of this part and forwards the rest of the conversion to @@ -955,14 +1004,16 @@ void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { MLIRContext *ctx = converter.getDialect()->getContext(); - patterns.insert(ctx); patterns.insert(ctx, converter); + VectorFMAOpConversion, VectorInsertElementOpConversion, + VectorInsertOpConversion, VectorOuterProductOpConversion, + VectorTypeCastOpConversion, VectorPrintOpConversion>( + ctx, converter); } namespace { diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1,612 +1,18 @@ // RUN: mlir-opt %s -convert-vector-to-llvm | FileCheck %s -func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> { - %0 = vector.broadcast %arg0 : f32 to vector<2xf32> - return %0 : vector<2xf32> +// CHECK-LABEL: llvm.func @vector_fma +func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) + -> (vector<8xf32>, vector<2x4xf32>) +{ + // CHECK: llvm.intr.fmuladd{{.*}}: (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) -> !llvm<"<8 x float>"> + %0 = vector.fma %a, %a, %a : vector<8xf32> + // CHECK-COUNT-3: llvm.extractvalue {{.*}}[0] : !llvm<"[2 x <4 x float>]"> + // CHECK: llvm.intr.fmuladd{{.*}} : (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>"> + // CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[2 x <4 x float>]"> + // CHECK-COUNT-3: llvm.extractvalue {{.*}}[1] : !llvm<"[2 x <4 x float>]"> + // CHECK: llvm.intr.fmuladd{{.*}} : (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>"> + // CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[2 x <4 x float>]"> + %1 = vector.fma %b, %b, %b : vector<2x4xf32> + return %0, %1: vector<8xf32>, vector<2x4xf32> } -// CHECK-LABEL: llvm.func @broadcast_vec1d_from_scalar -// CHECK: llvm.mlir.undef : !llvm<"<2 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> -// CHECK: llvm.shufflevector {{.*}}, {{.*}}[0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>"> -// CHECK: llvm.return {{.*}} : !llvm<"<2 x float>"> -func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> { - %0 = vector.broadcast %arg0 : f32 to vector<2x3xf32> - return %0 : vector<2x3xf32> -} -// CHECK-LABEL: llvm.func @broadcast_vec2d_from_scalar -// CHECK: llvm.mlir.undef : !llvm<"<3 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>"> -// CHECK: llvm.shufflevector {{.*}}, {{.*}}[0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> -// CHECK: llvm.mlir.undef : !llvm<"[2 x <3 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[2 x <3 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[2 x <3 x float>]"> -// CHECK: llvm.return {{.*}} : !llvm<"[2 x <3 x float>]"> - -func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> { - %0 = vector.broadcast %arg0 : f32 to vector<2x3x4xf32> - return %0 : vector<2x3x4xf32> -} -// CHECK-LABEL: llvm.func @broadcast_vec3d_from_scalar -// CHECK: llvm.mlir.undef : !llvm<"<4 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> -// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> -// CHECK: llvm.mlir.undef : !llvm<"[3 x <4 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <4 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <4 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <4 x float>]"> -// CHECK: llvm.mlir.undef : !llvm<"[2 x [3 x <4 x float>]]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[2 x [3 x <4 x float>]]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[2 x [3 x <4 x float>]]"> -// CHECK: llvm.return {{.*}} : !llvm<"[2 x [3 x <4 x float>]]"> - -func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> { - %0 = vector.broadcast %arg0 : vector<2xf32> to vector<2xf32> - return %0 : vector<2xf32> -} -// CHECK-LABEL: llvm.func @broadcast_vec1d_from_vec1d -// CHECK: llvm.return {{.*}} : !llvm<"<2 x float>"> - -func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> { - %0 = vector.broadcast %arg0 : vector<2xf32> to vector<3x2xf32> - return %0 : vector<3x2xf32> -} -// CHECK-LABEL: llvm.func @broadcast_vec2d_from_vec1d -// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.return {{.*}} : !llvm<"[3 x <2 x float>]"> - -func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32> { - %0 = vector.broadcast %arg0 : vector<2xf32> to vector<4x3x2xf32> - return %0 : vector<4x3x2xf32> -} -// CHECK-LABEL: llvm.func @broadcast_vec3d_from_vec1d -// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.mlir.undef : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.return {{.*}} : !llvm<"[4 x [3 x <2 x float>]]"> - -func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf32> { - %0 = vector.broadcast %arg0 : vector<3x2xf32> to vector<4x3x2xf32> - return %0 : vector<4x3x2xf32> -} -// CHECK-LABEL: llvm.func @broadcast_vec3d_from_vec2d -// CHECK: llvm.mlir.undef : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.return {{.*}} : !llvm<"[4 x [3 x <2 x float>]]"> - -func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> { - %0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32> - return %0 : vector<4xf32> -} -// CHECK-LABEL: llvm.func @broadcast_stretch -// CHECK: llvm.mlir.undef : !llvm<"<4 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> -// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> -// CHECK: llvm.return {{.*}} : !llvm<"<4 x float>"> - -func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> { - %0 = vector.broadcast %arg0 : vector<1x4xf32> to vector<3x4xf32> - return %0 : vector<3x4xf32> -} -// CHECK-LABEL: llvm.func @broadcast_stretch_at_start -// CHECK: llvm.mlir.undef : !llvm<"[3 x <4 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <4 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <4 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <4 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <4 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <4 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <4 x float>]"> -// CHECK: llvm.return {{.*}} : !llvm<"[3 x <4 x float>]"> - -func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> { - %0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<4x3xf32> - return %0 : vector<4x3xf32> -} -// CHECK-LABEL: llvm.func @broadcast_stretch_at_end -// CHECK: llvm.mlir.undef : !llvm<"[4 x <3 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[4 x <1 x float>]"> -// CHECK: llvm.mlir.undef : !llvm<"<3 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>"> -// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x <3 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"[4 x <1 x float>]"> -// CHECK: llvm.mlir.undef : !llvm<"<3 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>"> -// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[4 x <3 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x <1 x float>]"> -// CHECK: llvm.mlir.undef : !llvm<"<3 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>"> -// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x <3 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <1 x float>]"> -// CHECK: llvm.mlir.undef : !llvm<"<3 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>"> -// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x <3 x float>]"> -// CHECK: llvm.return {{.*}} : !llvm<"[4 x <3 x float>]"> - -func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32> { - %0 = vector.broadcast %arg0 : vector<4x1x2xf32> to vector<4x3x2xf32> - return %0 : vector<4x3x2xf32> -} -// CHECK-LABEL: llvm.func @broadcast_stretch_in_middle -// CHECK: llvm.mlir.undef : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[4 x [1 x <2 x float>]]"> -// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"[4 x [1 x <2 x float>]]"> -// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x [1 x <2 x float>]]"> -// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x [1 x <2 x float>]]"> -// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.return {{.*}} : !llvm<"[4 x [3 x <2 x float>]]"> - -func @outerproduct(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<2x3xf32> { - %2 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32> - return %2 : vector<2x3xf32> -} -// CHECK-LABEL: llvm.func @outerproduct -// CHECK: llvm.mlir.undef : !llvm<"[2 x <3 x float>]"> -// CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>"> -// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<3 x float>"> -// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]"> -// CHECK: llvm.shufflevector {{.*}} [1 : i32, 1 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>"> -// CHECK: llvm.fmul {{.*}}, {{.*}} : !llvm<"<3 x float>"> -// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]"> -// CHECK: llvm.return {{.*}} : !llvm<"[2 x <3 x float>]"> - -func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> { - %2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xf32>, vector<3xf32> - return %2 : vector<2x3xf32> -} -// CHECK-LABEL: llvm.func @outerproduct_add -// CHECK: llvm.mlir.undef : !llvm<"[2 x <3 x float>]"> -// CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]"> -// CHECK: "llvm.intr.fmuladd"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>"> -// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]"> -// CHECK: llvm.shufflevector {{.*}} [1 : i32, 1 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>"> -// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]"> -// CHECK: "llvm.intr.fmuladd"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>"> -// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]"> -// CHECK: llvm.return {{.*}} : !llvm<"[2 x <3 x float>]"> - -func @shuffle_1D_direct(%arg0: vector<2xf32>, %arg1: vector<2xf32>) -> vector<2xf32> { - %1 = vector.shuffle %arg0, %arg1 [0, 1] : vector<2xf32>, vector<2xf32> - return %1 : vector<2xf32> -} -// CHECK-LABEL: llvm.func @shuffle_1D_direct -// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm<"<2 x float>"> -// CHECK-SAME: %[[B:arg[0-9]+]]: !llvm<"<2 x float>"> -// CHECK: %[[s:.*]] = llvm.shufflevector %[[A]], %[[B]] [0, 1] : !llvm<"<2 x float>">, !llvm<"<2 x float>"> -// CHECK: llvm.return %[[s]] : !llvm<"<2 x float>"> - -func @shuffle_1D(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<5xf32> { - %1 = vector.shuffle %arg0, %arg1 [4, 3, 2, 1, 0] : vector<2xf32>, vector<3xf32> - return %1 : vector<5xf32> -} -// CHECK-LABEL: llvm.func @shuffle_1D -// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm<"<2 x float>"> -// CHECK-SAME: %[[B:arg[0-9]+]]: !llvm<"<3 x float>"> -// CHECK: %[[u0:.*]] = llvm.mlir.undef : !llvm<"<5 x float>"> -// CHECK: %[[c2:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 -// CHECK: %[[e1:.*]] = llvm.extractelement %[[B]][%[[c2]] : !llvm.i64] : !llvm<"<3 x float>"> -// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: %[[i1:.*]] = llvm.insertelement %[[e1]], %[[u0]][%[[c0]] : !llvm.i64] : !llvm<"<5 x float>"> -// CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK: %[[e2:.*]] = llvm.extractelement %[[B]][%[[c1]] : !llvm.i64] : !llvm<"<3 x float>"> -// CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK: %[[i2:.*]] = llvm.insertelement %[[e2]], %[[i1]][%[[c1]] : !llvm.i64] : !llvm<"<5 x float>"> -// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: %[[e3:.*]] = llvm.extractelement %[[B]][%[[c0]] : !llvm.i64] : !llvm<"<3 x float>"> -// CHECK: %[[c2:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 -// CHECK: %[[i3:.*]] = llvm.insertelement %[[e3]], %[[i2]][%[[c2]] : !llvm.i64] : !llvm<"<5 x float>"> -// CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK: %[[e4:.*]] = llvm.extractelement %[[A]][%[[c1]] : !llvm.i64] : !llvm<"<2 x float>"> -// CHECK: %[[c3:.*]] = llvm.mlir.constant(3 : index) : !llvm.i64 -// CHECK: %[[i4:.*]] = llvm.insertelement %[[e4]], %[[i3]][%[[c3]] : !llvm.i64] : !llvm<"<5 x float>"> -// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: %[[e5:.*]] = llvm.extractelement %[[A]][%[[c0]] : !llvm.i64] : !llvm<"<2 x float>"> -// CHECK: %[[c4:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64 -// CHECK: %[[i5:.*]] = llvm.insertelement %[[e5]], %[[i4]][%[[c4]] : !llvm.i64] : !llvm<"<5 x float>"> -// CHECK: llvm.return %[[i5]] : !llvm<"<5 x float>"> - -func @shuffle_2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32> { - %1 = vector.shuffle %a, %b[1, 0, 2] : vector<1x4xf32>, vector<2x4xf32> - return %1 : vector<3x4xf32> -} -// CHECK-LABEL: llvm.func @shuffle_2D -// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm<"[1 x <4 x float>]"> -// CHECK-SAME: %[[B:arg[0-9]+]]: !llvm<"[2 x <4 x float>]"> -// CHECK: %[[u0:.*]] = llvm.mlir.undef : !llvm<"[3 x <4 x float>]"> -// CHECK: %[[e1:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]"> -// CHECK: %[[i1:.*]] = llvm.insertvalue %[[e1]], %[[u0]][0] : !llvm<"[3 x <4 x float>]"> -// CHECK: %[[e2:.*]] = llvm.extractvalue %[[A]][0] : !llvm<"[1 x <4 x float>]"> -// CHECK: %[[i2:.*]] = llvm.insertvalue %[[e2]], %[[i1]][1] : !llvm<"[3 x <4 x float>]"> -// CHECK: %[[e3:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]"> -// CHECK: %[[i3:.*]] = llvm.insertvalue %[[e3]], %[[i2]][2] : !llvm<"[3 x <4 x float>]"> -// CHECK: llvm.return %[[i3]] : !llvm<"[3 x <4 x float>]"> - -func @extract_element(%arg0: vector<16xf32>) -> f32 { - %0 = constant 15 : i32 - %1 = vector.extractelement %arg0[%0 : i32]: vector<16xf32> - return %1 : f32 -} -// CHECK-LABEL: llvm.func @extract_element -// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm<"<16 x float>"> -// CHECK: %[[c:.*]] = llvm.mlir.constant(15 : i32) : !llvm.i32 -// CHECK: %[[x:.*]] = llvm.extractelement %[[A]][%[[c]] : !llvm.i32] : !llvm<"<16 x float>"> -// CHECK: llvm.return %[[x]] : !llvm.float - -func @extract_element_from_vec_1d(%arg0: vector<16xf32>) -> f32 { - %0 = vector.extract %arg0[15]: vector<16xf32> - return %0 : f32 -} -// CHECK-LABEL: llvm.func @extract_element_from_vec_1d -// CHECK: llvm.mlir.constant(15 : i64) : !llvm.i64 -// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<16 x float>"> -// CHECK: llvm.return {{.*}} : !llvm.float - -func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32> { - %0 = vector.extract %arg0[0]: vector<4x3x16xf32> - return %0 : vector<3x16xf32> -} -// CHECK-LABEL: llvm.func @extract_vec_2d_from_vec_3d -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[4 x [3 x <16 x float>]]"> -// CHECK: llvm.return {{.*}} : !llvm<"[3 x <16 x float>]"> - -func @extract_vec_1d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<16xf32> { - %0 = vector.extract %arg0[0, 0]: vector<4x3x16xf32> - return %0 : vector<16xf32> -} -// CHECK-LABEL: llvm.func @extract_vec_1d_from_vec_3d -// CHECK: llvm.extractvalue {{.*}}[0, 0] : !llvm<"[4 x [3 x <16 x float>]]"> -// CHECK: llvm.return {{.*}} : !llvm<"<16 x float>"> - -func @extract_element_from_vec_3d(%arg0: vector<4x3x16xf32>) -> f32 { - %0 = vector.extract %arg0[0, 0, 0]: vector<4x3x16xf32> - return %0 : f32 -} -// CHECK-LABEL: llvm.func @extract_element_from_vec_3d -// CHECK: llvm.extractvalue {{.*}}[0, 0] : !llvm<"[4 x [3 x <16 x float>]]"> -// CHECK: llvm.mlir.constant(0 : i64) : !llvm.i64 -// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<16 x float>"> -// CHECK: llvm.return {{.*}} : !llvm.float - -func @insert_element(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> { - %0 = constant 3 : i32 - %1 = vector.insertelement %arg0, %arg1[%0 : i32] : vector<4xf32> - return %1 : vector<4xf32> -} -// CHECK-LABEL: llvm.func @insert_element -// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm.float -// CHECK-SAME: %[[B:arg[0-9]+]]: !llvm<"<4 x float>"> -// CHECK: %[[c:.*]] = llvm.mlir.constant(3 : i32) : !llvm.i32 -// CHECK: %[[x:.*]] = llvm.insertelement %[[A]], %[[B]][%[[c]] : !llvm.i32] : !llvm<"<4 x float>"> -// CHECK: llvm.return %[[x]] : !llvm<"<4 x float>"> - -func @insert_element_into_vec_1d(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> { - %0 = vector.insert %arg0, %arg1[3] : f32 into vector<4xf32> - return %0 : vector<4xf32> -} -// CHECK-LABEL: llvm.func @insert_element_into_vec_1d -// CHECK: llvm.mlir.constant(3 : i64) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> -// CHECK: llvm.return {{.*}} : !llvm<"<4 x float>"> - -func @insert_vec_2d_into_vec_3d(%arg0: vector<8x16xf32>, %arg1: vector<4x8x16xf32>) -> vector<4x8x16xf32> { - %0 = vector.insert %arg0, %arg1[3] : vector<8x16xf32> into vector<4x8x16xf32> - return %0 : vector<4x8x16xf32> -} -// CHECK-LABEL: llvm.func @insert_vec_2d_into_vec_3d -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x [8 x <16 x float>]]"> -// CHECK: llvm.return {{.*}} : !llvm<"[4 x [8 x <16 x float>]]"> - -func @insert_vec_1d_into_vec_3d(%arg0: vector<16xf32>, %arg1: vector<4x8x16xf32>) -> vector<4x8x16xf32> { - %0 = vector.insert %arg0, %arg1[3, 7] : vector<16xf32> into vector<4x8x16xf32> - return %0 : vector<4x8x16xf32> -} -// CHECK-LABEL: llvm.func @insert_vec_1d_into_vec_3d -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3, 7] : !llvm<"[4 x [8 x <16 x float>]]"> -// CHECK: llvm.return {{.*}} : !llvm<"[4 x [8 x <16 x float>]]"> - -func @insert_element_into_vec_3d(%arg0: f32, %arg1: vector<4x8x16xf32>) -> vector<4x8x16xf32> { - %0 = vector.insert %arg0, %arg1[3, 7, 15] : f32 into vector<4x8x16xf32> - return %0 : vector<4x8x16xf32> -} -// CHECK-LABEL: llvm.func @insert_element_into_vec_3d -// CHECK: llvm.extractvalue {{.*}}[3, 7] : !llvm<"[4 x [8 x <16 x float>]]"> -// CHECK: llvm.mlir.constant(15 : i64) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<16 x float>"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3, 7] : !llvm<"[4 x [8 x <16 x float>]]"> -// CHECK: llvm.return {{.*}} : !llvm<"[4 x [8 x <16 x float>]]"> - -func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref> { - %0 = vector.type_cast %arg0: memref<8x8x8xf32> to memref> - return %0 : memref> -} -// CHECK-LABEL: llvm.func @vector_type_cast -// CHECK: llvm.mlir.undef : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }"> -// CHECK: %[[allocated:.*]] = llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: %[[allocatedBit:.*]] = llvm.bitcast %[[allocated]] : !llvm<"float*"> to !llvm<"[8 x [8 x <8 x float>]]*"> -// CHECK: llvm.insertvalue %[[allocatedBit]], {{.*}}[0] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }"> -// CHECK: %[[aligned:.*]] = llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: %[[alignedBit:.*]] = llvm.bitcast %[[aligned]] : !llvm<"float*"> to !llvm<"[8 x [8 x <8 x float>]]*"> -// CHECK: llvm.insertvalue %[[alignedBit]], {{.*}}[1] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }"> -// CHECK: llvm.mlir.constant(0 : index -// CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }"> - -func @vector_print_scalar(%arg0: f32) { - vector.print %arg0 : f32 - return -} -// CHECK-LABEL: llvm.func @vector_print_scalar -// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm.float -// CHECK: llvm.call @print_f32(%[[A]]) : (!llvm.float) -> () -// CHECK: llvm.call @print_newline() : () -> () - -func @vector_print_vector(%arg0: vector<2x2xf32>) { - vector.print %arg0 : vector<2x2xf32> - return -} -// CHECK-LABEL: llvm.func @vector_print_vector -// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm<"[2 x <2 x float>]"> -// CHECK: llvm.call @print_open() : () -> () -// CHECK: %[[x0:.*]] = llvm.extractvalue %[[A]][0] : !llvm<"[2 x <2 x float>]"> -// CHECK: llvm.call @print_open() : () -> () -// CHECK: %[[x1:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: %[[x2:.*]] = llvm.extractelement %[[x0]][%[[x1]] : !llvm.i64] : !llvm<"<2 x float>"> -// CHECK: llvm.call @print_f32(%[[x2]]) : (!llvm.float) -> () -// CHECK: llvm.call @print_comma() : () -> () -// CHECK: %[[x3:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK: %[[x4:.*]] = llvm.extractelement %[[x0]][%[[x3]] : !llvm.i64] : !llvm<"<2 x float>"> -// CHECK: llvm.call @print_f32(%[[x4]]) : (!llvm.float) -> () -// CHECK: llvm.call @print_close() : () -> () -// CHECK: llvm.call @print_comma() : () -> () -// CHECK: %[[x5:.*]] = llvm.extractvalue %[[A]][1] : !llvm<"[2 x <2 x float>]"> -// CHECK: llvm.call @print_open() : () -> () -// CHECK: %[[x6:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: %[[x7:.*]] = llvm.extractelement %[[x5]][%[[x6]] : !llvm.i64] : !llvm<"<2 x float>"> -// CHECK: llvm.call @print_f32(%[[x7]]) : (!llvm.float) -> () -// CHECK: llvm.call @print_comma() : () -> () -// CHECK: %[[x8:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK: %[[x9:.*]] = llvm.extractelement %[[x5]][%[[x8]] : !llvm.i64] : !llvm<"<2 x float>"> -// CHECK: llvm.call @print_f32(%[[x9]]) : (!llvm.float) -> () -// CHECK: llvm.call @print_close() : () -> () -// CHECK: llvm.call @print_close() : () -> () -// CHECK: llvm.call @print_newline() : () -> () - -func @strided_slice1(%arg0: vector<4xf32>) -> vector<2xf32> { - %0 = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> - return %0 : vector<2xf32> -} -// CHECK-LABEL: llvm.func @strided_slice1 -// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float -// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm<"<2 x float>"> -// CHECK: llvm.mlir.constant(2 : index) : !llvm.i64 -// CHECK: llvm.extractelement %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> -// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64 -// CHECK: llvm.extractelement %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> -// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK: llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> - -func @strided_slice2(%arg0: vector<4x8xf32>) -> vector<2x8xf32> { - %0 = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8xf32> to vector<2x8xf32> - return %0 : vector<2x8xf32> -} -// CHECK-LABEL: llvm.func @strided_slice2 -// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float -// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2x8xf32>) : !llvm<"[2 x <8 x float>]"> -// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"[4 x <8 x float>]"> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"[2 x <8 x float>]"> -// CHECK: llvm.extractvalue %{{.*}}[3] : !llvm<"[4 x <8 x float>]"> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"[2 x <8 x float>]"> - -func @strided_slice3(%arg0: vector<4x8xf32>) -> vector<2x2xf32> { - %0 = vector.strided_slice %arg0 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32> - return %0 : vector<2x2xf32> -} -// CHECK-LABEL: llvm.func @strided_slice3 -// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float -// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2x2xf32>) : !llvm<"[2 x <2 x float>]"> -// -// Subvector vector<8xf32> @2 -// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x <8 x float>]"> -// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float -// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm<"<2 x float>"> -// CHECK: llvm.mlir.constant(2 : index) : !llvm.i64 -// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<8 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> -// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64 -// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<8 x float>"> -// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[2 x <2 x float>]"> -// -// Subvector vector<8xf32> @3 -// CHECK: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <8 x float>]"> -// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float -// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm<"<2 x float>"> -// CHECK: llvm.mlir.constant(2 : index) : !llvm.i64 -// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<8 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> -// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64 -// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<8 x float>"> -// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[2 x <2 x float>]"> - -func @insert_strided_slice1(%b: vector<4x4xf32>, %c: vector<4x4x4xf32>) -> vector<4x4x4xf32> { - %0 = vector.insert_strided_slice %b, %c {offsets = [2, 0, 0], strides = [1, 1]} : vector<4x4xf32> into vector<4x4x4xf32> - return %0 : vector<4x4x4xf32> -} -// CHECK-LABEL: llvm.func @insert_strided_slice1 -// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x [4 x <4 x float>]]"> -// CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [4 x <4 x float>]]"> - -func @insert_strided_slice2(%a: vector<2x2xf32>, %b: vector<4x4xf32>) -> vector<4x4xf32> { - %0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> - return %0 : vector<4x4xf32> -} -// CHECK-LABEL: llvm.func @insert_strided_slice2 -// -// Subvector vector<2xf32> @0 into vector<4xf32> @2 -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[2 x <2 x float>]"> -// CHECK-NEXT: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x <4 x float>]"> -// Element @0 -> element @2 -// CHECK-NEXT: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK-NEXT: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> -// CHECK-NEXT: llvm.mlir.constant(2 : index) : !llvm.i64 -// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> -// Element @1 -> element @3 -// CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> -// CHECK-NEXT: llvm.mlir.constant(3 : index) : !llvm.i64 -// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> -// CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x <4 x float>]"> -// -// Subvector vector<2xf32> @1 into vector<4xf32> @3 -// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"[2 x <2 x float>]"> -// CHECK-NEXT: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <4 x float>]"> -// Element @0 -> element @2 -// CHECK-NEXT: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK-NEXT: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> -// CHECK-NEXT: llvm.mlir.constant(2 : index) : !llvm.i64 -// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> -// Element @1 -> element @3 -// CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> -// CHECK-NEXT: llvm.mlir.constant(3 : index) : !llvm.i64 -// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> -// CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x <4 x float>]"> - -func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf32>) -> vector<16x4x8xf32> { - %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 2], strides = [1, 1]}: - vector<2x4xf32> into vector<16x4x8xf32> - return %0 : vector<16x4x8xf32> -} -// CHECK-LABEL: llvm.func @insert_strided_slice3 -// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm<"[2 x <4 x float>]"> -// CHECK-SAME: %[[B:arg[0-9]+]]: !llvm<"[16 x [4 x <8 x float>]]"> -// CHECK: %[[s0:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[16 x [4 x <8 x float>]]"> -// CHECK: %[[s1:.*]] = llvm.extractvalue %[[A]][0] : !llvm<"[2 x <4 x float>]"> -// CHECK: %[[s2:.*]] = llvm.extractvalue %[[s0]][0] : !llvm<"[4 x <8 x float>]"> -// CHECK: %[[s3:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: %[[s4:.*]] = llvm.extractelement %[[s1]][%[[s3]] : !llvm.i64] : !llvm<"<4 x float>"> -// CHECK: %[[s5:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 -// CHECK: %[[s6:.*]] = llvm.insertelement %[[s4]], %[[s2]][%[[s5]] : !llvm.i64] : !llvm<"<8 x float>"> -// CHECK: %[[s7:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK: %[[s8:.*]] = llvm.extractelement %[[s1]][%[[s7]] : !llvm.i64] : !llvm<"<4 x float>"> -// CHECK: %[[s9:.*]] = llvm.mlir.constant(3 : index) : !llvm.i64 -// CHECK: %[[s10:.*]] = llvm.insertelement %[[s8]], %[[s6]][%[[s9]] : !llvm.i64] : !llvm<"<8 x float>"> -// CHECK: %[[s11:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 -// CHECK: %[[s12:.*]] = llvm.extractelement %[[s1]][%[[s11]] : !llvm.i64] : !llvm<"<4 x float>"> -// CHECK: %[[s13:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64 -// CHECK: %[[s14:.*]] = llvm.insertelement %[[s12]], %[[s10]][%[[s13]] : !llvm.i64] : !llvm<"<8 x float>"> -// CHECK: %[[s15:.*]] = llvm.mlir.constant(3 : index) : !llvm.i64 -// CHECK: %[[s16:.*]] = llvm.extractelement %[[s1]][%[[s15]] : !llvm.i64] : !llvm<"<4 x float>"> -// CHECK: %[[s17:.*]] = llvm.mlir.constant(5 : index) : !llvm.i64 -// CHECK: %[[s18:.*]] = llvm.insertelement %[[s16]], %[[s14]][%[[s17]] : !llvm.i64] : !llvm<"<8 x float>"> -// CHECK: %[[s19:.*]] = llvm.insertvalue %[[s18]], %[[s0]][0] : !llvm<"[4 x <8 x float>]"> -// CHECK: %[[s20:.*]] = llvm.extractvalue %[[A]][1] : !llvm<"[2 x <4 x float>]"> -// CHECK: %[[s21:.*]] = llvm.extractvalue %[[s0]][1] : !llvm<"[4 x <8 x float>]"> -// CHECK: %[[s22:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: %[[s23:.*]] = llvm.extractelement %[[s20]][%[[s22]] : !llvm.i64] : !llvm<"<4 x float>"> -// CHECK: %[[s24:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 -// CHECK: %[[s25:.*]] = llvm.insertelement %[[s23]], %[[s21]][%[[s24]] : !llvm.i64] : !llvm<"<8 x float>"> -// CHECK: %[[s26:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK: %[[s27:.*]] = llvm.extractelement %[[s20]][%[[s26]] : !llvm.i64] : !llvm<"<4 x float>"> -// CHECK: %[[s28:.*]] = llvm.mlir.constant(3 : index) : !llvm.i64 -// CHECK: %[[s29:.*]] = llvm.insertelement %[[s27]], %[[s25]][%[[s28]] : !llvm.i64] : !llvm<"<8 x float>"> -// CHECK: %[[s30:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 -// CHECK: %[[s31:.*]] = llvm.extractelement %[[s20]][%[[s30]] : !llvm.i64] : !llvm<"<4 x float>"> -// CHECK: %[[s32:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64 -// CHECK: %[[s33:.*]] = llvm.insertelement %[[s31]], %[[s29]][%[[s32]] : !llvm.i64] : !llvm<"<8 x float>"> -// CHECK: %[[s34:.*]] = llvm.mlir.constant(3 : index) : !llvm.i64 -// CHECK: %[[s35:.*]] = llvm.extractelement %[[s20]][%[[s34]] : !llvm.i64] : !llvm<"<4 x float>"> -// CHECK: %[[s36:.*]] = llvm.mlir.constant(5 : index) : !llvm.i64 -// CHECK: %[[s37:.*]] = llvm.insertelement %[[s35]], %[[s33]][%[[s36]] : !llvm.i64] : !llvm<"<8 x float>"> -// CHECK: %[[s38:.*]] = llvm.insertvalue %[[s37]], %[[s19]][1] : !llvm<"[4 x <8 x float>]"> -// CHECK: %[[s39:.*]] = llvm.insertvalue %[[s38]], %[[B]][0] : !llvm<"[16 x [4 x <8 x float>]]"> -// CHECK: llvm.return %[[s39]] : !llvm<"[16 x [4 x <8 x float>]]"> - -func @extract_strides(%arg0: vector<3x3xf32>) -> vector<1x1xf32> { - %0 = vector.extract_slices %arg0, [2, 2], [1, 1] - : vector<3x3xf32> into tuple, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> - %1 = vector.tuple_get %0, 3 : tuple, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> - return %1 : vector<1x1xf32> -} -// CHECK-LABEL: llvm.func @extract_strides -// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm<"[3 x <3 x float>]"> -// CHECK: %[[s0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1x1xf32>) : !llvm<"[1 x <1 x float>]"> -// CHECK: %[[s1:.*]] = llvm.extractvalue %[[A]][2] : !llvm<"[3 x <3 x float>]"> -// CHECK: %[[s3:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1xf32>) : !llvm<"<1 x float>"> -// CHECK: %[[s4:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 -// CHECK: %[[s5:.*]] = llvm.extractelement %[[s1]][%[[s4]] : !llvm.i64] : !llvm<"<3 x float>"> -// CHECK: %[[s6:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: %[[s7:.*]] = llvm.insertelement %[[s5]], %[[s3]][%[[s6]] : !llvm.i64] : !llvm<"<1 x float>"> -// CHECK: %[[s8:.*]] = llvm.insertvalue %[[s7]], %[[s0]][0] : !llvm<"[1 x <1 x float>]"> -// CHECK: llvm.return %[[s8]] : !llvm<"[1 x <1 x float>]"> diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -233,3 +233,12 @@ return %1 : vector<2x3x4xf32> } + +// CHECK-LABEL: @vector_fma +func @vector_fma(%a: vector<8xf32>, %b: vector<8x4xf32>) { + // CHECK: vector.fma %{{.*}} : vector<8xf32> + vector.fma %a, %a, %a : vector<8xf32> + // CHECK: vector.fma %{{.*}} : vector<8x4xf32> + vector.fma %b, %b, %b : vector<8x4xf32> + return +}