diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -646,6 +646,91 @@ } }; +/// Decomposes ShapeCastOp on tuple-of-vectors to multiple ShapeCastOps, each +/// on vector types. +struct ShapeCastOpDecomposer : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, + PatternRewriter &rewriter) const override { + // Check if 'shapeCastOp' has tuple source/result type. + auto sourceTupleType = + shapeCastOp.source().getType().dyn_cast_or_null(); + auto resultTupleType = + shapeCastOp.result().getType().dyn_cast_or_null(); + if (!sourceTupleType || !resultTupleType) + return matchFailure(); + assert(sourceTupleType.size() == resultTupleType.size()); + + // Create single-vector ShapeCastOp for each source tuple element. + Location loc = shapeCastOp.getLoc(); + SmallVector resultElements; + resultElements.reserve(resultTupleType.size()); + for (unsigned i = 0, e = sourceTupleType.size(); i < e; ++i) { + auto sourceElement = rewriter.create( + loc, sourceTupleType.getType(i), shapeCastOp.source(), + rewriter.getI64IntegerAttr(i)); + resultElements.push_back(rewriter.create( + loc, resultTupleType.getType(i), sourceElement)); + } + + // Replace 'shapeCastOp' with tuple of 'resultElements'. + rewriter.replaceOpWithNewOp(shapeCastOp, resultTupleType, + resultElements); + return matchSuccess(); + } +}; + +/// ShapeCastOpFolder folds cancelling ShapeCastOps away. +// +// EX: +// +// The following MLIR with cancelling ShapeCastOps: +// +// %0 = source : vector<5x4x2xf32> +// %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32> +// %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32> +// %3 = user %2 : vector<5x4x2xf32> +// +// Should canonicalize to the following: +// +// +// %0 = source : vector<5x4x2xf32> +// %1 = user %0 : vector<5x4x2xf32> +// +struct ShapeCastOpFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, + PatternRewriter &rewriter) const override { + // Check if 'shapeCastOp' has vector source/result type. + auto sourceVectorType = + shapeCastOp.source().getType().dyn_cast_or_null(); + auto resultVectorType = + shapeCastOp.result().getType().dyn_cast_or_null(); + if (!sourceVectorType || !resultVectorType) + return matchFailure(); + + // Check if shape cast op source operand is also a shape cast op. + auto sourceShapeCastOp = dyn_cast_or_null( + shapeCastOp.source().getDefiningOp()); + if (!sourceShapeCastOp) + return matchFailure(); + auto operandSourceVectorType = + sourceShapeCastOp.source().getType().cast(); + auto operandResultVectorType = + sourceShapeCastOp.result().getType().cast(); + + // Check if shape cast operations invert each other. + if (operandSourceVectorType != resultVectorType || + operandResultVectorType != sourceVectorType) + return matchFailure(); + + rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source()); + return matchSuccess(); + } +}; + // Patter rewrite which forward tuple elements to their users. // User(TupleGetOp(ExtractSlicesOp(InsertSlicesOp(TupleOp(Producer))))) // -> User(Producer) @@ -784,8 +869,8 @@ // TODO(andydavis) Add this as DRR pattern. void mlir::vector::populateVectorToVectorTransformationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert( - context); + patterns.insert(context); } void mlir::vector::populateVectorSlicesLoweringPatterns( diff --git a/mlir/test/Dialect/VectorOps/vector-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-transforms.mlir --- a/mlir/test/Dialect/VectorOps/vector-transforms.mlir +++ b/mlir/test/Dialect/VectorOps/vector-transforms.mlir @@ -346,3 +346,62 @@ return } + +// Test that ShapeCastOp on tuple of vectors, decomposes to multiple +// ShapeCastOps on vectors. +// CHECK-LABEL: func @shape_cast_decomposition +// CHECK: %[[V0:.*]] = vector.shape_cast %{{.*}} : vector<5x4x2xf32> to vector<20x2xf32> +// CHECK-NEXT: %[[V1:.*]] = vector.shape_cast %{{.*}} : vector<3x4x2xf32> to vector<12x2xf32> +// CHECK-NEXT: return %[[V0]], %[[V1]] : vector<20x2xf32>, vector<12x2xf32> + +func @shape_cast_decomposition(%arg0 : vector<5x4x2xf32>, + %arg1 : vector<3x4x2xf32>) + -> (vector<20x2xf32>, vector<12x2xf32>) { + %0 = vector.tuple %arg0, %arg1 : vector<5x4x2xf32>, vector<3x4x2xf32> + %1 = vector.shape_cast %0 : tuple, vector<3x4x2xf32>> to + tuple, vector<12x2xf32>> + %2 = vector.tuple_get %1, 0 : tuple, vector<12x2xf32>> + %3 = vector.tuple_get %1, 1 : tuple, vector<12x2xf32>> + return %2, %3 : vector<20x2xf32>, vector<12x2xf32> +} + +// Test that cancelling ShapeCastOps are canonicalized away. +// EX: +// +// The following MLIR with cancelling ShapeCastOps: +// +// %0 = source : vector<5x4x2xf32> +// %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32> +// %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32> +// %3 = user %2 : vector<5x4x2xf32> +// +// Should canonicalize to the following: +// +// +// %0 = source : vector<5x4x2xf32> +// %1 = user %0 : vector<5x4x2xf32> +// + +// ShapeCastOps on vectors. +// CHECK-LABEL: func @shape_cast_fold +// CHECK: return %{{.*}}, %{{.*}} : vector<5x4x2xf32>, vector<3x4x2xf32> + +func @shape_cast_fold(%arg0 : vector<5x4x2xf32>, %arg1 : vector<3x4x2xf32>) + -> (vector<5x4x2xf32>, vector<3x4x2xf32>) { + %0 = vector.tuple %arg0, %arg1 : vector<5x4x2xf32>, vector<3x4x2xf32> + + %1 = vector.shape_cast %0 : tuple, vector<3x4x2xf32>> to + tuple, vector<12x2xf32>> + + %2 = vector.tuple_get %1, 0 : tuple, vector<12x2xf32>> + %3 = vector.tuple_get %1, 1 : tuple, vector<12x2xf32>> + + %4 = vector.tuple %2, %3 : vector<20x2xf32>, vector<12x2xf32> + %5 = vector.shape_cast %4 : tuple, vector<12x2xf32>> to + tuple, vector<3x4x2xf32>> + + %6 = vector.tuple_get %5, 0 : tuple, vector<3x4x2xf32>> + %7 = vector.tuple_get %5, 1 : tuple, vector<3x4x2xf32>> + + return %6, %7 : vector<5x4x2xf32>, vector<3x4x2xf32> +}