diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -20,12 +20,13 @@ include "mlir/IR/OpAsmInterface.td" include "mlir/IR/EnumAttr.td" -// Base class for Arith dialect ops. Ops in this dialect have no side +// Base class for Arith dialect ops. Ops in this dialect have no memory // effects and can be applied element-wise to vectors and tensors. class Arith_Op traits = []> : - Op] # - ElementwiseMappable.traits>; + Op, NoMemoryEffect] # + ElementwiseMappable.traits>; // Base class for integer and floating point arithmetic ops. All ops have one // result, require operands and results to be of the same type, and can accept @@ -35,7 +36,7 @@ // Base class for unary arithmetic operations. class Arith_UnaryOp traits = []> : - Arith_ArithOp { + Arith_ArithOp { let assemblyFormat = "$operand attr-dict `:` type($result)"; } @@ -47,7 +48,7 @@ // Base class for ternary arithmetic operations. class Arith_TernaryOp traits = []> : - Arith_ArithOp { + Arith_ArithOp { let assemblyFormat = "$a `,` $b `,` $c attr-dict `:` type($result)"; } @@ -58,6 +59,10 @@ Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>, Results<(outs SignlessIntegerLike:$result)>; +// Base class for integer binary operations without undefined behavior. +class Arith_TotalIntBinaryOp traits = []> : + Arith_IntBinaryOp; + // Base class for floating point unary operations. class Arith_FloatUnaryOp traits = []> : Arith_UnaryOp traits = []> : Arith_BinaryOp], + !listconcat([Pure, DeclareOpInterfaceMethods], traits)>, Arguments<(ins FloatLike:$lhs, FloatLike:$rhs, DefaultValuedAttr:$fastmath)>, @@ -86,7 +91,7 @@ // result. If either is a shaped type, then the other must be of the same shape. class Arith_CastOp traits = []> : - Arith_Op]>, Arguments<(ins From:$in)>, Results<(outs To:$out)> { @@ -121,7 +126,7 @@ // and returns a single `BoolLike` result. If the operand type is a vector or // tensor, then the result will be one of `i1` of the same shape. class Arith_CompareOp traits = []> : - Arith_Op]> { let results = (outs BoolLike:$result); @@ -191,7 +196,7 @@ // AddIOp //===----------------------------------------------------------------------===// -def Arith_AddIOp : Arith_IntBinaryOp<"addi", [Commutative]> { +def Arith_AddIOp : Arith_TotalIntBinaryOp<"addi", [Commutative]> { let summary = "integer addition operation"; let description = [{ The `addi` operation takes two operands and returns one result, each of @@ -217,7 +222,7 @@ } -def Arith_AddUICarryOp : Arith_Op<"addui_carry", [Commutative, +def Arith_AddUICarryOp : Arith_Op<"addui_carry", [Pure, Commutative, AllTypesMatch<["lhs", "rhs", "sum"]>]> { let summary = "unsigned integer addition operation returning sum and carry"; let description = [{ @@ -264,7 +269,7 @@ // SubIOp //===----------------------------------------------------------------------===// -def Arith_SubIOp : Arith_IntBinaryOp<"subi"> { +def Arith_SubIOp : Arith_TotalIntBinaryOp<"subi"> { let summary = "integer subtraction operation"; let hasFolder = 1; let hasCanonicalizer = 1; @@ -274,7 +279,7 @@ // MulIOp //===----------------------------------------------------------------------===// -def Arith_MulIOp : Arith_IntBinaryOp<"muli", [Commutative]> { +def Arith_MulIOp : Arith_TotalIntBinaryOp<"muli", [Commutative]> { let summary = "integer multiplication operation"; let hasFolder = 1; } @@ -283,7 +288,7 @@ // DivUIOp //===----------------------------------------------------------------------===// -def Arith_DivUIOp : Arith_IntBinaryOp<"divui"> { +def Arith_DivUIOp : Arith_IntBinaryOp<"divui", [ConditionallySpeculatable]> { let summary = "unsigned integer division operation"; let description = [{ Unsigned integer division. Rounds towards zero. Treats the leading bit as @@ -306,6 +311,12 @@ %x = arith.divui %y, %z : tensor<4x?xi8> ``` }]; + + let extraClassDeclaration = [{ + /// Interface method for ConditionallySpeculatable. + Speculation::Speculatability getSpeculatability(); + }]; + let hasFolder = 1; } @@ -313,7 +324,7 @@ // DivSIOp //===----------------------------------------------------------------------===// -def Arith_DivSIOp : Arith_IntBinaryOp<"divsi"> { +def Arith_DivSIOp : Arith_IntBinaryOp<"divsi", [ConditionallySpeculatable]> { let summary = "signed integer division operation"; let description = [{ Signed integer division. Rounds towards zero. Treats the leading bit as @@ -335,6 +346,12 @@ %x = arith.divsi %y, %z : tensor<4x?xi8> ``` }]; + + let extraClassDeclaration = [{ + /// Interface method for ConditionallySpeculatable. + Speculation::Speculatability getSpeculatability(); + }]; + let hasFolder = 1; } @@ -342,7 +359,8 @@ // CeilDivUIOp //===----------------------------------------------------------------------===// -def Arith_CeilDivUIOp : Arith_IntBinaryOp<"ceildivui"> { +def Arith_CeilDivUIOp : Arith_IntBinaryOp<"ceildivui", + [ConditionallySpeculatable]> { let summary = "unsigned ceil integer division operation"; let description = [{ Unsigned integer division. Rounds towards positive infinity. Treats the @@ -359,6 +377,12 @@ %a = arith.ceildivui %b, %c : i64 ``` }]; + + let extraClassDeclaration = [{ + /// Interface method for ConditionallySpeculatable. + Speculation::Speculatability getSpeculatability(); + }]; + let hasFolder = 1; } @@ -366,7 +390,8 @@ // CeilDivSIOp //===----------------------------------------------------------------------===// -def Arith_CeilDivSIOp : Arith_IntBinaryOp<"ceildivsi"> { +def Arith_CeilDivSIOp : Arith_IntBinaryOp<"ceildivsi", + [ConditionallySpeculatable]> { let summary = "signed ceil integer division operation"; let description = [{ Signed integer division. Rounds towards positive infinity, i.e. `7 / -2 = -3`. @@ -381,6 +406,12 @@ %a = arith.ceildivsi %b, %c : i64 ``` }]; + + let extraClassDeclaration = [{ + /// Interface method for ConditionallySpeculatable. + Speculation::Speculatability getSpeculatability(); + }]; + let hasFolder = 1; } @@ -388,7 +419,7 @@ // FloorDivSIOp //===----------------------------------------------------------------------===// -def Arith_FloorDivSIOp : Arith_IntBinaryOp<"floordivsi"> { +def Arith_FloorDivSIOp : Arith_TotalIntBinaryOp<"floordivsi"> { let summary = "signed floor integer division operation"; let description = [{ Signed integer division. Rounds towards negative infinity, i.e. `5 / -2 = -3`. @@ -411,7 +442,7 @@ // RemUIOp //===----------------------------------------------------------------------===// -def Arith_RemUIOp : Arith_IntBinaryOp<"remui"> { +def Arith_RemUIOp : Arith_TotalIntBinaryOp<"remui"> { let summary = "unsigned integer division remainder operation"; let description = [{ Unsigned integer division remainder. Treats the leading bit as the most @@ -440,7 +471,7 @@ // RemSIOp //===----------------------------------------------------------------------===// -def Arith_RemSIOp : Arith_IntBinaryOp<"remsi"> { +def Arith_RemSIOp : Arith_TotalIntBinaryOp<"remsi"> { let summary = "signed integer division remainder operation"; let description = [{ Signed integer division remainder. Treats the leading bit as sign, i.e. `6 % @@ -469,7 +500,7 @@ // AndIOp //===----------------------------------------------------------------------===// -def Arith_AndIOp : Arith_IntBinaryOp<"andi", [Commutative, Idempotent]> { +def Arith_AndIOp : Arith_TotalIntBinaryOp<"andi", [Commutative, Idempotent]> { let summary = "integer binary and"; let description = [{ The `andi` operation takes two operands and returns one result, each of @@ -498,7 +529,7 @@ // OrIOp //===----------------------------------------------------------------------===// -def Arith_OrIOp : Arith_IntBinaryOp<"ori", [Commutative, Idempotent]> { +def Arith_OrIOp : Arith_TotalIntBinaryOp<"ori", [Commutative, Idempotent]> { let summary = "integer binary or"; let description = [{ The `ori` operation takes two operands and returns one result, each of these @@ -527,7 +558,7 @@ // XOrIOp //===----------------------------------------------------------------------===// -def Arith_XOrIOp : Arith_IntBinaryOp<"xori", [Commutative]> { +def Arith_XOrIOp : Arith_TotalIntBinaryOp<"xori", [Commutative]> { let summary = "integer binary xor"; let description = [{ The `xori` operation takes two operands and returns one result, each of @@ -556,7 +587,7 @@ // ShLIOp //===----------------------------------------------------------------------===// -def Arith_ShLIOp : Arith_IntBinaryOp<"shli"> { +def Arith_ShLIOp : Arith_TotalIntBinaryOp<"shli"> { let summary = "integer left-shift"; let description = [{ The `shli` operation shifts an integer value to the left by a variable @@ -577,7 +608,7 @@ // ShRUIOp //===----------------------------------------------------------------------===// -def Arith_ShRUIOp : Arith_IntBinaryOp<"shrui"> { +def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> { let summary = "unsigned integer right-shift"; let description = [{ The `shrui` operation shifts an integer value to the right by a variable @@ -599,7 +630,7 @@ // ShRSIOp //===----------------------------------------------------------------------===// -def Arith_ShRSIOp : Arith_IntBinaryOp<"shrsi"> { +def Arith_ShRSIOp : Arith_TotalIntBinaryOp<"shrsi"> { let summary = "signed integer right-shift"; let description = [{ The `shrsi` operation shifts an integer value to the right by a variable @@ -740,7 +771,7 @@ // MaxSIOp //===----------------------------------------------------------------------===// -def Arith_MaxSIOp : Arith_IntBinaryOp<"maxsi", [Commutative]> { +def Arith_MaxSIOp : Arith_TotalIntBinaryOp<"maxsi", [Commutative]> { let summary = "signed integer maximum operation"; let hasFolder = 1; } @@ -749,7 +780,7 @@ // MaxUIOp //===----------------------------------------------------------------------===// -def Arith_MaxUIOp : Arith_IntBinaryOp<"maxui", [Commutative]> { +def Arith_MaxUIOp : Arith_TotalIntBinaryOp<"maxui", [Commutative]> { let summary = "unsigned integer maximum operation"; let hasFolder = 1; } @@ -784,7 +815,7 @@ // MinSIOp //===----------------------------------------------------------------------===// -def Arith_MinSIOp : Arith_IntBinaryOp<"minsi", [Commutative]> { +def Arith_MinSIOp : Arith_TotalIntBinaryOp<"minsi", [Commutative]> { let summary = "signed integer minimum operation"; let hasFolder = 1; } @@ -793,7 +824,7 @@ // MinUIOp //===----------------------------------------------------------------------===// -def Arith_MinUIOp : Arith_IntBinaryOp<"minui", [Commutative]> { +def Arith_MinUIOp : Arith_TotalIntBinaryOp<"minui", [Commutative]> { let summary = "unsigned integer minimum operation"; let hasFolder = 1; } @@ -1250,7 +1281,7 @@ // SelectOp //===----------------------------------------------------------------------===// -def SelectOp : Arith_Op<"select", [ +def SelectOp : Arith_Op<"select", [Pure, AllTypesMatch<["true_value", "false_value", "result"]>, DeclareOpInterfaceMethods, ] # ElementwiseMappable.traits> { diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -367,6 +367,12 @@ return div0 ? Attribute() : result; } +Speculation::Speculatability arith::DivUIOp::getSpeculatability() { + // X / 0 => UB + return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable + : Speculation::NotSpeculatable; +} + //===----------------------------------------------------------------------===// // DivSIOp //===----------------------------------------------------------------------===// @@ -390,6 +396,18 @@ return overflowOrDiv0 ? Attribute() : result; } +Speculation::Speculatability arith::DivSIOp::getSpeculatability() { + bool mayHaveUB = true; + + APInt constRHS; + // X / 0 => UB + // INT_MIN / -1 => UB + if (matchPattern(getRhs(), m_ConstantInt(&constRHS))) + mayHaveUB = constRHS.isAllOnes() || constRHS.isZero(); + + return mayHaveUB ? Speculation::NotSpeculatable : Speculation::Speculatable; +} + //===----------------------------------------------------------------------===// // Ceil and floor division folding helpers //===----------------------------------------------------------------------===// @@ -428,6 +446,12 @@ return overflowOrDiv0 ? Attribute() : result; } +Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() { + // X / 0 => UB + return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable + : Speculation::NotSpeculatable; +} + //===----------------------------------------------------------------------===// // CeilDivSIOp //===----------------------------------------------------------------------===// @@ -477,6 +501,18 @@ return overflowOrDiv0 ? Attribute() : result; } +Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() { + bool mayHaveUB = true; + + APInt constRHS; + // X / 0 => UB + // INT_MIN / -1 => UB + if (matchPattern(getRhs(), m_ConstantInt(&constRHS))) + mayHaveUB = constRHS.isAllOnes() || constRHS.isZero(); + + return mayHaveUB ? Speculation::NotSpeculatable : Speculation::Speculatable; +} + //===----------------------------------------------------------------------===// // FloorDivSIOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir --- a/mlir/test/Transforms/loop-invariant-code-motion.mlir +++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir @@ -696,3 +696,181 @@ return } + +// ----- + +func.func @no_speculate_divui( +// CHECK-LABEL: @no_speculate_divui( + %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { + scf.for %i = %lb to %ub step %step { +// CHECK: scf.for +// CHECK: arith.divui + %val = arith.divui %num, %denom : i32 + } + + return +} + +func.func @no_speculate_divsi( +// CHECK-LABEL: @no_speculate_divsi( + %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { + scf.for %i = %lb to %ub step %step { +// CHECK: scf.for +// CHECK: arith.divsi + %val = arith.divsi %num, %denom : i32 + } + + return +} + +func.func @no_speculate_ceildivui( +// CHECK-LABEL: @no_speculate_ceildivui( + %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { + scf.for %i = %lb to %ub step %step { +// CHECK: scf.for +// CHECK: arith.ceildivui + %val = arith.ceildivui %num, %denom : i32 + } + + return +} + +func.func @no_speculate_ceildivsi( +// CHECK-LABEL: @no_speculate_ceildivsi( + %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { + scf.for %i = %lb to %ub step %step { +// CHECK: scf.for +// CHECK: arith.ceildivsi + %val = arith.ceildivsi %num, %denom : i32 + } + + return +} + +func.func @no_speculate_divui_const(%num: i32, %lb: index, %ub: index, %step: index) { +// CHECK-LABEL: @no_speculate_divui_const( + %c0 = arith.constant 0 : i32 + scf.for %i = %lb to %ub step %step { +// CHECK: scf.for +// CHECK: arith.divui + %val = arith.divui %num, %c0 : i32 + } + + return +} + +func.func @speculate_divui_const( +// CHECK-LABEL: @speculate_divui_const( + %num: i32, %lb: index, %ub: index, %step: index) { + %c5 = arith.constant 5 : i32 +// CHECK: arith.divui +// CHECK: scf.for + scf.for %i = %lb to %ub step %step { + %val = arith.divui %num, %c5 : i32 + } + + return +} + +func.func @no_speculate_ceildivui_const(%num: i32, %lb: index, %ub: index, %step: index) { +// CHECK-LABEL: @no_speculate_ceildivui_const( + %c0 = arith.constant 0 : i32 + scf.for %i = %lb to %ub step %step { +// CHECK: scf.for +// CHECK: arith.ceildivui + %val = arith.ceildivui %num, %c0 : i32 + } + + return +} + +func.func @speculate_ceildivui_const( +// CHECK-LABEL: @speculate_ceildivui_const( + %num: i32, %lb: index, %ub: index, %step: index) { + %c5 = arith.constant 5 : i32 +// CHECK: arith.ceildivui +// CHECK: scf.for + scf.for %i = %lb to %ub step %step { + %val = arith.ceildivui %num, %c5 : i32 + } + + return +} + +func.func @no_speculate_divsi_const0( +// CHECK-LABEL: @no_speculate_divsi_const0( + %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { + %c0 = arith.constant 0 : i32 + scf.for %i = %lb to %ub step %step { +// CHECK: scf.for +// CHECK: arith.divsi + %val = arith.divsi %num, %c0 : i32 + } + + return +} + +func.func @no_speculate_divsi_const_minus1( +// CHECK-LABEL: @no_speculate_divsi_const_minus1( + %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { + %cm1 = arith.constant -1 : i32 + scf.for %i = %lb to %ub step %step { +// CHECK: scf.for +// CHECK: arith.divsi + %val = arith.divsi %num, %cm1 : i32 + } + + return +} + +func.func @speculate_divsi_const( +// CHECK-LABEL: @speculate_divsi_const( + %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { + %c5 = arith.constant 5 : i32 + scf.for %i = %lb to %ub step %step { +// CHECK: arith.divsi +// CHECK: scf.for + %val = arith.divsi %num, %c5 : i32 + } + + return +} + +func.func @no_speculate_ceildivsi_const0( +// CHECK-LABEL: @no_speculate_ceildivsi_const0( + %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { + %c0 = arith.constant 0 : i32 + scf.for %i = %lb to %ub step %step { +// CHECK: scf.for +// CHECK: arith.ceildivsi + %val = arith.ceildivsi %num, %c0 : i32 + } + + return +} + +func.func @no_speculate_ceildivsi_const_minus1( +// CHECK-LABEL: @no_speculate_ceildivsi_const_minus1( + %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { + %cm1 = arith.constant -1 : i32 + scf.for %i = %lb to %ub step %step { +// CHECK: scf.for +// CHECK: arith.ceildivsi + %val = arith.ceildivsi %num, %cm1 : i32 + } + + return +} + +func.func @speculate_ceildivsi_const( +// CHECK-LABEL: @speculate_ceildivsi_const( + %num: i32, %denom: i32, %lb: index, %ub: index, %step: index) { + %c5 = arith.constant 5 : i32 + scf.for %i = %lb to %ub step %step { +// CHECK: arith.ceildivsi +// CHECK: scf.for + %val = arith.ceildivsi %num, %c5 : i32 + } + + return +}