diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -55,6 +55,7 @@ /// ContractionOpLowering, /// ShapeCastOp2DDownCastRewritePattern, /// ShapeCastOp2DUpCastRewritePattern +/// BroadcastOpLowering, /// TransposeOpLowering /// OuterproductOpLowering /// These transformation express higher level vector ops in terms of more diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -126,155 +126,6 @@ namespace { -class VectorBroadcastOpConversion : public ConvertToLLVMPattern { -public: - explicit VectorBroadcastOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(vector::BroadcastOp::getOperationName(), context, - typeConverter) {} - - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto broadcastOp = cast(op); - VectorType dstVectorType = broadcastOp.getVectorType(); - if (typeConverter.convertType(dstVectorType) == nullptr) - return failure(); - // Rewrite when the full vector type can be lowered (which - // implies all 'reduced' types can be lowered too). - auto adaptor = vector::BroadcastOpOperandAdaptor(operands); - VectorType srcVectorType = - broadcastOp.getSourceType().dyn_cast(); - rewriter.replaceOp( - op, expandRanks(adaptor.source(), // source value to be expanded - op->getLoc(), // location of original broadcast - srcVectorType, dstVectorType, rewriter)); - return success(); - } - -private: - // Expands the given source value over all the ranks, as defined - // by the source and destination type (a null source type denotes - // expansion from a scalar value into a vector). - // - // TODO(ajcbik): consider replacing this one-pattern lowering - // with a two-pattern lowering using other vector - // ops once all insert/extract/shuffle operations - // are available with lowering implementation. - // - Value expandRanks(Value value, Location loc, VectorType srcVectorType, - VectorType dstVectorType, - ConversionPatternRewriter &rewriter) const { - assert((dstVectorType != nullptr) && "invalid result type in broadcast"); - // Determine rank of source and destination. - int64_t srcRank = srcVectorType ? srcVectorType.getRank() : 0; - int64_t dstRank = dstVectorType.getRank(); - int64_t curDim = dstVectorType.getDimSize(0); - if (srcRank < dstRank) - // Duplicate this rank. - return duplicateOneRank(value, loc, srcVectorType, dstVectorType, dstRank, - curDim, rewriter); - // If all trailing dimensions are the same, the broadcast consists of - // simply passing through the source value and we are done. Otherwise, - // any non-matching dimension forces a stretch along this rank. - assert((srcVectorType != nullptr) && (srcRank > 0) && - (srcRank == dstRank) && "invalid rank in broadcast"); - for (int64_t r = 0; r < dstRank; r++) { - if (srcVectorType.getDimSize(r) != dstVectorType.getDimSize(r)) { - return stretchOneRank(value, loc, srcVectorType, dstVectorType, dstRank, - curDim, rewriter); - } - } - return value; - } - - // Picks the best way to duplicate a single rank. For the 1-D case, a - // single insert-elt/shuffle is the most efficient expansion. For higher - // dimensions, however, we need dim x insert-values on a new broadcast - // with one less leading dimension, which will be lowered "recursively" - // to matching LLVM IR. - // For example: - // v = broadcast s : f32 to vector<4x2xf32> - // becomes: - // x = broadcast s : f32 to vector<2xf32> - // v = [x,x,x,x] - // becomes: - // x = [s,s] - // v = [x,x,x,x] - Value duplicateOneRank(Value value, Location loc, VectorType srcVectorType, - VectorType dstVectorType, int64_t rank, int64_t dim, - ConversionPatternRewriter &rewriter) const { - Type llvmType = typeConverter.convertType(dstVectorType); - assert((llvmType != nullptr) && "unlowerable vector type"); - if (rank == 1) { - Value undef = rewriter.create(loc, llvmType); - Value expand = insertOne(rewriter, typeConverter, loc, undef, value, - llvmType, rank, 0); - SmallVector zeroValues(dim, 0); - return rewriter.create( - loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues)); - } - Value expand = expandRanks(value, loc, srcVectorType, - reducedVectorTypeFront(dstVectorType), rewriter); - Value result = rewriter.create(loc, llvmType); - for (int64_t d = 0; d < dim; ++d) { - result = insertOne(rewriter, typeConverter, loc, result, expand, llvmType, - rank, d); - } - return result; - } - - // Picks the best way to stretch a single rank. For the 1-D case, a - // single insert-elt/shuffle is the most efficient expansion when at - // a stretch. Otherwise, every dimension needs to be expanded - // individually and individually inserted in the resulting vector. - // For example: - // v = broadcast w : vector<4x1x2xf32> to vector<4x2x2xf32> - // becomes: - // a = broadcast w[0] : vector<1x2xf32> to vector<2x2xf32> - // b = broadcast w[1] : vector<1x2xf32> to vector<2x2xf32> - // c = broadcast w[2] : vector<1x2xf32> to vector<2x2xf32> - // d = broadcast w[3] : vector<1x2xf32> to vector<2x2xf32> - // v = [a,b,c,d] - // becomes: - // x = broadcast w[0][0] : vector<2xf32> to vector <2x2xf32> - // y = broadcast w[1][0] : vector<2xf32> to vector <2x2xf32> - // a = [x, y] - // etc. - Value stretchOneRank(Value value, Location loc, VectorType srcVectorType, - VectorType dstVectorType, int64_t rank, int64_t dim, - ConversionPatternRewriter &rewriter) const { - Type llvmType = typeConverter.convertType(dstVectorType); - assert((llvmType != nullptr) && "unlowerable vector type"); - Value result = rewriter.create(loc, llvmType); - bool atStretch = dim != srcVectorType.getDimSize(0); - if (rank == 1) { - assert(atStretch); - Type redLlvmType = - typeConverter.convertType(dstVectorType.getElementType()); - Value one = - extractOne(rewriter, typeConverter, loc, value, redLlvmType, rank, 0); - Value expand = insertOne(rewriter, typeConverter, loc, result, one, - llvmType, rank, 0); - SmallVector zeroValues(dim, 0); - return rewriter.create( - loc, expand, result, rewriter.getI32ArrayAttr(zeroValues)); - } - VectorType redSrcType = reducedVectorTypeFront(srcVectorType); - VectorType redDstType = reducedVectorTypeFront(dstVectorType); - Type redLlvmType = typeConverter.convertType(redSrcType); - for (int64_t d = 0; d < dim; ++d) { - int64_t pos = atStretch ? 0 : d; - Value one = extractOne(rewriter, typeConverter, loc, value, redLlvmType, - rank, pos); - Value expand = expandRanks(one, loc, redSrcType, redDstType, rewriter); - result = insertOne(rewriter, typeConverter, loc, result, expand, llvmType, - rank, d); - } - return result; - } -}; - /// Conversion pattern for a vector.matrix_multiply. /// This is lowered directly to the proper llvm.intr.matrix.multiply. class VectorMatmulOpConversion : public ConvertToLLVMPattern { @@ -1209,8 +1060,7 @@ VectorInsertStridedSliceOpSameRankRewritePattern, VectorStridedSliceOpConversion>(ctx); patterns - .insert { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::BroadcastOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + VectorType dstType = op.getVectorType(); + VectorType srcType = op.getSourceType().dyn_cast(); + Type eltType = dstType.getElementType(); + + // Determine rank of source and destination. + int64_t srcRank = srcType ? srcType.getRank() : 0; + int64_t dstRank = dstType.getRank(); + + // Duplicate this rank. + // For example: + // %x = broadcast %y : k-D to n-D, k < n + // becomes: + // %b = broadcast %y : k-D to (n-1)-D + // %x = [%b,%b,%b,%b] : n-D + // becomes: + // %b = [%y,%y] : (n-1)-D + // %x = [%b,%b,%b,%b] : n-D + if (srcRank < dstRank) { + // Scalar to any vector can use splat. + if (srcRank == 0) { + rewriter.replaceOpWithNewOp(op, dstType, op.source()); + return success(); + } + // Duplication. + VectorType resType = + VectorType::get(dstType.getShape().drop_front(), eltType); + Value bcst = + rewriter.create(loc, resType, op.source()); + Value zero = rewriter.create(loc, eltType, + rewriter.getZeroAttr(eltType)); + Value result = rewriter.create(loc, dstType, zero); + for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) + result = rewriter.create(loc, bcst, result, d); + rewriter.replaceOp(op, result); + return success(); + } + + // Find non-matching dimension, if any. + assert(srcRank == dstRank); + int64_t m = -1; + for (int64_t r = 0; r < dstRank; r++) + if (srcType.getDimSize(r) != dstType.getDimSize(r)) { + m = r; + break; + } + + // All trailing dimensions are the same. Simply pass through. + if (m == -1) { + rewriter.replaceOp(op, op.source()); + return success(); + } + + // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. + if (srcRank == 1) { + assert(m == 0); + Value ext = rewriter.create(loc, op.source(), 0); + rewriter.replaceOpWithNewOp(op, dstType, ext); + return success(); + } + + // Any non-matching dimension forces a stretch along this rank. + // For example: + // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32> + // becomes: + // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32> + // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32> + // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32> + // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32> + // %x = [%a,%b,%c,%d] + // becomes: + // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32> + // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32> + // %a = [%u, %v] + // .. + // %x = [%a,%b,%c,%d] + VectorType resType = + VectorType::get(dstType.getShape().drop_front(), eltType); + Value zero = rewriter.create(loc, eltType, + rewriter.getZeroAttr(eltType)); + Value result = rewriter.create(loc, dstType, zero); + if (m == 0) { + // Stetch at start. + Value ext = rewriter.create(loc, op.source(), 0); + Value bcst = rewriter.create(loc, resType, ext); + for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) + result = rewriter.create(loc, bcst, result, d); + } else { + // Stetch not at start. + for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) { + Value ext = rewriter.create(loc, op.source(), d); + Value bcst = rewriter.create(loc, resType, ext); + result = rewriter.create(loc, bcst, result, d); + } + } + rewriter.replaceOp(op, result); + return success(); + } +}; + +/// Progressive lowering of TransposeOp. /// One: /// %x = vector.transpose %y, [1, 0] /// is replaced by: @@ -1520,7 +1627,7 @@ OwningRewritePatternList &patterns, MLIRContext *context, VectorTransformsOptions parameters) { patterns.insert(context); + ShapeCastOp2DUpCastRewritePattern, BroadcastOpLowering, + TransposeOpLowering, OuterProductOpLowering>(context); patterns.insert(parameters, context); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -4,201 +4,199 @@ %0 = vector.broadcast %arg0 : f32 to vector<2xf32> return %0 : vector<2xf32> } -// CHECK-LABEL: llvm.func @broadcast_vec1d_from_scalar -// CHECK: llvm.mlir.undef : !llvm<"<2 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>"> -// CHECK: llvm.shufflevector {{.*}}, {{.*}}[0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>"> -// CHECK: llvm.return {{.*}} : !llvm<"<2 x float>"> +// CHECK-LABEL: llvm.func @broadcast_vec1d_from_scalar( +// CHECK-SAME: %[[A:.*]]: !llvm.float) +// CHECK: %[[T0:.*]] = llvm.mlir.undef : !llvm<"<2 x float>"> +// CHECK: %[[T1:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 +// CHECK: %[[T2:.*]] = llvm.insertelement %[[A]], %[[T0]][%[[T1]] : !llvm.i32] : !llvm<"<2 x float>"> +// CHECK: %[[T3:.*]] = llvm.shufflevector %[[T2]], %[[T0]] [0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>"> +// CHECK: llvm.return %[[T3]] : !llvm<"<2 x float>"> func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> { %0 = vector.broadcast %arg0 : f32 to vector<2x3xf32> return %0 : vector<2x3xf32> } -// CHECK-LABEL: llvm.func @broadcast_vec2d_from_scalar -// CHECK: llvm.mlir.undef : !llvm<"<3 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>"> -// CHECK: llvm.shufflevector {{.*}}, {{.*}}[0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> -// CHECK: llvm.mlir.undef : !llvm<"[2 x <3 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[2 x <3 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[2 x <3 x float>]"> -// CHECK: llvm.return {{.*}} : !llvm<"[2 x <3 x float>]"> +// CHECK-LABEL: llvm.func @broadcast_vec2d_from_scalar( +// CHECK-SAME: %[[A:.*]]: !llvm.float) +// CHECK: %[[T0:.*]] = llvm.mlir.undef : !llvm<"[2 x <3 x float>]"> +// CHECK: %[[T1:.*]] = llvm.mlir.undef : !llvm<"<3 x float>"> +// CHECK: %[[T2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 +// CHECK: %[[T3:.*]] = llvm.insertelement %[[A]], %[[T1]][%[[T2]] : !llvm.i32] : !llvm<"<3 x float>"> +// CHECK: %[[T4:.*]] = llvm.shufflevector %[[T3]], %[[T3]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> +// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T4]], %[[T0]][0] : !llvm<"[2 x <3 x float>]"> +// CHECK: %[[T6:.*]] = llvm.insertvalue %[[T4]], %[[T5]][1] : !llvm<"[2 x <3 x float>]"> +// CHECK: llvm.return %[[T6]] : !llvm<"[2 x <3 x float>]"> func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> { %0 = vector.broadcast %arg0 : f32 to vector<2x3x4xf32> return %0 : vector<2x3x4xf32> } -// CHECK-LABEL: llvm.func @broadcast_vec3d_from_scalar -// CHECK: llvm.mlir.undef : !llvm<"<4 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> -// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> -// CHECK: llvm.mlir.undef : !llvm<"[3 x <4 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <4 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <4 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <4 x float>]"> -// CHECK: llvm.mlir.undef : !llvm<"[2 x [3 x <4 x float>]]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[2 x [3 x <4 x float>]]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[2 x [3 x <4 x float>]]"> -// CHECK: llvm.return {{.*}} : !llvm<"[2 x [3 x <4 x float>]]"> +// CHECK-LABEL: llvm.func @broadcast_vec3d_from_scalar( +// CHECK-SAME: %[[A:.*]]: !llvm.float) +// CHECK: %[[T0:.*]] = llvm.mlir.undef : !llvm<"[2 x [3 x <4 x float>]]"> +// CHECK: %[[T1:.*]] = llvm.mlir.undef : !llvm<"<4 x float>"> +// CHECK: %[[T2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 +// CHECK: %[[T3:.*]] = llvm.insertelement %[[A]], %[[T1]][%[[T2]] : !llvm.i32] : !llvm<"<4 x float>"> +// CHECK: %[[T4:.*]] = llvm.shufflevector %[[T3]], %[[T3]] [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> +// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T4]], %[[T0]][0, 0] : !llvm<"[2 x [3 x <4 x float>]]"> +// CHECK: %[[T6:.*]] = llvm.insertvalue %[[T4]], %[[T5]][0, 1] : !llvm<"[2 x [3 x <4 x float>]]"> +// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T4]], %[[T6]][0, 2] : !llvm<"[2 x [3 x <4 x float>]]"> +// CHECK: %[[T8:.*]] = llvm.insertvalue %[[T4]], %[[T7]][1, 0] : !llvm<"[2 x [3 x <4 x float>]]"> +// CHECK: %[[T9:.*]] = llvm.insertvalue %[[T4]], %[[T8]][1, 1] : !llvm<"[2 x [3 x <4 x float>]]"> +// CHECK: %[[T10:.*]] = llvm.insertvalue %[[T4]], %[[T9]][1, 2] : !llvm<"[2 x [3 x <4 x float>]]"> +// CHECK: llvm.return %[[T10]] : !llvm<"[2 x [3 x <4 x float>]]"> func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> { %0 = vector.broadcast %arg0 : vector<2xf32> to vector<2xf32> return %0 : vector<2xf32> } -// CHECK-LABEL: llvm.func @broadcast_vec1d_from_vec1d -// CHECK: llvm.return {{.*}} : !llvm<"<2 x float>"> +// CHECK-LABEL: llvm.func @broadcast_vec1d_from_vec1d( +// CHECK-SAME: %[[A:.*]]: !llvm<"<2 x float>">) +// CHECK: llvm.return %[[A]] : !llvm<"<2 x float>"> func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> { %0 = vector.broadcast %arg0 : vector<2xf32> to vector<3x2xf32> return %0 : vector<3x2xf32> } -// CHECK-LABEL: llvm.func @broadcast_vec2d_from_vec1d -// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.return {{.*}} : !llvm<"[3 x <2 x float>]"> +// CHECK-LABEL: llvm.func @broadcast_vec2d_from_vec1d( +// CHECK-SAME: %[[A:.*]]: !llvm<"<2 x float>">) +// CHECK: %[[T0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<3x2xf32>) : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T1:.*]] = llvm.insertvalue %[[A]], %[[T0]][0] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T2:.*]] = llvm.insertvalue %[[A]], %[[T1]][1] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][2] : !llvm<"[3 x <2 x float>]"> +// CHECK: llvm.return %[[T3]] : !llvm<"[3 x <2 x float>]"> func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32> { %0 = vector.broadcast %arg0 : vector<2xf32> to vector<4x3x2xf32> return %0 : vector<4x3x2xf32> } -// CHECK-LABEL: llvm.func @broadcast_vec3d_from_vec1d -// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.mlir.undef : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.return {{.*}} : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK-LABEL: llvm.func @broadcast_vec3d_from_vec1d( +// CHECK-SAME: %[[A:.*]]: !llvm<"<2 x float>">) +// CHECK: %[[T0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<3x2xf32>) : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T1:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<4x3x2xf32>) : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: %[[T2:.*]] = llvm.insertvalue %[[A]], %[[T0]][0] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][1] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T4:.*]] = llvm.insertvalue %[[A]], %[[T3]][2] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T4]], %[[T1]][0] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: %[[T6:.*]] = llvm.insertvalue %[[T4]], %[[T5]][1] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T4]], %[[T6]][2] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: %[[T8:.*]] = llvm.insertvalue %[[T4]], %[[T7]][3] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.return %[[T8]] : !llvm<"[4 x [3 x <2 x float>]]"> func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf32> { %0 = vector.broadcast %arg0 : vector<3x2xf32> to vector<4x3x2xf32> return %0 : vector<4x3x2xf32> } -// CHECK-LABEL: llvm.func @broadcast_vec3d_from_vec2d -// CHECK: llvm.mlir.undef : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.return {{.*}} : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK-LABEL: llvm.func @broadcast_vec3d_from_vec2d( +// CHECK-SAME: %[[A:.*]]: !llvm<"[3 x <2 x float>]">) +// CHECK: %[[T0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<4x3x2xf32>) : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: %[[T1:.*]] = llvm.insertvalue %[[A]], %[[T0]][0] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: %[[T2:.*]] = llvm.insertvalue %[[A]], %[[T1]][1] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][2] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: %[[T4:.*]] = llvm.insertvalue %[[A]], %[[T3]][3] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.return %[[T4]] : !llvm<"[4 x [3 x <2 x float>]]"> func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> { %0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32> return %0 : vector<4xf32> } -// CHECK-LABEL: llvm.func @broadcast_stretch -// CHECK: llvm.mlir.undef : !llvm<"<4 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>"> -// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> -// CHECK: llvm.return {{.*}} : !llvm<"<4 x float>"> +// CHECK-LABEL: llvm.func @broadcast_stretch( +// CHECK-SAME: %[[A:.*]]: !llvm<"<1 x float>">) +// CHECK: %[[T0:.*]] = llvm.mlir.constant(0 : i64) : !llvm.i64 +// CHECK: %[[T1:.*]] = llvm.extractelement %[[A]][%[[T0]] : !llvm.i64] : !llvm<"<1 x float>"> +// CHECK: %[[T2:.*]] = llvm.mlir.undef : !llvm<"<4 x float>"> +// CHECK: %[[T3:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 +// CHECK: %[[T4:.*]] = llvm.insertelement %[[T1]], %[[T2]][%3 : !llvm.i32] : !llvm<"<4 x float>"> +// CHECK: %[[T5:.*]] = llvm.shufflevector %[[T4]], %[[T2]] [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> +// CHECK: llvm.return %[[T5]] : !llvm<"<4 x float>"> func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> { %0 = vector.broadcast %arg0 : vector<1x4xf32> to vector<3x4xf32> return %0 : vector<3x4xf32> } -// CHECK-LABEL: llvm.func @broadcast_stretch_at_start -// CHECK: llvm.mlir.undef : !llvm<"[3 x <4 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <4 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <4 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <4 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <4 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <4 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <4 x float>]"> -// CHECK: llvm.return {{.*}} : !llvm<"[3 x <4 x float>]"> +// CHECK-LABEL: llvm.func @broadcast_stretch_at_start( +// CHECK-SAME: %[[A:.*]]: !llvm<"[1 x <4 x float>]">) +// CHECK: %[[T0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<3x4xf32>) : !llvm<"[3 x <4 x float>]"> +// CHECK: %[[T1:.*]] = llvm.extractvalue %[[A]][0] : !llvm<"[1 x <4 x float>]"> +// CHECK: %[[T2:.*]] = llvm.insertvalue %[[T1]], %[[T0]][0] : !llvm<"[3 x <4 x float>]"> +// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][1] : !llvm<"[3 x <4 x float>]"> +// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T1]], %[[T3]][2] : !llvm<"[3 x <4 x float>]"> +// CHECK: llvm.return %[[T4]] : !llvm<"[3 x <4 x float>]"> func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> { %0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<4x3xf32> return %0 : vector<4x3xf32> } -// CHECK-LABEL: llvm.func @broadcast_stretch_at_end -// CHECK: llvm.mlir.undef : !llvm<"[4 x <3 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[4 x <1 x float>]"> -// CHECK: llvm.mlir.undef : !llvm<"<3 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>"> -// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x <3 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"[4 x <1 x float>]"> -// CHECK: llvm.mlir.undef : !llvm<"<3 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>"> -// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[4 x <3 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x <1 x float>]"> -// CHECK: llvm.mlir.undef : !llvm<"<3 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>"> -// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x <3 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <1 x float>]"> -// CHECK: llvm.mlir.undef : !llvm<"<3 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.extractelement {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<1 x float>"> -// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<3 x float>"> -// CHECK: llvm.shufflevector {{.*}}, {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x <3 x float>]"> -// CHECK: llvm.return {{.*}} : !llvm<"[4 x <3 x float>]"> +// CHECK-LABEL: llvm.func @broadcast_stretch_at_end( +// CHECK-SAME: %[[A:.*]]: !llvm<"[4 x <1 x float>]">) +// CHECK: %[[T0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<4x3xf32>) : !llvm<"[4 x <3 x float>]"> +// CHECK: %[[T1:.*]] = llvm.extractvalue %[[A]][0] : !llvm<"[4 x <1 x float>]"> +// CHECK: %[[T2:.*]] = llvm.mlir.constant(0 : i64) : !llvm.i64 +// CHECK: %[[T3:.*]] = llvm.extractelement %[[T1]][%[[T2]] : !llvm.i64] : !llvm<"<1 x float>"> +// CHECK: %[[T4:.*]] = llvm.mlir.undef : !llvm<"<3 x float>"> +// CHECK: %[[T5:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 +// CHECK: %[[T6:.*]] = llvm.insertelement %[[T3]], %[[T4]][%[[T5]] : !llvm.i32] : !llvm<"<3 x float>"> +// CHECK: %[[T7:.*]] = llvm.shufflevector %[[T6]], %[[T4]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> +// CHECK: %[[T8:.*]] = llvm.insertvalue %[[T7]], %[[T0]][0] : !llvm<"[4 x <3 x float>]"> +// CHECK: %[[T9:.*]] = llvm.extractvalue %[[A]][1] : !llvm<"[4 x <1 x float>]"> +// CHECK: %[[T10:.*]] = llvm.mlir.constant(0 : i64) : !llvm.i64 +// CHECK: %[[T11:.*]] = llvm.extractelement %[[T9]][%[[T10]] : !llvm.i64] : !llvm<"<1 x float>"> +// CHECK: %[[T12:.*]] = llvm.mlir.undef : !llvm<"<3 x float>"> +// CHECK: %[[T13:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 +// CHECK: %[[T14:.*]] = llvm.insertelement %[[T11]], %[[T12]][%[[T13]] : !llvm.i32] : !llvm<"<3 x float>"> +// CHECK: %[[T15:.*]] = llvm.shufflevector %[[T14]], %[[T12]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> +// CHECK: %[[T16:.*]] = llvm.insertvalue %[[T15]], %[[T8]][1] : !llvm<"[4 x <3 x float>]"> +// CHECK: %[[T17:.*]] = llvm.extractvalue %[[A]][2] : !llvm<"[4 x <1 x float>]"> +// CHECK: %[[T18:.*]] = llvm.mlir.constant(0 : i64) : !llvm.i64 +// CHECK: %[[T19:.*]] = llvm.extractelement %[[T17]][%[[T18]] : !llvm.i64] : !llvm<"<1 x float>"> +// CHECK: %[[T20:.*]] = llvm.mlir.undef : !llvm<"<3 x float>"> +// CHECK: %[[T21:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 +// CHECK: %[[T22:.*]] = llvm.insertelement %[[T19]], %[[T20]][%[[T21]] : !llvm.i32] : !llvm<"<3 x float>"> +// CHECK: %[[T23:.*]] = llvm.shufflevector %[[T22]], %[[T20]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> +// CHECK: %[[T24:.*]] = llvm.insertvalue %[[T23]], %[[T16]][2] : !llvm<"[4 x <3 x float>]"> +// CHECK: %[[T25:.*]] = llvm.extractvalue %[[A]][3] : !llvm<"[4 x <1 x float>]"> +// CHECK: %[[T26:.*]] = llvm.mlir.constant(0 : i64) : !llvm.i64 +// CHECK: %[[T27:.*]] = llvm.extractelement %[[T25]][%[[T26]] : !llvm.i64] : !llvm<"<1 x float>"> +// CHECK: %[[T28:.*]] = llvm.mlir.undef : !llvm<"<3 x float>"> +// CHECK: %[[T29:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 +// CHECK: %[[T30:.*]] = llvm.insertelement %[[T27]], %[[T28]][%[[T29]] : !llvm.i32] : !llvm<"<3 x float>"> +// CHECK: %[[T31:.*]] = llvm.shufflevector %[[T30]], %[[T28]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> +// CHECK: %[[T32:.*]] = llvm.insertvalue %[[T31]], %[[T24]][3] : !llvm<"[4 x <3 x float>]"> +// CHECK: llvm.return %[[T32]] : !llvm<"[4 x <3 x float>]"> func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32> { %0 = vector.broadcast %arg0 : vector<4x1x2xf32> to vector<4x3x2xf32> return %0 : vector<4x3x2xf32> } -// CHECK-LABEL: llvm.func @broadcast_stretch_in_middle -// CHECK: llvm.mlir.undef : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[4 x [1 x <2 x float>]]"> -// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"[4 x [1 x <2 x float>]]"> -// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x [1 x <2 x float>]]"> -// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x [1 x <2 x float>]]"> -// CHECK: llvm.mlir.undef : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[1 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[3 x <2 x float>]"> -// CHECK: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x [3 x <2 x float>]]"> -// CHECK: llvm.return {{.*}} : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK-LABEL: llvm.func @broadcast_stretch_in_middle( +// CHECK-SAME: %[[A:.*]]: !llvm<"[4 x [1 x <2 x float>]]">) +// CHECK: %[[T0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<4x3x2xf32>) : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: %[[T1:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<3x2xf32>) : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T2:.*]] = llvm.extractvalue %[[A]][0] : !llvm<"[4 x [1 x <2 x float>]]"> +// CHECK: %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm<"[1 x <2 x float>]"> +// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T1]][0] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T3]], %[[T4]][1] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T6:.*]] = llvm.insertvalue %[[T3]], %[[T5]][2] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T6]], %[[T0]][0] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: %[[T8:.*]] = llvm.extractvalue %[[A]][1] : !llvm<"[4 x [1 x <2 x float>]]"> +// CHECK: %[[T9:.*]] = llvm.extractvalue %[[T8]][0] : !llvm<"[1 x <2 x float>]"> +// CHECK: %[[T10:.*]] = llvm.insertvalue %[[T9]], %[[T1]][0] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T11:.*]] = llvm.insertvalue %[[T9]], %[[T10]][1] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T12:.*]] = llvm.insertvalue %[[T9]], %[[T11]][2] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T13:.*]] = llvm.insertvalue %[[T12]], %[[T7]][1] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: %[[T14:.*]] = llvm.extractvalue %[[A]][2] : !llvm<"[4 x [1 x <2 x float>]]"> +// CHECK: %[[T15:.*]] = llvm.extractvalue %[[T14]][0] : !llvm<"[1 x <2 x float>]"> +// CHECK: %[[T16:.*]] = llvm.insertvalue %[[T15]], %[[T1]][0] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T17:.*]] = llvm.insertvalue %[[T15]], %[[T16]][1] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T18:.*]] = llvm.insertvalue %[[T15]], %[[T17]][2] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T19:.*]] = llvm.insertvalue %[[T18]], %[[T13]][2] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: %[[T20:.*]] = llvm.extractvalue %[[A]][3] : !llvm<"[4 x [1 x <2 x float>]]"> +// CHECK: %[[T21:.*]] = llvm.extractvalue %[[T20]][0] : !llvm<"[1 x <2 x float>]"> +// CHECK: %[[T22:.*]] = llvm.insertvalue %[[T21]], %[[T1]][0] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T23:.*]] = llvm.insertvalue %[[T21]], %[[T22]][1] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T24:.*]] = llvm.insertvalue %[[T21]], %[[T23]][2] : !llvm<"[3 x <2 x float>]"> +// CHECK: %[[T25:.*]] = llvm.insertvalue %[[T24]], %[[T19]][3] : !llvm<"[4 x [3 x <2 x float>]]"> +// CHECK: llvm.return %[[T25]] : !llvm<"[4 x [3 x <2 x float>]]"> func @outerproduct(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<2x3xf32> { %2 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32> @@ -211,16 +209,16 @@ // CHECK: %[[T1:.*]] = llvm.mlir.constant(0 : i64) : !llvm.i64 // CHECK: %[[T2:.*]] = llvm.extractelement %[[A]][%[[T1]] : !llvm.i64] : !llvm<"<2 x float>"> // CHECK: %[[T3:.*]] = llvm.mlir.undef : !llvm<"<3 x float>"> -// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: %[[T5:.*]] = llvm.insertelement %[[T2]], %[[T3]][%4 : !llvm.i64] : !llvm<"<3 x float>"> +// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 +// CHECK: %[[T5:.*]] = llvm.insertelement %[[T2]], %[[T3]][%4 : !llvm.i32] : !llvm<"<3 x float>"> // CHECK: %[[T6:.*]] = llvm.shufflevector %[[T5]], %[[T3]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> // CHECK: %[[T7:.*]] = llvm.fmul %[[T6]], %[[B]] : !llvm<"<3 x float>"> // CHECK: %[[T8:.*]] = llvm.insertvalue %[[T7]], %[[T0]][0] : !llvm<"[2 x <3 x float>]"> // CHECK: %[[T9:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64 // CHECK: %[[T10:.*]] = llvm.extractelement %[[A]][%9 : !llvm.i64] : !llvm<"<2 x float>"> // CHECK: %[[T11:.*]] = llvm.mlir.undef : !llvm<"<3 x float>"> -// CHECK: %[[T12:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: %[[T13:.*]] = llvm.insertelement %[[T10]], %[[T11]][%12 : !llvm.i64] : !llvm<"<3 x float>"> +// CHECK: %[[T12:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 +// CHECK: %[[T13:.*]] = llvm.insertelement %[[T10]], %[[T11]][%12 : !llvm.i32] : !llvm<"<3 x float>"> // CHECK: %[[T14:.*]] = llvm.shufflevector %[[T13]], %[[T11]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> // CHECK: %[[T15:.*]] = llvm.fmul %[[T14]], %[[B]] : !llvm<"<3 x float>"> // CHECK: %[[T16:.*]] = llvm.insertvalue %[[T15]], %[[T8]][1] : !llvm<"[2 x <3 x float>]"> @@ -238,8 +236,8 @@ // CHECK: %[[T1:.*]] = llvm.mlir.constant(0 : i64) : !llvm.i64 // CHECK: %[[T2:.*]] = llvm.extractelement %[[A]][%[[T1]] : !llvm.i64] : !llvm<"<2 x float>"> // CHECK: %[[T3:.*]] = llvm.mlir.undef : !llvm<"<3 x float>"> -// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: %[[T5:.*]] = llvm.insertelement %[[T2]], %[[T3]][%[[T4]] : !llvm.i64] : !llvm<"<3 x float>"> +// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 +// CHECK: %[[T5:.*]] = llvm.insertelement %[[T2]], %[[T3]][%[[T4]] : !llvm.i32] : !llvm<"<3 x float>"> // CHECK: %[[T6:.*]] = llvm.shufflevector %[[T5]], %[[T3]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> // CHECK: %[[T7:.*]] = llvm.extractvalue %[[C]][0] : !llvm<"[2 x <3 x float>]"> // CHECK: %[[T8:.*]] = "llvm.intr.fma"(%[[T6]], %[[B]], %[[T7]]) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) @@ -247,8 +245,8 @@ // CHECK: %[[T10:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64 // CHECK: %[[T11:.*]] = llvm.extractelement %[[A]][%[[T10]] : !llvm.i64] : !llvm<"<2 x float>"> // CHECK: %[[T12:.*]] = llvm.mlir.undef : !llvm<"<3 x float>"> -// CHECK: %[[T13:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK: %[[T14:.*]] = llvm.insertelement %[[T11]], %[[T12]][%[[T13]] : !llvm.i64] : !llvm<"<3 x float>"> +// CHECK: %[[T13:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 +// CHECK: %[[T14:.*]] = llvm.insertelement %[[T11]], %[[T12]][%[[T13]] : !llvm.i32] : !llvm<"<3 x float>"> // CHECK: %[[T15:.*]] = llvm.shufflevector %[[T14]], %[[T12]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>"> // CHECK: %[[T16:.*]] = llvm.extractvalue %[[C]][1] : !llvm<"[2 x <3 x float>]"> // CHECK: %[[T17:.*]] = "llvm.intr.fma"(%[[T15]], %[[B]], %[[T16]]) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -257,11 +257,11 @@ // CHECK-SAME: %[[B:.*1]]: vector<3xf32> // CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<2x3xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> -// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32> +// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xf32> // CHECK: %[[T2:.*]] = mulf %[[T1]], %[[B]] : vector<3xf32> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> // CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xf32> -// CHECK: %[[T5:.*]] = vector.broadcast %[[T4]] : f32 to vector<3xf32> +// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<3xf32> // CHECK: %[[T6:.*]] = mulf %[[T5]], %[[B]] : vector<3xf32> // CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32> // CHECK: return %[[T7]] : vector<2x3xf32> @@ -278,12 +278,12 @@ // CHECK-SAME: %[[C:.*2]]: vector<2x3xf32> // CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<2x3xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> -// CHECK: %[[T1:.*]] = vector.broadcast %[[T0]] : f32 to vector<3xf32> +// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xf32> // CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xf32> // CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32> // CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> // CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2xf32> -// CHECK: %[[T6:.*]] = vector.broadcast %[[T5]] : f32 to vector<3xf32> +// CHECK: %[[T6:.*]] = splat %[[T5]] : vector<3xf32> // CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<2x3xf32> // CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32> // CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32> @@ -389,3 +389,173 @@ : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32> return %0 : vector<2x3xf32> } + +// CHECK-LABEL: func @broadcast_vec1d_from_scalar +// CHECK-SAME: %[[A:.*0]]: f32 +// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2xf32> +// CHECK: return %[[T0]] : vector<2xf32> + +func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<2xf32> + return %0 : vector<2xf32> +} + +// CHECK-LABEL: func @broadcast_vec2d_from_scalar +// CHECK-SAME: %[[A:.*0]]: f32 +// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2x3xf32> +// CHECK: return %[[T0]] : vector<2x3xf32> + +func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<2x3xf32> + return %0 : vector<2x3xf32> +} + +// CHECK-LABEL: func @broadcast_vec3d_from_scalar +// CHECK-SAME: %[[A:.*0]]: f32 +// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2x3x4xf32> +// CHECK: return %[[T0]] : vector<2x3x4xf32> + +func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<2x3x4xf32> + return %0 : vector<2x3x4xf32> +} + +// CHECK-LABEL: func @broadcast_vec1d_from_vec1d +// CHECK-SAME: %[[A:.*0]]: vector<2xf32> +// CHECK: return %[[A]] : vector<2xf32> + +func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> { + %0 = vector.broadcast %arg0 : vector<2xf32> to vector<2xf32> + return %0 : vector<2xf32> +} + +// CHECK-LABEL: func @broadcast_vec2d_from_vec1d +// CHECK-SAME: %[[A:.*0]]: vector<2xf32> +// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<3x2xf32> +// CHECK: %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<2xf32> into vector<3x2xf32> +// CHECK: return %[[T2]] : vector<3x2xf32> + +func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> { + %0 = vector.broadcast %arg0 : vector<2xf32> to vector<3x2xf32> + return %0 : vector<3x2xf32> +} + +// CHECK-LABEL: func @broadcast_vec3d_from_vec1d +// CHECK-SAME: %[[A:.*0]]: vector<2xf32> +// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<3x2xf32> +// CHECK: %[[C1:.*]] = constant dense<0.000000e+00> : vector<4x3x2xf32> +// CHECK: %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C1]] [0] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T4:.*]] = vector.insert %[[T2]], %[[T3]] [1] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T5:.*]] = vector.insert %[[T2]], %[[T4]] [2] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T6:.*]] = vector.insert %[[T2]], %[[T5]] [3] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: return %[[T6]] : vector<4x3x2xf32> + +func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32> { + %0 = vector.broadcast %arg0 : vector<2xf32> to vector<4x3x2xf32> + return %0 : vector<4x3x2xf32> +} + +// CHECK-LABEL: func @broadcast_vec3d_from_vec2d +// CHECK-SAME: %[[A:.*0]]: vector<3x2xf32> +// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<4x3x2xf32> +// CHECK: %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[A]], %[[T2]] [3] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: return %[[T3]] : vector<4x3x2xf32> + +func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf32> { + %0 = vector.broadcast %arg0 : vector<3x2xf32> to vector<4x3x2xf32> + return %0 : vector<4x3x2xf32> +} + +// CHECK-LABEL: func @broadcast_stretch +// CHECK-SAME: %[[A:.*0]]: vector<1xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<1xf32> +// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<4xf32> +// CHECK: return %[[T1]] : vector<4xf32> + +func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> { + %0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @broadcast_stretch_at_start +// CHECK-SAME: %[[A:.*0]]: vector<1x4xf32> +// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<3x4xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<1x4xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C0]] [0] : vector<4xf32> into vector<3x4xf32> +// CHECK: %[[T2:.*]] = vector.insert %[[T0]], %[[T1]] [1] : vector<4xf32> into vector<3x4xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T0]], %[[T2]] [2] : vector<4xf32> into vector<3x4xf32> +// CHECK: return %[[T3]] : vector<3x4xf32> + +func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> { + %0 = vector.broadcast %arg0 : vector<1x4xf32> to vector<3x4xf32> + return %0 : vector<3x4xf32> +} + +// CHECK-LABEL: func @broadcast_stretch_at_end +// CHECK-SAME: %[[A:.*0]]: vector<4x1xf32> +// CHECK: %[[C:.*]] = constant dense<0.000000e+00> : vector<4x3xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x1xf32> +// CHECK: %[[T1:.*]] = vector.extract %[[T0]][0] : vector<1xf32> +// CHECK: %[[T2:.*]] = splat %[[T1]] : vector<3xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<4x3xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<4x1xf32> +// CHECK: %[[T5:.*]] = vector.extract %[[T4]][0] : vector<1xf32> +// CHECK: %[[T6:.*]] = splat %[[T5]] : vector<3xf32> +// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[A]][2] : vector<4x1xf32> +// CHECK: %[[T9:.*]] = vector.extract %[[T8]][0] : vector<1xf32> +// CHECK: %[[T10:.*]] = splat %[[T9]] : vector<3xf32> +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32> +// CHECK: %[[T12:.*]] = vector.extract %[[A]][3] : vector<4x1xf32> +// CHECK: %[[T13:.*]] = vector.extract %[[T12]][0] : vector<1xf32> +// CHECK: %[[T14:.*]] = splat %[[T13]] : vector<3xf32> +// CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32> +// CHECK: return %[[T15]] : vector<4x3xf32> + +func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> { + %0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<4x3xf32> + return %0 : vector<4x3xf32> +} + +// CHECK-LABEL: func @broadcast_stretch_in_middle +// CHECK-SAME: %[[A:.*0]]: vector<4x1x2xf32> +// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<4x3x2xf32> +// CHECK: %[[C1:.*]] = constant dense<0.000000e+00> : vector<3x2xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x1x2xf32> +// CHECK: %[[T1:.*]] = vector.extract %[[T0]][0] : vector<1x2xf32> +// CHECK: %[[T2:.*]] = vector.insert %[[T1]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T1]], %[[T2]] [1] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T4:.*]] = vector.insert %[[T1]], %[[T3]] [2] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<4x1x2xf32> +// CHECK: %[[T7:.*]] = vector.extract %[[T6]][0] : vector<1x2xf32> +// CHECK: %[[T8:.*]] = vector.insert %[[T7]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T9:.*]] = vector.insert %[[T7]], %[[T8]] [1] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T10:.*]] = vector.insert %[[T7]], %[[T9]] [2] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T12:.*]] = vector.extract %[[A]][2] : vector<4x1x2xf32> +// CHECK: %[[T13:.*]] = vector.extract %[[T12]][0] : vector<1x2xf32> +// CHECK: %[[T14:.*]] = vector.insert %[[T13]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T15:.*]] = vector.insert %[[T13]], %[[T14]] [1] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T16:.*]] = vector.insert %[[T13]], %[[T15]] [2] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T17:.*]] = vector.insert %[[T16]], %[[T11]] [2] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T18:.*]] = vector.extract %[[A]][3] : vector<4x1x2xf32> +// CHECK: %[[T19:.*]] = vector.extract %[[T18]][0] : vector<1x2xf32> +// CHECK: %[[T20:.*]] = vector.insert %[[T19]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T21:.*]] = vector.insert %[[T19]], %[[T20]] [1] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T22:.*]] = vector.insert %[[T19]], %[[T21]] [2] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T23:.*]] = vector.insert %[[T22]], %[[T17]] [3] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: return %[[T23]] : vector<4x3x2xf32> + +func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32> { + %0 = vector.broadcast %arg0 : vector<4x1x2xf32> to vector<4x3x2xf32> + return %0 : vector<4x3x2xf32> +}