diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -14,6 +14,7 @@ #define MLIR_DIALECT_TOSA_IR_TOSAOPS_H #include "mlir/Dialect/Traits.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -270,6 +270,34 @@ let hasCanonicalizer = 1; } +//===----------------------------------------------------------------------===// +// Operator: rfft2d +//===----------------------------------------------------------------------===// +def Tosa_RFFT2dOp : Tosa_Op<"rfft2d", [ + DeclareOpInterfaceMethods, + Pure]> { + let summary = "Performs RFFT2D operation on the input."; + + let description = [{ + Performs a batched 2D real-valued Fast Fourier Transform over the input where + the input tensor consists of real values producing complex valued output. The + complex output values will be split into the output_real and output_imag + tensor arguments. RFFT2D takes advantage of Hermitian symmetry to only + calculate the first half of the final output axis. Imaginary values with + locations (0,0), (0,W/2), (H/2,0) and (H/2,W/2) are zero. + }]; + + let arguments = (ins + Tosa_Tensor3D:$input + ); + + let results = (outs + Tosa_Tensor3D:$output_real, + Tosa_Tensor3D:$output_imag + ); +} + //===----------------------------------------------------------------------===// // Operator: transpose_conv2d //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -387,6 +387,31 @@ return success(); } +LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional location, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + ShapeAdaptor inputShape = operands.getShape(0); + + if (!inputShape.hasRank()) + return failure(); + + llvm::SmallVector outputShape; + outputShape.resize(3, ShapedType::kDynamic); + outputShape[0] = inputShape.getDimSize(0); + outputShape[1] = inputShape.getDimSize(1); + int64_t inWidth = inputShape.getDimSize(2); + + // Note that we can support this calculation symbolically + // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1] + if (inWidth != ShapedType::kDynamic) + outputShape[2] = inWidth / 2 + 1; + + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + return success(); +} + LogicalResult tosa::ConcatOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -72,6 +72,13 @@ return %0 : tensor<1x32x32x8xf32> } +// ----- +// CHECK-LABEL: rfft2d +func.func @test_rfft2d(%arg0: tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) { + %0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) + return %0, %1 : tensor<13x8x9xf32>, tensor<13x8x9xf32> +} + // ----- // CHECK-LABEL: transpose_conv2d func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -1189,3 +1189,30 @@ }) : (tensor, tensor<1xi32>) -> (tensor<*xi32>, tensor<*xi32>) return } + +// ----- + +// CHECK-LABEL: @test_static_rfft2d +func.func @test_static_rfft2d(%arg0: tensor<5x2x8xf32>) -> () { + // CHECK: -> (tensor<5x2x5xf32>, tensor<5x2x5xf32>) + %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x2x8xf32>) -> (tensor, tensor) + return +} + +// ----- + +// CHECK-LABEL: @test_dynamic_batch_rfft2d +func.func @test_dynamic_batch_rfft2d(%arg0 : tensor) -> () { + // CHECK: -> (tensor, tensor) + %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor) -> (tensor, tensor) + return +} + +// ----- + +// CHECK-LABEL: @test_dynamic_width_rfft2d +func.func @test_dynamic_width_rfft2d(%arg0 : tensor<5x2x?xf32>) -> () { + // CHECK: -> (tensor<5x2x?xf32>, tensor<5x2x?xf32>) + %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x2x?xf32>) -> (tensor, tensor) + return +}