diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -183,6 +183,36 @@ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// Operator: fft2d +//===----------------------------------------------------------------------===// +def Tosa_FFT2dOp : Tosa_Op<"fft2d", [ + DeclareOpInterfaceMethods, + Pure]> { + let summary = "Performs FFT2D operation on the input."; + + let description = [{ + Performs a batched complex 2D Fast Fourier Transform over the input. The + complex input values are constructed from the corresponding values in the + input_real and input_imag tensors. The resulting values in the output are + split into the output_real and output_imag tensors. No normalization is + applied on either the forward or inverse versions of the operation. + }]; + + let arguments = (ins + Tosa_Tensor3D:$input_real, + Tosa_Tensor3D:$input_imag, + + BoolAttr:$inverse + ); + + let results = (outs + Tosa_Tensor3D:$output_real, + Tosa_Tensor3D:$output_imag + ); +} + //===----------------------------------------------------------------------===// // Operator: fully_connected //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -409,6 +409,16 @@ inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + + return success(); +} + +LogicalResult tosa::FFT2dOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional location, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + inferredReturnShapes.push_back(ShapedTypeComponents(operands.getShape(0))); + inferredReturnShapes.push_back(ShapedTypeComponents(operands.getShape(1))); return success(); } diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -51,6 +51,13 @@ return %2 : tensor<1x4x4x8xf32> } +// ----- +// CHECK-LABEL: fft2d +func.func @test_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) { + %0, %1 = "tosa.fft2d"(%arg0, %arg1) {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) + return %0, %1 : tensor<1x4x8xf32>, tensor<1x4x8xf32> +} + // ----- // CHECK-LABEL: fully_connected func.func @test_fully_connected(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>, %arg2: tensor<28xf32>) -> tensor<14x28xf32> { diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -1216,3 +1216,21 @@ %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x2x?xf32>) -> (tensor, tensor) return } + +// ----- + +// CHECK-LABEL: @test_static_fft2d +func.func @test_static_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) { + // CHECK: -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) + %output_real, %output_imag = "tosa.fft2d"(%arg0, %arg1) {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) + return %output_real, %output_imag : tensor<1x4x8xf32>, tensor<1x4x8xf32> +} + +// ----- + +// CHECK-LABEL: @test_dynamic_batch_fft2d +func.func @test_dynamic_batch_fft2d(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + // CHECK: -> (tensor, tensor) + %output_real, %output_imag = "tosa.fft2d"(%arg0, %arg1) {inverse = false} : (tensor, tensor) -> (tensor, tensor) + return %output_real, %output_imag : tensor, tensor +}