diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1466,6 +1466,61 @@ } }; +// We typically should not lower general shape cast operations into data +// movement instructions, since the assumption is that these casts are +// optimized away during progressive lowering. For completeness, however, +// we fall back to a reference implementation that moves all elements +// into the right place if we get here. +class ShapeCastOpRewritePattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ShapeCastOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto sourceVectorType = op.getSourceVectorType(); + auto resultVectorType = op.getResultVectorType(); + // Intended 2D/1D lowerings with better implementations. + int64_t srcRank = sourceVectorType.getRank(); + int64_t resRank = resultVectorType.getRank(); + if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2)) + return failure(); + // Compute number of elements involved in the reshape. + int64_t numElts = 1; + for (int64_t r = 0; r < srcRank; r++) + numElts *= sourceVectorType.getDimSize(r); + // Replace with data movement operations: + // x[0,0,0] = y[0,0] + // x[0,0,1] = y[0,1] + // x[0,1,0] = y[0,2] + // etc., incrementing the two index vectors "row-major" + // within the source and result shape. + SmallVector srcIdx(srcRank); + SmallVector resIdx(resRank); + Value result = rewriter.create( + loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); + for (int64_t i = 0; i < numElts; i++) { + if (i != 0) { + incIdx(srcIdx, sourceVectorType, srcRank - 1); + incIdx(resIdx, resultVectorType, resRank - 1); + } + Value e = rewriter.create(loc, op.source(), srcIdx); + result = rewriter.create(loc, e, result, resIdx); + } + rewriter.replaceOp(op, result); + return success(); + } + +private: + static void incIdx(SmallVector &idx, VectorType tp, int64_t r) { + assert(0 <= r && r < tp.getRank()); + if (++idx[r] == tp.getDimSize(r)) { + idx[r] = 0; + incIdx(idx, tp, r - 1); + } + } +}; + } // namespace namespace mlir { @@ -1864,7 +1919,8 @@ ConstantMaskOpLowering, OuterProductOpLowering, ShapeCastOp2DDownCastRewritePattern, - ShapeCastOp2DUpCastRewritePattern>(context); + ShapeCastOp2DUpCastRewritePattern, + ShapeCastOpRewritePattern>(context); patterns.insert } - // CHECK-LABEL: func @nop_shape_cast // CHECK-SAME: %[[A:.*]]: vector<16xf32> // CHECK: return %[[A]] : vector<16xf32> @@ -378,6 +377,72 @@ return %r0, %1 : vector<4xf32>, vector<2x2xf32> } +// CHECK-LABEL: func @shape_cast_2d2d +// CHECK-SAME: %[[A:.*]]: vector<3x2xf32> +// CHECK: %[[C:.*]] = constant dense<0.000000e+00> : vector<2x3xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<3x2xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : f32 into vector<2x3xf32> +// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<3x2xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<2x3xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<3x2xf32> +// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] : f32 into vector<2x3xf32> +// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<3x2xf32> +// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] : f32 into vector<2x3xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<3x2xf32> +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<2x3xf32> +// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : vector<3x2xf32> +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] : f32 into vector<2x3xf32> +// CHECK: return %[[T11]] : vector<2x3xf32> + +func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> { + %s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32> + return %s : vector<2x3xf32> +} + +// CHECK-LABEL: func @shape_cast_3d1d +// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32> +// CHECK: %[[C:.*]] = constant dense<0.000000e+00> : vector<6xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0, 0] : vector<1x3x2xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0] : f32 into vector<6xf32> +// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 0, 1] : vector<1x3x2xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : f32 into vector<6xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1, 0] : vector<1x3x2xf32> +// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : f32 into vector<6xf32> +// CHECK: %[[T6:.*]] = vector.extract %[[A]][0, 1, 1] : vector<1x3x2xf32> +// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [3] : f32 into vector<6xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2, 0] : vector<1x3x2xf32> +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [4] : f32 into vector<6xf32> +// CHECK: %[[T10:.*]] = vector.extract %[[A]][0, 2, 1] : vector<1x3x2xf32> +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [5] : f32 into vector<6xf32> +// CHECK: return %[[T11]] : vector<6xf32> + +func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> { + %s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32> + return %s : vector<6xf32> +} + +// CHECK-LABEL: func @shape_cast_1d3d +// CHECK-SAME: %[[A:.*]]: vector<6xf32> +// CHECK: %[[C:.*]] = constant dense<0.000000e+00> : vector<2x1x3xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<6xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0, 0] : f32 into vector<2x1x3xf32> +// CHECK: %[[T2:.*]] = vector.extract %[[A]][1] : vector<6xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 0, 1] : f32 into vector<2x1x3xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[A]][2] : vector<6xf32> +// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 0, 2] : f32 into vector<2x1x3xf32> +// CHECK: %[[T6:.*]] = vector.extract %[[A]][3] : vector<6xf32> +// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0, 0] : f32 into vector<2x1x3xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[A]][4] : vector<6xf32> +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 0, 1] : f32 into vector<2x1x3xf32> +// CHECK: %[[T10:.*]] = vector.extract %[[A]][5] : vector<6xf32> +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 0, 2] : f32 into vector<2x1x3xf32> +// CHECK: return %[[T11]] : vector<2x1x3xf32> + +func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> { + %s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32> + return %s : vector<2x1x3xf32> +} + // MATRIX-LABEL: func @matmul // MATRIX-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, // MATRIX-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,