[mlir] Enable specifying verify on OpInterface

Summary:
Add method in ODS to specify verification for operations implementing a
OpInterface. Use this with infer type op interface to verify that the
inferred type matches the return type and remove special case in
TestPatterns.

This could also have been achieved by using OpInterfaceMethod but verify
seems pretty common and it is not an arbitrary method that just happened
to be named verifyTrait, so having it be defined in special way seems
appropriate/better documenting.

Differential Revision: https://reviews.llvm.org/D73122
This commit is contained in:
Jacques Pienaar 2020-01-21 09:40:22 -08:00
parent 963f268186
commit 178562fb35
10 changed files with 56 additions and 25 deletions

View File

@ -400,6 +400,10 @@ def OpWithInferTypeInterfaceOp : Op<...
[DeclareOpInterfaceMethods<MyInterface>]> { ... } [DeclareOpInterfaceMethods<MyInterface>]> { ... }
``` ```
A verification method can also be specified on the `OpInterface` by setting
`verify`. Setting `verify` results in the generated trait having a `verifyTrait`
method that is applied to all operations implementing the trait.
### Builder methods ### Builder methods
For each operation, there are a few builders automatically generated based on For each operation, there are a few builders automatically generated based on

View File

@ -72,8 +72,6 @@ private:
Attribute attr; Attribute attr;
}; };
#include "mlir/Analysis/InferTypeOpInterface.h.inc"
namespace detail { namespace detail {
// Helper function to infer return tensor returns types given element and shape // Helper function to infer return tensor returns types given element and shape
// inference function. // inference function.
@ -89,8 +87,14 @@ LogicalResult inferReturnTensorTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands, MLIRContext *context, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions, ArrayRef<NamedAttribute> attributes, RegionRange regions,
SmallVectorImpl<Type> &inferedReturnTypes); SmallVectorImpl<Type> &inferedReturnTypes);
/// Verifies that the inferred result types match the actual result types for
/// the op. Precondition: op implements InferTypeOpInterface.
LogicalResult verifyInferredResultTypes(Operation *op);
} // namespace detail } // namespace detail
#include "mlir/Analysis/InferTypeOpInterface.h.inc"
namespace OpTrait { namespace OpTrait {
/// Tensor type inference trait that constructs a tensor from the infered /// Tensor type inference trait that constructs a tensor from the infered

View File

@ -60,6 +60,10 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
}] }]
>, >,
]; ];
let verify = [{
return detail::verifyInferredResultTypes($_op);
}];
} }
def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> { def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {

View File

@ -1411,8 +1411,12 @@ def ins;
// OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in // OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in
// C++. The purpose to wrap around C++ symbol string with this class is to make // C++. The purpose to wrap around C++ symbol string with this class is to make
// interfaces specified for ops in TableGen less alien and more integrated. // interfaces specified for ops in TableGen less alien and more integrated.
class OpInterfaceTrait<string name> : NativeOpTrait<""> { class OpInterfaceTrait<string name, code verifyBody = [{}]> : NativeOpTrait<""> {
let trait = name # "::Trait"; let trait = name # "::Trait";
// Specify the body of the verification function. `$_op` will be replaced with
// the operation being verified.
code verify = verifyBody;
} }
// This class represents a single, optionally static, interface method. // This class represents a single, optionally static, interface method.

View File

@ -86,6 +86,9 @@ public:
// Return the description of this method if it has one. // Return the description of this method if it has one.
llvm::Optional<StringRef> getDescription() const; llvm::Optional<StringRef> getDescription() const;
// Return the verify method body if it has one.
llvm::Optional<StringRef> getVerify() const;
private: private:
// The TableGen definition of this interface. // The TableGen definition of this interface.
const llvm::Record *def; const llvm::Record *def;

View File

@ -45,3 +45,17 @@ LogicalResult mlir::detail::inferReturnTensorTypes(
} }
return success(); return success();
} }
LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
SmallVector<Type, 4> inferedReturnTypes;
auto retTypeFn = cast<InferTypeOpInterface>(op);
if (failed(retTypeFn.inferReturnTypes(op->getContext(), op->getLoc(),
op->getOperands(), op->getAttrs(),
op->getRegions(), inferedReturnTypes)))
return failure();
SmallVector<Type, 4> resultTypes(op->getResultTypes());
if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes))
return op->emitOpError(
"inferred type incompatible with return type of operation");
return success();
}

View File

@ -85,3 +85,9 @@ llvm::Optional<StringRef> OpInterface::getDescription() const {
auto value = def->getValueAsString("description"); auto value = def->getValueAsString("description");
return value.empty() ? llvm::Optional<StringRef>() : value; return value.empty() ? llvm::Optional<StringRef>() : value;
} }
// Return the body for this method if it has one.
llvm::Optional<StringRef> OpInterface::getVerify() const {
auto value = def->getValueAsString("verify");
return value.empty() ? llvm::Optional<StringRef>() : value;
}

View File

@ -103,26 +103,6 @@ struct TestReturnTypeDriver : public FunctionPass<TestReturnTypeDriver> {
}; };
return; return;
} }
// Verification check.
// TODO: Move to ops that implement type infer interface.
getFunction().walk([this](Operation *op) -> void {
auto retTypeFn = dyn_cast<InferTypeOpInterface>(op);
if (!retTypeFn)
return;
auto *context = &getContext();
SmallVector<Type, 2> inferedReturnTypes;
if (failed(retTypeFn.inferReturnTypes(
context, op->getLoc(), op->getOperands(), op->getAttrs(),
op->getRegions(), inferedReturnTypes)))
return;
SmallVector<Type, 1> resultTypes(op->getResultTypes());
if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes)) {
op->emitOpError(
"inferred type incompatible with return type of operation");
return;
}
});
} }
}; };
} // end anonymous namespace } // end anonymous namespace

View File

@ -23,7 +23,6 @@ func @testCreateFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xi32>) {
// ----- // -----
// CHECK-LABEL: testReturnTypeOpInterface
func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) { func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) {
// expected-error@+1 {{incompatible with return type}} // expected-error@+1 {{incompatible with return type}}
%bad = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32> %bad = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
@ -32,7 +31,6 @@ func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) {
// ----- // -----
// CHECK-LABEL: testReturnTypeOpInterface
func @testReturnTypeOpInterfaceMismatch(%arg0 : tensor<10xf32>, %arg1 : tensor<20xf32>) { func @testReturnTypeOpInterfaceMismatch(%arg0 : tensor<10xf32>, %arg1 : tensor<20xf32>) {
// expected-error@+1 {{operand type mismatch}} // expected-error@+1 {{operand type mismatch}}
%bad = "test.op_with_infer_type_if"(%arg0, %arg1) : (tensor<10xf32>, tensor<20xf32>) -> tensor<*xf32> %bad = "test.op_with_infer_type_if"(%arg0, %arg1) : (tensor<10xf32>, tensor<20xf32>) -> tensor<*xf32>

View File

@ -12,6 +12,7 @@
#include "DocGenUtilities.h" #include "DocGenUtilities.h"
#include "mlir/Support/STLExtras.h" #include "mlir/Support/STLExtras.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/OpInterfaces.h" #include "mlir/TableGen/OpInterfaces.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
@ -152,6 +153,12 @@ static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
// Insert the default implementation for any methods. // Insert the default implementation for any methods.
for (auto &method : interface.getMethods()) { for (auto &method : interface.getMethods()) {
// Flag interface methods named verifyTrait.
if (method.getName() == "verifyTrait")
PrintFatalError(
formatv("'verifyTrait' method cannot be specified as interface "
"method for '{0}'; set 'verify' on OpInterfaceTrait instead",
interfaceName));
auto defaultImpl = method.getDefaultImplementation(); auto defaultImpl = method.getDefaultImplementation();
if (!defaultImpl) if (!defaultImpl)
continue; continue;
@ -162,6 +169,13 @@ static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
os << " {\n" << defaultImpl.getValue() << " }\n"; os << " {\n" << defaultImpl.getValue() << " }\n";
} }
tblgen::FmtContext traitCtx;
traitCtx.withOp("op");
if (auto verify = interface.getVerify()) {
os << " static LogicalResult verifyTrait(Operation* op) {\n"
<< tblgen::tgfmt(*verify, &traitCtx) << "\n }\n";
}
os << " };\n"; os << " };\n";
} }