diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -59,8 +59,7 @@ /// Return the unique instance of OpType in `block` if it is indeed unique. /// Return null if none or more than 1 instances exist. -template -static OpType getSingleOpOfType(Block &block) { +template static OpType getSingleOpOfType(Block &block) { OpType res = nullptr; block.walk([&](OpType op) { if (res) { @@ -129,7 +128,9 @@ // TODO: more fields than add/mul. if (!isAddMul(linalgOp->getRegion(0).front()) && !isAddMul(linalgOp->getRegion(0).front()) && - !isAddMul(linalgOp->getRegion(0).front())) + !isAddMul( + linalgOp->getRegion(0).front()) && + !isAddMul(linalgOp->getRegion(0).front())) return MatchContractionResult::NotAddMul; return MatchContractionResult::Success; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -325,6 +325,8 @@ bool allComplex = isComplex(arg0) && isComplex(arg1); bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1); bool allInteger = isInteger(arg0) && isInteger(arg1); + bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 && + arg1.getType().getIntOrFloatBitWidth() == 1; if (!allComplex && !allFloatingPoint && !allInteger) llvm_unreachable("unsupported non numeric type"); OpBuilder builder = getBuilder(); @@ -334,18 +336,24 @@ return builder.create(arg0.getLoc(), arg0, arg1); if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); + if (allBool) + return builder.create(arg0.getLoc(), arg0, arg1); return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::sub: if (allComplex) return builder.create(arg0.getLoc(), arg0, arg1); if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); + if (allBool) + llvm_unreachable("unsupported operation: sub with bools"); return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::mul: if (allComplex) return builder.create(arg0.getLoc(), arg0, arg1); if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); + if (allBool) + return builder.create(arg0.getLoc(), arg0, arg1); return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::max_signed: assert(!allComplex);