diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -16,6 +16,8 @@ #include "mlir/Pass/Pass.h" namespace mlir { +std::unique_ptr> createConvertElementwiseToLinalgPass(); + std::unique_ptr> createLinalgFoldUnitExtentDimsPass(); std::unique_ptr createLinalgFusionOfTensorOpsPass(); @@ -48,6 +50,11 @@ /// buffers instead. std::unique_ptr> createLinalgBufferizePass(); +/// Populate patterns that convert `ElementwiseMappable` ops to linalg +/// parallel loops. +void populateElementwiseToLinalgConversionPatterns( + OwningRewritePatternList &patterns, MLIRContext *ctx); + /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its /// producer (consumer) generic operation by expanding the dimensionality of the /// loop in the generic op. diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -11,6 +11,17 @@ include "mlir/Pass/PassBase.td" +def ConvertElementwiseToLinalg : FunctionPass<"convert-elementwise-to-linalg"> { + let summary = "Convert ElementwiseMappable ops to linalg"; + let description = [{ + Convert ops with the `ElementwiseMappable` trait to linalg parallel loops. + + This pass only converts ops that operate on ranked tensors. + }]; + let constructor = "mlir::createConvertElementwiseToLinalgPass()"; + let dependentDialects = ["linalg::LinalgDialect"]; +} + def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> { let summary = "Remove unit-extent dimension in Linalg ops on tensors"; let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()"; diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ +// RUN: | FileCheck %s + +func @main() { + %a = constant dense<[1.0, 2.0, 3.0]> : tensor<3xf32> + %b = constant dense<[10.0, 20.0, 30.0]> : tensor<3xf32> + + %addf = addf %a, %b : tensor<3xf32> + %addf_unranked = tensor_cast %addf : tensor<3xf32> to tensor<*xf32> + call @print_memref_f32(%addf_unranked) : (tensor<*xf32>) -> () + // CHECK: Unranked Memref base@ = {{.*}} rank = 1 offset = 0 sizes = [3] strides = [1] data = + // CHECK-NEXT: [11, 22, 33] + + return +} + +func @print_memref_f32(%ptr : tensor<*xf32>) diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ Bufferize.cpp CodegenStrategy.cpp DropUnitDims.cpp + ElementwiseToLinalg.cpp Fusion.cpp FusionOnTensors.cpp Hoisting.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp @@ -0,0 +1,98 @@ +//===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Passes.h" + +#include "PassDetail.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +static bool isElementwiseMappableOpOnRankedTensors(Operation *op) { + if (!op->hasTrait()) + return false; + + // TODO: The conversion pattern can be made to work for `any_of` here, but + // it's more complex as it requires tracking which operands are scalars. + return llvm::all_of(op->getOperandTypes(), + [](Type type) { return type.isa(); }); +} + +namespace { +struct ConvertStdElementwiseOpOnRankedTensors : public RewritePattern { + ConvertStdElementwiseOpOnRankedTensors() + : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final { + if (!isElementwiseMappableOpOnRankedTensors(op)) + return rewriter.notifyMatchFailure( + op, "requires elementwise op on ranked tensors"); + + auto rank = op->getResult(0).getType().cast().getRank(); + SmallVector indexingMaps( + op->getNumResults() + op->getNumOperands(), + rewriter.getMultiDimIdentityMap(rank)); + SmallVector iteratorTypes(rank, + getParallelIteratorTypeName()); + rewriter.replaceOpWithNewOp( + op, /*resultTensorTypes=*/op->getResultTypes(), + /*inputs=*/op->getOperands(), + /*outputBuffers=*/ValueRange(), + /*initTensors=*/ValueRange(), + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + /*bodyBuilder=*/ + [&](OpBuilder &builder, Location loc, ValueRange regionArgs) { + OperationState state(loc, op->getName()); + state.addAttributes(op->getAttrs()); + state.addOperands(regionArgs); + auto resultTypes = llvm::to_vector<6>( + llvm::map_range(op->getResultTypes(), [](Type type) { + return type.cast().getElementType(); + })); + state.addTypes(resultTypes); + auto *scalarOp = builder.createOperation(state); + builder.create(loc, scalarOp->getResults()); + }); + return success(); + } +}; +} // namespace + +void mlir::populateElementwiseToLinalgConversionPatterns( + OwningRewritePatternList &patterns, MLIRContext *) { + patterns.insert(); +} + +namespace { +class ConvertElementwiseToLinalgPass + : public ConvertElementwiseToLinalgBase { + + void runOnFunction() final { + auto func = getOperation(); + auto *context = &getContext(); + ConversionTarget target(*context); + OwningRewritePatternList patterns; + + populateElementwiseToLinalgConversionPatterns(patterns, context); + target.markUnknownOpDynamicallyLegal([](Operation *op) { + return !isElementwiseMappableOpOnRankedTensors(op); + }); + + if (failed(applyPartialConversion(func, target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> +mlir::createConvertElementwiseToLinalgPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir @@ -0,0 +1,60 @@ +// RUN: mlir-opt -convert-elementwise-to-linalg -split-input-file %s | FileCheck %s + +// In-depth checking of the linalg.generic op for a very trivial case. +// CHECK: #map = affine_map<() -> ()> +// CHECK-LABEL: func @addf_rank0 +func @addf_rank0(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: %{{.*}} = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%{{.*}}, %{{.*}} : tensor, tensor) { + // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): + // CHECK: %[[YIELD:.*]] = addf %[[LHS]], %[[RHS]] : f32 + // CHECK: linalg.yield %[[YIELD]] : f32 + // CHECK: } -> tensor + %0 = addf %arg0, %arg1 : tensor + return %0 : tensor +} + +// ----- + +// Check indexing maps and iterator types for the rank > 0 case. +// CHECK: #map = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func @addf_rank1 +func @addf_rank1(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: linalg.generic{{.*}}indexing_maps = [#map, #map, #map], iterator_types = ["parallel"] + %0 = addf %arg0, %arg1 : tensor + return %0 : tensor +} + +// ----- + +// Check a unary op. +// CHECK-LABEL: func @exp +func @exp(%arg0: tensor) -> tensor { + // CHECK: linalg.generic + // CHECK: ^bb0(%[[SCALAR:.*]]: f32): + // CHECK: %[[YIELD:.*]] = exp %[[SCALAR]] : f32 + // CHECK: linalg.yield %[[YIELD]] : f32 + %0 = exp %arg0 : tensor + return %0 : tensor +} + +// ----- + +// Check a case with varying operand types. +// CHECK-LABEL: func @select +func @select(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: linalg.generic + // CHECK: ^bb0(%[[PRED:.*]]: i1, %[[TRUE_VAL:.*]]: i32, %[[FALSE_VAL:.*]]: i32): + // CHECK: select %[[PRED]], %[[TRUE_VAL]], %[[FALSE_VAL]] : i32 + %0 = select %arg0, %arg1, %arg2 : tensor, tensor + return %0 : tensor +} + +// ----- + +// Spot-check an op that requires copying attributes properly to the created scalar op. +// CHECK-LABEL: func @cmpf( +func @cmpf(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: cmpf "olt", %{{.*}}, %{{.*}} : f32 + %0 = cmpf "olt", %arg0, %arg1 : tensor + return %0 : tensor +}