diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h deleted file mode 100644 --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h +++ /dev/null @@ -1,231 +0,0 @@ -//===- Builders.h - MLIR Declarative Linalg Builders ------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Provides intuitive composable interfaces for building structured MLIR -// snippets in a declarative fashion. -// -//===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_ -#define MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_ - -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" -#include "mlir/EDSC/Builders.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/Builders.h" - -namespace mlir { -class AffineForOp; -class BlockArgument; - -namespace scf { -class ParallelOp; -} // namespace scf - -namespace edsc { -inline void defaultRegionBuilder(ValueRange args) {} - -/// Build a `linalg.generic` op with the specified `inputs`, `outputs`, -/// `resultTensorsTypes` and `region`. -/// -/// `otherValues` and `otherAttributes` may be passed and will be appended as -/// operands and attributes respectively. -/// -/// Prerequisites: -/// ============= -/// -/// 1. `inputs` may contain StructuredIndexed that capture either buffer or -/// tensor values. -/// 2. `outputs` may contain StructuredIndexed that capture either buffer or -/// tensor values. In the future this will be extended with ranked shape values. -/// 4. `resultTensorTypes` may contain return tensor types. -Operation *makeGenericLinalgOp( - ArrayRef iteratorTypes, ArrayRef inputs, - ArrayRef outputs, TypeRange resultTensorTypes, - function_ref regionBuilder = defaultRegionBuilder, - ArrayRef otherValues = {}, ArrayRef otherAttributes = {}); - -namespace ops { -using edsc::StructuredIndexed; - -//===----------------------------------------------------------------------===// -// EDSC builders for linalg generic operations. -//===----------------------------------------------------------------------===// - -/// Build the body of a region to compute a scalar multiply, under the current -/// ScopedContext, at the current insert point. -void mulRegionBuilder(ValueRange args); - -/// Build the body of a region to compute a scalar multiply-accumulate, under -/// the current ScopedContext, at the current insert point. -void macRegionBuilder(ValueRange args); - -/// TODO: In the future we should tie these implementations to something in -/// Tablegen that generates the proper interfaces and the proper sugared named -/// ops. - -/// Build a linalg.pointwise, under the current ScopedContext, at the current -/// insert point, that computes: -/// ``` -/// (i0, ..., in) = (par, ..., par) -/// | -/// | O...(some_subset...(i0, ..., in)) = -/// | some_pointwise_func...(I...(some_other_subset...(i0, ..., in))) -/// ``` -/// -/// This is a very generic entry point that can be configured in many ways to -/// build a perfect loop nest of parallel loops with arbitrarily complex -/// innermost loop code and whatever (explicit) broadcast semantics. -/// -/// This can be used with both out-of-place and in-place semantics. -/// The client is responsible for ensuring the region operations are compatible -/// with in-place semantics and parallelism. - -/// Unary pointwise operation (with broadcast) entry point. -using UnaryPointwiseOpBuilder = function_ref; -Operation *linalg_generic_pointwise(UnaryPointwiseOpBuilder unaryOp, - StructuredIndexed I, StructuredIndexed O); - -/// Build a linalg.pointwise with all `parallel` iterators and a region that -/// computes `O = tanh(I)`. The client is responsible for specifying the proper -/// indexings when creating the StructuredIndexed. -Operation *linalg_generic_pointwise_tanh(StructuredIndexed I, - StructuredIndexed O); - -/// Binary pointwise operation (with broadcast) entry point. -using BinaryPointwiseOpBuilder = function_ref; -Operation *linalg_generic_pointwise(BinaryPointwiseOpBuilder binaryOp, - StructuredIndexed I1, StructuredIndexed I2, - StructuredIndexed O); - -/// Build a linalg.pointwise with all `parallel` iterators and a region that -/// computes `O = I1 + I2`. The client is responsible for specifying the proper -/// indexings when creating the StructuredIndexed. -Operation *linalg_generic_pointwise_add(StructuredIndexed I1, - StructuredIndexed I2, - StructuredIndexed O); - -/// Build a linalg.pointwise with all `parallel` iterators and a region that -/// computes `O = max(I1, I2)`. The client is responsible for specifying the -/// proper indexings when creating the StructuredIndexed. -Operation *linalg_generic_pointwise_max(StructuredIndexed I1, - StructuredIndexed I2, - StructuredIndexed O); - -// TODO: Implement more useful pointwise operations on a per-need basis. - -using MatmulRegionBuilder = function_ref; - -/// Build a linalg.generic, under the current ScopedContext, at the current -/// insert point, that computes: -/// ``` -/// (m, n, k) = (par, par, seq) -/// | -/// | C(m, n) += A(m, k) * B(k, n) -/// ``` -Operation * -linalg_generic_matmul(Value vA, Value vB, Value vC, - MatmulRegionBuilder regionBuilder = macRegionBuilder); - -/// Build a linalg.generic, under the current ScopedContext, at the current -/// insert point, that computes: -/// ``` -/// (m, n, k) = (par, par, seq) -/// | -/// | D(m, n) = C(m, n) + sum_k(A(m, k) * B(k, n)) -/// ``` -/// and returns the tensor `D`. -Operation * -linalg_generic_matmul(Value vA, Value vB, Value vC, RankedTensorType tD, - MatmulRegionBuilder regionBuilder = macRegionBuilder); - -template -Operation * -linalg_generic_matmul(Container values, - MatmulRegionBuilder regionBuilder = macRegionBuilder) { - assert(values.size() == 3 && "Expected exactly 3 values"); - return linalg_generic_matmul(values[0], values[1], values[2], regionBuilder); -} - -/// Build a linalg.generic, under the current ScopedContext, at the current -/// insert point, that computes: -/// ``` -/// (batch, f, [h, w, ...], [kh, kw, ...], c) = -/// | (par, par, [par, par, ...], [red, red, ...], red) -/// | -/// | O(batch, [h, w, ...], f) += -/// | I(batch, -/// | [ -/// | stride[0] * h + dilations[0] * kh, -/// | stride[1] * w + dilations[1] * kw, ... -/// ], -/// | c) -/// | * -/// | W([kh, kw, ...], c, f) -/// ``` -/// If `dilations` or `strides` are left empty, the default value of `1` is used -/// along each relevant dimension. -/// -/// For now `...` must be empty (i.e. only 2-D convolutions are supported). -/// -// TODO: Extend convolution rank with some template magic. -Operation *linalg_generic_conv_nhwc(Value vI, Value vW, Value vO, - ArrayRef strides = {}, - ArrayRef dilations = {}); - -template -Operation *linalg_generic_conv_nhwc(Container values, - ArrayRef strides = {}, - ArrayRef dilations = {}) { - assert(values.size() == 3 && "Expected exactly 3 values"); - return linalg_generic_conv_nhwc(values[0], values[1], values[2], strides, - dilations); -} - -/// Build a linalg.generic, under the current ScopedContext, at the current -/// insert point, that computes: -/// ``` -/// (batch, dm, c, [h, w, ...], [kh, kw, ...]) = -/// | (par, par, par, [par, par, ...], [red, red, ...]) -/// | -/// | O(batch, [h, w, ...], c * depth_multiplier) += -/// | I(batch, -/// | [ -/// | stride[0] * h + dilations[0] * kh, -/// | stride[1] * w + dilations[1] * kw, ... -/// ], -/// | c) -/// | * -/// | W([kh, kw, ...], c, depth_multiplier) -/// ``` -/// If `dilations` or `strides` are left empty, the default value of `1` is used -/// along each relevant dimension. -/// -/// For now `...` must be empty (i.e. only 2-D convolutions are supported). -/// -// TODO: Extend convolution rank with some template magic. -Operation *linalg_generic_dilated_conv_nhwc(Value vI, Value vW, Value vO, - int depth_multiplier = 1, - ArrayRef strides = {}, - ArrayRef dilations = {}); - -template -Operation *linalg_generic_dilated_conv_nhwc(Container values, - int depth_multiplier, - ArrayRef strides = {}, - ArrayRef dilations = {}) { - assert(values.size() == 3 && "Expected exactly 3 values"); - return linalg_generic_dilated_conv_nhwc(values[0], values[1], values[2], - depth_multiplier, strides, dilations); -} - -} // namespace ops -} // namespace edsc -} // namespace mlir - -#endif // MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h b/mlir/include/mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h --- a/mlir/include/mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h @@ -1,26 +0,0 @@ -//===- FoldedIntrinsics.h - MLIR EDSC Intrinsics for Linalg -----*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_LINALG_EDSC_FOLDEDINTRINSICS_H_ -#define MLIR_DIALECT_LINALG_EDSC_FOLDEDINTRINSICS_H_ - -#include "mlir/Dialect/Linalg/EDSC/Builders.h" -#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" - -#include "mlir/Transforms/FoldUtils.h" - -namespace mlir { -namespace edsc { -namespace intrinsics { -} // namespace intrinsics -} // namespace edsc -} // namespace mlir - -#endif // MLIR_DIALECT_LINALG_EDSC_FOLDEDINTRINSICS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h deleted file mode 100644 --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h +++ /dev/null @@ -1,33 +0,0 @@ -//===- Intrinsics.h - MLIR EDSC Intrinsics for Linalg -----------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_ -#define MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_ - -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" -#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" - -namespace mlir { -namespace edsc { -namespace intrinsics { - -using linalg_copy = OperationBuilder; -using linalg_dot = OperationBuilder; -using linalg_fill = OperationBuilder; -using linalg_init_tensor = ValueBuilder; -using linalg_matmul = OperationBuilder; -using linalg_matvec = OperationBuilder; -using linalg_vecmat = OperationBuilder; -using linalg_range = ValueBuilder; -using linalg_reshape = ValueBuilder; -using linalg_yield = OperationBuilder; - -} // namespace intrinsics -} // namespace edsc -} // namespace mlir - -#endif // MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -48,7 +48,8 @@ constexpr const static ::llvm::StringLiteral kInplaceableAttrName = "linalg.inplaceable"; - using RegionBuilderFunType = llvm::function_ref; + using RegionBuilderFunType = + llvm::function_ref; RegionBuilderFunType getRegionBuilder(StringRef name) { return namedStructuredOpRegionBuilders.lookup(name); } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -17,6 +17,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -1233,7 +1233,7 @@ Returns a null function if this named op does not define a region builder. }], - /*retTy=*/"std::function", + /*retTy=*/"std::function", /*methodName=*/"getRegionBuilder", (ins), [{ return ConcreteOp::getRegionBuilder(); }] diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc @@ -1,7 +1,7 @@ ods_def implements_interface : def matmul_column_major(A: f32(K, M), B: f32(N, K)) -> (C: f32(N, M)) { - C(n, m) = std_addf(C(n, m), std_mulf(A(k, m), B(n, k))); + C(n, m) = AddFOp(C(n, m), MulFOp(A(k, m), B(n, k))); } ods_def @@ -9,146 +9,146 @@ def matmul_i8_i8_i32(A: i8(M, K), B: i8(K, N)) -> (C: i32(M, N)) { // TODO: ideally something closer to // C(m, n) += cast(A(m, k)) * cast(B(k, n)) - C(m, n) = std_addi(C(m, n), std_muli(std_sexti32(A(m, k)), std_sexti32(B(k, n)))); + C(m, n) = AddIOp(C(m, n), MulIOp(SignExtendIOp32(A(m, k)), SignExtendIOp32(B(k, n)))); } ods_def implements_interface : def matmul_i16_i16_i32(A: i16(M, K), B: i16(K, N)) -> (C: i32(M, N)) { - C(m, n) = std_addi(C(m, n), std_muli(std_sexti32(A(m, k)), std_sexti32(B(k, n)))); + C(m, n) = AddIOp(C(m, n), MulIOp(SignExtendIOp32(A(m, k)), SignExtendIOp32(B(k, n)))); } ods_def implements_interface : def matmul_i32_i32_i32(A: i32(M, K), B: i32(K, N)) -> (C: i32(M, N)) { - C(m, n) = std_addi(C(m, n), std_muli(A(m, k), B(k, n))); + C(m, n) = AddIOp(C(m, n), MulIOp(A(m, k), B(k, n))); } ods_def implements_interface : def matvec_i8_i8_i32(A: i8(M, N), y: i8(N)) -> (x: i32(M)) { - x(m) = std_addi(x(m), std_muli(std_sexti32(A(m, n)), std_sexti32(y(n)))); + x(m) = AddIOp(x(m), MulIOp(SignExtendIOp32(A(m, n)), SignExtendIOp32(y(n)))); } ods_def implements_interface : def matvec_i16_i16_i32(A: i16(M, N), y: i16(N)) -> (x: i32(M)) { - x(m) = std_addi(x(m), std_muli(std_sexti32(A(m, n)), std_sexti32(y(n)))); + x(m) = AddIOp(x(m), MulIOp(SignExtendIOp32(A(m, n)), SignExtendIOp32(y(n)))); } ods_def implements_interface : def matvec_i32_i32_i32(A: i32(M, N), y: i32(N)) -> (x: i32(M)) { - x(m) = std_addi(x(m), std_muli(A(m, n), y(n))); + x(m) = AddIOp(x(m), MulIOp(A(m, n), y(n))); } ods_def implements_interface : def vecmat_i8_i8_i32(y: i8(M), A: i8(M, N)) -> (x: i32(N)) { - x(n) = std_addi(x(n), std_muli(std_sexti32(y(m)), std_sexti32(A(m, n)))); + x(n) = AddIOp(x(n), MulIOp(SignExtendIOp32(y(m)), SignExtendIOp32(A(m, n)))); } ods_def implements_interface : def vecmat_i16_i16_i32(y: i16(M), A: i16(M, N)) -> (x: i32(N)) { - x(n) = std_addi(x(n), std_muli(std_sexti32(y(m)), std_sexti32(A(m, n)))); + x(n) = AddIOp(x(n), MulIOp(SignExtendIOp32(y(m)), SignExtendIOp32(A(m, n)))); } ods_def implements_interface : def vecmat_i32_i32_i32(y: i32(M), A: i32(M, N)) -> (x: i32(N)) { - x(n) = std_addi(x(n), std_muli(y(m), A(m, n))); + x(n) = AddIOp(x(n), MulIOp(y(m), A(m, n))); } ods_def implements_interface : def dot_i8_i8_i32(A: i8(M), B: i8(M)) -> (C: i32()) { - C() = std_addi(C(), std_muli(std_sexti32(A(m)), std_sexti32(B(m)))); + C() = AddIOp(C(), MulIOp(SignExtendIOp32(A(m)), SignExtendIOp32(B(m)))); } ods_def implements_interface : def dot_i16_i16_i32(A: i16(M), B: i16(M)) -> (C: i32()) { - C() = std_addi(C(), std_muli(std_sexti32(A(m)), std_sexti32(B(m)))); + C() = AddIOp(C(), MulIOp(SignExtendIOp32(A(m)), SignExtendIOp32(B(m)))); } ods_def implements_interface : def dot_i32_i32_i32(A: i32(M), B: i32(M)) -> (C: i32()) { - C() = std_addi(C(), std_muli(A(m), B(m))); + C() = AddIOp(C(), MulIOp(A(m), B(m))); } ods_def implements_interface : def batch_matmul_i8_i8_i32(A: i8(Batch, M, K), B: i8(Batch, K, N)) -> (C: i32(Batch, M, N)) { C(b, m, n) = - std_addi(C(b, m, n), std_muli(std_sexti32(A(b, m, k)), std_sexti32(B(b, k, n)))); + AddIOp(C(b, m, n), MulIOp(SignExtendIOp32(A(b, m, k)), SignExtendIOp32(B(b, k, n)))); } ods_def implements_interface : def batch_matmul_i16_i16_i32(A: i16(Batch, M, K), B: i16(Batch, K, N)) -> (C: i32(Batch, M, N)) { C(b, m, n) = - std_addi(C(b, m, n), std_muli(std_sexti32(A(b, m, k)), std_sexti32(B(b, k, n)))); + AddIOp(C(b, m, n), MulIOp(SignExtendIOp32(A(b, m, k)), SignExtendIOp32(B(b, k, n)))); } ods_def implements_interface : def batch_matmul_i32_i32_i32(A: i32(Batch, M, K), B: i32(Batch, K, N)) -> (C: i32(Batch, M, N)) { - C(b, m, n) = std_addi(C(b, m, n), std_muli(A(b, m, k), B(b, k, n))); + C(b, m, n) = AddIOp(C(b, m, n), MulIOp(A(b, m, k), B(b, k, n))); } ods_def: def conv_1d(I: f32(W), K: f32(KW)) -> (O: f32(W)) { - O(w) = std_addf(O(w), std_mulf(I(w + kw), K(kw))); + O(w) = AddFOp(O(w), MulFOp(I(w + kw), K(kw))); } ods_def: def conv_1d_nwc(I: f32(N, W, C), K: f32(F, KW, C)) -> (O: f32(N, W, F)) { - O(n, w, f) = std_addf(O(n, w, f), std_mulf(I(n, w + kw, c), K(f, kw, c))); + O(n, w, f) = AddFOp(O(n, w, f), MulFOp(I(n, w + kw, c), K(f, kw, c))); } ods_def: def conv_1d_ncw(I: f32(N, C, W), K: f32(F, C, KW)) -> (O: f32(N, F, W)) { - O(n, f, w) = std_addf(O(n, f, w), std_mulf(I(n, c, w + kw), K(f, c, kw))); + O(n, f, w) = AddFOp(O(n, f, w), MulFOp(I(n, c, w + kw), K(f, c, kw))); } ods_def: def conv_2d(I: f32(H, W), K: f32(KH, KW)) -> (O: f32(H, W)) { - O(h, w) = std_addf(O(h, w), std_mulf(I(h + kh, w + kw), K(kh, kw))); + O(h, w) = AddFOp(O(h, w), MulFOp(I(h + kh, w + kw), K(kh, kw))); } ods_def: def conv_2d_nhwc(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F)) { - O(n, h, w, f) = std_addf( - O(n, h, w, f), std_mulf(I(n, h + kh, w + kw, c), K(f, kh, kw, c))); + O(n, h, w, f) = AddFOp( + O(n, h, w, f), MulFOp(I(n, h + kh, w + kw, c), K(f, kh, kw, c))); } ods_def: def conv_2d_nchw(I: f32(N, C, H, W), K: f32(F, C, KH, KW)) -> (O: f32(N, F, H, W)) { - O(n, f, h, w) = std_addf( - O(n, f, h, w), std_mulf(I(n, c, h + kh, w + kw), K(f, c, kh, kw))); + O(n, f, h, w) = AddFOp( + O(n, f, h, w), MulFOp(I(n, c, h + kh, w + kw), K(f, c, kh, kw))); } ods_def: def conv_3d(I: f32(D, H, W), K: f32(KD, KH, KW)) -> (O: f32(D, H, W)) { - O(d, h, w) = std_addf( - O(d, h, w), std_mulf(I(d + kd, h + kh, w + kw), K(kd, kh, kw))); + O(d, h, w) = AddFOp( + O(d, h, w), MulFOp(I(d + kd, h + kh, w + kw), K(kd, kh, kw))); } ods_def: def conv_3d_ndhwc(I: f32(N, D, H, W, C), K: f32(F, KD, KH, KW, C)) -> (O: f32(N, D, H, W, F)) { - O(n, d, h, w, f) = std_addf( + O(n, d, h, w, f) = AddFOp( O(n, d, h, w, f), - std_mulf(I(n, d + kd, h + kh, w + kw, c), K(f, kd, kh, kw, c))); + MulFOp(I(n, d + kd, h + kh, w + kw, c), K(f, kd, kh, kw, c))); } ods_def: def conv_3d_ncdhw(I: f32(N, C, D, H, W), K: f32(F, C, KD, KH, KW)) -> (O: f32(N, F, D, H, W)) { - O(n, f, d, h, w) = std_addf( + O(n, f, d, h, w) = AddFOp( O(n, f, d, h, w), - std_mulf(I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw))); + MulFOp(I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw))); } ods_def: @@ -162,9 +162,9 @@ `F` and generates output `O` using the following computation: ``` - O(n, oh, ow, ci, co) = std_addf( + O(n, oh, ow, ci, co) = AddFOp( O(n, oh, ow, ci, co), - std_mulf(I(n, oh * strides[0] + kh, ow * strides[1] + kw, ci), + MulFOp(I(n, oh * strides[0] + kh, ow * strides[1] + kw, ci), K(kh, kw, ci, co))); ``` @@ -184,9 +184,9 @@ Linalg reshape op which collapses `CI` and `CO` into one dimension. """ { - O(n, oh, ow, ci, co) = std_addf( + O(n, oh, ow, ci, co) = AddFOp( O(n, oh, ow, ci, co), - std_mulf(I(n, oh * strides[0] + kh, ow * strides[1] + kw, ci), + MulFOp(I(n, oh * strides[0] + kh, ow * strides[1] + kw, ci), K(kh, kw, ci, co))); } @@ -201,9 +201,9 @@ `F` and generates output `O` using the following computation: ``` -O(n, oh, ow, c) = std_addf( +O(n, oh, ow, c) = AddFOp( O(n, oh, ow, c), - std_mulf(I(n, oh * strides[0] + kh, ow * strides[1] + kw, c), + MulFOp(I(n, oh * strides[0] + kh, ow * strides[1] + kw, c), K(kh, kw, c))); ``` @@ -221,9 +221,9 @@ Note: this op only supports channel multiplier == 1. """ { - O(n, oh, ow, c) = std_addf( + O(n, oh, ow, c) = AddFOp( O(n, oh, ow, c), - std_mulf(I(n, oh * strides[0] + kh, ow * strides[1] + kw, c), + MulFOp(I(n, oh * strides[0] + kh, ow * strides[1] + kw, c), K(kh, kw, c))); } @@ -239,9 +239,9 @@ order of (`N`, `W`, `F`, `KW`, `C`). """ { - O(n, w, f) = std_addf( + O(n, w, f) = AddFOp( O(n, w, f), - std_mulf(I(n, w * strides[0] + kw * dilations[0], c), K(kw, c, f))); + MulFOp(I(n, w * strides[0] + kw * dilations[0], c), K(kw, c, f))); } ods_def: @@ -256,9 +256,9 @@ order of (`N`, `F`, `W`, `KW`, `C`). """ { - O(n, f, w) = std_addf( + O(n, f, w) = AddFOp( O(n, f, w), - std_mulf(I(n, c, w * strides[0] + kw * dilations[0]), K(kw, c, f))); + MulFOp(I(n, c, w * strides[0] + kw * dilations[0]), K(kw, c, f))); } ods_def: @@ -273,8 +273,8 @@ order of (`N`, `H`, `W`, `F`, `KH`, `KW`, `C`). """ { - O(n, h, w, f) = std_addf( - O(n, h, w, f), std_mulf(I(n, h * strides[0] + kh * dilations[0], + O(n, h, w, f) = AddFOp( + O(n, h, w, f), MulFOp(I(n, h * strides[0] + kh * dilations[0], w * strides[1] + kw * dilations[1], c), K(kh, kw, c, f))); } @@ -293,8 +293,8 @@ order of (`N`, `F`, `H`, `W`, `KH`, `KW`, `C`). """ { - O(n, f, h, w) = std_addf( - O(n, f, h, w), std_mulf(I(n, c, h * strides[0] + kh * dilations[0], + O(n, f, h, w) = AddFOp( + O(n, f, h, w), MulFOp(I(n, c, h * strides[0] + kh * dilations[0], w * strides[1] + kw * dilations[1]), K(kh, kw, c, f))); } @@ -313,8 +313,8 @@ order of (`N`, `D`, `H`, `W`, `F`, `KD`, `KH`, `KW`, `C`). """ { - O(n, d, h, w, f) = std_addf( - O(n, d, h, w, f), std_mulf(I(n, d * strides[0] + kd * dilations[0], + O(n, d, h, w, f) = AddFOp( + O(n, d, h, w, f), MulFOp(I(n, d * strides[0] + kd * dilations[0], h * strides[1] + kh * dilations[1], w * strides[2] + kw * dilations[2], c), K(kd, kh, kw, c, f))); @@ -334,8 +334,8 @@ order of (`N`, `F`, `D`, `H`, `W`, `KD`, `KH`, `KW`, `C`). """ { - O(n, f, d, h, w) = std_addf( - O(n, f, d, h, w), std_mulf(I(n, c, d * strides[0] + kd * dilations[0], + O(n, f, d, h, w) = AddFOp( + O(n, f, d, h, w), MulFOp(I(n, c, d * strides[0] + kd * dilations[0], h * strides[1] + kh * dilations[1], w * strides[2] + kw * dilations[2]), K(kd, kh, kw, c, f))); @@ -347,7 +347,7 @@ -> (O: f32(N, OH, OW, C)) attr(strides: 2xi64, dilations: 2xi64) { - O(n, oh, ow, c) = std_addf(O(n, oh, ow, c), + O(n, oh, ow, c) = AddFOp(O(n, oh, ow, c), I(n, oh * strides[0] + kh * dilations[0], ow * strides[1] + kw * dilations[1], c)); } @@ -359,7 +359,7 @@ attr(strides: 2xi64, dilations: 2xi64) { O(n, oh, ow, c) = - std_select(std_cmpi_sgt(I(n, oh * strides[0] + kh * dilations[0], + SelectOp(CmpIOpSGT(I(n, oh * strides[0] + kh * dilations[0], ow * strides[1] + kw * dilations[1], c), O(n, oh, ow, c)), I(n, oh * strides[0] + kh * dilations[0], @@ -374,7 +374,7 @@ attr(strides: 2xi64, dilations: 2xi64) { O(n, oh, ow, c) = - std_select(std_cmpi_sgt(I(n, oh * strides[0] + kh * dilations[0], + SelectOp(CmpIOpSGT(I(n, oh * strides[0] + kh * dilations[0], ow * strides[1] + kw * dilations[1], c), O(n, oh, ow, c)), I(n, oh * strides[0] + kh * dilations[0], @@ -389,7 +389,7 @@ attr(strides: 2xi64, dilations: 2xi64) { O(n, oh, ow, c) = - std_select(std_cmpi_sgt(I(n, oh * strides[0] + kh * dilations[0], + SelectOp(CmpIOpSGT(I(n, oh * strides[0] + kh * dilations[0], ow * strides[1] + kw * dilations[1], c), O(n, oh, ow, c)), I(n, oh * strides[0] + kh * dilations[0], @@ -404,7 +404,7 @@ attr(strides: 2xi64, dilations: 2xi64) { O(n, oh, ow, c) = - std_select(std_cmpf_ogt(I(n, oh * strides[0] + kh * dilations[0], + SelectOp(CmpFOpOGT(I(n, oh * strides[0] + kh * dilations[0], ow * strides[1] + kw * dilations[1], c), O(n, oh, ow, c)), I(n, oh * strides[0] + kh * dilations[0], @@ -419,7 +419,7 @@ attr(strides: 2xi64, dilations: 2xi64) { O(n, oh, ow, c) = - std_select(std_cmpf_olt(I(n, oh * strides[0] + kh * dilations[0], + SelectOp(CmpFOpOLT(I(n, oh * strides[0] + kh * dilations[0], ow * strides[1] + kw * dilations[1], c), O(n, oh, ow, c)), I(n, oh * strides[0] + kh * dilations[0], diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -15,7 +15,6 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" @@ -41,7 +40,7 @@ // TOFO: allow an extra ValueRange to specify an indexing and allow // non-hyperrectangular shapes. using LoopRangeBuilder = - std::function(OpBuilder &, Location)>; + std::function(ImplicitLocOpBuilder)>; /// Provide a very simple inference procedure to build the loop ranges from the /// op and its operands. This only works with permutation affine maps and diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -160,8 +160,10 @@ Value getSource() { return input();} Value getTarget() { return output(); } - static void regionBuilder(Block &block, ValueRange captures); - static std::function + static void regionBuilder( + ImplicitLocOpBuilder &b, Block &block, ValueRange captures); + static std::function< + void(ImplicitLocOpBuilder &b, Block &block, ValueRange captures)> getRegionBuilder() { return ®ionBuilder; } @@ -205,8 +207,10 @@ extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)}); } - static void regionBuilder(Block &block, ValueRange captures); - static std::function + static void regionBuilder( + ImplicitLocOpBuilder &b, Block &block, ValueRange captures); + static std::function< + void(ImplicitLocOpBuilder &b, Block &block, ValueRange captures)> getRegionBuilder() { return ®ionBuilder; } @@ -295,8 +299,9 @@ return padding().getValue().getValue({i, 1}); } - static std::function getRegionBuilder() - { + static std::function< + void(ImplicitLocOpBuilder &b, Block &block, ValueRange captures)> + getRegionBuilder() { return nullptr; } }]; @@ -543,7 +548,9 @@ library_call()->str() : "op_has_no_registered_library_name"; } - static std::function getRegionBuilder() { + static std::function< + void(ImplicitLocOpBuilder &b, Block &block, ValueRange captures)> + getRegionBuilder() { return nullptr; } }]; @@ -551,7 +558,6 @@ let parser = [{ return ::parseGenericOp(parser, result); }]; } -/// Index-free GenericOp. def GenericOp : GenericOpBase<"generic"> { let description = [{ Generic Linalg op form where the key properties of the computation are diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h @@ -13,6 +13,7 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Types.h" #include "llvm/ADT/StringMap.h" diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -11,7 +11,6 @@ #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" -#include "mlir/Dialect/Linalg/EDSC/Builders.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/MemRef/EDSC/Intrinsics.h" #include "mlir/Dialect/SCF/SCF.h" diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -8,7 +8,6 @@ #include "mlir-c/Dialect/Linalg.h" #include "mlir/CAPI/Registration.h" -#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" using namespace mlir; @@ -38,12 +37,11 @@ for (auto t : linalgOp.getShapedOperandTypes()) argTypes.push_back(getElementTypeOrSelf(t)); - OpBuilder b(op->getContext()); + ImplicitLocOpBuilder b(op->getLoc(), op->getContext()); Region ®ion = op->getRegion(0); Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes); b.setInsertionPointToStart(body); - mlir::edsc::ScopedContext scope(b, op->getLoc()); - fun(*body, captures); + fun(b, *body, captures); } MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect) diff --git a/mlir/lib/Dialect/Linalg/CMakeLists.txt b/mlir/lib/Dialect/Linalg/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/CMakeLists.txt @@ -1,5 +1,4 @@ add_subdirectory(Analysis) -add_subdirectory(EDSC) add_subdirectory(IR) add_subdirectory(Transforms) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ /dev/null @@ -1,255 +0,0 @@ -//===- Builders.cpp - MLIR Declarative Linalg Builders --------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/IR/Builders.h" -#include "mlir/Dialect/Affine/EDSC/Intrinsics.h" -#include "mlir/Dialect/Linalg/EDSC/Builders.h" -#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" -#include "mlir/Dialect/Math/EDSC/Intrinsics.h" -#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" -#include "mlir/IR/AffineExpr.h" - -using namespace mlir; -using namespace mlir::edsc; -using namespace mlir::edsc::intrinsics; -using namespace mlir::linalg; -using namespace mlir::scf; - -Operation *mlir::edsc::makeGenericLinalgOp( - ArrayRef iteratorTypes, ArrayRef inputs, - ArrayRef outputs, TypeRange resultTensorTypes, - function_ref regionBuilder, ArrayRef otherValues, - ArrayRef otherAttributes) { - // Build maps - SmallVector, 4> exprsList; - exprsList.reserve(inputs.size() + outputs.size()); - - for (auto container : {inputs, outputs}) - for (const StructuredIndexed &s : container) - exprsList.emplace_back(s.getExprs().begin(), s.getExprs().end()); - auto maps = AffineMap::inferFromExprList(exprsList); - - SmallVector inputValues, outputValues; - inputValues.reserve(inputs.size()); - outputValues.reserve(outputs.size()); - std::copy(inputs.begin(), inputs.end(), std::back_inserter(inputValues)); - std::copy(outputs.begin(), outputs.end(), std::back_inserter(outputValues)); - - auto iteratorStrTypes = - llvm::to_vector<8>(llvm::map_range(iteratorTypes, toString)); - // clang-format off - auto *op = - edsc::ScopedContext::getBuilderRef() - .create( - edsc::ScopedContext::getLocation(), - resultTensorTypes, - inputValues, - outputValues, - maps, - iteratorStrTypes, - ""/*doc*/, - ""/*library_call*/) - .getOperation(); - // clang-format on - - using namespace edsc; - SmallVector blockTypes; - blockTypes.reserve(inputs.size() + outputs.size()); - for (auto container : {inputs, outputs}) - for (const StructuredIndexed &s : container) - blockTypes.push_back(getElementTypeOrSelf(s.getType())); - - assert(op->getNumRegions() == 1); - assert(op->getRegion(0).empty()); - OpBuilder opBuilder(op); - ScopedContext scope(opBuilder, op->getLoc()); - buildInNewBlock(op->getRegion(0), blockTypes, regionBuilder); - assert(llvm::hasSingleElement(op->getRegion(0))); - return op; -} - -void mlir::edsc::ops::mulRegionBuilder(ValueRange args) { - using edsc::op::operator+; - using edsc::op::operator*; - assert(args.size() == 2 && "expected 2 block arguments"); - Value a(args[0]), b(args[1]); - linalg_yield(a * b); -} - -void mlir::edsc::ops::macRegionBuilder(ValueRange args) { - using edsc::op::operator+; - using edsc::op::operator*; - assert(args.size() == 3 && "expected 3 block arguments"); - Value a(args[0]), b(args[1]), c(args[2]); - linalg_yield(c + a * b); -} - -Operation *mlir::edsc::ops::linalg_generic_pointwise( - UnaryPointwiseOpBuilder unaryOp, StructuredIndexed I, StructuredIndexed O) { - SmallVector iterTypes(O.getExprs().size(), - IteratorType::Parallel); - auto fun = [&unaryOp](ValueRange args) { - assert(!args.empty() && "expected >= 1 block arguments"); - Value a(args[0]); - linalg_yield(unaryOp(a)); - }; - if (O.getType().isa()) - return makeGenericLinalgOp(iterTypes, /*inputs=*/{I}, /*outputs=*/{O}, - /*resultTensorTypes=*/{O}, fun); - return makeGenericLinalgOp(iterTypes, /*inputs=*/{I}, /*outputs=*/{O}, - /*resultTensorTypes=*/{}, fun); -} - -Operation *mlir::edsc::ops::linalg_generic_pointwise_tanh(StructuredIndexed I, - StructuredIndexed O) { - UnaryPointwiseOpBuilder unOp([](Value a) -> Value { return math_tanh(a); }); - return linalg_generic_pointwise(unOp, I, O); -} - -/// Binary pointwise operation (with broadcast) entry point. -Operation *mlir::edsc::ops::linalg_generic_pointwise( - BinaryPointwiseOpBuilder binaryOp, StructuredIndexed I1, - StructuredIndexed I2, StructuredIndexed O) { - SmallVector iterTypes(O.getExprs().size(), - IteratorType::Parallel); - auto fun = [&binaryOp](ValueRange args) { - assert(args.size() >= 2 && "expected >= 2 block arguments"); - Value a(args[0]), b(args[1]); - linalg_yield(binaryOp(a, b)); - }; - if (O.getType().isa()) - return makeGenericLinalgOp(iterTypes, /*inputs=*/{I1, I2}, /*outputs=*/{O}, - /*resultTensorTypes=*/{O}, fun); - return makeGenericLinalgOp(iterTypes, /*inputs=*/{I1, I2}, - /*outputs=*/{O}, /*resultTensorTypes=*/{}, fun); -} - -Operation *mlir::edsc::ops::linalg_generic_pointwise_add(StructuredIndexed I1, - StructuredIndexed I2, - StructuredIndexed O) { - using edsc::op::operator+; - BinaryPointwiseOpBuilder binOp( - [](Value a, Value b) -> Value { return a + b; }); - return linalg_generic_pointwise(binOp, I1, I2, O); -} - -Operation *mlir::edsc::ops::linalg_generic_pointwise_max(StructuredIndexed I1, - StructuredIndexed I2, - StructuredIndexed O) { - BinaryPointwiseOpBuilder binOp([](Value a, Value b) -> Value { - using edsc::op::sgt; - return std_select(sgt(a, b), a, b); - }); - return linalg_generic_pointwise(binOp, I1, I2, O); -} - -Operation * -mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, Value vC, - MatmulRegionBuilder regionBuilder) { - // clang-format off - AffineExpr m, n, k; - bindDims(ScopedContext::getContext(), m, n, k); - StructuredIndexed A(vA), B(vB), C(vC); - return makeGenericLinalgOp( - {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction}, - /*inputs=*/{A({m, k}), B({k, n})}, - /*outputs=*/{C({m, n})}, - /*resultTensorTypes=*/{}, - regionBuilder); - // clang-format on -} - -Operation * -mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, Value vC, - RankedTensorType tD, - MatmulRegionBuilder regionBuilder) { - // clang-format off - AffineExpr m, n, k; - bindDims(ScopedContext::getContext(), m, n, k); - StructuredIndexed A(vA), B(vB), C(vC), D(tD); - return makeGenericLinalgOp( - {IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction}, - /*inputs=*/{A({m, k}), B({k, n})}, - /*outputs=*/{C({m, n})}, - /*resultTensorTypes=*/{D({m, n})}, - regionBuilder); - // clang-format on -} - -Operation *mlir::edsc::ops::linalg_generic_conv_nhwc(Value vI, Value vW, - Value vO, - ArrayRef strides, - ArrayRef dilations) { - MLIRContext *ctx = ScopedContext::getContext(); - // TODO: some template magic to make everything rank-polymorphic. - assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm"); - assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm"); - - // Some short names. - auto par = IteratorType::Parallel; - auto red = IteratorType::Reduction; - auto s = strides; - auto d = dilations; - - AffineExpr b, f, h, w, kh, kw, c; - bindDims(ctx, b, f, h, w, kh, kw, c); - unsigned numDims = c.cast().getPosition() + 1; - StructuredIndexed I(vI), W(vW), O(vO); - // clang-format off - return makeGenericLinalgOp( - {par, par, par, par, red, red, red}, - /*inputs=*/{ - I({b, - // Roundtrip to flattened form to serve as canonicalization and ensure - // consistent ordering of subexpressions. - simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0), - simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0), - c}), - W({kh, kw, c, f}) }, - /*outputs=*/{ O({b, h, w, f}) }, - /*resultTensorTypes=*/{}, - macRegionBuilder); - // clang-format on -} - -Operation *mlir::edsc::ops::linalg_generic_dilated_conv_nhwc( - Value vI, Value vW, Value vO, int depth_multiplier, ArrayRef strides, - ArrayRef dilations) { - MLIRContext *ctx = ScopedContext::getContext(); - // TODO: some template magic to make everything rank-polymorphic. - assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm"); - assert((strides.empty() || strides.size() == 2) && "only 2-D conv atm"); - - // Some short names. - auto par = IteratorType::Parallel; - auto red = IteratorType::Reduction; - auto s = strides; - auto d = dilations; - - // clang-format off - AffineExpr b, dm, c, h, w, kh, kw; - bindDims(ctx, b, dm, c, h, w, kh, kw); - unsigned numDims = kw.cast().getPosition() + 1; - StructuredIndexed I(vI), W(vW), O(vO); - return makeGenericLinalgOp( - {par, par, par, par, par, red, red}, - /*inputs=*/{ - I({b, - // Roundtrip to flattened form to serve as canonicalization and ensure - // consistent ordering of subexpressions. - simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0), - simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0), - c}), - W({kh, kw, c, dm})}, - /*outputs=*/{ - O({b, h, w, simplifyAffineExpr(c * depth_multiplier + dm, numDims, 0)})}, - /*resultTensorTypes=*/{}, - macRegionBuilder); - // clang-format on -} diff --git a/mlir/lib/Dialect/Linalg/EDSC/CMakeLists.txt b/mlir/lib/Dialect/Linalg/EDSC/CMakeLists.txt deleted file mode 100644 --- a/mlir/lib/Dialect/Linalg/EDSC/CMakeLists.txt +++ /dev/null @@ -1,17 +0,0 @@ -add_mlir_dialect_library(MLIRLinalgEDSC - Builders.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg - - LINK_LIBS PUBLIC - MLIREDSC - MLIRIR - MLIRAffine - MLIRAffineEDSC - MLIRLinalg - MLIRMath - MLIRMemRef - MLIRSCF - MLIRStandard - ) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -13,7 +13,6 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -340,10 +339,10 @@ //===----------------------------------------------------------------------===// // CopyOp //===----------------------------------------------------------------------===// -void CopyOp::regionBuilder(Block &block, ValueRange captures) { - using namespace edsc::intrinsics; +void CopyOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, + ValueRange captures) { assert(block.getNumArguments() == 2 && "CopyOp regionBuilder expects 2 args"); - (linalg_yield(block.getArgument(0))); + b.create(block.getArgument(0)); } void CopyOp::build(OpBuilder &builder, OperationState &result, Value input, @@ -420,10 +419,10 @@ //===----------------------------------------------------------------------===// // FillOp //===----------------------------------------------------------------------===// -void FillOp::regionBuilder(Block &block, ValueRange captures) { - using namespace edsc::intrinsics; +void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, + ValueRange captures) { assert(captures.size() == 1 && "FillOp regionBuilder expects 1 capture"); - (linalg_yield(captures)); + b.create(captures); } void FillOp::build(OpBuilder &builder, OperationState &result, Value output, @@ -2769,8 +2768,8 @@ } opBuilder.setInsertionPointToStart(body); - mlir::edsc::ScopedContext scope(opBuilder, opBuilder.getUnknownLoc()); - NamedStructuredOpType::regionBuilder(*body, captures); + ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder); + NamedStructuredOpType::regionBuilder(b, *body, captures); // indexing_maps is an auto-generated method. diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -27,11 +27,9 @@ MLIRAffineUtils MLIRAnalysis MLIRComplex - MLIREDSC MLIRIR MLIRMemRef MLIRLinalgAnalysis - MLIRLinalgEDSC MLIRLinalg MLIRLinalgUtils MLIRSCF diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SmallVector.h" @@ -60,8 +61,9 @@ namedOp.getLoc(), types, namedOp.getInputs(), namedOp.getOutputs(), indexingMaps, iterators, [®ionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) { - edsc::ScopedContext scope(bodyBuilder, loc); - regionBuilder(*bodyBuilder.getBlock(), /*captures=*/{}); + ImplicitLocOpBuilder b(loc, bodyBuilder); + regionBuilder(b, *bodyBuilder.getBlock(), + /*captures=*/{}); }); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -11,17 +11,13 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" -#include "mlir/Dialect/Affine/EDSC/Intrinsics.h" #include "mlir/Dialect/Complex/IR/Complex.h" -#include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/MemRef/EDSC/Intrinsics.h" #include "mlir/Dialect/SCF/SCF.h" -#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" @@ -29,11 +25,10 @@ #include "mlir/Support/LLVM.h" #include "mlir/Transforms/FoldUtils.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/CommandLine.h" using namespace mlir; -using namespace mlir::edsc; -using namespace mlir::edsc::intrinsics; using namespace mlir::linalg; using namespace mlir::scf; @@ -80,10 +75,11 @@ /// no call back to do so is provided. The default is to allocate a /// memref<..xi8> and return a view to get a memref type of shape /// boundingSubViewSize. -static Optional defaultAllocBufferCallBack( - const LinalgPromotionOptions &options, OpBuilder &builder, - memref::SubViewOp subView, ArrayRef boundingSubViewSize, - bool dynamicBuffers, Optional alignment, DataLayout &layout) { +static Optional +defaultAllocBufferCallBack(const LinalgPromotionOptions &options, + OpBuilder &builder, memref::SubViewOp subView, + ArrayRef boundingSubViewSize, + Optional alignment, DataLayout &layout) { ShapedType viewType = subView.getType(); ImplicitLocOpBuilder b(subView.getLoc(), builder); auto zero = b.createOrFold(0); @@ -108,10 +104,10 @@ static LogicalResult defaultDeallocBufferCallBack(const LinalgPromotionOptions &options, OpBuilder &b, Value fullLocalView) { - auto viewOp = fullLocalView.getDefiningOp(); - assert(viewOp && "expected full local view to be a ViewOp"); - if (!options.useAlloca) - memref_dealloc(viewOp.source()); + if (!options.useAlloca) { + auto viewOp = cast(fullLocalView.getDefiningOp()); + b.create(viewOp.source().getLoc(), viewOp.source()); + } return success(); } @@ -125,7 +121,7 @@ LinalgOpInstancePromotionOptions(LinalgOp op, const LinalgPromotionOptions &options); /// SubViews to promote. - MapVector subViews; + MapVector subViews; /// True if the full view should be used for the promoted buffer. DenseMap useFullTileBuffers; @@ -138,6 +134,7 @@ /// Allow the use of dynamically-sized buffers. bool dynamicBuffers; + /// Alignment of promoted buffer. Optional alignment; }; @@ -148,12 +145,12 @@ : subViews(), dynamicBuffers(options.dynamicBuffers), alignment(options.alignment) { assert(linalgOp.hasBufferSemantics() && "revisit usage of shaped operand"); - unsigned nBuffers = linalgOp.getNumShapedOperands(); + int64_t nBuffers = linalgOp.getNumShapedOperands(); auto vUseFullTileBuffers = options.useFullTileBuffers.getValueOr(llvm::SmallBitVector()); vUseFullTileBuffers.resize(nBuffers, options.useFullTileBuffersDefault); - for (unsigned idx = 0; idx != nBuffers; ++idx) { + for (int64_t idx = 0; idx != nBuffers; ++idx) { if (options.operandsToPromote && !options.operandsToPromote->count(idx)) continue; auto *op = linalgOp.getShapedOperand(idx).getDefiningOp(); @@ -163,24 +160,30 @@ } } - allocationFn = (options.allocationFn - ? *(options.allocationFn) - : [&](OpBuilder &builder, memref::SubViewOp subViewOp, - ArrayRef boundingSubViewSize, - DataLayout &layout) -> Optional { - return defaultAllocBufferCallBack(options, builder, subViewOp, - boundingSubViewSize, dynamicBuffers, - alignment, layout); - }); - deallocationFn = - (options.deallocationFn - ? *(options.deallocationFn) - : [&](OpBuilder &b, Value buffer) { - return defaultDeallocBufferCallBack(options, b, buffer); - }); - auto defaultCopyCallBack = [&](OpBuilder &builder, Value src, - Value dst) -> LogicalResult { - linalg_copy(src, dst); + if (options.allocationFn) { + allocationFn = *options.allocationFn; + } else { + allocationFn = [&](OpBuilder &b, memref::SubViewOp subViewOp, + ArrayRef boundingSubViewSize, + DataLayout &layout) -> Optional { + return defaultAllocBufferCallBack(options, b, subViewOp, + boundingSubViewSize, alignment, layout); + }; + } + + if (options.deallocationFn) { + deallocationFn = *options.deallocationFn; + } else { + deallocationFn = [&](OpBuilder &b, Value buffer) { + return defaultDeallocBufferCallBack(options, b, buffer); + }; + } + + // Save the loc because `linalgOp` goes out of scope. + Location loc = linalgOp.getLoc(); + auto defaultCopyCallBack = [loc](OpBuilder &b, Value src, + Value dst) -> LogicalResult { + b.create(loc, src, dst); return success(); }; copyInFn = (options.copyInFn ? *(options.copyInFn) : defaultCopyCallBack); @@ -207,7 +210,6 @@ Optional mlir::linalg::promoteSubviewAsNewBuffer( OpBuilder &b, Location loc, memref::SubViewOp subView, AllocBufferCallbackFn allocationFn, DataLayout &layout) { - ScopedContext scopedContext(b, loc); auto viewType = subView.getType(); auto rank = viewType.getRank(); SmallVector fullSizes; @@ -223,7 +225,8 @@ (!sizeAttr) ? rangeValue.size : b.create(loc, sizeAttr); LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n"); fullSizes.push_back(size); - partialSizes.push_back(memref_dim(subView, en.index()).value); + partialSizes.push_back( + b.createOrFold(loc, subView, en.index())); } SmallVector dynSizes(fullSizes.size(), -1); // If a callback is not specified, then use the default implementation for @@ -238,20 +241,19 @@ return PromotionInfo{*fullLocalView, partialLocalView}; } -static Optional> -promoteSubViews(OpBuilder &b, Location loc, +static Optional> +promoteSubViews(ImplicitLocOpBuilder &b, LinalgOpInstancePromotionOptions options, DataLayout &layout) { if (options.subViews.empty()) return {}; - ScopedContext scope(b, loc); - MapVector promotionInfoMap; + MapVector promotionInfoMap; for (auto v : options.subViews) { memref::SubViewOp subView = cast(v.second.getDefiningOp()); Optional promotionInfo = promoteSubviewAsNewBuffer( - b, loc, subView, options.allocationFn, layout); + b, b.getLoc(), subView, options.allocationFn, layout); if (!promotionInfo) return {}; promotionInfoMap[v.first] = *promotionInfo; @@ -259,23 +261,27 @@ // Only fill the buffer if the full local view is used if (!options.useFullTileBuffers[v.second]) continue; - Value fillVal; - if (auto t = subView.getType().getElementType().dyn_cast()) { - fillVal = std_constant(FloatAttr::get(t, 0.0)); - } else if (auto t = - subView.getType().getElementType().dyn_cast()) { - fillVal = std_constant_int(0, t); - } else if (auto t = - subView.getType().getElementType().dyn_cast()) { - if (auto et = t.getElementType().dyn_cast()) - fillVal = std_constant(FloatAttr::get(et, 0.0)); - else if (auto et = t.getElementType().cast()) - fillVal = std_constant_int(0, et); - fillVal = b.create(loc, t, fillVal, fillVal); - } else { + Type subviewEltType = subView.getType().getElementType(); + Value fillVal = + llvm::TypeSwitch(subviewEltType) + .Case([&](FloatType t) { + return b.create(FloatAttr::get(t, 0.0)); + }) + .Case([&](IntegerType t) { + return b.create(IntegerAttr::get(t, 0)); + }) + .Case([&](ComplexType t) { + Value tmp; + if (auto et = t.getElementType().dyn_cast()) + tmp = b.create(FloatAttr::get(et, 0.0)); + else if (auto et = t.getElementType().cast()) + tmp = b.create(IntegerAttr::get(et, 0)); + return b.create(t, tmp, tmp); + }) + .Default([](auto) { return Value(); }); + if (!fillVal) return {}; - } - linalg_fill(promotionInfo->fullLocalView, fillVal); + b.create(promotionInfo->fullLocalView, fillVal); } // Copy data into the promoted buffers. Use callback if provided. @@ -292,7 +298,7 @@ } static Optional -promoteSubViews(OpBuilder &b, LinalgOp op, +promoteSubViews(ImplicitLocOpBuilder &b, LinalgOp op, LinalgOpInstancePromotionOptions options, DataLayout &layout) { assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); @@ -303,8 +309,7 @@ } // 1. Promote the specified views and use them in the new op. - auto loc = op.getLoc(); - auto promotedBuffersAndViews = promoteSubViews(b, loc, options, layout); + auto promotedBuffersAndViews = promoteSubViews(b, options, layout); if (!promotedBuffersAndViews || promotedBuffersAndViews->size() != options.subViews.size()) return {}; @@ -336,7 +341,6 @@ OpBuilder::InsertionGuard guard(b); b.setInsertionPointAfter(op); - ScopedContext scope(b, loc); // 3. Emit write-back for the promoted output views: copy the partial view. for (auto viewAndPartialLocalView : writebackViews) { if (failed(options.copyOutFn(b, viewAndPartialLocalView.second, @@ -372,10 +376,11 @@ } Optional -mlir::linalg::promoteSubViews(OpBuilder &b, LinalgOp linalgOp, +mlir::linalg::promoteSubViews(OpBuilder &builder, LinalgOp linalgOp, LinalgPromotionOptions options) { LinalgOpInstancePromotionOptions linalgOptions(linalgOp, options); auto layout = DataLayout::closest(linalgOp); + ImplicitLocOpBuilder b(linalgOp.getLoc(), builder); return ::promoteSubViews(b, linalgOp, linalgOptions, layout); } @@ -388,14 +393,14 @@ } void runOnFunction() override { - getFunction().walk([this](LinalgOp op) { + getFunction().walk([&](LinalgOp op) { auto options = LinalgPromotionOptions() .setDynamicBuffers(dynamicBuffers) .setUseAlloca(useAlloca); if (failed(promoteSubviewsPrecondition(op, options))) return; LLVM_DEBUG(llvm::dbgs() << "Promote: " << *(op.getOperation()) << "\n"); - OpBuilder b(op); + ImplicitLocOpBuilder b(op.getLoc(), op); promoteSubViews(b, op, options); }); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -12,7 +12,6 @@ #include "PassDetail.h" #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" -#include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" diff --git a/mlir/lib/Dialect/Linalg/Utils/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Utils/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Utils/CMakeLists.txt @@ -6,9 +6,8 @@ LINK_LIBS PUBLIC MLIRAffine - MLIREDSC + MLIRAffineEDSC MLIRIR - MLIRLinalgEDSC MLIRLinalg MLIRSCF MLIRPass diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -23,15 +23,16 @@ // IMPL-NEXT: map2 = simplifyAffineMap(map2); // IMPL-NEXT: return {{.+}}.getAffineMapArrayAttr({ map0, map1, map2 }); // -// IMPL: void Test1Op::regionBuilder(Block &block, ValueRange captures) { +// IMPL: void Test1Op::regionBuilder(ImplicitLocOpBuilder &b, +// IMPL: Block &block, ValueRange captures) { // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); -// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]); -// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]); -// IMPL: (linalg_yield(ValueRange{ [[e]] })); +// IMPL: Value [[d:.*]] = b.create([[a]], [[b]]); +// IMPL: Value [[e:.*]] = b.create([[c]], [[d]]); +// IMPL: b.create(ValueRange{ [[e]] }); // ods_def : def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) { - C(m) = std_addf(C(m), std_mulf(A(m, k), B(k))); + C(m) = AddFOp(C(m), MulFOp(A(m, k), B(k))); } // ODS-LABEL: def Test2Op : LinalgStructuredBase_Op<"test2", [ @@ -47,15 +48,16 @@ // IMPL: AffineMap::get(3, 3, {d2, d1}, context) // IMPL: AffineMap::get(3, 3, {d0, d1}, context) // -// IMPL: Test2Op::regionBuilder(Block &block, ValueRange captures) { +// IMPL: Test2Op::regionBuilder(ImplicitLocOpBuilder &b, +// IMPL: Block &block, ValueRange captures) { // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); -// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]); -// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]); -// IMPL: (linalg_yield(ValueRange{ [[e]] })); +// IMPL: Value [[d:.*]] = b.create([[a]], [[b]]); +// IMPL: Value [[e:.*]] = b.create([[c]], [[d]]); +// IMPL: b.create(ValueRange{ [[e]] }); // ods_def : def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) { - C(m, n) = std_addf(C(m, n), std_mulf(A(m, k), B(k, n))); + C(m, n) = AddFOp(C(m, n), MulFOp(A(m, k), B(k, n))); } // ODS-LABEL: def Test3Op : LinalgStructuredBase_Op<"test3", [ @@ -71,15 +73,16 @@ // IMPL: AffineMap::get(4, 4, {d3, d2}, context) // IMPL: AffineMap::get(4, 4, {d0, d1, d2}, context) // -// IMPL: Test3Op::regionBuilder(Block &block, ValueRange captures) { +// IMPL: Test3Op::regionBuilder(ImplicitLocOpBuilder &b, +// IMPL: Block &block, ValueRange captures) { // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); -// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]); -// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]); -// IMPL: (linalg_yield(ValueRange{ [[e]] })); +// IMPL: Value [[d:.*]] = b.create([[a]], [[b]]); +// IMPL: Value [[e:.*]] = b.create([[c]], [[d]]); +// IMPL: b.create(ValueRange{ [[e]] }); // ods_def : def test3(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) { - C(b, m, n) = std_addf(C(b, m, n), std_mulf(A(b, m, k), B(k, n))); + C(b, m, n) = AddFOp(C(b, m, n), MulFOp(A(b, m, k), B(k, n))); } // Test attribute definitions @@ -115,7 +118,7 @@ array_attr : f32[], optional_attr? : f32 ) { - C(b, m, n) = std_addf(C(b, m, n), std_mulf(A(b, m, k), B(k, n))); + C(b, m, n) = AddFOp(C(b, m, n), MulFOp(A(b, m, k), B(k, n))); } // Test attribute usage in affine expressions @@ -136,8 +139,8 @@ ods_def: def test5(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F)) attr(strides: 2xi32) { - O(n, h, w, f) = std_addf( - std_mulf(std_addf(I(n, h * strides[0] + kh, w * strides[1] + kw, c), + O(n, h, w, f) = AddFOp( + MulFOp(AddFOp(I(n, h * strides[0] + kh, w * strides[1] + kw, c), I(n, h * strides[0] + kh, w * strides[1] + kw, c)), K(f, kh, kw, c))); } @@ -159,7 +162,7 @@ It has one output. """ { - C(m) = std_addf(C(m), std_mulf(A(m, k), B(k))); + C(m) = AddFOp(C(m), MulFOp(A(m, k), B(k))); } // Test attribute builder @@ -174,19 +177,20 @@ def test7(A: f32(M, K), B: f32(K)) -> (C: f32(M)) attr(attr_a: f32, attr_b: 4xi32) { - C(m) = std_addf(C(m), std_mulf(A(m, k), B(k))); + C(m) = AddFOp(C(m), MulFOp(A(m, k), B(k))); } // Test output arg order. -// IMPL-LABEL: void Test8Op::regionBuilder(Block &block, ValueRange captures) { +// IMPL-LABEL: void Test8Op::regionBuilder(ImplicitLocOpBuilder &b, +// IMPL: Block &block, ValueRange captures) { // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); -// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]); -// IMPL: Value [[e:.*]] = std_subf([[d]], [[c]]); -// IMPL: (linalg_yield(ValueRange{ [[e]] })); +// IMPL: Value [[d:.*]] = b.create([[a]], [[b]]); +// IMPL: Value [[e:.*]] = b.create([[d]], [[c]]); +// IMPL: b.create(ValueRange{ [[e]] }); ods_def: def test8(A: f32(M, K), B: f32(K)) -> (C: f32(M)) { - C(m) = std_subf(std_mulf(A(m, k), B(k)), C(m)); + C(m) = SubFOp(MulFOp(A(m, k), B(k)), C(m)); } // Test shape-only operand. @@ -194,10 +198,11 @@ // IMPL: auto map0 = AffineMap::get(2, 2, {d0, d1}, context); // IMPL: auto map1 = AffineMap::get(2, 2, {d1}, context); // IMPL: auto map2 = AffineMap::get(2, 2, {d0}, context); -// IMPL-LABEL: void Test9Op::regionBuilder(Block &block, ValueRange captures) { +// IMPL-LABEL: void Test9Op::regionBuilder(ImplicitLocOpBuilder &b, +// IMPL: Block &block, ValueRange captures) { // IMPL: Value [[a:.*]](args[0]), [[c:.*]](args[2]); ods_def: def test9(A: f32(M, K), B: f32(K)) -> (C: f32(M)) { - C(m) = std_addf(C(m), A(m, k)); + C(m) = AddFOp(C(m), A(m, k)); } diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -75,8 +75,8 @@ # ODS-NEXT: TypeRange(inputs), # ODS-NEXT: TypeRange(outputs) -# IMPL-LABEL: void Test1Op::regionBuilder -# IMPL-SAME: (Block &block, ValueRange captures) +# IMPL-LABEL: void Test1Op::regionBuilder( +# IMPL: ImplicitLocOpBuilder &b, Block &block, ValueRange captures) # IMPL: Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64"); # IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL0]]); # IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1); @@ -133,5 +133,6 @@ # IMPL: "affine_map<(d0, d1)[s0, s1] -> (d1, d0)>" # IMPL: "affine_map<(d0, d1)[s0, s1] -> (d0, d1)>" -# IMPL: void Test2Op::regionBuilder(Block &block, ValueRange captures) +# IMPL: void Test2Op::regionBuilder( +# IMPL: ImplicitLocOpBuilder &b, Block &block, ValueRange captures) # IMPL: yields.push_back(block.getArgument(0)); diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -72,6 +72,30 @@ using namespace mlir; +//===----------------------------------------------------------------------===// +// Special "op aliases" substitutions. +//===----------------------------------------------------------------------===// + +/// Perform substitutions of known special ops. +/// This is a poor man's way of achieving "op aliases": i.e. giving an op a +/// name. +/// This is hacky and temporary until migration to the python opdsl is complete. +static void substituteOpAliases(std::string &expressionsStr) { + for (auto kvp : SmallVector>{ + {"b.create(", "b.create(CmpIPredicate::sgt, "}, + {"b.create(", "b.create(CmpFPredicate::OGT, "}, + {"b.create(", "b.create(CmpFPredicate::OLT, "}, + {"b.create(", + "b.create(b.getI32Type(), "}, + }) { + size_t pos = 0; + while ((pos = expressionsStr.find(kvp.first, pos)) != std::string::npos) { + expressionsStr.replace(pos, kvp.first.size(), kvp.second); + pos += kvp.second.size(); + } + } +} + //===----------------------------------------------------------------------===// // Lexer //===----------------------------------------------------------------------===// @@ -1941,8 +1965,10 @@ // Auto-generated. ArrayAttr iterator_types(); ArrayAttr indexing_maps(); - static void regionBuilder(Block &block, ValueRange captures); - static std::function getRegionBuilder() {{ + static void regionBuilder(ImplicitLocOpBuilder &b, + Block &block, ValueRange captures); + static std::function getRegionBuilder() {{ return regionBuilder; } @@ -2325,7 +2351,7 @@ printExpr(subExprsStringStream, *e); }); subExprsStringStream.flush(); - const char *tensorExprFmt = "\n Value _{0} = {1}({2});"; + const char *tensorExprFmt = "\n Value _{0} = b.create<{1}>({2});"; os << llvm::formatv(tensorExprFmt, ++count, pTensorExpr->operationName, subExprs); subExprsMap[pTensorExpr] = count; @@ -2333,13 +2359,12 @@ }; const char *regionBuilderFmt = R"FMT( - void {0}::regionBuilder(Block &block, ValueRange captures) { - using namespace edsc; - using namespace intrinsics; + void {0}::regionBuilder(ImplicitLocOpBuilder &b, + Block &block, ValueRange captures) { auto args = block.getArguments(); Value {1}; {2} - (linalg_yield(ValueRange{ {3} })); + b.create(ValueRange{ {3} }); })FMT"; std::string valueHandleStr; @@ -2358,6 +2383,8 @@ if (e.kind == Expression::Kind::TensorExpr) printExpr(expressionStringStream, e); }); + expressionStringStream.flush(); + substituteOpAliases(expressionsStr); std::string yieldStr; llvm::raw_string_ostream yieldStringStream(yieldStr); @@ -2367,7 +2394,6 @@ }); valueHandleStringStream.flush(); - expressionStringStream.flush(); yieldStringStream.flush(); os << llvm::formatv(regionBuilderFmt, cppOpName, valueHandleStr, diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -509,8 +509,11 @@ // Auto-generated. ArrayAttr iterator_types(); ArrayAttr indexing_maps(); - static void regionBuilder(Block &block, ValueRange captures); - static std::function getRegionBuilder() {{ + static void regionBuilder( + ImplicitLocOpBuilder &b, Block &block, ValueRange captures); + static std::function< + void(ImplicitLocOpBuilder &b, Block &, ValueRange)> + getRegionBuilder() {{ return regionBuilder; } @@ -755,7 +758,8 @@ // {1}: Number of args // {2}: Statements static const char structuredOpRegionBuilderFormat[] = R"FMT( -void {0}::regionBuilder(Block &block, ValueRange captures) {{ +void {0}::regionBuilder( + ImplicitLocOpBuilder &b, Block &block, ValueRange captures) {{ assert({1} > 0 && block.getNumArguments() == {1} && "{0} regionBuilder expects {1} (>=0) args"); RegionBuilderHelper helper(block.getArgument(0).getContext(), block);