diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -12,7 +12,9 @@ #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/IndexingUtils.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -20,6 +22,7 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" @@ -2021,6 +2024,162 @@ } }; +struct RFFT2dConverter final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + static bool isRankedTensor(Type type) { return isa(type); } + + static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc, + OpFoldResult ofr) { + auto one = builder.create(loc, 1); + auto two = builder.create(loc, 2); + + auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr); + auto divBy2 = builder.createOrFold(loc, value, two); + auto plusOne = builder.createOrFold(loc, divBy2, one); + return getAsOpFoldResult(plusOne); + } + + static RankedTensorType + computeOutputShape(OpBuilder &builder, Location loc, Value input, + llvm::SmallVectorImpl &dynamicSizes) { + // Get [N, H, W] + auto dims = linalg::getMixedDimensions(builder, loc, input); + + // Set W = (W / 2) + 1 to account for the half-sized W dimension of the + // output tensors. + dims[2] = halfPlusOne(builder, loc, dims[2]); + + llvm::SmallVector staticSizes; + dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes); + + auto elementType = + input.getType().cast().getElementType(); + return RankedTensorType::get(staticSizes, elementType); + } + + static Value createZeroTensor(PatternRewriter &rewriter, Location loc, + RankedTensorType type, + llvm::ArrayRef dynamicSizes) { + auto emptyTensor = + rewriter.create(loc, type, dynamicSizes); + auto fillValueAttr = rewriter.getZeroAttr(type.getElementType()); + auto fillValue = rewriter.create(loc, fillValueAttr); + auto filledTensor = rewriter + .create(loc, ValueRange{fillValue}, + ValueRange{emptyTensor}) + .result(); + return filledTensor; + } + + static Value castIndexToFloat(OpBuilder &builder, Location loc, + FloatType type, Value value) { + auto integerVal = + builder.create(loc, builder.getI64Type(), value); + + return builder.create(loc, type, integerVal); + } + + static Value createLinalgIndex(OpBuilder &builder, Location loc, + FloatType type, int64_t index) { + auto indexVal = builder.create(loc, index); + return castIndexToFloat(builder, loc, type, indexVal); + } + + template + static llvm::SmallVector affineDimsExpr(OpBuilder &builder, + Args... args) { + return {builder.getAffineDimExpr(args)...}; + } + + LogicalResult matchAndRewrite(RFFT2dOp rfft2d, + PatternRewriter &rewriter) const override { + if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) || + !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) { + return rewriter.notifyMatchFailure(rfft2d, + "only supports ranked tensors"); + } + + auto loc = rfft2d.getLoc(); + auto input = rfft2d.getInput(); + auto elementType = + input.getType().cast().getElementType().cast(); + + // Compute the output type and set of dynamic sizes + llvm::SmallVector dynamicSizes; + auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes); + + // Iterator types for the linalg.generic implementation + llvm::SmallVector iteratorTypes = { + utils::IteratorType::parallel, utils::IteratorType::parallel, + utils::IteratorType::parallel, utils::IteratorType::reduction, + utils::IteratorType::reduction}; + + // Inputs/outputs to the linalg.generic implementation + llvm::SmallVector genericOpInputs = {input}; + llvm::SmallVector genericOpOutputs = { + createZeroTensor(rewriter, loc, outputType, dynamicSizes), + createZeroTensor(rewriter, loc, outputType, dynamicSizes)}; + + // Indexing maps for input and output tensors + auto indexingMaps = AffineMap::inferFromExprList(llvm::ArrayRef{ + affineDimsExpr(rewriter, 0, 3, 4), affineDimsExpr(rewriter, 0, 1, 2), + affineDimsExpr(rewriter, 0, 1, 2)}); + + // Width and height dimensions of the original input. + auto dimH = linalg::createOrFoldDimOp(rewriter, loc, input, 1); + auto dimW = linalg::createOrFoldDimOp(rewriter, loc, input, 2); + + // Constants and dimension sizes + auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586); + auto twoPi = rewriter.create(loc, twoPiAttr); + auto constH = castIndexToFloat(rewriter, loc, elementType, dimH); + auto constW = castIndexToFloat(rewriter, loc, elementType, dimW); + + auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) { + Value valReal = args[0]; + Value sumReal = args[1]; + Value sumImag = args[2]; + + // Indices for angle computation + auto oy = createLinalgIndex(builder, loc, elementType, 1); + auto ox = createLinalgIndex(builder, loc, elementType, 2); + auto iy = createLinalgIndex(builder, loc, elementType, 3); + auto ix = createLinalgIndex(builder, loc, elementType, 4); + + // angle = 2 * pi() * ((iy * oy) / H + (ix * ox) / W) + auto iyXoy = builder.create(loc, iy, oy); + auto ixXox = builder.create(loc, ix, ox); + auto yComponent = builder.create(loc, iyXoy, constH); + auto xComponent = builder.create(loc, ixXox, constW); + auto sumXY = builder.create(loc, yComponent, xComponent); + auto angle = builder.create(loc, twoPi, sumXY); + + // realComponent = valReal * cos(angle) + // imagComponent = valReal * sin(angle) + auto cosAngle = builder.create(loc, angle); + auto sinAngle = builder.create(loc, angle); + auto realComponent = + builder.create(loc, valReal, cosAngle); + auto imagComponent = + builder.create(loc, valReal, sinAngle); + + // outReal = sumReal + realComponent + // outImag = sumImag - imagComponent + auto outReal = builder.create(loc, sumReal, realComponent); + auto outImag = builder.create(loc, sumImag, imagComponent); + + builder.create(loc, ValueRange{outReal, outImag}); + }; + + rewriter.replaceOpWithNewOp( + rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs, + indexingMaps, iteratorTypes, buildBody); + + return success(); + } +}; + } // namespace void mlir::tosa::populateTosaToLinalgConversionPatterns( @@ -2083,6 +2242,7 @@ GatherConverter, RescaleConverter, ReverseConverter, + RFFT2dConverter, TableConverter, TileConverter, TransposeConverter>(patterns->getContext()); diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1412,3 +1412,132 @@ return %0 : tensor<1x12x5x5xf32> } +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> + +// CHECK-LABEL: @test_static_rfft2d +// CHECK-SAME: (%[[ARG_0:[0-9a-zA-Z_]*]]: +func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) { +// CHECK: %[[CST_1:.*]] = arith.constant 1 : index +// CHECK: %[[CST_2:.*]] = arith.constant 2 : index +// CHECK: %[[CST_8:.*]] = arith.constant 8 : index +// CHECK: %[[CST_4:.*]] = arith.constant 4 : index +// CHECK: %[[CST_5:.*]] = arith.constant 5 : index +// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<5x5x5xf32> +// CHECK: %[[CST_ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAR_1:.*]] = linalg.fill ins(%[[CST_ZERO:.*]] : f32) outs(%[[EMPTY_0:.*]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32> +// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<5x5x5xf32> +// CHECK: %[[VAR_3:.*]] = linalg.fill ins(%[[CST_ZERO:.*]]: f32) outs(%[[EMPTY_1:.*]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32> +// CHECK: %[[CST_PI:.*]] = arith.constant 6.28318548 : f32 +// CHECK: %[[VAR_5:.*]] = arith.index_castui %[[CST_5:.*]] : index to i64 +// CHECK: %[[VAR_6:.*]] = arith.uitofp %[[VAR_5:.*]] : i64 to f32 +// CHECK: %[[VAR_7:.*]] = arith.index_castui %[[CST_8:.*]] : index to i64 +// CHECK: %[[VAR_8:.*]] = arith.uitofp %[[VAR_7:.*]] : i64 to f32 +// CHECK: linalg.generic { +// CHECK: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], +// CHECK: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} +// CHECK: ins(%[[ARG_0]] : tensor<5x5x8xf32>) +// CHECK: outs(%[[VAR_1]], %[[VAR_3]] : tensor<5x5x5xf32>, tensor<5x5x5xf32>) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT_0:.*]]: f32, %[[OUT_1:.*]]: f32): +// CHECK: %[[INDEX_1:.*]] = linalg.index 1 : index +// CHECK: %[[VAR_10:.*]] = arith.index_castui %[[INDEX_1]] : index to i64 +// CHECK: %[[VAR_11:.*]] = arith.uitofp %[[VAR_10]] : i64 to f32 +// CHECK: %[[INDEX_2:.*]] = linalg.index 2 : index +// CHECK: %[[VAR_13:.*]] = arith.index_castui %[[INDEX_2]] : index to i64 +// CHECK: %[[VAR_14:.*]] = arith.uitofp %[[VAR_13]] : i64 to f32 +// CHECK: %[[INDEX_3:.*]] = linalg.index 3 : index +// CHECK: %[[VAR_16:.*]] = arith.index_castui %[[INDEX_3]] : index to i64 +// CHECK: %[[VAR_17:.*]] = arith.uitofp %[[VAR_16]] : i64 to f32 +// CHECK: %[[INDEX_4:.*]] = linalg.index 4 : index +// CHECK: %[[VAR_19:.*]] = arith.index_castui %[[INDEX_4]] : index to i64 +// CHECK: %[[VAR_20:.*]] = arith.uitofp %[[VAR_19]] : i64 to f32 +// CHECK: %[[VAR_21:.*]] = arith.mulf %[[VAR_17]], %[[VAR_11]] : f32 +// CHECK: %[[VAR_22:.*]] = arith.mulf %[[VAR_20]], %[[VAR_14]] : f32 +// CHECK: %[[XCOMP:.*]] = arith.divf %[[VAR_21]], %[[VAR_6]] : f32 +// CHECK: %[[YCOMP:.*]] = arith.divf %[[VAR_22]], %[[VAR_8]] : f32 +// CHECK: %[[VAR_25:.*]] = arith.addf %[[XCOMP]], %[[YCOMP]] : f32 +// CHECK: %[[ALPHA:.*]] = arith.mulf %[[CST_PI]], %[[VAR_25]] : f32 +// CHECK: %[[COS_ALPHA:.*]] = math.cos %[[ALPHA]] : f32 +// CHECK: %[[SIN_ALPHA:.*]] = math.sin %[[ALPHA]] : f32 +// CHECK: %[[REAL_CONTRIB:.*]] = arith.mulf %[[IN]], %[[COS_ALPHA]] : f32 +// CHECK: %[[IMAG_CONTRIB:.*]] = arith.mulf %[[IN]], %[[SIN_ALPHA]] : f32 +// CHECK: %[[OUT_REAL:.*]] = arith.addf %[[OUT_0]], %[[REAL_CONTRIB]] : f32 +// CHECK: %[[OUT_IMAG:.*]] = arith.subf %[[OUT_1]], %[[IMAG_CONTRIB]] : f32 +// CHECK: linalg.yield %[[OUT_REAL]], %[[OUT_IMAG]] : f32, f32 +// CHECK: } -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) + + %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) + return %output_real, %output_imag : tensor<5x5x5xf32>, tensor<5x5x5xf32> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> + +// CHECK-LABEL: @test_dynamic_rfft2d +// CHECK-SAME: (%[[ARG_0:[0-9a-zA-Z_]*]]: +func.func @test_dynamic_rfft2d(%arg0: tensor) -> (tensor, tensor) { +// CHECK: %[[CST_0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG_0]], %[[CST_0]] : tensor +// CHECK: %[[CST_1:.*]] = arith.constant 1 : index +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[ARG_0]], %[[CST_1]] : tensor +// CHECK: %[[CST_2:.*]] = arith.constant 2 : index +// CHECK: %[[DIM_1:.*]] = tensor.dim %[[ARG_0]], %[[CST_2]] : tensor +// CHECK: %[[CST_1_2:.*]] = arith.constant 1 : index +// CHECK: %[[CST_2_3:.*]] = arith.constant 2 : index +// CHECK: %[[VAR_0:.*]] = arith.divui %[[DIM_1]], %[[CST_2_3]] : index +// CHECK: %[[VAR_1:.*]] = arith.addi %[[VAR_0]], %[[CST_1_2]] : index +// CHECK: %[[EMPTY_0:.*]] = tensor.empty(%[[DIM]], %[[DIM_0]], %[[VAR_1]]) : tensor +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAR_3:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMPTY_0]] : tensor) -> tensor +// CHECK: %[[EMPTY_1:.*]] = tensor.empty(%[[DIM]], %[[DIM_0]], %[[VAR_1]]) : tensor +// CHECK: %[[CST_4:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAR_5:.*]] = linalg.fill ins(%[[CST_4]] : f32) outs(%[[EMPTY_1]] : tensor) -> tensor +// CHECK: %[[CST_1_5:.*]] = arith.constant 1 : index +// CHECK: %[[DIM_6:.*]] = tensor.dim %[[ARG_0]], %[[CST_1_5]] : tensor +// CHECK: %[[CST_2:.*]] = arith.constant 2 : index +// CHECK: %[[DIM_8:.*]] = tensor.dim %[[ARG_0]], %[[CST_2]] : tensor +// CHECK: %[[CST_9:.*]] = arith.constant 6.28318548 : f32 +// CHECK: %[[VAR_6:.*]] = arith.index_castui %[[DIM_6]] : index to i64 +// CHECK: %[[VAR_7:.*]] = arith.uitofp %[[VAR_6]] : i64 to f32 +// CHECK: %[[VAR_8:.*]] = arith.index_castui %[[DIM_8]] : index to i64 +// CHECK: %[[VAR_9:.*]] = arith.uitofp %[[VAR_8]] : i64 to f32 +// CHECK: linalg.generic { +// CHECK: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], +// CHECK: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} +// CHECK: ins(%[[ARG_0]] : tensor) +// CHECK: outs(%[[VAR_3]], %[[VAR_5]] : tensor, tensor) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT_0:.*]]: f32, %[[OUT_1:.*]]: f32): +// CHECK: %[[INDEX_1:.*]] = linalg.index 1 : index +// CHECK: %[[VAR_12:.*]] = arith.index_castui %[[INDEX_1]] : index to i64 +// CHECK: %[[VAR_13:.*]] = arith.uitofp %[[VAR_12]] : i64 to f32 +// CHECK: %[[INDEX_2:.*]] = linalg.index 2 : index +// CHECK: %[[VAR_15:.*]] = arith.index_castui %[[INDEX_2]] : index to i64 +// CHECK: %[[VAR_16:.*]] = arith.uitofp %[[VAR_15]] : i64 to f32 +// CHECK: %[[INDEX_3:.*]] = linalg.index 3 : index +// CHECK: %[[VAR_18:.*]] = arith.index_castui %[[INDEX_3]] : index to i64 +// CHECK: %[[VAR_19:.*]] = arith.uitofp %[[VAR_18]] : i64 to f32 +// CHECK: %[[INDEX_4:.*]] = linalg.index 4 : index +// CHECK: %[[VAR_21:.*]] = arith.index_castui %[[INDEX_4]] : index to i64 +// CHECK: %[[VAR_22:.*]] = arith.uitofp %[[VAR_21]] : i64 to f32 +// CHECK: %[[VAR_23:.*]] = arith.mulf %[[VAR_19]], %[[VAR_13]] : f32 +// CHECK: %[[VAR_24:.*]] = arith.mulf %[[VAR_22]], %[[VAR_16]] : f32 +// CHECK: %[[XCOMP:.*]] = arith.divf %[[VAR_23]], %[[VAR_7]] : f32 +// CHECK: %[[YCOMP:.*]] = arith.divf %[[VAR_24]], %[[VAR_9]] : f32 +// CHECK: %[[VAR_27:.*]] = arith.addf %[[XCOMP]], %[[YCOMP]] : f32 +// CHECK: %[[ALPHA:.*]] = arith.mulf %[[CST_9]], %[[VAR_27]] : f32 +// CHECK: %[[COS_ALPHA:.*]] = math.cos %[[ALPHA]] : f32 +// CHECK: %[[SIN_ALPHA:.*]] = math.sin %[[ALPHA]] : f32 +// CHECK: %[[REAL_CONTRIB:.*]] = arith.mulf %[[IN]], %[[COS_ALPHA]] : f32 +// CHECK: %[[IMAG_CONTRIB:.*]] = arith.mulf %[[IN]], %[[SIN_ALPHA]] : f32 +// CHECK: %[[OUT_REAL:.*]] = arith.addf %[[OUT_0]], %[[REAL_CONTRIB]] : f32 +// CHECK: %[[OUT_IMAG:.*]] = arith.subf %[[OUT_1]], %[[IMAG_CONTRIB]] : f32 +// CHECK: linalg.yield %[[OUT_REAL]], %[[OUT_IMAG]] : f32, f32 +// CHECK: } -> (tensor, tensor) + + %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor) -> (tensor, tensor) + return %output_real, %output_imag : tensor, tensor +}