diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -59,8 +59,7 @@ /// Return the unique instance of OpType in `block` if it is indeed unique. /// Return null if none or more than 1 instances exist. -template -static OpType getSingleOpOfType(Block &block) { +template static OpType getSingleOpOfType(Block &block) { OpType res = nullptr; block.walk([&](OpType op) { if (res) { @@ -129,7 +128,9 @@ // TODO: more fields than add/mul. if (!isAddMul(linalgOp->getRegion(0).front()) && !isAddMul(linalgOp->getRegion(0).front()) && - !isAddMul(linalgOp->getRegion(0).front())) + !isAddMul( + linalgOp->getRegion(0).front()) && + !isAddMul(linalgOp->getRegion(0).front())) return MatchContractionResult::NotAddMul; return MatchContractionResult::Success; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -325,6 +325,8 @@ bool allComplex = isComplex(arg0) && isComplex(arg1); bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1); bool allInteger = isInteger(arg0) && isInteger(arg1); + bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 && + arg1.getType().getIntOrFloatBitWidth() == 1; if (!allComplex && !allFloatingPoint && !allInteger) llvm_unreachable("unsupported non numeric type"); OpBuilder builder = getBuilder(); @@ -334,18 +336,24 @@ return builder.create(arg0.getLoc(), arg0, arg1); if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); + if (allBool) + return builder.create(arg0.getLoc(), arg0, arg1); return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::sub: if (allComplex) return builder.create(arg0.getLoc(), arg0, arg1); if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); + if (allBool) + llvm_unreachable("unsupported operation: sub with bools"); return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::mul: if (allComplex) return builder.create(arg0.getLoc(), arg0, arg1); if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); + if (allBool) + return builder.create(arg0.getLoc(), arg0, arg1); return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::max_signed: assert(!allComplex); diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -137,6 +137,32 @@ // CHECKPARALLEL: store %[[res]], %[[C]][] : memref +func.func @dot_int(%arg0: memref, %arg1: memref, + %arg3: memref) { + // Verifies that we use the correct arith operations for integers. + linalg.dot ins(%arg0, %arg1 : memref, memref) + outs(%arg3 : memref) + return +} +// CHECK-LABEL: func @dot_int( +// CHECK: %[[inc:.*]] = arith.muli {{.*}} : i32 +// CHECK-NEXT: %[[res:.*]] = arith.addi {{.*}}, %[[inc]] : i32 +// CHECK-NEXT: store %[[res]], {{.*}} : memref + + +func.func @dot_bool(%arg0: memref, %arg1: memref, + %arg3: memref) { + // Verifies that we use the correct (saturating) arith operations for booleans. + linalg.dot ins(%arg0, %arg1 : memref, memref) + outs(%arg3 : memref) + return +} +// CHECK-LABEL: func @dot_bool( +// CHECK: %[[inc:.*]] = arith.andi {{.*}} : i1 +// CHECK-NEXT: %[[res:.*]] = arith.ori {{.*}}, %[[inc]] : i1 +// CHECK-NEXT: store %[[res]], {{.*}} : memref + + func.func @dot_view(%arg0: memref, %arg1: memref, %arg2: memref) { linalg.dot ins(%arg0, %arg1 : memref, memref)