diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp @@ -61,7 +61,7 @@ static Value extractLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value input, int64_t lastOffset) { - llvm::ArrayRef shape = input.getType().cast().getShape(); + ArrayRef shape = input.getType().cast().getShape(); assert(lastOffset < shape.back() && "Offset out of bounds"); // Scalarize the result in case of 1D vectors. @@ -87,13 +87,45 @@ extractLastDimSlice(rewriter, loc, input, 1)}; } +// Performs a vector shape cast to drop the trailing x1 dimension. If the +// `input` is a scalar, this is a noop. +static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter, + Location loc, Value input) { + auto vecTy = input.getType().dyn_cast(); + if (!vecTy) + return input; + + // Shape cast to drop the last x1 dimention. + ArrayRef shape = vecTy.getShape(); + assert(shape.size() >= 2 && "Expected vector with at list two dims"); + assert(shape.back() == 1 && "Expected the last vector dim to be x1"); + + auto newVecTy = VectorType::get(shape.drop_back(), vecTy.getElementType()); + return rewriter.create(loc, newVecTy, input); +} + +// Performs a vector shape cast to append an x1 dimension. If the +// `input` is a scalar, this is a noop. +static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc, + Value input) { + auto vecTy = input.getType().dyn_cast(); + if (!vecTy) + return input; + + // Add a trailing x1 dim. + auto newShape = llvm::to_vector(vecTy.getShape()); + newShape.push_back(1); + auto newTy = VectorType::get(newShape, vecTy.getElementType()); + return rewriter.create(loc, newTy, input); +} + // Inserts the `source` vector slice into the `dest` vector at offset // `lastOffset` in the last dimension. `source` can be a scalar when `dest` is a // 1D vector. static Value insertLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value source, Value dest, int64_t lastOffset) { - llvm::ArrayRef shape = dest.getType().cast().getShape(); + ArrayRef shape = dest.getType().cast().getShape(); assert(lastOffset < shape.back() && "Offset out of bounds"); // Handle scalar source. @@ -116,7 +148,7 @@ static Value constructResultVector(ConversionPatternRewriter &rewriter, Location loc, VectorType resultType, ValueRange resultComponents) { - llvm::ArrayRef resultShape = resultType.getShape(); + ArrayRef resultShape = resultType.getShape(); assert(!resultShape.empty() && "Result expected to have dimentions"); assert(resultShape.back() == static_cast(resultComponents.size()) && "Wrong number of result components"); @@ -227,6 +259,108 @@ } }; +//===----------------------------------------------------------------------===// +// ConvertExtSI +//===----------------------------------------------------------------------===// + +struct ConvertExtSI final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + + // Check if the input type is legal for this target. + TypeConverter &typeConverter = *getTypeConverter(); + if (!typeConverter.isLegal(op.getIn().getType())) + return rewriter.notifyMatchFailure(loc, + "unsupported extension input type"); + + auto newTy = typeConverter.convertType(op.getType()).cast(); + Type newResultComponentTy = reduceInnermostDim(newTy); + + // Sign-extend the input value to determine the low half of the result. + // Then, check if the low half is negative, and sign-extend the comparison + // result to get the high half. + Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn()); + Value extended = rewriter.createOrFold( + loc, newResultComponentTy, newOperand); + Value operandZeroCst = rewriter.create( + loc, rewriter.getZeroAttr(newResultComponentTy)); + Value signBit = rewriter.create( + loc, arith::CmpIPredicate::slt, extended, operandZeroCst); + Value signValue = + rewriter.create(loc, newResultComponentTy, signBit); + + Value resultVec = + constructResultVector(rewriter, loc, newTy, {extended, signValue}); + rewriter.replaceOp(op, resultVec); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ConvertExtUI +//===----------------------------------------------------------------------===// + +struct ConvertExtUI final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + + // Check if the input type is legal for this target. + TypeConverter &typeConverter = *getTypeConverter(); + if (!typeConverter.isLegal(op.getIn().getType())) + return rewriter.notifyMatchFailure(loc, + "unsupported extension input type"); + + auto newTy = typeConverter.convertType(op.getType()).cast(); + Type newResultComponentTy = reduceInnermostDim(newTy); + + // Zero-extend the input value to determine the low half of the result. + // The high half is always zero. + Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn()); + Value extended = rewriter.createOrFold( + loc, newResultComponentTy, newOperand); + Value zeroCst = rewriter.create( + op->getLoc(), rewriter.getZeroAttr(newTy)); + Value newRes = insertLastDimSlice(rewriter, loc, extended, zeroCst, 0); + rewriter.replaceOp(op, newRes); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ConvertTruncI +//===----------------------------------------------------------------------===// + +struct ConvertTruncI final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + // Check if the result type is legal for this target. Currently, we do not + // support truncation to types wider than supported by the target. + if (!getTypeConverter()->isLegal(op.getType())) + return rewriter.notifyMatchFailure(loc, + "unsupported truncation result type"); + + // Discard the high half of the input. Truncate the low half, if necessary. + Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0); + extracted = dropTrailingX1Dim(rewriter, loc, extracted); + Value truncated = + rewriter.createOrFold(loc, op.getType(), extracted); + rewriter.replaceOp(op, truncated); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// @@ -334,6 +468,12 @@ populateReturnOpTypeConversionPattern(patterns, typeConverter); // Populate `arith.*` conversion patterns. - patterns.add(typeConverter, - patterns.getContext()); + patterns.add< + // Misc ops. + ConvertConstant, + // Binary ops. + ConvertAddI, + // Extension and truncation ops. + ConvertExtSI, ConvertExtUI, ConvertTruncI>(typeConverter, + patterns.getContext()); } diff --git a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir --- a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir +++ b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir @@ -111,3 +111,97 @@ %x = arith.addi %a, %b : vector<4xi64> return %x : vector<4xi64> } + +// CHECK-LABEL: func @extsi_scalar +// CHECK-SAME: ([[ARG:%.+]]: i16) -> vector<2xi32> +// CHECK-NEXT: [[EXT:%.+]] = arith.extsi [[ARG]] : i16 to i32 +// CHECK-NEXT: [[SZ:%.+]] = arith.constant 0 : i32 +// CHECK-NEXT: [[SB:%.+]] = arith.cmpi slt, [[EXT]], [[SZ]] : i32 +// CHECK-NEXT: [[SV:%.+]] = arith.extsi [[SB]] : i1 to i32 +// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32> +// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[EXT]], [[VZ]] [0] : i32 into vector<2xi32> +// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[SV]], [[INS0]] [1] : i32 into vector<2xi32> +// CHECK: return [[INS1]] : vector<2xi32> +func.func @extsi_scalar(%a : i16) -> i64 { + %r = arith.extsi %a : i16 to i64 + return %r : i64 +} + +// CHECK-LABEL: func @extsi_vector +// CHECK-SAME: ([[ARG:%.+]]: vector<3xi16>) -> vector<3x2xi32> +// CHECK-NEXT: [[SHAPE:%.+]] = vector.shape_cast [[ARG]] : vector<3xi16> to vector<3x1xi16> +// CHECK-NEXT: [[EXT:%.+]] = arith.extsi [[SHAPE]] : vector<3x1xi16> to vector<3x1xi32> +// CHECK-NEXT: [[CSTE:%.+]] = arith.constant dense<0> : vector<3x1xi32> +// CHECK-NEXT: [[CMP:%.+]] = arith.cmpi slt, [[EXT]], [[CSTE]] : vector<3x1xi32> +// CHECK-NEXT: [[HIGH:%.+]] = arith.extsi [[CMP]] : vector<3x1xi1> to vector<3x1xi32> +// CHECK-NEXT: [[CSTZ:%.+]] = arith.constant dense<0> : vector<3x2xi32> +// CHECK-NEXT: [[INS0:%.+]] = vector.insert_strided_slice [[EXT]], [[CSTZ]] {offsets = [0, 0], strides = [1, 1]} : vector<3x1xi32> into vector<3x2xi32> +// CHECK-NEXT: [[INS1:%.+]] = vector.insert_strided_slice [[HIGH]], [[INS0]] {offsets = [0, 1], strides = [1, 1]} : vector<3x1xi32> into vector<3x2xi32> +// CHECK-NEXT: return [[INS1]] : vector<3x2xi32> +func.func @extsi_vector(%a : vector<3xi16>) -> vector<3xi64> { + %r = arith.extsi %a : vector<3xi16> to vector<3xi64> + return %r : vector<3xi64> +} + +// CHECK-LABEL: func @extui_scalar1 +// CHECK-SAME: ([[ARG:%.+]]: i16) -> vector<2xi32> +// CHECK-NEXT: [[EXT:%.+]] = arith.extui [[ARG]] : i16 to i32 +// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32> +// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[EXT]], [[VZ]] [0] : i32 into vector<2xi32> +// CHECK: return [[INS0]] : vector<2xi32> +func.func @extui_scalar1(%a : i16) -> i64 { + %r = arith.extui %a : i16 to i64 + return %r : i64 +} + +// CHECK-LABEL: func @extui_scalar2 +// CHECK-SAME: ([[ARG:%.+]]: i32) -> vector<2xi32> +// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32> +// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[ARG]], [[VZ]] [0] : i32 into vector<2xi32> +// CHECK: return [[INS0]] : vector<2xi32> +func.func @extui_scalar2(%a : i32) -> i64 { + %r = arith.extui %a : i32 to i64 + return %r : i64 +} + +// CHECK-LABEL: func @extui_vector +// CHECK-SAME: ([[ARG:%.+]]: vector<3xi16>) -> vector<3x2xi32> +// CHECK-NEXT: [[SHAPE:%.+]] = vector.shape_cast [[ARG]] : vector<3xi16> to vector<3x1xi16> +// CHECK-NEXT: [[EXT:%.+]] = arith.extui [[SHAPE]] : vector<3x1xi16> to vector<3x1xi32> +// CHECK-NEXT: [[CST:%.+]] = arith.constant dense<0> : vector<3x2xi32> +// CHECK-NEXT: [[INS0:%.+]] = vector.insert_strided_slice [[EXT]], [[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<3x1xi32> into vector<3x2xi32> +// CHECK: return [[INS0]] : vector<3x2xi32> +func.func @extui_vector(%a : vector<3xi16>) -> vector<3xi64> { + %r = arith.extui %a : vector<3xi16> to vector<3xi64> + return %r : vector<3xi64> +} + +// CHECK-LABEL: func @trunci_scalar1 +// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> i32 +// CHECK-NEXT: [[EXT:%.+]] = vector.extract [[ARG]][0] : vector<2xi32> +// CHECK-NEXT: return [[EXT]] : i32 +func.func @trunci_scalar1(%a : i64) -> i32 { + %b = arith.trunci %a : i64 to i32 + return %b : i32 +} + +// CHECK-LABEL: func @trunci_scalar2 +// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> i16 +// CHECK-NEXT: [[EXTR:%.+]] = vector.extract [[ARG]][0] : vector<2xi32> +// CHECK-NEXT: [[TRNC:%.+]] = arith.trunci [[EXTR]] : i32 to i16 +// CHECK-NEXT: return [[TRNC]] : i16 +func.func @trunci_scalar2(%a : i64) -> i16 { + %b = arith.trunci %a : i64 to i16 + return %b : i16 +} + +// CHECK-LABEL: func @trunci_vector +// CHECK-SAME: ([[ARG:%.+]]: vector<3x2xi32>) -> vector<3xi16> +// CHECK-NEXT: [[EXTR:%.+]] = vector.extract_strided_slice [[ARG]] {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32> +// CHECK-NEXT: [[SHAPE:%.+]] = vector.shape_cast [[EXTR]] : vector<3x1xi32> to vector<3xi32> +// CHECK-NEXT: [[TRNC:%.+]] = arith.trunci [[SHAPE]] : vector<3xi32> to vector<3xi16> +// CHECK-NEXT: return [[TRNC]] : vector<3xi16> +func.func @trunci_vector(%a : vector<3xi64>) -> vector<3xi16> { + %b = arith.trunci %a : vector<3xi64> to vector<3xi16> + return %b : vector<3xi16> +}