diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -27,6 +27,7 @@ #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h" #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Conversion/StandardToLinalg/StandardToLinalg.h" #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -352,6 +352,22 @@ ]; } +//===----------------------------------------------------------------------===// +// StandardToLinalg +//===----------------------------------------------------------------------===// + +def ConvertStandardToLinalg : Pass<"convert-std-to-linalg", "FuncOp"> { + let summary = "Convert std ops to linalg"; + let description = [{ + Convert ops from the std dialect to linalg. + + Currently, this pass only supports std elementwise ops operating on tensors + and converts them to linalg parallel loops. + }]; + let constructor = "mlir::createConvertStandardToLinalgPass()"; + let dependentDialects = ["linalg::LinalgDialect"]; +} + //===----------------------------------------------------------------------===// // StandardToSPIRV //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/StandardToLinalg/StandardToLinalg.h b/mlir/include/mlir/Conversion/StandardToLinalg/StandardToLinalg.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/StandardToLinalg/StandardToLinalg.h @@ -0,0 +1,28 @@ +//===- StandardToLinalg.h - Conversion utils from std to linalg -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_STANDARDTOLINALG_STANDARDTOLINALG_H_ +#define MLIR_CONVERSION_STANDARDTOLINALG_STANDARDTOLINALG_H_ + +#include + +namespace mlir { + +class Pass; +class MLIRContext; +class Pass; +class OwningRewritePatternList; + +void populateStandardToLinalgConversionPatterns( + OwningRewritePatternList &patterns, MLIRContext *ctx); + +std::unique_ptr createConvertStandardToLinalgPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_STANDARDTOLINALG_STANDARDTOLINALG_H_ diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-std-elementwise.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-std-elementwise.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-std-elementwise.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-opt %s -convert-std-to-linalg -linalg-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ +// RUN: | FileCheck %s + +func @main() { + %a = constant dense<[1.0, 2.0, 3.0]> : tensor<3xf32> + %b = constant dense<[10.0, 20.0, 30.0]> : tensor<3xf32> + + %addf = addf %a, %b : tensor<3xf32> + %addf_unranked = tensor_cast %addf : tensor<3xf32> to tensor<*xf32> + call @print_memref_f32(%addf_unranked) : (tensor<*xf32>) -> () + // CHECK: Unranked Memref base@ = {{.*}} rank = 1 offset = 0 sizes = [3] strides = [1] data = + // CHECK-NEXT: [11, 22, 33] + + return +} + +func @print_memref_f32(%ptr : tensor<*xf32>) diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -17,6 +17,7 @@ add_subdirectory(ShapeToStandard) add_subdirectory(SPIRVToLLVM) add_subdirectory(StandardToLLVM) +add_subdirectory(StandardToLinalg) add_subdirectory(StandardToSPIRV) add_subdirectory(VectorToROCDL) add_subdirectory(VectorToLLVM) diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -33,6 +33,10 @@ class NVVMDialect; } // end namespace NVVM +namespace linalg { +class LinalgDialect; +} // end namespace linalg + namespace pdl_interp { class PDLInterpDialect; } // end namespace pdl_interp diff --git a/mlir/lib/Conversion/StandardToLinalg/CMakeLists.txt b/mlir/lib/Conversion/StandardToLinalg/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/StandardToLinalg/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_conversion_library(MLIRStandardToLinalg + StandardToLinalg.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/StandardToLinalg + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLinalg + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/StandardToLinalg/StandardToLinalg.cpp b/mlir/lib/Conversion/StandardToLinalg/StandardToLinalg.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/StandardToLinalg/StandardToLinalg.cpp @@ -0,0 +1,150 @@ +//===- StandardToLinalg.cpp - conversion from Std to Linalg dialect -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/StandardToLinalg/StandardToLinalg.h" + +#include "../PassDetail.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +static bool isStdElementwiseOpOnRankedTensors(Operation *op) { + if (!llvm::all_of(op->getResultTypes(), + [](Type type) { return type.isa(); })) { + return false; + } + if (!llvm::all_of(op->getOperandTypes(), + [](Type type) { return type.isa(); })) { + return false; + } + // TODO: Have an "elementwise" trait that supercedes this. + // clang-format off + return isa< + AbsFOp, + AddFOp, + AddIOp, + AndOp, + AtanOp, + Atan2Op, + CeilFOp, + FloorFOp, + CmpFOp, + CmpIOp, + CopySignOp, + CosOp, + SinOp, + DivFOp, + ExpOp, + Exp2Op, + FPExtOp, + FPToSIOp, + FPToUIOp, + FPTruncOp, + IndexCastOp, + LogOp, + Log10Op, + Log2Op, + MulFOp, + MulIOp, + NegFOp, + OrOp, + RemFOp, + RsqrtOp, + SelectOp, + ShiftLeftOp, + SignedDivIOp, + SignedRemIOp, + SignedShiftRightOp, + SignExtendIOp, + SIToFPOp, + SqrtOp, + SubFOp, + SubIOp, + SubViewOp, + TanhOp, + TruncateIOp, + UIToFPOp, + UnsignedDivIOp, + UnsignedRemIOp, + UnsignedShiftRightOp, + XOrOp, + ZeroExtendIOp + >(op); + // clang-format on +} + +namespace { +struct ConvertStdElementwiseOpOnRankedTensors : public RewritePattern { + ConvertStdElementwiseOpOnRankedTensors() + : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final { + if (!isStdElementwiseOpOnRankedTensors(op)) + return rewriter.notifyMatchFailure(op, "requires std elementwise op"); + + auto rank = op->getResult(0).getType().cast().getRank(); + SmallVector indexingMaps( + op->getNumResults() + op->getNumOperands(), + rewriter.getMultiDimIdentityMap(rank)); + SmallVector iteratorTypes(rank, + getParallelIteratorTypeName()); + rewriter.replaceOpWithNewOp( + op, /*resultTensorTypes=*/op->getResultTypes(), + /*inputs=*/op->getOperands(), + /*outputBuffers=*/ValueRange(), + /*initTensors=*/ValueRange(), + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + /*bodyBuilder=*/ + [&](OpBuilder &builder, Location loc, ValueRange regionArgs) { + OperationState state(loc, op->getName()); + state.addAttributes(op->getAttrs()); + state.addOperands(regionArgs); + auto resultTypes = llvm::to_vector<6>( + llvm::map_range(op->getResultTypes(), [](Type type) { + return type.cast().getElementType(); + })); + state.addTypes(resultTypes); + auto *scalarOp = builder.createOperation(state); + builder.create(loc, scalarOp->getResults()); + }); + return success(); + } +}; +} // namespace + +void mlir::populateStandardToLinalgConversionPatterns( + OwningRewritePatternList &patterns, MLIRContext *) { + patterns.insert(); +} + +namespace { +class ConvertStandardToLinalgPass + : public ConvertStandardToLinalgBase { + + void runOnOperation() final { + auto func = getOperation(); + auto *context = &getContext(); + ConversionTarget target(*context); + OwningRewritePatternList patterns; + + populateStandardToLinalgConversionPatterns(patterns, context); + target.markUnknownOpDynamicallyLegal( + [](Operation *op) { return !isStdElementwiseOpOnRankedTensors(op); }); + + if (failed(applyPartialConversion(func, target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr mlir::createConvertStandardToLinalgPass() { + return std::make_unique(); +} diff --git a/mlir/test/Conversion/StandardToLinalg/std-to-linalg.mlir b/mlir/test/Conversion/StandardToLinalg/std-to-linalg.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/StandardToLinalg/std-to-linalg.mlir @@ -0,0 +1,60 @@ +// RUN: mlir-opt -convert-std-to-linalg -split-input-file %s | FileCheck %s + +// In-depth checking of the linalg.generic op for a very trivial case. +// CHECK-LABEL: #map0 = affine_map<() -> ()> +// CHECK-LABEL: func @addf_rank0 +func @addf_rank0(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: %{{.*}} = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = []} ins(%{{.*}}, %{{.*}} : tensor, tensor) { + // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): + // CHECK: %[[YIELD:.*]] = addf %[[LHS]], %[[RHS]] : f32 + // CHECK: linalg.yield %[[YIELD]] : f32 + // CHECK: } -> tensor + %0 = addf %arg0, %arg1 : tensor + return %0 : tensor +} + +// ----- + +// Check indexing maps and iterator types for the rank > 0 case. +// CHECK-LABEL: #map0 = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func @addf_rank1 +func @addf_rank1(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: linalg.generic{{.*}}indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel"] + %0 = addf %arg0, %arg1 : tensor + return %0 : tensor +} + +// ----- + +// Check a unary op. +// CHECK-LABEL: func @exp +func @exp(%arg0: tensor) -> tensor { + // CHECK: linalg.generic + // CHECK: ^bb0(%[[SCALAR:.*]]: f32): + // CHECK: %[[YIELD:.*]] = exp %[[SCALAR]] : f32 + // CHECK: linalg.yield %[[YIELD]] : f32 + %0 = exp %arg0 : tensor + return %0 : tensor +} + +// ----- + +// Check a case with varying operand types. +// CHECK-LABEL: func @select +func @select(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: linalg.generic + // CHECK: ^bb0(%[[PRED:.*]]: i1, %[[TRUE_VAL:.*]]: i32, %[[FALSE_VAL:.*]]: i32): + // CHECK: select %[[PRED]], %[[TRUE_VAL]], %[[FALSE_VAL]] : i32 + %0 = select %arg0, %arg1, %arg2 : tensor, tensor + return %0 : tensor +} + +// ----- + +// Spot-check an op that requires copying attributes properly to the created scalar op. +// CHECK-LABEL: func @cmpf( +func @cmpf(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: cmpf "olt", %{{.*}}, %{{.*}} : f32 + %0 = cmpf "olt", %arg0, %arg1 : tensor + return %0 : tensor +}