diff --git a/mlir/include/mlir/Conversion/IndexToLLVM/IndexToLLVM.h b/mlir/include/mlir/Conversion/IndexToLLVM/IndexToLLVM.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/IndexToLLVM/IndexToLLVM.h @@ -0,0 +1,28 @@ +//===- IndexToLLVM.h - Index to LLVM dialect conversion ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_INDEXTOLLVM_INDEXTOLLVM_H +#define MLIR_CONVERSION_INDEXTOLLVM_INDEXTOLLVM_H + +#include + +namespace mlir { +class LLVMTypeConverter; +class RewritePatternSet; +class Pass; + +#define GEN_PASS_DECL_INDEXTOLLVMCONVERSIONPASS +#include "mlir/Conversion/Passes.h.inc" + +namespace index { +void populateIndexToLLVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); +} // namespace index +} // namespace mlir + +#endif // MLIR_CONVERSION_INDEXTOLLVM_INDEXTOLLVM_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -29,6 +29,7 @@ #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h" #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" #include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h" #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -439,6 +439,29 @@ let dependentDialects = ["LLVM::LLVMDialect"]; } +//===----------------------------------------------------------------------===// +// IndexToLLVMConversionPass +//===----------------------------------------------------------------------===// + +def IndexToLLVMConversionPass : Pass<"convert-index-to-llvm"> { + let summary = "Lower the `index` dialect to the `llvm` dialect."; + let description = [{ + This pass lowers Index dialect operations to LLVM dialect operations. + Operation conversions are 1-to-1 except for the exotic divides: `ceildivs`, + `ceildivu`, and `floordivs`, which expand to series of LLVM operations. + Importantly, the index bitwidth should be correctly set to the target + pointer width via `index-bitwidth`. + }]; + + let dependentDialects = ["::mlir::LLVM::LLVMDialect"]; + + let options = [ + Option<"indexBitwidth", "index-bitwidth", "unsigned", + /*default=kDeriveIndexBitwidthFromDataLayout*/"0", + "Bitwidth of the index type, 0 to use size of machine word">, + ]; +} + //===----------------------------------------------------------------------===// // LinalgToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -17,6 +17,7 @@ add_subdirectory(GPUToROCDL) add_subdirectory(GPUToSPIRV) add_subdirectory(GPUToVulkan) +add_subdirectory(IndexToLLVM) add_subdirectory(LinalgToLLVM) add_subdirectory(LinalgToSPIRV) add_subdirectory(LinalgToStandard) diff --git a/mlir/lib/Conversion/IndexToLLVM/CMakeLists.txt b/mlir/lib/Conversion/IndexToLLVM/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/IndexToLLVM/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_conversion_library(MLIRIndexToLLVM + IndexToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/IndexToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIndexDialect + MLIRLLVMCommonConversion + MLIRLLVMDialect + ) diff --git a/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp b/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp @@ -0,0 +1,347 @@ +//===- IndexToLLVM.cpp - Index to LLVM dialect conversion -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/Index/IR/IndexAttrs.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace index; + +namespace { + +//===----------------------------------------------------------------------===// +// ConvertIndexCeilDivS +//===----------------------------------------------------------------------===// + +/// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then +/// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. +struct ConvertIndexCeilDivS : mlir::ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value n = adaptor.getLhs(); + Value m = adaptor.getRhs(); + Value zero = rewriter.create(loc, n.getType(), 0); + Value posOne = rewriter.create(loc, n.getType(), 1); + Value negOne = rewriter.create(loc, n.getType(), -1); + + // Compute `x`. + Value mPos = + rewriter.create(loc, LLVM::ICmpPredicate::sgt, m, zero); + Value x = rewriter.create(loc, mPos, negOne, posOne); + + // Compute the positive result. + Value nPlusX = rewriter.create(loc, n, x); + Value nPlusXDivM = rewriter.create(loc, nPlusX, m); + Value posRes = rewriter.create(loc, nPlusXDivM, posOne); + + // Compute the negative result. + Value negN = rewriter.create(loc, zero, n); + Value negNDivM = rewriter.create(loc, negN, m); + Value negRes = rewriter.create(loc, zero, negNDivM); + + // Pick the positive result if `n` and `m` have the same sign and `n` is + // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`. + Value nPos = + rewriter.create(loc, LLVM::ICmpPredicate::sgt, n, zero); + Value sameSign = + rewriter.create(loc, LLVM::ICmpPredicate::eq, nPos, mPos); + Value nNonZero = + rewriter.create(loc, LLVM::ICmpPredicate::ne, n, zero); + Value cmp = rewriter.create(loc, sameSign, nNonZero); + rewriter.replaceOpWithNewOp(op, cmp, posRes, negRes); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ConvertIndexCeilDivU +//===----------------------------------------------------------------------===// + +/// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`. +struct ConvertIndexCeilDivU : mlir::ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value n = adaptor.getLhs(); + Value m = adaptor.getRhs(); + Value zero = rewriter.create(loc, n.getType(), 0); + Value one = rewriter.create(loc, n.getType(), 1); + + // Compute the non-zero result. + Value minusOne = rewriter.create(loc, n, one); + Value quotient = rewriter.create(loc, minusOne, m); + Value plusOne = rewriter.create(loc, quotient, one); + + // Pick the result. + Value cmp = + rewriter.create(loc, LLVM::ICmpPredicate::eq, n, zero); + rewriter.replaceOpWithNewOp(op, cmp, zero, plusOne); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ConvertIndexFloorDivS +//===----------------------------------------------------------------------===// + +/// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then +/// `n*m < 0 ? -1 - (x-n)/m : n/m`. +struct ConvertIndexFloorDivS : mlir::ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value n = adaptor.getLhs(); + Value m = adaptor.getRhs(); + Value zero = rewriter.create(loc, n.getType(), 0); + Value posOne = rewriter.create(loc, n.getType(), 1); + Value negOne = rewriter.create(loc, n.getType(), -1); + + // Compute `x`. + Value mNeg = + rewriter.create(loc, LLVM::ICmpPredicate::slt, m, zero); + Value x = rewriter.create(loc, mNeg, posOne, negOne); + + // Compute the negative result. + Value xMinusN = rewriter.create(loc, x, n); + Value xMinusNDivM = rewriter.create(loc, xMinusN, m); + Value negRes = rewriter.create(loc, negOne, xMinusNDivM); + + // Compute the positive result. + Value posRes = rewriter.create(loc, n, m); + + // Pick the negative result if `n` and `m` have different signs and `n` is + // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`. + Value nNeg = + rewriter.create(loc, LLVM::ICmpPredicate::slt, n, zero); + Value diffSign = + rewriter.create(loc, LLVM::ICmpPredicate::ne, nNeg, mNeg); + Value nNonZero = + rewriter.create(loc, LLVM::ICmpPredicate::ne, n, zero); + Value cmp = rewriter.create(loc, diffSign, nNonZero); + rewriter.replaceOpWithNewOp(op, cmp, negRes, posRes); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// CovnertIndexCast +//===----------------------------------------------------------------------===// + +/// Convert a cast op. If the materialized index type is the same as the other +/// type, fold away the op. Otherwise, truncate or extend the op as appropriate. +/// Signed casts sign extend when the result bitwidth is larger. Unsigned casts +/// zero extend when the result bitwidth is larger. +template +struct ConvertIndexCast : public mlir::ConvertOpToLLVMPattern { + using mlir::ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type in = adaptor.getInput().getType(); + Type out = this->getTypeConverter()->convertType(op.getType()); + if (in == out) + rewriter.replaceOp(op, adaptor.getInput()); + else if (in.getIntOrFloatBitWidth() > out.getIntOrFloatBitWidth()) + rewriter.replaceOpWithNewOp(op, out, adaptor.getInput()); + else + rewriter.replaceOpWithNewOp(op, out, adaptor.getInput()); + return success(); + } +}; + +using ConvertIndexCastS = ConvertIndexCast; +using ConvertIndexCastU = ConvertIndexCast; + +//===----------------------------------------------------------------------===// +// ConvertIndexCmp +//===----------------------------------------------------------------------===// + +/// Assert that the LLVM comparison enum lines up with index's enum. +static constexpr bool checkPredicates(LLVM::ICmpPredicate lhs, + IndexCmpPredicate rhs) { + return static_cast(lhs) == static_cast(rhs); +} + +static_assert( + LLVM::getMaxEnumValForICmpPredicate() == + getMaxEnumValForIndexCmpPredicate() && + checkPredicates(LLVM::ICmpPredicate::eq, IndexCmpPredicate::EQ) && + checkPredicates(LLVM::ICmpPredicate::ne, IndexCmpPredicate::NE) && + checkPredicates(LLVM::ICmpPredicate::sge, IndexCmpPredicate::SGE) && + checkPredicates(LLVM::ICmpPredicate::sgt, IndexCmpPredicate::SGT) && + checkPredicates(LLVM::ICmpPredicate::sle, IndexCmpPredicate::SLE) && + checkPredicates(LLVM::ICmpPredicate::slt, IndexCmpPredicate::SLT) && + checkPredicates(LLVM::ICmpPredicate::uge, IndexCmpPredicate::UGE) && + checkPredicates(LLVM::ICmpPredicate::ugt, IndexCmpPredicate::UGT) && + checkPredicates(LLVM::ICmpPredicate::ule, IndexCmpPredicate::ULE) && + checkPredicates(LLVM::ICmpPredicate::ult, IndexCmpPredicate::ULT), + "LLVM ICmpPredicate mismatches IndexCmpPredicate"); + +struct ConvertIndexCmp : public mlir::ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // The LLVM enum has the same values as the index predicate enums. + rewriter.replaceOpWithNewOp( + op, *LLVM::symbolizeICmpPredicate(static_cast(op.getPred())), + adaptor.getLhs(), adaptor.getRhs()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ConvertIndexSizeOf +//===----------------------------------------------------------------------===// + +/// Lower `index.sizeof` to a constant with the value of the index bitwidth. +struct ConvertIndexSizeOf : public mlir::ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->getIndexType(), + getTypeConverter()->getIndexTypeBitwidth()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ConvertIndexConstant +//===----------------------------------------------------------------------===// + +/// Convert an index constant. Truncate the value as appropriate. +struct ConvertIndexConstant : public mlir::ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type type = getTypeConverter()->getIndexType(); + APInt value = op.getValue().trunc(type.getIntOrFloatBitWidth()); + rewriter.replaceOpWithNewOp( + op, type, IntegerAttr::get(type, value)); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Trivial Conversions +//===----------------------------------------------------------------------===// + +using ConvertIndexAdd = mlir::OneToOneConvertToLLVMPattern; +using ConvertIndexSub = mlir::OneToOneConvertToLLVMPattern; +using ConvertIndexMul = mlir::OneToOneConvertToLLVMPattern; +using ConvertIndexDivS = + mlir::OneToOneConvertToLLVMPattern; +using ConvertIndexDivU = + mlir::OneToOneConvertToLLVMPattern; +using ConvertIndexRemS = + mlir::OneToOneConvertToLLVMPattern; +using ConvertIndexRemU = + mlir::OneToOneConvertToLLVMPattern; +using ConvertIndexMaxS = + mlir::OneToOneConvertToLLVMPattern; +using ConvertIndexMaxU = + mlir::OneToOneConvertToLLVMPattern; +using ConvertIndexBoolConstant = + mlir::OneToOneConvertToLLVMPattern; + +} // namespace + +//===----------------------------------------------------------------------===// +// Pattern Population +//===----------------------------------------------------------------------===// + +void index::populateIndexToLLVMConversionPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { + patterns.insert< + // clang-format off + ConvertIndexAdd, + ConvertIndexSub, + ConvertIndexMul, + ConvertIndexDivS, + ConvertIndexDivU, + ConvertIndexRemS, + ConvertIndexRemU, + ConvertIndexMaxS, + ConvertIndexMaxU, + ConvertIndexCeilDivS, + ConvertIndexCeilDivU, + ConvertIndexFloorDivS, + ConvertIndexCastS, + ConvertIndexCastU, + ConvertIndexCmp, + ConvertIndexSizeOf, + ConvertIndexConstant, + ConvertIndexBoolConstant + // clang-format on + >(typeConverter); +} + +//===----------------------------------------------------------------------===// +// ODS-Generated Definitions +//===----------------------------------------------------------------------===// + +namespace mlir { +#define GEN_PASS_DEF_INDEXTOLLVMCONVERSIONPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +namespace { +struct IndexToLLVMPass + : public impl::IndexToLLVMConversionPassBase { + using Base::Base; + + void runOnOperation() override; +}; +} // namespace + +void IndexToLLVMPass::runOnOperation() { + // Configure dialect conversion. + mlir::ConversionTarget target(getContext()); + target.addIllegalDialect(); + target.addLegalDialect(); + + // Set LLVM lowering options. + mlir::LowerToLLVMOptions options(&getContext()); + if (indexBitwidth != mlir::kDeriveIndexBitwidthFromDataLayout) + options.overrideIndexBitwidth(indexBitwidth); + mlir::LLVMTypeConverter typeConverter(&getContext(), options); + + // Populate patterns and run the conversion. + mlir::RewritePatternSet patterns(&getContext()); + populateIndexToLLVMConversionPatterns(typeConverter, patterns); + + if (failed(mlir::applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); +} diff --git a/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir b/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir @@ -0,0 +1,176 @@ +// RUN: mlir-opt %s -convert-index-to-llvm | FileCheck %s +// RUN: mlir-opt %s -convert-index-to-llvm=index-bitwidth=32 | FileCheck %s --check-prefix=INDEX32 +// RUN: mlir-opt %s -convert-index-to-llvm=index-bitwidth=64 | FileCheck %s --check-prefix=INDEX64 + +// CHECK-LABEL: @trivial_ops +func.func @trivial_ops(%a: index, %b: index) { + // CHECK: llvm.add + %0 = index.add %a, %b + // CHECK: llvm.sub + %1 = index.sub %a, %b + // CHECK: llvm.mul + %2 = index.mul %a, %b + // CHECK: llvm.sdiv + %3 = index.divs %a, %b + // CHECK: llvm.udiv + %4 = index.divu %a, %b + // CHECK: llvm.srem + %5 = index.rems %a, %b + // CHECK: llvm.urem + %6 = index.remu %a, %b + // CHECK: llvm.intr.smax + %7 = index.maxs %a, %b + // CHECK: llvm.intr.umax + %8 = index.maxu %a, %b + // CHECK: llvm.mlir.constant(true + %9 = index.bool.constant true + return +} + +// CHECK-LABEL: @ceildivs +// CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index +func.func @ceildivs(%n: index, %m: index) -> index { + // CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]] + // CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]] + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : + // CHECK: %[[POS_ONE:.*]] = llvm.mlir.constant(1 : + // CHECK: %[[NEG_ONE:.*]] = llvm.mlir.constant(-1 : + + // CHECK: %[[M_POS:.*]] = llvm.icmp "sgt" %[[M]], %[[ZERO]] + // CHECK: %[[X:.*]] = llvm.select %[[M_POS]], %[[NEG_ONE]], %[[POS_ONE]] + + // CHECK: %[[N_PLUS_X:.*]] = llvm.add %[[N]], %[[X]] + // CHECK: %[[N_PLUS_X_DIV_M:.*]] = llvm.sdiv %[[N_PLUS_X]], %[[M]] + // CHECK: %[[POS_RES:.*]] = llvm.add %[[N_PLUS_X_DIV_M]], %[[POS_ONE]] + + // CHECK: %[[NEG_N:.*]] = llvm.sub %[[ZERO]], %[[N]] + // CHECK: %[[NEG_N_DIV_M:.*]] = llvm.sdiv %[[NEG_N]], %[[M]] + // CHECK: %[[NEG_RES:.*]] = llvm.sub %[[ZERO]], %[[NEG_N_DIV_M]] + + // CHECK: %[[N_POS:.*]] = llvm.icmp "sgt" %[[N]], %[[ZERO]] + // CHECK: %[[SAME_SIGN:.*]] = llvm.icmp "eq" %[[N_POS]], %[[M_POS]] + // CHECK: %[[N_NON_ZERO:.*]] = llvm.icmp "ne" %[[N]], %[[ZERO]] + // CHECK: %[[CMP:.*]] = llvm.and %[[SAME_SIGN]], %[[N_NON_ZERO]] + // CHECK: %[[RESULT:.*]] = llvm.select %[[CMP]], %[[POS_RES]], %[[NEG_RES]] + %result = index.ceildivs %n, %m + + // CHECK: %[[RESULTI:.*]] = builtin.unrealized_conversion_cast %[[RESULT]] + // CHECK: return %[[RESULTI]] + return %result : index +} + +// CHECK-LABEL: @ceildivu +// CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index +func.func @ceildivu(%n: index, %m: index) -> index { + // CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]] + // CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]] + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : + + // CHECK: %[[MINUS_ONE:.*]] = llvm.sub %[[N]], %[[ONE]] + // CHECK: %[[QUOTIENT:.*]] = llvm.udiv %[[MINUS_ONE]], %[[M]] + // CHECK: %[[PLUS_ONE:.*]] = llvm.add %[[QUOTIENT]], %[[ONE]] + + // CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[N]], %[[ZERO]] + // CHECK: %[[RESULT:.*]] = llvm.select %[[CMP]], %[[ZERO]], %[[PLUS_ONE]] + %result = index.ceildivu %n, %m + + // CHECK: %[[RESULTI:.*]] = builtin.unrealized_conversion_cast %[[RESULT]] + // CHECK: return %[[RESULTI]] + return %result : index +} + +// CHECK-LABEL: @floordivs +// CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index +func.func @floordivs(%n: index, %m: index) -> index { + // CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]] + // CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]] + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : + // CHECK: %[[POS_ONE:.*]] = llvm.mlir.constant(1 : + // CHECK: %[[NEG_ONE:.*]] = llvm.mlir.constant(-1 : + + // CHECK: %[[M_NEG:.*]] = llvm.icmp "slt" %[[M]], %[[ZERO]] + // CHECK: %[[X:.*]] = llvm.select %[[M_NEG]], %[[POS_ONE]], %[[NEG_ONE]] + + // CHECK: %[[X_MINUS_N:.*]] = llvm.sub %[[X]], %[[N]] + // CHECK: %[[X_MINUS_N_DIV_M:.*]] = llvm.sdiv %[[X_MINUS_N]], %[[M]] + // CHECK: %[[NEG_RES:.*]] = llvm.sub %[[NEG_ONE]], %[[X_MINUS_N_DIV_M]] + + // CHECK: %[[POS_RES:.*]] = llvm.sdiv %[[N]], %[[M]] + + // CHECK: %[[N_NEG:.*]] = llvm.icmp "slt" %[[N]], %[[ZERO]] + // CHECK: %[[DIFF_SIGN:.*]] = llvm.icmp "ne" %[[N_NEG]], %[[M_NEG]] + // CHECK: %[[N_NON_ZERO:.*]] = llvm.icmp "ne" %[[N]], %[[ZERO]] + // CHECK: %[[CMP:.*]] = llvm.and %[[DIFF_SIGN]], %[[N_NON_ZERO]] + // CHECK: %[[RESULT:.*]] = llvm.select %[[CMP]], %[[NEG_RES]], %[[POS_RES]] + %result = index.floordivs %n, %m + + // CHECK: %[[RESULTI:.*]] = builtin.unrealized_conversion_cast %[[RESULT]] + // CHECK: return %[[RESULTI]] + return %result : index +} + +// INDEX32-LABEL: @index_cast_from +// INDEX64-LABEL: @index_cast_from +// INDEX32-SAME: %[[AI:.*]]: index +// INDEX64-SAME: %[[AI:.*]]: index +func.func @index_cast_from(%a: index) -> (i64, i32, i64, i32) { + // INDEX32: %[[A:.*]] = builtin.unrealized_conversion_cast %[[AI]] : index to i32 + // INDEX64: %[[A:.*]] = builtin.unrealized_conversion_cast %[[AI]] : index to i64 + + // INDEX32: %[[V0:.*]] = llvm.sext %[[A]] : i32 to i64 + %0 = index.casts %a : index to i64 + // INDEX64: %[[V1:.*]] = llvm.trunc %[[A]] : i64 to i32 + %1 = index.casts %a : index to i32 + // INDEX32: %[[V2:.*]] = llvm.zext %[[A]] : i32 to i64 + %2 = index.castu %a : index to i64 + // INDEX64: %[[V3:.*]] = llvm.trunc %[[A]] : i64 to i32 + %3 = index.castu %a : index to i32 + + // INDEX32: return %[[V0]], %[[A]], %[[V2]], %[[A]] + // INDEX64: return %[[A]], %[[V1]], %[[A]], %[[V3]] + return %0, %1, %2, %3 : i64, i32, i64, i32 +} + +// INDEX32-LABEL: @index_cast_to +// INDEX64-LABEL: @index_cast_to +// INDEX32-SAME: %[[A:.*]]: i32, %[[B:.*]]: i64 +// INDEX64-SAME: %[[A:.*]]: i32, %[[B:.*]]: i64 +func.func @index_cast_to(%a: i32, %b: i64) -> (index, index, index, index) { + // INDEX64: %[[V0:.*]] = llvm.sext %[[A]] : i32 to i64 + %0 = index.casts %a : i32 to index + // INDEX32: %[[V1:.*]] = llvm.trunc %[[B]] : i64 to i32 + %1 = index.casts %b : i64 to index + // INDEX64: %[[V2:.*]] = llvm.zext %[[A]] : i32 to i64 + %2 = index.castu %a : i32 to index + // INDEX32: %[[V3:.*]] = llvm.trunc %[[B]] : i64 to i32 + %3 = index.castu %b : i64 to index + return %0, %1, %2, %3 : index, index, index, index +} + +// INDEX32-LABEL: @index_sizeof +// INDEX64-LABEL: @index_sizeof +func.func @index_sizeof() { + // INDEX32-NEXT: llvm.mlir.constant(32 : i32) + // INDEX64-NEXT: llvm.mlir.constant(64 : i64) + %0 = index.sizeof + return +} + +// INDEX32-LABEL: @index_constant +// INDEX64-LABEL: @index_constant +func.func @index_constant() { + // INDEX32: llvm.mlir.constant(-2100000000 : i32) : i32 + // INDEX64: llvm.mlir.constant(-2100000000 : i64) : i64 + %0 = index.constant -2100000000 + // INDEX32: llvm.mlir.constant(2100000000 : i32) : i32 + // INDEX64: llvm.mlir.constant(2100000000 : i64) : i64 + %1 = index.constant 2100000000 + // INDEX32: llvm.mlir.constant(1294967296 : i32) : i32 + // INDEX64: llvm.mlir.constant(-3000000000 : i64) : i64 + %2 = index.constant -3000000000 + // INDEX32: llvm.mlir.constant(-1294967296 : i32) : i32 + // INDEX64: llvm.mlir.constant(3000000000 : i64) : i64 + %3 = index.constant 3000000000 + return +}