Index: mlir/include/mlir/Dialect/Affine/EDSC/Builders.h =================================================================== --- mlir/include/mlir/Dialect/Affine/EDSC/Builders.h +++ mlir/include/mlir/Dialect/Affine/EDSC/Builders.h @@ -86,10 +86,14 @@ /// Comparison operator overloadings. Value eq(Value lhs, Value rhs); Value ne(Value lhs, Value rhs); -Value operator<(Value lhs, Value rhs); -Value operator<=(Value lhs, Value rhs); -Value operator>(Value lhs, Value rhs); -Value operator>=(Value lhs, Value rhs); +Value slt(Value lhs, Value rhs); +Value sle(Value lhs, Value rhs); +Value sgt(Value lhs, Value rhs); +Value sge(Value lhs, Value rhs); +Value ult(Value lhs, Value rhs); +Value ule(Value lhs, Value rhs); +Value ugt(Value lhs, Value rhs); +Value uge(Value lhs, Value rhs); } // namespace op @@ -179,24 +183,44 @@ return ne(value, e); } template -Value TemplatedIndexedValue::operator<(Value e) { - using op::operator<; - return static_cast(*this) < e; +Value TemplatedIndexedValue::slt(Value e) { + using op::slt; + return slt(static_cast(*this), e); } template -Value TemplatedIndexedValue::operator<=(Value e) { - using op::operator<=; - return static_cast(*this) <= e; +Value TemplatedIndexedValue::sle(Value e) { + using op::sle; + return sle(static_cast(*this), e); } template -Value TemplatedIndexedValue::operator>(Value e) { - using op::operator>; - return static_cast(*this) > e; +Value TemplatedIndexedValue::sgt(Value e) { + using op::sgt; + return sgt(static_cast(*this), e); } template -Value TemplatedIndexedValue::operator>=(Value e) { - using op::operator>=; - return static_cast(*this) >= e; +Value TemplatedIndexedValue::sge(Value e) { + using op::sge; + return sge(static_cast(*this), e); +} +template +Value TemplatedIndexedValue::ult(Value e) { + using op::ult; + return ult(static_cast(*this), e); +} +template +Value TemplatedIndexedValue::ule(Value e) { + using op::ule; + return ule(static_cast(*this), e); +} +template +Value TemplatedIndexedValue::ugt(Value e) { + using op::ugt; + return ugt(static_cast(*this), e); +} +template +Value TemplatedIndexedValue::uge(Value e) { + using op::uge; + return uge(static_cast(*this), e); } } // namespace edsc Index: mlir/include/mlir/EDSC/Builders.h =================================================================== --- mlir/include/mlir/EDSC/Builders.h +++ mlir/include/mlir/EDSC/Builders.h @@ -440,21 +440,37 @@ /// Comparison operator overloadings. Value eq(Value e); Value ne(Value e); - Value operator<(Value e); - Value operator<=(Value e); - Value operator>(Value e); - Value operator>=(Value e); - Value operator<(TemplatedIndexedValue e) { - return *this < static_cast(e); + Value slt(Value e); + Value sle(Value e); + Value sgt(Value e); + Value sge(Value e); + Value ult(Value e); + Value ule(Value e); + Value ugt(Value e); + Value uge(Value e); + Value slt(TemplatedIndexedValue e) { + return slt(*this, static_cast(e)); } - Value operator<=(TemplatedIndexedValue e) { - return *this <= static_cast(e); + Value sle(TemplatedIndexedValue e) { + return sle(*this, static_cast(e)); } - Value operator>(TemplatedIndexedValue e) { - return *this > static_cast(e); + Value sgt(TemplatedIndexedValue e) { + return sgt(*this, static_cast(e)); } - Value operator>=(TemplatedIndexedValue e) { - return *this >= static_cast(e); + Value sge(TemplatedIndexedValue e) { + return sge(*this, static_cast(e)); + } + Value ult(TemplatedIndexedValue e) { + return ult(*this, static_cast(e)); + } + Value ule(TemplatedIndexedValue e) { + return ule(*this, static_cast(e)); + } + Value ugt(TemplatedIndexedValue e) { + return ugt(*this, static_cast(e)); + } + Value uge(TemplatedIndexedValue e) { + return uge(*this, static_cast(e)); } private: Index: mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp =================================================================== --- mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -187,7 +187,7 @@ using namespace mlir::edsc::op; majorIvsPlusOffsets.push_back(iv + off); if (xferOp.isMaskedDim(leadingRank + idx)) { - Value inBounds = majorIvsPlusOffsets.back() < ub; + Value inBounds = slt(majorIvsPlusOffsets.back(), ub); inBoundsCondition = (inBoundsCondition) ? (inBoundsCondition && inBounds) : inBounds; } @@ -433,16 +433,16 @@ auto i = memRefAccess[memRefDim]; if (loopIndex < 0) { auto N_minus_1 = N - one; - auto select_1 = std_select(i < N, i, N_minus_1); + auto select_1 = std_select(slt(i, N), i, N_minus_1); clippedScalarAccessExprs[memRefDim] = - std_select(i < zero, zero, select_1); + std_select(slt(i, zero), zero, select_1); } else { auto ii = ivs[loopIndex]; auto i_plus_ii = i + ii; auto N_minus_1 = N - one; - auto select_1 = std_select(i_plus_ii < N, i_plus_ii, N_minus_1); + auto select_1 = std_select(slt(i_plus_ii, N), i_plus_ii, N_minus_1); clippedScalarAccessExprs[memRefDim] = - std_select(i_plus_ii < zero, zero, select_1); + std_select(slt(i_plus_ii, zero), zero, select_1); } } Index: mlir/lib/Dialect/Affine/EDSC/Builders.cpp =================================================================== --- mlir/lib/Dialect/Affine/EDSC/Builders.cpp +++ mlir/lib/Dialect/Affine/EDSC/Builders.cpp @@ -251,29 +251,51 @@ ? createFComparisonExpr(CmpFPredicate::ONE, lhs, rhs) : createIComparisonExpr(CmpIPredicate::ne, lhs, rhs); } -Value mlir::edsc::op::operator<(Value lhs, Value rhs) { +Value mlir::edsc::op::slt(Value lhs, Value rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs) - : - // TODO(ntv,zinenko): signed by default, how about unsigned? - createIComparisonExpr(CmpIPredicate::slt, lhs, rhs); + : createIComparisonExpr(CmpIPredicate::slt, lhs, rhs); } -Value mlir::edsc::op::operator<=(Value lhs, Value rhs) { +Value mlir::edsc::op::sle(Value lhs, Value rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs) : createIComparisonExpr(CmpIPredicate::sle, lhs, rhs); } -Value mlir::edsc::op::operator>(Value lhs, Value rhs) { +Value mlir::edsc::op::sgt(Value lhs, Value rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs) : createIComparisonExpr(CmpIPredicate::sgt, lhs, rhs); } -Value mlir::edsc::op::operator>=(Value lhs, Value rhs) { +Value mlir::edsc::op::sge(Value lhs, Value rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs) : createIComparisonExpr(CmpIPredicate::sge, lhs, rhs); } +Value mlir::edsc::op::ult(Value lhs, Value rhs) { + auto type = lhs.getType(); + return type.isa() + ? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs) + : createIComparisonExpr(CmpIPredicate::ult, lhs, rhs); +} +Value mlir::edsc::op::ule(Value lhs, Value rhs) { + auto type = lhs.getType(); + return type.isa() + ? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs) + : createIComparisonExpr(CmpIPredicate::ule, lhs, rhs); +} +Value mlir::edsc::op::ugt(Value lhs, Value rhs) { + auto type = lhs.getType(); + return type.isa() + ? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs) + : createIComparisonExpr(CmpIPredicate::ugt, lhs, rhs); +} +Value mlir::edsc::op::uge(Value lhs, Value rhs) { + auto type = lhs.getType(); + return type.isa() + ? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs) + : createIComparisonExpr(CmpIPredicate::uge, lhs, rhs); +} Index: mlir/lib/Dialect/Linalg/EDSC/Builders.cpp =================================================================== --- mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -221,8 +221,8 @@ StructuredIndexed I2, StructuredIndexed O) { BinaryPointwiseOpBuilder binOp([](Value a, Value b) -> Value { - using edsc::op::operator>; - return std_select(a > b, a, b); + using edsc::op::sgt; + return std_select(sgt(a, b), a, b); }); return linalg_generic_pointwise(binOp, I1, I2, O); } Index: mlir/lib/Dialect/Linalg/Transforms/Loops.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -308,16 +308,16 @@ continue; } - using edsc::op::operator<; - using edsc::op::operator>=; + using edsc::op::sge; + using edsc::op::slt; using edsc::op::operator||; - Value leftOutOfBound = dim < zeroIndex; + Value leftOutOfBound = slt(dim, zeroIndex); if (conds.empty()) conds.push_back(leftOutOfBound); else conds.push_back(conds.back() || leftOutOfBound); Value rightBound = std_dim(convOp.input(), idx); - conds.push_back(conds.back() || (dim >= rightBound)); + conds.push_back(conds.back() || (sge(dim, rightBound))); // When padding is involved, the indices will only be shifted to negative, // so having a max op is enough. @@ -386,8 +386,8 @@ // Emit scalar form. Value lhs = std_load(op.output(), indices.outputs); Value rhs = std_load(op.input(), indices.inputs); - using edsc::op::operator>; - Value maxValue = std_select(lhs > rhs, lhs, rhs); + using edsc::op::sgt; + Value maxValue = std_select(sgt(lhs, rhs), lhs, rhs); std_store(maxValue, op.output(), indices.outputs); } }; @@ -401,8 +401,8 @@ // Emit scalar form. Value lhs = std_load(op.output(), indices.outputs); Value rhs = std_load(op.input(), indices.inputs); - using edsc::op::operator<; - Value minValue = std_select(lhs < rhs, lhs, rhs); + using edsc::op::slt; + Value minValue = std_select(slt(lhs, rhs), lhs, rhs); std_store(minValue, op.output(), indices.outputs); } }; Index: mlir/test/EDSC/builder-api-test.cpp =================================================================== --- mlir/test/EDSC/builder-api-test.cpp +++ mlir/test/EDSC/builder-api-test.cpp @@ -518,9 +518,9 @@ TEST_FUNC(select_op_i32) { using namespace edsc::op; - auto f32Type = FloatType::getF32(&globalContext()); + auto i32Type = IntegerType::get(32, &globalContext()); auto memrefType = MemRefType::get( - {ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0); + {ShapedType::kDynamicSize, ShapedType::kDynamicSize}, i32Type, {}, 0); auto f = makeFunction("select_op", {}, {memrefType}); OpBuilder builder(f.getBody()); @@ -533,6 +533,15 @@ Value &i = ivs[0], &j = ivs[1]; AffineLoopNestBuilder(ivs, {zero, zero}, {one, one}, {1, 1})([&]{ std_select(eq(i, zero), A(zero, zero), A(i, j)); + std_select(ne(i, zero), A(zero, zero), A(i, j)); + std_select(slt(i, zero), A(zero, zero), A(i, j)); + std_select(sle(i, zero), A(zero, zero), A(i, j)); + std_select(sgt(i, zero), A(zero, zero), A(i, j)); + std_select(sge(i, zero), A(zero, zero), A(i, j)); + std_select(ult(i, zero), A(zero, zero), A(i, j)); + std_select(ule(i, zero), A(zero, zero), A(i, j)); + std_select(ugt(i, zero), A(zero, zero), A(i, j)); + std_select(uge(i, zero), A(zero, zero), A(i, j)); }); // CHECK-LABEL: @select_op @@ -542,6 +551,42 @@ // CHECK-DAG: {{.*}} = affine.load // CHECK-DAG: {{.*}} = affine.load // CHECK-NEXT: {{.*}} = select + // CHECK-DAG: {{.*}} = cmpi "ne" + // CHECK-DAG: {{.*}} = affine.load + // CHECK-DAG: {{.*}} = affine.load + // CHECK-NEXT: {{.*}} = select + // CHECK-DAG: {{.*}} = cmpi "slt" + // CHECK-DAG: {{.*}} = affine.load + // CHECK-DAG: {{.*}} = affine.load + // CHECK-NEXT: {{.*}} = select + // CHECK-DAG: {{.*}} = cmpi "sle" + // CHECK-DAG: {{.*}} = affine.load + // CHECK-DAG: {{.*}} = affine.load + // CHECK-NEXT: {{.*}} = select + // CHECK-DAG: {{.*}} = cmpi "sgt" + // CHECK-DAG: {{.*}} = affine.load + // CHECK-DAG: {{.*}} = affine.load + // CHECK-NEXT: {{.*}} = select + // CHECK-DAG: {{.*}} = cmpi "sge" + // CHECK-DAG: {{.*}} = affine.load + // CHECK-DAG: {{.*}} = affine.load + // CHECK-NEXT: {{.*}} = select + // CHECK-DAG: {{.*}} = cmpi "ult" + // CHECK-DAG: {{.*}} = affine.load + // CHECK-DAG: {{.*}} = affine.load + // CHECK-NEXT: {{.*}} = select + // CHECK-DAG: {{.*}} = cmpi "ule" + // CHECK-DAG: {{.*}} = affine.load + // CHECK-DAG: {{.*}} = affine.load + // CHECK-NEXT: {{.*}} = select + // CHECK-DAG: {{.*}} = cmpi "ugt" + // CHECK-DAG: {{.*}} = affine.load + // CHECK-DAG: {{.*}} = affine.load + // CHECK-NEXT: {{.*}} = select + // CHECK-DAG: {{.*}} = cmpi "uge" + // CHECK-DAG: {{.*}} = affine.load + // CHECK-DAG: {{.*}} = affine.load + // CHECK-NEXT: {{.*}} = select // clang-format on f.print(llvm::outs()); f.erase(); @@ -565,10 +610,14 @@ using namespace edsc::op; std_select(eq(B(i, j), B(i + one, j)), A(zero, zero), A(i, j)); std_select(ne(B(i, j), B(i + one, j)), A(zero, zero), A(i, j)); - std_select(B(i, j) >= B(i + one, j), A(zero, zero), A(i, j)); - std_select(B(i, j) <= B(i + one, j), A(zero, zero), A(i, j)); - std_select(B(i, j) < B(i + one, j), A(zero, zero), A(i, j)); - std_select(B(i, j) > B(i + one, j), A(zero, zero), A(i, j)); + std_select(sge(B(i, j), B(i + one, j)), A(zero, zero), A(i, j)); + std_select(sle(B(i, j), B(i + one, j)), A(zero, zero), A(i, j)); + std_select(slt(B(i, j), B(i + one, j)), A(zero, zero), A(i, j)); + std_select(sgt(B(i, j), B(i + one, j)), A(zero, zero), A(i, j)); + std_select(uge(B(i, j), B(i + one, j)), A(zero, zero), A(i, j)); + std_select(ule(B(i, j), B(i + one, j)), A(zero, zero), A(i, j)); + std_select(ult(B(i, j), B(i + one, j)), A(zero, zero), A(i, j)); + std_select(ugt(B(i, j), B(i + one, j)), A(zero, zero), A(i, j)); }); // CHECK-LABEL: @select_op @@ -616,6 +665,34 @@ // CHECK-DAG: affine.load // CHECK-DAG: affine.apply // CHECK-NEXT: select + // CHECK-DAG: cmpf "oge" + // CHECK-DAG: affine.load + // CHECK-DAG: affine.load + // CHECK-DAG: affine.load + // CHECK-DAG: affine.load + // CHECK-DAG: affine.apply + // CHECK-NEXT: select + // CHECK-DAG: cmpf "ole" + // CHECK-DAG: affine.load + // CHECK-DAG: affine.load + // CHECK-DAG: affine.load + // CHECK-DAG: affine.load + // CHECK-DAG: affine.apply + // CHECK-NEXT: select + // CHECK-DAG: cmpf "olt" + // CHECK-DAG: affine.load + // CHECK-DAG: affine.load + // CHECK-DAG: affine.load + // CHECK-DAG: affine.load + // CHECK-DAG: affine.apply + // CHECK-NEXT: select + // CHECK-DAG: cmpf "ogt" + // CHECK-DAG: affine.load + // CHECK-DAG: affine.load + // CHECK-DAG: affine.load + // CHECK-DAG: affine.load + // CHECK-DAG: affine.apply + // CHECK-NEXT: select // clang-format on f.print(llvm::outs()); f.erase();