diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp --- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp @@ -598,6 +598,86 @@ } }; +// Convert IndexCast ops +//===----------------------------------------------------------------------===// + +/// Returns true iff the type is `index` or `vector<...index>`. +static bool isIndexOrIndexVector(Type type) { + if (type.isa()) + return true; + + if (auto vectorTy = type.dyn_cast()) + if (vectorTy.getElementType().isa()) + return true; + + return false; +} + +template +struct ConvertIndexCastIntToIndex final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = op.getType(); + if (!isIndexOrIndexVector(resultType)) + return failure(); + + Location loc = op.getLoc(); + Type inType = op.getIn().getType(); + auto newInTy = this->getTypeConverter() + ->convertType(inType) + .template dyn_cast_or_null(); + if (!newInTy) + return rewriter.notifyMatchFailure( + loc, llvm::formatv("unsupported type: {0}", inType)); + + // Discard the high half of the input truncating the original value. + Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0); + extracted = dropTrailingX1Dim(rewriter, loc, extracted); + rewriter.replaceOpWithNewOp(op, resultType, extracted); + return success(); + } +}; + +template +struct ConvertIndexCastIndexToInt final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type inType = op.getIn().getType(); + if (!isIndexOrIndexVector(inType)) + return failure(); + + Location loc = op.getLoc(); + auto *typeConverter = + this->template getTypeConverter(); + + Type resultType = op.getType(); + auto newTy = typeConverter->convertType(resultType) + .template dyn_cast_or_null(); + if (!newTy) + return rewriter.notifyMatchFailure( + loc, llvm::formatv("unsupported type: {0}", resultType)); + + // Emit an index cast over the matching narrow type. + Type narrowTy = + rewriter.getIntegerType(typeConverter->getMaxTargetIntBitWidth()); + if (auto vecTy = resultType.dyn_cast()) + narrowTy = VectorType::get(vecTy.getShape(), narrowTy); + + // Sign or zero-extend the result. Let the matching conversion pattern + // legalize the extension op. + Value underlyingVal = + rewriter.create(loc, narrowTy, adaptor.getIn()); + rewriter.replaceOpWithNewOp(op, resultType, underlyingVal); + return success(); + } +}; + //===----------------------------------------------------------------------===// // ConvertSelect //===----------------------------------------------------------------------===// @@ -841,8 +921,7 @@ // Rewrite this as an bitwise or of `arith.shrui` and sign extension bits. // Perform as many ops over the narrow integer type as possible and let the // other emulation patterns convert the rest. - Value elemZero = - createScalarOrSplatConstant(rewriter, loc, narrowTy, 0); + Value elemZero = createScalarOrSplatConstant(rewriter, loc, narrowTy, 0); Value signBit = rewriter.create( loc, arith::CmpIPredicate::slt, lhsElem1, elemZero); signBit = dropTrailingX1Dim(rewriter, loc, signBit); @@ -862,7 +941,8 @@ rewriter.create(loc, allSign, numNonSignExtBits); // Use original arguments to create the right shift. - Value shrui = rewriter.create(loc, op.getLhs(), op.getRhs()); + Value shrui = + rewriter.create(loc, op.getLhs(), op.getRhs()); Value shrsi = rewriter.create(loc, shrui, signBits); // Handle shifting by zero. This is necessary when the `signBits` shift is @@ -870,7 +950,8 @@ Value isNoop = rewriter.create(loc, arith::CmpIPredicate::eq, rhsElem0, elemZero); isNoop = dropTrailingX1Dim(rewriter, loc, isNoop); - rewriter.replaceOpWithNewOp(op, isNoop, op.getLhs(), shrsi); + rewriter.replaceOpWithNewOp(op, isNoop, op.getLhs(), + shrsi); return success(); } @@ -1045,6 +1126,11 @@ ConvertBitwiseBinary, ConvertBitwiseBinary, ConvertBitwiseBinary, // Extension and truncation ops. - ConvertExtSI, ConvertExtUI, ConvertTruncI>(typeConverter, - patterns.getContext()); + ConvertExtSI, ConvertExtUI, ConvertTruncI, + // Cast ops. + ConvertIndexCastIntToIndex, + ConvertIndexCastIntToIndex, + ConvertIndexCastIndexToInt, + ConvertIndexCastIndexToInt>( + typeConverter, patterns.getContext()); } diff --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir --- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir +++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir @@ -365,6 +365,102 @@ return %r : vector<3xi64> } +// CHECK-LABEL: func @index_cast_int_to_index_scalar +// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> index +// CHECK-NEXT: [[EXT:%.+]] = vector.extract [[ARG]][0] : vector<2xi32> +// CHECK-NEXT: [[RES:%.+]] = arith.index_cast [[EXT]] : i32 to index +// CHECK-NEXT: return [[RES]] : index +func.func @index_cast_int_to_index_scalar(%a : i64) -> index { + %r = arith.index_cast %a : i64 to index + return %r : index +} + +// CHECK-LABEL: func @index_cast_int_to_index_vector +// CHECK-SAME: ([[ARG:%.+]]: vector<3x2xi32>) -> vector<3xindex> +// CHECK-NEXT: [[EXT:%.+]] = vector.extract_strided_slice [[ARG]] {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32> +// CHECK-NEXT: [[SHAPE:%.+]] = vector.shape_cast [[EXT]] : vector<3x1xi32> to vector<3xi32> +// CHECK-NEXT: [[RES:%.+]] = arith.index_cast [[SHAPE]] : vector<3xi32> to vector<3xindex> +// CHECK-NEXT: return [[RES]] : vector<3xindex> +func.func @index_cast_int_to_index_vector(%a : vector<3xi64>) -> vector<3xindex> { + %r = arith.index_cast %a : vector<3xi64> to vector<3xindex> + return %r : vector<3xindex> +} + +// CHECK-LABEL: func @index_castui_int_to_index_scalar +// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> index +// CHECK-NEXT: [[EXT:%.+]] = vector.extract [[ARG]][0] : vector<2xi32> +// CHECK-NEXT: [[RES:%.+]] = arith.index_castui [[EXT]] : i32 to index +// CHECK-NEXT: return [[RES]] : index +func.func @index_castui_int_to_index_scalar(%a : i64) -> index { + %r = arith.index_castui %a : i64 to index + return %r : index +} + +// CHECK-LABEL: func @index_castui_int_to_index_vector +// CHECK-SAME: ([[ARG:%.+]]: vector<3x2xi32>) -> vector<3xindex> +// CHECK-NEXT: [[EXT:%.+]] = vector.extract_strided_slice [[ARG]] {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32> +// CHECK-NEXT: [[SHAPE:%.+]] = vector.shape_cast [[EXT]] : vector<3x1xi32> to vector<3xi32> +// CHECK-NEXT: [[RES:%.+]] = arith.index_castui [[SHAPE]] : vector<3xi32> to vector<3xindex> +// CHECK-NEXT: return [[RES]] : vector<3xindex> +func.func @index_castui_int_to_index_vector(%a : vector<3xi64>) -> vector<3xindex> { + %r = arith.index_castui %a : vector<3xi64> to vector<3xindex> + return %r : vector<3xindex> +} + +// CHECK-LABEL: func @index_cast_index_to_int_scalar +// CHECK-SAME: ([[ARG:%.+]]: index) -> vector<2xi32> +// CHECK-NEXT: [[CAST:%.+]] = arith.index_cast [[ARG]] : index to i32 +// CHECK-NEXT: [[C0I32:%.+]] = arith.constant 0 : i32 +// CHECK-NEXT: [[NEG:%.+]] = arith.cmpi slt, [[CAST]], [[C0I32]] : i32 +// CHECK-NEXT: [[EXT:%.+]] = arith.extsi [[NEG]] : i1 to i32 +// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32> +// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[CAST]], [[VZ]] [0] : i32 into vector<2xi32> +// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[EXT]], [[INS0]] [1] : i32 into vector<2xi32> +// CHECK-NEXT: return [[INS1]] : vector<2xi32> +func.func @index_cast_index_to_int_scalar(%a : index) -> i64 { + %r = arith.index_cast %a : index to i64 + return %r : i64 +} + +// CHECK-LABEL: func @index_cast_index_to_int_vector +// CHECK-SAME: ([[ARG:%.+]]: vector<3xindex>) -> vector<3x2xi32> +// CHECK-NEXT: arith.index_cast [[ARG]] : vector<3xindex> to vector<3xi32> +// CHECK-NEXT: vector.shape_cast +// CHECK-NEXT: arith.constant dense<0> : vector<3x1xi32> +// CHECK-NEXT: arith.cmpi slt +// CHECK-NEXT: arith.extsi +// CHECK-NEXT: arith.constant dense<0> : vector<3x2xi32> +// CHECK-NEXT: vector.insert_strided_slice +// CHECK-NEXT: vector.insert_strided_slice +// CHECK-NEXT: return {{%.+}} : vector<3x2xi32> +func.func @index_cast_index_to_int_vector(%a : vector<3xindex>) -> vector<3xi64> { + %r = arith.index_cast %a : vector<3xindex> to vector<3xi64> + return %r : vector<3xi64> +} + +// CHECK-LABEL: func @index_castui_index_to_int_scalar +// CHECK-SAME: ([[ARG:%.+]]: index) -> vector<2xi32> +// CHECK-NEXT: [[CAST:%.+]] = arith.index_castui [[ARG]] : index to i32 +// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32> +// CHECK-NEXT: [[RES:%.+]] = vector.insert [[CAST]], [[VZ]] [0] : i32 into vector<2xi32> +// CHECK-NEXT: return [[RES]] : vector<2xi32> +func.func @index_castui_index_to_int_scalar(%a : index) -> i64 { + %r = arith.index_castui %a : index to i64 + return %r : i64 +} + +// CHECK-LABEL: func @index_castui_index_to_int_vector +// CHECK-SAME: ([[ARG:%.+]]: vector<3xindex>) -> vector<3x2xi32> +// CHECK-NEXT: [[CAST:%.+]] = arith.index_castui [[ARG]] : vector<3xindex> to vector<3xi32> +// CHECK-NEXT: [[SHAPE:%.+]] = vector.shape_cast [[CAST]] : vector<3xi32> to vector<3x1xi32> +// CHECK-NEXT: [[CST:%.+]] = arith.constant dense<0> : vector<3x2xi32> +// CHECK-NEXT: [[RES:%.+]] = vector.insert_strided_slice [[SHAPE]], [[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<3x1xi32> into vector<3x2xi32> +// CHECK-NEXT: return [[RES]] : vector<3x2xi32> +func.func @index_castui_index_to_int_vector(%a : vector<3xindex>) -> vector<3xi64> { + %r = arith.index_castui %a : vector<3xindex> to vector<3xi64> + return %r : vector<3xi64> +} + // CHECK-LABEL: func @trunci_scalar1 // CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> i32 // CHECK-NEXT: [[EXT:%.+]] = vector.extract [[ARG]][0] : vector<2xi32>