diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -149,10 +149,29 @@ if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); - // tosa::LogicalrightShiftOp + // tosa::LogicalRightShiftOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); + // tosa::LogicalAnd + if (isa(op) && elementTy.isInteger(1)) + return rewriter.create(loc, resultTypes, args); + + // tosa::LogicalNot + if (isa(op) && elementTy.isInteger(1)) { + auto one = rewriter.create( + loc, rewriter.getIntegerAttr(elementTy, 1)); + return rewriter.create(loc, resultTypes, args[0], one); + } + + // tosa::LogicalOr + if (isa(op) && elementTy.isInteger(1)) + return rewriter.create(loc, resultTypes, args); + + // tosa::LogicalXor + if (isa(op) && elementTy.isInteger(1)) + return rewriter.create(loc, resultTypes, args); + // tosa::PowOp if (isa(op) && elementTy.isa()) return rewriter.create(loc, resultTypes, args); @@ -869,6 +888,10 @@ PointwiseConverter, PointwiseConverter, PointwiseConverter, + PointwiseConverter, + PointwiseConverter, + PointwiseConverter, + PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -260,6 +260,30 @@ // ----- +// CHECK-LABEL: @test_bool +func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () { + // CHECK: linalg.generic + // CHECK: and + %0 = "tosa.logical_and"(%arg0, %arg1) : (tensor<1xi1>, tensor<1xi1>) -> tensor<1xi1> + + // CHECK: linalg.generic + // CHECK: or + %1 = "tosa.logical_or"(%arg0, %arg1) : (tensor<1xi1>, tensor<1xi1>) -> tensor<1xi1> + + // CHECK: linalg.generic + // CHECK: xor + %2 = "tosa.logical_xor"(%arg0, %arg1) : (tensor<1xi1>, tensor<1xi1>) -> tensor<1xi1> + + // CHECK: linalg.generic + // CHECK: constant true + // CHECK: xor + %3 = "tosa.logical_not"(%arg0) : (tensor<1xi1>) -> tensor<1xi1> + + return +} + +// ----- + // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: @test_reshape_downrank func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {