diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -1026,7 +1026,16 @@ return $_op.getLibraryCallName(); }] >, - + InterfaceMethod< + /*desc=*/[{ + Return whether the op accesses the iteration indices. + }], + /*retTy=*/"bool", + /*methodName=*/"hasIndexSemantics", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"" + >, //===------------------------------------------------------------------===// // Linalg generalization hooks. //===------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -618,5 +618,51 @@ let hasFolder = 1; } +def Linalg_IndexOp : Linalg_Op<"index", [NoSideEffect]>, + Arguments<(ins Confined]>:$dim)>, + Results<(outs Index:$result)> { + let summary = "linalg index operation"; + let description = [{ + The `linalg.index` operation returns the iteration index of the immediately + enclosing linalg structured operation for the iteration dimension `dim`. The + `dim` attribute specifies the position of the accessed dimension in the + indexing map domain. + + Example: + + ```mlir + #map = affine_map<(i, j) -> (i, j)> + linalg.generic {indexing_maps = [#map, #map], + iterator_types = ["parallel", "parallel"]} + outs(%I, %J : memref, memref) { + ^bb0(%arg0 : index, %arg1 : index): + // Access the outer iteration dimension i + %i = linalg.index 0 : index + // Access the inner iteration dimension j + %j = linalg.index 1 : index + linalg.yield %i, %j : index, index + } + ``` + + This may lower to IR resembling: + + ```mlir + %0 = dim %I, %c0 : memref + %1 = dim %I, %c1 : memref + scf.for %i = %c0 to %0 step %c1 { + scf.for %j = %c0 to %1 step %c1 { + store %i, %I[%i, %j] : memref + store %j, %J[%i, %j] : memref + } + } + ``` + }]; + let builders = [ + OpBuilder<(ins "int64_t":$dim), + [{ build($_builder, $_state, $_builder.getIndexType(), dim); }]> + ]; + + let assemblyFormat = [{ $dim attr-dict `:` type($result) }]; +} #endif // LINALG_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -35,6 +35,14 @@ return isa(this->getOperation()) ? getNumLoops() : 0; } + // Return whether the op accesses the iteration indices. + bool hasIndexSemantics() { + Operation *op = this->getOperation(); + if(op->getNumRegions() == 0 || op->getRegion(0).empty()) + return false; + return !op->getRegion(0).front().getOps().empty(); + } + LogicalResult reifyReturnTypeShapesPerResultDim(OpBuilder &b, SmallVectorImpl> &reifiedReturnShapes) { return cast(getOperation()).reifyReturnTypeShapesPerResultDim(b, diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -106,6 +106,10 @@ if (isa(op) || isa(op)) return failure(); + // TODO: remove once index ops are supported. + if (op.hasIndexSemantics()) + return failure(); + auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); if (!libraryCallName) return failure(); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2047,6 +2047,21 @@ return foldMemRefCast(*this); } +//===----------------------------------------------------------------------===// +// IndexOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(IndexOp op) { + auto linalgOp = dyn_cast(op->getParentOp()); + if (!linalgOp) + return op.emitOpError("expected parent op with LinalgOp interface"); + if (linalgOp.getNumLoops() <= op.dim()) + return op.emitOpError("expected dim (") + << op.dim() << ") to be lower than the number of loops (" + << linalgOp.getNumLoops() << ") of the enclosing LinalgOp"; + return success(); +} + /////// Operations corresponding to library calls defined with Tablegen //////// template diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -177,6 +177,10 @@ using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOpTy op, PatternRewriter &rewriter) const override { + // TODO: remove once index ops are supported. + if (op.hasIndexSemantics()) + return failure(); + SmallVector indexingMaps = op.getIndexingMaps(); if (indexingMaps.empty()) return failure(); @@ -321,6 +325,10 @@ using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOpTy op, PatternRewriter &rewriter) const override { + // TODO: remove once index ops are supported. + if (op.hasIndexSemantics()) + return failure(); + if (!op.hasTensorSemantics()) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -28,6 +28,10 @@ /// Implementation of fusion of generic ops and indexed_generic ops. static bool areElementwiseOpsFusable(LinalgOp producer, LinalgOp consumer, unsigned consumerIdx) { + // TODO: remove once index ops are supported. + if (producer.hasIndexSemantics() || consumer.hasIndexSemantics()) + return false; + // Producer and consumer must have tensor semantics. if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) return false; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -527,7 +527,9 @@ LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - if (!isa(op)) + auto linalgOp = dyn_cast(op); + // TODO: remove hasIndexSemantics check once index ops are supported. + if (!linalgOp || linalgOp.hasIndexSemantics()) return failure(); if (!linalgOpToLoopsImpl(op, rewriter, interchangeVector)) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -246,7 +246,8 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { LinalgOp linalgOp = dyn_cast(op); - if (!linalgOp) + // TODO: remove hasIndexSemantics check once index ops are supported. + if (!linalgOp || linalgOp.hasIndexSemantics()) return failure(); if (failed(filter.checkAndNotify(rewriter, linalgOp))) return failure(); @@ -314,7 +315,8 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( Operation *op, PatternRewriter &rewriter) const { LinalgOp linalgOp = dyn_cast(op); - if (!linalgOp) + // TODO: remove hasIndexSemantics check once index ops are supported. + if (!linalgOp || linalgOp.hasIndexSemantics()) return failure(); if (failed(filter.checkAndNotify(rewriter, linalgOp))) return failure(); @@ -407,7 +409,8 @@ LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite( Operation *op, PatternRewriter &rewriter) const { LinalgOp linalgOp = dyn_cast(op); - if (!linalgOp) + // TODO: remove hasIndexSemantics check once index ops are supported. + if (!linalgOp || linalgOp.hasIndexSemantics()) return failure(); if (failed(filter.checkAndNotify(rewriter, linalgOp))) return failure(); @@ -465,7 +468,8 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( Operation *op, PatternRewriter &rewriter) const { LinalgOp linalgOp = dyn_cast(op); - if (!linalgOp) + // TODO: remove hasIndexSemantics check once index ops are supported. + if (!linalgOp || linalgOp.hasIndexSemantics()) return failure(); if (failed(filter.checkAndNotify(rewriter, linalgOp))) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -402,6 +402,9 @@ for (Type outputTensorType : linalgOp.getOutputTensorTypes()) if (!outputTensorType.cast().hasStaticShape()) return failure(); + // TODO: remove once index ops are supported. + if (linalgOp.hasIndexSemantics()) + return failure(); if (isElementwise(op)) return success(); return success(isaContractionOpInterface(linalgOp)); diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -566,3 +566,19 @@ // CHECK-SAME: outs(%[[FILL]] : tensor) // CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] [#[[MAP2]]] // CHECK: return %[[RESULT_RESHAPE]] + +// ----- + +// CHECK: #{{.+}} = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: @index_op +func @index_op(%arg0: memref<1x8xindex>) { + linalg.generic { + indexing_maps = [affine_map<(i, j) -> (i, j)>], + iterator_types = ["parallel", "parallel"]} + outs(%arg0 : memref<1x8xindex>) { + ^bb0(%arg1: index): // no predecessors + %0 = linalg.index 1 : index + linalg.yield %0 : index + } + return +} diff --git a/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir b/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir --- a/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir +++ b/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir @@ -188,3 +188,53 @@ // CHECK: {{.*}} = index_cast [[j_new]] : index to i32 // CHECK: linalg.generic // CHECK: addf + +// ----- + +#map = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> +#id_2d = affine_map<(d0, d1) -> (d0, d1)> +#pointwise_2d_trait = { + indexing_maps = [#id_2d], + iterator_types = ["parallel", "parallel"] +} +func @index_op(%A: memref, + %B: memref) { + linalg.generic #pointwise_2d_trait + outs(%B : memref) { + ^bb0(%arg6: index): // no predecessors + %2 = constant 0 : index + linalg.yield %2 : index + } + %c1 = constant 1 : index + %c0 = constant 0 : index + %c25 = constant 25 : index + %c10 = constant 10 : index + %0 = memref.dim %A, %c0 : memref + %1 = memref.dim %A, %c1 : memref + %2 = memref.dim %B, %c0 : memref + %3 = memref.dim %B, %c1 : memref + scf.for %arg2 = %c0 to %0 step %c10 { + scf.for %arg3 = %c0 to %1 step %c25 { + %4 = memref.subview %A[%arg2, %arg3][%c10, %c25][%c1, %c1] : + memref to memref + %5 = memref.subview %B[%arg2, %arg3][%c10, %c25][%c1, %c1] : + memref to memref + linalg.generic { + indexing_maps = [#id_2d, #id_2d], + iterator_types = ["parallel", "parallel"]} + ins(%4 : memref) + outs(%5 : memref) { + ^bb0(%arg6: index, %arg7: index): + %6 = linalg.index 0 : index + linalg.yield %6 : index + } + } + } + return +} +// CHECK-LABEL: func @index_op +// CHECK: linalg.generic +// CHECK: scf.for +// CHECK: scf.for +// CHECK-NOT: scf.for +// CHECK: linalg.generic diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -652,3 +652,29 @@ // CHECK-NEXT: return %[[R]] : tensor<1x8xi32> return %1 : tensor<1x8xi32> } + +// ----- + +// CHECK-LABEL: func @index_op( +// CHECK-COUNT-2: linalg.generic +func @index_op(%arg0: tensor<1x8xindex>, %arg1: tensor<1x8xindex>) -> tensor<1x8xindex> { + %0 = linalg.generic { + indexing_maps = [affine_map<(i, j) -> (i, j)>], + iterator_types = ["parallel", "parallel"]} + outs(%arg0 : tensor<1x8xindex>) { + ^bb0(%a: index): // no predecessors + %2 = linalg.index 1 : index + linalg.yield %2 : index + } -> tensor<1x8xindex> + %1 = linalg.generic { + indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>], + iterator_types = ["parallel", "parallel"]} + ins(%0 : tensor<1x8xindex>) + outs(%arg1 : tensor<1x8xindex>) { + ^bb0(%a: index, %b: index): // no predecessors + %2 = linalg.index 0 : index + %3 = addi %2, %a : index + linalg.yield %3 : index + } -> tensor<1x8xindex> + return %1 : tensor<1x8xindex> +} diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -24,6 +24,41 @@ // ----- +func @index_parent() { + // expected-error @+1 {{op expected parent op with LinalgOp interface}} + linalg.index 0 : index +} + +// ----- + +func @index_dim_lower_than_number_of_loops(%arg0: memref) { + // expected-error @+6 {{op expected dim (2) to be lower than the number of loops (0) of the enclosing LinalgOp}} + linalg.generic { + indexing_maps = [ affine_map<() -> ()> ], + iterator_types = []} + outs(%arg0 : memref) { + ^bb(%0: f32): + linalg.index 2 : index + linalg.yield %0 : f32 + } +} + +// ----- + +func @index_dim_negative(%arg0: memref) { + // expected-error @+6 {{op attribute 'dim' failed to satisfy constraint: 64-bit signless integer attribute whose minimum value is 0}} + linalg.generic { + indexing_maps = [ affine_map<() -> ()> ], + iterator_types = []} + outs(%arg0 : memref) { + ^bb(%0: f32): + linalg.index -1 : index + linalg.yield %0 : f32 + } +} + +// ----- + func @generic_no_region(%arg0: memref) { // expected-error @+5 {{expected '{' to begin a region}} linalg.generic { diff --git a/mlir/test/Dialect/Linalg/loop-order.mlir b/mlir/test/Dialect/Linalg/loop-order.mlir --- a/mlir/test/Dialect/Linalg/loop-order.mlir +++ b/mlir/test/Dialect/Linalg/loop-order.mlir @@ -1,6 +1,6 @@ -// RUN: mlir-opt %s -convert-linalg-to-loops="interchange-vector=4,0,3,1,2" | FileCheck --check-prefix=LOOP %s -// RUN: mlir-opt %s -convert-linalg-to-parallel-loops="interchange-vector=4,0,3,1,2" | FileCheck --check-prefix=PARALLEL %s -// RUN: mlir-opt %s -convert-linalg-to-affine-loops="interchange-vector=4,0,3,1,2" | FileCheck --check-prefix=AFFINE %s +// RUN: mlir-opt %s -convert-linalg-to-loops="interchange-vector=4,0,3,1,2" -split-input-file | FileCheck --check-prefix=LOOP %s +// RUN: mlir-opt %s -convert-linalg-to-parallel-loops="interchange-vector=4,0,3,1,2" -split-input-file | FileCheck --check-prefix=PARALLEL %s +// RUN: mlir-opt %s -convert-linalg-to-affine-loops="interchange-vector=4,0,3,1,2" -split-input-file | FileCheck --check-prefix=AFFINE %s func @copy(%input: memref<1x2x3x4x5xf32>, %output: memref<1x2x3x4x5xf32>) { linalg.copy(%input, %output): memref<1x2x3x4x5xf32>, memref<1x2x3x4x5xf32> @@ -22,3 +22,24 @@ // AFFINE: affine.for %{{.*}} = 0 to 2 // AFFINE: affine.for %{{.*}} = 0 to 3 +// ----- + +func @index_op(%arg0: memref<4x8xindex>) { + linalg.generic { + indexing_maps = [affine_map<(i, j) -> (i, j)>], + iterator_types = ["parallel", "parallel"]} + outs(%arg0 : memref<4x8xindex>) { + ^bb0(%arg1: index): // no predecessors + %0 = linalg.index 1 : index + linalg.yield %0 : index + } + return +} +// LOOP-LABEL: @index_op +// LOOP: linalg.generic + +// PARALLEL-LABEL: @index_op +// PARALLEL: linalg.generic + +// AFFINE-LABEL: @index_op +// AFFINE: linalg.generic diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -525,6 +525,9 @@ outs(%arg1 : memref) attrs = {foo = 1} { ^bb(%a: vector<3x4xi4>, %b: f32) : + %0 = linalg.index 0 : index + %1 = linalg.index 1 : index + %2 = linalg.index 2 : index linalg.yield %b : f32 } return @@ -538,6 +541,9 @@ // CHECK-SAME: outs({{.*}} : memref) // CHECK-SAME: attrs = {foo = 1 : i64} { // CHECK: ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32): +// CHECK: %{{.*}} = linalg.index 0 : index +// CHECK: %{{.*}} = linalg.index 1 : index +// CHECK: %{{.*}} = linalg.index 2 : index // CHECK: linalg.yield %{{.*}} : f32 func @indexed_generic(%arg0: memref, offset: ?, strides: [?, 1]>, diff --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir --- a/mlir/test/Dialect/Linalg/tile.mlir +++ b/mlir/test/Dialect/Linalg/tile.mlir @@ -377,3 +377,18 @@ // TILE-234: for // TILE-234-NOT: for // TILE-234: linalg.generic + +// TILE-2-LABEL: func @index_op +// TILE-2-NOT: for +// TILE-2: linalg.generic +func @index_op(%arg0: memref) { + linalg.generic { + indexing_maps = [affine_map<(i, j) -> (i, j)>], + iterator_types = ["parallel", "parallel"]} + outs(%arg0 : memref) { + ^bb0(%arg1: index): // no predecessors + %0 = linalg.index 1 : index + linalg.yield %0 : index + } + return +} diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -469,3 +469,19 @@ } : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32> return %0 : tensor<6x?x?x?xf32> } + +// ----- + +// CHECK-LABEL: @index_op +// CHECK: linalg.generic +func @index_op(%arg0: memref<4x8xindex>) { + linalg.generic { + indexing_maps = [affine_map<(i, j) -> (i, j)>], + iterator_types = ["parallel", "parallel"]} + outs(%arg0 : memref<4x8xindex>) { + ^bb0(%arg1: index): // no predecessors + %0 = linalg.index 1 : index + linalg.yield %0 : index + } + return +} diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp @@ -126,6 +126,10 @@ // Save original Linalg ops, we only want to make a pass over those. SmallVector linalgOps; f.walk([&](LinalgOp op) { + // TODO: remove hasIndexSemantics check once index ops are supported. + if (op.hasIndexSemantics()) + return; + // TODO: support multi-results. if (op->getNumResults() <= 1) linalgOps.push_back(op);