diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -65,6 +65,22 @@ SmallVector loops; }; +/// Helper template class for building loop.for and affine.loop nests from +/// ranges. +template +class GenericLoopNestRangeBuilder { +public: + GenericLoopNestRangeBuilder(ArrayRef ivs, + ArrayRef ranges); + void operator()(std::function fun = nullptr) { (*builder)(fun); } + +private: + typedef typename std::conditional::value, + AffineLoopNestBuilder, + LoopNestRangeBuilder>::type BuilderType; + std::unique_ptr builder; +}; + } // namespace edsc namespace linalg { diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -6,7 +6,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" @@ -423,7 +422,7 @@ invertedMap, getViewSizes(linalgOp)); assert(loopRanges.size() == allIvs.size()); - LoopNestRangeBuilder(allPIvs, loopRanges)([&] { + GenericLoopNestRangeBuilder(allPIvs, loopRanges)([&] { auto allIvValues = extractValues(allIvs); LinalgScopedEmitter::emitScalarImplementation( allIvValues, linalgOp); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -97,6 +97,29 @@ return ValueHandle::null(); } +template <> +GenericLoopNestRangeBuilder::GenericLoopNestRangeBuilder( + ArrayRef ivs, ArrayRef ranges) { + builder = std::make_unique(ivs, ranges); +} + +template <> +GenericLoopNestRangeBuilder::GenericLoopNestRangeBuilder( + ArrayRef ivs, ArrayRef ranges) { + SmallVector lbs; + SmallVector ubs; + SmallVector steps; + for (const Value range : ranges) { + assert(range.getType() && "expected linalg.range type"); + assert(range.getDefiningOp() && "need operations to extract range parts"); + RangeOp rangeOp = cast(range.getDefiningOp()); + lbs.emplace_back(ValueHandle(rangeOp.min())); + ubs.emplace_back(ValueHandle(rangeOp.max())); + steps.emplace_back(ValueHandle(rangeOp.step())); + } + builder = std::make_unique(ivs, lbs, ubs, steps); +} + static Value emitOrFoldComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef operandsRef, diff --git a/mlir/test/Dialect/Linalg/affine.mlir b/mlir/test/Dialect/Linalg/affine.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/affine.mlir @@ -0,0 +1,55 @@ +// RUN: mlir-opt %s -convert-linalg-to-affine-loops | FileCheck %s + +// Test that we can lower all the way to LLVM without crashing, don't check results here. +// RUN: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1 + +// CHECK-DAG: #[[strided2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) +// CHECK-DAG: #[[strided3D:.*]] = (d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2) + +// CHECK-DAG: #[[stride2Dilation1:.*]] = (d0, d1) -> (d0 * 2 + d1) + +func @matmul(%arg0: memref, %M: index, %N: index, %K: index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %A = view %arg0[%c0][%M, %K] : memref to memref + %B = view %arg0[%c0][%K, %N] : memref to memref + %C = view %arg0[%c0][%M, %N] : memref to memref + linalg.matmul(%A, %B, %C) : memref, memref, memref + return +} + +// CHECK-LABEL: func @matmul(%{{.*}}: memref, +// CHECK-SAME: [[M:arg[0-9]+]]: index +// CHECK-SAME: [[N:arg[0-9]+]]: index +// CHECK-SAME: [[K:arg[0-9]+]]: index +// CHECK: %[[A:.*]] = std.view %{{.*}}[{{.*}}] : memref to memref +// CHECK: %[[B:.*]] = std.view %{{.*}}[{{.*}}] : memref to memref +// CHECK: %[[C:.*]] = std.view %{{.*}}[{{.*}}] : memref to memref +// CHECK: affine.for %{{.*}} = 0 to %{{.*}} { +// CHECK: affine.for %{{.*}} = 0 to %{{.*}} { +// CHECK: affine.for %{{.*}} = 0 to %{{.*}} { +// CHECK-DAG: %[[a:.*]] = affine.load %[[A]][%{{.*}}, %{{.*}}] : memref +// CHECK-DAG: %[[b:.*]] = affine.load %[[B]][%{{.*}}, %{{.*}}] : memref +// CHECK-DAG: %[[inc:.*]] = mulf %[[a]], %[[b]] : f32 +// CHECK-DAG: %[[c:.*]] = affine.load %[[C]][%{{.*}}, %{{.*}}] : memref +// CHECK-DAG: %[[res:.*]] = addf %[[c]], %[[inc]] : f32 +// CHECK: affine.store %[[res]], %[[C]][%{{.*}}, %{{.*}}] : memref + +func @conv_view3(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.conv(%arg0, %arg1, %arg2) {strides = [2]}: memref, memref, memref + return +} + +// CHECK-LABEL: func @conv_view3( +// CHECK: %{{.*}}: memref, %{{.*}}: memref, %{{.*}}: memref) { +// CHECK: %[[Z0:.*]] = dim %arg0, 0 : memref +// CHECK: %[[Q:.*]] = dim %arg0, 1 : memref +// CHECK: %[[K:.*]] = dim %arg0, 2 : memref +// CHECK: %[[B:.*]] = dim %arg1, 0 : memref +// CHECK: %[[X0:.*]] = dim %arg2, 1 : memref +// CHECK: affine.for %{{.*}} = 0 to %[[B]] { +// CHECK: affine.for %{{.*}} = 0 to %[[X0]] { +// CHECK: affine.for %{{.*}} = 0 to %[[K]] { +// CHECK: affine.for %{{.*}} = 0 to %[[Q]] { +// CHECK: affine.for %{{.*}} = 0 to %[[Z0]] { +// CHECK: %[[SUM:.*]] = affine.apply #[[stride2Dilation1]](%{{.*}}, %{{.*}})