diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp --- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp @@ -249,6 +249,26 @@ // Patterns to Commute Extension Ops //===----------------------------------------------------------------------===// +struct ExtensionOverBroadcast final : NarrowingPattern { + using NarrowingPattern::NarrowingPattern; + + LogicalResult matchAndRewrite(vector::BroadcastOp op, + PatternRewriter &rewriter) const override { + FailureOr ext = + ExtensionOp::from(op.getSource().getDefiningOp()); + if (failed(ext)) + return failure(); + + VectorType origTy = op.getResultVectorType(); + VectorType newTy = + origTy.cloneWith(origTy.getShape(), ext->getInElementType()); + Value newBroadcast = + rewriter.create(op.getLoc(), newTy, ext->getIn()); + ext->recreateAndReplace(rewriter, op, newBroadcast); + return success(); + } +}; + struct ExtensionOverExtract final : NarrowingPattern { using NarrowingPattern::NarrowingPattern; @@ -421,6 +441,68 @@ } }; +struct ExtensionOverShapeCast final : NarrowingPattern { + using NarrowingPattern::NarrowingPattern; + + LogicalResult matchAndRewrite(vector::ShapeCastOp op, + PatternRewriter &rewriter) const override { + FailureOr ext = + ExtensionOp::from(op.getSource().getDefiningOp()); + if (failed(ext)) + return failure(); + + VectorType origTy = op.getResultVectorType(); + VectorType newTy = + origTy.cloneWith(origTy.getShape(), ext->getInElementType()); + Value newCast = + rewriter.create(op.getLoc(), newTy, ext->getIn()); + ext->recreateAndReplace(rewriter, op, newCast); + return success(); + } +}; + +struct ExtensionOverTranspose final : NarrowingPattern { + using NarrowingPattern::NarrowingPattern; + + LogicalResult matchAndRewrite(vector::TransposeOp op, + PatternRewriter &rewriter) const override { + FailureOr ext = + ExtensionOp::from(op.getVector().getDefiningOp()); + if (failed(ext)) + return failure(); + + VectorType origTy = op.getResultVectorType(); + VectorType newTy = + origTy.cloneWith(origTy.getShape(), ext->getInElementType()); + Value newTranspose = rewriter.create( + op.getLoc(), newTy, ext->getIn(), op.getTransp()); + ext->recreateAndReplace(rewriter, op, newTranspose); + return success(); + } +}; + +struct ExtensionOverFlatTranspose final + : NarrowingPattern { + using NarrowingPattern::NarrowingPattern; + + LogicalResult matchAndRewrite(vector::FlatTransposeOp op, + PatternRewriter &rewriter) const override { + FailureOr ext = + ExtensionOp::from(op.getMatrix().getDefiningOp()); + if (failed(ext)) + return failure(); + + VectorType origTy = op.getType(); + VectorType newTy = + origTy.cloneWith(origTy.getShape(), ext->getInElementType()); + Value newTranspose = rewriter.create( + op.getLoc(), newTy, ext->getIn(), op.getRowsAttr(), + op.getColumnsAttr()); + ext->recreateAndReplace(rewriter, op, newTranspose); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Pass Definitions //===----------------------------------------------------------------------===// @@ -449,9 +531,11 @@ RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) { // Add commute patterns with a higher benefit. This is to expose more // optimization opportunities to narrowing patterns. - patterns.add( + patterns.add( patterns.getContext(), options, PatternBenefit(2)); patterns.add(patterns.getContext(), options); diff --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir --- a/mlir/test/Dialect/Arith/int-narrowing.mlir +++ b/mlir/test/Dialect/Arith/int-narrowing.mlir @@ -442,3 +442,91 @@ %e = vector.insert_strided_slice %d, %cst {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi32> into vector<2x3xi32> return %e : vector<2x3xi32> } + +// CHECK-LABEL: func.func @extsi_over_broadcast_3xi16 +// CHECK-SAME: (%[[ARG:.+]]: i16) +// CHECK-NEXT: %[[BCST:.+]] = vector.broadcast %[[ARG]] : i16 to vector<3xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[BCST]] : vector<3xi16> to vector<3xi32> +// CHECK-NEXT: return %[[RET]] : vector<3xi32> +func.func @extsi_over_broadcast_3xi16(%a: i16) -> vector<3xi32> { + %b = arith.extsi %a : i16 to i32 + %r = vector.broadcast %b : i32 to vector<3xi32> + return %r : vector<3xi32> +} + +// CHECK-LABEL: func.func @extui_over_broadcast_2x3xi16 +// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>) +// CHECK-NEXT: %[[BCST:.+]] = vector.broadcast %[[ARG]] : vector<3xi16> to vector<2x3xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[BCST]] : vector<2x3xi16> to vector<2x3xi32> +// CHECK-NEXT: return %[[RET]] : vector<2x3xi32> +func.func @extui_over_broadcast_2x3xi16(%a: vector<3xi16>) -> vector<2x3xi32> { + %b = arith.extui %a : vector<3xi16> to vector<3xi32> + %r = vector.broadcast %b : vector<3xi32> to vector<2x3xi32> + return %r : vector<2x3xi32> +} + +// CHECK-LABEL: func.func @extsi_over_shape_cast_2x3xi16 +// CHECK-SAME: (%[[ARG:.+]]: vector<2x3xi16>) +// CHECK-NEXT: %[[CAST:.+]] = vector.shape_cast %[[ARG]] : vector<2x3xi16> to vector<3x2xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[CAST]] : vector<3x2xi16> to vector<3x2xi32> +// CHECK-NEXT: return %[[RET]] : vector<3x2xi32> +func.func @extsi_over_shape_cast_2x3xi16(%a: vector<2x3xi16>) -> vector<3x2xi32> { + %b = arith.extsi %a : vector<2x3xi16> to vector<2x3xi32> + %r = vector.shape_cast %b : vector<2x3xi32> to vector<3x2xi32> + return %r : vector<3x2xi32> +} + +// CHECK-LABEL: func.func @extui_over_shape_cast_5x2x3xi16 +// CHECK-SAME: (%[[ARG:.+]]: vector<5x2x3xi16>) +// CHECK-NEXT: %[[CAST:.+]] = vector.shape_cast %[[ARG]] : vector<5x2x3xi16> to vector<2x3x5xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[CAST]] : vector<2x3x5xi16> to vector<2x3x5xi32> +// CHECK-NEXT: return %[[RET]] : vector<2x3x5xi32> +func.func @extui_over_shape_cast_5x2x3xi16(%a: vector<5x2x3xi16>) -> vector<2x3x5xi32> { + %b = arith.extui %a : vector<5x2x3xi16> to vector<5x2x3xi32> + %r = vector.shape_cast %b : vector<5x2x3xi32> to vector<2x3x5xi32> + return %r : vector<2x3x5xi32> +} + +// CHECK-LABEL: func.func @extsi_over_transpose_2x3xi16 +// CHECK-SAME: (%[[ARG:.+]]: vector<2x3xi16>) +// CHECK-NEXT: %[[TRAN:.+]] = vector.transpose %[[ARG]], [1, 0] : vector<2x3xi16> to vector<3x2xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[TRAN]] : vector<3x2xi16> to vector<3x2xi32> +// CHECK-NEXT: return %[[RET]] : vector<3x2xi32> +func.func @extsi_over_transpose_2x3xi16(%a: vector<2x3xi16>) -> vector<3x2xi32> { + %b = arith.extsi %a : vector<2x3xi16> to vector<2x3xi32> + %r = vector.transpose %b, [1, 0] : vector<2x3xi32> to vector<3x2xi32> + return %r : vector<3x2xi32> +} + +// CHECK-LABEL: func.func @extui_over_transpose_5x2x3xi16 +// CHECK-SAME: (%[[ARG:.+]]: vector<5x2x3xi16>) +// CHECK-NEXT: %[[TRAN:.+]] = vector.transpose %[[ARG]], [1, 2, 0] : vector<5x2x3xi16> to vector<2x3x5xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[TRAN]] : vector<2x3x5xi16> to vector<2x3x5xi32> +// CHECK-NEXT: return %[[RET]] : vector<2x3x5xi32> +func.func @extui_over_transpose_5x2x3xi16(%a: vector<5x2x3xi16>) -> vector<2x3x5xi32> { + %b = arith.extui %a : vector<5x2x3xi16> to vector<5x2x3xi32> + %r = vector.transpose %b, [1, 2, 0] : vector<5x2x3xi32> to vector<2x3x5xi32> + return %r : vector<2x3x5xi32> +} + +// CHECK-LABEL: func.func @extsi_over_flat_transpose_16xi16 +// CHECK-SAME: (%[[ARG:.+]]: vector<16xi16>) +// CHECK-NEXT: %[[TRAN:.+]] = vector.flat_transpose %[[ARG]] {columns = 4 : i32, rows = 4 : i32} : vector<16xi16> -> vector<16xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[TRAN]] : vector<16xi16> to vector<16xi32> +// CHECK-NEXT: return %[[RET]] : vector<16xi32> +func.func @extsi_over_flat_transpose_16xi16(%a: vector<16xi16>) -> vector<16xi32> { + %b = arith.extsi %a : vector<16xi16> to vector<16xi32> + %r = vector.flat_transpose %b {columns = 4 : i32, rows = 4 : i32} : vector<16xi32> -> vector<16xi32> + return %r : vector<16xi32> +} + +// CHECK-LABEL: func.func @extui_over_flat_transpose_16xi16 +// CHECK-SAME: (%[[ARG:.+]]: vector<16xi16>) +// CHECK-NEXT: %[[TRAN:.+]] = vector.flat_transpose %[[ARG]] {columns = 8 : i32, rows = 2 : i32} : vector<16xi16> -> vector<16xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[TRAN]] : vector<16xi16> to vector<16xi32> +// CHECK-NEXT: return %[[RET]] : vector<16xi32> +func.func @extui_over_flat_transpose_16xi16(%a: vector<16xi16>) -> vector<16xi32> { + %b = arith.extui %a : vector<16xi16> to vector<16xi32> + %r = vector.flat_transpose %b {columns = 8 : i32, rows = 2 : i32} : vector<16xi32> -> vector<16xi32> + return %r : vector<16xi32> +}