diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp --- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp @@ -70,7 +70,7 @@ if (!isa(elemTy)) return failure(); - auto newElemTy = IntegerType::get(origTy.getContext(), bitsRequired); + auto newElemTy = IntegerType::get(origTy.getContext(), *bestBitwidth); if (newElemTy == elemTy) return failure(); @@ -100,11 +100,58 @@ enum class ExtensionKind { Sign, Zero }; +ExtensionKind getExtensionKind(Operation *op) { + assert(op); + assert((isa(op)) && "Not an extension op"); + return isa(op) ? ExtensionKind::Sign : ExtensionKind::Zero; +} + +/// Returns the integer bitwidth required to represent `value`. +unsigned calculateBitsRequired(const APInt &value, + ExtensionKind lookThroughExtension) { + // For unsigned values, we only need the active bits. As a special case, zero + // requires one bit. + if (lookThroughExtension == ExtensionKind::Zero) + return std::max(value.getActiveBits(), 1u); + + // If a signed value is nonnegative, we need one extra bit for the sign. + if (value.isNonNegative()) + return value.getActiveBits() + 1; + + // For the signed min, we need all the bits. + if (value.isMinSignedValue()) + return value.getBitWidth(); + + // For negative values, we need all the non-sign bits and one extra bit for + // the sign. + return value.getBitWidth() - value.getNumSignBits() + 1; +} + /// Returns the integer bitwidth required to represent `value`. /// Looks through either sign- or zero-extension as specified by /// `lookThroughExtension`. FailureOr calculateBitsRequired(Value value, ExtensionKind lookThroughExtension) { + // Handle constants. + if (TypedAttr attr; matchPattern(value, m_Constant(&attr))) { + if (auto intAttr = dyn_cast(attr)) + return calculateBitsRequired(intAttr.getValue(), lookThroughExtension); + + if (auto elemsAttr = dyn_cast(attr)) { + if (elemsAttr.getElementType().isIntOrIndex()) { + if (elemsAttr.isSplat()) + return calculateBitsRequired(elemsAttr.getSplatValue(), + lookThroughExtension); + + unsigned maxBits = 1; + for (const APInt &elemValue : elemsAttr.getValues()) + maxBits = std::max( + maxBits, calculateBitsRequired(elemValue, lookThroughExtension)); + return maxBits; + } + } + } + if (lookThroughExtension == ExtensionKind::Sign) { if (auto sext = value.getDefiningOp()) return calculateBitsRequired(sext.getIn().getType()); @@ -150,8 +197,8 @@ // Patterns to Commute Extension Ops //===----------------------------------------------------------------------===// -struct ExtensionOverExtract final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct ExtensionOverExtract final : NarrowingPattern { + using NarrowingPattern::NarrowingPattern; LogicalResult matchAndRewrite(vector::ExtractOp op, PatternRewriter &rewriter) const override { @@ -172,8 +219,8 @@ }; struct ExtensionOverExtractElement final - : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + : NarrowingPattern { + using NarrowingPattern::NarrowingPattern; LogicalResult matchAndRewrite(vector::ExtractElementOp op, PatternRewriter &rewriter) const override { @@ -194,8 +241,8 @@ }; struct ExtensionOverExtractStridedSlice final - : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + : NarrowingPattern { + using NarrowingPattern::NarrowingPattern; LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -220,6 +267,72 @@ } }; +struct ExtensionOverInsert final : NarrowingPattern { + using NarrowingPattern::NarrowingPattern; + + LogicalResult matchAndRewrite(vector::InsertOp op, + PatternRewriter &rewriter) const override { + Operation *def = op.getSource().getDefiningOp(); + if (!def) + return failure(); + + return TypeSwitch(def) + .Case([&](auto extOp) { + FailureOr newInsert = + createNarrowInsert(op, rewriter, extOp); + if (failed(newInsert)) + return failure(); + rewriter.replaceOpWithNewOp(op, op.getType(), + *newInsert); + return success(); + }) + .Default(failure()); + } + + FailureOr createNarrowInsert(vector::InsertOp op, + PatternRewriter &rewriter, + Operation *insValue) const { + assert((isa(insValue))); + + FailureOr origBitsRequired = calculateBitsRequired(op.getType()); + if (failed(origBitsRequired)) + return failure(); + + ExtensionKind kind = getExtensionKind(insValue); + FailureOr destBitsRequired = + calculateBitsRequired(op.getDest(), kind); + if (failed(destBitsRequired) || *destBitsRequired >= *origBitsRequired) + return failure(); + + FailureOr insertedBitsRequired = + calculateBitsRequired(insValue->getOperands().front(), kind); + if (failed(insertedBitsRequired) || + *insertedBitsRequired >= *origBitsRequired) + return failure(); + + // Find a narrower element type that satisfies the bitwidth requirements of + // both the source and the destination values. + unsigned newInsertionBits = + std::max(*destBitsRequired, *insertedBitsRequired); + FailureOr newVecTy = getNarrowType(newInsertionBits, op.getType()); + if (failed(newVecTy) || *newVecTy == op.getType()) + return failure(); + + FailureOr newInsertedValueTy = + getNarrowType(newInsertionBits, insValue->getResultTypes().front()); + if (failed(newInsertedValueTy)) + return failure(); + + Location loc = op.getLoc(); + Value narrowValue = rewriter.createOrFold( + loc, *newInsertedValueTy, insValue->getResult(0)); + Value narrowDest = + rewriter.createOrFold(loc, *newVecTy, op.getDest()); + return rewriter.create(loc, narrowValue, narrowDest, + op.getPosition()); + } +}; + //===----------------------------------------------------------------------===// // Pass Definitions //===----------------------------------------------------------------------===// @@ -249,8 +362,8 @@ // Add commute patterns with a higher benefit. This is to expose more // optimization opportunities to narrowing patterns. patterns.add(patterns.getContext(), - PatternBenefit(2)); + ExtensionOverExtractStridedSlice, ExtensionOverInsert>( + patterns.getContext(), options, PatternBenefit(2)); patterns.add(patterns.getContext(), options); } diff --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir --- a/mlir/test/Dialect/Arith/int-narrowing.mlir +++ b/mlir/test/Dialect/Arith/int-narrowing.mlir @@ -235,3 +235,96 @@ {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi32> to vector<1x2xi32> return %c : vector<1x2xi32> } + +// CHECK-LABEL: func.func @extsi_over_insert_3xi16 +// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: i16) +// CHECK-NEXT: %[[INS:.+]] = vector.insert %[[ARG1]], %[[ARG0]] [1] : i16 into vector<3xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32> +// CHECK-NEXT: return %[[RET]] : vector<3xi32> +func.func @extsi_over_insert_3xi16(%a: vector<3xi16>, %b: i16) -> vector<3xi32> { + %c = arith.extsi %a : vector<3xi16> to vector<3xi32> + %d = arith.extsi %b : i16 to i32 + %e = vector.insert %d, %c [1] : i32 into vector<3xi32> + return %e : vector<3xi32> +} + +// CHECK-LABEL: func.func @extui_over_insert_3xi16 +// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: i16) +// CHECK-NEXT: %[[INS:.+]] = vector.insert %[[ARG1]], %[[ARG0]] [1] : i16 into vector<3xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32> +// CHECK-NEXT: return %[[RET]] : vector<3xi32> +func.func @extui_over_insert_3xi16(%a: vector<3xi16>, %b: i16) -> vector<3xi32> { + %c = arith.extui %a : vector<3xi16> to vector<3xi32> + %d = arith.extui %b : i16 to i32 + %e = vector.insert %d, %c [1] : i32 into vector<3xi32> + return %e : vector<3xi32> +} + +// CHECK-LABEL: func.func @extsi_over_insert_3xi16_cst_0 +// CHECK-SAME: (%[[ARG:.+]]: i16) +// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<0> : vector<3xi16> +// CHECK-NEXT: %[[INS:.+]] = vector.insert %[[ARG]], %[[CST]] [1] : i16 into vector<3xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32> +// CHECK-NEXT: return %[[RET]] : vector<3xi32> +func.func @extsi_over_insert_3xi16_cst_0(%a: i16) -> vector<3xi32> { + %cst = arith.constant dense<0> : vector<3xi32> + %d = arith.extsi %a : i16 to i32 + %e = vector.insert %d, %cst [1] : i32 into vector<3xi32> + return %e : vector<3xi32> +} + +// CHECK-LABEL: func.func @extsi_over_insert_3xi8_cst +// CHECK-SAME: (%[[ARG:.+]]: i8) +// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[-1, 127, -128]> : vector<3xi8> +// CHECK-NEXT: %[[INS:.+]] = vector.insert %[[ARG]], %[[CST]] [1] : i8 into vector<3xi8> +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi8> to vector<3xi32> +// CHECK-NEXT: return %[[RET]] : vector<3xi32> +func.func @extsi_over_insert_3xi8_cst(%a: i8) -> vector<3xi32> { + %cst = arith.constant dense<[-1, 127, -128]> : vector<3xi32> + %d = arith.extsi %a : i8 to i32 + %e = vector.insert %d, %cst [1] : i32 into vector<3xi32> + return %e : vector<3xi32> +} + +// CHECK-LABEL: func.func @extui_over_insert_3xi8_cst +// CHECK-SAME: (%[[ARG:.+]]: i8) +// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[1, 127, -1]> : vector<3xi8> +// CHECK-NEXT: %[[INS:.+]] = vector.insert %[[ARG]], %[[CST]] [1] : i8 into vector<3xi8> +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi8> to vector<3xi32> +// CHECK-NEXT: return %[[RET]] : vector<3xi32> +func.func @extui_over_insert_3xi8_cst(%a: i8) -> vector<3xi32> { + %cst = arith.constant dense<[1, 127, 255]> : vector<3xi32> + %d = arith.extui %a : i8 to i32 + %e = vector.insert %d, %cst [1] : i32 into vector<3xi32> + return %e : vector<3xi32> +} + +// CHECK-LABEL: func.func @extsi_over_insert_3xi16_cst_i16 +// CHECK-SAME: (%[[ARG:.+]]: i8) +// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[-1, 128, 0]> : vector<3xi16> +// CHECK-NEXT: %[[SRCE:.+]] = arith.extsi %[[ARG]] : i8 to i32 +// CHECK-NEXT: %[[SRCT:.+]] = arith.trunci %[[SRCE]] : i32 to i16 +// CHECK-NEXT: %[[INS:.+]] = vector.insert %[[SRCT]], %[[CST]] [1] : i16 into vector<3xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32> +// CHECK-NEXT: return %[[RET]] : vector<3xi32> +func.func @extsi_over_insert_3xi16_cst_i16(%a: i8) -> vector<3xi32> { + %cst = arith.constant dense<[-1, 128, 0]> : vector<3xi32> + %d = arith.extsi %a : i8 to i32 + %e = vector.insert %d, %cst [1] : i32 into vector<3xi32> + return %e : vector<3xi32> +} + +// CHECK-LABEL: func.func @extui_over_insert_3xi16_cst_i16 +// CHECK-SAME: (%[[ARG:.+]]: i8) +// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[1, 256, 0]> : vector<3xi16> +// CHECK-NEXT: %[[SRCE:.+]] = arith.extui %[[ARG]] : i8 to i32 +// CHECK-NEXT: %[[SRCT:.+]] = arith.trunci %[[SRCE]] : i32 to i16 +// CHECK-NEXT: %[[INS:.+]] = vector.insert %[[SRCT]], %[[CST]] [1] : i16 into vector<3xi16> +// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32> +// CHECK-NEXT: return %[[RET]] : vector<3xi32> +func.func @extui_over_insert_3xi16_cst_i16(%a: i8) -> vector<3xi32> { + %cst = arith.constant dense<[1, 256, 0]> : vector<3xi32> + %d = arith.extui %a : i8 to i32 + %e = vector.insert %d, %cst [1] : i32 into vector<3xi32> + return %e : vector<3xi32> +}