diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -30,6 +30,18 @@ using edsc::op::operator+; +/// Templated load/stores. Depending on whether IndexedValueTy is an +/// AffineIndexedValue or StdIndexedValue, an affine_load or std_load is +/// substituted respectively. +template +using load = typename std::conditional< + std::is_same::value, affine_load, + std_load>::type; +template +using store = typename std::conditional< + std::is_same::value, affine_store, + std_store>::type; + static SmallVector makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map, @@ -335,21 +347,22 @@ void emitScalarImplementation(ArrayRef allIvs, PoolingMaxOp op) { auto indices = getInputAndOutputIndices(allIvs, op); // Emit scalar form. - Value lhs = std_load(op.output(), indices.outputs); - Value rhs = std_load(op.input(), indices.inputs); + Value lhs = load(op.output(), indices.outputs); + Value rhs = load(op.input(), indices.inputs); using edsc::op::sgt; Value maxValue = std_select(sgt(lhs, rhs), lhs, rhs); - std_store(maxValue, op.output(), indices.outputs); + store(maxValue, op.output(), indices.outputs); } + template void emitScalarImplementation(ArrayRef allIvs, PoolingMinOp op) { auto indices = getInputAndOutputIndices(allIvs, op); // Emit scalar form. - Value lhs = std_load(op.output(), indices.outputs); - Value rhs = std_load(op.input(), indices.inputs); + Value lhs = load(op.output(), indices.outputs); + Value rhs = load(op.input(), indices.inputs); using edsc::op::slt; Value minValue = std_select(slt(lhs, rhs), lhs, rhs); - std_store(minValue, op.output(), indices.outputs); + store(minValue, op.output(), indices.outputs); } template void emitScalarImplementation(ArrayRef allIvs, PoolingSumOp op) { diff --git a/mlir/test/Dialect/Linalg/affine.mlir b/mlir/test/Dialect/Linalg/affine.mlir --- a/mlir/test/Dialect/Linalg/affine.mlir +++ b/mlir/test/Dialect/Linalg/affine.mlir @@ -123,3 +123,27 @@ // CHECK: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32 // CHECK: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 // CHECK: affine.store %[[res]], %[[mC]][%[[b]], %[[m]], %[[n]]] : memref + +// CHECK-LABEL: func @pooling_max_min +func @pooling_max_min(%arg0: memref, + %arg1: memref, + %arg2: memref) { + linalg.pooling_max(%arg0, %arg1, %arg2) { strides = [2, 1] }: + memref, memref, memref + linalg.pooling_min(%arg0, %arg1, %arg2) { strides = [2, 1] }: + memref, memref, memref + return +} +// This is a basic check to make sure the right load/stores are used. loops.mlir +// checks for the rest. +// CHECK: affine.load +// CHECK-NEXT: affine.load +// CHECK-NEXT: cmpf +// CHECK-NEXT: select +// CHECK-NEXT: affine.store +// The min pooling body. +// CHECK: affine.load +// CHECK-NEXT: affine.load +// CHECK-NEXT: cmpf +// CHECK-NEXT: select +// CHECK-NEXT: affine.store