diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp --- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp @@ -8,7 +8,9 @@ #include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Analysis/Presburger/IntegerRelation.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/Transforms.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" @@ -18,6 +20,7 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" @@ -419,6 +422,65 @@ using SIToFPPattern = IToFPPattern; using UIToFPPattern = IToFPPattern; +//===----------------------------------------------------------------------===// +// Index Cast Patterns +//===----------------------------------------------------------------------===// + +// These rely on the `ValueBounds` interface for index values. For example, we +// can often statically tell index value bounds of loop induction variables. + +template +struct IndexCastPattern final : NarrowingPattern { + using NarrowingPattern::NarrowingPattern; + + LogicalResult matchAndRewrite(CastOp op, + PatternRewriter &rewriter) const override { + Value in = op.getIn(); + // We only support scalar index -> integer casts. + if (!isa(in.getType())) + return failure(); + + // Check the lower bound in both the signed and unsigned cast case. We + // conservatively assume that even unsigned casts may be performed on + // negative indices. + FailureOr lb = ValueBoundsConstraintSet::computeConstantBound( + presburger::BoundType::LB, in); + if (failed(lb)) + return failure(); + + FailureOr ub = ValueBoundsConstraintSet::computeConstantBound( + presburger::BoundType::UB, in, /*dim=*/std::nullopt, + /*stopCondition=*/nullptr, /*closedUB=*/true); + if (failed(ub)) + return failure(); + + assert(*lb <= *ub && "Invalid bounds"); + unsigned lbBitsRequired = calculateBitsRequired(APInt(64, *lb), Kind); + unsigned ubBitsRequired = calculateBitsRequired(APInt(64, *ub), Kind); + unsigned bitsRequired = std::max(lbBitsRequired, ubBitsRequired); + + IntegerType resultTy = cast(op.getType()); + if (resultTy.getWidth() <= bitsRequired) + return failure(); + + FailureOr narrowTy = this->getNarrowType(bitsRequired, resultTy); + if (failed(narrowTy)) + return failure(); + + Value newCast = rewriter.create(op.getLoc(), *narrowTy, op.getIn()); + + if (Kind == ExtensionKind::Sign) + rewriter.replaceOpWithNewOp(op, resultTy, newCast); + else + rewriter.replaceOpWithNewOp(op, resultTy, newCast); + return success(); + } +}; +using IndexCastSIPattern = + IndexCastPattern; +using IndexCastUIPattern = + IndexCastPattern; + //===----------------------------------------------------------------------===// // Patterns to Commute Extension Ops //===----------------------------------------------------------------------===// @@ -714,8 +776,8 @@ patterns.add( - patterns.getContext(), options); + MinUIPattern, SIToFPPattern, UIToFPPattern, IndexCastSIPattern, + IndexCastUIPattern>(patterns.getContext(), options); } } // namespace mlir::arith diff --git a/mlir/test/Dialect/Linalg/int-narrowing.mlir b/mlir/test/Dialect/Linalg/int-narrowing.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/int-narrowing.mlir @@ -0,0 +1,147 @@ +// RUN: mlir-opt --arith-int-narrowing="int-bitwidths-supported=1,8,16,32" \ +// RUN: --verify-diagnostics %s | FileCheck %s + +// Check that we can calculate `linalg.index` value bounds and use them to +// optimize index casts. + +//===----------------------------------------------------------------------===// +// arith.index_cast +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @linalg_indexcast_dim_0_i8 +// CHECK: %[[IDX:.+]] = linalg.index 0 : index +// CHECK-NEXT: %[[INT:.+]] = arith.index_cast %[[IDX]] : index to i8 +// CHECK-NEXT: %[[FP:.+]] = arith.sitofp %[[INT]] : i8 to f16 +// CHECK-NEXT: linalg.yield %[[FP]] : f16 +func.func @linalg_indexcast_dim_0_i8(%arg0: tensor) -> tensor<128xf16> { + %init = tensor.empty() : tensor<128xf16> + %res = linalg.generic { + indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } + ins(%arg0 : tensor) + outs(%init : tensor<128xf16>) { + ^bb0(%in: f16, %out: f16): + %idx = linalg.index 0 : index + %int = arith.index_cast %idx : index to i64 + %fp = arith.sitofp %int : i64 to f16 + linalg.yield %fp : f16 + } -> tensor<128xf16> + + return %res : tensor<128xf16> +} + +// CHECK-LABEL: func @linalg_indexcast_dim_1_i16 +// CHECK: %[[IDX:.+]] = linalg.index 1 : index +// CHECK-NEXT: %[[INT:.+]] = arith.index_cast %[[IDX]] : index to i16 +// CHECK-NEXT: %[[FP:.+]] = arith.sitofp %[[INT]] : i16 to f16 +// CHECK-NEXT: linalg.yield %[[FP]] : f16 +func.func @linalg_indexcast_dim_1_i16(%arg0: tensor, %arg1: tensor) -> tensor { + %res = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } + ins(%arg0 : tensor) + outs(%arg1 : tensor) { + ^bb0(%in: f16, %out: f16): + %idx = linalg.index 1 : index + %int = arith.index_cast %idx : index to i64 + %fp = arith.sitofp %int : i64 to f16 + linalg.yield %fp : f16 + } -> tensor + + return %res : tensor +} + +// CHECK-LABEL: func @linalg_indexcast_dynamic_dim_i64 +// CHECK: %[[IDX:.+]] = linalg.index 0 : index +// CHECK-NEXT: %[[INT:.+]] = arith.index_cast %[[IDX]] : index to i64 +// CHECK-NEXT: %[[FP:.+]] = arith.sitofp %[[INT]] : i64 to f16 +// CHECK-NEXT: linalg.yield %[[FP]] : f16 +func.func @linalg_indexcast_dynamic_dim_i64(%arg0: tensor, %arg1: tensor) -> tensor { + %res = linalg.generic { + indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } + ins(%arg0 : tensor) + outs(%arg1 : tensor) { + ^bb0(%in: f16, %out: f16): + %idx = linalg.index 0 : index + %int = arith.index_cast %idx : index to i64 + %fp = arith.sitofp %int : i64 to f16 + linalg.yield %fp : f16 + } -> tensor + + return %res : tensor +} + +//===----------------------------------------------------------------------===// +// arith.index_castui +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @linalg_indexcastui_dim_0_i8 +// CHECK: %[[IDX:.+]] = linalg.index 0 : index +// CHECK-NEXT: %[[INT:.+]] = arith.index_castui %[[IDX]] : index to i8 +// CHECK-NEXT: %[[FP:.+]] = arith.uitofp %[[INT]] : i8 to f16 +// CHECK-NEXT: linalg.yield %[[FP]] : f16 +func.func @linalg_indexcastui_dim_0_i8(%arg0: tensor) -> tensor<256xf16> { + %init = tensor.empty() : tensor<256xf16> + %res = linalg.generic { + indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } + ins(%arg0 : tensor) + outs(%init : tensor<256xf16>) { + ^bb0(%in: f16, %out: f16): + %idx = linalg.index 0 : index + %int = arith.index_castui %idx : index to i64 + %fp = arith.uitofp %int : i64 to f16 + linalg.yield %fp : f16 + } -> tensor<256xf16> + + return %res : tensor<256xf16> +} + +// CHECK-LABEL: func @linalg_indexcastui_dim_1_i16 +// CHECK: %[[IDX:.+]] = linalg.index 1 : index +// CHECK-NEXT: %[[INT:.+]] = arith.index_castui %[[IDX]] : index to i16 +// CHECK-NEXT: %[[FP:.+]] = arith.uitofp %[[INT]] : i16 to f16 +// CHECK-NEXT: linalg.yield %[[FP]] : f16 +func.func @linalg_indexcastui_dim_1_i16(%arg0: tensor, %arg1: tensor) -> tensor { + %res = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } + ins(%arg0 : tensor) + outs(%arg1 : tensor) { + ^bb0(%in: f16, %out: f16): + %idx = linalg.index 1 : index + %int = arith.index_castui %idx : index to i64 + %fp = arith.uitofp %int : i64 to f16 + linalg.yield %fp : f16 + } -> tensor + + return %res : tensor +} + +// CHECK-LABEL: func @linalg_indexcastui_dynamic_dim_i64 +// CHECK: %[[IDX:.+]] = linalg.index 0 : index +// CHECK-NEXT: %[[INT:.+]] = arith.index_castui %[[IDX]] : index to i64 +// CHECK-NEXT: %[[FP:.+]] = arith.uitofp %[[INT]] : i64 to f16 +// CHECK-NEXT: linalg.yield %[[FP]] : f16 +func.func @linalg_indexcastui_dynamic_dim_i64(%arg0: tensor, %arg1: tensor) -> tensor { + %res = linalg.generic { + indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } + ins(%arg0 : tensor) + outs(%arg1 : tensor) { + ^bb0(%in: f16, %out: f16): + %idx = linalg.index 0 : index + %int = arith.index_castui %idx : index to i64 + %fp = arith.uitofp %int : i64 to f16 + linalg.yield %fp : f16 + } -> tensor + + return %res : tensor +}