diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -18,6 +18,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/TilingInterface.td" include "mlir/Interfaces/ViewLikeInterface.td" +include "mlir/IR/OpAsmInterface.td" class Tensor_Op traits = []> : Op; @@ -46,7 +47,9 @@ //===----------------------------------------------------------------------===// def Tensor_CastOp : Tensor_Op<"cast", [ - DeclareOpInterfaceMethods, NoSideEffect + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + NoSideEffect ]> { let summary = "tensor cast operation"; let description = [{ @@ -82,7 +85,10 @@ // DimOp //===----------------------------------------------------------------------===// -def Tensor_DimOp : Tensor_Op<"dim", [NoSideEffect, ShapedDimOpInterface]> { +def Tensor_DimOp : Tensor_Op<"dim", [ + DeclareOpInterfaceMethods, + NoSideEffect, + ShapedDimOpInterface]> { let summary = "dimension index operation"; let description = [{ The `tensor.dim` operation takes a tensor and a dimension operand of type @@ -199,11 +205,12 @@ // ExtractOp //===----------------------------------------------------------------------===// -def Tensor_ExtractOp : Tensor_Op<"extract", - [NoSideEffect, - TypesMatchWith<"result type matches element type of tensor", - "tensor", "result", - "$_self.cast().getElementType()">]> { +def Tensor_ExtractOp : Tensor_Op<"extract", [ + DeclareOpInterfaceMethods, + NoSideEffect, + TypesMatchWith<"result type matches element type of tensor", + "tensor", "result", + "$_self.cast().getElementType()">]> { let summary = "element extraction operation"; let description = [{ The `tensor.extract` op reads a tensor and returns one @@ -242,8 +249,10 @@ //===----------------------------------------------------------------------===// def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice", [ - NoSideEffect, AttrSizedOperandSegments, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, + AttrSizedOperandSegments, + NoSideEffect, OffsetSizeAndStrideOpInterface ]> { let summary = "extract slice operation"; @@ -436,6 +445,7 @@ //===----------------------------------------------------------------------===// def Tensor_FromElementsOp : Tensor_Op<"from_elements", [ + DeclareOpInterfaceMethods, NoSideEffect, TypesMatchWith<"operand types match result element type", "result", "elements", "SmallVector(" @@ -481,6 +491,7 @@ //===----------------------------------------------------------------------===// def Tensor_GatherOp : Tensor_Op<"gather", [ + DeclareOpInterfaceMethods, NoSideEffect ]> { let summary = "gather a subset of a tensor at specified indices"; @@ -618,10 +629,11 @@ // GenerateOp //===----------------------------------------------------------------------===// -def Tensor_GenerateOp : Tensor_Op<"generate", - [RecursiveSideEffects, - DeclareOpInterfaceMethods, - SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> { +def Tensor_GenerateOp : Tensor_Op<"generate", [ + DeclareOpInterfaceMethods, + RecursiveSideEffects, + DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> { let summary = "Creates a dynamically sized tensor from elements"; let description = [{ This operation creates a dynamically sized tensor with elements of any type. @@ -664,14 +676,15 @@ // InsertOp //===----------------------------------------------------------------------===// -def Tensor_InsertOp : Tensor_Op<"insert", - [NoSideEffect, - TypesMatchWith<"result type matches type of dest", - "dest", "result", - "$_self.cast()">, - TypesMatchWith<"scalar type matches element type of dest", - "dest", "scalar", - "$_self.cast().getElementType()">]> { +def Tensor_InsertOp : Tensor_Op<"insert", [ + DeclareOpInterfaceMethods, + NoSideEffect, + TypesMatchWith<"result type matches type of dest", + "dest", "result", + "$_self.cast()">, + TypesMatchWith<"scalar type matches element type of dest", + "dest", "scalar", + "$_self.cast().getElementType()">]> { let summary = "element insertion operation"; let description = [{ The `tensor.insert` op writes a tensor into a tensor `dest`as specified by @@ -717,8 +730,11 @@ //===----------------------------------------------------------------------===// def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [ - NoSideEffect, AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, + AttrSizedOperandSegments, + NoSideEffect, + OffsetSizeAndStrideOpInterface, TypesMatchWith<"expected result type to match dest type", "dest", "result", "$_self"> ]> { @@ -854,7 +870,9 @@ // RankOp //===----------------------------------------------------------------------===// -def Tensor_RankOp : Tensor_Op<"rank", [NoSideEffect]> { +def Tensor_RankOp : Tensor_Op<"rank", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "rank operation"; let description = [{ The `tensor.rank` operation takes a tensor operand and returns its rank. @@ -878,7 +896,9 @@ // ReshapeOp //===----------------------------------------------------------------------===// -def Tensor_ReshapeOp: Tensor_Op<"reshape", [NoSideEffect]> { +def Tensor_ReshapeOp: Tensor_Op<"reshape", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "tensor reshape operation"; let description = [{ The `reshape` operation converts a tensor from one type to an equivalent @@ -941,7 +961,9 @@ //===----------------------------------------------------------------------===// class Tensor_ReassociativeReshapeOp traits = []> : - Tensor_Op, + Tensor_Op, + NoSideEffect])>, Arguments<(ins AnyTensor:$src, IndexListArrayAttr:$reassociation)>, Results<(outs AnyTensor:$result)> { @@ -1091,7 +1113,10 @@ // PadOp //===----------------------------------------------------------------------===// -def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect, +def Tensor_PadOp : Tensor_Op<"pad", [ + DeclareOpInterfaceMethods, + AttrSizedOperandSegments, + NoSideEffect, SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> { let summary = "tensor pad operation"; let description = [{ @@ -1433,6 +1458,7 @@ //===----------------------------------------------------------------------===// def Tensor_ScatterOp : Tensor_Op<"scatter", [ + DeclareOpInterfaceMethods, NoSideEffect ]> { let summary = @@ -1573,6 +1599,7 @@ //===----------------------------------------------------------------------===// def Tensor_SplatOp : Tensor_Op<"splat", [ + DeclareOpInterfaceMethods, NoSideEffect, TypesMatchWith<"operand type matches element type of result", "aggregate", "input", diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -58,6 +58,10 @@ // CastOp //===----------------------------------------------------------------------===// +void CastOp::getAsmResultNames(function_ref setNameFn) { + setNameFn(getResult(), "cast"); +} + /// Returns true if `target` is a ranked tensor type that preserves static /// information available in the `source` ranked tensor type. bool mlir::tensor::preservesStaticInformation(Type source, Type target) { @@ -307,6 +311,10 @@ // DimOp //===----------------------------------------------------------------------===// +void DimOp::getAsmResultNames(function_ref setNameFn) { + setNameFn(getResult(), "dim"); +} + void DimOp::build(OpBuilder &builder, OperationState &result, Value source, int64_t index) { auto loc = result.location; @@ -697,6 +705,11 @@ // ExtractOp //===----------------------------------------------------------------------===// +void ExtractOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "extracted"); +} + LogicalResult ExtractOp::verify() { // Verify the # indices match if we have a ranked type. if (auto tensorType = getTensor().getType().dyn_cast()) @@ -756,6 +769,11 @@ // FromElementsOp //===----------------------------------------------------------------------===// +void FromElementsOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "from_elements"); +} + void FromElementsOp::build(OpBuilder &builder, OperationState &result, Type resultType, ValueRange elements) { result.addOperands(elements); @@ -828,6 +846,11 @@ // GatherOp //===----------------------------------------------------------------------===// +void GatherOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "gather"); +} + /// Return the inferred result type for a gatherOp where: /// - sourceType is the type of the source tensor gathered from /// - indicesType is the type of the indices used to gather @@ -911,6 +934,11 @@ // InsertOp //===----------------------------------------------------------------------===// +void InsertOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "inserted"); +} + LogicalResult InsertOp::verify() { // Verify the # indices match if we have a ranked type. if (auto destType = getDest().getType().dyn_cast()) @@ -933,6 +961,11 @@ // GenerateOp //===----------------------------------------------------------------------===// +void GenerateOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "generated"); +} + LogicalResult GenerateOp::reifyResultShapes( OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { reifiedReturnShapes.resize(1, SmallVector(getType().getRank())); @@ -1116,6 +1149,10 @@ // RankOp //===----------------------------------------------------------------------===// +void RankOp::getAsmResultNames(function_ref setNameFn) { + setNameFn(getResult(), "rank"); +} + OpFoldResult RankOp::fold(ArrayRef operands) { // Constant fold rank when the rank of the operand is known. auto type = getOperand().getType(); @@ -1129,6 +1166,11 @@ // ReshapeOp //===----------------------------------------------------------------------===// +void ReshapeOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "reshape"); +} + static int64_t getNumElements(ShapedType type) { int64_t numElements = 1; for (auto dim : type.getShape()) @@ -1170,6 +1212,16 @@ // Reassociative reshape ops //===----------------------------------------------------------------------===// +void CollapseShapeOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "collapsed"); +} + +void ExpandShapeOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "expanded"); +} + SmallVector CollapseShapeOp::getReassociationMaps() { return getSymbolLessAffineMaps(getReassociationExprs()); } @@ -1369,6 +1421,11 @@ // ExtractSliceOp //===----------------------------------------------------------------------===// +void ExtractSliceOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "extracted_slice"); +} + /// An extract_slice result type can be inferred, when it is not /// rank-reduced, from the source type and the static representation of /// offsets, sizes and strides. Special sentinels encode the dynamic case. @@ -1865,6 +1922,11 @@ // InsertSliceOp //===----------------------------------------------------------------------===// +void InsertSliceOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "inserted_slice"); +} + // Build a InsertSliceOp with mixed static and dynamic entries. void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, Value dest, ArrayRef offsets, @@ -2218,6 +2280,10 @@ // PadOp //===----------------------------------------------------------------------===// +void PadOp::getAsmResultNames(function_ref setNameFn) { + setNameFn(getResult(), "padded"); +} + // TODO: Replace custom directive with AllTypesMatch as soon as it // supports optional types. void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, @@ -2725,6 +2791,11 @@ // ScatterOp //===----------------------------------------------------------------------===// +void ScatterOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "scatter"); +} + LogicalResult ScatterOp::verify() { int64_t destRank = getDestType().getRank(); ArrayRef scatterDims = getScatterDims(); @@ -2761,6 +2832,11 @@ // SplatOp //===----------------------------------------------------------------------===// +void SplatOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "splat"); +} + OpFoldResult SplatOp::fold(ArrayRef operands) { auto constOperand = operands.front(); if (!constOperand.isa_and_nonnull()) diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -595,7 +595,7 @@ // CHECK: ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): // CHECK: tensor.yield %cst : f32 // CHECK: } : tensor<2x?x?x3xf32> to tensor<2x?x?x3xf32> - // CHECK: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} ins(%[[PADDED]], %arg1 : tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>) outs(%22 : tensor<2x?x?x3x5xf32>) -> tensor<2x?x?x3x5xf32> + // CHECK: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} ins(%[[PADDED]], %arg1 : tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>) outs(%{{.*}} : tensor<2x?x?x3x5xf32>) -> tensor<2x?x?x3x5xf32> // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[CONV]] {{\[}}[0], [1], [2], [3, 4]] %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [1, 2, 3, 4], dilation = [2, 1], stride = [1, 2]} : (tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x?x?x15xf32> return diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -604,8 +604,8 @@ // CHECK-LABEL: @test_reshape_downrank_6D_dyn func.func @test_reshape_downrank_6D_dyn(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor { - // CHECK: tensor.collapse_shape %arg0 {{\[}}[0, 1, 2, 3, 4, 5]] - // CHECK: tensor.expand_shape %0 {{\[}}[0, 1, 2]] + // CHECK: tensor.collapse_shape {{.*}}[0, 1, 2, 3, 4, 5] + // CHECK: tensor.expand_shape {{.*}}[0, 1, 2] %0 = "tosa.reshape"(%arg0) {new_shape = [-1, 5, 77]} : (tensor<1x2x?x5x7x11xf32>) -> tensor return %0 : tensor } diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir --- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir +++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir @@ -15,7 +15,7 @@ // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]] // CHECK: %[[C2:.+]] = arith.constant 2 : index // CHECK: %[[SUB:.+]] = arith.subi %[[DIM]], %[[C2]] - // CHECK: %2 = tensor.extract_slice %arg0[2] [%[[SUB]]] [1] + // CHECK: tensor.extract_slice %arg0[2] [%[[SUB]]] [1] %0 = "tosa.slice"(%arg0) {start = [2], size = [-1]} : (tensor) -> (tensor) return %0 : tensor } diff --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir --- a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir @@ -5,13 +5,13 @@ // CHECK: linalg.generic // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] - // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}} : tensor<16x4x64xf32>, tensor<4x64x32xf32>) - // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<16x32x4xf32>) { + // CHECK-SAME: ins(%{{[a-zA-Z0-9_]*}}, %{{[a-zA-Z0-9_]*}} : tensor<16x4x64xf32>, tensor<4x64x32xf32>) + // CHECK-SAME: outs(%{{[a-zA-Z0-9_]*}} : tensor<16x32x4xf32>) { // CHECK: linalg.generic // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] - // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<16x32x4xf32>) - // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<16x32xf32>) { + // CHECK-SAME: ins(%{{[a-zA-Z0-9_]*}} : tensor<16x32x4xf32>) + // CHECK-SAME: outs(%{{[a-zA-Z0-9_]*}} : tensor<16x32xf32>) { %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>) outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> return %0: tensor<16x32xf32> diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -2,13 +2,13 @@ // CHECK-LABEL: func @cast( func.func @cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor) { - // CHECK: tensor.cast %arg0 : tensor<*xf32> to tensor + // CHECK: tensor.cast %{{.*}} : tensor<*xf32> to tensor %0 = tensor.cast %arg0 : tensor<*xf32> to tensor - // CHECK: tensor.cast %arg1 : tensor<4x4xf32> to tensor<*xf32> + // CHECK: tensor.cast %{{.*}} : tensor<4x4xf32> to tensor<*xf32> %1 = tensor.cast %arg1 : tensor<4x4xf32> to tensor<*xf32> - // CHECK: tensor.cast %arg2 : tensor to tensor<4x?xf32> + // CHECK: tensor.cast %{{.*}} : tensor to tensor<4x?xf32> %2 = tensor.cast %arg2 : tensor to tensor<4x?xf32> - // CHECK: tensor.cast %2 : tensor<4x?xf32> to tensor + // CHECK: tensor.cast %{{.*}} : tensor<4x?xf32> to tensor %3 = tensor.cast %2 : tensor<4x?xf32> to tensor return }