diff --git a/mlir/lib/Dialect/AffineOps/EDSC/Builders.cpp b/mlir/lib/Dialect/AffineOps/EDSC/Builders.cpp --- a/mlir/lib/Dialect/AffineOps/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/AffineOps/EDSC/Builders.cpp @@ -209,11 +209,13 @@ ValueHandle mlir::edsc::op::operator&&(ValueHandle lhs, ValueHandle rhs) { assert(lhs.getType().isInteger(1) && "expected boolean expression on LHS"); assert(rhs.getType().isInteger(1) && "expected boolean expression on RHS"); - return lhs * rhs; + return ValueHandle::create(lhs, rhs); } ValueHandle mlir::edsc::op::operator||(ValueHandle lhs, ValueHandle rhs) { - return !(!lhs && !rhs); + assert(lhs.getType().isInteger(1) && "expected boolean expression on LHS"); + assert(rhs.getType().isInteger(1) && "expected boolean expression on RHS"); + return ValueHandle::create(lhs, rhs); } static ValueHandle createIComparisonExpr(CmpIPredicate predicate, diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -480,6 +480,44 @@ f.erase(); } +TEST_FUNC(operator_or) { + auto i1Type = IntegerType::get(/*width=*/1, &globalContext()); + auto f = makeFunction("operator_or", {}, {i1Type, i1Type}); + + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + + using op::operator||; + ValueHandle lhs(f.getArgument(0)); + ValueHandle rhs(f.getArgument(1)); + lhs || rhs; + + // CHECK-LABEL: @operator_or + // CHECK: [[ARG0:%.*]]: i1, [[ARG1:%.*]]: i1 + // CHECK: or [[ARG0]], [[ARG1]] + f.print(llvm::outs()); + f.erase(); +} + +TEST_FUNC(operator_and) { + auto i1Type = IntegerType::get(/*width=*/1, &globalContext()); + auto f = makeFunction("operator_and", {}, {i1Type, i1Type}); + + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + + using op::operator&&; + ValueHandle lhs(f.getArgument(0)); + ValueHandle rhs(f.getArgument(1)); + lhs &&rhs; + + // CHECK-LABEL: @operator_and + // CHECK: [[ARG0:%.*]]: i1, [[ARG1:%.*]]: i1 + // CHECK: and [[ARG0]], [[ARG1]] + f.print(llvm::outs()); + f.erase(); +} + TEST_FUNC(select_op_i32) { using namespace edsc::op; auto indexType = IndexType::get(&globalContext());