diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -16,6 +16,8 @@ #include "mlir/Pass/Pass.h" namespace mlir { +std::unique_ptr> createConvertElementwiseToLinalgPass(); + std::unique_ptr> createLinalgFoldUnitExtentDimsPass(); std::unique_ptr createLinalgFusionOfTensorOpsPass(); @@ -48,6 +50,11 @@ /// buffers instead. std::unique_ptr> createLinalgBufferizePass(); +/// Populate patterns that convert `ElementwiseMappable` ops to linalg +/// parallel loops. +void populateElementwiseToLinalgConversionPatterns( + OwningRewritePatternList &patterns, MLIRContext *ctx); + /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its /// producer (consumer) generic operation by expanding the dimensionality of the /// loop in the generic op. diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -11,6 +11,17 @@ include "mlir/Pass/PassBase.td" +def ConvertElementwiseToLinalg : FunctionPass<"convert-elementwise-to-linalg"> { + let summary = "Convert ElementwiseMappable ops to linalg"; + let description = [{ + Convert ops with the `ElementwiseMappable` trait to linalg parallel loops. + + This pass only converts ops that operate on ranked tensors. + }]; + let constructor = "mlir::createConvertElementwiseToLinalgPass()"; + let dependentDialects = ["linalg::LinalgDialect"]; +} + def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> { let summary = "Remove unit-extent dimension in Linalg ops on tensors"; let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()"; diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -67,6 +67,11 @@ let hasFolder = 1; } +// Base class for arithmetic cast operations. +class ArithmeticCastOp traits = []> : + CastOp { +} + // Base class for unary ops. Requires single operand and result. Individual // classes will have `operand` accessor. class UnaryOp traits = []> : @@ -88,7 +93,8 @@ class FloatUnaryOp traits = []> : UnaryOpSameOperandAndResultType])>, + [DeclareOpInterfaceMethods, + ElementwiseMappable])>, Arguments<(ins FloatLike:$operand)>; // Base class for standard arithmetic operations. Requires operands and @@ -96,7 +102,9 @@ // types. Individual classes will have `lhs` and `rhs` accessor to operands. class ArithmeticOp traits = []> : Op { + !listconcat(traits, [NoSideEffect, + SameOperandsAndResultType, + ElementwiseMappable])> { let results = (outs AnyType); @@ -1152,10 +1160,10 @@ } def CmpFOp : Std_Op<"cmpf", - [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape, - TypesMatchWith< + [NoSideEffect, SameTypeOperands, + SameOperandsAndResultShape, TypesMatchWith< "result type has i1 element type and same shape as operands", - "lhs", "result", "getI1SameShape($_self)">]> { + "lhs", "result", "getI1SameShape($_self)">, ElementwiseMappable]> { let summary = "floating-point comparison operation"; let description = [{ The `cmpf` operation compares its two operands according to the float @@ -1236,10 +1244,10 @@ } def CmpIOp : Std_Op<"cmpi", - [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape, - TypesMatchWith< + [NoSideEffect, SameTypeOperands, + SameOperandsAndResultShape, TypesMatchWith< "result type has i1 element type and same shape as operands", - "lhs", "result", "getI1SameShape($_self)">]> { + "lhs", "result", "getI1SameShape($_self)">, ElementwiseMappable]> { let summary = "integer comparison operation"; let description = [{ The `cmpi` operation is a generic comparison for integer-like types. Its two @@ -1926,7 +1934,7 @@ // FPExtOp //===----------------------------------------------------------------------===// -def FPExtOp : CastOp<"fpext">, Arguments<(ins AnyType:$in)> { +def FPExtOp : ArithmeticCastOp<"fpext">, Arguments<(ins AnyType:$in)> { let summary = "cast from floating-point to wider floating-point"; let description = [{ Cast a floating-point value to a larger floating-point-typed value. @@ -1947,7 +1955,7 @@ // FPToSIOp //===----------------------------------------------------------------------===// -def FPToSIOp : CastOp<"fptosi">, Arguments<(ins AnyType:$in)> { +def FPToSIOp : ArithmeticCastOp<"fptosi">, Arguments<(ins AnyType:$in)> { let summary = "cast from floating-point type to integer type"; let description = [{ Cast from a value interpreted as floating-point to the nearest (rounding @@ -1967,7 +1975,7 @@ // FPToUIOp //===----------------------------------------------------------------------===// -def FPToUIOp : CastOp<"fptoui">, Arguments<(ins AnyType:$in)> { +def FPToUIOp : ArithmeticCastOp<"fptoui">, Arguments<(ins AnyType:$in)> { let summary = "cast from floating-point type to integer type"; let description = [{ Cast from a value interpreted as floating-point to the nearest (rounding @@ -1987,7 +1995,7 @@ // FPTruncOp //===----------------------------------------------------------------------===// -def FPTruncOp : CastOp<"fptrunc">, Arguments<(ins AnyType:$in)> { +def FPTruncOp : ArithmeticCastOp<"fptrunc">, Arguments<(ins AnyType:$in)> { let summary = "cast from floating-point to narrower floating-point"; let description = [{ Truncate a floating-point value to a smaller floating-point-typed value. @@ -2039,7 +2047,7 @@ // IndexCastOp //===----------------------------------------------------------------------===// -def IndexCastOp : CastOp<"index_cast">, Arguments<(ins AnyType:$in)> { +def IndexCastOp : ArithmeticCastOp<"index_cast">, Arguments<(ins AnyType:$in)> { let summary = "cast between index and integer types"; let description = [{ Casts between integer scalars and 'index' scalars. Index is an integer of @@ -2622,7 +2630,8 @@ //===----------------------------------------------------------------------===// def SelectOp : Std_Op<"select", [NoSideEffect, - AllTypesMatch<["true_value", "false_value", "result"]>]> { + AllTypesMatch<["true_value", "false_value", "result"]>, + ElementwiseMappable]> { let summary = "select operation"; let description = [{ The `select` operation chooses one value based on a binary condition @@ -2796,7 +2805,7 @@ //===----------------------------------------------------------------------===// def SignExtendIOp : Std_Op<"sexti", - [NoSideEffect, SameOperandsAndResultShape]> { + [NoSideEffect, SameOperandsAndResultShape, ElementwiseMappable]> { let summary = "integer sign extension operation"; let description = [{ The integer sign extension operation takes an integer input of @@ -2838,7 +2847,7 @@ // SIToFPOp //===----------------------------------------------------------------------===// -def SIToFPOp : CastOp<"sitofp">, Arguments<(ins AnyType:$in)> { +def SIToFPOp : ArithmeticCastOp<"sitofp">, Arguments<(ins AnyType:$in)> { let summary = "cast from integer type to floating-point"; let description = [{ Cast from a value interpreted as signed or vector of signed integers to the @@ -3637,7 +3646,9 @@ // TruncateIOp //===----------------------------------------------------------------------===// -def TruncateIOp : Std_Op<"trunci", [NoSideEffect, SameOperandsAndResultShape]> { +def TruncateIOp : Std_Op<"trunci", [NoSideEffect, + SameOperandsAndResultShape, + ElementwiseMappable]> { let summary = "integer truncation operation"; let description = [{ The integer truncation operation takes an integer input of @@ -3677,7 +3688,7 @@ // UIToFPOp //===----------------------------------------------------------------------===// -def UIToFPOp : CastOp<"uitofp">, Arguments<(ins AnyType:$in)> { +def UIToFPOp : ArithmeticCastOp<"uitofp">, Arguments<(ins AnyType:$in)> { let summary = "cast from unsigned integer type to floating-point"; let description = [{ Cast from a value interpreted as unsigned integer or vector of unsigned @@ -3904,7 +3915,9 @@ // ZeroExtendIOp //===----------------------------------------------------------------------===// -def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, SameOperandsAndResultShape]> { +def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, + SameOperandsAndResultShape, + ElementwiseMappable]> { let summary = "integer zero extension operation"; let description = [{ The integer zero extension operation takes an integer input of diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1750,6 +1750,9 @@ // Op can be safely normalized in the presence of MemRefs with // non-identity maps. def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">; +// Op can be systematically interconverted between scalar and vector/tensor +// form. +def ElementwiseMappable : NativeOpTrait<"ElementwiseMappable">; // Op's regions have a single block with the specified terminator. class SingleBlockImplicitTerminator diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -423,6 +423,7 @@ LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName); LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName); LogicalResult verifyNoRegionArguments(Operation *op); +LogicalResult verifyElementwiseMappable(Operation *op); } // namespace impl /// Helper class for implementing traits. Clients are not expected to interact @@ -1304,6 +1305,93 @@ struct MemRefsNormalizable : public TraitBase {}; +/// This trait tags scalar ops that also can be applied to vectors/tensors, with +/// their semantics on vectors/tensors being elementwise application. +/// +/// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this +/// trait. In particular, broadcasting behavior is not allowed. This trait +/// describes a set of invariants that allow systematic +/// vectorization/tensorization, and the reverse, scalarization. The properties +/// needed for this also can be used to implement a number of +/// transformations/analyses/interfaces. +/// +/// An `ElementwiseMappable` op must satisfy the following properties: +/// +/// 1. If any result is a vector (resp. tensor), then at least one operand must +/// be a vector (resp. tensor). +/// 2. If any operand is a vector (resp. tensor), then there must be at least +/// one result, and all results must be vectors (resp. tensors). +/// 3. The static types of all vector (resp. tensor) operands and results must +/// have the same shape. +/// 4. In the case of tensor operands, the dynamic shapes of all tensor operands +/// must be the same, otherwise the op has undefined behavior. +/// 5. ("systematic scalarization" property) If an op has vector/tensor +/// operands/results, then the same op, with the operand/result types changed to +/// their corresponding element type, shall be a verifier-valid op. +/// 6. The semantics of the op on vectors (resp. tensors) shall be same as +/// applying the scalarized version of the op for each corresponding element of +/// the vector (resp. tensor) operands in parallel. +/// 7. ("systematic vectorization/tensorization" property) If an op has +/// scalar operands/results, the op shall remain verifier-valid if all scalar +/// operands are replaced with vectors/tensors of the same shape and +/// corresponding element types. +/// +/// Together, these properties provide an easy way for scalar operations to +/// conveniently generalize their behavior to vectors/tensors, and systematize +/// conversion between these forms. +/// +/// Examples: +/// ``` +/// %scalar = "std.addf"(%a, %b) : (f32, f32) -> f32 +/// // Applying the systematic vectorization/tensorization property, this op +/// // must also be valid: +/// %tensor = "std.addf"(%a_tensor, %b_tensor) +/// : (tensor, tensor) -> tensor) +/// +/// // These properties generalize well to the cases of non-scalar operands. +/// %select_scalar_pred = "std.select"(%pred, %true_val, %false_val) +/// : (i1, tensor, tensor) -> tensor +/// // Applying the systematic vectorization / tensorization property, this +/// // op must also be valid: +/// %select_tensor_pred = "std.select"(%pred_tensor, %true_val, %false_val) +/// : (tensor, tensor, tensor) +/// -> tensor +/// // Applying the systematic scalarization property, this op must also +/// // be valid. +/// %select_scalar = "std.select"(%pred, %true_val_scalar, %false_val_scalar) +/// : (i1, f32, f32) -> f32 +/// ``` +/// +/// TODO: Avoid hardcoding vector/tensor, and generalize this to any type +/// implementing a new "ElementwiseMappableTypeInterface" that describes types +/// for which it makes sense to apply a scalar function to each element. +/// +/// Rationale: +/// - 1. and 2. guarantee a well-defined iteration space for 6. +/// - These also exclude the cases of 0 non-scalar operands or 0 non-scalar +/// results, which complicate a generic definition of the iteration space. +/// - 3. guarantees that folding can be done across scalars/vectors/tensors +/// with the same pattern, as otherwise lots of special handling of type +/// mismatches would be needed. +/// - 4. guarantees that no error handling cases need to be considered. +/// - Higher-level dialects should reify any needed guards / error handling +/// code before lowering to an ElementwiseMappable op. +/// - 5. and 6. allow defining the semantics on vectors/tensors via the scalar +/// semantics and provide a constructive procedure for IR transformations +/// to e.g. create scalar loop bodies from tensor ops. +/// - 7. provides the reverse of 5., which when chained together allows +/// reasoning about the relationship between the tensor and vector case. +/// Additionally, it permits reasoning about promoting scalars to +/// vectors/tensors via broadcasting in cases like `%select_scalar_pred` +/// above. +template +struct ElementwiseMappable + : public TraitBase { + static LogicalResult verifyTrait(Operation *op) { + return ::mlir::OpTrait::impl::verifyElementwiseMappable(op); + } +}; + } // end namespace OpTrait //===----------------------------------------------------------------------===// diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ +// RUN: | FileCheck %s + +func @main() { + %a = constant dense<[1.0, 2.0, 3.0]> : tensor<3xf32> + %b = constant dense<[10.0, 20.0, 30.0]> : tensor<3xf32> + + %addf = addf %a, %b : tensor<3xf32> + %addf_unranked = tensor_cast %addf : tensor<3xf32> to tensor<*xf32> + call @print_memref_f32(%addf_unranked) : (tensor<*xf32>) -> () + // CHECK: Unranked Memref base@ = {{.*}} rank = 1 offset = 0 sizes = [3] strides = [1] data = + // CHECK-NEXT: [11, 22, 33] + + return +} + +func @print_memref_f32(%ptr : tensor<*xf32>) diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ Bufferize.cpp CodegenStrategy.cpp DropUnitDims.cpp + ElementwiseToLinalg.cpp Fusion.cpp FusionOnTensors.cpp Hoisting.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp @@ -0,0 +1,98 @@ +//===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Passes.h" + +#include "PassDetail.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +static bool isElementwiseMappableOpOnRankedTensors(Operation *op) { + if (!op->hasTrait()) + return false; + + // TODO: The conversion pattern can be made to work for `any_of` here, but + // it's more complex as it requires tracking which operands are scalars. + return llvm::all_of(op->getOperandTypes(), + [](Type type) { return type.isa(); }); +} + +namespace { +struct ConvertStdElementwiseOpOnRankedTensors : public RewritePattern { + ConvertStdElementwiseOpOnRankedTensors() + : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final { + if (!isElementwiseMappableOpOnRankedTensors(op)) + return rewriter.notifyMatchFailure( + op, "requires elementwise op on ranked tensors"); + + auto rank = op->getResult(0).getType().cast().getRank(); + SmallVector indexingMaps( + op->getNumResults() + op->getNumOperands(), + rewriter.getMultiDimIdentityMap(rank)); + SmallVector iteratorTypes(rank, + getParallelIteratorTypeName()); + rewriter.replaceOpWithNewOp( + op, /*resultTensorTypes=*/op->getResultTypes(), + /*inputs=*/op->getOperands(), + /*outputBuffers=*/ValueRange(), + /*initTensors=*/ValueRange(), + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + /*bodyBuilder=*/ + [&](OpBuilder &builder, Location loc, ValueRange regionArgs) { + OperationState state(loc, op->getName()); + state.addAttributes(op->getAttrs()); + state.addOperands(regionArgs); + auto resultTypes = llvm::to_vector<6>( + llvm::map_range(op->getResultTypes(), [](Type type) { + return type.cast().getElementType(); + })); + state.addTypes(resultTypes); + auto *scalarOp = builder.createOperation(state); + builder.create(loc, scalarOp->getResults()); + }); + return success(); + } +}; +} // namespace + +void mlir::populateElementwiseToLinalgConversionPatterns( + OwningRewritePatternList &patterns, MLIRContext *) { + patterns.insert(); +} + +namespace { +class ConvertElementwiseToLinalgPass + : public ConvertElementwiseToLinalgBase { + + void runOnFunction() final { + auto func = getOperation(); + auto *context = &getContext(); + ConversionTarget target(*context); + OwningRewritePatternList patterns; + + populateElementwiseToLinalgConversionPatterns(patterns, context); + target.markUnknownOpDynamicallyLegal([](Operation *op) { + return !isElementwiseMappableOpOnRankedTensors(op); + }); + + if (failed(applyPartialConversion(func, target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> +mlir::createConvertElementwiseToLinalgPass() { + return std::make_unique(); +} diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1068,6 +1068,57 @@ return success(); } +/// Checks if two ShapedTypes are the same, ignoring the element type. +static bool areSameShapedTypeIgnoringElementType(ShapedType a, ShapedType b) { + if (a.getTypeID() != b.getTypeID()) + return false; + if (!a.hasRank()) + return !b.hasRank(); + return a.getShape() == b.getShape(); +} + +LogicalResult OpTrait::impl::verifyElementwiseMappable(Operation *op) { + auto isMappableType = [](Type type) { + return type.isa() || type.isa(); + }; + auto resultMappableTypes = llvm::to_vector<1>( + llvm::make_filter_range(op->getResultTypes(), isMappableType)); + auto operandMappableTypes = llvm::to_vector<2>( + llvm::make_filter_range(op->getOperandTypes(), isMappableType)); + + // If the op only has scalar operand/result types, then we have nothing to + // check. + if (resultMappableTypes.empty() && operandMappableTypes.empty()) + return success(); + + if (!resultMappableTypes.empty() && operandMappableTypes.empty()) + return op->emitOpError("if a result is non-scalar, then at least one " + "operand must be non-scalar"); + + assert(!operandMappableTypes.empty()); + + if (resultMappableTypes.empty()) + return op->emitOpError("if an operand is non-scalar, then there must be at " + "least one non-scalar result"); + + if (resultMappableTypes.size() != op->getNumResults()) + return op->emitOpError( + "if an operand is non-scalar, then all results must be non-scalar"); + + auto mustMatchType = operandMappableTypes[0].cast(); + for (auto type : + llvm::concat(resultMappableTypes, operandMappableTypes)) { + if (!areSameShapedTypeIgnoringElementType(type.cast(), + mustMatchType)) { + return op->emitOpError() << "all non-scalar operands/results must have " + "the same shape and base type: found " + << type << " and " << mustMatchType; + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // BinaryOp implementation //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir @@ -0,0 +1,60 @@ +// RUN: mlir-opt -convert-elementwise-to-linalg -split-input-file %s | FileCheck %s + +// In-depth checking of the linalg.generic op for a very trivial case. +// CHECK: #map = affine_map<() -> ()> +// CHECK-LABEL: func @addf_rank0 +func @addf_rank0(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: %{{.*}} = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%{{.*}}, %{{.*}} : tensor, tensor) { + // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): + // CHECK: %[[YIELD:.*]] = addf %[[LHS]], %[[RHS]] : f32 + // CHECK: linalg.yield %[[YIELD]] : f32 + // CHECK: } -> tensor + %0 = addf %arg0, %arg1 : tensor + return %0 : tensor +} + +// ----- + +// Check indexing maps and iterator types for the rank > 0 case. +// CHECK: #map = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func @addf_rank1 +func @addf_rank1(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: linalg.generic{{.*}}indexing_maps = [#map, #map, #map], iterator_types = ["parallel"] + %0 = addf %arg0, %arg1 : tensor + return %0 : tensor +} + +// ----- + +// Check a unary op. +// CHECK-LABEL: func @exp +func @exp(%arg0: tensor) -> tensor { + // CHECK: linalg.generic + // CHECK: ^bb0(%[[SCALAR:.*]]: f32): + // CHECK: %[[YIELD:.*]] = exp %[[SCALAR]] : f32 + // CHECK: linalg.yield %[[YIELD]] : f32 + %0 = exp %arg0 : tensor + return %0 : tensor +} + +// ----- + +// Check a case with varying operand types. +// CHECK-LABEL: func @select +func @select(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: linalg.generic + // CHECK: ^bb0(%[[PRED:.*]]: i1, %[[TRUE_VAL:.*]]: i32, %[[FALSE_VAL:.*]]: i32): + // CHECK: select %[[PRED]], %[[TRUE_VAL]], %[[FALSE_VAL]] : i32 + %0 = select %arg0, %arg1, %arg2 : tensor, tensor + return %0 : tensor +} + +// ----- + +// Spot-check an op that requires copying attributes properly to the created scalar op. +// CHECK-LABEL: func @cmpf( +func @cmpf(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: cmpf "olt", %{{.*}}, %{{.*}} : f32 + %0 = cmpf "olt", %arg0, %arg1 : tensor + return %0 : tensor +} diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: test_index_cast_shape_error func @test_index_cast_shape_error(%arg0 : tensor) -> tensor<2xi64> { - // expected-error @+1 {{requires the same shape for all operands and results}} + // expected-error @+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<2xi64>' and 'tensor'}} %0 = index_cast %arg0 : tensor to tensor<2xi64> return %0 : tensor<2xi64> } @@ -11,7 +11,7 @@ // CHECK-LABEL: test_index_cast_tensor_error func @test_index_cast_tensor_error(%arg0 : tensor) -> i64 { - // expected-error @+1 {{requires the same shape for all operands and results}} + // expected-error @+1 {{if an operand is non-scalar, then there must be at least one non-scalar result}} %0 = index_cast %arg0 : tensor to i64 return %0 : i64 } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -267,7 +267,7 @@ func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) { ^bb0(%cond : vector<12xi1>, %t : vector<42xi32>, %f : vector<42xi32>): - // expected-error@+1 {{expected condition type to have the same shape as the result type, expected 'vector<42xi1>', but got 'vector<12xi1>'}} + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<42xi32>' and 'vector<12xi1>'}} %r = "std.select"(%cond, %t, %f) : (vector<12xi1>, vector<42xi32>, vector<42xi32>) -> vector<42xi32> } @@ -275,7 +275,7 @@ func @func_with_ops(tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) { ^bb0(%cond : tensor<12xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>): - // expected-error@+1 {{expected condition type to have the same shape as the result type, expected 'tensor<42xi1>', but got 'tensor<12xi1>'}} + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<42xi32>' and 'tensor<12xi1>'}} %r = "std.select"(%cond, %t, %f) : (tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32> } @@ -685,7 +685,7 @@ // ----- func @fpext_vec(%arg0 : vector<2xf16>) { - // expected-error@+1 {{requires the same shape for all operands and results}} + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<3xf32>' and 'vector<2xf16>'}} %0 = fpext %arg0 : vector<2xf16> to vector<3xf32> return } @@ -757,7 +757,7 @@ // ----- func @fptrunc_vec(%arg0 : vector<2xf16>) { - // expected-error@+1 {{requires the same shape for all operands and results}} + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<3xf32>' and 'vector<2xf16>'}} %0 = fptrunc %arg0 : vector<2xf16> to vector<3xf32> return } diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -166,6 +166,84 @@ // ----- +func @failedElementwiseMappable_different_rankedness(%arg0: tensor, %arg1: tensor<*xf32>) { + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<*xf32>' and 'tensor'}} + %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor, tensor<*xf32>) -> tensor<*xf32> +} + +// ----- + +func @failedElementwiseMappable_different_rank(%arg0: tensor, %arg1: tensor) { + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor' and 'tensor'}} + %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor, tensor) -> tensor +} + +// ----- + +func @failedElementwiseMappable_different_shape(%arg0: tensor, %arg1: tensor<5xf32>) { + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<5xf32>' and 'tensor'}} + %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor, tensor<5xf32>) -> tensor +} + +// ----- + +func @failedElementwiseMappable_different_base_type(%arg0: vector<2xf32>, %arg1: tensor<2xf32>) { + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<2xf32>' and 'vector<2xf32>'}} + %0 = "test.elementwise_mappable"(%arg0, %arg1) : (vector<2xf32>, tensor<2xf32>) -> tensor<2xf32> + return +} + +// ----- + +func @failedElementwiseMappable_non_scalar_output(%arg0: vector<2xf32>) { + // expected-error@+1 {{if an operand is non-scalar, then there must be at least one non-scalar result}} + %0 = "test.elementwise_mappable"(%arg0) : (vector<2xf32>) -> f32 + return +} + +// ----- + +func @failedElementwiseMappable_non_scalar_result_all_scalar_input(%arg0: f32) { + // expected-error@+1 {{if a result is non-scalar, then at least one operand must be non-scalar}} + %0 = "test.elementwise_mappable"(%arg0) : (f32) -> tensor + return +} + +// ----- + +func @failedElementwiseMappable_mixed_scalar_non_scalar_results(%arg0: tensor<10xf32>) { + // expected-error@+1 {{if an operand is non-scalar, then all results must be non-scalar}} + %0, %1 = "test.elementwise_mappable"(%arg0) : (tensor<10xf32>) -> (f32, tensor<10xf32>) + return +} + +// ----- + +func @failedElementwiseMappable_zero_results(%arg0: tensor<10xf32>) { + // expected-error@+1 {{if an operand is non-scalar, then there must be at least one non-scalar result}} + "test.elementwise_mappable"(%arg0) : (tensor<10xf32>) -> () + return +} + +// ----- + +func @failedElementwiseMappable_zero_operands() { + // expected-error@+1 {{if a result is non-scalar, then at least one operand must be non-scalar}} + "test.elementwise_mappable"() : () -> (tensor<6xf32>) + return +} + +// ----- + +func @succeededElementwiseMappable(%arg0: vector<2xf32>) { + // Check that varying element types are allowed. + // CHECK: test.elementwise_mappable + %0 = "test.elementwise_mappable"(%arg0) : (vector<2xf32>) -> vector<2xf16> + return +} + +// ----- + func @failedHasParent_wrong_parent() { "some.op"() ({ // expected-error@+1 {{'test.child' op expects parent op 'test.parent'}} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -370,6 +370,12 @@ let results = (outs Variadic); } +def ElementwiseMappableOp : TEST_Op<"elementwise_mappable", + [ElementwiseMappable]> { + let arguments = (ins Variadic); + let results = (outs Variadic); +} + def ArgAndResHaveFixedElementTypesOp : TEST_Op<"arg_and_res_have_fixed_element_types", [PredOpTrait<"fixed type combination",