diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp --- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp @@ -306,27 +306,35 @@ } }; -struct ExtensionOverInsert final : NarrowingPattern { - using NarrowingPattern::NarrowingPattern; - - LogicalResult matchAndRewrite(vector::InsertOp op, - PatternRewriter &rewriter) const override { +/// Base pattern for `vector.insert` narrowing patterns. +template +struct ExtensionOverInsertionPattern : NarrowingPattern { + using NarrowingPattern::NarrowingPattern; + + /// Derived classes must provide a function to create the matching insertion + /// op based on the original op and new arguments. + virtual InsertionOp createInsertionOp(PatternRewriter &rewriter, + InsertionOp origInsert, + Value narrowValue, + Value narrowDest) const = 0; + + LogicalResult matchAndRewrite(InsertionOp op, + PatternRewriter &rewriter) const final { FailureOr ext = ExtensionOp::from(op.getSource().getDefiningOp()); if (failed(ext)) return failure(); - FailureOr newInsert = - createNarrowInsert(op, rewriter, *ext); + FailureOr newInsert = createNarrowInsert(op, rewriter, *ext); if (failed(newInsert)) return failure(); ext->recreateAndReplace(rewriter, op, *newInsert); return success(); } - FailureOr createNarrowInsert(vector::InsertOp op, - PatternRewriter &rewriter, - ExtensionOp insValue) const { + FailureOr createNarrowInsert(InsertionOp op, + PatternRewriter &rewriter, + ExtensionOp insValue) const { // Calculate the operand and result bitwidths. We can only apply narrowing // when the inserted source value and destination vector require fewer bits // than the result. Because the source and destination may have different @@ -337,6 +345,8 @@ if (failed(origBitsRequired)) return failure(); + // TODO: We could relax this check by disregarding bitwidth requirements of + // elements that we know will be replaced by the insertion. FailureOr destBitsRequired = calculateBitsRequired(op.getDest(), insValue.getKind()); if (failed(destBitsRequired) || *destBitsRequired >= *origBitsRequired) @@ -352,12 +362,13 @@ // both the source and the destination values. unsigned newInsertionBits = std::max(*destBitsRequired, *insertedBitsRequired); - FailureOr newVecTy = getNarrowType(newInsertionBits, op.getType()); + FailureOr newVecTy = + this->getNarrowType(newInsertionBits, op.getType()); if (failed(newVecTy) || *newVecTy == op.getType()) return failure(); FailureOr newInsertedValueTy = - getNarrowType(newInsertionBits, insValue.getType()); + this->getNarrowType(newInsertionBits, insValue.getType()); if (failed(newInsertedValueTy)) return failure(); @@ -366,8 +377,47 @@ loc, *newInsertedValueTy, insValue.getResult()); Value narrowDest = rewriter.createOrFold(loc, *newVecTy, op.getDest()); - return rewriter.create(loc, narrowValue, narrowDest, - op.getPosition()); + return createInsertionOp(rewriter, op, narrowValue, narrowDest); + } +}; + +struct ExtensionOverInsert final + : ExtensionOverInsertionPattern { + using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern; + + vector::InsertOp createInsertionOp(PatternRewriter &rewriter, + vector::InsertOp origInsert, + Value narrowValue, + Value narrowDest) const override { + return rewriter.create( + origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition()); + } +}; + +struct ExtensionOverInsertElement final + : ExtensionOverInsertionPattern { + using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern; + + vector::InsertElementOp createInsertionOp(PatternRewriter &rewriter, + vector::InsertElementOp origInsert, + Value narrowValue, + Value narrowDest) const override { + return rewriter.create( + origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition()); + } +}; + +struct ExtensionOverInsertStridedSlice final + : ExtensionOverInsertionPattern { + using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern; + + vector::InsertStridedSliceOp + createInsertionOp(PatternRewriter &rewriter, + vector::InsertStridedSliceOp origInsert, Value narrowValue, + Value narrowDest) const override { + return rewriter.create( + origInsert.getLoc(), narrowValue, narrowDest, origInsert.getOffsets(), + origInsert.getStrides()); } }; @@ -400,7 +450,8 @@ // Add commute patterns with a higher benefit. This is to expose more // optimization opportunities to narrowing patterns. patterns.add( + ExtensionOverExtractStridedSlice, ExtensionOverInsert, + ExtensionOverInsertElement, ExtensionOverInsertStridedSlice>( patterns.getContext(), options, PatternBenefit(2)); patterns.add(patterns.getContext(), options); diff --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir --- a/mlir/test/Dialect/Arith/int-narrowing.mlir +++ b/mlir/test/Dialect/Arith/int-narrowing.mlir @@ -328,3 +328,117 @@ %e = vector.insert %d, %cst [1] : i32 into vector<3xi32> return %e : vector<3xi32> } + +// CHECK-LABEL: func.func @extsi_over_insertelement_3xi16 +// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: i16, %[[POS:.+]]: i32) +// CHECK-NEXT: %[[INS:.+]] = vector.insertelement %[[ARG1]], %[[ARG0]][%[[POS]] : i32] : vector<3xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32> +// CHECK-NEXT: return %[[RET]] : vector<3xi32> +func.func @extsi_over_insertelement_3xi16(%a: vector<3xi16>, %b: i16, %pos: i32) -> vector<3xi32> { + %c = arith.extsi %a : vector<3xi16> to vector<3xi32> + %d = arith.extsi %b : i16 to i32 + %e = vector.insertelement %d, %c[%pos : i32] : vector<3xi32> + return %e : vector<3xi32> +} + +// CHECK-LABEL: func.func @extui_over_insertelement_3xi16 +// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: i16, %[[POS:.+]]: i32) +// CHECK-NEXT: %[[INS:.+]] = vector.insertelement %[[ARG1]], %[[ARG0]][%[[POS]] : i32] : vector<3xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32> +// CHECK-NEXT: return %[[RET]] : vector<3xi32> +func.func @extui_over_insertelement_3xi16(%a: vector<3xi16>, %b: i16, %pos: i32) -> vector<3xi32> { + %c = arith.extui %a : vector<3xi16> to vector<3xi32> + %d = arith.extui %b : i16 to i32 + %e = vector.insertelement %d, %c[%pos : i32] : vector<3xi32> + return %e : vector<3xi32> +} + +// CHECK-LABEL: func.func @extsi_over_insertelement_3xi16_cst_i16 +// CHECK-SAME: (%[[ARG:.+]]: i8, %[[POS:.+]]: i32) +// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[-1, 128, 0]> : vector<3xi16> +// CHECK-NEXT: %[[SRCE:.+]] = arith.extsi %[[ARG]] : i8 to i32 +// CHECK-NEXT: %[[SRCT:.+]] = arith.trunci %[[SRCE]] : i32 to i16 +// CHECK-NEXT: %[[INS:.+]] = vector.insertelement %[[SRCT]], %[[CST]][%[[POS]] : i32] : vector<3xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32> +// CHECK-NEXT: return %[[RET]] : vector<3xi32> +func.func @extsi_over_insertelement_3xi16_cst_i16(%a: i8, %pos: i32) -> vector<3xi32> { + %cst = arith.constant dense<[-1, 128, 0]> : vector<3xi32> + %d = arith.extsi %a : i8 to i32 + %e = vector.insertelement %d, %cst[%pos : i32] : vector<3xi32> + return %e : vector<3xi32> +} + +// CHECK-LABEL: func.func @extui_over_insertelement_3xi16_cst_i16 +// CHECK-SAME: (%[[ARG:.+]]: i8, %[[POS:.+]]: i32) +// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[1, 256, 0]> : vector<3xi16> +// CHECK-NEXT: %[[SRCE:.+]] = arith.extui %[[ARG]] : i8 to i32 +// CHECK-NEXT: %[[SRCT:.+]] = arith.trunci %[[SRCE]] : i32 to i16 +// CHECK-NEXT: %[[INS:.+]] = vector.insertelement %[[SRCT]], %[[CST]][%[[POS]] : i32] : vector<3xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32> +// CHECK-NEXT: return %[[RET]] : vector<3xi32> +func.func @extui_over_insertelement_3xi16_cst_i16(%a: i8, %pos: i32) -> vector<3xi32> { + %cst = arith.constant dense<[1, 256, 0]> : vector<3xi32> + %d = arith.extui %a : i8 to i32 + %e = vector.insertelement %d, %cst[%pos : i32] : vector<3xi32> + return %e : vector<3xi32> +} + +// CHECK-LABEL: func.func @extsi_over_insert_strided_slice_1d +// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: vector<2xi16>) +// CHECK-NEXT: %[[INS:.+]] = vector.insert_strided_slice %[[ARG1]], %[[ARG0]] +// CHECK-SAME: {offsets = [1], strides = [1]} : vector<2xi16> into vector<3xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32> +// CHECK-NEXT: return %[[RET]] : vector<3xi32> +func.func @extsi_over_insert_strided_slice_1d(%a: vector<3xi16>, %b: vector<2xi16>) -> vector<3xi32> { + %c = arith.extsi %a : vector<3xi16> to vector<3xi32> + %d = arith.extsi %b : vector<2xi16> to vector<2xi32> + %e = vector.insert_strided_slice %d, %c {offsets = [1], strides = [1]} : vector<2xi32> into vector<3xi32> + return %e : vector<3xi32> +} + +// CHECK-LABEL: func.func @extui_over_insert_strided_slice_1d +// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: vector<2xi16>) +// CHECK-NEXT: %[[INS:.+]] = vector.insert_strided_slice %[[ARG1]], %[[ARG0]] +// CHECK-SAME: {offsets = [1], strides = [1]} : vector<2xi16> into vector<3xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32> +// CHECK-NEXT: return %[[RET]] : vector<3xi32> +func.func @extui_over_insert_strided_slice_1d(%a: vector<3xi16>, %b: vector<2xi16>) -> vector<3xi32> { + %c = arith.extui %a : vector<3xi16> to vector<3xi32> + %d = arith.extui %b : vector<2xi16> to vector<2xi32> + %e = vector.insert_strided_slice %d, %c {offsets = [1], strides = [1]} : vector<2xi32> into vector<3xi32> + return %e : vector<3xi32> +} + +// CHECK-LABEL: func.func @extsi_over_insert_strided_slice_cst_2d +// CHECK-SAME: (%[[ARG:.+]]: vector<1x2xi8>) +// CHECK-NEXT: %[[CST:.+]] = arith.constant +// CHECK-SAME{LITERAL}: dense<[[-1, 128, 0], [-129, 42, 1337]]> : vector<2x3xi16> +// CHECK-NEXT: %[[SRCE:.+]] = arith.extsi %[[ARG]] : vector<1x2xi8> to vector<1x2xi32> +// CHECK-NEXT: %[[SRCT:.+]] = arith.trunci %[[SRCE]] : vector<1x2xi32> to vector<1x2xi16> +// CHECK-NEXT: %[[INS:.+]] = vector.insert_strided_slice %[[SRCT]], %[[CST]] +// CHECK-SAME: {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi16> into vector<2x3xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<2x3xi16> to vector<2x3xi32> +// CHECK-NEXT: return %[[RET]] : vector<2x3xi32> +func.func @extsi_over_insert_strided_slice_cst_2d(%a: vector<1x2xi8>) -> vector<2x3xi32> { + %cst = arith.constant dense<[[-1, 128, 0], [-129, 42, 1337]]> : vector<2x3xi32> + %d = arith.extsi %a : vector<1x2xi8> to vector<1x2xi32> + %e = vector.insert_strided_slice %d, %cst {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi32> into vector<2x3xi32> + return %e : vector<2x3xi32> +} + +// CHECK-LABEL: func.func @extui_over_insert_strided_slice_cst_2d +// CHECK-SAME: (%[[ARG:.+]]: vector<1x2xi8>) +// CHECK-NEXT: %[[CST:.+]] = arith.constant +// CHECK-SAME{LITERAL}: dense<[[1, 128, 0], [256, 42, 1337]]> : vector<2x3xi16> +// CHECK-NEXT: %[[SRCE:.+]] = arith.extui %[[ARG]] : vector<1x2xi8> to vector<1x2xi32> +// CHECK-NEXT: %[[SRCT:.+]] = arith.trunci %[[SRCE]] : vector<1x2xi32> to vector<1x2xi16> +// CHECK-NEXT: %[[INS:.+]] = vector.insert_strided_slice %[[SRCT]], %[[CST]] +// CHECK-SAME: {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi16> into vector<2x3xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[INS]] : vector<2x3xi16> to vector<2x3xi32> +// CHECK-NEXT: return %[[RET]] : vector<2x3xi32> +func.func @extui_over_insert_strided_slice_cst_2d(%a: vector<1x2xi8>) -> vector<2x3xi32> { + %cst = arith.constant dense<[[1, 128, 0], [256, 42, 1337]]> : vector<2x3xi32> + %d = arith.extui %a : vector<1x2xi8> to vector<1x2xi32> + %e = vector.insert_strided_slice %d, %cst {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi32> into vector<2x3xi32> + return %e : vector<2x3xi32> +}