diff --git a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt --- a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt +++ b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt @@ -1,3 +1,7 @@ +set(LLVM_TARGET_DEFINITIONS ShapeToStandard.td) +mlir_tablegen(ShapeToStandard.cpp.inc -gen-rewriters) +add_public_tablegen_target(ShapeToStandardIncGen) + add_mlir_conversion_library(MLIRShapeToStandard ConvertShapeConstraints.cpp ShapeToStandard.cpp @@ -7,6 +11,7 @@ DEPENDS MLIRConversionPassIncGen + ShapeToStandardIncGen LINK_COMPONENTS Core diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -569,6 +569,9 @@ } }; +/// Import the GPU Ops to NVVM Patterns. +#include "ShapeToStandard.cpp.inc" + } // namespace namespace { @@ -585,7 +588,7 @@ MLIRContext &ctx = getContext(); ConversionTarget target(ctx); target.addLegalDialect(); - target.addLegalOp(); + target.addLegalOp(); // Setup conversion patterns. OwningRewritePatternList patterns; @@ -600,6 +603,7 @@ void mlir::populateShapeToStandardConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { // clang-format off + populateWithGenerated(ctx, patterns); patterns.insert< AnyOpConversion, BinaryOpConversion, diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td @@ -0,0 +1,27 @@ +//==-- ShapeToStandard.td - Shape to Standard Patterns -------*- tablegen -*==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines Patterns to lower Shape ops to Std. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_SHAPETOSTANDARD_TD +#define MLIR_CONVERSION_SHAPETOSTANDARD_TD + +include "mlir/Dialect/Shape/IR/ShapeOps.td" + +def BroadcastableStringAttr : NativeCodeCall<[{ + $_builder.getStringAttr("required broadcastable shapes") +}]>; + +def : Pat<(Shape_CstrBroadcastableOp $LHS, $RHS), + (Shape_CstrRequireOp + (Shape_IsBroadcastableOp $LHS, $RHS), + (BroadcastableStringAttr))>; + +#endif // MLIR_CONVERSION_SHAPETOSTANDARD_TD diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -420,3 +420,42 @@ // CHECK: } // CHECK: return %[[ALL_RESULT]] : i1 // CHECK: } + +// ----- + +func @broadcast(%a : tensor, %b : tensor) -> !shape.witness { + %0 = shape.cstr_broadcastable %a, %b : tensor, tensor + return %0 : !shape.witness +} + +// CHECK-LABEL: func @broadcast( +// CHECK-SAME: %[[LHS:.*]]: tensor, +// CHECK-SAME: %[[RHS:.*]]: tensor) -> !shape.witness { +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor +// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor +// CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index +// CHECK: %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index +// CHECK: %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index +// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor to tensor +// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor to tensor +// CHECK: %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor +// CHECK: %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor +// CHECK: %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index +// CHECK: %[[TRUE:.*]] = constant true +// CHECK: %[[ALL_RESULT:.*]] = scf.for %[[VAL_16:.*]] = %[[RANK_DIFF]] to %[[LARGER_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) { +// CHECK: %[[LARGER_EXTENT:.*]] = extract_element %[[LARGER_SHAPE]]{{\[}}%[[VAL_16]]] : tensor +// CHECK: %[[LARGER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[C1]] : index +// CHECK: %[[LHS_EXTENT_INDEX:.*]] = subi %[[VAL_16]], %[[RANK_DIFF]] : index +// CHECK: %[[SMALLER_EXTENT:.*]] = extract_element %[[SMALLER_SHAPE]]{{\[}}%[[LHS_EXTENT_INDEX]]] : tensor +// CHECK: %[[SMALLER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[SMALLER_EXTENT]], %[[C1]] : index +// CHECK: %[[EXTENTS_ARE_EQUAL:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[SMALLER_EXTENT]] : index +// CHECK: %[[EITHER_EXTENT_IS_ONE:.*]] = or %[[LARGER_EXTENT_IS_ONE]], %[[SMALLER_EXTENT_IS_ONE]] : i1 +// CHECK: %[[OR_EXTENTS_ARE_EQUAL:.*]] = or %[[EITHER_EXTENT_IS_ONE]], %[[EXTENTS_ARE_EQUAL]] : i1 +// CHECK: %[[NEW_ALL_SO_FAR:.*]] = and %[[ALL_SO_FAR]], %[[OR_EXTENTS_ARE_EQUAL]] : i1 +// CHECK: scf.yield %[[NEW_ALL_SO_FAR]] : i1 +// CHECK: } +// CHECK: %[[RESULT:.*]] = shape.cstr_require %[[ALL_RESULT]], "required broadcastable shapes" +// CHECK: return %[[RESULT]] : !shape.witness +// CHECK: }