diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1863,7 +1863,9 @@ if (parser.resolveOperands(outputs, outputTypes, outputsOperandsLoc, result.operands)) return failure(); - result.addTypes(outputTypes); + for (Type outputType : outputTypes) + if (outputType.isa()) + result.addTypes(outputType); } // Parse attributes. diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -6,6 +6,8 @@ // Test that we can lower all the way to LLVM without crashing, don't check results here. // DISABLED: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1 +// CHECK-DAG: #[[$id_2d:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$id_1d:.*]] = affine_map<(d0, d1, d2) -> (d1)> // CHECK-DAG: #[[$permute_0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)> // CHECK-DAG: #[[$permute_1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1, d0)> // CHECK-DAG: #[[$reshape5D01:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)> @@ -881,3 +883,61 @@ } // CHECK-LABEL: func @tiled_loop_reduction // CHECK: iterators[ + +// ----- + +#trait_6 = { + indexing_maps = [ + #id_3d, + #id_2d, + #id_1d, + #id_1d + ], + iterator_types = ["reduction", "parallel", "reduction"] +} +#map_1 = affine_map<(d0, d1, d2)[s0] -> (d0 * 768 + s0 + d1 * 32 + d2)> +#map_2 = affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)> +#map_3 = affine_map<(d0)[s0] -> (d0 + s0)> + +func @tiled_loop_on_buffers(%input_3d: memref<16x24x32xf32>, + %input_2d: memref<16x32xf32>, + %input_1d: memref<24xf32>, + %output: memref<24xf32>) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c4 = constant 4 : index + %c8 = constant 8 : index + %X = memref.dim %input_3d, %c0 : memref<16x24x32xf32> + %Y = memref.dim %input_3d, %c1 : memref<16x24x32xf32> + %Z = memref.dim %input_3d, %c2 : memref<16x24x32xf32> + linalg.tiled_loop (%i, %j, %k) = (%c0, %c0, %c0) + to (%X, %Y, %Z) step (%c2, %c4, %c8) + ins(%input_3d, %input_2d: memref<16x24x32xf32>, memref<16x32xf32>) + outs( %output: memref<24xf32>) + iterators["reduction", "parallel", "reduction"] { + %sub_3d = memref.subview %input_3d[%i, %j, %k][2, 4, 8][1, 1, 1] + : memref<16x24x32xf32> to memref<2x4x8xf32, #map_1> + %sub_2d = memref.subview %input_2d[%i, %k][2, 8][1, 1] + : memref<16x32xf32> to memref<2x8xf32, #map_2> + %sub_1d = memref.subview %input_1d[%j] [4] [1] + : memref<24xf32> to memref<4xf32, #map_3> + %sub_out = memref.subview %output[%j] [4] [1] + : memref<24xf32> to memref<4xf32, #map_3> + linalg.generic #trait_6 + ins(%sub_3d, %sub_2d, %sub_1d + : memref<2x4x8xf32, #map_1>, + memref<2x8xf32, #map_2>, + memref<4xf32, #map_3>) + outs(%sub_out : memref<4xf32, #map_3>) { + ^bb0(%i3d: f32, %i2d: f32, %i1d: f32, %o: f32): + %0 = addf %i3d, %i2d : f32 + %1 = addf %0, %i1d : f32 + linalg.yield %1 : f32 + } + linalg.yield + } + return +} +// CHECK-LABEL: func @tiled_loop_on_buffers +// CHECK: iterators[